{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Unstream (unstream) where
import Control.Monad.State
import Control.Monad.Reader
import Futhark.MonadFreshNames
import Futhark.Representation.Kernels
import Futhark.Pass
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT
unstream :: Pass Kernels Kernels
unstream :: Pass Kernels Kernels
unstream = String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"unstream" String
"sequentialise remaining SOACs" ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$
(Scope Kernels -> Stms Kernels -> PassM (Stms Kernels))
-> Prog Kernels -> PassM (Prog Kernels)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope Kernels -> Stms Kernels -> PassM (Stms Kernels)
forall (m :: * -> *).
MonadFreshNames m =>
Scope Kernels -> Stms Kernels -> m (Stms Kernels)
optimise
where optimise :: Scope Kernels -> Stms Kernels -> m (Stms Kernels)
optimise Scope Kernels
scope Stms Kernels
stms =
(VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels))
-> (VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Stms Kernels)
-> VNameSource -> (Stms Kernels, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms Kernels)
-> VNameSource -> (Stms Kernels, VNameSource))
-> State VNameSource (Stms Kernels)
-> VNameSource
-> (Stms Kernels, VNameSource)
forall a b. (a -> b) -> a -> b
$ ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> Scope Kernels -> State VNameSource (Stms Kernels)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms) Scope Kernels
scope
type UnstreamM = ReaderT (Scope Kernels) (State VNameSource)
optimiseStms :: Stms Kernels -> UnstreamM (Stms Kernels)
optimiseStms :: Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms =
Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms) (ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
[Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels)
-> ([[Stm Kernels]] -> [Stm Kernels])
-> [[Stm Kernels]]
-> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Stm Kernels]] -> [Stm Kernels]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Stm Kernels]] -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
optimiseStm (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms)
optimiseBody :: Body Kernels -> UnstreamM (Body Kernels)
optimiseBody :: Body Kernels -> UnstreamM (Body Kernels)
optimiseBody (Body () Stms Kernels
stms Result
res) =
BodyAttr Kernels -> Stms Kernels -> Result -> Body Kernels
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () (Stms Kernels -> Result -> Body Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT
(Scope Kernels) (State VNameSource) (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms ReaderT
(Scope Kernels) (State VNameSource) (Result -> Body Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) Result
-> UnstreamM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope Kernels) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
optimiseKernelBody :: KernelBody Kernels -> UnstreamM (KernelBody Kernels)
optimiseKernelBody :: KernelBody Kernels -> UnstreamM (KernelBody Kernels)
optimiseKernelBody (KernelBody () Stms Kernels
stms [KernelResult]
res) =
Scope Kernels
-> UnstreamM (KernelBody Kernels) -> UnstreamM (KernelBody Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms) (UnstreamM (KernelBody Kernels) -> UnstreamM (KernelBody Kernels))
-> UnstreamM (KernelBody Kernels) -> UnstreamM (KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$
BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () (Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT
(Scope Kernels)
(State VNameSource)
([KernelResult] -> KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels)
-> ([[Stm Kernels]] -> [Stm Kernels])
-> [[Stm Kernels]]
-> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Stm Kernels]] -> [Stm Kernels]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Stm Kernels]] -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
optimiseStm (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms)) ReaderT
(Scope Kernels)
(State VNameSource)
([KernelResult] -> KernelBody Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [KernelResult]
-> UnstreamM (KernelBody Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult]
-> ReaderT (Scope Kernels) (State VNameSource) [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res
optimiseLambda :: Lambda Kernels -> UnstreamM (Lambda Kernels)
optimiseLambda :: Lambda Kernels -> UnstreamM (Lambda Kernels)
optimiseLambda Lambda Kernels
lam = Scope Kernels
-> UnstreamM (Lambda Kernels) -> UnstreamM (Lambda Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope Kernels
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams ([Param Type] -> Scope Kernels) -> [Param Type] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam) (UnstreamM (Lambda Kernels) -> UnstreamM (Lambda Kernels))
-> UnstreamM (Lambda Kernels) -> UnstreamM (Lambda Kernels)
forall a b. (a -> b) -> a -> b
$ do
Body Kernels
body <- Body Kernels -> UnstreamM (Body Kernels)
optimiseBody (Body Kernels -> UnstreamM (Body Kernels))
-> Body Kernels -> UnstreamM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam
Lambda Kernels -> UnstreamM (Lambda Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda Kernels
lam { lambdaBody :: Body Kernels
lambdaBody = Body Kernels
body}
optimiseStm :: Stm Kernels -> UnstreamM [Stm Kernels]
optimiseStm :: Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
optimiseStm (Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
_ (Op (OtherOp soac))) = do
Stms Kernels
stms <- Binder Kernels ()
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels ()
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> Binder Kernels ()
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT Kernels (State VNameSource)))
-> SOAC (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (BinderT Kernels (State VNameSource)))
Pattern Kernels
pat SOAC (Lore (BinderT Kernels (State VNameSource)))
SOAC Kernels
soac
([[Stm Kernels]] -> [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Stm Kernels]] -> [Stm Kernels]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall a b. (a -> b) -> a -> b
$ Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms) (ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]])
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall a b. (a -> b) -> a -> b
$ (Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
optimiseStm ([Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms
optimiseStm (Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
aux (Op (SegOp op))) =
Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegSpace -> Scope Kernels) -> SegSpace -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> SegSpace
forall lore. SegOp lore -> SegSpace
segSpace SegOp Kernels
op) (ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall a b. (a -> b) -> a -> b
$
Stm Kernels -> [Stm Kernels]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm Kernels -> [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
aux (ExpT Kernels -> Stm Kernels)
-> (SegOp Kernels -> ExpT Kernels) -> SegOp Kernels -> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostOp Kernels (SOAC Kernels) -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (HostOp Kernels (SOAC Kernels) -> ExpT Kernels)
-> (SegOp Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp Kernels
-> ExpT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp Kernels -> Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (SegOp Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper
Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
-> SegOp Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (SegOp Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper
Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise SegOp Kernels
op)
where optimise :: SegOpMapper
Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise = SegOpMapper Any Any (ReaderT (Scope Kernels) (State VNameSource))
forall (m :: * -> *) lore. Monad m => SegOpMapper lore lore m
identitySegOpMapper { mapOnSegOpBody :: KernelBody Kernels -> UnstreamM (KernelBody Kernels)
mapOnSegOpBody = KernelBody Kernels -> UnstreamM (KernelBody Kernels)
optimiseKernelBody
, mapOnSegOpLambda :: Lambda Kernels -> UnstreamM (Lambda Kernels)
mapOnSegOpLambda = Lambda Kernels -> UnstreamM (Lambda Kernels)
optimiseLambda
}
optimiseStm (Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
aux ExpT Kernels
e) =
Stm Kernels -> [Stm Kernels]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm Kernels -> [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
aux (ExpT Kernels -> Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper
Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
-> ExpT Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper
Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise ExpT Kernels
e)
where optimise :: Mapper
Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise = Mapper
Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope Kernels -> Body Kernels -> UnstreamM (Body Kernels)
mapOnBody = \Scope Kernels
scope -> Scope Kernels
-> UnstreamM (Body Kernels) -> UnstreamM (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (UnstreamM (Body Kernels) -> UnstreamM (Body Kernels))
-> (Body Kernels -> UnstreamM (Body Kernels))
-> Body Kernels
-> UnstreamM (Body Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body Kernels -> UnstreamM (Body Kernels)
optimiseBody }