{-# LANGUAGE CPP #-}
#if __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Trustworthy #-}
#endif
module System.Random.SplitMix (
SMGen,
nextWord64,
nextWord32,
nextTwoWord32,
nextInt,
nextDouble,
nextFloat,
nextInteger,
splitSMGen,
bitmaskWithRejection32,
bitmaskWithRejection32',
bitmaskWithRejection64,
bitmaskWithRejection64',
mkSMGen,
initSMGen,
newSMGen,
seedSMGen,
seedSMGen',
unseedSMGen,
) where
import Data.Bits (complement, shiftL, shiftR, xor, (.&.), (.|.))
import Data.Bits.Compat (countLeadingZeros, popCount, zeroBits)
import Data.IORef (IORef, atomicModifyIORef, newIORef)
import Data.Time.Clock.POSIX (getPOSIXTime)
import Data.Word (Word32, Word64)
import System.IO.Unsafe (unsafePerformIO)
#if defined(__HUGS__) || !MIN_VERSION_base(4,8,0)
import Data.Word (Word)
#endif
#ifndef __HUGS__
import Control.DeepSeq (NFData (..))
#endif
#if !__GHCJS__
import System.CPUTime (cpuTimePrecision, getCPUTime)
#endif
data SMGen = SMGen !Word64 !Word64
deriving Show
#ifndef __HUGS__
instance NFData SMGen where
rnf (SMGen _ _) = ()
#endif
instance Read SMGen where
readsPrec d r = readParen (d > 10) (\r0 ->
[ (SMGen seed gamma, r3)
| ("SMGen", r1) <- lex r0
, (seed, r2) <- readsPrec 11 r1
, (gamma, r3) <- readsPrec 11 r2
, odd gamma
]) r
nextWord64 :: SMGen -> (Word64, SMGen)
nextWord64 (SMGen seed gamma) = (mix64 seed', SMGen seed' gamma)
where
seed' = seed `plus` gamma
nextWord32 :: SMGen -> (Word32, SMGen)
nextWord32 g =
#ifdef __HUGS__
(fromIntegral $ w64 .&. 0xffffffff, g')
#else
(fromIntegral w64, g')
#endif
where
(w64, g') = nextWord64 g
nextTwoWord32 :: SMGen -> (Word32, Word32, SMGen)
nextTwoWord32 g =
#ifdef __HUGS__
(fromIntegral $ w64 `shiftR` 32, fromIntegral $ w64 .&. 0xffffffff, g')
#else
(fromIntegral $ w64 `shiftR` 32, fromIntegral w64, g')
#endif
where
(w64, g') = nextWord64 g
nextInt :: SMGen -> (Int, SMGen)
nextInt g = case nextWord64 g of
#ifdef __HUGS__
(w64, g') -> (fromIntegral $ w64 `shiftR` 32, g')
#else
(w64, g') -> (fromIntegral w64, g')
#endif
nextDouble :: SMGen -> (Double, SMGen)
nextDouble g = case nextWord64 g of
(w64, g') -> (fromIntegral (w64 `shiftR` 11) * doubleUlp, g')
nextFloat :: SMGen -> (Float, SMGen)
nextFloat g = case nextWord32 g of
(w32, g') -> (fromIntegral (w32 `shiftR` 8) * floatUlp, g')
nextInteger :: Integer -> Integer -> SMGen -> (Integer, SMGen)
nextInteger lo hi g = case compare lo hi of
LT -> let (i, g') = nextInteger' (hi - lo) g in (i + lo, g')
EQ -> (lo, g)
GT -> let (i, g') = nextInteger' (lo - hi) g in (i + hi, g')
nextInteger' :: Integer -> SMGen -> (Integer, SMGen)
nextInteger' range = loop
where
leadMask :: Word64
restDigits :: Word
(leadMask, restDigits) = go 0 range where
go :: Word -> Integer -> (Word64, Word)
go n x | x < two64 = (complement zeroBits `shiftR` countLeadingZeros (fromInteger x :: Word64), n)
| otherwise = go (n + 1) (x `shiftR` 64)
generate :: SMGen -> (Integer, SMGen)
generate g0 =
let (x, g') = nextWord64 g0
x' = x .&. leadMask
in go (fromIntegral x') restDigits g'
where
go :: Integer -> Word -> SMGen -> (Integer, SMGen)
go acc 0 g = acc `seq` (acc, g)
go acc n g =
let (x, g') = nextWord64 g
in go (acc * two64 + fromIntegral x) (n - 1) g'
loop g = let (x, g') = generate g
in if x > range
then loop g'
else (x, g')
two64 :: Integer
two64 = 2 ^ (64 :: Int)
splitSMGen :: SMGen -> (SMGen, SMGen)
splitSMGen (SMGen seed gamma) =
(SMGen seed'' gamma, SMGen (mix64 seed') (mixGamma seed''))
where
seed' = seed `plus` gamma
seed'' = seed' `plus` gamma
goldenGamma :: Word64
goldenGamma = 0x9e3779b97f4a7c15
floatUlp :: Float
floatUlp = 1.0 / fromIntegral (1 `shiftL` 24 :: Word32)
doubleUlp :: Double
doubleUlp = 1.0 / fromIntegral (1 `shiftL` 53 :: Word64)
mix64 :: Word64 -> Word64
mix64 z0 =
let z1 = shiftXorMultiply 33 0xff51afd7ed558ccd z0
z2 = shiftXorMultiply 33 0xc4ceb9fe1a85ec53 z1
z3 = shiftXor 33 z2
in z3
mix64variant13 :: Word64 -> Word64
mix64variant13 z0 =
let z1 = shiftXorMultiply 30 0xbf58476d1ce4e5b9 z0
z2 = shiftXorMultiply 27 0x94d049bb133111eb z1
z3 = shiftXor 31 z2
in z3
mixGamma :: Word64 -> Word64
mixGamma z0 =
let z1 = mix64variant13 z0 .|. 1
n = popCount (z1 `xor` (z1 `shiftR` 1))
in if n >= 24
then z1
else z1 `xor` 0xaaaaaaaaaaaaaaaa
shiftXor :: Int -> Word64 -> Word64
shiftXor n w = w `xor` (w `shiftR` n)
shiftXorMultiply :: Int -> Word64 -> Word64 -> Word64
shiftXorMultiply n k w = shiftXor n w `mult` k
bitmaskWithRejection32 :: Word32 -> SMGen -> (Word32, SMGen)
bitmaskWithRejection32 0 = error "bitmaskWithRejection32 0"
bitmaskWithRejection32 n = bitmaskWithRejection32' (n - 1)
bitmaskWithRejection64 :: Word64 -> SMGen -> (Word64, SMGen)
bitmaskWithRejection64 0 = error "bitmaskWithRejection64 0"
bitmaskWithRejection64 n = bitmaskWithRejection64' (n - 1)
bitmaskWithRejection32' :: Word32 -> SMGen -> (Word32, SMGen)
bitmaskWithRejection32' range = go where
mask = complement zeroBits `shiftR` countLeadingZeros (range .|. 1)
go g = let (x, g') = nextWord32 g
x' = x .&. mask
in if x' > range
then go g'
else (x', g')
bitmaskWithRejection64' :: Word64 -> SMGen -> (Word64, SMGen)
bitmaskWithRejection64' range = go where
mask = complement zeroBits `shiftR` countLeadingZeros range
go g = let (x, g') = nextWord64 g
x' = x .&. mask
in if x' > range
then go g'
else (x', g')
seedSMGen
:: Word64
-> Word64
-> SMGen
seedSMGen seed gamma = SMGen seed (gamma .|. 1)
seedSMGen' :: (Word64, Word64) -> SMGen
seedSMGen' = uncurry seedSMGen
unseedSMGen :: SMGen -> (Word64, Word64)
unseedSMGen (SMGen seed gamma) = (seed, gamma)
mkSMGen :: Word64 -> SMGen
mkSMGen s = SMGen (mix64 s) (mixGamma (s `plus` goldenGamma))
initSMGen :: IO SMGen
initSMGen = fmap mkSMGen mkSeedTime
newSMGen :: IO SMGen
newSMGen = atomicModifyIORef theSMGen splitSMGen
theSMGen :: IORef SMGen
theSMGen = unsafePerformIO $ initSMGen >>= newIORef
{-# NOINLINE theSMGen #-}
mkSeedTime :: IO Word64
mkSeedTime = do
now <- getPOSIXTime
let lo = truncate now :: Word32
#if __GHCJS__
let hi = lo
#else
cpu <- getCPUTime
let hi = fromIntegral (cpu `div` cpuTimePrecision) :: Word32
#endif
return $ fromIntegral hi `shiftL` 32 .|. fromIntegral lo
mult, plus :: Word64 -> Word64 -> Word64
#ifndef __HUGS__
mult = (*)
plus = (+)
#else
mult x y = fromInteger ((toInteger x * toInteger y) `mod` 18446744073709551616)
plus x y = fromInteger ((toInteger x + toInteger y) `mod` 18446744073709551616)
#endif