-- |A small selection of utilities that might be of use to others working with bytestring/number combinations. module Crypto.Util where import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as L import Data.ByteString.Unsafe (unsafeIndex, unsafeUseAsCStringLen) import Data.Bits (shiftL, shiftR) import Data.Bits (xor, setBit, shiftR, shiftL) import Control.Exception (Exception, throw) import Data.Tagged import System.IO.Unsafe import Foreign.C.Types import Foreign.Ptr -- |@incBS bs@ inefficiently computes the value @i2bs (8 * B.length bs) (bs2i bs + 1)@ incBS :: B.ByteString -> B.ByteString incBS bs = B.concat (go bs (B.length bs - 1)) where go bs i | B.length bs == 0 = [] | unsafeIndex bs i == 0xFF = (go (B.init bs) (i-1)) ++ [B.singleton 0] | otherwise = [B.init bs] ++ [B.singleton $ (unsafeIndex bs i) + 1] {-# INLINE incBS #-} -- |@i2bs bitLen i@ converts @i@ to a 'ByteString' of @bitLen@ bits (must be a multiple of 8). i2bs :: Int -> Integer -> B.ByteString i2bs l i = B.unfoldr (\l' -> if l' < 0 then Nothing else Just (fromIntegral (i `shiftR` l'), l' - 8)) (l-8) {-# INLINE i2bs #-} -- |@i2bs_unsized i@ converts @i@ to a 'ByteString' of sufficient bytes to express the integer. -- The integer must be non-negative and a zero will be encoded in one byte. i2bs_unsized :: Integer -> B.ByteString i2bs_unsized 0 = B.singleton 0 i2bs_unsized i = B.reverse $ B.unfoldr (\i' -> if i' <= 0 then Nothing else Just (fromIntegral i', (i' `shiftR` 8))) i {-# INLINE i2bs_unsized #-} -- | Useful utility to extract the result of a generator operation -- and translate error results to exceptions. throwLeft :: Exception e => Either e a -> a throwLeft (Left e) = throw e throwLeft (Right a) = a -- |Obtain a tagged value for a particular instantiated type. for :: Tagged a b -> a -> b for t _ = unTagged t -- |Infix `for` operator (.::.) :: Tagged a b -> a -> b (.::.) = for -- | Checks two bytestrings for equality without breaches for -- timing attacks. -- -- Semantically, @constTimeEq = (==)@. However, @x == y@ takes less -- time when the first byte is different than when the first byte -- is equal. This side channel allows an attacker to mount a -- timing attack. On the other hand, @constTimeEq@ always takes the -- same time regardless of the bytestrings' contents, unless they are -- of difference size. -- -- You should always use @constTimeEq@ when comparing secrets, -- otherwise you may leave a significant security hole -- (cf. ). constTimeEq :: B.ByteString -> B.ByteString -> Bool constTimeEq s1 s2 = unsafePerformIO $ unsafeUseAsCStringLen s1 $ \(s1_ptr, s1_len) -> unsafeUseAsCStringLen s2 $ \(s2_ptr, s2_len) -> if s1_len /= s2_len then return False else (== 0) `fmap` c_constTimeEq s1_ptr s2_ptr (fromIntegral s1_len) foreign import ccall unsafe c_constTimeEq :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt -- |Helper function to convert bytestrings to integers bs2i :: B.ByteString -> Integer bs2i bs = B.foldl' (\i b -> (i `shiftL` 8) + fromIntegral b) 0 bs {-# INLINE bs2i #-} -- |zipWith xor + Pack -- As a result of rewrite rules, this should automatically be -- optimized (at compile time). to use the bytestring libraries -- 'zipWith'' function. zwp' :: B.ByteString -> B.ByteString -> B.ByteString zwp' a = B.pack . B.zipWith xor a {-# INLINE zwp' #-} -- |zipWith xor + Pack -- -- This is written intentionally to take advantage -- of the bytestring libraries 'zipWith'' rewrite rule but at the -- extra cost of the resulting lazy bytestring being more fragmented -- than either of the two inputs. zwp :: L.ByteString -> L.ByteString -> L.ByteString zwp a b = let as = L.toChunks a bs = L.toChunks b in L.fromChunks (go as bs) where go [] _ = [] go _ [] = [] go (a:as) (b:bs) = let l = min (B.length a) (B.length b) (a',ar) = B.splitAt l a (b',br) = B.splitAt l b as' = if B.length ar == 0 then as else ar : as bs' = if B.length br == 0 then bs else br : bs in (zwp' a' b') : go as' bs' {-# INLINEABLE zwp #-} -- gather a specified number of bytes from the list of bytestrings collect :: Int -> [B.ByteString] -> [B.ByteString] collect 0 _ = [] collect _ [] = [] collect i (b:bs) | len < i = b : collect (i - len) bs | len >= i = [B.take i b] where len = B.length b {-# INLINE collect #-}