{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
#if __GLASGOW_HASKELL__ >= 707
{-# LANGUAGE RoleAnnotations #-}
#endif
module System.Random.PCG.Pure
(
Gen, GenIO, GenST
, create, createSystemRandom, initialize, withSystemRandom
, Variate (..)
, advance, retract
, FrozenGen, save, restore, seed, initFrozen
, uniformW8, uniformW16, uniformW32, uniformW64
, uniformI8, uniformI16, uniformI32, uniformI64
, uniformF, uniformD, uniformBool
, uniformRW8, uniformRW16, uniformRW32, uniformRW64
, uniformRI8, uniformRI16, uniformRI32, uniformRI64
, uniformRF, uniformRD, uniformRBool
, uniformBW8, uniformBW16, uniformBW32, uniformBW64
, uniformBI8, uniformBI16, uniformBI32, uniformBI64
, uniformBF, uniformBD, uniformBBool
, SetSeq
, next'
, advanceSetSeq
) where
import Control.Monad.Primitive
import Data.Bits
import Data.Data
import Data.Primitive.ByteArray
import Foreign
import GHC.Generics
import System.Random.PCG.Class
import System.Random
type GenIO = Gen RealWorld
type GenST = Gen
newtype Gen s = G (MutableByteArray s)
type FrozenGen = SetSeq
data SetSeq = SetSeq
{-# UNPACK #-} !Word64
{-# UNPACK #-} !Word64
deriving (Show, Ord, Eq, Data, Typeable, Generic)
instance Storable SetSeq where
sizeOf _ = 16
{-# INLINE sizeOf #-}
alignment _ = 8
{-# INLINE alignment #-}
poke ptr (SetSeq x y) = poke ptr' x >> pokeElemOff ptr' 1 y
where ptr' = castPtr ptr
{-# INLINE poke #-}
peek ptr = do
let ptr' = castPtr ptr
s <- peek ptr'
inc <- peekElemOff ptr' 1
return $ SetSeq s inc
{-# INLINE peek #-}
seed :: SetSeq
seed = SetSeq 9600629759793949339 15726070495360670683
data Pair = Pair
{-# UNPACK #-} !Word64
{-# UNPACK #-} !Word32
multiplier :: Word64
multiplier = 6364136223846793005
state :: SetSeq -> Word64
state (SetSeq s inc) = s * multiplier + inc
{-# INLINE state #-}
output :: Word64 -> Word32
output s =
(shifted `unsafeShiftR` rot) .|. (shifted `unsafeShiftL` (negate rot .&. 31))
where
rot = fromIntegral $ s `shiftR` 59 :: Int
shifted = fromIntegral $ ((s `shiftR` 18) `xor` s) `shiftR` 27 :: Word32
{-# INLINE output #-}
pair :: SetSeq -> Pair
pair g@(SetSeq s _) = Pair (state g) (output s)
{-# INLINE pair #-}
bounded :: Word32 -> SetSeq -> Pair
bounded b (SetSeq s0 inc) = go s0
where
t = negate b `mod` b
go !s | r >= t = Pair s' (r `mod` b)
| otherwise = go s'
where Pair s' r = pair (SetSeq s inc)
{-# INLINE bounded #-}
advancing
:: Word64
-> Word64
-> Word64
-> Word64
-> Word64
advancing d0 s m0 p0 = go d0 m0 p0 1 0
where
go d cm cp am ap
| d <= 0 = am * s + ap
| odd d = go d' cm' cp' (am * cm) (ap * cm + cp)
| otherwise = go d' cm' cp' am ap
where
cm' = cm * cm
cp' = (cm + 1) * cp
d' = d `div` 2
advanceSetSeq :: Word64 -> FrozenGen -> FrozenGen
advanceSetSeq d (SetSeq s inc) = SetSeq (advancing d s multiplier inc) inc
advanceSetSeq' :: Word64 -> FrozenGen -> Word64
advanceSetSeq' d (SetSeq s inc) = advancing d s multiplier inc
start :: Word64 -> Word64 -> SetSeq
start a b = SetSeq s i
where
s = state (SetSeq (a + i) i)
i = (b `shiftL` 1) .|. 1
{-# INLINE start #-}
next' :: SetSeq -> (Word32, SetSeq)
next' g@(SetSeq _ inc) = (r, SetSeq s' inc)
where Pair s' r = pair g
{-# INLINE next' #-}
save :: PrimMonad m => Gen (PrimState m) -> m SetSeq
save (G a) = do
s <- readByteArray a 0
inc <- readByteArray a 1
return $ SetSeq s inc
{-# INLINE save #-}
restore :: PrimMonad m => FrozenGen -> m (Gen (PrimState m))
restore (SetSeq s inc) = do
a <- newByteArray 16
writeByteArray a 0 s
writeByteArray a 1 inc
return $! G a
{-# INLINE restore #-}
initFrozen :: Word64 -> Word64 -> SetSeq
initFrozen = start
create :: PrimMonad m => m (Gen (PrimState m))
create = restore seed
initialize :: PrimMonad m => Word64 -> Word64 -> m (Gen (PrimState m))
initialize a b = restore (initFrozen a b)
withSystemRandom :: (GenIO -> IO a) -> IO a
withSystemRandom f = do
a <- sysRandom
b <- sysRandom
initialize a b >>= f
createSystemRandom :: IO GenIO
createSystemRandom = withSystemRandom return
advance :: PrimMonad m => Word64 -> Gen (PrimState m) -> m ()
advance u g@(G a) = do
ss <- save g
let s' = advanceSetSeq' u ss
writeByteArray a 0 s'
{-# INLINE advance #-}
retract :: PrimMonad m => Word64 -> Gen (PrimState m) -> m ()
retract u g = advance (-u) g
{-# INLINE retract #-}
instance (PrimMonad m, s ~ PrimState m) => Generator (Gen s) m where
uniform1 f (G a) = do
s <- readByteArray a 0
inc <- readByteArray a 1
writeByteArray a 0 $! s * multiplier + inc
return $! f (output s)
{-# INLINE uniform1 #-}
uniform2 f (G a) = do
s <- readByteArray a 0
inc <- readByteArray a 1
let !s' = s * multiplier + inc
writeByteArray a 0 $! s' * multiplier + inc
return $! f (output s) (output s')
{-# INLINE uniform2 #-}
uniform1B f b g@(G a) = do
ss <- save g
let Pair s' r = bounded b ss
writeByteArray a 0 s'
return $! f r
{-# INLINE uniform1B #-}
instance RandomGen FrozenGen where
next (SetSeq s inc) = (wordsTo64Bit w1 w2, SetSeq s'' inc)
where
Pair s' w1 = pair (SetSeq s inc)
Pair s'' w2 = pair (SetSeq s' inc)
{-# INLINE next #-}
split (SetSeq s inc) = (SetSeq s4 inc, mk w1 w2 w3 w4)
where
mk a b c d = start (wordsTo64Bit a b) (wordsTo64Bit c d)
Pair s1 w1 = pair (SetSeq s inc)
Pair s2 w2 = pair (SetSeq s1 inc)
Pair s3 w3 = pair (SetSeq s2 inc)
Pair s4 w4 = pair (SetSeq s3 inc)
{-# INLINE split #-}