{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
module GHC.Stg.Lift.Monad (
decomposeStgBinding, mkStgBinding,
Env (..),
FloatLang (..), collectFloats,
LiftM, runLiftM,
startBindingGroup, endBindingGroup, addTopStringLit, addLiftedBinding,
withSubstBndr, withSubstBndrs, withLiftedBndr, withLiftedBndrs,
substOcc, isLifted, formerFreeVars, liftedIdsExpander
) where
#include "HsVersions.h"
import GHC.Prelude
import GHC.Types.Basic
import GHC.Types.CostCentre ( isCurrentCCS, dontCareCCS )
import GHC.Driver.Session
import GHC.Data.FastString
import GHC.Types.Id
import GHC.Types.Name
import GHC.Utils.Outputable
import GHC.Data.OrdList
import GHC.Stg.Subst
import GHC.Stg.Syntax
import GHC.Core.Utils
import GHC.Types.Unique.Supply
import GHC.Utils.Misc
import GHC.Utils.Panic
import GHC.Types.Var.Env
import GHC.Types.Var.Set
import GHC.Core.Multiplicity
import Control.Arrow ( second )
import Control.Monad.Trans.Class
import Control.Monad.Trans.RWS.Strict ( RWST, runRWST )
import qualified Control.Monad.Trans.RWS.Strict as RWS
import Control.Monad.Trans.Cont ( ContT (..) )
import Data.ByteString ( ByteString )
decomposeStgBinding :: GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding :: forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding (StgRec [(BinderP pass, GenStgRhs pass)]
pairs) = (RecFlag
Recursive, [(BinderP pass, GenStgRhs pass)]
pairs)
decomposeStgBinding (StgNonRec BinderP pass
bndr GenStgRhs pass
rhs) = (RecFlag
NonRecursive, [(BinderP pass
bndr, GenStgRhs pass
rhs)])
mkStgBinding :: RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding :: forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding RecFlag
Recursive = forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec
mkStgBinding RecFlag
NonRecursive = forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> a
head
data Env
= Env
{ Env -> DynFlags
e_dflags :: !DynFlags
, Env -> Subst
e_subst :: !Subst
, Env -> IdEnv DIdSet
e_expansions :: !(IdEnv DIdSet)
}
emptyEnv :: DynFlags -> Env
emptyEnv :: DynFlags -> Env
emptyEnv DynFlags
dflags = DynFlags -> Subst -> IdEnv DIdSet -> Env
Env DynFlags
dflags Subst
emptySubst forall a. VarEnv a
emptyVarEnv
data FloatLang
= StartBindingGroup
| EndBindingGroup
| PlainTopBinding OutStgTopBinding
| LiftedBinding OutStgBinding
instance Outputable FloatLang where
ppr :: FloatLang -> SDoc
ppr FloatLang
StartBindingGroup = Char -> SDoc
char Char
'('
ppr FloatLang
EndBindingGroup = Char -> SDoc
char Char
')'
ppr (PlainTopBinding StgTopStringLit{}) = String -> SDoc
text String
"<str>"
ppr (PlainTopBinding (StgTopLifted GenStgBinding 'Vanilla
b)) = forall a. Outputable a => a -> SDoc
ppr (GenStgBinding 'Vanilla -> FloatLang
LiftedBinding GenStgBinding 'Vanilla
b)
ppr (LiftedBinding GenStgBinding 'Vanilla
bind) = (if RecFlag -> Bool
isRec RecFlag
rec then Char -> SDoc
char Char
'r' else Char -> SDoc
char Char
'n') SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs)
where
(RecFlag
rec, [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs) = forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding GenStgBinding 'Vanilla
bind
collectFloats :: [FloatLang] -> [OutStgTopBinding]
collectFloats :: [FloatLang] -> [GenStgTopBinding 'Vanilla]
collectFloats = forall {a}.
(Eq a, Num a) =>
a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (Int
0 :: Int) []
where
go :: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
0 [] [] = []
go a
_ [GenStgBinding 'Vanilla]
_ [] = forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
text String
"unterminated group")
go a
n [GenStgBinding 'Vanilla]
binds (FloatLang
f:[FloatLang]
rest) = case FloatLang
f of
FloatLang
StartBindingGroup -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (a
nforall a. Num a => a -> a -> a
+a
1) [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
FloatLang
EndBindingGroup
| a
n forall a. Eq a => a -> a -> Bool
== a
0 -> forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
text String
"no group to end")
| a
n forall a. Eq a => a -> a -> Bool
== a
1 -> forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted (forall {t :: * -> *} {pass :: StgPass}.
Foldable t =>
t (GenStgBinding pass) -> GenStgBinding pass
merge_binds [GenStgBinding 'Vanilla]
binds) forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
0 [] [FloatLang]
rest
| Bool
otherwise -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (a
nforall a. Num a => a -> a -> a
-a
1) [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
PlainTopBinding GenStgTopBinding 'Vanilla
top_bind
| a
n forall a. Eq a => a -> a -> Bool
== a
0 -> GenStgTopBinding 'Vanilla
top_bind forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
| Bool
otherwise -> forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
text String
"plain top binding inside group")
LiftedBinding GenStgBinding 'Vanilla
bind
| a
n forall a. Eq a => a -> a -> Bool
== a
0 -> forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted (forall {pass :: StgPass}. GenStgBinding pass -> GenStgBinding pass
rm_cccs GenStgBinding 'Vanilla
bind) forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
| Bool
otherwise -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n (GenStgBinding 'Vanilla
bindforall a. a -> [a] -> [a]
:[GenStgBinding 'Vanilla]
binds) [FloatLang]
rest
map_rhss :: (GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
map_rhss GenStgRhs pass -> GenStgRhs pass
f = forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (forall a b. (a -> b) -> [a] -> [b]
map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second GenStgRhs pass -> GenStgRhs pass
f)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding
rm_cccs :: GenStgBinding pass -> GenStgBinding pass
rm_cccs = forall {pass :: StgPass} {pass :: StgPass}.
(BinderP pass ~ BinderP pass) =>
(GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
map_rhss forall (pass :: StgPass). GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS
merge_binds :: t (GenStgBinding pass) -> GenStgBinding pass
merge_binds t (GenStgBinding pass)
binds = ASSERT( any is_rec binds )
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {pass :: StgPass}. GenStgBinding pass -> GenStgBinding pass
rm_cccs) t (GenStgBinding pass)
binds)
is_rec :: GenStgBinding pass -> Bool
is_rec StgRec{} = Bool
True
is_rec GenStgBinding pass
_ = Bool
False
removeRhsCCCS :: GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS :: forall (pass :: StgPass). GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS (StgRhsClosure XRhsClosure pass
ext CostCentreStack
ccs UpdateFlag
upd [BinderP pass]
bndrs GenStgExpr pass
body)
| CostCentreStack -> Bool
isCurrentCCS CostCentreStack
ccs
= forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure XRhsClosure pass
ext CostCentreStack
dontCareCCS UpdateFlag
upd [BinderP pass]
bndrs GenStgExpr pass
body
removeRhsCCCS (StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
mu [StgTickish]
ts [StgArg]
args)
| CostCentreStack -> Bool
isCurrentCCS CostCentreStack
ccs
= forall (pass :: StgPass).
CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> GenStgRhs pass
StgRhsCon CostCentreStack
dontCareCCS DataCon
con ConstructorNumber
mu [StgTickish]
ts [StgArg]
args
removeRhsCCCS GenStgRhs pass
rhs = GenStgRhs pass
rhs
newtype LiftM a
= LiftM { forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM :: RWST Env (OrdList FloatLang) () UniqSM a }
deriving (forall a b. a -> LiftM b -> LiftM a
forall a b. (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> LiftM b -> LiftM a
$c<$ :: forall a b. a -> LiftM b -> LiftM a
fmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
$cfmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
Functor, Functor LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. LiftM a -> LiftM b -> LiftM a
$c<* :: forall a b. LiftM a -> LiftM b -> LiftM a
*> :: forall a b. LiftM a -> LiftM b -> LiftM b
$c*> :: forall a b. LiftM a -> LiftM b -> LiftM b
liftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
$cliftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
$c<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
pure :: forall a. a -> LiftM a
$cpure :: forall a. a -> LiftM a
Applicative, Applicative LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> LiftM a
$creturn :: forall a. a -> LiftM a
>> :: forall a b. LiftM a -> LiftM b -> LiftM b
$c>> :: forall a b. LiftM a -> LiftM b -> LiftM b
>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
$c>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
Monad)
instance HasDynFlags LiftM where
getDynFlags :: LiftM DynFlags
getDynFlags = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> DynFlags
e_dflags)
instance MonadUnique LiftM where
getUniqueSupplyM :: LiftM UniqSupply
getUniqueSupplyM = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadUnique m => m UniqSupply
getUniqueSupplyM)
getUniqueM :: LiftM Unique
getUniqueM = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM)
getUniquesM :: LiftM [Unique]
getUniquesM = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadUnique m => m [Unique]
getUniquesM)
runLiftM :: DynFlags -> UniqSupply -> LiftM () -> [OutStgTopBinding]
runLiftM :: DynFlags -> UniqSupply -> LiftM () -> [GenStgTopBinding 'Vanilla]
runLiftM DynFlags
dflags UniqSupply
us (LiftM RWST Env (OrdList FloatLang) () UniqSM ()
m) = [FloatLang] -> [GenStgTopBinding 'Vanilla]
collectFloats (forall a. OrdList a -> [a]
fromOL OrdList FloatLang
floats)
where
(()
_, ()
_, OrdList FloatLang
floats) = forall a. UniqSupply -> UniqSM a -> a
initUs_ UniqSupply
us (forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
runRWST RWST Env (OrdList FloatLang) () UniqSM ()
m (DynFlags -> Env
emptyEnv DynFlags
dflags) ())
addTopStringLit :: OutId -> ByteString -> LiftM ()
addTopStringLit :: Id -> ByteString -> LiftM ()
addTopStringLit Id
id = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> OrdList a
unitOL forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgTopBinding 'Vanilla -> FloatLang
PlainTopBinding forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (pass :: StgPass). Id -> ByteString -> GenStgTopBinding pass
StgTopStringLit Id
id
startBindingGroup :: LiftM ()
startBindingGroup :: LiftM ()
startBindingGroup = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell forall a b. (a -> b) -> a -> b
$ forall a. a -> OrdList a
unitOL forall a b. (a -> b) -> a -> b
$ FloatLang
StartBindingGroup
endBindingGroup :: LiftM ()
endBindingGroup :: LiftM ()
endBindingGroup = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell forall a b. (a -> b) -> a -> b
$ forall a. a -> OrdList a
unitOL forall a b. (a -> b) -> a -> b
$ FloatLang
EndBindingGroup
addLiftedBinding :: OutStgBinding -> LiftM ()
addLiftedBinding :: GenStgBinding 'Vanilla -> LiftM ()
addLiftedBinding = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> OrdList a
unitOL forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding 'Vanilla -> FloatLang
LiftedBinding
withSubstBndr :: Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr :: forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr Id
bndr Id -> LiftM a
inner = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall a b. (a -> b) -> a -> b
$ do
Subst
subst <- forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> Subst
e_subst
let (Id
bndr', Subst
subst') = Id -> Subst -> (Id, Subst)
substBndr Id
bndr Subst
subst
forall r w s (m :: * -> *) a.
(r -> r) -> RWST r w s m a -> RWST r w s m a
RWS.local (\Env
e -> Env
e { e_subst :: Subst
e_subst = Subst
subst' }) (forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM (Id -> LiftM a
inner Id
bndr'))
withSubstBndrs :: Traversable f => f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs :: forall (f :: * -> *) a.
Traversable f =>
f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs = forall {k} (r :: k) (m :: k -> *) a.
ContT r m a -> (a -> m r) -> m r
runContT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr)
withLiftedBndr :: DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr :: forall a. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr DIdSet
abs_ids Id
bndr Id -> LiftM a
inner = do
Unique
uniq <- forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
let str :: String
str = String
"$l" forall a. [a] -> [a] -> [a]
++ OccName -> String
occNameString (forall a. NamedThing a => a -> OccName
getOccName Id
bndr)
let ty :: Type
ty = [Id] -> Type -> Type
mkLamTypes (DIdSet -> [Id]
dVarSetElems DIdSet
abs_ids) (Id -> Type
idType Id
bndr)
let bndr' :: Id
bndr'
= Id -> [Id] -> Id -> Id
transferPolyIdInfo Id
bndr (DIdSet -> [Id]
dVarSetElems DIdSet
abs_ids)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FastString -> Unique -> Type -> Type -> Id
mkSysLocal (String -> FastString
mkFastString String
str) Unique
uniq Type
Many
forall a b. (a -> b) -> a -> b
$ Type
ty
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall a b. (a -> b) -> a -> b
$ forall r w s (m :: * -> *) a.
(r -> r) -> RWST r w s m a -> RWST r w s m a
RWS.local
(\Env
e -> Env
e
{ e_subst :: Subst
e_subst = Id -> Id -> Subst -> Subst
extendSubst Id
bndr Id
bndr' forall a b. (a -> b) -> a -> b
$ Id -> Subst -> Subst
extendInScope Id
bndr' forall a b. (a -> b) -> a -> b
$ Env -> Subst
e_subst Env
e
, e_expansions :: IdEnv DIdSet
e_expansions = forall a. VarEnv a -> Id -> a -> VarEnv a
extendVarEnv (Env -> IdEnv DIdSet
e_expansions Env
e) Id
bndr DIdSet
abs_ids
})
(forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM (Id -> LiftM a
inner Id
bndr'))
withLiftedBndrs :: Traversable f => DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs :: forall (f :: * -> *) a.
Traversable f =>
DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs DIdSet
abs_ids = forall {k} (r :: k) (m :: k -> *) a.
ContT r m a -> (a -> m r) -> m r
runContT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr DIdSet
abs_ids)
substOcc :: Id -> LiftM Id
substOcc :: Id -> LiftM Id
substOcc Id
id = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks (HasCallStack => Id -> Subst -> Id
lookupIdSubst Id
id forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Subst
e_subst))
isLifted :: InId -> LiftM Bool
isLifted :: Id -> LiftM Bool
isLifted Id
bndr = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks (forall a. Id -> VarEnv a -> Bool
elemVarEnv Id
bndr forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> IdEnv DIdSet
e_expansions))
formerFreeVars :: InId -> LiftM [OutId]
formerFreeVars :: Id -> LiftM [Id]
formerFreeVars Id
f = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall a b. (a -> b) -> a -> b
$ do
IdEnv DIdSet
expansions <- forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> IdEnv DIdSet
e_expansions
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv IdEnv DIdSet
expansions Id
f of
Maybe DIdSet
Nothing -> []
Just DIdSet
fvs -> DIdSet -> [Id]
dVarSetElems DIdSet
fvs
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander = forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM forall a b. (a -> b) -> a -> b
$ do
IdEnv DIdSet
expansions <- forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> IdEnv DIdSet
e_expansions
Subst
subst <- forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> Subst
e_subst
let go :: DIdSet -> Id -> DIdSet
go DIdSet
set Id
fv = case forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv IdEnv DIdSet
expansions Id
fv of
Maybe DIdSet
Nothing -> DIdSet -> Id -> DIdSet
extendDVarSet DIdSet
set (HasCallStack => Id -> Subst -> Id
noWarnLookupIdSubst Id
fv Subst
subst)
Just DIdSet
fvs' -> DIdSet -> DIdSet -> DIdSet
unionDVarSet DIdSet
set DIdSet
fvs'
let expander :: DIdSet -> DIdSet
expander DIdSet
fvs = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' DIdSet -> Id -> DIdSet
go DIdSet
emptyDVarSet (DIdSet -> [Id]
dVarSetElems DIdSet
fvs)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DIdSet -> DIdSet
expander