{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}

-- | Expand allocations inside of maps when possible.
module Futhark.Pass.ExpandAllocations (expandAllocations) where

import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.List (find, foldl')
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.Rephrase
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Error
import Futhark.IR
import qualified Futhark.IR.Kernels.Simplify as Kernels
import Futhark.IR.KernelsMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Lore (addScopeWisdom)
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.Kernels (explicitAllocationsInStms)
import Futhark.Pass.ExtractKernels.BlockedKernel (nonSegRed)
import Futhark.Pass.ExtractKernels.ToKernels (segThread)
import Futhark.Tools
import Futhark.Transform.CopyPropagate (copyPropagateInFun)
import Futhark.Transform.Rename (renameStm)
import Futhark.Util (mapAccumLM)
import Futhark.Util.IntegralExp
import Prelude hiding (quot)

-- | The memory expansion pass definition.
expandAllocations :: Pass KernelsMem KernelsMem
expandAllocations :: Pass KernelsMem KernelsMem
expandAllocations =
  [Char]
-> [Char]
-> (Prog KernelsMem -> PassM (Prog KernelsMem))
-> Pass KernelsMem KernelsMem
forall fromlore tolore.
[Char]
-> [Char]
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass [Char]
"expand allocations" [Char]
"Expand allocations" ((Prog KernelsMem -> PassM (Prog KernelsMem))
 -> Pass KernelsMem KernelsMem)
-> (Prog KernelsMem -> PassM (Prog KernelsMem))
-> Pass KernelsMem KernelsMem
forall a b. (a -> b) -> a -> b
$
    \(Prog Stms KernelsMem
consts [FunDef KernelsMem]
funs) -> do
      Stms KernelsMem
consts' <-
        (VNameSource -> (Stms KernelsMem, VNameSource))
-> PassM (Stms KernelsMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms KernelsMem, VNameSource))
 -> PassM (Stms KernelsMem))
-> (VNameSource -> (Stms KernelsMem, VNameSource))
-> PassM (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ Either [Char] (Stms KernelsMem, VNameSource)
-> (Stms KernelsMem, VNameSource)
forall a. Either [Char] a -> a
limitationOnLeft (Either [Char] (Stms KernelsMem, VNameSource)
 -> (Stms KernelsMem, VNameSource))
-> (VNameSource -> Either [Char] (Stms KernelsMem, VNameSource))
-> VNameSource
-> (Stms KernelsMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either [Char]) (Stms KernelsMem)
-> VNameSource -> Either [Char] (Stms KernelsMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either [Char]))
  (Stms KernelsMem)
-> Scope KernelsMem
-> StateT VNameSource (Either [Char]) (Stms KernelsMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
transformStms Stms KernelsMem
consts) Scope KernelsMem
forall a. Monoid a => a
mempty)
      Stms KernelsMem -> [FunDef KernelsMem] -> Prog KernelsMem
forall lore. Stms lore -> [FunDef lore] -> Prog lore
Prog Stms KernelsMem
consts' ([FunDef KernelsMem] -> Prog KernelsMem)
-> PassM [FunDef KernelsMem] -> PassM (Prog KernelsMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FunDef KernelsMem -> PassM (FunDef KernelsMem))
-> [FunDef KernelsMem] -> PassM [FunDef KernelsMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope KernelsMem -> FunDef KernelsMem -> PassM (FunDef KernelsMem)
transformFunDef (Scope KernelsMem
 -> FunDef KernelsMem -> PassM (FunDef KernelsMem))
-> Scope KernelsMem
-> FunDef KernelsMem
-> PassM (FunDef KernelsMem)
forall a b. (a -> b) -> a -> b
$ Stms KernelsMem -> Scope KernelsMem
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms KernelsMem
consts') [FunDef KernelsMem]
funs

-- Cannot use intraproceduralTransformation because it might create
-- duplicate size keys (which are not fixed by renamer, and size
-- keys must currently be globally unique).

type ExpandM = ReaderT (Scope KernelsMem) (StateT VNameSource (Either String))

limitationOnLeft :: Either String a -> a
limitationOnLeft :: forall a. Either [Char] a -> a
limitationOnLeft = ([Char] -> a) -> (a -> a) -> Either [Char] a -> a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char] -> a
forall a. [Char] -> a
compilerLimitationS a -> a
forall a. a -> a
id

transformFunDef ::
  Scope KernelsMem ->
  FunDef KernelsMem ->
  PassM (FunDef KernelsMem)
transformFunDef :: Scope KernelsMem -> FunDef KernelsMem -> PassM (FunDef KernelsMem)
transformFunDef Scope KernelsMem
scope FunDef KernelsMem
fundec = do
  Body KernelsMem
body' <- (VNameSource -> (Body KernelsMem, VNameSource))
-> PassM (Body KernelsMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Body KernelsMem, VNameSource))
 -> PassM (Body KernelsMem))
-> (VNameSource -> (Body KernelsMem, VNameSource))
-> PassM (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ Either [Char] (Body KernelsMem, VNameSource)
-> (Body KernelsMem, VNameSource)
forall a. Either [Char] a -> a
limitationOnLeft (Either [Char] (Body KernelsMem, VNameSource)
 -> (Body KernelsMem, VNameSource))
-> (VNameSource -> Either [Char] (Body KernelsMem, VNameSource))
-> VNameSource
-> (Body KernelsMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either [Char]) (Body KernelsMem)
-> VNameSource -> Either [Char] (Body KernelsMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either [Char]))
  (Body KernelsMem)
-> Scope KernelsMem
-> StateT VNameSource (Either [Char]) (Body KernelsMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either [Char]))
  (Body KernelsMem)
m Scope KernelsMem
forall a. Monoid a => a
mempty)
  SimpleOps KernelsMem
-> SymbolTable (Wise KernelsMem)
-> FunDef KernelsMem
-> PassM (FunDef KernelsMem)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> SymbolTable (Wise lore) -> FunDef lore -> m (FunDef lore)
copyPropagateInFun
    SimpleOps KernelsMem
simpleKernelsMem
    (Scope (Wise KernelsMem) -> SymbolTable (Wise KernelsMem)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope (Scope KernelsMem -> Scope (Wise KernelsMem)
forall lore. Scope lore -> Scope (Wise lore)
addScopeWisdom Scope KernelsMem
scope))
    FunDef KernelsMem
fundec {funDefBody :: Body KernelsMem
funDefBody = Body KernelsMem
body'}
  where
    m :: ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either [Char]))
  (Body KernelsMem)
m =
      Scope KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope KernelsMem
scope (ReaderT
   (Scope KernelsMem)
   (StateT VNameSource (Either [Char]))
   (Body KernelsMem)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Body KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$
        FunDef KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf FunDef KernelsMem
fundec (ReaderT
   (Scope KernelsMem)
   (StateT VNameSource (Either [Char]))
   (Body KernelsMem)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Body KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$
          Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
transformBody (Body KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Body KernelsMem))
-> Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ FunDef KernelsMem -> Body KernelsMem
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef KernelsMem
fundec

transformBody :: Body KernelsMem -> ExpandM (Body KernelsMem)
transformBody :: Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
transformBody (Body () Stms KernelsMem
stms Result
res) = BodyDec KernelsMem -> Stms KernelsMem -> Result -> Body KernelsMem
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () (Stms KernelsMem -> Result -> Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Result -> Body KernelsMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
transformStms Stms KernelsMem
stms ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either [Char]))
  (Result -> Body KernelsMem)
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either [Char])) Result
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either [Char])) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

transformStms :: Stms KernelsMem -> ExpandM (Stms KernelsMem)
transformStms :: Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
transformStms Stms KernelsMem
stms =
  Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms KernelsMem
stms (ReaderT
   (Scope KernelsMem)
   (StateT VNameSource (Either [Char]))
   (Stms KernelsMem)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Stms KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ [Stms KernelsMem] -> Stms KernelsMem
forall a. Monoid a => [a] -> a
mconcat ([Stms KernelsMem] -> Stms KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     [Stms KernelsMem]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Stms KernelsMem))
-> [Stm KernelsMem]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     [Stms KernelsMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
transformStm (Stms KernelsMem -> [Stm KernelsMem]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms KernelsMem
stms)

transformStm :: Stm KernelsMem -> ExpandM (Stms KernelsMem)
-- It is possible that we are unable to expand allocations in some
-- code versions.  If so, we can remove the offending branch.  Only if
-- both versions fail do we propagate the error.
transformStm :: Stm KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
transformStm (Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
aux (If SubExp
cond Body KernelsMem
tbranch Body KernelsMem
fbranch (IfDec [BranchType KernelsMem]
ts IfSort
IfEquiv))) = do
  Either [Char] (Body KernelsMem)
tbranch' <- (Body KernelsMem -> Either [Char] (Body KernelsMem)
forall a b. b -> Either a b
Right (Body KernelsMem -> Either [Char] (Body KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body KernelsMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
transformBody Body KernelsMem
tbranch) ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either [Char]))
  (Either [Char] (Body KernelsMem))
-> ([Char]
    -> ReaderT
         (Scope KernelsMem)
         (StateT VNameSource (Either [Char]))
         (Either [Char] (Body KernelsMem)))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body KernelsMem))
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either [Char] (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either [Char] (Body KernelsMem)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Either [Char] (Body KernelsMem)))
-> ([Char] -> Either [Char] (Body KernelsMem))
-> [Char]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body KernelsMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Either [Char] (Body KernelsMem)
forall a b. a -> Either a b
Left)
  Either [Char] (Body KernelsMem)
fbranch' <- (Body KernelsMem -> Either [Char] (Body KernelsMem)
forall a b. b -> Either a b
Right (Body KernelsMem -> Either [Char] (Body KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body KernelsMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
transformBody Body KernelsMem
fbranch) ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either [Char]))
  (Either [Char] (Body KernelsMem))
-> ([Char]
    -> ReaderT
         (Scope KernelsMem)
         (StateT VNameSource (Either [Char]))
         (Either [Char] (Body KernelsMem)))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body KernelsMem))
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either [Char] (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either [Char] (Body KernelsMem)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Either [Char] (Body KernelsMem)))
-> ([Char] -> Either [Char] (Body KernelsMem))
-> [Char]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body KernelsMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Either [Char] (Body KernelsMem)
forall a b. a -> Either a b
Left)
  case (Either [Char] (Body KernelsMem)
tbranch', Either [Char] (Body KernelsMem)
fbranch') of
    (Left [Char]
_, Right Body KernelsMem
fbranch'') ->
      Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Stms KernelsMem))
-> Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ Body KernelsMem -> Stms KernelsMem
useBranch Body KernelsMem
fbranch''
    (Right Body KernelsMem
tbranch'', Left [Char]
_) ->
      Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Stms KernelsMem))
-> Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ Body KernelsMem -> Stms KernelsMem
useBranch Body KernelsMem
tbranch''
    (Right Body KernelsMem
tbranch'', Right Body KernelsMem
fbranch'') ->
      Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Stms KernelsMem))
-> Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ Stm KernelsMem -> Stms KernelsMem
forall lore. Stm lore -> Stms lore
oneStm (Stm KernelsMem -> Stms KernelsMem)
-> Stm KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
aux (ExpT KernelsMem -> Stm KernelsMem)
-> ExpT KernelsMem -> Stm KernelsMem
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body KernelsMem
-> Body KernelsMem
-> IfDec (BranchType KernelsMem)
-> ExpT KernelsMem
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond Body KernelsMem
tbranch'' Body KernelsMem
fbranch'' ([BranchTypeMem] -> IfSort -> IfDec BranchTypeMem
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType KernelsMem]
[BranchTypeMem]
ts IfSort
IfEquiv)
    (Left [Char]
e, Either [Char] (Body KernelsMem)
_) ->
      [Char]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError [Char]
e
  where
    bindRes :: PatElemT (LetDec lore) -> SubExp -> Stm lore
bindRes PatElemT (LetDec lore)
pe SubExp
se = Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se

    useBranch :: Body KernelsMem -> Stms KernelsMem
useBranch Body KernelsMem
b =
      Body KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms Body KernelsMem
b
        Stms KernelsMem -> Stms KernelsMem -> Stms KernelsMem
forall a. Semigroup a => a -> a -> a
<> [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList ((PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> SubExp -> Stm KernelsMem)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> Result
-> [Stm KernelsMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> SubExp -> Stm KernelsMem
forall {lore}.
(ExpDec lore ~ ()) =>
PatElemT (LetDec lore) -> SubExp -> Stm lore
bindRes (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat) (Body KernelsMem -> Result
forall lore. BodyT lore -> Result
bodyResult Body KernelsMem
b))
transformStm (Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
aux ExpT KernelsMem
e) = do
  (Stms KernelsMem
bnds, ExpT KernelsMem
e') <- ExpT KernelsMem -> ExpandM (Stms KernelsMem, ExpT KernelsMem)
transformExp (ExpT KernelsMem -> ExpandM (Stms KernelsMem, ExpT KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (ExpT KernelsMem)
-> ExpandM (Stms KernelsMem, ExpT KernelsMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper
  KernelsMem
  KernelsMem
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
-> ExpT KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (ExpT KernelsMem)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper
  KernelsMem
  KernelsMem
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
transform ExpT KernelsMem
e
  Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Stms KernelsMem))
-> Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ Stms KernelsMem
bnds Stms KernelsMem -> Stms KernelsMem -> Stms KernelsMem
forall a. Semigroup a => a -> a -> a
<> Stm KernelsMem -> Stms KernelsMem
forall lore. Stm lore -> Stms lore
oneStm (Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
aux ExpT KernelsMem
e')
  where
    transform :: Mapper
  KernelsMem
  KernelsMem
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
transform =
      Mapper
  KernelsMem
  KernelsMem
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope KernelsMem
-> Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
mapOnBody = \Scope KernelsMem
scope -> Scope KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope KernelsMem
scope (ReaderT
   (Scope KernelsMem)
   (StateT VNameSource (Either [Char]))
   (Body KernelsMem)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Body KernelsMem))
-> (Body KernelsMem
    -> ReaderT
         (Scope KernelsMem)
         (StateT VNameSource (Either [Char]))
         (Body KernelsMem))
-> Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Body KernelsMem)
transformBody
        }

nameInfoConv :: NameInfo KernelsMem -> NameInfo KernelsMem
nameInfoConv :: NameInfo KernelsMem -> NameInfo KernelsMem
nameInfoConv (LetName LetDec KernelsMem
mem_info) = LetDec KernelsMem -> NameInfo KernelsMem
forall lore. LetDec lore -> NameInfo lore
LetName LetDec KernelsMem
mem_info
nameInfoConv (FParamName FParamInfo KernelsMem
mem_info) = FParamInfo KernelsMem -> NameInfo KernelsMem
forall lore. FParamInfo lore -> NameInfo lore
FParamName FParamInfo KernelsMem
mem_info
nameInfoConv (LParamName LParamInfo KernelsMem
mem_info) = LParamInfo KernelsMem -> NameInfo KernelsMem
forall lore. LParamInfo lore -> NameInfo lore
LParamName LParamInfo KernelsMem
mem_info
nameInfoConv (IndexName IntType
it) = IntType -> NameInfo KernelsMem
forall lore. IntType -> NameInfo lore
IndexName IntType
it

transformExp :: Exp KernelsMem -> ExpandM (Stms KernelsMem, Exp KernelsMem)
transformExp :: ExpT KernelsMem -> ExpandM (Stms KernelsMem, ExpT KernelsMem)
transformExp (Op (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody KernelsMem
kbody)))) = do
  (Stms KernelsMem
alloc_stms, ([Lambda KernelsMem]
_, KernelBody KernelsMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda KernelsMem]
-> KernelBody KernelsMem
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
transformScanRed SegLevel
lvl SegSpace
space [] KernelBody KernelsMem
kbody
  (Stms KernelsMem, ExpT KernelsMem)
-> ExpandM (Stms KernelsMem, ExpT KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms KernelsMem
alloc_stms,
      Op KernelsMem -> ExpT KernelsMem
forall lore. Op lore -> ExpT lore
Op (Op KernelsMem -> ExpT KernelsMem)
-> Op KernelsMem -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel KernelsMem -> HostOp KernelsMem ())
-> SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody KernelsMem
-> SegOp SegLevel KernelsMem
forall lvl lore.
lvl
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody lore
-> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody KernelsMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody KernelsMem
kbody)))) = do
  (Stms KernelsMem
alloc_stms, ([Lambda KernelsMem]
lams, KernelBody KernelsMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda KernelsMem]
-> KernelBody KernelsMem
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp KernelsMem -> Lambda KernelsMem)
-> [SegBinOp KernelsMem] -> [Lambda KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda [SegBinOp KernelsMem]
reds) KernelBody KernelsMem
kbody
  let reds' :: [SegBinOp KernelsMem]
reds' = (SegBinOp KernelsMem -> Lambda KernelsMem -> SegBinOp KernelsMem)
-> [SegBinOp KernelsMem]
-> [Lambda KernelsMem]
-> [SegBinOp KernelsMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp KernelsMem
red Lambda KernelsMem
lam -> SegBinOp KernelsMem
red {segBinOpLambda :: Lambda KernelsMem
segBinOpLambda = Lambda KernelsMem
lam}) [SegBinOp KernelsMem]
reds [Lambda KernelsMem]
lams
  (Stms KernelsMem, ExpT KernelsMem)
-> ExpandM (Stms KernelsMem, ExpT KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms KernelsMem
alloc_stms,
      Op KernelsMem -> ExpT KernelsMem
forall lore. Op lore -> ExpT lore
Op (Op KernelsMem -> ExpT KernelsMem)
-> Op KernelsMem -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel KernelsMem -> HostOp KernelsMem ())
-> SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody KernelsMem
-> SegOp SegLevel KernelsMem
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody lore
-> SegOp lvl lore
SegRed SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody KernelsMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
scans [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody KernelsMem
kbody)))) = do
  (Stms KernelsMem
alloc_stms, ([Lambda KernelsMem]
lams, KernelBody KernelsMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda KernelsMem]
-> KernelBody KernelsMem
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp KernelsMem -> Lambda KernelsMem)
-> [SegBinOp KernelsMem] -> [Lambda KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda [SegBinOp KernelsMem]
scans) KernelBody KernelsMem
kbody
  let scans' :: [SegBinOp KernelsMem]
scans' = (SegBinOp KernelsMem -> Lambda KernelsMem -> SegBinOp KernelsMem)
-> [SegBinOp KernelsMem]
-> [Lambda KernelsMem]
-> [SegBinOp KernelsMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp KernelsMem
red Lambda KernelsMem
lam -> SegBinOp KernelsMem
red {segBinOpLambda :: Lambda KernelsMem
segBinOpLambda = Lambda KernelsMem
lam}) [SegBinOp KernelsMem]
scans [Lambda KernelsMem]
lams
  (Stms KernelsMem, ExpT KernelsMem)
-> ExpandM (Stms KernelsMem, ExpT KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms KernelsMem
alloc_stms,
      Op KernelsMem -> ExpT KernelsMem
forall lore. Op lore -> ExpT lore
Op (Op KernelsMem -> ExpT KernelsMem)
-> Op KernelsMem -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel KernelsMem -> HostOp KernelsMem ())
-> SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody KernelsMem
-> SegOp SegLevel KernelsMem
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody lore
-> SegOp lvl lore
SegScan SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
scans' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody KernelsMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp KernelsMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody KernelsMem
kbody)))) = do
  (Stms KernelsMem
alloc_stms, ([Lambda KernelsMem]
lams', KernelBody KernelsMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda KernelsMem]
-> KernelBody KernelsMem
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda KernelsMem]
lams KernelBody KernelsMem
kbody
  let ops' :: [HistOp KernelsMem]
ops' = (HistOp KernelsMem -> Lambda KernelsMem -> HistOp KernelsMem)
-> [HistOp KernelsMem]
-> [Lambda KernelsMem]
-> [HistOp KernelsMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith HistOp KernelsMem -> Lambda KernelsMem -> HistOp KernelsMem
forall {lore} {lore}. HistOp lore -> Lambda lore -> HistOp lore
onOp [HistOp KernelsMem]
ops [Lambda KernelsMem]
lams'
  (Stms KernelsMem, ExpT KernelsMem)
-> ExpandM (Stms KernelsMem, ExpT KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms KernelsMem
alloc_stms,
      Op KernelsMem -> ExpT KernelsMem
forall lore. Op lore -> ExpT lore
Op (Op KernelsMem -> ExpT KernelsMem)
-> Op KernelsMem -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel KernelsMem -> HostOp KernelsMem ())
-> SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [HistOp KernelsMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody KernelsMem
-> SegOp SegLevel KernelsMem
forall lvl lore.
lvl
-> SegSpace
-> [HistOp lore]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody lore
-> SegOp lvl lore
SegHist SegLevel
lvl SegSpace
space [HistOp KernelsMem]
ops' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody KernelsMem
kbody'
    )
  where
    lams :: [Lambda KernelsMem]
lams = (HistOp KernelsMem -> Lambda KernelsMem)
-> [HistOp KernelsMem] -> [Lambda KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map HistOp KernelsMem -> Lambda KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp [HistOp KernelsMem]
ops
    onOp :: HistOp lore -> Lambda lore -> HistOp lore
onOp HistOp lore
op Lambda lore
lam = HistOp lore
op {histOp :: Lambda lore
histOp = Lambda lore
lam}
transformExp ExpT KernelsMem
e =
  (Stms KernelsMem, ExpT KernelsMem)
-> ExpandM (Stms KernelsMem, ExpT KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
forall a. Monoid a => a
mempty, ExpT KernelsMem
e)

transformScanRed ::
  SegLevel ->
  SegSpace ->
  [Lambda KernelsMem] ->
  KernelBody KernelsMem ->
  ExpandM (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
transformScanRed :: SegLevel
-> SegSpace
-> [Lambda KernelsMem]
-> KernelBody KernelsMem
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda KernelsMem]
ops KernelBody KernelsMem
kbody = do
  Names
bound_outside <- (Scope KernelsMem -> Names)
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either [Char])) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope KernelsMem -> Names)
 -> ReaderT
      (Scope KernelsMem) (StateT VNameSource (Either [Char])) Names)
-> (Scope KernelsMem -> Names)
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either [Char])) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope KernelsMem -> [VName]) -> Scope KernelsMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope KernelsMem -> [VName]
forall k a. Map k a -> [k]
M.keys
  let (KernelBody KernelsMem
kbody', Extraction
kbody_allocs) =
        SegLevel
-> Names
-> Names
-> KernelBody KernelsMem
-> (KernelBody KernelsMem, Extraction)
extractKernelBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_in_kernel KernelBody KernelsMem
kbody
      ([Lambda KernelsMem]
ops', [Extraction]
ops_allocs) = [(Lambda KernelsMem, Extraction)]
-> ([Lambda KernelsMem], [Extraction])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Lambda KernelsMem, Extraction)]
 -> ([Lambda KernelsMem], [Extraction]))
-> [(Lambda KernelsMem, Extraction)]
-> ([Lambda KernelsMem], [Extraction])
forall a b. (a -> b) -> a -> b
$ (Lambda KernelsMem -> (Lambda KernelsMem, Extraction))
-> [Lambda KernelsMem] -> [(Lambda KernelsMem, Extraction)]
forall a b. (a -> b) -> [a] -> [b]
map (SegLevel
-> Names
-> Names
-> Lambda KernelsMem
-> (Lambda KernelsMem, Extraction)
extractLambdaAllocations SegLevel
lvl Names
bound_outside Names
forall a. Monoid a => a
mempty) [Lambda KernelsMem]
ops
      variantAlloc :: (SegLevel, SubExp, Space) -> Bool
variantAlloc (SegLevel
_, Var VName
v, Space
_) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside
      variantAlloc (SegLevel, SubExp, Space)
_ = Bool
False
      allocs :: Extraction
allocs = Extraction
kbody_allocs Extraction -> Extraction -> Extraction
forall a. Semigroup a => a -> a -> a
<> [Extraction] -> Extraction
forall a. Monoid a => [a] -> a
mconcat [Extraction]
ops_allocs
      (Extraction
variant_allocs, Extraction
invariant_allocs) = ((SegLevel, SubExp, Space) -> Bool)
-> Extraction -> (Extraction, Extraction)
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition (SegLevel, SubExp, Space) -> Bool
variantAlloc Extraction
allocs
      badVariant :: (SegLevel, SubExp, Space) -> Bool
badVariant (SegLevel
_, Var VName
v, Space
_) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_kernel
      badVariant (SegLevel, SubExp, Space)
_ = Bool
False

  case ((SegLevel, SubExp, Space) -> Bool)
-> [(SegLevel, SubExp, Space)] -> Maybe (SegLevel, SubExp, Space)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (SegLevel, SubExp, Space) -> Bool
badVariant ([(SegLevel, SubExp, Space)] -> Maybe (SegLevel, SubExp, Space))
-> [(SegLevel, SubExp, Space)] -> Maybe (SegLevel, SubExp, Space)
forall a b. (a -> b) -> a -> b
$ Extraction -> [(SegLevel, SubExp, Space)]
forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
    Just (SegLevel, SubExp, Space)
v ->
      [Char]
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either [Char])) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Char]
 -> ReaderT
      (Scope KernelsMem) (StateT VNameSource (Either [Char])) ())
-> [Char]
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either [Char])) ()
forall a b. (a -> b) -> a -> b
$
        [Char]
"Cannot handle un-sliceable allocation size: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (SegLevel, SubExp, Space) -> [Char]
forall a. Pretty a => a -> [Char]
pretty (SegLevel, SubExp, Space)
v
          [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\nLikely cause: irregular nested operations inside parallel constructs."
    Maybe (SegLevel, SubExp, Space)
Nothing ->
      ()
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either [Char])) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  case SegLevel
lvl of
    SegGroup {}
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Extraction -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs ->
        [Char]
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either [Char])) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError [Char]
"Cannot handle invariant allocations in SegGroup."
    SegLevel
_ ->
      ()
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either [Char])) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody KernelsMem
-> (Stms KernelsMem
    -> KernelBody KernelsMem
    -> OffsetM
         (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem)))
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
forall b.
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody KernelsMem
-> (Stms KernelsMem -> KernelBody KernelsMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody KernelsMem
kbody' ((Stms KernelsMem
  -> KernelBody KernelsMem
  -> OffsetM
       (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem)))
 -> ExpandM
      (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem)))
-> (Stms KernelsMem
    -> KernelBody KernelsMem
    -> OffsetM
         (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem)))
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
forall a b. (a -> b) -> a -> b
$ \Stms KernelsMem
alloc_stms KernelBody KernelsMem
kbody'' -> do
    [Lambda KernelsMem]
ops'' <- [Lambda KernelsMem]
-> (Lambda KernelsMem -> OffsetM (Lambda KernelsMem))
-> OffsetM [Lambda KernelsMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Lambda KernelsMem]
ops' ((Lambda KernelsMem -> OffsetM (Lambda KernelsMem))
 -> OffsetM [Lambda KernelsMem])
-> (Lambda KernelsMem -> OffsetM (Lambda KernelsMem))
-> OffsetM [Lambda KernelsMem]
forall a b. (a -> b) -> a -> b
$ \Lambda KernelsMem
op' ->
      Scope KernelsMem
-> OffsetM (Lambda KernelsMem) -> OffsetM (Lambda KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Lambda KernelsMem -> Scope KernelsMem
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda KernelsMem
op') (OffsetM (Lambda KernelsMem) -> OffsetM (Lambda KernelsMem))
-> OffsetM (Lambda KernelsMem) -> OffsetM (Lambda KernelsMem)
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
offsetMemoryInLambda Lambda KernelsMem
op'
    (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
-> OffsetM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
alloc_stms, ([Lambda KernelsMem]
ops'', KernelBody KernelsMem
kbody''))
  where
    bound_in_kernel :: Names
bound_in_kernel =
      [VName] -> Names
namesFromList (Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space)
        Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody KernelsMem -> Names
boundInKernelBody KernelBody KernelsMem
kbody

boundInKernelBody :: KernelBody KernelsMem -> Names
boundInKernelBody :: KernelBody KernelsMem -> Names
boundInKernelBody = [VName] -> Names
namesFromList ([VName] -> Names)
-> (KernelBody KernelsMem -> [VName])
-> KernelBody KernelsMem
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope KernelsMem -> [VName]
forall k a. Map k a -> [k]
M.keys (Scope KernelsMem -> [VName])
-> (KernelBody KernelsMem -> Scope KernelsMem)
-> KernelBody KernelsMem
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms KernelsMem -> Scope KernelsMem
forall lore a. Scoped lore a => a -> Scope lore
scopeOf (Stms KernelsMem -> Scope KernelsMem)
-> (KernelBody KernelsMem -> Stms KernelsMem)
-> KernelBody KernelsMem
-> Scope KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms

allocsForBody ::
  Extraction ->
  Extraction ->
  SegLevel ->
  SegSpace ->
  KernelBody KernelsMem ->
  (Stms KernelsMem -> KernelBody KernelsMem -> OffsetM b) ->
  ExpandM b
allocsForBody :: forall b.
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody KernelsMem
-> (Stms KernelsMem -> KernelBody KernelsMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody KernelsMem
kbody' Stms KernelsMem -> KernelBody KernelsMem -> OffsetM b
m = do
  (RebaseMap
alloc_offsets, Stms KernelsMem
alloc_stms) <-
    SegLevel
-> SegSpace
-> Stms KernelsMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms KernelsMem)
memoryRequirements
      SegLevel
lvl
      SegSpace
space
      (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody')
      Extraction
variant_allocs
      Extraction
invariant_allocs

  Scope KernelsMem
scope <- ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either [Char]))
  (Scope KernelsMem)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  let scope' :: Scope KernelsMem
scope' = SegSpace -> Scope KernelsMem
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space Scope KernelsMem -> Scope KernelsMem -> Scope KernelsMem
forall a. Semigroup a => a -> a -> a
<> (NameInfo KernelsMem -> NameInfo KernelsMem)
-> Scope KernelsMem -> Scope KernelsMem
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo KernelsMem -> NameInfo KernelsMem
nameInfoConv Scope KernelsMem
scope
  ([Char] -> ExpandM b)
-> (b -> ExpandM b) -> Either [Char] b -> ExpandM b
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char] -> ExpandM b
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError b -> ExpandM b
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either [Char] b -> ExpandM b) -> Either [Char] b -> ExpandM b
forall a b. (a -> b) -> a -> b
$
    Scope KernelsMem -> RebaseMap -> OffsetM b -> Either [Char] b
forall a.
Scope KernelsMem -> RebaseMap -> OffsetM a -> Either [Char] a
runOffsetM Scope KernelsMem
scope' RebaseMap
alloc_offsets (OffsetM b -> Either [Char] b) -> OffsetM b -> Either [Char] b
forall a b. (a -> b) -> a -> b
$ do
      KernelBody KernelsMem
kbody'' <- KernelBody KernelsMem -> OffsetM (KernelBody KernelsMem)
offsetMemoryInKernelBody KernelBody KernelsMem
kbody'
      Stms KernelsMem -> KernelBody KernelsMem -> OffsetM b
m Stms KernelsMem
alloc_stms KernelBody KernelsMem
kbody''

memoryRequirements ::
  SegLevel ->
  SegSpace ->
  Stms KernelsMem ->
  Extraction ->
  Extraction ->
  ExpandM (RebaseMap, Stms KernelsMem)
memoryRequirements :: SegLevel
-> SegSpace
-> Stms KernelsMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms KernelsMem)
memoryRequirements SegLevel
lvl SegSpace
space Stms KernelsMem
kstms Extraction
variant_allocs Extraction
invariant_allocs = do
  (SubExp
num_threads, Stms KernelsMem
num_threads_stms) <-
    Binder KernelsMem SubExp
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (SubExp, Stms KernelsMem)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder KernelsMem SubExp
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (SubExp, Stms KernelsMem))
-> Binder KernelsMem SubExp
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (SubExp, Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$
      [Char]
-> Exp (Lore (BinderT KernelsMem (State VNameSource)))
-> Binder KernelsMem SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"num_threads" (Exp (Lore (BinderT KernelsMem (State VNameSource)))
 -> Binder KernelsMem SubExp)
-> Exp (Lore (BinderT KernelsMem (State VNameSource)))
-> Binder KernelsMem SubExp
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT KernelsMem
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT KernelsMem) -> BasicOp -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$
          BinOp -> SubExp -> SubExp -> BasicOp
BinOp
            (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
            (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
            (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)

  (Stms KernelsMem
invariant_alloc_stms, RebaseMap
invariant_alloc_offsets) <-
    Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms KernelsMem
num_threads_stms (ReaderT
   (Scope KernelsMem)
   (StateT VNameSource (Either [Char]))
   (Stms KernelsMem, RebaseMap)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Stms KernelsMem, RebaseMap))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
      (SubExp, Count NumGroups SubExp, Count GroupSize SubExp)
-> SegSpace
-> Extraction
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
expandedInvariantAllocations
        (SubExp
num_threads, SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl, SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
        SegSpace
space
        Extraction
invariant_allocs

  (Stms KernelsMem
variant_alloc_stms, RebaseMap
variant_alloc_offsets) <-
    Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms KernelsMem
num_threads_stms (ReaderT
   (Scope KernelsMem)
   (StateT VNameSource (Either [Char]))
   (Stms KernelsMem, RebaseMap)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Stms KernelsMem, RebaseMap))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
      SubExp
-> SegSpace
-> Stms KernelsMem
-> Extraction
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
expandedVariantAllocations
        SubExp
num_threads
        SegSpace
space
        Stms KernelsMem
kstms
        Extraction
variant_allocs

  (RebaseMap, Stms KernelsMem)
-> ExpandM (RebaseMap, Stms KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( RebaseMap
invariant_alloc_offsets RebaseMap -> RebaseMap -> RebaseMap
forall a. Semigroup a => a -> a -> a
<> RebaseMap
variant_alloc_offsets,
      Stms KernelsMem
num_threads_stms Stms KernelsMem -> Stms KernelsMem -> Stms KernelsMem
forall a. Semigroup a => a -> a -> a
<> Stms KernelsMem
invariant_alloc_stms Stms KernelsMem -> Stms KernelsMem -> Stms KernelsMem
forall a. Semigroup a => a -> a -> a
<> Stms KernelsMem
variant_alloc_stms
    )

-- | A description of allocations that have been extracted, and how
-- much memory (and which space) is needed.
type Extraction = M.Map VName (SegLevel, SubExp, Space)

extractKernelBodyAllocations ::
  SegLevel ->
  Names ->
  Names ->
  KernelBody KernelsMem ->
  ( KernelBody KernelsMem,
    Extraction
  )
extractKernelBodyAllocations :: SegLevel
-> Names
-> Names
-> KernelBody KernelsMem
-> (KernelBody KernelsMem, Extraction)
extractKernelBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel =
  SegLevel
-> Names
-> Names
-> (KernelBody KernelsMem -> Stms KernelsMem)
-> (Stms KernelsMem
    -> KernelBody KernelsMem -> KernelBody KernelsMem)
-> KernelBody KernelsMem
-> (KernelBody KernelsMem, Extraction)
forall body.
SegLevel
-> Names
-> Names
-> (body -> Stms KernelsMem)
-> (Stms KernelsMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms ((Stms KernelsMem
  -> KernelBody KernelsMem -> KernelBody KernelsMem)
 -> KernelBody KernelsMem -> (KernelBody KernelsMem, Extraction))
-> (Stms KernelsMem
    -> KernelBody KernelsMem -> KernelBody KernelsMem)
-> KernelBody KernelsMem
-> (KernelBody KernelsMem, Extraction)
forall a b. (a -> b) -> a -> b
$
    \Stms KernelsMem
stms KernelBody KernelsMem
kbody -> KernelBody KernelsMem
kbody {kernelBodyStms :: Stms KernelsMem
kernelBodyStms = Stms KernelsMem
stms}

extractBodyAllocations ::
  SegLevel ->
  Names ->
  Names ->
  Body KernelsMem ->
  (Body KernelsMem, Extraction)
extractBodyAllocations :: SegLevel
-> Names
-> Names
-> Body KernelsMem
-> (Body KernelsMem, Extraction)
extractBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel =
  SegLevel
-> Names
-> Names
-> (Body KernelsMem -> Stms KernelsMem)
-> (Stms KernelsMem -> Body KernelsMem -> Body KernelsMem)
-> Body KernelsMem
-> (Body KernelsMem, Extraction)
forall body.
SegLevel
-> Names
-> Names
-> (body -> Stms KernelsMem)
-> (Stms KernelsMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel Body KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms ((Stms KernelsMem -> Body KernelsMem -> Body KernelsMem)
 -> Body KernelsMem -> (Body KernelsMem, Extraction))
-> (Stms KernelsMem -> Body KernelsMem -> Body KernelsMem)
-> Body KernelsMem
-> (Body KernelsMem, Extraction)
forall a b. (a -> b) -> a -> b
$
    \Stms KernelsMem
stms Body KernelsMem
body -> Body KernelsMem
body {bodyStms :: Stms KernelsMem
bodyStms = Stms KernelsMem
stms}

extractLambdaAllocations ::
  SegLevel ->
  Names ->
  Names ->
  Lambda KernelsMem ->
  (Lambda KernelsMem, Extraction)
extractLambdaAllocations :: SegLevel
-> Names
-> Names
-> Lambda KernelsMem
-> (Lambda KernelsMem, Extraction)
extractLambdaAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel Lambda KernelsMem
lam = (Lambda KernelsMem
lam {lambdaBody :: Body KernelsMem
lambdaBody = Body KernelsMem
body'}, Extraction
allocs)
  where
    (Body KernelsMem
body', Extraction
allocs) = SegLevel
-> Names
-> Names
-> Body KernelsMem
-> (Body KernelsMem, Extraction)
extractBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel (Body KernelsMem -> (Body KernelsMem, Extraction))
-> Body KernelsMem -> (Body KernelsMem, Extraction)
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> Body KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
lam

extractGenericBodyAllocations ::
  SegLevel ->
  Names ->
  Names ->
  (body -> Stms KernelsMem) ->
  (Stms KernelsMem -> body -> body) ->
  body ->
  ( body,
    Extraction
  )
extractGenericBodyAllocations :: forall body.
SegLevel
-> Names
-> Names
-> (body -> Stms KernelsMem)
-> (Stms KernelsMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel body -> Stms KernelsMem
get_stms Stms KernelsMem -> body -> body
set_stms body
body =
  let bound_kernel' :: Names
bound_kernel' = Names
bound_kernel Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms KernelsMem -> Names
forall lore. Stms lore -> Names
boundByStms (body -> Stms KernelsMem
get_stms body
body)
      ([Stm KernelsMem]
stms, Extraction
allocs) =
        Writer Extraction [Stm KernelsMem]
-> ([Stm KernelsMem], Extraction)
forall w a. Writer w a -> (a, w)
runWriter (Writer Extraction [Stm KernelsMem]
 -> ([Stm KernelsMem], Extraction))
-> Writer Extraction [Stm KernelsMem]
-> ([Stm KernelsMem], Extraction)
forall a b. (a -> b) -> a -> b
$
          ([Maybe (Stm KernelsMem)] -> [Stm KernelsMem])
-> WriterT Extraction Identity [Maybe (Stm KernelsMem)]
-> Writer Extraction [Stm KernelsMem]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe (Stm KernelsMem)] -> [Stm KernelsMem]
forall a. [Maybe a] -> [a]
catMaybes (WriterT Extraction Identity [Maybe (Stm KernelsMem)]
 -> Writer Extraction [Stm KernelsMem])
-> WriterT Extraction Identity [Maybe (Stm KernelsMem)]
-> Writer Extraction [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$
            (Stm KernelsMem
 -> WriterT Extraction Identity (Maybe (Stm KernelsMem)))
-> [Stm KernelsMem]
-> WriterT Extraction Identity [Maybe (Stm KernelsMem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegLevel
-> Names
-> Names
-> Stm KernelsMem
-> WriterT Extraction Identity (Maybe (Stm KernelsMem))
extractStmAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel') ([Stm KernelsMem]
 -> WriterT Extraction Identity [Maybe (Stm KernelsMem)])
-> [Stm KernelsMem]
-> WriterT Extraction Identity [Maybe (Stm KernelsMem)]
forall a b. (a -> b) -> a -> b
$
              Stms KernelsMem -> [Stm KernelsMem]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms KernelsMem -> [Stm KernelsMem])
-> Stms KernelsMem -> [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$ body -> Stms KernelsMem
get_stms body
body
   in (Stms KernelsMem -> body -> body
set_stms ([Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm KernelsMem]
stms) body
body, Extraction
allocs)

expandable, notScalar :: Space -> Bool
expandable :: Space -> Bool
expandable (Space [Char]
"local") = Bool
False
expandable ScalarSpace {} = Bool
False
expandable Space
_ = Bool
True
notScalar :: Space -> Bool
notScalar ScalarSpace {} = Bool
False
notScalar Space
_ = Bool
True

extractStmAllocations ::
  SegLevel ->
  Names ->
  Names ->
  Stm KernelsMem ->
  Writer Extraction (Maybe (Stm KernelsMem))
extractStmAllocations :: SegLevel
-> Names
-> Names
-> Stm KernelsMem
-> WriterT Extraction Identity (Maybe (Stm KernelsMem))
extractStmAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel (Let (Pattern [] [PatElemT (LetDec KernelsMem)
patElem]) StmAux (ExpDec KernelsMem)
_ (Op (Alloc SubExp
size Space
space)))
  | Space -> Bool
expandable Space
space Bool -> Bool -> Bool
&& SubExp -> Bool
expandableSize SubExp
size
      -- FIXME: the '&& notScalar space' part is a hack because we
      -- don't otherwise hoist the sizes out far enough, and we
      -- promise to be super-duper-careful about not having variant
      -- scalar allocations.
      Bool -> Bool -> Bool
|| (SubExp -> Bool
boundInKernel SubExp
size Bool -> Bool -> Bool
&& Space -> Bool
notScalar Space
space) = do
    Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Extraction -> WriterT Extraction Identity ())
-> Extraction -> WriterT Extraction Identity ()
forall a b. (a -> b) -> a -> b
$ VName -> (SegLevel, SubExp, Space) -> Extraction
forall k a. k -> a -> Map k a
M.singleton (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
patElem) (SegLevel
lvl, SubExp
size, Space
space)
    Maybe (Stm KernelsMem)
-> WriterT Extraction Identity (Maybe (Stm KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stm KernelsMem)
forall a. Maybe a
Nothing
  where
    expandableSize :: SubExp -> Bool
expandableSize (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside Bool -> Bool -> Bool
|| VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    expandableSize Constant {} = Bool
True
    boundInKernel :: SubExp -> Bool
boundInKernel (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    boundInKernel Constant {} = Bool
False
extractStmAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel Stm KernelsMem
stm = do
  ExpT KernelsMem
e <- Mapper KernelsMem KernelsMem (WriterT Extraction Identity)
-> ExpT KernelsMem -> WriterT Extraction Identity (ExpT KernelsMem)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM (SegLevel
-> Mapper KernelsMem KernelsMem (WriterT Extraction Identity)
expMapper SegLevel
lvl) (ExpT KernelsMem -> WriterT Extraction Identity (ExpT KernelsMem))
-> ExpT KernelsMem -> WriterT Extraction Identity (ExpT KernelsMem)
forall a b. (a -> b) -> a -> b
$ Stm KernelsMem -> ExpT KernelsMem
forall lore. Stm lore -> Exp lore
stmExp Stm KernelsMem
stm
  Maybe (Stm KernelsMem)
-> WriterT Extraction Identity (Maybe (Stm KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Stm KernelsMem)
 -> WriterT Extraction Identity (Maybe (Stm KernelsMem)))
-> Maybe (Stm KernelsMem)
-> WriterT Extraction Identity (Maybe (Stm KernelsMem))
forall a b. (a -> b) -> a -> b
$ Stm KernelsMem -> Maybe (Stm KernelsMem)
forall a. a -> Maybe a
Just (Stm KernelsMem -> Maybe (Stm KernelsMem))
-> Stm KernelsMem -> Maybe (Stm KernelsMem)
forall a b. (a -> b) -> a -> b
$ Stm KernelsMem
stm {stmExp :: ExpT KernelsMem
stmExp = ExpT KernelsMem
e}
  where
    expMapper :: SegLevel
-> Mapper KernelsMem KernelsMem (WriterT Extraction Identity)
expMapper SegLevel
lvl' =
      Mapper KernelsMem KernelsMem (WriterT Extraction Identity)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope KernelsMem
-> Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem)
mapOnBody = (Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem))
-> Scope KernelsMem
-> Body KernelsMem
-> WriterT Extraction Identity (Body KernelsMem)
forall a b. a -> b -> a
const ((Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem))
 -> Scope KernelsMem
 -> Body KernelsMem
 -> WriterT Extraction Identity (Body KernelsMem))
-> (Body KernelsMem
    -> WriterT Extraction Identity (Body KernelsMem))
-> Scope KernelsMem
-> Body KernelsMem
-> WriterT Extraction Identity (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem)
onBody SegLevel
lvl',
          mapOnOp :: Op KernelsMem -> WriterT Extraction Identity (Op KernelsMem)
mapOnOp = Op KernelsMem -> WriterT Extraction Identity (Op KernelsMem)
MemOp (HostOp KernelsMem ())
-> WriterT Extraction Identity (MemOp (HostOp KernelsMem ()))
onOp
        }

    onBody :: SegLevel
-> Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem)
onBody SegLevel
lvl' Body KernelsMem
body = do
      let (Body KernelsMem
body', Extraction
allocs) = SegLevel
-> Names
-> Names
-> Body KernelsMem
-> (Body KernelsMem, Extraction)
extractBodyAllocations SegLevel
lvl' Names
bound_outside Names
bound_kernel Body KernelsMem
body
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Body KernelsMem
body'

    onOp :: MemOp (HostOp KernelsMem ())
-> WriterT Extraction Identity (MemOp (HostOp KernelsMem ()))
onOp (Inner (SegOp SegOp SegLevel KernelsMem
op)) =
      HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> (SegOp SegLevel KernelsMem -> HostOp KernelsMem ())
-> SegOp SegLevel KernelsMem
-> MemOp (HostOp KernelsMem ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel KernelsMem -> MemOp (HostOp KernelsMem ()))
-> WriterT Extraction Identity (SegOp SegLevel KernelsMem)
-> WriterT Extraction Identity (MemOp (HostOp KernelsMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper
  SegLevel KernelsMem KernelsMem (WriterT Extraction Identity)
-> SegOp SegLevel KernelsMem
-> WriterT Extraction Identity (SegOp SegLevel KernelsMem)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM (SegLevel
-> SegOpMapper
     SegLevel KernelsMem KernelsMem (WriterT Extraction Identity)
opMapper (SegOp SegLevel KernelsMem -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel KernelsMem
op)) SegOp SegLevel KernelsMem
op
    onOp MemOp (HostOp KernelsMem ())
op = MemOp (HostOp KernelsMem ())
-> WriterT Extraction Identity (MemOp (HostOp KernelsMem ()))
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp (HostOp KernelsMem ())
op

    opMapper :: SegLevel
-> SegOpMapper
     SegLevel KernelsMem KernelsMem (WriterT Extraction Identity)
opMapper SegLevel
lvl' =
      SegOpMapper SegLevel Any Any (WriterT Extraction Identity)
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda KernelsMem
-> WriterT Extraction Identity (Lambda KernelsMem)
mapOnSegOpLambda = SegLevel
-> Lambda KernelsMem
-> WriterT Extraction Identity (Lambda KernelsMem)
onLambda SegLevel
lvl',
          mapOnSegOpBody :: KernelBody KernelsMem
-> WriterT Extraction Identity (KernelBody KernelsMem)
mapOnSegOpBody = SegLevel
-> KernelBody KernelsMem
-> WriterT Extraction Identity (KernelBody KernelsMem)
onKernelBody SegLevel
lvl'
        }

    onKernelBody :: SegLevel
-> KernelBody KernelsMem
-> WriterT Extraction Identity (KernelBody KernelsMem)
onKernelBody SegLevel
lvl' KernelBody KernelsMem
body = do
      let (KernelBody KernelsMem
body', Extraction
allocs) = SegLevel
-> Names
-> Names
-> KernelBody KernelsMem
-> (KernelBody KernelsMem, Extraction)
extractKernelBodyAllocations SegLevel
lvl' Names
bound_outside Names
bound_kernel KernelBody KernelsMem
body
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      KernelBody KernelsMem
-> WriterT Extraction Identity (KernelBody KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody KernelsMem
body'

    onLambda :: SegLevel
-> Lambda KernelsMem
-> WriterT Extraction Identity (Lambda KernelsMem)
onLambda SegLevel
lvl' Lambda KernelsMem
lam = do
      Body KernelsMem
body <- SegLevel
-> Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem)
onBody SegLevel
lvl' (Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem))
-> Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> Body KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
lam
      Lambda KernelsMem
-> WriterT Extraction Identity (Lambda KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda KernelsMem
lam {lambdaBody :: Body KernelsMem
lambdaBody = Body KernelsMem
body}

expandedInvariantAllocations ::
  ( SubExp,
    Count NumGroups SubExp,
    Count GroupSize SubExp
  ) ->
  SegSpace ->
  Extraction ->
  ExpandM (Stms KernelsMem, RebaseMap)
expandedInvariantAllocations :: (SubExp, Count NumGroups SubExp, Count GroupSize SubExp)
-> SegSpace
-> Extraction
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
expandedInvariantAllocations
  ( SubExp
num_threads,
    Count SubExp
num_groups,
    Count SubExp
group_size
    )
  SegSpace
segspace
  Extraction
invariant_allocs = do
    -- We expand the invariant allocations by adding an inner dimension
    -- equal to the number of kernel threads.
    ([Stms KernelsMem]
alloc_bnds, [RebaseMap]
rebases) <- [(Stms KernelsMem, RebaseMap)] -> ([Stms KernelsMem], [RebaseMap])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms KernelsMem, RebaseMap)]
 -> ([Stms KernelsMem], [RebaseMap]))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     [(Stms KernelsMem, RebaseMap)]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     ([Stms KernelsMem], [RebaseMap])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, (SegLevel, SubExp, Space))
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Stms KernelsMem, RebaseMap))
-> [(VName, (SegLevel, SubExp, Space))]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     [(Stms KernelsMem, RebaseMap)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (SegLevel, SubExp, Space))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
expand (Extraction -> [(VName, (SegLevel, SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList Extraction
invariant_allocs)

    (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stms KernelsMem] -> Stms KernelsMem
forall a. Monoid a => [a] -> a
mconcat [Stms KernelsMem]
alloc_bnds, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
    where
      expand :: (VName, (SegLevel, SubExp, Space))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
expand (VName
mem, (SegLevel
lvl, SubExp
per_thread_size, Space
space)) = do
        VName
total_size <- [Char]
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either [Char])) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"total_size"
        let sizepat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
sizepat = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
total_size (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64]
            allocpat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
            num_users :: SubExp
num_users = case SegLevel
lvl of
              SegThread {} -> SubExp
num_threads
              SegGroup {} -> SubExp
num_groups
        (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList
              [ Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
sizepat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT KernelsMem -> Stm KernelsMem)
-> ExpT KernelsMem -> Stm KernelsMem
forall a b. (a -> b) -> a -> b
$
                  BasicOp -> ExpT KernelsMem
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT KernelsMem) -> BasicOp -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
num_users SubExp
per_thread_size,
                Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT KernelsMem -> Stm KernelsMem)
-> ExpT KernelsMem -> Stm KernelsMem
forall a b. (a -> b) -> a -> b
$
                  Op KernelsMem -> ExpT KernelsMem
forall lore. Op lore -> ExpT lore
Op (Op KernelsMem -> ExpT KernelsMem)
-> Op KernelsMem -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp KernelsMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc (VName -> SubExp
Var VName
total_size) Space
space
              ],
            VName
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ((([TPrimExp Int64 VName], PrimType)
  -> IxFun (TPrimExp Int64 VName))
 -> RebaseMap)
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ SegLevel
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SegLevel
lvl
          )

      untouched :: d -> DimIndex d
untouched d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1

      newBase :: SegLevel
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SegThread {} ([TPrimExp Int64 VName]
old_shape, PrimType
_) =
        let num_dims :: Int
num_dims = [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape
            perm :: [Int]
perm = Int
num_dims Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
0 .. Int
num_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
            root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun =
              [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota
                ( [TPrimExp Int64 VName]
old_shape
                    [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [ SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_groups TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
group_size
                       ]
                )
            permuted_ixfun :: IxFun (TPrimExp Int64 VName)
permuted_ixfun = IxFun (TPrimExp Int64 VName)
-> [Int] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TPrimExp Int64 VName)
root_ixfun [Int]
perm
            offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
              IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
permuted_ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
                TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (SegSpace -> VName
segFlat SegSpace
segspace)) DimIndex (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall a. a -> [a] -> [a]
:
                (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
         in IxFun (TPrimExp Int64 VName)
offset_ixfun
      newBase SegGroup {} ([TPrimExp Int64 VName]
old_shape, PrimType
_) =
        let root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_groups TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: [TPrimExp Int64 VName]
old_shape)
            offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
              IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
root_ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
                TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (SegSpace -> VName
segFlat SegSpace
segspace)) DimIndex (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall a. a -> [a] -> [a]
:
                (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
         in IxFun (TPrimExp Int64 VName)
offset_ixfun

expandedVariantAllocations ::
  SubExp ->
  SegSpace ->
  Stms KernelsMem ->
  Extraction ->
  ExpandM (Stms KernelsMem, RebaseMap)
expandedVariantAllocations :: SubExp
-> SegSpace
-> Stms KernelsMem
-> Extraction
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
expandedVariantAllocations SubExp
_ SegSpace
_ Stms KernelsMem
_ Extraction
variant_allocs
  | Extraction -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs = (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
forall a. Monoid a => a
mempty, RebaseMap
forall a. Monoid a => a
mempty)
expandedVariantAllocations SubExp
num_threads SegSpace
kspace Stms KernelsMem
kstms Extraction
variant_allocs = do
  let sizes_to_blocks :: [(SubExp, [(VName, Space)])]
sizes_to_blocks = Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes Extraction
variant_allocs
      variant_sizes :: Result
variant_sizes = ((SubExp, [(VName, Space)]) -> SubExp)
-> [(SubExp, [(VName, Space)])] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, [(VName, Space)])]
sizes_to_blocks

  (Stms Kernels
slice_stms, [VName]
offsets, [VName]
size_sums) <-
    SubExp
-> Result
-> SegSpace
-> Stms KernelsMem
-> ExpandM (Stms Kernels, [VName], [VName])
sliceKernelSizes SubExp
num_threads Result
variant_sizes SegSpace
kspace Stms KernelsMem
kstms
  -- Note the recursive call to expand allocations inside the newly
  -- produced kernels.
  (SymbolTable (Wise KernelsMem)
_, Stms KernelsMem
slice_stms_tmp) <-
    Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (SymbolTable (Wise KernelsMem), Stms KernelsMem)
forall (m :: * -> *).
(HasScope KernelsMem m, MonadFreshNames m) =>
Stms KernelsMem
-> m (SymbolTable (Wise KernelsMem), Stms KernelsMem)
simplifyStms (Stms KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (SymbolTable (Wise KernelsMem), Stms KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (SymbolTable (Wise KernelsMem), Stms KernelsMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms Kernels
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
forall (m :: * -> *).
(MonadFreshNames m, HasScope KernelsMem m) =>
Stms Kernels -> m (Stms KernelsMem)
explicitAllocationsInStms Stms Kernels
slice_stms
  Stms KernelsMem
slice_stms' <- Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem)
transformStms Stms KernelsMem
slice_stms_tmp

  let variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
      variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
variant_allocs' =
        [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(VName, (SubExp, SubExp, Space))]]
 -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall a b. (a -> b) -> a -> b
$
          ([(VName, Space)]
 -> (VName, VName) -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, Space)]]
-> [(VName, VName)]
-> [[(VName, (SubExp, SubExp, Space))]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            [(VName, Space)]
-> (VName, VName) -> [(VName, (SubExp, SubExp, Space))]
forall {a} {c}.
[(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo
            (((SubExp, [(VName, Space)]) -> [(VName, Space)])
-> [(SubExp, [(VName, Space)])] -> [[(VName, Space)]]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> [(VName, Space)]
forall a b. (a, b) -> b
snd [(SubExp, [(VName, Space)])]
sizes_to_blocks)
            ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
offsets [VName]
size_sums)
      memInfo :: [(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo [(a, c)]
blocks (VName
offset, VName
total_size) =
        [(a
mem, (VName -> SubExp
Var VName
offset, VName -> SubExp
Var VName
total_size, c
space)) | (a
mem, c
space) <- [(a, c)]
blocks]

  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the sum of the sizes required by different threads.
  ([Stm KernelsMem]
alloc_bnds, [RebaseMap]
rebases) <- [(Stm KernelsMem, RebaseMap)] -> ([Stm KernelsMem], [RebaseMap])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stm KernelsMem, RebaseMap)] -> ([Stm KernelsMem], [RebaseMap]))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     [(Stm KernelsMem, RebaseMap)]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     ([Stm KernelsMem], [RebaseMap])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, (SubExp, SubExp, Space))
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Stm KernelsMem, RebaseMap))
-> [(VName, (SubExp, SubExp, Space))]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     [(Stm KernelsMem, RebaseMap)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stm KernelsMem, RebaseMap)
expand [(VName, (SubExp, SubExp, Space))]
variant_allocs'

  (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms KernelsMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
slice_stms' Stms KernelsMem -> Stms KernelsMem -> Stms KernelsMem
forall a. Semigroup a => a -> a -> a
<> [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm KernelsMem]
alloc_bnds, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where
    expand :: (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stm KernelsMem, RebaseMap)
expand (VName
mem, (SubExp
offset, SubExp
total_size, Space
space)) = do
      let allocpat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
      (Stm KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stm KernelsMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT KernelsMem -> Stm KernelsMem)
-> ExpT KernelsMem -> Stm KernelsMem
forall a b. (a -> b) -> a -> b
$ Op KernelsMem -> ExpT KernelsMem
forall lore. Op lore -> ExpT lore
Op (Op KernelsMem -> ExpT KernelsMem)
-> Op KernelsMem -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp KernelsMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
total_size Space
space,
          VName
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ((([TPrimExp Int64 VName], PrimType)
  -> IxFun (TPrimExp Int64 VName))
 -> RebaseMap)
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
offset
        )

    num_threads' :: TPrimExp Int64 VName
num_threads' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads
    gtid :: TPrimExp Int64 VName
gtid = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
kspace

    -- For the variant allocations, we add an inner dimension,
    -- which is then offset by a thread-specific amount.
    newBase :: SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
size_per_thread ([TPrimExp Int64 VName]
old_shape, PrimType
pt) =
      let elems_per_thread :: TPrimExp Int64 VName
elems_per_thread =
            SubExp -> TPrimExp Int64 VName
pe64 SubExp
size_per_thread TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
          root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName
elems_per_thread, TPrimExp Int64 VName
num_threads']
          offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
            IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice
              IxFun (TPrimExp Int64 VName)
root_ixfun
              [ TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
0 TPrimExp Int64 VName
num_threads' TPrimExp Int64 VName
1,
                TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
gtid
              ]
          shapechange :: [DimChange (TPrimExp Int64 VName)]
shapechange =
            if [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
              then (TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimChange (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName)
forall d. d -> DimChange d
DimCoercion [TPrimExp Int64 VName]
old_shape
              else (TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimChange (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName)
forall d. d -> DimChange d
DimNew [TPrimExp Int64 VName]
old_shape
       in IxFun (TPrimExp Int64 VName)
-> [DimChange (TPrimExp Int64 VName)]
-> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
IxFun.reshape IxFun (TPrimExp Int64 VName)
offset_ixfun [DimChange (TPrimExp Int64 VName)]
shapechange

-- | A map from memory block names to new index function bases.
type RebaseMap = M.Map VName (([TPrimExp Int64 VName], PrimType) -> IxFun)

newtype OffsetM a
  = OffsetM
      ( ReaderT
          (Scope KernelsMem)
          (ReaderT RebaseMap (Either String))
          a
      )
  deriving
    ( Functor OffsetM
Functor OffsetM
-> (forall a. a -> OffsetM a)
-> (forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b c.
    (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM a)
-> Applicative OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
$c<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
liftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
$cliftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
$c<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
pure :: forall a. a -> OffsetM a
$cpure :: forall a. a -> OffsetM a
Applicative,
      (forall a b. (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b. a -> OffsetM b -> OffsetM a) -> Functor OffsetM
forall a b. a -> OffsetM b -> OffsetM a
forall a b. (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> OffsetM b -> OffsetM a
$c<$ :: forall a b. a -> OffsetM b -> OffsetM a
fmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
$cfmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
Functor,
      Applicative OffsetM
Applicative OffsetM
-> (forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a. a -> OffsetM a)
-> Monad OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> OffsetM a
$creturn :: forall a. a -> OffsetM a
>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
$c>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
Monad,
      HasScope KernelsMem,
      LocalScope KernelsMem,
      MonadError String
    )

runOffsetM :: Scope KernelsMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM :: forall a.
Scope KernelsMem -> RebaseMap -> OffsetM a -> Either [Char] a
runOffsetM Scope KernelsMem
scope RebaseMap
offsets (OffsetM ReaderT (Scope KernelsMem) (ReaderT RebaseMap (Either [Char])) a
m) =
  ReaderT RebaseMap (Either [Char]) a -> RebaseMap -> Either [Char] a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ReaderT (Scope KernelsMem) (ReaderT RebaseMap (Either [Char])) a
-> Scope KernelsMem -> ReaderT RebaseMap (Either [Char]) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope KernelsMem) (ReaderT RebaseMap (Either [Char])) a
m Scope KernelsMem
scope) RebaseMap
offsets

askRebaseMap :: OffsetM RebaseMap
askRebaseMap :: OffsetM RebaseMap
askRebaseMap = ReaderT
  (Scope KernelsMem) (ReaderT RebaseMap (Either [Char])) RebaseMap
-> OffsetM RebaseMap
forall a.
ReaderT (Scope KernelsMem) (ReaderT RebaseMap (Either [Char])) a
-> OffsetM a
OffsetM (ReaderT
   (Scope KernelsMem) (ReaderT RebaseMap (Either [Char])) RebaseMap
 -> OffsetM RebaseMap)
-> ReaderT
     (Scope KernelsMem) (ReaderT RebaseMap (Either [Char])) RebaseMap
-> OffsetM RebaseMap
forall a b. (a -> b) -> a -> b
$ ReaderT RebaseMap (Either [Char]) RebaseMap
-> ReaderT
     (Scope KernelsMem) (ReaderT RebaseMap (Either [Char])) RebaseMap
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ReaderT RebaseMap (Either [Char]) RebaseMap
forall r (m :: * -> *). MonadReader r m => m r
ask

lookupNewBase :: VName -> ([TPrimExp Int64 VName], PrimType) -> OffsetM (Maybe IxFun)
lookupNewBase :: VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
name ([TPrimExp Int64 VName], PrimType)
x = do
  RebaseMap
offsets <- OffsetM RebaseMap
askRebaseMap
  Maybe (IxFun (TPrimExp Int64 VName))
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (IxFun (TPrimExp Int64 VName))
 -> OffsetM (Maybe (IxFun (TPrimExp Int64 VName))))
-> Maybe (IxFun (TPrimExp Int64 VName))
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
forall a b. (a -> b) -> a -> b
$ ((([TPrimExp Int64 VName], PrimType)
 -> IxFun (TPrimExp Int64 VName))
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ ([TPrimExp Int64 VName], PrimType)
x) ((([TPrimExp Int64 VName], PrimType)
  -> IxFun (TPrimExp Int64 VName))
 -> IxFun (TPrimExp Int64 VName))
-> Maybe
     (([TPrimExp Int64 VName], PrimType)
      -> IxFun (TPrimExp Int64 VName))
-> Maybe (IxFun (TPrimExp Int64 VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> RebaseMap
-> Maybe
     (([TPrimExp Int64 VName], PrimType)
      -> IxFun (TPrimExp Int64 VName))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name RebaseMap
offsets

offsetMemoryInKernelBody :: KernelBody KernelsMem -> OffsetM (KernelBody KernelsMem)
offsetMemoryInKernelBody :: KernelBody KernelsMem -> OffsetM (KernelBody KernelsMem)
offsetMemoryInKernelBody KernelBody KernelsMem
kbody = do
  Scope KernelsMem
scope <- OffsetM (Scope KernelsMem)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  Stms KernelsMem
stms' <-
    [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm KernelsMem] -> Stms KernelsMem)
-> ((Scope KernelsMem, [Stm KernelsMem]) -> [Stm KernelsMem])
-> (Scope KernelsMem, [Stm KernelsMem])
-> Stms KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope KernelsMem, [Stm KernelsMem]) -> [Stm KernelsMem]
forall a b. (a, b) -> b
snd
      ((Scope KernelsMem, [Stm KernelsMem]) -> Stms KernelsMem)
-> OffsetM (Scope KernelsMem, [Stm KernelsMem])
-> OffsetM (Stms KernelsMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scope KernelsMem
 -> Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem))
-> Scope KernelsMem
-> [Stm KernelsMem]
-> OffsetM (Scope KernelsMem, [Stm KernelsMem])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (\Scope KernelsMem
scope' -> Scope KernelsMem
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope KernelsMem
scope' (OffsetM (Scope KernelsMem, Stm KernelsMem)
 -> OffsetM (Scope KernelsMem, Stm KernelsMem))
-> (Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem))
-> Stm KernelsMem
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem)
offsetMemoryInStm)
        Scope KernelsMem
scope
        (Stms KernelsMem -> [Stm KernelsMem]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms KernelsMem -> [Stm KernelsMem])
-> Stms KernelsMem -> [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody)
  KernelBody KernelsMem -> OffsetM (KernelBody KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody KernelsMem
kbody {kernelBodyStms :: Stms KernelsMem
kernelBodyStms = Stms KernelsMem
stms'}

offsetMemoryInBody :: Body KernelsMem -> OffsetM (Body KernelsMem)
offsetMemoryInBody :: Body KernelsMem -> OffsetM (Body KernelsMem)
offsetMemoryInBody (Body BodyDec KernelsMem
dec Stms KernelsMem
stms Result
res) = do
  Scope KernelsMem
scope <- OffsetM (Scope KernelsMem)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  Stms KernelsMem
stms' <-
    [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm KernelsMem] -> Stms KernelsMem)
-> ((Scope KernelsMem, [Stm KernelsMem]) -> [Stm KernelsMem])
-> (Scope KernelsMem, [Stm KernelsMem])
-> Stms KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope KernelsMem, [Stm KernelsMem]) -> [Stm KernelsMem]
forall a b. (a, b) -> b
snd
      ((Scope KernelsMem, [Stm KernelsMem]) -> Stms KernelsMem)
-> OffsetM (Scope KernelsMem, [Stm KernelsMem])
-> OffsetM (Stms KernelsMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scope KernelsMem
 -> Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem))
-> Scope KernelsMem
-> [Stm KernelsMem]
-> OffsetM (Scope KernelsMem, [Stm KernelsMem])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (\Scope KernelsMem
scope' -> Scope KernelsMem
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope KernelsMem
scope' (OffsetM (Scope KernelsMem, Stm KernelsMem)
 -> OffsetM (Scope KernelsMem, Stm KernelsMem))
-> (Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem))
-> Stm KernelsMem
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem)
offsetMemoryInStm)
        Scope KernelsMem
scope
        (Stms KernelsMem -> [Stm KernelsMem]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms KernelsMem
stms)
  Body KernelsMem -> OffsetM (Body KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body KernelsMem -> OffsetM (Body KernelsMem))
-> Body KernelsMem -> OffsetM (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ BodyDec KernelsMem -> Stms KernelsMem -> Result -> Body KernelsMem
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec KernelsMem
dec Stms KernelsMem
stms' Result
res

offsetMemoryInStm :: Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem)
offsetMemoryInStm :: Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem)
offsetMemoryInStm (Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
dec ExpT KernelsMem
e) = do
  PatternT (MemInfo SubExp NoUniqueness MemBind)
pat' <- Pattern KernelsMem -> OffsetM (Pattern KernelsMem)
offsetMemoryInPattern Pattern KernelsMem
pat
  ExpT KernelsMem
e' <- Scope KernelsMem
-> OffsetM (ExpT KernelsMem) -> OffsetM (ExpT KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (PatternT (MemInfo SubExp NoUniqueness MemBind) -> Scope KernelsMem
forall lore dec. (LetDec lore ~ dec) => PatternT dec -> Scope lore
scopeOfPattern PatternT (MemInfo SubExp NoUniqueness MemBind)
pat') (OffsetM (ExpT KernelsMem) -> OffsetM (ExpT KernelsMem))
-> OffsetM (ExpT KernelsMem) -> OffsetM (ExpT KernelsMem)
forall a b. (a -> b) -> a -> b
$ ExpT KernelsMem -> OffsetM (ExpT KernelsMem)
offsetMemoryInExp ExpT KernelsMem
e
  Scope KernelsMem
scope <- OffsetM (Scope KernelsMem)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  -- Try to recompute the index function.  Fall back to creating rebase
  -- operations with the RebaseMap.
  [ExpReturns]
rts <- ReaderT (Scope KernelsMem) OffsetM [ExpReturns]
-> Scope KernelsMem -> OffsetM [ExpReturns]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ExpT KernelsMem -> ReaderT (Scope KernelsMem) OffsetM [ExpReturns]
forall (m :: * -> *) lore.
(Monad m, HasScope lore m, Mem lore) =>
Exp lore -> m [ExpReturns]
expReturns ExpT KernelsMem
e') Scope KernelsMem
scope
  let pat'' :: PatternT (MemInfo SubExp NoUniqueness MemBind)
pat'' =
        [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern
          (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT (MemInfo SubExp NoUniqueness MemBind)
pat')
          ((PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> ExpReturns -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [ExpReturns]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElemT (MemInfo SubExp NoUniqueness MemBind)
pick (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT (MemInfo SubExp NoUniqueness MemBind)
pat') [ExpReturns]
rts)
      stm :: Stm KernelsMem
stm = Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat'' StmAux (ExpDec KernelsMem)
dec ExpT KernelsMem
e'
  let scope' :: Scope KernelsMem
scope' = Stm KernelsMem -> Scope KernelsMem
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stm KernelsMem
stm Scope KernelsMem -> Scope KernelsMem -> Scope KernelsMem
forall a. Semigroup a => a -> a -> a
<> Scope KernelsMem
scope
  (Scope KernelsMem, Stm KernelsMem)
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scope KernelsMem
scope', Stm KernelsMem
stm)
  where
    pick ::
      PatElemT (MemInfo SubExp NoUniqueness MemBind) ->
      ExpReturns ->
      PatElemT (MemInfo SubExp NoUniqueness MemBind)
    pick :: PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElemT (MemInfo SubExp NoUniqueness MemBind)
pick
      (PatElem VName
name (MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u MemBind
_ret))
      (MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ (Just (ReturnsInBlock VName
m ExtIxFun
extixfun)))
        | Just IxFun (TPrimExp Int64 VName)
ixfun <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun ExtIxFun
extixfun =
          VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
name (PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u (VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int64 VName)
ixfun))
    pick PatElemT (MemInfo SubExp NoUniqueness MemBind)
p ExpReturns
_ = PatElemT (MemInfo SubExp NoUniqueness MemBind)
p

    instantiateIxFun :: ExtIxFun -> Maybe IxFun
    instantiateIxFun :: ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun = (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Ext VName -> Maybe VName)
-> TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> Maybe VName
forall {a}. Ext a -> Maybe a
inst)
      where
        inst :: Ext a -> Maybe a
inst Ext {} = Maybe a
forall a. Maybe a
Nothing
        inst (Free a
x) = a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

offsetMemoryInPattern :: Pattern KernelsMem -> OffsetM (Pattern KernelsMem)
offsetMemoryInPattern :: Pattern KernelsMem -> OffsetM (Pattern KernelsMem)
offsetMemoryInPattern (Pattern [PatElemT (LetDec KernelsMem)]
ctx [PatElemT (LetDec KernelsMem)]
vals) = do
  (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> OffsetM ())
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)] -> OffsetM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> OffsetM ()
forall {dec} {m :: * -> *}.
(Typed dec, MonadError [Char] m) =>
PatElemT dec -> m ()
inspectCtx [PatElemT (LetDec KernelsMem)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctx
  [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec KernelsMem)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctx ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
 -> PatternT (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> OffsetM (PatternT (MemInfo SubExp NoUniqueness MemBind))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> OffsetM (PatElemT (MemInfo SubExp NoUniqueness MemBind)))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> OffsetM [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall {u}.
PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
inspectVal [PatElemT (LetDec KernelsMem)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
vals
  where
    inspectVal :: PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
inspectVal PatElemT (MemBound u)
patElem = do
      MemBound u
new_dec <- MemBound u -> OffsetM (MemBound u)
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ PatElemT (MemBound u) -> MemBound u
forall dec. PatElemT dec -> dec
patElemDec PatElemT (MemBound u)
patElem
      PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
forall (m :: * -> *) a. Monad m => a -> m a
return PatElemT (MemBound u)
patElem {patElemDec :: MemBound u
patElemDec = MemBound u
new_dec}
    inspectCtx :: PatElemT dec -> m ()
inspectCtx PatElemT dec
patElem
      | Mem Space
space <- PatElemT dec -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElemT dec
patElem,
        Space -> Bool
expandable Space
space =
        [Char] -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Char] -> m ()) -> [Char] -> m ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords
            [ [Char]
"Cannot deal with existential memory block",
              VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
patElem),
              [Char]
"when expanding inside kernels."
            ]
      | Bool
otherwise = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

offsetMemoryInParam :: Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam :: forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam Param (MemBound u)
fparam = do
  MemBound u
fparam' <- MemBound u -> OffsetM (MemBound u)
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ Param (MemBound u) -> MemBound u
forall dec. Param dec -> dec
paramDec Param (MemBound u)
fparam
  Param (MemBound u) -> OffsetM (Param (MemBound u))
forall (m :: * -> *) a. Monad m => a -> m a
return Param (MemBound u)
fparam {paramDec :: MemBound u
paramDec = MemBound u
fparam'}

offsetMemoryInMemBound :: MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound :: forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound summary :: MemBound u
summary@(MemArray PrimType
pt ShapeBase SubExp
shape u
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun)) = do
  Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun, PrimType
pt)
  MemBound u -> OffsetM (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$
    MemBound u -> Maybe (MemBound u) -> MemBound u
forall a. a -> Maybe a -> a
fromMaybe MemBound u
summary (Maybe (MemBound u) -> MemBound u)
-> Maybe (MemBound u) -> MemBound u
forall a b. (a -> b) -> a -> b
$ do
      IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
      MemBound u -> Maybe (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemBound u -> Maybe (MemBound u))
-> MemBound u -> Maybe (MemBound u)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase SubExp -> u -> MemBind -> MemBound u
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape u
u (MemBind -> MemBound u) -> MemBind -> MemBound u
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (IxFun (TPrimExp Int64 VName) -> MemBind)
-> IxFun (TPrimExp Int64 VName) -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName)
-> IxFun (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase IxFun (TPrimExp Int64 VName)
new_base' IxFun (TPrimExp Int64 VName)
ixfun
offsetMemoryInMemBound MemBound u
summary = MemBound u -> OffsetM (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return MemBound u
summary

offsetMemoryInBodyReturns :: BodyReturns -> OffsetM BodyReturns
offsetMemoryInBodyReturns :: BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns br :: BranchTypeMem
br@(MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (ReturnsInBlock VName
mem ExtIxFun
ixfun))
  | Just IxFun (TPrimExp Int64 VName)
ixfun' <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
isStaticIxFun ExtIxFun
ixfun = do
    Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun', PrimType
pt)
    BranchTypeMem -> OffsetM BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return (BranchTypeMem -> OffsetM BranchTypeMem)
-> BranchTypeMem -> OffsetM BranchTypeMem
forall a b. (a -> b) -> a -> b
$
      BranchTypeMem -> Maybe BranchTypeMem -> BranchTypeMem
forall a. a -> Maybe a -> a
fromMaybe BranchTypeMem
br (Maybe BranchTypeMem -> BranchTypeMem)
-> Maybe BranchTypeMem -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$ do
        IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
        BranchTypeMem -> Maybe BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return (BranchTypeMem -> Maybe BranchTypeMem)
-> BranchTypeMem -> Maybe BranchTypeMem
forall a b. (a -> b) -> a -> b
$
          PrimType
-> ShapeBase ExtSize -> NoUniqueness -> MemReturn -> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (MemReturn -> BranchTypeMem) -> MemReturn -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$
            VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
              ExtIxFun -> ExtIxFun -> ExtIxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase ((TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun (TPrimExp Int64 VName) -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free) IxFun (TPrimExp Int64 VName)
new_base') ExtIxFun
ixfun
offsetMemoryInBodyReturns BranchTypeMem
br = BranchTypeMem -> OffsetM BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return BranchTypeMem
br

offsetMemoryInLambda :: Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
offsetMemoryInLambda :: Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
offsetMemoryInLambda Lambda KernelsMem
lam = Lambda KernelsMem
-> OffsetM (Lambda KernelsMem) -> OffsetM (Lambda KernelsMem)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Lambda KernelsMem
lam (OffsetM (Lambda KernelsMem) -> OffsetM (Lambda KernelsMem))
-> OffsetM (Lambda KernelsMem) -> OffsetM (Lambda KernelsMem)
forall a b. (a -> b) -> a -> b
$ do
  Body KernelsMem
body <- Body KernelsMem -> OffsetM (Body KernelsMem)
offsetMemoryInBody (Body KernelsMem -> OffsetM (Body KernelsMem))
-> Body KernelsMem -> OffsetM (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> Body KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
lam
  Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda KernelsMem -> OffsetM (Lambda KernelsMem))
-> Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem
lam {lambdaBody :: Body KernelsMem
lambdaBody = Body KernelsMem
body}

offsetMemoryInExp :: Exp KernelsMem -> OffsetM (Exp KernelsMem)
offsetMemoryInExp :: ExpT KernelsMem -> OffsetM (ExpT KernelsMem)
offsetMemoryInExp (DoLoop [(FParam KernelsMem, SubExp)]
ctx [(FParam KernelsMem, SubExp)]
val LoopForm KernelsMem
form Body KernelsMem
body) = do
  let ([Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams, Result
ctxinit) = [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> ([Param (MemInfo SubExp Uniqueness MemBind)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam KernelsMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx
      ([Param (MemInfo SubExp Uniqueness MemBind)]
valparams, Result
valinit) = [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> ([Param (MemInfo SubExp Uniqueness MemBind)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam KernelsMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val
  [Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams' <- (Param (MemInfo SubExp Uniqueness MemBind)
 -> OffsetM (Param (MemInfo SubExp Uniqueness MemBind)))
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo SubExp Uniqueness MemBind)
-> OffsetM (Param (MemInfo SubExp Uniqueness MemBind))
forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams
  [Param (MemInfo SubExp Uniqueness MemBind)]
valparams' <- (Param (MemInfo SubExp Uniqueness MemBind)
 -> OffsetM (Param (MemInfo SubExp Uniqueness MemBind)))
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo SubExp Uniqueness MemBind)
-> OffsetM (Param (MemInfo SubExp Uniqueness MemBind))
forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemInfo SubExp Uniqueness MemBind)]
valparams
  Body KernelsMem
body' <- Scope KernelsMem
-> OffsetM (Body KernelsMem) -> OffsetM (Body KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (MemInfo SubExp Uniqueness MemBind)] -> Scope KernelsMem
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams' Scope KernelsMem -> Scope KernelsMem -> Scope KernelsMem
forall a. Semigroup a => a -> a -> a
<> [Param (MemInfo SubExp Uniqueness MemBind)] -> Scope KernelsMem
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (MemInfo SubExp Uniqueness MemBind)]
valparams' Scope KernelsMem -> Scope KernelsMem -> Scope KernelsMem
forall a. Semigroup a => a -> a -> a
<> LoopForm KernelsMem -> Scope KernelsMem
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm KernelsMem
form) (Body KernelsMem -> OffsetM (Body KernelsMem)
offsetMemoryInBody Body KernelsMem
body)
  ExpT KernelsMem -> OffsetM (ExpT KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT KernelsMem -> OffsetM (ExpT KernelsMem))
-> ExpT KernelsMem -> OffsetM (ExpT KernelsMem)
forall a b. (a -> b) -> a -> b
$ [(FParam KernelsMem, SubExp)]
-> [(FParam KernelsMem, SubExp)]
-> LoopForm KernelsMem
-> Body KernelsMem
-> ExpT KernelsMem
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop ([Param (MemInfo SubExp Uniqueness MemBind)]
-> Result -> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams' Result
ctxinit) ([Param (MemInfo SubExp Uniqueness MemBind)]
-> Result -> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp Uniqueness MemBind)]
valparams' Result
valinit) LoopForm KernelsMem
form Body KernelsMem
body'
offsetMemoryInExp ExpT KernelsMem
e = Mapper KernelsMem KernelsMem OffsetM
-> ExpT KernelsMem -> OffsetM (ExpT KernelsMem)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper KernelsMem KernelsMem OffsetM
recurse ExpT KernelsMem
e
  where
    recurse :: Mapper KernelsMem KernelsMem OffsetM
recurse =
      Mapper KernelsMem KernelsMem OffsetM
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope KernelsMem -> Body KernelsMem -> OffsetM (Body KernelsMem)
mapOnBody = \Scope KernelsMem
bscope -> Scope KernelsMem
-> OffsetM (Body KernelsMem) -> OffsetM (Body KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope KernelsMem
bscope (OffsetM (Body KernelsMem) -> OffsetM (Body KernelsMem))
-> (Body KernelsMem -> OffsetM (Body KernelsMem))
-> Body KernelsMem
-> OffsetM (Body KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body KernelsMem -> OffsetM (Body KernelsMem)
offsetMemoryInBody,
          mapOnBranchType :: BranchType KernelsMem -> OffsetM (BranchType KernelsMem)
mapOnBranchType = BranchType KernelsMem -> OffsetM (BranchType KernelsMem)
BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns,
          mapOnOp :: Op KernelsMem -> OffsetM (Op KernelsMem)
mapOnOp = Op KernelsMem -> OffsetM (Op KernelsMem)
forall {op}.
MemOp (HostOp KernelsMem op)
-> OffsetM (MemOp (HostOp KernelsMem op))
onOp
        }
    onOp :: MemOp (HostOp KernelsMem op)
-> OffsetM (MemOp (HostOp KernelsMem op))
onOp (Inner (SegOp SegOp SegLevel KernelsMem
op)) =
      HostOp KernelsMem op -> MemOp (HostOp KernelsMem op)
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem op -> MemOp (HostOp KernelsMem op))
-> (SegOp SegLevel KernelsMem -> HostOp KernelsMem op)
-> SegOp SegLevel KernelsMem
-> MemOp (HostOp KernelsMem op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel KernelsMem -> HostOp KernelsMem op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp
        (SegOp SegLevel KernelsMem -> MemOp (HostOp KernelsMem op))
-> OffsetM (SegOp SegLevel KernelsMem)
-> OffsetM (MemOp (HostOp KernelsMem op))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope KernelsMem
-> OffsetM (SegOp SegLevel KernelsMem)
-> OffsetM (SegOp SegLevel KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope KernelsMem
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegOp SegLevel KernelsMem -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel KernelsMem
op)) (SegOpMapper SegLevel KernelsMem KernelsMem OffsetM
-> SegOp SegLevel KernelsMem -> OffsetM (SegOp SegLevel KernelsMem)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper SegLevel KernelsMem KernelsMem OffsetM
forall {lvl}. SegOpMapper lvl KernelsMem KernelsMem OffsetM
segOpMapper SegOp SegLevel KernelsMem
op)
      where
        segOpMapper :: SegOpMapper lvl KernelsMem KernelsMem OffsetM
segOpMapper =
          SegOpMapper lvl Any Any OffsetM
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
            { mapOnSegOpBody :: KernelBody KernelsMem -> OffsetM (KernelBody KernelsMem)
mapOnSegOpBody = KernelBody KernelsMem -> OffsetM (KernelBody KernelsMem)
offsetMemoryInKernelBody,
              mapOnSegOpLambda :: Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
mapOnSegOpLambda = Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
offsetMemoryInLambda
            }
    onOp MemOp (HostOp KernelsMem op)
op = MemOp (HostOp KernelsMem op)
-> OffsetM (MemOp (HostOp KernelsMem op))
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp (HostOp KernelsMem op)
op

---- Slicing allocation sizes out of a kernel.

unAllocKernelsStms :: Stms KernelsMem -> Either String (Stms Kernels.Kernels)
unAllocKernelsStms :: Stms KernelsMem -> Either [Char] (Stms Kernels)
unAllocKernelsStms = Bool -> Stms KernelsMem -> Either [Char] (Stms Kernels)
unAllocStms Bool
False
  where
    unAllocBody :: Body KernelsMem -> Either [Char] (BodyT Kernels)
unAllocBody (Body BodyDec KernelsMem
dec Stms KernelsMem
stms Result
res) =
      BodyDec Kernels -> Stms Kernels -> Result -> BodyT Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec Kernels
BodyDec KernelsMem
dec (Stms Kernels -> Result -> BodyT Kernels)
-> Either [Char] (Stms Kernels)
-> Either [Char] (Result -> BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms KernelsMem -> Either [Char] (Stms Kernels)
unAllocStms Bool
True Stms KernelsMem
stms Either [Char] (Result -> BodyT Kernels)
-> Either [Char] Result -> Either [Char] (BodyT Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> Either [Char] Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

    unAllocKernelBody :: KernelBody KernelsMem -> Either [Char] (KernelBody Kernels)
unAllocKernelBody (KernelBody BodyDec KernelsMem
dec Stms KernelsMem
stms [KernelResult]
res) =
      BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec Kernels
BodyDec KernelsMem
dec (Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> Either [Char] (Stms Kernels)
-> Either [Char] ([KernelResult] -> KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms KernelsMem -> Either [Char] (Stms Kernels)
unAllocStms Bool
True Stms KernelsMem
stms Either [Char] ([KernelResult] -> KernelBody Kernels)
-> Either [Char] [KernelResult]
-> Either [Char] (KernelBody Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> Either [Char] [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

    unAllocStms :: Bool -> Stms KernelsMem -> Either [Char] (Stms Kernels)
unAllocStms Bool
nested =
      ([Maybe (Stm Kernels)] -> Stms Kernels)
-> Either [Char] [Maybe (Stm Kernels)]
-> Either [Char] (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels)
-> ([Maybe (Stm Kernels)] -> [Stm Kernels])
-> [Maybe (Stm Kernels)]
-> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (Stm Kernels)] -> [Stm Kernels]
forall a. [Maybe a] -> [a]
catMaybes) (Either [Char] [Maybe (Stm Kernels)]
 -> Either [Char] (Stms Kernels))
-> (Stms KernelsMem -> Either [Char] [Maybe (Stm Kernels)])
-> Stms KernelsMem
-> Either [Char] (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm KernelsMem -> Either [Char] (Maybe (Stm Kernels)))
-> [Stm KernelsMem] -> Either [Char] [Maybe (Stm Kernels)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Bool -> Stm KernelsMem -> Either [Char] (Maybe (Stm Kernels))
unAllocStm Bool
nested) ([Stm KernelsMem] -> Either [Char] [Maybe (Stm Kernels)])
-> (Stms KernelsMem -> [Stm KernelsMem])
-> Stms KernelsMem
-> Either [Char] [Maybe (Stm Kernels)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms KernelsMem -> [Stm KernelsMem]
forall lore. Stms lore -> [Stm lore]
stmsToList

    unAllocStm :: Bool -> Stm KernelsMem -> Either [Char] (Maybe (Stm Kernels))
unAllocStm Bool
nested stm :: Stm KernelsMem
stm@(Let Pattern KernelsMem
_ StmAux (ExpDec KernelsMem)
_ (Op Alloc {}))
      | Bool
nested = [Char] -> Either [Char] (Maybe (Stm Kernels))
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Char] -> Either [Char] (Maybe (Stm Kernels)))
-> [Char] -> Either [Char] (Maybe (Stm Kernels))
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot handle nested allocation: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Stm KernelsMem -> [Char]
forall a. Pretty a => a -> [Char]
pretty Stm KernelsMem
stm
      | Bool
otherwise = Maybe (Stm Kernels) -> Either [Char] (Maybe (Stm Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stm Kernels)
forall a. Maybe a
Nothing
    unAllocStm Bool
_ (Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
dec ExpT KernelsMem
e) =
      Stm Kernels -> Maybe (Stm Kernels)
forall a. a -> Maybe a
Just (Stm Kernels -> Maybe (Stm Kernels))
-> Either [Char] (Stm Kernels)
-> Either [Char] (Maybe (Stm Kernels))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StmAux () -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> StmAux () -> ExpT Kernels -> Stm Kernels)
-> Either
     [Char] (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness))
-> Either [Char] (StmAux () -> ExpT Kernels -> Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PatternT (MemInfo SubExp NoUniqueness MemBind)
-> Either
     [Char] (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness))
forall {d} {u} {ret}.
(Pretty d, Pretty u, Pretty ret,
 Pretty (TypeBase (ShapeBase d) u)) =>
PatternT (MemInfo d u ret)
-> Either [Char] (PatternT (TypeBase (ShapeBase d) u))
unAllocPattern Pattern KernelsMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat Either [Char] (StmAux () -> ExpT Kernels -> Stm Kernels)
-> Either [Char] (StmAux ())
-> Either [Char] (ExpT Kernels -> Stm Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StmAux () -> Either [Char] (StmAux ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure StmAux ()
StmAux (ExpDec KernelsMem)
dec Either [Char] (ExpT Kernels -> Stm Kernels)
-> Either [Char] (ExpT Kernels) -> Either [Char] (Stm Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Mapper KernelsMem Kernels (Either [Char])
-> ExpT KernelsMem -> Either [Char] (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper KernelsMem Kernels (Either [Char])
unAlloc' ExpT KernelsMem
e)

    unAllocLambda :: Lambda KernelsMem -> Either [Char] (Lambda Kernels)
unAllocLambda (Lambda [LParam KernelsMem]
params Body KernelsMem
body [TypeBase (ShapeBase SubExp) NoUniqueness]
ret) =
      [LParam Kernels]
-> BodyT Kernels
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda Kernels
forall lore.
[LParam lore]
-> BodyT lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT lore
Lambda ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall {d} {u} {ret}.
[Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
unParams [LParam KernelsMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
params) (BodyT Kernels
 -> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda Kernels)
-> Either [Char] (BodyT Kernels)
-> Either
     [Char]
     ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body KernelsMem -> Either [Char] (BodyT Kernels)
unAllocBody Body KernelsMem
body Either
  [Char]
  ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda Kernels)
-> Either [Char] [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Either [Char] (Lambda Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Either [Char] [TypeBase (ShapeBase SubExp) NoUniqueness]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase (ShapeBase SubExp) NoUniqueness]
ret

    unParams :: [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
unParams = (Param (MemInfo d u ret)
 -> Maybe (Param (TypeBase (ShapeBase d) u)))
-> [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((Param (MemInfo d u ret)
  -> Maybe (Param (TypeBase (ShapeBase d) u)))
 -> [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)])
-> (Param (MemInfo d u ret)
    -> Maybe (Param (TypeBase (ShapeBase d) u)))
-> [Param (MemInfo d u ret)]
-> [Param (TypeBase (ShapeBase d) u)]
forall a b. (a -> b) -> a -> b
$ (MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> Param (MemInfo d u ret)
-> Maybe (Param (TypeBase (ShapeBase d) u))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem

    unAllocPattern :: PatternT (MemInfo d u ret)
-> Either [Char] (PatternT (TypeBase (ShapeBase d) u))
unAllocPattern pat :: PatternT (MemInfo d u ret)
pat@(Pattern [PatElemT (MemInfo d u ret)]
ctx [PatElemT (MemInfo d u ret)]
val) =
      [PatElemT (TypeBase (ShapeBase d) u)]
-> [PatElemT (TypeBase (ShapeBase d) u)]
-> PatternT (TypeBase (ShapeBase d) u)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern ([PatElemT (TypeBase (ShapeBase d) u)]
 -> [PatElemT (TypeBase (ShapeBase d) u)]
 -> PatternT (TypeBase (ShapeBase d) u))
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
-> Either
     [Char]
     ([PatElemT (TypeBase (ShapeBase d) u)]
      -> PatternT (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
-> ([PatElemT (TypeBase (ShapeBase d) u)]
    -> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)])
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
bad [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (MemInfo d u ret)
 -> Maybe (PatElemT (TypeBase (ShapeBase d) u)))
-> [PatElemT (MemInfo d u ret)]
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> PatElemT (MemInfo d u ret)
-> Maybe (PatElemT (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElemT from -> m (PatElemT to)
rephrasePatElem MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem) [PatElemT (MemInfo d u ret)]
ctx)
        Either
  [Char]
  ([PatElemT (TypeBase (ShapeBase d) u)]
   -> PatternT (TypeBase (ShapeBase d) u))
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] (PatternT (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
-> ([PatElemT (TypeBase (ShapeBase d) u)]
    -> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)])
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
bad [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (MemInfo d u ret)
 -> Maybe (PatElemT (TypeBase (ShapeBase d) u)))
-> [PatElemT (MemInfo d u ret)]
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> PatElemT (MemInfo d u ret)
-> Maybe (PatElemT (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElemT from -> m (PatElemT to)
rephrasePatElem MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem) [PatElemT (MemInfo d u ret)]
val)
      where
        bad :: Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
bad = [Char] -> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall a b. a -> Either a b
Left ([Char] -> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)])
-> [Char] -> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot handle memory in pattern " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo d u ret) -> [Char]
forall a. Pretty a => a -> [Char]
pretty PatternT (MemInfo d u ret)
pat

    unAllocOp :: MemOp (HostOp KernelsMem ())
-> Either [Char] (HostOp Kernels (SOAC Kernels))
unAllocOp Alloc {} = [Char] -> Either [Char] (HostOp Kernels (SOAC Kernels))
forall a b. a -> Either a b
Left [Char]
"unAllocOp: unhandled Alloc"
    unAllocOp (Inner OtherOp {}) = [Char] -> Either [Char] (HostOp Kernels (SOAC Kernels))
forall a b. a -> Either a b
Left [Char]
"unAllocOp: unhandled OtherOp"
    unAllocOp (Inner (SizeOp SizeOp
op)) =
      HostOp Kernels (SOAC Kernels)
-> Either [Char] (HostOp Kernels (SOAC Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (HostOp Kernels (SOAC Kernels)
 -> Either [Char] (HostOp Kernels (SOAC Kernels)))
-> HostOp Kernels (SOAC Kernels)
-> Either [Char] (HostOp Kernels (SOAC Kernels))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
    unAllocOp (Inner (SegOp SegOp SegLevel KernelsMem
op)) = SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> Either [Char] (SegOp SegLevel Kernels)
-> Either [Char] (HostOp Kernels (SOAC Kernels))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel KernelsMem Kernels (Either [Char])
-> SegOp SegLevel KernelsMem
-> Either [Char] (SegOp SegLevel Kernels)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper SegLevel KernelsMem Kernels (Either [Char])
mapper SegOp SegLevel KernelsMem
op
      where
        mapper :: SegOpMapper SegLevel KernelsMem Kernels (Either [Char])
mapper =
          SegOpMapper SegLevel Any Any (Either [Char])
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
            { mapOnSegOpLambda :: Lambda KernelsMem -> Either [Char] (Lambda Kernels)
mapOnSegOpLambda = Lambda KernelsMem -> Either [Char] (Lambda Kernels)
unAllocLambda,
              mapOnSegOpBody :: KernelBody KernelsMem -> Either [Char] (KernelBody Kernels)
mapOnSegOpBody = KernelBody KernelsMem -> Either [Char] (KernelBody Kernels)
unAllocKernelBody
            }

    unParam :: t (MemInfo d u ret) -> Either [Char] (t (TypeBase (ShapeBase d) u))
unParam t (MemInfo d u ret)
p = Either [Char] (t (TypeBase (ShapeBase d) u))
-> (t (TypeBase (ShapeBase d) u)
    -> Either [Char] (t (TypeBase (ShapeBase d) u)))
-> Maybe (t (TypeBase (ShapeBase d) u))
-> Either [Char] (t (TypeBase (ShapeBase d) u))
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either [Char] (t (TypeBase (ShapeBase d) u))
bad t (TypeBase (ShapeBase d) u)
-> Either [Char] (t (TypeBase (ShapeBase d) u))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (t (TypeBase (ShapeBase d) u))
 -> Either [Char] (t (TypeBase (ShapeBase d) u)))
-> Maybe (t (TypeBase (ShapeBase d) u))
-> Either [Char] (t (TypeBase (ShapeBase d) u))
forall a b. (a -> b) -> a -> b
$ (MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> t (MemInfo d u ret) -> Maybe (t (TypeBase (ShapeBase d) u))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem t (MemInfo d u ret)
p
      where
        bad :: Either [Char] (t (TypeBase (ShapeBase d) u))
bad = [Char] -> Either [Char] (t (TypeBase (ShapeBase d) u))
forall a b. a -> Either a b
Left ([Char] -> Either [Char] (t (TypeBase (ShapeBase d) u)))
-> [Char] -> Either [Char] (t (TypeBase (ShapeBase d) u))
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot handle memory-typed parameter '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ t (MemInfo d u ret) -> [Char]
forall a. Pretty a => a -> [Char]
pretty t (MemInfo d u ret)
p [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"'"

    unT :: MemInfo d u ret -> Either [Char] (TypeBase (ShapeBase d) u)
unT MemInfo d u ret
t = Either [Char] (TypeBase (ShapeBase d) u)
-> (TypeBase (ShapeBase d) u
    -> Either [Char] (TypeBase (ShapeBase d) u))
-> Maybe (TypeBase (ShapeBase d) u)
-> Either [Char] (TypeBase (ShapeBase d) u)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either [Char] (TypeBase (ShapeBase d) u)
bad TypeBase (ShapeBase d) u
-> Either [Char] (TypeBase (ShapeBase d) u)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (TypeBase (ShapeBase d) u)
 -> Either [Char] (TypeBase (ShapeBase d) u))
-> Maybe (TypeBase (ShapeBase d) u)
-> Either [Char] (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem MemInfo d u ret
t
      where
        bad :: Either [Char] (TypeBase (ShapeBase d) u)
bad = [Char] -> Either [Char] (TypeBase (ShapeBase d) u)
forall a b. a -> Either a b
Left ([Char] -> Either [Char] (TypeBase (ShapeBase d) u))
-> [Char] -> Either [Char] (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot handle memory type '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ MemInfo d u ret -> [Char]
forall a. Pretty a => a -> [Char]
pretty MemInfo d u ret
t [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"'"

    unAlloc' :: Mapper KernelsMem Kernels (Either [Char])
unAlloc' =
      Mapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope tlore -> Body flore -> m (Body tlore))
-> (VName -> m VName)
-> (RetType flore -> m (RetType tlore))
-> (BranchType flore -> m (BranchType tlore))
-> (FParam flore -> m (FParam tlore))
-> (LParam flore -> m (LParam tlore))
-> (Op flore -> m (Op tlore))
-> Mapper flore tlore m
Mapper
        { mapOnBody :: Scope Kernels -> Body KernelsMem -> Either [Char] (BodyT Kernels)
mapOnBody = (Body KernelsMem -> Either [Char] (BodyT Kernels))
-> Scope Kernels
-> Body KernelsMem
-> Either [Char] (BodyT Kernels)
forall a b. a -> b -> a
const Body KernelsMem -> Either [Char] (BodyT Kernels)
unAllocBody,
          mapOnRetType :: RetType KernelsMem -> Either [Char] (RetType Kernels)
mapOnRetType = RetType KernelsMem -> Either [Char] (RetType Kernels)
forall {d} {u} {ret}.
(Pretty d, Pretty u, Pretty ret,
 Pretty (TypeBase (ShapeBase d) u)) =>
MemInfo d u ret -> Either [Char] (TypeBase (ShapeBase d) u)
unT,
          mapOnBranchType :: BranchType KernelsMem -> Either [Char] (BranchType Kernels)
mapOnBranchType = BranchType KernelsMem -> Either [Char] (BranchType Kernels)
forall {d} {u} {ret}.
(Pretty d, Pretty u, Pretty ret,
 Pretty (TypeBase (ShapeBase d) u)) =>
MemInfo d u ret -> Either [Char] (TypeBase (ShapeBase d) u)
unT,
          mapOnFParam :: FParam KernelsMem -> Either [Char] (FParam Kernels)
mapOnFParam = FParam KernelsMem -> Either [Char] (FParam Kernels)
forall {t :: * -> *} {d} {u} {ret}.
(Pretty (t (MemInfo d u ret)), Traversable t) =>
t (MemInfo d u ret) -> Either [Char] (t (TypeBase (ShapeBase d) u))
unParam,
          mapOnLParam :: LParam KernelsMem -> Either [Char] (LParam Kernels)
mapOnLParam = LParam KernelsMem -> Either [Char] (LParam Kernels)
forall {t :: * -> *} {d} {u} {ret}.
(Pretty (t (MemInfo d u ret)), Traversable t) =>
t (MemInfo d u ret) -> Either [Char] (t (TypeBase (ShapeBase d) u))
unParam,
          mapOnOp :: Op KernelsMem -> Either [Char] (Op Kernels)
mapOnOp = Op KernelsMem -> Either [Char] (Op Kernels)
MemOp (HostOp KernelsMem ())
-> Either [Char] (HostOp Kernels (SOAC Kernels))
unAllocOp,
          mapOnSubExp :: SubExp -> Either [Char] SubExp
mapOnSubExp = SubExp -> Either [Char] SubExp
forall a b. b -> Either a b
Right,
          mapOnVName :: VName -> Either [Char] VName
mapOnVName = VName -> Either [Char] VName
forall a b. b -> Either a b
Right
        }

unMem :: MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem :: forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem (MemPrim PrimType
pt) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
unMem (MemArray PrimType
pt ShapeBase d
shape u
u ret
_) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase d -> u -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase d
shape u
u
unMem MemMem {} = Maybe (TypeBase (ShapeBase d) u)
forall a. Maybe a
Nothing

unAllocScope :: Scope KernelsMem -> Scope Kernels.Kernels
unAllocScope :: Scope KernelsMem -> Scope Kernels
unAllocScope = (NameInfo KernelsMem -> Maybe (NameInfo Kernels))
-> Scope KernelsMem -> Scope Kernels
forall a b k. (a -> Maybe b) -> Map k a -> Map k b
M.mapMaybe NameInfo KernelsMem -> Maybe (NameInfo Kernels)
forall {lore} {d} {u} {d} {u} {d} {u} {lore} {ret} {ret} {ret}.
(LParamInfo lore ~ TypeBase (ShapeBase d) u,
 FParamInfo lore ~ TypeBase (ShapeBase d) u,
 LetDec lore ~ TypeBase (ShapeBase d) u,
 LetDec lore ~ MemInfo d u ret, FParamInfo lore ~ MemInfo d u ret,
 LParamInfo lore ~ MemInfo d u ret) =>
NameInfo lore -> Maybe (NameInfo lore)
unInfo
  where
    unInfo :: NameInfo lore -> Maybe (NameInfo lore)
unInfo (LetName LetDec lore
dec) = TypeBase (ShapeBase d) u -> NameInfo lore
forall lore. LetDec lore -> NameInfo lore
LetName (TypeBase (ShapeBase d) u -> NameInfo lore)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem LetDec lore
MemInfo d u ret
dec
    unInfo (FParamName FParamInfo lore
dec) = TypeBase (ShapeBase d) u -> NameInfo lore
forall lore. FParamInfo lore -> NameInfo lore
FParamName (TypeBase (ShapeBase d) u -> NameInfo lore)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem FParamInfo lore
MemInfo d u ret
dec
    unInfo (LParamName LParamInfo lore
dec) = TypeBase (ShapeBase d) u -> NameInfo lore
forall lore. LParamInfo lore -> NameInfo lore
LParamName (TypeBase (ShapeBase d) u -> NameInfo lore)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem LParamInfo lore
MemInfo d u ret
dec
    unInfo (IndexName IntType
it) = NameInfo lore -> Maybe (NameInfo lore)
forall a. a -> Maybe a
Just (NameInfo lore -> Maybe (NameInfo lore))
-> NameInfo lore -> Maybe (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexName IntType
it

removeCommonSizes ::
  Extraction ->
  [(SubExp, [(VName, Space)])]
removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes = Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])]
forall k a. Map k a -> [(k, a)]
M.toList (Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])])
-> (Extraction -> Map SubExp [(VName, Space)])
-> Extraction
-> [(SubExp, [(VName, Space)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map SubExp [(VName, Space)]
 -> (VName, (SegLevel, SubExp, Space))
 -> Map SubExp [(VName, Space)])
-> Map SubExp [(VName, Space)]
-> [(VName, (SegLevel, SubExp, Space))]
-> Map SubExp [(VName, Space)]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map SubExp [(VName, Space)]
-> (VName, (SegLevel, SubExp, Space))
-> Map SubExp [(VName, Space)]
forall {k} {a} {b} {a}.
Ord k =>
Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map SubExp [(VName, Space)]
forall a. Monoid a => a
mempty ([(VName, (SegLevel, SubExp, Space))]
 -> Map SubExp [(VName, Space)])
-> (Extraction -> [(VName, (SegLevel, SubExp, Space))])
-> Extraction
-> Map SubExp [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Extraction -> [(VName, (SegLevel, SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList
  where
    comb :: Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map k [(a, b)]
m (a
mem, (a
_, k
size, b
space)) = ([(a, b)] -> [(a, b)] -> [(a, b)])
-> k -> [(a, b)] -> Map k [(a, b)] -> Map k [(a, b)]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [(a, b)] -> [(a, b)] -> [(a, b)]
forall a. [a] -> [a] -> [a]
(++) k
size [(a
mem, b
space)] Map k [(a, b)]
m

sliceKernelSizes ::
  SubExp ->
  [SubExp] ->
  SegSpace ->
  Stms KernelsMem ->
  ExpandM (Stms Kernels.Kernels, [VName], [VName])
sliceKernelSizes :: SubExp
-> Result
-> SegSpace
-> Stms KernelsMem
-> ExpandM (Stms Kernels, [VName], [VName])
sliceKernelSizes SubExp
num_threads Result
sizes SegSpace
space Stms KernelsMem
kstms = do
  Stms Kernels
kstms' <- ([Char]
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Stms Kernels))
-> (Stms Kernels
    -> ReaderT
         (Scope KernelsMem)
         (StateT VNameSource (Either [Char]))
         (Stms Kernels))
-> Either [Char] (Stms Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms Kernels)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms Kernels)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Stms Kernels
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either [Char] (Stms Kernels)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Stms Kernels))
-> Either [Char] (Stms Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms KernelsMem -> Either [Char] (Stms Kernels)
unAllocKernelsStms Stms KernelsMem
kstms
  let num_sizes :: Int
num_sizes = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
sizes
      i64s :: [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s = Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. Int -> a -> [a]
replicate Int
num_sizes (TypeBase (ShapeBase SubExp) NoUniqueness
 -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

  Scope Kernels
kernels_scope <- (Scope KernelsMem -> Scope Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Scope Kernels)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Scope KernelsMem -> Scope Kernels
unAllocScope

  (Lambda Kernels
max_lam, Stms Kernels
_) <- (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   (Lambda Kernels)
 -> Scope Kernels
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Lambda Kernels, Stms Kernels))
-> Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Lambda Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Lambda Kernels, Stms Kernels)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
  Kernels
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
  (Lambda Kernels)
-> Scope Kernels
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Lambda Kernels, Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Scope Kernels
kernels_scope (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   (Lambda Kernels)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Lambda Kernels, Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Lambda Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Lambda Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs <- Int
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ [Char]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"x" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys <- Int
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ [Char]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"y" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms Kernels
stms) <- Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Result, Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Result, Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
 -> Scope Kernels)
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Scope Kernels
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   (Result, Stms Kernels)
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      (Result, Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Result, Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Result, Stms Kernels)
forall a b. (a -> b) -> a -> b
$
      BinderT
  Kernels
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
  Result
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Result,
      Stms
        (Lore
           (BinderT
              Kernels
              (ReaderT
                 (Scope KernelsMem) (StateT VNameSource (Either [Char]))))))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   Result
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      (Result,
       Stms
         (Lore
            (BinderT
               Kernels
               (ReaderT
                  (Scope KernelsMem) (StateT VNameSource (Either [Char])))))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     Result
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Result,
      Stms
        (Lore
           (BinderT
              Kernels
              (ReaderT
                 (Scope KernelsMem) (StateT VNameSource (Either [Char]))))))
forall a b. (a -> b) -> a -> b
$
        [(Param (TypeBase (ShapeBase SubExp) NoUniqueness),
  Param (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> ((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
     Param (TypeBase (ShapeBase SubExp) NoUniqueness))
    -> BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
         SubExp)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness),
     Param (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
   Param (TypeBase (ShapeBase SubExp) NoUniqueness))
  -> BinderT
       Kernels
       (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
       SubExp)
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      Result)
-> ((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
     Param (TypeBase (ShapeBase SubExp) NoUniqueness))
    -> BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
         SubExp)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     Result
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x, Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y) ->
          [Char]
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"z" (Exp
   (Lore
      (BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))))
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      SubExp)
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
Int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y)
    Lambda Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Lambda Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      (Lambda Kernels))
-> Lambda Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Lambda Kernels)
forall a b. (a -> b) -> a -> b
$ [LParam Kernels]
-> BodyT Kernels
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda Kernels
forall lore.
[LParam lore]
-> BodyT lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT lore
Lambda ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (Stms Kernels -> Result -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms Kernels
stms Result
zs) [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s

  Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam <- VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> Param dec
Param (VName
 -> TypeBase (ShapeBase SubExp) NoUniqueness
 -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either [Char])) VName
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (TypeBase (ShapeBase SubExp) NoUniqueness
      -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char]
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either [Char])) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"flat_gtid" ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either [Char]))
  (TypeBase (ShapeBase SubExp) NoUniqueness
   -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TypeBase (ShapeBase SubExp) NoUniqueness
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (TypeBase (ShapeBase SubExp) NoUniqueness)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
Int64))

  (Lambda Kernels
size_lam', Stms Kernels
_) <- (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   (Lambda Kernels)
 -> Scope Kernels
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Lambda Kernels, Stms Kernels))
-> Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Lambda Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Lambda Kernels, Stms Kernels)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
  Kernels
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
  (Lambda Kernels)
-> Scope Kernels
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Lambda Kernels, Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Scope Kernels
kernels_scope (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   (Lambda Kernels)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (Lambda Kernels, Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Lambda Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (Lambda Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
params <- Int
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ [Char]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"x" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms Kernels
stms) <- Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Result, Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Result, Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope
      ( [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
params
          Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam]
      )
      (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   (Result, Stms Kernels)
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      (Result, Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Result, Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Result, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ BinderT
  Kernels
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
  Result
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Result,
      Stms
        (Lore
           (BinderT
              Kernels
              (ReaderT
                 (Scope KernelsMem) (StateT VNameSource (Either [Char]))))))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   Result
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      (Result,
       Stms
         (Lore
            (BinderT
               Kernels
               (ReaderT
                  (Scope KernelsMem) (StateT VNameSource (Either [Char])))))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     Result
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Result,
      Stms
        (Lore
           (BinderT
              Kernels
              (ReaderT
                 (Scope KernelsMem) (StateT VNameSource (Either [Char]))))))
forall a b. (a -> b) -> a -> b
$ do
        -- Even though this SegRed is one-dimensional, we need to
        -- provide indexes corresponding to the original potentially
        -- multi-dimensional construct.
        let ([VName]
kspace_gtids, Result
kspace_dims) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], Result))
-> [(VName, SubExp)] -> ([VName], Result)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
            new_inds :: [TPrimExp Int64 VName]
new_inds =
              [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
                ((SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 Result
kspace_dims)
                (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam)
        ([VName]
 -> ExpT Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      ())
-> [[VName]]
-> [ExpT Kernels]
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ [VName]
-> ExpT Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames ((VName -> [VName]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map VName -> [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
kspace_gtids) ([ExpT Kernels]
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      ())
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     [ExpT Kernels]
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (TPrimExp Int64 VName
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      (ExpT Kernels))
-> [TPrimExp Int64 VName]
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     [ExpT Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TPrimExp Int64 VName
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (ExpT Kernels)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp [TPrimExp Int64 VName]
new_inds

        (Stm Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      ())
-> Stms Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stms Kernels
kstms'
        Result
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
sizes

    Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Lambda Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Lambda Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   (Lambda Kernels)
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      (Lambda Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Lambda Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Lambda Kernels)
forall a b. (a -> b) -> a -> b
$
      Lambda Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Lambda Kernels)
forall (m :: * -> *).
(HasScope Kernels m, MonadFreshNames m) =>
Lambda Kernels -> m (Lambda Kernels)
Kernels.simplifyLambda ([LParam Kernels]
-> BodyT Kernels
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda Kernels
forall lore.
[LParam lore]
-> BodyT lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT lore
Lambda [Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam Kernels
flat_gtid_lparam] (BodyDec Kernels -> Stms Kernels -> Result -> BodyT Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () Stms Kernels
stms Result
zs) [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s)

  (([VName]
maxes_per_thread, [VName]
size_sums), Stms Kernels
slice_stms) <- (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   ([VName], [VName])
 -> Scope Kernels
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (([VName], [VName]), Stms Kernels))
-> Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     ([VName], [VName])
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (([VName], [VName]), Stms Kernels)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
  Kernels
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
  ([VName], [VName])
-> Scope Kernels
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (([VName], [VName]), Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Scope Kernels
kernels_scope (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
   ([VName], [VName])
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either [Char]))
      (([VName], [VName]), Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     ([VName], [VName])
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either [Char]))
     (([VName], [VName]), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
pat <-
      [Ident]
-> [Ident] -> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
basicPattern []
        ([Ident] -> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     [Ident]
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     Ident
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     [Ident]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM
          Int
num_sizes
          ([Char]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> TypeBase (ShapeBase SubExp) NoUniqueness -> m Ident
newIdent [Char]
"max_per_thread" (TypeBase (ShapeBase SubExp) NoUniqueness
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      Ident)
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)

    SubExp
w <-
      [Char]
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"size_slice_w"
        (ExpT Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      SubExp)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (ExpT Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> Result
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Exp
        (Lore
           (BinderT
              Kernels
              (ReaderT
                 (Scope KernelsMem) (StateT VNameSource (Either [Char]))))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (SegSpace -> Result
segSpaceDims SegSpace
space)

    VName
thread_space_iota <-
      [Char]
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp [Char]
"thread_space_iota" (Exp
   (Lore
      (BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))))
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      VName)
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
    let red_op :: SegBinOp Kernels
red_op =
          Commutativity
-> Lambda Kernels -> Result -> ShapeBase SubExp -> SegBinOp Kernels
forall lore.
Commutativity
-> Lambda lore -> Result -> ShapeBase SubExp -> SegBinOp lore
SegBinOp
            Commutativity
Commutative
            Lambda Kernels
max_lam
            (Int -> SubExp -> Result
forall a. Int -> a -> [a]
replicate Int
num_sizes (SubExp -> Result) -> SubExp -> Result
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
            ShapeBase SubExp
forall a. Monoid a => a
mempty
    SegLevel
lvl <- [Char]
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     SegLevel
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
[Char] -> m SegLevel
segThread [Char]
"segred"

    Stms Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      ())
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      (Stm Kernels))
-> Stms Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Stms Kernels)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Stm Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm
      (Stms Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      (Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Stms Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel Kernels
-> Pattern Kernels
-> SubExp
-> [SegBinOp Kernels]
-> Lambda Kernels
-> [VName]
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     (Stms Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> m (Stms lore)
nonSegRed SegOpLevel Kernels
SegLevel
lvl PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
Pattern Kernels
pat SubExp
w [SegBinOp Kernels
red_op] Lambda Kernels
size_lam' [VName
thread_space_iota]

    [VName]
size_sums <- [VName]
-> (VName
    -> BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
         VName)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
pat) ((VName
  -> BinderT
       Kernels
       (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
       VName)
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      [VName])
-> (VName
    -> BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
         VName)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     [VName]
forall a b. (a -> b) -> a -> b
$ \VName
threads_max ->
      [Char]
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m VName
letExp [Char]
"size_sum" (Exp
   (Lore
      (BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))))
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
      VName)
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Var VName
threads_max) SubExp
num_threads

    ([VName], [VName])
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either [Char])))
     ([VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
pat, [VName]
size_sums)

  (Stms Kernels, [VName], [VName])
-> ExpandM (Stms Kernels, [VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
slice_stms, [VName]
maxes_per_thread, [VName]
size_sums)