-- |
-- Module      : Control.Monad.Bayes.Traced.Common
-- Description : Numeric code for Trace MCMC
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
module Control.Monad.Bayes.Traced.Common
  ( Trace (..),
    singleton,
    scored,
    bind,
    mhTrans,
    mhTransWithBool,
    mhTransFree,
    mhTrans',
    burnIn,
    MHResult (..),
  )
where

import Control.Monad.Bayes.Class
  ( MonadDistribution (bernoulli, random),
    discrete,
  )
import Control.Monad.Bayes.Density.Free qualified as Free
import Control.Monad.Bayes.Density.State qualified as State
import Control.Monad.Bayes.Weighted as WeightedT
  ( WeightedT,
    hoist,
    runWeightedT,
  )
import Control.Monad.Writer (WriterT (WriterT, runWriterT))
import Data.Functor.Identity (Identity (runIdentity))
import Numeric.Log (Log, ln)
import Statistics.Distribution.DiscreteUniform (discreteUniformAB)

data MHResult a = MHResult
  { forall a. MHResult a -> Bool
success :: Bool,
    forall a. MHResult a -> Trace a
trace :: Trace a
  }

-- | Collection of random variables sampler during the program's execution.
data Trace a = Trace
  { -- | Sequence of random variables sampler during the program's execution.
    forall a. Trace a -> [Double]
variables :: [Double],
    --
    forall a. Trace a -> a
output :: a,
    -- | The probability of observing this particular sequence.
    forall a. Trace a -> Log Double
probDensity :: Log Double
  }

instance Functor Trace where
  fmap :: forall a b. (a -> b) -> Trace a -> Trace b
fmap a -> b
f Trace a
t = Trace a
t {output = f (output t)}

instance Applicative Trace where
  pure :: forall a. a -> Trace a
pure a
x = Trace {variables :: [Double]
variables = [], output :: a
output = a
x, probDensity :: Log Double
probDensity = Log Double
1}
  Trace (a -> b)
tf <*> :: forall a b. Trace (a -> b) -> Trace a -> Trace b
<*> Trace a
tx =
    Trace
      { variables :: [Double]
variables = Trace (a -> b) -> [Double]
forall a. Trace a -> [Double]
variables Trace (a -> b)
tf [Double] -> [Double] -> [Double]
forall a. [a] -> [a] -> [a]
++ Trace a -> [Double]
forall a. Trace a -> [Double]
variables Trace a
tx,
        output :: b
output = Trace (a -> b) -> a -> b
forall a. Trace a -> a
output Trace (a -> b)
tf (Trace a -> a
forall a. Trace a -> a
output Trace a
tx),
        probDensity :: Log Double
probDensity = Trace (a -> b) -> Log Double
forall a. Trace a -> Log Double
probDensity Trace (a -> b)
tf Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Trace a -> Log Double
forall a. Trace a -> Log Double
probDensity Trace a
tx
      }

instance Monad Trace where
  Trace a
t >>= :: forall a b. Trace a -> (a -> Trace b) -> Trace b
>>= a -> Trace b
f =
    let t' :: Trace b
t' = a -> Trace b
f (Trace a -> a
forall a. Trace a -> a
output Trace a
t)
     in Trace b
t' {variables = variables t ++ variables t', probDensity = probDensity t * probDensity t'}

singleton :: Double -> Trace Double
singleton :: Double -> Trace Double
singleton Double
u = Trace {variables :: [Double]
variables = [Double
u], output :: Double
output = Double
u, probDensity :: Log Double
probDensity = Log Double
1}

scored :: Log Double -> Trace ()
scored :: Log Double -> Trace ()
scored Log Double
w = Trace {variables :: [Double]
variables = [], output :: ()
output = (), probDensity :: Log Double
probDensity = Log Double
w}

bind :: (Monad m) => m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
bind :: forall (m :: * -> *) a b.
Monad m =>
m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)
bind m (Trace a)
dx a -> m (Trace b)
f = do
  Trace a
t1 <- m (Trace a)
dx
  Trace b
t2 <- a -> m (Trace b)
f (Trace a -> a
forall a. Trace a -> a
output Trace a
t1)
  Trace b -> m (Trace b)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Trace b -> m (Trace b)) -> Trace b -> m (Trace b)
forall a b. (a -> b) -> a -> b
$ Trace b
t2 {variables = variables t1 ++ variables t2, probDensity = probDensity t1 * probDensity t2}

-- | A single Metropolis-corrected transition of single-site Trace MCMC.
mhTrans :: (MonadDistribution m) => (WeightedT (State.DensityT m)) a -> Trace a -> m (Trace a)
mhTrans :: forall (m :: * -> *) a.
MonadDistribution m =>
WeightedT (DensityT m) a -> Trace a -> m (Trace a)
mhTrans WeightedT (DensityT m) a
m t :: Trace a
t@Trace {variables :: forall a. Trace a -> [Double]
variables = [Double]
us, probDensity :: forall a. Trace a -> Log Double
probDensity = Log Double
p} = do
  let n :: Int
n = [Double] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Double]
us
  [Double]
us' <- do
    Int
i <- DiscreteUniform -> m Int
forall d (m :: * -> *).
(DiscreteDistr d, MonadDistribution m) =>
d -> m Int
discrete (DiscreteUniform -> m Int) -> DiscreteUniform -> m Int
forall a b. (a -> b) -> a -> b
$ Int -> Int -> DiscreteUniform
discreteUniformAB Int
0 (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    Double
u' <- m Double
forall (m :: * -> *). MonadDistribution m => m Double
random
    case Int -> [Double] -> ([Double], [Double])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
i [Double]
us of
      ([Double]
xs, Double
_ : [Double]
ys) -> [Double] -> m [Double]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Double] -> m [Double]) -> [Double] -> m [Double]
forall a b. (a -> b) -> a -> b
$ [Double]
xs [Double] -> [Double] -> [Double]
forall a. [a] -> [a] -> [a]
++ (Double
u' Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: [Double]
ys)
      ([Double], [Double])
_ -> [Char] -> m [Double]
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
  ((a
b, Log Double
q), [Double]
vs) <- DensityT m (a, Log Double)
-> [Double] -> m ((a, Log Double), [Double])
forall (m :: * -> *) b.
Monad m =>
DensityT m b -> [Double] -> m (b, [Double])
State.runDensityT (WeightedT (DensityT m) a -> DensityT m (a, Log Double)
forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT WeightedT (DensityT m) a
m) [Double]
us'
  let ratio :: Double
ratio = (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln) (Log Double -> Double) -> Log Double -> Double
forall a b. (a -> b) -> a -> b
$ Log Double -> Log Double -> Log Double
forall a. Ord a => a -> a -> a
min Log Double
1 (Log Double
q Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Int -> Log Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ (Log Double
p Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Int -> Log Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Double] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Double]
vs)))
  Bool
accept <- Double -> m Bool
forall (m :: * -> *). MonadDistribution m => Double -> m Bool
bernoulli Double
ratio
  Trace a -> m (Trace a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Trace a -> m (Trace a)) -> Trace a -> m (Trace a)
forall a b. (a -> b) -> a -> b
$ if Bool
accept then [Double] -> a -> Log Double -> Trace a
forall a. [Double] -> a -> Log Double -> Trace a
Trace [Double]
vs a
b Log Double
q else Trace a
t

mhTransFree :: (MonadDistribution m) => WeightedT (Free.DensityT m) a -> Trace a -> m (Trace a)
mhTransFree :: forall (m :: * -> *) a.
MonadDistribution m =>
WeightedT (DensityT m) a -> Trace a -> m (Trace a)
mhTransFree WeightedT (DensityT m) a
m Trace a
t = MHResult a -> Trace a
forall a. MHResult a -> Trace a
trace (MHResult a -> Trace a) -> m (MHResult a) -> m (Trace a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> WeightedT (DensityT m) a -> Trace a -> m (MHResult a)
forall (m :: * -> *) a.
MonadDistribution m =>
WeightedT (DensityT m) a -> Trace a -> m (MHResult a)
mhTransWithBool WeightedT (DensityT m) a
m Trace a
t

-- | A single Metropolis-corrected transition of single-site Trace MCMC.
mhTransWithBool :: (MonadDistribution m) => WeightedT (Free.DensityT m) a -> Trace a -> m (MHResult a)
mhTransWithBool :: forall (m :: * -> *) a.
MonadDistribution m =>
WeightedT (DensityT m) a -> Trace a -> m (MHResult a)
mhTransWithBool WeightedT (DensityT m) a
m t :: Trace a
t@Trace {variables :: forall a. Trace a -> [Double]
variables = [Double]
us, probDensity :: forall a. Trace a -> Log Double
probDensity = Log Double
p} = do
  let n :: Int
n = [Double] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Double]
us
  [Double]
us' <- do
    Int
i <- DiscreteUniform -> m Int
forall d (m :: * -> *).
(DiscreteDistr d, MonadDistribution m) =>
d -> m Int
discrete (DiscreteUniform -> m Int) -> DiscreteUniform -> m Int
forall a b. (a -> b) -> a -> b
$ Int -> Int -> DiscreteUniform
discreteUniformAB Int
0 (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    Double
u' <- m Double
forall (m :: * -> *). MonadDistribution m => m Double
random
    case Int -> [Double] -> ([Double], [Double])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
i [Double]
us of
      ([Double]
xs, Double
_ : [Double]
ys) -> [Double] -> m [Double]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Double] -> m [Double]) -> [Double] -> m [Double]
forall a b. (a -> b) -> a -> b
$ [Double]
xs [Double] -> [Double] -> [Double]
forall a. [a] -> [a] -> [a]
++ (Double
u' Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: [Double]
ys)
      ([Double], [Double])
_ -> [Char] -> m [Double]
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
  ((a
b, Log Double
q), [Double]
vs) <- WriterT [Double] m (a, Log Double) -> m ((a, Log Double), [Double])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Double] m (a, Log Double)
 -> m ((a, Log Double), [Double]))
-> WriterT [Double] m (a, Log Double)
-> m ((a, Log Double), [Double])
forall a b. (a -> b) -> a -> b
$ WeightedT (WriterT [Double] m) a
-> WriterT [Double] m (a, Log Double)
forall (m :: * -> *) a. WeightedT m a -> m (a, Log Double)
runWeightedT (WeightedT (WriterT [Double] m) a
 -> WriterT [Double] m (a, Log Double))
-> WeightedT (WriterT [Double] m) a
-> WriterT [Double] m (a, Log Double)
forall a b. (a -> b) -> a -> b
$ (forall x. DensityT m x -> WriterT [Double] m x)
-> WeightedT (DensityT m) a -> WeightedT (WriterT [Double] m) a
forall (m :: * -> *) (n :: * -> *) a.
(forall x. m x -> n x) -> WeightedT m a -> WeightedT n a
WeightedT.hoist (m (x, [Double]) -> WriterT [Double] m x
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterT (m (x, [Double]) -> WriterT [Double] m x)
-> (DensityT m x -> m (x, [Double]))
-> DensityT m x
-> WriterT [Double] m x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Double] -> DensityT m x -> m (x, [Double])
forall (m :: * -> *) a.
MonadDistribution m =>
[Double] -> DensityT m a -> m (a, [Double])
Free.runDensityT [Double]
us') WeightedT (DensityT m) a
m
  let ratio :: Double
ratio = (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln) (Log Double -> Double) -> Log Double -> Double
forall a b. (a -> b) -> a -> b
$ Log Double -> Log Double -> Log Double
forall a. Ord a => a -> a -> a
min Log Double
1 (Log Double
q Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Int -> Log Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ (Log Double
p Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Int -> Log Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Double] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Double]
vs)))
  Bool
accept <- Double -> m Bool
forall (m :: * -> *). MonadDistribution m => Double -> m Bool
bernoulli Double
ratio
  MHResult a -> m (MHResult a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return if Bool
accept then Bool -> Trace a -> MHResult a
forall a. Bool -> Trace a -> MHResult a
MHResult Bool
True ([Double] -> a -> Log Double -> Trace a
forall a. [Double] -> a -> Log Double -> Trace a
Trace [Double]
vs a
b Log Double
q) else Bool -> Trace a -> MHResult a
forall a. Bool -> Trace a -> MHResult a
MHResult Bool
False Trace a
t

-- | A variant of 'mhTrans' with an external sampling monad.
mhTrans' :: (MonadDistribution m) => WeightedT (Free.DensityT Identity) a -> Trace a -> m (Trace a)
mhTrans' :: forall (m :: * -> *) a.
MonadDistribution m =>
WeightedT (DensityT Identity) a -> Trace a -> m (Trace a)
mhTrans' WeightedT (DensityT Identity) a
m = WeightedT (DensityT m) a -> Trace a -> m (Trace a)
forall (m :: * -> *) a.
MonadDistribution m =>
WeightedT (DensityT m) a -> Trace a -> m (Trace a)
mhTransFree ((forall x. DensityT Identity x -> DensityT m x)
-> WeightedT (DensityT Identity) a -> WeightedT (DensityT m) a
forall (m :: * -> *) (n :: * -> *) a.
(forall x. m x -> n x) -> WeightedT m a -> WeightedT n a
WeightedT.hoist ((forall x. Identity x -> m x)
-> DensityT Identity x -> DensityT m x
forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> DensityT m a -> DensityT n a
Free.hoist (x -> m x
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (x -> m x) -> (Identity x -> x) -> Identity x -> m x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity x -> x
forall a. Identity a -> a
runIdentity)) WeightedT (DensityT Identity) a
m)

-- | burn in an MCMC chain for n steps (which amounts to dropping samples of the end of the list)
burnIn :: (Functor m) => Int -> m [a] -> m [a]
burnIn :: forall (m :: * -> *) a. Functor m => Int -> m [a] -> m [a]
burnIn Int
n = ([a] -> [a]) -> m [a] -> m [a]
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [a] -> [a]
forall {a}. [a] -> [a]
dropEnd
  where
    dropEnd :: [a] -> [a]
dropEnd [a]
ls = let len :: Int
len = [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ls in Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n) [a]
ls