{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
module Futhark.Optimise.Fusion ( fuseSOACs )
where
import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Except
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import qualified Data.List as L
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.SOACS hiding (SOAC(..))
import qualified Futhark.Representation.Aliases as Aliases
import qualified Futhark.Representation.SOACS as Futhark
import Futhark.MonadFreshNames
import Futhark.Representation.SOACS.Simplify
import Futhark.Optimise.Fusion.LoopKernel
import Futhark.Construct
import qualified Futhark.Analysis.HORepresentation.SOAC as SOAC
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Pass
data VarEntry = IsArray VName (NameInfo SOACS) Names SOAC.Input
| IsNotArray VName (NameInfo SOACS)
varEntryType :: VarEntry -> NameInfo SOACS
varEntryType (IsArray _ attr _ _) =
attr
varEntryType (IsNotArray _ attr) =
attr
varEntryAliases :: VarEntry -> Names
varEntryAliases (IsArray _ _ x _) = x
varEntryAliases _ = mempty
data FusionGEnv = FusionGEnv {
soacs :: M.Map VName [VName]
, varsInScope:: M.Map VName VarEntry
, fusedRes :: FusedRes
}
lookupArr :: VName -> FusionGEnv -> Maybe SOAC.Input
lookupArr v env = asArray =<< M.lookup v (varsInScope env)
where asArray (IsArray _ _ _ input) = Just input
asArray IsNotArray{} = Nothing
newtype Error = Error String
instance Show Error where
show (Error msg) = "Fusion error:\n" ++ msg
newtype FusionGM a = FusionGM (ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a)
deriving (Monad, Applicative, Functor,
MonadError Error,
MonadState VNameSource,
MonadReader FusionGEnv)
instance MonadFreshNames FusionGM where
getNameSource = get
putNameSource = put
instance HasScope SOACS FusionGM where
askScope = toScope <$> asks varsInScope
where toScope = M.map varEntryType
bindVar :: FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar env (Ident name t, aliases) =
env { varsInScope = M.insert name entry $ varsInScope env }
where entry = case t of
Array {} -> IsArray name (LetInfo t) aliases' $ SOAC.identInput $ Ident name t
_ -> IsNotArray name $ LetInfo t
expand = maybe mempty varEntryAliases . flip M.lookup (varsInScope env)
aliases' = aliases <> mconcat (map expand $ S.toList aliases)
bindVars :: FusionGEnv -> [(Ident, Names)] -> FusionGEnv
bindVars = foldl bindVar
binding :: [(Ident, Names)] -> FusionGM a -> FusionGM a
binding vs = local (`bindVars` vs)
gatherStmPattern :: Pattern -> Exp -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPattern pat e = binding $ zip idents aliases
where idents = patternIdents pat
aliases = replicate (length (patternContextNames pat)) mempty ++
expAliases (Alias.analyseExp e)
bindingPat :: Pattern -> FusionGM a -> FusionGM a
bindingPat = binding . (`zip` repeat mempty) . patternIdents
bindingParams :: Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams = binding . (`zip` repeat mempty) . map paramIdent
bindingFamilyVar :: [VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar faml env (Ident nm t) =
env { soacs = M.insert nm faml $ soacs env
, varsInScope = M.insert nm (IsArray nm (LetInfo t) mempty $
SOAC.identInput $ Ident nm t) $
varsInScope env
}
varAliases :: VName -> FusionGM Names
varAliases v = asks $ S.insert v . maybe mempty varEntryAliases .
M.lookup v . varsInScope
varsAliases :: Names -> FusionGM Names
varsAliases = fmap mconcat . mapM varAliases . S.toList
checkForUpdates :: FusedRes -> Exp -> FusionGM FusedRes
checkForUpdates res (BasicOp (Update src is _)) = do
res' <- foldM addVarToInfusible res $
src : S.toList (mconcat $ map freeIn is)
aliases <- varAliases src
let inspectKer k = k { inplace = aliases <> inplace k }
return res' { kernels = M.map inspectKer $ kernels res' }
checkForUpdates res _ = return res
bindingFamily :: Pattern -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily pat = local bind
where idents = patternIdents pat
family = patternNames pat
bind env = foldl (bindingFamilyVar family) env idents
bindingTransform :: PatElem -> VName -> SOAC.ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform pe srcname trns = local $ \env ->
case M.lookup srcname $ varsInScope env of
Just (IsArray src' _ aliases input) ->
env { varsInScope =
M.insert vname
(IsArray src' (LetInfo attr) (srcname `S.insert` aliases) $
trns `SOAC.addTransform` input) $
varsInScope env
}
_ -> bindVar env (patElemIdent pe, S.singleton vname)
where vname = patElemName pe
attr = patElemAttr pe
bindRes :: FusedRes -> FusionGM a -> FusionGM a
bindRes rrr = local (\x -> x { fusedRes = rrr })
runFusionGatherM :: MonadFreshNames m =>
FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM (FusionGM a) env =
modifyNameSource $ \src -> runReader (runStateT (runExceptT a) src) env
fuseSOACs :: Pass SOACS SOACS
fuseSOACs =
Pass { passName = "Fuse SOACs"
, passDescription = "Perform higher-order optimisation, i.e., fusion."
, passFunction = simplifySOACS <=< renameProg <=< intraproceduralTransformation fuseFun
}
fuseFun :: FunDef -> PassM FunDef
fuseFun fun = do
let env = FusionGEnv { soacs = M.empty
, varsInScope = M.empty
, fusedRes = mempty
}
k <- cleanFusionResult <$>
liftEitherM (runFusionGatherM (fusionGatherFun fun) env)
if not $ rsucc k
then return fun
else liftEitherM $ runFusionGatherM (fuseInFun k fun) env
fusionGatherFun :: FunDef -> FusionGM FusedRes
fusionGatherFun fundec =
bindingParams (funDefParams fundec) $
fusionGatherBody mempty $ funDefBody fundec
fuseInFun :: FusedRes -> FunDef -> FusionGM FunDef
fuseInFun res fundec = do
body' <- bindingParams (funDefParams fundec) $
bindRes res $
fuseInBody $ funDefBody fundec
return $ fundec { funDefBody = body' }
newtype KernName = KernName { unKernName :: VName }
deriving (Eq, Ord, Show)
data FusedRes = FusedRes {
rsucc :: Bool
, outArr :: M.Map VName KernName
, inpArr :: M.Map VName (S.Set KernName)
, infusible :: Names
, kernels :: M.Map KernName FusedKer
}
instance Semigroup FusedRes where
res1 <> res2 =
FusedRes (rsucc res1 || rsucc res2)
(outArr res1 `M.union` outArr res2)
(M.unionWith S.union (inpArr res1) (inpArr res2) )
(infusible res1 `S.union` infusible res2)
(kernels res1 `M.union` kernels res2)
instance Monoid FusedRes where
mempty = FusedRes { rsucc = False, outArr = M.empty, inpArr = M.empty,
infusible = S.empty, kernels = M.empty }
isInpArrInResModKers :: FusedRes -> S.Set KernName -> VName -> Bool
isInpArrInResModKers ress kers nm =
case M.lookup nm (inpArr ress) of
Nothing -> False
Just s -> not $ S.null $ s `S.difference` kers
getKersWithInpArrs :: FusedRes -> [VName] -> S.Set KernName
getKersWithInpArrs ress =
S.unions . mapMaybe (`M.lookup` inpArr ress)
expandSoacInpArr :: [VName] -> FusionGM [VName]
expandSoacInpArr =
foldM (\y nm -> do bnd <- asks $ M.lookup nm . soacs
case bnd of
Nothing -> return (y++[nm])
Just nns -> return (y++nns )
) []
soacInputs :: SOAC -> FusionGM ([VName], [VName])
soacInputs soac = do
let (inp_idds, other_idds) = getIdentArr $ SOAC.inputs soac
(inp_nms0, other_nms0) = (inp_idds, other_idds)
inp_nms <- expandSoacInpArr inp_nms0
other_nms <- expandSoacInpArr other_nms0
return (inp_nms, other_nms)
addNewKerWithInfusible :: FusedRes -> ([Ident], Certificates, SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible res (idd, cs, soac, consumed) ufs = do
nm_ker <- KernName <$> newVName "ker"
scope <- askScope
let out_nms = map identName idd
new_ker = newKernel cs soac consumed out_nms scope
comb = M.unionWith S.union
os' = M.fromList [(arr,nm_ker) | arr <- out_nms]
`M.union` outArr res
is' = M.fromList [(arr,S.singleton nm_ker)
| arr <- map SOAC.inputArray $ SOAC.inputs soac]
`comb` inpArr res
return $ FusedRes (rsucc res) os' is' ufs
(M.insert nm_ker new_ker (kernels res))
lookupInput :: VName -> FusionGM (Maybe SOAC.Input)
lookupInput name = asks $ lookupArr name
inlineSOACInput :: SOAC.Input -> FusionGM SOAC.Input
inlineSOACInput (SOAC.Input ts v t) = do
maybe_inp <- lookupInput v
case maybe_inp of
Nothing ->
return $ SOAC.Input ts v t
Just (SOAC.Input ts2 v2 t2) ->
return $ SOAC.Input (ts2<>ts) v2 t2
inlineSOACInputs :: SOAC -> FusionGM SOAC
inlineSOACInputs soac = do
inputs' <- mapM inlineSOACInput $ SOAC.inputs soac
return $ inputs' `SOAC.setInputs` soac
greedyFuse :: [Stm] -> Names -> FusedRes -> (Pattern, Certificates, SOAC, Names)
-> FusionGM FusedRes
greedyFuse rem_bnds lam_used_nms res (out_idds, cs, orig_soac, consumed) = do
soac <- inlineSOACInputs orig_soac
(inp_nms, other_nms) <- soacInputs soac
let out_nms = patternNames out_idds
isInfusible = (`S.member` infusible res)
is_screma = case orig_soac of
SOAC.Screma _ form _ ->
(isJust (isRedomapSOAC form) || isJust (isScanomapSOAC form)) &&
not (isJust (isReduceSOAC form) || isJust (isScanSOAC form))
_ -> False
(ok_kers_compat, fused_kers, fused_nms, old_kers, oldker_nms) <-
if is_screma || any isInfusible out_nms
then horizontGreedyFuse rem_bnds res (out_idds, cs, soac, consumed)
else prodconsGreedyFuse res (out_idds, cs, soac, consumed)
let all_used_names = S.toList $ S.unions [lam_used_nms, S.fromList inp_nms, S.fromList other_nms]
has_inplace ker = any (`S.member` inplace ker) all_used_names
ok_inplace = not $ any has_inplace old_kers
let fusible_ker = not (null old_kers) && ok_inplace && ok_kers_compat
let mod_kerS = if fusible_ker then S.fromList oldker_nms else S.empty
let used_inps = filter (isInpArrInResModKers res mod_kerS) inp_nms
let ufs = S.unions [infusible res, S.fromList used_inps,
S.fromList other_nms `S.difference`
S.fromList (map SOAC.inputArray $ SOAC.inputs soac)]
let comb = M.unionWith S.union
if not fusible_ker then
addNewKerWithInfusible res (patternIdents out_idds, cs, soac, consumed) ufs
else do
let inpArr' =
foldl (\inpa (kold, knm) ->
S.foldl'
(\inpp nm ->
case M.lookup nm inpp of
Nothing -> inpp
Just s -> let new_set = S.delete knm s
in if S.null new_set
then M.delete nm inpp
else M.insert nm new_set inpp
)
inpa $ arrInputs kold
)
(inpArr res) (zip old_kers oldker_nms)
let fused_ker_nms = zip fused_nms fused_kers
inpArr''= foldl (\inpa' (knm, knew) ->
M.fromList [ (k, S.singleton knm)
| k <- S.toList $ arrInputs knew ]
`comb` inpa'
)
inpArr' fused_ker_nms
let kernels' = M.fromList fused_ker_nms `M.union` kernels res
return $ FusedRes True (outArr res) inpArr'' ufs kernels'
prodconsGreedyFuse :: FusedRes -> (Pattern, Certificates, SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse res (out_idds, cs, soac, consumed) = do
let out_nms = patternNames out_idds
to_fuse_knmSet = getKersWithInpArrs res out_nms
to_fuse_knms = S.toList to_fuse_knmSet
lookup_kern k = case M.lookup k (kernels res) of
Nothing -> throwError $ Error
("In Fusion.hs, greedyFuse, comp of to_fuse_kers: "
++ "kernel name not found in kernels field!")
Just ker -> return ker
to_fuse_kers <- mapM lookup_kern to_fuse_knms
(ok_kers_compat, fused_kers) <- do
kers <- forM to_fuse_kers $
attemptFusion S.empty (patternNames out_idds) soac consumed
case sequence kers of
Nothing -> return (False, [])
Just kers' -> return (True, map certifyKer kers')
return (ok_kers_compat, fused_kers, to_fuse_knms, to_fuse_kers, to_fuse_knms)
where certifyKer k = k { certificates = certificates k <> cs }
horizontGreedyFuse :: [Stm] -> FusedRes -> (Pattern, Certificates, SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse rem_bnds res (out_idds, cs, soac, consumed) = do
(inp_nms, _) <- soacInputs soac
let out_nms = patternNames out_idds
infusible_nms = S.fromList $ filter (`S.member` infusible res) out_nms
out_arr_nms = case soac of
SOAC.Screma _ (ScremaForm (_, scan_nes) (_, _, red_nes) _) _ ->
drop (length scan_nes + length red_nes) out_nms
SOAC.Stream _ frm _ _ -> drop (length $ getStreamAccums frm) out_nms
_ -> out_nms
to_fuse_knms1 = S.toList $ getKersWithInpArrs res (out_arr_nms++inp_nms)
to_fuse_knms2 = getKersWithSameInpSize (SOAC.width soac) res
to_fuse_knms = S.toList $ S.fromList $ to_fuse_knms1 ++ to_fuse_knms2
lookupKernel k = case M.lookup k (kernels res) of
Nothing -> throwError $ Error
("In Fusion.hs, greedyFuse, comp of to_fuse_kers: "
++ "kernel name not found in kernels field!")
Just ker -> return ker
let bnd_nms = map (patternNames . stmPattern) rem_bnds
kernminds <- forM to_fuse_knms $ \ker_nm -> do
ker <- lookupKernel ker_nm
let out_nm = case fsoac ker of
SOAC.Stream _ frm _ _
| x:_ <- drop (length $ getStreamAccums frm) $ outNames ker ->
x
SOAC.Screma _ (ScremaForm (_, scan_nes) (_, _, red_nes) _) _
| x:_ <- drop (length scan_nes + length red_nes) $ outNames ker ->
x
_ -> head $ outNames ker
case L.findIndex (elem out_nm) bnd_nms of
Nothing -> return Nothing
Just i -> return $ Just (ker,ker_nm,i)
scope <- askScope
let kernminds' = L.sortBy (\(_,_,i1) (_,_,i2)->compare i1 i2) $ catMaybes kernminds
soac_kernel = newKernel cs soac consumed out_nms scope
(_,ok_ind,_,fused_ker,_) <-
foldM (\(cur_ok,n,prev_ind,cur_ker,ufus_nms) (ker, _ker_nm, bnd_ind) -> do
let curker_outnms = outNames cur_ker
curker_outset = S.fromList curker_outnms
new_ufus_nms = S.fromList $ outNames ker ++ S.toList ufus_nms
out_transf_ok = let ker_inp = SOAC.inputs $ fsoac ker
unfuse1 = S.fromList (map SOAC.inputArray ker_inp) `S.difference`
S.fromList (mapMaybe SOAC.isVarInput ker_inp)
unfuse2 = S.intersection curker_outset ufus_nms
in S.null $ S.intersection unfuse1 unfuse2
cons_no_out_transf = SOAC.nullTransforms $ outputTransform ker
consumer_ok <- do let consumer_bnd = rem_bnds !! bnd_ind
maybesoac <- SOAC.fromExp $ stmExp consumer_bnd
case maybesoac of
Right conssoac -> return $ S.null $ S.intersection curker_outset $
freeInBody $ lambdaBody $ SOAC.lambda conssoac
Left _ -> return True
let interm_bnds_ok = cur_ok && consumer_ok && out_transf_ok && cons_no_out_transf &&
foldl (\ok bnd-> ok &&
S.null ( S.intersection curker_outset $
freeInExp (stmExp bnd) ) ||
not ( null $ curker_outnms `L.intersect`
patternNames (stmPattern bnd))
) True (drop (prev_ind+1) $ take bnd_ind rem_bnds)
if not interm_bnds_ok then return (False,n,bnd_ind,cur_ker,S.empty)
else do new_ker <- attemptFusion ufus_nms (outNames cur_ker)
(fsoac cur_ker) (fusedConsumed cur_ker) ker
case new_ker of
Nothing -> return (False, n,bnd_ind,cur_ker,S.empty)
Just krn-> return (True,n+1,bnd_ind,krn,new_ufus_nms)
) (True,0,0,soac_kernel,infusible_nms) kernminds'
let (to_fuse_kers', to_fuse_knms',_) = unzip3 $ take ok_ind kernminds'
new_kernms = drop (ok_ind-1) to_fuse_knms'
return (ok_ind>0, [fused_ker], new_kernms, to_fuse_kers', to_fuse_knms')
where getKersWithSameInpSize :: SubExp -> FusedRes -> [KernName]
getKersWithSameInpSize sz ress =
map fst $ filter (\ (_,ker) -> sz == SOAC.width (fsoac ker)) $ M.toList $ kernels ress
fusionGatherBody :: FusedRes -> Body -> FusionGM FusedRes
fusionGatherBody fres (Body blore (stmsToList ->
Let (Pattern [] pes) bndtp
(DoLoop [] merge (ForLoop i it w loop_vars) body)
:bnds) res) | not $ null loop_vars = do
let (merge_params,merge_init) = unzip merge
(loop_params,loop_arrs) = unzip loop_vars
chunk_size <- newVName "chunk_size"
offset <- newVName "offset"
let chunk_param = Param chunk_size $ Prim int32
offset_param = Param offset $ Prim $ IntType it
acc_params <- forM merge_params $ \p ->
Param <$> newVName (baseString (paramName p) ++ "_outer") <*>
pure (paramType p)
chunked_params <- forM loop_vars $ \(p,arr) ->
Param <$> newVName (baseString arr ++ "_chunk") <*>
pure (paramType p `arrayOfRow` Futhark.Var chunk_size)
let lam_params = chunk_param : acc_params ++ [offset_param] ++ chunked_params
lam_body <- runBodyBinder $ localScope (scopeOfLParams lam_params) $ do
let merge' = zip merge_params $ map (Futhark.Var . paramName) acc_params
j <- newVName "j"
loop_body <- runBodyBinder $ do
forM_ (zip loop_params chunked_params) $ \(p,a_p) ->
letBindNames_ [paramName p] $ BasicOp $ Index (paramName a_p) $
fullSlice (paramType a_p) [DimFix $ Futhark.Var j]
letBindNames_ [i] $ BasicOp $ BinOp (Add it) (Futhark.Var offset) (Futhark.Var j)
return body
eBody [pure $
DoLoop [] merge' (ForLoop j it (Futhark.Var chunk_size) []) loop_body,
pure $
BasicOp $ BinOp (Add Int32) (Futhark.Var offset) (Futhark.Var chunk_size)]
let lam = Lambda { lambdaParams = lam_params
, lambdaBody = lam_body
, lambdaReturnType = map paramType $ acc_params ++ [offset_param]
}
stream = Futhark.Stream w (Sequential $ merge_init ++ [intConst it 0]) lam loop_arrs
discard <- newVName "discard"
let discard_pe = PatElem discard $ Prim int32
fusionGatherBody fres $ Body blore
(oneStm (Let (Pattern [] (pes<>[discard_pe])) bndtp (Op stream))<>stmsFromList bnds) res
fusionGatherBody fres (Body _ (stmsToList -> (bnd@(Let pat _ e):bnds)) res) = do
maybesoac <- SOAC.fromExp e
case maybesoac of
Right soac@(SOAC.Scatter _len lam _ivs _as) -> do
fres' <- addNamesToInfusible fres $ S.fromList $ patternNames pat
mapLike fres' soac lam
Right soac@(SOAC.GenReduce _ _ lam _) -> do
fres' <- addNamesToInfusible fres $ S.fromList $ patternNames pat
mapLike fres' soac lam
Right soac@(SOAC.Screma _ (ScremaForm (scan_lam, scan_nes)
(_, reduce_lam, reduce_nes)
map_lam) _) ->
reduceLike soac [scan_lam, reduce_lam, map_lam] $ scan_nes <> reduce_nes
Right soac@(SOAC.Stream _ form lam _) -> do
let lambdas = case form of
Parallel _ _ lout _ -> [lout, lam]
_ -> [lam]
reduceLike soac lambdas $ getStreamAccums form
_ | [pe] <- patternValueElements pat,
Just (src,trns) <- SOAC.transformFromExp (stmCerts bnd) e ->
bindingTransform pe src trns $ fusionGatherBody fres body
| otherwise -> do
let pat_vars = map (BasicOp . SubExp . Var) $ patternNames pat
bres <- gatherStmPattern pat e $ fusionGatherBody fres body
bres' <- checkForUpdates bres e
foldM fusionGatherExp bres' (e:pat_vars)
where body = mkBody (stmsFromList bnds) res
cs = stmCerts bnd
rem_bnds = bnd : bnds
consumed = consumedInExp $ Alias.analyseExp e
reduceLike soac lambdas nes = do
(used_lam, lres) <- foldM fusionGatherLam (S.empty, fres) lambdas
bres <- bindingFamily pat $ fusionGatherBody lres body
bres' <- foldM fusionGatherSubExp bres nes
consumed' <- varsAliases consumed
greedyFuse rem_bnds used_lam bres' (pat, cs, soac, consumed')
mapLike fres' soac lambda = do
bres <- bindingFamily pat $ fusionGatherBody fres' body
(used_lam, blres) <- fusionGatherLam (S.empty, bres) lambda
consumed' <- varsAliases consumed
greedyFuse rem_bnds used_lam blres (pat, cs, soac, consumed')
fusionGatherBody fres (Body _ _ res) =
foldM fusionGatherExp fres $ map (BasicOp . SubExp) res
fusionGatherExp :: FusedRes -> Exp -> FusionGM FusedRes
fusionGatherExp fres (DoLoop ctx val form loop_body) = do
fres' <- addNamesToInfusible fres $ freeIn form <> freeIn ctx <> freeIn val
let form_idents =
case form of
ForLoop i _ _ loopvars ->
Ident i (Prim int32) : map (paramIdent . fst) loopvars
WhileLoop{} -> []
new_res <- binding (zip (form_idents ++ map (paramIdent . fst) (ctx<>val)) $
repeat mempty) $
fusionGatherBody mempty loop_body
let (inp_arrs, _) = unzip $ M.toList $ inpArr new_res
let new_res' = new_res { infusible = foldl (flip S.insert) (infusible new_res) inp_arrs }
return $ new_res' <> fres'
fusionGatherExp fres (If cond e_then e_else _) = do
then_res <- fusionGatherBody mempty e_then
else_res <- fusionGatherBody mempty e_else
let both_res = then_res <> else_res
fres' <- fusionGatherSubExp fres cond
mergeFusionRes fres' both_res
fusionGatherExp _ (Op Futhark.Screma{}) = errorIllegal "screma"
fusionGatherExp _ (Op Futhark.Scatter{}) = errorIllegal "write"
fusionGatherExp fres e =
addNamesToInfusible fres $ freeInExp e
fusionGatherSubExp :: FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp fres (Var idd) = addVarToInfusible fres idd
fusionGatherSubExp fres _ = return fres
addNamesToInfusible :: FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible fres = foldM addVarToInfusible fres . S.toList
addVarToInfusible :: FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible fres name = do
trns <- asks $ lookupArr name
let name' = case trns of
Nothing -> name
Just (SOAC.Input _ orig _) -> orig
return fres { infusible = S.insert name' $ infusible fres }
fusionGatherLam :: (Names, FusedRes) -> Lambda -> FusionGM (S.Set VName, FusedRes)
fusionGatherLam (u_set,fres) (Lambda idds body _) = do
new_res <- bindingParams idds $ fusionGatherBody mempty body
let inp_arrs = S.fromList $ M.keys $ inpArr new_res
let unfus = infusible new_res `S.union` inp_arrs
bnds <- M.keys <$> asks varsInScope
let unfus' = unfus `S.intersection` S.fromList bnds
let new_res' = new_res { infusible = unfus' }
return (u_set `S.union` unfus', new_res' <> fres)
fuseInBody :: Body -> FusionGM Body
fuseInBody (Body _ stms res)
| Let pat aux e:bnds <- stmsToList stms = do
body' <- bindingPat pat $ fuseInBody $ mkBody (stmsFromList bnds) res
soac_bnds <- replaceSOAC pat aux e
return $ insertStms soac_bnds body'
| otherwise = return $ Body () mempty res
fuseInExp :: Exp -> FusionGM Exp
fuseInExp (DoLoop ctx val form loopbody) =
binding (zip form_idents $ repeat mempty) $
bindingParams (map fst $ ctx ++ val) $
DoLoop ctx val form <$> fuseInBody loopbody
where form_idents = case form of
WhileLoop{} -> []
ForLoop i it _ loopvars ->
Ident i (Prim $ IntType it) :
map (paramIdent . fst) loopvars
fuseInExp e = mapExpM fuseIn e
fuseIn :: Mapper SOACS SOACS FusionGM
fuseIn = identityMapper {
mapOnBody = const fuseInBody
, mapOnOp = mapSOACM identitySOACMapper { mapOnSOACLambda = fuseInLambda }
}
fuseInLambda :: Lambda -> FusionGM Lambda
fuseInLambda (Lambda params body rtp) = do
body' <- bindingParams params $ fuseInBody body
return $ Lambda params body' rtp
replaceSOAC :: Pattern -> StmAux () -> Exp -> FusionGM (Stms SOACS)
replaceSOAC (Pattern _ []) _ _ = return mempty
replaceSOAC pat@(Pattern _ (patElem : _)) aux e = do
fres <- asks fusedRes
let pat_nm = patElemName patElem
names = patternIdents pat
case M.lookup pat_nm (outArr fres) of
Nothing ->
oneStm . Let pat aux <$> fuseInExp e
Just knm ->
case M.lookup knm (kernels fres) of
Nothing -> throwError $ Error
("In Fusion.hs, replaceSOAC, outArr in ker_name "
++"which is not in Res: "++pretty (unKernName knm))
Just ker -> do
when (null $ fusedVars ker) $
throwError $ Error
("In Fusion.hs, replaceSOAC, unfused kernel "
++"still in result: "++pretty names)
insertKerSOAC (outNames ker) ker
insertKerSOAC :: [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC names ker = do
new_soac' <- finaliseSOAC $ fsoac ker
runBinder_ $ do
f_soac <- SOAC.toSOAC new_soac'
f_soac' <- copyNewlyConsumed (fusedConsumed ker) $ addOpAliases f_soac
validents <- zipWithM newIdent (map baseString names) $ SOAC.typeOf new_soac'
letBind_ (basicPattern [] validents) $ Op f_soac'
transformOutput (outputTransform ker) names validents
finaliseSOAC :: SOAC.SOAC SOACS -> FusionGM (SOAC.SOAC SOACS)
finaliseSOAC new_soac =
case new_soac of
SOAC.Screma w (ScremaForm (scan_lam, scan_nes) (comm, red_lam, red_nes) map_lam) arrs -> do
scan_lam' <- simplifyAndFuseInLambda scan_lam
red_lam' <- simplifyAndFuseInLambda red_lam
map_lam' <- simplifyAndFuseInLambda map_lam
return $ SOAC.Screma w (ScremaForm (scan_lam', scan_nes)
(comm, red_lam', red_nes)
map_lam')
arrs
SOAC.Scatter w lam inps dests -> do
lam' <- simplifyAndFuseInLambda lam
return $ SOAC.Scatter w lam' inps dests
SOAC.GenReduce w ops lam arrs -> do
lam' <- simplifyAndFuseInLambda lam
return $ SOAC.GenReduce w ops lam' arrs
SOAC.Stream w form lam inps -> do
lam' <- simplifyAndFuseInLambda lam
return $ SOAC.Stream w form lam' inps
simplifyAndFuseInLambda :: Lambda -> FusionGM Lambda
simplifyAndFuseInLambda lam = do
let args = replicate (length $ lambdaParams lam) Nothing
lam' <- simplifyLambda lam args
(_, nfres) <- fusionGatherLam (S.empty, mkFreshFusionRes) lam'
let nfres' = cleanFusionResult nfres
bindRes nfres' $ fuseInLambda lam'
copyNewlyConsumed :: Names
-> Futhark.SOAC (Aliases.Aliases SOACS)
-> Binder SOACS (Futhark.SOAC SOACS)
copyNewlyConsumed was_consumed soac =
case soac of
Futhark.Screma w (Futhark.ScremaForm
(scan_lam, scan_nes)
(comm, reduce_lam, reduce_nes)
map_lam) arrs -> do
arrs' <- mapM copyConsumedArr arrs
map_lam' <- copyFreeInLambda map_lam
return $ Futhark.Screma w
(Futhark.ScremaForm
(Aliases.removeLambdaAliases scan_lam, scan_nes)
(comm, Aliases.removeLambdaAliases reduce_lam, reduce_nes)
map_lam') arrs'
_ -> return $ removeOpAliases soac
where consumed = consumedInOp soac
newly_consumed = consumed `S.difference` was_consumed
copyConsumedArr a
| a `S.member` newly_consumed =
letExp (baseString a <> "_copy") $ BasicOp $ Copy a
| otherwise = return a
copyFreeInLambda lam = do
let free_consumed = consumedByLambda lam `S.difference`
S.fromList (map paramName $ lambdaParams lam)
(bnds, subst) <-
foldM copyFree (mempty, mempty) $ S.toList free_consumed
let lam' = Aliases.removeLambdaAliases lam
return $ if null bnds
then lam'
else lam' { lambdaBody =
insertStms bnds $
substituteNames subst $ lambdaBody lam'
}
copyFree (bnds, subst) v = do
v_copy <- newVName $ baseString v <> "_copy"
copy <- mkLetNamesM [v_copy] $ BasicOp $ Copy v
return (oneStm copy<>bnds, M.insert v v_copy subst)
mkFreshFusionRes :: FusedRes
mkFreshFusionRes =
FusedRes { rsucc = False, outArr = M.empty, inpArr = M.empty,
infusible = S.empty, kernels = M.empty }
mergeFusionRes :: FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes res1 res2 = do
let ufus_mres = infusible res1 `S.union` infusible res2
inp_both <- expandSoacInpArr $ M.keys $ inpArr res1 `M.intersection` inpArr res2
let m_unfus = foldl (flip S.insert) ufus_mres inp_both
return $ FusedRes (rsucc res1 || rsucc res2)
(outArr res1 `M.union` outArr res2)
(M.unionWith S.union (inpArr res1) (inpArr res2) )
m_unfus
(kernels res1 `M.union` kernels res2)
getIdentArr :: [SOAC.Input] -> ([VName], [VName])
getIdentArr = foldl comb ([],[])
where comb (vs,os) (SOAC.Input ts idd _)
| SOAC.nullTransforms ts = (idd:vs, os)
comb (vs, os) inp =
(vs, SOAC.inputArray inp : os)
cleanFusionResult :: FusedRes -> FusedRes
cleanFusionResult fres =
let newks = M.filter (not . null . fusedVars) (kernels fres)
newoa = M.filter (`M.member` newks) (outArr fres)
newia = M.map (S.filter (`M.member` newks)) (inpArr fres)
in fres { outArr = newoa, inpArr = newia, kernels = newks }
errorIllegal :: String -> FusionGM FusedRes
errorIllegal soac_name =
throwError $ Error
("In Fusion.hs, soac "++soac_name++" appears illegally in pgm!")