-- |
-- Module      : Control.Monad.Bayes.Class
-- Description : Types for probabilistic modelling
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- This module defines 'MonadInfer', which can be used to represent a simple model
-- like the following:
--
-- @
-- import Control.Monad (when)
-- import Control.Monad.Bayes.Class
--
-- model :: MonadInfer m => m Bool
-- model = do
--   rain <- bernoulli 0.3
--   sprinkler <-
--     bernoulli $
--     if rain
--       then 0.1
--       else 0.4
--   let wetProb =
--     case (rain, sprinkler) of
--       (True,  True)  -> 0.98
--       (True,  False) -> 0.80
--       (False, True)  -> 0.90
--       (False, False) -> 0.00
--   score wetProb
--   return rain
-- @
module Control.Monad.Bayes.Class
  ( MonadSample,
    random,
    uniform,
    normal,
    gamma,
    beta,
    bernoulli,
    categorical,
    logCategorical,
    uniformD,
    geometric,
    poisson,
    dirichlet,
    MonadCond,
    score,
    factor,
    condition,
    MonadInfer,
    discrete,
    normalPdf,
  )
where

import Control.Monad (when)
import Control.Monad.Trans.Class
import Control.Monad.Trans.Cont
import Control.Monad.Trans.Identity
import Control.Monad.Trans.List
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.RWS hiding (tell)
import Control.Monad.Trans.Reader
import Control.Monad.Trans.State
import Control.Monad.Trans.Writer
import qualified Data.Vector as V
import Data.Vector.Generic as VG
import Numeric.Log
import Statistics.Distribution
import Statistics.Distribution.Beta (betaDistr)
import Statistics.Distribution.Gamma (gammaDistr)
import Statistics.Distribution.Geometric (geometric0)
import Statistics.Distribution.Normal (normalDistr)
import qualified Statistics.Distribution.Poisson as Poisson
import Statistics.Distribution.Uniform (uniformDistr)

-- | Monads that can draw random variables.
class Monad m => MonadSample m where
  -- | Draw from a uniform distribution.
  random ::
    -- | \(\sim \mathcal{U}(0, 1)\)
    m Double

  -- | Draw from a uniform distribution.
  uniform ::
    -- | lower bound a
    Double ->
    -- | upper bound b
    Double ->
    -- | \(\sim \mathcal{U}(a, b)\).
    m Double
  uniform a :: Double
a b :: Double
b = UniformDistribution -> m Double
forall d (m :: * -> *).
(ContDistr d, MonadSample m) =>
d -> m Double
draw (Double -> Double -> UniformDistribution
uniformDistr Double
a Double
b)

  -- | Draw from a normal distribution.
  normal ::
    -- | mean μ
    Double ->
    -- | standard deviation σ
    Double ->
    -- | \(\sim \mathcal{N}(\mu, \sigma^2)\)
    m Double
  normal m :: Double
m s :: Double
s = NormalDistribution -> m Double
forall d (m :: * -> *).
(ContDistr d, MonadSample m) =>
d -> m Double
draw (Double -> Double -> NormalDistribution
normalDistr Double
m Double
s)

  -- | Draw from a gamma distribution.
  gamma ::
    -- | shape k
    Double ->
    -- | scale θ
    Double ->
    -- | \(\sim \Gamma(k, \theta)\)
    m Double
  gamma shape :: Double
shape scale :: Double
scale = GammaDistribution -> m Double
forall d (m :: * -> *).
(ContDistr d, MonadSample m) =>
d -> m Double
draw (Double -> Double -> GammaDistribution
gammaDistr Double
shape Double
scale)

  -- | Draw from a beta distribution.
  beta ::
    -- | shape α
    Double ->
    -- | shape β
    Double ->
    -- | \(\sim \mathrm{Beta}(\alpha, \beta)\)
    m Double
  beta a :: Double
a b :: Double
b = BetaDistribution -> m Double
forall d (m :: * -> *).
(ContDistr d, MonadSample m) =>
d -> m Double
draw (Double -> Double -> BetaDistribution
betaDistr Double
a Double
b)

  -- | Draw from a Bernoulli distribution.
  bernoulli ::
    -- | probability p
    Double ->
    -- | \(\sim \mathrm{B}(1, p)\)
    m Bool
  bernoulli p :: Double
p = (Double -> Bool) -> m Double -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
p) m Double
forall (m :: * -> *). MonadSample m => m Double
random

  -- | Draw from a categorical distribution.
  categorical ::
    Vector v Double =>
    -- | event probabilities
    v Double ->
    -- | outcome category
    m Int
  categorical ps :: v Double
ps = (Int -> Double) -> m Int
forall (m :: * -> *). MonadSample m => (Int -> Double) -> m Int
fromPMF (v Double
ps v Double -> Int -> Double
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
!)

  -- | Draw from a categorical distribution in the log domain.
  logCategorical ::
    (Vector v (Log Double), Vector v Double) =>
    -- | event probabilities
    v (Log Double) ->
    -- | outcome category
    m Int
  logCategorical = v Double -> m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> m Int
categorical (v Double -> m Int)
-> (v (Log Double) -> v Double) -> v (Log Double) -> m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Log Double -> Double) -> v (Log Double) -> v Double
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
VG.map (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln)

  -- | Draw from a discrete uniform distribution.
  uniformD ::
    -- | observable outcomes @xs@
    [a] ->
    -- | \(\sim \mathcal{U}\{\mathrm{xs}\}\)
    m a
  uniformD xs :: [a]
xs = do
    let n :: Int
n = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
Prelude.length [a]
xs
    Int
i <- Vector Double -> m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> m Int
categorical (Vector Double -> m Int) -> Vector Double -> m Int
forall a b. (a -> b) -> a -> b
$ Int -> Double -> Vector Double
forall a. Int -> a -> Vector a
V.replicate Int
n (1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
    a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ([a]
xs [a] -> Int -> a
forall a. [a] -> Int -> a
!! Int
i)

  -- | Draw from a geometric distribution.
  geometric ::
    -- | success rate p
    Double ->
    -- | \(\sim\) number of failed Bernoulli trials with success probability p before first success
    m Int
  geometric = GeometricDistribution0 -> m Int
forall d (m :: * -> *).
(DiscreteDistr d, MonadSample m) =>
d -> m Int
discrete (GeometricDistribution0 -> m Int)
-> (Double -> GeometricDistribution0) -> Double -> m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> GeometricDistribution0
geometric0

  -- | Draw from a Poisson distribution.
  poisson ::
    -- | parameter λ
    Double ->
    -- | \(\sim \mathrm{Pois}(\lambda)\)
    m Int
  poisson = PoissonDistribution -> m Int
forall d (m :: * -> *).
(DiscreteDistr d, MonadSample m) =>
d -> m Int
discrete (PoissonDistribution -> m Int)
-> (Double -> PoissonDistribution) -> Double -> m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> PoissonDistribution
Poisson.poisson

  -- | Draw from a Dirichlet distribution.
  dirichlet ::
    Vector v Double =>
    -- | concentration parameters @as@
    v Double ->
    -- | \(\sim \mathrm{Dir}(\mathrm{as})\)
    m (v Double)
  dirichlet as :: v Double
as = do
    v Double
xs <- (Double -> m Double) -> v Double -> m (v Double)
forall (m :: * -> *) (v :: * -> *) a b.
(Monad m, Vector v a, Vector v b) =>
(a -> m b) -> v a -> m (v b)
VG.mapM (Double -> Double -> m Double
forall (m :: * -> *). MonadSample m => Double -> Double -> m Double
`gamma` 1) v Double
as
    let s :: Double
s = v Double -> Double
forall (v :: * -> *) a. (Vector v a, Num a) => v a -> a
VG.sum v Double
xs
    let ys :: v Double
ys = (Double -> Double) -> v Double -> v Double
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
VG.map (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
s) v Double
xs
    v Double -> m (v Double)
forall (m :: * -> *) a. Monad m => a -> m a
return v Double
ys

-- | Draw from a continuous distribution using the inverse cumulative density
-- function.
draw :: (ContDistr d, MonadSample m) => d -> m Double
draw :: d -> m Double
draw d :: d
d = (Double -> Double) -> m Double -> m Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (d -> Double -> Double
forall d. ContDistr d => d -> Double -> Double
quantile d
d) m Double
forall (m :: * -> *). MonadSample m => m Double
random

-- | Draw from a discrete distribution using a sequence of draws from
-- Bernoulli.
fromPMF :: MonadSample m => (Int -> Double) -> m Int
fromPMF :: (Int -> Double) -> m Int
fromPMF p :: Int -> Double
p = Int -> Double -> m Int
f 0 1
  where
    f :: Int -> Double -> m Int
f i :: Int
i r :: Double
r = do
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Double
r Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< 0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error "fromPMF: total PMF above 1"
      let q :: Double
q = Int -> Double
p Int
i
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Double
q Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< 0 Bool -> Bool -> Bool
|| Double
q Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> 1) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error "fromPMF: invalid probability value"
      Bool
b <- Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli (Double
q Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
r)
      if Bool
b then Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
i else Int -> Double -> m Int
f (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) (Double
r Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
q)

-- | Draw from a discrete distributions using the probability mass function.
discrete :: (DiscreteDistr d, MonadSample m) => d -> m Int
discrete :: d -> m Int
discrete = (Int -> Double) -> m Int
forall (m :: * -> *). MonadSample m => (Int -> Double) -> m Int
fromPMF ((Int -> Double) -> m Int) -> (d -> Int -> Double) -> d -> m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. d -> Int -> Double
forall d. DiscreteDistr d => d -> Int -> Double
probability

-- | Monads that can score different execution paths.
class Monad m => MonadCond m where
  -- | Record a likelihood.
  score ::
    -- | likelihood of the execution path
    Log Double ->
    m ()

-- | Synonym for 'score'.
factor ::
  MonadCond m =>
  -- | likelihood of the execution path
  Log Double ->
  m ()
factor :: Log Double -> m ()
factor = Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score

-- | Hard conditioning.
condition :: MonadCond m => Bool -> m ()
condition :: Bool -> m ()
condition b :: Bool
b = Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score (Log Double -> m ()) -> Log Double -> m ()
forall a b. (a -> b) -> a -> b
$ if Bool
b then 1 else 0

-- | Monads that support both sampling and scoring.
class (MonadSample m, MonadCond m) => MonadInfer m

-- | Probability density function of the normal distribution.
normalPdf ::
  -- | mean μ
  Double ->
  -- | standard deviation σ
  Double ->
  -- | sample x
  Double ->
  -- | relative likelihood of observing sample x in \(\mathcal{N}(\mu, \sigma^2)\)
  Log Double
normalPdf :: Double -> Double -> Double -> Log Double
normalPdf mu :: Double
mu sigma :: Double
sigma x :: Double
x = Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double) -> Double -> Log Double
forall a b. (a -> b) -> a -> b
$ NormalDistribution -> Double -> Double
forall d. ContDistr d => d -> Double -> Double
logDensity (Double -> Double -> NormalDistribution
normalDistr Double
mu Double
sigma) Double
x

----------------------------------------------------------------------------
-- Instances that lift probabilistic effects to standard tranformers.

instance MonadSample m => MonadSample (IdentityT m) where
  random :: IdentityT m Double
random = m Double -> IdentityT m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
  bernoulli :: Double -> IdentityT m Bool
bernoulli = m Bool -> IdentityT m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Bool -> IdentityT m Bool)
-> (Double -> m Bool) -> Double -> IdentityT m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli

instance MonadCond m => MonadCond (IdentityT m) where
  score :: Log Double -> IdentityT m ()
score = m () -> IdentityT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> IdentityT m ())
-> (Log Double -> m ()) -> Log Double -> IdentityT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score

instance MonadInfer m => MonadInfer (IdentityT m)

instance MonadSample m => MonadSample (MaybeT m) where
  random :: MaybeT m Double
random = m Double -> MaybeT m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random

instance MonadCond m => MonadCond (MaybeT m) where
  score :: Log Double -> MaybeT m ()
score = m () -> MaybeT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> MaybeT m ())
-> (Log Double -> m ()) -> Log Double -> MaybeT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score

instance MonadInfer m => MonadInfer (MaybeT m)

instance MonadSample m => MonadSample (ReaderT r m) where
  random :: ReaderT r m Double
random = m Double -> ReaderT r m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
  bernoulli :: Double -> ReaderT r m Bool
bernoulli = m Bool -> ReaderT r m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Bool -> ReaderT r m Bool)
-> (Double -> m Bool) -> Double -> ReaderT r m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli

instance MonadCond m => MonadCond (ReaderT r m) where
  score :: Log Double -> ReaderT r m ()
score = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ())
-> (Log Double -> m ()) -> Log Double -> ReaderT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score

instance MonadInfer m => MonadInfer (ReaderT r m)

instance (Monoid w, MonadSample m) => MonadSample (WriterT w m) where
  random :: WriterT w m Double
random = m Double -> WriterT w m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
  bernoulli :: Double -> WriterT w m Bool
bernoulli = m Bool -> WriterT w m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Bool -> WriterT w m Bool)
-> (Double -> m Bool) -> Double -> WriterT w m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli
  categorical :: v Double -> WriterT w m Int
categorical = m Int -> WriterT w m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Int -> WriterT w m Int)
-> (v Double -> m Int) -> v Double -> WriterT w m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v Double -> m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> m Int
categorical

instance (Monoid w, MonadCond m) => MonadCond (WriterT w m) where
  score :: Log Double -> WriterT w m ()
score = m () -> WriterT w m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> WriterT w m ())
-> (Log Double -> m ()) -> Log Double -> WriterT w m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score

instance (Monoid w, MonadInfer m) => MonadInfer (WriterT w m)

instance MonadSample m => MonadSample (StateT s m) where
  random :: StateT s m Double
random = m Double -> StateT s m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
  bernoulli :: Double -> StateT s m Bool
bernoulli = m Bool -> StateT s m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Bool -> StateT s m Bool)
-> (Double -> m Bool) -> Double -> StateT s m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli
  categorical :: v Double -> StateT s m Int
categorical = m Int -> StateT s m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Int -> StateT s m Int)
-> (v Double -> m Int) -> v Double -> StateT s m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v Double -> m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> m Int
categorical

instance MonadCond m => MonadCond (StateT s m) where
  score :: Log Double -> StateT s m ()
score = m () -> StateT s m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> StateT s m ())
-> (Log Double -> m ()) -> Log Double -> StateT s m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score

instance MonadInfer m => MonadInfer (StateT s m)

instance (MonadSample m, Monoid w) => MonadSample (RWST r w s m) where
  random :: RWST r w s m Double
random = m Double -> RWST r w s m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random

instance (MonadCond m, Monoid w) => MonadCond (RWST r w s m) where
  score :: Log Double -> RWST r w s m ()
score = m () -> RWST r w s m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> RWST r w s m ())
-> (Log Double -> m ()) -> Log Double -> RWST r w s m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score

instance (MonadInfer m, Monoid w) => MonadInfer (RWST r w s m)

instance MonadSample m => MonadSample (ListT m) where
  random :: ListT m Double
random = m Double -> ListT m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
  bernoulli :: Double -> ListT m Bool
bernoulli = m Bool -> ListT m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Bool -> ListT m Bool)
-> (Double -> m Bool) -> Double -> ListT m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli
  categorical :: v Double -> ListT m Int
categorical = m Int -> ListT m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Int -> ListT m Int)
-> (v Double -> m Int) -> v Double -> ListT m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v Double -> m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> m Int
categorical

instance MonadCond m => MonadCond (ListT m) where
  score :: Log Double -> ListT m ()
score = m () -> ListT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ListT m ())
-> (Log Double -> m ()) -> Log Double -> ListT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score

instance MonadInfer m => MonadInfer (ListT m)

instance MonadSample m => MonadSample (ContT r m) where
  random :: ContT r m Double
random = m Double -> ContT r m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random

instance MonadCond m => MonadCond (ContT r m) where
  score :: Log Double -> ContT r m ()
score = m () -> ContT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ContT r m ())
-> (Log Double -> m ()) -> Log Double -> ContT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score

instance MonadInfer m => MonadInfer (ContT r m)