-- |
-- Module      : Control.Monad.Bayes.Helpers
-- Description : Helper functions for working with inference monads
-- Copyright   : (c) Adam Scibior, 2015-2020
-- License     : MIT
-- Maintainer  : leonhard.markert@tweag.io
-- Stability   : experimental
-- Portability : GHC
module Control.Monad.Bayes.Helpers
  ( W,
    hoistW,
    P,
    hoistP,
    S,
    hoistS,
    F,
    hoistF,
    T,
    hoistT,
    hoistWF,
    hoistSP,
    hoistSTP,
  )
where

import Control.Monad.Bayes.Free as Free
import Control.Monad.Bayes.Population as Pop
import Control.Monad.Bayes.Sequential as Seq
import Control.Monad.Bayes.Traced as Tr
import Control.Monad.Bayes.Weighted as Weighted

type W = Weighted

type P = Population

type S = Sequential

type F = FreeSampler

type T = Traced

hoistW :: (forall x. m x -> n x) -> W m a -> W n a
hoistW :: (forall x. m x -> n x) -> W m a -> W n a
hoistW = (forall x. m x -> n x) -> W m a -> W n a
forall (m :: * -> *) (n :: * -> *) a.
(forall x. m x -> n x) -> Weighted m a -> Weighted n a
Weighted.hoist

hoistP ::
  (Monad m, Monad n) =>
  (forall x. m x -> n x) ->
  P m a ->
  P n a
hoistP :: (forall x. m x -> n x) -> P m a -> P n a
hoistP = (forall x. m x -> n x) -> P m a -> P n a
forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> Population m a -> Population n a
Pop.hoist

hoistS :: (forall x. m x -> m x) -> S m a -> S m a
hoistS :: (forall x. m x -> m x) -> S m a -> S m a
hoistS = (forall x. m x -> m x) -> S m a -> S m a
forall (m :: * -> *) a.
(forall x. m x -> m x) -> Sequential m a -> Sequential m a
Seq.hoistFirst

hoistF :: (Monad m, Monad n) => (forall x. m x -> n x) -> F m a -> F n a
hoistF :: (forall x. m x -> n x) -> F m a -> F n a
hoistF = (forall x. m x -> n x) -> F m a -> F n a
forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> FreeSampler m a -> FreeSampler n a
Free.hoist

hoistWF ::
  (Monad m, Monad n) =>
  (forall x. m x -> n x) ->
  W (F m) a ->
  W (F n) a
hoistWF :: (forall x. m x -> n x) -> W (F m) a -> W (F n) a
hoistWF m :: forall x. m x -> n x
m = (forall x. F m x -> F n x) -> W (F m) a -> W (F n) a
forall (m :: * -> *) (n :: * -> *) a.
(forall x. m x -> n x) -> Weighted m a -> Weighted n a
hoistW ((forall x. F m x -> F n x) -> W (F m) a -> W (F n) a)
-> (forall x. F m x -> F n x) -> W (F m) a -> W (F n) a
forall a b. (a -> b) -> a -> b
$ (forall x. m x -> n x) -> F m x -> F n x
forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> FreeSampler m a -> FreeSampler n a
hoistF forall x. m x -> n x
m

hoistSP ::
  Monad m =>
  (forall x. m x -> m x) ->
  S (P m) a ->
  S (P m) a
hoistSP :: (forall x. m x -> m x) -> S (P m) a -> S (P m) a
hoistSP m :: forall x. m x -> m x
m = (forall x. P m x -> P m x) -> S (P m) a -> S (P m) a
forall (m :: * -> *) a.
(forall x. m x -> m x) -> Sequential m a -> Sequential m a
hoistS ((forall x. P m x -> P m x) -> S (P m) a -> S (P m) a)
-> (forall x. P m x -> P m x) -> S (P m) a -> S (P m) a
forall a b. (a -> b) -> a -> b
$ (forall x. m x -> m x) -> P m x -> P m x
forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> Population m a -> Population n a
hoistP forall x. m x -> m x
m

hoistSTP ::
  Monad m =>
  (forall x. m x -> m x) ->
  S (T (P m)) a ->
  S (T (P m)) a
hoistSTP :: (forall x. m x -> m x) -> S (T (P m)) a -> S (T (P m)) a
hoistSTP m :: forall x. m x -> m x
m = (forall x. T (P m) x -> T (P m) x)
-> S (T (P m)) a -> S (T (P m)) a
forall (m :: * -> *) a.
(forall x. m x -> m x) -> Sequential m a -> Sequential m a
hoistS ((forall x. T (P m) x -> T (P m) x)
 -> S (T (P m)) a -> S (T (P m)) a)
-> (forall x. T (P m) x -> T (P m) x)
-> S (T (P m)) a
-> S (T (P m)) a
forall a b. (a -> b) -> a -> b
$ (forall x. P m x -> P m x) -> Traced (P m) x -> Traced (P m) x
forall (m :: * -> *) a.
(forall x. m x -> m x) -> Traced m a -> Traced m a
hoistT ((forall x. P m x -> P m x) -> Traced (P m) x -> Traced (P m) x)
-> (forall x. P m x -> P m x) -> Traced (P m) x -> Traced (P m) x
forall a b. (a -> b) -> a -> b
$ (forall x. m x -> m x) -> P m x -> P m x
forall (m :: * -> *) (n :: * -> *) a.
(Monad m, Monad n) =>
(forall x. m x -> n x) -> Population m a -> Population n a
hoistP forall x. m x -> m x
m