module Control.Monad.Bayes.Sequential (
Sequential,
suspend,
finish,
advance,
finished,
hoistFirst,
hoist,
sis
) where
import Control.Monad.Trans
import Control.Monad.Coroutine hiding (suspend)
import Control.Monad.Coroutine.SuspensionFunctors
import Data.Either
import Control.Monad.Bayes.Class
newtype Sequential m a = Sequential {runSequential :: Coroutine (Await ()) m a}
deriving(Functor,Applicative,Monad,MonadTrans,MonadIO)
extract :: Await () a -> a
extract (Await f) = f ()
instance MonadSample m => MonadSample (Sequential m) where
random = lift random
bernoulli = lift . bernoulli
categorical = lift . categorical
instance MonadCond m => MonadCond (Sequential m) where
score w = lift (score w) >> suspend
instance MonadInfer m => MonadInfer (Sequential m)
suspend :: Monad m => Sequential m ()
suspend = Sequential await
finish :: Monad m => Sequential m a -> m a
finish = pogoStick extract . runSequential
advance :: Monad m => Sequential m a -> Sequential m a
advance = Sequential . bounce extract . runSequential
finished :: Monad m => Sequential m a -> m Bool
finished = fmap isRight . resume . runSequential
hoistFirst :: (forall x. m x -> m x) -> Sequential m a -> Sequential m a
hoistFirst f = Sequential . Coroutine . f . resume . runSequential
hoist :: (Monad m, Monad n) =>
(forall x. m x -> n x) -> Sequential m a -> Sequential n a
hoist f = Sequential . mapMonad f . runSequential
composeCopies :: Int -> (a -> a) -> (a -> a)
composeCopies k f = foldr (.) id (replicate k f)
sis :: Monad m
=> (forall x. m x -> m x)
-> Int
-> Sequential m a
-> m a
sis f k = finish . composeCopies k (advance . hoistFirst f)