{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}

-- |
-- Module      : Control.Monad.Bayes.Sampler
-- Description : Pseudo-random sampling monads
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- 'SamplerIO' and 'SamplerST' are instances of 'MonadDistribution'. Apply a 'MonadFactor'
-- transformer to obtain a 'MonadMeasure' that can execute probabilistic models.
module Control.Monad.Bayes.Sampler.Strict
  ( Sampler,
    SamplerIO,
    SamplerST,
    sampleIO,
    sampleIOfixed,
    sampleWith,
    sampleSTfixed,
    sampleMean,
    sampler,
  )
where

import Control.Foldl qualified as F hiding (random)
import Control.Monad.Bayes.Class
  ( MonadDistribution
      ( bernoulli,
        beta,
        categorical,
        gamma,
        geometric,
        normal,
        random,
        uniform
      ),
  )
import Control.Monad.Reader (MonadIO, ReaderT (..))
import Control.Monad.ST (ST)
import Numeric.Log (Log (ln))
import System.Random.MWC.Distributions qualified as MWC
import System.Random.Stateful (IOGenM (..), STGenM, StatefulGen, StdGen, initStdGen, mkStdGen, newIOGenM, newSTGenM, uniformDouble01M, uniformRM)

-- | The sampling interpretation of a probabilistic program
-- Here m is typically IO or ST
newtype Sampler g m a = Sampler (ReaderT g m a) deriving (forall a b. a -> Sampler g m b -> Sampler g m a
forall a b. (a -> b) -> Sampler g m a -> Sampler g m b
forall g (m :: * -> *) a b.
Functor m =>
a -> Sampler g m b -> Sampler g m a
forall g (m :: * -> *) a b.
Functor m =>
(a -> b) -> Sampler g m a -> Sampler g m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> Sampler g m b -> Sampler g m a
$c<$ :: forall g (m :: * -> *) a b.
Functor m =>
a -> Sampler g m b -> Sampler g m a
fmap :: forall a b. (a -> b) -> Sampler g m a -> Sampler g m b
$cfmap :: forall g (m :: * -> *) a b.
Functor m =>
(a -> b) -> Sampler g m a -> Sampler g m b
Functor, forall a. a -> Sampler g m a
forall a b. Sampler g m a -> Sampler g m b -> Sampler g m a
forall a b. Sampler g m a -> Sampler g m b -> Sampler g m b
forall a b. Sampler g m (a -> b) -> Sampler g m a -> Sampler g m b
forall a b c.
(a -> b -> c) -> Sampler g m a -> Sampler g m b -> Sampler g m c
forall {g} {m :: * -> *}. Applicative m => Functor (Sampler g m)
forall g (m :: * -> *) a. Applicative m => a -> Sampler g m a
forall g (m :: * -> *) a b.
Applicative m =>
Sampler g m a -> Sampler g m b -> Sampler g m a
forall g (m :: * -> *) a b.
Applicative m =>
Sampler g m a -> Sampler g m b -> Sampler g m b
forall g (m :: * -> *) a b.
Applicative m =>
Sampler g m (a -> b) -> Sampler g m a -> Sampler g m b
forall g (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> Sampler g m a -> Sampler g m b -> Sampler g m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. Sampler g m a -> Sampler g m b -> Sampler g m a
$c<* :: forall g (m :: * -> *) a b.
Applicative m =>
Sampler g m a -> Sampler g m b -> Sampler g m a
*> :: forall a b. Sampler g m a -> Sampler g m b -> Sampler g m b
$c*> :: forall g (m :: * -> *) a b.
Applicative m =>
Sampler g m a -> Sampler g m b -> Sampler g m b
liftA2 :: forall a b c.
(a -> b -> c) -> Sampler g m a -> Sampler g m b -> Sampler g m c
$cliftA2 :: forall g (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> Sampler g m a -> Sampler g m b -> Sampler g m c
<*> :: forall a b. Sampler g m (a -> b) -> Sampler g m a -> Sampler g m b
$c<*> :: forall g (m :: * -> *) a b.
Applicative m =>
Sampler g m (a -> b) -> Sampler g m a -> Sampler g m b
pure :: forall a. a -> Sampler g m a
$cpure :: forall g (m :: * -> *) a. Applicative m => a -> Sampler g m a
Applicative, forall a. a -> Sampler g m a
forall a b. Sampler g m a -> Sampler g m b -> Sampler g m b
forall a b. Sampler g m a -> (a -> Sampler g m b) -> Sampler g m b
forall {g} {m :: * -> *}. Monad m => Applicative (Sampler g m)
forall g (m :: * -> *) a. Monad m => a -> Sampler g m a
forall g (m :: * -> *) a b.
Monad m =>
Sampler g m a -> Sampler g m b -> Sampler g m b
forall g (m :: * -> *) a b.
Monad m =>
Sampler g m a -> (a -> Sampler g m b) -> Sampler g m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> Sampler g m a
$creturn :: forall g (m :: * -> *) a. Monad m => a -> Sampler g m a
>> :: forall a b. Sampler g m a -> Sampler g m b -> Sampler g m b
$c>> :: forall g (m :: * -> *) a b.
Monad m =>
Sampler g m a -> Sampler g m b -> Sampler g m b
>>= :: forall a b. Sampler g m a -> (a -> Sampler g m b) -> Sampler g m b
$c>>= :: forall g (m :: * -> *) a b.
Monad m =>
Sampler g m a -> (a -> Sampler g m b) -> Sampler g m b
Monad, forall a. IO a -> Sampler g m a
forall {g} {m :: * -> *}. MonadIO m => Monad (Sampler g m)
forall g (m :: * -> *) a. MonadIO m => IO a -> Sampler g m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
liftIO :: forall a. IO a -> Sampler g m a
$cliftIO :: forall g (m :: * -> *) a. MonadIO m => IO a -> Sampler g m a
MonadIO)

-- | convenient type synonym to show specializations of Sampler
-- to particular pairs of monad and RNG
type SamplerIO = Sampler (IOGenM StdGen) IO

-- | convenient type synonym to show specializations of Sampler
-- to particular pairs of monad and RNG
type SamplerST s = Sampler (STGenM StdGen s) (ST s)

instance StatefulGen g m => MonadDistribution (Sampler g m) where
  random :: Sampler g m Double
random = forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler (forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall g (m :: * -> *). StatefulGen g m => g -> m Double
uniformDouble01M)

  uniform :: Double -> Double -> Sampler g m Double
uniform Double
a Double
b = forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler (forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Double
a, Double
b))
  normal :: Double -> Double -> Sampler g m Double
normal Double
m Double
s = forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler (forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT (forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.normal Double
m Double
s))
  gamma :: Double -> Double -> Sampler g m Double
gamma Double
shape Double
scale = forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler (forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.gamma Double
shape Double
scale)
  beta :: Double -> Double -> Sampler g m Double
beta Double
a Double
b = forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler (forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.beta Double
a Double
b)

  bernoulli :: Double -> Sampler g m Bool
bernoulli Double
p = forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler (forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Bool
MWC.bernoulli Double
p)
  categorical :: forall (v :: * -> *).
Vector v Double =>
v Double -> Sampler g m Int
categorical v Double
ps = forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler (forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ forall g (m :: * -> *) (v :: * -> *).
(StatefulGen g m, Vector v Double) =>
v Double -> g -> m Int
MWC.categorical v Double
ps)
  geometric :: Double -> Sampler g m Int
geometric Double
p = forall g (m :: * -> *) a. ReaderT g m a -> Sampler g m a
Sampler (forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Int
MWC.geometric0 Double
p)

-- | Sample with a random number generator of your choice e.g. the one
-- from `System.Random`.
--
-- >>> import Control.Monad.Bayes.Class
-- >>> import System.Random.Stateful hiding (random)
-- >>> newIOGenM (mkStdGen 1729) >>= sampleWith random
-- 4.690861245089605e-2
sampleWith :: Sampler g m a -> g -> m a
sampleWith :: forall g (m :: * -> *) a. Sampler g m a -> g -> m a
sampleWith (Sampler ReaderT g m a
m) = forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT g m a
m

-- | initialize random seed using system entropy, and sample
sampleIO, sampler :: SamplerIO a -> IO a
sampleIO :: forall a. SamplerIO a -> IO a
sampleIO SamplerIO a
x = forall (m :: * -> *). MonadIO m => m StdGen
initStdGen forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) g. MonadIO m => g -> m (IOGenM g)
newIOGenM forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall g (m :: * -> *) a. Sampler g m a -> g -> m a
sampleWith SamplerIO a
x
sampler :: forall a. SamplerIO a -> IO a
sampler = forall a. SamplerIO a -> IO a
sampleIO

-- | Run the sampler with a fixed random seed
sampleIOfixed :: SamplerIO a -> IO a
sampleIOfixed :: forall a. SamplerIO a -> IO a
sampleIOfixed SamplerIO a
x = forall (m :: * -> *) g. MonadIO m => g -> m (IOGenM g)
newIOGenM (Int -> StdGen
mkStdGen Int
1729) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall g (m :: * -> *) a. Sampler g m a -> g -> m a
sampleWith SamplerIO a
x

-- | Run the sampler with a fixed random seed
sampleSTfixed :: SamplerST s b -> ST s b
sampleSTfixed :: forall s b. SamplerST s b -> ST s b
sampleSTfixed SamplerST s b
x = forall g s. g -> ST s (STGenM g s)
newSTGenM (Int -> StdGen
mkStdGen Int
1729) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall g (m :: * -> *) a. Sampler g m a -> g -> m a
sampleWith SamplerST s b
x

sampleMean :: [(Double, Log Double)] -> Double
sampleMean :: [(Double, Log Double)] -> Double
sampleMean [(Double, Log Double)]
samples =
  let z :: Fold (a, Log Double) Double
z = forall a b r. (a -> b) -> Fold b r -> Fold a r
F.premap (forall a. Log a -> a
ln forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a. Num a => Fold a a
F.sum
      w :: Fold (Double, Log Double) Double
w = (forall a b r. (a -> b) -> Fold b r -> Fold a r
F.premap (\(Double
x, Log Double
y) -> Double
x forall a. Num a => a -> a -> a
* forall a. Log a -> a
ln (forall a. Floating a => a -> a
exp Log Double
y)) forall a. Num a => Fold a a
F.sum)
      s :: Fold (Double, Log Double) Double
s = forall a. Fractional a => a -> a -> a
(/) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fold (Double, Log Double) Double
w forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {a}. Fold (a, Log Double) Double
z
   in forall (f :: * -> *) a b. Foldable f => Fold a b -> f a -> b
F.fold Fold (Double, Log Double) Double
s [(Double, Log Double)]
samples