module Language.Hakaru.Evaluation.Types
(
Head(..), fromHead, toHead, viewHeadDatum
, Whnf(..), fromWhnf, toWhnf, caseWhnf, viewWhnfDatum
, Lazy(..), fromLazy, caseLazy
, getLazyVariable, isLazyVariable
, getLazyLiteral, isLazyLiteral
, Purity(..), Statement(..), statementVars, isBoundBy
, Index, indVar, indSize
#ifdef __TRACE_DISINTEGRATE__
, ppList
, ppInds
, ppStatement
, pretty_Statements
, pretty_Statements_withTerm
, prettyAssocs
#endif
, EvaluationMonad(..)
, freshVar
, freshenVar
, Hint(..), freshVars
, freshenVars
, freshInd
, push
, pushes
) where
import Prelude hiding (id, (.))
import Control.Category (Category(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (Monoid(..))
import Data.Functor ((<$>))
import Control.Applicative (Applicative(..))
#endif
import Control.Arrow ((***))
import qualified Data.Foldable as F
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.Text as T
import Data.Text (Text)
import Data.Proxy (KProxy(..))
import Language.Hakaru.Syntax.IClasses
import Data.Number.Nat
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing (Sing(..))
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.AST.Eq (alphaEq)
import Language.Hakaru.Syntax.ABT
import qualified Language.Hakaru.Syntax.Prelude as P
#ifdef __TRACE_DISINTEGRATE__
import qualified Text.PrettyPrint as PP
import Language.Hakaru.Pretty.Haskell
#endif
data Head :: ([Hakaru] -> Hakaru -> *) -> Hakaru -> * where
WLiteral :: !(Literal a) -> Head abt a
WDatum :: !(Datum (abt '[]) (HData' t)) -> Head abt (HData' t)
WEmpty :: !(Sing ('HArray a)) -> Head abt ('HArray a)
WArray :: !(abt '[] 'HNat) -> !(abt '[ 'HNat] a) -> Head abt ('HArray a)
WLam :: !(abt '[ a ] b) -> Head abt (a ':-> b)
WMeasureOp
:: (typs ~ UnLCs args, args ~ LCs typs)
=> !(MeasureOp typs a)
-> !(SArgs abt args)
-> Head abt ('HMeasure a)
WDirac :: !(abt '[] a) -> Head abt ('HMeasure a)
WMBind
:: !(abt '[] ('HMeasure a))
-> !(abt '[ a ] ('HMeasure b))
-> Head abt ('HMeasure b)
WPlate
:: !(abt '[] 'HNat)
-> !(abt '[ 'HNat ] ('HMeasure a))
-> Head abt ('HMeasure ('HArray a))
WChain
:: !(abt '[] 'HNat)
-> !(abt '[] s)
-> !(abt '[ s ] ('HMeasure (HPair a s)))
-> Head abt ('HMeasure (HPair ('HArray a) s))
WSuperpose
:: !(NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a)))
-> Head abt ('HMeasure a)
WReject
:: !(Sing ('HMeasure a)) -> Head abt ('HMeasure a)
WCoerceTo :: !(Coercion a b) -> !(Head abt a) -> Head abt b
WUnsafeFrom :: !(Coercion a b) -> !(Head abt b) -> Head abt a
WIntegrate
:: !(abt '[] 'HReal)
-> !(abt '[] 'HReal)
-> !(abt '[ 'HReal ] 'HProb)
-> Head abt 'HProb
fromHead :: (ABT Term abt) => Head abt a -> abt '[] a
fromHead (WLiteral v) = syn (Literal_ v)
fromHead (WDatum d) = syn (Datum_ d)
fromHead (WEmpty typ) = syn (Empty_ typ)
fromHead (WArray e1 e2) = syn (Array_ e1 e2)
fromHead (WLam e1) = syn (Lam_ :$ e1 :* End)
fromHead (WMeasureOp o es) = syn (MeasureOp_ o :$ es)
fromHead (WDirac e1) = syn (Dirac :$ e1 :* End)
fromHead (WMBind e1 e2) = syn (MBind :$ e1 :* e2 :* End)
fromHead (WPlate e1 e2) = syn (Plate :$ e1 :* e2 :* End)
fromHead (WChain e1 e2 e3) = syn (Chain :$ e1 :* e2 :* e3 :* End)
fromHead (WSuperpose pes) = syn (Superpose_ pes)
fromHead (WReject typ) = syn (Reject_ typ)
fromHead (WCoerceTo c e1) = syn (CoerceTo_ c :$ fromHead e1 :* End)
fromHead (WUnsafeFrom c e1) = syn (UnsafeFrom_ c :$ fromHead e1 :* End)
fromHead (WIntegrate e1 e2 e3) = syn (Integrate :$ e1 :* e2 :* e3 :* End)
toHead :: (ABT Term abt) => abt '[] a -> Maybe (Head abt a)
toHead e =
caseVarSyn e (const Nothing) $ \t ->
case t of
Literal_ v -> Just $ WLiteral v
Datum_ d -> Just $ WDatum d
Empty_ typ -> Just $ WEmpty typ
Array_ e1 e2 -> Just $ WArray e1 e2
Lam_ :$ e1 :* End -> Just $ WLam e1
MeasureOp_ o :$ es -> Just $ WMeasureOp o es
Dirac :$ e1 :* End -> Just $ WDirac e1
MBind :$ e1 :* e2 :* End -> Just $ WMBind e1 e2
Plate :$ e1 :* e2 :* End -> Just $ WPlate e1 e2
Chain :$ e1 :* e2 :* e3 :* End -> Just $ WChain e1 e2 e3
Superpose_ pes -> Just $ WSuperpose pes
CoerceTo_ c :$ e1 :* End -> WCoerceTo c <$> toHead e1
UnsafeFrom_ c :$ e1 :* End -> WUnsafeFrom c <$> toHead e1
Integrate :$ e1 :* e2 :* e3 :* End -> Just $ WIntegrate e1 e2 e3
_ -> Nothing
instance Functor21 Head where
fmap21 _ (WLiteral v) = WLiteral v
fmap21 f (WDatum d) = WDatum (fmap11 f d)
fmap21 _ (WEmpty typ) = WEmpty typ
fmap21 f (WArray e1 e2) = WArray (f e1) (f e2)
fmap21 f (WLam e1) = WLam (f e1)
fmap21 f (WMeasureOp o es) = WMeasureOp o (fmap21 f es)
fmap21 f (WDirac e1) = WDirac (f e1)
fmap21 f (WMBind e1 e2) = WMBind (f e1) (f e2)
fmap21 f (WPlate e1 e2) = WPlate (f e1) (f e2)
fmap21 f (WChain e1 e2 e3) = WChain (f e1) (f e2) (f e3)
fmap21 f (WSuperpose pes) = WSuperpose (fmap (f *** f) pes)
fmap21 _ (WReject typ) = WReject typ
fmap21 f (WCoerceTo c e1) = WCoerceTo c (fmap21 f e1)
fmap21 f (WUnsafeFrom c e1) = WUnsafeFrom c (fmap21 f e1)
fmap21 f (WIntegrate e1 e2 e3) = WIntegrate (f e1) (f e2) (f e3)
instance Foldable21 Head where
foldMap21 _ (WLiteral _) = mempty
foldMap21 f (WDatum d) = foldMap11 f d
foldMap21 _ (WEmpty _) = mempty
foldMap21 f (WArray e1 e2) = f e1 `mappend` f e2
foldMap21 f (WLam e1) = f e1
foldMap21 f (WMeasureOp _ es) = foldMap21 f es
foldMap21 f (WDirac e1) = f e1
foldMap21 f (WMBind e1 e2) = f e1 `mappend` f e2
foldMap21 f (WPlate e1 e2) = f e1 `mappend` f e2
foldMap21 f (WChain e1 e2 e3) = f e1 `mappend` f e2 `mappend` f e3
foldMap21 f (WSuperpose pes) = foldMapPairs f pes
foldMap21 _ (WReject _) = mempty
foldMap21 f (WCoerceTo _ e1) = foldMap21 f e1
foldMap21 f (WUnsafeFrom _ e1) = foldMap21 f e1
foldMap21 f (WIntegrate e1 e2 e3) = f e1 `mappend` f e2 `mappend` f e3
instance Traversable21 Head where
traverse21 _ (WLiteral v) = pure $ WLiteral v
traverse21 f (WDatum d) = WDatum <$> traverse11 f d
traverse21 _ (WEmpty typ) = pure $ WEmpty typ
traverse21 f (WArray e1 e2) = WArray <$> f e1 <*> f e2
traverse21 f (WLam e1) = WLam <$> f e1
traverse21 f (WMeasureOp o es) = WMeasureOp o <$> traverse21 f es
traverse21 f (WDirac e1) = WDirac <$> f e1
traverse21 f (WMBind e1 e2) = WMBind <$> f e1 <*> f e2
traverse21 f (WPlate e1 e2) = WPlate <$> f e1 <*> f e2
traverse21 f (WChain e1 e2 e3) = WChain <$> f e1 <*> f e2 <*> f e3
traverse21 f (WSuperpose pes) = WSuperpose <$> traversePairs f pes
traverse21 _ (WReject typ) = pure $ WReject typ
traverse21 f (WCoerceTo c e1) = WCoerceTo c <$> traverse21 f e1
traverse21 f (WUnsafeFrom c e1) = WUnsafeFrom c <$> traverse21 f e1
traverse21 f (WIntegrate e1 e2 e3) = WIntegrate <$> f e1 <*> f e2 <*> f e3
data Whnf (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
= Head_ !(Head abt a)
| Neutral !(abt '[] a)
fromWhnf :: (ABT Term abt) => Whnf abt a -> abt '[] a
fromWhnf (Head_ e) = fromHead e
fromWhnf (Neutral e) = e
toWhnf :: (ABT Term abt) => abt '[] a -> Maybe (Whnf abt a)
toWhnf e = Head_ <$> toHead e
caseWhnf :: Whnf abt a -> (Head abt a -> r) -> (abt '[] a -> r) -> r
caseWhnf (Head_ e) k _ = k e
caseWhnf (Neutral e) _ k = k e
instance Functor21 Whnf where
fmap21 f (Head_ v) = Head_ (fmap21 f v)
fmap21 f (Neutral e) = Neutral (f e)
instance Foldable21 Whnf where
foldMap21 f (Head_ v) = foldMap21 f v
foldMap21 f (Neutral e) = f e
instance Traversable21 Whnf where
traverse21 f (Head_ v) = Head_ <$> traverse21 f v
traverse21 f (Neutral e) = Neutral <$> f e
viewWhnfDatum
:: (ABT Term abt)
=> Whnf abt (HData' t)
-> Maybe (Datum (abt '[]) (HData' t))
viewWhnfDatum (Head_ v) = Just $ viewHeadDatum v
viewWhnfDatum (Neutral _) = Nothing
viewHeadDatum
:: (ABT Term abt)
=> Head abt (HData' t)
-> Datum (abt '[]) (HData' t)
viewHeadDatum (WDatum d) = d
viewHeadDatum _ = error "viewHeadDatum: the impossible happened"
instance (ABT Term abt) => Coerce (Whnf abt) where
coerceTo c w =
case w of
Neutral e ->
Neutral . maybe (P.coerceTo_ c e) id
$ caseVarSyn e (const Nothing) $ \t ->
case t of
Literal_ x -> Just $ P.literal_ (coerceTo c x)
CoerceTo_ c' :$ es' ->
case es' of
e' :* End -> Just $ P.coerceTo_ (c . c') e'
_ -> error "coerceTo@Whnf: the impossible happened"
_ -> Nothing
Head_ v ->
case v of
WLiteral x -> Head_ $ WLiteral (coerceTo c x)
WCoerceTo c' v' -> Head_ $ WCoerceTo (c . c') v'
_ -> Head_ $ WCoerceTo c v
coerceFrom c w =
case w of
Neutral e ->
Neutral . maybe (P.unsafeFrom_ c e) id
$ caseVarSyn e (const Nothing) $ \t ->
case t of
Literal_ x -> Just $ P.literal_ (coerceFrom c x)
UnsafeFrom_ c' :$ es' ->
case es' of
e' :* End -> Just $ P.unsafeFrom_ (c' . c) e'
_ -> error "unsafeFrom@Whnf: the impossible happened"
_ -> Nothing
Head_ v ->
case v of
WLiteral x -> Head_ $ WLiteral (coerceFrom c x)
WUnsafeFrom c' v' -> Head_ $ WUnsafeFrom (c' . c) v'
_ -> Head_ $ WUnsafeFrom c v
data Lazy (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
= Whnf_ !(Whnf abt a)
| Thunk !(abt '[] a)
fromLazy :: (ABT Term abt) => Lazy abt a -> abt '[] a
fromLazy (Whnf_ e) = fromWhnf e
fromLazy (Thunk e) = e
caseLazy :: Lazy abt a -> (Whnf abt a -> r) -> (abt '[] a -> r) -> r
caseLazy (Whnf_ e) k _ = k e
caseLazy (Thunk e) _ k = k e
instance Functor21 Lazy where
fmap21 f (Whnf_ v) = Whnf_ (fmap21 f v)
fmap21 f (Thunk e) = Thunk (f e)
instance Foldable21 Lazy where
foldMap21 f (Whnf_ v) = foldMap21 f v
foldMap21 f (Thunk e) = f e
instance Traversable21 Lazy where
traverse21 f (Whnf_ v) = Whnf_ <$> traverse21 f v
traverse21 f (Thunk e) = Thunk <$> f e
getLazyVariable :: (ABT Term abt) => Lazy abt a -> Maybe (Variable a)
getLazyVariable e =
case e of
Whnf_ (Head_ _) -> Nothing
Whnf_ (Neutral e') -> caseVarSyn e' Just (const Nothing)
Thunk e' -> caseVarSyn e' Just (const Nothing)
isLazyVariable :: (ABT Term abt) => Lazy abt a -> Bool
isLazyVariable = maybe False (const True) . getLazyVariable
getLazyLiteral :: (ABT Term abt) => Lazy abt a -> Maybe (Literal a)
getLazyLiteral e =
case e of
Whnf_ (Head_ (WLiteral v)) -> Just v
Whnf_ _ -> Nothing
Thunk e' ->
caseVarSyn e' (const Nothing) $ \t ->
case t of
Literal_ v -> Just v
_ -> Nothing
isLazyLiteral :: (ABT Term abt) => Lazy abt a -> Bool
isLazyLiteral = maybe False (const True) . getLazyLiteral
data Purity = Pure | Impure | ExpectP
deriving (Eq, Read, Show)
data Index ast = Ind (Variable 'HNat) (ast 'HNat)
instance (ABT Term abt) => Eq (Index (abt '[])) where
Ind i1 s1 == Ind i2 s2 = i1 == i2 && (alphaEq s1 s2)
instance (ABT Term abt) => Ord (Index (abt '[])) where
compare (Ind i _) (Ind j _) = compare i j
indVar :: Index ast -> Variable 'HNat
indVar (Ind v _ ) = v
indSize :: Index ast -> ast 'HNat
indSize (Ind _ a) = a
data Statement :: ([Hakaru] -> Hakaru -> *) -> Purity -> * where
SBind
:: forall abt (a :: Hakaru)
. !(Variable a)
-> !(Lazy abt ('HMeasure a))
-> [Index (abt '[])]
-> Statement abt 'Impure
SLet
:: forall abt p (a :: Hakaru)
. !(Variable a)
-> !(Lazy abt a)
-> [Index (abt '[])]
-> Statement abt p
SWeight
:: forall abt
. !(Lazy abt 'HProb)
-> [Index (abt '[])]
-> Statement abt 'Impure
SGuard
:: forall abt (xs :: [Hakaru]) (a :: Hakaru)
. !(List1 Variable xs)
-> !(Pattern xs a)
-> !(Lazy abt a)
-> [Index (abt '[])]
-> Statement abt 'Impure
SStuff0
:: forall abt
. (abt '[] 'HProb -> abt '[] 'HProb)
-> [Index (abt '[])]
-> Statement abt 'ExpectP
SStuff1
:: forall abt (a :: Hakaru)
. !(Variable a)
-> (abt '[] 'HProb -> abt '[] 'HProb)
-> [Index (abt '[])]
-> Statement abt 'ExpectP
statementVars :: Statement abt p -> VarSet ('KProxy :: KProxy Hakaru)
statementVars (SBind x _ _) = singletonVarSet x
statementVars (SLet x _ _) = singletonVarSet x
statementVars (SWeight _ _) = emptyVarSet
statementVars (SGuard xs _ _ _) = toVarSet1 xs
statementVars (SStuff0 _ _) = emptyVarSet
statementVars (SStuff1 x _ _) = singletonVarSet x
isBoundBy :: Variable (a :: Hakaru) -> Statement abt p -> Maybe ()
x `isBoundBy` SBind y _ _ = const () <$> varEq x y
x `isBoundBy` SLet y _ _ = const () <$> varEq x y
_ `isBoundBy` SWeight _ _ = Nothing
x `isBoundBy` SGuard ys _ _ _ =
if memberVarSet x (toVarSet1 ys)
then Just ()
else Nothing
_ `isBoundBy` SStuff0 _ _ = Nothing
x `isBoundBy` SStuff1 y _ _ = const () <$> varEq x y
#ifdef __TRACE_DISINTEGRATE__
instance (ABT Term abt) => Pretty (Whnf abt) where
prettyPrec_ p (Head_ w) = ppApply1 p "Head_" (fromHead w)
prettyPrec_ p (Neutral e) = ppApply1 p "Neutral" e
instance (ABT Term abt) => Pretty (Lazy abt) where
prettyPrec_ p (Whnf_ w) = ppFun p "Whnf_" [PP.sep (prettyPrec_ 11 w)]
prettyPrec_ p (Thunk e) = ppApply1 p "Thunk" e
ppApply1 :: (ABT Term abt) => Int -> String -> abt '[] a -> [PP.Doc]
ppApply1 p f e1 =
let d = PP.text f PP.<+> PP.nest (1 + length f) (prettyPrec 11 e1)
in [if p > 9 then PP.parens (PP.nest 1 d) else d]
ppFun :: Int -> String -> [PP.Doc] -> [PP.Doc]
ppFun _ f [] = [PP.text f]
ppFun p f ds =
parens (p > 9) [PP.text f PP.<+> PP.nest (1 + length f) (PP.sep ds)]
parens :: Bool -> [PP.Doc] -> [PP.Doc]
parens True ds = [PP.parens (PP.nest 1 (PP.sep ds))]
parens False ds = ds
ppList :: [PP.Doc] -> PP.Doc
ppList = PP.sep . (:[]) . PP.brackets . PP.nest 1 . PP.fsep . PP.punctuate PP.comma
ppInds :: (ABT Term abt) => [Index (abt '[])] -> PP.Doc
ppInds = ppList . map (ppVariable . indVar)
ppStatement :: (ABT Term abt) => Int -> Statement abt p -> PP.Doc
ppStatement p s =
case s of
SBind x e inds ->
PP.sep $ ppFun p "SBind"
[ ppVariable x
, PP.sep $ prettyPrec_ 11 e
, ppInds inds
]
SLet x e inds ->
PP.sep $ ppFun p "SLet"
[ ppVariable x
, PP.sep $ prettyPrec_ 11 e
, ppInds inds
]
SWeight e inds ->
PP.sep $ ppFun p "SWeight"
[ PP.sep $ prettyPrec_ 11 e
, ppInds inds
]
SGuard xs pat e inds ->
PP.sep $ ppFun p "SGuard"
[ PP.sep $ ppVariables xs
, PP.sep $ prettyPrec_ 11 pat
, PP.sep $ prettyPrec_ 11 e
, ppInds inds
]
SStuff0 _ _ ->
PP.sep $ ppFun p "SStuff0"
[ PP.text "TODO: ppStatement{SStuff0}"
]
SStuff1 _ _ _ ->
PP.sep $ ppFun p "SStuff1"
[ PP.text "TODO: ppStatement{SStuff1}"
]
pretty_Statements :: (ABT Term abt) => [Statement abt p] -> PP.Doc
pretty_Statements [] = PP.text "[]"
pretty_Statements (s:ss) =
foldl
(\d s' -> d PP.$+$ PP.comma PP.<+> ppStatement 0 s')
(PP.text "[" PP.<+> ppStatement 0 s)
ss
PP.$+$ PP.text "]"
pretty_Statements_withTerm
:: (ABT Term abt) => [Statement abt p] -> abt '[] a -> PP.Doc
pretty_Statements_withTerm ss e =
pretty_Statements ss PP.$+$ pretty e
prettyAssocs
:: (ABT Term abt)
=> Assocs (abt '[])
-> PP.Doc
prettyAssocs a = PP.vcat $ map go (fromAssocs a)
where go (Assoc x e) = ppVariable x PP.<+>
PP.text "->" PP.<+>
pretty e
#endif
class (Functor m, Applicative m, Monad m, ABT Term abt)
=> EvaluationMonad abt m p | m -> abt p
where
freshNat :: m Nat
freshenStatement
:: Statement abt p
-> m (Statement abt p, Assocs (Variable :: Hakaru -> *))
freshenStatement s =
case s of
SWeight _ _ -> return (s, mempty)
SBind x body i -> do
x' <- freshenVar x
return (SBind x' body i, singletonAssocs x x')
SLet x body i -> do
x' <- freshenVar x
return (SLet x' body i, singletonAssocs x x')
SGuard xs pat scrutinee i -> do
xs' <- freshenVars xs
return (SGuard xs' pat scrutinee i, toAssocs1 xs xs')
SStuff0 _ _ -> return (s, mempty)
SStuff1 x f i -> do
x' <- freshenVar x
return (SStuff1 x' f i, singletonAssocs x x')
getIndices :: m [Index (abt '[])]
getIndices = return []
unsafePush :: Statement abt p -> m ()
unsafePushes :: [Statement abt p] -> m ()
unsafePushes = mapM_ unsafePush
select
:: Variable (a :: Hakaru)
-> (Statement abt p -> Maybe (m r))
-> m (Maybe r)
freshVar
:: (EvaluationMonad abt m p)
=> Text
-> Sing (a :: Hakaru)
-> m (Variable a)
freshVar hint typ = (\i -> Variable hint i typ) <$> freshNat
data Hint (a :: Hakaru) = Hint !Text !(Sing a)
freshVars
:: (EvaluationMonad abt m p)
=> List1 Hint xs
-> m (List1 Variable xs)
freshVars Nil1 = return Nil1
freshVars (Cons1 x xs) = Cons1 <$> freshVar' x <*> freshVars xs
where
freshVar' (Hint hint typ) = freshVar hint typ
freshenVar
:: (EvaluationMonad abt m p)
=> Variable (a :: Hakaru)
-> m (Variable a)
freshenVar x = (\i -> x{varID=i}) <$> freshNat
freshenVars
:: (EvaluationMonad abt m p)
=> List1 Variable (xs :: [Hakaru])
-> m (List1 Variable xs)
freshenVars Nil1 = return Nil1
freshenVars (Cons1 x xs) = Cons1 <$> freshenVar x <*> freshenVars xs
freshInd :: (EvaluationMonad abt m p)
=> abt '[] 'HNat
-> m (Index (abt '[]))
freshInd s = do
x <- freshVar T.empty SNat
return $ Ind x s
push_
:: (ABT Term abt, EvaluationMonad abt m p)
=> Statement abt p
-> m (Assocs (Variable :: Hakaru -> *))
push_ s = do
(s',rho) <- freshenStatement s
unsafePush s'
return rho
push
:: (ABT Term abt, EvaluationMonad abt m p)
=> Statement abt p
-> abt xs a
-> (abt xs a -> m r)
-> m r
push s e k = do
rho <- push_ s
k (renames rho e)
pushes
:: (ABT Term abt, EvaluationMonad abt m p)
=> [Statement abt p]
-> abt xs a
-> (abt xs a -> m r)
-> m r
pushes ss e k = do
rho <- F.foldlM (\rho s -> mappend rho <$> push_ s) mempty ss
k (renames rho e)