module Language.Hakaru.Evaluation.ExpectMonad
( pureEvaluate
, ListContext(..), ExpectAns, Expect(..), runExpect
, residualizeExpectListContext
, emit
, emit_
) where
import Prelude hiding (id, (.))
import Control.Category (Category(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Functor ((<$>))
import Control.Applicative (Applicative(..))
#endif
import qualified Data.Foldable as F
import Language.Hakaru.Syntax.IClasses (Some2(..))
import Language.Hakaru.Syntax.ABT (ABT(..), caseVarSyn, subst, maxNextFreeOrBind)
import Language.Hakaru.Syntax.Variable (memberVarSet)
import Language.Hakaru.Syntax.AST hiding (Expect)
import Language.Hakaru.Evaluation.Types
import Language.Hakaru.Evaluation.Lazy (TermEvaluator, evaluate, defaultCaseEvaluator)
import Language.Hakaru.Evaluation.PEvalMonad (ListContext(..))
import Data.Text (Text)
import Language.Hakaru.Syntax.Variable (Variable())
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing (Sing)
#ifdef __TRACE_DISINTEGRATE__
import Debug.Trace (trace)
#endif
type ExpectAns abt = ListContext abt 'ExpectP -> abt '[] 'HProb
newtype Expect abt x =
Expect { unExpect :: (x -> ExpectAns abt) -> ExpectAns abt }
residualizeExpectListContext
:: forall abt
. (ABT Term abt)
=> abt '[] 'HProb
-> ListContext abt 'ExpectP
-> abt '[] 'HProb
residualizeExpectListContext e0 =
foldl step e0 . statements
where
step :: abt '[] 'HProb -> Statement abt 'ExpectP -> abt '[] 'HProb
step e s =
case s of
SLet x body _
| not (x `memberVarSet` freeVars e) -> e
| otherwise ->
case getLazyVariable body of
Just y -> subst x (var y) e
Nothing ->
case getLazyLiteral body of
Just v -> subst x (syn $ Literal_ v) e
Nothing ->
syn (Let_ :$ fromLazy body :* bind x e :* End)
SStuff0 f _ -> f e
SStuff1 _x f _ -> f e
pureEvaluate :: (ABT Term abt) => TermEvaluator abt (Expect abt)
pureEvaluate = evaluate (brokenInvariant "perform") defaultCaseEvaluator
brokenInvariant :: String -> a
brokenInvariant loc = error (loc ++ ": Expect's invariant broken")
runExpect
:: forall abt f a
. (ABT Term abt, F.Foldable f)
=> Expect abt (abt '[] a)
-> abt '[a] 'HProb
-> f (Some2 abt)
-> abt '[] 'HProb
runExpect (Expect m) f es =
m c0 h0
where
i0 = nextFreeOrBind f `max` maxNextFreeOrBind es
h0 = ListContext i0 []
c0 e =
residualizeExpectListContext $
caseVarSyn e
(\x -> caseBind f $ \y f' -> subst y (var x) f')
(\_ -> syn (Let_ :$ e :* f :* End))
instance Functor (Expect abt) where
fmap f (Expect m) = Expect $ \c -> m (c . f)
instance Applicative (Expect abt) where
pure x = Expect $ \c -> c x
Expect mf <*> Expect mx = Expect $ \c -> mf $ \f -> mx $ \x -> c (f x)
instance Monad (Expect abt) where
return = pure
Expect m >>= k = Expect $ \c -> m $ \x -> unExpect (k x) c
instance (ABT Term abt) => EvaluationMonad abt (Expect abt) 'ExpectP where
freshNat =
Expect $ \c (ListContext i ss) ->
c i (ListContext (i+1) ss)
unsafePush s =
Expect $ \c (ListContext i ss) ->
c () (ListContext i (s:ss))
unsafePushes ss =
Expect $ \c (ListContext i ss') ->
c () (ListContext i (reverse ss ++ ss'))
select x p = loop []
where
loop ss = do
ms <- unsafePop
case ms of
Nothing -> do
unsafePushes ss
return Nothing
Just s ->
case x `isBoundBy` s >> p s of
Nothing -> loop (s:ss)
Just mr -> do
r <- mr
unsafePushes ss
return (Just r)
unsafePop :: Expect abt (Maybe (Statement abt 'ExpectP))
unsafePop =
Expect $ \c h@(ListContext i ss) ->
case ss of
[] -> c Nothing h
s:ss' -> c (Just s) (ListContext i ss')
emit
:: (ABT Term abt)
=> Text
-> Sing a
-> (abt '[a] 'HProb -> abt '[] 'HProb)
-> Expect abt (Variable a)
emit hint typ f = do
x <- freshVar hint typ
Expect $ \c h -> (f . bind x) $ c x h
emit_
:: (ABT Term abt)
=> (abt '[] 'HProb -> abt '[] 'HProb)
-> Expect abt ()
emit_ f = Expect $ \c h -> f $ c () h