module Control.Monad.Bayes.Class
( MonadSample,
random,
uniform,
normal,
gamma,
beta,
bernoulli,
categorical,
logCategorical,
uniformD,
geometric,
poisson,
dirichlet,
MonadCond,
score,
factor,
condition,
MonadInfer,
discrete,
normalPdf,
)
where
import Control.Monad (when)
import Control.Monad.Trans.Class
import Control.Monad.Trans.Cont
import Control.Monad.Trans.Identity
import Control.Monad.Trans.List
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.RWS hiding (tell)
import Control.Monad.Trans.Reader
import Control.Monad.Trans.State
import Control.Monad.Trans.Writer
import qualified Data.Vector as V
import Data.Vector.Generic as VG
import Numeric.Log
import Statistics.Distribution
import Statistics.Distribution.Beta (betaDistr)
import Statistics.Distribution.Gamma (gammaDistr)
import Statistics.Distribution.Geometric (geometric0)
import Statistics.Distribution.Normal (normalDistr)
import qualified Statistics.Distribution.Poisson as Poisson
import Statistics.Distribution.Uniform (uniformDistr)
class Monad m => MonadSample m where
random ::
m Double
uniform ::
Double ->
Double ->
m Double
uniform a :: Double
a b :: Double
b = UniformDistribution -> m Double
forall d (m :: * -> *).
(ContDistr d, MonadSample m) =>
d -> m Double
draw (Double -> Double -> UniformDistribution
uniformDistr Double
a Double
b)
normal ::
Double ->
Double ->
m Double
normal m :: Double
m s :: Double
s = NormalDistribution -> m Double
forall d (m :: * -> *).
(ContDistr d, MonadSample m) =>
d -> m Double
draw (Double -> Double -> NormalDistribution
normalDistr Double
m Double
s)
gamma ::
Double ->
Double ->
m Double
gamma shape :: Double
shape scale :: Double
scale = GammaDistribution -> m Double
forall d (m :: * -> *).
(ContDistr d, MonadSample m) =>
d -> m Double
draw (Double -> Double -> GammaDistribution
gammaDistr Double
shape Double
scale)
beta ::
Double ->
Double ->
m Double
beta a :: Double
a b :: Double
b = BetaDistribution -> m Double
forall d (m :: * -> *).
(ContDistr d, MonadSample m) =>
d -> m Double
draw (Double -> Double -> BetaDistribution
betaDistr Double
a Double
b)
bernoulli ::
Double ->
m Bool
bernoulli p :: Double
p = (Double -> Bool) -> m Double -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
p) m Double
forall (m :: * -> *). MonadSample m => m Double
random
categorical ::
Vector v Double =>
v Double ->
m Int
categorical ps :: v Double
ps = (Int -> Double) -> m Int
forall (m :: * -> *). MonadSample m => (Int -> Double) -> m Int
fromPMF (v Double
ps v Double -> Int -> Double
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
!)
logCategorical ::
(Vector v (Log Double), Vector v Double) =>
v (Log Double) ->
m Int
logCategorical = v Double -> m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> m Int
categorical (v Double -> m Int)
-> (v (Log Double) -> v Double) -> v (Log Double) -> m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Log Double -> Double) -> v (Log Double) -> v Double
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
VG.map (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double)
-> (Log Double -> Double) -> Log Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> Double
forall a. Log a -> a
ln)
uniformD ::
[a] ->
m a
uniformD xs :: [a]
xs = do
let n :: Int
n = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
Prelude.length [a]
xs
Int
i <- Vector Double -> m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> m Int
categorical (Vector Double -> m Int) -> Vector Double -> m Int
forall a b. (a -> b) -> a -> b
$ Int -> Double -> Vector Double
forall a. Int -> a -> Vector a
V.replicate Int
n (1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ([a]
xs [a] -> Int -> a
forall a. [a] -> Int -> a
!! Int
i)
geometric ::
Double ->
m Int
geometric = GeometricDistribution0 -> m Int
forall d (m :: * -> *).
(DiscreteDistr d, MonadSample m) =>
d -> m Int
discrete (GeometricDistribution0 -> m Int)
-> (Double -> GeometricDistribution0) -> Double -> m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> GeometricDistribution0
geometric0
poisson ::
Double ->
m Int
poisson = PoissonDistribution -> m Int
forall d (m :: * -> *).
(DiscreteDistr d, MonadSample m) =>
d -> m Int
discrete (PoissonDistribution -> m Int)
-> (Double -> PoissonDistribution) -> Double -> m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> PoissonDistribution
Poisson.poisson
dirichlet ::
Vector v Double =>
v Double ->
m (v Double)
dirichlet as :: v Double
as = do
v Double
xs <- (Double -> m Double) -> v Double -> m (v Double)
forall (m :: * -> *) (v :: * -> *) a b.
(Monad m, Vector v a, Vector v b) =>
(a -> m b) -> v a -> m (v b)
VG.mapM (Double -> Double -> m Double
forall (m :: * -> *). MonadSample m => Double -> Double -> m Double
`gamma` 1) v Double
as
let s :: Double
s = v Double -> Double
forall (v :: * -> *) a. (Vector v a, Num a) => v a -> a
VG.sum v Double
xs
let ys :: v Double
ys = (Double -> Double) -> v Double -> v Double
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
VG.map (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
s) v Double
xs
v Double -> m (v Double)
forall (m :: * -> *) a. Monad m => a -> m a
return v Double
ys
draw :: (ContDistr d, MonadSample m) => d -> m Double
draw :: d -> m Double
draw d :: d
d = (Double -> Double) -> m Double -> m Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (d -> Double -> Double
forall d. ContDistr d => d -> Double -> Double
quantile d
d) m Double
forall (m :: * -> *). MonadSample m => m Double
random
fromPMF :: MonadSample m => (Int -> Double) -> m Int
fromPMF :: (Int -> Double) -> m Int
fromPMF p :: Int -> Double
p = Int -> Double -> m Int
f 0 1
where
f :: Int -> Double -> m Int
f i :: Int
i r :: Double
r = do
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Double
r Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< 0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error "fromPMF: total PMF above 1"
let q :: Double
q = Int -> Double
p Int
i
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Double
q Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< 0 Bool -> Bool -> Bool
|| Double
q Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> 1) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error "fromPMF: invalid probability value"
Bool
b <- Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli (Double
q Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
r)
if Bool
b then Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
i else Int -> Double -> m Int
f (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) (Double
r Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
q)
discrete :: (DiscreteDistr d, MonadSample m) => d -> m Int
discrete :: d -> m Int
discrete = (Int -> Double) -> m Int
forall (m :: * -> *). MonadSample m => (Int -> Double) -> m Int
fromPMF ((Int -> Double) -> m Int) -> (d -> Int -> Double) -> d -> m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. d -> Int -> Double
forall d. DiscreteDistr d => d -> Int -> Double
probability
class Monad m => MonadCond m where
score ::
Log Double ->
m ()
factor ::
MonadCond m =>
Log Double ->
m ()
factor :: Log Double -> m ()
factor = Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score
condition :: MonadCond m => Bool -> m ()
condition :: Bool -> m ()
condition b :: Bool
b = Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score (Log Double -> m ()) -> Log Double -> m ()
forall a b. (a -> b) -> a -> b
$ if Bool
b then 1 else 0
class (MonadSample m, MonadCond m) => MonadInfer m
normalPdf ::
Double ->
Double ->
Double ->
Log Double
normalPdf :: Double -> Double -> Double -> Log Double
normalPdf mu :: Double
mu sigma :: Double
sigma x :: Double
x = Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double) -> Double -> Log Double
forall a b. (a -> b) -> a -> b
$ NormalDistribution -> Double -> Double
forall d. ContDistr d => d -> Double -> Double
logDensity (Double -> Double -> NormalDistribution
normalDistr Double
mu Double
sigma) Double
x
instance MonadSample m => MonadSample (IdentityT m) where
random :: IdentityT m Double
random = m Double -> IdentityT m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
bernoulli :: Double -> IdentityT m Bool
bernoulli = m Bool -> IdentityT m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Bool -> IdentityT m Bool)
-> (Double -> m Bool) -> Double -> IdentityT m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli
instance MonadCond m => MonadCond (IdentityT m) where
score :: Log Double -> IdentityT m ()
score = m () -> IdentityT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> IdentityT m ())
-> (Log Double -> m ()) -> Log Double -> IdentityT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score
instance MonadInfer m => MonadInfer (IdentityT m)
instance MonadSample m => MonadSample (MaybeT m) where
random :: MaybeT m Double
random = m Double -> MaybeT m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
instance MonadCond m => MonadCond (MaybeT m) where
score :: Log Double -> MaybeT m ()
score = m () -> MaybeT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> MaybeT m ())
-> (Log Double -> m ()) -> Log Double -> MaybeT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score
instance MonadInfer m => MonadInfer (MaybeT m)
instance MonadSample m => MonadSample (ReaderT r m) where
random :: ReaderT r m Double
random = m Double -> ReaderT r m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
bernoulli :: Double -> ReaderT r m Bool
bernoulli = m Bool -> ReaderT r m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Bool -> ReaderT r m Bool)
-> (Double -> m Bool) -> Double -> ReaderT r m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli
instance MonadCond m => MonadCond (ReaderT r m) where
score :: Log Double -> ReaderT r m ()
score = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ())
-> (Log Double -> m ()) -> Log Double -> ReaderT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score
instance MonadInfer m => MonadInfer (ReaderT r m)
instance (Monoid w, MonadSample m) => MonadSample (WriterT w m) where
random :: WriterT w m Double
random = m Double -> WriterT w m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
bernoulli :: Double -> WriterT w m Bool
bernoulli = m Bool -> WriterT w m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Bool -> WriterT w m Bool)
-> (Double -> m Bool) -> Double -> WriterT w m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli
categorical :: v Double -> WriterT w m Int
categorical = m Int -> WriterT w m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Int -> WriterT w m Int)
-> (v Double -> m Int) -> v Double -> WriterT w m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v Double -> m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> m Int
categorical
instance (Monoid w, MonadCond m) => MonadCond (WriterT w m) where
score :: Log Double -> WriterT w m ()
score = m () -> WriterT w m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> WriterT w m ())
-> (Log Double -> m ()) -> Log Double -> WriterT w m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score
instance (Monoid w, MonadInfer m) => MonadInfer (WriterT w m)
instance MonadSample m => MonadSample (StateT s m) where
random :: StateT s m Double
random = m Double -> StateT s m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
bernoulli :: Double -> StateT s m Bool
bernoulli = m Bool -> StateT s m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Bool -> StateT s m Bool)
-> (Double -> m Bool) -> Double -> StateT s m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli
categorical :: v Double -> StateT s m Int
categorical = m Int -> StateT s m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Int -> StateT s m Int)
-> (v Double -> m Int) -> v Double -> StateT s m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v Double -> m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> m Int
categorical
instance MonadCond m => MonadCond (StateT s m) where
score :: Log Double -> StateT s m ()
score = m () -> StateT s m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> StateT s m ())
-> (Log Double -> m ()) -> Log Double -> StateT s m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score
instance MonadInfer m => MonadInfer (StateT s m)
instance (MonadSample m, Monoid w) => MonadSample (RWST r w s m) where
random :: RWST r w s m Double
random = m Double -> RWST r w s m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
instance (MonadCond m, Monoid w) => MonadCond (RWST r w s m) where
score :: Log Double -> RWST r w s m ()
score = m () -> RWST r w s m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> RWST r w s m ())
-> (Log Double -> m ()) -> Log Double -> RWST r w s m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score
instance (MonadInfer m, Monoid w) => MonadInfer (RWST r w s m)
instance MonadSample m => MonadSample (ListT m) where
random :: ListT m Double
random = m Double -> ListT m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
bernoulli :: Double -> ListT m Bool
bernoulli = m Bool -> ListT m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Bool -> ListT m Bool)
-> (Double -> m Bool) -> Double -> ListT m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> m Bool
forall (m :: * -> *). MonadSample m => Double -> m Bool
bernoulli
categorical :: v Double -> ListT m Int
categorical = m Int -> ListT m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Int -> ListT m Int)
-> (v Double -> m Int) -> v Double -> ListT m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v Double -> m Int
forall (m :: * -> *) (v :: * -> *).
(MonadSample m, Vector v Double) =>
v Double -> m Int
categorical
instance MonadCond m => MonadCond (ListT m) where
score :: Log Double -> ListT m ()
score = m () -> ListT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ListT m ())
-> (Log Double -> m ()) -> Log Double -> ListT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score
instance MonadInfer m => MonadInfer (ListT m)
instance MonadSample m => MonadSample (ContT r m) where
random :: ContT r m Double
random = m Double -> ContT r m Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Double
forall (m :: * -> *). MonadSample m => m Double
random
instance MonadCond m => MonadCond (ContT r m) where
score :: Log Double -> ContT r m ()
score = m () -> ContT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ContT r m ())
-> (Log Double -> m ()) -> Log Double -> ContT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Log Double -> m ()
forall (m :: * -> *). MonadCond m => Log Double -> m ()
score
instance MonadInfer m => MonadInfer (ContT r m)