{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}

-- | A generic transformation for adding memory allocations to a
-- Futhark program.  Specialised by specific representations in
-- submodules.
module Futhark.Pass.ExplicitAllocations
  ( explicitAllocationsGeneric,
    explicitAllocationsInStmsGeneric,
    ExpHint (..),
    defaultExpHints,
    askDefaultSpace,
    Allocable,
    AllocM,
    AllocEnv (..),
    SizeSubst (..),
    allocInStms,
    allocForArray,
    simplifiable,
    arraySizeInBytesExp,
    mkLetNamesB',
    mkLetNamesB'',

    -- * Module re-exports

    --
    -- These are highly likely to be needed by any downstream
    -- users.
    module Control.Monad.Reader,
    module Futhark.MonadFreshNames,
    module Futhark.Pass,
    module Futhark.Tools,
  )
where

import Control.Monad
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor (first)
import Data.Either (partitionEithers)
import Data.Foldable (toList)
import Data.List (foldl', transpose, zip4)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Analysis.SymbolTable (IndexOp)
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR.Mem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Prop.Aliases (AliasedOp)
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Engine (SimpleOps (..))
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep (mkWiseBody)
import Futhark.Pass
import Futhark.Tools
import Futhark.Util (maybeNth, splitAt3)

type Allocable fromrep torep inner =
  ( PrettyRep fromrep,
    PrettyRep torep,
    Mem torep inner,
    LetDec torep ~ LetDecMem,
    FParamInfo fromrep ~ DeclType,
    LParamInfo fromrep ~ Type,
    BranchType fromrep ~ ExtType,
    RetType fromrep ~ DeclExtType,
    BodyDec fromrep ~ (),
    BodyDec torep ~ (),
    ExpDec torep ~ (),
    SizeSubst (inner torep),
    BuilderOps torep
  )

data AllocEnv fromrep torep = AllocEnv
  { -- | Aggressively try to reuse memory in do-loops -
    -- should be True inside kernels, False outside.
    forall fromrep torep. AllocEnv fromrep torep -> Bool
aggressiveReuse :: Bool,
    -- | When allocating memory, put it in this memory space.
    -- This is primarily used to ensure that group-wide
    -- statements store their results in local memory.
    forall fromrep torep. AllocEnv fromrep torep -> Space
allocSpace :: Space,
    -- | The set of names that are known to be constants at
    -- kernel compile time.
    forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts :: S.Set VName,
    forall fromrep torep.
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep),
    forall fromrep torep.
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
  }

-- | Monad for adding allocations to an entire program.
newtype AllocM fromrep torep a
  = AllocM (BuilderT torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a)
  deriving
    ( forall a. a -> AllocM fromrep torep a
forall {fromrep} {torep}. Functor (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
$c<* :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
*> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c*> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
liftA2 :: forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
$cliftA2 :: forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
<*> :: forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
$c<*> :: forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
pure :: forall a. a -> AllocM fromrep torep a
$cpure :: forall fromrep torep a. a -> AllocM fromrep torep a
Applicative,
      forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
$c<$ :: forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
fmap :: forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
$cfmap :: forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
Functor,
      forall a. a -> AllocM fromrep torep a
forall fromrep torep. Applicative (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> AllocM fromrep torep a
$creturn :: forall fromrep torep a. a -> AllocM fromrep torep a
>> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c>> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
>>= :: forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
$c>>= :: forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
Monad,
      AllocM fromrep torep VNameSource
VNameSource -> AllocM fromrep torep ()
forall fromrep torep. Monad (AllocM fromrep torep)
forall fromrep torep. AllocM fromrep torep VNameSource
forall fromrep torep. VNameSource -> AllocM fromrep torep ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
putNameSource :: VNameSource -> AllocM fromrep torep ()
$cputNameSource :: forall fromrep torep. VNameSource -> AllocM fromrep torep ()
getNameSource :: AllocM fromrep torep VNameSource
$cgetNameSource :: forall fromrep torep. AllocM fromrep torep VNameSource
MonadFreshNames,
      HasScope torep,
      LocalScope torep,
      MonadReader (AllocEnv fromrep torep)
    )

instance (Allocable fromrep torep inner) => MonadBuilder (AllocM fromrep torep) where
  type Rep (AllocM fromrep torep) = torep

  mkExpDecM :: Pat (LetDec (Rep (AllocM fromrep torep)))
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (ExpDec (Rep (AllocM fromrep torep)))
mkExpDecM Pat (LetDec (Rep (AllocM fromrep torep)))
_ Exp (Rep (AllocM fromrep torep))
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  mkLetNamesM :: [VName]
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (Stm (Rep (AllocM fromrep torep)))
mkLetNamesM [VName]
names Exp (Rep (AllocM fromrep torep))
e = do
    Space
def_space <- forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
    [ExpHint]
hints <- forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp (Rep (AllocM fromrep torep))
e
    Pat LParamMem
pat <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
def_space [VName]
names Exp (Rep (AllocM fromrep torep))
e [ExpHint]
hints
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat LParamMem
pat (forall dec. dec -> StmAux dec
defAux ()) Exp (Rep (AllocM fromrep torep))
e

  mkBodyM :: Stms (Rep (AllocM fromrep torep))
-> Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
mkBodyM Stms (Rep (AllocM fromrep torep))
stms Result
res = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms (Rep (AllocM fromrep torep))
stms Result
res

  addStms :: Stms (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
addStms = forall fromrep torep a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
  collectStms :: forall a.
AllocM fromrep torep a
-> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep)))
collectStms (AllocM BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) = forall fromrep torep a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m

expHints :: Exp torep -> AllocM fromrep torep [ExpHint]
expHints :: forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e = do
  Exp torep -> AllocM fromrep torep [ExpHint]
f <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall fromrep torep.
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints
  Exp torep -> AllocM fromrep torep [ExpHint]
f Exp torep
e

-- | The space in which we allocate memory if we have no other
-- preferences or constraints.
askDefaultSpace :: AllocM fromrep torep Space
askDefaultSpace :: forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall fromrep torep. AllocEnv fromrep torep -> Space
allocSpace

runAllocM ::
  (MonadFreshNames m) =>
  Space ->
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  AllocM fromrep torep a ->
  m a
runAllocM :: forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ forall s a. State s a -> s -> (a, s)
runState forall a b. (a -> b) -> a -> b
$ forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m forall a. Monoid a => a
mempty) AllocEnv fromrep torep
env
  where
    env :: AllocEnv fromrep torep
env =
      AllocEnv
        { aggressiveReuse :: Bool
aggressiveReuse = Bool
False,
          allocSpace :: Space
allocSpace = Space
space,
          envConsts :: Set VName
envConsts = forall a. Monoid a => a
mempty,
          allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp = Op fromrep -> AllocM fromrep torep (Op torep)
handleOp,
          envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints = Exp torep -> AllocM fromrep torep [ExpHint]
hints
        }

elemSize :: (Num a) => Type -> a
elemSize :: forall a. Num a => Type -> a
elemSize = forall a. Num a => PrimType -> a
primByteSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. TypeBase shape u -> PrimType
elemType

arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp Type
t =
  forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall a. Num a => a -> a -> a
(*) (forall a. Num a => Type -> a
elemSize Type
t) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)

arraySizeInBytesExpM :: (MonadBuilder m) => Type -> m (PrimExp VName)
arraySizeInBytesExpM :: forall (m :: * -> *). MonadBuilder m => Type -> m (PrimExp VName)
arraySizeInBytesExpM Type
t = do
  let dim_prod_i64 :: TPrimExp Int64 VName
dim_prod_i64 = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
      elm_size_i64 :: TPrimExp Int64 VName
elm_size_i64 = forall a. Num a => Type -> a
elemSize Type
t
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMax IntType
Int64) (forall v. PrimValue -> PrimExp v
ValueExp forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
0) forall a b. (a -> b) -> a -> b
$
      forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName
dim_prod_i64 forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elm_size_i64

arraySizeInBytes :: (MonadBuilder m) => Type -> m SubExp
arraySizeInBytes :: forall (m :: * -> *). MonadBuilder m => Type -> m SubExp
arraySizeInBytes = forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *). MonadBuilder m => Type -> m (PrimExp VName)
arraySizeInBytesExpM

allocForArray' ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
  Type ->
  Space ->
  m VName
allocForArray' :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
t Space
space = do
  SubExp
size <- forall (m :: * -> *). MonadBuilder m => Type -> m SubExp
arraySizeInBytes Type
t
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
size Space
space

-- | Allocate memory for a value of the given type.
allocForArray ::
  (Allocable fromrep torep inner) =>
  Type ->
  Space ->
  AllocM fromrep torep VName
allocForArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space = do
  forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
t Space
space

-- | Repair an expression that cannot be assigned an index function.
-- There is a simple remedy for this: normalise the input arrays and
-- try again.
repairExpression ::
  (Allocable fromrep torep inner) =>
  Exp torep ->
  AllocM fromrep torep (Exp torep)
repairExpression :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp torep -> AllocM fromrep torep (Exp torep)
repairExpression (BasicOp (Reshape ReshapeKind
k Shape
shape VName
v)) = do
  VName
v_mem <- forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
space <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem
  VName
v' <- forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray (forall a. a -> Maybe a
Just Space
space) VName
v
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k Shape
shape VName
v'
repairExpression Exp torep
e =
  forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"repairExpression:\n" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> String
prettyString Exp torep
e

expReturns' ::
  (Allocable fromrep torep inner) =>
  Exp torep ->
  AllocM fromrep torep ([ExpReturns], Exp torep)
expReturns' :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp torep -> AllocM fromrep torep ([ExpReturns], Exp torep)
expReturns' Exp torep
e = do
  Maybe [ExpReturns]
maybe_rts <- forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m (Maybe [ExpReturns])
expReturns Exp torep
e
  case Maybe [ExpReturns]
maybe_rts of
    Just [ExpReturns]
rts -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExpReturns]
rts, Exp torep
e)
    Maybe [ExpReturns]
Nothing -> do
      Exp torep
e' <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp torep -> AllocM fromrep torep (Exp torep)
repairExpression Exp torep
e
      let bad :: [ExpReturns]
bad =
            forall a. HasCallStack => String -> a
error forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unlines forall a b. (a -> b) -> a -> b
$
              [ String
"expReturns': impossible index transformation",
                forall a. Pretty a => a -> String
prettyString Exp torep
e,
                forall a. Pretty a => a -> String
prettyString Exp torep
e'
              ]
      [ExpReturns]
rts <- forall a. a -> Maybe a -> a
fromMaybe [ExpReturns]
bad forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m (Maybe [ExpReturns])
expReturns Exp torep
e'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExpReturns]
rts, Exp torep
e')

allocsForStm ::
  (Allocable fromrep torep inner) =>
  [Ident] ->
  Exp torep ->
  AllocM fromrep torep (Stm torep)
allocsForStm :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm [Ident]
idents Exp torep
e = do
  Space
def_space <- forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  [ExpHint]
hints <- forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e
  ([ExpReturns]
rts, Exp torep
e') <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp torep -> AllocM fromrep torep ([ExpReturns], Exp torep)
expReturns' Exp torep
e
  [PatElem LParamMem]
pes <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
idents [ExpReturns]
rts [ExpHint]
hints
  ()
dec <- forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
pes) Exp torep
e'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
pes) (forall dec. dec -> StmAux dec
defAux ()
dec) Exp torep
e'

patWithAllocations ::
  (MonadBuilder m, Mem (Rep m) inner) =>
  Space ->
  [VName] ->
  Exp (Rep m) ->
  [ExpHint] ->
  m (Pat LetDecMem)
patWithAllocations :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
def_space [VName]
names Exp (Rep m)
e [ExpHint]
hints = do
  [Type]
ts' <- forall u.
[VName]
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> [TypeBase Shape u]
instantiateShapes' [VName]
names forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp (Rep m)
e
  let idents :: [Ident]
idents = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
names [Type]
ts'
  [ExpReturns]
rts <- forall a. a -> Maybe a -> a
fromMaybe (forall a. HasCallStack => String -> a
error String
"patWithAllocations: ill-typed") forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m (Maybe [ExpReturns])
expReturns Exp (Rep m)
e
  forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
idents [ExpReturns]
rts [ExpHint]
hints

mkMissingIdents :: (MonadFreshNames m) => [Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents :: forall (m :: * -> *).
MonadFreshNames m =>
[Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
idents [ExpReturns]
rts =
  forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {f :: * -> *} {d} {u} {ret}.
MonadFreshNames f =>
MemInfo d u ret -> Maybe Ident -> f Ident
f (forall a. [a] -> [a]
reverse [ExpReturns]
rts) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> Maybe a
Just (forall a. [a] -> [a]
reverse [Ident]
idents) forall a. [a] -> [a] -> [a]
++ forall a. a -> [a]
repeat forall a. Maybe a
Nothing)
  where
    f :: MemInfo d u ret -> Maybe Ident -> f Ident
f MemInfo d u ret
_ (Just Ident
ident) = forall (f :: * -> *) a. Applicative f => a -> f a
pure Ident
ident
    f (MemMem Space
space) Maybe Ident
Nothing = forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext_mem" forall a b. (a -> b) -> a -> b
$ forall shape u. Space -> TypeBase shape u
Mem Space
space
    f MemInfo d u ret
_ Maybe Ident
Nothing = forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

allocsForPat ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
  Space ->
  [Ident] ->
  [ExpReturns] ->
  [ExpHint] ->
  m [PatElem LetDecMem]
allocsForPat :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
some_idents [ExpReturns]
rts [ExpHint]
hints = do
  [Ident]
idents <- forall (m :: * -> *).
MonadFreshNames m =>
[Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
some_idents [ExpReturns]
rts

  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Ident]
idents [ExpReturns]
rts [ExpHint]
hints) forall a b. (a -> b) -> a -> b
$ \(Ident
ident, ExpReturns
rt, ExpHint
hint) -> do
    let ident_shape :: Shape
ident_shape = forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident
    case ExpReturns
rt of
      MemPrim PrimType
_ -> do
        LParamMem
summary <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
def_space (Ident -> Type
identType Ident
ident) ExpHint
hint
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
      MemMem Space
space ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
      MemArray PrimType
bt ShapeBase (Ext SubExp)
_ NoUniqueness
u (Just (ReturnsInBlock VName
mem ExtIxFun
extixfun)) -> do
        let ixfn :: IxFun (TPrimExp Int64 VName)
ixfn = forall {f :: * -> *} {f :: * -> *}.
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents ExtIxFun
extixfun
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfn
      MemArray PrimType
_ ShapeBase (Ext SubExp)
extshape NoUniqueness
_ Maybe MemReturn
Nothing
        | Just [SubExp]
_ <- forall {b}. ShapeBase (Ext b) -> Maybe [b]
knownShape ShapeBase (Ext SubExp)
extshape -> do
            LParamMem
summary <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
def_space (Ident -> Type
identType Ident
ident) ExpHint
hint
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
      MemArray PrimType
bt ShapeBase (Ext SubExp)
_ NoUniqueness
u (Just (ReturnsNewBlock Space
_ Int
i ExtIxFun
extixfn)) -> do
        let ixfn :: IxFun (TPrimExp Int64 VName)
ixfn = forall {f :: * -> *} {f :: * -> *}.
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents ExtIxFun
extixfn
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$
          VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (forall {a}. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i) IxFun (TPrimExp Int64 VName)
ixfn
      MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
      ExpReturns
_ -> forall a. HasCallStack => String -> a
error String
"Impossible case reached in allocsForPat!"
  where
    knownShape :: ShapeBase (Ext b) -> Maybe [b]
knownShape = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {a}. Ext a -> Maybe a
known forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. ShapeBase d -> [d]
shapeDims
    known :: Ext a -> Maybe a
known (Free a
v) = forall a. a -> Maybe a
Just a
v
    known Ext {} = forall a. Maybe a
Nothing

    getIdent :: [Ident] -> a -> VName
getIdent [Ident]
idents a
i =
      case forall int a. Integral int => int -> [a] -> Maybe a
maybeNth a
i [Ident]
idents of
        Just Ident
ident -> Ident -> VName
identName Ident
ident
        Maybe Ident
Nothing ->
          forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"getIdent: Ext " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show a
i forall a. Semigroup a => a -> a -> a
<> String
" but pattern has " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
idents) forall a. Semigroup a => a -> a -> a
<> String
" elements: " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> String
prettyString [Ident]
idents

    instantiateExtIxFun :: [Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> VName
inst
      where
        inst :: Ext VName -> VName
inst (Free VName
v) = VName
v
        inst (Ext Int
i) = forall {a}. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i

instantiateIxFun :: (Monad m) => ExtIxFun -> m IxFun
instantiateIxFun :: forall (m :: * -> *).
Monad m =>
ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
instantiateIxFun = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {f :: * -> *} {a}. Applicative f => Ext a -> f a
inst
  where
    inst :: Ext a -> f a
inst Ext {} = forall a. HasCallStack => String -> a
error String
"instantiateIxFun: not yet"
    inst (Free a
x) = forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

summaryForBindage ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
  Space ->
  Type ->
  ExpHint ->
  m (MemBound NoUniqueness)
summaryForBindage :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
_ (Prim PrimType
bt) ExpHint
_ =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
summaryForBindage Space
_ (Mem Space
space) ExpHint
_ =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
summaryForBindage Space
_ (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) ExpHint
_ =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
summaryForBindage Space
def_space t :: Type
t@(Array PrimType
pt Shape
shape NoUniqueness
u) ExpHint
NoHint = do
  VName
m <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
t Space
def_space
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
summaryForBindage Space
_ t :: Type
t@(Array PrimType
pt Shape
_ NoUniqueness
_) (Hint IxFun (TPrimExp Int64 VName)
ixfun Space
space) = do
  SubExp
bytes <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
      forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product
        [ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun,
          forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Num a => PrimType -> a
primByteSize PrimType
pt :: Int64)
        ]
  VName
m <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
bytes Space
space
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int64 VName)
ixfun

allocInFParams ::
  (Allocable fromrep torep inner) =>
  [(FParam fromrep, Space)] ->
  ([FParam torep] -> AllocM fromrep torep a) ->
  AllocM fromrep torep a
allocInFParams :: forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams [(FParam fromrep, Space)]
params [FParam torep] -> AllocM fromrep torep a
m = do
  ([Param FParamMem]
valparams, ([Param FParamMem]
memparams, [Param FParamMem]
ctxparams)) <-
    forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam) [(FParam fromrep, Space)]
params
  let params' :: [Param FParamMem]
params' = [Param FParamMem]
memparams forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctxparams forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
params'
  forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary forall a b. (a -> b) -> a -> b
$ [FParam torep] -> AllocM fromrep torep a
m [Param FParamMem]
params'

allocInFParam ::
  (Allocable fromrep torep inner) =>
  FParam fromrep ->
  Space ->
  WriterT
    ([FParam torep], [FParam torep])
    (AllocM fromrep torep)
    (FParam torep)
allocInFParam :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam FParam fromrep
param Space
pspace =
  case forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType FParam fromrep
param of
    Array PrimType
pt Shape
shape Uniqueness
u -> do
      let memname :: String
memname = VName -> String
baseString (forall dec. Param dec -> VName
paramName FParam fromrep
param) forall a. Semigroup a => a -> a -> a
<> String
"_mem"
          ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName
mem <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
memname
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([forall dec. Attrs -> VName -> dec -> Param dec
Param (forall dec. Param dec -> Attrs
paramAttrs FParam fromrep
param) VName
mem forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
pspace], [])
      forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec :: FParamMem
paramDec = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun}
    Prim PrimType
pt ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec :: FParamMem
paramDec = forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt}
    Mem Space
space ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec :: FParamMem
paramDec = forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}
    Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec :: FParamMem
paramDec = forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u}

ensureRowMajorArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensureRowMajorArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray Maybe Space
space_ok VName
v = do
  (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  let space :: Space
space = forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok
  if forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun) forall a. Eq a => a -> a -> Bool
== forall num. IntegralExp num => IxFun num -> Int
IxFun.rank IxFun (TPrimExp Int64 VName)
ixfun
    Bool -> Bool -> Bool
&& forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
    else forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

ensureArrayIn ::
  (Allocable fromrep torep inner) =>
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
_ (Constant PrimValue
v) =
  forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"ensureArrayIn: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString PrimValue
v forall a. [a] -> [a] -> [a]
++ String
" cannot be an array."
ensureArrayIn Space
space (Var VName
v) = do
  (VName
mem', VName
v') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray (forall a. a -> Maybe a
Just Space
space) VName
v
  (VName
_, IxFun (TPrimExp Int64 VName)
ixfun) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v'
  [SubExp]
ctx <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ixfun_arg" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp) (forall (t :: * -> *) a. Foldable t => t a -> [a]
toList IxFun (TPrimExp Int64 VName)
ixfun)
  forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
mem'], [SubExp]
ctx)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'

allocInMergeParams ::
  (Allocable fromrep torep inner) =>
  [(FParam fromrep, SubExp)] ->
  ( [(FParam torep, SubExp)] ->
    ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])) ->
    AllocM fromrep torep a
  ) ->
  AllocM fromrep torep a
allocInMergeParams :: forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
merge [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m = do
  (([Param FParamMem]
valparams, [SubExp]
valargs, [SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps), ([Param FParamMem]
mem_params, [Param FParamMem]
ctx_params)) <-
    forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam [(FParam fromrep, SubExp)]
merge
  let mergeparams' :: [Param FParamMem]
mergeparams' = [Param FParamMem]
mem_params forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctx_params forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
mergeparams'

      mk_loop_res :: [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
ses = do
        ([SubExp]
ses', ([SubExp]
memargs, [SubExp]
ctxargs)) <-
          forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall a b. (a -> b) -> a -> b
($) [SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps [SubExp]
ses
        forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
memargs forall a. Semigroup a => a -> a -> a
<> [SubExp]
ctxargs, [SubExp]
ses')

  ([SubExp]
valctx_args, [SubExp]
valargs') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
valargs
  let merge' :: [(Param FParamMem, SubExp)]
merge' =
        forall a b. [a] -> [b] -> [(a, b)]
zip ([Param FParamMem]
mem_params forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctx_params forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams) ([SubExp]
valctx_args forall a. Semigroup a => a -> a -> a
<> [SubExp]
valargs')
  forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary forall a b. (a -> b) -> a -> b
$ [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m [(Param FParamMem, SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res
  where
    param_names :: Names
param_names = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam fromrep, SubExp)]
merge
    anyIsLoopParam :: Names -> Bool
anyIsLoopParam Names
names = Names
names Names -> Names -> Bool
`namesIntersect` Names
param_names

    scalarRes :: DeclType
-> Space -> IxFun (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun (Var VName
res) = do
      -- Try really hard to avoid copying needlessly, but the result
      -- _must_ be in ScalarSpace and have the right index function.
      (VName
res_mem, IxFun (TPrimExp Int64 VName)
res_ixfun) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
res
      Space
res_mem_space <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
res_mem
      (VName
res_mem', VName
res') <-
        if (Space
res_mem_space, IxFun (TPrimExp Int64 VName)
res_ixfun) forall a. Eq a => a -> a -> Bool
== (Space
v_mem_space, IxFun (TPrimExp Int64 VName)
v_ixfun)
          then forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
res_mem, VName
res)
          else forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m),
 LetDec (Rep m) ~ LParamMem) =>
Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
arrayWithIxFun Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun (forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl DeclType
param_t) VName
res
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
res_mem'], [])
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
res'
    scalarRes DeclType
_ Space
_ IxFun (TPrimExp Int64 VName)
_ SubExp
se = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

    allocInMergeParam ::
      (Allocable fromrep torep inner) =>
      (Param DeclType, SubExp) ->
      WriterT
        ([FParam torep], [FParam torep])
        (AllocM fromrep torep)
        ( FParam torep,
          SubExp,
          SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
        )
    allocInMergeParam :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, Var VName
v)
      | param_t :: DeclType
param_t@(Array PrimType
pt Shape
shape Uniqueness
u) <- forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
mergeparam = do
          (VName
v_mem, IxFun (TPrimExp Int64 VName)
v_ixfun) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
          Space
v_mem_space <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem

          -- Loop-invariant array parameters that are in scalar space
          -- are special - we do not wish to existentialise their index
          -- function at all (but the memory block is still existential).
          case Space
v_mem_space of
            ScalarSpace {} ->
              if Names -> Bool
anyIsLoopParam (forall a. FreeIn a => a -> Names
freeIn Shape
shape)
                then do
                  -- Arrays with loop-variant shape cannot be in scalar
                  -- space, so copy them elsewhere and try again.
                  Space
space <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
                  (VName
_, VName
v') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v
                  forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, VName -> SubExp
Var VName
v')
                else do
                  Param FParamMem
p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space
                  forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem
p], [])

                  forall (f :: * -> *) a. Applicative f => a -> f a
pure
                    ( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (forall dec. Param dec -> VName
paramName Param FParamMem
p) IxFun (TPrimExp Int64 VName)
v_ixfun},
                      VName -> SubExp
Var VName
v,
                      forall {m :: * -> *} {inner :: * -> *} {t :: (* -> *) -> * -> *}
       {a}.
(BranchType (Rep m) ~ BranchTypeMem, LetDec (Rep m) ~ LParamMem,
 RetType (Rep m) ~ RetTypeMem, LParamInfo (Rep m) ~ LParamMem,
 FParamInfo (Rep m) ~ FParamMem, OpC (Rep m) ~ MemOp inner,
 RephraseOp inner, MonadWriter ([SubExp], [a]) (t m),
 MonadBuilder m, MonadTrans t, HasLetDecMem (LetDec (Rep m)),
 OpReturns (inner (Rep m))) =>
DeclType
-> Space -> IxFun (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun
                    )
            Space
_ -> do
              (VName
v_mem', VName
v') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray forall a. Maybe a
Nothing VName
v
              (VName
_, IxFun (TPrimExp Int64 VName)
v_ixfun') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v'
              Space
v_mem_space' <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem'

              [Param FParamMem]
ctx_params <-
                forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length IxFun (TPrimExp Int64 VName)
v_ixfun') forall a b. (a -> b) -> a -> b
$
                  forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ctx_param_ext" (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)

              IxFun (TPrimExp Int64 VName)
param_ixfun <-
                forall (m :: * -> *).
Monad m =>
ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
instantiateIxFun forall a b. (a -> b) -> a -> b
$
                  forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun
                    ( forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Int -> Ext a
Ext [Int
0 ..]) forall a b. (a -> b) -> a -> b
$
                        forall a b. (a -> b) -> [a] -> [b]
map (forall a. a -> TPrimExp Int64 a
le64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Ext a
Free forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param FParamMem]
ctx_params
                    )
                    (forall a b.
IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b))
IxFun.existentialize IxFun (TPrimExp Int64 VName)
v_ixfun')

              Param FParamMem
mem_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space'
              forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem
mem_param], [Param FParamMem]
ctx_params)
              forall (f :: * -> *) a. Applicative f => a -> f a
pure
                ( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (forall dec. Param dec -> VName
paramName Param FParamMem
mem_param) IxFun (TPrimExp Int64 VName)
param_ixfun},
                  VName -> SubExp
Var VName
v',
                  forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
v_mem_space'
                )
    allocInMergeParam (Param DeclType
mergeparam, SubExp
se) = forall {fromrep} {fromrep} {torep} {torep} {inner :: * -> *}
       {inner :: * -> *} {b}.
(BranchType fromrep ~ ExtType, BranchType fromrep ~ ExtType,
 BranchType torep ~ BranchTypeMem, BranchType torep ~ BranchTypeMem,
 LParamInfo fromrep ~ Type, LParamInfo torep ~ LParamMem,
 LParamInfo torep ~ LParamMem, LParamInfo fromrep ~ Type,
 ExpDec torep ~ (), ExpDec torep ~ (), LetDec torep ~ LParamMem,
 LetDec torep ~ LParamMem, BodyDec torep ~ (), BodyDec fromrep ~ (),
 BodyDec fromrep ~ (), BodyDec torep ~ (),
 FParamInfo torep ~ FParamMem, FParamInfo fromrep ~ DeclType,
 FParamInfo fromrep ~ DeclType, FParamInfo torep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType fromrep ~ DeclExtType,
 RetType torep ~ RetTypeMem, RetType torep ~ RetTypeMem,
 OpC torep ~ MemOp inner, OpC torep ~ MemOp inner,
 PrettyRep fromrep, PrettyRep fromrep, HasLetDecMem (LetDec torep),
 HasLetDecMem (LetDec torep), OpReturns (inner torep),
 OpReturns (inner torep), RephraseOp inner, RephraseOp inner,
 SizeSubst (inner torep), SizeSubst (inner torep), BuilderOps torep,
 BuilderOps torep) =>
Param (FParamInfo fromrep)
-> b
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param DeclType
mergeparam SubExp
se forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace

    doDefault :: Param (FParamInfo fromrep)
-> b
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param (FParamInfo fromrep)
mergeparam b
se Space
space = do
      Param (FParamInfo torep)
mergeparam' <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam Param (FParamInfo fromrep)
mergeparam Space
space
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (FParamInfo torep)
mergeparam', b
se, forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg (forall dec. Typed dec => Param dec -> Type
paramType Param (FParamInfo fromrep)
mergeparam) Space
space)

arrayWithIxFun ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m), LetDec (Rep m) ~ LetDecMem) =>
  Space ->
  IxFun ->
  Type ->
  VName ->
  m (VName, VName)
arrayWithIxFun :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m),
 LetDec (Rep m) ~ LParamMem) =>
Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
arrayWithIxFun Space
space IxFun (TPrimExp Int64 VName)
ixfun Type
v_t VName
v = do
  let Array PrimType
pt Shape
shape NoUniqueness
u = Type
v_t
  VName
mem <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
v_t Space
space
  VName
v_copy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v forall a. Semigroup a => a -> a -> a
<> String
"_scalcopy"
  let pe :: PatElem LParamMem
pe = forall dec. VName -> dec -> PatElem dec
PatElem VName
v_copy forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun
  forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem
pe]) forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v_copy)

ensureDirectArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensureDirectArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v = do
  (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  if forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun (TPrimExp Int64 VName)
ixfun Bool -> Bool -> Bool
&& forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
    else Space -> AllocM fromrep torep (VName, VName)
needCopy (forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok)
  where
    needCopy :: Space -> AllocM fromrep torep (VName, VName)
needCopy Space
space =
      -- We need to do a new allocation, copy 'v', and make a new
      -- binding for the size of the memory block.
      forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

allocPermArray ::
  (Allocable fromrep torep inner) =>
  Space ->
  [Int] ->
  String ->
  VName ->
  AllocM fromrep torep (VName, VName)
allocPermArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray Space
space [Int]
perm String
s VName
v = do
  Type
t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  case Type
t of
    Array PrimType
pt Shape
shape NoUniqueness
u -> do
      VName
mem <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space
      VName
v' <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ String
s forall a. Semigroup a => a -> a -> a
<> String
"_desired_form"
      let info :: LParamMem
info =
            forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem forall a b. (a -> b) -> a -> b
$
              forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute (forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) [Int]
perm
          pat :: Pat LParamMem
pat = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
v' LParamMem
info]
      forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat LParamMem
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
v
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v')
    Type
_ ->
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"allocPermArray: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Type
t

ensurePermArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  [Int] ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensurePermArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space
-> [Int] -> VName -> AllocM fromrep torep (VName, VName)
ensurePermArray Maybe Space
space_ok [Int]
perm VName
v = do
  (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  if forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape IxFun (TPrimExp Int64 VName)
ixfun)
    Bool -> Bool -> Bool
&& forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
    else forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray (forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok) [Int]
perm (VName -> String
baseString VName
v) VName
v

allocLinearArray ::
  (Allocable fromrep torep inner) =>
  Space ->
  String ->
  VName ->
  AllocM fromrep torep (VName, VName)
allocLinearArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space String
s VName
v = do
  Type
t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  let perm :: [Int]
perm = [Int
0 .. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t forall a. Num a => a -> a -> a
- Int
1]
  forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray Space
space [Int]
perm String
s VName
v

funcallArgs ::
  (Allocable fromrep torep inner) =>
  [(SubExp, Diet)] ->
  AllocM fromrep torep [(SubExp, Diet)]
funcallArgs :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args = do
  ([(SubExp, Diet)]
valargs, ([SubExp]
ctx_args, [SubExp]
mem_and_size_args)) <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(SubExp, Diet)]
args forall a b. (a -> b) -> a -> b
$ \(SubExp
arg, Diet
d) -> do
      Type
t <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
arg
      Space
space <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
      SubExp
arg' <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Type
t Space
space SubExp
arg
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
arg', Diet
d)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (,Diet
Observe) ([SubExp]
ctx_args forall a. Semigroup a => a -> a -> a
<> [SubExp]
mem_and_size_args) forall a. Semigroup a => a -> a -> a
<> [(SubExp, Diet)]
valargs

linearFuncallArg ::
  (Allocable fromrep torep inner) =>
  Type ->
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Array {} Space
space (Var VName
v) = do
  (VName
mem, VName
arg') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray (forall a. a -> Maybe a
Just Space
space) VName
v
  forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
mem], [])
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arg'
linearFuncallArg Type
_ Space
_ SubExp
arg =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
arg

shiftRetAls :: Int -> Int -> RetAls -> RetAls
shiftRetAls :: Int -> Int -> RetAls -> RetAls
shiftRetAls Int
a Int
b (RetAls [Int]
is [Int]
js) =
  [Int] -> [Int] -> RetAls
RetAls (forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
+ Int
a) [Int]
is) (forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
+ Int
b) [Int]
js)

explicitAllocationsGeneric ::
  (Allocable fromrep torep inner) =>
  Space ->
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Pass fromrep torep
explicitAllocationsGeneric :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Pass fromrep torep
explicitAllocationsGeneric Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints =
  forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"explicit allocations" String
"Transform program to explicit memory representation" forall a b. (a -> b) -> a -> b
$
    forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms fromrep -> PassM (Stms torep)
onStms Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun
  where
    onStms :: Stms fromrep -> PassM (Stms torep)
onStms Stms fromrep
stms =
      forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    allocInFun :: Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun Stms torep
consts (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [(RetType fromrep, RetAls)]
rettype [FParam fromrep]
params Body fromrep
fbody) =
      forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms torep
consts forall a b. (a -> b) -> a -> b
$
        forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams (forall a b. (a -> b) -> [a] -> [b]
map (,Space
space) [FParam fromrep]
params) forall a b. (a -> b) -> a -> b
$ \[FParam torep]
params' -> do
          (Body torep
fbody', [RetTypeMem]
mem_rets) <-
            forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just Space
space) [(RetType fromrep, RetAls)]
rettype) Body fromrep
fbody
          let num_extra_params :: Int
num_extra_params = forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam torep]
params' forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam fromrep]
params
              num_extra_rets :: Int
num_extra_rets = forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
mem_rets
              rettype' :: [(RetTypeMem, RetAls)]
rettype' =
                forall a b. (a -> b) -> [a] -> [b]
map (,[Int] -> [Int] -> RetAls
RetAls forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty) [RetTypeMem]
mem_rets
                  forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip
                    (Space -> Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Space
space (forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
mem_rets) (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(RetType fromrep, RetAls)]
rettype))
                    (forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> RetAls -> RetAls
shiftRetAls Int
num_extra_params Int
num_extra_rets forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(RetType fromrep, RetAls)]
rettype)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [(RetType rep, RetAls)]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [(RetTypeMem, RetAls)]
rettype' [FParam torep]
params' Body torep
fbody'

explicitAllocationsInStmsGeneric ::
  ( MonadFreshNames m,
    HasScope torep m,
    Allocable fromrep torep inner
  ) =>
  Space ->
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Stms fromrep ->
  m (Stms torep)
explicitAllocationsInStmsGeneric :: forall (m :: * -> *) torep fromrep (inner :: * -> *).
(MonadFreshNames m, HasScope torep m,
 Allocable fromrep torep inner) =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Stms fromrep
-> m (Stms torep)
explicitAllocationsInStmsGeneric Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints Stms fromrep
stms = do
  Scope torep
scope <- forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints forall a b. (a -> b) -> a -> b
$
    forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
scope forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$
        forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms forall a b. (a -> b) -> a -> b
$
          forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

memoryInDeclExtType :: Space -> Int -> [DeclExtType] -> [FunReturns]
memoryInDeclExtType :: Space -> Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Space
space Int
k [DeclExtType]
dets = forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DeclExtType -> StateT Int Identity RetTypeMem
addMem [DeclExtType]
dets) Int
0
  where
    addMem :: DeclExtType -> StateT Int Identity RetTypeMem
addMem (Prim PrimType
t) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    addMem Mem {} = forall a. HasCallStack => String -> a
error String
"memoryInDeclExtType: too much memory"
    addMem (Array PrimType
pt ShapeBase (Ext SubExp)
shape Uniqueness
u) = do
      Int
i <- forall s (m :: * -> *). MonadState s m => m s
get forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Num a => a -> a -> a
+ Int
1)
      let shape' :: ShapeBase (Ext SubExp)
shape' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext SubExp -> Ext SubExp
shift ShapeBase (Ext SubExp)
shape
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' Uniqueness
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i forall a b. (a -> b) -> a -> b
$
        forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert forall a b. (a -> b) -> a -> b
$
            forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape'
    addMem (Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = forall a. a -> TPrimExp Int64 a
le64 forall a b. (a -> b) -> a -> b
$ forall a. Int -> Ext a
Ext Int
i
    convert (Free SubExp
v) = forall a. a -> Ext a
Free forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    shift :: Ext SubExp -> Ext SubExp
shift (Ext Int
i) = forall a. Int -> Ext a
Ext (Int
i forall a. Num a => a -> a -> a
+ Int
k)
    shift (Free SubExp
x) = forall a. a -> Ext a
Free SubExp
x

bodyReturnMemCtx ::
  (Allocable fromrep torep inner) =>
  SubExpRes ->
  AllocM fromrep torep [(SubExpRes, MemInfo ExtSize u MemReturn)]
bodyReturnMemCtx :: forall fromrep torep (inner :: * -> *) u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx (SubExpRes Certs
_ Constant {}) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure []
bodyReturnMemCtx (SubExpRes Certs
_ (Var VName
v)) = do
  LParamMem
info <- forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
v
  case LParamMem
info of
    MemPrim {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    MemAcc {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    MemMem {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [] -- should not happen
    MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_) -> do
      LParamMem
mem_info <- forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
mem
      case LParamMem
mem_info of
        MemMem Space
space ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure [(SubExp -> SubExpRes
subExpRes forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
mem, forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)]
        LParamMem
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"bodyReturnMemCtx: not a memory block: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString VName
mem

allocInFunBody ::
  (Allocable fromrep torep inner) =>
  [Maybe Space] ->
  Body fromrep ->
  AllocM fromrep torep (Body torep, [FunReturns])
allocInFunBody :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody [Maybe Space]
space_oks (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
  forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms forall a b. (a -> b) -> a -> b
$ do
    Result
res' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect [Maybe Space]
space_oks' Result
res
    (Result
mem_ctx_res, [RetTypeMem]
mem_ctx_rets) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall fromrep torep (inner :: * -> *) u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx Result
res'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
mem_ctx_res forall a. Semigroup a => a -> a -> a
<> Result
res', [RetTypeMem]
mem_ctx_rets)
  where
    num_vals :: Int
num_vals = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Space]
space_oks
    space_oks' :: [Maybe Space]
space_oks' = forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res forall a. Num a => a -> a -> a
- Int
num_vals) forall a. Maybe a
Nothing forall a. [a] -> [a] -> [a]
++ [Maybe Space]
space_oks

ensureDirect ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  SubExpRes ->
  AllocM fromrep torep SubExpRes
ensureDirect :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect Maybe Space
space_ok (SubExpRes Certs
cs SubExp
se) = do
  LParamMem
se_info <- forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
  Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case (LParamMem
se_info, SubExp
se) of
    (MemArray {}, Var VName
v) -> do
      (VName
_, VName
v') <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'
    (LParamMem, SubExp)
_ ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

allocInStms ::
  (Allocable fromrep torep inner) =>
  Stms fromrep ->
  AllocM fromrep torep a ->
  AllocM fromrep torep a
allocInStms :: forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
origstms AllocM fromrep torep a
m = [Stm fromrep] -> AllocM fromrep torep a
allocInStms' forall a b. (a -> b) -> a -> b
$ forall rep. Stms rep -> [Stm rep]
stmsToList Stms fromrep
origstms
  where
    allocInStms' :: [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [] = AllocM fromrep torep a
m
    allocInStms' (Stm fromrep
stm : [Stm fromrep]
stms) = do
      Seq (Stm torep)
allocstms <- forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing (forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm fromrep
stm) forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm Stm fromrep
stm
      forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Seq (Stm torep)
allocstms
      let stms_consts :: Set VName
stms_consts = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall rep. SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts Seq (Stm torep)
allocstms
          f :: AllocEnv fromrep torep -> AllocEnv fromrep torep
f AllocEnv fromrep torep
env = AllocEnv fromrep torep
env {envConsts :: Set VName
envConsts = Set VName
stms_consts forall a. Semigroup a => a -> a -> a
<> forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts AllocEnv fromrep torep
env}
      forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv fromrep torep -> AllocEnv fromrep torep
f forall a b. (a -> b) -> a -> b
$ [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [Stm fromrep]
stms

allocInStm ::
  (Allocable fromrep torep inner) =>
  Stm fromrep ->
  AllocM fromrep torep ()
allocInStm :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm (Let (Pat [PatElem (LetDec fromrep)]
pes) StmAux (ExpDec fromrep)
_ Exp fromrep
e) =
  forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => PatElem dec -> Ident
patElemIdent [PatElem (LetDec fromrep)]
pes) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp Exp fromrep
e

allocInLambda ::
  (Allocable fromrep torep inner) =>
  [LParam torep] ->
  Body fromrep ->
  AllocM fromrep torep (Lambda torep)
allocInLambda :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
params Body fromrep
body =
  forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [LParam torep]
params forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms (forall rep. Body rep -> Stms rep
bodyStms Body fromrep
body) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult Body fromrep
body

data MemReq
  = MemReq Space Rank
  | NeedsNormalisation Space
  deriving (MemReq -> MemReq -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemReq -> MemReq -> Bool
$c/= :: MemReq -> MemReq -> Bool
== :: MemReq -> MemReq -> Bool
$c== :: MemReq -> MemReq -> Bool
Eq, Int -> MemReq -> ShowS
[MemReq] -> ShowS
MemReq -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemReq] -> ShowS
$cshowList :: [MemReq] -> ShowS
show :: MemReq -> String
$cshow :: MemReq -> String
showsPrec :: Int -> MemReq -> ShowS
$cshowsPrec :: Int -> MemReq -> ShowS
Show)

combMemReqs :: MemReq -> MemReq -> MemReq
combMemReqs :: MemReq -> MemReq -> MemReq
combMemReqs x :: MemReq
x@NeedsNormalisation {} MemReq
_ = MemReq
x
combMemReqs MemReq
_ y :: MemReq
y@NeedsNormalisation {} = MemReq
y
combMemReqs x :: MemReq
x@(MemReq Space
x_space Rank
_) y :: MemReq
y@MemReq {} =
  if MemReq
x forall a. Eq a => a -> a -> Bool
== MemReq
y then MemReq
x else Space -> MemReq
NeedsNormalisation Space
x_space

type MemReqType = MemInfo (Ext SubExp) NoUniqueness MemReq

combMemReqTypes :: MemReqType -> MemReqType -> MemReqType
combMemReqTypes :: MemReqType -> MemReqType -> MemReqType
combMemReqTypes (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u MemReq
x) (MemArray PrimType
_ ShapeBase (Ext SubExp)
_ NoUniqueness
_ MemReq
y) =
  forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ MemReq -> MemReq -> MemReq
combMemReqs MemReq
x MemReq
y
combMemReqTypes MemReqType
x MemReqType
_ = MemReqType
x

contextRets :: MemReqType -> [MemInfo d u r]
contextRets :: forall d u r. MemReqType -> [MemInfo d u r]
contextRets (MemArray PrimType
_ ShapeBase (Ext SubExp)
shape NoUniqueness
_ (MemReq Space
space (Rank Int
base_rank))) =
  -- Memory + offset + base_rank + (stride,size)*rank.
  forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    forall a. a -> [a] -> [a]
: forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64
    forall a. a -> [a] -> [a]
: forall a. Int -> a -> [a]
replicate Int
base_rank (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
    forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate (Int
2 forall a. Num a => a -> a -> a
* forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
contextRets (MemArray PrimType
_ ShapeBase (Ext SubExp)
shape NoUniqueness
_ (NeedsNormalisation Space
space)) =
  -- Memory + offset + (base,stride,size)*rank.
  forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    forall a. a -> [a] -> [a]
: forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64
    forall a. a -> [a] -> [a]
: forall a. Int -> a -> [a]
replicate (Int
3 forall a. Num a => a -> a -> a
* forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
contextRets MemReqType
_ = []

-- Add memory information to the body, but do not return memory/ixfun
-- information.  Instead, return restrictions on what the index
-- function should look like.  We will then (crudely) unify these
-- restrictions across all bodies.
allocInMatchBody ::
  (Allocable fromrep torep inner) =>
  [ExtType] ->
  Body fromrep ->
  AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [ExtType]
rets (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
  forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms forall a b. (a -> b) -> a -> b
$ do
    [MemReqType]
restrictions <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {rep} {inner :: * -> *} {m :: * -> *} {d}.
(LParamInfo rep ~ LParamMem, FParamInfo rep ~ FParamMem,
 BranchType rep ~ BranchTypeMem, RetType rep ~ RetTypeMem,
 OpC rep ~ MemOp inner, HasScope rep m, RephraseOp inner,
 ASTRep rep, HasLetDecMem (LetDec rep), Monad m,
 OpReturns (inner rep), Show d) =>
TypeBase (ShapeBase d) NoUniqueness
-> SubExp -> m (MemInfo d NoUniqueness MemReq)
restriction [ExtType]
rets (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res)
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
res, [MemReqType]
restrictions)
  where
    restriction :: TypeBase (ShapeBase d) NoUniqueness
-> SubExp -> m (MemInfo d NoUniqueness MemReq)
restriction TypeBase (ShapeBase d) NoUniqueness
t SubExp
se = do
      LParamMem
v_info <- forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
      case (TypeBase (ShapeBase d) NoUniqueness
t, LParamMem
v_info) of
        (Array PrimType
pt ShapeBase d
shape NoUniqueness
u, MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun)) -> do
          Space
space <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$
            Space -> Rank -> MemReq
MemReq Space
space (Int -> Rank
Rank forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun)
        (TypeBase (ShapeBase d) NoUniqueness
_, MemMem Space
space) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
        (TypeBase (ShapeBase d) NoUniqueness
_, MemPrim PrimType
pt) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
        (TypeBase (ShapeBase d) NoUniqueness
_, MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
        (TypeBase (ShapeBase d) NoUniqueness, LParamMem)
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"allocInMatchBody: mismatch: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (TypeBase (ShapeBase d) NoUniqueness
t, LParamMem
v_info)

mkBranchRet :: [MemReqType] -> [BranchTypeMem]
mkBranchRet :: [MemReqType] -> [BranchTypeMem]
mkBranchRet [MemReqType]
reqs =
  let ([BranchTypeMem]
ctx_rets, [BranchTypeMem]
res_rets) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([BranchTypeMem], [BranchTypeMem])
-> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem])
helper ([], []) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [MemReqType]
reqs [Int]
offsets
   in [BranchTypeMem]
ctx_rets forall a. [a] -> [a] -> [a]
++ [BranchTypeMem]
res_rets
  where
    numCtxNeeded :: MemReqType -> Int
numCtxNeeded = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u r. MemReqType -> [MemInfo d u r]
contextRets

    offsets :: [Int]
offsets = forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(+) Int
0 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map MemReqType -> Int
numCtxNeeded [MemReqType]
reqs
    num_new_ctx :: Int
num_new_ctx = forall a. [a] -> a
last [Int]
offsets

    helper :: ([BranchTypeMem], [BranchTypeMem])
-> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem])
helper ([BranchTypeMem]
ctx_rets_acc, [BranchTypeMem]
res_rets_acc) (MemReqType
req, Int
ctx_offset) =
      ( [BranchTypeMem]
ctx_rets_acc forall a. [a] -> [a] -> [a]
++ forall d u r. MemReqType -> [MemInfo d u r]
contextRets MemReqType
req,
        [BranchTypeMem]
res_rets_acc forall a. [a] -> [a] -> [a]
++ [Int -> MemReqType -> BranchTypeMem
inspect Int
ctx_offset MemReqType
req]
      )

    arrayInfo :: Int -> MemReq -> (Space, Int)
arrayInfo Int
rank (NeedsNormalisation Space
space) =
      (Space
space, Int
rank)
    arrayInfo Int
_ (MemReq Space
space (Rank Int
base_rank)) =
      (Space
space, Int
base_rank)

    inspect :: Int -> MemReqType -> BranchTypeMem
inspect Int
ctx_offset (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u MemReq
req) =
      let shape' :: ShapeBase (Ext SubExp)
shape' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> Ext a -> Ext a
adjustExt Int
num_new_ctx) ShapeBase (Ext SubExp)
shape
          (Space
space, Int
base_rank) = Int -> MemReq -> (Space, Int)
arrayInfo (forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) MemReq
req
       in forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' NoUniqueness
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
ctx_offset forall a b. (a -> b) -> a -> b
$
            Ext SubExp -> TPrimExp Int64 (Ext VName)
convert forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Int -> Int -> Int -> IxFun (Ext a)
IxFun.mkExistential Int
base_rank (forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) (Int
ctx_offset forall a. Num a => a -> a -> a
+ Int
1)
    inspect Int
_ (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
    inspect Int
_ (MemPrim PrimType
pt) = forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
    inspect Int
_ (MemMem Space
space) = forall d u ret. Space -> MemInfo d u ret
MemMem Space
space

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = forall a. a -> TPrimExp Int64 a
le64 (forall a. Int -> Ext a
Ext Int
i)
    convert (Free SubExp
v) = forall a. a -> Ext a
Free forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    adjustExt :: Int -> Ext a -> Ext a
    adjustExt :: forall a. Int -> Ext a -> Ext a
adjustExt Int
_ (Free a
v) = forall a. a -> Ext a
Free a
v
    adjustExt Int
k (Ext Int
i) = forall a. Int -> Ext a
Ext (Int
k forall a. Num a => a -> a -> a
+ Int
i)

addCtxToMatchBody ::
  (Allocable fromrep torep inner) =>
  [MemReqType] ->
  Body torep ->
  AllocM fromrep torep (Body torep)
addCtxToMatchBody :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs Body torep
body = forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
  Result
res <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {torep} {fromrep} {inner :: * -> *} {d} {u}.
(ExpDec torep ~ (), BodyDec fromrep ~ (), BodyDec torep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo torep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType torep ~ RetTypeMem,
 LetDec torep ~ LParamMem, LParamInfo fromrep ~ Type,
 LParamInfo torep ~ LParamMem, BranchType fromrep ~ ExtType,
 BranchType torep ~ BranchTypeMem, OpC torep ~ MemOp inner,
 PrettyRep fromrep, HasLetDecMem (LetDec torep),
 OpReturns (inner torep), RephraseOp inner, SizeSubst (inner torep),
 BuilderOps torep, ArrayShape (ShapeBase d)) =>
MemInfo d u MemReq -> SubExpRes -> AllocM fromrep torep SubExpRes
normaliseIfNeeded [MemReqType]
reqs forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body torep
body
  Result
ctx <- forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {f :: * -> *} {inner :: * -> *}.
(BranchType (Rep f) ~ BranchTypeMem,
 LParamInfo (Rep f) ~ LParamMem, FParamInfo (Rep f) ~ FParamMem,
 RetType (Rep f) ~ RetTypeMem, OpC (Rep f) ~ MemOp inner,
 MonadBuilder f, RephraseOp inner, HasLetDecMem (LetDec (Rep f)),
 OpReturns (inner (Rep f))) =>
SubExpRes -> f Result
resCtx Result
res
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Result
ctx forall a. [a] -> [a] -> [a]
++ Result
res
  where
    normaliseIfNeeded :: MemInfo d u MemReq -> SubExpRes -> AllocM fromrep torep SubExpRes
normaliseIfNeeded (MemArray PrimType
_ ShapeBase d
shape u
_ (NeedsNormalisation Space
space)) (SubExpRes Certs
cs (Var VName
v)) =
      Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space
-> [Int] -> VName -> AllocM fromrep torep (VName, VName)
ensurePermArray (forall a. a -> Maybe a
Just Space
space) [Int
0 .. forall a. ArrayShape a => a -> Int
shapeRank ShapeBase d
shape forall a. Num a => a -> a -> a
- Int
1] VName
v
    normaliseIfNeeded MemInfo d u MemReq
_ SubExpRes
res =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExpRes
res

    resCtx :: SubExpRes -> f Result
resCtx (SubExpRes Certs
_ Constant {}) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    resCtx (SubExpRes Certs
_ (Var VName
v)) = do
      LParamMem
info <- forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
v
      case LParamMem
info of
        MemPrim {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        MemAcc {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        MemMem {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [] -- should not happen
        MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun) -> do
          [SubExp]
ixfun_exts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ixfun_ext" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> [a]
toList IxFun (TPrimExp Int64 VName)
ixfun
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SubExp -> SubExpRes
subExpRes (VName -> SubExp
Var VName
mem) forall a. a -> [a] -> [a]
: [SubExp] -> Result
subExpsRes [SubExp]
ixfun_exts

-- Do a a simple form of invariance analysis to simplify a Match.  It
-- is unfortunate that we have to do it here, but functions such as
-- scalarRes will look carefully at the index functions before the
-- simplifier has a chance to run.  In a perfect world we would
-- simplify away those copies afterwards. XXX; this should be fixed by
-- a more general copy-removal pass. See
-- Futhark.Optimise.EntryPointMem for a very specialised version of
-- the idea, but which could perhaps be generalised.
simplifyMatch ::
  (Mem rep inner) =>
  [Case (Body rep)] ->
  Body rep ->
  [BranchTypeMem] ->
  ( [Case (Body rep)],
    Body rep,
    [BranchTypeMem]
  )
simplifyMatch :: forall rep (inner :: * -> *).
Mem rep inner =>
[Case (Body rep)]
-> Body rep
-> [BranchTypeMem]
-> ([Case (Body rep)], Body rep, [BranchTypeMem])
simplifyMatch [Case (Body rep)]
cases Body rep
defbody [BranchTypeMem]
ts =
  let case_reses :: [Result]
case_reses = forall a b. (a -> b) -> [a] -> [b]
map (forall rep. Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
      defbody_res :: Result
defbody_res = forall rep. Body rep -> Result
bodyResult Body rep
defbody
      ([(Int, SubExp)]
ctx_fixes, [(Result, SubExpRes, BranchTypeMem)]
variant) =
        forall a b. [Either a b] -> ([a], [b])
partitionEithers forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (Int, Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
branchInvariant forall a b. (a -> b) -> a -> b
$
          forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Int
0 ..] (forall a. [[a]] -> [[a]]
transpose [Result]
case_reses) Result
defbody_res [BranchTypeMem]
ts
      ([Result]
cases_reses, Result
defbody_reses, [BranchTypeMem]
ts') = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Result, SubExpRes, BranchTypeMem)]
variant
   in ( forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {f :: * -> *} {rep}.
Functor f =>
f (Body rep) -> Result -> f (Body rep)
onCase [Case (Body rep)]
cases (forall a. [[a]] -> [[a]]
transpose [Result]
cases_reses),
        forall {rep}. Body rep -> Result -> Body rep
onBody Body rep
defbody Result
defbody_reses,
        forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [BranchTypeMem]
ts' [(Int, SubExp)]
ctx_fixes
      )
  where
    bound_in_branches :: Names
bound_in_branches =
      [VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Pat (LetDec rep)
stmPat) forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall rep. Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Stms rep
bodyStms Body rep
defbody

    onCase :: f (Body rep) -> Result -> f (Body rep)
onCase f (Body rep)
c Result
res = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {rep}. Body rep -> Result -> Body rep
`onBody` Result
res) f (Body rep)
c
    onBody :: Body rep -> Result -> Body rep
onBody Body rep
body Result
res = Body rep
body {bodyResult :: Result
bodyResult = Result
res}

    branchInvariant :: (Int, Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
branchInvariant (Int
i, Result
case_reses, SubExpRes
defres, BranchTypeMem
t)
      -- If even one branch has a variant result, then we give up.
      | Names -> Names -> Bool
namesIntersect Names
bound_in_branches forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ SubExpRes
defres forall a. a -> [a] -> [a]
: Result
case_reses =
          forall a b. b -> Either a b
Right (Result
case_reses, SubExpRes
defres, BranchTypeMem
t)
      -- Do all branches return the same value?
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
defres) forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
case_reses =
          forall a b. a -> Either a b
Left (Int
i, SubExpRes -> SubExp
resSubExp SubExpRes
defres)
      | Bool
otherwise =
          forall a b. b -> Either a b
Right (Result
case_reses, SubExpRes
defres, BranchTypeMem
t)

allocInExp ::
  (Allocable fromrep torep inner) =>
  Exp fromrep ->
  AllocM fromrep torep (Exp torep)
allocInExp :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp (Loop [(FParam fromrep, SubExp)]
merge LoopForm
form (Body () Stms fromrep
bodystms Result
bodyres)) =
  forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
merge forall a b. (a -> b) -> a -> b
$ \[(FParam torep, SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val -> do
    forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form) forall a b. (a -> b) -> a -> b
$ do
      Body torep
body' <-
        forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
bodystms forall a b. (a -> b) -> a -> b
$ do
          ([SubExp]
valctx, [SubExp]
valres') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
bodyres
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
valctx forall a. Semigroup a => a -> a -> a
<> forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Certs -> SubExp -> SubExpRes
SubExpRes (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts Result
bodyres) [SubExp]
valres'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(FParam torep, SubExp)]
merge' LoopForm
form Body torep
body'
allocInExp (Apply Name
fname [(SubExp, Diet)]
args [(RetType fromrep, RetAls)]
rettype (Safety, SrcLoc, [SrcLoc])
loc) = do
  [(SubExp, Diet)]
args' <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args
  Space
space <- forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  -- We assume that every array is going to be in its own memory.
  let num_extra_args :: Int
num_extra_args = forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SubExp, Diet)]
args' forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SubExp, Diet)]
args
      rettype' :: [(RetTypeMem, RetAls)]
rettype' =
        Space -> [(RetTypeMem, RetAls)]
mems Space
space
          forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip
            (Space -> Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Space
space Int
num_arrays (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(RetType fromrep, RetAls)]
rettype))
            (forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> RetAls -> RetAls
shiftRetAls Int
num_extra_args Int
num_arrays forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(RetType fromrep, RetAls)]
rettype)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Name
-> [(SubExp, Diet)]
-> [(RetType rep, RetAls)]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp rep
Apply Name
fname [(SubExp, Diet)]
args' [(RetTypeMem, RetAls)]
rettype' (Safety, SrcLoc, [SrcLoc])
loc
  where
    mems :: Space -> [(RetTypeMem, RetAls)]
mems Space
space = forall a. Int -> a -> [a]
replicate Int
num_arrays (forall d u ret. Space -> MemInfo d u ret
MemMem Space
space, [Int] -> [Int] -> RetAls
RetAls forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty)
    num_arrays :: Int
num_arrays = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Ord a => a -> a -> Bool
> Int
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(RetType fromrep, RetAls)]
rettype
allocInExp (Match [SubExp]
ses [Case (Body fromrep)]
cases Body fromrep
defbody (MatchDec [BranchType fromrep]
rets MatchSort
ifsort)) = do
  (Body torep
defbody', [MemReqType]
def_reqs) <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [BranchType fromrep]
rets Body fromrep
defbody
  ([Case (Body torep)]
cases', [[MemReqType]]
cases_reqs) <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM Case (Body fromrep)
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
onCase [Case (Body fromrep)]
cases
  let reqs :: [MemReqType]
reqs = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl MemReqType -> MemReqType -> MemReqType
combMemReqTypes) [MemReqType]
def_reqs (forall a. [[a]] -> [[a]]
transpose [[MemReqType]]
cases_reqs)
  Body torep
defbody'' <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs Body torep
defbody'
  [Case (Body torep)]
cases'' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs) [Case (Body torep)]
cases'
  let ([Case (Body torep)]
cases''', Body torep
defbody''', [BranchTypeMem]
rets') =
        forall rep (inner :: * -> *).
Mem rep inner =>
[Case (Body rep)]
-> Body rep
-> [BranchTypeMem]
-> ([Case (Body rep)], Body rep, [BranchTypeMem])
simplifyMatch [Case (Body torep)]
cases'' Body torep
defbody'' forall a b. (a -> b) -> a -> b
$ [MemReqType] -> [BranchTypeMem]
mkBranchRet [MemReqType]
reqs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses [Case (Body torep)]
cases''' Body torep
defbody''' forall a b. (a -> b) -> a -> b
$ forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchTypeMem]
rets' MatchSort
ifsort
  where
    onCase :: Case (Body fromrep)
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
onCase (Case [Maybe PrimValue]
vs Body fromrep
body) = forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [BranchType fromrep]
rets Body fromrep
body
allocInExp (WithAcc [WithAccInput fromrep]
inputs Lambda fromrep
bodylam) =
  forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {rep} {fromrep} {inner :: * -> *} {t :: * -> *} {a} {b}.
(BodyDec rep ~ (), BodyDec fromrep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo rep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType rep ~ RetTypeMem,
 LParamInfo fromrep ~ Type, LParamInfo rep ~ LParamMem,
 BranchType fromrep ~ ExtType, BranchType rep ~ BranchTypeMem,
 LetDec rep ~ LParamMem, ExpDec rep ~ (), OpC rep ~ MemOp inner,
 Traversable t, ArrayShape a, HasLetDecMem (LetDec rep),
 BuilderOps rep, OpReturns (inner rep), RephraseOp inner,
 SizeSubst (inner rep), PrettyRep fromrep) =>
(a, [VName], t (Lambda fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput [WithAccInput fromrep]
inputs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {torep} {fromrep} {inner :: * -> *}.
(LetDec torep ~ LParamMem, RetType fromrep ~ DeclExtType,
 RetType torep ~ RetTypeMem, ExpDec torep ~ (),
 BodyDec fromrep ~ (), BodyDec torep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo torep ~ FParamMem,
 LParamInfo fromrep ~ Type, LParamInfo torep ~ LParamMem,
 BranchType fromrep ~ ExtType, BranchType torep ~ BranchTypeMem,
 OpC torep ~ MemOp inner, PrettyRep fromrep,
 HasLetDecMem (LetDec torep), OpReturns (inner torep),
 RephraseOp inner, SizeSubst (inner torep), BuilderOps torep) =>
Lambda fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
bodylam
  where
    onLambda :: Lambda fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
lam = do
      [Param LParamMem]
params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda fromrep
lam) forall a b. (a -> b) -> a -> b
$ \(Param Attrs
attrs VName
pv Type
t) ->
        case Type
t of
          Prim PrimType
Unit -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
Unit
          Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
          Type
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Unexpected WithAcc lambda param: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString (forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv Type
t)
      forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [Param LParamMem]
params (forall rep. Lambda rep -> Body rep
lambdaBody Lambda fromrep
lam)

    onInput :: (a, [VName], t (Lambda fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput (a
shape, [VName]
arrs, t (Lambda fromrep, b)
op) =
      (a
shape,[VName]
arrs,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall {rep} {fromrep} {inner :: * -> *} {a} {b}.
(BodyDec rep ~ (), BodyDec fromrep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo rep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType rep ~ RetTypeMem,
 LParamInfo fromrep ~ Type, LParamInfo rep ~ LParamMem,
 BranchType fromrep ~ ExtType, BranchType rep ~ BranchTypeMem,
 ExpDec rep ~ (), LetDec rep ~ LParamMem, OpC rep ~ MemOp inner,
 ArrayShape a, HasLetDecMem (LetDec rep), BuilderOps rep,
 PrettyRep fromrep, OpReturns (inner rep), RephraseOp inner,
 SizeSubst (inner rep)) =>
a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
shape [VName]
arrs) t (Lambda fromrep, b)
op

    onOp :: a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
accshape [VName]
arrs (Lambda fromrep
lam, b
nes) = do
      let num_vs :: Int
num_vs = forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda fromrep
lam)
          num_is :: Int
num_is = forall a. ArrayShape a => a -> Int
shapeRank a
accshape
          ([Param Type]
i_params, [Param Type]
x_params, [Param Type]
y_params) =
            forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 Int
num_is Int
num_vs forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda fromrep
lam
          i_params' :: [Param LParamMem]
i_params' = forall a b. (a -> b) -> [a] -> [b]
map (\(Param Attrs
attrs VName
v Type
_) -> forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
v forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64) [Param Type]
i_params
          is :: [DimIndex SubExp]
is = forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param LParamMem]
i_params'
      [Param LParamMem]
x_params' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (forall {rep} {inner :: * -> *} {f :: * -> *} {u}.
(FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem,
 RetType rep ~ RetTypeMem, BranchType rep ~ BranchTypeMem,
 OpC rep ~ MemOp inner, Monad f, HasLetDecMem (LetDec rep),
 ASTRep rep, OpReturns (inner rep), RephraseOp inner,
 HasScope rep f, Pretty u) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
is) [Param Type]
x_params [VName]
arrs
      [Param LParamMem]
y_params' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (forall {rep} {fromrep} {inner :: * -> *} {u}.
(ExpDec rep ~ (), BodyDec fromrep ~ (), BodyDec rep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo rep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType rep ~ RetTypeMem,
 LetDec rep ~ LParamMem, LParamInfo fromrep ~ Type,
 LParamInfo rep ~ LParamMem, BranchType fromrep ~ ExtType,
 BranchType rep ~ BranchTypeMem, OpC rep ~ MemOp inner,
 PrettyRep fromrep, HasLetDecMem (LetDec rep),
 OpReturns (inner rep), RephraseOp inner, SizeSubst (inner rep),
 BuilderOps rep, Pretty u) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
is) [Param Type]
y_params [VName]
arrs
      Lambda rep
lam' <-
        forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda
          ([Param LParamMem]
i_params' forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
x_params' forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
y_params')
          (forall rep. Lambda rep -> Body rep
lambdaBody Lambda fromrep
lam)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep
lam', b
nes)

    mkP :: Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is =
      forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape u
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
ixfun forall a b. (a -> b) -> a -> b
$
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$
          forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
            [DimIndex SubExp]
is forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall d. ShapeBase d -> [d]
shapeDims Shape
shape)

    onXParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
    onXParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
arr
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {u}.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is
    onXParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Param (TypeBase Shape u)
p

    onYParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    onYParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      Space
space <- forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
      VName
mem <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
arr_t Space
space
      let base_dims :: [TPrimExp Int64 VName]
base_dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
          ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
base_dims
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {u}.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is
    onYParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Param (TypeBase Shape u)
p
allocInExp Exp fromrep
e = forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM forall {fromrep} {trep}. Mapper fromrep trep (AllocM fromrep trep)
alloc Exp fromrep
e
  where
    alloc :: Mapper fromrep trep (AllocM fromrep trep)
alloc =
      forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope trep -> Body fromrep -> AllocM fromrep trep (Body trep)
mapOnBody = forall a. HasCallStack => String -> a
error String
"Unhandled Body in ExplicitAllocations",
          mapOnRetType :: RetType fromrep -> AllocM fromrep trep (RetType trep)
mapOnRetType = forall a. HasCallStack => String -> a
error String
"Unhandled RetType in ExplicitAllocations",
          mapOnBranchType :: BranchType fromrep -> AllocM fromrep trep (BranchType trep)
mapOnBranchType = forall a. HasCallStack => String -> a
error String
"Unhandled BranchType in ExplicitAllocations",
          mapOnFParam :: FParam fromrep -> AllocM fromrep trep (FParam trep)
mapOnFParam = forall a. HasCallStack => String -> a
error String
"Unhandled FParam in ExplicitAllocations",
          mapOnLParam :: LParam fromrep -> AllocM fromrep trep (LParam trep)
mapOnLParam = forall a. HasCallStack => String -> a
error String
"Unhandled LParam in ExplicitAllocations",
          mapOnOp :: Op fromrep -> AllocM fromrep trep (Op trep)
mapOnOp = \Op fromrep
op -> do
            Op fromrep -> AllocM fromrep trep (Op trep)
handle <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall fromrep torep.
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp
            Op fromrep -> AllocM fromrep trep (Op trep)
handle Op fromrep
op
        }

class SizeSubst op where
  opIsConst :: op -> Bool
  opIsConst = forall a b. a -> b -> a
const Bool
False

instance SizeSubst (NoOp rep)

instance (SizeSubst (op rep)) => SizeSubst (MemOp op rep) where
  opIsConst :: MemOp op rep -> Bool
opIsConst (Inner op rep
op) = forall op. SizeSubst op => op -> Bool
opIsConst op rep
op
  opIsConst MemOp op rep
_ = Bool
False

stmConsts :: (SizeSubst (Op rep)) => Stm rep -> S.Set VName
stmConsts :: forall rep. SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Op Op rep
op))
  | forall op. SizeSubst op => op -> Bool
opIsConst Op rep
op = forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat
stmConsts Stm rep
_ = forall a. Monoid a => a
mempty

mkLetNamesB' ::
  ( LetDec (Rep m) ~ LetDecMem,
    Mem (Rep m) inner,
    MonadBuilder m,
    ExpDec (Rep m) ~ ()
  ) =>
  Space ->
  ExpDec (Rep m) ->
  [VName] ->
  Exp (Rep m) ->
  m (Stm (Rep m))
mkLetNamesB' :: forall (m :: * -> *) (inner :: * -> *).
(LetDec (Rep m) ~ LParamMem, Mem (Rep m) inner, MonadBuilder m,
 ExpDec (Rep m) ~ ()) =>
Space
-> ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' Space
space ExpDec (Rep m)
dec [VName]
names Exp (Rep m)
e = do
  Pat LParamMem
pat <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
space [VName]
names Exp (Rep m)
e [ExpHint]
nohints
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat LParamMem
pat (forall dec. dec -> StmAux dec
defAux ExpDec (Rep m)
dec) Exp (Rep m)
e
  where
    nohints :: [ExpHint]
nohints = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names

mkLetNamesB'' ::
  ( Mem rep inner,
    LetDec rep ~ LetDecMem,
    OpReturns (inner (Engine.Wise rep)),
    ExpDec rep ~ (),
    Rep m ~ Engine.Wise rep,
    HasScope (Engine.Wise rep) m,
    MonadBuilder m,
    AliasedOp (inner (Engine.Wise rep)),
    RephraseOp (MemOp inner),
    Engine.CanBeWise inner
  ) =>
  Space ->
  [VName] ->
  Exp (Engine.Wise rep) ->
  m (Stm (Engine.Wise rep))
mkLetNamesB'' :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, LetDec rep ~ LParamMem,
 OpReturns (inner (Wise rep)), ExpDec rep ~ (), Rep m ~ Wise rep,
 HasScope (Wise rep) m, MonadBuilder m,
 AliasedOp (inner (Wise rep)), RephraseOp (MemOp inner),
 CanBeWise inner) =>
Space -> [VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB'' Space
space [VName]
names Exp (Wise rep)
e = do
  Pat LParamMem
pat <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
space [VName]
names Exp (Wise rep)
e [ExpHint]
nohints
  let pat' :: Pat (LetDec (Wise rep))
pat' = forall rep.
Informing rep =>
Pat (LetDec rep) -> Exp (Wise rep) -> Pat (LetDec (Wise rep))
Engine.addWisdomToPat Pat LParamMem
pat Exp (Wise rep)
e
      dec :: ExpDec (Wise rep)
dec = forall rep.
Informing rep =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (LetDec (Wise rep))
pat' () Exp (Wise rep)
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Wise rep))
pat' (forall dec. dec -> StmAux dec
defAux ExpDec (Wise rep)
dec) Exp (Wise rep)
e
  where
    nohints :: [ExpHint]
nohints = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names

simplifyMemOp ::
  (Engine.SimplifiableRep rep) =>
  ( inner (Engine.Wise rep) ->
    Engine.SimpleM rep (inner (Engine.Wise rep), Stms (Engine.Wise rep))
  ) ->
  MemOp inner (Engine.Wise rep) ->
  Engine.SimpleM rep (MemOp inner (Engine.Wise rep), Stms (Engine.Wise rep))
simplifyMemOp :: forall rep (inner :: * -> *).
SimplifiableRep rep =>
(inner (Wise rep)
 -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> MemOp inner (Wise rep)
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
simplifyMemOp inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
_ (Alloc SubExp
size Space
space) =
  (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
size forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
simplifyMemOp inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
onInner (Inner inner (Wise rep)
k) = do
  (inner (Wise rep)
k', Stms (Wise rep)
hoisted) <- inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
onInner inner (Wise rep)
k
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner inner (Wise rep)
k', Stms (Wise rep)
hoisted)

simplifiable ::
  ( Engine.SimplifiableRep rep,
    LetDec rep ~ LetDecMem,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    Mem (Engine.Wise rep) inner,
    Engine.CanBeWise inner,
    RephraseOp inner,
    IsOp (inner rep),
    OpReturns (inner (Engine.Wise rep)),
    AliasedOp (inner (Engine.Wise rep)),
    IndexOp (inner (Engine.Wise rep))
  ) =>
  (inner (Engine.Wise rep) -> UT.UsageTable) ->
  ( inner (Engine.Wise rep) ->
    Engine.SimpleM rep (inner (Engine.Wise rep), Stms (Engine.Wise rep))
  ) ->
  SimpleOps rep
simplifiable :: forall rep (inner :: * -> *).
(SimplifiableRep rep, LetDec rep ~ LParamMem, ExpDec rep ~ (),
 BodyDec rep ~ (), Mem (Wise rep) inner, CanBeWise inner,
 RephraseOp inner, IsOp (inner rep), OpReturns (inner (Wise rep)),
 AliasedOp (inner (Wise rep)), IndexOp (inner (Wise rep))) =>
(inner (Wise rep) -> UsageTable)
-> (inner (Wise rep)
    -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> SimpleOps rep
simplifiable inner (Wise rep) -> UsageTable
innerUsage inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
simplifyInnerOp =
  forall {k} (rep :: k).
(SymbolTable (Wise rep)
 -> Pat (LetDec (Wise rep))
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (Builder (Wise rep))
-> (Op (Wise rep) -> UsageTable)
-> SimplifyOp rep (Op (Wise rep))
-> SimpleOps rep
SimpleOps forall {rep} {f :: * -> *} {p}.
(ExpDec rep ~ (), Applicative f, ASTRep rep,
 AliasedOp (OpC rep (Wise rep)), CanBeWise (OpC rep)) =>
p
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> f (ExpWisdom, ExpDec rep)
mkExpDecS' forall {rep} {f :: * -> *} {p}.
(BodyDec rep ~ (), Applicative f, ASTRep rep,
 AliasedOp (OpC rep (Wise rep)), CanBeWise (OpC rep)) =>
p -> Stms (Wise rep) -> Result -> f (Body (Wise rep))
mkBodyS' forall {m :: * -> *} {d} {u} {ret} {inner :: * -> *}
       {inner :: * -> *} {rep}.
(BranchType (Rep m) ~ MemInfo d u ret, OpC (Rep m) ~ MemOp inner,
 MonadBuilder m) =>
SubExp -> Pat (LetDec (Rep m)) -> MemOp inner rep -> Maybe (m ())
protectOp MemOp inner (Wise rep) -> UsageTable
opUsage (forall rep (inner :: * -> *).
SimplifiableRep rep =>
(inner (Wise rep)
 -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> MemOp inner (Wise rep)
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
simplifyMemOp inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
simplifyInnerOp)
  where
    mkExpDecS' :: p
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> f (ExpWisdom, ExpDec rep)
mkExpDecS' p
_ Pat (VarWisdom, LetDec rep)
pat Exp (Wise rep)
e =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Informing rep =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (VarWisdom, LetDec rep)
pat () Exp (Wise rep)
e

    mkBodyS' :: p -> Stms (Wise rep) -> Result -> f (Body (Wise rep))
mkBodyS' p
_ Stms (Wise rep)
stms Result
res = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Informing rep =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody () Stms (Wise rep)
stms Result
res

    protectOp :: SubExp -> Pat (LetDec (Rep m)) -> MemOp inner rep -> Maybe (m ())
protectOp SubExp
taken Pat (LetDec (Rep m))
pat (Alloc SubExp
size Space
space) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
      Body (Rep m)
tbody <- forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
size]
      Body (Rep m)
fbody <- forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
      SubExp
size' <-
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"hoisted_alloc_size" forall a b. (a -> b) -> a -> b
$
          forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp
taken] [forall body. [Maybe PrimValue] -> body -> Case body
Case [forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] Body (Rep m)
tbody] Body (Rep m)
fbody forall a b. (a -> b) -> a -> b
$
            forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64] MatchSort
MatchFallback
      forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
size' Space
space
    protectOp SubExp
_ Pat (LetDec (Rep m))
_ MemOp inner rep
_ = forall a. Maybe a
Nothing

    opUsage :: MemOp inner (Wise rep) -> UsageTable
opUsage (Alloc (Var VName
size) Space
_) =
      VName -> UsageTable
UT.sizeUsage VName
size
    opUsage (Alloc SubExp
_ Space
_) =
      forall a. Monoid a => a
mempty
    opUsage (Inner inner (Wise rep)
inner) =
      inner (Wise rep) -> UsageTable
innerUsage inner (Wise rep)
inner

data ExpHint
  = NoHint
  | Hint IxFun Space

defaultExpHints :: (ASTRep rep, HasScope rep m) => Exp rep -> m [ExpHint]
defaultExpHints :: forall rep (m :: * -> *).
(ASTRep rep, HasScope rep m) =>
Exp rep -> m [ExpHint]
defaultExpHints Exp rep
e = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const ExpHint
NoHint) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp rep
e