module Futhark.Optimise.Fusion.Composing
( fuseMaps,
fuseRedomap,
)
where
import Data.List (mapAccumL)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.HORep.SOAC qualified as SOAC
import Futhark.Builder (Buildable (..), insertStm, insertStms, mkLet)
import Futhark.Construct (mapResult)
import Futhark.IR
import Futhark.Util (dropLast, splitAt3, takeLast)
fuseMaps ::
Buildable rep =>
Names ->
Lambda rep ->
[SOAC.Input] ->
[(VName, Ident)] ->
Lambda rep ->
[SOAC.Input] ->
(Lambda rep, [SOAC.Input])
fuseMaps :: forall {k} (rep :: k).
Buildable rep =>
Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> (Lambda rep, [Input])
fuseMaps Names
unfus_nms Lambda rep
lam1 [Input]
inp1 [(VName, Ident)]
out1 Lambda rep
lam2 [Input]
inp2 = (Lambda rep
lam2', forall k a. Map k a -> [a]
M.elems Map Ident Input
inputmap)
where
lam2' :: Lambda rep
lam2' =
Lambda rep
lam2
{ lambdaParams :: [LParam rep]
lambdaParams =
[ forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
name Type
t
| Ident VName
name Type
t <- [Ident]
lam2redparams forall a. [a] -> [a] -> [a]
++ forall k a. Map k a -> [k]
M.keys Map Ident Input
inputmap
],
lambdaBody :: Body rep
lambdaBody = Body rep
new_body2'
}
new_body2 :: Body rep
new_body2 =
let stms :: [SubExpRes] -> [Stm rep]
stms [SubExpRes]
res =
[ forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Buildable rep =>
[Ident] -> Exp rep -> Stm rep
mkLet [Ident
p] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e
| (Ident
p, SubExpRes Certs
cs SubExp
e) <- forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
pat [SubExpRes]
res
]
bindLambda :: [SubExpRes] -> Body rep
bindLambda [SubExpRes]
res =
forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList (forall {k} {rep :: k}. Buildable rep => [SubExpRes] -> [Stm rep]
stms [SubExpRes]
res) forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Body rep -> Body rep
`insertStms` Body rep -> Body rep
makeCopiesInner (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)
in Body rep -> Body rep
makeCopies forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Buildable rep =>
([SubExpRes] -> Body rep) -> Body rep -> Body rep
mapResult [SubExpRes] -> Body rep
bindLambda (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam1)
new_body2_rses :: [SubExpRes]
new_body2_rses = forall {k} (rep :: k). Body rep -> [SubExpRes]
bodyResult Body rep
new_body2
new_body2' :: Body rep
new_body2' =
Body rep
new_body2 {bodyResult :: [SubExpRes]
bodyResult = [SubExpRes]
new_body2_rses forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExpRes
varRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
unfus_pat}
([Ident]
lam2redparams, [Ident]
unfus_pat, [Ident]
pat, Map Ident Input
inputmap, Body rep -> Body rep
makeCopies, Body rep -> Body rep
makeCopiesInner) =
forall {k} (rep :: k).
Buildable rep =>
Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> ([Ident], [Ident], [Ident], Map Ident Input,
Body rep -> Body rep, Body rep -> Body rep)
fuseInputs Names
unfus_nms Lambda rep
lam1 [Input]
inp1 [(VName, Ident)]
out1 Lambda rep
lam2 [Input]
inp2
fuseInputs ::
Buildable rep =>
Names ->
Lambda rep ->
[SOAC.Input] ->
[(VName, Ident)] ->
Lambda rep ->
[SOAC.Input] ->
( [Ident],
[Ident],
[Ident],
M.Map Ident SOAC.Input,
Body rep -> Body rep,
Body rep -> Body rep
)
fuseInputs :: forall {k} (rep :: k).
Buildable rep =>
Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> ([Ident], [Ident], [Ident], Map Ident Input,
Body rep -> Body rep, Body rep -> Body rep)
fuseInputs Names
unfus_nms Lambda rep
lam1 [Input]
inp1 [(VName, Ident)]
out1 Lambda rep
lam2 [Input]
inp2 =
([Ident]
lam2redparams, [Ident]
unfus_vars, [Ident]
outstms, Map Ident Input
inputmap, Body rep -> Body rep
makeCopies, Body rep -> Body rep
makeCopiesInner)
where
([Ident]
lam2redparams, [Ident]
lam2arrparams) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
lam2params forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
inp2) [Ident]
lam2params
lam1params :: [Ident]
lam1params = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Ident
paramIdent forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam1
lam2params :: [Ident]
lam2params = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Ident
paramIdent forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam2
lam1inputmap :: Map Ident Input
lam1inputmap = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
lam1params [Input]
inp1
lam2inputmap :: Map Ident Input
lam2inputmap = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
lam2arrparams [Input]
inp2
(Map Ident Input
lam2inputmap', Body rep -> Body rep
makeCopiesInner) = forall {k} (rep :: k).
Buildable rep =>
Map Ident Input -> (Map Ident Input, Body rep -> Body rep)
removeDuplicateInputs Map Ident Input
lam2inputmap
originputmap :: Map Ident Input
originputmap = Map Ident Input
lam1inputmap forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Map Ident Input
lam2inputmap'
outins :: Map Ident Input
outins =
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([VName] -> [Ident] -> [Input] -> Map Ident Input
outParams forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, Ident)]
out1) forall a b. (a -> b) -> a -> b
$
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall k a. Map k a -> [(k, a)]
M.toList Map Ident Input
lam2inputmap'
outstms :: [Ident]
outstms = [(VName, Ident)] -> Map Ident Input -> [Ident]
filterOutParams [(VName, Ident)]
out1 Map Ident Input
outins
(Map Ident Input
inputmap, Body rep -> Body rep
makeCopies) =
forall {k} (rep :: k).
Buildable rep =>
Map Ident Input -> (Map Ident Input, Body rep -> Body rep)
removeDuplicateInputs forall a b. (a -> b) -> a -> b
$ Map Ident Input
originputmap forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` Map Ident Input
outins
getVarParPair :: (b, Input) -> Maybe (VName, b)
getVarParPair (b, Input)
x = case Input -> Maybe VName
SOAC.isVarInput (forall a b. (a, b) -> b
snd (b, Input)
x) of
Just VName
nm -> forall a. a -> Maybe a
Just (VName
nm, forall a b. (a, b) -> a
fst (b, Input)
x)
Maybe VName
Nothing -> forall a. Maybe a
Nothing
outinsrev :: Map VName Ident
outinsrev = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {b}. (b, Input) -> Maybe (VName, b)
getVarParPair forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Map Ident Input
outins
unfusible :: VName -> Maybe Ident
unfusible VName
outname
| VName
outname VName -> Names -> Bool
`nameIn` Names
unfus_nms =
VName
outname forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Map VName Ident
outinsrev (forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName, Ident)]
out1)
unfusible VName
_ = forall a. Maybe a
Nothing
unfus_vars :: [Ident]
unfus_vars = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Maybe Ident
unfusible forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(VName, Ident)]
out1
outParams ::
[VName] ->
[Ident] ->
[SOAC.Input] ->
M.Map Ident SOAC.Input
outParams :: [VName] -> [Ident] -> [Input] -> Map Ident Input
outParams [VName]
out1 [Ident]
lam2arrparams [Input]
inp2 =
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {a}. (a, Input) -> Maybe (a, Input)
isOutParam forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
lam2arrparams [Input]
inp2
where
isOutParam :: (a, Input) -> Maybe (a, Input)
isOutParam (a
p, Input
inp)
| Just VName
a <- Input -> Maybe VName
SOAC.isVarInput Input
inp,
VName
a forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
out1 =
forall a. a -> Maybe a
Just (a
p, Input
inp)
isOutParam (a, Input)
_ = forall a. Maybe a
Nothing
filterOutParams ::
[(VName, Ident)] ->
M.Map Ident SOAC.Input ->
[Ident]
filterOutParams :: [(VName, Ident)] -> Map Ident Input -> [Ident]
filterOutParams [(VName, Ident)]
out1 Map Ident Input
outins =
forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL forall {k} {b}. Ord k => Map k [b] -> (k, b) -> (Map k [b], b)
checkUsed Map VName [Ident]
outUsage [(VName, Ident)]
out1
where
outUsage :: Map VName [Ident]
outUsage = forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey' forall {a}. Map VName [a] -> a -> Input -> Map VName [a]
add forall k a. Map k a
M.empty Map Ident Input
outins
where
add :: Map VName [a] -> a -> Input -> Map VName [a]
add Map VName [a]
m a
p Input
inp =
case Input -> Maybe VName
SOAC.isVarInput Input
inp of
Just VName
v -> forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith forall a. [a] -> [a] -> [a]
(++) VName
v [a
p] Map VName [a]
m
Maybe VName
Nothing -> Map VName [a]
m
checkUsed :: Map k [b] -> (k, b) -> (Map k [b], b)
checkUsed Map k [b]
m (k
a, b
ra) =
case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
a Map k [b]
m of
Just (b
p : [b]
ps) -> (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
a [b]
ps Map k [b]
m, b
p)
Maybe [b]
_ -> (Map k [b]
m, b
ra)
removeDuplicateInputs ::
Buildable rep =>
M.Map Ident SOAC.Input ->
(M.Map Ident SOAC.Input, Body rep -> Body rep)
removeDuplicateInputs :: forall {k} (rep :: k).
Buildable rep =>
Map Ident Input -> (Map Ident Input, Body rep -> Body rep)
removeDuplicateInputs = forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey' forall {k} {rep :: k} {k} {c}.
(Buildable rep, Ord k) =>
((Map Ident k, Body rep -> c), Map k VName)
-> Ident -> k -> ((Map Ident k, Body rep -> c), Map k VName)
comb ((forall k a. Map k a
M.empty, forall a. a -> a
id), forall k a. Map k a
M.empty)
where
comb :: ((Map Ident k, Body rep -> c), Map k VName)
-> Ident -> k -> ((Map Ident k, Body rep -> c), Map k VName)
comb ((Map Ident k
parmap, Body rep -> c
inner), Map k VName
arrmap) Ident
par k
arr =
case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
arr Map k VName
arrmap of
Maybe VName
Nothing ->
( (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Ident
par k
arr Map Ident k
parmap, Body rep -> c
inner),
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
arr (Ident -> VName
identName Ident
par) Map k VName
arrmap
)
Just VName
par' ->
( (Map Ident k
parmap, Body rep -> c
inner forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} {rep :: k}.
Buildable rep =>
Ident -> VName -> Body rep -> Body rep
forward Ident
par VName
par'),
Map k VName
arrmap
)
forward :: Ident -> VName -> Body rep -> Body rep
forward Ident
to VName
from Body rep
b =
forall {k} (rep :: k).
Buildable rep =>
[Ident] -> Exp rep -> Stm rep
mkLet [Ident
to] (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
from) forall {k} (rep :: k).
Buildable rep =>
Stm rep -> Body rep -> Body rep
`insertStm` Body rep
b
fuseRedomap ::
Buildable rep =>
Names ->
[VName] ->
Lambda rep ->
[SubExp] ->
[SubExp] ->
[SOAC.Input] ->
[(VName, Ident)] ->
Lambda rep ->
[SubExp] ->
[SubExp] ->
[SOAC.Input] ->
(Lambda rep, [SOAC.Input])
fuseRedomap :: forall {k} (rep :: k).
Buildable rep =>
Names
-> [VName]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda rep, [Input])
fuseRedomap
Names
unfus_nms
[VName]
outVars
Lambda rep
p_lam
[SubExp]
p_scan_nes
[SubExp]
p_red_nes
[Input]
p_inparr
[(VName, Ident)]
outPairs
Lambda rep
c_lam
[SubExp]
c_scan_nes
[SubExp]
c_red_nes
[Input]
c_inparr =
let p_num_nes :: Int
p_num_nes = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_scan_nes forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_red_nes
unfus_arrs :: [VName]
unfus_arrs = forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_nms) [VName]
outVars
p_lam_body :: Body rep
p_lam_body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
p_lam
([Type]
p_lam_scan_ts, [Type]
p_lam_red_ts, [Type]
p_lam_map_ts) =
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_scan_nes) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_red_nes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
p_lam
([SubExpRes]
p_lam_scan_res, [SubExpRes]
p_lam_red_res, [SubExpRes]
p_lam_map_res) =
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_scan_nes) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_red_nes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> [SubExpRes]
bodyResult Body rep
p_lam_body
p_lam_hacked :: Lambda rep
p_lam_hacked =
Lambda rep
p_lam
{ lambdaParams :: [LParam rep]
lambdaParams = forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
p_inparr) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
p_lam,
lambdaBody :: Body rep
lambdaBody = Body rep
p_lam_body {bodyResult :: [SubExpRes]
bodyResult = [SubExpRes]
p_lam_map_res},
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
p_lam_map_ts
}
(Lambda rep
res_lam, [Input]
new_inp) =
forall {k} (rep :: k).
Buildable rep =>
Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> (Lambda rep, [Input])
fuseMaps
([VName] -> Names
namesFromList [VName]
unfus_arrs)
Lambda rep
p_lam_hacked
[Input]
p_inparr
(forall a. Int -> [a] -> [a]
drop Int
p_num_nes [(VName, Ident)]
outPairs)
Lambda rep
c_lam
[Input]
c_inparr
([Type]
res_lam_scan_ts, [Type]
res_lam_red_ts, [Type]
res_lam_map_ts) =
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
c_scan_nes) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
c_red_nes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
res_lam
([VName]
_, [Type]
extra_map_ts) =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall a. (a -> Bool) -> [a] -> [a]
filter (\(VName
nm, Type
_) -> VName
nm forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
unfus_arrs) forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. Int -> [a] -> [a]
drop Int
p_num_nes [VName]
outVars) forall a b. (a -> b) -> a -> b
$
forall a. Int -> [a] -> [a]
drop Int
p_num_nes forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
p_lam
accpars :: [LParam rep]
accpars = forall a. Int -> [a] -> [a]
dropLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
p_inparr) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
p_lam
res_body :: Body rep
res_body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
res_lam
([SubExpRes]
res_lam_scan_res, [SubExpRes]
res_lam_red_res, [SubExpRes]
res_lam_map_res) =
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
c_scan_nes) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
c_red_nes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> [SubExpRes]
bodyResult Body rep
res_body
res_body' :: Body rep
res_body' =
Body rep
res_body
{ bodyResult :: [SubExpRes]
bodyResult =
[SubExpRes]
p_lam_scan_res
forall a. [a] -> [a] -> [a]
++ [SubExpRes]
res_lam_scan_res
forall a. [a] -> [a] -> [a]
++ [SubExpRes]
p_lam_red_res
forall a. [a] -> [a] -> [a]
++ [SubExpRes]
res_lam_red_res
forall a. [a] -> [a] -> [a]
++ [SubExpRes]
res_lam_map_res
}
res_lam' :: Lambda rep
res_lam' =
Lambda rep
res_lam
{ lambdaParams :: [LParam rep]
lambdaParams = [LParam rep]
accpars forall a. [a] -> [a] -> [a]
++ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
res_lam,
lambdaBody :: Body rep
lambdaBody = Body rep
res_body',
lambdaReturnType :: [Type]
lambdaReturnType =
[Type]
p_lam_scan_ts
forall a. [a] -> [a] -> [a]
++ [Type]
res_lam_scan_ts
forall a. [a] -> [a] -> [a]
++ [Type]
p_lam_red_ts
forall a. [a] -> [a] -> [a]
++ [Type]
res_lam_red_ts
forall a. [a] -> [a] -> [a]
++ [Type]
res_lam_map_ts
forall a. [a] -> [a] -> [a]
++ [Type]
extra_map_ts
}
in (Lambda rep
res_lam', [Input]
new_inp)