{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Unstream
( unstream )
where
import Control.Monad.State
import Control.Monad.Reader
import qualified Data.Set as S
import qualified Data.Map as M
import Futhark.Representation.AST.Attributes.Aliases
import qualified Futhark.Analysis.Alias as Alias
import Futhark.MonadFreshNames
import Futhark.Representation.Kernels
import Futhark.Pass
import Futhark.Tools
unstream :: Pass Kernels Kernels
unstream = Pass "unstream" "Remove whole-array streams in kernels" $
intraproceduralTransformation optimiseFunDef
optimiseFunDef :: MonadFreshNames m => FunDef Kernels -> m (FunDef Kernels)
optimiseFunDef fundec = do
body' <- modifyNameSource $ runState $
runReaderT m (scopeOfFParams (funDefParams fundec))
return fundec { funDefBody = body' }
where m = optimiseBody $ funDefBody fundec
type UnstreamM = ReaderT (Scope Kernels) (State VNameSource)
optimiseBody :: Body Kernels -> UnstreamM (Body Kernels)
optimiseBody (Body () stms res) =
localScope (scopeOf stms) $
Body () <$> (stmsFromList . concat <$> mapM optimiseStm (stmsToList stms)) <*> pure res
optimiseStm :: Stm Kernels -> UnstreamM [Stm Kernels]
optimiseStm (Let pat aux (Op (Kernel desc space ts body))) = do
inv <- S.fromList . M.keys <$> askScope
stms' <- localScope (scopeOfKernelSpace space) $
runBinder_ $ optimiseInKernelStms inv $ kernelBodyStms body
return [Let pat aux $ Op $ Kernel desc space ts $ body { kernelBodyStms = stms' }]
optimiseStm (Let pat aux e) =
pure <$> (Let pat aux <$> mapExpM optimise e)
where optimise = identityMapper { mapOnBody = \scope -> localScope scope . optimiseBody }
type Invariant = S.Set VName
type InKernelM = Binder InKernel
optimiseInKernelStms :: Invariant -> Stms InKernel -> InKernelM ()
optimiseInKernelStms inv = mapM_ (optimiseInKernelStm inv) . stmsToList
optimiseInKernelStm :: Invariant -> Stm InKernel -> InKernelM ()
optimiseInKernelStm inv (Let pat aux (Op (GroupStream w max_chunk lam accs arrs)))
| max_chunk == w || maybe False (`S.notMember` inv) (subExpVar w) = do
let GroupStreamLambda chunk_size chunk_offset acc_params arr_params body = lam
letBindNames_ [chunk_size] $ BasicOp $ SubExp $ constant (1::Int32)
loop_body <- insertStmsM $ do
forM_ (zip arr_params arrs) $ \(p,a) ->
letBindNames_ [paramName p] $
BasicOp $ Index a $ fullSlice (paramType p)
[DimSlice (Var chunk_offset) (Var chunk_size) (constant (1::Int32))]
localScope (scopeOfLParams acc_params) $ optimiseInBody inv body
let lam_consumed = consumedInBody $ Alias.analyseBody $ groupStreamLambdaBody lam
uniqueIfConsumed p | paramName p `S.member` lam_consumed =
fmap (`toDecl` Unique) p
| otherwise = fmap (`toDecl` Nonunique) p
merge = zip (map uniqueIfConsumed acc_params) accs
certifying (stmAuxCerts aux) $
letBind_ pat $ DoLoop [] merge (ForLoop chunk_offset Int32 w []) loop_body
optimiseInKernelStm inv (Let pat aux e) =
addStm =<< (Let pat aux <$> mapExpM optimise e)
where optimise = identityMapper
{ mapOnBody = \scope -> localScope scope . optimiseInBody inv }
optimiseInBody :: Invariant -> Body InKernel -> InKernelM (Body InKernel)
optimiseInBody inv body = do
stms' <- collectStms_ $ optimiseInKernelStms inv $ bodyStms body
return body { bodyStms = stms' }