{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Futhark.Optimise.DoubleBuffer
( doubleBuffer )
where
import Control.Monad.State
import Control.Monad.Writer
import Control.Monad.Reader
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import Data.List
import Futhark.MonadFreshNames
import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory
hiding (Prog, Body, Stm, Pattern, PatElem,
BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType)
import Futhark.Pass
doubleBuffer :: Pass ExplicitMemory ExplicitMemory
doubleBuffer =
Pass { passName = "Double buffer"
, passDescription = "Perform double buffering for merge parameters of sequential loops."
, passFunction = intraproceduralTransformation optimiseFunDef
}
optimiseFunDef :: FunDef ExplicitMemory -> PassM (FunDef ExplicitMemory)
optimiseFunDef fundec = modifyNameSource $ \src ->
let m = runDoubleBufferM $ inScopeOf fundec $ optimiseBody $ funDefBody fundec
(body', src') = runState (runReaderT m env) src
in (fundec { funDefBody = body' }, src')
where env = Env mempty optimiseKernelOp doNotTouchLoop
optimiseKernelOp (Inner k) = do
scope <- castScope <$> askScope
modifyNameSource $
runState (runReaderT (runDoubleBufferM $ Inner <$> optimiseKernel k) $
Env scope optimiseInKernelOp optimiseLoop)
where optimiseKernel =
mapKernelM identityKernelMapper
{ mapOnKernelBody = optimiseBody
, mapOnKernelKernelBody = optimiseKernelBody
, mapOnKernelLambda = optimiseLambda
}
optimiseKernelOp op = return op
optimiseInKernelOp (Inner (GroupStream w maxchunk lam accs arrs)) = do
lam' <- optimiseGroupStreamLambda lam
return $ Inner $ GroupStream w maxchunk lam' accs arrs
optimiseInKernelOp op = return op
doNotTouchLoop ctx val body = return (mempty, ctx, val, body)
data Env lore = Env { envScope :: Scope lore
, envOptimiseOp :: Op lore -> DoubleBufferM lore (Op lore)
, envOptimiseLoop :: OptimiseLoop lore
}
newtype DoubleBufferM lore a =
DoubleBufferM { runDoubleBufferM :: ReaderT (Env lore) (State VNameSource) a }
deriving (Functor, Applicative, Monad, MonadReader (Env lore), MonadFreshNames)
instance Annotations lore => HasScope lore (DoubleBufferM lore) where
askScope = asks envScope
instance Annotations lore => LocalScope lore (DoubleBufferM lore) where
localScope scope = local $ \env -> env { envScope = envScope env <> scope }
type LoreConstraints lore inner =
(ExpAttr lore ~ (), BodyAttr lore ~ (),
ExplicitMemorish lore, Op lore ~ MemOp inner)
optimiseBody :: LoreConstraints lore inner =>
Body lore -> DoubleBufferM lore (Body lore)
optimiseBody body = do
bnds' <- optimiseStms $ stmsToList $ bodyStms body
return $ body { bodyStms = stmsFromList bnds' }
optimiseStms :: LoreConstraints lore inner =>
[Stm lore] -> DoubleBufferM lore [Stm lore]
optimiseStms [] = return []
optimiseStms (e:es) = do
e_es <- optimiseStm e
es' <- localScope (castScope $ scopeOf e_es) $ optimiseStms es
return $ e_es ++ es'
optimiseStm :: forall lore inner.
LoreConstraints lore inner =>
Stm lore -> DoubleBufferM lore [Stm lore]
optimiseStm (Let pat aux (DoLoop ctx val form body)) = do
body' <- localScope (scopeOf form <> scopeOfFParams (map fst $ ctx++val)) $
optimiseBody body
opt_loop <- asks envOptimiseLoop
(bnds, ctx', val', body'') <- opt_loop ctx val body'
return $ bnds ++ [Let pat aux $ DoLoop ctx' val' form body'']
optimiseStm (Let pat aux e) =
pure . Let pat aux <$> mapExpM optimise e
where optimise = identityMapper { mapOnBody = \_ x ->
(optimiseBody x :: DoubleBufferM lore (Body lore))
, mapOnOp = optimiseOp
}
optimiseOp :: Op lore -> DoubleBufferM lore (Op lore)
optimiseOp op = do f <- asks envOptimiseOp
f op
optimiseKernelBody :: KernelBody InKernel
-> DoubleBufferM InKernel (KernelBody InKernel)
optimiseKernelBody kbody = do
stms' <- optimiseStms $ stmsToList $ kernelBodyStms kbody
return $ kbody { kernelBodyStms = stmsFromList stms' }
optimiseLambda :: Lambda InKernel -> DoubleBufferM InKernel (Lambda InKernel)
optimiseLambda lam = do
body <- localScope (castScope $ scopeOf lam) $ optimiseBody $ lambdaBody lam
return lam { lambdaBody = body }
optimiseGroupStreamLambda :: GroupStreamLambda InKernel
-> DoubleBufferM InKernel (GroupStreamLambda InKernel)
optimiseGroupStreamLambda lam = do
body <- localScope (scopeOf lam) $
optimiseBody $ groupStreamLambdaBody lam
return lam { groupStreamLambdaBody = body }
type OptimiseLoop lore =
[(FParam lore, SubExp)] -> [(FParam lore, SubExp)] -> Body lore
-> DoubleBufferM lore ([Stm lore],
[(FParam lore, SubExp)],
[(FParam lore, SubExp)],
Body lore)
optimiseLoop :: LoreConstraints lore inner => OptimiseLoop lore
optimiseLoop ctx val body = do
buffered <- doubleBufferMergeParams
(zip (map fst ctx) (bodyResult body)) (map fst merge)
(boundInBody body)
(merge', allocs) <- allocStms merge buffered
let body' = doubleBufferResult (map fst merge) buffered body
(ctx', val') = splitAt (length ctx) merge'
return (allocs, ctx', val', body')
where merge = ctx ++ val
data DoubleBuffer lore = BufferAlloc VName SubExp Space Bool
| BufferCopy VName IxFun VName Bool
| NoBuffer
deriving (Show)
doubleBufferMergeParams :: (ExplicitMemorish lore, MonadFreshNames m) =>
[(FParam lore,SubExp)]
-> [FParam lore] -> Names
-> m [DoubleBuffer lore]
doubleBufferMergeParams ctx_and_res val_params bound_in_loop =
evalStateT (mapM buffer val_params) M.empty
where loopVariant v = v `S.member` bound_in_loop ||
v `elem` map (paramName . fst) ctx_and_res
loopInvariantSize (Constant v) =
Just (Constant v, True)
loopInvariantSize (Var v) =
case find ((==v) . paramName . fst) ctx_and_res of
Just (_, Constant val) ->
Just (Constant val, False)
Just (_, Var v') | not $ loopVariant v' ->
Just (Var v', False)
Just _ ->
Nothing
Nothing ->
Just (Var v, True)
buffer fparam = case paramType fparam of
Mem size space
| Just (size', b) <- loopInvariantSize size -> do
bufname <- lift $ newVName "double_buffer_mem"
modify $ M.insert (paramName fparam) (bufname, b)
return $ BufferAlloc bufname size' space b
Array {}
| MemArray _ _ _ (ArrayIn mem ixfun) <- paramAttr fparam -> do
buffered <- gets $ M.lookup mem
case buffered of
Just (bufname, b) -> do
copyname <- lift $ newVName "double_buffer_array"
return $ BufferCopy bufname ixfun copyname b
Nothing ->
return NoBuffer
_ -> return NoBuffer
allocStms :: LoreConstraints lore inner =>
[(FParam lore,SubExp)] -> [DoubleBuffer lore]
-> DoubleBufferM lore ([(FParam lore, SubExp)], [Stm lore])
allocStms merge = runWriterT . zipWithM allocation merge
where allocation m@(Param pname _, _) (BufferAlloc name size space b) = do
tell [Let (Pattern [] [PatElem name $ MemMem size space]) (defAux ()) $
Op $ Alloc size space]
if b then return (Param pname $ MemMem size space, Var name)
else return m
allocation (f, Var v) (BufferCopy mem _ _ b) | b = do
v_copy <- lift $ newVName $ baseString v ++ "_double_buffer_copy"
(_v_mem, v_ixfun) <- lift $ lookupArraySummary v
let bt = elemType $ paramType f
shape = arrayShape $ paramType f
bound = MemArray bt shape NoUniqueness $ ArrayIn mem v_ixfun
tell [Let (Pattern [] [PatElem v_copy bound]) (defAux ()) $
BasicOp $ Copy v]
return (f, Var v_copy)
allocation (f, se) _ =
return (f, se)
doubleBufferResult :: (ExplicitMemorish lore,
ExpAttr lore ~ (), BodyAttr lore ~ ()) =>
[FParam lore] -> [DoubleBuffer lore]
-> Body lore -> Body lore
doubleBufferResult valparams buffered (Body () bnds res) =
let (ctx_res, val_res) = splitAt (length res - length valparams) res
(copybnds,val_res') =
unzip $ zipWith3 buffer valparams buffered val_res
in Body () (bnds<>stmsFromList (catMaybes copybnds)) $ ctx_res ++ val_res'
where buffer _ (BufferAlloc bufname _ _ _) _ =
(Nothing, Var bufname)
buffer fparam (BufferCopy bufname ixfun copyname _) (Var v) =
let t = resultType $ paramType fparam
summary = MemArray (elemType t) (arrayShape t) NoUniqueness $ ArrayIn bufname ixfun
copybnd = Let (Pattern [] [PatElem copyname summary]) (defAux ()) $
BasicOp $ Copy v
in (Just copybnd, Var copyname)
buffer _ _ se =
(Nothing, se)
parammap = M.fromList $ zip (map paramName valparams) res
resultType t = t `setArrayDims` map substitute (arrayDims t)
substitute (Var v)
| Just replacement <- M.lookup v parammap = replacement
substitute se =
se