{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.Optimise.InPlaceLowering
(
inPlaceLowering
) where
import Control.Monad.RWS
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Futhark.Analysis.Alias
import Futhark.Representation.Aliases
import Futhark.Representation.Kernels
import Futhark.Optimise.InPlaceLowering.LowerIntoStm
import Futhark.MonadFreshNames
import Futhark.Binder
import Futhark.Pass
import Futhark.Tools (fullSlice)
inPlaceLowering :: Pass Kernels Kernels
inPlaceLowering =
Pass "In-place lowering" "Lower in-place updates into loops" $
fmap removeProgAliases .
intraproceduralTransformation optimiseFunDef .
aliasAnalysis
optimiseFunDef :: MonadFreshNames m => FunDef (Aliases Kernels)
-> m (FunDef (Aliases Kernels))
optimiseFunDef fundec =
modifyNameSource $ runForwardingM lowerUpdateKernels onKernelOp $
bindingFParams (funDefParams fundec) $ do
body <- optimiseBody $ funDefBody fundec
return $ fundec { funDefBody = body }
type Constraints lore = (Bindable lore, CanBeAliased (Op lore))
optimiseBody :: Constraints lore =>
Body (Aliases lore) -> ForwardingM lore (Body (Aliases lore))
optimiseBody (Body als bnds res) = do
bnds' <- deepen $ optimiseStms (stmsToList bnds) $
mapM_ seen res
return $ Body als (stmsFromList bnds') res
where seen Constant{} = return ()
seen (Var v) = seenVar v
optimiseStms :: Constraints lore =>
[Stm (Aliases lore)] -> ForwardingM lore ()
-> ForwardingM lore [Stm (Aliases lore)]
optimiseStms [] m = m >> return []
optimiseStms (bnd:bnds) m = do
(bnds', bup) <- tapBottomUp $ bindingStm bnd $ optimiseStms bnds m
bnd' <- optimiseInStm bnd
case filter ((`elem` boundHere) . updateValue) $
forwardThese bup of
[] -> checkIfForwardableUpdate bnd' bnds'
updates -> do
let updateStms = map updateStm updates
lower <- asks lowerUpdate
case lower bnd' updates of
Just lowering -> do new_bnds <- lowering
new_bnds' <- optimiseStms new_bnds $
tell bup { forwardThese = [] }
return $ new_bnds' ++ bnds'
Nothing -> checkIfForwardableUpdate bnd' $
updateStms ++ bnds'
where boundHere = patternNames $ stmPattern bnd
checkIfForwardableUpdate bnd'@(Let (Pattern [] [PatElem v attr])
(StmAux cs _) e) bnds'
| BasicOp (Update src (DimFix i:slice) (Var ve)) <- e,
slice == drop 1 (fullSlice (typeOf attr) [DimFix i]) = do
forwarded <- maybeForward ve v attr cs src i
return $ if forwarded
then bnds'
else bnd' : bnds'
checkIfForwardableUpdate bnd' bnds' =
return $ bnd' : bnds'
optimiseInStm :: Constraints lore => Stm (Aliases lore) -> ForwardingM lore (Stm (Aliases lore))
optimiseInStm (Let pat attr e) =
Let pat attr <$> optimiseExp e
optimiseExp :: Constraints lore => Exp (Aliases lore) -> ForwardingM lore (Exp (Aliases lore))
optimiseExp (DoLoop ctx val form body) =
bindingScope (scopeOf form) $
bindingFParams (map fst $ ctx ++ val) $
DoLoop ctx val form <$> optimiseBody body
optimiseExp (Op op) = do
f <- asks onOp
Op <$> f op
optimiseExp e = mapExpM optimise e
where optimise = identityMapper { mapOnBody = const optimiseBody
}
onKernelOp :: OnOp Kernels
onKernelOp (Kernel debug kspace ts kbody) = do
old_scope <- askScope
modifyNameSource $ runForwardingM lowerUpdateInKernel onKernelExp $
bindingScope (castScope old_scope <> scopeOfKernelSpace kspace) $ do
stms <- deepen $ optimiseStms (stmsToList (kernelBodyStms kbody)) $
mapM_ seenVar $ freeIn $ kernelBodyResult kbody
return $ Kernel debug kspace ts $ kbody { kernelBodyStms = stmsFromList stms }
onKernelOp op = return op
onKernelExp :: OnOp InKernel
onKernelExp (GroupStream w maxchunk lam accs arrs) = do
lam_body <- bindingScope (scopeOf lam) $
optimiseBody $ groupStreamLambdaBody lam
let lam' = lam { groupStreamLambdaBody = lam_body }
return $ GroupStream w maxchunk lam' accs arrs
onKernelExp op = return op
data Entry lore = Entry { entryNumber :: Int
, entryAliases :: Names
, entryDepth :: Int
, entryOptimisable :: Bool
, entryType :: NameInfo (Aliases lore)
}
type VTable lore = M.Map VName (Entry lore)
type OnOp lore = Op (Aliases lore) -> ForwardingM lore (Op (Aliases lore))
data TopDown lore = TopDown { topDownCounter :: Int
, topDownTable :: VTable lore
, topDownDepth :: Int
, lowerUpdate :: LowerUpdate lore (ForwardingM lore)
, onOp :: OnOp lore
}
data BottomUp lore = BottomUp { bottomUpSeen :: Names
, forwardThese :: [DesiredUpdate (LetAttr (Aliases lore))]
}
instance Semigroup (BottomUp lore) where
BottomUp seen1 forward1 <> BottomUp seen2 forward2 =
BottomUp (seen1 <> seen2) (forward1 <> forward2)
instance Monoid (BottomUp lore) where
mempty = BottomUp mempty mempty
updateStm :: Constraints lore => DesiredUpdate (LetAttr (Aliases lore)) -> Stm (Aliases lore)
updateStm fwd =
mkLet [] [Ident (updateName fwd) $ typeOf $ updateType fwd] $
BasicOp $ Update (updateSource fwd)
(fullSlice (typeOf $ updateType fwd) $ updateIndices fwd) $
Var $ updateValue fwd
newtype ForwardingM lore a = ForwardingM (RWS (TopDown lore) (BottomUp lore) VNameSource a)
deriving (Monad, Applicative, Functor,
MonadReader (TopDown lore),
MonadWriter (BottomUp lore),
MonadState VNameSource)
instance MonadFreshNames (ForwardingM lore) where
getNameSource = get
putNameSource = put
instance Constraints lore => HasScope (Aliases lore) (ForwardingM lore) where
askScope = M.map entryType <$> asks topDownTable
runForwardingM :: LowerUpdate lore (ForwardingM lore) -> OnOp lore -> ForwardingM lore a
-> VNameSource -> (a, VNameSource)
runForwardingM f g (ForwardingM m) src = let (x, src', _) = runRWS m emptyTopDown src
in (x, src')
where emptyTopDown = TopDown { topDownCounter = 0
, topDownTable = M.empty
, topDownDepth = 0
, lowerUpdate = f
, onOp = g
}
bindingParams :: (attr -> NameInfo (Aliases lore))
-> [Param attr]
-> ForwardingM lore a
-> ForwardingM lore a
bindingParams f params = local $ \(TopDown n vtable d x y) ->
let entry fparam =
(paramName fparam,
Entry n mempty d False $ f $ paramAttr fparam)
entries = M.fromList $ map entry params
in TopDown (n+1) (M.union entries vtable) d x y
bindingFParams :: [FParam (Aliases lore)]
-> ForwardingM lore a
-> ForwardingM lore a
bindingFParams = bindingParams FParamInfo
bindingScope :: Scope (Aliases lore)
-> ForwardingM lore a
-> ForwardingM lore a
bindingScope scope = local $ \(TopDown n vtable d x y) ->
let entries = M.map entry scope
infoAliases (LetInfo (aliases, _)) = unNames aliases
infoAliases _ = mempty
entry info = Entry n (infoAliases info) d False info
in TopDown (n+1) (entries<>vtable) d x y
bindingStm :: Stm (Aliases lore)
-> ForwardingM lore a
-> ForwardingM lore a
bindingStm (Let pat _ _) = local $ \(TopDown n vtable d x y) ->
let entries = M.fromList $ map entry $ patternElements pat
entry patElem =
let (aliases, _) = patElemAttr patElem
in (patElemName patElem,
Entry n (unNames aliases) d True $ LetInfo $ patElemAttr patElem)
in TopDown (n+1) (M.union entries vtable) d x y
bindingNumber :: VName -> ForwardingM lore Int
bindingNumber name = do
res <- asks $ fmap entryNumber . M.lookup name . topDownTable
case res of Just n -> return n
Nothing -> fail $ "bindingNumber: variable " ++
pretty name ++ " not found."
deepen :: ForwardingM lore a -> ForwardingM lore a
deepen = local $ \env -> env { topDownDepth = topDownDepth env + 1 }
areAvailableBefore :: [SubExp] -> VName -> ForwardingM lore Bool
areAvailableBefore ses point = do
pointN <- bindingNumber point
nameNs <- mapM bindingNumber $ subExpVars ses
return $ all (< pointN) nameNs
isInCurrentBody :: VName -> ForwardingM lore Bool
isInCurrentBody name = do
current <- asks topDownDepth
res <- asks $ fmap entryDepth . M.lookup name . topDownTable
case res of Just d -> return $ d == current
Nothing -> fail $ "isInCurrentBody: variable " ++
pretty name ++ " not found."
isOptimisable :: VName -> ForwardingM lore Bool
isOptimisable name = do
res <- asks $ fmap entryOptimisable . M.lookup name . topDownTable
case res of Just b -> return b
Nothing -> fail $ "isOptimisable: variable " ++
pretty name ++ " not found."
seenVar :: VName -> ForwardingM lore ()
seenVar name = do
aliases <- asks $
maybe mempty entryAliases .
M.lookup name . topDownTable
tell $ mempty { bottomUpSeen = S.insert name aliases }
tapBottomUp :: ForwardingM lore a -> ForwardingM lore (a, BottomUp lore)
tapBottomUp m = do (x,bup) <- listen m
return (x, bup)
maybeForward :: Constraints lore =>
VName
-> VName -> LetAttr (Aliases lore) -> Certificates -> VName -> SubExp
-> ForwardingM lore Bool
maybeForward v dest_nm dest_attr cs src i = do
available <- [i,Var src] `areAvailableBefore` v
certs_available <- map Var (S.toList $ freeIn cs) `areAvailableBefore` v
samebody <- isInCurrentBody v
optimisable <- isOptimisable v
not_prim <- not . primType <$> lookupType v
if available && certs_available && samebody && optimisable && not_prim then do
let fwd = DesiredUpdate dest_nm dest_attr cs src [DimFix i] v
tell mempty { forwardThese = [fwd] }
return True
else return False