{-# LANGUAGE BangPatterns #-}
module Crypto.Saltine.Internal.Util (
module Crypto.Saltine.Internal.Util
, withCString
, allocaBytes
)
where
import Data.ByteString (ByteString)
import Data.ByteString.Unsafe
import Data.Monoid
import Foreign.C
import Foreign.Marshal.Alloc (mallocBytes, allocaBytes)
import Foreign.Ptr
import GHC.Word (Word8)
import System.IO.Unsafe
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
safeSubtract :: (Ord a, Num a) => a -> a -> Maybe a
a
x safeSubtract :: a -> a -> Maybe a
`safeSubtract` a
y = if a
y a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
x then Maybe a
forall a. Maybe a
Nothing else a -> Maybe a
forall a. a -> Maybe a
Just (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
y)
cycleSucc :: (Bounded a, Enum a, Eq a) => a -> (Bool, a)
cycleSucc :: a -> (Bool, a)
cycleSucc a
a = (Bool
top, if Bool
top then a
forall a. Bounded a => a
minBound else a -> a
forall a. Enum a => a -> a
succ a
a)
where top :: Bool
top = a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Bounded a => a
maxBound
nudgeBS :: ByteString -> ByteString
nudgeBS :: ByteString -> ByteString
nudgeBS ByteString
i = (ByteString, Maybe (Bool, ByteString)) -> ByteString
forall a b. (a, b) -> a
fst ((ByteString, Maybe (Bool, ByteString)) -> ByteString)
-> (ByteString, Maybe (Bool, ByteString)) -> ByteString
forall a b. (a -> b) -> a -> b
$ Int
-> ((Bool, ByteString) -> Maybe (Word8, (Bool, ByteString)))
-> (Bool, ByteString)
-> (ByteString, Maybe (Bool, ByteString))
forall a.
Int -> (a -> Maybe (Word8, a)) -> a -> (ByteString, Maybe a)
S.unfoldrN (ByteString -> Int
S.length ByteString
i) (Bool, ByteString) -> Maybe (Word8, (Bool, ByteString))
go (Bool
True, ByteString
i) where
go :: (Bool, ByteString) -> Maybe (Word8, (Bool, ByteString))
go (Bool
toSucc, ByteString
bs) = do
(Word8
hd, ByteString
tl) <- ByteString -> Maybe (Word8, ByteString)
S.uncons ByteString
bs
let (Bool
top, Word8
hd') = Word8 -> (Bool, Word8)
forall a. (Bounded a, Enum a, Eq a) => a -> (Bool, a)
cycleSucc Word8
hd
if Bool
toSucc
then (Word8, (Bool, ByteString)) -> Maybe (Word8, (Bool, ByteString))
forall (m :: * -> *) a. Monad m => a -> m a
return (Word8
hd', (Bool
top, ByteString
tl))
else (Word8, (Bool, ByteString)) -> Maybe (Word8, (Bool, ByteString))
forall (m :: * -> *) a. Monad m => a -> m a
return (Word8
hd, (Bool
top Bool -> Bool -> Bool
&& Bool
toSucc, ByteString
tl))
orbit :: Eq a => (a -> a) -> a -> [a]
orbit :: (a -> a) -> a -> [a]
orbit a -> a
f a
a0 = a -> [a]
orbit' (a -> a
f a
a0) where
orbit' :: a -> [a]
orbit' a
a = if a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
a0 then [a
a0] else a
a a -> [a] -> [a]
forall a. a -> [a] -> [a]
: a -> [a]
orbit' (a -> a
f a
a)
pad :: Int -> ByteString -> ByteString
pad :: Int -> ByteString -> ByteString
pad Int
n = ByteString -> ByteString -> ByteString
forall a. Monoid a => a -> a -> a
mappend (Int -> Word8 -> ByteString
S.replicate Int
n Word8
0)
unpad :: Int -> ByteString -> ByteString
unpad :: Int -> ByteString -> ByteString
unpad = Int -> ByteString -> ByteString
S.drop
handleErrno :: CInt -> (a -> Either String a)
handleErrno :: CInt -> a -> Either String a
handleErrno CInt
err a
a = case CInt
err of
CInt
0 -> a -> Either String a
forall a b. b -> Either a b
Right a
a
-1 -> String -> Either String a
forall a b. a -> Either a b
Left String
"failed"
CInt
n -> String -> Either String a
forall a b. a -> Either a b
Left (String
"unexpected error code: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ CInt -> String
forall a. Show a => a -> String
show CInt
n)
unsafeDidSucceed :: IO CInt -> Bool
unsafeDidSucceed :: IO CInt -> Bool
unsafeDidSucceed = CInt -> Bool
forall a. (Eq a, Num a) => a -> Bool
go (CInt -> Bool) -> (IO CInt -> CInt) -> IO CInt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO CInt -> CInt
forall a. IO a -> a
unsafePerformIO
where go :: a -> Bool
go a
0 = Bool
True
go a
_ = Bool
False
withCStrings :: [String] -> ([CString] -> IO a) -> IO a
withCStrings :: [String] -> ([CString] -> IO a) -> IO a
withCStrings = (String
-> (([CString] -> IO a) -> IO a) -> ([CString] -> IO a) -> IO a)
-> (([CString] -> IO a) -> IO a)
-> [String]
-> ([CString] -> IO a)
-> IO a
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\String
v ([CString] -> IO a) -> IO a
kk -> \[CString] -> IO a
k -> (String -> (CString -> IO a) -> IO a
forall a. String -> (CString -> IO a) -> IO a
withCString String
v) (\CString
a -> ([CString] -> IO a) -> IO a
kk (\[CString]
as -> [CString] -> IO a
k (CString
aCString -> [CString] -> [CString]
forall a. a -> [a] -> [a]
:[CString]
as)))) (([CString] -> IO a) -> [CString] -> IO a
forall a b. (a -> b) -> a -> b
$ [])
withCStringLens :: [String] -> ([CStringLen] -> IO a) -> IO a
withCStringLens :: [String] -> ([CStringLen] -> IO a) -> IO a
withCStringLens = (String
-> (([CStringLen] -> IO a) -> IO a)
-> ([CStringLen] -> IO a)
-> IO a)
-> (([CStringLen] -> IO a) -> IO a)
-> [String]
-> ([CStringLen] -> IO a)
-> IO a
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\String
v ([CStringLen] -> IO a) -> IO a
kk -> \[CStringLen] -> IO a
k -> (String -> (CStringLen -> IO a) -> IO a
forall a. String -> (CStringLen -> IO a) -> IO a
withCStringLen String
v) (\CStringLen
a -> ([CStringLen] -> IO a) -> IO a
kk (\[CStringLen]
as -> [CStringLen] -> IO a
k (CStringLen
aCStringLen -> [CStringLen] -> [CStringLen]
forall a. a -> [a] -> [a]
:[CStringLen]
as)))) (([CStringLen] -> IO a) -> [CStringLen] -> IO a
forall a b. (a -> b) -> a -> b
$ [])
constByteStrings :: [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings :: [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings =
(ByteString
-> (([CStringLen] -> IO b) -> IO b)
-> ([CStringLen] -> IO b)
-> IO b)
-> (([CStringLen] -> IO b) -> IO b)
-> [ByteString]
-> ([CStringLen] -> IO b)
-> IO b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\ByteString
v ([CStringLen] -> IO b) -> IO b
kk -> \[CStringLen] -> IO b
k -> (ByteString -> (CStringLen -> IO b) -> IO b
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
v) (\CStringLen
a -> ([CStringLen] -> IO b) -> IO b
kk (\[CStringLen]
as -> [CStringLen] -> IO b
k (CStringLen
aCStringLen -> [CStringLen] -> [CStringLen]
forall a. a -> [a] -> [a]
:[CStringLen]
as)))) (([CStringLen] -> IO b) -> [CStringLen] -> IO b
forall a b. (a -> b) -> a -> b
$ [])
buildUnsafeByteString' :: Int -> (Ptr CChar -> IO b) -> IO (b, ByteString)
buildUnsafeByteString' :: Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeByteString' Int
n CString -> IO b
k = do
CString
ph <- Int -> IO CString
forall a. Int -> IO (Ptr a)
mallocBytes Int
n
ByteString
bs <- CStringLen -> IO ByteString
unsafePackMallocCStringLen (CString
ph, Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
b
out <- ByteString -> (CString -> IO b) -> IO b
forall a. ByteString -> (CString -> IO a) -> IO a
unsafeUseAsCString ByteString
bs CString -> IO b
k
(b, ByteString) -> IO (b, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (b
out, ByteString
bs)
buildUnsafeVariableByteString' :: Int -> (Ptr CChar -> IO b) -> IO (b, ByteString)
buildUnsafeVariableByteString' :: Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeVariableByteString' Int
n CString -> IO b
k = do
CString
ph <- Int -> IO CString
forall a. Int -> IO (Ptr a)
mallocBytes Int
n
b
out <- CString -> IO b
k CString
ph
ByteString
bs <- CString -> IO ByteString
unsafePackMallocCString CString
ph
(b, ByteString) -> IO (b, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (b
out, ByteString
bs)
buildUnsafeVariableByteString :: Int -> (Ptr CChar -> IO b) -> (b, ByteString)
buildUnsafeVariableByteString :: Int -> (CString -> IO b) -> (b, ByteString)
buildUnsafeVariableByteString Int
n = IO (b, ByteString) -> (b, ByteString)
forall a. IO a -> a
unsafePerformIO (IO (b, ByteString) -> (b, ByteString))
-> ((CString -> IO b) -> IO (b, ByteString))
-> (CString -> IO b)
-> (b, ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (CString -> IO b) -> IO (b, ByteString)
forall b. Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeVariableByteString' Int
n
buildUnsafeByteString :: Int -> (Ptr CChar -> IO b) -> (b, ByteString)
buildUnsafeByteString :: Int -> (CString -> IO b) -> (b, ByteString)
buildUnsafeByteString Int
n = IO (b, ByteString) -> (b, ByteString)
forall a. IO a -> a
unsafePerformIO (IO (b, ByteString) -> (b, ByteString))
-> ((CString -> IO b) -> IO (b, ByteString))
-> (CString -> IO b)
-> (b, ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (CString -> IO b) -> IO (b, ByteString)
forall b. Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeByteString' Int
n
randomByteString :: Int -> IO ByteString
randomByteString :: Int -> IO ByteString
randomByteString Int
n =
((), ByteString) -> ByteString
forall a b. (a, b) -> b
snd (((), ByteString) -> ByteString)
-> IO ((), ByteString) -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (CString -> IO ()) -> IO ((), ByteString)
forall b. Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeByteString' Int
n (CString -> CInt -> IO ()
`c_randombytes_buf` Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
hush :: Either s a -> Maybe a
hush :: Either s a -> Maybe a
hush = (s -> Maybe a) -> (a -> Maybe a) -> Either s a -> Maybe a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe a -> s -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing) a -> Maybe a
forall a. a -> Maybe a
Just
foreign import ccall "randombytes_buf"
c_randombytes_buf :: Ptr CChar -> CInt -> IO ()
foreign import ccall unsafe "sodium_memcmp"
c_sodium_memcmp
:: Ptr CChar
-> Ptr CChar
-> CInt
-> IO CInt
foreign import ccall unsafe "sodium_malloc"
c_sodium_malloc
:: CSize -> IO (Ptr a)
foreign import ccall unsafe "sodium_free"
c_sodium_free
:: Ptr Word8 -> IO ()
buildUnsafeScrubbedByteString' :: Int -> (Ptr CChar -> IO b) -> IO (b,ByteString)
buildUnsafeScrubbedByteString' :: Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeScrubbedByteString' Int
n CString -> IO b
k = do
Ptr Word8
p <- CSize -> IO (Ptr Word8)
forall a. CSize -> IO (Ptr a)
c_sodium_malloc (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
ByteString
bs <- Ptr Word8 -> Int -> IO () -> IO ByteString
unsafePackCStringFinalizer Ptr Word8
p Int
n (Ptr Word8 -> IO ()
c_sodium_free Ptr Word8
p)
b
out <- ByteString -> (CString -> IO b) -> IO b
forall a. ByteString -> (CString -> IO a) -> IO a
unsafeUseAsCString ByteString
bs CString -> IO b
k
(b, ByteString) -> IO (b, ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b
out,ByteString
bs)
buildUnsafeScrubbedByteString :: Int -> (Ptr CChar -> IO b) -> (b,ByteString)
buildUnsafeScrubbedByteString :: Int -> (CString -> IO b) -> (b, ByteString)
buildUnsafeScrubbedByteString Int
n = IO (b, ByteString) -> (b, ByteString)
forall a. IO a -> a
unsafePerformIO (IO (b, ByteString) -> (b, ByteString))
-> ((CString -> IO b) -> IO (b, ByteString))
-> (CString -> IO b)
-> (b, ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (CString -> IO b) -> IO (b, ByteString)
forall b. Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeScrubbedByteString' Int
n
compare :: ByteString -> ByteString -> Bool
compare :: ByteString -> ByteString -> Bool
compare ByteString
a ByteString
b =
(ByteString -> Int
S.length ByteString
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> Int
S.length ByteString
b) Bool -> Bool -> Bool
&& IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO ([ByteString] -> ([CStringLen] -> IO Bool) -> IO Bool
forall b. [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings [ByteString
a, ByteString
b] (([CStringLen] -> IO Bool) -> IO Bool)
-> ([CStringLen] -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \
[(CString
bsa, Int
_), (CString
bsb,Int
_)] ->
(CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
0) (CInt -> Bool) -> IO CInt -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CString -> CString -> CInt -> IO CInt
c_sodium_memcmp CString
bsa CString
bsb (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CInt) -> Int -> CInt
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
S.length ByteString
a))
foreign import ccall unsafe "sodium_bin2hex"
c_sodium_bin2hex
:: Ptr CChar
-> CInt
-> Ptr CChar
-> CInt
-> IO (Ptr CChar)
bin2hex :: ByteString -> String
bin2hex :: ByteString -> String
bin2hex ByteString
bs = let tlen :: Int
tlen = ByteString -> Int
S.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 in
ByteString -> String
S8.unpack (ByteString -> String)
-> ((CString -> IO CString) -> ByteString)
-> (CString -> IO CString)
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
S8.init (ByteString -> ByteString)
-> ((CString -> IO CString) -> ByteString)
-> (CString -> IO CString)
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CString, ByteString) -> ByteString
forall a b. (a, b) -> b
snd ((CString, ByteString) -> ByteString)
-> ((CString -> IO CString) -> (CString, ByteString))
-> (CString -> IO CString)
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (CString -> IO CString) -> (CString, ByteString)
forall b. Int -> (CString -> IO b) -> (b, ByteString)
buildUnsafeByteString Int
tlen ((CString -> IO CString) -> String)
-> (CString -> IO CString) -> String
forall a b. (a -> b) -> a -> b
$ \CString
t ->
[ByteString] -> ([CStringLen] -> IO CString) -> IO CString
forall b. [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings [ByteString
bs] (([CStringLen] -> IO CString) -> IO CString)
-> ([CStringLen] -> IO CString) -> IO CString
forall a b. (a -> b) -> a -> b
$ \
[(CString
pbs, Int
_)] ->
CString -> CInt -> CString -> CInt -> IO CString
c_sodium_bin2hex CString
t (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
tlen) CString
pbs (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CInt) -> Int -> CInt
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
S.length ByteString
bs)
uncurry3 :: (a -> b -> c -> d) -> ((a, b, c) -> d)
uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 a -> b -> c -> d
f ~(a
a,b
b,c
c) = a -> b -> c -> d
f a
a b
b c
c
uncurry5 :: (a -> b -> c -> d -> e -> f) -> ((a, b, c, d, e) -> f)
uncurry5 :: (a -> b -> c -> d -> e -> f) -> (a, b, c, d, e) -> f
uncurry5 a -> b -> c -> d -> e -> f
f ~(a
a,b
b,c
c,d
d,e
e) = a -> b -> c -> d -> e -> f
f a
a b
b c
c d
d e
e
(!&&!) :: Bool -> Bool -> Bool
!&&! :: Bool -> Bool -> Bool
(!&&!) !Bool
a !Bool
b = Bool
a Bool -> Bool -> Bool
&& Bool
b
(!||!) :: Bool -> Bool -> Bool
!||! :: Bool -> Bool -> Bool
(!||!) !Bool
a !Bool
b = Bool
a Bool -> Bool -> Bool
|| Bool
b