{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.InPlaceLowering.LowerIntoStm
(
lowerUpdateInKernel
, lowerUpdateKernels
, LowerUpdate
, DesiredUpdate (..)
) where
import Control.Monad
import Control.Monad.Writer
import Data.List (find)
import Data.Maybe (mapMaybe)
import Data.Either
import qualified Data.Set as S
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.Aliases
import Futhark.Representation.Kernels
import Futhark.Construct
import Futhark.Optimise.InPlaceLowering.SubstituteIndices
import Futhark.Tools (fullSlice)
data DesiredUpdate attr =
DesiredUpdate { updateName :: VName
, updateType :: attr
, updateCertificates :: Certificates
, updateSource :: VName
, updateIndices :: Slice SubExp
, updateValue :: VName
}
deriving (Show)
instance Functor DesiredUpdate where
f `fmap` u = u { updateType = f $ updateType u }
updateHasValue :: VName -> DesiredUpdate attr -> Bool
updateHasValue name = (name==) . updateValue
type LowerUpdate lore m = Stm (Aliases lore)
-> [DesiredUpdate (LetAttr (Aliases lore))]
-> Maybe (m [Stm (Aliases lore)])
lowerUpdate :: (MonadFreshNames m, Bindable lore,
LetAttr lore ~ Type, CanBeAliased (Op lore)) => LowerUpdate lore m
lowerUpdate (Let pat aux (DoLoop ctx val form body)) updates = do
canDo <- lowerUpdateIntoLoop updates pat ctx val body
Just $ do
(prebnds, postbnds, ctxpat, valpat, ctx', val', body') <- canDo
return $
prebnds ++ [certify (stmAuxCerts aux) $
mkLet ctxpat valpat $ DoLoop ctx' val' form body'] ++ postbnds
lowerUpdate
(Let pat aux (BasicOp (SubExp (Var v))))
[DesiredUpdate bindee_nm bindee_attr cs src is val]
| patternNames pat == [src] =
let is' = fullSlice (typeOf bindee_attr) is
in Just $
return [certify (stmAuxCerts aux <> cs) $
mkLet [] [Ident bindee_nm $ typeOf bindee_attr] $
BasicOp $ Update v is' $ Var val]
lowerUpdate _ _ =
Nothing
lowerUpdateKernels :: MonadFreshNames m => LowerUpdate Kernels m
lowerUpdateKernels
(Let (Pattern [] [PatElem v v_attr]) aux (Op (Kernel debug kspace ts kbody)))
[update@(DesiredUpdate bindee_nm bindee_attr cs _src is val)]
| v == val = do
kbody' <- lowerUpdateIntoKernel update kspace kbody
let is' = fullSlice (typeOf bindee_attr) is
Just $ return [certify (stmAuxCerts aux <> cs) $
mkLet [] [Ident bindee_nm $ typeOf bindee_attr] $
Op $ Kernel debug kspace ts kbody',
mkLet [] [Ident v $ typeOf v_attr] $ BasicOp $ Index bindee_nm is']
lowerUpdateKernels stm updates = lowerUpdate stm updates
lowerUpdateInKernel :: MonadFreshNames m => LowerUpdate InKernel m
lowerUpdateInKernel = lowerUpdate
lowerUpdateIntoKernel :: DesiredUpdate (LetAttr (Aliases Kernels))
-> KernelSpace -> KernelBody (Aliases InKernel)
-> Maybe (KernelBody (Aliases InKernel))
lowerUpdateIntoKernel update kspace kbody = do
[ThreadsReturn ThreadsInSpace se] <- Just $ kernelBodyResult kbody
is' <- mapM dimFix is
let ret = WriteReturn (arrayDims $ snd bindee_attr) src [(is'++map Var gtids, se)]
return kbody { kernelBodyResult = [ret] }
where DesiredUpdate _bindee_nm bindee_attr _cs src is _val = update
gtids = map fst $ spaceDimensions kspace
lowerUpdateIntoLoop :: (Bindable lore, BinderOps lore,
Aliased lore, LetAttr lore ~ (als, Type),
MonadFreshNames m) =>
[DesiredUpdate (LetAttr lore)]
-> Pattern lore
-> [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> Body lore
-> Maybe (m ([Stm lore],
[Stm lore],
[Ident],
[Ident],
[(FParam lore, SubExp)],
[(FParam lore, SubExp)],
Body lore))
lowerUpdateIntoLoop updates pat ctx val body = do
mk_in_place_map <- summariseLoop updates usedInBody resmap val
Just $ do
in_place_map <- mk_in_place_map
(val',prebnds,postbnds) <- mkMerges in_place_map
let (ctxpat,valpat) = mkResAndPat in_place_map
idxsubsts = indexSubstitutions in_place_map
(idxsubsts', newbnds) <- substituteIndices idxsubsts $ bodyStms body
(body_res, res_bnds) <- manipulateResult in_place_map idxsubsts'
let body' = mkBody (newbnds<>res_bnds) body_res
return (prebnds, postbnds, ctxpat, valpat, ctx, val', body')
where usedInBody = freeInBody body
resmap = zip (bodyResult body) $ patternValueIdents pat
mkMerges :: (MonadFreshNames m, Bindable lore) =>
[LoopResultSummary (als, Type)]
-> m ([(Param DeclType, SubExp)], [Stm lore], [Stm lore])
mkMerges summaries = do
((origmerge, extramerge), (prebnds, postbnds)) <-
runWriterT $ partitionEithers <$> mapM mkMerge summaries
return (origmerge ++ extramerge, prebnds, postbnds)
mkMerge summary
| Just (update, mergename, mergeattr) <- relatedUpdate summary = do
source <- newVName "modified_source"
let source_t = snd $ updateType update
elmident = Ident (updateValue update) $ rowType source_t
tell ([mkLet [] [Ident source source_t] $ BasicOp $ Update
(updateSource update)
(fullSlice source_t $ updateIndices update) $
snd $ mergeParam summary],
[mkLet [] [elmident] $ BasicOp $ Index
(updateName update) (fullSlice (typeOf $ updateType update) $ updateIndices update)])
return $ Right (Param
mergename
(toDecl (typeOf mergeattr) Unique),
Var source)
| otherwise = return $ Left $ mergeParam summary
mkResAndPat summaries =
let (origpat,extrapat) = partitionEithers $ map mkResAndPat' summaries
in (patternContextIdents pat,
origpat ++ extrapat)
mkResAndPat' summary
| Just (update, _, _) <- relatedUpdate summary =
Right (Ident (updateName update) (snd $ updateType update))
| otherwise =
Left (inPatternAs summary)
summariseLoop :: MonadFreshNames m =>
[DesiredUpdate (als, Type)]
-> Names
-> [(SubExp, Ident)]
-> [(Param DeclType, SubExp)]
-> Maybe (m [LoopResultSummary (als, Type)])
summariseLoop updates usedInBody resmap merge =
sequence <$> zipWithM summariseLoopResult resmap merge
where summariseLoopResult (se, v) (fparam, mergeinit)
| Just update <- find (updateHasValue $ identName v) updates =
if updateSource update `S.member` usedInBody
then Nothing
else if hasLoopInvariantShape fparam then Just $ do
lowered_array <- newVName "lowered_array"
return LoopResultSummary { resultSubExp = se
, inPatternAs = v
, mergeParam = (fparam, mergeinit)
, relatedUpdate = Just (update,
lowered_array,
updateType update)
}
else Nothing
summariseLoopResult _ _ =
Nothing
hasLoopInvariantShape = all loopInvariant . arrayDims . paramType
merge_param_names = map (paramName . fst) merge
loopInvariant (Var v) = v `notElem` merge_param_names
loopInvariant Constant{} = True
data LoopResultSummary attr =
LoopResultSummary { resultSubExp :: SubExp
, inPatternAs :: Ident
, mergeParam :: (Param DeclType, SubExp)
, relatedUpdate :: Maybe (DesiredUpdate attr, VName, attr)
}
deriving (Show)
indexSubstitutions :: [LoopResultSummary attr]
-> IndexSubstitutions attr
indexSubstitutions = mapMaybe getSubstitution
where getSubstitution res = do
(DesiredUpdate _ _ cs _ is _, nm, attr) <- relatedUpdate res
let name = paramName $ fst $ mergeParam res
return (name, (cs, nm, attr, is))
manipulateResult :: (Bindable lore, MonadFreshNames m) =>
[LoopResultSummary (LetAttr lore)]
-> IndexSubstitutions (LetAttr lore)
-> m (Result, Stms lore)
manipulateResult summaries substs = do
let (orig_ses,updated_ses) = partitionEithers $ map unchangedRes summaries
(subst_ses, res_bnds) <- runWriterT $ zipWithM substRes updated_ses substs
return (orig_ses ++ subst_ses, stmsFromList res_bnds)
where
unchangedRes summary =
case relatedUpdate summary of
Nothing -> Left $ resultSubExp summary
Just _ -> Right $ resultSubExp summary
substRes (Var res_v) (subst_v, (_, nm, _, _))
| res_v == subst_v =
return $ Var nm
substRes res_se (_, (cs, nm, attr, is)) = do
v' <- newIdent' (++"_updated") $ Ident nm $ typeOf attr
tell [certify cs $ mkLet [] [v'] $ BasicOp $
Update nm (fullSlice (typeOf attr) is) res_se]
return $ Var $ identName v'