{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}
module Effects.Dist (
Tag
, Addr
, Dist(..)
, handleDist
, Sample(..)
, pattern Samp
, Observe(..)
, pattern Obs
) where
import Data.Map (Map)
import Data.Maybe ( fromMaybe )
import Prog ( call, discharge, Member(..), Prog(..), EffectSum(..) )
import qualified Data.Map as Map
import PrimDist ( PrimDist )
type Tag = String
type Addr = (Tag, Int)
data Dist a = Dist
{ forall a. Dist a -> PrimDist a
getPrimDist :: PrimDist a
, forall a. Dist a -> Maybe a
getObs :: Maybe a
, forall a. Dist a -> Maybe Tag
getTag :: Maybe Tag
}
instance Show a => Show (Dist a) where
show :: Dist a -> Tag
show (Dist PrimDist a
d Maybe a
y Maybe Tag
tag) = Tag
"Dist(" Tag -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimDist a -> Tag
forall a. Show a => a -> Tag
show PrimDist a
d Tag -> ShowS
forall a. [a] -> [a] -> [a]
++ Tag
", " Tag -> ShowS
forall a. [a] -> [a] -> [a]
++ Maybe a -> Tag
forall a. Show a => a -> Tag
show Maybe a
y Tag -> ShowS
forall a. [a] -> [a] -> [a]
++ Tag
", " Tag -> ShowS
forall a. [a] -> [a] -> [a]
++ Maybe Tag -> Tag
forall a. Show a => a -> Tag
show Maybe Tag
tag Tag -> ShowS
forall a. [a] -> [a] -> [a]
++ Tag
")"
instance Eq (Dist a) where
== :: Dist a -> Dist a -> Bool
(==) (Dist PrimDist a
d1 Maybe a
_ Maybe Tag
_) (Dist PrimDist a
d2 Maybe a
_ Maybe Tag
_) = PrimDist a
d1 PrimDist a -> PrimDist a -> Bool
forall a. Eq a => a -> a -> Bool
== PrimDist a
d2
data Sample a where
Sample :: PrimDist a
-> Addr
-> Sample a
pattern Samp :: Member Sample es => PrimDist x -> Addr -> EffectSum es x
pattern $mSamp :: forall {r} {es :: [* -> *]} {x}.
Member Sample es =>
EffectSum es x -> (PrimDist x -> Addr -> r) -> (Void# -> r) -> r
Samp d α <- (prj -> Just (Sample d α))
data Observe a where
Observe :: PrimDist a
-> a
-> Addr
-> Observe a
pattern Obs :: Member Observe es => PrimDist x -> x -> Addr -> EffectSum es x
pattern $mObs :: forall {r} {es :: [* -> *]} {x}.
Member Observe es =>
EffectSum es x
-> (PrimDist x -> x -> Addr -> r) -> (Void# -> r) -> r
Obs d y α <- (prj -> Just (Observe d y α))
handleDist :: (Member Sample es, Member Observe es)
=> Prog (Dist : es) a
-> Prog es a
handleDist :: forall (es :: [* -> *]) a.
(Member Sample es, Member Observe es) =>
Prog (Dist : es) a -> Prog es a
handleDist = Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
forall (es :: [* -> *]) a.
(Member Sample es, Member Observe es) =>
Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
loop Int
0 Map Tag Int
forall k a. Map k a
Map.empty
where
loop :: (Member Sample es, Member Observe es)
=> Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
loop :: forall (es :: [* -> *]) a.
(Member Sample es, Member Observe es) =>
Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
loop Int
_ Map Tag Int
_ (Val a
x) = a -> Prog es a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
loop Int
counter Map Tag Int
tagMap (Op EffectSum (Dist : es) x
u x -> Prog (Dist : es) a
k) = case EffectSum (Dist : es) x -> Either (EffectSum es x) (Dist x)
forall (e :: * -> *) (es :: [* -> *]) x.
EffectSum (e : es) x -> Either (EffectSum es x) (e x)
discharge EffectSum (Dist : es) x
u of
Right (Dist PrimDist x
d Maybe x
maybe_y Maybe Tag
maybe_tag) ->
case Maybe x
maybe_y of
Just x
y -> do Observe x -> Prog es x
forall (e :: * -> *) (es :: [* -> *]) x.
Member e es =>
e x -> Prog es x
call (PrimDist x -> x -> Addr -> Observe x
forall a. PrimDist a -> a -> Addr -> Observe a
Observe PrimDist x
d x
y (Tag
tag, Int
tagIdx)) Prog es x -> (x -> Prog es a) -> Prog es a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= x -> Prog es a
k'
Maybe x
Nothing -> do Sample x -> Prog es x
forall (e :: * -> *) (es :: [* -> *]) x.
Member e es =>
e x -> Prog es x
call (PrimDist x -> Addr -> Sample x
forall a. PrimDist a -> Addr -> Sample a
Sample PrimDist x
d (Tag
tag, Int
tagIdx)) Prog es x -> (x -> Prog es a) -> Prog es a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= x -> Prog es a
k'
where tag :: Tag
tag = Tag -> Maybe Tag -> Tag
forall a. a -> Maybe a -> a
fromMaybe (Int -> Tag
forall a. Show a => a -> Tag
show Int
counter) Maybe Tag
maybe_tag
tagIdx :: Int
tagIdx = Int -> Tag -> Map Tag Int -> Int
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault Int
0 Tag
tag Map Tag Int
tagMap
tagMap' :: Map Tag Int
tagMap' = Tag -> Int -> Map Tag Int -> Map Tag Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Tag
tag (Int
tagIdx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Map Tag Int
tagMap
k' :: x -> Prog es a
k' = Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
forall (es :: [* -> *]) a.
(Member Sample es, Member Observe es) =>
Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
loop (Int
counter Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Map Tag Int
tagMap' (Prog (Dist : es) a -> Prog es a)
-> (x -> Prog (Dist : es) a) -> x -> Prog es a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> Prog (Dist : es) a
k
Left EffectSum es x
u' -> EffectSum es x -> (x -> Prog es a) -> Prog es a
forall (es :: [* -> *]) x a.
EffectSum es x -> (x -> Prog es a) -> Prog es a
Op EffectSum es x
u' (Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
forall (es :: [* -> *]) a.
(Member Sample es, Member Observe es) =>
Int -> Map Tag Int -> Prog (Dist : es) a -> Prog es a
loop Int
counter Map Tag Int
tagMap (Prog (Dist : es) a -> Prog es a)
-> (x -> Prog (Dist : es) a) -> x -> Prog es a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> Prog (Dist : es) a
k)