-- |
-- Module      : Control.Monad.Bayes.Weighted
-- Description : Probability monad accumulating the likelihood
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- 'Weighted' is an instance of 'MonadCond'. Apply a 'MonadSample' transformer to
-- obtain a 'MonadInfer' that can execute probabilistic models.
module Control.Monad.Bayes.Weighted
  ( Weighted,
    withWeight,
    runWeighted,
    extractWeight,
    prior,
    flatten,
    applyWeight,
    hoist,
  )
where

import Control.Monad.Bayes.Class
import Control.Monad.Trans (MonadIO, MonadTrans (..))
import Control.Monad.Trans.State (StateT (..), mapStateT, modify)
import Numeric.Log (Log)

-- | Execute the program using the prior distribution, while accumulating likelihood.
newtype Weighted m a = Weighted (StateT (Log Double) m a)
  -- StateT is more efficient than WriterT
  deriving (a -> Weighted m b -> Weighted m a
(a -> b) -> Weighted m a -> Weighted m b
(forall a b. (a -> b) -> Weighted m a -> Weighted m b)
-> (forall a b. a -> Weighted m b -> Weighted m a)
-> Functor (Weighted m)
forall a b. a -> Weighted m b -> Weighted m a
forall a b. (a -> b) -> Weighted m a -> Weighted m b
forall (m :: * -> *) a b.
Functor m =>
a -> Weighted m b -> Weighted m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> Weighted m a -> Weighted m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Weighted m b -> Weighted m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> Weighted m b -> Weighted m a
fmap :: (a -> b) -> Weighted m a -> Weighted m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> Weighted m a -> Weighted m b
Functor, Functor (Weighted m)
a -> Weighted m a
Functor (Weighted m) =>
(forall a. a -> Weighted m a)
-> (forall a b.
    Weighted m (a -> b) -> Weighted m a -> Weighted m b)
-> (forall a b c.
    (a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c)
-> (forall a b. Weighted m a -> Weighted m b -> Weighted m b)
-> (forall a b. Weighted m a -> Weighted m b -> Weighted m a)
-> Applicative (Weighted m)
Weighted m a -> Weighted m b -> Weighted m b
Weighted m a -> Weighted m b -> Weighted m a
Weighted m (a -> b) -> Weighted m a -> Weighted m b
(a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c
forall a. a -> Weighted m a
forall a b. Weighted m a -> Weighted m b -> Weighted m a
forall a b. Weighted m a -> Weighted m b -> Weighted m b
forall a b. Weighted m (a -> b) -> Weighted m a -> Weighted m b
forall a b c.
(a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c
forall (m :: * -> *). Monad m => Functor (Weighted m)
forall (m :: * -> *) a. Monad m => a -> Weighted m a
forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m a
forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m b
forall (m :: * -> *) a b.
Monad m =>
Weighted m (a -> b) -> Weighted m a -> Weighted m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: Weighted m a -> Weighted m b -> Weighted m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m a
*> :: Weighted m a -> Weighted m b -> Weighted m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m b
liftA2 :: (a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c
<*> :: Weighted m (a -> b) -> Weighted m a -> Weighted m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
Weighted m (a -> b) -> Weighted m a -> Weighted m b
pure :: a -> Weighted m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> Weighted m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (Weighted m)
Applicative, Applicative (Weighted m)
a -> Weighted m a
Applicative (Weighted m) =>
(forall a b. Weighted m a -> (a -> Weighted m b) -> Weighted m b)
-> (forall a b. Weighted m a -> Weighted m b -> Weighted m b)
-> (forall a. a -> Weighted m a)
-> Monad (Weighted m)
Weighted m a -> (a -> Weighted m b) -> Weighted m b
Weighted m a -> Weighted m b -> Weighted m b
forall a. a -> Weighted m a
forall a b. Weighted m a -> Weighted m b -> Weighted m b
forall a b. Weighted m a -> (a -> Weighted m b) -> Weighted m b
forall (m :: * -> *). Monad m => Applicative (Weighted m)
forall (m :: * -> *) a. Monad m => a -> Weighted m a
forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m b
forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> (a -> Weighted m b) -> Weighted m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> Weighted m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> Weighted m a
>> :: Weighted m a -> Weighted m b -> Weighted m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m b
>>= :: Weighted m a -> (a -> Weighted m b) -> Weighted m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> (a -> Weighted m b) -> Weighted m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (Weighted m)
Monad, Monad (Weighted m)
Monad (Weighted m) =>
(forall a. IO a -> Weighted m a) -> MonadIO (Weighted m)
IO a -> Weighted m a
forall a. IO a -> Weighted m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (Weighted m)
forall (m :: * -> *) a. MonadIO m => IO a -> Weighted m a
liftIO :: IO a -> Weighted m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> Weighted m a
$cp1MonadIO :: forall (m :: * -> *). MonadIO m => Monad (Weighted m)
MonadIO, m a -> Weighted m a
(forall (m :: * -> *) a. Monad m => m a -> Weighted m a)
-> MonadTrans Weighted
forall (m :: * -> *) a. Monad m => m a -> Weighted m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> Weighted m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> Weighted m a
MonadTrans, Monad (Weighted m)
Weighted m Double
v Double -> Weighted m Int
v (Log Double) -> Weighted m Int
v Double -> Weighted m (v Double)
Monad (Weighted m) =>
Weighted m Double
-> (Double -> Double -> Weighted m Double)
-> (Double -> Double -> Weighted m Double)
-> (Double -> Double -> Weighted m Double)
-> (Double -> Double -> Weighted m Double)
-> (Double -> Weighted m Bool)
-> (forall (v :: * -> *).
    Vector v Double =>
    v Double -> Weighted m Int)
-> (forall (v :: * -> *).
    (Vector v (Log Double), Vector v Double) =>
    v (Log Double) -> Weighted m Int)
-> (forall a. [a] -> Weighted m a)
-> (Double -> Weighted m Int)
-> (Double -> Weighted m Int)
-> (forall (v :: * -> *).
    Vector v Double =>
    v Double -> Weighted m (v Double))
-> MonadSample (Weighted m)
Double -> Weighted m Bool
Double -> Weighted m Int
Double -> Double -> Weighted m Double
[a] -> Weighted m a
forall a. [a] -> Weighted m a
forall (m :: * -> *).
Monad m =>
m Double
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> m Bool)
-> (forall (v :: * -> *). Vector v Double => v Double -> m Int)
-> (forall (v :: * -> *).
    (Vector v (Log Double), Vector v Double) =>
    v (Log Double) -> m Int)
-> (forall a. [a] -> m a)
-> (Double -> m Int)
-> (Double -> m Int)
-> (forall (v :: * -> *).
    Vector v Double =>
    v Double -> m (v Double))
-> MonadSample m
forall (v :: * -> *).
Vector v Double =>
v Double -> Weighted m (v Double)
forall (v :: * -> *). Vector v Double => v Double -> Weighted m Int
forall (v :: * -> *).
(Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Weighted m Int
forall (m :: * -> *). MonadSample m => Monad (Weighted m)
forall (m :: * -> *). MonadSample m => Weighted m Double
forall (m :: * -> *). MonadSample m => Double -> Weighted m Bool
forall (m :: * -> *). MonadSample m => Double -> Weighted m Int
forall (m :: * -> *).
MonadSample m =>
Double -> Double -> Weighted m Double
forall (m :: * -> *) a. MonadSample m => [a] -> Weighted m a
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> Weighted m (v Double)
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> Weighted m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Weighted m Int
dirichlet :: v Double -> Weighted m (v Double)
$cdirichlet :: forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> Weighted m (v Double)
poisson :: Double -> Weighted m Int
$cpoisson :: forall (m :: * -> *). MonadSample m => Double -> Weighted m Int
geometric :: Double -> Weighted m Int
$cgeometric :: forall (m :: * -> *). MonadSample m => Double -> Weighted m Int
uniformD :: [a] -> Weighted m a
$cuniformD :: forall (m :: * -> *) a. MonadSample m => [a] -> Weighted m a
logCategorical :: v (Log Double) -> Weighted m Int
$clogCategorical :: forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Weighted m Int
categorical :: v Double -> Weighted m Int
$ccategorical :: forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> Weighted m Int
bernoulli :: Double -> Weighted m Bool
$cbernoulli :: forall (m :: * -> *). MonadSample m => Double -> Weighted m Bool
beta :: Double -> Double -> Weighted m Double
$cbeta :: forall (m :: * -> *).
MonadSample m =>
Double -> Double -> Weighted m Double
gamma :: Double -> Double -> Weighted m Double
$cgamma :: forall (m :: * -> *).
MonadSample m =>
Double -> Double -> Weighted m Double
normal :: Double -> Double -> Weighted m Double
$cnormal :: forall (m :: * -> *).
MonadSample m =>
Double -> Double -> Weighted m Double
uniform :: Double -> Double -> Weighted m Double
$cuniform :: forall (m :: * -> *).
MonadSample m =>
Double -> Double -> Weighted m Double
random :: Weighted m Double
$crandom :: forall (m :: * -> *). MonadSample m => Weighted m Double
$cp1MonadSample :: forall (m :: * -> *). MonadSample m => Monad (Weighted m)
MonadSample)

instance Monad m => MonadCond (Weighted m) where
  score :: Log Double -> Weighted m ()
score w :: Log Double
w = StateT (Log Double) m () -> Weighted m ()
forall (m :: * -> *) a. StateT (Log Double) m a -> Weighted m a
Weighted ((Log Double -> Log Double) -> StateT (Log Double) m ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
w))

instance MonadSample m => MonadInfer (Weighted m)

-- | Obtain an explicit value of the likelihood for a given value.
runWeighted :: (Functor m) => Weighted m a -> m (a, Log Double)
runWeighted :: Weighted m a -> m (a, Log Double)
runWeighted (Weighted m :: StateT (Log Double) m a
m) = StateT (Log Double) m a -> Log Double -> m (a, Log Double)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT (Log Double) m a
m 1

-- | Compute the sample and discard the weight.
--
-- This operation introduces bias.
prior :: Functor m => Weighted m a -> m a
prior :: Weighted m a -> m a
prior = ((a, Log Double) -> a) -> m (a, Log Double) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Log Double) -> a
forall a b. (a, b) -> a
fst (m (a, Log Double) -> m a)
-> (Weighted m a -> m (a, Log Double)) -> Weighted m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Weighted m a -> m (a, Log Double)
forall (m :: * -> *) a.
Functor m =>
Weighted m a -> m (a, Log Double)
runWeighted

-- | Compute the weight and discard the sample.
extractWeight :: Functor m => Weighted m a -> m (Log Double)
extractWeight :: Weighted m a -> m (Log Double)
extractWeight = ((a, Log Double) -> Log Double)
-> m (a, Log Double) -> m (Log Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd (m (a, Log Double) -> m (Log Double))
-> (Weighted m a -> m (a, Log Double))
-> Weighted m a
-> m (Log Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Weighted m a -> m (a, Log Double)
forall (m :: * -> *) a.
Functor m =>
Weighted m a -> m (a, Log Double)
runWeighted

-- | Embed a random variable with explicitly given likelihood.
--
-- > runWeighted . withWeight = id
withWeight :: (Monad m) => m (a, Log Double) -> Weighted m a
withWeight :: m (a, Log Double) -> Weighted m a
withWeight m :: m (a, Log Double)
m = StateT (Log Double) m a -> Weighted m a
forall (m :: * -> *) a. StateT (Log Double) m a -> Weighted m a
Weighted (StateT (Log Double) m a -> Weighted m a)
-> StateT (Log Double) m a -> Weighted m a
forall a b. (a -> b) -> a -> b
$ do
  (x :: a
x, w :: Log Double
w) <- m (a, Log Double) -> StateT (Log Double) m (a, Log Double)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (a, Log Double)
m
  (Log Double -> Log Double) -> StateT (Log Double) m ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
w)
  a -> StateT (Log Double) m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

-- | Combine weights from two different levels.
flatten :: Monad m => Weighted (Weighted m) a -> Weighted m a
flatten :: Weighted (Weighted m) a -> Weighted m a
flatten m :: Weighted (Weighted m) a
m = m (a, Log Double) -> Weighted m a
forall (m :: * -> *) a.
Monad m =>
m (a, Log Double) -> Weighted m a
withWeight (m (a, Log Double) -> Weighted m a)
-> m (a, Log Double) -> Weighted m a
forall a b. (a -> b) -> a -> b
$ (\((x :: a
x, p :: Log Double
p), q :: Log Double
q) -> (a
x, Log Double
p Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
q)) (((a, Log Double), Log Double) -> (a, Log Double))
-> m ((a, Log Double), Log Double) -> m (a, Log Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Weighted m (a, Log Double) -> m ((a, Log Double), Log Double)
forall (m :: * -> *) a.
Functor m =>
Weighted m a -> m (a, Log Double)
runWeighted (Weighted (Weighted m) a -> Weighted m (a, Log Double)
forall (m :: * -> *) a.
Functor m =>
Weighted m a -> m (a, Log Double)
runWeighted Weighted (Weighted m) a
m)

-- | Use the weight as a factor in the transformed monad.
applyWeight :: MonadCond m => Weighted m a -> m a
applyWeight :: Weighted m a -> m a
applyWeight m :: Weighted m a
m = do
  (x :: a
x, w :: Log Double
w) <- Weighted m a -> m (a, Log Double)
forall (m :: * -> *) a.
Functor m =>
Weighted m a -> m (a, Log Double)
runWeighted Weighted m a
m
  Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
factor Log Double
w
  a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

-- | Apply a transformation to the transformed monad.
hoist :: (forall x. m x -> n x) -> Weighted m a -> Weighted n a
hoist :: (forall x. m x -> n x) -> Weighted m a -> Weighted n a
hoist t :: forall x. m x -> n x
t (Weighted m :: StateT (Log Double) m a
m) = StateT (Log Double) n a -> Weighted n a
forall (m :: * -> *) a. StateT (Log Double) m a -> Weighted m a
Weighted (StateT (Log Double) n a -> Weighted n a)
-> StateT (Log Double) n a -> Weighted n a
forall a b. (a -> b) -> a -> b
$ (m (a, Log Double) -> n (a, Log Double))
-> StateT (Log Double) m a -> StateT (Log Double) n a
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT m (a, Log Double) -> n (a, Log Double)
forall x. m x -> n x
t StateT (Log Double) m a
m