{-# LANGUAGE CPP #-}
module FloatOut ( floatOutwards ) where
import GhcPrelude
import CoreSyn
import CoreUtils
import MkCore
import CoreArity ( etaExpand )
import CoreMonad ( FloatOutSwitches(..) )
import DynFlags
import ErrUtils ( dumpIfSet_dyn )
import Id ( Id, idArity, idType, isBottomingId,
isJoinId, isJoinId_maybe )
import SetLevels
import UniqSupply ( UniqSupply )
import Bag
import Util
import Maybes
import Outputable
import Type
import qualified Data.IntMap as M
import Data.List ( partition )
#include "HsVersions.h"
floatOutwards :: FloatOutSwitches
-> DynFlags
-> UniqSupply
-> CoreProgram -> IO CoreProgram
floatOutwards :: FloatOutSwitches
-> DynFlags -> UniqSupply -> CoreProgram -> IO CoreProgram
floatOutwards float_sws :: FloatOutSwitches
float_sws dflags :: DynFlags
dflags us :: UniqSupply
us pgm :: CoreProgram
pgm
= do {
let { annotated_w_levels :: [LevelledBind]
annotated_w_levels = FloatOutSwitches -> CoreProgram -> UniqSupply -> [LevelledBind]
setLevels FloatOutSwitches
float_sws CoreProgram
pgm UniqSupply
us ;
(fss :: [FloatStats]
fss, binds_s' :: [Bag CoreBind]
binds_s') = [(FloatStats, Bag CoreBind)] -> ([FloatStats], [Bag CoreBind])
forall a b. [(a, b)] -> ([a], [b])
unzip ((LevelledBind -> (FloatStats, Bag CoreBind))
-> [LevelledBind] -> [(FloatStats, Bag CoreBind)]
forall a b. (a -> b) -> [a] -> [b]
map LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind [LevelledBind]
annotated_w_levels)
} ;
DynFlags -> DumpFlag -> String -> SDoc -> IO ()
dumpIfSet_dyn DynFlags
dflags DumpFlag
Opt_D_verbose_core2core "Levels added:"
([SDoc] -> SDoc
vcat ((LevelledBind -> SDoc) -> [LevelledBind] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map LevelledBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr [LevelledBind]
annotated_w_levels));
let { (tlets :: Int
tlets, ntlets :: Int
ntlets, lams :: Int
lams) = FloatStats -> (Int, Int, Int)
get_stats ([FloatStats] -> FloatStats
sum_stats [FloatStats]
fss) };
DynFlags -> DumpFlag -> String -> SDoc -> IO ()
dumpIfSet_dyn DynFlags
dflags DumpFlag
Opt_D_dump_simpl_stats "FloatOut stats:"
([SDoc] -> SDoc
hcat [ Int -> SDoc
int Int
tlets, String -> SDoc
text " Lets floated to top level; ",
Int -> SDoc
int Int
ntlets, String -> SDoc
text " Lets floated elsewhere; from ",
Int -> SDoc
int Int
lams, String -> SDoc
text " Lambda groups"]);
CoreProgram -> IO CoreProgram
forall (m :: * -> *) a. Monad m => a -> m a
return (Bag CoreBind -> CoreProgram
forall a. Bag a -> [a]
bagToList ([Bag CoreBind] -> Bag CoreBind
forall a. [Bag a] -> Bag a
unionManyBags [Bag CoreBind]
binds_s'))
}
floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind bind :: LevelledBind
bind
= case (LevelledBind -> (FloatStats, FloatBinds, CoreProgram)
floatBind LevelledBind
bind) of { (fs :: FloatStats
fs, floats :: FloatBinds
floats, bind' :: CoreProgram
bind') ->
let float_bag :: Bag CoreBind
float_bag = FloatBinds -> Bag CoreBind
flattenTopFloats FloatBinds
floats
in case CoreProgram
bind' of
[Rec prs :: [(CoreBndr, Expr CoreBndr)]
prs] -> (FloatStats
fs, CoreBind -> Bag CoreBind
forall a. a -> Bag a
unitBag ([(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec (Bag CoreBind
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
addTopFloatPairs Bag CoreBind
float_bag [(CoreBndr, Expr CoreBndr)]
prs)))
[NonRec b :: CoreBndr
b e :: Expr CoreBndr
e] -> (FloatStats
fs, Bag CoreBind
float_bag Bag CoreBind -> CoreBind -> Bag CoreBind
forall a. Bag a -> a -> Bag a
`snocBag` CoreBndr -> Expr CoreBndr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
b Expr CoreBndr
e)
_ -> String -> SDoc -> (FloatStats, Bag CoreBind)
forall a. HasCallStack => String -> SDoc -> a
pprPanic "floatTopBind" (CoreProgram -> SDoc
forall a. Outputable a => a -> SDoc
ppr CoreProgram
bind') }
floatBind :: LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind :: LevelledBind -> (FloatStats, FloatBinds, CoreProgram)
floatBind (NonRec (TB var :: CoreBndr
var _) rhs :: Expr (TaggedBndr FloatSpec)
rhs)
= case (CoreBndr
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatRhs CoreBndr
var Expr (TaggedBndr FloatSpec)
rhs) of { (fs :: FloatStats
fs, rhs_floats :: FloatBinds
rhs_floats, rhs' :: Expr CoreBndr
rhs') ->
let rhs'' :: Expr CoreBndr
rhs'' | CoreBndr -> Bool
isBottomingId CoreBndr
var = Int -> Expr CoreBndr -> Expr CoreBndr
etaExpand (CoreBndr -> Int
idArity CoreBndr
var) Expr CoreBndr
rhs'
| Bool
otherwise = Expr CoreBndr
rhs'
in (FloatStats
fs, FloatBinds
rhs_floats, [CoreBndr -> Expr CoreBndr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
var Expr CoreBndr
rhs'']) }
floatBind (Rec pairs :: [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
pairs)
= case ((TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds,
([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)])))
-> [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
-> (FloatStats, FloatBinds,
[([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)])])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList (TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds,
([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)]))
do_pair [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
pairs of { (fs :: FloatStats
fs, rhs_floats :: FloatBinds
rhs_floats, new_pairs :: [([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)])]
new_pairs) ->
let (new_ul_pairss :: [[(CoreBndr, Expr CoreBndr)]]
new_ul_pairss, new_other_pairss :: [[(CoreBndr, Expr CoreBndr)]]
new_other_pairss) = [([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)])]
-> ([[(CoreBndr, Expr CoreBndr)]], [[(CoreBndr, Expr CoreBndr)]])
forall a b. [(a, b)] -> ([a], [b])
unzip [([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)])]
new_pairs
(new_join_pairs :: [(CoreBndr, Expr CoreBndr)]
new_join_pairs, new_l_pairs :: [(CoreBndr, Expr CoreBndr)]
new_l_pairs) = ((CoreBndr, Expr CoreBndr) -> Bool)
-> [(CoreBndr, Expr CoreBndr)]
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (CoreBndr -> Bool
isJoinId (CoreBndr -> Bool)
-> ((CoreBndr, Expr CoreBndr) -> CoreBndr)
-> (CoreBndr, Expr CoreBndr)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CoreBndr, Expr CoreBndr) -> CoreBndr
forall a b. (a, b) -> a
fst)
([[(CoreBndr, Expr CoreBndr)]] -> [(CoreBndr, Expr CoreBndr)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(CoreBndr, Expr CoreBndr)]]
new_other_pairss)
new_rec_binds :: CoreProgram
new_rec_binds | [(CoreBndr, Expr CoreBndr)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(CoreBndr, Expr CoreBndr)]
new_join_pairs = [ [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, Expr CoreBndr)]
new_l_pairs ]
| [(CoreBndr, Expr CoreBndr)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(CoreBndr, Expr CoreBndr)]
new_l_pairs = [ [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, Expr CoreBndr)]
new_join_pairs ]
| Bool
otherwise = [ [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, Expr CoreBndr)]
new_l_pairs
, [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, Expr CoreBndr)]
new_join_pairs ]
new_non_rec_binds :: CoreProgram
new_non_rec_binds = [ CoreBndr -> Expr CoreBndr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
b Expr CoreBndr
e | (b :: CoreBndr
b, e :: Expr CoreBndr
e) <- [[(CoreBndr, Expr CoreBndr)]] -> [(CoreBndr, Expr CoreBndr)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(CoreBndr, Expr CoreBndr)]]
new_ul_pairss ]
in
(FloatStats
fs, FloatBinds
rhs_floats, CoreProgram
new_non_rec_binds CoreProgram -> CoreProgram -> CoreProgram
forall a. [a] -> [a] -> [a]
++ CoreProgram
new_rec_binds) }
where
do_pair :: (LevelledBndr, LevelledExpr)
-> (FloatStats, FloatBinds,
([(Id,CoreExpr)],
[(Id,CoreExpr)]))
do_pair :: (TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds,
([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)]))
do_pair (TB name :: CoreBndr
name spec :: FloatSpec
spec, rhs :: Expr (TaggedBndr FloatSpec)
rhs)
| Level -> Bool
isTopLvl Level
dest_lvl
= case (CoreBndr
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatRhs CoreBndr
name Expr (TaggedBndr FloatSpec)
rhs) of { (fs :: FloatStats
fs, rhs_floats :: FloatBinds
rhs_floats, rhs' :: Expr CoreBndr
rhs') ->
(FloatStats
fs, FloatBinds
emptyFloats, ([], Bag CoreBind
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
addTopFloatPairs (FloatBinds -> Bag CoreBind
flattenTopFloats FloatBinds
rhs_floats)
[(CoreBndr
name, Expr CoreBndr
rhs')]))}
| Bool
otherwise
= case (CoreBndr
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatRhs CoreBndr
name Expr (TaggedBndr FloatSpec)
rhs) of { (fs :: FloatStats
fs, rhs_floats :: FloatBinds
rhs_floats, rhs' :: Expr CoreBndr
rhs') ->
case (Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel Level
dest_lvl FloatBinds
rhs_floats) of { (rhs_floats' :: FloatBinds
rhs_floats', heres :: Bag FloatBind
heres) ->
case (Bag FloatBind
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
Bag FloatBind)
splitRecFloats Bag FloatBind
heres) of { (ul_pairs :: [(CoreBndr, Expr CoreBndr)]
ul_pairs, pairs :: [(CoreBndr, Expr CoreBndr)]
pairs, case_heres :: Bag FloatBind
case_heres) ->
let pairs' :: [(CoreBndr, Expr CoreBndr)]
pairs' = (CoreBndr
name, Bag FloatBind -> Expr CoreBndr -> Expr CoreBndr
installUnderLambdas Bag FloatBind
case_heres Expr CoreBndr
rhs') (CoreBndr, Expr CoreBndr)
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall a. a -> [a] -> [a]
: [(CoreBndr, Expr CoreBndr)]
pairs in
(FloatStats
fs, FloatBinds
rhs_floats', ([(CoreBndr, Expr CoreBndr)]
ul_pairs, [(CoreBndr, Expr CoreBndr)]
pairs')) }}}
where
dest_lvl :: Level
dest_lvl = FloatSpec -> Level
floatSpecLevel FloatSpec
spec
splitRecFloats :: Bag FloatBind
-> ([(Id,CoreExpr)],
[(Id,CoreExpr)],
Bag FloatBind)
splitRecFloats :: Bag FloatBind
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
Bag FloatBind)
splitRecFloats fs :: Bag FloatBind
fs
= [(CoreBndr, Expr CoreBndr)]
-> [(CoreBndr, Expr CoreBndr)]
-> [FloatBind]
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
Bag FloatBind)
go [] [] (Bag FloatBind -> [FloatBind]
forall a. Bag a -> [a]
bagToList Bag FloatBind
fs)
where
go :: [(CoreBndr, Expr CoreBndr)]
-> [(CoreBndr, Expr CoreBndr)]
-> [FloatBind]
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
Bag FloatBind)
go ul_prs :: [(CoreBndr, Expr CoreBndr)]
ul_prs prs :: [(CoreBndr, Expr CoreBndr)]
prs (FloatLet (NonRec b :: CoreBndr
b r :: Expr CoreBndr
r) : fs :: [FloatBind]
fs) | HasDebugCallStack => Type -> Bool
Type -> Bool
isUnliftedType (CoreBndr -> Type
idType CoreBndr
b)
, Bool -> Bool
not (CoreBndr -> Bool
isJoinId CoreBndr
b)
= [(CoreBndr, Expr CoreBndr)]
-> [(CoreBndr, Expr CoreBndr)]
-> [FloatBind]
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
Bag FloatBind)
go ((CoreBndr
b,Expr CoreBndr
r)(CoreBndr, Expr CoreBndr)
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall a. a -> [a] -> [a]
:[(CoreBndr, Expr CoreBndr)]
ul_prs) [(CoreBndr, Expr CoreBndr)]
prs [FloatBind]
fs
| Bool
otherwise
= [(CoreBndr, Expr CoreBndr)]
-> [(CoreBndr, Expr CoreBndr)]
-> [FloatBind]
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
Bag FloatBind)
go [(CoreBndr, Expr CoreBndr)]
ul_prs ((CoreBndr
b,Expr CoreBndr
r)(CoreBndr, Expr CoreBndr)
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall a. a -> [a] -> [a]
:[(CoreBndr, Expr CoreBndr)]
prs) [FloatBind]
fs
go ul_prs :: [(CoreBndr, Expr CoreBndr)]
ul_prs prs :: [(CoreBndr, Expr CoreBndr)]
prs (FloatLet (Rec prs' :: [(CoreBndr, Expr CoreBndr)]
prs') : fs :: [FloatBind]
fs) = [(CoreBndr, Expr CoreBndr)]
-> [(CoreBndr, Expr CoreBndr)]
-> [FloatBind]
-> ([(CoreBndr, Expr CoreBndr)], [(CoreBndr, Expr CoreBndr)],
Bag FloatBind)
go [(CoreBndr, Expr CoreBndr)]
ul_prs ([(CoreBndr, Expr CoreBndr)]
prs' [(CoreBndr, Expr CoreBndr)]
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall a. [a] -> [a] -> [a]
++ [(CoreBndr, Expr CoreBndr)]
prs) [FloatBind]
fs
go ul_prs :: [(CoreBndr, Expr CoreBndr)]
ul_prs prs :: [(CoreBndr, Expr CoreBndr)]
prs fs :: [FloatBind]
fs = ([(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall a. [a] -> [a]
reverse [(CoreBndr, Expr CoreBndr)]
ul_prs, [(CoreBndr, Expr CoreBndr)]
prs,
[FloatBind] -> Bag FloatBind
forall a. [a] -> Bag a
listToBag [FloatBind]
fs)
installUnderLambdas :: Bag FloatBind -> CoreExpr -> CoreExpr
installUnderLambdas :: Bag FloatBind -> Expr CoreBndr -> Expr CoreBndr
installUnderLambdas floats :: Bag FloatBind
floats e :: Expr CoreBndr
e
| Bag FloatBind -> Bool
forall a. Bag a -> Bool
isEmptyBag Bag FloatBind
floats = Expr CoreBndr
e
| Bool
otherwise = Expr CoreBndr -> Expr CoreBndr
go Expr CoreBndr
e
where
go :: Expr CoreBndr -> Expr CoreBndr
go (Lam b :: CoreBndr
b e :: Expr CoreBndr
e) = CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. b -> Expr b -> Expr b
Lam CoreBndr
b (Expr CoreBndr -> Expr CoreBndr
go Expr CoreBndr
e)
go e :: Expr CoreBndr
e = Bag FloatBind -> Expr CoreBndr -> Expr CoreBndr
install Bag FloatBind
floats Expr CoreBndr
e
floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
floatList :: (a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList _ [] = (FloatStats
zeroStats, FloatBinds
emptyFloats, [])
floatList f :: a -> (FloatStats, FloatBinds, b)
f (a :: a
a:as :: [a]
as) = case a -> (FloatStats, FloatBinds, b)
f a
a of { (fs_a :: FloatStats
fs_a, binds_a :: FloatBinds
binds_a, b :: b
b) ->
case (a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList a -> (FloatStats, FloatBinds, b)
f [a]
as of { (fs_as :: FloatStats
fs_as, binds_as :: FloatBinds
binds_as, bs :: [b]
bs) ->
(FloatStats
fs_a FloatStats -> FloatStats -> FloatStats
`add_stats` FloatStats
fs_as, FloatBinds
binds_a FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
binds_as, b
bb -> [b] -> [b]
forall a. a -> [a] -> [a]
:[b]
bs) }}
floatBody :: Level
-> LevelledExpr
-> (FloatStats, FloatBinds, CoreExpr)
floatBody :: Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatBody lvl :: Level
lvl arg :: Expr (TaggedBndr FloatSpec)
arg
= case (Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
arg) of { (fsa :: FloatStats
fsa, floats :: FloatBinds
floats, arg' :: Expr CoreBndr
arg') ->
case (Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel Level
lvl FloatBinds
floats) of { (floats' :: FloatBinds
floats', heres :: Bag FloatBind
heres) ->
(FloatStats
fsa, FloatBinds
floats', Bag FloatBind -> Expr CoreBndr -> Expr CoreBndr
install Bag FloatBind
heres Expr CoreBndr
arg') }}
floatExpr :: LevelledExpr
-> (FloatStats, FloatBinds, CoreExpr)
floatExpr :: Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr (Var v :: CoreBndr
v) = (FloatStats
zeroStats, FloatBinds
emptyFloats, CoreBndr -> Expr CoreBndr
forall b. CoreBndr -> Expr b
Var CoreBndr
v)
floatExpr (Type ty :: Type
ty) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Type -> Expr CoreBndr
forall b. Type -> Expr b
Type Type
ty)
floatExpr (Coercion co :: Coercion
co) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Coercion -> Expr CoreBndr
forall b. Coercion -> Expr b
Coercion Coercion
co)
floatExpr (Lit lit :: Literal
lit) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Literal -> Expr CoreBndr
forall b. Literal -> Expr b
Lit Literal
lit)
floatExpr (App e :: Expr (TaggedBndr FloatSpec)
e a :: Expr (TaggedBndr FloatSpec)
a)
= case ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
e) of { (fse :: FloatStats
fse, floats_e :: FloatBinds
floats_e, e' :: Expr CoreBndr
e') ->
case ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
a) of { (fsa :: FloatStats
fsa, floats_a :: FloatBinds
floats_a, a' :: Expr CoreBndr
a') ->
(FloatStats
fse FloatStats -> FloatStats -> FloatStats
`add_stats` FloatStats
fsa, FloatBinds
floats_e FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
floats_a, Expr CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Expr b -> Expr b -> Expr b
App Expr CoreBndr
e' Expr CoreBndr
a') }}
floatExpr lam :: Expr (TaggedBndr FloatSpec)
lam@(Lam (TB _ lam_spec :: FloatSpec
lam_spec) _)
= let (bndrs_w_lvls :: [TaggedBndr FloatSpec]
bndrs_w_lvls, body :: Expr (TaggedBndr FloatSpec)
body) = Expr (TaggedBndr FloatSpec)
-> ([TaggedBndr FloatSpec], Expr (TaggedBndr FloatSpec))
forall b. Expr b -> ([b], Expr b)
collectBinders Expr (TaggedBndr FloatSpec)
lam
bndrs :: [CoreBndr]
bndrs = [CoreBndr
b | TB b :: CoreBndr
b _ <- [TaggedBndr FloatSpec]
bndrs_w_lvls]
bndr_lvl :: Level
bndr_lvl = Level -> Level
asJoinCeilLvl (FloatSpec -> Level
floatSpecLevel FloatSpec
lam_spec)
in
case (Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatBody Level
bndr_lvl Expr (TaggedBndr FloatSpec)
body) of { (fs :: FloatStats
fs, floats :: FloatBinds
floats, body' :: Expr CoreBndr
body') ->
(FloatStats -> FloatBinds -> FloatStats
add_to_stats FloatStats
fs FloatBinds
floats, FloatBinds
floats, [CoreBndr] -> Expr CoreBndr -> Expr CoreBndr
forall b. [b] -> Expr b -> Expr b
mkLams [CoreBndr]
bndrs Expr CoreBndr
body') }
floatExpr (Tick tickish :: Tickish CoreBndr
tickish expr :: Expr (TaggedBndr FloatSpec)
expr)
| Tickish CoreBndr
tickish Tickish CoreBndr -> TickishScoping -> Bool
forall id. Tickish id -> TickishScoping -> Bool
`tickishScopesLike` TickishScoping
SoftScope
= case ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
expr) of { (fs :: FloatStats
fs, floating_defns :: FloatBinds
floating_defns, expr' :: Expr CoreBndr
expr') ->
(FloatStats
fs, FloatBinds
floating_defns, Tickish CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
tickish Expr CoreBndr
expr') }
| Bool -> Bool
not (Tickish CoreBndr -> Bool
forall id. Tickish id -> Bool
tickishCounts Tickish CoreBndr
tickish) Bool -> Bool -> Bool
|| Tickish CoreBndr -> Bool
forall id. Tickish id -> Bool
tickishCanSplit Tickish CoreBndr
tickish
= case ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
expr) of { (fs :: FloatStats
fs, floating_defns :: FloatBinds
floating_defns, expr' :: Expr CoreBndr
expr') ->
let
annotated_defns :: FloatBinds
annotated_defns = Tickish CoreBndr -> FloatBinds -> FloatBinds
wrapTick (Tickish CoreBndr -> Tickish CoreBndr
forall id. Tickish id -> Tickish id
mkNoCount Tickish CoreBndr
tickish) FloatBinds
floating_defns
in
(FloatStats
fs, FloatBinds
annotated_defns, Tickish CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
tickish Expr CoreBndr
expr') }
| Breakpoint{} <- Tickish CoreBndr
tickish
= case (Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
expr) of { (fs :: FloatStats
fs, floating_defns :: FloatBinds
floating_defns, expr' :: Expr CoreBndr
expr') ->
(FloatStats
fs, FloatBinds
floating_defns, Tickish CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
tickish Expr CoreBndr
expr') }
| Bool
otherwise
= String -> SDoc -> (FloatStats, FloatBinds, Expr CoreBndr)
forall a. HasCallStack => String -> SDoc -> a
pprPanic "floatExpr tick" (Tickish CoreBndr -> SDoc
forall a. Outputable a => a -> SDoc
ppr Tickish CoreBndr
tickish)
floatExpr (Cast expr :: Expr (TaggedBndr FloatSpec)
expr co :: Coercion
co)
= case ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
expr) of { (fs :: FloatStats
fs, floating_defns :: FloatBinds
floating_defns, expr' :: Expr CoreBndr
expr') ->
(FloatStats
fs, FloatBinds
floating_defns, Expr CoreBndr -> Coercion -> Expr CoreBndr
forall b. Expr b -> Coercion -> Expr b
Cast Expr CoreBndr
expr' Coercion
co) }
floatExpr (Let bind :: LevelledBind
bind body :: Expr (TaggedBndr FloatSpec)
body)
= case FloatSpec
bind_spec of
FloatMe dest_lvl :: Level
dest_lvl
-> case (LevelledBind -> (FloatStats, FloatBinds, CoreProgram)
floatBind LevelledBind
bind) of { (fsb :: FloatStats
fsb, bind_floats :: FloatBinds
bind_floats, binds' :: CoreProgram
binds') ->
case (Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
body) of { (fse :: FloatStats
fse, body_floats :: FloatBinds
body_floats, body' :: Expr CoreBndr
body') ->
let new_bind_floats :: FloatBinds
new_bind_floats = (FloatBinds -> FloatBinds -> FloatBinds)
-> FloatBinds -> [FloatBinds] -> FloatBinds
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatBinds -> FloatBinds -> FloatBinds
plusFloats FloatBinds
emptyFloats
((CoreBind -> FloatBinds) -> CoreProgram -> [FloatBinds]
forall a b. (a -> b) -> [a] -> [b]
map (Level -> CoreBind -> FloatBinds
unitLetFloat Level
dest_lvl) CoreProgram
binds') in
( FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fsb FloatStats
fse
, FloatBinds
bind_floats FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
new_bind_floats
FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
body_floats
, Expr CoreBndr
body') }}
StayPut bind_lvl :: Level
bind_lvl
-> case (LevelledBind -> (FloatStats, FloatBinds, CoreProgram)
floatBind LevelledBind
bind) of { (fsb :: FloatStats
fsb, bind_floats :: FloatBinds
bind_floats, binds' :: CoreProgram
binds') ->
case (Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatBody Level
bind_lvl Expr (TaggedBndr FloatSpec)
body) of { (fse :: FloatStats
fse, body_floats :: FloatBinds
body_floats, body' :: Expr CoreBndr
body') ->
( FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fsb FloatStats
fse
, FloatBinds
bind_floats FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
body_floats
, (CoreBind -> Expr CoreBndr -> Expr CoreBndr)
-> Expr CoreBndr -> CoreProgram -> Expr CoreBndr
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr CoreBind -> Expr CoreBndr -> Expr CoreBndr
forall b. Bind b -> Expr b -> Expr b
Let Expr CoreBndr
body' CoreProgram
binds' ) }}
where
bind_spec :: FloatSpec
bind_spec = case LevelledBind
bind of
NonRec (TB _ s :: FloatSpec
s) _ -> FloatSpec
s
Rec ((TB _ s :: FloatSpec
s, _) : _) -> FloatSpec
s
Rec [] -> String -> FloatSpec
forall a. String -> a
panic "floatExpr:rec"
floatExpr (Case scrut :: Expr (TaggedBndr FloatSpec)
scrut (TB case_bndr :: CoreBndr
case_bndr case_spec :: FloatSpec
case_spec) ty :: Type
ty alts :: [Alt (TaggedBndr FloatSpec)]
alts)
= case FloatSpec
case_spec of
FloatMe dest_lvl :: Level
dest_lvl
| [(con :: AltCon
con@(DataAlt {}), bndrs :: [TaggedBndr FloatSpec]
bndrs, rhs :: Expr (TaggedBndr FloatSpec)
rhs)] <- [Alt (TaggedBndr FloatSpec)]
alts
-> case (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
scrut of { (fse :: FloatStats
fse, fde :: FloatBinds
fde, scrut' :: Expr CoreBndr
scrut') ->
case Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
rhs of { (fsb :: FloatStats
fsb, fdb :: FloatBinds
fdb, rhs' :: Expr CoreBndr
rhs') ->
let
float :: FloatBinds
float = Level
-> Expr CoreBndr -> CoreBndr -> AltCon -> [CoreBndr] -> FloatBinds
unitCaseFloat Level
dest_lvl Expr CoreBndr
scrut'
CoreBndr
case_bndr AltCon
con [CoreBndr
b | TB b :: CoreBndr
b _ <- [TaggedBndr FloatSpec]
bndrs]
in
(FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fse FloatStats
fsb, FloatBinds
fde FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
float FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
fdb, Expr CoreBndr
rhs') }}
| Bool
otherwise
-> String -> SDoc -> (FloatStats, FloatBinds, Expr CoreBndr)
forall a. HasCallStack => String -> SDoc -> a
pprPanic "Floating multi-case" ([Alt (TaggedBndr FloatSpec)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Alt (TaggedBndr FloatSpec)]
alts)
StayPut bind_lvl :: Level
bind_lvl
-> case (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
scrut of { (fse :: FloatStats
fse, fde :: FloatBinds
fde, scrut' :: Expr CoreBndr
scrut') ->
case (Alt (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, (AltCon, [CoreBndr], Expr CoreBndr)))
-> [Alt (TaggedBndr FloatSpec)]
-> (FloatStats, FloatBinds, [(AltCon, [CoreBndr], Expr CoreBndr)])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList (Level
-> Alt (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, (AltCon, [CoreBndr], Expr CoreBndr))
forall a t.
Level
-> (a, [TaggedBndr t], Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, (a, [CoreBndr], Expr CoreBndr))
float_alt Level
bind_lvl) [Alt (TaggedBndr FloatSpec)]
alts of { (fsa :: FloatStats
fsa, fda :: FloatBinds
fda, alts' :: [(AltCon, [CoreBndr], Expr CoreBndr)]
alts') ->
(FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fse FloatStats
fsa, FloatBinds
fda FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
fde, Expr CoreBndr
-> CoreBndr
-> Type
-> [(AltCon, [CoreBndr], Expr CoreBndr)]
-> Expr CoreBndr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case Expr CoreBndr
scrut' CoreBndr
case_bndr Type
ty [(AltCon, [CoreBndr], Expr CoreBndr)]
alts')
}}
where
float_alt :: Level
-> (a, [TaggedBndr t], Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, (a, [CoreBndr], Expr CoreBndr))
float_alt bind_lvl :: Level
bind_lvl (con :: a
con, bs :: [TaggedBndr t]
bs, rhs :: Expr (TaggedBndr FloatSpec)
rhs)
= case (Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatBody Level
bind_lvl Expr (TaggedBndr FloatSpec)
rhs) of { (fs :: FloatStats
fs, rhs_floats :: FloatBinds
rhs_floats, rhs' :: Expr CoreBndr
rhs') ->
(FloatStats
fs, FloatBinds
rhs_floats, (a
con, [CoreBndr
b | TB b :: CoreBndr
b _ <- [TaggedBndr t]
bs], Expr CoreBndr
rhs')) }
floatRhs :: CoreBndr
-> LevelledExpr
-> (FloatStats, FloatBinds, CoreExpr)
floatRhs :: CoreBndr
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatRhs bndr :: CoreBndr
bndr rhs :: Expr (TaggedBndr FloatSpec)
rhs
| Just join_arity :: Int
join_arity <- CoreBndr -> Maybe Int
isJoinId_maybe CoreBndr
bndr
, Just (bndrs :: [TaggedBndr FloatSpec]
bndrs, body :: Expr (TaggedBndr FloatSpec)
body) <- Int
-> Expr (TaggedBndr FloatSpec)
-> [TaggedBndr FloatSpec]
-> Maybe ([TaggedBndr FloatSpec], Expr (TaggedBndr FloatSpec))
forall t a.
(Eq t, Num t) =>
t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect Int
join_arity Expr (TaggedBndr FloatSpec)
rhs []
= case [TaggedBndr FloatSpec]
bndrs of
[] -> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
rhs
(TB _ lam_spec :: FloatSpec
lam_spec):_ ->
let lvl :: Level
lvl = FloatSpec -> Level
floatSpecLevel FloatSpec
lam_spec in
case Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatBody Level
lvl Expr (TaggedBndr FloatSpec)
body of { (fs :: FloatStats
fs, floats :: FloatBinds
floats, body' :: Expr CoreBndr
body') ->
(FloatStats
fs, FloatBinds
floats, [CoreBndr] -> Expr CoreBndr -> Expr CoreBndr
forall b. [b] -> Expr b -> Expr b
mkLams [CoreBndr
b | TB b :: CoreBndr
b _ <- [TaggedBndr FloatSpec]
bndrs] Expr CoreBndr
body') }
| Bool
otherwise
= (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling ((FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr))
-> (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr CoreBndr)
floatExpr Expr (TaggedBndr FloatSpec)
rhs
where
try_collect :: t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect 0 expr :: Expr a
expr acc :: [a]
acc = ([a], Expr a) -> Maybe ([a], Expr a)
forall a. a -> Maybe a
Just ([a] -> [a]
forall a. [a] -> [a]
reverse [a]
acc, Expr a
expr)
try_collect n :: t
n (Lam b :: a
b e :: Expr a
e) acc :: [a]
acc = t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect (t
nt -> t -> t
forall a. Num a => a -> a -> a
-1) Expr a
e (a
ba -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
acc)
try_collect _ _ _ = Maybe ([a], Expr a)
forall a. Maybe a
Nothing
data FloatStats
= FlS Int
Int
Int
get_stats :: FloatStats -> (Int, Int, Int)
get_stats :: FloatStats -> (Int, Int, Int)
get_stats (FlS a :: Int
a b :: Int
b c :: Int
c) = (Int
a, Int
b, Int
c)
zeroStats :: FloatStats
zeroStats :: FloatStats
zeroStats = Int -> Int -> Int -> FloatStats
FlS 0 0 0
sum_stats :: [FloatStats] -> FloatStats
sum_stats :: [FloatStats] -> FloatStats
sum_stats xs :: [FloatStats]
xs = (FloatStats -> FloatStats -> FloatStats)
-> FloatStats -> [FloatStats] -> FloatStats
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
zeroStats [FloatStats]
xs
add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats (FlS a1 :: Int
a1 b1 :: Int
b1 c1 :: Int
c1) (FlS a2 :: Int
a2 b2 :: Int
b2 c2 :: Int
c2)
= Int -> Int -> Int -> FloatStats
FlS (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a2) (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
b2) (Int
c1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c2)
add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats (FlS a :: Int
a b :: Int
b c :: Int
c) (FB tops :: Bag CoreBind
tops ceils :: Bag FloatBind
ceils others :: MajorEnv
others)
= Int -> Int -> Int -> FloatStats
FlS (Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag CoreBind -> Int
forall a. Bag a -> Int
lengthBag Bag CoreBind
tops)
(Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag FloatBind -> Int
forall a. Bag a -> Int
lengthBag Bag FloatBind
ceils Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag FloatBind -> Int
forall a. Bag a -> Int
lengthBag (MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
others))
(Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1)
type FloatLet = CoreBind
type MajorEnv = M.IntMap MinorEnv
type MinorEnv = M.IntMap (Bag FloatBind)
data FloatBinds = FB !(Bag FloatLet)
!(Bag FloatBind)
!MajorEnv
instance Outputable FloatBinds where
ppr :: FloatBinds -> SDoc
ppr (FB fbs :: Bag CoreBind
fbs ceils :: Bag FloatBind
ceils defs :: MajorEnv
defs)
= String -> SDoc
text "FB" SDoc -> SDoc -> SDoc
<+> (SDoc -> SDoc
braces (SDoc -> SDoc) -> SDoc -> SDoc
forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
vcat
[ String -> SDoc
text "tops =" SDoc -> SDoc -> SDoc
<+> Bag CoreBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Bag CoreBind
fbs
, String -> SDoc
text "ceils =" SDoc -> SDoc -> SDoc
<+> Bag FloatBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Bag FloatBind
ceils
, String -> SDoc
text "non-tops =" SDoc -> SDoc -> SDoc
<+> MajorEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr MajorEnv
defs ])
flattenTopFloats :: FloatBinds -> Bag CoreBind
flattenTopFloats :: FloatBinds -> Bag CoreBind
flattenTopFloats (FB tops :: Bag CoreBind
tops ceils :: Bag FloatBind
ceils defs :: MajorEnv
defs)
= ASSERT2( isEmptyBag (flattenMajor defs), ppr defs )
ASSERT2( isEmptyBag ceils, ppr ceils )
Bag CoreBind
tops
addTopFloatPairs :: Bag CoreBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
addTopFloatPairs :: Bag CoreBind
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
addTopFloatPairs float_bag :: Bag CoreBind
float_bag prs :: [(CoreBndr, Expr CoreBndr)]
prs
= (CoreBind
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)])
-> [(CoreBndr, Expr CoreBndr)]
-> Bag CoreBind
-> [(CoreBndr, Expr CoreBndr)]
forall a r. (a -> r -> r) -> r -> Bag a -> r
foldrBag CoreBind
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall a. Bind a -> [(a, Expr a)] -> [(a, Expr a)]
add [(CoreBndr, Expr CoreBndr)]
prs Bag CoreBind
float_bag
where
add :: Bind a -> [(a, Expr a)] -> [(a, Expr a)]
add (NonRec b :: a
b r :: Expr a
r) prs :: [(a, Expr a)]
prs = (a
b,Expr a
r)(a, Expr a) -> [(a, Expr a)] -> [(a, Expr a)]
forall a. a -> [a] -> [a]
:[(a, Expr a)]
prs
add (Rec prs1 :: [(a, Expr a)]
prs1) prs2 :: [(a, Expr a)]
prs2 = [(a, Expr a)]
prs1 [(a, Expr a)] -> [(a, Expr a)] -> [(a, Expr a)]
forall a. [a] -> [a] -> [a]
++ [(a, Expr a)]
prs2
flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor = (MinorEnv -> Bag FloatBind -> Bag FloatBind)
-> Bag FloatBind -> MajorEnv -> Bag FloatBind
forall a b. (a -> b -> b) -> b -> IntMap a -> b
M.foldr (Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> (MinorEnv -> Bag FloatBind)
-> MinorEnv
-> Bag FloatBind
-> Bag FloatBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MinorEnv -> Bag FloatBind
flattenMinor) Bag FloatBind
forall a. Bag a
emptyBag
flattenMinor :: MinorEnv -> Bag FloatBind
flattenMinor :: MinorEnv -> Bag FloatBind
flattenMinor = (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> Bag FloatBind -> MinorEnv -> Bag FloatBind
forall a b. (a -> b -> b) -> b -> IntMap a -> b
M.foldr Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags Bag FloatBind
forall a. Bag a
emptyBag
emptyFloats :: FloatBinds
emptyFloats :: FloatBinds
emptyFloats = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
forall a. Bag a
emptyBag MajorEnv
forall a. IntMap a
M.empty
unitCaseFloat :: Level -> CoreExpr -> Id -> AltCon -> [Var] -> FloatBinds
unitCaseFloat :: Level
-> Expr CoreBndr -> CoreBndr -> AltCon -> [CoreBndr] -> FloatBinds
unitCaseFloat (Level major :: Int
major minor :: Int
minor t :: LevelType
t) e :: Expr CoreBndr
e b :: CoreBndr
b con :: AltCon
con bs :: [CoreBndr]
bs
| LevelType
t LevelType -> LevelType -> Bool
forall a. Eq a => a -> a -> Bool
== LevelType
JoinCeilLvl
= Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
floats MajorEnv
forall a. IntMap a
M.empty
| Bool
otherwise
= Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
forall a. Bag a
emptyBag (Int -> MinorEnv -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major (Int -> Bag FloatBind -> MinorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
minor Bag FloatBind
floats))
where
floats :: Bag FloatBind
floats = FloatBind -> Bag FloatBind
forall a. a -> Bag a
unitBag (Expr CoreBndr -> CoreBndr -> AltCon -> [CoreBndr] -> FloatBind
FloatCase Expr CoreBndr
e CoreBndr
b AltCon
con [CoreBndr]
bs)
unitLetFloat :: Level -> FloatLet -> FloatBinds
unitLetFloat :: Level -> CoreBind -> FloatBinds
unitLetFloat lvl :: Level
lvl@(Level major :: Int
major minor :: Int
minor t :: LevelType
t) b :: CoreBind
b
| Level -> Bool
isTopLvl Level
lvl = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB (CoreBind -> Bag CoreBind
forall a. a -> Bag a
unitBag CoreBind
b) Bag FloatBind
forall a. Bag a
emptyBag MajorEnv
forall a. IntMap a
M.empty
| LevelType
t LevelType -> LevelType -> Bool
forall a. Eq a => a -> a -> Bool
== LevelType
JoinCeilLvl = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
floats MajorEnv
forall a. IntMap a
M.empty
| Bool
otherwise = Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag Bag FloatBind
forall a. Bag a
emptyBag (Int -> MinorEnv -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major
(Int -> Bag FloatBind -> MinorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
minor Bag FloatBind
floats))
where
floats :: Bag FloatBind
floats = FloatBind -> Bag FloatBind
forall a. a -> Bag a
unitBag (CoreBind -> FloatBind
FloatLet CoreBind
b)
plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats (FB t1 :: Bag CoreBind
t1 c1 :: Bag FloatBind
c1 l1 :: MajorEnv
l1) (FB t2 :: Bag CoreBind
t2 c2 :: Bag FloatBind
c2 l2 :: MajorEnv
l2)
= Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB (Bag CoreBind
t1 Bag CoreBind -> Bag CoreBind -> Bag CoreBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag CoreBind
t2) (Bag FloatBind
c1 Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag FloatBind
c2) (MajorEnv
l1 MajorEnv -> MajorEnv -> MajorEnv
`plusMajor` MajorEnv
l2)
plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor = (MinorEnv -> MinorEnv -> MinorEnv)
-> MajorEnv -> MajorEnv -> MajorEnv
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
M.unionWith MinorEnv -> MinorEnv -> MinorEnv
plusMinor
plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
plusMinor = (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> MinorEnv -> MinorEnv -> MinorEnv
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
M.unionWith Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags
install :: Bag FloatBind -> CoreExpr -> CoreExpr
install :: Bag FloatBind -> Expr CoreBndr -> Expr CoreBndr
install defn_groups :: Bag FloatBind
defn_groups expr :: Expr CoreBndr
expr
= (FloatBind -> Expr CoreBndr -> Expr CoreBndr)
-> Expr CoreBndr -> Bag FloatBind -> Expr CoreBndr
forall a r. (a -> r -> r) -> r -> Bag a -> r
foldrBag FloatBind -> Expr CoreBndr -> Expr CoreBndr
wrapFloat Expr CoreBndr
expr Bag FloatBind
defn_groups
partitionByLevel
:: Level
-> FloatBinds
-> (FloatBinds,
Bag FloatBind)
partitionByLevel :: Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel (Level major :: Int
major minor :: Int
minor typ :: LevelType
typ) (FB tops :: Bag CoreBind
tops ceils :: Bag FloatBind
ceils defns :: MajorEnv
defns)
= (Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
tops Bag FloatBind
ceils' (MajorEnv
outer_maj MajorEnv -> MajorEnv -> MajorEnv
`plusMajor` Int -> MinorEnv -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major MinorEnv
outer_min),
Bag FloatBind
here_min Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag FloatBind
here_ceil
Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` MinorEnv -> Bag FloatBind
flattenMinor MinorEnv
inner_min
Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
inner_maj)
where
(outer_maj :: MajorEnv
outer_maj, mb_here_maj :: Maybe MinorEnv
mb_here_maj, inner_maj :: MajorEnv
inner_maj) = Int -> MajorEnv -> (MajorEnv, Maybe MinorEnv, MajorEnv)
forall a. Int -> IntMap a -> (IntMap a, Maybe a, IntMap a)
M.splitLookup Int
major MajorEnv
defns
(outer_min :: MinorEnv
outer_min, mb_here_min :: Maybe (Bag FloatBind)
mb_here_min, inner_min :: MinorEnv
inner_min) = case Maybe MinorEnv
mb_here_maj of
Nothing -> (MinorEnv
forall a. IntMap a
M.empty, Maybe (Bag FloatBind)
forall a. Maybe a
Nothing, MinorEnv
forall a. IntMap a
M.empty)
Just min_defns :: MinorEnv
min_defns -> Int -> MinorEnv -> (MinorEnv, Maybe (Bag FloatBind), MinorEnv)
forall a. Int -> IntMap a -> (IntMap a, Maybe a, IntMap a)
M.splitLookup Int
minor MinorEnv
min_defns
here_min :: Bag FloatBind
here_min = Maybe (Bag FloatBind)
mb_here_min Maybe (Bag FloatBind) -> Bag FloatBind -> Bag FloatBind
forall a. Maybe a -> a -> a
`orElse` Bag FloatBind
forall a. Bag a
emptyBag
(here_ceil :: Bag FloatBind
here_ceil, ceils' :: Bag FloatBind
ceils') | LevelType
typ LevelType -> LevelType -> Bool
forall a. Eq a => a -> a -> Bool
== LevelType
JoinCeilLvl = (Bag FloatBind
ceils, Bag FloatBind
forall a. Bag a
emptyBag)
| Bool
otherwise = (Bag FloatBind
forall a. Bag a
emptyBag, Bag FloatBind
ceils)
partitionAtJoinCeiling :: FloatBinds -> (FloatBinds, Bag FloatBind)
partitionAtJoinCeiling :: FloatBinds -> (FloatBinds, Bag FloatBind)
partitionAtJoinCeiling (FB tops :: Bag CoreBind
tops ceils :: Bag FloatBind
ceils defs :: MajorEnv
defs)
= (Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
tops Bag FloatBind
forall a. Bag a
emptyBag MajorEnv
defs, Bag FloatBind
ceils)
atJoinCeiling :: (FloatStats, FloatBinds, CoreExpr)
-> (FloatStats, FloatBinds, CoreExpr)
atJoinCeiling :: (FloatStats, FloatBinds, Expr CoreBndr)
-> (FloatStats, FloatBinds, Expr CoreBndr)
atJoinCeiling (fs :: FloatStats
fs, floats :: FloatBinds
floats, expr' :: Expr CoreBndr
expr')
= (FloatStats
fs, FloatBinds
floats', Bag FloatBind -> Expr CoreBndr -> Expr CoreBndr
install Bag FloatBind
ceils Expr CoreBndr
expr')
where
(floats' :: FloatBinds
floats', ceils :: Bag FloatBind
ceils) = FloatBinds -> (FloatBinds, Bag FloatBind)
partitionAtJoinCeiling FloatBinds
floats
wrapTick :: Tickish Id -> FloatBinds -> FloatBinds
wrapTick :: Tickish CoreBndr -> FloatBinds -> FloatBinds
wrapTick t :: Tickish CoreBndr
t (FB tops :: Bag CoreBind
tops ceils :: Bag FloatBind
ceils defns :: MajorEnv
defns)
= Bag CoreBind -> Bag FloatBind -> MajorEnv -> FloatBinds
FB ((CoreBind -> CoreBind) -> Bag CoreBind -> Bag CoreBind
forall a b. (a -> b) -> Bag a -> Bag b
mapBag CoreBind -> CoreBind
wrap_bind Bag CoreBind
tops) (Bag FloatBind -> Bag FloatBind
wrap_defns Bag FloatBind
ceils)
((MinorEnv -> MinorEnv) -> MajorEnv -> MajorEnv
forall a b. (a -> b) -> IntMap a -> IntMap b
M.map ((Bag FloatBind -> Bag FloatBind) -> MinorEnv -> MinorEnv
forall a b. (a -> b) -> IntMap a -> IntMap b
M.map Bag FloatBind -> Bag FloatBind
wrap_defns) MajorEnv
defns)
where
wrap_defns :: Bag FloatBind -> Bag FloatBind
wrap_defns = (FloatBind -> FloatBind) -> Bag FloatBind -> Bag FloatBind
forall a b. (a -> b) -> Bag a -> Bag b
mapBag FloatBind -> FloatBind
wrap_one
wrap_bind :: CoreBind -> CoreBind
wrap_bind (NonRec binder :: CoreBndr
binder rhs :: Expr CoreBndr
rhs) = CoreBndr -> Expr CoreBndr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
binder (Expr CoreBndr -> Expr CoreBndr
maybe_tick Expr CoreBndr
rhs)
wrap_bind (Rec pairs :: [(CoreBndr, Expr CoreBndr)]
pairs) = [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ((Expr CoreBndr -> Expr CoreBndr)
-> [(CoreBndr, Expr CoreBndr)] -> [(CoreBndr, Expr CoreBndr)]
forall b c a. (b -> c) -> [(a, b)] -> [(a, c)]
mapSnd Expr CoreBndr -> Expr CoreBndr
maybe_tick [(CoreBndr, Expr CoreBndr)]
pairs)
wrap_one :: FloatBind -> FloatBind
wrap_one (FloatLet bind :: CoreBind
bind) = CoreBind -> FloatBind
FloatLet (CoreBind -> CoreBind
wrap_bind CoreBind
bind)
wrap_one (FloatCase e :: Expr CoreBndr
e b :: CoreBndr
b c :: AltCon
c bs :: [CoreBndr]
bs) = Expr CoreBndr -> CoreBndr -> AltCon -> [CoreBndr] -> FloatBind
FloatCase (Expr CoreBndr -> Expr CoreBndr
maybe_tick Expr CoreBndr
e) CoreBndr
b AltCon
c [CoreBndr]
bs
maybe_tick :: Expr CoreBndr -> Expr CoreBndr
maybe_tick e :: Expr CoreBndr
e | Expr CoreBndr -> Bool
exprIsHNF Expr CoreBndr
e = Tickish CoreBndr -> Expr CoreBndr -> Expr CoreBndr
tickHNFArgs Tickish CoreBndr
t Expr CoreBndr
e
| Bool
otherwise = Tickish CoreBndr -> Expr CoreBndr -> Expr CoreBndr
mkTick Tickish CoreBndr
t Expr CoreBndr
e