{-# LANGUAGE CPP               #-}
------------------------------------------------------------------------------
-- | This module contains functionality common among multiple back-ends.
--

module Snap.Snaplet.Session.Common
  ( RNG
  , mkRNG
  , withRNG
  , randomToken
  , mkCSRFToken
  ) where

------------------------------------------------------------------------------
import           Control.Concurrent
import           Control.Monad
import           Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.Text.Encoding as T
import           Data.Text (Text)
import           Numeric
import           System.Random.MWC

#if !MIN_VERSION_base(4,8,0)
import           Control.Applicative
#endif


------------------------------------------------------------------------------
-- | High speed, mutable random number generator state
newtype RNG = RNG (MVar GenIO)

------------------------------------------------------------------------------
-- | Perform given action, mutating the RNG state
withRNG :: RNG
        -> (GenIO -> IO a)
        -> IO a
withRNG :: forall a. RNG -> (GenIO -> IO a) -> IO a
withRNG (RNG MVar GenIO
rng) GenIO -> IO a
m = forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar GenIO
rng GenIO -> IO a
m


------------------------------------------------------------------------------
-- | Create a new RNG
mkRNG :: IO RNG
mkRNG :: IO RNG
mkRNG = forall (m :: * -> *) a.
PrimBase m =>
(Gen (PrimState m) -> m a) -> IO a
withSystemRandom (forall a. a -> IO (MVar a)
newMVar forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVar GenIO -> RNG
RNG)


------------------------------------------------------------------------------
-- | Generates a random salt of given length
randomToken :: Int -> RNG -> IO ByteString
randomToken :: Int -> RNG -> IO ByteString
randomToken Int
n RNG
rng = do
    [Int]
is <- forall a. RNG -> (GenIO -> IO a) -> IO a
withRNG RNG
rng forall a b. (a -> b) -> a -> b
$ \GenIO
gen -> forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
take Int
n forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> [a]
repeat forall a b. (a -> b) -> a -> b
$ GenIO -> IO Int
mk GenIO
gen
    forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
B.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. (Integral a, Show a) => a -> ShowS
showHex String
"") forall a b. (a -> b) -> a -> b
$ [Int]
is
  where
    mk :: GenIO -> IO Int
    mk :: GenIO -> IO Int
mk = forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
(a, a) -> Gen (PrimState m) -> m a
uniformR (Int
0,Int
15)


------------------------------------------------------------------------------
-- | Generate a randomized CSRF token
mkCSRFToken :: RNG -> IO Text
mkCSRFToken :: RNG -> IO Text
mkCSRFToken RNG
rng = ByteString -> Text
T.decodeUtf8 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> RNG -> IO ByteString
randomToken Int
40 RNG
rng