{-# LANGUAGE BangPatterns #-}
module Stochastic.Uniform(xorshift128plus,
                          UniformRandom,
                          nWayAllocate,
                          splitAllocate,
                          RandomGen(..)) where

import Stochastic.Generator
import Data.Word
import Data.Bits
import Data.Typeable
import Control.Exception(throw, Exception)
import System.Random(RandomGen(..))

data UniformRandom = XorShift128Plus Word64 Word64 Integer

data EntropyExhausted = EntropyExhausted
                      deriving(Eq, Typeable)

instance Exception EntropyExhausted where
instance Show EntropyExhausted where
  show e = "EntropyExhausted"

{- |
For information on the performance of the xorshift-128-plus PRNG, please see: <http://vigna.di.unimi.it/ftp/papers/xorshiftplus.pdf Vigna et al.>
-}
xorshift128plus :: Integer -> UniformRandom
xorshift128plus seed = XorShift128Plus high low entropy
  where
    high = fromInteger seed
    low = fromInteger seed
    entropy = (2^127)

nWayAllocate :: Integer -> Integer -> UniformRandom -> ([UniformRandom], UniformRandom)
nWayAllocate _ 0 g0 = ([], g0)
nWayAllocate size n g0 = ((g1:gs), g3)
  where
    !(gs,g3) = nWayAllocate size (n-1) g2
    !(g1,g2) = splitAllocate size g0

splitAllocate :: Integer -> UniformRandom -> (UniformRandom, UniformRandom)
splitAllocate count g@(XorShift128Plus high low entropy) =
    ((XorShift128Plus high low count), g')
    where
      !g' = step (next) (count) g


instance RandomGen UniformRandom where
  next (XorShift128Plus high low entropy) 
    | entropy == 0 = throw EntropyExhausted
    | otherwise    = final 
    where
      -- eagerly evaluate this function, retain no intermediaries or we might blow the stack
      final    = (ret, XorShift128Plus high' low' entropy')
      !ret     = fromInteger $ toInteger $ high' + high
      x        = low `xor` (low `shift` 23)
      high'    = x `xor` high `xor` (x `shift` (-17)) `xor` (high `shift` (-26))
      low'     = high
      entropy' = entropy - 1
  split g@(XorShift128Plus high low entropy) =
    ((XorShift128Plus high low entropy'), g')
    where
      !entropy' = (2^32)
      !g' = step (next) (entropy') g

step :: (g -> (a,g)) -> Integer -> g -> g
step f 0 g'' = g''
step f n g'' = step (f) (n-1) $! (snd $ f g'')