{-# LANGUAGE CPP
, DataKinds
, EmptyCase
, ExistentialQuantification
, FlexibleContexts
, GADTs
, GeneralizedNewtypeDeriving
, KindSignatures
, MultiParamTypeClasses
, OverloadedStrings
, PolyKinds
, ScopedTypeVariables
, TypeFamilies
, TypeOperators
#-}
{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
module Language.Hakaru.Syntax.CSE (cse) where
import Control.Monad.Reader
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.AST.Eq
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Types.DataKind
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif
data EAssoc (abt :: [Hakaru] -> Hakaru -> *)
= forall a . EAssoc !(abt '[] a) !(abt '[] a)
-- An association list for now
newtype Env (abt :: [Hakaru] -> Hakaru -> *) = Env [EAssoc abt]
emptyEnv :: Env a
emptyEnv = Env []
trivial :: (ABT Term abt) => abt '[] a -> Bool
trivial abt = case viewABT abt of
Var _ -> True
Syn (Literal_ _) -> True
_ -> False
-- Attempt to find a new expression in the environment. The lookup is chained
-- to iteratively perform lookup until no match is found, resulting in an
-- equivalence-relation in the environment. This could be made faster with path
-- compression and a more efficient lookup structure.
-- NB: This code could potentially produce an infinite loop depending on how
-- terms are added to the environment. How do we want to prevent this?
lookupEnv
:: forall abt a . (ABT Term abt)
=> abt '[] a
-> Env abt
-> abt '[] a
lookupEnv start (Env env) = go start env
where
go :: abt '[] a -> [EAssoc abt] -> abt '[] a
go ast [] = ast
go ast (EAssoc a b : xs) =
case jmEq1 (typeOf ast) (typeOf a) of
Just Refl | alphaEq ast a -> go b env
_ -> go ast xs
insertEnv
:: forall abt a . (ABT Term abt)
=> abt '[] a
-> abt '[] a
-> Env abt
-> Env abt
insertEnv ast1 ast2 (Env env)
-- Point new variables to the older ones, this does not affect the amount of
-- work done, since ast2 is always a variable. This allows the pass to
-- eliminate redundant variables, as we only eliminate binders during CSE.
| trivial ast1 = Env (EAssoc ast2 ast1 : env)
-- Otherwise map expressions to their binding variables
| otherwise = Env (EAssoc ast1 ast2 : env)
newtype CSE (abt :: [Hakaru] -> Hakaru -> *) a = CSE { runCSE :: Reader (Env abt) a }
deriving (Functor, Applicative, Monad, MonadReader (Env abt))
replaceCSE
:: (ABT Term abt)
=> abt '[] a
-> CSE abt (abt '[] a)
replaceCSE abt = lookupEnv abt `fmap` ask
cse :: forall abt a . (ABT Term abt) => abt '[] a -> abt '[] a
cse abt = runReader (runCSE (cse' abt)) emptyEnv
cse' :: forall abt xs a . (ABT Term abt) => abt xs a -> CSE abt (abt xs a)
cse' = loop . viewABT
where
loop :: View (Term abt) ys a -> CSE abt (abt ys a)
loop (Var v) = cseVar v
loop (Syn s) = cseTerm s
loop (Bind v b) = fmap (bind v) (loop b)
-- Variables can be equivalent to other variables
-- TODO: A good sanity check would be to ensure the result in this case is
-- always a variable or constant. A variable should never be substituted for
-- a more complex expression.
cseVar
:: (ABT Term abt)
=> Variable a
-> CSE abt (abt '[] a)
cseVar = replaceCSE . var
mklet :: ABT Term abt => Variable b -> abt '[] b -> abt '[] a -> abt '[] a
mklet v rhs body = syn (Let_ :$ rhs :* bind v body :* End)
-- Thanks to A-normalization, the only case we need to care about is let bindings.
-- Everything else is just structural recursion.
cseTerm
:: (ABT Term abt)
=> Term abt a
-> CSE abt (abt '[] a)
cseTerm (Let_ :$ rhs :* body :* End) = do
rhs' <- cse' rhs
caseBind body $ \v body' ->
local (insertEnv rhs' (var v)) $
if trivial rhs'
then cse' body'
else fmap (mklet v rhs') (cse' body')
cseTerm term = traverse21 cse' term >>= replaceCSE . syn