module Control.Monad.Bayes.Sampler (
SamplerIO,
sampleIO,
sampleIOfixed,
sampleIOwith,
Seed,
SamplerST(SamplerST),
runSamplerST,
sampleST,
sampleSTfixed
) where
import Control.Monad.ST (ST, runST, stToIO)
import System.Random.MWC
import qualified System.Random.MWC.Distributions as MWC
import Control.Monad.State (State, state)
import Control.Monad.Trans (lift, MonadIO)
import Control.Monad.Trans.Reader (ReaderT, runReaderT, ask, mapReaderT)
import Control.Monad.Bayes.Class
newtype SamplerIO a = SamplerIO (ReaderT GenIO IO a)
deriving(Functor, Applicative, Monad, MonadIO)
sampleIO :: SamplerIO a -> IO a
sampleIO (SamplerIO m) = createSystemRandom >>= runReaderT m
sampleIOfixed :: SamplerIO a -> IO a
sampleIOfixed (SamplerIO m) = create >>= runReaderT m
sampleIOwith :: SamplerIO a -> GenIO -> IO a
sampleIOwith (SamplerIO m) = runReaderT m
fromSamplerST :: SamplerST a -> SamplerIO a
fromSamplerST (SamplerST m) = SamplerIO $ mapReaderT stToIO m
instance MonadSample SamplerIO where
random = fromSamplerST random
newtype SamplerST a = SamplerST (forall s. ReaderT (GenST s) (ST s) a)
runSamplerST :: SamplerST a -> ReaderT (GenST s) (ST s) a
runSamplerST (SamplerST s) = s
instance Functor SamplerST where
fmap f (SamplerST s) = SamplerST $ fmap f s
instance Applicative SamplerST where
pure x = SamplerST $ pure x
(SamplerST f) <*> (SamplerST x) = SamplerST $ f <*> x
instance Monad SamplerST where
(SamplerST x) >>= f = SamplerST $ x >>= runSamplerST . f
sampleST :: SamplerST a -> State Seed a
sampleST (SamplerST s) =
state $ \seed -> runST $ do
gen <- restore seed
y <- runReaderT s gen
finalSeed <- save gen
return (y, finalSeed)
sampleSTfixed :: SamplerST a -> a
sampleSTfixed (SamplerST s) = runST $ do
gen <- create
runReaderT s gen
fromMWC :: (forall s. GenST s -> ST s a) -> SamplerST a
fromMWC s = SamplerST $ ask >>= lift . s
instance MonadSample SamplerST where
random = fromMWC System.Random.MWC.uniform
uniform a b = fromMWC $ uniformR (a,b)
normal m s = fromMWC $ MWC.normal m s
gamma shape scale = fromMWC $ MWC.gamma shape scale
beta a b = fromMWC $ MWC.beta a b
bernoulli p = fromMWC $ MWC.bernoulli p
categorical ps = fromMWC $ MWC.categorical ps
geometric p = fromMWC $ MWC.geometric0 p