{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImportQualifiedPost #-}
module Control.Monad.Bayes.Sampler.Strict
( SamplerT (..),
SamplerIO,
SamplerST,
sampleIO,
sampleIOfixed,
sampleWith,
sampleSTfixed,
sampleMean,
sampler,
)
where
import Control.Foldl qualified as F hiding (random)
import Control.Monad.Bayes.Class
( MonadDistribution
( bernoulli,
beta,
categorical,
gamma,
geometric,
normal,
random,
uniform
),
)
import Control.Monad.Reader (MonadIO, ReaderT (..))
import Control.Monad.ST (ST)
import Control.Monad.Trans (MonadTrans)
import Numeric.Log (Log (ln))
import System.Random.MWC.Distributions qualified as MWC
import System.Random.Stateful (IOGenM (..), STGenM, StatefulGen, StdGen, initStdGen, mkStdGen, newIOGenM, newSTGenM, uniformDouble01M, uniformRM)
newtype SamplerT g m a = SamplerT {forall g (m :: * -> *) a. SamplerT g m a -> ReaderT g m a
runSamplerT :: ReaderT g m a} deriving ((forall a b. (a -> b) -> SamplerT g m a -> SamplerT g m b)
-> (forall a b. a -> SamplerT g m b -> SamplerT g m a)
-> Functor (SamplerT g m)
forall a b. a -> SamplerT g m b -> SamplerT g m a
forall a b. (a -> b) -> SamplerT g m a -> SamplerT g m b
forall g (m :: * -> *) a b.
Functor m =>
a -> SamplerT g m b -> SamplerT g m a
forall g (m :: * -> *) a b.
Functor m =>
(a -> b) -> SamplerT g m a -> SamplerT g m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall g (m :: * -> *) a b.
Functor m =>
(a -> b) -> SamplerT g m a -> SamplerT g m b
fmap :: forall a b. (a -> b) -> SamplerT g m a -> SamplerT g m b
$c<$ :: forall g (m :: * -> *) a b.
Functor m =>
a -> SamplerT g m b -> SamplerT g m a
<$ :: forall a b. a -> SamplerT g m b -> SamplerT g m a
Functor, Functor (SamplerT g m)
Functor (SamplerT g m) =>
(forall a. a -> SamplerT g m a)
-> (forall a b.
SamplerT g m (a -> b) -> SamplerT g m a -> SamplerT g m b)
-> (forall a b c.
(a -> b -> c)
-> SamplerT g m a -> SamplerT g m b -> SamplerT g m c)
-> (forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m b)
-> (forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m a)
-> Applicative (SamplerT g m)
forall a. a -> SamplerT g m a
forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m a
forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m b
forall a b.
SamplerT g m (a -> b) -> SamplerT g m a -> SamplerT g m b
forall a b c.
(a -> b -> c) -> SamplerT g m a -> SamplerT g m b -> SamplerT g m c
forall g (m :: * -> *). Applicative m => Functor (SamplerT g m)
forall g (m :: * -> *) a. Applicative m => a -> SamplerT g m a
forall g (m :: * -> *) a b.
Applicative m =>
SamplerT g m a -> SamplerT g m b -> SamplerT g m a
forall g (m :: * -> *) a b.
Applicative m =>
SamplerT g m a -> SamplerT g m b -> SamplerT g m b
forall g (m :: * -> *) a b.
Applicative m =>
SamplerT g m (a -> b) -> SamplerT g m a -> SamplerT g m b
forall g (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> SamplerT g m a -> SamplerT g m b -> SamplerT g m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall g (m :: * -> *) a. Applicative m => a -> SamplerT g m a
pure :: forall a. a -> SamplerT g m a
$c<*> :: forall g (m :: * -> *) a b.
Applicative m =>
SamplerT g m (a -> b) -> SamplerT g m a -> SamplerT g m b
<*> :: forall a b.
SamplerT g m (a -> b) -> SamplerT g m a -> SamplerT g m b
$cliftA2 :: forall g (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> SamplerT g m a -> SamplerT g m b -> SamplerT g m c
liftA2 :: forall a b c.
(a -> b -> c) -> SamplerT g m a -> SamplerT g m b -> SamplerT g m c
$c*> :: forall g (m :: * -> *) a b.
Applicative m =>
SamplerT g m a -> SamplerT g m b -> SamplerT g m b
*> :: forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m b
$c<* :: forall g (m :: * -> *) a b.
Applicative m =>
SamplerT g m a -> SamplerT g m b -> SamplerT g m a
<* :: forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m a
Applicative, Applicative (SamplerT g m)
Applicative (SamplerT g m) =>
(forall a b.
SamplerT g m a -> (a -> SamplerT g m b) -> SamplerT g m b)
-> (forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m b)
-> (forall a. a -> SamplerT g m a)
-> Monad (SamplerT g m)
forall a. a -> SamplerT g m a
forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m b
forall a b.
SamplerT g m a -> (a -> SamplerT g m b) -> SamplerT g m b
forall g (m :: * -> *). Monad m => Applicative (SamplerT g m)
forall g (m :: * -> *) a. Monad m => a -> SamplerT g m a
forall g (m :: * -> *) a b.
Monad m =>
SamplerT g m a -> SamplerT g m b -> SamplerT g m b
forall g (m :: * -> *) a b.
Monad m =>
SamplerT g m a -> (a -> SamplerT g m b) -> SamplerT g m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall g (m :: * -> *) a b.
Monad m =>
SamplerT g m a -> (a -> SamplerT g m b) -> SamplerT g m b
>>= :: forall a b.
SamplerT g m a -> (a -> SamplerT g m b) -> SamplerT g m b
$c>> :: forall g (m :: * -> *) a b.
Monad m =>
SamplerT g m a -> SamplerT g m b -> SamplerT g m b
>> :: forall a b. SamplerT g m a -> SamplerT g m b -> SamplerT g m b
$creturn :: forall g (m :: * -> *) a. Monad m => a -> SamplerT g m a
return :: forall a. a -> SamplerT g m a
Monad, Monad (SamplerT g m)
Monad (SamplerT g m) =>
(forall a. IO a -> SamplerT g m a) -> MonadIO (SamplerT g m)
forall a. IO a -> SamplerT g m a
forall g (m :: * -> *). MonadIO m => Monad (SamplerT g m)
forall g (m :: * -> *) a. MonadIO m => IO a -> SamplerT g m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
$cliftIO :: forall g (m :: * -> *) a. MonadIO m => IO a -> SamplerT g m a
liftIO :: forall a. IO a -> SamplerT g m a
MonadIO, (forall (m :: * -> *). Monad m => Monad (SamplerT g m)) =>
(forall (m :: * -> *) a. Monad m => m a -> SamplerT g m a)
-> MonadTrans (SamplerT g)
forall g (m :: * -> *). Monad m => Monad (SamplerT g m)
forall g (m :: * -> *) a. Monad m => m a -> SamplerT g m a
forall (m :: * -> *). Monad m => Monad (SamplerT g m)
forall (m :: * -> *) a. Monad m => m a -> SamplerT g m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *). Monad m => Monad (t m)) =>
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
$clift :: forall g (m :: * -> *) a. Monad m => m a -> SamplerT g m a
lift :: forall (m :: * -> *) a. Monad m => m a -> SamplerT g m a
MonadTrans)
type SamplerIO = SamplerT (IOGenM StdGen) IO
type SamplerST s = SamplerT (STGenM StdGen s) (ST s)
instance (StatefulGen g m) => MonadDistribution (SamplerT g m) where
random :: SamplerT g m Double
random = ReaderT g m Double -> SamplerT g m Double
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT g -> m Double
forall g (m :: * -> *). StatefulGen g m => g -> m Double
uniformDouble01M)
uniform :: Double -> Double -> SamplerT g m Double
uniform Double
a Double
b = ReaderT g m Double -> SamplerT g m Double
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Double) -> ReaderT g m Double)
-> (g -> m Double) -> ReaderT g m Double
forall a b. (a -> b) -> a -> b
$ (Double, Double) -> g -> m Double
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *).
StatefulGen g m =>
(Double, Double) -> g -> m Double
uniformRM (Double
a, Double
b))
normal :: Double -> Double -> SamplerT g m Double
normal Double
m Double
s = ReaderT g m Double -> SamplerT g m Double
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT (Double -> Double -> g -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.normal Double
m Double
s))
gamma :: Double -> Double -> SamplerT g m Double
gamma Double
shape Double
scale = ReaderT g m Double -> SamplerT g m Double
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Double) -> ReaderT g m Double)
-> (g -> m Double) -> ReaderT g m Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> g -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.gamma Double
shape Double
scale)
beta :: Double -> Double -> SamplerT g m Double
beta Double
a Double
b = ReaderT g m Double -> SamplerT g m Double
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Double) -> ReaderT g m Double
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Double) -> ReaderT g m Double)
-> (g -> m Double) -> ReaderT g m Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> g -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.beta Double
a Double
b)
bernoulli :: Double -> SamplerT g m Bool
bernoulli Double
p = ReaderT g m Bool -> SamplerT g m Bool
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Bool) -> ReaderT g m Bool
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Bool) -> ReaderT g m Bool)
-> (g -> m Bool) -> ReaderT g m Bool
forall a b. (a -> b) -> a -> b
$ Double -> g -> m Bool
forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Bool
MWC.bernoulli Double
p)
categorical :: forall (v :: * -> *).
Vector v Double =>
v Double -> SamplerT g m Int
categorical v Double
ps = ReaderT g m Int -> SamplerT g m Int
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Int) -> ReaderT g m Int
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Int) -> ReaderT g m Int)
-> (g -> m Int) -> ReaderT g m Int
forall a b. (a -> b) -> a -> b
$ v Double -> g -> m Int
forall g (m :: * -> *) (v :: * -> *).
(StatefulGen g m, Vector v Double) =>
v Double -> g -> m Int
MWC.categorical v Double
ps)
geometric :: Double -> SamplerT g m Int
geometric Double
p = ReaderT g m Int -> SamplerT g m Int
forall g (m :: * -> *) a. ReaderT g m a -> SamplerT g m a
SamplerT ((g -> m Int) -> ReaderT g m Int
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((g -> m Int) -> ReaderT g m Int)
-> (g -> m Int) -> ReaderT g m Int
forall a b. (a -> b) -> a -> b
$ Double -> g -> m Int
forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Int
MWC.geometric0 Double
p)
sampleWith :: SamplerT g m a -> g -> m a
sampleWith :: forall g (m :: * -> *) a. SamplerT g m a -> g -> m a
sampleWith (SamplerT ReaderT g m a
m) = ReaderT g m a -> g -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT g m a
m
sampleIO, sampler :: SamplerIO a -> IO a
sampleIO :: forall a. SamplerIO a -> IO a
sampleIO SamplerIO a
x = IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
initStdGen IO StdGen -> (StdGen -> IO (IOGenM StdGen)) -> IO (IOGenM StdGen)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StdGen -> IO (IOGenM StdGen)
forall (m :: * -> *) g. MonadIO m => g -> m (IOGenM g)
newIOGenM IO (IOGenM StdGen) -> (IOGenM StdGen -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SamplerIO a -> IOGenM StdGen -> IO a
forall g (m :: * -> *) a. SamplerT g m a -> g -> m a
sampleWith SamplerIO a
x
sampler :: forall a. SamplerIO a -> IO a
sampler = SamplerIO a -> IO a
forall a. SamplerIO a -> IO a
sampleIO
sampleIOfixed :: SamplerIO a -> IO a
sampleIOfixed :: forall a. SamplerIO a -> IO a
sampleIOfixed SamplerIO a
x = StdGen -> IO (IOGenM StdGen)
forall (m :: * -> *) g. MonadIO m => g -> m (IOGenM g)
newIOGenM (Int -> StdGen
mkStdGen Int
1729) IO (IOGenM StdGen) -> (IOGenM StdGen -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SamplerIO a -> IOGenM StdGen -> IO a
forall g (m :: * -> *) a. SamplerT g m a -> g -> m a
sampleWith SamplerIO a
x
sampleSTfixed :: SamplerST s b -> ST s b
sampleSTfixed :: forall s b. SamplerST s b -> ST s b
sampleSTfixed SamplerST s b
x = StdGen -> ST s (STGenM StdGen s)
forall g s. g -> ST s (STGenM g s)
newSTGenM (Int -> StdGen
mkStdGen Int
1729) ST s (STGenM StdGen s) -> (STGenM StdGen s -> ST s b) -> ST s b
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SamplerST s b -> STGenM StdGen s -> ST s b
forall g (m :: * -> *) a. SamplerT g m a -> g -> m a
sampleWith SamplerST s b
x
sampleMean :: [(Double, Log Double)] -> Double
sampleMean :: [(Double, Log Double)] -> Double
sampleMean [(Double, Log Double)]
samples =
let z :: Fold (a, Log Double) Double
z = ((a, Log Double) -> Double)
-> Fold Double Double -> Fold (a, Log Double) Double
forall a b r. (a -> b) -> Fold b r -> Fold a r
F.premap (Log Double -> Double
forall a. Log a -> a
ln (Log Double -> Double)
-> ((a, Log Double) -> Log Double) -> (a, Log Double) -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Log Double
forall a. Floating a => a -> a
exp (Log Double -> Log Double)
-> ((a, Log Double) -> Log Double) -> (a, Log Double) -> Log Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd) Fold Double Double
forall a. Num a => Fold a a
F.sum
w :: Fold (Double, Log Double) Double
w = (((Double, Log Double) -> Double)
-> Fold Double Double -> Fold (Double, Log Double) Double
forall a b r. (a -> b) -> Fold b r -> Fold a r
F.premap (\(Double
x, Log Double
y) -> Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* Log Double -> Double
forall a. Log a -> a
ln (Log Double -> Log Double
forall a. Floating a => a -> a
exp Log Double
y)) Fold Double Double
forall a. Num a => Fold a a
F.sum)
s :: Fold (Double, Log Double) Double
s = Double -> Double -> Double
forall a. Fractional a => a -> a -> a
(/) (Double -> Double -> Double)
-> Fold (Double, Log Double) Double
-> Fold (Double, Log Double) (Double -> Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fold (Double, Log Double) Double
w Fold (Double, Log Double) (Double -> Double)
-> Fold (Double, Log Double) Double
-> Fold (Double, Log Double) Double
forall a b.
Fold (Double, Log Double) (a -> b)
-> Fold (Double, Log Double) a -> Fold (Double, Log Double) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Fold (Double, Log Double) Double
forall {a}. Fold (a, Log Double) Double
z
in Fold (Double, Log Double) Double
-> [(Double, Log Double)] -> Double
forall (f :: * -> *) a b. Foldable f => Fold a b -> f a -> b
F.fold Fold (Double, Log Double) Double
s [(Double, Log Double)]
samples