module Crypto.Saltine.Internal.Util where

import           Foreign.C
import           Foreign.Marshal.Alloc    (mallocBytes)
import           Foreign.Ptr
import           System.IO.Unsafe

import           Control.Applicative
import qualified Data.ByteString        as S
import           Data.ByteString          (ByteString)
import           Data.ByteString.Unsafe
import           Data.Monoid

-- | Returns @Nothing@ if the subtraction would result in an
-- underflow or a negative number.
safeSubtract :: (Ord a, Num a) => a -> a -> Maybe a
a
x safeSubtract :: a -> a -> Maybe a
`safeSubtract` a
y = if a
y a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
x then Maybe a
forall a. Maybe a
Nothing else a -> Maybe a
forall a. a -> Maybe a
Just (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
y)

-- | @snd . cycleSucc@ computes the 'succ' of a 'Bounded', 'Eq' 'Enum'
-- with wraparound. The @fst . cycleSuc@ is whether the wraparound
-- occurred (i.e. @fst . cycleSucc == (== maxBound)@).
cycleSucc :: (Bounded a, Enum a, Eq a) => a -> (Bool, a)
cycleSucc :: a -> (Bool, a)
cycleSucc a
a = (Bool
top, if Bool
top then a
forall a. Bounded a => a
minBound else a -> a
forall a. Enum a => a -> a
succ a
a)
  where top :: Bool
top = a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Bounded a => a
maxBound

-- | Treats a 'ByteString' as a little endian bitstring and increments
-- it.
nudgeBS :: ByteString -> ByteString
nudgeBS :: ByteString -> ByteString
nudgeBS ByteString
i = (ByteString, Maybe (Bool, ByteString)) -> ByteString
forall a b. (a, b) -> a
fst ((ByteString, Maybe (Bool, ByteString)) -> ByteString)
-> (ByteString, Maybe (Bool, ByteString)) -> ByteString
forall a b. (a -> b) -> a -> b
$ Int
-> ((Bool, ByteString) -> Maybe (Word8, (Bool, ByteString)))
-> (Bool, ByteString)
-> (ByteString, Maybe (Bool, ByteString))
forall a.
Int -> (a -> Maybe (Word8, a)) -> a -> (ByteString, Maybe a)
S.unfoldrN (ByteString -> Int
S.length ByteString
i) (Bool, ByteString) -> Maybe (Word8, (Bool, ByteString))
go (Bool
True, ByteString
i) where
  go :: (Bool, ByteString) -> Maybe (Word8, (Bool, ByteString))
go (Bool
toSucc, ByteString
bs) = do
    (Word8
hd, ByteString
tl)      <- ByteString -> Maybe (Word8, ByteString)
S.uncons ByteString
bs
    let (Bool
top, Word8
hd') = Word8 -> (Bool, Word8)
forall a. (Bounded a, Enum a, Eq a) => a -> (Bool, a)
cycleSucc Word8
hd

    if   Bool
toSucc
    then (Word8, (Bool, ByteString)) -> Maybe (Word8, (Bool, ByteString))
forall (m :: * -> *) a. Monad m => a -> m a
return (Word8
hd', (Bool
top, ByteString
tl))
    else (Word8, (Bool, ByteString)) -> Maybe (Word8, (Bool, ByteString))
forall (m :: * -> *) a. Monad m => a -> m a
return (Word8
hd, (Bool
top Bool -> Bool -> Bool
&& Bool
toSucc, ByteString
tl))

-- | Computes the orbit of a endomorphism... in a very brute force
-- manner. Exists just for the below property.
--
-- prop> length . orbit nudgeBS . S.pack . replicate 0 == (256^)
orbit :: Eq a => (a -> a) -> a -> [a]
orbit :: (a -> a) -> a -> [a]
orbit a -> a
f a
a0 = a -> [a]
orbit' (a -> a
f a
a0) where
  orbit' :: a -> [a]
orbit' a
a = if a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
a0 then [a
a0] else a
a a -> [a] -> [a]
forall a. a -> [a] -> [a]
: a -> [a]
orbit' (a -> a
f a
a)

-- | 0-pad a 'ByteString'
pad :: Int -> ByteString -> ByteString
pad :: Int -> ByteString -> ByteString
pad Int
n = ByteString -> ByteString -> ByteString
forall a. Monoid a => a -> a -> a
mappend (Int -> Word8 -> ByteString
S.replicate Int
n Word8
0)

-- | Remove a 0-padding from a 'ByteString'
unpad :: Int -> ByteString -> ByteString
unpad :: Int -> ByteString -> ByteString
unpad = Int -> ByteString -> ByteString
S.drop

-- | Converts a C-convention errno to an Either
handleErrno :: CInt -> (a -> Either String a)
handleErrno :: CInt -> a -> Either String a
handleErrno CInt
err a
a = case CInt
err of
  CInt
0  -> a -> Either String a
forall a b. b -> Either a b
Right a
a
  -1 -> String -> Either String a
forall a b. a -> Either a b
Left String
"failed"
  CInt
n  -> String -> Either String a
forall a b. a -> Either a b
Left (String
"unexpected error code: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ CInt -> String
forall a. Show a => a -> String
show CInt
n)

unsafeDidSucceed :: IO CInt -> Bool
unsafeDidSucceed :: IO CInt -> Bool
unsafeDidSucceed = CInt -> Bool
forall a. (Eq a, Num a) => a -> Bool
go (CInt -> Bool) -> (IO CInt -> CInt) -> IO CInt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO CInt -> CInt
forall a. IO a -> a
unsafePerformIO
  where go :: a -> Bool
go a
0 = Bool
True
        go a
_ = Bool
False

-- | Convenience function for accessing constant C strings
constByteStrings :: [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings :: [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings =
  (ByteString
 -> (([CStringLen] -> IO b) -> IO b)
 -> ([CStringLen] -> IO b)
 -> IO b)
-> (([CStringLen] -> IO b) -> IO b)
-> [ByteString]
-> ([CStringLen] -> IO b)
-> IO b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\ByteString
v ([CStringLen] -> IO b) -> IO b
kk -> \[CStringLen] -> IO b
k -> (ByteString -> (CStringLen -> IO b) -> IO b
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
v) (\CStringLen
a -> ([CStringLen] -> IO b) -> IO b
kk (\[CStringLen]
as -> [CStringLen] -> IO b
k (CStringLen
aCStringLen -> [CStringLen] -> [CStringLen]
forall a. a -> [a] -> [a]
:[CStringLen]
as)))) (([CStringLen] -> IO b) -> [CStringLen] -> IO b
forall a b. (a -> b) -> a -> b
$ [])

-- | Slightly safer cousin to 'buildUnsafeByteString' that remains in the
-- 'IO' monad.
buildUnsafeByteString' :: Int -> (Ptr CChar -> IO b) -> IO (b, ByteString)
buildUnsafeByteString' :: Int -> (Ptr CChar -> IO b) -> IO (b, ByteString)
buildUnsafeByteString' Int
n Ptr CChar -> IO b
k = do
  Ptr CChar
ph  <- Int -> IO (Ptr CChar)
forall a. Int -> IO (Ptr a)
mallocBytes Int
n
  ByteString
bs  <- CStringLen -> IO ByteString
unsafePackMallocCStringLen (Ptr CChar
ph, Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
  b
out <- ByteString -> (Ptr CChar -> IO b) -> IO b
forall a. ByteString -> (Ptr CChar -> IO a) -> IO a
unsafeUseAsCString ByteString
bs Ptr CChar -> IO b
k
  (b, ByteString) -> IO (b, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (b
out, ByteString
bs)

-- | Extremely unsafe function, use with utmost care! Builds a new
-- ByteString using a ccall which is given access to the raw underlying
-- pointer. Overwrites are UNCHECKED and 'unsafePerformIO' is used so
-- it's difficult to predict the timing of the 'ByteString' creation.
buildUnsafeByteString :: Int -> (Ptr CChar -> IO b) -> (b, ByteString)
buildUnsafeByteString :: Int -> (Ptr CChar -> IO b) -> (b, ByteString)
buildUnsafeByteString Int
n = IO (b, ByteString) -> (b, ByteString)
forall a. IO a -> a
unsafePerformIO (IO (b, ByteString) -> (b, ByteString))
-> ((Ptr CChar -> IO b) -> IO (b, ByteString))
-> (Ptr CChar -> IO b)
-> (b, ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (Ptr CChar -> IO b) -> IO (b, ByteString)
forall b. Int -> (Ptr CChar -> IO b) -> IO (b, ByteString)
buildUnsafeByteString' Int
n

-- | Build a sized random 'ByteString' using Sodium's bindings to
-- @/dev/urandom@.
randomByteString :: Int -> IO ByteString
randomByteString :: Int -> IO ByteString
randomByteString Int
n =
  ((), ByteString) -> ByteString
forall a b. (a, b) -> b
snd (((), ByteString) -> ByteString)
-> IO ((), ByteString) -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (Ptr CChar -> IO ()) -> IO ((), ByteString)
forall b. Int -> (Ptr CChar -> IO b) -> IO (b, ByteString)
buildUnsafeByteString' Int
n (Ptr CChar -> CInt -> IO ()
`c_randombytes_buf` Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)

-- | To prevent a dependency on package 'errors'
hush :: Either s a -> Maybe a
hush :: Either s a -> Maybe a
hush = (s -> Maybe a) -> (a -> Maybe a) -> Either s a -> Maybe a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe a -> s -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing) a -> Maybe a
forall a. a -> Maybe a
Just

foreign import ccall "randombytes_buf"
  c_randombytes_buf :: Ptr CChar -> CInt -> IO ()