module MT19937.Internal where

import Data.Bits
import Data.Word ( Word32 )
import Data.Vector.Unboxed.Mutable qualified as VUM

-- | MT19937 tempering function.
temper :: (Num a, Bits a) => a -> a
temper :: forall a. (Num a, Bits a) => a -> a
temper a
x = a
z
  where
    y1 :: a
y1 = a
x  a -> a -> a
forall a. Bits a => a -> a -> a
`xor` ((a
x  a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
u) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
d)
    y2 :: a
y2 = a
y1 a -> a -> a
forall a. Bits a => a -> a -> a
`xor` ((a
y1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
s) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
b)
    y3 :: a
y3 = a
y2 a -> a -> a
forall a. Bits a => a -> a -> a
`xor` ((a
y2 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
t) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
c)
    z :: a
z  = a
y3 a -> a -> a
forall a. Bits a => a -> a -> a
`xor`  (a
y3 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
l)
    u :: Int
u = Int
11
    d :: a
d = a
0xFFFFFFFF
    s :: Int
s = Int
7
    b :: a
b = a
0x9D2C5680
    t :: Int
t = Int
15
    c :: a
c = a
0xEFC60000
    l :: Int
l = Int
18

-- | Twist an MT19937 state vector.
twist :: VUM.PrimMonad m => VUM.MVector (VUM.PrimState m) Word32 -> m ()
twist :: forall (m :: Type -> Type).
PrimMonad m =>
MVector (PrimState m) Word32 -> m ()
twist MVector (PrimState m) Word32
mt = Word32 -> m ()
go (Word32
0 :: Word32)
  where
    fI :: Word32 -> Int
fI = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral
    m :: Word32
m = Word32
397
    a :: Word32
a = Word32
0x9908B0DF
    go :: Word32 -> m ()
go = \case
      Word32
624 -> () -> m ()
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
      Word32
i   -> do
        Word32
mti  <- MVector (PrimState m) Word32 -> Int -> m Word32
forall (m :: Type -> Type) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
VUM.unsafeRead MVector (PrimState m) Word32
mt (Word32 -> Int
fI Word32
i)
        Word32
mti1 <- MVector (PrimState m) Word32 -> Int -> m Word32
forall (m :: Type -> Type) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
VUM.unsafeRead MVector (PrimState m) Word32
mt (Word32 -> Int
fI ((Word32
iWord32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+Word32
1) Word32 -> Word32 -> Word32
forall a. Integral a => a -> a -> a
`mod` Word32
624))
        Word32
mtim <- MVector (PrimState m) Word32 -> Int -> m Word32
forall (m :: Type -> Type) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
VUM.unsafeRead MVector (PrimState m) Word32
mt (Word32 -> Int
fI ((Word32
iWord32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+Word32
m) Word32 -> Word32 -> Word32
forall a. Integral a => a -> a -> a
`mod` Word32
624))
        let x :: Word32
x    = (Word32
mti Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0x80000000) Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ (Word32
mti1 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0x7FFFFFFF)
            mti' :: Word32
mti' = Word32
mtim Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` (Word32
x Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
        if   Word32
x Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
1 Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
== Word32
0
        then MVector (PrimState m) Word32 -> Int -> Word32 -> m ()
forall (m :: Type -> Type) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
VUM.unsafeWrite MVector (PrimState m) Word32
mt (Word32 -> Int
fI Word32
i) Word32
mti'
        else MVector (PrimState m) Word32 -> Int -> Word32 -> m ()
forall (m :: Type -> Type) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
VUM.unsafeWrite MVector (PrimState m) Word32
mt (Word32 -> Int
fI Word32
i) (Word32
mti' Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` Word32
a)
        Word32 -> m ()
go (Word32
iWord32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+Word32
1)