-- |
-- Module      : Control.Monad.Bayes.Free
-- Description : Free monad transformer over random sampling
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
--
-- 'FreeSampler' is a free monad transformer over random sampling.
module Control.Monad.Bayes.Free
  ( FreeSampler,
    hoist,
    interpret,
    withRandomness,
    withPartialRandomness,
    runWith,
  )
where

import Control.Monad.Bayes.Class
import Control.Monad.State (evalStateT, get, put)
import Control.Monad.Trans (MonadTrans (..))
import Control.Monad.Trans.Free.Church (FT, MonadFree (..), hoistFT, iterT, iterTM, liftF)
import Control.Monad.Writer (WriterT (..), tell)
import Data.Functor.Identity (Identity, runIdentity)

-- | Random sampling functor.
newtype SamF a = Random (Double -> a)

instance Functor SamF where
  fmap :: (a -> b) -> SamF a -> SamF b
fmap f :: a -> b
f (Random k :: Double -> a
k) = (Double -> b) -> SamF b
forall a. (Double -> a) -> SamF a
Random (a -> b
f (a -> b) -> (Double -> a) -> Double -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> a
k)

-- | Free monad transformer over random sampling.
--
-- Uses the Church-encoded version of the free monad for efficiency.
newtype FreeSampler m a = FreeSampler {FreeSampler m a -> FT SamF m a
runFreeSampler :: FT SamF m a}
  deriving (a -> FreeSampler m b -> FreeSampler m a
(a -> b) -> FreeSampler m a -> FreeSampler m b
(forall a b. (a -> b) -> FreeSampler m a -> FreeSampler m b)
-> (forall a b. a -> FreeSampler m b -> FreeSampler m a)
-> Functor (FreeSampler m)
forall a b. a -> FreeSampler m b -> FreeSampler m a
forall a b. (a -> b) -> FreeSampler m a -> FreeSampler m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (m :: * -> *) a b. a -> FreeSampler m b -> FreeSampler m a
forall (m :: * -> *) a b.
(a -> b) -> FreeSampler m a -> FreeSampler m b
<$ :: a -> FreeSampler m b -> FreeSampler m a
$c<$ :: forall (m :: * -> *) a b. a -> FreeSampler m b -> FreeSampler m a
fmap :: (a -> b) -> FreeSampler m a -> FreeSampler m b
$cfmap :: forall (m :: * -> *) a b.
(a -> b) -> FreeSampler m a -> FreeSampler m b
Functor, Functor (FreeSampler m)
a -> FreeSampler m a
Functor (FreeSampler m) =>
(forall a. a -> FreeSampler m a)
-> (forall a b.
    FreeSampler m (a -> b) -> FreeSampler m a -> FreeSampler m b)
-> (forall a b c.
    (a -> b -> c)
    -> FreeSampler m a -> FreeSampler m b -> FreeSampler m c)
-> (forall a b.
    FreeSampler m a -> FreeSampler m b -> FreeSampler m b)
-> (forall a b.
    FreeSampler m a -> FreeSampler m b -> FreeSampler m a)
-> Applicative (FreeSampler m)
FreeSampler m a -> FreeSampler m b -> FreeSampler m b
FreeSampler m a -> FreeSampler m b -> FreeSampler m a
FreeSampler m (a -> b) -> FreeSampler m a -> FreeSampler m b
(a -> b -> c)
-> FreeSampler m a -> FreeSampler m b -> FreeSampler m c
forall a. a -> FreeSampler m a
forall a b. FreeSampler m a -> FreeSampler m b -> FreeSampler m a
forall a b. FreeSampler m a -> FreeSampler m b -> FreeSampler m b
forall a b.
FreeSampler m (a -> b) -> FreeSampler m a -> FreeSampler m b
forall a b c.
(a -> b -> c)
-> FreeSampler m a -> FreeSampler m b -> FreeSampler m c
forall (m :: * -> *). Functor (FreeSampler m)
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall (m :: * -> *) a. a -> FreeSampler m a
forall (m :: * -> *) a b.
FreeSampler m a -> FreeSampler m b -> FreeSampler m a
forall (m :: * -> *) a b.
FreeSampler m a -> FreeSampler m b -> FreeSampler m b
forall (m :: * -> *) a b.
FreeSampler m (a -> b) -> FreeSampler m a -> FreeSampler m b
forall (m :: * -> *) a b c.
(a -> b -> c)
-> FreeSampler m a -> FreeSampler m b -> FreeSampler m c
<* :: FreeSampler m a -> FreeSampler m b -> FreeSampler m a
$c<* :: forall (m :: * -> *) a b.
FreeSampler m a -> FreeSampler m b -> FreeSampler m a
*> :: FreeSampler m a -> FreeSampler m b -> FreeSampler m b
$c*> :: forall (m :: * -> *) a b.
FreeSampler m a -> FreeSampler m b -> FreeSampler m b
liftA2 :: (a -> b -> c)
-> FreeSampler m a -> FreeSampler m b -> FreeSampler m c
$cliftA2 :: forall (m :: * -> *) a b c.
(a -> b -> c)
-> FreeSampler m a -> FreeSampler m b -> FreeSampler m c
<*> :: FreeSampler m (a -> b) -> FreeSampler m a -> FreeSampler m b
$c<*> :: forall (m :: * -> *) a b.
FreeSampler m (a -> b) -> FreeSampler m a -> FreeSampler m b
pure :: a -> FreeSampler m a
$cpure :: forall (m :: * -> *) a. a -> FreeSampler m a
$cp1Applicative :: forall (m :: * -> *). Functor (FreeSampler m)
Applicative, Applicative (FreeSampler m)
a -> FreeSampler m a
Applicative (FreeSampler m) =>
(forall a b.
 FreeSampler m a -> (a -> FreeSampler m b) -> FreeSampler m b)
-> (forall a b.
    FreeSampler m a -> FreeSampler m b -> FreeSampler m b)
-> (forall a. a -> FreeSampler m a)
-> Monad (FreeSampler m)
FreeSampler m a -> (a -> FreeSampler m b) -> FreeSampler m b
FreeSampler m a -> FreeSampler m b -> FreeSampler m b
forall a. a -> FreeSampler m a
forall a b. FreeSampler m a -> FreeSampler m b -> FreeSampler m b
forall a b.
FreeSampler m a -> (a -> FreeSampler m b) -> FreeSampler m b
forall (m :: * -> *). Applicative (FreeSampler m)
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
forall (m :: * -> *) a. a -> FreeSampler m a
forall (m :: * -> *) a b.
FreeSampler m a -> FreeSampler m b -> FreeSampler m b
forall (m :: * -> *) a b.
FreeSampler m a -> (a -> FreeSampler m b) -> FreeSampler m b
return :: a -> FreeSampler m a
$creturn :: forall (m :: * -> *) a. a -> FreeSampler m a
>> :: FreeSampler m a -> FreeSampler m b -> FreeSampler m b
$c>> :: forall (m :: * -> *) a b.
FreeSampler m a -> FreeSampler m b -> FreeSampler m b
>>= :: FreeSampler m a -> (a -> FreeSampler m b) -> FreeSampler m b
$c>>= :: forall (m :: * -> *) a b.
FreeSampler m a -> (a -> FreeSampler m b) -> FreeSampler m b
$cp1Monad :: forall (m :: * -> *). Applicative (FreeSampler m)
Monad, m a -> FreeSampler m a
(forall (m :: * -> *) a. Monad m => m a -> FreeSampler m a)
-> MonadTrans FreeSampler
forall (m :: * -> *) a. Monad m => m a -> FreeSampler m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> FreeSampler m a
$clift :: forall (m :: * -> *) a. Monad m => m a -> FreeSampler m a
MonadTrans)

instance Monad m => MonadFree SamF (FreeSampler m) where
  wrap :: SamF (FreeSampler m a) -> FreeSampler m a
wrap = FT SamF m a -> FreeSampler m a
forall (m :: * -> *) a. FT SamF m a -> FreeSampler m a
FreeSampler (FT SamF m a -> FreeSampler m a)
-> (SamF (FreeSampler m a) -> FT SamF m a)
-> SamF (FreeSampler m a)
-> FreeSampler m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SamF (FT SamF m a) -> FT SamF m a
forall (f :: * -> *) (m :: * -> *) a.
MonadFree f m =>
f (m a) -> m a
wrap (SamF (FT SamF m a) -> FT SamF m a)
-> (SamF (FreeSampler m a) -> SamF (FT SamF m a))
-> SamF (FreeSampler m a)
-> FT SamF m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FreeSampler m a -> FT SamF m a)
-> SamF (FreeSampler m a) -> SamF (FT SamF m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap FreeSampler m a -> FT SamF m a
forall (m :: * -> *) a. FreeSampler m a -> FT SamF m a
runFreeSampler

instance Monad m => MonadSample (FreeSampler m) where
  random :: FreeSampler m Double
random = FT SamF m Double -> FreeSampler m Double
forall (m :: * -> *) a. FT SamF m a -> FreeSampler m a
FreeSampler (FT SamF m Double -> FreeSampler m Double)
-> FT SamF m Double -> FreeSampler m Double
forall a b. (a -> b) -> a -> b
$ SamF Double -> FT SamF m Double
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF ((Double -> Double) -> SamF Double
forall a. (Double -> a) -> SamF a
Random Double -> Double
forall a. a -> a
id)

-- | Hoist 'FreeSampler' through a monad transform.
hoist :: (Monad m, Monad n) => (forall x. m x -> n x) -> FreeSampler m a -> FreeSampler n a
hoist :: (forall x. m x -> n x) -> FreeSampler m a -> FreeSampler n a
hoist f :: forall x. m x -> n x
f (FreeSampler m :: FT SamF m a
m) = FT SamF n a -> FreeSampler n a
forall (m :: * -> *) a. FT SamF m a -> FreeSampler m a
FreeSampler ((forall x. m x -> n x) -> FT SamF m a -> FT SamF n a
forall (m :: * -> *) (n :: * -> *) (f :: * -> *) b.
(Monad m, Monad n) =>
(forall a. m a -> n a) -> FT f m b -> FT f n b
hoistFT forall x. m x -> n x
f FT SamF m a
m)

-- | Execute random sampling in the transformed monad.
interpret :: MonadSample m => FreeSampler m a -> m a
interpret :: FreeSampler m a -> m a
interpret (FreeSampler m :: FT SamF m a
m) = (SamF (m a) -> m a) -> FT SamF m a -> m a
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(f (m a) -> m a) -> FT f m a -> m a
iterT SamF (m a) -> m a
forall (m :: * -> *) b. MonadSample m => SamF (m b) -> m b
f FT SamF m a
m
  where
    f :: SamF (m b) -> m b
f (Random k :: Double -> m b
k) = m Double
forall (m :: * -> *). MonadSample m => m Double
random m Double -> (Double -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Double -> m b
k

-- | Execute computation with supplied values for random choices.
withRandomness :: Monad m => [Double] -> FreeSampler m a -> m a
withRandomness :: [Double] -> FreeSampler m a -> m a
withRandomness randomness :: [Double]
randomness (FreeSampler m :: FT SamF m a
m) = StateT [Double] m a -> [Double] -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((SamF (StateT [Double] m a) -> StateT [Double] m a)
-> FT SamF m a -> StateT [Double] m a
forall (f :: * -> *) (m :: * -> *) (t :: (* -> *) -> * -> *) a.
(Functor f, Monad m, MonadTrans t, Monad (t m)) =>
(f (t m a) -> t m a) -> FT f m a -> t m a
iterTM SamF (StateT [Double] m a) -> StateT [Double] m a
forall (m :: * -> *) b. MonadState [Double] m => SamF (m b) -> m b
f FT SamF m a
m) [Double]
randomness
  where
    f :: SamF (m b) -> m b
f (Random k :: Double -> m b
k) = do
      [Double]
xs <- m [Double]
forall s (m :: * -> *). MonadState s m => m s
get
      case [Double]
xs of
        [] -> [Char] -> m b
forall a. HasCallStack => [Char] -> a
error "FreeSampler: the list of randomness was too short"
        y :: Double
y : ys :: [Double]
ys -> [Double] -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [Double]
ys m () -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Double -> m b
k Double
y

-- | Execute computation with supplied values for a subset of random choices.
-- Return the output value and a record of all random choices used, whether
-- taken as input or drawn using the transformed monad.
withPartialRandomness :: MonadSample m => [Double] -> FreeSampler m a -> m (a, [Double])
withPartialRandomness :: [Double] -> FreeSampler m a -> m (a, [Double])
withPartialRandomness randomness :: [Double]
randomness (FreeSampler m :: FT SamF m a
m) =
  WriterT [Double] m a -> m (a, [Double])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Double] m a -> m (a, [Double]))
-> WriterT [Double] m a -> m (a, [Double])
forall a b. (a -> b) -> a -> b
$ StateT [Double] (WriterT [Double] m) a
-> [Double] -> WriterT [Double] m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((SamF (StateT [Double] (WriterT [Double] m) a)
 -> StateT [Double] (WriterT [Double] m) a)
-> FT SamF (WriterT [Double] m) a
-> StateT [Double] (WriterT [Double] m) a
forall (f :: * -> *) (m :: * -> *) (t :: (* -> *) -> * -> *) a.
(Functor f, Monad m, MonadTrans t, Monad (t m)) =>
(f (t m a) -> t m a) -> FT f m a -> t m a
iterTM SamF (StateT [Double] (WriterT [Double] m) a)
-> StateT [Double] (WriterT [Double] m) a
forall (m :: * -> *) b.
(MonadState [Double] m, MonadSample m, MonadWriter [Double] m) =>
SamF (m b) -> m b
f (FT SamF (WriterT [Double] m) a
 -> StateT [Double] (WriterT [Double] m) a)
-> FT SamF (WriterT [Double] m) a
-> StateT [Double] (WriterT [Double] m) a
forall a b. (a -> b) -> a -> b
$ (forall a. m a -> WriterT [Double] m a)
-> FT SamF m a -> FT SamF (WriterT [Double] m) a
forall (m :: * -> *) (n :: * -> *) (f :: * -> *) b.
(Monad m, Monad n) =>
(forall a. m a -> n a) -> FT f m b -> FT f n b
hoistFT forall a. m a -> WriterT [Double] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift FT SamF m a
m) [Double]
randomness
  where
    f :: SamF (m b) -> m b
f (Random k :: Double -> m b
k) = do
      -- This block runs in StateT [Double] (WriterT [Double]) m.
      -- StateT propagates consumed randomness while WriterT records
      -- randomness used, whether old or new.
      [Double]
xs <- m [Double]
forall s (m :: * -> *). MonadState s m => m s
get
      Double
x <- case [Double]
xs of
        [] -> m Double
forall (m :: * -> *). MonadSample m => m Double
random
        y :: Double
y : ys :: [Double]
ys -> [Double] -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [Double]
ys m () -> m Double -> m Double
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Double -> m Double
forall (m :: * -> *) a. Monad m => a -> m a
return Double
y
      [Double] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Double
x]
      Double -> m b
k Double
x

-- | Like 'withPartialRandomness', but use an arbitrary sampling monad.
runWith :: MonadSample m => [Double] -> FreeSampler Identity a -> m (a, [Double])
runWith :: [Double] -> FreeSampler Identity a -> m (a, [Double])
runWith randomness :: [Double]
randomness m :: FreeSampler Identity a
m = [Double] -> FreeSampler m a -> m (a, [Double])
forall (m :: * -> *) a.
MonadSample m =>
[Double] -> FreeSampler m a -> m (a, [Double])
withPartialRandomness [Double]
randomness (FreeSampler m a -> m (a, [Double]))
-> FreeSampler m a -> m (a, [Double])
forall a b. (a -> b) -> a -> b
$ (forall x. Identity x -> m x)
-> FreeSampler Identity a -> FreeSampler m a
forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> FreeSampler m a -> FreeSampler n a
hoist (x -> m x
forall (m :: * -> *) a. Monad m => a -> m a
return (x -> m x) -> (Identity x -> x) -> Identity x -> m x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity x -> x
forall a. Identity a -> a
runIdentity) FreeSampler Identity a
m