-- | Simple code motion transformation performing common sub-expression
-- elimination and variable hoisting. Note that the implementation is very
-- inefficient.
--
-- The code is based on an implementation by Gergely Dévai.

module Language.Syntactic.Functional.Sharing
    ( -- * Interface
      InjDict (..)
    , CodeMotionInterface (..)
    , defaultInterface
    , defaultInterfaceDecor
      -- * Code motion
    , codeMotion
    ) where



import Control.Monad.State
import Data.Maybe (isNothing)
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Typeable

import Data.Constraint (Dict (..))

import Language.Syntactic
import Language.Syntactic.Functional



--------------------------------------------------------------------------------
-- * Interface
--------------------------------------------------------------------------------

-- | Interface for injecting binding constructs
data InjDict sym a b = InjDict
    { InjDict sym a b -> Name -> sym (Full a)
injVariable :: Name -> sym (Full a)
        -- ^ Inject a variable
    , InjDict sym a b -> Name -> sym (b :-> Full (a -> b))
injLambda   :: Name -> sym (b :-> Full (a -> b))
        -- ^ Inject a lambda
    , InjDict sym a b -> sym (a :-> ((a -> b) :-> Full b))
injLet      :: sym (a :-> (a -> b) :-> Full b)
        -- ^ Inject a "let" symbol
    }

-- | Code motion interface
data CodeMotionInterface sym = Interface
    { CodeMotionInterface sym
-> forall a b. ASTF sym a -> ASTF sym b -> Maybe (InjDict sym a b)
mkInjDict   :: forall a b . ASTF sym a -> ASTF sym b -> Maybe (InjDict sym a b)
        -- ^ Try to construct an 'InjDict'. The first argument is the expression
        -- to be shared, and the second argument the expression in which it will
        -- be shared. This function can be used to transfer information (e.g.
        -- from static analysis) from the shared expression to the introduced
        -- variable.
    , CodeMotionInterface sym
-> forall a b. ASTF sym a -> ASTF sym b -> Maybe (ASTF sym b)
castExprCM  :: forall a b . ASTF sym a -> ASTF sym b -> Maybe (ASTF sym b)
        -- ^ Try to type cast an expression. The first argument is the
        -- expression to cast. The second argument can be used to construct a
        -- witness to support the casting. The resulting expression (if any)
        -- should be equal to the first argument.
    , CodeMotionInterface sym -> forall c. ASTF sym c -> Bool
hoistOver   :: forall c. ASTF sym c -> Bool
        -- ^ Whether a sub-expression can be hoisted over the given expression
    }

-- | Default 'CodeMotionInterface' for domains of the form
-- @`Typed` (... `:+:` `Binding` `:+:` ...)@.
defaultInterface :: forall binding sym symT
    .  ( binding :<: sym
       , Let     :<: sym
       , symT ~ Typed sym
       )
    => (forall a .   Typeable a => Name -> binding (Full a))
         -- ^ Variable constructor (e.g. 'Var' or 'VarT')
    -> (forall a b . Typeable a => Name -> binding (b :-> Full (a -> b)))
         -- ^ Lambda constructor (e.g. 'Lam' or 'LamT')
    -> (forall a b . ASTF symT a -> ASTF symT b -> Bool)
         -- ^ Can the expression represented by the first argument be shared in
         -- the second argument?
    -> (forall a . ASTF symT a -> Bool)
         -- ^ Can we hoist over this expression?
    -> CodeMotionInterface symT
defaultInterface :: (forall a. Typeable a => Name -> binding (Full a))
-> (forall a b.
    Typeable a =>
    Name -> binding (b :-> Full (a -> b)))
-> (forall a b. ASTF symT a -> ASTF symT b -> Bool)
-> (forall a. ASTF symT a -> Bool)
-> CodeMotionInterface symT
defaultInterface forall a. Typeable a => Name -> binding (Full a)
var forall a b. Typeable a => Name -> binding (b :-> Full (a -> b))
lam forall a b. ASTF symT a -> ASTF symT b -> Bool
sharable forall a. ASTF symT a -> Bool
hoistOver = Interface :: forall (sym :: * -> *).
(forall a b. ASTF sym a -> ASTF sym b -> Maybe (InjDict sym a b))
-> (forall a b. ASTF sym a -> ASTF sym b -> Maybe (ASTF sym b))
-> (forall c. ASTF sym c -> Bool)
-> CodeMotionInterface sym
Interface {forall a. ASTF symT a -> Bool
forall a b.
AST symT (Full a) -> AST symT (Full b) -> Maybe (AST symT (Full b))
forall a b. ASTF symT a -> ASTF symT b -> Maybe (InjDict symT a b)
forall (sym :: * -> *) a b.
ASTF (Typed sym) a
-> ASTF (Typed sym) b -> Maybe (ASTF (Typed sym) b)
castExprCM :: forall (sym :: * -> *) a b.
ASTF (Typed sym) a
-> ASTF (Typed sym) b -> Maybe (ASTF (Typed sym) b)
mkInjDict :: forall a b. ASTF symT a -> ASTF symT b -> Maybe (InjDict symT a b)
hoistOver :: forall a. ASTF symT a -> Bool
hoistOver :: forall a. ASTF symT a -> Bool
castExprCM :: forall a b.
AST symT (Full a) -> AST symT (Full b) -> Maybe (AST symT (Full b))
mkInjDict :: forall a b. ASTF symT a -> ASTF symT b -> Maybe (InjDict symT a b)
..}
  where
    mkInjDict :: ASTF symT a -> ASTF symT b -> Maybe (InjDict symT a b)
    mkInjDict :: ASTF symT a -> ASTF symT b -> Maybe (InjDict symT a b)
mkInjDict ASTF symT a
a ASTF symT b
b | Bool -> Bool
not (ASTF symT a -> ASTF symT b -> Bool
forall a b. ASTF symT a -> ASTF symT b -> Bool
sharable ASTF symT a
a ASTF symT b
b) = Maybe (InjDict symT a b)
forall a. Maybe a
Nothing
    mkInjDict ASTF symT a
a ASTF symT b
b =
        (forall sig.
 (a ~ DenResult sig) =>
 symT sig -> Args (AST symT) sig -> Maybe (InjDict symT a b))
-> ASTF symT a -> Maybe (InjDict symT a b)
forall (sym :: * -> *) a b.
(forall sig.
 (a ~ DenResult sig) =>
 sym sig -> Args (AST sym) sig -> b)
-> ASTF sym a -> b
simpleMatch
          (\(Typed _) Args (AST symT) sig
_ -> (forall sig.
 (b ~ DenResult sig) =>
 symT sig -> Args (AST symT) sig -> Maybe (InjDict symT a b))
-> ASTF symT b -> Maybe (InjDict symT a b)
forall (sym :: * -> *) a b.
(forall sig.
 (a ~ DenResult sig) =>
 sym sig -> Args (AST sym) sig -> b)
-> ASTF sym a -> b
simpleMatch
            (\(Typed _) Args (AST symT) sig
_ ->
              let injVariable :: Name -> Typed sym (Full a)
injVariable = sym (Full a) -> Typed sym (Full a)
forall sig (sym :: * -> *).
Typeable (DenResult sig) =>
sym sig -> Typed sym sig
Typed (sym (Full a) -> Typed sym (Full a))
-> (Name -> sym (Full a)) -> Name -> Typed sym (Full a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. binding (Full a) -> sym (Full a)
forall (sub :: * -> *) (sup :: * -> *) a.
(sub :<: sup) =>
sub a -> sup a
inj (binding (Full a) -> sym (Full a))
-> (Name -> binding (Full a)) -> Name -> sym (Full a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> binding (Full a)
forall a. Typeable a => Name -> binding (Full a)
var
                  injLambda :: Name -> Typed sym (b :-> Full (a -> b))
injLambda   = sym (b :-> Full (a -> b)) -> Typed sym (b :-> Full (a -> b))
forall sig (sym :: * -> *).
Typeable (DenResult sig) =>
sym sig -> Typed sym sig
Typed (sym (b :-> Full (a -> b)) -> Typed sym (b :-> Full (a -> b)))
-> (Name -> sym (b :-> Full (a -> b)))
-> Name
-> Typed sym (b :-> Full (a -> b))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. binding (b :-> Full (a -> b)) -> sym (b :-> Full (a -> b))
forall (sub :: * -> *) (sup :: * -> *) a.
(sub :<: sup) =>
sub a -> sup a
inj (binding (b :-> Full (a -> b)) -> sym (b :-> Full (a -> b)))
-> (Name -> binding (b :-> Full (a -> b)))
-> Name
-> sym (b :-> Full (a -> b))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> binding (b :-> Full (a -> b))
forall a b. Typeable a => Name -> binding (b :-> Full (a -> b))
lam
                  injLet :: Typed sym (a :-> ((a -> b) :-> Full b))
injLet      = sym (a :-> ((a -> b) :-> Full b))
-> Typed sym (a :-> ((a -> b) :-> Full b))
forall sig (sym :: * -> *).
Typeable (DenResult sig) =>
sym sig -> Typed sym sig
Typed (sym (a :-> ((a -> b) :-> Full b))
 -> Typed sym (a :-> ((a -> b) :-> Full b)))
-> sym (a :-> ((a -> b) :-> Full b))
-> Typed sym (a :-> ((a -> b) :-> Full b))
forall a b. (a -> b) -> a -> b
$ Let (a :-> ((a -> b) :-> Full b))
-> sym (a :-> ((a -> b) :-> Full b))
forall (sub :: * -> *) (sup :: * -> *) a.
(sub :<: sup) =>
sub a -> sup a
inj (String -> Let (a :-> ((a -> b) :-> Full b))
forall a b. String -> Let (a :-> ((a -> b) :-> Full b))
Let String
"")
              in  InjDict (Typed sym) a b -> Maybe (InjDict (Typed sym) a b)
forall a. a -> Maybe a
Just InjDict :: forall (sym :: * -> *) a b.
(Name -> sym (Full a))
-> (Name -> sym (b :-> Full (a -> b)))
-> sym (a :-> ((a -> b) :-> Full b))
-> InjDict sym a b
InjDict {Typed sym (a :-> ((a -> b) :-> Full b))
Name -> Typed sym (b :-> Full (a -> b))
Name -> Typed sym (Full a)
forall a. Typed sym (a :-> ((a -> b) :-> Full b))
injLet :: forall a. Typed sym (a :-> ((a -> b) :-> Full b))
injLambda :: Name -> Typed sym (b :-> Full (a -> b))
injVariable :: Name -> Typed sym (Full a)
injLet :: Typed sym (a :-> ((a -> b) :-> Full b))
injLambda :: Name -> Typed sym (b :-> Full (a -> b))
injVariable :: Name -> Typed sym (Full a)
..}
            ) ASTF symT b
b
          ) ASTF symT a
a

    castExprCM :: ASTF (Typed sym) a
-> ASTF (Typed sym) b -> Maybe (ASTF (Typed sym) b)
castExprCM = ASTF (Typed sym) a
-> ASTF (Typed sym) b -> Maybe (ASTF (Typed sym) b)
forall (sym :: * -> *) a b.
ASTF (Typed sym) a
-> ASTF (Typed sym) b -> Maybe (ASTF (Typed sym) b)
castExpr

-- | Default 'CodeMotionInterface' for domains of the form
-- @(... `:&:` info)@, where @info@ can be used to witness type casting
defaultInterfaceDecor :: forall binding sym symI info
    .  ( binding :<: sym
       , Let     :<: sym
       , symI ~ (sym :&: info)
       )
    => (forall a b . info a -> info b -> Maybe (Dict (a ~ b)))
         -- ^ Construct a type equality witness
    -> (forall a b . info a -> info b -> info (a -> b))
         -- ^ Construct info for a function, given info for the argument and the
         -- result
    -> (forall a . info a -> Name -> binding (Full a))
         -- ^ Variable constructor
    -> (forall a b . info a -> info b -> Name -> binding (b :-> Full (a -> b)))
         -- ^ Lambda constructor
    -> (forall a b . ASTF symI a -> ASTF symI b -> Bool)
         -- ^ Can the expression represented by the first argument be shared in
         -- the second argument?
    -> (forall a . ASTF symI a -> Bool)
         -- ^ Can we hoist over this expression?
    -> CodeMotionInterface symI
defaultInterfaceDecor :: (forall a b. info a -> info b -> Maybe (Dict (a ~ b)))
-> (forall a b. info a -> info b -> info (a -> b))
-> (forall a. info a -> Name -> binding (Full a))
-> (forall a b.
    info a -> info b -> Name -> binding (b :-> Full (a -> b)))
-> (forall a b. ASTF symI a -> ASTF symI b -> Bool)
-> (forall a. ASTF symI a -> Bool)
-> CodeMotionInterface symI
defaultInterfaceDecor forall a b. info a -> info b -> Maybe (Dict (a ~ b))
teq forall a b. info a -> info b -> info (a -> b)
mkFunInfo forall a. info a -> Name -> binding (Full a)
var forall a b.
info a -> info b -> Name -> binding (b :-> Full (a -> b))
lam forall a b. ASTF symI a -> ASTF symI b -> Bool
sharable forall a. ASTF symI a -> Bool
hoistOver = Interface :: forall (sym :: * -> *).
(forall a b. ASTF sym a -> ASTF sym b -> Maybe (InjDict sym a b))
-> (forall a b. ASTF sym a -> ASTF sym b -> Maybe (ASTF sym b))
-> (forall c. ASTF sym c -> Bool)
-> CodeMotionInterface sym
Interface {forall a. ASTF symI a -> Bool
forall a b. ASTF symI a -> ASTF symI b -> Maybe (ASTF symI b)
forall a b. ASTF symI a -> ASTF symI b -> Maybe (InjDict symI a b)
castExprCM :: forall a b. ASTF symI a -> ASTF symI b -> Maybe (ASTF symI b)
mkInjDict :: forall a b. ASTF symI a -> ASTF symI b -> Maybe (InjDict symI a b)
hoistOver :: forall a. ASTF symI a -> Bool
hoistOver :: forall a. ASTF symI a -> Bool
castExprCM :: forall a b. ASTF symI a -> ASTF symI b -> Maybe (ASTF symI b)
mkInjDict :: forall a b. ASTF symI a -> ASTF symI b -> Maybe (InjDict symI a b)
..}
  where
    mkInjDict :: ASTF symI a -> ASTF symI b -> Maybe (InjDict symI a b)
    mkInjDict :: ASTF symI a -> ASTF symI b -> Maybe (InjDict symI a b)
mkInjDict ASTF symI a
a ASTF symI b
b | Bool -> Bool
not (ASTF symI a -> ASTF symI b -> Bool
forall a b. ASTF symI a -> ASTF symI b -> Bool
sharable ASTF symI a
a ASTF symI b
b) = Maybe (InjDict symI a b)
forall a. Maybe a
Nothing
    mkInjDict ASTF symI a
a ASTF symI b
b =
        (forall sig.
 (a ~ DenResult sig) =>
 symI sig -> Args (AST symI) sig -> Maybe (InjDict symI a b))
-> ASTF symI a -> Maybe (InjDict symI a b)
forall (sym :: * -> *) a b.
(forall sig.
 (a ~ DenResult sig) =>
 sym sig -> Args (AST sym) sig -> b)
-> ASTF sym a -> b
simpleMatch
          (\(_ :&: aInfo) Args (AST symI) sig
_ -> (forall sig.
 (b ~ DenResult sig) =>
 symI sig -> Args (AST symI) sig -> Maybe (InjDict symI a b))
-> ASTF symI b -> Maybe (InjDict symI a b)
forall (sym :: * -> *) a b.
(forall sig.
 (a ~ DenResult sig) =>
 sym sig -> Args (AST sym) sig -> b)
-> ASTF sym a -> b
simpleMatch
            (\(_ :&: bInfo) Args (AST symI) sig
_ ->
              let injVariable :: Name -> (:&:) sym info (Full a)
injVariable Name
v = binding (Full a) -> sym (Full a)
forall (sub :: * -> *) (sup :: * -> *) a.
(sub :<: sup) =>
sub a -> sup a
inj (info a -> Name -> binding (Full a)
forall a. info a -> Name -> binding (Full a)
var info a
info (DenResult sig)
aInfo Name
v) sym (Full a)
-> info (DenResult (Full a)) -> (:&:) sym info (Full a)
forall (expr :: * -> *) sig (info :: * -> *).
expr sig -> info (DenResult sig) -> (:&:) expr info sig
:&: info (DenResult sig)
info (DenResult (Full a))
aInfo
                  injLambda :: Name -> (:&:) sym info (b :-> Full (a -> b))
injLambda   Name
v = binding (b :-> Full (a -> b)) -> sym (b :-> Full (a -> b))
forall (sub :: * -> *) (sup :: * -> *) a.
(sub :<: sup) =>
sub a -> sup a
inj (info a -> info b -> Name -> binding (b :-> Full (a -> b))
forall a b.
info a -> info b -> Name -> binding (b :-> Full (a -> b))
lam info a
info (DenResult sig)
aInfo info b
info (DenResult sig)
bInfo Name
v) sym (b :-> Full (a -> b))
-> info (DenResult (b :-> Full (a -> b)))
-> (:&:) sym info (b :-> Full (a -> b))
forall (expr :: * -> *) sig (info :: * -> *).
expr sig -> info (DenResult sig) -> (:&:) expr info sig
:&: info a -> info b -> info (a -> b)
forall a b. info a -> info b -> info (a -> b)
mkFunInfo info a
info (DenResult sig)
aInfo info b
info (DenResult sig)
bInfo
                  injLet :: (:&:) sym info (a :-> ((a -> b) :-> Full b))
injLet        = Let (a :-> ((a -> b) :-> Full b))
-> sym (a :-> ((a -> b) :-> Full b))
forall (sub :: * -> *) (sup :: * -> *) a.
(sub :<: sup) =>
sub a -> sup a
inj (String -> Let (a :-> ((a -> b) :-> Full b))
forall a b. String -> Let (a :-> ((a -> b) :-> Full b))
Let String
"") sym (a :-> ((a -> b) :-> Full b))
-> info (DenResult (a :-> ((a -> b) :-> Full b)))
-> (:&:) sym info (a :-> ((a -> b) :-> Full b))
forall (expr :: * -> *) sig (info :: * -> *).
expr sig -> info (DenResult sig) -> (:&:) expr info sig
:&: info (DenResult sig)
info (DenResult (a :-> ((a -> b) :-> Full b)))
bInfo
              in  InjDict (sym :&: info) a b -> Maybe (InjDict (sym :&: info) a b)
forall a. a -> Maybe a
Just InjDict :: forall (sym :: * -> *) a b.
(Name -> sym (Full a))
-> (Name -> sym (b :-> Full (a -> b)))
-> sym (a :-> ((a -> b) :-> Full b))
-> InjDict sym a b
InjDict {(:&:) sym info (a :-> ((a -> b) :-> Full b))
Name -> (:&:) sym info (b :-> Full (a -> b))
Name -> (:&:) sym info (Full a)
injLet :: (:&:) sym info (a :-> ((a -> b) :-> Full b))
injLambda :: Name -> (:&:) sym info (b :-> Full (a -> b))
injVariable :: Name -> (:&:) sym info (Full a)
injLet :: (:&:) sym info (a :-> ((a -> b) :-> Full b))
injLambda :: Name -> (:&:) sym info (b :-> Full (a -> b))
injVariable :: Name -> (:&:) sym info (Full a)
..}
            ) ASTF symI b
b
          ) ASTF symI a
a

    castExprCM :: ASTF symI a -> ASTF symI b -> Maybe (ASTF symI b)
    castExprCM :: ASTF symI a -> ASTF symI b -> Maybe (ASTF symI b)
castExprCM ASTF symI a
a ASTF symI b
b =
        (forall sig.
 (a ~ DenResult sig) =>
 symI sig -> Args (AST symI) sig -> Maybe (ASTF symI b))
-> ASTF symI a -> Maybe (ASTF symI b)
forall (sym :: * -> *) a b.
(forall sig.
 (a ~ DenResult sig) =>
 sym sig -> Args (AST sym) sig -> b)
-> ASTF sym a -> b
simpleMatch
          (\(_ :&: aInfo) Args (AST symI) sig
_ -> (forall sig.
 (b ~ DenResult sig) =>
 symI sig -> Args (AST symI) sig -> Maybe (ASTF symI b))
-> ASTF symI b -> Maybe (ASTF symI b)
forall (sym :: * -> *) a b.
(forall sig.
 (a ~ DenResult sig) =>
 sym sig -> Args (AST sym) sig -> b)
-> ASTF sym a -> b
simpleMatch
            (\(_ :&: bInfo) Args (AST symI) sig
_ -> case info a -> info b -> Maybe (Dict (a ~ b))
forall a b. info a -> info b -> Maybe (Dict (a ~ b))
teq info a
info (DenResult sig)
aInfo info b
info (DenResult sig)
bInfo of
              Just Dict (a ~ b)
Dict -> ASTF symI a -> Maybe (ASTF symI a)
forall a. a -> Maybe a
Just ASTF symI a
a
              Maybe (Dict (a ~ b))
_ -> Maybe (ASTF symI b)
forall a. Maybe a
Nothing
            ) ASTF symI b
b
          ) ASTF symI a
a



--------------------------------------------------------------------------------
-- * Code motion
--------------------------------------------------------------------------------

-- | Substituting a sub-expression. Assumes that the free variables of the
-- replacing expression do not occur as binders in the whole expression (so that
-- there is no risk of capturing).
substitute :: forall sym a b
    .  (Equality sym, BindingDomain sym)
    => CodeMotionInterface sym
    -> ASTF sym a  -- ^ Sub-expression to be replaced
    -> ASTF sym a  -- ^ Replacing sub-expression
    -> ASTF sym b  -- ^ Whole expression
    -> ASTF sym b
substitute :: CodeMotionInterface sym
-> ASTF sym a -> ASTF sym a -> ASTF sym b -> ASTF sym b
substitute CodeMotionInterface sym
iface ASTF sym a
x ASTF sym a
y ASTF sym b
a = ASTF sym b -> ASTF sym b
forall c. ASTF sym c -> ASTF sym c
subst ASTF sym b
a
  where
    fv :: Set Name
fv = ASTF sym a -> Set Name
forall (sym :: * -> *) sig.
BindingDomain sym =>
AST sym sig -> Set Name
freeVars ASTF sym a
x

    subst :: ASTF sym c -> ASTF sym c
    subst :: ASTF sym c -> ASTF sym c
subst ASTF sym c
a
      | Just ASTF sym c
y' <- CodeMotionInterface sym
-> ASTF sym a -> ASTF sym c -> Maybe (ASTF sym c)
forall (sym :: * -> *).
CodeMotionInterface sym
-> forall a b. ASTF sym a -> ASTF sym b -> Maybe (ASTF sym b)
castExprCM CodeMotionInterface sym
iface ASTF sym a
y ASTF sym c
a, ASTF sym a -> ASTF sym c -> Bool
forall (sym :: * -> *) a b.
(Equality sym, BindingDomain sym) =>
ASTF sym a -> ASTF sym b -> Bool
alphaEq ASTF sym a
x ASTF sym c
a = ASTF sym c
y'
      | Bool
otherwise = ASTF sym c -> ASTF sym c
forall c. AST sym c -> AST sym c
subst' ASTF sym c
a

    subst' :: AST sym c -> AST sym c
    subst' :: AST sym c -> AST sym c
subst' a :: AST sym c
a@(AST sym (a :-> c)
lam :$ AST sym (Full a)
body)
      | Just Name
v <- AST sym (a :-> c) -> Maybe Name
forall (sym :: * -> *) sig.
BindingDomain sym =>
sym sig -> Maybe Name
prLam AST sym (a :-> c)
lam
      , Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member Name
v Set Name
fv = AST sym c
a
    subst' (AST sym (a :-> c)
s :$ AST sym (Full a)
a) = AST sym (a :-> c) -> AST sym (a :-> c)
forall c. AST sym c -> AST sym c
subst' AST sym (a :-> c)
s AST sym (a :-> c) -> AST sym (Full a) -> AST sym c
forall (sym :: * -> *) a sig.
AST sym (a :-> sig) -> AST sym (Full a) -> AST sym sig
:$ AST sym (Full a) -> AST sym (Full a)
forall c. ASTF sym c -> ASTF sym c
subst AST sym (Full a)
a
    subst' AST sym c
a = AST sym c
a

  -- Note: Since `codeMotion` only uses `substitute` to replace sub-expressions
  -- with fresh variables, the assumption above is fulfilled. However, the
  -- matching in `subst` needs to be aware of free variables, which is why the
  -- substitution stops when reaching a lambda that binds a variable that is
  -- free in the expression to be replaced.

-- | Count the number of occurrences of a sub-expression
count :: forall sym a b
    .  (Equality sym, BindingDomain sym)
    => ASTF sym a  -- ^ Expression to count
    -> ASTF sym b  -- ^ Expression to count in
    -> Int
count :: ASTF sym a -> ASTF sym b -> Int
count ASTF sym a
a ASTF sym b
b = ASTF sym b -> Int
forall c. ASTF sym c -> Int
cnt ASTF sym b
b
  where
    fv :: Set Name
fv = ASTF sym a -> Set Name
forall (sym :: * -> *) sig.
BindingDomain sym =>
AST sym sig -> Set Name
freeVars ASTF sym a
a

    cnt :: ASTF sym c -> Int
    cnt :: ASTF sym c -> Int
cnt ASTF sym c
c
      | ASTF sym a -> ASTF sym c -> Bool
forall (sym :: * -> *) a b.
(Equality sym, BindingDomain sym) =>
ASTF sym a -> ASTF sym b -> Bool
alphaEq ASTF sym a
a ASTF sym c
c = Int
1
      | Bool
otherwise   = ASTF sym c -> Int
forall sig. AST sym sig -> Int
cnt' ASTF sym c
c

    cnt' :: AST sym sig -> Int
    cnt' :: AST sym sig -> Int
cnt' (AST sym (a :-> sig)
lam :$ AST sym (Full a)
body)
      | Just Name
v <- AST sym (a :-> sig) -> Maybe Name
forall (sym :: * -> *) sig.
BindingDomain sym =>
sym sig -> Maybe Name
prLam AST sym (a :-> sig)
lam
      , Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member Name
v Set Name
fv = Int
0
          -- There can be no match under a lambda that binds a variable that is
          -- free in `a`. This case needs to be handled in order to avoid false
          -- matches.
          --
          -- Consider the following expression:
          --
          --     (\x -> f x) 0 + f x
          --
          -- The sub-expression `f x` appear twice, but `x` means different
          -- things in the two cases.
    cnt' (AST sym (a :-> sig)
s :$ AST sym (Full a)
c) = AST sym (a :-> sig) -> Int
forall sig. AST sym sig -> Int
cnt' AST sym (a :-> sig)
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ AST sym (Full a) -> Int
forall c. ASTF sym c -> Int
cnt AST sym (Full a)
c
    cnt' AST sym sig
_        = Int
0

-- | Environment for the expression in the 'choose' function
data Env sym = Env
    { Env sym -> Bool
inLambda :: Bool  -- ^ Whether the current expression is inside a lambda
    , Env sym -> EF (AST sym) -> Int
counter  :: EF (AST sym) -> Int
        -- ^ Counting the number of occurrences of an expression in the
        -- environment
    , Env sym -> Set Name
dependencies :: Set Name
        -- ^ The set of variables that are not allowed to occur in the chosen
        -- expression
    }

-- | Checks whether a sub-expression in a given environment can be lifted out
liftable :: BindingDomain sym => Env sym -> ASTF sym a -> Bool
liftable :: Env sym -> ASTF sym a -> Bool
liftable Env sym
env ASTF sym a
a = Bool
independent Bool -> Bool -> Bool
&& Maybe Name -> Bool
forall a. Maybe a -> Bool
isNothing (ASTF sym a -> Maybe Name
forall (sym :: * -> *) sig.
BindingDomain sym =>
sym sig -> Maybe Name
prVar ASTF sym a
a) Bool -> Bool -> Bool
&& Bool
heuristic
      -- Lifting dependent expressions is semantically incorrect. Lifting
      -- variables would cause `codeMotion` to loop.
  where
    independent :: Bool
independent = Set Name -> Bool
forall a. Set a -> Bool
Set.null (Set Name -> Bool) -> Set Name -> Bool
forall a b. (a -> b) -> a -> b
$ Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.intersection (ASTF sym a -> Set Name
forall (sym :: * -> *) sig.
BindingDomain sym =>
AST sym sig -> Set Name
freeVars ASTF sym a
a) (Env sym -> Set Name
forall (sym :: * -> *). Env sym -> Set Name
dependencies Env sym
env)
    heuristic :: Bool
heuristic   = Env sym -> Bool
forall (sym :: * -> *). Env sym -> Bool
inLambda Env sym
env Bool -> Bool -> Bool
|| (Env sym -> EF (AST sym) -> Int
forall (sym :: * -> *). Env sym -> EF (AST sym) -> Int
counter Env sym
env (ASTF sym a -> EF (AST sym)
forall (e :: * -> *) a. e (Full a) -> EF e
EF ASTF sym a
a) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1)

-- | A sub-expression chosen to be shared together with an evidence that it can
-- actually be shared in the whole expression under consideration
data Chosen sym a
  where
    Chosen :: InjDict sym b a -> ASTF sym b -> Chosen sym a

-- | Choose a sub-expression to share
choose :: forall sym a
    .  (Equality sym, BindingDomain sym)
    => CodeMotionInterface sym
    -> ASTF sym a
    -> Maybe (Chosen sym a)
choose :: CodeMotionInterface sym -> ASTF sym a -> Maybe (Chosen sym a)
choose CodeMotionInterface sym
iface ASTF sym a
a = Env sym -> ASTF sym a -> Maybe (Chosen sym a)
forall b. Env sym -> AST sym b -> Maybe (Chosen sym a)
chooseEnvSub Env sym
initEnv ASTF sym a
a
  where
    initEnv :: Env sym
initEnv = Env :: forall (sym :: * -> *).
Bool -> (EF (AST sym) -> Int) -> Set Name -> Env sym
Env
        { inLambda :: Bool
inLambda     = Bool
False
        , counter :: EF (AST sym) -> Int
counter      = \(EF AST sym (Full a)
b) -> AST sym (Full a) -> ASTF sym a -> Int
forall (sym :: * -> *) a b.
(Equality sym, BindingDomain sym) =>
ASTF sym a -> ASTF sym b -> Int
count AST sym (Full a)
b ASTF sym a
a
        , dependencies :: Set Name
dependencies = Set Name
forall a. Set a
Set.empty
        }

    chooseEnv :: Env sym -> ASTF sym b -> Maybe (Chosen sym a)
    chooseEnv :: Env sym -> ASTF sym b -> Maybe (Chosen sym a)
chooseEnv Env sym
env ASTF sym b
b
        | Env sym -> ASTF sym b -> Bool
forall (sym :: * -> *) a.
BindingDomain sym =>
Env sym -> ASTF sym a -> Bool
liftable Env sym
env ASTF sym b
b
        , Just InjDict sym b a
id <- CodeMotionInterface sym
-> ASTF sym b -> ASTF sym a -> Maybe (InjDict sym b a)
forall (sym :: * -> *).
CodeMotionInterface sym
-> forall a b. ASTF sym a -> ASTF sym b -> Maybe (InjDict sym a b)
mkInjDict CodeMotionInterface sym
iface ASTF sym b
b ASTF sym a
a
        = Chosen sym a -> Maybe (Chosen sym a)
forall a. a -> Maybe a
Just (Chosen sym a -> Maybe (Chosen sym a))
-> Chosen sym a -> Maybe (Chosen sym a)
forall a b. (a -> b) -> a -> b
$ InjDict sym b a -> ASTF sym b -> Chosen sym a
forall (sym :: * -> *) b a.
InjDict sym b a -> ASTF sym b -> Chosen sym a
Chosen InjDict sym b a
id ASTF sym b
b
    chooseEnv Env sym
env ASTF sym b
b
        | CodeMotionInterface sym -> ASTF sym b -> Bool
forall (sym :: * -> *).
CodeMotionInterface sym -> forall c. ASTF sym c -> Bool
hoistOver CodeMotionInterface sym
iface ASTF sym b
b = Env sym -> ASTF sym b -> Maybe (Chosen sym a)
forall b. Env sym -> AST sym b -> Maybe (Chosen sym a)
chooseEnvSub Env sym
env ASTF sym b
b
        | Bool
otherwise         = Maybe (Chosen sym a)
forall a. Maybe a
Nothing

    -- | Like 'chooseEnv', but does not consider the top expression for sharing
    chooseEnvSub :: Env sym -> AST sym b -> Maybe (Chosen sym a)
    chooseEnvSub :: Env sym -> AST sym b -> Maybe (Chosen sym a)
chooseEnvSub Env sym
env (Sym sym (a :-> b)
lam :$ AST sym (Full a)
b)
        | Just Name
v <- sym (a :-> b) -> Maybe Name
forall (sym :: * -> *) sig.
BindingDomain sym =>
sym sig -> Maybe Name
prLam sym (a :-> b)
lam
        = Env sym -> AST sym (Full a) -> Maybe (Chosen sym a)
forall b. Env sym -> ASTF sym b -> Maybe (Chosen sym a)
chooseEnv (Name -> Env sym
env' Name
v) AST sym (Full a)
b
      where
        env' :: Name -> Env sym
env' Name
v = Env sym
env
            { inLambda :: Bool
inLambda     = Bool
True
            , dependencies :: Set Name
dependencies = Name -> Set Name -> Set Name
forall a. Ord a => a -> Set a -> Set a
Set.insert Name
v (Env sym -> Set Name
forall (sym :: * -> *). Env sym -> Set Name
dependencies Env sym
env)
            }
    chooseEnvSub Env sym
env (AST sym (a :-> b)
s :$ AST sym (Full a)
b) = Env sym -> AST sym (a :-> b) -> Maybe (Chosen sym a)
forall b. Env sym -> AST sym b -> Maybe (Chosen sym a)
chooseEnvSub Env sym
env AST sym (a :-> b)
s Maybe (Chosen sym a)
-> Maybe (Chosen sym a) -> Maybe (Chosen sym a)
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Env sym -> AST sym (Full a) -> Maybe (Chosen sym a)
forall b. Env sym -> ASTF sym b -> Maybe (Chosen sym a)
chooseEnv Env sym
env AST sym (Full a)
b
    chooseEnvSub Env sym
_ AST sym b
_ = Maybe (Chosen sym a)
forall a. Maybe a
Nothing

-- If `codeMotionM` loops forever, the reason may be that `castExprCM` is
-- broken. If `castExprCM` fails to cast even when it should, it means that
-- we can get into situations where `substitute` returns the same expression
-- unchanged. This in turn means that `codeMotionM` will loop, since it calls
-- itself with `codeMotionM iface $ substitute iface b x a`.

codeMotionM :: forall sym m a
    .  ( Equality sym
       , BindingDomain sym
       , MonadState Name m
       )
    => CodeMotionInterface sym
    -> ASTF sym a
    -> m (ASTF sym a)
codeMotionM :: CodeMotionInterface sym -> ASTF sym a -> m (ASTF sym a)
codeMotionM CodeMotionInterface sym
iface ASTF sym a
a
    | Just (Chosen InjDict sym b a
id ASTF sym b
b) <- CodeMotionInterface sym -> ASTF sym a -> Maybe (Chosen sym a)
forall (sym :: * -> *) a.
(Equality sym, BindingDomain sym) =>
CodeMotionInterface sym -> ASTF sym a -> Maybe (Chosen sym a)
choose CodeMotionInterface sym
iface ASTF sym a
a = InjDict sym b a -> ASTF sym b -> m (ASTF sym a)
forall b. InjDict sym b a -> ASTF sym b -> m (ASTF sym a)
share InjDict sym b a
id ASTF sym b
b
    | Bool
otherwise = ASTF sym a -> m (ASTF sym a)
forall b. AST sym b -> m (AST sym b)
descend ASTF sym a
a
  where
    share :: InjDict sym b a -> ASTF sym b -> m (ASTF sym a)
    share :: InjDict sym b a -> ASTF sym b -> m (ASTF sym a)
share InjDict sym b a
id ASTF sym b
b = do
        ASTF sym b
b' <- CodeMotionInterface sym -> ASTF sym b -> m (ASTF sym b)
forall (sym :: * -> *) (m :: * -> *) a.
(Equality sym, BindingDomain sym, MonadState Name m) =>
CodeMotionInterface sym -> ASTF sym a -> m (ASTF sym a)
codeMotionM CodeMotionInterface sym
iface ASTF sym b
b
        Name
v  <- m Name
forall s (m :: * -> *). MonadState s m => m s
get; Name -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Name
vName -> Name -> Name
forall a. Num a => a -> a -> a
+Name
1)
        let x :: ASTF sym b
x = sym (Full b) -> ASTF sym b
forall (sym :: * -> *) sig. sym sig -> AST sym sig
Sym (InjDict sym b a -> Name -> sym (Full b)
forall (sym :: * -> *) a b. InjDict sym a b -> Name -> sym (Full a)
injVariable InjDict sym b a
id Name
v)
        ASTF sym a
body <- CodeMotionInterface sym -> ASTF sym a -> m (ASTF sym a)
forall (sym :: * -> *) (m :: * -> *) a.
(Equality sym, BindingDomain sym, MonadState Name m) =>
CodeMotionInterface sym -> ASTF sym a -> m (ASTF sym a)
codeMotionM CodeMotionInterface sym
iface (ASTF sym a -> m (ASTF sym a)) -> ASTF sym a -> m (ASTF sym a)
forall a b. (a -> b) -> a -> b
$ CodeMotionInterface sym
-> ASTF sym b -> ASTF sym b -> ASTF sym a -> ASTF sym a
forall (sym :: * -> *) a b.
(Equality sym, BindingDomain sym) =>
CodeMotionInterface sym
-> ASTF sym a -> ASTF sym a -> ASTF sym b -> ASTF sym b
substitute CodeMotionInterface sym
iface ASTF sym b
b ASTF sym b
x ASTF sym a
a
        ASTF sym a -> m (ASTF sym a)
forall (m :: * -> *) a. Monad m => a -> m a
return
            (ASTF sym a -> m (ASTF sym a)) -> ASTF sym a -> m (ASTF sym a)
forall a b. (a -> b) -> a -> b
$  sym (b :-> ((b -> a) :-> Full a))
-> AST sym (b :-> ((b -> a) :-> Full a))
forall (sym :: * -> *) sig. sym sig -> AST sym sig
Sym (InjDict sym b a -> sym (b :-> ((b -> a) :-> Full a))
forall (sym :: * -> *) a b.
InjDict sym a b -> sym (a :-> ((a -> b) :-> Full b))
injLet InjDict sym b a
id)
            AST sym (b :-> ((b -> a) :-> Full a))
-> ASTF sym b -> AST sym ((b -> a) :-> Full a)
forall (sym :: * -> *) a sig.
AST sym (a :-> sig) -> AST sym (Full a) -> AST sym sig
:$ ASTF sym b
b'
            AST sym ((b -> a) :-> Full a)
-> AST sym (Full (b -> a)) -> ASTF sym a
forall (sym :: * -> *) a sig.
AST sym (a :-> sig) -> AST sym (Full a) -> AST sym sig
:$ (sym (a :-> Full (b -> a)) -> AST sym (a :-> Full (b -> a))
forall (sym :: * -> *) sig. sym sig -> AST sym sig
Sym (InjDict sym b a -> Name -> sym (a :-> Full (b -> a))
forall (sym :: * -> *) a b.
InjDict sym a b -> Name -> sym (b :-> Full (a -> b))
injLambda InjDict sym b a
id Name
v) AST sym (a :-> Full (b -> a))
-> ASTF sym a -> AST sym (Full (b -> a))
forall (sym :: * -> *) a sig.
AST sym (a :-> sig) -> AST sym (Full a) -> AST sym sig
:$ ASTF sym a
body)

    descend :: AST sym b -> m (AST sym b)
    descend :: AST sym b -> m (AST sym b)
descend (AST sym (a :-> b)
s :$ AST sym (Full a)
a) = (AST sym (a :-> b) -> AST sym (Full a) -> AST sym b)
-> m (AST sym (a :-> b)) -> m (AST sym (Full a)) -> m (AST sym b)
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 AST sym (a :-> b) -> AST sym (Full a) -> AST sym b
forall (sym :: * -> *) a sig.
AST sym (a :-> sig) -> AST sym (Full a) -> AST sym sig
(:$) (AST sym (a :-> b) -> m (AST sym (a :-> b))
forall b. AST sym b -> m (AST sym b)
descend AST sym (a :-> b)
s) (CodeMotionInterface sym -> AST sym (Full a) -> m (AST sym (Full a))
forall (sym :: * -> *) (m :: * -> *) a.
(Equality sym, BindingDomain sym, MonadState Name m) =>
CodeMotionInterface sym -> ASTF sym a -> m (ASTF sym a)
codeMotionM CodeMotionInterface sym
iface AST sym (Full a)
a)
    descend AST sym b
a        = AST sym b -> m (AST sym b)
forall (m :: * -> *) a. Monad m => a -> m a
return AST sym b
a

-- | Perform common sub-expression elimination and variable hoisting
codeMotion :: forall sym m a
    .  ( Equality sym
       , BindingDomain sym
       )
    => CodeMotionInterface sym
    -> ASTF sym a
    -> ASTF sym a
codeMotion :: CodeMotionInterface sym -> ASTF sym a -> ASTF sym a
codeMotion CodeMotionInterface sym
iface ASTF sym a
a = (State Name (ASTF sym a) -> Name -> ASTF sym a)
-> Name -> State Name (ASTF sym a) -> ASTF sym a
forall a b c. (a -> b -> c) -> b -> a -> c
flip State Name (ASTF sym a) -> Name -> ASTF sym a
forall s a. State s a -> s -> a
evalState Name
maxVar (State Name (ASTF sym a) -> ASTF sym a)
-> State Name (ASTF sym a) -> ASTF sym a
forall a b. (a -> b) -> a -> b
$ CodeMotionInterface sym -> ASTF sym a -> State Name (ASTF sym a)
forall (sym :: * -> *) (m :: * -> *) a.
(Equality sym, BindingDomain sym, MonadState Name m) =>
CodeMotionInterface sym -> ASTF sym a -> m (ASTF sym a)
codeMotionM CodeMotionInterface sym
iface ASTF sym a
a
  where
    maxVar :: Name
maxVar = Name -> Name
forall a. Enum a => a -> a
succ (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ Set Name -> Name
forall a. Set a -> a
Set.findMax (Set Name -> Name) -> Set Name -> Name
forall a b. (a -> b) -> a -> b
$ Name -> Set Name -> Set Name
forall a. Ord a => a -> Set a -> Set a
Set.insert Name
0 (Set Name -> Set Name) -> Set Name -> Set Name
forall a b. (a -> b) -> a -> b
$ ASTF sym a -> Set Name
forall (sym :: * -> *) sig.
BindingDomain sym =>
AST sym sig -> Set Name
allVars ASTF sym a
a