{-# LANGUAGE FlexibleContexts #-}
module Raaz.Random.ChaCha20PRG
( reseedMT, fillRandomBytesMT, RandomState(..)
) where
import Control.Applicative
import Control.Monad
import Foreign.Ptr (Ptr, castPtr)
import Prelude
import Raaz.Core
import Raaz.Cipher.ChaCha20.Internal
import Raaz.Cipher.ChaCha20.Recommendation
import Raaz.Entropy
maxCounterVal :: Counter
maxCounterVal :: Counter
maxCounterVal = Counter
1024 Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
* Counter
1024 Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
* Counter
1024
data RandomState = RandomState { RandomState -> ChaCha20Mem
chacha20State :: ChaCha20Mem
, RandomState -> RandomBuf
auxBuffer :: RandomBuf
, RandomState -> MemoryCell (BYTES Int)
remainingBytes :: MemoryCell (BYTES Int)
}
withAuxBuffer :: (Ptr something -> MT RandomState a) -> MT RandomState a
withAuxBuffer :: (Ptr something -> MT RandomState a) -> MT RandomState a
withAuxBuffer Ptr something -> MT RandomState a
action = (RandomState -> RandomBuf)
-> MT RandomBuf Pointer -> MT RandomState Pointer
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory RandomState -> RandomBuf
auxBuffer MT RandomBuf Pointer
getBufferPointer MT RandomState Pointer
-> (Pointer -> MT RandomState a) -> MT RandomState a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr something -> MT RandomState a
action (Ptr something -> MT RandomState a)
-> (Pointer -> Ptr something) -> Pointer -> MT RandomState a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pointer -> Ptr something
forall a b. Ptr a -> Ptr b
castPtr
getRemainingBytes :: MT RandomState (BYTES Int)
getRemainingBytes :: MT RandomState (BYTES Int)
getRemainingBytes = (RandomState -> MemoryCell (BYTES Int))
-> MT (MemoryCell (BYTES Int)) (BYTES Int)
-> MT RandomState (BYTES Int)
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory RandomState -> MemoryCell (BYTES Int)
remainingBytes MT (MemoryCell (BYTES Int)) (BYTES Int)
forall m v. Extractable m v => MT m v
extract
setRemainingBytes :: BYTES Int -> MT RandomState ()
setRemainingBytes :: BYTES Int -> MT RandomState ()
setRemainingBytes = (RandomState -> MemoryCell (BYTES Int))
-> MT (MemoryCell (BYTES Int)) () -> MT RandomState ()
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory RandomState -> MemoryCell (BYTES Int)
remainingBytes (MT (MemoryCell (BYTES Int)) () -> MT RandomState ())
-> (BYTES Int -> MT (MemoryCell (BYTES Int)) ())
-> BYTES Int
-> MT RandomState ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BYTES Int -> MT (MemoryCell (BYTES Int)) ()
forall m v. Initialisable m v => v -> MT m ()
initialise
instance Memory RandomState where
memoryAlloc :: Alloc RandomState
memoryAlloc = ChaCha20Mem -> RandomBuf -> MemoryCell (BYTES Int) -> RandomState
RandomState (ChaCha20Mem -> RandomBuf -> MemoryCell (BYTES Int) -> RandomState)
-> TwistRF AllocField (BYTES Int) ChaCha20Mem
-> TwistRF
AllocField
(BYTES Int)
(RandomBuf -> MemoryCell (BYTES Int) -> RandomState)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TwistRF AllocField (BYTES Int) ChaCha20Mem
forall m. Memory m => Alloc m
memoryAlloc TwistRF
AllocField
(BYTES Int)
(RandomBuf -> MemoryCell (BYTES Int) -> RandomState)
-> TwistRF AllocField (BYTES Int) RandomBuf
-> TwistRF
AllocField (BYTES Int) (MemoryCell (BYTES Int) -> RandomState)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) RandomBuf
forall m. Memory m => Alloc m
memoryAlloc TwistRF
AllocField (BYTES Int) (MemoryCell (BYTES Int) -> RandomState)
-> TwistRF AllocField (BYTES Int) (MemoryCell (BYTES Int))
-> Alloc RandomState
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) (MemoryCell (BYTES Int))
forall m. Memory m => Alloc m
memoryAlloc
unsafeToPointer :: RandomState -> Pointer
unsafeToPointer = ChaCha20Mem -> Pointer
forall m. Memory m => m -> Pointer
unsafeToPointer (ChaCha20Mem -> Pointer)
-> (RandomState -> ChaCha20Mem) -> RandomState -> Pointer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> ChaCha20Mem
chacha20State
newSample :: MT RandomState ()
newSample :: MT RandomState ()
newSample = do
MT RandomState ()
seedIfReq
(Pointer -> MT RandomState ()) -> MT RandomState ()
forall something a.
(Ptr something -> MT RandomState a) -> MT RandomState a
withAuxBuffer ((Pointer -> MT RandomState ()) -> MT RandomState ())
-> (Pointer -> MT RandomState ()) -> MT RandomState ()
forall a b. (a -> b) -> a -> b
$ (RandomState -> ChaCha20Mem)
-> MT ChaCha20Mem () -> MT RandomState ()
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory RandomState -> ChaCha20Mem
chacha20State (MT ChaCha20Mem () -> MT RandomState ())
-> (Pointer -> MT ChaCha20Mem ()) -> Pointer -> MT RandomState ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Pointer -> BLOCKS ChaCha20 -> MT ChaCha20Mem ())
-> BLOCKS ChaCha20 -> Pointer -> MT ChaCha20Mem ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Pointer -> BLOCKS ChaCha20 -> MT ChaCha20Mem ()
chacha20Random BLOCKS ChaCha20
randomBufferSize
BYTES Int -> MT RandomState ()
setRemainingBytes (BYTES Int -> MT RandomState ()) -> BYTES Int -> MT RandomState ()
forall a b. (a -> b) -> a -> b
$ BLOCKS ChaCha20 -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes BLOCKS ChaCha20
randomBufferSize
(BYTES Int -> Pointer -> MT RandomState (BYTES Int))
-> MT RandomState ()
forall a.
(BYTES Int -> Pointer -> MT RandomState a) -> MT RandomState ()
fillKeyIVWith BYTES Int -> Pointer -> MT RandomState (BYTES Int)
fillExistingBytes
seed :: MT RandomState ()
seed :: MT RandomState ()
seed = do (RandomState -> MemoryCell Counter)
-> MT (MemoryCell Counter) () -> MT RandomState ()
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory (ChaCha20Mem -> MemoryCell Counter
counterCell (ChaCha20Mem -> MemoryCell Counter)
-> (RandomState -> ChaCha20Mem)
-> RandomState
-> MemoryCell Counter
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> ChaCha20Mem
chacha20State) (MT (MemoryCell Counter) () -> MT RandomState ())
-> MT (MemoryCell Counter) () -> MT RandomState ()
forall a b. (a -> b) -> a -> b
$ Counter -> MT (MemoryCell Counter) ()
forall m v. Initialisable m v => v -> MT m ()
initialise (Counter
0 :: Counter)
(BYTES Int -> Pointer -> MT RandomState (BYTES Int))
-> MT RandomState ()
forall a.
(BYTES Int -> Pointer -> MT RandomState a) -> MT RandomState ()
fillKeyIVWith BYTES Int -> Pointer -> MT RandomState (BYTES Int)
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
l -> Pointer -> m (BYTES Int)
getEntropy
seedIfReq :: MT RandomState ()
seedIfReq :: MT RandomState ()
seedIfReq = do Counter
c <- (RandomState -> MemoryCell Counter)
-> MT (MemoryCell Counter) Counter -> MT RandomState Counter
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory (ChaCha20Mem -> MemoryCell Counter
counterCell (ChaCha20Mem -> MemoryCell Counter)
-> (RandomState -> ChaCha20Mem)
-> RandomState
-> MemoryCell Counter
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> ChaCha20Mem
chacha20State) MT (MemoryCell Counter) Counter
forall m v. Extractable m v => MT m v
extract
Bool -> MT RandomState () -> MT RandomState ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Counter
c Counter -> Counter -> Bool
forall a. Ord a => a -> a -> Bool
> Counter
maxCounterVal) MT RandomState ()
seed
fillKeyIVWith :: (BYTES Int -> Pointer -> MT RandomState a)
-> MT RandomState ()
fillKeyIVWith :: (BYTES Int -> Pointer -> MT RandomState a) -> MT RandomState ()
fillKeyIVWith BYTES Int -> Pointer -> MT RandomState a
filler = let
keySize :: BYTES Int
keySize = KEY -> BYTES Int
forall a. Storable a => a -> BYTES Int
sizeOf (KEY
forall a. HasCallStack => a
undefined :: KEY)
ivSize :: BYTES Int
ivSize = IV -> BYTES Int
forall a. Storable a => a -> BYTES Int
sizeOf (IV
forall a. HasCallStack => a
undefined :: IV)
in do (RandomState -> MemoryCell KEY)
-> MT (MemoryCell KEY) (Ptr KEY) -> MT RandomState (Ptr KEY)
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory (ChaCha20Mem -> MemoryCell KEY
keyCell (ChaCha20Mem -> MemoryCell KEY)
-> (RandomState -> ChaCha20Mem) -> RandomState -> MemoryCell KEY
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> ChaCha20Mem
chacha20State) MT (MemoryCell KEY) (Ptr KEY)
forall (mT :: * -> * -> *) a.
(MemoryThread mT, Storable a) =>
mT (MemoryCell a) (Ptr a)
getCellPointer MT RandomState (Ptr KEY)
-> (Ptr KEY -> MT RandomState ()) -> MT RandomState ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MT RandomState a -> MT RandomState ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (MT RandomState a -> MT RandomState ())
-> (Ptr KEY -> MT RandomState a) -> Ptr KEY -> MT RandomState ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BYTES Int -> Pointer -> MT RandomState a
filler BYTES Int
keySize (Pointer -> MT RandomState a)
-> (Ptr KEY -> Pointer) -> Ptr KEY -> MT RandomState a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr KEY -> Pointer
forall a b. Ptr a -> Ptr b
castPtr
(RandomState -> MemoryCell IV)
-> MT (MemoryCell IV) (Ptr IV) -> MT RandomState (Ptr IV)
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory (ChaCha20Mem -> MemoryCell IV
ivCell (ChaCha20Mem -> MemoryCell IV)
-> (RandomState -> ChaCha20Mem) -> RandomState -> MemoryCell IV
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> ChaCha20Mem
chacha20State) MT (MemoryCell IV) (Ptr IV)
forall (mT :: * -> * -> *) a.
(MemoryThread mT, Storable a) =>
mT (MemoryCell a) (Ptr a)
getCellPointer MT RandomState (Ptr IV)
-> (Ptr IV -> MT RandomState ()) -> MT RandomState ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MT RandomState a -> MT RandomState ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (MT RandomState a -> MT RandomState ())
-> (Ptr IV -> MT RandomState a) -> Ptr IV -> MT RandomState ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BYTES Int -> Pointer -> MT RandomState a
filler BYTES Int
ivSize (Pointer -> MT RandomState a)
-> (Ptr IV -> Pointer) -> Ptr IV -> MT RandomState a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr IV -> Pointer
forall a b. Ptr a -> Ptr b
castPtr
reseedMT :: MT RandomState ()
reseedMT :: MT RandomState ()
reseedMT = MT RandomState ()
seed MT RandomState () -> MT RandomState () -> MT RandomState ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MT RandomState ()
newSample
fillRandomBytesMT :: LengthUnit l => l -> Pointer -> MT RandomState ()
fillRandomBytesMT :: l -> Pointer -> MT RandomState ()
fillRandomBytesMT l
l = BYTES Int -> Pointer -> MT RandomState ()
go (l -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes l
l)
where go :: BYTES Int -> Pointer -> MT RandomState ()
go BYTES Int
m Pointer
ptr
| BYTES Int
m BYTES Int -> BYTES Int -> Bool
forall a. Ord a => a -> a -> Bool
> BYTES Int
0 = do BYTES Int
mGot <- BYTES Int -> Pointer -> MT RandomState (BYTES Int)
fillExistingBytes BYTES Int
m Pointer
ptr
Bool -> MT RandomState () -> MT RandomState ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (BYTES Int
mGot BYTES Int -> BYTES Int -> Bool
forall a. Ord a => a -> a -> Bool
<= BYTES Int
0) MT RandomState ()
newSample
BYTES Int -> Pointer -> MT RandomState ()
go (BYTES Int
m BYTES Int -> BYTES Int -> BYTES Int
forall a. Num a => a -> a -> a
- BYTES Int
mGot) (Pointer -> MT RandomState ()) -> Pointer -> MT RandomState ()
forall a b. (a -> b) -> a -> b
$ Pointer -> BYTES Int -> Pointer
forall l a. LengthUnit l => Ptr a -> l -> Ptr a
movePtr Pointer
ptr BYTES Int
mGot
| Bool
otherwise = () -> MT RandomState ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
fillExistingBytes :: BYTES Int -> Pointer -> MT RandomState (BYTES Int)
fillExistingBytes :: BYTES Int -> Pointer -> MT RandomState (BYTES Int)
fillExistingBytes BYTES Int
req Pointer
ptr = (Pointer -> MT RandomState (BYTES Int))
-> MT RandomState (BYTES Int)
forall something a.
(Ptr something -> MT RandomState a) -> MT RandomState a
withAuxBuffer ((Pointer -> MT RandomState (BYTES Int))
-> MT RandomState (BYTES Int))
-> (Pointer -> MT RandomState (BYTES Int))
-> MT RandomState (BYTES Int)
forall a b. (a -> b) -> a -> b
$ \ Pointer
sptr -> do
BYTES Int
r <- MT RandomState (BYTES Int)
getRemainingBytes
let m :: BYTES Int
m = BYTES Int -> BYTES Int -> BYTES Int
forall a. Ord a => a -> a -> a
min BYTES Int
r BYTES Int
req
l :: BYTES Int
l = BYTES Int
r BYTES Int -> BYTES Int -> BYTES Int
forall a. Num a => a -> a -> a
- BYTES Int
m
tailPtr :: Pointer
tailPtr = Pointer -> BYTES Int -> Pointer
forall l a. LengthUnit l => Ptr a -> l -> Ptr a
movePtr Pointer
sptr BYTES Int
l
in do
Dest Pointer -> Src Pointer -> BYTES Int -> MT RandomState ()
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
Dest Pointer -> Src Pointer -> l -> m ()
memcpy (Pointer -> Dest Pointer
forall a. a -> Dest a
destination Pointer
ptr) (Pointer -> Src Pointer
forall a. a -> Src a
source Pointer
tailPtr) BYTES Int
m
Pointer -> Word8 -> BYTES Int -> MT RandomState ()
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
Pointer -> Word8 -> l -> m ()
memset Pointer
tailPtr Word8
0 BYTES Int
m
BYTES Int -> MT RandomState ()
setRemainingBytes BYTES Int
l
BYTES Int -> MT RandomState (BYTES Int)
forall (m :: * -> *) a. Monad m => a -> m a
return BYTES Int
m