{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Representation.Kernels.Simplify
( simplifyKernels
, simplifyLambda
, simplifyKernelOp
, simplifyKernelExp
)
where
import Control.Monad
import Data.Either
import Data.Foldable
import Data.List
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Futhark.Representation.Kernels
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Rules
import Futhark.Optimise.Simplify.Lore
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Pass
import qualified Futhark.Optimise.Simplify as Simplify
import Futhark.Optimise.Simplify.Rule
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Analysis.Rephrase (castStm)
simpleKernels :: Simplify.SimpleOps Kernels
simpleKernels = Simplify.bindableSimpleOps (simplifyKernelOp simpleInKernel inKernelEnv)
simpleInKernel :: KernelSpace -> Simplify.SimpleOps InKernel
simpleInKernel = Simplify.bindableSimpleOps . simplifyKernelExp
simplifyKernels :: Prog Kernels -> PassM (Prog Kernels)
simplifyKernels =
Simplify.simplifyProg simpleKernels kernelRules Simplify.noExtraHoistBlockers
simplifyLambda :: (HasScope InKernel m, MonadFreshNames m) =>
KernelSpace -> Lambda InKernel -> [Maybe VName] -> m (Lambda InKernel)
simplifyLambda kspace =
Simplify.simplifyLambda (simpleInKernel kspace)
inKernelRules Engine.noExtraHoistBlockers
simplifyKernelOp :: (Engine.SimplifiableLore lore,
Engine.SimplifiableLore outerlore,
BodyAttr outerlore ~ (), BodyAttr lore ~ (),
ExpAttr lore ~ ExpAttr outerlore,
SameScope lore outerlore,
RetType lore ~ RetType outerlore,
BranchType lore ~ BranchType outerlore) =>
(KernelSpace -> Engine.SimpleOps lore) -> Engine.Env lore
-> Kernel lore -> Engine.SimpleM outerlore (Kernel (Wise lore), Stms (Wise outerlore))
simplifyKernelOp mk_ops env (Kernel desc space ts kbody) = do
space' <- Engine.simplify space
ts' <- mapM Engine.simplify ts
outer_vtable <- Engine.askVtable
((kbody_stms, kbody_res), kbody_hoisted) <-
Engine.subSimpleM (mk_ops space) env outer_vtable $ do
par_blocker <- Engine.asksEngineEnv $ Engine.blockHoistPar . Engine.envHoistBlockers
Engine.localVtable (<>scope_vtable) $
Engine.blockIf (Engine.hasFree bound_here
`Engine.orIf` Engine.isOp
`Engine.orIf` par_blocker
`Engine.orIf` Engine.isConsumed) $
simplifyKernelBodyM kbody
kbody_hoisted' <- mapM processHoistedStm kbody_hoisted
return (Kernel desc space' ts' $ mkWiseKernelBody () kbody_stms kbody_res,
kbody_hoisted')
where scope = scopeOfKernelSpace space
scope_vtable = ST.fromScope scope
bound_here = S.fromList $ M.keys scope
simplifyKernelOp mk_ops env (SegRed space comm red_op nes ts body) = do
space' <- Engine.simplify space
nes' <- mapM Engine.simplify nes
ts' <- mapM Engine.simplify ts
outer_vtable <- Engine.askVtable
(red_op', red_op_hoisted) <-
Engine.subSimpleM (mk_ops space) env outer_vtable $
Engine.localVtable (<>scope_vtable) $
Engine.simplifyLambda red_op $ replicate (length nes * 2) Nothing
red_op_hoisted' <- mapM processHoistedStm red_op_hoisted
((body_stms, body_res), body_hoisted) <-
Engine.subSimpleM (mk_ops space) env outer_vtable $ do
par_blocker <- Engine.asksEngineEnv $ Engine.blockHoistPar . Engine.envHoistBlockers
Engine.localVtable (<>scope_vtable) $
Engine.blockIf (Engine.hasFree bound_here
`Engine.orIf` Engine.isOp
`Engine.orIf` par_blocker
`Engine.orIf` Engine.isConsumed) $
Engine.simplifyBody (replicate (length ts) Observe) body
body_hoisted' <- mapM processHoistedStm body_hoisted
return (SegRed space' comm red_op' nes' ts' $
mkWiseBody () body_stms body_res,
red_op_hoisted' <> body_hoisted')
where scope_vtable = ST.fromScope scope
scope = scopeOfKernelSpace space
bound_here = S.fromList $ M.keys scope
simplifyKernelOp _ _ (GetSize key size_class) = return (GetSize key size_class, mempty)
simplifyKernelOp _ _ (GetSizeMax size_class) = return (GetSizeMax size_class, mempty)
simplifyKernelOp _ _ (CmpSizeLe key size_class x) = do
x' <- Engine.simplify x
return (CmpSizeLe key size_class x', mempty)
processHoistedStm :: (Monad m,
PrettyLore from,
ExpAttr from ~ ExpAttr to,
BodyAttr from ~ BodyAttr to,
RetType from ~ RetType to,
BranchType from ~ BranchType to,
LetAttr from ~ LetAttr to,
FParamAttr from ~ FParamAttr to,
LParamAttr from ~ LParamAttr to) =>
Stm from -> m (Stm to)
processHoistedStm bnd
| Just bnd' <- castStm bnd = return bnd'
| otherwise = fail $ "Cannot hoist binding: " ++ pretty bnd
mkWiseKernelBody :: (Attributes lore, CanBeWise (Op lore)) =>
BodyAttr lore -> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody attr bnds res =
let Body attr' _ _ = mkWiseBody attr bnds res_vs
in KernelBody attr' bnds res
where res_vs = map resValue res
resValue (ThreadsReturn _ se) = se
resValue (WriteReturn _ arr _) = Var arr
resValue (ConcatReturns _ _ _ _ v) = Var v
resValue (KernelInPlaceReturn v) = Var v
inKernelEnv :: Engine.Env InKernel
inKernelEnv = Engine.emptyEnv inKernelRules Simplify.noExtraHoistBlockers
instance Engine.Simplifiable SplitOrdering where
simplify SplitContiguous =
return SplitContiguous
simplify (SplitStrided stride) =
SplitStrided <$> Engine.simplify stride
instance Engine.Simplifiable CombineSpace where
simplify (CombineSpace scatter cspace) =
CombineSpace <$> mapM Engine.simplify scatter
<*> mapM (traverse Engine.simplify) cspace
simplifyKernelExp :: Engine.SimplifiableLore lore =>
KernelSpace -> KernelExp lore
-> Engine.SimpleM lore (KernelExp (Wise lore), Stms (Wise lore))
simplifyKernelExp _ (Barrier se) =
(,) <$> (Barrier <$> Engine.simplify se) <*> pure mempty
simplifyKernelExp _ (SplitSpace o w i elems_per_thread) =
(,) <$> (SplitSpace <$> Engine.simplify o <*> Engine.simplify w
<*> Engine.simplify i <*> Engine.simplify elems_per_thread)
<*> pure mempty
simplifyKernelExp kspace (Combine cspace ts active body) = do
((body_stms', body_res'), hoisted) <-
wrapbody $ Engine.blockIf (Engine.hasFree bound_here `Engine.orIf`
maybeBlockUnsafe) $
localScope (scopeOfCombineSpace cspace) $
Engine.simplifyBody (map (const Observe) ts) body
body' <- Engine.constructBody body_stms' body_res'
(,) <$> (Combine <$> Engine.simplify cspace
<*> mapM Engine.simplify ts
<*> mapM Engine.simplify active
<*> pure body') <*> pure hoisted
where bound_here = S.fromList $ M.keys $ scopeOfCombineSpace cspace
protectCombineHoisted checkIfActive m = do
(x, stms) <- m
runBinder $ do
if any (not . safeExp . stmExp) stms
then do is_active <- checkIfActive
mapM_ (Engine.protectIf (not . safeExp) is_active) stms
else addStms stms
return x
(maybeBlockUnsafe, wrapbody)
| [d] <- map snd $ cspaceDims cspace,
d == spaceGroupSize kspace =
(Engine.isFalse True,
protectCombineHoisted $
letSubExp "active" =<<
foldBinOp LogAnd (constant True) =<<
mapM (uncurry check) active)
| otherwise =
(Engine.isNotSafe, id)
check v se =
letSubExp "is_active" $ BasicOp $ CmpOp (CmpSlt Int32) (Var v) se
simplifyKernelExp _ (GroupReduce w lam input) = do
arrs' <- mapM Engine.simplify arrs
nes' <- mapM Engine.simplify nes
w' <- Engine.simplify w
(lam', hoisted) <- Engine.simplifyLambdaSeq lam (map (const Nothing) arrs')
return (GroupReduce w' lam' $ zip nes' arrs', hoisted)
where (nes,arrs) = unzip input
simplifyKernelExp _ (GroupScan w lam input) = do
w' <- Engine.simplify w
nes' <- mapM Engine.simplify nes
arrs' <- mapM Engine.simplify arrs
(lam', hoisted) <- Engine.simplifyLambdaSeq lam (map (const Nothing) arrs')
return (GroupScan w' lam' $ zip nes' arrs', hoisted)
where (nes,arrs) = unzip input
simplifyKernelExp _ (GroupGenReduce w dests op bucket vs locks) = do
w' <- Engine.simplify w
dests' <- mapM Engine.simplify dests
(op', hoisted) <- Engine.simplifyLambdaSeq op (map (const Nothing) vs)
bucket' <- Engine.simplify bucket
vs' <- mapM Engine.simplify vs
locks' <- Engine.simplify locks
return (GroupGenReduce w' dests' op' bucket' vs' locks', hoisted)
simplifyKernelExp _ (GroupStream w maxchunk lam accs arrs) = do
w' <- Engine.simplify w
maxchunk' <- Engine.simplify maxchunk
accs' <- mapM Engine.simplify accs
arrs' <- mapM Engine.simplify arrs
(lam', hoisted) <- simplifyGroupStreamLambda lam w' maxchunk' arrs'
return (GroupStream w' maxchunk' lam' accs' arrs', hoisted)
simplifyKernelBodyM :: Engine.SimplifiableLore lore =>
KernelBody lore
-> Engine.SimpleM lore (Engine.SimplifiedBody lore [KernelResult])
simplifyKernelBodyM (KernelBody _ stms res) =
Engine.simplifyStms stms $ do res' <- mapM Engine.simplify res
return ((res', UT.usages $ freeIn res'), mempty)
simplifyGroupStreamLambda :: Engine.SimplifiableLore lore =>
GroupStreamLambda lore
-> SubExp -> SubExp -> [VName]
-> Engine.SimpleM lore (GroupStreamLambda (Wise lore), Stms (Wise lore))
simplifyGroupStreamLambda lam w max_chunk arrs = do
let GroupStreamLambda block_size block_offset acc_params arr_params body = lam
bound_here = S.fromList $ block_size : block_offset :
map paramName (acc_params ++ arr_params)
((body_stms', body_res'), hoisted) <-
Engine.enterLoop $
Engine.bindLoopVar block_size Int32 max_chunk $
Engine.bindLoopVar block_offset Int32 w $
Engine.bindLParams acc_params $
Engine.bindChunkLParams block_offset (zip arr_params arrs) $
Engine.blockIf (Engine.hasFree bound_here `Engine.orIf` Engine.isConsumed) $
Engine.simplifyBody (replicate (length (bodyResult body)) Observe) body
acc_params' <- mapM (Engine.simplifyParam Engine.simplify) acc_params
arr_params' <- mapM (Engine.simplifyParam Engine.simplify) arr_params
body' <- Engine.constructBody body_stms' body_res'
return (GroupStreamLambda block_size block_offset acc_params' arr_params' body', hoisted)
instance Engine.Simplifiable KernelSpace where
simplify (KernelSpace gtid ltid gid num_threads num_groups group_size structure) =
KernelSpace gtid ltid gid
<$> Engine.simplify num_threads
<*> Engine.simplify num_groups
<*> Engine.simplify group_size
<*> Engine.simplify structure
instance Engine.Simplifiable SpaceStructure where
simplify (FlatThreadSpace dims) =
FlatThreadSpace <$> (zip gtids <$> mapM Engine.simplify gdims)
where (gtids, gdims) = unzip dims
simplify (NestedThreadSpace dims) =
NestedThreadSpace
<$> (zip4 gtids
<$> mapM Engine.simplify gdims
<*> pure ltids
<*> mapM Engine.simplify ldims)
where (gtids, gdims, ltids, ldims) = unzip4 dims
instance Engine.Simplifiable KernelResult where
simplify (ThreadsReturn threads what) =
ThreadsReturn <$> Engine.simplify threads <*> Engine.simplify what
simplify (WriteReturn ws a res) =
WriteReturn <$> Engine.simplify ws <*> Engine.simplify a <*> Engine.simplify res
simplify (ConcatReturns o w pte moffset what) =
ConcatReturns
<$> Engine.simplify o
<*> Engine.simplify w
<*> Engine.simplify pte
<*> Engine.simplify moffset
<*> Engine.simplify what
simplify (KernelInPlaceReturn what) =
KernelInPlaceReturn <$> Engine.simplify what
instance Engine.Simplifiable WhichThreads where
simplify AllThreads = pure AllThreads
simplify OneResultPerGroup = pure OneResultPerGroup
simplify ThreadsInSpace = pure ThreadsInSpace
simplify (ThreadsPerGroup limit) =
ThreadsPerGroup <$> mapM Engine.simplify limit
instance BinderOps (Wise Kernels) where
mkExpAttrB = bindableMkExpAttrB
mkBodyB = bindableMkBodyB
mkLetNamesB = bindableMkLetNamesB
instance BinderOps (Wise InKernel) where
mkExpAttrB = bindableMkExpAttrB
mkBodyB = bindableMkBodyB
mkLetNamesB = bindableMkLetNamesB
kernelRules :: RuleBook (Wise Kernels)
kernelRules = standardRules <>
ruleBook [RuleOp removeInvariantKernelResults]
[RuleOp distributeKernelResults,
RuleBasicOp removeUnnecessaryCopy]
fuseStreamIota :: TopDownRuleOp (Wise InKernel)
fuseStreamIota vtable pat _ (GroupStream w max_chunk lam accs arrs)
| ([(iota_cs, iota_param, iota_start, iota_stride, iota_t)], params_and_arrs) <-
partitionEithers $ zipWith (isIota vtable) (groupStreamArrParams lam) arrs = do
let (arr_params', arrs') = unzip params_and_arrs
chunk_size = groupStreamChunkSize lam
offset = groupStreamChunkOffset lam
body' <- insertStmsM $ inScopeOf lam $ certifying iota_cs $ do
offset' <- asIntS iota_t $ Var offset
offset'' <- letSubExp "offset_by_stride" $
BasicOp $ BinOp (Mul iota_t) offset' iota_stride
start <- letSubExp "iota_start" $
BasicOp $ BinOp (Add iota_t) offset'' iota_start
letBindNames_ [paramName iota_param] $
BasicOp $ Iota (Var chunk_size) start iota_stride iota_t
return $ groupStreamLambdaBody lam
let lam' = lam { groupStreamArrParams = arr_params',
groupStreamLambdaBody = body'
}
letBind_ pat $ Op $ GroupStream w max_chunk lam' accs arrs'
fuseStreamIota _ _ _ _ = cannotSimplify
isIota :: ST.SymbolTable lore -> a -> VName
-> Either (Certificates, a, SubExp, SubExp, IntType) (a, VName)
isIota vtable chunk arr
| Just (BasicOp (Iota _ x s it), cs) <- ST.lookupExp arr vtable =
Left (cs, chunk, x, s, it)
| otherwise =
Right (chunk, arr)
removeInvariantKernelResults :: TopDownRuleOp (Wise Kernels)
removeInvariantKernelResults vtable (Pattern [] kpes) attr
(Kernel desc space ts (KernelBody _ kstms kres)) = do
(ts', kpes', kres') <-
unzip3 <$> filterM checkForInvarianceResult (zip3 ts kpes kres)
when (kres == kres')
cannotSimplify
addStm $ Let (Pattern [] kpes') attr $ Op $ Kernel desc space ts' $
mkWiseKernelBody () kstms kres'
where isInvariant Constant{} = True
isInvariant (Var v) = isJust $ ST.lookup v vtable
num_threads = spaceNumThreads space
space_dims = map snd $ spaceDimensions space
checkForInvarianceResult (_, pe, ThreadsReturn threads se)
| isInvariant se =
case threads of
AllThreads -> do
letBindNames_ [patElemName pe] $ BasicOp $
Replicate (Shape [num_threads]) se
return False
ThreadsInSpace -> do
let rep a d = BasicOp . Replicate (Shape [d]) <$> letSubExp "rep" a
letBindNames_ [patElemName pe] =<<
foldM rep (BasicOp (SubExp se)) (reverse space_dims)
return False
_ -> return True
checkForInvarianceResult _ =
return True
removeInvariantKernelResults _ _ _ _ = cannotSimplify
distributeKernelResults :: BottomUpRuleOp (Wise Kernels)
distributeKernelResults (vtable, used)
(Pattern [] kpes) attr (Kernel desc kspace kts (KernelBody _ kstms kres)) = do
(kpes', kts', kres', kstms_rev) <- localScope (scopeOfKernelSpace kspace) $
foldM distribute (kpes, kts, kres, []) kstms
when (kpes' == kpes)
cannotSimplify
addStm $ Let (Pattern [] kpes') attr $
Op $ Kernel desc kspace kts' $ mkWiseKernelBody () (stmsFromList $ reverse kstms_rev) kres'
where
free_in_kstms = fold $ fmap freeInStm kstms
distribute (kpes', kts', kres', kstms_rev) bnd
| Let (Pattern [] [pe]) _ (BasicOp (Index arr slice)) <- bnd,
kspace_slice <- map (DimFix . Var . fst) $ spaceDimensions kspace,
kspace_slice `isPrefixOf` slice,
remaining_slice <- drop (length kspace_slice) slice,
all (isJust . flip ST.lookup vtable) $ S.toList $
freeIn arr <> freeIn remaining_slice,
Just (kpe, kpes'', kts'', kres'') <- isResult kpes' kts' kres' pe = do
let outer_slice = map (\(_, d) -> DimSlice
(constant (0::Int32))
d
(constant (1::Int32))) $
spaceDimensions kspace
index kpe' = letBind_ (Pattern [] [kpe']) $ BasicOp $ Index arr $
outer_slice <> remaining_slice
if patElemName kpe `UT.isConsumed` used
then do precopy <- newVName $ baseString (patElemName kpe) <> "_precopy"
index kpe { patElemName = precopy }
letBind_ (Pattern [] [kpe]) $ BasicOp $ Copy precopy
else index kpe
return (kpes'', kts'', kres'',
if patElemName pe `S.member` free_in_kstms
then bnd : kstms_rev
else kstms_rev)
distribute (kpes', kts', kres', kstms_rev) bnd =
return (kpes', kts', kres', bnd : kstms_rev)
isResult kpes' kts' kres' pe =
case partition matches $ zip3 kpes' kts' kres' of
([(kpe,_,_)], kpes_and_kres)
| (kpes'', kts'', kres'') <- unzip3 kpes_and_kres ->
Just (kpe, kpes'', kts'', kres'')
_ -> Nothing
where matches (_, _, kre) = kre == ThreadsReturn ThreadsInSpace (Var $ patElemName pe)
distributeKernelResults _ _ _ _ = cannotSimplify
simplifyKnownIterationStream :: TopDownRuleOp (Wise InKernel)
simplifyKnownIterationStream _ pat _ (GroupStream (Constant v) _ lam accs arrs)
| oneIsh v = do
let GroupStreamLambda chunk_size chunk_offset acc_params arr_params body = lam
letBindNames_ [chunk_size] $ BasicOp $ SubExp $ constant (1::Int32)
letBindNames_ [chunk_offset] $ BasicOp $ SubExp $ constant (0::Int32)
forM_ (zip acc_params accs) $ \(p,a) ->
letBindNames_ [paramName p] $ BasicOp $ SubExp a
forM_ (zip arr_params arrs) $ \(p,a) ->
letBindNames_ [paramName p] $ BasicOp $ Index a $
fullSlice (paramType p)
[DimSlice (Var chunk_offset) (Var chunk_size) (constant (1::Int32))]
res <- bodyBind body
forM_ (zip (patternElements pat) res) $ \(pe,r) ->
letBindNames_ [patElemName pe] $ BasicOp $ SubExp r
simplifyKnownIterationStream _ _ _ _ = cannotSimplify
removeUnusedStreamInputs :: TopDownRuleOp (Wise InKernel)
removeUnusedStreamInputs _ pat _ (GroupStream w maxchunk lam accs arrs)
| (used,unused) <- partition (isUsed . paramName . fst) $ zip arr_params arrs,
not $ null unused = do
let (arr_params', arrs') = unzip used
lam' = GroupStreamLambda chunk_size chunk_offset acc_params arr_params' body
letBind_ pat $ Op $ GroupStream w maxchunk lam' accs arrs'
where GroupStreamLambda chunk_size chunk_offset acc_params arr_params body = lam
isUsed = (`S.member` freeInBody body)
removeUnusedStreamInputs _ _ _ _ = cannotSimplify
inKernelRules :: RuleBook (Wise InKernel)
inKernelRules = standardRules <>
ruleBook [RuleOp fuseStreamIota,
RuleOp simplifyKnownIterationStream,
RuleOp removeUnusedStreamInputs] []