module Language.Hakaru.Evaluation.DisintegrationMonad
(
getStatements, putStatements
, ListContext(..), Ans, Dis(..), runDis
, bot
, emit
, emitMBind
, emitLet
, emitLet'
, emitUnpair
, emit_
, emitMBind_
, emitGuard
, emitWeight
, emitFork_
, emitSuperpose
, choose
, pushPlate
, getIndices
, withIndices
, extendIndices
, extendLocInds
, statementInds
, sizeInnermostInd
, Loc(..)
, getLocs
, putLocs
, insertLoc
, adjustLoc
, mkLoc
, freeLocError
, apply
#ifdef __TRACE_DISINTEGRATE__
, prettyLoc
, prettyLocs
#endif
) where
import Prelude hiding (id, (.))
import Control.Category (Category(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (Monoid(..))
import Data.Functor ((<$>))
import Control.Applicative (Applicative(..))
#endif
import Data.Maybe
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NE
import Control.Applicative (Alternative(..))
import Control.Monad (MonadPlus(..),foldM,guard)
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Number.Nat
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing (Sing(..), sUnMeasure, sUnPair, sUnit)
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.DatumABT
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Syntax.ABT
import qualified Language.Hakaru.Syntax.Prelude as P
import Language.Hakaru.Evaluation.Types
import Language.Hakaru.Evaluation.PEvalMonad (ListContext(..))
import Language.Hakaru.Evaluation.Lazy (reifyPair)
#ifdef __TRACE_DISINTEGRATE__
import Debug.Trace (trace, traceM)
import qualified Text.PrettyPrint as PP
import Language.Hakaru.Pretty.Haskell (ppVariable, pretty)
#endif
getStatements :: Dis abt [Statement abt 'Impure]
getStatements = Dis $ \_ c h -> c (statements h) h
putStatements :: [Statement abt 'Impure] -> Dis abt ()
putStatements ss =
Dis $ \_ c (ListContext i _) loc ->
c () (ListContext i ss) loc
plug :: forall abt a xs b
. (ABT Term abt)
=> Variable a
-> abt '[] a
-> abt xs b
-> abt xs b
plug x e = start
where
start :: forall xs' b' . abt xs' b' -> abt xs' b'
start f = loop f (viewABT f)
loop :: forall xs' b'. abt xs' b' -> View (Term abt) xs' b' -> abt xs' b'
loop _ (Syn t) = syn $! fmap21 start t
loop f (Var z) = case varEq x z of
Just Refl -> e
Nothing -> f
loop f (Bind _ _) = caseBind f $ \z f' ->
bind z (loop f' (viewABT f'))
plugs :: forall abt xs a
. (ABT Term abt)
=> Assocs (abt '[])
-> abt xs a
-> abt xs a
plugs rho0 e0 = F.foldl (\e (Assoc x v) -> plug x v e) e0 (unAssocs rho0)
residualizeListContext
:: forall abt a
. (ABT Term abt)
=> ListContext abt 'Impure
-> Assocs (abt '[])
-> abt '[] ('HMeasure a)
-> abt '[] ('HMeasure a)
residualizeListContext ss rho e0 =
#ifdef __TRACE_DISINTEGRATE__
trace ("e0: " ++ show (pretty e0) ++ "\n"
++ show (pretty_Statements (statements ss))) $
#endif
foldl step (plugs rho e0) (statements ss)
where
step
:: abt '[] ('HMeasure a)
-> Statement abt 'Impure
-> abt '[] ('HMeasure a)
step e s =
#ifdef __TRACE_DISINTEGRATE__
trace ("wrapping " ++ show (ppStatement 0 s) ++ "\n"
++ "around term " ++ show (pretty e)) $
#endif
case s of
SBind x body _ ->
syn (MBind :$ plugs rho (fromLazy body) :* bind x e :* End)
SLet x body _
| not (x `memberVarSet` freeVars e) ->
#ifdef __TRACE_DISINTEGRATE__
trace ("could not find location" ++ show x ++ "\n"
++ "in term " ++ show (pretty e) ++ "\n"
++ "given rho " ++ show (prettyAssocs rho)) $
#endif
e
| otherwise ->
case getLazyVariable body of
Just y -> plug x (plugs rho (var y)) e
Nothing ->
case getLazyLiteral body of
Just v -> plug x (syn $ Literal_ v) e
Nothing ->
syn (Let_ :$ plugs rho (fromLazy body) :* bind x e :* End)
SGuard xs pat scrutinee _ ->
syn $ Case_ (plugs rho $ fromLazy scrutinee)
[ Branch pat (binds_ xs e)
, Branch PWild (P.reject $ typeOf e)
]
SWeight body _ -> syn $ Superpose_ ((plugs rho $ fromLazy body, e) :| [])
data Loc :: (Hakaru -> *) -> Hakaru -> * where
Loc :: Variable a
-> [Variable 'HNat]
-> Loc ast a
MultiLoc
:: Variable a
-> [Variable 'HNat]
-> Loc ast ('HArray a)
locIndices :: Loc ast a -> [Variable 'HNat]
locIndices (Loc _ inds) = inds
locIndices (MultiLoc _ inds) = inds
extendLocInds :: Variable 'HNat -> [Variable 'HNat] -> [Variable 'HNat]
extendLocInds = (:)
#ifdef __TRACE_DISINTEGRATE__
prettyLoc :: Loc ast (a :: Hakaru) -> PP.Doc
prettyLoc (Loc l inds) = PP.text "Loc" PP.<+> ppVariable l
PP.<+> ppList (map ppVariable inds)
prettyLoc (MultiLoc l inds) = PP.text "MultiLoc" PP.<+> ppVariable l
PP.<+> ppList (map ppVariable inds)
prettyLocs :: (ABT Term abt)
=> Assocs (Loc (abt '[]))
-> PP.Doc
prettyLocs a = PP.vcat $ map go (fromAssocs a)
where go (Assoc x l) = ppVariable x PP.<+>
PP.text "->" PP.<+>
prettyLoc l
#endif
type Ans abt a
= ListContext abt 'Impure
-> Assocs (Loc (abt '[]))
-> [abt '[] ('HMeasure a)]
newtype Dis abt x =
Dis { unDis :: forall a. [Index (abt '[])] -> (x -> Ans abt a) -> Ans abt a }
runDis :: (ABT Term abt, F.Foldable f)
=> Dis abt (abt '[] a)
-> f (Some2 abt)
-> [abt '[] ('HMeasure a)]
runDis d es =
m0 [] c0 (ListContext i0 []) emptyAssocs
where
(Dis m0) = d >>= residualizeLocs
c0 (e,rho) ss _ = [residualizeListContext ss rho (syn(Dirac :$ e :* End))]
i0 = maxNextFree es
residualizeLocs :: forall a abt. (ABT Term abt)
=> abt '[] a
-> Dis abt (abt '[] a, Assocs (abt '[]))
residualizeLocs e = do
ss <- getStatements
(ss', newlocs) <- foldM step ([], emptyAssocs) ss
rho <- convertLocs newlocs
putStatements (reverse ss')
#ifdef __TRACE_DISINTEGRATE__
trace ("residualizeLocs: old heap:\n" ++ show (pretty_Statements ss )) $ return ()
trace ("residualizeLocs: new heap:\n" ++ show (pretty_Statements ss')) $ return ()
locs <- getLocs
traceM ("oldlocs:\n" ++ show (prettyLocs locs) ++ "\n")
traceM ("new assoc for renaming:\n" ++ show (prettyAssocs rho))
#endif
return (e, rho)
where step (ss',newlocs) s = do (s',newlocs') <- residualizeLoc s
return (s':ss', insertAssocs newlocs' newlocs)
data Name (a :: Hakaru) = Name {nameHint :: Text, nameID :: Nat}
varName :: Variable a -> Name b
varName x = Name (varHint x) (varID x)
residualizeLoc :: (ABT Term abt)
=> Statement abt 'Impure
-> Dis abt (Statement abt 'Impure, Assocs Name)
residualizeLoc s =
case s of
SBind l _ _ -> do
(s', newname) <- reifyStatement s
return (s', singletonAssocs l newname)
SLet l _ _ -> do
(s', newname) <- reifyStatement s
return (s', singletonAssocs l newname)
SWeight w inds -> do
l <- freshVar Text.empty sUnit
let bodyW = Thunk $ P.weight (fromLazy w)
(s', newname) <- reifyStatement (SBind l bodyW inds)
return (s', singletonAssocs l newname)
SGuard ls _ _ ixs
| null ixs -> return (s, toAssocs1 ls (fmap11 varName ls))
| otherwise -> error "undefined: case statement under an array"
reifyStatement :: (ABT Term abt)
=> Statement abt 'Impure
-> Dis abt (Statement abt 'Impure, Name a)
reifyStatement s =
case s of
SBind l _ [] -> return (s, varName l)
SBind l body (i:is) -> do
let plate = Thunk . P.plateWithVar (indSize i) (indVar i)
l' <- freshVar (varHint l) (SArray (varType l))
reifyStatement (SBind l' (plate $ fromLazy body) is)
SLet l _ [] -> return (s, varName l)
SLet l body (i:is) -> do
let array = Thunk . P.arrayWithVar (indSize i) (indVar i)
l' <- freshVar (varHint l) (SArray (varType l))
reifyStatement (SLet l' (array $ fromLazy body) is)
SWeight _ _ -> error "reifyStatement called on SWeight"
SGuard _ _ _ _ -> error "reifyStatement called on SGuard"
sizeInnermostInd :: (ABT Term abt)
=> Variable (a :: Hakaru)
-> Dis abt (abt '[] 'HNat)
sizeInnermostInd l =
(maybe (freeLocError l) return =<<) . select l $ \s ->
do guard (length (statementInds s) >= 1)
case s of
SBind l' _ ixs -> do Refl <- varEq l l'
Just $ unsafePush s >>
return (indSize (head ixs))
SLet l' _ ixs -> do Refl <- varEq l l'
Just $ unsafePush s >>
return (indSize (head ixs))
SWeight _ _ -> Nothing
SGuard _ _ _ _ -> error "TODO: sizeInnermostInd{SGuard}"
fromLoc :: (ABT Term abt)
=> Name b
-> Sing a
-> [Variable 'HNat]
-> abt '[] a
fromLoc name typ [] = var $ Variable { varHint = nameHint name
, varID = nameID name
, varType = typ }
fromLoc name typ (i:is) = fromLoc name (SArray typ) is P.! var i
convertLocs :: (ABT Term abt)
=> Assocs Name
-> Dis abt (Assocs (abt '[]))
convertLocs newlocs = F.foldr step emptyAssocs . fromAssocs <$> getLocs
where
build :: (ABT Term abt)
=> Assoc (Loc (abt '[]))
-> Name a
-> Assoc (abt '[])
build (Assoc x loc) name =
Assoc x (fromLoc name (varType x)
(case loc of Loc _ js -> js; MultiLoc _ js -> js))
step assoc@(Assoc _ loc) = insertAssoc $
case loc of
Loc l _ -> maybe (freeLocError l)
(build assoc)
(lookupAssoc l newlocs)
MultiLoc l _ -> maybe (freeLocError l)
(build assoc)
(lookupAssoc l newlocs)
freeLocError :: Variable (a :: Hakaru) -> b
freeLocError l = error $ "Found a free location " ++ show l
apply :: (ABT Term abt)
=> [(Index (abt '[]), Index (abt '[]))]
-> abt '[] a
-> Dis abt (abt '[] a)
apply ijs e = do locs <- fromAssocs <$> getLocs
rho' <- foldM step rho locs
return (renames rho' e)
where
rho = toAssocs $ map (\(i,j) -> Assoc (indVar i) (indVar j)) ijs
step r (Assoc x loc) =
let inds = locIndices loc
check i = lookupAssoc i rho
inds' = map (\i -> fromMaybe i (check i)) inds
in if (any isJust (map check inds))
then do x' <- case loc of
Loc l _ -> mkLoc Text.empty l inds'
MultiLoc l _ -> mkMultiLoc Text.empty l inds'
return (insertAssoc (Assoc x x') r)
else return r
extendIndices
:: (ABT Term abt)
=> Index (abt '[])
-> [Index (abt '[])]
-> [Index (abt '[])]
extendIndices j js | j `elem` js
= error ("Duplicate index between " )
| otherwise
= j : js
statementInds :: Statement abt p -> [Index (abt '[])]
statementInds (SBind _ _ i) = i
statementInds (SLet _ _ i) = i
statementInds (SWeight _ i) = i
statementInds (SGuard _ _ _ i) = i
statementInds (SStuff0 _ i) = i
statementInds (SStuff1 _ _ i) = i
getLocs :: (ABT Term abt)
=> Dis abt (Assocs (Loc (abt '[])))
getLocs = Dis $ \_ c h l -> c l h l
putLocs :: (ABT Term abt)
=> Assocs (Loc (abt '[]))
-> Dis abt ()
putLocs l = Dis $ \_ c h _ -> c () h l
insertLoc :: (ABT Term abt)
=> Variable a
-> Loc (abt '[]) a
-> Dis abt ()
insertLoc v loc =
Dis $ \_ c h l -> c () h $
insertAssoc (Assoc v loc) l
adjustLoc :: (ABT Term abt)
=> Variable (a :: Hakaru)
-> (Assoc (Loc (abt '[])) -> Assoc (Loc (abt '[])))
-> Dis abt ()
adjustLoc x f = do
locs <- getLocs
putLocs $ adjustAssoc x f locs
mkLoc
:: (ABT Term abt)
=> Text
-> Variable (a :: Hakaru)
-> [Variable 'HNat]
-> Dis abt (Variable a)
mkLoc hint s inds = do
x <- freshVar hint (varType s)
insertLoc x (Loc s inds)
return x
mkLocs
:: (ABT Term abt)
=> List1 Variable (xs :: [Hakaru])
-> [Variable 'HNat]
-> Dis abt (List1 Variable xs)
mkLocs Nil1 _ = return Nil1
mkLocs (Cons1 x xs) inds = Cons1
<$> mkLoc Text.empty x inds
<*> mkLocs xs inds
mkMultiLoc
:: (ABT Term abt)
=> Text
-> Variable a
-> [Variable 'HNat]
-> Dis abt (Variable ('HArray a))
mkMultiLoc hint s inds = do
x' <- freshVar hint (SArray $ varType s)
insertLoc x' (MultiLoc s inds)
return x'
instance Functor (Dis abt) where
fmap f (Dis m) = Dis $ \i c -> m i (c . f)
instance Applicative (Dis abt) where
pure x = Dis $ \_ c -> c x
Dis mf <*> Dis mx = Dis $ \i c -> mf i $ \f -> mx i $ \x -> c (f x)
instance Monad (Dis abt) where
return = pure
Dis m >>= k = Dis $ \i c -> m i $ \x -> unDis (k x) i c
instance Alternative (Dis abt) where
empty = Dis $ \_ _ _ _ -> []
Dis m <|> Dis n = Dis $ \i c h l -> m i c h l ++ n i c h l
instance MonadPlus (Dis abt) where
mzero = empty
mplus = (<|>)
instance (ABT Term abt) => EvaluationMonad abt (Dis abt) 'Impure where
freshNat =
Dis $ \_ c (ListContext n ss) ->
c n (ListContext (n+1) ss)
freshenStatement s =
case s of
SWeight _ _ -> return (s, mempty)
SBind x body i -> do
l <- freshenVar x
x' <- mkLoc (varHint x) l (map indVar i)
return (SBind l body i, singletonAssocs x x')
SLet x body i -> do
l <- freshenVar x
x' <- mkLoc (varHint x) l (map indVar i)
return (SLet l body i, singletonAssocs x x')
SGuard xs pat scrutinee i -> do
ls <- freshenVars xs
xs' <- mkLocs ls (map indVar i)
return (SGuard ls pat scrutinee i, toAssocs1 xs xs')
getIndices = Dis $ \i c -> c i
unsafePush s =
Dis $ \_ c (ListContext i ss) ->
c () (ListContext i (s:ss))
unsafePushes ss =
Dis $ \_ c (ListContext i ss') ->
c () (ListContext i (reverse ss ++ ss'))
select l p = loop []
where
loop ss = do
ms <- unsafePop
case ms of
Nothing -> do
unsafePushes ss
return Nothing
Just s ->
case l `isBoundBy` s >> p s of
Nothing -> loop (s:ss)
Just mr -> do
r <- mr
unsafePushes ss
return (Just r)
withIndices :: [Index (abt '[])] -> Dis abt a -> Dis abt a
withIndices inds (Dis m) = Dis $ \_ c -> m inds c
unsafePop :: Dis abt (Maybe (Statement abt 'Impure))
unsafePop =
Dis $ \_ c h@(ListContext i ss) loc ->
case ss of
[] -> c Nothing h loc
s:ss' -> c (Just s) (ListContext i ss') loc
pushPlate
:: (ABT Term abt)
=> abt '[] 'HNat
-> abt '[ 'HNat ] ('HMeasure a)
-> Dis abt (Variable ('HArray a))
pushPlate n e =
caseBind e $ \x body -> do
inds <- getIndices
i <- freshInd n
p <- freshVar Text.empty (sUnMeasure $ typeOf body)
unsafePush (SBind p (Thunk $ rename x (indVar i) body)
(extendIndices i inds))
mkMultiLoc Text.empty p (map indVar inds)
bot :: (ABT Term abt) => Dis abt a
bot = Dis $ \_ _ _ _ -> []
emit
:: (ABT Term abt)
=> Text
-> Sing a
-> (forall r. abt '[a] ('HMeasure r) -> abt '[] ('HMeasure r))
-> Dis abt (Variable a)
emit hint typ f = do
x <- freshVar hint typ
Dis $ \_ c h l -> (f . bind x) <$> c x h l
emitMBind :: (ABT Term abt) => abt '[] ('HMeasure a) -> Dis abt (Variable a)
emitMBind m =
emit Text.empty (sUnMeasure $ typeOf m) $ \e ->
syn (MBind :$ m :* e :* End)
emitLet :: (ABT Term abt) => abt '[] a -> Dis abt (Variable a)
emitLet e =
caseVarSyn e return $ \_ ->
emit Text.empty (typeOf e) $ \m ->
syn (Let_ :$ e :* m :* End)
emitLet' :: (ABT Term abt) => abt '[] a -> Dis abt (abt '[] a)
emitLet' e =
caseVarSyn e (const $ return e) $ \t ->
case t of
Literal_ _ -> return e
_ -> do
x <- emit Text.empty (typeOf e) $ \m ->
syn (Let_ :$ e :* m :* End)
return (var x)
emitUnpair
:: (ABT Term abt)
=> Whnf abt (HPair a b)
-> Dis abt (abt '[] a, abt '[] b)
emitUnpair (Head_ w) = return $ reifyPair w
emitUnpair (Neutral e) = do
let (a,b) = sUnPair (typeOf e)
x <- freshVar Text.empty a
y <- freshVar Text.empty b
emitUnpair_ x y e
emitUnpair_
:: forall abt a b
. (ABT Term abt)
=> Variable a
-> Variable b
-> abt '[] (HPair a b)
-> Dis abt (abt '[] a, abt '[] b)
emitUnpair_ x y = loop
where
done :: abt '[] (HPair a b) -> Dis abt (abt '[] a, abt '[] b)
done e =
#ifdef __TRACE_DISINTEGRATE__
trace "-- emitUnpair: done (term is not Datum_ nor Case_)" $
#endif
Dis $ \_ c h l ->
( syn
. Case_ e
. (:[])
. Branch (pPair PVar PVar)
. bind x
. bind y
) <$> c (var x, var y) h l
loop :: abt '[] (HPair a b) -> Dis abt (abt '[] a, abt '[] b)
loop e0 =
caseVarSyn e0 (done . var) $ \t ->
case t of
Datum_ d -> do
#ifdef __TRACE_DISINTEGRATE__
trace "-- emitUnpair: found Datum_" $ return ()
#endif
return $ reifyPair (WDatum d)
Case_ e bs -> do
#ifdef __TRACE_DISINTEGRATE__
trace "-- emitUnpair: going under Case_" $ return ()
#endif
emitCaseWith loop e bs
_ -> done e0
emit_
:: (ABT Term abt)
=> (forall r. abt '[] ('HMeasure r) -> abt '[] ('HMeasure r))
-> Dis abt ()
emit_ f = Dis $ \_ c h l -> f <$> c () h l
emitMBind_ :: (ABT Term abt) => abt '[] ('HMeasure HUnit) -> Dis abt ()
emitMBind_ m = emit_ (m P.>>)
emitGuard :: (ABT Term abt) => abt '[] HBool -> Dis abt ()
emitGuard b = emit_ (P.withGuard b)
emitWeight :: (ABT Term abt) => abt '[] 'HProb -> Dis abt ()
emitWeight w = emit_ (P.withWeight w)
emitFork_
:: (ABT Term abt, T.Traversable t)
=> (forall r. t (abt '[] ('HMeasure r)) -> abt '[] ('HMeasure r))
-> t (Dis abt a)
-> Dis abt a
emitFork_ f ms = Dis $ \i c h l -> f <$> T.traverse (\m -> unDis m i c h l) ms
emitSuperpose
:: (ABT Term abt)
=> [abt '[] ('HMeasure a)]
-> Dis abt (Variable a)
emitSuperpose [] = error "TODO: emitSuperpose[]"
emitSuperpose [e] = emitMBind e
emitSuperpose es =
emitMBind . P.superpose . NE.map ((,) P.one) $ NE.fromList es
choose :: (ABT Term abt) => [Dis abt a] -> Dis abt a
choose [] = error "TODO: choose[]"
choose [m] = m
choose ms = emitFork_ (P.superpose . NE.map ((,) P.one) . NE.fromList) ms
emitCaseWith
:: (ABT Term abt)
=> (abt '[] b -> Dis abt r)
-> abt '[] a
-> [Branch a abt b]
-> Dis abt r
emitCaseWith f e bs = do
gms <- T.for bs $ \(Branch pat body) ->
let (vars, body') = caseBinds body
in (\vars' ->
let rho = toAssocs1 vars vars'
in GBranch pat vars' (f $ renames rho body')
) <$> freshenVars vars
Dis $ \i c h l ->
(syn . Case_ e) <$> T.for gms (\gm ->
fromGBranch <$> T.for gm (\m ->
unDis m i c h l))