-- |
-- Module      : Control.Monad.Bayes.Sequential
-- Description : Suspendable probabilistic computation
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- 'Sequential' represents a computation that can be suspended.
module Control.Monad.Bayes.Sequential
  ( Sequential,
    suspend,
    finish,
    advance,
    finished,
    hoistFirst,
    hoist,
    sis,
  )
where

import Control.Monad.Bayes.Class
import Control.Monad.Coroutine hiding (suspend)
import Control.Monad.Coroutine.SuspensionFunctors
import Control.Monad.Trans
import Data.Either

-- | Represents a computation that can be suspended at certain points.
-- The intermediate monadic effects can be extracted, which is particularly
-- useful for implementation of Sequential Monte Carlo related methods.
-- All the probabilistic effects are lifted from the transformed monad, but
-- also `suspend` is inserted after each `factor`.
newtype Sequential m a = Sequential {Sequential m a -> Coroutine (Await ()) m a
runSequential :: Coroutine (Await ()) m a}
  deriving (a -> Sequential m b -> Sequential m a
(a -> b) -> Sequential m a -> Sequential m b
(forall a b. (a -> b) -> Sequential m a -> Sequential m b)
-> (forall a b. a -> Sequential m b -> Sequential m a)
-> Functor (Sequential m)
forall a b. a -> Sequential m b -> Sequential m a
forall a b. (a -> b) -> Sequential m a -> Sequential m b
forall (m :: * -> *) a b.
Functor m =>
a -> Sequential m b -> Sequential m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> Sequential m a -> Sequential m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Sequential m b -> Sequential m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> Sequential m b -> Sequential m a
fmap :: (a -> b) -> Sequential m a -> Sequential m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> Sequential m a -> Sequential m b
Functor, Functor (Sequential m)
a -> Sequential m a
Functor (Sequential m) =>
(forall a. a -> Sequential m a)
-> (forall a b.
    Sequential m (a -> b) -> Sequential m a -> Sequential m b)
-> (forall a b c.
    (a -> b -> c)
    -> Sequential m a -> Sequential m b -> Sequential m c)
-> (forall a b. Sequential m a -> Sequential m b -> Sequential m b)
-> (forall a b. Sequential m a -> Sequential m b -> Sequential m a)
-> Applicative (Sequential m)
Sequential m a -> Sequential m b -> Sequential m b
Sequential m a -> Sequential m b -> Sequential m a
Sequential m (a -> b) -> Sequential m a -> Sequential m b
(a -> b -> c) -> Sequential m a -> Sequential m b -> Sequential m c
forall a. a -> Sequential m a
forall a b. Sequential m a -> Sequential m b -> Sequential m a
forall a b. Sequential m a -> Sequential m b -> Sequential m b
forall a b.
Sequential m (a -> b) -> Sequential m a -> Sequential m b
forall a b c.
(a -> b -> c) -> Sequential m a -> Sequential m b -> Sequential m c
forall (m :: * -> *). Monad m => Functor (Sequential m)
forall (m :: * -> *) a. Monad m => a -> Sequential m a
forall (m :: * -> *) a b.
Monad m =>
Sequential m a -> Sequential m b -> Sequential m a
forall (m :: * -> *) a b.
Monad m =>
Sequential m a -> Sequential m b -> Sequential m b
forall (m :: * -> *) a b.
Monad m =>
Sequential m (a -> b) -> Sequential m a -> Sequential m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> Sequential m a -> Sequential m b -> Sequential m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: Sequential m a -> Sequential m b -> Sequential m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
Sequential m a -> Sequential m b -> Sequential m a
*> :: Sequential m a -> Sequential m b -> Sequential m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
Sequential m a -> Sequential m b -> Sequential m b
liftA2 :: (a -> b -> c) -> Sequential m a -> Sequential m b -> Sequential m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> Sequential m a -> Sequential m b -> Sequential m c
<*> :: Sequential m (a -> b) -> Sequential m a -> Sequential m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
Sequential m (a -> b) -> Sequential m a -> Sequential m b
pure :: a -> Sequential m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> Sequential m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (Sequential m)
Applicative, Applicative (Sequential m)
a -> Sequential m a
Applicative (Sequential m) =>
(forall a b.
 Sequential m a -> (a -> Sequential m b) -> Sequential m b)
-> (forall a b. Sequential m a -> Sequential m b -> Sequential m b)
-> (forall a. a -> Sequential m a)
-> Monad (Sequential m)
Sequential m a -> (a -> Sequential m b) -> Sequential m b
Sequential m a -> Sequential m b -> Sequential m b
forall a. a -> Sequential m a
forall a b. Sequential m a -> Sequential m b -> Sequential m b
forall a b.
Sequential m a -> (a -> Sequential m b) -> Sequential m b
forall (m :: * -> *). Monad m => Applicative (Sequential m)
forall (m :: * -> *) a. Monad m => a -> Sequential m a
forall (m :: * -> *) a b.
Monad m =>
Sequential m a -> Sequential m b -> Sequential m b
forall (m :: * -> *) a b.
Monad m =>
Sequential m a -> (a -> Sequential m b) -> Sequential m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> Sequential m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> Sequential m a
>> :: Sequential m a -> Sequential m b -> Sequential m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
Sequential m a -> Sequential m b -> Sequential m b
>>= :: Sequential m a -> (a -> Sequential m b) -> Sequential m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
Sequential m a -> (a -> Sequential m b) -> Sequential m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (Sequential m)
Monad, m a -> Sequential m a
(forall (m :: * -> *) a. Monad m => m a -> Sequential m a)
-> MonadTrans Sequential
forall (m :: * -> *) a. Monad m => m a -> Sequential m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> Sequential m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> Sequential m a
MonadTrans, Monad (Sequential m)
Monad (Sequential m) =>
(forall a. IO a -> Sequential m a) -> MonadIO (Sequential m)
IO a -> Sequential m a
forall a. IO a -> Sequential m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (Sequential m)
forall (m :: * -> *) a. MonadIO m => IO a -> Sequential m a
liftIO :: IO a -> Sequential m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> Sequential m a
$cp1MonadIO :: forall (m :: * -> *). MonadIO m => Monad (Sequential m)
MonadIO)

extract :: Await () a -> a
extract :: Await () a -> a
extract (Await f :: () -> a
f) = () -> a
f ()

instance MonadSample m => MonadSample (Sequential m) where
  random :: Sequential m Double
random = m Double -> Sequential m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
  bernoulli :: Double -> Sequential m Bool
bernoulli = m Bool -> Sequential m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Bool -> Sequential m Bool)
-> (Double -> m Bool) -> Double -> Sequential m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli
  categorical :: v Double -> Sequential m Int
categorical = m Int -> Sequential m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Int -> Sequential m Int)
-> (v Double -> m Int) -> v Double -> Sequential m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v Double -> m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> m Int
categorical

-- | Execution is 'suspend'ed after each 'score'.
instance MonadCond m => MonadCond (Sequential m) where
  score :: Log Double -> Sequential m ()
score w :: Log Double
w = m () -> Sequential m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score Log Double
w) Sequential m () -> Sequential m () -> Sequential m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Sequential m ()
forall (m :: * -> *). Monad m => Sequential m ()
suspend

instance MonadInfer m => MonadInfer (Sequential m)

-- | A point where the computation is paused.
suspend :: Monad m => Sequential m ()
suspend :: Sequential m ()
suspend = Coroutine (Await ()) m () -> Sequential m ()
forall (m :: * -> *) a. Coroutine (Await ()) m a -> Sequential m a
Sequential Coroutine (Await ()) m ()
forall (m :: * -> *) x. Monad m => Coroutine (Await x) m x
await

-- | Remove the remaining suspension points.
finish :: Monad m => Sequential m a -> m a
finish :: Sequential m a -> m a
finish = (Await () (Coroutine (Await ()) m a) -> Coroutine (Await ()) m a)
-> Coroutine (Await ()) m a -> m a
forall (m :: * -> *) (s :: * -> *) x.
Monad m =>
(s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick Await () (Coroutine (Await ()) m a) -> Coroutine (Await ()) m a
forall a. Await () a -> a
extract (Coroutine (Await ()) m a -> m a)
-> (Sequential m a -> Coroutine (Await ()) m a)
-> Sequential m a
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sequential m a -> Coroutine (Await ()) m a
forall (m :: * -> *) a. Sequential m a -> Coroutine (Await ()) m a
runSequential

-- | Execute to the next suspension point.
-- If the computation is finished, do nothing.
--
-- > finish = finish . advance
advance :: Monad m => Sequential m a -> Sequential m a
advance :: Sequential m a -> Sequential m a
advance = Coroutine (Await ()) m a -> Sequential m a
forall (m :: * -> *) a. Coroutine (Await ()) m a -> Sequential m a
Sequential (Coroutine (Await ()) m a -> Sequential m a)
-> (Sequential m a -> Coroutine (Await ()) m a)
-> Sequential m a
-> Sequential m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Await () (Coroutine (Await ()) m a) -> Coroutine (Await ()) m a)
-> Coroutine (Await ()) m a -> Coroutine (Await ()) m a
forall (m :: * -> *) (s :: * -> *) x.
(Monad m, Functor s) =>
(s (Coroutine s m x) -> Coroutine s m x)
-> Coroutine s m x -> Coroutine s m x
bounce Await () (Coroutine (Await ()) m a) -> Coroutine (Await ()) m a
forall a. Await () a -> a
extract (Coroutine (Await ()) m a -> Coroutine (Await ()) m a)
-> (Sequential m a -> Coroutine (Await ()) m a)
-> Sequential m a
-> Coroutine (Await ()) m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sequential m a -> Coroutine (Await ()) m a
forall (m :: * -> *) a. Sequential m a -> Coroutine (Await ()) m a
runSequential

-- | Return True if no more suspension points remain.
finished :: Monad m => Sequential m a -> m Bool
finished :: Sequential m a -> m Bool
finished = (Either (Await () (Coroutine (Await ()) m a)) a -> Bool)
-> m (Either (Await () (Coroutine (Await ()) m a)) a) -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Either (Await () (Coroutine (Await ()) m a)) a -> Bool
forall a b. Either a b -> Bool
isRight (m (Either (Await () (Coroutine (Await ()) m a)) a) -> m Bool)
-> (Sequential m a
    -> m (Either (Await () (Coroutine (Await ()) m a)) a))
-> Sequential m a
-> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Coroutine (Await ()) m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a)
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume (Coroutine (Await ()) m a
 -> m (Either (Await () (Coroutine (Await ()) m a)) a))
-> (Sequential m a -> Coroutine (Await ()) m a)
-> Sequential m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sequential m a -> Coroutine (Await ()) m a
forall (m :: * -> *) a. Sequential m a -> Coroutine (Await ()) m a
runSequential

-- | Transform the inner monad.
-- This operation only applies to computation up to the first suspension.
hoistFirst :: (forall x. m x -> m x) -> Sequential m a -> Sequential m a
hoistFirst :: (forall x. m x -> m x) -> Sequential m a -> Sequential m a
hoistFirst f :: forall x. m x -> m x
f = Coroutine (Await ()) m a -> Sequential m a
forall (m :: * -> *) a. Coroutine (Await ()) m a -> Sequential m a
Sequential (Coroutine (Await ()) m a -> Sequential m a)
-> (Sequential m a -> Coroutine (Await ()) m a)
-> Sequential m a
-> Sequential m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (Either (Await () (Coroutine (Await ()) m a)) a)
-> Coroutine (Await ()) m a
forall (s :: * -> *) (m :: * -> *) r.
m (Either (s (Coroutine s m r)) r) -> Coroutine s m r
Coroutine (m (Either (Await () (Coroutine (Await ()) m a)) a)
 -> Coroutine (Await ()) m a)
-> (Sequential m a
    -> m (Either (Await () (Coroutine (Await ()) m a)) a))
-> Sequential m a
-> Coroutine (Await ()) m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (Either (Await () (Coroutine (Await ()) m a)) a)
-> m (Either (Await () (Coroutine (Await ()) m a)) a)
forall x. m x -> m x
f (m (Either (Await () (Coroutine (Await ()) m a)) a)
 -> m (Either (Await () (Coroutine (Await ()) m a)) a))
-> (Sequential m a
    -> m (Either (Await () (Coroutine (Await ()) m a)) a))
-> Sequential m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Coroutine (Await ()) m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a)
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume (Coroutine (Await ()) m a
 -> m (Either (Await () (Coroutine (Await ()) m a)) a))
-> (Sequential m a -> Coroutine (Await ()) m a)
-> Sequential m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sequential m a -> Coroutine (Await ()) m a
forall (m :: * -> *) a. Sequential m a -> Coroutine (Await ()) m a
runSequential

-- | Transform the inner monad.
-- The transformation is applied recursively through all the suspension points.
hoist ::
  (Monad m, Monad n) =>
  (forall x. m x -> n x) ->
  Sequential m a ->
  Sequential n a
hoist :: (forall x. m x -> n x) -> Sequential m a -> Sequential n a
hoist f :: forall x. m x -> n x
f = Coroutine (Await ()) n a -> Sequential n a
forall (m :: * -> *) a. Coroutine (Await ()) m a -> Sequential m a
Sequential (Coroutine (Await ()) n a -> Sequential n a)
-> (Sequential m a -> Coroutine (Await ()) n a)
-> Sequential m a
-> Sequential n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. m x -> n x)
-> Coroutine (Await ()) m a -> Coroutine (Await ()) n a
forall (s :: * -> *) (m :: * -> *) (m' :: * -> *) x.
(Functor s, Monad m, Monad m') =>
(forall y. m y -> m' y) -> Coroutine s m x -> Coroutine s m' x
mapMonad forall x. m x -> n x
f (Coroutine (Await ()) m a -> Coroutine (Await ()) n a)
-> (Sequential m a -> Coroutine (Await ()) m a)
-> Sequential m a
-> Coroutine (Await ()) n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sequential m a -> Coroutine (Await ()) m a
forall (m :: * -> *) a. Sequential m a -> Coroutine (Await ()) m a
runSequential

-- | Apply a function a given number of times.
composeCopies :: Int -> (a -> a) -> (a -> a)
composeCopies :: Int -> (a -> a) -> a -> a
composeCopies k :: Int
k f :: a -> a
f = ((a -> a) -> (a -> a) -> a -> a) -> (a -> a) -> [a -> a] -> a -> a
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) a -> a
forall a. a -> a
id (Int -> (a -> a) -> [a -> a]
forall a. Int -> a -> [a]
replicate Int
k a -> a
f)

-- | Sequential importance sampling.
-- Applies a given transformation after each time step.
sis ::
  Monad m =>
  -- | transformation
  (forall x. m x -> m x) ->
  -- | number of time steps
  Int ->
  Sequential m a ->
  m a
sis :: (forall x. m x -> m x) -> Int -> Sequential m a -> m a
sis f :: forall x. m x -> m x
f k :: Int
k = Sequential m a -> m a
forall (m :: * -> *) a. Monad m => Sequential m a -> m a
finish (Sequential m a -> m a)
-> (Sequential m a -> Sequential m a) -> Sequential m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int
-> (Sequential m a -> Sequential m a)
-> Sequential m a
-> Sequential m a
forall a. Int -> (a -> a) -> a -> a
composeCopies Int
k (Sequential m a -> Sequential m a
forall (m :: * -> *) a. Monad m => Sequential m a -> Sequential m a
advance (Sequential m a -> Sequential m a)
-> (Sequential m a -> Sequential m a)
-> Sequential m a
-> Sequential m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. m x -> m x) -> Sequential m a -> Sequential m a
forall (m :: * -> *) a.
(forall x. m x -> m x) -> Sequential m a -> Sequential m a
hoistFirst forall x. m x -> m x
f)