{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
module Control.Monad.Bayes.Sequential.Coroutine
( SequentialT,
suspend,
finish,
advance,
finished,
hoistFirst,
hoist,
sequentially,
sis,
)
where
import Control.Monad.Bayes.Class
( MonadDistribution (bernoulli, categorical, random),
MonadFactor (..),
MonadMeasure,
)
import Control.Monad.Coroutine
( Coroutine (..),
bounce,
mapMonad,
pogoStick,
)
import Control.Monad.Coroutine.SuspensionFunctors
( Await (..),
await,
)
import Control.Monad.Trans (MonadIO, MonadTrans (..))
import Data.Either (isRight)
newtype SequentialT m a = SequentialT {forall (m :: * -> *) a. SequentialT m a -> Coroutine (Await ()) m a
runSequentialT :: Coroutine (Await ()) m a}
deriving newtype ((forall a b. (a -> b) -> SequentialT m a -> SequentialT m b)
-> (forall a b. a -> SequentialT m b -> SequentialT m a)
-> Functor (SequentialT m)
forall a b. a -> SequentialT m b -> SequentialT m a
forall a b. (a -> b) -> SequentialT m a -> SequentialT m b
forall (m :: * -> *) a b.
Functor m =>
a -> SequentialT m b -> SequentialT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SequentialT m a -> SequentialT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SequentialT m a -> SequentialT m b
fmap :: forall a b. (a -> b) -> SequentialT m a -> SequentialT m b
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> SequentialT m b -> SequentialT m a
<$ :: forall a b. a -> SequentialT m b -> SequentialT m a
Functor, Functor (SequentialT m)
Functor (SequentialT m) =>
(forall a. a -> SequentialT m a)
-> (forall a b.
SequentialT m (a -> b) -> SequentialT m a -> SequentialT m b)
-> (forall a b c.
(a -> b -> c)
-> SequentialT m a -> SequentialT m b -> SequentialT m c)
-> (forall a b.
SequentialT m a -> SequentialT m b -> SequentialT m b)
-> (forall a b.
SequentialT m a -> SequentialT m b -> SequentialT m a)
-> Applicative (SequentialT m)
forall a. a -> SequentialT m a
forall a b. SequentialT m a -> SequentialT m b -> SequentialT m a
forall a b. SequentialT m a -> SequentialT m b -> SequentialT m b
forall a b.
SequentialT m (a -> b) -> SequentialT m a -> SequentialT m b
forall a b c.
(a -> b -> c)
-> SequentialT m a -> SequentialT m b -> SequentialT m c
forall (m :: * -> *). Monad m => Functor (SequentialT m)
forall (m :: * -> *) a. Monad m => a -> SequentialT m a
forall (m :: * -> *) a b.
Monad m =>
SequentialT m a -> SequentialT m b -> SequentialT m a
forall (m :: * -> *) a b.
Monad m =>
SequentialT m a -> SequentialT m b -> SequentialT m b
forall (m :: * -> *) a b.
Monad m =>
SequentialT m (a -> b) -> SequentialT m a -> SequentialT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> SequentialT m a -> SequentialT m b -> SequentialT m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall (m :: * -> *) a. Monad m => a -> SequentialT m a
pure :: forall a. a -> SequentialT m a
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
SequentialT m (a -> b) -> SequentialT m a -> SequentialT m b
<*> :: forall a b.
SequentialT m (a -> b) -> SequentialT m a -> SequentialT m b
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> SequentialT m a -> SequentialT m b -> SequentialT m c
liftA2 :: forall a b c.
(a -> b -> c)
-> SequentialT m a -> SequentialT m b -> SequentialT m c
$c*> :: forall (m :: * -> *) a b.
Monad m =>
SequentialT m a -> SequentialT m b -> SequentialT m b
*> :: forall a b. SequentialT m a -> SequentialT m b -> SequentialT m b
$c<* :: forall (m :: * -> *) a b.
Monad m =>
SequentialT m a -> SequentialT m b -> SequentialT m a
<* :: forall a b. SequentialT m a -> SequentialT m b -> SequentialT m a
Applicative, Applicative (SequentialT m)
Applicative (SequentialT m) =>
(forall a b.
SequentialT m a -> (a -> SequentialT m b) -> SequentialT m b)
-> (forall a b.
SequentialT m a -> SequentialT m b -> SequentialT m b)
-> (forall a. a -> SequentialT m a)
-> Monad (SequentialT m)
forall a. a -> SequentialT m a
forall a b. SequentialT m a -> SequentialT m b -> SequentialT m b
forall a b.
SequentialT m a -> (a -> SequentialT m b) -> SequentialT m b
forall (m :: * -> *). Monad m => Applicative (SequentialT m)
forall (m :: * -> *) a. Monad m => a -> SequentialT m a
forall (m :: * -> *) a b.
Monad m =>
SequentialT m a -> SequentialT m b -> SequentialT m b
forall (m :: * -> *) a b.
Monad m =>
SequentialT m a -> (a -> SequentialT m b) -> SequentialT m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SequentialT m a -> (a -> SequentialT m b) -> SequentialT m b
>>= :: forall a b.
SequentialT m a -> (a -> SequentialT m b) -> SequentialT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SequentialT m a -> SequentialT m b -> SequentialT m b
>> :: forall a b. SequentialT m a -> SequentialT m b -> SequentialT m b
$creturn :: forall (m :: * -> *) a. Monad m => a -> SequentialT m a
return :: forall a. a -> SequentialT m a
Monad, (forall (m :: * -> *). Monad m => Monad (SequentialT m)) =>
(forall (m :: * -> *) a. Monad m => m a -> SequentialT m a)
-> MonadTrans SequentialT
forall (m :: * -> *). Monad m => Monad (SequentialT m)
forall (m :: * -> *) a. Monad m => m a -> SequentialT m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *). Monad m => Monad (t m)) =>
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
$clift :: forall (m :: * -> *) a. Monad m => m a -> SequentialT m a
lift :: forall (m :: * -> *) a. Monad m => m a -> SequentialT m a
MonadTrans, Monad (SequentialT m)
Monad (SequentialT m) =>
(forall a. IO a -> SequentialT m a) -> MonadIO (SequentialT m)
forall a. IO a -> SequentialT m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (SequentialT m)
forall (m :: * -> *) a. MonadIO m => IO a -> SequentialT m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> SequentialT m a
liftIO :: forall a. IO a -> SequentialT m a
MonadIO)
extract :: Await () a -> a
(Await () -> a
f) = () -> a
f ()
instance (MonadDistribution m) => MonadDistribution (SequentialT m) where
random :: SequentialT m Double
random = m Double -> SequentialT m Double
forall (m :: * -> *) a. Monad m => m a -> SequentialT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadDistribution m => m Double
random
bernoulli :: Double -> SequentialT m Bool
bernoulli = m Bool -> SequentialT m Bool
forall (m :: * -> *) a. Monad m => m a -> SequentialT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Bool -> SequentialT m Bool)
-> (Double -> m Bool) -> Double -> SequentialT m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> m Bool
forall (m :: * -> *). MonadDistribution m => Double -> m Bool
bernoulli
categorical :: forall (v :: * -> *).
Vector v Double =>
v Double -> SequentialT m Int
categorical = m Int -> SequentialT m Int
forall (m :: * -> *) a. Monad m => m a -> SequentialT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Int -> SequentialT m Int)
-> (v Double -> m Int) -> v Double -> SequentialT m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v Double -> m Int
forall (v :: * -> *). Vector v Double => v Double -> m Int
forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> m Int
categorical
instance (MonadFactor m) => MonadFactor (SequentialT m) where
score :: Log Double -> SequentialT m ()
score Log Double
w = m () -> SequentialT m ()
forall (m :: * -> *) a. Monad m => m a -> SequentialT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Log Double -> m ()
forall (m :: * -> *). MonadFactor m => Log Double -> m ()
score Log Double
w) SequentialT m () -> SequentialT m () -> SequentialT m ()
forall a b. SequentialT m a -> SequentialT m b -> SequentialT m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SequentialT m ()
forall (m :: * -> *). Monad m => SequentialT m ()
suspend
instance (MonadMeasure m) => MonadMeasure (SequentialT m)
suspend :: (Monad m) => SequentialT m ()
suspend :: forall (m :: * -> *). Monad m => SequentialT m ()
suspend = Coroutine (Await ()) m () -> SequentialT m ()
forall (m :: * -> *) a. Coroutine (Await ()) m a -> SequentialT m a
SequentialT Coroutine (Await ()) m ()
forall (m :: * -> *) x. Monad m => Coroutine (Await x) m x
await
finish :: (Monad m) => SequentialT m a -> m a
finish :: forall (m :: * -> *) a. Monad m => SequentialT m a -> m a
finish = (Await () (Coroutine (Await ()) m a) -> Coroutine (Await ()) m a)
-> Coroutine (Await ()) m a -> m a
forall (m :: * -> *) (s :: * -> *) x.
Monad m =>
(s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick Await () (Coroutine (Await ()) m a) -> Coroutine (Await ()) m a
forall a. Await () a -> a
extract (Coroutine (Await ()) m a -> m a)
-> (SequentialT m a -> Coroutine (Await ()) m a)
-> SequentialT m a
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SequentialT m a -> Coroutine (Await ()) m a
forall (m :: * -> *) a. SequentialT m a -> Coroutine (Await ()) m a
runSequentialT
advance :: (Monad m) => SequentialT m a -> SequentialT m a
advance :: forall (m :: * -> *) a.
Monad m =>
SequentialT m a -> SequentialT m a
advance = Coroutine (Await ()) m a -> SequentialT m a
forall (m :: * -> *) a. Coroutine (Await ()) m a -> SequentialT m a
SequentialT (Coroutine (Await ()) m a -> SequentialT m a)
-> (SequentialT m a -> Coroutine (Await ()) m a)
-> SequentialT m a
-> SequentialT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Await () (Coroutine (Await ()) m a) -> Coroutine (Await ()) m a)
-> Coroutine (Await ()) m a -> Coroutine (Await ()) m a
forall (m :: * -> *) (s :: * -> *) x.
(Monad m, Functor s) =>
(s (Coroutine s m x) -> Coroutine s m x)
-> Coroutine s m x -> Coroutine s m x
bounce Await () (Coroutine (Await ()) m a) -> Coroutine (Await ()) m a
forall a. Await () a -> a
extract (Coroutine (Await ()) m a -> Coroutine (Await ()) m a)
-> (SequentialT m a -> Coroutine (Await ()) m a)
-> SequentialT m a
-> Coroutine (Await ()) m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SequentialT m a -> Coroutine (Await ()) m a
forall (m :: * -> *) a. SequentialT m a -> Coroutine (Await ()) m a
runSequentialT
finished :: (Monad m) => SequentialT m a -> m Bool
finished :: forall (m :: * -> *) a. Monad m => SequentialT m a -> m Bool
finished = (Either (Await () (Coroutine (Await ()) m a)) a -> Bool)
-> m (Either (Await () (Coroutine (Await ()) m a)) a) -> m Bool
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Either (Await () (Coroutine (Await ()) m a)) a -> Bool
forall a b. Either a b -> Bool
isRight (m (Either (Await () (Coroutine (Await ()) m a)) a) -> m Bool)
-> (SequentialT m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a))
-> SequentialT m a
-> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Coroutine (Await ()) m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a)
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume (Coroutine (Await ()) m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a))
-> (SequentialT m a -> Coroutine (Await ()) m a)
-> SequentialT m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SequentialT m a -> Coroutine (Await ()) m a
forall (m :: * -> *) a. SequentialT m a -> Coroutine (Await ()) m a
runSequentialT
hoistFirst :: (forall x. m x -> m x) -> SequentialT m a -> SequentialT m a
hoistFirst :: forall (m :: * -> *) a.
(forall x. m x -> m x) -> SequentialT m a -> SequentialT m a
hoistFirst forall x. m x -> m x
f = Coroutine (Await ()) m a -> SequentialT m a
forall (m :: * -> *) a. Coroutine (Await ()) m a -> SequentialT m a
SequentialT (Coroutine (Await ()) m a -> SequentialT m a)
-> (SequentialT m a -> Coroutine (Await ()) m a)
-> SequentialT m a
-> SequentialT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (Either (Await () (Coroutine (Await ()) m a)) a)
-> Coroutine (Await ()) m a
forall (s :: * -> *) (m :: * -> *) r.
m (Either (s (Coroutine s m r)) r) -> Coroutine s m r
Coroutine (m (Either (Await () (Coroutine (Await ()) m a)) a)
-> Coroutine (Await ()) m a)
-> (SequentialT m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a))
-> SequentialT m a
-> Coroutine (Await ()) m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (Either (Await () (Coroutine (Await ()) m a)) a)
-> m (Either (Await () (Coroutine (Await ()) m a)) a)
forall x. m x -> m x
f (m (Either (Await () (Coroutine (Await ()) m a)) a)
-> m (Either (Await () (Coroutine (Await ()) m a)) a))
-> (SequentialT m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a))
-> SequentialT m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Coroutine (Await ()) m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a)
forall (s :: * -> *) (m :: * -> *) r.
Coroutine s m r -> m (Either (s (Coroutine s m r)) r)
resume (Coroutine (Await ()) m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a))
-> (SequentialT m a -> Coroutine (Await ()) m a)
-> SequentialT m a
-> m (Either (Await () (Coroutine (Await ()) m a)) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SequentialT m a -> Coroutine (Await ()) m a
forall (m :: * -> *) a. SequentialT m a -> Coroutine (Await ()) m a
runSequentialT
hoist ::
(Monad m, Monad n) =>
(forall x. m x -> n x) ->
SequentialT m a ->
SequentialT n a
hoist :: forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> SequentialT m a -> SequentialT n a
hoist forall x. m x -> n x
f = Coroutine (Await ()) n a -> SequentialT n a
forall (m :: * -> *) a. Coroutine (Await ()) m a -> SequentialT m a
SequentialT (Coroutine (Await ()) n a -> SequentialT n a)
-> (SequentialT m a -> Coroutine (Await ()) n a)
-> SequentialT m a
-> SequentialT n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. m x -> n x)
-> Coroutine (Await ()) m a -> Coroutine (Await ()) n a
forall (s :: * -> *) (m :: * -> *) (m' :: * -> *) x.
(Functor s, Monad m, Monad m') =>
(forall y. m y -> m' y) -> Coroutine s m x -> Coroutine s m' x
mapMonad m y -> n y
forall x. m x -> n x
f (Coroutine (Await ()) m a -> Coroutine (Await ()) n a)
-> (SequentialT m a -> Coroutine (Await ()) m a)
-> SequentialT m a
-> Coroutine (Await ()) n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SequentialT m a -> Coroutine (Await ()) m a
forall (m :: * -> *) a. SequentialT m a -> Coroutine (Await ()) m a
runSequentialT
composeCopies :: Int -> (a -> a) -> (a -> a)
composeCopies :: forall a. Int -> (a -> a) -> a -> a
composeCopies Int
k a -> a
f = ((a -> a) -> (a -> a) -> a -> a) -> (a -> a) -> [a -> a] -> a -> a
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) a -> a
forall a. a -> a
id (Int -> (a -> a) -> [a -> a]
forall a. Int -> a -> [a]
replicate Int
k a -> a
f)
sequentially,
sis ::
(Monad m) =>
(forall x. m x -> m x) ->
Int ->
SequentialT m a ->
m a
sequentially :: forall (m :: * -> *) a.
Monad m =>
(forall x. m x -> m x) -> Int -> SequentialT m a -> m a
sequentially forall x. m x -> m x
f Int
k = SequentialT m a -> m a
forall (m :: * -> *) a. Monad m => SequentialT m a -> m a
finish (SequentialT m a -> m a)
-> (SequentialT m a -> SequentialT m a) -> SequentialT m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int
-> (SequentialT m a -> SequentialT m a)
-> SequentialT m a
-> SequentialT m a
forall a. Int -> (a -> a) -> a -> a
composeCopies Int
k (SequentialT m a -> SequentialT m a
forall (m :: * -> *) a.
Monad m =>
SequentialT m a -> SequentialT m a
advance (SequentialT m a -> SequentialT m a)
-> (SequentialT m a -> SequentialT m a)
-> SequentialT m a
-> SequentialT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall x. m x -> m x) -> SequentialT m a -> SequentialT m a
forall (m :: * -> *) a.
(forall x. m x -> m x) -> SequentialT m a -> SequentialT m a
hoistFirst m x -> m x
forall x. m x -> m x
f)
sis :: forall (m :: * -> *) a.
Monad m =>
(forall x. m x -> m x) -> Int -> SequentialT m a -> m a
sis = (forall x. m x -> m x) -> Int -> SequentialT m a -> m a
forall (m :: * -> *) a.
Monad m =>
(forall x. m x -> m x) -> Int -> SequentialT m a -> m a
sequentially