-- | Various internal utilities.
module Numeric.MCMC.Util 

where

import Control.Monad.Primitive (PrimMonad, PrimState)
import System.Random.MWC 
import Data.Function           (fix)
import qualified Data.Sequence as Seq
import Data.Foldable           (toList)
import Control.Monad
import Control.Monad.ST
import Data.STRef

-- | Given Int, bounds, and generator, generate a different Int in the bound.
genDiffInt :: PrimMonad m => Int -> (Int, Int) -> Gen (PrimState m) -> m Int
genDiffInt a bounds gen = fix $ \loopB -> 
    do  b <- uniformR bounds gen
        if a == b then loopB else return b

-- | Tail-recursive, list-fused mean function.
mean :: [Double] -> Double 
mean = go 0.0 0 
    where go :: Double -> Int -> [Double] -> Double
          go s l []     = s / fromIntegral l
          go s l (x:xs) = go (s + x) (l + 1) xs

-- | Map a function over a pair.
mapPair :: (a -> b) -> (a, a) -> (b, b)
mapPair f (a, b) = (f a, f b) 

-- | Convert a list to a pair.
shortListToPair :: [a] -> (a, a)
shortListToPair [x0, x1]  = (x0, x1)
shortListToPair _         = error "shortListToPair - list must have length 2."

-- | Knuth-shuffle a list.  Uses 'Seq' internally.
shuffle :: [a] -> Seed -> [a]
shuffle xs seed = runST $ do
    xsref <- newSTRef $ Seq.fromList xs
    gen   <- restore seed
    let n = length xs
    forM_ [n-1,n-2..1] $ \i -> do
        j <- uniformR (0, i) gen

        xst <- readSTRef xsref
        let tmp = Seq.index xst j

        modifySTRef xsref $ Seq.update j (Seq.index xst i)
        modifySTRef xsref $ Seq.update i tmp

    result <- readSTRef xsref
    return $ toList result

-- | Sample from a list without replacement.
sample :: Int -> [a] -> Seed -> [a]
sample k xs seed = take k $ shuffle xs seed