module Crypto.Util where
import qualified Data.ByteString as B
import Data.ByteString.Unsafe (unsafeIndex, unsafeUseAsCStringLen)
import Data.Bits (shiftL, shiftR)
import Data.Bits (xor, setBit, shiftR, shiftL)
import Control.Exception (Exception, throw)
import Data.Tagged
import System.IO.Unsafe
import Foreign.C.Types
import Foreign.Ptr
incBS :: B.ByteString -> B.ByteString
incBS bs = B.concat (go bs (B.length bs 1))
where
go bs i
| B.length bs == 0 = []
| unsafeIndex bs i == 0xFF = (go (B.init bs) (i1)) ++ [B.singleton 0]
| otherwise = [B.init bs] ++ [B.singleton $ (unsafeIndex bs i) + 1]
i2bs :: Int -> Integer -> B.ByteString
i2bs l i = B.unfoldr (\l' -> if l' < 0 then Nothing else Just (fromIntegral (i `shiftR` l'), l' 8)) (l8)
i2bs_unsized :: Integer -> B.ByteString
i2bs_unsized 0 = B.singleton 0
i2bs_unsized i = B.reverse $ B.unfoldr (\i' -> if i' <= 0 then Nothing else Just (fromIntegral i', (i' `shiftR` 8))) i
throwLeft :: Exception e => Either e a -> a
throwLeft (Left e) = throw e
throwLeft (Right a) = a
for :: Tagged a b -> a -> b
for t _ = unTagged t
(.::.) :: Tagged a b -> a -> b
(.::.) = for
constTimeEq :: B.ByteString -> B.ByteString -> Bool
constTimeEq s1 s2 =
unsafePerformIO $
unsafeUseAsCStringLen s1 $ \(s1_ptr, s1_len) ->
unsafeUseAsCStringLen s2 $ \(s2_ptr, s2_len) ->
if s1_len /= s2_len
then return False
else (== 0) `fmap` c_constTimeEq s1_ptr s2_ptr (fromIntegral s1_len)
foreign import ccall unsafe
c_constTimeEq :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt
bs2i :: B.ByteString -> Integer
bs2i bs = B.foldl' (\i b -> (i `shiftL` 8) + fromIntegral b) 0 bs
zwp' :: B.ByteString -> B.ByteString -> B.ByteString
zwp' a = B.pack . B.zipWith xor a