module DDC.Core.Transform.Bubble
( bubbleModule
, bubbleX)
where
import DDC.Core.Collect
import DDC.Core.Transform.BoundX
import DDC.Core.Module
import DDC.Core.Exp.Annot
import DDC.Type.Env (KindEnv, TypeEnv)
import qualified DDC.Type.Env as Env
import qualified DDC.Type.Sum as Sum
import qualified Data.Set as Set
import Data.Set (Set)
import Data.List
bubbleModule
:: Ord n
=> Module a n -> Module a n
bubbleModule mm@ModuleCore{}
= let kenv = moduleKindEnv mm
tenv = moduleTypeEnv mm
in mm { moduleBody = bubbleX kenv tenv (moduleBody mm) }
bubbleX :: Ord n
=> KindEnv n -> TypeEnv n -> Exp a n -> Exp a n
bubbleX kenv tenv x
= let
(cs, x') = bubble kenv tenv x
a = annotOfExp x'
in dropAllCasts kenv tenv a cs x'
class Bubble (c :: * -> * -> *) where
bubble :: Ord n
=> KindEnv n
-> TypeEnv n
-> c a n
-> ([FvsCast a n], c a n)
instance Bubble Exp where
bubble kenv tenv xx
=
case xx of
XVar{} -> ([], xx)
XCon{} -> ([], xx)
XLAM a b x
-> let kenv' = Env.extend b kenv
(cs, x') = bubble kenv' tenv x
in ([], XLAM a b (dropAllCasts kenv' tenv a cs x'))
XLam a b x
-> let tenv' = Env.extend b tenv
(cs, x') = bubble kenv tenv' x
in ([], XLam a b (dropAllCasts kenv tenv' a cs x'))
XApp a x1 x2
-> let (cs1, x1') = bubble kenv tenv x1
(cs2, x2') = bubble kenv tenv x2
in (cs1 ++ cs2, XApp a x1' x2')
XLet a lts x2
-> let (cs1, lts') = bubble kenv tenv lts
(bs1, bs0) = bindsOfLets lts
kenv' = Env.extends bs1 kenv
tenv' = Env.extends bs0 tenv
(cs2, x2') = bubble kenv' tenv' x2
(cs2', x2'') = dropCasts kenv' tenv' a bs1 bs0 cs2 x2'
in ( cs1 ++ cs2'
, XLet a lts' x2'')
XCase a x alts
-> let (cs, x') = bubble kenv tenv x
(css, alts') = unzip $ map (bubble kenv tenv) alts
in ( cs ++ concat css
, XCase a x' alts')
XCast _ c x
-> let (cs, x') = bubble kenv tenv x
fvsT = freeT Env.empty c
fvsX = freeX Env.empty c
fc = FvsCast c fvsT fvsX
in (fc : cs, x')
XType{} -> ([], xx)
XWitness{} -> ([], xx)
instance Bubble Lets where
bubble kenv tenv lts
= case lts of
LLet b x
-> let (cs, x') = bubble kenv tenv x
a = annotOfExp x'
(cs', xc') = dropCasts kenv tenv a [] [b] cs x'
in (cs', LLet b xc')
LRec bxs
-> let bs = map fst bxs
tenv' = Env.extends bs tenv
bubbleRec (b, x)
= let (cs, x') = bubble kenv tenv' x
a = annotOfExp x'
in (b, dropAllCasts kenv tenv' a cs x')
bxs' = map bubbleRec bxs
in ([], LRec bxs')
LPrivate{}
-> ([], lts)
instance Bubble Alt where
bubble kenv tenv (AAlt PDefault x)
= let (cs, x') = bubble kenv tenv x
in (cs, AAlt PDefault x')
bubble kenv tenv (AAlt p x)
= let bs = bindsOfPat p
a = annotOfExp x'
tenv' = Env.extends bs tenv
(cs, x') = bubble kenv tenv' x
(csUp, xcHere) = dropCasts kenv tenv' a [] bs cs x'
in (csUp, AAlt p xcHere)
data FvsCast a n
= FvsCast (Cast a n)
(Set (Bound n))
(Set (Bound n))
instance Ord n => MapBoundX (FvsCast a) n where
mapBoundAtDepthX f d (FvsCast c fvs1 fvs0)
= FvsCast (mapBoundAtDepthX f d c)
fvs1
(Set.fromList
$ map (mapBoundAtDepthX f d)
$ Set.toList fvs0)
packFvsCasts
:: Ord n
=> KindEnv n -> TypeEnv n
-> a -> [FvsCast a n] -> [Cast a n]
packFvsCasts kenv tenv a fvsCasts
= packCasts kenv tenv a [ c | FvsCast c _ _ <- fvsCasts ]
packCasts :: Ord n
=> KindEnv n -> TypeEnv n -> a -> [Cast a n] -> [Cast a n]
packCasts _kenv _tenv _a vs
= let
collect weakEffs weakClos others cc
= case cc of
[]
-> (reverse weakEffs, reverse weakClos, reverse others)
CastWeakenEffect eff : cs
-> collect (eff : weakEffs) weakClos others cs
c : cs
-> collect weakEffs weakClos (c : others) cs
(effs, csOthers, _)
= collect [] [] [] vs
in (if null effs
then []
else [CastWeakenEffect (TSum $ Sum.fromList kEffect effs)])
++ csOthers
dropAllCasts
:: Ord n
=> KindEnv n
-> TypeEnv n
-> a
-> [FvsCast a n] -> Exp a n
-> Exp a n
dropAllCasts kenv tenv a cs x
= let cs' = packFvsCasts kenv tenv a cs
in foldr (XCast a) x cs'
dropCasts
:: Ord n
=> KindEnv n -> TypeEnv n
-> a
-> [Bind n]
-> [Bind n]
-> [FvsCast a n]
-> Exp a n
-> ([FvsCast a n], Exp a n)
dropCasts kenv tenv a bs1 bs0 cs x
= let (csHere1, cs1) = partition (fvsCastUsesBinds1 bs1) cs
(csHere0, csUp) = partition (fvsCastUsesBinds0 bs0) cs1
csHere = packFvsCasts kenv tenv a $ csHere1 ++ csHere0
in ( map (lowerX 1) csUp
, foldr (XCast a) x csHere)
fvsCastUsesBinds0 :: Ord n => [Bind n] -> FvsCast a n -> Bool
fvsCastUsesBinds0 bb (FvsCast _ _ fvs0)
= bindsMatchBoundSet bb fvs0
fvsCastUsesBinds1 :: Ord n => [Bind n] -> FvsCast a n -> Bool
fvsCastUsesBinds1 bb (FvsCast _ fvs1 _)
= bindsMatchBoundSet bb fvs1
bindsMatchBoundSet :: Ord n => [Bind n] -> Set (Bound n) -> Bool
bindsMatchBoundSet bb fvs
= go bb
where go [] = False
go (b : bs)
| Just u <- takeSubstBoundOfBind b
= if Set.member u fvs
then True
else go bs
| otherwise
= go bs