-- |
-- Module      :  Mcmc.Chain.Trace
-- Description :  History of a Markov chain
-- Copyright   :  (c) Dominik Schrempf 2021
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Wed May 20 09:11:25 2020.
module Mcmc.Chain.Trace
  ( Trace,
    replicateT,
    lengthT,
    pushT,
    headT,
    takeT,
    freezeT,
    thawT,
  )
where

import Control.Monad.Primitive
import qualified Data.Stack.Circular as C
import qualified Data.Vector as VB
import Mcmc.Chain.Link

-- | A 'Trace' is a mutable circular stack that passes through a list of states
-- with associated priors and likelihoods called 'Link's.
newtype Trace a = Trace {Trace a -> MStack Vector RealWorld (Link a)
fromTrace :: C.MStack VB.Vector RealWorld (Link a)}

-- | Initialize a trace of given length by replicating the same value.
--
-- Be careful not to compute summary statistics before pushing enough values.
--
-- Call 'error' if the maximum size is zero or negative.
replicateT :: Int -> Link a -> IO (Trace a)
replicateT :: Int -> Link a -> IO (Trace a)
replicateT Int
n Link a
l = MStack Vector RealWorld (Link a) -> Trace a
forall a. MStack Vector RealWorld (Link a) -> Trace a
Trace (MStack Vector RealWorld (Link a) -> Trace a)
-> IO (MStack Vector RealWorld (Link a)) -> IO (Trace a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Link a -> IO (MStack Vector (PrimState IO) (Link a))
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
Int -> a -> m (MStack v (PrimState m) a)
C.replicate Int
n Link a
l

-- | Get the length of the trace.
lengthT :: Trace a -> Int
lengthT :: Trace a -> Int
lengthT = MStack Vector RealWorld (Link a) -> Int
forall (v :: * -> *) a s. Vector v a => MStack v s a -> Int
C.size (MStack Vector RealWorld (Link a) -> Int)
-> (Trace a -> MStack Vector RealWorld (Link a)) -> Trace a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Trace a -> MStack Vector RealWorld (Link a)
forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace

-- | Push a 'Link' on the 'Trace'.
pushT :: Link a -> Trace a -> IO (Trace a)
pushT :: Link a -> Trace a -> IO (Trace a)
pushT Link a
x Trace a
t = do
  MStack Vector RealWorld (Link a)
s' <- Link a
-> MStack Vector (PrimState IO) (Link a)
-> IO (MStack Vector (PrimState IO) (Link a))
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
a -> MStack v (PrimState m) a -> m (MStack v (PrimState m) a)
C.push Link a
x (Trace a -> MStack Vector RealWorld (Link a)
forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace Trace a
t)
  Trace a -> IO (Trace a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Trace a -> IO (Trace a)) -> Trace a -> IO (Trace a)
forall a b. (a -> b) -> a -> b
$ MStack Vector RealWorld (Link a) -> Trace a
forall a. MStack Vector RealWorld (Link a) -> Trace a
Trace MStack Vector RealWorld (Link a)
s'
{-# INLINEABLE pushT #-}

-- | Get the most recent link of the trace.
--
-- See 'C.get'.
headT :: Trace a -> IO (Link a)
headT :: Trace a -> IO (Link a)
headT = MStack Vector RealWorld (Link a) -> IO (Link a)
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
MStack v (PrimState m) a -> m a
C.get (MStack Vector RealWorld (Link a) -> IO (Link a))
-> (Trace a -> MStack Vector RealWorld (Link a))
-> Trace a
-> IO (Link a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Trace a -> MStack Vector RealWorld (Link a)
forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace
{-# INLINEABLE headT #-}

-- | Get the k most recent links of the trace.
--
-- See 'C.take'.
takeT :: Int -> Trace a -> IO (VB.Vector (Link a))
takeT :: Int -> Trace a -> IO (Vector (Link a))
takeT Int
k = Int
-> MStack Vector (PrimState IO) (Link a) -> IO (Vector (Link a))
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
Int -> MStack v (PrimState m) a -> m (v a)
C.take Int
k (MStack Vector RealWorld (Link a) -> IO (Vector (Link a)))
-> (Trace a -> MStack Vector RealWorld (Link a))
-> Trace a
-> IO (Vector (Link a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Trace a -> MStack Vector RealWorld (Link a)
forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace

-- | Freeze the mutable trace for storage.
--
-- See 'C.freeze'.
freezeT :: Trace a -> IO (C.Stack VB.Vector (Link a))
freezeT :: Trace a -> IO (Stack Vector (Link a))
freezeT = MStack Vector RealWorld (Link a) -> IO (Stack Vector (Link a))
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
MStack v (PrimState m) a -> m (Stack v a)
C.freeze (MStack Vector RealWorld (Link a) -> IO (Stack Vector (Link a)))
-> (Trace a -> MStack Vector RealWorld (Link a))
-> Trace a
-> IO (Stack Vector (Link a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Trace a -> MStack Vector RealWorld (Link a)
forall a. Trace a -> MStack Vector RealWorld (Link a)
fromTrace

-- | Thaw a circular stack.
--
-- See 'See.thaw'.
thawT :: C.Stack VB.Vector (Link a) -> IO (Trace a)
thawT :: Stack Vector (Link a) -> IO (Trace a)
thawT Stack Vector (Link a)
t = MStack Vector RealWorld (Link a) -> Trace a
forall a. MStack Vector RealWorld (Link a) -> Trace a
Trace (MStack Vector RealWorld (Link a) -> Trace a)
-> IO (MStack Vector RealWorld (Link a)) -> IO (Trace a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stack Vector (Link a) -> IO (MStack Vector (PrimState IO) (Link a))
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
Stack v a -> m (MStack v (PrimState m) a)
C.thaw Stack Vector (Link a)
t