{-# LANGUAGE CPP                        #-}
{-# LANGUAGE ExplicitForAll             #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE UndecidableInstances       #-}
#if __GLASGOW_HASKELL__ < 710
{-# LANGUAGE OverlappingInstances #-}
#endif
module Crypto.RNG (
  
    module Crypto.RNG.Class
  
  , CryptoRNGState
  , newCryptoRNGState
  , unsafeCryptoRNGState
  , randomBytesIO
  , randomR
  
  , Random(..)
  , boundedIntegralRandom
  
  , CryptoRNGT
  , mapCryptoRNGT
  , runCryptoRNGT
  , withCryptoRNGState
  ) where
import Control.Applicative
import Control.Concurrent
import Control.Monad.Base
import Control.Monad.Catch
import Control.Monad.Cont
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.Trans.Control
import Crypto.Random
import Crypto.Random.DRBG
import Data.Bits
import Data.ByteString (ByteString, unpack)
import Data.Int
import Data.List
import Data.Word
import Crypto.RNG.Class
newtype CryptoRNGState = CryptoRNGState (MVar (GenAutoReseed HashDRBG HashDRBG))
newCryptoRNGState :: MonadIO m => m CryptoRNGState
newCryptoRNGState = liftIO $ newGenIO >>= fmap CryptoRNGState . newMVar
unsafeCryptoRNGState :: MonadIO m => ByteString -> m CryptoRNGState
unsafeCryptoRNGState s = liftIO $
  either (fail . show) (fmap CryptoRNGState . newMVar) (newGen s)
randomBytesIO :: ByteLength 
              -> CryptoRNGState
              -> IO ByteString
randomBytesIO n (CryptoRNGState gv) = do
  liftIO $ modifyMVar gv $ \g -> do
    (bs, g') <- either (fail "Crypto.GlobalRandom.genBytes") return $
                genBytes n g
    return (g', bs)
randomR :: (CryptoRNG m, Integral a) => (a, a) -> m a
randomR (minb', maxb') = do
  bs <- randomBytes byteLen
  return . fromIntegral $
    minb + foldl1' (\r a -> shiftL r 8 .|. a) (map toInteger (unpack bs))
            `mod` range
    where
      minb, maxb, range :: Integer
      minb = fromIntegral minb'
      maxb = fromIntegral maxb'
      range = maxb - minb + 1
      byteLen = ceiling $ logBase 2 (fromIntegral range) / (8 :: Double)
boundedIntegralRandom :: forall m a. (CryptoRNG m, Integral a, Bounded a) => m a
boundedIntegralRandom = randomR (minBound :: a, maxBound :: a)
class Random a where
  random :: CryptoRNG m => m a
instance Random Int16 where
  random = boundedIntegralRandom
instance Random Int32 where
  random = boundedIntegralRandom
instance Random Int64 where
  random = boundedIntegralRandom
instance Random Int where
  random = boundedIntegralRandom
instance Random Word8 where
  random = boundedIntegralRandom
instance Random Word16 where
  random = boundedIntegralRandom
instance Random Word32 where
  random = boundedIntegralRandom
instance Random Word64 where
  random = boundedIntegralRandom
instance Random Word where
  random = boundedIntegralRandom
type InnerCryptoRNGT = ReaderT CryptoRNGState
newtype CryptoRNGT m a = CryptoRNGT { unCryptoRNGT :: InnerCryptoRNGT m a }
  deriving ( Alternative, Applicative, Functor, Monad
           , MonadBase b, MonadCatch, MonadError e, MonadIO, MonadMask, MonadPlus
           , MonadThrow, MonadTrans )
mapCryptoRNGT :: (m a -> n b) -> CryptoRNGT m a -> CryptoRNGT n b
mapCryptoRNGT f m = withCryptoRNGState $ \s -> f (runCryptoRNGT s m)
runCryptoRNGT :: CryptoRNGState -> CryptoRNGT m a -> m a
runCryptoRNGT gv m = runReaderT (unCryptoRNGT m) gv
withCryptoRNGState :: (CryptoRNGState -> m a) -> CryptoRNGT m a
withCryptoRNGState = CryptoRNGT . ReaderT
instance MonadTransControl CryptoRNGT where
  type StT CryptoRNGT a = StT InnerCryptoRNGT a
  liftWith = defaultLiftWith CryptoRNGT unCryptoRNGT
  restoreT = defaultRestoreT CryptoRNGT
  {-# INLINE liftWith #-}
  {-# INLINE restoreT #-}
instance MonadBaseControl b m => MonadBaseControl b (CryptoRNGT m) where
  type StM (CryptoRNGT m) a = ComposeSt CryptoRNGT m a
  liftBaseWith = defaultLiftBaseWith
  restoreM     = defaultRestoreM
  {-# INLINE liftBaseWith #-}
  {-# INLINE restoreM #-}
instance {-# OVERLAPPABLE #-} MonadIO m => CryptoRNG (CryptoRNGT m) where
  randomBytes n = CryptoRNGT ask >>= liftIO . randomBytesIO n