{-# LANGUAGE CPP
, GADTs
, KindSignatures
, DataKinds
, PolyKinds
, TypeOperators
, Rank2Types
, BangPatterns
, FlexibleContexts
, MultiParamTypeClasses
, FunctionalDependencies
, FlexibleInstances
, UndecidableInstances
, EmptyCase
, ScopedTypeVariables
#-}
{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
module Language.Hakaru.Evaluation.Types
(
Head(..), fromHead, toHead, viewHeadDatum
, Whnf(..), fromWhnf, toWhnf, caseWhnf, viewWhnfDatum
, Lazy(..), fromLazy, caseLazy
, getLazyVariable, isLazyVariable
, getLazyLiteral, isLazyLiteral
, TermEvaluator
, MeasureEvaluator
, CaseEvaluator
, VariableEvaluator
, Purity(..), Statement(..), statementVars, isBoundBy
, Index, indVar, indSize, fromIndex
, Location(..), locEq, locHint, locType, locations1
, fromLocation, fromLocations1, freshenLoc, freshenLocs
, LAssoc, LAssocs , emptyLAssocs, singletonLAssocs
, toLAssocs1, insertLAssocs, lookupLAssoc
#ifdef __TRACE_DISINTEGRATE__
, ppList
, ppInds
, ppStatement
, pretty_Statements
, pretty_Statements_withTerm
, prettyAssocs
#endif
, EvaluationMonad(..)
, defaultCaseEvaluator
, toVarStatements
, extSubst
, extSubsts
, freshVar
, freshenVar
, Hint(..), freshVars
, freshenVars
, freshInd
, push
, pushes
) where
import Prelude hiding (id, (.))
import Control.Category (Category(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (Monoid(..))
import Data.Functor ((<$>))
import Control.Applicative (Applicative(..))
import Data.Traversable
#endif
import Control.Arrow ((***))
import qualified Data.Foldable as F
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.Text as T
import Data.Text (Text)
import Data.Proxy (KProxy(..))
import Language.Hakaru.Syntax.IClasses
import Data.Number.Nat
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing (Sing(..))
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.DatumCase (DatumEvaluator,
MatchResult(..),
matchBranches)
import Language.Hakaru.Syntax.AST.Eq (alphaEq)
import Language.Hakaru.Syntax.ABT
import qualified Language.Hakaru.Syntax.Prelude as P
#ifdef __TRACE_DISINTEGRATE__
import qualified Text.PrettyPrint as PP
import Language.Hakaru.Pretty.Haskell
import Debug.Trace (trace)
#endif
data Head :: ([Hakaru] -> Hakaru -> *) -> Hakaru -> * where
WLiteral :: !(Literal a) -> Head abt a
WDatum :: !(Datum (abt '[]) (HData' t)) -> Head abt (HData' t)
WEmpty :: !(Sing ('HArray a)) -> Head abt ('HArray a)
WArray :: !(abt '[] 'HNat) -> !(abt '[ 'HNat] a) -> Head abt ('HArray a)
WArrayLiteral
:: [abt '[] a] -> Head abt ('HArray a)
WLam :: !(abt '[ a ] b) -> Head abt (a ':-> b)
WMeasureOp
:: (typs ~ UnLCs args, args ~ LCs typs)
=> !(MeasureOp typs a)
-> !(SArgs abt args)
-> Head abt ('HMeasure a)
WDirac :: !(abt '[] a) -> Head abt ('HMeasure a)
WMBind
:: !(abt '[] ('HMeasure a))
-> !(abt '[ a ] ('HMeasure b))
-> Head abt ('HMeasure b)
WPlate
:: !(abt '[] 'HNat)
-> !(abt '[ 'HNat ] ('HMeasure a))
-> Head abt ('HMeasure ('HArray a))
WChain
:: !(abt '[] 'HNat)
-> !(abt '[] s)
-> !(abt '[ s ] ('HMeasure (HPair a s)))
-> Head abt ('HMeasure (HPair ('HArray a) s))
WSuperpose
:: !(NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a)))
-> Head abt ('HMeasure a)
WReject
:: !(Sing ('HMeasure a)) -> Head abt ('HMeasure a)
WCoerceTo :: !(Coercion a b) -> !(Head abt a) -> Head abt b
WUnsafeFrom :: !(Coercion a b) -> !(Head abt b) -> Head abt a
WIntegrate
:: !(abt '[] 'HReal)
-> !(abt '[] 'HReal)
-> !(abt '[ 'HReal ] 'HProb)
-> Head abt 'HProb
fromHead :: (ABT Term abt) => Head abt a -> abt '[] a
fromHead (WLiteral v) = syn (Literal_ v)
fromHead (WDatum d) = syn (Datum_ d)
fromHead (WEmpty typ) = syn (Empty_ typ)
fromHead (WArray e1 e2) = syn (Array_ e1 e2)
fromHead (WArrayLiteral es) = syn (ArrayLiteral_ es)
fromHead (WLam e1) = syn (Lam_ :$ e1 :* End)
fromHead (WMeasureOp o es) = syn (MeasureOp_ o :$ es)
fromHead (WDirac e1) = syn (Dirac :$ e1 :* End)
fromHead (WMBind e1 e2) = syn (MBind :$ e1 :* e2 :* End)
fromHead (WPlate e1 e2) = syn (Plate :$ e1 :* e2 :* End)
fromHead (WChain e1 e2 e3) = syn (Chain :$ e1 :* e2 :* e3 :* End)
fromHead (WSuperpose pes) = syn (Superpose_ pes)
fromHead (WReject typ) = syn (Reject_ typ)
fromHead (WCoerceTo c e1) = syn (CoerceTo_ c :$ fromHead e1 :* End)
fromHead (WUnsafeFrom c e1) = syn (UnsafeFrom_ c :$ fromHead e1 :* End)
fromHead (WIntegrate e1 e2 e3) = syn (Integrate :$ e1 :* e2 :* e3 :* End)
toHead :: (ABT Term abt) => abt '[] a -> Maybe (Head abt a)
toHead e =
caseVarSyn e (const Nothing) $ \t ->
case t of
Literal_ v -> Just $ WLiteral v
Datum_ d -> Just $ WDatum d
Empty_ typ -> Just $ WEmpty typ
Array_ e1 e2 -> Just $ WArray e1 e2
ArrayLiteral_ es -> Just $ WArrayLiteral es
Lam_ :$ e1 :* End -> Just $ WLam e1
MeasureOp_ o :$ es -> Just $ WMeasureOp o es
Dirac :$ e1 :* End -> Just $ WDirac e1
MBind :$ e1 :* e2 :* End -> Just $ WMBind e1 e2
Plate :$ e1 :* e2 :* End -> Just $ WPlate e1 e2
Chain :$ e1 :* e2 :* e3 :* End -> Just $ WChain e1 e2 e3
Superpose_ pes -> Just $ WSuperpose pes
CoerceTo_ c :$ e1 :* End -> WCoerceTo c <$> toHead e1
UnsafeFrom_ c :$ e1 :* End -> WUnsafeFrom c <$> toHead e1
Integrate :$ e1 :* e2 :* e3 :* End -> Just $ WIntegrate e1 e2 e3
_ -> Nothing
instance Functor21 Head where
fmap21 _ (WLiteral v) = WLiteral v
fmap21 f (WDatum d) = WDatum (fmap11 f d)
fmap21 _ (WEmpty typ) = WEmpty typ
fmap21 f (WArray e1 e2) = WArray (f e1) (f e2)
fmap21 f (WArrayLiteral es) = WArrayLiteral (fmap f es)
fmap21 f (WLam e1) = WLam (f e1)
fmap21 f (WMeasureOp o es) = WMeasureOp o (fmap21 f es)
fmap21 f (WDirac e1) = WDirac (f e1)
fmap21 f (WMBind e1 e2) = WMBind (f e1) (f e2)
fmap21 f (WPlate e1 e2) = WPlate (f e1) (f e2)
fmap21 f (WChain e1 e2 e3) = WChain (f e1) (f e2) (f e3)
fmap21 f (WSuperpose pes) = WSuperpose (fmap (f *** f) pes)
fmap21 _ (WReject typ) = WReject typ
fmap21 f (WCoerceTo c e1) = WCoerceTo c (fmap21 f e1)
fmap21 f (WUnsafeFrom c e1) = WUnsafeFrom c (fmap21 f e1)
fmap21 f (WIntegrate e1 e2 e3) = WIntegrate (f e1) (f e2) (f e3)
instance Foldable21 Head where
foldMap21 _ (WLiteral _) = mempty
foldMap21 f (WDatum d) = foldMap11 f d
foldMap21 _ (WEmpty _) = mempty
foldMap21 f (WArray e1 e2) = f e1 `mappend` f e2
foldMap21 f (WArrayLiteral es) = F.foldMap f es
foldMap21 f (WLam e1) = f e1
foldMap21 f (WMeasureOp _ es) = foldMap21 f es
foldMap21 f (WDirac e1) = f e1
foldMap21 f (WMBind e1 e2) = f e1 `mappend` f e2
foldMap21 f (WPlate e1 e2) = f e1 `mappend` f e2
foldMap21 f (WChain e1 e2 e3) = f e1 `mappend` f e2 `mappend` f e3
foldMap21 f (WSuperpose pes) = foldMapPairs f pes
foldMap21 _ (WReject _) = mempty
foldMap21 f (WCoerceTo _ e1) = foldMap21 f e1
foldMap21 f (WUnsafeFrom _ e1) = foldMap21 f e1
foldMap21 f (WIntegrate e1 e2 e3) = f e1 `mappend` f e2 `mappend` f e3
instance Traversable21 Head where
traverse21 _ (WLiteral v) = pure $ WLiteral v
traverse21 f (WDatum d) = WDatum <$> traverse11 f d
traverse21 _ (WEmpty typ) = pure $ WEmpty typ
traverse21 f (WArray e1 e2) = WArray <$> f e1 <*> f e2
traverse21 f (WArrayLiteral es) = WArrayLiteral <$> traverse f es
traverse21 f (WLam e1) = WLam <$> f e1
traverse21 f (WMeasureOp o es) = WMeasureOp o <$> traverse21 f es
traverse21 f (WDirac e1) = WDirac <$> f e1
traverse21 f (WMBind e1 e2) = WMBind <$> f e1 <*> f e2
traverse21 f (WPlate e1 e2) = WPlate <$> f e1 <*> f e2
traverse21 f (WChain e1 e2 e3) = WChain <$> f e1 <*> f e2 <*> f e3
traverse21 f (WSuperpose pes) = WSuperpose <$> traversePairs f pes
traverse21 _ (WReject typ) = pure $ WReject typ
traverse21 f (WCoerceTo c e1) = WCoerceTo c <$> traverse21 f e1
traverse21 f (WUnsafeFrom c e1) = WUnsafeFrom c <$> traverse21 f e1
traverse21 f (WIntegrate e1 e2 e3) = WIntegrate <$> f e1 <*> f e2 <*> f e3
data Whnf (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
= Head_ !(Head abt a)
| Neutral !(abt '[] a)
fromWhnf :: (ABT Term abt) => Whnf abt a -> abt '[] a
fromWhnf (Head_ e) = fromHead e
fromWhnf (Neutral e) = e
toWhnf :: (ABT Term abt) => abt '[] a -> Maybe (Whnf abt a)
toWhnf e = Head_ <$> toHead e
caseWhnf :: Whnf abt a -> (Head abt a -> r) -> (abt '[] a -> r) -> r
caseWhnf (Head_ e) k _ = k e
caseWhnf (Neutral e) _ k = k e
viewWhnfDatum
:: (ABT Term abt)
=> Whnf abt (HData' t)
-> Maybe (Datum (abt '[]) (HData' t))
viewWhnfDatum (Head_ v) = Just $ viewHeadDatum v
viewWhnfDatum (Neutral _) = Nothing
viewHeadDatum
:: (ABT Term abt)
=> Head abt (HData' t)
-> Datum (abt '[]) (HData' t)
viewHeadDatum (WDatum d) = d
viewHeadDatum _ = error "viewHeadDatum: the impossible happened"
instance (ABT Term abt) => Coerce (Whnf abt) where
coerceTo c w =
case w of
Neutral e ->
Neutral . maybe (P.coerceTo_ c e) id
$ caseVarSyn e (const Nothing) $ \t ->
case t of
Literal_ x -> Just $ P.literal_ (coerceTo c x)
CoerceTo_ c' :$ es' ->
case es' of
e' :* End -> Just $ P.coerceTo_ (c . c') e'
_ -> Nothing
Head_ v ->
case v of
WLiteral x -> Head_ $ WLiteral (coerceTo c x)
WCoerceTo c' v' -> Head_ $ WCoerceTo (c . c') v'
_ -> Head_ $ WCoerceTo c v
coerceFrom c w =
case w of
Neutral e ->
Neutral . maybe (P.unsafeFrom_ c e) id
$ caseVarSyn e (const Nothing) $ \t ->
case t of
Literal_ x -> Just $ P.literal_ (coerceFrom c x)
UnsafeFrom_ c' :$ es' ->
case es' of
e' :* End -> Just $ P.unsafeFrom_ (c' . c) e'
_ -> Nothing
Head_ v ->
case v of
WLiteral x -> Head_ $ WLiteral (coerceFrom c x)
WUnsafeFrom c' v' -> Head_ $ WUnsafeFrom (c' . c) v'
_ -> Head_ $ WUnsafeFrom c v
data Lazy (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
= Whnf_ !(Whnf abt a)
| Thunk !(abt '[] a)
fromLazy :: (ABT Term abt) => Lazy abt a -> abt '[] a
fromLazy (Whnf_ e) = fromWhnf e
fromLazy (Thunk e) = e
caseLazy :: Lazy abt a -> (Whnf abt a -> r) -> (abt '[] a -> r) -> r
caseLazy (Whnf_ e) k _ = k e
caseLazy (Thunk e) _ k = k e
getLazyVariable :: (ABT Term abt) => Lazy abt a -> Maybe (Variable a)
getLazyVariable e =
case e of
Whnf_ (Head_ _) -> Nothing
Whnf_ (Neutral e') -> caseVarSyn e' Just (const Nothing)
Thunk e' -> caseVarSyn e' Just (const Nothing)
isLazyVariable :: (ABT Term abt) => Lazy abt a -> Bool
isLazyVariable = maybe False (const True) . getLazyVariable
getLazyLiteral :: (ABT Term abt) => Lazy abt a -> Maybe (Literal a)
getLazyLiteral e =
case e of
Whnf_ (Head_ (WLiteral v)) -> Just v
Whnf_ _ -> Nothing
Thunk e' ->
caseVarSyn e' (const Nothing) $ \t ->
case t of
Literal_ v -> Just v
_ -> Nothing
isLazyLiteral :: (ABT Term abt) => Lazy abt a -> Bool
isLazyLiteral = maybe False (const True) . getLazyLiteral
data Purity = Pure | Impure | ExpectP
deriving (Eq, Read, Show)
data Index ast = Ind (Variable 'HNat) (ast 'HNat)
instance (ABT Term abt) => Eq (Index (abt '[])) where
Ind i1 s1 == Ind i2 s2 = i1 == i2 && (alphaEq s1 s2)
instance (ABT Term abt) => Ord (Index (abt '[])) where
compare (Ind i _) (Ind j _) = compare i j
indVar :: Index ast -> Variable 'HNat
indVar (Ind v _ ) = v
indSize :: Index ast -> ast 'HNat
indSize (Ind _ a) = a
fromIndex :: (ABT Term abt) => Index (abt '[]) -> abt '[] 'HNat
fromIndex (Ind v _) = var v
newtype Location (a :: k) = Location (Variable a)
instance Show (Sing a) => Show (Location a) where
show (Location v) = show v
locHint :: Location a -> Text
locHint (Location x) = varHint x
locType :: Location a -> Sing a
locType (Location x) = varType x
locEq :: (Show1 (Sing :: k -> *), JmEq1 (Sing :: k -> *))
=> Location (a :: k)
-> Location (b :: k)
-> Maybe (TypeEq a b)
locEq (Location a) (Location b) = varEq a b
fromLocation :: Location a -> Variable a
fromLocation (Location v) = v
fromLocations1 :: List1 Location a -> List1 Variable a
fromLocations1 = fmap11 fromLocation
locations1 :: List1 Variable a -> List1 Location a
locations1 = fmap11 Location
newtype LAssoc ast = LAssoc (Assoc ast)
newtype LAssocs ast = LAssocs (Assocs ast)
emptyLAssocs :: LAssocs abt
emptyLAssocs = LAssocs (emptyAssocs)
singletonLAssocs :: Location a -> f a -> LAssocs f
singletonLAssocs (Location v) e = LAssocs (singletonAssocs v e)
toLAssocs1 :: List1 Location xs -> List1 ast xs -> LAssocs ast
toLAssocs1 ls es = LAssocs (toAssocs1 (fromLocations1 ls) es)
insertLAssocs :: LAssocs ast -> LAssocs ast -> LAssocs ast
insertLAssocs (LAssocs a) (LAssocs b) = LAssocs (insertAssocs a b)
lookupLAssoc :: (Show1 (Sing :: k -> *), JmEq1 (Sing :: k -> *))
=> Location (a :: k)
-> LAssocs ast
-> Maybe (ast a)
lookupLAssoc (Location v) (LAssocs a) = lookupAssoc v a
data Statement :: ([Hakaru] -> Hakaru -> *) -> (Hakaru -> *) -> Purity -> * where
SBind
:: forall abt (v :: Hakaru -> *) (a :: Hakaru)
. {-# UNPACK #-} !(v a)
-> !(Lazy abt ('HMeasure a))
-> [Index (abt '[])]
-> Statement abt v 'Impure
SLet
:: forall abt p (v :: Hakaru -> *) (a :: Hakaru)
. {-# UNPACK #-} !(v a)
-> !(Lazy abt a)
-> [Index (abt '[])]
-> Statement abt v p
SWeight
:: forall abt (v :: Hakaru -> *)
. !(Lazy abt 'HProb)
-> [Index (abt '[])]
-> Statement abt v 'Impure
SGuard
:: forall abt (v :: Hakaru -> *) (xs :: [Hakaru]) (a :: Hakaru)
. !(List1 v xs)
-> !(Pattern xs a)
-> !(Lazy abt a)
-> [Index (abt '[])]
-> Statement abt v 'Impure
SStuff0
:: forall abt (v :: Hakaru -> *)
. (abt '[] 'HProb -> abt '[] 'HProb)
-> [Index (abt '[])]
-> Statement abt v 'ExpectP
SStuff1
:: forall abt (v :: Hakaru -> *) (a :: Hakaru)
. {-# UNPACK #-} !(v a)
-> (abt '[] 'HProb -> abt '[] 'HProb)
-> [Index (abt '[])]
-> Statement abt v 'ExpectP
statementVars :: Statement abt Location p -> VarSet ('KProxy :: KProxy Hakaru)
statementVars (SBind x _ _) = singletonVarSet (fromLocation x)
statementVars (SLet x _ _) = singletonVarSet (fromLocation x)
statementVars (SWeight _ _) = emptyVarSet
statementVars (SGuard xs _ _ _) = toVarSet1 (fromLocations1 xs)
statementVars (SStuff0 _ _) = emptyVarSet
statementVars (SStuff1 x _ _) = singletonVarSet (fromLocation x)
isBoundBy :: Location (a :: Hakaru) -> Statement abt Location p -> Maybe ()
x `isBoundBy` SBind y _ _ = const () <$> locEq x y
x `isBoundBy` SLet y _ _ = const () <$> locEq x y
_ `isBoundBy` SWeight _ _ = Nothing
x `isBoundBy` SGuard ys _ _ _ =
if memberVarSet (fromLocation x) (toVarSet1 (fmap11 fromLocation ys))
then Just ()
else Nothing
_ `isBoundBy` SStuff0 _ _ = Nothing
x `isBoundBy` SStuff1 y _ _ = const () <$> locEq x y
#ifdef __TRACE_DISINTEGRATE__
instance (ABT Term abt) => Pretty (Whnf abt) where
prettyPrec_ p (Head_ w) = ppApply1 p "Head_" (fromHead w)
prettyPrec_ p (Neutral e) = ppApply1 p "Neutral" e
instance (ABT Term abt) => Pretty (Lazy abt) where
prettyPrec_ p (Whnf_ w) = ppFun p "Whnf_" [PP.sep (prettyPrec_ 11 w)]
prettyPrec_ p (Thunk e) = ppApply1 p "Thunk" e
ppApply1 :: (ABT Term abt) => Int -> String -> abt '[] a -> [PP.Doc]
ppApply1 p f e1 =
let d = PP.text f PP.<+> PP.nest (1 + length f) (prettyPrec 11 e1)
in [if p > 9 then PP.parens (PP.nest 1 d) else d]
ppFun :: Int -> String -> [PP.Doc] -> [PP.Doc]
ppFun _ f [] = [PP.text f]
ppFun p f ds =
parens (p > 9) [PP.text f PP.<+> PP.nest (1 + length f) (PP.sep ds)]
parens :: Bool -> [PP.Doc] -> [PP.Doc]
parens True ds = [PP.parens (PP.nest 1 (PP.sep ds))]
parens False ds = ds
ppList :: [PP.Doc] -> PP.Doc
ppList = PP.sep . (:[]) . PP.brackets . PP.nest 1 . PP.fsep . PP.punctuate PP.comma
ppInds :: (ABT Term abt) => [Index (abt '[])] -> PP.Doc
ppInds = ppList . map (ppVariable . indVar)
ppStatement :: (ABT Term abt) => Int -> Statement abt Location p -> PP.Doc
ppStatement p s =
case s of
SBind (Location x) e inds ->
PP.sep $ ppFun p "SBind"
[ ppVariable x
, PP.sep $ prettyPrec_ 11 e
, ppInds inds
]
SLet (Location x) e inds ->
PP.sep $ ppFun p "SLet"
[ ppVariable x
, PP.sep $ prettyPrec_ 11 e
, ppInds inds
]
SWeight e inds ->
PP.sep $ ppFun p "SWeight"
[ PP.sep $ prettyPrec_ 11 e
, ppInds inds
]
SGuard xs pat e inds ->
PP.sep $ ppFun p "SGuard"
[ PP.sep $ ppVariables (fromLocations1 xs)
, PP.sep $ prettyPrec_ 11 pat
, PP.sep $ prettyPrec_ 11 e
, ppInds inds
]
SStuff0 _ _ ->
PP.sep $ ppFun p "SStuff0"
[ PP.text "TODO: ppStatement{SStuff0}"
]
SStuff1 _ _ _ ->
PP.sep $ ppFun p "SStuff1"
[ PP.text "TODO: ppStatement{SStuff1}"
]
pretty_Statements :: (ABT Term abt) => [Statement abt Location p] -> PP.Doc
pretty_Statements [] = PP.text "[]"
pretty_Statements (s:ss) =
foldl
(\d s' -> d PP.$+$ PP.comma PP.<+> ppStatement 0 s')
(PP.text "[" PP.<+> ppStatement 0 s)
ss
PP.$+$ PP.text "]"
pretty_Statements_withTerm
:: (ABT Term abt) => [Statement abt Location p] -> abt '[] a -> PP.Doc
pretty_Statements_withTerm ss e =
pretty_Statements ss PP.$+$ pretty e
prettyAssocs
:: (ABT Term abt)
=> Assocs (abt '[])
-> PP.Doc
prettyAssocs a = PP.vcat $ map go (fromAssocs a)
where go (Assoc x e) = ppVariable x PP.<+>
PP.text "->" PP.<+>
pretty e
#endif
type TermEvaluator abt m =
forall a. abt '[] a -> m (Whnf abt a)
type MeasureEvaluator abt m =
forall a. abt '[] ('HMeasure a) -> m (Whnf abt a)
type CaseEvaluator abt m =
forall a b. abt '[] a -> [Branch a abt b] -> m (Whnf abt b)
type VariableEvaluator abt m =
forall a. Variable a -> m (Whnf abt a)
class (Functor m, Applicative m, Monad m, ABT Term abt)
=> EvaluationMonad abt m p | m -> abt p
where
freshNat :: m Nat
freshLocStatement
:: Statement abt Variable p
-> m (Statement abt Location p, Assocs (Variable :: Hakaru -> *))
freshLocStatement s =
case s of
SWeight w e -> return (SWeight w e, mempty)
SBind x body i -> do
x' <- freshenVar x
return (SBind (Location x') body i, singletonAssocs x x')
SLet x body i -> do
x' <- freshenVar x
return (SLet (Location x') body i, singletonAssocs x x')
SGuard xs pat scrutinee i -> do
xs' <- freshenVars xs
return (SGuard (locations1 xs') pat scrutinee i,
toAssocs1 xs xs')
SStuff0 e e' -> return (SStuff0 e e', mempty)
SStuff1 x f i -> do
x' <- freshenVar x
return (SStuff1 (Location x') f i, singletonAssocs x x')
getIndices :: m [Index (abt '[])]
getIndices = return []
unsafePush :: Statement abt Location p -> m ()
unsafePushes :: [Statement abt Location p] -> m ()
unsafePushes = mapM_ unsafePush
select
:: Location (a :: Hakaru)
-> (Statement abt Location p -> Maybe (m r))
-> m (Maybe r)
substVar :: Variable a -> abt '[] a
-> (forall b'. Variable b' -> m (abt '[] b'))
substVar _ _ = return . var
extFreeVars :: abt xs a -> m (VarSet (KindOf a))
extFreeVars e = return (freeVars e)
evaluateCase :: TermEvaluator abt m -> CaseEvaluator abt m
{-# INLINE evaluateCase #-}
evaluateCase = defaultCaseEvaluator
evaluateVar :: MeasureEvaluator abt m
-> TermEvaluator abt m
-> VariableEvaluator abt m
evaluateVar perform evaluate_ = \x ->
fmap (maybe (Neutral $ var x) id) . select (Location x) $ \s ->
case s of
SBind y e i -> do
Refl <- locEq (Location x) y
Just $ do
w <- perform $ caseLazy e fromWhnf id
unsafePush (SLet (Location x) (Whnf_ w) i)
#ifdef __TRACE_DISINTEGRATE__
trace ("-- updated "
++ show (ppStatement 11 s)
++ " to "
++ show (ppStatement 11 (SLet (Location x) (Whnf_ w) i))
) $ return ()
#endif
return w
SLet y e i -> do
Refl <- locEq (Location x) y
Just $ do
w <- caseLazy e return evaluate_
unsafePush (SLet (Location x) (Whnf_ w) i)
return w
SWeight _ _ -> Nothing
SStuff0 _ _ -> Nothing
SStuff1 _ _ _ -> Just . return . Neutral $ var x
SGuard _ _ _ _ -> Just . return . Neutral $ var x
defaultCaseEvaluator
:: forall abt m p
. (ABT Term abt, EvaluationMonad abt m p)
=> TermEvaluator abt m
-> CaseEvaluator abt m
{-# INLINE defaultCaseEvaluator #-}
defaultCaseEvaluator evaluate_ = evaluateCase_
where
evaluateDatum :: DatumEvaluator (abt '[]) m
evaluateDatum e = viewWhnfDatum <$> evaluate_ e
evaluateCase_ :: CaseEvaluator abt m
evaluateCase_ e bs = do
match <- matchBranches evaluateDatum e bs
case match of
Nothing ->
error "defaultCaseEvaluator: non-exhaustive patterns in case!"
Just GotStuck ->
return . Neutral . syn $ Case_ e bs
Just (Matched ss body) ->
pushes (toVarStatements ss) body >>= evaluate_
toVarStatements :: Assocs (abt '[]) -> [Statement abt Variable p]
toVarStatements = map (\(Assoc x e) -> SLet x (Thunk e) []) .
fromAssocs
extSubst
:: forall abt a xs b m p. (EvaluationMonad abt m p)
=> Variable a
-> abt '[] a
-> abt xs b
-> m (abt xs b)
extSubst x e = substM x e (substVar x e)
extSubsts
:: forall abt a xs m p. (EvaluationMonad abt m p)
=> Assocs (abt '[])
-> abt xs a
-> m (abt xs a)
extSubsts rho0 e0 =
F.foldlM (\e (Assoc x v) -> extSubst x v e) e0 (unAssocs rho0)
freshVar
:: (EvaluationMonad abt m p)
=> Text
-> Sing (a :: Hakaru)
-> m (Variable a)
freshVar hint typ = (\i -> Variable hint i typ) <$> freshNat
data Hint (a :: Hakaru) = Hint {-# UNPACK #-} !Text !(Sing a)
freshVars
:: (EvaluationMonad abt m p)
=> List1 Hint xs
-> m (List1 Variable xs)
freshVars Nil1 = return Nil1
freshVars (Cons1 x xs) = Cons1 <$> freshVar' x <*> freshVars xs
where
freshVar' (Hint hint typ) = freshVar hint typ
freshenVar
:: (EvaluationMonad abt m p)
=> Variable (a :: Hakaru)
-> m (Variable a)
freshenVar x = (\i -> x{varID=i}) <$> freshNat
freshenVars
:: (EvaluationMonad abt m p)
=> List1 Variable (xs :: [Hakaru])
-> m (List1 Variable xs)
freshenVars Nil1 = return Nil1
freshenVars (Cons1 x xs) = Cons1 <$> freshenVar x <*> freshenVars xs
freshInd :: (EvaluationMonad abt m p)
=> abt '[] 'HNat
-> m (Index (abt '[]))
freshInd s = do
x <- freshVar T.empty SNat
return $ Ind x s
freshenLoc :: (EvaluationMonad abt m p)
=> Location (a :: Hakaru) -> m (Location a)
freshenLoc (Location x) = Location <$> freshenVar x
freshenLocs :: (EvaluationMonad abt m p)
=> List1 Location (ls :: [Hakaru])
-> m (List1 Location ls)
freshenLocs Nil1 = return Nil1
freshenLocs (Cons1 l ls) = Cons1 <$> freshenLoc l <*> freshenLocs ls
push_
:: (ABT Term abt, EvaluationMonad abt m p)
=> Statement abt Variable p
-> m (Assocs (Variable :: Hakaru -> *))
push_ s = do
(s',rho) <- freshLocStatement s
unsafePush s'
return rho
push
:: (ABT Term abt, EvaluationMonad abt m p)
=> Statement abt Variable p
-> abt xs a
-> m (abt xs a)
push s e = do
rho <- push_ s
return (renames rho e)
pushes
:: (ABT Term abt, EvaluationMonad abt m p)
=> [Statement abt Variable p]
-> abt xs a
-> m (abt xs a)
pushes ss e = do
rho <- F.foldlM (\rho s -> mappend rho <$> push_ s) mempty ss
return (renames rho e)