module DDC.Core.Transform.Lambdas
(lambdasModule)
where
import DDC.Core.Fragment
import DDC.Core.Collect.Support
import DDC.Core.Transform.SubstituteXX
import DDC.Core.Module
import DDC.Core.Exp.Annot.Context
import DDC.Core.Exp.Annot.Ctx
import DDC.Core.Exp.Annot
import DDC.Type.Collect
import DDC.Base.Pretty
import DDC.Base.Name
import Data.Function
import Data.List
import Data.Set (Set)
import Data.Map (Map)
import qualified DDC.Core.Check as Check
import qualified DDC.Type.Env as Env
import qualified Data.Set as Set
import qualified Data.Map as Map
import Data.Maybe
lambdasModule
:: ( Show a, Pretty a
, Show n, Pretty n, Ord n, CompoundName n)
=> Profile n
-> Module a n -> Module a n
lambdasModule profile mm
= let
defs = moduleDataDefs mm
kenv = moduleKindEnv mm
tenv = moduleTypeEnv mm
c = Context
kenv tenv
(Env.fromList
[BName n t | (n, ImportCapAbstract t) <- moduleImportCaps mm])
(CtxTop defs kenv tenv)
x' = lambdasLoopX profile c $ moduleBody mm
in beautifyModule
$ mm { moduleBody = x' }
data Result a n
= Result
Bool
[(Bind n, Exp a n)]
instance Ord n => Monoid (Result a n) where
mempty
= Result False []
mappend (Result p1 lts1) (Result p2 lts2)
= Result (p1 || p2) (lts1 ++ lts2)
lambdasLoopX
:: (Show n, Show a, Pretty n, Pretty a, CompoundName n, Ord n)
=> Profile n
-> Context a n
-> Exp a n
-> Exp a n
lambdasLoopX p c xx
= let (xx1, Result progress _) = lambdasX p c xx
in if progress then lambdasLoopX p c xx1
else xx1
lambdasX :: (Show n, Show a, Pretty n, Pretty a, CompoundName n, Ord n)
=> Profile n
-> Context a n
-> Exp a n
-> ( Exp a n
, Result a n)
lambdasX p c xx
= case xx of
XVar{} -> (xx, mempty)
XCon{} -> (xx, mempty)
XLAM a b x0
-> enterLAM c a b x0 $ \c' x
-> let (x', r) = lambdasX p c' x
xx' = XLAM a b x'
Result _ bxs = r
liftMe = isLiftyContext (contextCtx c) && null bxs
in if liftMe
then let us' = supportEnvFlags
$ support Env.empty Env.empty xx'
(xCall, bLifted, xLifted)
= liftLambda p c us' a [(True, b)] x'
in ( xCall
, Result True (bxs ++ [(bLifted, xLifted)]))
else (xx', r)
XLam a b x0
-> enterLam c a b x0 $ \c' x
-> let (x', r) = lambdasX p c' x
xx' = XLam a b x'
Result _ bxs = r
liftMe = isLiftyContext (contextCtx c) && null bxs
in if liftMe
then let us' = supportEnvFlags
$ support Env.empty Env.empty xx'
(xCall, bLifted, xLifted)
= liftLambda p c us' a [(False, b)] x'
in ( xCall
, Result True (bxs ++ [(bLifted, xLifted)]))
else (xx', r)
XCast a cc@CastBox x0
-> enterCastBody c a cc x0 $ \c' x
-> let (x', r) = lambdasX p c' x
xx' = XCast a CastBox x'
Result _ bxs = r
liftMe = isLiftyContext (contextCtx c) && null bxs
in if liftMe
then let us' = supportEnvFlags
$ support Env.empty Env.empty xx'
(xCall, bLifted, xLifted)
= liftLambda p c us' a [] (XCast a CastBox x')
in ( xCall
, Result True (bxs ++ [(bLifted, xLifted)]))
else (xx', r)
XApp a x1 x2
-> let (x1', r1) = enterAppLeft c a x1 x2 (lambdasX p)
(x2', r2) = enterAppRight c a x1 x2 (lambdasX p)
in ( XApp a x1' x2'
, mappend r1 r2)
XLet a lts x
-> let (lts', r1) = lambdasLets p c a x lts
(x', r2) = enterLetBody c a lts x (lambdasX p)
in ( foldr (XLet a) x' lts'
, mappend r1 r2)
XCase a x alts
-> let (x', r1) = enterCaseScrut c a x alts (lambdasX p)
(alts', r2) = lambdasAlts p c a x [] alts
in ( XCase a x' alts'
, mappend r1 r2)
XCast a cc x
-> lambdasCast p c a cc x
XType{} -> (xx, mempty)
XWitness{} -> (xx, mempty)
lambdasLets
:: (Show a, Show n, Ord n, Pretty n, Pretty a, CompoundName n)
=> Profile n -> Context a n
-> a -> Exp a n
-> Lets a n
-> ([Lets a n], Result a n)
lambdasLets p c a xBody lts
= case lts of
LLet b x
-> let (x', r) = enterLetLLet c a b x xBody (lambdasX p)
in ([LLet b x'], r)
LRec bxs
| isLiftyContext (contextCtx c)
, Just _ <- sequence $ map (takeXLamFlags . snd) bxs
-> let (bxs', r) = lambdasLetRecLiftAll p c a bxs
in (map (uncurry LLet) bxs', r)
| otherwise
-> let (bxs', r) = lambdasLetRec p c a [] bxs xBody
in ([LRec bxs'], r)
LPrivate{}
-> ([lts], mempty)
lambdasLetRec
:: (Show a, Show n, Ord n, Pretty n, Pretty a, CompoundName n)
=> Profile n -> Context a n
-> a -> [(Bind n, Exp a n)] -> [(Bind n, Exp a n)] -> Exp a n
-> ([(Bind n, Exp a n)], Result a n)
lambdasLetRec _ _ _ _ [] _
= ([], mempty)
lambdasLetRec p c a bxsAcc ((b, x) : bxsMore) xBody
= let (x', r1) = enterLetLRec c a bxsAcc b x bxsMore xBody (lambdasX p)
in case contextCtx c of
CtxTop{}
-> let (bxs', Result p2 bxs2)
= lambdasLetRec p c a ((b, x') : bxsAcc) bxsMore xBody
Result p1 bxsLifted = r1
in ( bxsLifted ++ ((b, x') : bxs')
, Result (p1 || p2) bxs2 )
_
-> let (bxs', r2) = lambdasLetRec p c a ((b, x') : bxsAcc) bxsMore xBody
in ( (b, x') : bxs'
, mappend r1 r2 )
lambdasLetRecLiftAll
:: (Show a, Show n, Ord n, Pretty n, Pretty a, CompoundName n)
=> Profile n -> Context a n
-> a
-> [(Bind n, Exp a n)]
-> ([(Bind n, Exp a n)], Result a n)
lambdasLetRecLiftAll p c a bxs
= let
us = Set.unions
$ map (supportEnvFlags . support Env.empty Env.empty)
$ map snd bxs
us' = Set.filter (\(_,bo) -> not $ any (boundMatchesBind bo . fst) bxs)
$ us
lift _before [] = []
lift before ((b, x) : after)
= let Just (lams, xx) = takeXLamFlags x
c' = ctx before b x after
l' = liftLambda p c' us' a lams xx
in (b, l') : lift (before ++ [(b,x)]) after
ls = lift [] bxs
calls = map (\(b,(xC,_,_)) -> (b,xC)) ls
sub x = case takeXLamFlags x of
Just (lams, xx) -> makeXLamFlags a lams (substituteXXs calls xx)
Nothing -> substituteXXs calls x
res = map (\(_, (_, bL, xL)) -> (bL, sub xL)) ls
in (calls, Result True res)
where
ctx before b x after
= enterLetLRec c a before b x after x (\c' _ -> c')
lambdasAlts
:: (Show a, Show n, Ord n, Pretty n, Pretty a, CompoundName n)
=> Profile n -> Context a n
-> a -> Exp a n -> [Alt a n] -> [Alt a n]
-> ([Alt a n], Result a n)
lambdasAlts _ _ _ _ _ []
= ([], mempty)
lambdasAlts p c a xScrut altsAcc (AAlt w x : altsMore)
= let (x', r1) = enterCaseAlt c a xScrut altsAcc w x altsMore (lambdasX p)
(alts', r2) = lambdasAlts p c a xScrut (AAlt w x' : altsAcc) altsMore
in ( AAlt w x' : alts'
, mappend r1 r2)
lambdasCast
:: (Show a, Show n, Ord n, Pretty n, Pretty a, CompoundName n)
=> Profile n -> Context a n
-> a -> Cast a n -> Exp a n
-> (Exp a n, Result a n)
lambdasCast p c a cc x
= case cc of
CastWeakenEffect{}
-> let (x', r) = enterCastBody c a cc x (lambdasX p)
in ( XCast a cc x', r)
CastPurify{}
-> let (x', r) = enterCastBody c a cc x (lambdasX p)
in (XCast a cc x', r)
CastBox
-> let (x', r) = enterCastBody c a cc x (lambdasX p)
in (XCast a cc x', r)
CastRun
-> let (x', r) = enterCastBody c a cc x (lambdasX p)
in (XCast a cc x', r)
isLiftyContext :: Ctx a n -> Bool
isLiftyContext ctx
= case ctx of
CtxTop{} -> False
CtxLetLLet{} -> not $ isTopLetCtx ctx
CtxLetLRec{} -> not $ isTopLetCtx ctx
CtxLAM{} -> False
CtxLam{} -> False
CtxAppLeft{} -> True
CtxAppRight{} -> True
CtxLetBody{} -> True
CtxCaseScrut{} -> True
CtxCaseAlt{} -> True
CtxCastBody{} -> True
liftLambda
:: (Show a, Show n, Pretty n, Ord n, CompoundName n, Pretty a)
=> Profile n
-> Context a n
-> Set (Bool, Bound n)
-> a
-> [(Bool, Bind n)]
-> Exp a n
-> ( Exp a n
, Bind n, Exp a n)
liftLambda p c fusFree a lams xBody
= let ctx = contextCtx c
kenv = contextKindEnv c
tenv = contextTypeEnv c
xLambda = makeXLamFlags a lams xBody
Just nTop = takeTopNameOfCtx ctx
nsSuper = takeTopLetEnvNamesOfCtx ctx
nLifted = extendName nTop ("Lift_" ++ encodeCtx ctx)
uLifted = UName nLifted
(defs, _, _) = topOfCtx (contextCtx c)
config = Check.configOfProfile p
config' = config
{ Check.configDataDefs
= mappend defs (Check.configDataDefs config)
, Check.configGlobalCaps
= contextGlobalCaps c }
typeOfExp x
= case Check.typeOfExp
config' (contextKindEnv c) (contextTypeEnv c)
x
of Left err
-> error $ renderIndent $ vcat
[ text "ddc-core-simpl.liftLambda: type error in lifted expression"
, ppr err]
Right t -> t
keepVar fu@(_, u)
| (False, UName n) <- fu = not $ Set.member n nsSuper
| (_, UPrim{}) <- fu = False
| any (boundMatchesBind u . snd) lams
= False
| otherwise = True
fusFree_filtered
= filter keepVar
$ Set.toList fusFree
joinType (f, u)
= case f of
True | Just t <- Env.lookup u kenv
-> ((f, u), t)
False | Just t <- Env.lookup u tenv
-> ((f, u), t)
_ -> error $ unlines
[ "ddc-core-simpl.joinType: cannot find type of free var."
, show (f, u) ]
futsFree_types
= map joinType fusFree_filtered
expandFree ((f, u), t)
| False <- f = [(f, u)]
++ [(True, ut) | ut <- Set.toList
$ freeVarsT Env.empty t]
| otherwise = [(f, u)]
fusFree_body = [(True, ut) | ut <- Set.toList
$ freeVarsT Env.empty $ typeOfExp xLambda]
futsFree_expandFree
= map joinType
$ Set.toList $ Set.fromList
$ (concatMap expandFree $ futsFree_types)
++ fusFree_body
futsFree
= sortBy (compare `on` (not . fst . fst))
$ futsFree_expandFree
makeArg (True, u) = XType a (TVar u)
makeArg (False, u) = XVar a u
xCall = xApps a (XVar a uLifted)
$ map makeArg $ map fst futsFree
makeBind ((True, (UName n)), t) = (True, BName n t)
makeBind ((False, (UName n)), t) = (False, BName n t)
makeBind fut
= error $ "ddc-core-simpl.liftLamba: unhandled binder " ++ show fut
bsParam = map makeBind futsFree
xLifted = makeXLamFlags a bsParam xLambda
tLifted = typeOfExp xLifted
bLifted = BName nLifted tLifted
in ( xCall
, bLifted, xLifted)
beautifyModule
:: forall a n. (Ord n, Show n, CompoundName n)
=> Module a n -> Module a n
beautifyModule mm
= mm { moduleBody = beautifyX $ moduleBody mm }
where
makeRenamer
:: Map n Int -> Bind n
-> (Map n Int, Maybe (n, (n, Type n)))
makeRenamer acc b
| BName n t <- b
, Just (nBase, str) <- splitName n
, isPrefixOf "Lift_" str
= case Map.lookup nBase acc of
Nothing -> ( Map.insert nBase 0 acc
, Just ( extendName nBase str
, (extendName nBase ("L" ++ show (0 :: Int)), t)))
Just n' -> ( Map.insert nBase (n' + 1) acc
, Just ( extendName nBase str
, (extendName nBase ("L" ++ show (n' + 1)), t)))
| otherwise = (acc, Nothing)
beautifyBXs a bxs
= let bsRenames :: [(n, (n, Type n))]
bsRenames = catMaybes $ snd
$ mapAccumL makeRenamer (Map.empty :: Map n Int)
$ map fst bxs
bxsSubsts :: [(Bind n, Exp a n)]
bxsSubsts = [ (BName n t, XVar a (UName n'))
| (n, (n', t)) <- bsRenames]
renameBind (b, x)
| BName n t <- b
, Just (n', _) <- lookup n bsRenames
= (BName n' t, x)
| otherwise = (b, x)
in map (\(b, x) -> (b, substituteXXs bxsSubsts x))
$ map renameBind bxs
beautifyX xx
= case xx of
XLet a (LRec bxs) xBody
-> let bxs' = beautifyBXs a bxs
in XLet a (LRec bxs') (beautifyX xBody)
XLet a (LLet b x) xBody
-> let [(b', x')] = beautifyBXs a [(b, x)]
in XLet a (LLet b' x') (beautifyX xBody)
_ -> xx