{-|
Module      : Control.Monad.Bayes.Class
Description : Types for probabilistic modelling
Copyright   : (c) Adam Scibior, 2015-2020
License     : MIT
Maintainer  : leonhard.markert@tweag.io
Stability   : experimental
Portability : GHC

This module defines 'MonadInfer', which can be used to represent a simple model
like the following:

@
import Control.Monad (when)
import Control.Monad.Bayes.Class

model :: MonadInfer m => m Bool
model = do
  rain <- bernoulli 0.3
  sprinkler <-
    bernoulli $
    if rain
      then 0.1
      else 0.4
  let wetProb =
    case (rain, sprinkler) of
      (True,  True)  -> 0.98
      (True,  False) -> 0.80
      (False, True)  -> 0.90
      (False, False) -> 0.00
  score wetProb
  return rain
@
-}

module Control.Monad.Bayes.Class (
  MonadSample,
  random,
  uniform,
  normal,
  gamma,
  beta,
  bernoulli,
  categorical,
  logCategorical,
  uniformD,
  geometric,
  poisson,
  dirichlet,
  MonadCond,
  score,
  factor,
  condition,
  MonadInfer,
  normalPdf
) where

import Control.Monad.Trans.Class
import Control.Monad.Trans.Identity
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.State
import Control.Monad.Trans.Writer
import Control.Monad.Trans.Reader
import Control.Monad.Trans.RWS hiding (tell)
import Control.Monad.Trans.List
import Control.Monad.Trans.Cont

import Statistics.Distribution
import Statistics.Distribution.Uniform (uniformDistr)
import Statistics.Distribution.Normal (normalDistr)
import Statistics.Distribution.Gamma (gammaDistr)
import Statistics.Distribution.Beta (betaDistr)
import Statistics.Distribution.Geometric (geometric0)
import qualified Statistics.Distribution.Poisson as Poisson

import Numeric.Log

import Data.Vector.Generic as VG
import qualified Data.Vector as V
import Control.Monad (when)

-- | Monads that can draw random variables.
class Monad m => MonadSample m where
  -- | Draw from a uniform distribution.
  random :: m Double -- ^ \(\sim \mathcal{U}(0, 1)\)

  -- | Draw from a uniform distribution.
  uniform ::
       Double -- ^ lower bound a
    -> Double -- ^ upper bound b
    -> m Double -- ^ \(\sim \mathcal{U}(a, b)\).
  uniform a b = draw (uniformDistr a b)

  -- | Draw from a normal distribution.
  normal ::
       Double -- ^ mean μ
    -> Double -- ^ standard deviation σ
    -> m Double -- ^ \(\sim \mathcal{N}(\mu, \sigma^2)\)
  normal m s = draw (normalDistr m s)

  -- | Draw from a gamma distribution.
  gamma ::
       Double -- ^ shape k
    -> Double -- ^ scale θ
    -> m Double -- ^ \(\sim \Gamma(k, \theta)\)
  gamma shape scale = draw (gammaDistr shape scale)

  -- | Draw from a beta distribution.
  beta ::
       Double -- ^ shape α
    -> Double -- ^ shape β
    -> m Double -- ^ \(\sim \mathrm{Beta}(\alpha, \beta)\)
  beta a b = draw (betaDistr a b)

  -- | Draw from a Bernoulli distribution.
  bernoulli ::
       Double -- ^ probability p
    -> m Bool -- ^ \(\sim \mathrm{B}(1, p)\)
  bernoulli p = fmap (< p) random

  -- | Draw from a categorical distribution.
  categorical ::
       Vector v Double
    => v Double -- ^ event probabilities
    -> m Int -- ^ outcome category
  categorical ps = fromPMF (ps !)

  -- | Draw from a categorical distribution in the log domain.
  logCategorical ::
       (Vector v (Log Double), Vector v Double)
    => v (Log Double) -- ^ event probabilities
    -> m Int -- ^ outcome category
  logCategorical = categorical . VG.map (exp . ln)

  -- | Draw from a discrete uniform distribution.
  uniformD ::
       [a] -- ^ observable outcomes @xs@
    -> m a -- ^ \(\sim \mathcal{U}\{\mathrm{xs}\}\)
  uniformD xs = do
    let n = Prelude.length xs
    i <- categorical $ V.replicate n (1 / fromIntegral n)
    return (xs !! i)

  -- | Draw from a geometric distribution.
  geometric ::
       Double -- ^ success rate p
    -> m Int -- ^ \(\sim\) number of failed Bernoulli trials with success probability p before first success
  geometric = discrete . geometric0

  -- | Draw from a Poisson distribution.
  poisson ::
       Double -- ^ parameter λ
    -> m Int -- ^ \(\sim \mathrm{Pois}(\lambda)\)
  poisson = discrete . Poisson.poisson

  -- | Draw from a Dirichlet distribution.
  dirichlet ::
       Vector v Double
    => v Double -- ^ concentration parameters @as@
    -> m (v Double) -- ^ \(\sim \mathrm{Dir}(\mathrm{as})\)
  dirichlet as = do
    xs <- VG.mapM (`gamma` 1) as
    let s = VG.sum xs
    let ys = VG.map (/ s) xs
    return ys

-- | Draw from a continuous distribution using the inverse cumulative density
-- function.
draw :: (ContDistr d, MonadSample m) => d -> m Double
draw d = fmap (quantile d) random

-- | Draw from a discrete distribution using a sequence of draws from
-- Bernoulli.
fromPMF :: MonadSample m => (Int -> Double) -> m Int
fromPMF p = f 0 1 where
  f i r = do
    when (r < 0) $ error "fromPMF: total PMF above 1"
    let q = p i
    when (q < 0 || q > 1) $ error "fromPMF: invalid probability value"
    b <- bernoulli (q / r)
    if b then pure i else f (i+1) (r-q)

-- | Draw from a discrete distributions using the probability mass function.
discrete :: (DiscreteDistr d, MonadSample m) => d -> m Int
discrete = fromPMF . probability

-- | Monads that can score different execution paths.
class Monad m => MonadCond m where
  -- | Record a likelihood.
  score ::
       Log Double -- ^ likelihood of the execution path
    -> m ()

-- | Synonym for 'score'.
factor ::
     MonadCond m
  => Log Double -- ^ likelihood of the execution path
  -> m ()
factor = score

-- | Hard conditioning.
condition :: MonadCond m => Bool -> m ()
condition b = score $ if b then 1 else 0

-- | Monads that support both sampling and scoring.
class (MonadSample m, MonadCond m) => MonadInfer m

-- | Probability density function of the normal distribution.
normalPdf ::
     Double -- ^ mean μ
  -> Double -- ^ standard deviation σ
  -> Double -- ^ sample x
  -> Log Double -- ^ relative likelihood of observing sample x in \(\mathcal{N}(\mu, \sigma^2)\)
normalPdf mu sigma x = Exp $ logDensity (normalDistr mu sigma) x

----------------------------------------------------------------------------
-- Instances that lift probabilistic effects to standard tranformers.

instance MonadSample m => MonadSample (IdentityT m) where
  random = lift random
  bernoulli = lift . bernoulli

instance MonadCond m => MonadCond (IdentityT m) where
  score = lift . score

instance MonadInfer m => MonadInfer (IdentityT m)


instance MonadSample m => MonadSample (MaybeT m) where
  random = lift random

instance MonadCond m => MonadCond (MaybeT m) where
  score = lift . score

instance MonadInfer m => MonadInfer (MaybeT m)


instance MonadSample m => MonadSample (ReaderT r m) where
  random = lift random
  bernoulli = lift . bernoulli

instance MonadCond m => MonadCond (ReaderT r m) where
  score = lift . score

instance MonadInfer m => MonadInfer (ReaderT r m)


instance (Monoid w, MonadSample m) => MonadSample (WriterT w m) where
  random = lift random
  bernoulli = lift . bernoulli
  categorical = lift . categorical

instance (Monoid w, MonadCond m) => MonadCond (WriterT w m) where
  score = lift . score

instance (Monoid w, MonadInfer m) => MonadInfer (WriterT w m)


instance MonadSample m => MonadSample (StateT s m) where
  random = lift random
  bernoulli = lift . bernoulli
  categorical = lift . categorical

instance MonadCond m => MonadCond (StateT s m) where
  score = lift . score

instance MonadInfer m => MonadInfer (StateT s m)


instance (MonadSample m, Monoid w) => MonadSample (RWST r w s m) where
  random = lift random

instance (MonadCond m, Monoid w) => MonadCond (RWST r w s m) where
  score = lift . score

instance (MonadInfer m, Monoid w) => MonadInfer (RWST r w s m)


instance MonadSample m => MonadSample (ListT m) where
  random = lift random
  bernoulli = lift . bernoulli
  categorical = lift . categorical

instance MonadCond m => MonadCond (ListT m) where
  score = lift . score

instance MonadInfer m => MonadInfer (ListT m)


instance MonadSample m => MonadSample (ContT r m) where
  random = lift random

instance MonadCond m => MonadCond (ContT r m) where
  score = lift . score

instance MonadInfer m => MonadInfer (ContT r m)