{-# LANGUAGE TypeFamilies #-}

-- | Sequentialise any remaining SOACs.  It is very important that
-- this is run *after* any access-pattern-related optimisation,
-- because this pass will destroy information.
--
-- This pass conceptually contains three subpasses:
--
-- 1. Sequentialise 'Stream' operations, leaving other SOACs intact.
--
-- 2. Apply whole-program simplification.
--
-- 3. Sequentialise remaining SOACs.
--
-- This is because sequentialisation of streams creates many SOACs
-- operating on single-element arrays, which can be efficiently
-- simplified away, but only *before* they are turned into loops.  In
-- principle this pass could be split into multiple, but for now it is
-- kept together.
module Futhark.Optimise.Unstream (unstreamGPU, unstreamMC) where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Futhark.IR.GPU
import Futhark.IR.GPU qualified as GPU
import Futhark.IR.GPU.Simplify (simplifyGPU)
import Futhark.IR.MC
import Futhark.IR.MC qualified as MC
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.FirstOrderTransform qualified as FOT

-- | The pass for GPU kernels.
unstreamGPU :: Pass GPU GPU
unstreamGPU :: Pass GPU GPU
unstreamGPU = forall rep.
ASTRep rep =>
(Stage -> OnOp rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
unstream Stage -> OnOp GPU
onHostOp Prog GPU -> PassM (Prog GPU)
simplifyGPU

-- | The pass for multicore.
unstreamMC :: Pass MC MC
unstreamMC :: Pass MC MC
unstreamMC = forall rep.
ASTRep rep =>
(Stage -> OnOp rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
unstream Stage -> OnOp MC
onMCOp Prog MC -> PassM (Prog MC)
MC.simplifyProg

data Stage = SeqStreams | SeqAll

unstream ::
  ASTRep rep =>
  (Stage -> OnOp rep) ->
  (Prog rep -> PassM (Prog rep)) ->
  Pass rep rep
unstream :: forall rep.
ASTRep rep =>
(Stage -> OnOp rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
unstream Stage -> OnOp rep
onOp Prog rep -> PassM (Prog rep)
simplify =
  forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"unstream" String
"sequentialise remaining SOACs" forall a b. (a -> b) -> a -> b
$
    forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation (Stage -> Scope rep -> Stms rep -> PassM (Stms rep)
optimise Stage
SeqStreams)
      forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Prog rep -> PassM (Prog rep)
simplify
      forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation (Stage -> Scope rep -> Stms rep -> PassM (Stms rep)
optimise Stage
SeqAll)
  where
    optimise :: Stage -> Scope rep -> Stms rep -> PassM (Stms rep)
optimise Stage
stage Scope rep
scope Stms rep
stms =
      forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$
        forall s a. State s a -> s -> (a, s)
runState forall a b. (a -> b) -> a -> b
$
          forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall rep.
ASTRep rep =>
OnOp rep -> Stms rep -> UnstreamM rep (Stms rep)
optimiseStms (Stage -> OnOp rep
onOp Stage
stage) Stms rep
stms) Scope rep
scope

type UnstreamM rep = ReaderT (Scope rep) (State VNameSource)

type OnOp rep =
  Pat (LetDec rep) -> StmAux (ExpDec rep) -> Op rep -> UnstreamM rep [Stm rep]

optimiseStms ::
  ASTRep rep =>
  OnOp rep ->
  Stms rep ->
  UnstreamM rep (Stms rep)
optimiseStms :: forall rep.
ASTRep rep =>
OnOp rep -> Stms rep -> UnstreamM rep (Stms rep)
optimiseStms OnOp rep
onOp Stms rep
stms =
  forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms rep
stms) forall a b. (a -> b) -> a -> b
$
    forall rep. [Stm rep] -> Stms rep
stmsFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall rep.
ASTRep rep =>
OnOp rep -> Stm rep -> UnstreamM rep [Stm rep]
optimiseStm OnOp rep
onOp) (forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms)

optimiseBody ::
  ASTRep rep =>
  OnOp rep ->
  Body rep ->
  UnstreamM rep (Body rep)
optimiseBody :: forall rep.
ASTRep rep =>
OnOp rep -> Body rep -> UnstreamM rep (Body rep)
optimiseBody OnOp rep
onOp (Body BodyDec rep
aux Stms rep
stms Result
res) =
  forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
aux forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep.
ASTRep rep =>
OnOp rep -> Stms rep -> UnstreamM rep (Stms rep)
optimiseStms OnOp rep
onOp Stms rep
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

optimiseKernelBody ::
  ASTRep rep =>
  OnOp rep ->
  KernelBody rep ->
  UnstreamM rep (KernelBody rep)
optimiseKernelBody :: forall rep.
ASTRep rep =>
OnOp rep -> KernelBody rep -> UnstreamM rep (KernelBody rep)
optimiseKernelBody OnOp rep
onOp (KernelBody BodyDec rep
attr Stms rep
stms [KernelResult]
res) =
  forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms rep
stms) forall a b. (a -> b) -> a -> b
$
    forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
attr
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall rep. [Stm rep] -> Stms rep
stmsFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall rep.
ASTRep rep =>
OnOp rep -> Stm rep -> UnstreamM rep [Stm rep]
optimiseStm OnOp rep
onOp) (forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms))
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

optimiseLambda ::
  ASTRep rep =>
  OnOp rep ->
  Lambda rep ->
  UnstreamM rep (Lambda rep)
optimiseLambda :: forall rep.
ASTRep rep =>
OnOp rep -> Lambda rep -> UnstreamM rep (Lambda rep)
optimiseLambda OnOp rep
onOp Lambda rep
lam = forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) forall a b. (a -> b) -> a -> b
$ do
  Body rep
body <- forall rep.
ASTRep rep =>
OnOp rep -> Body rep -> UnstreamM rep (Body rep)
optimiseBody OnOp rep
onOp forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = Body rep
body}

optimiseStm ::
  ASTRep rep =>
  OnOp rep ->
  Stm rep ->
  UnstreamM rep [Stm rep]
optimiseStm :: forall rep.
ASTRep rep =>
OnOp rep -> Stm rep -> UnstreamM rep [Stm rep]
optimiseStm OnOp rep
onOp (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Op Op rep
op)) =
  OnOp rep
onOp Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Op rep
op
optimiseStm OnOp rep
onOp (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper rep rep (ReaderT (Scope rep) (State VNameSource))
optimise Exp rep
e)
  where
    optimise :: Mapper rep rep (ReaderT (Scope rep) (State VNameSource))
optimise =
      forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope rep
-> Body rep -> ReaderT (Scope rep) (State VNameSource) (Body rep)
mapOnBody = \Scope rep
scope ->
            forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope rep
scope forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
ASTRep rep =>
OnOp rep -> Body rep -> UnstreamM rep (Body rep)
optimiseBody OnOp rep
onOp
        }

optimiseSegOp ::
  ASTRep rep =>
  OnOp rep ->
  SegOp lvl rep ->
  UnstreamM rep (SegOp lvl rep)
optimiseSegOp :: forall rep lvl.
ASTRep rep =>
OnOp rep -> SegOp lvl rep -> UnstreamM rep (SegOp lvl rep)
optimiseSegOp OnOp rep
onOp SegOp lvl rep
op =
  forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. SegSpace -> Scope rep
scopeOfSegSpace forall a b. (a -> b) -> a -> b
$ forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
op) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper lvl rep rep (ReaderT (Scope rep) (State VNameSource))
optimise SegOp lvl rep
op
  where
    optimise :: SegOpMapper lvl rep rep (ReaderT (Scope rep) (State VNameSource))
optimise =
      forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpBody :: KernelBody rep
-> ReaderT (Scope rep) (State VNameSource) (KernelBody rep)
mapOnSegOpBody = forall rep.
ASTRep rep =>
OnOp rep -> KernelBody rep -> UnstreamM rep (KernelBody rep)
optimiseKernelBody OnOp rep
onOp,
          mapOnSegOpLambda :: Lambda rep -> ReaderT (Scope rep) (State VNameSource) (Lambda rep)
mapOnSegOpLambda = forall rep.
ASTRep rep =>
OnOp rep -> Lambda rep -> UnstreamM rep (Lambda rep)
optimiseLambda OnOp rep
onOp
        }

onMCOp :: Stage -> OnOp MC
onMCOp :: Stage -> OnOp MC
onMCOp Stage
stage Pat (LetDec MC)
pat StmAux (ExpDec MC)
aux (ParOp Maybe (SegOp () MC)
par_op SegOp () MC
op) = do
  Maybe (SegOp () MC)
par_op' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall rep lvl.
ASTRep rep =>
OnOp rep -> SegOp lvl rep -> UnstreamM rep (SegOp lvl rep)
optimiseSegOp (Stage -> OnOp MC
onMCOp Stage
stage)) Maybe (SegOp () MC)
par_op
  SegOp () MC
op' <- forall rep lvl.
ASTRep rep =>
OnOp rep -> SegOp lvl rep -> UnstreamM rep (SegOp lvl rep)
optimiseSegOp (Stage -> OnOp MC
onMCOp Stage
stage) SegOp () MC
op
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec MC)
pat StmAux (ExpDec MC)
aux forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp Maybe (SegOp () MC)
par_op' SegOp () MC
op']
onMCOp Stage
stage Pat (LetDec MC)
pat StmAux (ExpDec MC)
aux (MC.OtherOp SOAC MC
soac)
  | forall rep. Stage -> SOAC rep -> Bool
sequentialise Stage
stage SOAC MC
soac = do
      Stms MC
stms <- forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec MC)
pat SOAC MC
soac
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$
        forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms MC
stms) forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall rep.
ASTRep rep =>
OnOp rep -> Stm rep -> UnstreamM rep [Stm rep]
optimiseStm (Stage -> OnOp MC
onMCOp Stage
stage)) forall a b. (a -> b) -> a -> b
$
            forall rep. Stms rep -> [Stm rep]
stmsToList Stms MC
stms
  | Bool
otherwise =
      -- Still sequentialise whatever's inside.
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec MC)
pat StmAux (ExpDec MC)
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (op :: * -> *) rep. op rep -> MCOp op rep
MC.OtherOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper MC MC (ReaderT (Scope MC) (State VNameSource))
optimise SOAC MC
soac)
  where
    optimise :: SOACMapper MC MC (ReaderT (Scope MC) (State VNameSource))
optimise =
      forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper
        { mapOnSOACLambda :: Lambda MC -> ReaderT (Scope MC) (State VNameSource) (Lambda MC)
mapOnSOACLambda = forall rep.
ASTRep rep =>
OnOp rep -> Lambda rep -> UnstreamM rep (Lambda rep)
optimiseLambda (Stage -> OnOp MC
onMCOp Stage
stage)
        }

sequentialise :: Stage -> SOAC rep -> Bool
sequentialise :: forall rep. Stage -> SOAC rep -> Bool
sequentialise Stage
SeqStreams Stream {} = Bool
True
sequentialise Stage
SeqStreams SOAC rep
_ = Bool
False
sequentialise Stage
SeqAll SOAC rep
_ = Bool
True

onHostOp :: Stage -> OnOp GPU
onHostOp :: Stage -> OnOp GPU
onHostOp Stage
stage Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (GPU.OtherOp SOAC GPU
soac)
  | forall rep. Stage -> SOAC rep -> Bool
sequentialise Stage
stage SOAC GPU
soac = do
      Stms GPU
stms <- forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec GPU)
pat SOAC GPU
soac
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$
        forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms) forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall rep.
ASTRep rep =>
OnOp rep -> Stm rep -> UnstreamM rep [Stm rep]
optimiseStm (Stage -> OnOp GPU
onHostOp Stage
stage)) forall a b. (a -> b) -> a -> b
$
            forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms
  | Bool
otherwise =
      -- Still sequentialise whatever's inside.
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (op :: * -> *) rep. op rep -> HostOp op rep
GPU.OtherOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise SOAC GPU
soac)
  where
    optimise :: SOACMapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise =
      forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper
        { mapOnSOACLambda :: Lambda GPU -> ReaderT (Scope GPU) (State VNameSource) (Lambda GPU)
mapOnSOACLambda = forall rep.
ASTRep rep =>
OnOp rep -> Lambda rep -> UnstreamM rep (Lambda rep)
optimiseLambda (Stage -> OnOp GPU
onHostOp Stage
stage)
        }
onHostOp Stage
stage Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (SegOp SegOp SegLevel GPU
op) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep lvl.
ASTRep rep =>
OnOp rep -> SegOp lvl rep -> UnstreamM rep (SegOp lvl rep)
optimiseSegOp (Stage -> OnOp GPU
onHostOp Stage
stage) SegOp SegLevel GPU
op)
onHostOp Stage
_ Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Op GPU
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure [forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op Op GPU
op]