{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
-- | Sequentialise any remaining SOACs.  It is very important that
-- this is run *after* any access-pattern-related optimisation,
-- because this pass will destroy information.
module Futhark.Optimise.Unstream (unstream) where

import Control.Monad.State
import Control.Monad.Reader

import Futhark.MonadFreshNames
import Futhark.IR.Kernels
import Futhark.Pass
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT

-- | The pass definition.
unstream :: Pass Kernels Kernels
unstream :: Pass Kernels Kernels
unstream = String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"unstream" String
"sequentialise remaining SOACs" ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$
           (Scope Kernels -> Stms Kernels -> PassM (Stms Kernels))
-> Prog Kernels -> PassM (Prog Kernels)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope Kernels -> Stms Kernels -> PassM (Stms Kernels)
forall (m :: * -> *).
MonadFreshNames m =>
Scope Kernels -> Stms Kernels -> m (Stms Kernels)
optimise
  where optimise :: Scope Kernels -> Stms Kernels -> m (Stms Kernels)
optimise Scope Kernels
scope Stms Kernels
stms =
          (VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels))
-> (VNameSource -> (Stms Kernels, VNameSource)) -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Stms Kernels)
-> VNameSource -> (Stms Kernels, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms Kernels)
 -> VNameSource -> (Stms Kernels, VNameSource))
-> State VNameSource (Stms Kernels)
-> VNameSource
-> (Stms Kernels, VNameSource)
forall a b. (a -> b) -> a -> b
$ ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> Scope Kernels -> State VNameSource (Stms Kernels)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms) Scope Kernels
scope

type UnstreamM = ReaderT (Scope Kernels) (State VNameSource)

optimiseStms :: Stms Kernels -> UnstreamM (Stms Kernels)
optimiseStms :: Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms =
  Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms) (ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
 -> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
  [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels)
-> ([[Stm Kernels]] -> [Stm Kernels])
-> [[Stm Kernels]]
-> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Stm Kernels]] -> [Stm Kernels]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Stm Kernels]] -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels
 -> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
optimiseStm (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms)

optimiseBody :: Body Kernels -> UnstreamM (Body Kernels)
optimiseBody :: Body Kernels -> UnstreamM (Body Kernels)
optimiseBody (Body () Stms Kernels
stms Result
res) =
  BodyDec Kernels -> Stms Kernels -> Result -> Body Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () (Stms Kernels -> Result -> Body Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT
     (Scope Kernels) (State VNameSource) (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
optimiseStms Stms Kernels
stms ReaderT
  (Scope Kernels) (State VNameSource) (Result -> Body Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) Result
-> UnstreamM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope Kernels) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

optimiseKernelBody :: KernelBody Kernels -> UnstreamM (KernelBody Kernels)
optimiseKernelBody :: KernelBody Kernels -> UnstreamM (KernelBody Kernels)
optimiseKernelBody (KernelBody () Stms Kernels
stms [KernelResult]
res) =
  Scope Kernels
-> UnstreamM (KernelBody Kernels) -> UnstreamM (KernelBody Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms) (UnstreamM (KernelBody Kernels) -> UnstreamM (KernelBody Kernels))
-> UnstreamM (KernelBody Kernels) -> UnstreamM (KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$
  BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () (Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
-> ReaderT
     (Scope Kernels)
     (State VNameSource)
     ([KernelResult] -> KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels)
-> ([[Stm Kernels]] -> [Stm Kernels])
-> [[Stm Kernels]]
-> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Stm Kernels]] -> [Stm Kernels]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Stm Kernels]] -> Stms Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels
 -> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
optimiseStm (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms)) ReaderT
  (Scope Kernels)
  (State VNameSource)
  ([KernelResult] -> KernelBody Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [KernelResult]
-> UnstreamM (KernelBody Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult]
-> ReaderT (Scope Kernels) (State VNameSource) [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

optimiseLambda :: Lambda Kernels -> UnstreamM (Lambda Kernels)
optimiseLambda :: Lambda Kernels -> UnstreamM (Lambda Kernels)
optimiseLambda Lambda Kernels
lam = Scope Kernels
-> UnstreamM (Lambda Kernels) -> UnstreamM (Lambda Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param Type] -> Scope Kernels) -> [Param Type] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam) (UnstreamM (Lambda Kernels) -> UnstreamM (Lambda Kernels))
-> UnstreamM (Lambda Kernels) -> UnstreamM (Lambda Kernels)
forall a b. (a -> b) -> a -> b
$ do
  Body Kernels
body <- Body Kernels -> UnstreamM (Body Kernels)
optimiseBody (Body Kernels -> UnstreamM (Body Kernels))
-> Body Kernels -> UnstreamM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam
  Lambda Kernels -> UnstreamM (Lambda Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda Kernels
lam { lambdaBody :: Body Kernels
lambdaBody = Body Kernels
body}

optimiseStm :: Stm Kernels -> UnstreamM [Stm Kernels]

optimiseStm :: Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
optimiseStm (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
_ (Op (OtherOp soac))) = do
  Stms Kernels
stms <- Binder Kernels ()
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels ()
 -> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels))
-> Binder Kernels ()
-> ReaderT (Scope Kernels) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT Kernels (State VNameSource)))
-> SOAC (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (BinderT Kernels (State VNameSource)))
Pattern Kernels
pat SOAC (Lore (BinderT Kernels (State VNameSource)))
SOAC Kernels
soac
  ([[Stm Kernels]] -> [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Stm Kernels]] -> [Stm Kernels]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
 -> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall a b. (a -> b) -> a -> b
$ Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms) (ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
 -> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]])
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall a b. (a -> b) -> a -> b
$ (Stm Kernels
 -> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
optimiseStm ([Stm Kernels]
 -> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]])
-> [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [[Stm Kernels]]
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms

optimiseStm (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Op (SegOp op))) =
  Scope Kernels
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegSpace -> Scope Kernels) -> SegSpace -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel Kernels
op) (ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
 -> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall a b. (a -> b) -> a -> b
$
  Stm Kernels -> [Stm Kernels]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm Kernels -> [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels)
-> (SegOp SegLevel Kernels -> ExpT Kernels)
-> SegOp SegLevel Kernels
-> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostOp Kernels (SOAC Kernels) -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (HostOp Kernels (SOAC Kernels) -> ExpT Kernels)
-> (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels
-> ExpT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> Stm Kernels)
-> ReaderT
     (Scope Kernels) (State VNameSource) (SegOp SegLevel Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper
  SegLevel
  Kernels
  Kernels
  (ReaderT (Scope Kernels) (State VNameSource))
-> SegOp SegLevel Kernels
-> ReaderT
     (Scope Kernels) (State VNameSource) (SegOp SegLevel Kernels)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper
  SegLevel
  Kernels
  Kernels
  (ReaderT (Scope Kernels) (State VNameSource))
forall lvl.
SegOpMapper
  lvl Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise SegOp SegLevel Kernels
op)
  where optimise :: SegOpMapper
  lvl Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise = SegOpMapper
  lvl Any Any (ReaderT (Scope Kernels) (State VNameSource))
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper { mapOnSegOpBody :: KernelBody Kernels -> UnstreamM (KernelBody Kernels)
mapOnSegOpBody = KernelBody Kernels -> UnstreamM (KernelBody Kernels)
optimiseKernelBody
                                       , mapOnSegOpLambda :: Lambda Kernels -> UnstreamM (Lambda Kernels)
mapOnSegOpLambda = Lambda Kernels -> UnstreamM (Lambda Kernels)
optimiseLambda
                                       }

optimiseStm (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux ExpT Kernels
e) =
  Stm Kernels -> [Stm Kernels]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm Kernels -> [Stm Kernels])
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) [Stm Kernels]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
-> ReaderT (Scope Kernels) (State VNameSource) (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
-> ExpT Kernels
-> ReaderT (Scope Kernels) (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise ExpT Kernels
e)
  where optimise :: Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
optimise = Mapper
  Kernels Kernels (ReaderT (Scope Kernels) (State VNameSource))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope Kernels -> Body Kernels -> UnstreamM (Body Kernels)
mapOnBody = \Scope Kernels
scope -> Scope Kernels
-> UnstreamM (Body Kernels) -> UnstreamM (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (UnstreamM (Body Kernels) -> UnstreamM (Body Kernels))
-> (Body Kernels -> UnstreamM (Body Kernels))
-> Body Kernels
-> UnstreamM (Body Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body Kernels -> UnstreamM (Body Kernels)
optimiseBody }