{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StrictData          #-}
{-# LANGUAGE Trustworthy         #-}

module Network.Tox.Network.MonadRandomBytes where

import           Control.Applicative        (Applicative, (<$>))
import           Control.Monad.RWS          (RWST)
import           Control.Monad.Random       (RandT, getRandoms)
import           Control.Monad.Reader       (ReaderT)
import           Control.Monad.State        (StateT)
import           Control.Monad.Trans.Class  (lift)
import           Control.Monad.Writer       (WriterT)
import           Data.Binary                (get)
import           Data.Binary.Get            (Get, getWord16be, getWord32be,
                                             getWord64be, getWord8, runGet)
import           Data.ByteString            (ByteString, pack, unpack)
import           Data.ByteString.Lazy       (fromStrict)
import           Data.Monoid                (Monoid)
import           Data.Proxy                 (Proxy (..))
import           Data.Word                  (Word16, Word32, Word64, Word8)
import           System.Entropy             (getEntropy)
import           System.Random              (RandomGen)


import           Network.Tox.Crypto.Key     (Key)
import qualified Network.Tox.Crypto.Key     as Key
import           Network.Tox.Crypto.KeyPair (KeyPair)
import qualified Network.Tox.Crypto.KeyPair as KeyPair

class (Monad m, Applicative m) => MonadRandomBytes m where
  randomBytes :: Int -> m ByteString

  newKeyPair :: m KeyPair
  newKeyPair = SecretKey -> KeyPair
KeyPair.fromSecretKey (SecretKey -> KeyPair) -> m SecretKey -> m KeyPair
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m SecretKey
forall (m :: * -> *) a.
(MonadRandomBytes m, CryptoNumber a) =>
m (Key a)
randomKey

instance (Monad m, Applicative m, RandomGen s) => MonadRandomBytes (RandT s m) where
  randomBytes :: Int -> RandT s m ByteString
randomBytes Int
n = [Word8] -> ByteString
pack ([Word8] -> ByteString)
-> ([Word8] -> [Word8]) -> [Word8] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Word8] -> [Word8]
forall a. Int -> [a] -> [a]
take Int
n ([Word8] -> ByteString)
-> RandT s m [Word8] -> RandT s m ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RandT s m [Word8]
forall (m :: * -> *) a. (MonadRandom m, Random a) => m [a]
getRandoms

-- | cryptographically secure random bytes from system source
instance MonadRandomBytes IO where
  randomBytes :: Int -> IO ByteString
randomBytes = Int -> IO ByteString
getEntropy
  newKeyPair :: IO KeyPair
newKeyPair = IO KeyPair
KeyPair.newKeyPair

instance MonadRandomBytes m => MonadRandomBytes (ReaderT r m) where
  randomBytes :: Int -> ReaderT r m ByteString
randomBytes = m ByteString -> ReaderT r m ByteString
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ByteString -> ReaderT r m ByteString)
-> (Int -> m ByteString) -> Int -> ReaderT r m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> m ByteString
forall (m :: * -> *). MonadRandomBytes m => Int -> m ByteString
randomBytes
  newKeyPair :: ReaderT r m KeyPair
newKeyPair = m KeyPair -> ReaderT r m KeyPair
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m KeyPair
forall (m :: * -> *). MonadRandomBytes m => m KeyPair
newKeyPair
instance (Monoid w, MonadRandomBytes m) => MonadRandomBytes (WriterT w m) where
  randomBytes :: Int -> WriterT w m ByteString
randomBytes = m ByteString -> WriterT w m ByteString
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ByteString -> WriterT w m ByteString)
-> (Int -> m ByteString) -> Int -> WriterT w m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> m ByteString
forall (m :: * -> *). MonadRandomBytes m => Int -> m ByteString
randomBytes
  newKeyPair :: WriterT w m KeyPair
newKeyPair = m KeyPair -> WriterT w m KeyPair
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m KeyPair
forall (m :: * -> *). MonadRandomBytes m => m KeyPair
newKeyPair
instance MonadRandomBytes m => MonadRandomBytes (StateT s m) where
  randomBytes :: Int -> StateT s m ByteString
randomBytes = m ByteString -> StateT s m ByteString
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ByteString -> StateT s m ByteString)
-> (Int -> m ByteString) -> Int -> StateT s m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> m ByteString
forall (m :: * -> *). MonadRandomBytes m => Int -> m ByteString
randomBytes
  newKeyPair :: StateT s m KeyPair
newKeyPair = m KeyPair -> StateT s m KeyPair
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m KeyPair
forall (m :: * -> *). MonadRandomBytes m => m KeyPair
newKeyPair
instance (Monoid w, MonadRandomBytes m) => MonadRandomBytes (RWST r w s m) where
  randomBytes :: Int -> RWST r w s m ByteString
randomBytes = m ByteString -> RWST r w s m ByteString
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ByteString -> RWST r w s m ByteString)
-> (Int -> m ByteString) -> Int -> RWST r w s m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> m ByteString
forall (m :: * -> *). MonadRandomBytes m => Int -> m ByteString
randomBytes
  newKeyPair :: RWST r w s m KeyPair
newKeyPair = m KeyPair -> RWST r w s m KeyPair
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m KeyPair
forall (m :: * -> *). MonadRandomBytes m => m KeyPair
newKeyPair

randomBinary :: MonadRandomBytes m => Get a -> Int -> m a
randomBinary :: Get a -> Int -> m a
randomBinary Get a
g Int
len = Get a -> ByteString -> a
forall a. Get a -> ByteString -> a
runGet Get a
g (ByteString -> a) -> (ByteString -> ByteString) -> ByteString -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
fromStrict (ByteString -> a) -> m ByteString -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m ByteString
forall (m :: * -> *). MonadRandomBytes m => Int -> m ByteString
randomBytes Int
len

randomKey :: forall m a. (MonadRandomBytes m, Key.CryptoNumber a) => m (Key a)
randomKey :: m (Key a)
randomKey = Get (Key a) -> Int -> m (Key a)
forall (m :: * -> *) a. MonadRandomBytes m => Get a -> Int -> m a
randomBinary Get (Key a)
forall t. Binary t => Get t
get (Int -> m (Key a)) -> Int -> m (Key a)
forall a b. (a -> b) -> a -> b
$ Proxy a -> Int
forall a. CryptoNumber a => Proxy a -> Int
Key.encodedByteSize (Proxy a
forall k (t :: k). Proxy t
Proxy :: Proxy a)

randomNonce :: MonadRandomBytes m => m Key.Nonce
randomNonce :: m Nonce
randomNonce = m Nonce
forall (m :: * -> *) a.
(MonadRandomBytes m, CryptoNumber a) =>
m (Key a)
randomKey

randomWord64 :: MonadRandomBytes m => m Word64
randomWord64 :: m Word64
randomWord64 = Get Word64 -> Int -> m Word64
forall (m :: * -> *) a. MonadRandomBytes m => Get a -> Int -> m a
randomBinary Get Word64
getWord64be Int
8
randomWord32 :: MonadRandomBytes m => m Word32
randomWord32 :: m Word32
randomWord32 = Get Word32 -> Int -> m Word32
forall (m :: * -> *) a. MonadRandomBytes m => Get a -> Int -> m a
randomBinary Get Word32
getWord32be Int
4
randomWord16 :: MonadRandomBytes m => m Word16
randomWord16 :: m Word16
randomWord16 = Get Word16 -> Int -> m Word16
forall (m :: * -> *) a. MonadRandomBytes m => Get a -> Int -> m a
randomBinary Get Word16
getWord16be Int
2
randomWord8 :: MonadRandomBytes m => m Word8
randomWord8 :: m Word8
randomWord8 = Get Word8 -> Int -> m Word8
forall (m :: * -> *) a. MonadRandomBytes m => Get a -> Int -> m a
randomBinary Get Word8
getWord8 Int
1

-- produces Int uniformly distributed in range [0,bound)
randomInt :: MonadRandomBytes m => Int -> m Int
randomInt :: Int -> m Int
randomInt Int
bound | Int
bound Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = Int -> m Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
randomInt Int
bound =
  let
    numBits :: Int
numBits = Int -> Int
log2 Int
bound
    numBytes :: Int
numBytes = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
numBits Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8)
  in do
    Int
r <- (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
2Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^Int
numBits) (Int -> Int) -> (ByteString -> Int) -> ByteString -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> Int
makeInt ([Word8] -> Int) -> (ByteString -> [Word8]) -> ByteString -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Word8]
unpack (ByteString -> Int) -> m ByteString -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m ByteString
forall (m :: * -> *). MonadRandomBytes m => Int -> m ByteString
randomBytes Int
numBytes
    if Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
bound
      then Int -> m Int
forall (m :: * -> *). MonadRandomBytes m => Int -> m Int
randomInt Int
bound
      else Int -> m Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
r
  where
    log2 :: Int -> Int
    log2 :: Int -> Int
log2 = Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
ceiling (Double -> Int) -> (Int -> Double) -> Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double -> Double
forall a. Floating a => a -> a -> a
logBase Double
2 (Double -> Double) -> (Int -> Double) -> Int -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Int -> Double)
    makeInt :: [Word8] -> Int
    makeInt :: [Word8] -> Int
makeInt = (Word8 -> Int -> Int) -> Int -> [Word8] -> Int
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Word8
w -> (Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+) (Int -> Int) -> (Int -> Int) -> Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int
256Int -> Int -> Int
forall a. Num a => a -> a -> a
*)) Int
0

-- produces Int uniformly distributed in range [low,high]
randomIntR :: MonadRandomBytes m => (Int,Int) -> m Int
randomIntR :: (Int, Int) -> m Int
randomIntR (Int
low,Int
high) = (Int
low Int -> Int -> Int
forall a. Num a => a -> a -> a
+) (Int -> Int) -> m Int -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m Int
forall (m :: * -> *). MonadRandomBytes m => Int -> m Int
randomInt (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
high Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
low)

-- | produces uniformly random element of a list
uniform :: MonadRandomBytes m => [a] -> m a
uniform :: [a] -> m a
uniform [] = [Char] -> m a
forall a. HasCallStack => [Char] -> a
error [Char]
"empty list in uniform"
uniform [a]
as = ([a]
as[a] -> Int -> a
forall a. [a] -> Int -> a
!!) (Int -> a) -> m Int -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m Int
forall (m :: * -> *). MonadRandomBytes m => Int -> m Int
randomInt ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
as)

uniformSafe :: MonadRandomBytes m => [a] -> m (Maybe a)
uniformSafe :: [a] -> m (Maybe a)
uniformSafe [] = Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
uniformSafe [a]
as = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> m a -> m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a] -> m a
forall (m :: * -> *) a. MonadRandomBytes m => [a] -> m a
uniform [a]
as