-- | The module exposes the ChaCha20 based PRG.
{-# LANGUAGE FlexibleContexts #-}
module Raaz.Random.ChaCha20PRG
       ( reseedMT, fillRandomBytesMT, RandomState(..)
       ) where

import Control.Applicative
import Control.Monad
import Foreign.Ptr   (Ptr, castPtr)
import Prelude

import Raaz.Core
import Raaz.Cipher.ChaCha20.Internal
import Raaz.Cipher.ChaCha20.Recommendation
import Raaz.Entropy

-- | The maximum value of counter before reseeding from entropy
-- source. Currently set to 1024 * 1024 * 1024. Which will generate
-- 64GB before reseeding.
--
-- The counter is a 32-bit quantity. Which means that one can generate
-- 2^32 blocks of data before the counter roles over and starts
-- repeating. We have choosen a conservative 2^30 blocks here.
maxCounterVal :: Counter
maxCounterVal :: Counter
maxCounterVal = Counter
1024 Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
* Counter
1024 Counter -> Counter -> Counter
forall a. Num a => a -> a -> a
* Counter
1024

-- | Memory for strong the internal memory state.
data RandomState = RandomState { RandomState -> ChaCha20Mem
chacha20State  :: ChaCha20Mem
                               , RandomState -> RandomBuf
auxBuffer      :: RandomBuf
                               , RandomState -> MemoryCell (BYTES Int)
remainingBytes :: MemoryCell (BYTES Int)
                               }

-------------------------- Some helper functions on random state -------------------

-- | Run an action on the auxilary buffer.
withAuxBuffer :: (Ptr something -> MT RandomState a) -> MT RandomState a
withAuxBuffer :: (Ptr something -> MT RandomState a) -> MT RandomState a
withAuxBuffer Ptr something -> MT RandomState a
action = (RandomState -> RandomBuf)
-> MT RandomBuf Pointer -> MT RandomState Pointer
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory RandomState -> RandomBuf
auxBuffer MT RandomBuf Pointer
getBufferPointer MT RandomState Pointer
-> (Pointer -> MT RandomState a) -> MT RandomState a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr something -> MT RandomState a
action (Ptr something -> MT RandomState a)
-> (Pointer -> Ptr something) -> Pointer -> MT RandomState a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pointer -> Ptr something
forall a b. Ptr a -> Ptr b
castPtr

-- | Get the number of bytes in the buffer.
getRemainingBytes :: MT RandomState (BYTES Int)
getRemainingBytes :: MT RandomState (BYTES Int)
getRemainingBytes = (RandomState -> MemoryCell (BYTES Int))
-> MT (MemoryCell (BYTES Int)) (BYTES Int)
-> MT RandomState (BYTES Int)
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory RandomState -> MemoryCell (BYTES Int)
remainingBytes MT (MemoryCell (BYTES Int)) (BYTES Int)
forall m v. Extractable m v => MT m v
extract

-- | Set the number of remaining bytes.
setRemainingBytes :: BYTES Int -> MT RandomState ()
setRemainingBytes :: BYTES Int -> MT RandomState ()
setRemainingBytes = (RandomState -> MemoryCell (BYTES Int))
-> MT (MemoryCell (BYTES Int)) () -> MT RandomState ()
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory RandomState -> MemoryCell (BYTES Int)
remainingBytes (MT (MemoryCell (BYTES Int)) () -> MT RandomState ())
-> (BYTES Int -> MT (MemoryCell (BYTES Int)) ())
-> BYTES Int
-> MT RandomState ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BYTES Int -> MT (MemoryCell (BYTES Int)) ()
forall m v. Initialisable m v => v -> MT m ()
initialise

instance Memory RandomState where
  memoryAlloc :: Alloc RandomState
memoryAlloc     = ChaCha20Mem -> RandomBuf -> MemoryCell (BYTES Int) -> RandomState
RandomState (ChaCha20Mem -> RandomBuf -> MemoryCell (BYTES Int) -> RandomState)
-> TwistRF AllocField (BYTES Int) ChaCha20Mem
-> TwistRF
     AllocField
     (BYTES Int)
     (RandomBuf -> MemoryCell (BYTES Int) -> RandomState)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TwistRF AllocField (BYTES Int) ChaCha20Mem
forall m. Memory m => Alloc m
memoryAlloc TwistRF
  AllocField
  (BYTES Int)
  (RandomBuf -> MemoryCell (BYTES Int) -> RandomState)
-> TwistRF AllocField (BYTES Int) RandomBuf
-> TwistRF
     AllocField (BYTES Int) (MemoryCell (BYTES Int) -> RandomState)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) RandomBuf
forall m. Memory m => Alloc m
memoryAlloc TwistRF
  AllocField (BYTES Int) (MemoryCell (BYTES Int) -> RandomState)
-> TwistRF AllocField (BYTES Int) (MemoryCell (BYTES Int))
-> Alloc RandomState
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) (MemoryCell (BYTES Int))
forall m. Memory m => Alloc m
memoryAlloc
  unsafeToPointer :: RandomState -> Pointer
unsafeToPointer = ChaCha20Mem -> Pointer
forall m. Memory m => m -> Pointer
unsafeToPointer  (ChaCha20Mem -> Pointer)
-> (RandomState -> ChaCha20Mem) -> RandomState -> Pointer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> ChaCha20Mem
chacha20State

-------------------------------- The PRG operations ---------------------------------------------

-- | The overall idea is to generate a key stream into the auxilary
-- buffer using chacha20 and giving out bytes from this buffer. This
-- operation we call sampling. A portion of the sample is used for
-- resetting the key and iv to make the prg safe against backward
-- prediction, i.e. even if one knows the current seed (i.e. key iv
-- pair) one cannot predict the random values generated before.



-- | This fills in the random block with some new randomness
newSample :: MT RandomState ()
newSample :: MT RandomState ()
newSample = do
  MT RandomState ()
seedIfReq
  (Pointer -> MT RandomState ()) -> MT RandomState ()
forall something a.
(Ptr something -> MT RandomState a) -> MT RandomState a
withAuxBuffer ((Pointer -> MT RandomState ()) -> MT RandomState ())
-> (Pointer -> MT RandomState ()) -> MT RandomState ()
forall a b. (a -> b) -> a -> b
$ (RandomState -> ChaCha20Mem)
-> MT ChaCha20Mem () -> MT RandomState ()
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory RandomState -> ChaCha20Mem
chacha20State (MT ChaCha20Mem () -> MT RandomState ())
-> (Pointer -> MT ChaCha20Mem ()) -> Pointer -> MT RandomState ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Pointer -> BLOCKS ChaCha20 -> MT ChaCha20Mem ())
-> BLOCKS ChaCha20 -> Pointer -> MT ChaCha20Mem ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Pointer -> BLOCKS ChaCha20 -> MT ChaCha20Mem ()
chacha20Random BLOCKS ChaCha20
randomBufferSize -- keystream
  BYTES Int -> MT RandomState ()
setRemainingBytes (BYTES Int -> MT RandomState ()) -> BYTES Int -> MT RandomState ()
forall a b. (a -> b) -> a -> b
$ BLOCKS ChaCha20 -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes BLOCKS ChaCha20
randomBufferSize -- Total bytes generated in one go
  (BYTES Int -> Pointer -> MT RandomState (BYTES Int))
-> MT RandomState ()
forall a.
(BYTES Int -> Pointer -> MT RandomState a) -> MT RandomState ()
fillKeyIVWith BYTES Int -> Pointer -> MT RandomState (BYTES Int)
fillExistingBytes


-- | See the PRG from system entropy.
seed :: MT RandomState ()
seed :: MT RandomState ()
seed = do (RandomState -> MemoryCell Counter)
-> MT (MemoryCell Counter) () -> MT RandomState ()
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory (ChaCha20Mem -> MemoryCell Counter
counterCell (ChaCha20Mem -> MemoryCell Counter)
-> (RandomState -> ChaCha20Mem)
-> RandomState
-> MemoryCell Counter
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> ChaCha20Mem
chacha20State) (MT (MemoryCell Counter) () -> MT RandomState ())
-> MT (MemoryCell Counter) () -> MT RandomState ()
forall a b. (a -> b) -> a -> b
$ Counter -> MT (MemoryCell Counter) ()
forall m v. Initialisable m v => v -> MT m ()
initialise (Counter
0 :: Counter)
          (BYTES Int -> Pointer -> MT RandomState (BYTES Int))
-> MT RandomState ()
forall a.
(BYTES Int -> Pointer -> MT RandomState a) -> MT RandomState ()
fillKeyIVWith BYTES Int -> Pointer -> MT RandomState (BYTES Int)
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
l -> Pointer -> m (BYTES Int)
getEntropy

-- | Seed if we have already generated maxCounterVal blocks of random
-- bytes.
seedIfReq :: MT RandomState ()
seedIfReq :: MT RandomState ()
seedIfReq = do Counter
c <- (RandomState -> MemoryCell Counter)
-> MT (MemoryCell Counter) Counter -> MT RandomState Counter
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory (ChaCha20Mem -> MemoryCell Counter
counterCell (ChaCha20Mem -> MemoryCell Counter)
-> (RandomState -> ChaCha20Mem)
-> RandomState
-> MemoryCell Counter
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> ChaCha20Mem
chacha20State) MT (MemoryCell Counter) Counter
forall m v. Extractable m v => MT m v
extract
               Bool -> MT RandomState () -> MT RandomState ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Counter
c Counter -> Counter -> Bool
forall a. Ord a => a -> a -> Bool
> Counter
maxCounterVal) MT RandomState ()
seed

-- | Fill the iv and key from a filling function.
fillKeyIVWith :: (BYTES Int -> Pointer -> MT RandomState a) -- ^ The function used to fill the buffer
              -> MT RandomState ()
fillKeyIVWith :: (BYTES Int -> Pointer -> MT RandomState a) -> MT RandomState ()
fillKeyIVWith BYTES Int -> Pointer -> MT RandomState a
filler = let
  keySize :: BYTES Int
keySize = KEY -> BYTES Int
forall a. Storable a => a -> BYTES Int
sizeOf (KEY
forall a. HasCallStack => a
undefined :: KEY)
  ivSize :: BYTES Int
ivSize  = IV -> BYTES Int
forall a. Storable a => a -> BYTES Int
sizeOf (IV
forall a. HasCallStack => a
undefined :: IV)
  in do (RandomState -> MemoryCell KEY)
-> MT (MemoryCell KEY) (Ptr KEY) -> MT RandomState (Ptr KEY)
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory (ChaCha20Mem -> MemoryCell KEY
keyCell (ChaCha20Mem -> MemoryCell KEY)
-> (RandomState -> ChaCha20Mem) -> RandomState -> MemoryCell KEY
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> ChaCha20Mem
chacha20State) MT (MemoryCell KEY) (Ptr KEY)
forall (mT :: * -> * -> *) a.
(MemoryThread mT, Storable a) =>
mT (MemoryCell a) (Ptr a)
getCellPointer MT RandomState (Ptr KEY)
-> (Ptr KEY -> MT RandomState ()) -> MT RandomState ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MT RandomState a -> MT RandomState ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (MT RandomState a -> MT RandomState ())
-> (Ptr KEY -> MT RandomState a) -> Ptr KEY -> MT RandomState ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BYTES Int -> Pointer -> MT RandomState a
filler BYTES Int
keySize (Pointer -> MT RandomState a)
-> (Ptr KEY -> Pointer) -> Ptr KEY -> MT RandomState a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr KEY -> Pointer
forall a b. Ptr a -> Ptr b
castPtr
        (RandomState -> MemoryCell IV)
-> MT (MemoryCell IV) (Ptr IV) -> MT RandomState (Ptr IV)
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory (ChaCha20Mem -> MemoryCell IV
ivCell  (ChaCha20Mem -> MemoryCell IV)
-> (RandomState -> ChaCha20Mem) -> RandomState -> MemoryCell IV
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> ChaCha20Mem
chacha20State) MT (MemoryCell IV) (Ptr IV)
forall (mT :: * -> * -> *) a.
(MemoryThread mT, Storable a) =>
mT (MemoryCell a) (Ptr a)
getCellPointer MT RandomState (Ptr IV)
-> (Ptr IV -> MT RandomState ()) -> MT RandomState ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MT RandomState a -> MT RandomState ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (MT RandomState a -> MT RandomState ())
-> (Ptr IV -> MT RandomState a) -> Ptr IV -> MT RandomState ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BYTES Int -> Pointer -> MT RandomState a
filler BYTES Int
ivSize  (Pointer -> MT RandomState a)
-> (Ptr IV -> Pointer) -> Ptr IV -> MT RandomState a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr IV -> Pointer
forall a b. Ptr a -> Ptr b
castPtr





--------------------------- DANGEROUS CODE ---------------------------------------

-- | Reseed the prg.
reseedMT :: MT RandomState ()
reseedMT :: MT RandomState ()
reseedMT = MT RandomState ()
seed MT RandomState () -> MT RandomState () -> MT RandomState ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MT RandomState ()
newSample

-- NONTRIVIALITY: Picking up the newSample is important when we first
-- reseed.

-- | The function to generate random bytes. Fills from existing bytes
-- and continues if not enough bytes are obtained.
fillRandomBytesMT :: LengthUnit l => l -> Pointer -> MT RandomState ()
fillRandomBytesMT :: l -> Pointer -> MT RandomState ()
fillRandomBytesMT l
l = BYTES Int -> Pointer -> MT RandomState ()
go (l -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes l
l)
  where go :: BYTES Int -> Pointer -> MT RandomState ()
go BYTES Int
m Pointer
ptr
            | BYTES Int
m BYTES Int -> BYTES Int -> Bool
forall a. Ord a => a -> a -> Bool
> BYTES Int
0  = do BYTES Int
mGot <- BYTES Int -> Pointer -> MT RandomState (BYTES Int)
fillExistingBytes BYTES Int
m Pointer
ptr   -- Fill from the already generated buffer.
                          Bool -> MT RandomState () -> MT RandomState ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (BYTES Int
mGot BYTES Int -> BYTES Int -> Bool
forall a. Ord a => a -> a -> Bool
<= BYTES Int
0) MT RandomState ()
newSample        -- We did not get any so sample.
                          BYTES Int -> Pointer -> MT RandomState ()
go (BYTES Int
m BYTES Int -> BYTES Int -> BYTES Int
forall a. Num a => a -> a -> a
- BYTES Int
mGot) (Pointer -> MT RandomState ()) -> Pointer -> MT RandomState ()
forall a b. (a -> b) -> a -> b
$ Pointer -> BYTES Int -> Pointer
forall l a. LengthUnit l => Ptr a -> l -> Ptr a
movePtr Pointer
ptr BYTES Int
mGot  -- Get the remaining.
            | Bool
otherwise = () -> MT RandomState ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()   -- Nothing to do


-- | Fill from already existing bytes. Returns the number of bytes
-- filled. Let remaining bytes be r. Then fillExistingBytes will fill
-- min(r,m) bytes into the buffer, and return the number of bytes
-- filled.
fillExistingBytes :: BYTES Int -> Pointer -> MT RandomState (BYTES Int)
fillExistingBytes :: BYTES Int -> Pointer -> MT RandomState (BYTES Int)
fillExistingBytes BYTES Int
req Pointer
ptr = (Pointer -> MT RandomState (BYTES Int))
-> MT RandomState (BYTES Int)
forall something a.
(Ptr something -> MT RandomState a) -> MT RandomState a
withAuxBuffer ((Pointer -> MT RandomState (BYTES Int))
 -> MT RandomState (BYTES Int))
-> (Pointer -> MT RandomState (BYTES Int))
-> MT RandomState (BYTES Int)
forall a b. (a -> b) -> a -> b
$ \ Pointer
sptr -> do
  BYTES Int
r <- MT RandomState (BYTES Int)
getRemainingBytes
  let m :: BYTES Int
m  = BYTES Int -> BYTES Int -> BYTES Int
forall a. Ord a => a -> a -> a
min BYTES Int
r BYTES Int
req            -- actual bytes filled.
      l :: BYTES Int
l  = BYTES Int
r BYTES Int -> BYTES Int -> BYTES Int
forall a. Num a => a -> a -> a
- BYTES Int
m                -- leftover
      tailPtr :: Pointer
tailPtr = Pointer -> BYTES Int -> Pointer
forall l a. LengthUnit l => Ptr a -> l -> Ptr a
movePtr Pointer
sptr BYTES Int
l
    in do
    -- Fills the source ptr from the end.
    --  sptr                tailPtr
    --   |                  |
    --   V                  V
    --   -----------------------------------------------------
    --   |   l              |    m                           |
    --   -----------------------------------------------------
    Dest Pointer -> Src Pointer -> BYTES Int -> MT RandomState ()
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
Dest Pointer -> Src Pointer -> l -> m ()
memcpy (Pointer -> Dest Pointer
forall a. a -> Dest a
destination Pointer
ptr) (Pointer -> Src Pointer
forall a. a -> Src a
source Pointer
tailPtr) BYTES Int
m -- transfer the bytes to destination
    Pointer -> Word8 -> BYTES Int -> MT RandomState ()
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
Pointer -> Word8 -> l -> m ()
memset Pointer
tailPtr Word8
0 BYTES Int
m                          -- wipe the bytes already transfered.
    BYTES Int -> MT RandomState ()
setRemainingBytes BYTES Int
l                         -- set leftover bytes.
    BYTES Int -> MT RandomState (BYTES Int)
forall (m :: * -> *) a. Monad m => a -> m a
return BYTES Int
m