{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE UnicodeSyntax #-}

-- -------------------------------------------------------------------------- --
-- |
-- Module: Data.ByteString.Random.Internal
-- Copyright: (c) Lars Kuhtz <lakuhtz@gmail.com> 2017
-- License: MIT
-- Maintainer: lakuhtz@gmail.com
-- Stability: experimental

module Data.ByteString.Random.Internal
( generate
, RandomWords(..)
) where

import Control.Exception (bracketOnError)

import Data.ByteString (ByteString)
import Data.ByteString.Unsafe (unsafePackAddressLen)
import Data.Word (Word8, Word64)

import Foreign (mallocBytes, poke, plusPtr, free, castPtr)

import GHC.Ptr (Ptr(..))

import Numeric.Natural (Natural)

-- -------------------------------------------------------------------------- --

class RandomWords g where
    uniformW8  g  IO Word8
        -- ^ function that generates uniformily distributed random 8 bit words
    uniformW64  g  IO Word64
        -- ^ function that generates uniformily distributed random 64 bit words

-- The reason why a type class is used instead of passing the IO functions
-- directoy to generate, is that we can force GHC to inline these methods by
-- using a SPECIALIZE pragma.

-- -------------------------------------------------------------------------- --

-- | Generates uniformily distributed random bytestrings of length n using the
-- given PRNG.
--
generate
     RandomWords g
     g
        -- ^ PRNG
     Natural
        -- ^ Length of the result bytestring in bytes
     IO ByteString
generate :: g -> Natural -> IO ByteString
generate g
g Natural
n =
    IO (Ptr Word64)
-> (Ptr Word64 -> IO ())
-> (Ptr Word64 -> IO ByteString)
-> IO ByteString
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (Int -> IO (Ptr Word64)
forall a. Int -> IO (Ptr a)
mallocBytes Int
len8) Ptr Word64 -> IO ()
forall a. Ptr a -> IO ()
free ((Ptr Word64 -> IO ByteString) -> IO ByteString)
-> (Ptr Word64 -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \ptr :: Ptr Word64
ptr@(Ptr !Addr#
addr)  do
        {-# SCC "go" #-} Ptr Word64 -> IO ()
go Ptr Word64
ptr
        {-# SCC "pack" #-} Int -> Addr# -> IO ByteString
unsafePackAddressLen Int
len8 Addr#
addr
  where
    len8, len64  Int
    !len8 :: Int
len8 = Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
n
    !len64 :: Int
len64 = Int
len8 Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8

    go  Ptr Word64  IO ()
    go :: Ptr Word64 -> IO ()
go !Ptr Word64
startPtr = Ptr Word64 -> IO ()
loop64 Ptr Word64
startPtr
      where
        -- Would it help to add more manual unrolling levels?
        -- How smart is the compiler about unrolling loops?

        -- Generate 64bit values
        fin64Ptr  Ptr Word64
        !fin64Ptr :: Ptr Word64
fin64Ptr = Ptr Word64
startPtr Ptr Word64 -> Int -> Ptr Word64
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
len64 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8)

        loop64  Ptr Word64  IO ()
        loop64 :: Ptr Word64 -> IO ()
loop64 !Ptr Word64
curPtr
            | Ptr Word64
curPtr Ptr Word64 -> Ptr Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Ptr Word64
fin64Ptr = {-# SCC "loop64" #-} do
                !Word64
b  g -> IO Word64
forall g. RandomWords g => g -> IO Word64
uniformW64 g
g
                {-# SCC "poke64" #-} Ptr Word64 -> Word64 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word64
curPtr Word64
b
                Ptr Word64 -> IO ()
loop64 (Ptr Word64 -> IO ()) -> Ptr Word64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Word64
curPtr Ptr Word64 -> Int -> Ptr Word64
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
8
            | Bool
otherwise = Ptr Word8 -> IO ()
loop8 (Ptr Word8 -> IO ()) -> Ptr Word8 -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Word64 -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr Word64
curPtr

        -- Generate 8bit values
        fin8Ptr  Ptr Word8
        !fin8Ptr :: Ptr Word8
fin8Ptr = Ptr Word64
startPtr Ptr Word64 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
len8

        loop8  Ptr Word8  IO ()
        loop8 :: Ptr Word8 -> IO ()
loop8 !Ptr Word8
curPtr
            | Ptr Word8
curPtr Ptr Word8 -> Ptr Word8 -> Bool
forall a. Ord a => a -> a -> Bool
< Ptr Word8
fin8Ptr = {-# SCC "loop8" #-} do
                !Word8
b  g -> IO Word8
forall g. RandomWords g => g -> IO Word8
uniformW8 g
g
                Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
curPtr Word8
b
                Ptr Word8 -> IO ()
loop8 (Ptr Word8 -> IO ()) -> Ptr Word8 -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Word8
curPtr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1
            | Bool
otherwise = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINEABLE generate #-}