module Futhark.Optimise.Fusion.Composing
( fuseMaps
, fuseRedomap
, mergeReduceOps
)
where
import Data.List
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import qualified Futhark.Analysis.HORepresentation.SOAC as SOAC
import Futhark.Representation.AST
import Futhark.Binder (Bindable(..), insertStm, insertStms, mkLet)
import Futhark.Construct (mapResult)
import Futhark.Util (splitAt3, takeLast, dropLast)
fuseMaps :: Bindable lore =>
Names
-> Lambda lore
-> [SOAC.Input]
-> [(VName,Ident)]
-> Lambda lore
-> [SOAC.Input]
-> (Lambda lore, [SOAC.Input])
fuseMaps unfus_nms lam1 inp1 out1 lam2 inp2 = (lam2', M.elems inputmap)
where lam2' =
lam2 { lambdaParams = [ Param name t
| Ident name t <- lam2redparams ++ M.keys inputmap ]
, lambdaBody = new_body2'
}
new_body2 = let bnds res = [ mkLet [] [p] $ BasicOp $ SubExp e
| (p,e) <- zip pat res]
bindLambda res =
stmsFromList (bnds res) `insertStms` makeCopiesInner (lambdaBody lam2)
in makeCopies $ mapResult bindLambda (lambdaBody lam1)
new_body2_rses = bodyResult new_body2
new_body2'= new_body2 { bodyResult = new_body2_rses ++
map (Var . identName) unfus_pat }
(lam2redparams, unfus_pat, pat, inputmap, makeCopies, makeCopiesInner) =
fuseInputs unfus_nms lam1 inp1 out1 lam2 inp2
fuseInputs :: Bindable lore =>
Names
-> Lambda lore -> [SOAC.Input] -> [(VName,Ident)]
-> Lambda lore -> [SOAC.Input]
-> ([Ident], [Ident], [Ident],
M.Map Ident SOAC.Input,
Body lore -> Body lore, Body lore -> Body lore)
fuseInputs unfus_nms lam1 inp1 out1 lam2 inp2 =
(lam2redparams, unfus_vars, outbnds, inputmap, makeCopies, makeCopiesInner)
where (lam2redparams, lam2arrparams) =
splitAt (length lam2params - length inp2) lam2params
lam1params = map paramIdent $ lambdaParams lam1
lam2params = map paramIdent $ lambdaParams lam2
lam1inputmap = M.fromList $ zip lam1params inp1
lam2inputmap = M.fromList $ zip lam2arrparams inp2
(lam2inputmap', makeCopiesInner) = removeDuplicateInputs lam2inputmap
originputmap = lam1inputmap `M.union` lam2inputmap'
outins = uncurry (outParams $ map fst out1) $
unzip $ M.toList lam2inputmap'
outbnds= filterOutParams out1 outins
(inputmap, makeCopies) =
removeDuplicateInputs $ originputmap `M.difference` outins
getVarParPair x = case SOAC.isVarInput (snd x) of
Just nm -> Just (nm, fst x)
Nothing -> Nothing
outinsrev = M.fromList $ mapMaybe getVarParPair $ M.toList outins
unfusible outname
| outname `S.member` unfus_nms =
outname `M.lookup` M.union outinsrev (M.fromList out1)
unfusible _ = Nothing
unfus_vars= mapMaybe (unfusible . fst) out1
outParams :: [VName] -> [Ident] -> [SOAC.Input]
-> M.Map Ident SOAC.Input
outParams out1 lam2arrparams inp2 =
M.fromList $ mapMaybe isOutParam $ zip lam2arrparams inp2
where isOutParam (p, inp)
| Just a <- SOAC.isVarInput inp,
a `elem` out1 = Just (p, inp)
isOutParam _ = Nothing
filterOutParams :: [(VName,Ident)]
-> M.Map Ident SOAC.Input
-> [Ident]
filterOutParams out1 outins =
snd $ mapAccumL checkUsed outUsage out1
where outUsage = M.foldlWithKey' add M.empty outins
where add m p inp =
case SOAC.isVarInput inp of
Just v -> M.insertWith (++) v [p] m
Nothing -> m
checkUsed m (a,ra) =
case M.lookup a m of
Just (p:ps) -> (M.insert a ps m, p)
_ -> (m, ra)
removeDuplicateInputs :: Bindable lore =>
M.Map Ident SOAC.Input
-> (M.Map Ident SOAC.Input, Body lore -> Body lore)
removeDuplicateInputs = fst . M.foldlWithKey' comb ((M.empty, id), M.empty)
where comb ((parmap, inner), arrmap) par arr =
case M.lookup arr arrmap of
Nothing -> ((M.insert par arr parmap, inner),
M.insert arr (identName par) arrmap)
Just par' -> ((parmap, inner . forward par par'),
arrmap)
forward to from b =
mkLet [] [to] (BasicOp $ SubExp $ Var from)
`insertStm` b
fuseRedomap :: Bindable lore =>
Names -> [VName]
-> Lambda lore -> [SubExp] -> [SubExp] -> [SOAC.Input]
-> [(VName,Ident)]
-> Lambda lore -> [SubExp] -> [SubExp] -> [SOAC.Input]
-> (Lambda lore, [SOAC.Input])
fuseRedomap unfus_nms outVars p_lam p_scan_nes p_red_nes p_inparr outPairs
c_lam c_scan_nes c_red_nes c_inparr =
let p_num_nes = length p_scan_nes + length p_red_nes
unfus_arrs = filter (`S.member` unfus_nms) outVars
p_lam_body = lambdaBody p_lam
(p_lam_scan_ts, p_lam_red_ts, p_lam_map_ts) =
splitAt3 (length p_scan_nes) (length p_red_nes) $ lambdaReturnType p_lam
(p_lam_scan_res, p_lam_red_res, p_lam_map_res) =
splitAt3 (length p_scan_nes) (length p_red_nes) $ bodyResult p_lam_body
p_lam_hacked = p_lam { lambdaParams = takeLast (length p_inparr) $ lambdaParams p_lam
, lambdaBody = p_lam_body { bodyResult = p_lam_map_res }
, lambdaReturnType = p_lam_map_ts }
(res_lam, new_inp) = fuseMaps (S.fromList unfus_arrs) p_lam_hacked p_inparr
(drop p_num_nes outPairs) c_lam c_inparr
(res_lam_scan_ts, res_lam_red_ts, res_lam_map_ts) =
splitAt3 (length c_scan_nes) (length c_red_nes) $ lambdaReturnType res_lam
(_,extra_map_ts) = unzip $ filter (\(nm,_)->elem nm unfus_arrs) $
zip (drop p_num_nes outVars) $ drop p_num_nes $
lambdaReturnType p_lam
accpars = dropLast (length p_inparr) $ lambdaParams p_lam
res_body = lambdaBody res_lam
(res_lam_scan_res, res_lam_red_res, res_lam_map_res) =
splitAt3 (length c_scan_nes) (length c_red_nes) $ bodyResult res_body
res_body'= res_body { bodyResult = p_lam_scan_res ++ res_lam_scan_res ++
p_lam_red_res ++ res_lam_red_res ++
res_lam_map_res }
res_lam' = res_lam { lambdaParams = accpars ++ lambdaParams res_lam
, lambdaBody = res_body'
, lambdaReturnType = p_lam_scan_ts ++ res_lam_scan_ts ++
p_lam_red_ts ++ res_lam_red_ts ++
res_lam_map_ts ++ extra_map_ts
}
in (res_lam', new_inp)
mergeReduceOps :: Lambda lore -> Lambda lore -> Lambda lore
mergeReduceOps (Lambda par1 bdy1 rtp1) (Lambda par2 bdy2 rtp2) =
let body' = Body (bodyAttr bdy1)
(bodyStms bdy1 <> bodyStms bdy2)
(bodyResult bdy1 ++ bodyResult bdy2)
(len1, len2) = (length rtp1, length rtp2)
par' = take len1 par1 ++ take len2 par2 ++ drop len1 par1 ++ drop len2 par2
in Lambda par' body' (rtp1++rtp2)