module Language.Hakaru.Evaluation.Lazy
(
TermEvaluator
, MeasureEvaluator
, CaseEvaluator
, VariableEvaluator
, evaluate
, update
, defaultCaseEvaluator
, toStatements
, Interp(..), reifyPair
) where
import Prelude hiding (id, (.))
import Control.Category (Category(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Functor ((<$>))
#endif
import Control.Monad ((<=<))
import Control.Monad.Identity (Identity, runIdentity)
import Data.Sequence (Seq)
import qualified Data.Sequence as Seq
import qualified Data.Text as Text
import Language.Hakaru.Syntax.IClasses
import Data.Number.Nat
import Data.Number.Natural
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Types.HClasses
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.DatumCase (DatumEvaluator, MatchResult(..), matchBranches, MatchState(..), matchTopPattern)
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Evaluation.Types
import qualified Language.Hakaru.Syntax.Prelude as P
#ifdef __TRACE_DISINTEGRATE__
import Language.Hakaru.Pretty.Haskell (pretty)
import Debug.Trace (trace)
#endif
type TermEvaluator abt m =
forall a. abt '[] a -> m (Whnf abt a)
type MeasureEvaluator abt m =
forall a. abt '[] ('HMeasure a) -> m (Whnf abt a)
type CaseEvaluator abt m =
forall a b. abt '[] a -> [Branch a abt b] -> m (Whnf abt b)
type VariableEvaluator abt m =
forall a. Variable a -> m (Whnf abt a)
evaluate
:: forall abt m p
. (ABT Term abt, EvaluationMonad abt m p)
=> MeasureEvaluator abt m
-> (TermEvaluator abt m -> CaseEvaluator abt m)
-> TermEvaluator abt m
evaluate perform evaluateCase = evaluate_
where
evaluateCase_ :: CaseEvaluator abt m
evaluateCase_ = evaluateCase evaluate_
evaluate_ :: TermEvaluator abt m
evaluate_ e0 =
#ifdef __TRACE_DISINTEGRATE__
trace ("-- evaluate_: " ++ show (pretty e0)) $
#endif
caseVarSyn e0 (update perform evaluate_) $ \t ->
case t of
Literal_ v -> return . Head_ $ WLiteral v
Datum_ d -> return . Head_ $ WDatum d
Empty_ typ -> return . Head_ $ WEmpty typ
Array_ e1 e2 -> return . Head_ $ WArray e1 e2
Lam_ :$ e1 :* End -> return . Head_ $ WLam e1
Dirac :$ e1 :* End -> return . Head_ $ WDirac e1
MBind :$ e1 :* e2 :* End -> return . Head_ $ WMBind e1 e2
Plate :$ e1 :* e2 :* End -> return . Head_ $ WPlate e1 e2
MeasureOp_ o :$ es -> return . Head_ $ WMeasureOp o es
Superpose_ pes -> return . Head_ $ WSuperpose pes
Reject_ typ -> return . Head_ $ WReject typ
Integrate :$ e1 :* e2 :* e3 :* End ->
return . Head_ $ WIntegrate e1 e2 e3
Summate h1 h2 :$ e1 :* e2 :* e3 :* End ->
return . Neutral $ syn t
App_ :$ e1 :* e2 :* End -> do
w1 <- evaluate_ e1
case w1 of
Neutral e1' -> return . Neutral $ P.app e1' e2
Head_ v1 -> evaluateApp v1
where
evaluateApp (WLam f) =
caseBind f $ \x f' -> do
i <- getIndices
push (SLet x (Thunk e2) i) f' evaluate_
evaluateApp _ = error "evaluate{App_}: the impossible happened"
Let_ :$ e1 :* e2 :* End -> do
i <- getIndices
caseBind e2 $ \x e2' ->
push (SLet x (Thunk e1) i) e2' evaluate_
CoerceTo_ c :$ e1 :* End -> coerceTo c <$> evaluate_ e1
UnsafeFrom_ c :$ e1 :* End -> coerceFrom c <$> evaluate_ e1
NaryOp_ o es -> evaluateNaryOp evaluate_ o es
ArrayOp_ o :$ es -> evaluateArrayOp evaluate_ o es
PrimOp_ o :$ es -> evaluatePrimOp evaluate_ o es
Expect :$ e1 :* e2 :* End ->
error "TODO: evaluate{Expect}: unclear how to handle this without cyclic dependencies"
Case_ e bs -> evaluateCase_ e bs
_ :$ _ -> error "evaluate: the impossible happened"
defaultCaseEvaluator
:: forall abt m p
. (ABT Term abt, EvaluationMonad abt m p)
=> TermEvaluator abt m
-> CaseEvaluator abt m
defaultCaseEvaluator evaluate_ = evaluateCase_
where
evaluateDatum :: DatumEvaluator (abt '[]) m
evaluateDatum e = viewWhnfDatum <$> evaluate_ e
evaluateCase_ :: CaseEvaluator abt m
evaluateCase_ e bs = do
match <- matchBranches evaluateDatum e bs
case match of
Nothing ->
error "defaultCaseEvaluator: non-exhaustive patterns in case!"
Just GotStuck ->
return . Neutral . syn $ Case_ e bs
Just (Matched ss body) ->
pushes (toStatements ss) body evaluate_
toStatements :: Assocs (abt '[]) -> [Statement abt p]
toStatements = map (\(Assoc x e) -> SLet x (Thunk e) []) . fromAssocs
update
:: forall abt m p
. (ABT Term abt, EvaluationMonad abt m p)
=> MeasureEvaluator abt m
-> TermEvaluator abt m
-> VariableEvaluator abt m
update perform evaluate_ = \x ->
fmap (maybe (Neutral $ var x) id) . select x $ \s ->
case s of
SBind y e i -> do
Refl <- varEq x y
Just $ do
w <- perform $ caseLazy e fromWhnf id
unsafePush (SLet x (Whnf_ w) i)
#ifdef __TRACE_DISINTEGRATE__
trace ("-- updated "
++ show (ppStatement 11 s)
++ " to "
++ show (ppStatement 11 (SLet x (Whnf_ w) i))
) $ return ()
#endif
return w
SLet y e i -> do
Refl <- varEq x y
Just $ do
w <- caseLazy e return evaluate_
unsafePush (SLet x (Whnf_ w) i)
return w
SWeight _ _ -> Nothing
SStuff0 _ _ -> Nothing
SStuff1 _ _ _ -> Just . return . Neutral $ var x
SGuard ys pat scrutinee i -> Just . return . Neutral $ var x
class Interp a a' | a -> a' where
reify :: (ABT Term abt) => Head abt a -> a'
reflect :: (ABT Term abt) => a' -> Head abt a
instance Interp 'HNat Natural where
reflect = WLiteral . LNat
reify (WLiteral (LNat n)) = n
reify (WCoerceTo _ _) = error "TODO: reify{WCoerceTo}"
reify (WUnsafeFrom _ _) = error "TODO: reify{WUnsafeFrom}"
instance Interp 'HInt Integer where
reflect = WLiteral . LInt
reify (WLiteral (LInt i)) = i
reify (WCoerceTo _ _) = error "TODO: reify{WCoerceTo}"
reify (WUnsafeFrom _ _) = error "TODO: reify{WUnsafeFrom}"
instance Interp 'HProb NonNegativeRational where
reflect = WLiteral . LProb
reify (WLiteral (LProb p)) = p
reify (WCoerceTo _ _) = error "TODO: reify{WCoerceTo}"
reify (WUnsafeFrom _ _) = error "TODO: reify{WUnsafeFrom}"
reify (WIntegrate _ _ _) = error "TODO: reify{WIntegrate}"
instance Interp 'HReal Rational where
reflect = WLiteral . LReal
reify (WLiteral (LReal r)) = r
reify (WCoerceTo _ _) = error "TODO: reify{WCoerceTo}"
reify (WUnsafeFrom _ _) = error "TODO: reify{WUnsafeFrom}"
identifyDatum :: (ABT Term abt) => DatumEvaluator (abt '[]) Identity
identifyDatum = return . (viewWhnfDatum <=< toWhnf)
instance Interp HUnit () where
reflect () = WDatum dUnit
reify v = runIdentity $ do
match <- matchTopPattern identifyDatum (fromHead v) pUnit Nil1
case match of
Just (Matched_ _ss Nil1) -> return ()
_ -> error "reify{HUnit}: the impossible happened"
instance Interp HBool Bool where
reflect = WDatum . (\b -> if b then dTrue else dFalse)
reify v = runIdentity $ do
matchT <- matchTopPattern identifyDatum (fromHead v) pTrue Nil1
case matchT of
Just (Matched_ _ss Nil1) -> return True
Just GotStuck_ -> error "reify{HBool}: the impossible happened"
Nothing -> do
matchF <- matchTopPattern identifyDatum (fromHead v) pFalse Nil1
case matchF of
Just (Matched_ _ss Nil1) -> return False
_ -> error "reify{HBool}: the impossible happened"
reifyPair
:: (ABT Term abt) => Head abt (HPair a b) -> (abt '[] a, abt '[] b)
reifyPair v =
let impossible = error "reifyPair: the impossible happened"
e0 = fromHead v
n = nextFree e0
(a,b) = sUnPair $ typeOf e0
x = Variable Text.empty n a
y = Variable Text.empty (1 + n) b
in runIdentity $ do
match <- matchTopPattern identifyDatum e0 (pPair PVar PVar) (Cons1 x (Cons1 y Nil1))
case match of
Just (Matched_ ss Nil1) ->
case ss [] of
[Assoc x' e1, Assoc y' e2] ->
maybe impossible id $ do
Refl <- varEq x x'
Refl <- varEq y y'
Just $ return (e1, e2)
_ -> impossible
_ -> impossible
impl, diff, nand, nor :: Bool -> Bool -> Bool
impl x y = not x || y
diff x y = x && not y
nand x y = not (x && y)
nor x y = not (x || y)
natRoot :: (Floating a) => a -> Nat -> a
natRoot x y = x ** recip (fromIntegral (fromNat y))
evaluateNaryOp
:: (ABT Term abt, EvaluationMonad abt m p)
=> TermEvaluator abt m
-> NaryOp a
-> Seq (abt '[] a)
-> m (Whnf abt a)
evaluateNaryOp evaluate_ = \o es -> mainLoop o (evalOp o) Seq.empty es
where
mainLoop o op ws es =
case Seq.viewl es of
Seq.EmptyL -> return $
case Seq.viewl ws of
Seq.EmptyL -> identityElement o
w Seq.:< ws'
| Seq.null ws' -> w
| otherwise ->
Neutral . syn . NaryOp_ o $ fmap fromWhnf ws
e Seq.:< es' -> do
w <- evaluate_ e
case matchNaryOp o w of
Nothing -> mainLoop o op (snocLoop op ws w) es'
Just es2 -> mainLoop o op ws (es2 Seq.>< es')
snocLoop
:: (ABT syn abt)
=> (Head abt a -> Head abt a -> Head abt a)
-> Seq (Whnf abt a)
-> Whnf abt a
-> Seq (Whnf abt a)
snocLoop op ws w1 =
case Seq.viewr ws of
Seq.EmptyR -> Seq.singleton w1
ws' Seq.:> w2 ->
case (w1,w2) of
(Head_ v1, Head_ v2) -> snocLoop op ws' (Head_ (op v1 v2))
_ -> ws Seq.|> w1
matchNaryOp
:: (ABT Term abt)
=> NaryOp a
-> Whnf abt a
-> Maybe (Seq (abt '[] a))
matchNaryOp o w =
case w of
Head_ _ -> Nothing
Neutral e ->
caseVarSyn e (const Nothing) $ \t ->
case t of
NaryOp_ o' es | o' == o -> Just es
_ -> Nothing
identityElement :: (ABT Term abt) => NaryOp a -> Whnf abt a
identityElement o =
case o of
And -> Head_ (WDatum dTrue)
Or -> Head_ (WDatum dFalse)
Xor -> Head_ (WDatum dFalse)
Iff -> Head_ (WDatum dTrue)
Min _ -> Neutral (syn (NaryOp_ o Seq.empty))
Max _ -> Neutral (syn (NaryOp_ o Seq.empty))
Sum HSemiring_Nat -> Head_ (WLiteral (LNat 0))
Sum HSemiring_Int -> Head_ (WLiteral (LInt 0))
Sum HSemiring_Prob -> Head_ (WLiteral (LProb 0))
Sum HSemiring_Real -> Head_ (WLiteral (LReal 0))
Prod HSemiring_Nat -> Head_ (WLiteral (LNat 1))
Prod HSemiring_Int -> Head_ (WLiteral (LInt 1))
Prod HSemiring_Prob -> Head_ (WLiteral (LProb 1))
Prod HSemiring_Real -> Head_ (WLiteral (LReal 1))
evalOp
:: (ABT Term abt)
=> NaryOp a
-> Head abt a
-> Head abt a
-> Head abt a
evalOp And = \v1 v2 -> reflect (reify v1 && reify v2)
evalOp Or = \v1 v2 -> reflect (reify v1 || reify v2)
evalOp Xor = \v1 v2 -> reflect (reify v1 /= reify v2)
evalOp Iff = \v1 v2 -> reflect (reify v1 == reify v2)
evalOp (Min _) = error "TODO: evalOp{Min}"
evalOp (Max _) = error "TODO: evalOp{Max}"
evalOp (Sum theSemi) =
\(WLiteral v1) (WLiteral v2) -> WLiteral $ evalSum theSemi v1 v2
evalOp (Prod theSemi) =
\(WLiteral v1) (WLiteral v2) -> WLiteral $ evalProd theSemi v1 v2
evalSum, evalProd :: HSemiring a -> Literal a -> Literal a -> Literal a
evalSum HSemiring_Nat = \(LNat n1) (LNat n2) -> LNat (n1 + n2)
evalSum HSemiring_Int = \(LInt i1) (LInt i2) -> LInt (i1 + i2)
evalSum HSemiring_Prob = \(LProb p1) (LProb p2) -> LProb (p1 + p2)
evalSum HSemiring_Real = \(LReal r1) (LReal r2) -> LReal (r1 + r2)
evalProd HSemiring_Nat = \(LNat n1) (LNat n2) -> LNat (n1 * n2)
evalProd HSemiring_Int = \(LInt i1) (LInt i2) -> LInt (i1 * i2)
evalProd HSemiring_Prob = \(LProb p1) (LProb p2) -> LProb (p1 * p2)
evalProd HSemiring_Real = \(LReal r1) (LReal r2) -> LReal (r1 * r2)
evaluateArrayOp
:: ( ABT Term abt, EvaluationMonad abt m p
, typs ~ UnLCs args, args ~ LCs typs)
=> TermEvaluator abt m
-> ArrayOp typs a
-> SArgs abt args
-> m (Whnf abt a)
evaluateArrayOp evaluate_ = go
where
go o@(Index _) = \(e1 :* e2 :* End) -> do
w1 <- evaluate_ e1
case w1 of
Neutral e1' ->
return . Neutral $ syn (ArrayOp_ o :$ e1' :* e2 :* End)
Head_ v1 ->
error "TODO: evaluateArrayOp{Index}{Head_}"
go o@(Size _) = \(e1 :* End) -> do
w1 <- evaluate_ e1
case w1 of
Neutral e1' -> return . Neutral $ syn (ArrayOp_ o :$ e1' :* End)
Head_ v1 ->
case head2array v1 of
WAEmpty -> return . Head_ $ WLiteral (LNat 0)
WAArray e3 _ -> evaluate_ e3
go (Reduce _) = \(e1 :* e2 :* e3 :* End) ->
error "TODO: evaluateArrayOp{Reduce}"
data ArrayHead :: ([Hakaru] -> Hakaru -> *) -> Hakaru -> * where
WAEmpty :: ArrayHead abt a
WAArray
:: !(abt '[] 'HNat)
-> !(abt '[ 'HNat] a)
-> ArrayHead abt a
head2array :: Head abt ('HArray a) -> ArrayHead abt a
head2array (WEmpty _) = WAEmpty
head2array (WArray e1 e2) = WAArray e1 e2
evaluatePrimOp
:: forall abt m p typs args a
. ( ABT Term abt, EvaluationMonad abt m p
, typs ~ UnLCs args, args ~ LCs typs)
=> TermEvaluator abt m
-> PrimOp typs a
-> SArgs abt args
-> m (Whnf abt a)
evaluatePrimOp evaluate_ = go
where
neu1 :: forall b c
. (abt '[] b -> abt '[] c)
-> abt '[] b
-> m (Whnf abt c)
neu1 f e = (Neutral . f . fromWhnf) <$> evaluate_ e
neu2 :: forall b c d
. (abt '[] b -> abt '[] c -> abt '[] d)
-> abt '[] b
-> abt '[] c
-> m (Whnf abt d)
neu2 f e1 e2 = do e1' <- fromWhnf <$> evaluate_ e1
e2' <- fromWhnf <$> evaluate_ e2
return . Neutral $ f e1' e2'
rr1 :: forall b b' c c'
. (Interp b b', Interp c c')
=> (b' -> c')
-> (abt '[] b -> abt '[] c)
-> abt '[] b
-> m (Whnf abt c)
rr1 f' f e = do
w <- evaluate_ e
return $
case w of
Neutral e' -> Neutral $ f e'
Head_ v -> Head_ . reflect $ f' (reify v)
rr2 :: forall b b' c c' d d'
. (Interp b b', Interp c c', Interp d d')
=> (b' -> c' -> d')
-> (abt '[] b -> abt '[] c -> abt '[] d)
-> abt '[] b
-> abt '[] c
-> m (Whnf abt d)
rr2 f' f e1 e2 = do
w1 <- evaluate_ e1
w2 <- evaluate_ e2
return $
case w1 of
Neutral e1' -> Neutral $ f e1' (fromWhnf w2)
Head_ v1 ->
case w2 of
Neutral e2' -> Neutral $ f (fromWhnf w1) e2'
Head_ v2 -> Head_ . reflect $ f' (reify v1) (reify v2)
primOp2_
:: forall b c d
. PrimOp '[ b, c ] d -> abt '[] b -> abt '[] c -> abt '[] d
primOp2_ o e1 e2 = syn (PrimOp_ o :$ e1 :* e2 :* End)
go Not (e1 :* End) = rr1 not P.not e1
go Impl (e1 :* e2 :* End) = rr2 impl (primOp2_ Impl) e1 e2
go Diff (e1 :* e2 :* End) = rr2 diff (primOp2_ Diff) e1 e2
go Nand (e1 :* e2 :* End) = rr2 nand P.nand e1 e2
go Nor (e1 :* e2 :* End) = rr2 nor P.nor e1 e2
go Pi End = return $ Neutral P.pi
go Sin (e1 :* End) = neu1 P.sin e1
go Cos (e1 :* End) = neu1 P.cos e1
go Tan (e1 :* End) = neu1 P.tan e1
go Asin (e1 :* End) = neu1 P.asin e1
go Acos (e1 :* End) = neu1 P.acos e1
go Atan (e1 :* End) = neu1 P.atan e1
go Sinh (e1 :* End) = neu1 P.sinh e1
go Cosh (e1 :* End) = neu1 P.cosh e1
go Tanh (e1 :* End) = neu1 P.tanh e1
go Asinh (e1 :* End) = neu1 P.asinh e1
go Acosh (e1 :* End) = neu1 P.acosh e1
go Atanh (e1 :* End) = neu1 P.atanh e1
go RealPow (e1 :* e2 :* End) = neu2 (P.**) e1 e2
go Exp (e1 :* End) = neu1 P.exp e1
go Log (e1 :* End) = neu1 P.log e1
go (Infinity h) End =
case h of
HIntegrable_Nat -> return . Neutral $ P.primOp0_ (Infinity h)
HIntegrable_Prob -> return $ Neutral P.infinity
go GammaFunc (e1 :* End) = neu1 P.gammaFunc e1
go BetaFunc (e1 :* e2 :* End) = neu2 P.betaFunc e1 e2
go (Equal theEq) (e1 :* e2 :* End) = rrEqual theEq e1 e2
go (Less theOrd) (e1 :* e2 :* End) = rrLess theOrd e1 e2
go (NatPow theSemi) (e1 :* e2 :* End) =
case theSemi of
HSemiring_Nat -> rr2 (\v1 v2 -> v1 ^ fromNatural v2) (P.^) e1 e2
HSemiring_Int -> rr2 (\v1 v2 -> v1 ^ fromNatural v2) (P.^) e1 e2
HSemiring_Prob -> rr2 (\v1 v2 -> v1 ^ fromNatural v2) (P.^) e1 e2
HSemiring_Real -> rr2 (\v1 v2 -> v1 ^ fromNatural v2) (P.^) e1 e2
go (Negate theRing) (e1 :* End) =
case theRing of
HRing_Int -> rr1 negate P.negate e1
HRing_Real -> rr1 negate P.negate e1
go (Abs theRing) (e1 :* End) =
case theRing of
HRing_Int -> rr1 (unsafeNatural . abs) P.abs_ e1
HRing_Real -> rr1 (unsafeNonNegativeRational . abs) P.abs_ e1
go (Signum theRing) (e1 :* End) =
case theRing of
HRing_Int -> rr1 signum P.signum e1
HRing_Real -> rr1 signum P.signum e1
go (Recip theFractional) (e1 :* End) =
case theFractional of
HFractional_Prob -> rr1 recip P.recip e1
HFractional_Real -> rr1 recip P.recip e1
go (NatRoot theRadical) (e1 :* e2 :* End) =
case theRadical of
HRadical_Prob -> neu2 (flip P.thRootOf) e1 e2
go op _ = error $ "TODO: evaluatePrimOp{" ++ show op ++ "}"
rrEqual
:: forall b. HEq b -> abt '[] b -> abt '[] b -> m (Whnf abt HBool)
rrEqual theEq =
case theEq of
HEq_Nat -> rr2 (==) (P.==)
HEq_Int -> rr2 (==) (P.==)
HEq_Prob -> rr2 (==) (P.==)
HEq_Real -> rr2 (==) (P.==)
HEq_Array aEq -> error "TODO: rrEqual{HEq_Array}"
HEq_Bool -> rr2 (==) (P.==)
HEq_Unit -> rr2 (==) (P.==)
HEq_Pair aEq bEq ->
\e1 e2 -> do
w1 <- evaluate_ e1
w2 <- evaluate_ e2
case w1 of
Neutral e1' ->
return . Neutral
$ P.primOp2_ (Equal theEq) e1' (fromWhnf w2)
Head_ v1 ->
case w2 of
Neutral e2' ->
return . Neutral
$ P.primOp2_ (Equal theEq) (fromHead v1) e2'
Head_ v2 -> do
let (v1a, v1b) = reifyPair v1
let (v2a, v2b) = reifyPair v2
wa <- rrEqual aEq v1a v2a
wb <- rrEqual bEq v1b v2b
return $
case wa of
Neutral ea ->
case wb of
Neutral eb -> Neutral (ea P.&& eb)
Head_ vb
| reify vb -> wa
| otherwise -> Head_ $ WDatum dFalse
Head_ va
| reify va -> wb
| otherwise -> Head_ $ WDatum dFalse
HEq_Either aEq bEq -> error "TODO: rrEqual{HEq_Either}"
rrLess
:: forall b. HOrd b -> abt '[] b -> abt '[] b -> m (Whnf abt HBool)
rrLess theOrd =
case theOrd of
HOrd_Nat -> rr2 (<) (P.<)
HOrd_Int -> rr2 (<) (P.<)
HOrd_Prob -> rr2 (<) (P.<)
HOrd_Real -> rr2 (<) (P.<)
HOrd_Array aOrd -> error "TODO: rrLess{HOrd_Array}"
HOrd_Bool -> rr2 (<) (P.<)
HOrd_Unit -> rr2 (<) (P.<)
HOrd_Pair aOrd bOrd ->
\e1 e2 -> do
w1 <- evaluate_ e1
w2 <- evaluate_ e2
case w1 of
Neutral e1' ->
return . Neutral
$ P.primOp2_ (Less theOrd) e1' (fromWhnf w2)
Head_ v1 ->
case w2 of
Neutral e2' ->
return . Neutral
$ P.primOp2_ (Less theOrd) (fromHead v1) e2'
Head_ v2 -> do
let (v1a, v1b) = reifyPair v1
let (v2a, v2b) = reifyPair v2
error "TODO: rrLess{HOrd_Pair}"
HOrd_Either aOrd bOrd -> error "TODO: rrLess{HOrd_Either}"