{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}

-- |
-- Module      : Control.Monad.Bayes.Sampler
-- Description : Pseudo-random sampling monads
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- 'SamplerIO' and 'SamplerST' are instances of 'MonadDistribution'. Apply a 'MonadFactor'
-- transformer to obtain a 'MonadMeasure' that can execute probabilistic models.
module Control.Monad.Bayes.Sampler.Strict
  ( Sampler,
    SamplerIO,
    SamplerST,
    sampleIO,
    sampleIOfixed,
    sampleWith,
    sampleSTfixed,
    sampleMean,
    sampler,
  )
where

import Control.Foldl qualified as F hiding (random)
import Control.Monad.Bayes.Class
  ( MonadDistribution
      ( bernoulli,
        beta,
        categorical,
        gamma,
        geometric,
        normal,
        random,
        uniform
      ),
  )
import Control.Monad.Reader (MonadIO, ReaderT (..))
import Control.Monad.ST (ST)
import Numeric.Log (Log (ln))
import System.Random.MWC.Distributions qualified as MWC
import System.Random.Stateful (IOGenM (..), STGenM, StatefulGen, StdGen, initStdGen, mkStdGen, newIOGenM, newSTGenM, uniformDouble01M, uniformRM)

-- | The sampling interpretation of a probabilistic program
-- Here m is typically IO or ST
newtype Sampler g m a = Sampler (ReaderT g m a) deriving ((forall a b. (a -> b) -> Sampler g m a -> Sampler g m b)
-> (forall a b. a -> Sampler g m b -> Sampler g m a)
-> Functor (Sampler g m)
forall a b. a -> Sampler g m b -> Sampler g m a
forall a b. (a -> b) -> Sampler g m a -> Sampler g m b
forall g (m :: * -> *) a b.
Functor m =>
a -> Sampler g m b -> Sampler g m a
forall g (m :: * -> *) a b.
Functor m =>
(a -> b) -> Sampler g m a -> Sampler g m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall g (m :: * -> *) a b.
Functor m =>
(a -> b) -> Sampler g m a -> Sampler g m b
fmap :: forall a b. (a -> b) -> Sampler g m a -> Sampler g m b
$c<$ :: forall g (m :: * -> *) a b.
Functor m =>
a -> Sampler g m b -> Sampler g m a
<$ :: forall a b. a -> Sampler g m b -> Sampler g m a
Functor, Functor (Sampler g m)
Functor (Sampler g m)
-> (forall a. a -> Sampler g m a)
-> (forall a b.
    Sampler g m (a -> b) -> Sampler g m a -> Sampler g m b)
-> (forall a b c.
    (a -> b -> c) -> Sampler g m a -> Sampler g m b -> Sampler g m c)
-> (forall a b. Sampler g m a -> Sampler g m b -> Sampler g m b)
-> (forall a b. Sampler g m a -> Sampler g m b -> Sampler g m a)
-> Applicative (Sampler g m)
forall a. a -> Sampler g m a
forall a b. Sampler g m a -> Sampler g m b -> Sampler g m a
forall a b. Sampler g m a -> Sampler g m b -> Sampler g m b
forall a b. Sampler g m (a -> b) -> Sampler g m a -> Sampler g m b
forall a b c.
(a -> b -> c) -> Sampler g m a -> Sampler g m b -> Sampler g m c
forall {g} {m :: * -> *}. Applicative m => Functor (Sampler g m)
forall g (m :: * -> *) a. Applicative m => a -> Sampler g m a
forall g (m :: * -> *) a b.
Applicative m =>
Sampler g m a -> Sampler g m b -> Sampler g m a
forall g (m :: * -> *) a b.
Applicative m =>
Sampler g m a -> Sampler g m b -> Sampler g m b
forall g (m :: * -> *) a b.
Applicative m =>
Sampler g m (a -> b) -> Sampler g m a -> Sampler g m b
forall g (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> Sampler g m a -> Sampler g m b -> Sampler g m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall g (m :: * -> *) a. Applicative m => a -> Sampler g m a
pure :: forall a. a -> Sampler g m a
$c<*> :: forall g (m :: * -> *) a b.
Applicative m =>
Sampler g m (a -> b) -> Sampler g m a -> Sampler g m b
<*> :: forall a b. Sampler g m (a -> b) -> Sampler g m a -> Sampler g m b
$cliftA2 :: forall g (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> Sampler g m a -> Sampler g m b -> Sampler g m c
liftA2 :: forall a b c.
(a -> b -> c) -> Sampler g m a -> Sampler g m b -> Sampler g m c
$c*> :: forall g (m :: * -> *) a b.
Applicative m =>
Sampler g m a -> Sampler g m b -> Sampler g m b
*> :: forall a b. Sampler g m a -> Sampler g m b -> Sampler g m b
$c<* :: forall g (m :: * -> *) a b.
Applicative m =>
Sampler g m a -> Sampler g m b -> Sampler g m a
<* :: forall a b. Sampler g m a -> Sampler g m b -> Sampler g m a
Applicative, Applicative (Sampler g m)
Applicative (Sampler g m)
-> (forall a b.
    Sampler g m a -> (a -> Sampler g m b) -> Sampler g m b)
-> (forall a b. Sampler g m a -> Sampler g m b -> Sampler g m b)
-> (forall a. a -> Sampler g m a)
-> Monad (Sampler g m)
forall a. a -> Sampler g m a
forall a b. Sampler g m a -> Sampler g m b -> Sampler g m b
forall a b. Sampler g m a -> (a -> Sampler g m b) -> Sampler g m b
forall {g} {m :: * -> *}. Monad m => Applicative (Sampler g m)
forall g (m :: * -> *) a. Monad m => a -> Sampler g m a
forall g (m :: * -> *) a b.
Monad m =>
Sampler g m a -> Sampler g m b -> Sampler g m b
forall g (m :: * -> *) a b.
Monad m =>
Sampler g m a -> (a -> Sampler g m b) -> Sampler g m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall g (m :: * -> *) a b.
Monad m =>
Sampler g m a -> (a -> Sampler g m b) -> Sampler g m b
>>= :: forall a b. Sampler g m a -> (a -> Sampler g m b) -> Sampler g m b
$c>> :: forall g (m :: * -> *) a b.
Monad m =>
Sampler g m a -> Sampler g m b -> Sampler g m b
>> :: forall a b. Sampler g m a -> Sampler g m b -> Sampler g m b
$creturn :: forall g (m :: * -> *) a. Monad m => a -> Sampler g m a
return :: forall a. a -> Sampler g m a
Monad, Monad (Sampler g m)
Monad (Sampler g m)
-> (forall a. IO a -> Sampler g m a) -> MonadIO (Sampler g m)
forall a. IO a -> Sampler g m a
forall {g} {m :: * -> *}. MonadIO m => Monad (Sampler g m)
forall g (m :: * -> *) a. MonadIO m => IO a -> Sampler g m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
$cliftIO :: forall g (m :: * -> *) a. MonadIO m => IO a -> Sampler g m a
liftIO :: forall a. IO a -> Sampler g m a
MonadIO)

-- | convenient type synonym to show specializations of Sampler
-- to particular pairs of monad and RNG
type SamplerIO = Sampler (IOGenM StdGen) IO

-- | convenient type synonym to show specializations of Sampler
-- to particular pairs of monad and RNG
type SamplerST s = Sampler (STGenM StdGen s) (ST s)

instance StatefulGen g m => MonadDistribution (Sampler g m) where
  random :: Sampler g m Double
random = ReaderT g m Double -> Sampler g m Double
forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT g -> m Double
forall g (m :: * -> *). StatefulGen g m => g -> m Double
uniformDouble01M)

  uniform :: Double -> Double -> Sampler g m Double
uniform Double
a Double
b = ReaderT g m Double -> Sampler g m Double
forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Double) -> ReaderT g m Double)
-> (g -> m Double) -> ReaderT g m Double
forall a b. (a -> b) -> a -> b
$ (Double, Double) -> g -> m Double
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *).
StatefulGen g m =>
(Double, Double) -> g -> m Double
uniformRM (Double
a, Double
b))
  normal :: Double -> Double -> Sampler g m Double
normal Double
m Double
s = ReaderT g m Double -> Sampler g m Double
forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT (Double -> Double -> g -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.normal Double
m Double
s))
  gamma :: Double -> Double -> Sampler g m Double
gamma Double
shape Double
scale = ReaderT g m Double -> Sampler g m Double
forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Double) -> ReaderT g m Double)
-> (g -> m Double) -> ReaderT g m Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> g -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.gamma Double
shape Double
scale)
  beta :: Double -> Double -> Sampler g m Double
beta Double
a Double
b = ReaderT g m Double -> Sampler g m Double
forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Double) -> ReaderT g m Double)
-> (g -> m Double) -> ReaderT g m Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> g -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.beta Double
a Double
b)

  bernoulli :: Double -> Sampler g m Bool
bernoulli Double
p = ReaderT g m Bool -> Sampler g m Bool
forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler ((g -> m Bool) -> ReaderT g m Bool
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Bool) -> ReaderT g m Bool)
-> (g -> m Bool) -> ReaderT g m Bool
forall a b. (a -> b) -> a -> b
$ Double -> g -> m Bool
forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Bool
MWC.bernoulli Double
p)
  categorical :: forall (v :: * -> *).
Vector v Double =>
v Double -> Sampler g m Int
categorical v Double
ps = ReaderT g m Int -> Sampler g m Int
forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler ((g -> m Int) -> ReaderT g m Int
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Int) -> ReaderT g m Int)
-> (g -> m Int) -> ReaderT g m Int
forall a b. (a -> b) -> a -> b
$ v Double -> g -> m Int
forall g (m :: * -> *) (v :: * -> *).
(StatefulGen g m, Vector v Double) =>
v Double -> g -> m Int
MWC.categorical v Double
ps)
  geometric :: Double -> Sampler g m Int
geometric Double
p = ReaderT g m Int -> Sampler g m Int
forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler ((g -> m Int) -> ReaderT g m Int
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Int) -> ReaderT g m Int)
-> (g -> m Int) -> ReaderT g m Int
forall a b. (a -> b) -> a -> b
$ Double -> g -> m Int
forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Int
MWC.geometric0 Double
p)

-- | Sample with a random number generator of your choice e.g. the one
-- from `System.Random`.
--
-- >>> import Control.Monad.Bayes.Class
-- >>> import System.Random.Stateful hiding (random)
-- >>> newIOGenM (mkStdGen 1729) >>= sampleWith random
-- 4.690861245089605e-2
sampleWith :: Sampler g m a -> g -> m a
sampleWith :: forall g (m :: * -> *) a. Sampler g m a -> g -> m a
sampleWith (Sampler ReaderT g m a
m) = ReaderT g m a -> g -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT g m a
m

-- | initialize random seed using system entropy, and sample
sampleIO, sampler :: SamplerIO a -> IO a
sampleIO :: forall a. SamplerIO a -> IO a
sampleIO SamplerIO a
x = IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
initStdGen IO StdGen -> (StdGen -> IO (IOGenM StdGen)) -> IO (IOGenM StdGen)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StdGen -> IO (IOGenM StdGen)
forall (m :: * -> *) g. MonadIO m => g -> m (IOGenM g)
newIOGenM IO (IOGenM StdGen) -> (IOGenM StdGen -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SamplerIO a -> IOGenM StdGen -> IO a
forall g (m :: * -> *) a. Sampler g m a -> g -> m a
sampleWith SamplerIO a
x
sampler :: forall a. SamplerIO a -> IO a
sampler = SamplerIO a -> IO a
forall a. SamplerIO a -> IO a
sampleIO

-- | Run the sampler with a fixed random seed
sampleIOfixed :: SamplerIO a -> IO a
sampleIOfixed :: forall a. SamplerIO a -> IO a
sampleIOfixed SamplerIO a
x = StdGen -> IO (IOGenM StdGen)
forall (m :: * -> *) g. MonadIO m => g -> m (IOGenM g)
newIOGenM (Int -> StdGen
mkStdGen Int
1729) IO (IOGenM StdGen) -> (IOGenM StdGen -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SamplerIO a -> IOGenM StdGen -> IO a
forall g (m :: * -> *) a. Sampler g m a -> g -> m a
sampleWith SamplerIO a
x

-- | Run the sampler with a fixed random seed
sampleSTfixed :: SamplerST s b -> ST s b
sampleSTfixed :: forall s b. SamplerST s b -> ST s b
sampleSTfixed SamplerST s b
x = StdGen -> ST s (STGenM StdGen s)
forall g s. g -> ST s (STGenM g s)
newSTGenM (Int -> StdGen
mkStdGen Int
1729) ST s (STGenM StdGen s) -> (STGenM StdGen s -> ST s b) -> ST s b
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SamplerST s b -> STGenM StdGen s -> ST s b
forall g (m :: * -> *) a. Sampler g m a -> g -> m a
sampleWith SamplerST s b
x

sampleMean :: [(Double, Log Double)] -> Double
sampleMean :: [(Double, Log Double)] -> Double
sampleMean [(Double, Log Double)]
samples =
  let z :: Fold (a, Log Double) Double
z = ((a, Log Double) -> Double)
-> Fold Double Double -> Fold (a, Log Double) Double
forall a b r. (a -> b) -> Fold b r -> Fold a r
F.premap (Log Double -> Double
forall a. Log a -> a
ln (Log Double -> Double)
-> ((a, Log Double) -> Log Double) -> (a, Log Double) -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Log Double
forall a. Floating a => a -> a
exp (Log Double -> Log Double)
-> ((a, Log Double) -> Log Double) -> (a, Log Double) -> Log Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd) Fold Double Double
forall a. Num a => Fold a a
F.sum
      w :: Fold (Double, Log Double) Double
w = (((Double, Log Double) -> Double)
-> Fold Double Double -> Fold (Double, Log Double) Double
forall a b r. (a -> b) -> Fold b r -> Fold a r
F.premap (\(Double
x, Log Double
y) -> Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* Log Double -> Double
forall a. Log a -> a
ln (Log Double -> Log Double
forall a. Floating a => a -> a
exp Log Double
y)) Fold Double Double
forall a. Num a => Fold a a
F.sum)
      s :: Fold (Double, Log Double) Double
s = Double -> Double -> Double
forall a. Fractional a => a -> a -> a
(/) (Double -> Double -> Double)
-> Fold (Double, Log Double) Double
-> Fold (Double, Log Double) (Double -> Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fold (Double, Log Double) Double
w Fold (Double, Log Double) (Double -> Double)
-> Fold (Double, Log Double) Double
-> Fold (Double, Log Double) Double
forall a b.
Fold (Double, Log Double) (a -> b)
-> Fold (Double, Log Double) a -> Fold (Double, Log Double) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Fold (Double, Log Double) Double
forall {a}. Fold (a, Log Double) Double
z
   in Fold (Double, Log Double) Double
-> [(Double, Log Double)] -> Double
forall (f :: * -> *) a b. Foldable f => Fold a b -> f a -> b
F.fold Fold (Double, Log Double) Double
s [(Double, Log Double)]
samples