{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
module GHC.Stg.Lift.Monad (
decomposeStgBinding, mkStgBinding,
Env (..),
FloatLang (..), collectFloats,
LiftM, runLiftM,
getConfig,
startBindingGroup, endBindingGroup, addTopStringLit, addLiftedBinding,
withSubstBndr, withSubstBndrs, withLiftedBndr, withLiftedBndrs,
substOcc, isLifted, formerFreeVars, liftedIdsExpander
) where
import GHC.Prelude
import GHC.Types.Basic
import GHC.Types.CostCentre ( isCurrentCCS, dontCareCCS )
import GHC.Data.FastString
import GHC.Types.Id
import GHC.Types.Name
import GHC.Utils.Outputable
import GHC.Data.OrdList
import GHC.Stg.Lift.Config
import GHC.Stg.Subst
import GHC.Stg.Syntax
import GHC.Core.Utils
import GHC.Types.Unique.Supply
import GHC.Utils.Panic
import GHC.Types.Var.Env
import GHC.Types.Var.Set
import GHC.Core.Multiplicity
import Control.Arrow ( second )
import Control.Monad.Trans.Class
import Control.Monad.Trans.RWS.Strict ( RWST, runRWST )
import qualified Control.Monad.Trans.RWS.Strict as RWS
import Control.Monad.Trans.Cont ( ContT (..) )
import Data.ByteString ( ByteString )
decomposeStgBinding :: GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding :: forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding (StgRec [(BinderP pass, GenStgRhs pass)]
pairs) = (RecFlag
Recursive, [(BinderP pass, GenStgRhs pass)]
pairs)
decomposeStgBinding (StgNonRec BinderP pass
bndr GenStgRhs pass
rhs) = (RecFlag
NonRecursive, [(BinderP pass
bndr, GenStgRhs pass
rhs)])
mkStgBinding :: RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding :: forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding RecFlag
Recursive = [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec
mkStgBinding RecFlag
NonRecursive = (BinderP pass -> GenStgRhs pass -> GenStgBinding pass)
-> (BinderP pass, GenStgRhs pass) -> GenStgBinding pass
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry BinderP pass -> GenStgRhs pass -> GenStgBinding pass
forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec ((BinderP pass, GenStgRhs pass) -> GenStgBinding pass)
-> ([(BinderP pass, GenStgRhs pass)]
-> (BinderP pass, GenStgRhs pass))
-> [(BinderP pass, GenStgRhs pass)]
-> GenStgBinding pass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(BinderP pass, GenStgRhs pass)] -> (BinderP pass, GenStgRhs pass)
forall a. HasCallStack => [a] -> a
head
data Env
= Env
{ Env -> StgLiftConfig
e_config :: StgLiftConfig
, Env -> Subst
e_subst :: !Subst
, Env -> IdEnv DIdSet
e_expansions :: !(IdEnv DIdSet)
}
emptyEnv :: StgLiftConfig -> Env
emptyEnv :: StgLiftConfig -> Env
emptyEnv StgLiftConfig
cfg = Env
{ e_config :: StgLiftConfig
e_config = StgLiftConfig
cfg
, e_subst :: Subst
e_subst = Subst
emptySubst
, e_expansions :: IdEnv DIdSet
e_expansions = IdEnv DIdSet
forall a. VarEnv a
emptyVarEnv
}
data FloatLang
= StartBindingGroup
| EndBindingGroup
| PlainTopBinding OutStgTopBinding
| LiftedBinding OutStgBinding
instance Outputable FloatLang where
ppr :: FloatLang -> SDoc
ppr FloatLang
StartBindingGroup = Char -> SDoc
forall doc. IsLine doc => Char -> doc
char Char
'('
ppr FloatLang
EndBindingGroup = Char -> SDoc
forall doc. IsLine doc => Char -> doc
char Char
')'
ppr (PlainTopBinding StgTopStringLit{}) = String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"<str>"
ppr (PlainTopBinding (StgTopLifted GenStgBinding 'Vanilla
b)) = FloatLang -> SDoc
forall a. Outputable a => a -> SDoc
ppr (GenStgBinding 'Vanilla -> FloatLang
LiftedBinding GenStgBinding 'Vanilla
b)
ppr (LiftedBinding GenStgBinding 'Vanilla
bind) = (if RecFlag -> Bool
isRec RecFlag
rec then Char -> SDoc
forall doc. IsLine doc => Char -> doc
char Char
'r' else Char -> SDoc
forall doc. IsLine doc => Char -> doc
char Char
'n') SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> [Id] -> SDoc
forall a. Outputable a => a -> SDoc
ppr (((Id, GenStgRhs 'Vanilla) -> Id)
-> [(Id, GenStgRhs 'Vanilla)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, GenStgRhs 'Vanilla) -> Id
forall a b. (a, b) -> a
fst [(Id, GenStgRhs 'Vanilla)]
[(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs)
where
(RecFlag
rec, [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs) = GenStgBinding 'Vanilla
-> (RecFlag, [(BinderP 'Vanilla, GenStgRhs 'Vanilla)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding GenStgBinding 'Vanilla
bind
collectFloats :: [FloatLang] -> [OutStgTopBinding]
collectFloats :: [FloatLang] -> [GenStgTopBinding 'Vanilla]
collectFloats = Int
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
forall {a}.
(Eq a, Num a) =>
a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (Int
0 :: Int) []
where
go :: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
0 [] [] = []
go a
_ [GenStgBinding 'Vanilla]
_ [] = String -> SDoc -> [GenStgTopBinding 'Vanilla]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"unterminated group")
go a
n [GenStgBinding 'Vanilla]
binds (FloatLang
f:[FloatLang]
rest) = case FloatLang
f of
FloatLang
StartBindingGroup -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (a
na -> a -> a
forall a. Num a => a -> a -> a
+a
1) [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
FloatLang
EndBindingGroup
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 -> String -> SDoc -> [GenStgTopBinding 'Vanilla]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"no group to end")
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1 -> GenStgBinding 'Vanilla -> GenStgTopBinding 'Vanilla
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted ([GenStgBinding 'Vanilla] -> GenStgBinding 'Vanilla
forall {t :: * -> *} {pass :: StgPass}.
Foldable t =>
t (GenStgBinding pass) -> GenStgBinding pass
merge_binds [GenStgBinding 'Vanilla]
binds) GenStgTopBinding 'Vanilla
-> [GenStgTopBinding 'Vanilla] -> [GenStgTopBinding 'Vanilla]
forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
0 [] [FloatLang]
rest
| Bool
otherwise -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go (a
na -> a -> a
forall a. Num a => a -> a -> a
-a
1) [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
PlainTopBinding GenStgTopBinding 'Vanilla
top_bind
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 -> GenStgTopBinding 'Vanilla
top_bind GenStgTopBinding 'Vanilla
-> [GenStgTopBinding 'Vanilla] -> [GenStgTopBinding 'Vanilla]
forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
| Bool
otherwise -> String -> SDoc -> [GenStgTopBinding 'Vanilla]
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"collectFloats" (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"plain top binding inside group")
LiftedBinding GenStgBinding 'Vanilla
bind
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 -> GenStgBinding 'Vanilla -> GenStgTopBinding 'Vanilla
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted (GenStgBinding 'Vanilla -> GenStgBinding 'Vanilla
forall {pass :: StgPass}. GenStgBinding pass -> GenStgBinding pass
rm_cccs GenStgBinding 'Vanilla
bind) GenStgTopBinding 'Vanilla
-> [GenStgTopBinding 'Vanilla] -> [GenStgTopBinding 'Vanilla]
forall a. a -> [a] -> [a]
: a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n [GenStgBinding 'Vanilla]
binds [FloatLang]
rest
| Bool
otherwise -> a
-> [GenStgBinding 'Vanilla]
-> [FloatLang]
-> [GenStgTopBinding 'Vanilla]
go a
n (GenStgBinding 'Vanilla
bindGenStgBinding 'Vanilla
-> [GenStgBinding 'Vanilla] -> [GenStgBinding 'Vanilla]
forall a. a -> [a] -> [a]
:[GenStgBinding 'Vanilla]
binds) [FloatLang]
rest
map_rhss :: (GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
map_rhss GenStgRhs pass -> GenStgRhs pass
f = (RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass)
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
-> GenStgBinding pass
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding ((RecFlag, [(BinderP pass, GenStgRhs pass)]) -> GenStgBinding pass)
-> (GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> GenStgBinding pass
-> GenStgBinding pass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(BinderP pass, GenStgRhs pass)]
-> [(BinderP pass, GenStgRhs pass)])
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (((BinderP pass, GenStgRhs pass) -> (BinderP pass, GenStgRhs pass))
-> [(BinderP pass, GenStgRhs pass)]
-> [(BinderP pass, GenStgRhs pass)]
forall a b. (a -> b) -> [a] -> [b]
map ((GenStgRhs pass -> GenStgRhs pass)
-> (BinderP pass, GenStgRhs pass) -> (BinderP pass, GenStgRhs pass)
forall b c d. (b -> c) -> (d, b) -> (d, c)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second GenStgRhs pass -> GenStgRhs pass
f)) ((RecFlag, [(BinderP pass, GenStgRhs pass)])
-> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> (GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding
rm_cccs :: GenStgBinding pass -> GenStgBinding pass
rm_cccs = (GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
forall {pass :: StgPass} {pass :: StgPass}.
(BinderP pass ~ BinderP pass) =>
(GenStgRhs pass -> GenStgRhs pass)
-> GenStgBinding pass -> GenStgBinding pass
map_rhss GenStgRhs pass -> GenStgRhs pass
forall (pass :: StgPass). GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS
merge_binds :: t (GenStgBinding pass) -> GenStgBinding pass
merge_binds t (GenStgBinding pass)
binds = Bool -> GenStgBinding pass -> GenStgBinding pass
forall a. HasCallStack => Bool -> a -> a
assert ((GenStgBinding pass -> Bool) -> t (GenStgBinding pass) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any GenStgBinding pass -> Bool
forall {pass :: StgPass}. GenStgBinding pass -> Bool
is_rec t (GenStgBinding pass)
binds) (GenStgBinding pass -> GenStgBinding pass)
-> GenStgBinding pass -> GenStgBinding pass
forall a b. (a -> b) -> a -> b
$
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec ((GenStgBinding pass -> [(BinderP pass, GenStgRhs pass)])
-> t (GenStgBinding pass) -> [(BinderP pass, GenStgRhs pass)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((RecFlag, [(BinderP pass, GenStgRhs pass)])
-> [(BinderP pass, GenStgRhs pass)]
forall a b. (a, b) -> b
snd ((RecFlag, [(BinderP pass, GenStgRhs pass)])
-> [(BinderP pass, GenStgRhs pass)])
-> (GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> GenStgBinding pass
-> [(BinderP pass, GenStgRhs pass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding (GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)]))
-> (GenStgBinding pass -> GenStgBinding pass)
-> GenStgBinding pass
-> (RecFlag, [(BinderP pass, GenStgRhs pass)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding pass -> GenStgBinding pass
forall {pass :: StgPass}. GenStgBinding pass -> GenStgBinding pass
rm_cccs) t (GenStgBinding pass)
binds)
is_rec :: GenStgBinding pass -> Bool
is_rec StgRec{} = Bool
True
is_rec GenStgBinding pass
_ = Bool
False
removeRhsCCCS :: GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS :: forall (pass :: StgPass). GenStgRhs pass -> GenStgRhs pass
removeRhsCCCS (StgRhsClosure XRhsClosure pass
ext CostCentreStack
ccs UpdateFlag
upd [BinderP pass]
bndrs GenStgExpr pass
body Type
typ)
| CostCentreStack -> Bool
isCurrentCCS CostCentreStack
ccs
= XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> Type
-> GenStgRhs pass
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> Type
-> GenStgRhs pass
StgRhsClosure XRhsClosure pass
ext CostCentreStack
dontCareCCS UpdateFlag
upd [BinderP pass]
bndrs GenStgExpr pass
body Type
typ
removeRhsCCCS (StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
mu [StgTickish]
ts [StgArg]
args Type
typ)
| CostCentreStack -> Bool
isCurrentCCS CostCentreStack
ccs
= CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> Type
-> GenStgRhs pass
forall (pass :: StgPass).
CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> Type
-> GenStgRhs pass
StgRhsCon CostCentreStack
dontCareCCS DataCon
con ConstructorNumber
mu [StgTickish]
ts [StgArg]
args Type
typ
removeRhsCCCS GenStgRhs pass
rhs = GenStgRhs pass
rhs
newtype LiftM a
= LiftM { forall a. LiftM a -> RWST Env (OrdList FloatLang) () UniqSM a
unwrapLiftM :: RWST Env (OrdList FloatLang) () UniqSM a }
deriving ((forall a b. (a -> b) -> LiftM a -> LiftM b)
-> (forall a b. a -> LiftM b -> LiftM a) -> Functor LiftM
forall a b. a -> LiftM b -> LiftM a
forall a b. (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
fmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
$c<$ :: forall a b. a -> LiftM b -> LiftM a
<$ :: forall a b. a -> LiftM b -> LiftM a
Functor, Functor LiftM
Functor LiftM =>
(forall a. a -> LiftM a)
-> (forall a b. LiftM (a -> b) -> LiftM a -> LiftM b)
-> (forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c)
-> (forall a b. LiftM a -> LiftM b -> LiftM b)
-> (forall a b. LiftM a -> LiftM b -> LiftM a)
-> Applicative LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> LiftM a
pure :: forall a. a -> LiftM a
$c<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
$cliftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
liftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
$c*> :: forall a b. LiftM a -> LiftM b -> LiftM b
*> :: forall a b. LiftM a -> LiftM b -> LiftM b
$c<* :: forall a b. LiftM a -> LiftM b -> LiftM a
<* :: forall a b. LiftM a -> LiftM b -> LiftM a
Applicative, Applicative LiftM
Applicative LiftM =>
(forall a b. LiftM a -> (a -> LiftM b) -> LiftM b)
-> (forall a b. LiftM a -> LiftM b -> LiftM b)
-> (forall a. a -> LiftM a)
-> Monad LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
$c>> :: forall a b. LiftM a -> LiftM b -> LiftM b
>> :: forall a b. LiftM a -> LiftM b -> LiftM b
$creturn :: forall a. a -> LiftM a
return :: forall a. a -> LiftM a
Monad)
instance MonadUnique LiftM where
getUniqueSupplyM :: LiftM UniqSupply
getUniqueSupplyM = RWST Env (OrdList FloatLang) () UniqSM UniqSupply
-> LiftM UniqSupply
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (UniqSM UniqSupply
-> RWST Env (OrdList FloatLang) () UniqSM UniqSupply
forall (m :: * -> *) a.
Monad m =>
m a -> RWST Env (OrdList FloatLang) () m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift UniqSM UniqSupply
forall (m :: * -> *). MonadUnique m => m UniqSupply
getUniqueSupplyM)
getUniqueM :: LiftM Unique
getUniqueM = RWST Env (OrdList FloatLang) () UniqSM Unique -> LiftM Unique
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (UniqSM Unique -> RWST Env (OrdList FloatLang) () UniqSM Unique
forall (m :: * -> *) a.
Monad m =>
m a -> RWST Env (OrdList FloatLang) () m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift UniqSM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM)
getUniquesM :: LiftM [Unique]
getUniquesM = RWST Env (OrdList FloatLang) () UniqSM [Unique] -> LiftM [Unique]
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (UniqSM [Unique] -> RWST Env (OrdList FloatLang) () UniqSM [Unique]
forall (m :: * -> *) a.
Monad m =>
m a -> RWST Env (OrdList FloatLang) () m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift UniqSM [Unique]
forall (m :: * -> *). MonadUnique m => m [Unique]
getUniquesM)
runLiftM :: StgLiftConfig -> UniqSupply -> LiftM () -> [OutStgTopBinding]
runLiftM :: StgLiftConfig
-> UniqSupply -> LiftM () -> [GenStgTopBinding 'Vanilla]
runLiftM StgLiftConfig
cfg UniqSupply
us (LiftM RWST Env (OrdList FloatLang) () UniqSM ()
m) = [FloatLang] -> [GenStgTopBinding 'Vanilla]
collectFloats (OrdList FloatLang -> [FloatLang]
forall a. OrdList a -> [a]
fromOL OrdList FloatLang
floats)
where
(()
_, ()
_, OrdList FloatLang
floats) = UniqSupply
-> UniqSM ((), (), OrdList FloatLang)
-> ((), (), OrdList FloatLang)
forall a. UniqSupply -> UniqSM a -> a
initUs_ UniqSupply
us (RWST Env (OrdList FloatLang) () UniqSM ()
-> Env -> () -> UniqSM ((), (), OrdList FloatLang)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
runRWST RWST Env (OrdList FloatLang) () UniqSM ()
m (StgLiftConfig -> Env
emptyEnv StgLiftConfig
cfg) ())
getConfig :: LiftM StgLiftConfig
getConfig :: LiftM StgLiftConfig
getConfig = RWST Env (OrdList FloatLang) () UniqSM StgLiftConfig
-> LiftM StgLiftConfig
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM StgLiftConfig
-> LiftM StgLiftConfig)
-> RWST Env (OrdList FloatLang) () UniqSM StgLiftConfig
-> LiftM StgLiftConfig
forall a b. (a -> b) -> a -> b
$ Env -> StgLiftConfig
e_config (Env -> StgLiftConfig)
-> RWST Env (OrdList FloatLang) () UniqSM Env
-> RWST Env (OrdList FloatLang) () UniqSM StgLiftConfig
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RWST Env (OrdList FloatLang) () UniqSM Env
forall w (m :: * -> *) r s. (Monoid w, Monad m) => RWST r w s m r
RWS.ask
addTopStringLit :: OutId -> ByteString -> LiftM ()
addTopStringLit :: Id -> ByteString -> LiftM ()
addTopStringLit Id
id = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> (ByteString -> RWST Env (OrdList FloatLang) () UniqSM ())
-> ByteString
-> LiftM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> (ByteString -> OrdList FloatLang)
-> ByteString
-> RWST Env (OrdList FloatLang) () UniqSM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang)
-> (ByteString -> FloatLang) -> ByteString -> OrdList FloatLang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgTopBinding 'Vanilla -> FloatLang
PlainTopBinding (GenStgTopBinding 'Vanilla -> FloatLang)
-> (ByteString -> GenStgTopBinding 'Vanilla)
-> ByteString
-> FloatLang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> ByteString -> GenStgTopBinding 'Vanilla
forall (pass :: StgPass). Id -> ByteString -> GenStgTopBinding pass
StgTopStringLit Id
id
startBindingGroup :: LiftM ()
startBindingGroup :: LiftM ()
startBindingGroup = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a b. (a -> b) -> a -> b
$ OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall a b. (a -> b) -> a -> b
$ FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang) -> FloatLang -> OrdList FloatLang
forall a b. (a -> b) -> a -> b
$ FloatLang
StartBindingGroup
endBindingGroup :: LiftM ()
endBindingGroup :: LiftM ()
endBindingGroup = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a b. (a -> b) -> a -> b
$ OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall a b. (a -> b) -> a -> b
$ FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang) -> FloatLang -> OrdList FloatLang
forall a b. (a -> b) -> a -> b
$ FloatLang
EndBindingGroup
addLiftedBinding :: OutStgBinding -> LiftM ()
addLiftedBinding :: GenStgBinding 'Vanilla -> LiftM ()
addLiftedBinding = RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ()
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM () -> LiftM ())
-> (GenStgBinding 'Vanilla
-> RWST Env (OrdList FloatLang) () UniqSM ())
-> GenStgBinding 'Vanilla
-> LiftM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ()
forall (m :: * -> *) w r s. Monad m => w -> RWST r w s m ()
RWS.tell (OrdList FloatLang -> RWST Env (OrdList FloatLang) () UniqSM ())
-> (GenStgBinding 'Vanilla -> OrdList FloatLang)
-> GenStgBinding 'Vanilla
-> RWST Env (OrdList FloatLang) () UniqSM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FloatLang -> OrdList FloatLang
forall a. a -> OrdList a
unitOL (FloatLang -> OrdList FloatLang)
-> (GenStgBinding 'Vanilla -> FloatLang)
-> GenStgBinding 'Vanilla
-> OrdList FloatLang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding 'Vanilla -> FloatLang
LiftedBinding
withSubstBndr :: Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr :: forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr Id
bndr Id -> LiftM a
inner = RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a)
-> RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
forall a b. (a -> b) -> a -> b
$ do
subst <- (Env -> Subst) -> RWST Env (OrdList FloatLang) () UniqSM Subst
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> Subst
e_subst
let (bndr', subst') = substBndr bndr subst
RWS.local (\Env
e -> Env
e { e_subst = subst' }) (unwrapLiftM (inner bndr'))
withSubstBndrs :: Traversable f => f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs :: forall (f :: * -> *) a.
Traversable f =>
f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs = ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a
forall {k} (r :: k) (m :: k -> *) a.
ContT r m a -> (a -> m r) -> m r
runContT (ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a)
-> (f Id -> ContT a LiftM (f Id))
-> f Id
-> (f Id -> LiftM a)
-> LiftM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id -> ContT a LiftM Id) -> f Id -> ContT a LiftM (f Id)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> f a -> f (f b)
traverse (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id)
-> (Id -> (Id -> LiftM a) -> LiftM a) -> Id -> ContT a LiftM Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> (Id -> LiftM a) -> LiftM a
forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr)
withLiftedBndr :: DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr :: forall a. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr DIdSet
abs_ids Id
bndr Id -> LiftM a
inner = do
let str :: FastString
str = String -> FastString
fsLit String
"$l" FastString -> FastString -> FastString
`appendFS` OccName -> FastString
occNameFS (Id -> OccName
forall a. NamedThing a => a -> OccName
getOccName Id
bndr)
let ty :: Type
ty = [Id] -> Type -> Type
mkLamTypes (DIdSet -> [Id]
dVarSetElems DIdSet
abs_ids) (Id -> Type
idType Id
bndr)
bndr' <-
Id -> [Id] -> Id -> Id
transferPolyIdInfo Id
bndr (DIdSet -> [Id]
dVarSetElems DIdSet
abs_ids)
(Id -> Id) -> LiftM Id -> LiftM Id
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FastString -> Type -> Type -> LiftM Id
forall (m :: * -> *).
MonadUnique m =>
FastString -> Type -> Type -> m Id
mkSysLocalM FastString
str Type
ManyTy Type
ty
LiftM $ RWS.local
(\Env
e -> Env
e
{ e_subst = extendSubst bndr bndr' $ extendInScope bndr' $ e_subst e
, e_expansions = extendVarEnv (e_expansions e) bndr abs_ids
})
(unwrapLiftM (inner bndr'))
withLiftedBndrs :: Traversable f => DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs :: forall (f :: * -> *) a.
Traversable f =>
DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs DIdSet
abs_ids = ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a
forall {k} (r :: k) (m :: k -> *) a.
ContT r m a -> (a -> m r) -> m r
runContT (ContT a LiftM (f Id) -> (f Id -> LiftM a) -> LiftM a)
-> (f Id -> ContT a LiftM (f Id))
-> f Id
-> (f Id -> LiftM a)
-> LiftM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id -> ContT a LiftM Id) -> f Id -> ContT a LiftM (f Id)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> f a -> f (f b)
traverse (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Id -> LiftM a) -> LiftM a) -> ContT a LiftM Id)
-> (Id -> (Id -> LiftM a) -> LiftM a) -> Id -> ContT a LiftM Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
forall a. DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
withLiftedBndr DIdSet
abs_ids)
substOcc :: Id -> LiftM Id
substOcc :: Id -> LiftM Id
substOcc Id
id = RWST Env (OrdList FloatLang) () UniqSM Id -> LiftM Id
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM ((Env -> Id) -> RWST Env (OrdList FloatLang) () UniqSM Id
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks (HasCallStack => Id -> Subst -> Id
Id -> Subst -> Id
lookupIdSubst Id
id (Subst -> Id) -> (Env -> Subst) -> Env -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Subst
e_subst))
isLifted :: InId -> LiftM Bool
isLifted :: Id -> LiftM Bool
isLifted Id
bndr = RWST Env (OrdList FloatLang) () UniqSM Bool -> LiftM Bool
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM ((Env -> Bool) -> RWST Env (OrdList FloatLang) () UniqSM Bool
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks (Id -> IdEnv DIdSet -> Bool
forall a. Id -> VarEnv a -> Bool
elemVarEnv Id
bndr (IdEnv DIdSet -> Bool) -> (Env -> IdEnv DIdSet) -> Env -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> IdEnv DIdSet
e_expansions))
formerFreeVars :: InId -> LiftM [OutId]
formerFreeVars :: Id -> LiftM [Id]
formerFreeVars Id
f = RWST Env (OrdList FloatLang) () UniqSM [Id] -> LiftM [Id]
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM [Id] -> LiftM [Id])
-> RWST Env (OrdList FloatLang) () UniqSM [Id] -> LiftM [Id]
forall a b. (a -> b) -> a -> b
$ do
expansions <- (Env -> IdEnv DIdSet)
-> RWST Env (OrdList FloatLang) () UniqSM (IdEnv DIdSet)
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> IdEnv DIdSet
e_expansions
pure $ case lookupVarEnv expansions f of
Maybe DIdSet
Nothing -> []
Just DIdSet
fvs -> DIdSet -> [Id]
dVarSetElems DIdSet
fvs
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
liftedIdsExpander = RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
-> LiftM (DIdSet -> DIdSet)
forall a. RWST Env (OrdList FloatLang) () UniqSM a -> LiftM a
LiftM (RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
-> LiftM (DIdSet -> DIdSet))
-> RWST Env (OrdList FloatLang) () UniqSM (DIdSet -> DIdSet)
-> LiftM (DIdSet -> DIdSet)
forall a b. (a -> b) -> a -> b
$ do
expansions <- (Env -> IdEnv DIdSet)
-> RWST Env (OrdList FloatLang) () UniqSM (IdEnv DIdSet)
forall w (m :: * -> *) r a s.
(Monoid w, Monad m) =>
(r -> a) -> RWST r w s m a
RWS.asks Env -> IdEnv DIdSet
e_expansions
subst <- RWS.asks e_subst
let go DIdSet
set Id
fv = case IdEnv DIdSet -> Id -> Maybe DIdSet
forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv IdEnv DIdSet
expansions Id
fv of
Maybe DIdSet
Nothing -> DIdSet -> Id -> DIdSet
extendDVarSet DIdSet
set (HasCallStack => Id -> Subst -> Id
Id -> Subst -> Id
noWarnLookupIdSubst Id
fv Subst
subst)
Just DIdSet
fvs' -> DIdSet -> DIdSet -> DIdSet
unionDVarSet DIdSet
set DIdSet
fvs'
let expander DIdSet
fvs = (DIdSet -> Id -> DIdSet) -> DIdSet -> [Id] -> DIdSet
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' DIdSet -> Id -> DIdSet
go DIdSet
emptyDVarSet (DIdSet -> [Id]
dVarSetElems DIdSet
fvs)
pure expander