-- | This module implements the pseudo-random generator using the
-- /fast key erasure technique/
-- (<https://blog.cr.yp.to/20170723-random.html>) parameterised on the
-- signatures "Implementation" and "Entropy". This technique is the
-- underlying algorithm used in systems like OpenBSD in their
-- implementation of arc4random.
--
-- __WARNING:__ These details are only for developers and reviewers of
-- raaz the library. A casual user should not be looking into this
-- module this let alone tweaking the code here.
--

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DataKinds        #-}
{-# LANGUAGE RecordWildCards   #-}
-- {-# LANGUAGE NamedFieldPun    #-}

module PRGenerator
       ( -- * Pseudo-random generator
         -- $internals$
         RandomState, reseed, fillRandomBytes
         -- ** Information about the cryptographic generator.
       , entropySource, csprgName, csprgDescription
       ) where

import Foreign.Ptr ( castPtr )
import Entropy
import Prelude

import Raaz.Core
import Raaz.Core.Memory

import Implementation
import Context

-- $internals$
--
-- Generating unpredictable stream of bytes is one task that has burnt
-- the fingers of a lot of programmers. Unfortunately, getting it
-- correct is something of a black art. We give the internal details
-- of the cryptographic pseudo-random generator used in raaz. Note
-- that none of the details here are accessible or tuneable by the
-- user. This is a deliberate design choice to insulate the user from
-- things that are pretty easy to mess up.
--
-- The pseudo-random generator is essentially a primitive that
-- supports the generation of multiple blocks of data once its
-- internals are set. The overall idea is to set the internals from a
-- truly random source and then use the primitive to expand the
-- internal state into pseudo-random bytes. However, there are tricky
-- issues regarding forward security that will make such a simplistic
-- algorithm insecure. Besides, where do we get our truly random seed
-- to begin the process?
--
-- We more or less follow the /fast key erasure technique/
-- (<https://blog.cr.yp.to/20170723-random.html>) which is used in the
-- arc4random implementation in OpenBSD.  The two main steps in the
-- generation of the required random bytes are the following:
--
-- [Seeding:] Setting the internal state of a primitive. We use the
-- `getEntropy` function for this purposes.
--
-- [Sampling:] Pre-computing a few blocks using `randomBlocks` that
-- will later on be used to satisfy satisfy the requests for random
-- bytes.
--
-- Instead of running the `randomBlocks` for every request, we
-- generate `RandomBufferSize` blocks of random blocks in an auxiliary
-- buffer and satisfy requests for random bytes from this buffer. To
-- ensure that the compromise of the PRG state does not compromise the
-- random data already generated and given out, we do the following.
--
-- 1. After generating `RandomBufferSize` blocks of data in the
--    auxiliary buffer, we immediately re-initialise the internals of
--    the primitive from the auxiliary buffer. This ensures that there
--    is no way to know which internal state was used to generate the
--    current contents in the auxiliary buffer.
--
-- 2. Every use of data from the auxiliary buffer, whether it is to
--    satisfy a request for random bytes or to reinitialise the
--    internals in step 1 is wiped out immediately.
--
-- Assuming the security of the entropy source given by the
-- `getEntropy` and the random block generator given by the
-- `randomBlocks` we have the following security guarantee.
--
-- [Security Guarantee:] At any point of time, a compromise of the
-- cipher state (i.e. key iv pair) and/or the auxiliary buffer does
-- not reveal the random data that is given out previously.
--


-- | Name of the csprg used for stretching the seed.
csprgName :: String
csprgName :: String
csprgName = String
name

-- | A short description of the csprg.
csprgDescription :: String
csprgDescription :: String
csprgDescription = String
description

-- | Memory for storing the csprg state.
data RandomState = RandomState { RandomState -> Cxt RandomBufferSize
randomCxt       :: Cxt RandomBufferSize
                               , RandomState -> MemoryCell (BlockCount Prim)
randomGenBlocks :: MemoryCell (BlockCount Prim)
                               }


instance Memory RandomState where
  memoryAlloc :: Alloc RandomState
memoryAlloc     = Cxt RandomBufferSize -> MemoryCell (BlockCount Prim) -> RandomState
RandomState (Cxt RandomBufferSize
 -> MemoryCell (BlockCount Prim) -> RandomState)
-> TwistRF AllocField (BYTES Int) (Cxt RandomBufferSize)
-> TwistRF
     AllocField
     (BYTES Int)
     (MemoryCell (BlockCount Prim) -> RandomState)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TwistRF AllocField (BYTES Int) (Cxt RandomBufferSize)
forall m. Memory m => Alloc m
memoryAlloc TwistRF
  AllocField
  (BYTES Int)
  (MemoryCell (BlockCount Prim) -> RandomState)
-> TwistRF AllocField (BYTES Int) (MemoryCell (BlockCount Prim))
-> Alloc RandomState
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) (MemoryCell (BlockCount Prim))
forall m. Memory m => Alloc m
memoryAlloc
  unsafeToPointer :: RandomState -> Ptr Word8
unsafeToPointer = Cxt RandomBufferSize -> Ptr Word8
forall m. Memory m => m -> Ptr Word8
unsafeToPointer (Cxt RandomBufferSize -> Ptr Word8)
-> (RandomState -> Cxt RandomBufferSize)
-> RandomState
-> Ptr Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> Cxt RandomBufferSize
randomCxt

-- | Gives access into the internals of the associated cipher.
instance WriteAccessible RandomState where
  writeAccess :: RandomState -> [Access]
writeAccess          = Internals -> [Access]
forall mem. WriteAccessible mem => mem -> [Access]
writeAccess (Internals -> [Access])
-> (RandomState -> Internals) -> RandomState -> [Access]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cxt RandomBufferSize -> Internals
forall (n :: Nat). Cxt n -> Internals
cxtInternals (Cxt RandomBufferSize -> Internals)
-> (RandomState -> Cxt RandomBufferSize)
-> RandomState
-> Internals
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> Cxt RandomBufferSize
randomCxt
  afterWriteAdjustment :: RandomState -> IO ()
afterWriteAdjustment = Internals -> IO ()
forall mem. WriteAccessible mem => mem -> IO ()
afterWriteAdjustment (Internals -> IO ())
-> (RandomState -> Internals) -> RandomState -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cxt RandomBufferSize -> Internals
forall (n :: Nat). Cxt n -> Internals
cxtInternals (Cxt RandomBufferSize -> Internals)
-> (RandomState -> Cxt RandomBufferSize)
-> RandomState
-> Internals
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> Cxt RandomBufferSize
randomCxt

-------------------------------- The PRG operations ---------------------------------------------

-- | Generate a new sample, i.e. fill the context with psrg.
sample :: RandomState -> IO ()
sample :: RandomState -> IO ()
sample rstate :: RandomState
rstate@RandomState{MemoryCell (BlockCount Prim)
Cxt RandomBufferSize
randomGenBlocks :: MemoryCell (BlockCount Prim)
randomCxt :: Cxt RandomBufferSize
randomGenBlocks :: RandomState -> MemoryCell (BlockCount Prim)
randomCxt :: RandomState -> Cxt RandomBufferSize
..} = do
  BlockCount Prim
genBlocks <- MemoryCell (BlockCount Prim) -> IO (BlockCount Prim)
forall m v. Extractable m v => m -> IO v
extract MemoryCell (BlockCount Prim)
randomGenBlocks
  if BlockCount Prim
genBlocks BlockCount Prim -> BlockCount Prim -> Bool
forall a. Ord a => a -> a -> Bool
>= BlockCount Prim
reseedAfter then RandomState -> IO ()
reseed RandomState
rstate
    else RandomState -> IO ()
generateRandom RandomState
rstate

-- | Reseed the state from the system entropy pool. The CSPRG
-- interface automatically takes care of reseeding from the entropy
-- pool at regular intervals and the user almost never needs to use
-- this.
reseed :: RandomState -> IO ()
reseed :: RandomState -> IO ()
reseed rstate :: RandomState
rstate@RandomState{MemoryCell (BlockCount Prim)
Cxt RandomBufferSize
randomGenBlocks :: MemoryCell (BlockCount Prim)
randomCxt :: Cxt RandomBufferSize
randomGenBlocks :: RandomState -> MemoryCell (BlockCount Prim)
randomCxt :: RandomState -> Cxt RandomBufferSize
..} = do
  RandomState -> IO ()
unsafeInitWithEntropy RandomState
rstate
  BlockCount Prim -> MemoryCell (BlockCount Prim) -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise BlockCount Prim
zeroBlocks MemoryCell (BlockCount Prim)
randomGenBlocks
  RandomState -> IO ()
generateRandom RandomState
rstate


-- | Generate random bytes into the context in one go which will then
-- be slowly released to the outside world. We also keep track of how
-- much blocks is generated which will be used to check when to reseed
-- the generator from system entropy.
generateRandom :: RandomState -> IO ()
generateRandom :: RandomState -> IO ()
generateRandom rstate :: RandomState
rstate@RandomState{MemoryCell (BlockCount Prim)
Cxt RandomBufferSize
randomGenBlocks :: MemoryCell (BlockCount Prim)
randomCxt :: Cxt RandomBufferSize
randomGenBlocks :: RandomState -> MemoryCell (BlockCount Prim)
randomCxt :: RandomState -> Cxt RandomBufferSize
..} = do
  (BufferPtr -> BlockCount Prim -> Internals -> IO ())
-> Cxt RandomBufferSize -> IO ()
forall (n :: Nat).
KnownNat n =>
(BufferPtr -> BlockCount Prim -> Internals -> IO ())
-> Cxt n -> IO ()
unsafeGenerateBlocks BufferPtr -> BlockCount Prim -> Internals -> IO ()
randomBlocks Cxt RandomBufferSize
randomCxt
  (BlockCount Prim -> BlockCount Prim)
-> MemoryCell (BlockCount Prim) -> IO ()
forall mem a b.
(Initialisable mem a, Extractable mem b) =>
(b -> a) -> mem -> IO ()
modifyMem (BlockCount Prim -> BlockCount Prim -> BlockCount Prim
forall a. Monoid a => a -> a -> a
mappend (BlockCount Prim -> BlockCount Prim -> BlockCount Prim)
-> BlockCount Prim -> BlockCount Prim -> BlockCount Prim
forall a b. (a -> b) -> a -> b
$ Proxy (Cxt RandomBufferSize) -> BlockCount Prim
forall (n :: Nat). KnownNat n => Proxy (Cxt n) -> BlockCount Prim
cxtBlockCount (Proxy (Cxt RandomBufferSize) -> BlockCount Prim)
-> Proxy (Cxt RandomBufferSize) -> BlockCount Prim
forall a b. (a -> b) -> a -> b
$ Cxt RandomBufferSize -> Proxy (Cxt RandomBufferSize)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Cxt RandomBufferSize
randomCxt) MemoryCell (BlockCount Prim)
randomGenBlocks
  RandomState -> IO ()
unsafeInitFromBuffer RandomState
rstate

------------------------------ DANGEROUS ACCESS manipulation ------------------------

--
-- These are highly unsafe code do not export. All hell breaks loose
-- otherwise.
--

-- | Initialise the internals from the entropy source.
unsafeInitWithEntropy :: RandomState -> IO ()
unsafeInitWithEntropy :: RandomState -> IO ()
unsafeInitWithEntropy = (Access -> IO (BYTES Int)) -> [Access] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Access -> IO (BYTES Int)
initWithEntropy ([Access] -> IO ())
-> (RandomState -> [Access]) -> RandomState -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RandomState -> [Access]
forall mem. WriteAccessible mem => mem -> [Access]
writeAccess
  where initWithEntropy :: Access -> IO (BYTES Int)
initWithEntropy Access{Ptr Word8
BYTES Int
accessSize :: Access -> BYTES Int
accessPtr :: Access -> Ptr Word8
accessSize :: BYTES Int
accessPtr :: Ptr Word8
..} = BYTES Int -> Ptr Word8 -> IO (BYTES Int)
getEntropy BYTES Int
accessSize Ptr Word8
accessPtr

-- | Initialise the internals from the already generated blocks. CSPRG
-- implementations should ensure that the context is large enough to
-- hold enough bytes so even after initialising the internals, there
-- is enough data left to give out for subsequent calls. Otherwise
-- each sampling will result in a infinite loop.
unsafeInitFromBuffer :: RandomState -> IO ()
unsafeInitFromBuffer :: RandomState -> IO ()
unsafeInitFromBuffer rstate :: RandomState
rstate@RandomState{MemoryCell (BlockCount Prim)
Cxt RandomBufferSize
randomGenBlocks :: MemoryCell (BlockCount Prim)
randomCxt :: Cxt RandomBufferSize
randomGenBlocks :: RandomState -> MemoryCell (BlockCount Prim)
randomCxt :: RandomState -> Cxt RandomBufferSize
..} = (Access -> IO (BYTES Int)) -> [Access] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Access -> IO (BYTES Int)
initFromBuffer ([Access] -> IO ()) -> [Access] -> IO ()
forall a b. (a -> b) -> a -> b
$ RandomState -> [Access]
forall mem. WriteAccessible mem => mem -> [Access]
writeAccess RandomState
rstate
  where initFromBuffer :: Access -> IO (BYTES Int)
initFromBuffer Access{Ptr Word8
BYTES Int
accessSize :: BYTES Int
accessPtr :: Ptr Word8
accessSize :: Access -> BYTES Int
accessPtr :: Access -> Ptr Word8
..}
          = BYTES Int
-> Dest (Ptr Word8) -> Cxt RandomBufferSize -> IO (BYTES Int)
forall (n :: Nat).
KnownNat n =>
BYTES Int -> Dest (Ptr Word8) -> Cxt n -> IO (BYTES Int)
unsafeWriteTo BYTES Int
accessSize (Ptr Word8 -> Dest (Ptr Word8)
forall a. a -> Dest a
destination Ptr Word8
accessPtr) Cxt RandomBufferSize
randomCxt


-- | Zero blocks of the primitive
zeroBlocks :: BlockCount Prim
zeroBlocks :: BlockCount Prim
zeroBlocks = Int
0 Int -> Proxy Prim -> BlockCount Prim
forall p. Int -> Proxy p -> BlockCount p
`blocksOf` Proxy Prim
forall k (t :: k). Proxy t
Proxy


unsafeRandomBytes :: BYTES Int
                  -> Dest (Ptr Word8)
                  -> RandomState -> IO ()
unsafeRandomBytes :: BYTES Int -> Dest (Ptr Word8) -> RandomState -> IO ()
unsafeRandomBytes BYTES Int
sz Dest (Ptr Word8)
destPtr rstate :: RandomState
rstate@RandomState{MemoryCell (BlockCount Prim)
Cxt RandomBufferSize
randomGenBlocks :: MemoryCell (BlockCount Prim)
randomCxt :: Cxt RandomBufferSize
randomGenBlocks :: RandomState -> MemoryCell (BlockCount Prim)
randomCxt :: RandomState -> Cxt RandomBufferSize
..}
  = BYTES Int -> Dest (Ptr Word8) -> IO ()
go BYTES Int
sz Dest (Ptr Word8)
destPtr
  where go :: BYTES Int -> Dest (Ptr Word8) -> IO ()
go BYTES Int
n Dest (Ptr Word8)
ptr
          | BYTES Int
n BYTES Int -> BYTES Int -> Bool
forall a. Ord a => a -> a -> Bool
<= BYTES Int
0 = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          | Bool
otherwise = do BYTES Int
trfed <- BYTES Int
-> Dest (Ptr Word8) -> Cxt RandomBufferSize -> IO (BYTES Int)
forall (n :: Nat).
KnownNat n =>
BYTES Int -> Dest (Ptr Word8) -> Cxt n -> IO (BYTES Int)
unsafeWriteTo BYTES Int
n Dest (Ptr Word8)
ptr Cxt RandomBufferSize
randomCxt
                           let more :: BYTES Int
more    = BYTES Int
n BYTES Int -> BYTES Int -> BYTES Int
forall a. Num a => a -> a -> a
- BYTES Int
trfed
                               nextPtr :: Dest (Ptr Word8)
nextPtr = (Ptr Word8 -> BYTES Int -> Ptr Word8
forall l a. LengthUnit l => Ptr a -> l -> Ptr a
`movePtr` BYTES Int
trfed) (Ptr Word8 -> Ptr Word8) -> Dest (Ptr Word8) -> Dest (Ptr Word8)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Dest (Ptr Word8)
ptr
                             in Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (BYTES Int
more BYTES Int -> BYTES Int -> Bool
forall a. Ord a => a -> a -> Bool
> BYTES Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ RandomState -> IO ()
sample RandomState
rstate IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> BYTES Int -> Dest (Ptr Word8) -> IO ()
go BYTES Int
more Dest (Ptr Word8)
nextPtr

-- | Fill a buffer pointed by the given pointer with random bytes.
fillRandomBytes :: (LengthUnit l, Pointer ptr)
                => l
                -> Dest (ptr a)
                -> RandomState
                -> IO ()
fillRandomBytes :: l -> Dest (ptr a) -> RandomState -> IO ()
fillRandomBytes l
l Dest (ptr a)
ptr = BYTES Int -> Dest (Ptr Word8) -> RandomState -> IO ()
unsafeRandomBytes (l -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes l
l) Dest (Ptr Word8)
forall b. Dest (Ptr b)
wptr
  where wptr :: Dest (Ptr b)
wptr = (ptr a -> Ptr b) -> Dest (ptr a) -> Dest (Ptr b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Ptr a -> Ptr b
forall a b. Ptr a -> Ptr b
castPtr (Ptr a -> Ptr b) -> (ptr a -> Ptr a) -> ptr a -> Ptr b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ptr a -> Ptr a
forall (ptr :: * -> *) a. Pointer ptr => ptr a -> Ptr a
unsafeRawPtr) Dest (ptr a)
ptr

instance ByteSource RandomState where
  fillBytes :: BYTES Int -> RandomState -> Ptr a -> IO (FillResult RandomState)
fillBytes BYTES Int
n RandomState
rstate Ptr a
ptr
    = BYTES Int -> Dest (Ptr Word8) -> RandomState -> IO ()
unsafeRandomBytes BYTES Int
n (Ptr Word8 -> Dest (Ptr Word8)
forall a. a -> Dest a
destination (Ptr a -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr a
ptr)) RandomState
rstate IO () -> IO (FillResult RandomState) -> IO (FillResult RandomState)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> FillResult RandomState -> IO (FillResult RandomState)
forall (m :: * -> *) a. Monad m => a -> m a
return (RandomState -> FillResult RandomState
forall a. a -> FillResult a
Remaining RandomState
rstate)