{-| Module : Control.Monad.Bayes.Free Description : Free monad transformer over random sampling Copyright : (c) Adam Scibior, 2015-2020 License : MIT Maintainer : leonhard.markert@tweag.io Stability : experimental Portability : GHC 'FreeSampler' is a free monad transformer over random sampling. -} module Control.Monad.Bayes.Free ( FreeSampler, hoist, interpret, withRandomness, withPartialRandomness, runWith ) where import Data.Functor.Identity import Control.Monad.Trans import Control.Monad.Writer import Control.Monad.State import Control.Monad.Trans.Free.Church import Control.Monad.Bayes.Class -- | Random sampling functor. newtype SamF a = Random (Double -> a) instance Functor SamF where fmap f (Random k) = Random (f . k) -- | Free monad transformer over random sampling. -- Uses the Church-encoded version of the free monad for efficiency. newtype FreeSampler m a = FreeSampler (FT SamF m a) deriving(Functor,Applicative,Monad,MonadTrans) runFreeSampler :: FreeSampler m a -> FT SamF m a runFreeSampler (FreeSampler m) = m instance Monad m => MonadFree SamF (FreeSampler m) where wrap = FreeSampler . wrap . fmap runFreeSampler instance Monad m => MonadSample (FreeSampler m) where random = FreeSampler $ liftF (Random id) -- | Hoist 'FreeSampler' through a monad transform. hoist :: (Monad m, Monad n) => (forall x. m x -> n x) -> FreeSampler m a -> FreeSampler n a hoist f (FreeSampler m) = FreeSampler (hoistFT f m) -- | Execute random sampling in the transformed monad. interpret :: MonadSample m => FreeSampler m a -> m a interpret (FreeSampler m) = iterT f m where f (Random k) = random >>= k -- | Execute computation with supplied values for random choices. withRandomness :: Monad m => [Double] -> FreeSampler m a -> m a withRandomness randomness (FreeSampler m) = evalStateT (iterTM f m) randomness where f (Random k) = do xs <- get case xs of [] -> error "FreeSampler: the list of randomness was too short" y:ys -> put ys >> k y -- | Execute computation with supplied values for a subset of random choices. -- Return the output value and a record of all random choices used, whether -- taken as input or drawn using the transformed monad. withPartialRandomness :: MonadSample m => [Double] -> FreeSampler m a -> m (a, [Double]) withPartialRandomness randomness (FreeSampler m) = runWriterT $ evalStateT (iterTM f $ hoistFT lift m) randomness where f (Random k) = do -- This block runs in StateT [Double] (WriterT [Double]) m. -- StateT propagates consumed randomness while WriterT records -- randomness used, whether old or new. xs <- get x <- case xs of [] -> random y:ys -> put ys >> return y tell [x] k x -- | Like 'withPartialRandomness', but use an arbitrary sampling monad. runWith :: MonadSample m => [Double] -> FreeSampler Identity a -> m (a, [Double]) runWith randomness m = withPartialRandomness randomness $ hoist (return . runIdentity) m