module Language.Syntactic.Functional.Sharing
(
InjDict (..)
, CodeMotionInterface (..)
, defaultInterface
, defaultInterfaceT
, codeMotion
) where
import Control.Monad.State
import Data.Maybe (isNothing)
import Data.Set (Set)
import qualified Data.Set as Set
import Language.Syntactic
import Language.Syntactic.Functional
data InjDict sym a b = InjDict
{ injVariable :: Name -> sym (Full a)
, injLambda :: Name -> sym (b :-> Full (a -> b))
, injLet :: sym (a :-> (a -> b) :-> Full b)
}
data CodeMotionInterface sym = Interface
{ mkInjDict :: forall a b . ASTF sym a -> ASTF sym b -> Maybe (InjDict sym a b)
, castExprCM :: forall a b . ASTF sym a -> ASTF sym b -> Maybe (ASTF sym b)
, hoistOver :: forall c. ASTF sym c -> Bool
}
defaultInterface :: forall sym symT
. ( Binding :<: sym
, Let :<: sym
, symT ~ Typed sym
)
=> (forall a b . ASTF symT a -> ASTF symT b -> Bool)
-> (forall a . ASTF symT a -> Bool)
-> CodeMotionInterface symT
defaultInterface sharable hoistOver = Interface {..}
where
mkInjDict :: ASTF symT a -> ASTF symT b -> Maybe (InjDict symT a b)
mkInjDict a b | not (sharable a b) = Nothing
mkInjDict a b =
simpleMatch
(\(Typed _) _ -> simpleMatch
(\(Typed _) _ ->
let injVariable = Typed . inj . Var
injLambda = Typed . inj . Lam
injLet = Typed $ inj Let
in Just InjDict {..}
) b
) a
castExprCM = castExpr
defaultInterfaceT :: forall sym symT
. ( BindingT :<: sym
, Let :<: sym
, symT ~ Typed sym
)
=> (forall a b . ASTF symT a -> ASTF symT b -> Bool)
-> (forall a . ASTF symT a -> Bool)
-> CodeMotionInterface symT
defaultInterfaceT sharable hoistOver = Interface {..}
where
mkInjDict :: ASTF symT a -> ASTF symT b -> Maybe (InjDict symT a b)
mkInjDict a b | not (sharable a b) = Nothing
mkInjDict a b =
simpleMatch
(\(Typed _) _ -> simpleMatch
(\(Typed _) _ ->
let injVariable = Typed . inj . VarT
injLambda = Typed . inj . LamT
injLet = Typed $ inj Let
in Just InjDict {..}
) b
) a
castExprCM = castExpr
substitute :: forall sym a b
. (Equality sym, BindingDomain sym)
=> CodeMotionInterface sym
-> ASTF sym a
-> ASTF sym a
-> ASTF sym b
-> ASTF sym b
substitute iface x y a
| Just y' <- castExprCM iface y a, alphaEq x a = y'
| otherwise = subst a
where
subst :: AST sym c -> AST sym c
subst (f :$ a) = subst f :$ substitute iface x y a
subst a = a
count :: forall sym a b
. (Equality sym, BindingDomain sym)
=> ASTF sym a
-> ASTF sym b
-> Int
count a b
| alphaEq a b = 1
| otherwise = cnt b
where
cnt :: AST sym c -> Int
cnt (f :$ b) = cnt f + count a b
cnt _ = 0
data Env sym = Env
{ inLambda :: Bool
, counter :: EF (AST sym) -> Int
, dependencies :: Set Name
}
liftable :: BindingDomain sym => Env sym -> ASTF sym a -> Bool
liftable env a = independent && isNothing (prVar a) && heuristic
where
independent = Set.null $ Set.intersection (freeVars a) (dependencies env)
heuristic = inLambda env || (counter env (EF a) > 1)
data Chosen sym a
where
Chosen :: InjDict sym b a -> ASTF sym b -> Chosen sym a
choose :: forall sym a
. (Equality sym, BindingDomain sym)
=> CodeMotionInterface sym
-> ASTF sym a
-> Maybe (Chosen sym a)
choose iface a = chooseEnvSub initEnv a
where
initEnv = Env
{ inLambda = False
, counter = \(EF b) -> count b a
, dependencies = Set.empty
}
chooseEnv :: Env sym -> ASTF sym b -> Maybe (Chosen sym a)
chooseEnv env b
| liftable env b
, Just id <- mkInjDict iface b a
= Just $ Chosen id b
chooseEnv env b
| hoistOver iface b = chooseEnvSub env b
| otherwise = Nothing
chooseEnvSub :: Env sym -> AST sym b -> Maybe (Chosen sym a)
chooseEnvSub env (Sym lam :$ b)
| Just v <- prLam lam
= chooseEnv (env' v) b
where
env' v = env
{ inLambda = True
, dependencies = Set.insert v (dependencies env)
}
chooseEnvSub env (s :$ b) = chooseEnvSub env s `mplus` chooseEnv env b
chooseEnvSub _ _ = Nothing
codeMotionM :: forall sym m a
. ( Equality sym
, BindingDomain sym
, MonadState Name m
)
=> CodeMotionInterface sym
-> ASTF sym a
-> m (ASTF sym a)
codeMotionM iface a
| Just (Chosen id b) <- choose iface a = share id b
| otherwise = descend a
where
share :: InjDict sym b a -> ASTF sym b -> m (ASTF sym a)
share id b = do
b' <- codeMotionM iface b
v <- get; put (v+1)
let x = Sym (injVariable id v)
body <- codeMotionM iface $ substitute iface b x a
return
$ Sym (injLet id)
:$ b'
:$ (Sym (injLambda id v) :$ body)
descend :: AST sym b -> m (AST sym b)
descend (f :$ a) = liftM2 (:$) (descend f) (codeMotionM iface a)
descend a = return a
codeMotion :: forall sym m a
. ( Equality sym
, BindingDomain sym
)
=> CodeMotionInterface sym
-> ASTF sym a
-> ASTF sym a
codeMotion iface a = flip evalState maxVar $ codeMotionM iface a
where
maxVar = succ $ Set.findMax $ allVars a