{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-- |
-- Slower than Control.Monad.Bayes.Density.Free, so not used by default,
-- but more elementary to understand. Just uses standard
-- monad transformer techniques.
module Control.Monad.Bayes.Density.State where

import Control.Monad.Bayes.Class (MonadDistribution (random))
import Control.Monad.State (MonadState (get, put), StateT, evalStateT)
import Control.Monad.Writer

newtype DensityT m a = DensityT {forall (m :: * -> *) a.
DensityT m a -> WriterT [Double] (StateT [Double] m) a
getDensityT :: WriterT [Double] (StateT [Double] m) a} deriving newtype ((forall a b. (a -> b) -> DensityT m a -> DensityT m b)
-> (forall a b. a -> DensityT m b -> DensityT m a)
-> Functor (DensityT m)
forall a b. a -> DensityT m b -> DensityT m a
forall a b. (a -> b) -> DensityT m a -> DensityT m b
forall (m :: * -> *) a b.
Functor m =>
a -> DensityT m b -> DensityT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> DensityT m a -> DensityT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> DensityT m a -> DensityT m b
fmap :: forall a b. (a -> b) -> DensityT m a -> DensityT m b
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> DensityT m b -> DensityT m a
<$ :: forall a b. a -> DensityT m b -> DensityT m a
Functor, Functor (DensityT m)
Functor (DensityT m) =>
(forall a. a -> DensityT m a)
-> (forall a b.
    DensityT m (a -> b) -> DensityT m a -> DensityT m b)
-> (forall a b c.
    (a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c)
-> (forall a b. DensityT m a -> DensityT m b -> DensityT m b)
-> (forall a b. DensityT m a -> DensityT m b -> DensityT m a)
-> Applicative (DensityT m)
forall a. a -> DensityT m a
forall a b. DensityT m a -> DensityT m b -> DensityT m a
forall a b. DensityT m a -> DensityT m b -> DensityT m b
forall a b. DensityT m (a -> b) -> DensityT m a -> DensityT m b
forall a b c.
(a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c
forall (m :: * -> *). Monad m => Functor (DensityT m)
forall (m :: * -> *) a. Monad m => a -> DensityT m a
forall (m :: * -> *) a b.
Monad m =>
DensityT m a -> DensityT m b -> DensityT m a
forall (m :: * -> *) a b.
Monad m =>
DensityT m a -> DensityT m b -> DensityT m b
forall (m :: * -> *) a b.
Monad m =>
DensityT m (a -> b) -> DensityT m a -> DensityT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall (m :: * -> *) a. Monad m => a -> DensityT m a
pure :: forall a. a -> DensityT m a
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
DensityT m (a -> b) -> DensityT m a -> DensityT m b
<*> :: forall a b. DensityT m (a -> b) -> DensityT m a -> DensityT m b
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c
liftA2 :: forall a b c.
(a -> b -> c) -> DensityT m a -> DensityT m b -> DensityT m c
$c*> :: forall (m :: * -> *) a b.
Monad m =>
DensityT m a -> DensityT m b -> DensityT m b
*> :: forall a b. DensityT m a -> DensityT m b -> DensityT m b
$c<* :: forall (m :: * -> *) a b.
Monad m =>
DensityT m a -> DensityT m b -> DensityT m a
<* :: forall a b. DensityT m a -> DensityT m b -> DensityT m a
Applicative, Applicative (DensityT m)
Applicative (DensityT m) =>
(forall a b. DensityT m a -> (a -> DensityT m b) -> DensityT m b)
-> (forall a b. DensityT m a -> DensityT m b -> DensityT m b)
-> (forall a. a -> DensityT m a)
-> Monad (DensityT m)
forall a. a -> DensityT m a
forall a b. DensityT m a -> DensityT m b -> DensityT m b
forall a b. DensityT m a -> (a -> DensityT m b) -> DensityT m b
forall (m :: * -> *). Monad m => Applicative (DensityT m)
forall (m :: * -> *) a. Monad m => a -> DensityT m a
forall (m :: * -> *) a b.
Monad m =>
DensityT m a -> DensityT m b -> DensityT m b
forall (m :: * -> *) a b.
Monad m =>
DensityT m a -> (a -> DensityT m b) -> DensityT m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
DensityT m a -> (a -> DensityT m b) -> DensityT m b
>>= :: forall a b. DensityT m a -> (a -> DensityT m b) -> DensityT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
DensityT m a -> DensityT m b -> DensityT m b
>> :: forall a b. DensityT m a -> DensityT m b -> DensityT m b
$creturn :: forall (m :: * -> *) a. Monad m => a -> DensityT m a
return :: forall a. a -> DensityT m a
Monad)

instance MonadTrans DensityT where
  lift :: forall (m :: * -> *) a. Monad m => m a -> DensityT m a
lift = WriterT [Double] (StateT [Double] m) a -> DensityT m a
forall (m :: * -> *) a.
WriterT [Double] (StateT [Double] m) a -> DensityT m a
DensityT (WriterT [Double] (StateT [Double] m) a -> DensityT m a)
-> (m a -> WriterT [Double] (StateT [Double] m) a)
-> m a
-> DensityT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT [Double] m a -> WriterT [Double] (StateT [Double] m) a
forall (m :: * -> *) a. Monad m => m a -> WriterT [Double] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT [Double] m a -> WriterT [Double] (StateT [Double] m) a)
-> (m a -> StateT [Double] m a)
-> m a
-> WriterT [Double] (StateT [Double] m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> StateT [Double] m a
forall (m :: * -> *) a. Monad m => m a -> StateT [Double] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance (Monad m) => MonadState [Double] (DensityT m) where
  get :: DensityT m [Double]
get = WriterT [Double] (StateT [Double] m) [Double]
-> DensityT m [Double]
forall (m :: * -> *) a.
WriterT [Double] (StateT [Double] m) a -> DensityT m a
DensityT (WriterT [Double] (StateT [Double] m) [Double]
 -> DensityT m [Double])
-> WriterT [Double] (StateT [Double] m) [Double]
-> DensityT m [Double]
forall a b. (a -> b) -> a -> b
$ StateT [Double] m [Double]
-> WriterT [Double] (StateT [Double] m) [Double]
forall (m :: * -> *) a. Monad m => m a -> WriterT [Double] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT [Double] m [Double]
 -> WriterT [Double] (StateT [Double] m) [Double])
-> StateT [Double] m [Double]
-> WriterT [Double] (StateT [Double] m) [Double]
forall a b. (a -> b) -> a -> b
$ StateT [Double] m [Double]
forall s (m :: * -> *). MonadState s m => m s
get
  put :: [Double] -> DensityT m ()
put = WriterT [Double] (StateT [Double] m) () -> DensityT m ()
forall (m :: * -> *) a.
WriterT [Double] (StateT [Double] m) a -> DensityT m a
DensityT (WriterT [Double] (StateT [Double] m) () -> DensityT m ())
-> ([Double] -> WriterT [Double] (StateT [Double] m) ())
-> [Double]
-> DensityT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT [Double] m () -> WriterT [Double] (StateT [Double] m) ()
forall (m :: * -> *) a. Monad m => m a -> WriterT [Double] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT [Double] m () -> WriterT [Double] (StateT [Double] m) ())
-> ([Double] -> StateT [Double] m ())
-> [Double]
-> WriterT [Double] (StateT [Double] m) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Double] -> StateT [Double] m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

instance (Monad m) => MonadWriter [Double] (DensityT m) where
  tell :: [Double] -> DensityT m ()
tell = WriterT [Double] (StateT [Double] m) () -> DensityT m ()
forall (m :: * -> *) a.
WriterT [Double] (StateT [Double] m) a -> DensityT m a
DensityT (WriterT [Double] (StateT [Double] m) () -> DensityT m ())
-> ([Double] -> WriterT [Double] (StateT [Double] m) ())
-> [Double]
-> DensityT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Double] -> WriterT [Double] (StateT [Double] m) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
  listen :: forall a. DensityT m a -> DensityT m (a, [Double])
listen = WriterT [Double] (StateT [Double] m) (a, [Double])
-> DensityT m (a, [Double])
forall (m :: * -> *) a.
WriterT [Double] (StateT [Double] m) a -> DensityT m a
DensityT (WriterT [Double] (StateT [Double] m) (a, [Double])
 -> DensityT m (a, [Double]))
-> (DensityT m a
    -> WriterT [Double] (StateT [Double] m) (a, [Double]))
-> DensityT m a
-> DensityT m (a, [Double])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriterT [Double] (StateT [Double] m) a
-> WriterT [Double] (StateT [Double] m) (a, [Double])
forall a.
WriterT [Double] (StateT [Double] m) a
-> WriterT [Double] (StateT [Double] m) (a, [Double])
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen (WriterT [Double] (StateT [Double] m) a
 -> WriterT [Double] (StateT [Double] m) (a, [Double]))
-> (DensityT m a -> WriterT [Double] (StateT [Double] m) a)
-> DensityT m a
-> WriterT [Double] (StateT [Double] m) (a, [Double])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DensityT m a -> WriterT [Double] (StateT [Double] m) a
forall (m :: * -> *) a.
DensityT m a -> WriterT [Double] (StateT [Double] m) a
getDensityT
  pass :: forall a. DensityT m (a, [Double] -> [Double]) -> DensityT m a
pass = WriterT [Double] (StateT [Double] m) a -> DensityT m a
forall (m :: * -> *) a.
WriterT [Double] (StateT [Double] m) a -> DensityT m a
DensityT (WriterT [Double] (StateT [Double] m) a -> DensityT m a)
-> (DensityT m (a, [Double] -> [Double])
    -> WriterT [Double] (StateT [Double] m) a)
-> DensityT m (a, [Double] -> [Double])
-> DensityT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriterT [Double] (StateT [Double] m) (a, [Double] -> [Double])
-> WriterT [Double] (StateT [Double] m) a
forall a.
WriterT [Double] (StateT [Double] m) (a, [Double] -> [Double])
-> WriterT [Double] (StateT [Double] m) a
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (WriterT [Double] (StateT [Double] m) (a, [Double] -> [Double])
 -> WriterT [Double] (StateT [Double] m) a)
-> (DensityT m (a, [Double] -> [Double])
    -> WriterT [Double] (StateT [Double] m) (a, [Double] -> [Double]))
-> DensityT m (a, [Double] -> [Double])
-> WriterT [Double] (StateT [Double] m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DensityT m (a, [Double] -> [Double])
-> WriterT [Double] (StateT [Double] m) (a, [Double] -> [Double])
forall (m :: * -> *) a.
DensityT m a -> WriterT [Double] (StateT [Double] m) a
getDensityT

instance (MonadDistribution m) => MonadDistribution (DensityT m) where
  random :: DensityT m Double
random = do
    [Double]
trace <- DensityT m [Double]
forall s (m :: * -> *). MonadState s m => m s
get
    Double
x <- case [Double]
trace of
      [] -> DensityT m Double
forall (m :: * -> *). MonadDistribution m => m Double
random
      Double
r : [Double]
xs -> [Double] -> DensityT m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [Double]
xs DensityT m () -> DensityT m Double -> DensityT m Double
forall a b. DensityT m a -> DensityT m b -> DensityT m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Double -> DensityT m Double
forall a. a -> DensityT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
r
    [Double] -> DensityT m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Double
x]
    Double -> DensityT m Double
forall a. a -> DensityT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
x

runDensityT :: (Monad m) => DensityT m b -> [Double] -> m (b, [Double])
runDensityT :: forall (m :: * -> *) b.
Monad m =>
DensityT m b -> [Double] -> m (b, [Double])
runDensityT (DensityT WriterT [Double] (StateT [Double] m) b
m) = StateT [Double] m (b, [Double]) -> [Double] -> m (b, [Double])
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (WriterT [Double] (StateT [Double] m) b
-> StateT [Double] m (b, [Double])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT WriterT [Double] (StateT [Double] m) b
m)