{- -*- mode: haskell; coding: utf-8-unix -*- -} {-# LANGUAGE BangPatterns #-} module System.Random.MRG.Internal ( StateVector(..) , JumpMatrix(..) , T6(..) , vecTrMod , matMulMod , matSqrMod , vecTrModW64 , matMulModW64 , matSqrModW64 ) where import Data.Word (Word64) import Data.Bits (shiftL, shiftR, (.&.), (.|.)) newtype T6 a = T6 { fromT6 :: (a,a,a,a,a,a) } instance Functor T6 where fmap f (T6 (x1,x2,x3,x4,x5,x6)) = T6 (f x1, f x2, f x3, f x4, f x5, f x6) data StateVector a = SV (a,a,a) deriving (Show,Eq) data JumpMatrix a = JM (a,a,a) (a,a,a) (a,a,a) deriving (Show,Eq) instance Functor StateVector where fmap f (SV (x1,x2,x3)) = SV (f x1, f x2, f x3) instance Functor JumpMatrix where fmap f (JM (x11,x12,x13) (x21,x22,x23) (x31,x32,x33)) = JM (f x11,f x12,f x13) (f x21,f x22,f x23) (f x31,f x32,f x33) hiW64 :: Word64 -> Word64 hiW64 x = (x `shiftR` 32) .&. 0x00000000ffffffff {-# INLINE hiW64 #-} loW64 :: Word64 -> Word64 loW64 x = x .&. 0x00000000ffffffff {-# INLINE loW64 #-} splitW64 :: Word64 -> (Word64,Word64) splitW64 x = (hiW64 x, loW64 x) {-# INLINE splitW64 #-} mulW64 :: Word64 -> Word64 -> (Word64,Word64) mulW64 x y = (vh,vl) where (xh,xl) = splitW64 x (yh,yl) = splitW64 y z0 = xl * yl z1 = xh * yl z2 = xl * yh z3 = xh * yh (z0h,z0l) = splitW64 z0 (z1h,z1l) = splitW64 z1 (z2h,z2l) = splitW64 z2 w0 = z0h + z1l + z2l w1 = loW64 w0 `shiftL` 32 w2 = hiW64 w0 !vl = w1 .|. z0l !vh = z3 + z1h + z2h + w2 {-# INLINE mulW64 #-} mulModW64 :: Word64 -> Word64 -> Word64 -> Word64 mulModW64 m x y = go $ x `mulW64` y where s = (18446744073709551615 `mod` m) + 1 s' = if s >= m then s - m else s go v = case step v of (0, t0l) -> t0l `mod` m t1 -> go t1 step (wh,wl) = (vh, vlh .|. vll) where rl = wl `mod` m rh = wh `mod` m (rlh,rll) = splitW64 rl -- b = p *m + s -- ah = qh*m + rh -- al = ql*m + rl -- ah*b + al -- = (qh*m + rh)*(p*m + s) + ql*m + rl -- = (qh*m*p + qh*s + rh*p + ql)*m + rh*s + rl (zh ,zl ) = rh `mulW64` s' (zlh,zll) = splitW64 zl t0 = rll + zll (c0 ,vll) = splitW64 t0 t1 = rlh + zlh + c0 vlh = loW64 t1 `shiftL` 32 c1 = hiW64 t1 vh = zh + c1 {-# INLINE mulModW64 #-} dotModW64 :: Word64 -> (Word64,Word64,Word64) -> StateVector Word64 -> Word64 dotModW64 m (x1,x2,x3) (SV (y1,y2,y3)) = z where u = mulModW64 m !w = ((x1 `u` y1) + (x2 `u` y2)) `mod` m !z = (w + (x3 `u` y3)) `mod` m {-# INLINE dotModW64 #-} vecTrModW64 :: Word64 -> JumpMatrix Word64 -> StateVector Word64 -> StateVector Word64 vecTrModW64 m (JM xr1 xr2 xr3) y = SV (z1,z2,z3) where u = dotModW64 m !z1 = xr1 `u` y !z2 = xr2 `u` y !z3 = xr3 `u` y {-# INLINE vecTrModW64 #-} matMulModW64 :: Word64 -> JumpMatrix Word64 -> JumpMatrix Word64 -> JumpMatrix Word64 matMulModW64 m xx (JM (y11,y12,y13) (y21,y22,y23) (y31,y32,y33)) = JM (z11,z12,z13) (z21,z22,z23) (z31,z32,z33) where u = vecTrModW64 m SV (!z11,!z21,!z31) = xx `u` (SV (y11,y21,y31)) SV (!z12,!z22,!z32) = xx `u` (SV (y12,y22,y32)) SV (!z13,!z23,!z33) = xx `u` (SV (y13,y23,y33)) {-# INLINE matMulModW64 #-} matSqrModW64 :: Word64 -> JumpMatrix Word64 -> JumpMatrix Word64 matSqrModW64 m xx = matMulModW64 m xx xx {-# INLINE matSqrModW64 #-} mulMod :: (Integral a) => a -> a -> a -> a mulMod m x y = (x * y) `mod` m {-# INLINE mulMod #-} dotMod :: (Integral a) => a -> (a,a,a) -> StateVector a -> a dotMod m (x1,x2,x3) (SV (y1,y2,y3)) = z where u = mulMod m !z = ((x1 `u` y1) + (x2 `u` y2) + (x3 `u` y3)) `mod` m {-# INLINE dotMod #-} vecTrMod :: (Integral a) => a -> JumpMatrix a -> StateVector a -> StateVector a vecTrMod m (JM xr1 xr2 xr3) y = SV (z1,z2,z3) where u = dotMod m !z1 = xr1 `u` y !z2 = xr2 `u` y !z3 = xr3 `u` y {-# INLINE vecTrMod #-} matMulMod :: (Integral a) => a -> JumpMatrix a -> JumpMatrix a -> JumpMatrix a matMulMod m xx (JM (y11,y12,y13) (y21,y22,y23) (y31,y32,y33)) = JM (z11,z12,z13) (z21,z22,z23) (z31,z32,z33) where u = vecTrMod m SV (!z11,!z21,!z31) = xx `u` (SV (y11,y21,y31)) SV (!z12,!z22,!z32) = xx `u` (SV (y12,y22,y32)) SV (!z13,!z23,!z33) = xx `u` (SV (y13,y23,y33)) matSqrMod :: (Integral a) => a -> JumpMatrix a -> JumpMatrix a matSqrMod m xx = matMulMod m xx xx -- EOF