{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}

{- | The effects for primitive distributions, sampling, and observing.
-}

module Effects.Dist (
  -- ** Address
  -- $Address
    Tag
  , Addr
  -- ** Dist effect
  , Dist(..)
  , handleDist
  -- ** Sample effect
  , Sample(..)
  , pattern Samp
  -- ** Observe effect
  , Observe(..)
  , pattern Obs
  ) where

import Data.Map (Map)
import Data.Maybe ( fromMaybe )
import Prog ( call, discharge, Member(..), Prog(..), EffectSum(..) )
import qualified Data.Map as Map
import PrimDist ( PrimDist )

{- $Address
   Run-time identifiers for probabilistic operations
-}

-- | An observable variable name assigned to a primitive distribution, representing a compile-time identifier
type Tag  = String
-- | An observable variable name and the index of its run-time occurrence, representing a run-time identifier
type Addr = (Tag, Int)

-- | The effect @Dist@ for primitive distributions
data Dist a = Dist
  { forall a. Dist a -> PrimDist a
getPrimDist :: PrimDist a  -- ^ primitive distribution
  , forall a. Dist a -> Maybe a
getObs :: Maybe a          -- ^ optional observed value
  , forall a. Dist a -> Maybe Tag
getTag :: Maybe Tag        -- ^ optional observable variable name
  }

instance Show a => Show (Dist a) where
  show :: Dist a -> Tag
show (Dist PrimDist a
d Maybe a
y Maybe Tag
tag) = Tag
"Dist(" Tag -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimDist a -> Tag
forall a. Show a => a -> Tag
show PrimDist a
d Tag -> ShowS
forall a. [a] -> [a] -> [a]
++ Tag
", " Tag -> ShowS
forall a. [a] -> [a] -> [a]
++ Maybe a -> Tag
forall a. Show a => a -> Tag
show Maybe a
y Tag -> ShowS
forall a. [a] -> [a] -> [a]
++ Tag
", " Tag -> ShowS
forall a. [a] -> [a] -> [a]
++ Maybe Tag -> Tag
forall a. Show a => a -> Tag
show Maybe Tag
tag Tag -> ShowS
forall a. [a] -> [a] -> [a]
++ Tag
")"

instance Eq (Dist a) where
  == :: Dist a -> Dist a -> Bool
(==) (Dist PrimDist a
d1 Maybe a
_ Maybe Tag
_) (Dist PrimDist a
d2 Maybe a
_ Maybe Tag
_) = PrimDist a
d1 PrimDist a -> PrimDist a -> Bool
forall a. Eq a => a -> a -> Bool
== PrimDist a
d2

-- | The effect @Sample@ for sampling from distirbutions
data Sample a where
  Sample  :: PrimDist a     -- ^ distribution to sample from
          -> Addr           -- ^ address of @Sample@ operation
          -> Sample a

-- | For projecting and then successfully pattern matching against @Sample@
pattern Samp :: Member Sample es => PrimDist x -> Addr -> EffectSum es x
pattern $mSamp :: forall {r} {es :: [* -> *]} {x}.
Member Sample es =>
EffectSum es x -> (PrimDist x -> Addr -> r) -> (Void# -> r) -> r
Samp d α <- (prj  -> Just (Sample d α))

-- | The effect @Observe@ for conditioning against observed values
data Observe a where
  Observe :: PrimDist a     -- ^ distribution to condition with
          -> a              -- ^ observed value
          -> Addr           -- ^ address of @Observe@ operation
          -> Observe a

-- | For projecting and then successfully pattern matching against @Observe@
pattern Obs :: Member Observe es => PrimDist x -> x -> Addr -> EffectSum es x
pattern $mObs :: forall {r} {es :: [* -> *]} {x}.
Member Observe es =>
EffectSum es x
-> (PrimDist x -> x -> Addr -> r) -> (Void# -> r) -> r
Obs d y α <- (prj -> Just (Observe d y α))

-- | Handle the @Dist@ effect to a @Sample@ or @Observe@ effect and assign an address
handleDist :: (Member Sample es, Member Observe es)
  => Prog (Dist : es) a
  -> Prog es a
handleDist :: forall (es :: [* -> *]) a.
(Member Sample es, Member Observe es) =>
Prog (Dist : es) a -> Prog es a
handleDist = Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
forall (es :: [* -> *]) a.
(Member Sample es, Member Observe es) =>
Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
loop Int
0 Map Tag Int
forall k a. Map k a
Map.empty
  where
  loop :: (Member Sample es, Member Observe es)
       => Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
  loop :: forall (es :: [* -> *]) a.
(Member Sample es, Member Observe es) =>
Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
loop Int
_ Map Tag Int
_ (Val a
x) = a -> Prog es a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
  loop Int
counter Map Tag Int
tagMap (Op EffectSum (Dist : es) x
u x -> Prog (Dist : es) a
k) = case EffectSum (Dist : es) x -> Either (EffectSum es x) (Dist x)
forall (e :: * -> *) (es :: [* -> *]) x.
EffectSum (e : es) x -> Either (EffectSum es x) (e x)
discharge EffectSum (Dist : es) x
u of
    Right (Dist PrimDist x
d Maybe x
maybe_y Maybe Tag
maybe_tag) ->
         case Maybe x
maybe_y of
              Just x
y  -> do Observe x -> Prog es x
forall (e :: * -> *) (es :: [* -> *]) x.
Member e es =>
e x -> Prog es x
call (PrimDist x -> x -> Addr -> Observe x
forall a. PrimDist a -> a -> Addr -> Observe a
Observe PrimDist x
d x
y (Tag
tag, Int
tagIdx)) Prog es x -> (x -> Prog es a) -> Prog es a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= x -> Prog es a
k'
              Maybe x
Nothing -> do Sample x -> Prog es x
forall (e :: * -> *) (es :: [* -> *]) x.
Member e es =>
e x -> Prog es x
call (PrimDist x -> Addr -> Sample x
forall a. PrimDist a -> Addr -> Sample a
Sample PrimDist x
d (Tag
tag, Int
tagIdx))    Prog es x -> (x -> Prog es a) -> Prog es a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= x -> Prog es a
k'
          where tag :: Tag
tag     = Tag -> Maybe Tag -> Tag
forall a. a -> Maybe a -> a
fromMaybe (Int -> Tag
forall a. Show a => a -> Tag
show Int
counter) Maybe Tag
maybe_tag
                tagIdx :: Int
tagIdx  = Int -> Tag -> Map Tag Int -> Int
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault Int
0 Tag
tag Map Tag Int
tagMap
                tagMap' :: Map Tag Int
tagMap' = Tag -> Int -> Map Tag Int -> Map Tag Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Tag
tag (Int
tagIdx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Map Tag Int
tagMap
                k' :: x -> Prog es a
k'      = Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
forall (es :: [* -> *]) a.
(Member Sample es, Member Observe es) =>
Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
loop (Int
counter Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Map Tag Int
tagMap' (Prog (Dist : es) a -> Prog es a)
-> (x -> Prog (Dist : es) a) -> x -> Prog es a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> Prog (Dist : es) a
k
    Left  EffectSum es x
u'  -> EffectSum es x -> (x -> Prog es a) -> Prog es a
forall (es :: [* -> *]) x a.
EffectSum es x -> (x -> Prog es a) -> Prog es a
Op EffectSum es x
u' (Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
forall (es :: [* -> *]) a.
(Member Sample es, Member Observe es) =>
Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
loop Int
counter Map Tag Int
tagMap (Prog (Dist : es) a -> Prog es a)
-> (x -> Prog (Dist : es) a) -> x -> Prog es a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> Prog (Dist : es) a
k)