{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE UnicodeSyntax #-}
module Data.ByteString.Random.Internal
( generate
, RandomWords(..)
) where
import Control.Exception (bracketOnError)
import Data.ByteString (ByteString)
import Data.ByteString.Unsafe (unsafePackAddressLen)
import Data.Word (Word8, Word64)
import Foreign (mallocBytes, poke, plusPtr, free, castPtr)
import GHC.Ptr (Ptr(..))
import Numeric.Natural (Natural)
class RandomWords g where
uniformW8 ∷ g → IO Word8
uniformW64 ∷ g → IO Word64
generate
∷ RandomWords g
⇒ g
→ Natural
→ IO ByteString
generate g n =
bracketOnError (mallocBytes len8) free $ \ptr@(Ptr !addr) → do
{-# SCC "go" #-} go ptr
{-# SCC "pack" #-} unsafePackAddressLen len8 addr
where
len8, len64 ∷ Int
!len8 = fromIntegral n
!len64 = len8 `div` 8
go ∷ Ptr Word64 → IO ()
go !startPtr = loop64 startPtr
where
fin64Ptr ∷ Ptr Word64
!fin64Ptr = startPtr `plusPtr` (len64 * 8)
loop64 ∷ Ptr Word64 → IO ()
loop64 !curPtr
| curPtr < fin64Ptr = {-# SCC "loop64" #-} do
!b ← uniformW64 g
{-# SCC "poke64" #-} poke curPtr b
loop64 $ {-# SCC "ptr_inc" #-} curPtr `plusPtr` 8
| otherwise = loop8 $ castPtr curPtr
fin8Ptr ∷ Ptr Word8
!fin8Ptr = startPtr `plusPtr` len8
loop8 ∷ Ptr Word8 → IO ()
loop8 !curPtr
| curPtr < fin8Ptr = {-# SCC "loop8" #-} do
!b ← uniformW8 g
poke curPtr b
loop8 $ curPtr `plusPtr` 1
| otherwise = return ()
{-# INLINEABLE generate #-}