{-# LANGUAGE BangPatterns #-}
module System.Random.MRG.Internal
(
StateVector(..)
, JumpMatrix(..)
, T6(..)
, vecTrMod
, matMulMod
, matSqrMod
, vecTrModW64
, matMulModW64
, matSqrModW64
) where
import Data.Word (Word64)
import Data.Bits (shiftL, shiftR, (.&.), (.|.))
newtype T6 a = T6 { fromT6 :: (a,a,a,a,a,a) }
instance Functor T6 where
fmap f (T6 (x1,x2,x3,x4,x5,x6)) = T6 (f x1, f x2, f x3, f x4, f x5, f x6)
data StateVector a = SV (a,a,a) deriving (Show,Eq)
data JumpMatrix a = JM (a,a,a) (a,a,a) (a,a,a) deriving (Show,Eq)
instance Functor StateVector where
fmap f (SV (x1,x2,x3)) = SV (f x1, f x2, f x3)
instance Functor JumpMatrix where
fmap f (JM (x11,x12,x13) (x21,x22,x23) (x31,x32,x33))
= JM (f x11,f x12,f x13) (f x21,f x22,f x23) (f x31,f x32,f x33)
hiW64 :: Word64 -> Word64
hiW64 x = (x `shiftR` 32) .&. 0x00000000ffffffff
{-# INLINE hiW64 #-}
loW64 :: Word64 -> Word64
loW64 x = x .&. 0x00000000ffffffff
{-# INLINE loW64 #-}
splitW64 :: Word64 -> (Word64,Word64)
splitW64 x = (hiW64 x, loW64 x)
{-# INLINE splitW64 #-}
mulW64 :: Word64 -> Word64 -> (Word64,Word64)
mulW64 x y = (vh,vl)
where (xh,xl) = splitW64 x
(yh,yl) = splitW64 y
z0 = xl * yl
z1 = xh * yl
z2 = xl * yh
z3 = xh * yh
(z0h,z0l) = splitW64 z0
(z1h,z1l) = splitW64 z1
(z2h,z2l) = splitW64 z2
w0 = z0h + z1l + z2l
w1 = loW64 w0 `shiftL` 32
w2 = hiW64 w0
!vl = w1 .|. z0l
!vh = z3 + z1h + z2h + w2
{-# INLINE mulW64 #-}
mulModW64 :: Word64 -> Word64 -> Word64 -> Word64
mulModW64 m x y = go $ x `mulW64` y
where s = (18446744073709551615 `mod` m) + 1
s' = if s >= m then s - m else s
go v = case step v of
(0, t0l) -> t0l `mod` m
t1 -> go t1
step (wh,wl) = (vh, vlh .|. vll)
where rl = wl `mod` m
rh = wh `mod` m
(rlh,rll) = splitW64 rl
(zh ,zl ) = rh `mulW64` s'
(zlh,zll) = splitW64 zl
t0 = rll + zll
(c0 ,vll) = splitW64 t0
t1 = rlh + zlh + c0
vlh = loW64 t1 `shiftL` 32
c1 = hiW64 t1
vh = zh + c1
{-# INLINE mulModW64 #-}
dotModW64 :: Word64 -> (Word64,Word64,Word64) -> StateVector Word64 -> Word64
dotModW64 m (x1,x2,x3) (SV (y1,y2,y3)) = z
where u = mulModW64 m
!w = ((x1 `u` y1) + (x2 `u` y2)) `mod` m
!z = (w + (x3 `u` y3)) `mod` m
{-# INLINE dotModW64 #-}
vecTrModW64 :: Word64 -> JumpMatrix Word64 -> StateVector Word64 -> StateVector Word64
vecTrModW64 m (JM xr1 xr2 xr3) y = SV (z1,z2,z3)
where u = dotModW64 m
!z1 = xr1 `u` y
!z2 = xr2 `u` y
!z3 = xr3 `u` y
{-# INLINE vecTrModW64 #-}
matMulModW64 :: Word64 -> JumpMatrix Word64 -> JumpMatrix Word64 -> JumpMatrix Word64
matMulModW64 m xx (JM (y11,y12,y13) (y21,y22,y23) (y31,y32,y33))
= JM (z11,z12,z13) (z21,z22,z23) (z31,z32,z33)
where u = vecTrModW64 m
SV (!z11,!z21,!z31) = xx `u` (SV (y11,y21,y31))
SV (!z12,!z22,!z32) = xx `u` (SV (y12,y22,y32))
SV (!z13,!z23,!z33) = xx `u` (SV (y13,y23,y33))
{-# INLINE matMulModW64 #-}
matSqrModW64 :: Word64 -> JumpMatrix Word64 -> JumpMatrix Word64
matSqrModW64 m xx = matMulModW64 m xx xx
{-# INLINE matSqrModW64 #-}
mulMod :: (Integral a) => a -> a -> a -> a
mulMod m x y = (x * y) `mod` m
{-# INLINE mulMod #-}
dotMod :: (Integral a) => a -> (a,a,a) -> StateVector a -> a
dotMod m (x1,x2,x3) (SV (y1,y2,y3)) = z
where u = mulMod m
!z = ((x1 `u` y1) + (x2 `u` y2) + (x3 `u` y3)) `mod` m
{-# INLINE dotMod #-}
vecTrMod :: (Integral a) => a -> JumpMatrix a -> StateVector a -> StateVector a
vecTrMod m (JM xr1 xr2 xr3) y = SV (z1,z2,z3)
where u = dotMod m
!z1 = xr1 `u` y
!z2 = xr2 `u` y
!z3 = xr3 `u` y
{-# INLINE vecTrMod #-}
matMulMod :: (Integral a) => a -> JumpMatrix a -> JumpMatrix a -> JumpMatrix a
matMulMod m xx (JM (y11,y12,y13) (y21,y22,y23) (y31,y32,y33))
= JM (z11,z12,z13) (z21,z22,z23) (z31,z32,z33)
where u = vecTrMod m
SV (!z11,!z21,!z31) = xx `u` (SV (y11,y21,y31))
SV (!z12,!z22,!z32) = xx `u` (SV (y12,y22,y32))
SV (!z13,!z23,!z33) = xx `u` (SV (y13,y23,y33))
matSqrMod :: (Integral a) => a -> JumpMatrix a -> JumpMatrix a
matSqrMod m xx = matMulMod m xx xx