{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE Trustworthy #-}
module Crypto.NewHope.Internal.RNG where
import Codec.Crypto.AES
import Control.DeepSeq
import Data.Bits
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Char8 as BSC
import qualified Data.ByteString.Lazy as BSL
import Data.Semigroup ((<>))
import Data.Word
newtype RandomSeed = RandomSeed BS.ByteString
randomSeedBytes :: Int
randomSeedBytes = 48
getRandomSeedData :: RandomSeed -> BS.ByteString
getRandomSeedData (RandomSeed rsData) = rsData
class RandomSeedable a
where
makeRandomSeed :: a -> RandomSeed
instance RandomSeedable String
where
makeRandomSeed s | not lengthOK = error $ "Invalid length for RandomSeed. Have " ++ show len ++ " and require " ++ show randomSeedBytes ++ " bytes."
| otherwise = RandomSeed $ BSC.pack s
where
len = Prelude.length s
lengthOK = len == randomSeedBytes
instance RandomSeedable BS.ByteString
where
makeRandomSeed bs | not lengthOK = error $ "Invalid length for RandomSeed. Have " ++ show len ++ " and require " ++ show randomSeedBytes ++ " bytes."
| otherwise = RandomSeed bs
where
len = BS.length bs
lengthOK = len == randomSeedBytes
newtype Key = Key { getKey :: BS.ByteString } deriving (Eq)
keyBytes :: Int
keyBytes = 32
createKey :: BS.ByteString -> Key
createKey value
| BS.length value /= keyBytes = error "Incorrect key length"
| otherwise = Key { getKey = value }
newtype V = V { getV :: BS.ByteString } deriving (Eq)
vBytes :: Int
vBytes = 16
createV :: BS.ByteString -> V
createV value
| BS.length value /= vBytes = error $ "Incorrect V length: " ++ show (BS.length value)
| otherwise = V { getV = value }
incrementV :: V -> V
incrementV (V v) = let v' = reverse $ BS.unpack v
v'' = go v'
v''' = BS.pack $ reverse v''
go [] = []
go (0xff : is) = 0 : go is
go (i : is) = (i + 1) : is
in V v'''
data Context = Context { ctxKey :: Key
, ctxV :: V
, ctxReseedCounter :: Int
} deriving (Eq)
instance NFData Context
where
rnf Context { ctxKey = key, ctxV = v } = seq key seq v ()
update :: Context -> Maybe BS.ByteString -> Context
update ctx providedData = let
_ = ctxReseedCounter ctx
Key key = ctxKey ctx
v = ctxV ctx
ecbModeDoesNotUseIV = BS.pack $ replicate 16 0
(_, chunks) = foldr go (v, []) [0 .. (2 :: Int)]
where
go _ (_v, _chunks) = (v', encrypted : _chunks)
where
v' = incrementV _v
encrypted = crypt' ECB key ecbModeDoesNotUseIV Encrypt (getV v')
chunks' = reverse chunks
unified = mconcat chunks'
unified' = case providedData of
Nothing -> unified
Just providedData' -> BS.pack $ BS.zipWith xor unified providedData'
(nextKeyData, nextVData) = BSC.splitAt 32 unified'
nextKey = createKey nextKeyData
nextV = createV nextVData
ctx' = ctx { ctxKey = nextKey, ctxV = nextV }
in ctx'
randomBytesInit :: RandomSeed
-> Maybe RandomSeed
-> Integer
-> Context
randomBytesInit seed personalization _securityStrength = update ctx (Just seedMaterial)
where
RandomSeed entropyInput = seed
seedMaterial = BS.pack $ case personalization of
Nothing -> BS.unpack entropyInput
Just (RandomSeed persData) -> zipWith xor (BS.unpack entropyInput) (BS.unpack persData)
ctx = Context { ctxKey = createKey $ BS.pack $ replicate keyBytes 0
, ctxV = createV $ BS.pack $ replicate vBytes 0
, ctxReseedCounter = 1
}
randomBytes :: Context -> Int -> (BS.ByteString, Context)
randomBytes ctx count = (result, ctx'')
where
result = BS.take count $ BSL.toStrict $ Builder.toLazyByteString results
blocks = ceiling $ (fromIntegral count :: Double) / 16
counter = ctxReseedCounter ctx + 1
key = getKey $ ctxKey ctx
ecbModeDoesNotUseIV = BS.pack $ replicate 16 0
ctx'' = update ctx' { ctxReseedCounter = counter } Nothing
(results, ctx') = foldr go (Builder.byteString BS.empty, ctx) [1 .. blocks :: Int]
where
go _ (_results, _ctx) = (_results <> Builder.byteString block, _ctx')
where
v = incrementV $ ctxV _ctx
block = crypt' ECB key ecbModeDoesNotUseIV Encrypt (getV v)
_ctx' = _ctx { ctxV = v }
nextWord64 :: Context -> (Word64, Context)
nextWord64 ctx = (value, ctx')
where
(fourBytes, ctx') = randomBytes ctx 4
fourBytes' = BS.unpack fourBytes
value = shiftL (fromIntegral (fourBytes' !! 0)) 24
.|. shiftL (fromIntegral (fourBytes' !! 1)) 16
.|. shiftL (fromIntegral (fourBytes' !! 2)) 8
.|. fromIntegral (fourBytes' !! 3)
randomInteger :: Context -> (Integer, Integer) -> (Integer, Context)
randomInteger ctx (minValue, maxValue)
| minValue > maxValue = randomInteger ctx (maxValue, minValue)
| otherwise = (fromInteger (minValue + v `mod` k), ctx')
where
(v, ctx') = accumulate 1 0 ctx
(genlo, genhi) = (minBound, maxBound) :: (Word64, Word64)
b = fromIntegral genhi - fromIntegral genlo + 1
q = 1000
k = maxValue - minValue + 1
magtgt = k * q
accumulate mag vv ctx0
| mag >= magtgt = (vv, ctx0)
| otherwise = v' `seq` accumulate (mag * b) v' ctx0'
where
(x, ctx0') = nextWord64 ctx0
v' = vv * b + (fromIntegral x - fromIntegral genlo)