{-# LANGUAGE FlexibleInstances, TypeSynonymInstances, FlexibleContexts
, GeneralizedNewtypeDeriving, MultiParamTypeClasses, UndecidableInstances #-}
module Control.Monad.CryptoRandom
( CRandom(..)
, CRandomR(..)
, MonadCRandom(..)
, MonadCRandomR(..)
, ContainsGenError(..)
, CRandT(..)
, CRand
, runCRandT
, evalCRandT
, runCRand
, evalCRand
, newGenCRand
, liftCRand
, liftCRandT
, module Crypto.Random
) where
import Control.Applicative
import Control.Arrow (right, left, first)
import Control.Monad (liftM)
import qualified Control.Monad.Catch as C (MonadThrow(..), MonadCatch(..))
import Control.Monad.Cont
import Control.Monad.Trans.Except
import Control.Monad.Error.Class
import Control.Monad.IO.Class
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.RWS.Lazy as Lazy
import Control.Monad.RWS.Strict as Strict
import Control.Monad.State.Lazy as Lazy
import Control.Monad.State.Strict as Strict
import Control.Monad.Writer.Class
import Control.Monad.Writer.Lazy as Lazy
import Control.Monad.Writer.Strict as Strict
import Crypto.Random (CryptoRandomGen(..), GenError(..))
import Data.Bits (xor, setBit, shiftR, shiftL, (.&.))
import Data.Int
import Data.List (foldl')
import Data.Word
import Data.Proxy
import Data.Tagged
import qualified Data.ByteString as B
class (ContainsGenError e, MonadError e m) => MonadCRandom e m where
getCRandom :: CRandom a => m a
getBytes :: Int -> m B.ByteString
getBytesWithEntropy :: Int -> B.ByteString -> m B.ByteString
doReseed :: B.ByteString -> m ()
instance MonadCRandom e m => MonadCRandom e (Lazy.StateT s m) where
getCRandom = lift getCRandom
{-# INLINE getCRandom #-}
getBytes = lift . getBytes
{-# INLINE getBytes #-}
getBytesWithEntropy i = lift . getBytesWithEntropy i
{-# INLINE getBytesWithEntropy #-}
doReseed = lift . doReseed
{-# INLINE doReseed #-}
instance MonadCRandom e m => MonadCRandom e (Strict.StateT s m) where
getCRandom = lift getCRandom
{-# INLINE getCRandom #-}
getBytes = lift . getBytes
{-# INLINE getBytes #-}
getBytesWithEntropy i = lift . getBytesWithEntropy i
{-# INLINE getBytesWithEntropy #-}
doReseed = lift . doReseed
{-# INLINE doReseed #-}
instance (Monoid w, MonadCRandom e m) => MonadCRandom e (Strict.WriterT w m) where
getCRandom = lift getCRandom
{-# INLINE getCRandom #-}
getBytes = lift . getBytes
{-# INLINE getBytes #-}
getBytesWithEntropy i = lift . getBytesWithEntropy i
{-# INLINE getBytesWithEntropy #-}
doReseed = lift . doReseed
{-# INLINE doReseed #-}
instance (Monoid w, MonadCRandom e m) => MonadCRandom e (Lazy.WriterT w m) where
getCRandom = lift getCRandom
{-# INLINE getCRandom #-}
getBytes = lift . getBytes
{-# INLINE getBytes #-}
getBytesWithEntropy i = lift . getBytesWithEntropy i
{-# INLINE getBytesWithEntropy #-}
doReseed = lift . doReseed
{-# INLINE doReseed #-}
instance MonadCRandom e m => MonadCRandom e (ReaderT r m) where
getCRandom = lift getCRandom
{-# INLINE getCRandom #-}
getBytes = lift . getBytes
{-# INLINE getBytes #-}
getBytesWithEntropy i = lift . getBytesWithEntropy i
{-# INLINE getBytesWithEntropy #-}
doReseed = lift . doReseed
{-# INLINE doReseed #-}
instance (Monoid w, MonadCRandom e m) => MonadCRandom e (Strict.RWST r w s m) where
getCRandom = lift getCRandom
{-# INLINE getCRandom #-}
getBytes = lift . getBytes
{-# INLINE getBytes #-}
getBytesWithEntropy i = lift . getBytesWithEntropy i
{-# INLINE getBytesWithEntropy #-}
doReseed = lift . doReseed
{-# INLINE doReseed #-}
instance (Monoid w, MonadCRandom e m) => MonadCRandom e (Lazy.RWST r w s m) where
getCRandom = lift getCRandom
{-# INLINE getCRandom #-}
getBytes = lift . getBytes
{-# INLINE getBytes #-}
getBytesWithEntropy i = lift . getBytesWithEntropy i
{-# INLINE getBytesWithEntropy #-}
doReseed = lift . doReseed
{-# INLINE doReseed #-}
newGenCRand :: (CryptoRandomGen g, MonadCRandom GenError m, Functor m) => m g
newGenCRand = go 0
where
go 1000 = throwError (GenErrorOther "The generator instance requested by newGenCRand never instantiates.")
go i = do let p = Proxy
getTypedGen :: (Functor m, CryptoRandomGen g, MonadCRandom GenError m)
=> Proxy g -> m (Either GenError g)
getTypedGen pr = fmap newGen (getBytes $ proxy genSeedLength pr)
res <- getTypedGen p
case res of
Left _ -> go (i+1)
Right g -> return (g `asProxyTypeOf` p)
class (ContainsGenError e, MonadError e m) => MonadCRandomR e m where
getCRandomR :: CRandomR a => (a,a) -> m a
class ContainsGenError e where
toGenError :: e -> Maybe GenError
fromGenError :: GenError -> e
instance ContainsGenError GenError where
toGenError = Just
fromGenError = id
class CRandom a where
crandom :: (CryptoRandomGen g) => g -> Either GenError (a, g)
crandoms :: (CryptoRandomGen g) => g -> [a]
crandoms g =
case crandom g of
Left _ -> []
Right (a,g') -> a : crandoms g'
class CRandomR a where
crandomR :: (CryptoRandomGen g) => (a, a) -> g -> Either GenError (a, g)
crandomRs :: (CryptoRandomGen g) => (a, a) -> g -> [a]
crandomRs r g =
case crandomR r g of
Left _ -> []
Right (a,g') -> a : crandomRs r g'
instance CRandomR Integer where
crandomR = crandomR_Num
{-# INLINE crandomR #-}
instance CRandom Int where
crandom = crandomR (minBound, maxBound)
{-# INLINE crandom #-}
instance CRandomR Int where
crandomR = crandomR_Num
{-# INLINE crandomR #-}
instance CRandom Word8 where
crandom = crandomR (minBound, maxBound)
{-# INLINE crandom #-}
instance CRandomR Word8 where
crandomR = crandomR_Num
{-# INLINE crandomR #-}
instance CRandom Word16 where
crandom = crandomR (minBound, maxBound)
{-# INLINE crandom #-}
instance CRandomR Word16 where
crandomR = crandomR_Num
{-# INLINE crandomR #-}
instance CRandom Word32 where
crandom = crandomR (minBound, maxBound)
{-# INLINE crandom #-}
instance CRandomR Word32 where
crandomR = crandomR_Num
{-# INLINE crandomR #-}
instance CRandom Word64 where
crandom = crandomR (minBound, maxBound)
{-# INLINE crandom #-}
instance CRandomR Word64 where
crandomR = crandomR_Num
{-# INLINE crandomR #-}
instance CRandom Int8 where
crandom = crandomR (minBound, maxBound)
{-# INLINE crandom #-}
instance CRandomR Int8 where
crandomR = crandomR_Num
{-# INLINE crandomR #-}
instance CRandom Int16 where
crandom = crandomR (minBound, maxBound)
{-# INLINE crandom #-}
instance CRandomR Int16 where
crandomR = crandomR_Num
{-# INLINE crandomR #-}
instance CRandom Int32 where
crandom = crandomR (minBound, maxBound)
{-# INLINE crandom #-}
instance CRandomR Int32 where
crandomR = crandomR_Num
{-# INLINE crandomR #-}
instance CRandom Int64 where
crandom = crandomR (minBound, maxBound)
{-# INLINE crandom #-}
instance CRandomR Int64 where
crandomR = crandomR_Num
{-# INLINE crandomR #-}
instance CRandom Bool where
crandom g = first (toEnum . fromIntegral) `fmap` crandomR (0::Word8,1) g
crandomR_Num :: (Integral a, CryptoRandomGen g) => (a,a) -> g -> Either GenError (a,g)
crandomR_Num (low, high) g
| high < low = crandomR_Num (high,low) g
| high == low = Right (high, g)
| otherwise = go g
where
mask = foldl' setBit 0 [0 .. fromIntegral nrBits - 1]
nrBits = base2Log range
range :: Integer
range = (fromIntegral high) - (fromIntegral low) + 1
nrBytes = (nrBits + 7) `div` 8
go gen =
let offset = genBytes (fromIntegral nrBytes) gen
in case offset of
Left err -> Left err
Right (bs, g') ->
let res = fromIntegral low + (bs2i bs .&. mask)
in if res > fromIntegral high then go g' else Right (fromIntegral res, g')
{-# INLINE crandomR_Num #-}
wrap :: (Monad m, ContainsGenError e) => (g -> Either GenError (a,g)) -> CRandT g e m a
wrap f = CRandT $ do
g <- get
case f g of
Right (a,g') -> put g' >> return a
Left x -> throwError (fromGenError x)
{-# INLINE wrap #-}
liftCRand :: (g -> Either e (a, g)) -> CRand g e a
liftCRand f = CRandT $ Lazy.StateT $ (\g -> ExceptT $ Identity $ f g)
{-# INLINE liftCRand #-}
liftCRandT :: (Monad m) => (g -> Either e (a, g)) -> CRandT g e m a
liftCRandT f = CRandT $ Lazy.StateT $ (\g -> ExceptT $ return $ f g)
{-# INLINE liftCRandT #-}
newtype CRandT g e m a = CRandT { unCRandT :: Lazy.StateT g (ExceptT e m) a }
deriving (Functor, Applicative, Monad, MonadIO, MonadError e, MonadFix)
instance MonadTrans (CRandT g e) where
lift = CRandT . lift . lift
{-# INLINE lift #-}
instance (MonadState s m) => MonadState s (CRandT g e m) where
get = lift get
{-# INLINE get #-}
put = lift . put
{-# INLINE put #-}
instance (MonadReader r m) => MonadReader r (CRandT g e m) where
ask = lift ask
{-# INLINE ask #-}
local f = CRandT . local f . unCRandT
{-# INLINE local #-}
instance (MonadWriter w m) => MonadWriter w (CRandT g e m) where
tell = lift . tell
{-# INLINE tell #-}
listen = CRandT . listen . unCRandT
{-# INLINE listen #-}
pass = CRandT . pass . unCRandT
{-# INLINE pass #-}
instance (MonadCont m) => MonadCont (CRandT g e m) where
callCC f = CRandT $ callCC $ \amb -> unCRandT $ f (CRandT . amb)
{-# INLINE callCC #-}
instance C.MonadThrow m => C.MonadThrow (CRandT g e m) where
throwM = CRandT . C.throwM
instance C.MonadCatch m => C.MonadCatch (CRandT g e m) where
catch (CRandT m) f = CRandT $ C.catch m (unCRandT . f)
type CRand g e = CRandT g e Identity
runCRandT :: ContainsGenError e => CRandT g e m a -> g -> m (Either e (a,g))
runCRandT m g = runExceptT . flip Lazy.runStateT g . unCRandT $ m
{-# INLINE runCRandT #-}
evalCRandT :: (ContainsGenError e, Monad m) => CRandT g e m a -> g -> m (Either e a)
evalCRandT m g = liftM (right fst) (runCRandT m g)
{-# INLINE evalCRandT #-}
runCRand :: (ContainsGenError e) => CRand g e a -> g -> Either e (a, g)
runCRand m = runIdentity . runCRandT m
{-# INLINE runCRand #-}
evalCRand :: CRand g GenError a -> g -> Either GenError a
evalCRand m = runIdentity . evalCRandT m
{-# INLINE evalCRand #-}
instance (ContainsGenError e, Monad m, CryptoRandomGen g) => MonadCRandom e (CRandT g e m) where
getCRandom = wrap crandom
{-# INLINE getCRandom #-}
getBytes i = wrap (genBytes i)
{-# INLINE getBytes #-}
getBytesWithEntropy i e = wrap (genBytesWithEntropy i e)
{-# INLINE getBytesWithEntropy #-}
doReseed bs = CRandT $ do
get >>= \g ->
case reseed bs g of
Right g' -> put g'
Left x -> throwError (fromGenError x)
{-# INLINE doReseed #-}
instance (ContainsGenError e, Monad m, CryptoRandomGen g) => MonadCRandomR e (CRandT g e m) where
getCRandomR = wrap . crandomR
{-# INLINE getCRandomR #-}
instance MonadCRandomR e m => MonadCRandomR e (Lazy.StateT s m) where
getCRandomR = lift . getCRandomR
{-# INLINE getCRandomR #-}
instance MonadCRandomR e m => MonadCRandomR e (Strict.StateT s m) where
getCRandomR = lift . getCRandomR
{-# INLINE getCRandomR #-}
instance (MonadCRandomR e m, Monoid w) => MonadCRandomR e (Lazy.WriterT w m) where
getCRandomR = lift . getCRandomR
{-# INLINE getCRandomR #-}
instance (MonadCRandomR e m, Monoid w) => MonadCRandomR e (Strict.WriterT w m) where
getCRandomR = lift . getCRandomR
{-# INLINE getCRandomR #-}
instance MonadCRandomR e m => MonadCRandomR e (ReaderT r m) where
getCRandomR = lift . getCRandomR
{-# INLINE getCRandomR #-}
instance (MonadCRandomR e m, Monoid w) => MonadCRandomR e (Lazy.RWST r w s m) where
getCRandomR = lift . getCRandomR
{-# INLINE getCRandomR #-}
instance (MonadCRandomR e m, Monoid w) => MonadCRandomR e (Strict.RWST r w s m) where
getCRandomR = lift . getCRandomR
{-# INLINE getCRandomR #-}
base2Log :: Integer -> Integer
base2Log i
| i >= setBit 0 64 = 64 + base2Log (i `shiftR` 64)
| i >= setBit 0 32 = 32 + base2Log (i `shiftR` 32)
| i >= setBit 0 16 = 16 + base2Log (i `shiftR` 16)
| i >= setBit 0 8 = 8 + base2Log (i `shiftR` 8)
| i >= setBit 0 0 = 1 + base2Log (i `shiftR` 1)
| otherwise = 0
bs2i :: B.ByteString -> Integer
bs2i bs = B.foldl' (\i b -> (i `shiftL` 8) + fromIntegral b) 0 bs
{-# INLINE bs2i #-}