-- |
-- Module      : Control.Monad.Bayes.Sampler
-- Description : Pseudo-random sampling monads
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- 'SamplerIO' and 'SamplerST' are instances of 'MonadSample'. Apply a 'MonadCond'
-- transformer to obtain a 'MonadInfer' that can execute probabilistic models.
module Control.Monad.Bayes.Sampler
  ( SamplerIO,
    sampleIO,
    sampleIOfixed,
    sampleIOwith,
    Seed,
    SamplerST (SamplerST),
    runSamplerST,
    sampleST,
    sampleSTfixed,
  )
where

import Control.Monad.Bayes.Class
import Control.Monad.ST (ST, runST, stToIO)
import Control.Monad.State (State, state)
import Control.Monad.Trans (MonadIO, lift)
import Control.Monad.Trans.Reader (ReaderT, ask, mapReaderT, runReaderT)
import System.Random.MWC
import qualified System.Random.MWC.Distributions as MWC

-- | An 'IO' based random sampler using the MWC-Random package.
newtype SamplerIO a = SamplerIO (ReaderT GenIO IO a)
  deriving (a -> SamplerIO b -> SamplerIO a
(a -> b) -> SamplerIO a -> SamplerIO b
(forall a b. (a -> b) -> SamplerIO a -> SamplerIO b)
-> (forall a b. a -> SamplerIO b -> SamplerIO a)
-> Functor SamplerIO
forall a b. a -> SamplerIO b -> SamplerIO a
forall a b. (a -> b) -> SamplerIO a -> SamplerIO b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> SamplerIO b -> SamplerIO a
$c<$ :: forall a b. a -> SamplerIO b -> SamplerIO a
fmap :: (a -> b) -> SamplerIO a -> SamplerIO b
$cfmap :: forall a b. (a -> b) -> SamplerIO a -> SamplerIO b
Functor, Functor SamplerIO
a -> SamplerIO a
Functor SamplerIO =>
(forall a. a -> SamplerIO a)
-> (forall a b. SamplerIO (a -> b) -> SamplerIO a -> SamplerIO b)
-> (forall a b c.
    (a -> b -> c) -> SamplerIO a -> SamplerIO b -> SamplerIO c)
-> (forall a b. SamplerIO a -> SamplerIO b -> SamplerIO b)
-> (forall a b. SamplerIO a -> SamplerIO b -> SamplerIO a)
-> Applicative SamplerIO
SamplerIO a -> SamplerIO b -> SamplerIO b
SamplerIO a -> SamplerIO b -> SamplerIO a
SamplerIO (a -> b) -> SamplerIO a -> SamplerIO b
(a -> b -> c) -> SamplerIO a -> SamplerIO b -> SamplerIO c
forall a. a -> SamplerIO a
forall a b. SamplerIO a -> SamplerIO b -> SamplerIO a
forall a b. SamplerIO a -> SamplerIO b -> SamplerIO b
forall a b. SamplerIO (a -> b) -> SamplerIO a -> SamplerIO b
forall a b c.
(a -> b -> c) -> SamplerIO a -> SamplerIO b -> SamplerIO c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: SamplerIO a -> SamplerIO b -> SamplerIO a
$c<* :: forall a b. SamplerIO a -> SamplerIO b -> SamplerIO a
*> :: SamplerIO a -> SamplerIO b -> SamplerIO b
$c*> :: forall a b. SamplerIO a -> SamplerIO b -> SamplerIO b
liftA2 :: (a -> b -> c) -> SamplerIO a -> SamplerIO b -> SamplerIO c
$cliftA2 :: forall a b c.
(a -> b -> c) -> SamplerIO a -> SamplerIO b -> SamplerIO c
<*> :: SamplerIO (a -> b) -> SamplerIO a -> SamplerIO b
$c<*> :: forall a b. SamplerIO (a -> b) -> SamplerIO a -> SamplerIO b
pure :: a -> SamplerIO a
$cpure :: forall a. a -> SamplerIO a
$cp1Applicative :: Functor SamplerIO
Applicative, Applicative SamplerIO
a -> SamplerIO a
Applicative SamplerIO =>
(forall a b. SamplerIO a -> (a -> SamplerIO b) -> SamplerIO b)
-> (forall a b. SamplerIO a -> SamplerIO b -> SamplerIO b)
-> (forall a. a -> SamplerIO a)
-> Monad SamplerIO
SamplerIO a -> (a -> SamplerIO b) -> SamplerIO b
SamplerIO a -> SamplerIO b -> SamplerIO b
forall a. a -> SamplerIO a
forall a b. SamplerIO a -> SamplerIO b -> SamplerIO b
forall a b. SamplerIO a -> (a -> SamplerIO b) -> SamplerIO b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> SamplerIO a
$creturn :: forall a. a -> SamplerIO a
>> :: SamplerIO a -> SamplerIO b -> SamplerIO b
$c>> :: forall a b. SamplerIO a -> SamplerIO b -> SamplerIO b
>>= :: SamplerIO a -> (a -> SamplerIO b) -> SamplerIO b
$c>>= :: forall a b. SamplerIO a -> (a -> SamplerIO b) -> SamplerIO b
$cp1Monad :: Applicative SamplerIO
Monad, Monad SamplerIO
Monad SamplerIO =>
(forall a. IO a -> SamplerIO a) -> MonadIO SamplerIO
IO a -> SamplerIO a
forall a. IO a -> SamplerIO a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
liftIO :: IO a -> SamplerIO a
$cliftIO :: forall a. IO a -> SamplerIO a
$cp1MonadIO :: Monad SamplerIO
MonadIO)

-- | Initialize a pseudo-random number generator using randomness supplied by
-- the operating system.
-- For efficiency this operation should be applied at the very end, ideally
-- once per program.
sampleIO :: SamplerIO a -> IO a
sampleIO :: SamplerIO a -> IO a
sampleIO (SamplerIO m :: ReaderT GenIO IO a
m) = IO (Gen RealWorld)
IO GenIO
createSystemRandom IO (Gen RealWorld) -> (Gen RealWorld -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ReaderT (Gen RealWorld) IO a -> Gen RealWorld -> IO a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Gen RealWorld) IO a
ReaderT GenIO IO a
m

-- | Like 'sampleIO', but with a fixed random seed.
-- Useful for reproducibility.
sampleIOfixed :: SamplerIO a -> IO a
sampleIOfixed :: SamplerIO a -> IO a
sampleIOfixed (SamplerIO m :: ReaderT GenIO IO a
m) = IO (Gen RealWorld)
forall (m :: * -> *). PrimMonad m => m (Gen (PrimState m))
create IO (Gen RealWorld) -> (Gen RealWorld -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ReaderT (Gen RealWorld) IO a -> Gen RealWorld -> IO a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Gen RealWorld) IO a
ReaderT GenIO IO a
m

-- | Like 'sampleIO' but with a custom pseudo-random number generator.
sampleIOwith :: SamplerIO a -> GenIO -> IO a
sampleIOwith :: SamplerIO a -> GenIO -> IO a
sampleIOwith (SamplerIO m :: ReaderT GenIO IO a
m) = ReaderT (Gen RealWorld) IO a -> Gen RealWorld -> IO a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Gen RealWorld) IO a
ReaderT GenIO IO a
m

fromSamplerST :: SamplerST a -> SamplerIO a
fromSamplerST :: SamplerST a -> SamplerIO a
fromSamplerST (SamplerST m :: forall s. ReaderT (GenST s) (ST s) a
m) = ReaderT GenIO IO a -> SamplerIO a
forall a. ReaderT GenIO IO a -> SamplerIO a
SamplerIO (ReaderT GenIO IO a -> SamplerIO a)
-> ReaderT GenIO IO a -> SamplerIO a
forall a b. (a -> b) -> a -> b
$ (ST RealWorld a -> IO a)
-> ReaderT (Gen RealWorld) (ST RealWorld) a
-> ReaderT (Gen RealWorld) IO a
forall (m :: * -> *) a (n :: * -> *) b r.
(m a -> n b) -> ReaderT r m a -> ReaderT r n b
mapReaderT ST RealWorld a -> IO a
forall a. ST RealWorld a -> IO a
stToIO ReaderT (Gen RealWorld) (ST RealWorld) a
forall s. ReaderT (GenST s) (ST s) a
m

instance MonadSample SamplerIO where
  random :: SamplerIO Double
random = SamplerST Double -> SamplerIO Double
forall a. SamplerST a -> SamplerIO a
fromSamplerST SamplerST Double
forall (m :: * -> *). MonadSample m => m Double
random

-- | An 'ST' based random sampler using the @mwc-random@ package.
newtype SamplerST a = SamplerST (forall s. ReaderT (GenST s) (ST s) a)

runSamplerST :: SamplerST a -> ReaderT (GenST s) (ST s) a
runSamplerST :: SamplerST a -> ReaderT (GenST s) (ST s) a
runSamplerST (SamplerST s :: forall s. ReaderT (GenST s) (ST s) a
s) = ReaderT (GenST s) (ST s) a
forall s. ReaderT (GenST s) (ST s) a
s

instance Functor SamplerST where
  fmap :: (a -> b) -> SamplerST a -> SamplerST b
fmap f :: a -> b
f (SamplerST s :: forall s. ReaderT (GenST s) (ST s) a
s) = (forall s. ReaderT (GenST s) (ST s) b) -> SamplerST b
forall a. (forall s. ReaderT (GenST s) (ST s) a) -> SamplerST a
SamplerST ((forall s. ReaderT (GenST s) (ST s) b) -> SamplerST b)
-> (forall s. ReaderT (GenST s) (ST s) b) -> SamplerST b
forall a b. (a -> b) -> a -> b
$ (a -> b) -> ReaderT (Gen s) (ST s) a -> ReaderT (Gen s) (ST s) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f ReaderT (Gen s) (ST s) a
forall s. ReaderT (GenST s) (ST s) a
s

instance Applicative SamplerST where
  pure :: a -> SamplerST a
pure x :: a
x = (forall s. ReaderT (GenST s) (ST s) a) -> SamplerST a
forall a. (forall s. ReaderT (GenST s) (ST s) a) -> SamplerST a
SamplerST ((forall s. ReaderT (GenST s) (ST s) a) -> SamplerST a)
-> (forall s. ReaderT (GenST s) (ST s) a) -> SamplerST a
forall a b. (a -> b) -> a -> b
$ a -> ReaderT (Gen s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
  (SamplerST f :: forall s. ReaderT (GenST s) (ST s) (a -> b)
f) <*> :: SamplerST (a -> b) -> SamplerST a -> SamplerST b
<*> (SamplerST x :: forall s. ReaderT (GenST s) (ST s) a
x) = (forall s. ReaderT (GenST s) (ST s) b) -> SamplerST b
forall a. (forall s. ReaderT (GenST s) (ST s) a) -> SamplerST a
SamplerST ((forall s. ReaderT (GenST s) (ST s) b) -> SamplerST b)
-> (forall s. ReaderT (GenST s) (ST s) b) -> SamplerST b
forall a b. (a -> b) -> a -> b
$ ReaderT (Gen s) (ST s) (a -> b)
forall s. ReaderT (GenST s) (ST s) (a -> b)
f ReaderT (Gen s) (ST s) (a -> b)
-> ReaderT (Gen s) (ST s) a -> ReaderT (Gen s) (ST s) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ReaderT (Gen s) (ST s) a
forall s. ReaderT (GenST s) (ST s) a
x

instance Monad SamplerST where
  (SamplerST x :: forall s. ReaderT (GenST s) (ST s) a
x) >>= :: SamplerST a -> (a -> SamplerST b) -> SamplerST b
>>= f :: a -> SamplerST b
f = (forall s. ReaderT (GenST s) (ST s) b) -> SamplerST b
forall a. (forall s. ReaderT (GenST s) (ST s) a) -> SamplerST a
SamplerST ((forall s. ReaderT (GenST s) (ST s) b) -> SamplerST b)
-> (forall s. ReaderT (GenST s) (ST s) b) -> SamplerST b
forall a b. (a -> b) -> a -> b
$ ReaderT (Gen s) (ST s) a
forall s. ReaderT (GenST s) (ST s) a
x ReaderT (Gen s) (ST s) a
-> (a -> ReaderT (Gen s) (ST s) b) -> ReaderT (Gen s) (ST s) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SamplerST b -> ReaderT (Gen s) (ST s) b
forall a s. SamplerST a -> ReaderT (GenST s) (ST s) a
runSamplerST (SamplerST b -> ReaderT (Gen s) (ST s) b)
-> (a -> SamplerST b) -> a -> ReaderT (Gen s) (ST s) b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> SamplerST b
f

-- | Run the sampler with a supplied seed.
-- Note that 'State Seed' is much less efficient than 'SamplerST' for composing computation.
sampleST :: SamplerST a -> State Seed a
sampleST :: SamplerST a -> State Seed a
sampleST (SamplerST s :: forall s. ReaderT (GenST s) (ST s) a
s) =
  (Seed -> (a, Seed)) -> State Seed a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state ((Seed -> (a, Seed)) -> State Seed a)
-> (Seed -> (a, Seed)) -> State Seed a
forall a b. (a -> b) -> a -> b
$ \seed :: Seed
seed -> (forall s. ST s (a, Seed)) -> (a, Seed)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (a, Seed)) -> (a, Seed))
-> (forall s. ST s (a, Seed)) -> (a, Seed)
forall a b. (a -> b) -> a -> b
$ do
    Gen s
gen <- Seed -> ST s (Gen (PrimState (ST s)))
forall (m :: * -> *). PrimMonad m => Seed -> m (Gen (PrimState m))
restore Seed
seed
    a
y <- ReaderT (Gen s) (ST s) a -> Gen s -> ST s a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Gen s) (ST s) a
forall s. ReaderT (GenST s) (ST s) a
s Gen s
gen
    Seed
finalSeed <- Gen (PrimState (ST s)) -> ST s Seed
forall (m :: * -> *). PrimMonad m => Gen (PrimState m) -> m Seed
save Gen s
Gen (PrimState (ST s))
gen
    (a, Seed) -> ST s (a, Seed)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
y, Seed
finalSeed)

-- | Run the sampler with a fixed random seed.
sampleSTfixed :: SamplerST a -> a
sampleSTfixed :: SamplerST a -> a
sampleSTfixed (SamplerST s :: forall s. ReaderT (GenST s) (ST s) a
s) = (forall s. ST s a) -> a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s a) -> a) -> (forall s. ST s a) -> a
forall a b. (a -> b) -> a -> b
$ do
  Gen s
gen <- ST s (Gen s)
forall (m :: * -> *). PrimMonad m => m (Gen (PrimState m))
create
  ReaderT (Gen s) (ST s) a -> Gen s -> ST s a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Gen s) (ST s) a
forall s. ReaderT (GenST s) (ST s) a
s Gen s
gen

-- | Convert a distribution supplied by @mwc-random@.
fromMWC :: (forall s. GenST s -> ST s a) -> SamplerST a
fromMWC :: (forall s. GenST s -> ST s a) -> SamplerST a
fromMWC s :: forall s. GenST s -> ST s a
s = (forall s. ReaderT (GenST s) (ST s) a) -> SamplerST a
forall a. (forall s. ReaderT (GenST s) (ST s) a) -> SamplerST a
SamplerST ((forall s. ReaderT (GenST s) (ST s) a) -> SamplerST a)
-> (forall s. ReaderT (GenST s) (ST s) a) -> SamplerST a
forall a b. (a -> b) -> a -> b
$ ReaderT (Gen s) (ST s) (Gen s)
forall (m :: * -> *) r. Monad m => ReaderT r m r
ask ReaderT (Gen s) (ST s) (Gen s)
-> (Gen s -> ReaderT (Gen s) (ST s) a) -> ReaderT (Gen s) (ST s) a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ST s a -> ReaderT (Gen s) (ST s) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s a -> ReaderT (Gen s) (ST s) a)
-> (Gen s -> ST s a) -> Gen s -> ReaderT (Gen s) (ST s) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Gen s -> ST s a
forall s. GenST s -> ST s a
s

instance MonadSample SamplerST where
  random :: SamplerST Double
random = (forall s. GenST s -> ST s Double) -> SamplerST Double
forall a. (forall s. GenST s -> ST s a) -> SamplerST a
fromMWC forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
Gen (PrimState m) -> m a
forall s. GenST s -> ST s Double
System.Random.MWC.uniform

  uniform :: Double -> Double -> SamplerST Double
uniform a :: Double
a b :: Double
b = (forall s. GenST s -> ST s Double) -> SamplerST Double
forall a. (forall s. GenST s -> ST s a) -> SamplerST a
fromMWC ((forall s. GenST s -> ST s Double) -> SamplerST Double)
-> (forall s. GenST s -> ST s Double) -> SamplerST Double
forall a b. (a -> b) -> a -> b
$ (Double, Double) -> Gen (PrimState (ST s)) -> ST s Double
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
(a, a) -> Gen (PrimState m) -> m a
uniformR (Double
a, Double
b)
  normal :: Double -> Double -> SamplerST Double
normal m :: Double
m s :: Double
s = (forall s. GenST s -> ST s Double) -> SamplerST Double
forall a. (forall s. GenST s -> ST s a) -> SamplerST a
fromMWC ((forall s. GenST s -> ST s Double) -> SamplerST Double)
-> (forall s. GenST s -> ST s Double) -> SamplerST Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Gen (PrimState (ST s)) -> ST s Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Gen (PrimState m) -> m Double
MWC.normal Double
m Double
s
  gamma :: Double -> Double -> SamplerST Double
gamma shape :: Double
shape scale :: Double
scale = (forall s. GenST s -> ST s Double) -> SamplerST Double
forall a. (forall s. GenST s -> ST s a) -> SamplerST a
fromMWC ((forall s. GenST s -> ST s Double) -> SamplerST Double)
-> (forall s. GenST s -> ST s Double) -> SamplerST Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Gen (PrimState (ST s)) -> ST s Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Gen (PrimState m) -> m Double
MWC.gamma Double
shape Double
scale
  beta :: Double -> Double -> SamplerST Double
beta a :: Double
a b :: Double
b = (forall s. GenST s -> ST s Double) -> SamplerST Double
forall a. (forall s. GenST s -> ST s a) -> SamplerST a
fromMWC ((forall s. GenST s -> ST s Double) -> SamplerST Double)
-> (forall s. GenST s -> ST s Double) -> SamplerST Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Gen (PrimState (ST s)) -> ST s Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Gen (PrimState m) -> m Double
MWC.beta Double
a Double
b

  bernoulli :: Double -> SamplerST Bool
bernoulli p :: Double
p = (forall s. GenST s -> ST s Bool) -> SamplerST Bool
forall a. (forall s. GenST s -> ST s a) -> SamplerST a
fromMWC ((forall s. GenST s -> ST s Bool) -> SamplerST Bool)
-> (forall s. GenST s -> ST s Bool) -> SamplerST Bool
forall a b. (a -> b) -> a -> b
$ Double -> Gen (PrimState (ST s)) -> ST s Bool
forall (m :: * -> *).
PrimMonad m =>
Double -> Gen (PrimState m) -> m Bool
MWC.bernoulli Double
p
  categorical :: v Double -> SamplerST Int
categorical ps :: v Double
ps = (forall s. GenST s -> ST s Int) -> SamplerST Int
forall a. (forall s. GenST s -> ST s a) -> SamplerST a
fromMWC ((forall s. GenST s -> ST s Int) -> SamplerST Int)
-> (forall s. GenST s -> ST s Int) -> SamplerST Int
forall a b. (a -> b) -> a -> b
$ v Double -> Gen (PrimState (ST s)) -> ST s Int
forall (m :: * -> *) (v :: * -> *).
(PrimMonad m, Vector v Double) =>
v Double -> Gen (PrimState m) -> m Int
MWC.categorical v Double
ps
  geometric :: Double -> SamplerST Int
geometric p :: Double
p = (forall s. GenST s -> ST s Int) -> SamplerST Int
forall a. (forall s. GenST s -> ST s a) -> SamplerST a
fromMWC ((forall s. GenST s -> ST s Int) -> SamplerST Int)
-> (forall s. GenST s -> ST s Int) -> SamplerST Int
forall a b. (a -> b) -> a -> b
$ Double -> Gen (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
Double -> Gen (PrimState m) -> m Int
MWC.geometric0 Double
p