{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}

-- | A generic transformation for adding memory allocations to a
-- Futhark program.  Specialised by specific representations in
-- submodules.
module Futhark.Pass.ExplicitAllocations
  ( explicitAllocationsGeneric,
    explicitAllocationsInStmsGeneric,
    ExpHint (..),
    defaultExpHints,
    askDefaultSpace,
    Allocable,
    AllocM,
    AllocEnv (..),
    SizeSubst (..),
    allocInStms,
    allocForArray,
    simplifiable,
    arraySizeInBytesExp,
    mkLetNamesB',
    mkLetNamesB'',

    -- * Module re-exports

    --
    -- These are highly likely to be needed by any downstream
    -- users.
    module Control.Monad.Reader,
    module Futhark.MonadFreshNames,
    module Futhark.Pass,
    module Futhark.Tools,
  )
where

import Control.Monad
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor (first)
import Data.Either (partitionEithers)
import Data.Foldable (toList)
import Data.List (foldl', transpose, zip4)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Analysis.SymbolTable (IndexOp)
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR.Mem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Prop.Aliases (AliasedOp)
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Engine (SimpleOps (..))
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep (mkWiseBody)
import Futhark.Pass
import Futhark.Tools
import Futhark.Util (maybeNth, splitAt3)

type Allocable fromrep torep inner =
  ( PrettyRep fromrep,
    PrettyRep torep,
    Mem torep inner,
    LetDec torep ~ LetDecMem,
    FParamInfo fromrep ~ DeclType,
    LParamInfo fromrep ~ Type,
    BranchType fromrep ~ ExtType,
    RetType fromrep ~ DeclExtType,
    BodyDec fromrep ~ (),
    BodyDec torep ~ (),
    ExpDec torep ~ (),
    SizeSubst (inner torep),
    BuilderOps torep
  )

data AllocEnv fromrep torep = AllocEnv
  { -- | Aggressively try to reuse memory in do-loops -
    -- should be True inside kernels, False outside.
    forall fromrep torep. AllocEnv fromrep torep -> Bool
aggressiveReuse :: Bool,
    -- | When allocating memory, put it in this memory space.
    -- This is primarily used to ensure that group-wide
    -- statements store their results in local memory.
    forall fromrep torep. AllocEnv fromrep torep -> Space
allocSpace :: Space,
    -- | The set of names that are known to be constants at
    -- kernel compile time.
    forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts :: S.Set VName,
    forall fromrep torep.
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep),
    forall fromrep torep.
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
  }

-- | Monad for adding allocations to an entire program.
newtype AllocM fromrep torep a
  = AllocM (BuilderT torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a)
  deriving
    ( Functor (AllocM fromrep torep)
Functor (AllocM fromrep torep)
-> (forall a. a -> AllocM fromrep torep a)
-> (forall a b.
    AllocM fromrep torep (a -> b)
    -> AllocM fromrep torep a -> AllocM fromrep torep b)
-> (forall a b c.
    (a -> b -> c)
    -> AllocM fromrep torep a
    -> AllocM fromrep torep b
    -> AllocM fromrep torep c)
-> (forall a b.
    AllocM fromrep torep a
    -> AllocM fromrep torep b -> AllocM fromrep torep b)
-> (forall a b.
    AllocM fromrep torep a
    -> AllocM fromrep torep b -> AllocM fromrep torep a)
-> Applicative (AllocM fromrep torep)
forall a. a -> AllocM fromrep torep a
forall {fromrep} {torep}. Functor (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall fromrep torep a. a -> AllocM fromrep torep a
pure :: forall a. a -> AllocM fromrep torep a
$c<*> :: forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
<*> :: forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
$cliftA2 :: forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
liftA2 :: forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
$c*> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
*> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c<* :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
<* :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
Applicative,
      (forall a b.
 (a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b)
-> (forall a b.
    a -> AllocM fromrep torep b -> AllocM fromrep torep a)
-> Functor (AllocM fromrep torep)
forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
fmap :: forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
$c<$ :: forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
<$ :: forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
Functor,
      Applicative (AllocM fromrep torep)
Applicative (AllocM fromrep torep)
-> (forall a b.
    AllocM fromrep torep a
    -> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b)
-> (forall a b.
    AllocM fromrep torep a
    -> AllocM fromrep torep b -> AllocM fromrep torep b)
-> (forall a. a -> AllocM fromrep torep a)
-> Monad (AllocM fromrep torep)
forall a. a -> AllocM fromrep torep a
forall fromrep torep. Applicative (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
>>= :: forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
$c>> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
>> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$creturn :: forall fromrep torep a. a -> AllocM fromrep torep a
return :: forall a. a -> AllocM fromrep torep a
Monad,
      Monad (AllocM fromrep torep)
AllocM fromrep torep VNameSource
Monad (AllocM fromrep torep)
-> AllocM fromrep torep VNameSource
-> (VNameSource -> AllocM fromrep torep ())
-> MonadFreshNames (AllocM fromrep torep)
VNameSource -> AllocM fromrep torep ()
forall fromrep torep. Monad (AllocM fromrep torep)
forall fromrep torep. AllocM fromrep torep VNameSource
forall fromrep torep. VNameSource -> AllocM fromrep torep ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
$cgetNameSource :: forall fromrep torep. AllocM fromrep torep VNameSource
getNameSource :: AllocM fromrep torep VNameSource
$cputNameSource :: forall fromrep torep. VNameSource -> AllocM fromrep torep ()
putNameSource :: VNameSource -> AllocM fromrep torep ()
MonadFreshNames,
      HasScope torep,
      LocalScope torep,
      MonadReader (AllocEnv fromrep torep)
    )

instance (Allocable fromrep torep inner) => MonadBuilder (AllocM fromrep torep) where
  type Rep (AllocM fromrep torep) = torep

  mkExpDecM :: Pat (LetDec (Rep (AllocM fromrep torep)))
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (ExpDec (Rep (AllocM fromrep torep)))
mkExpDecM Pat (LetDec (Rep (AllocM fromrep torep)))
_ Exp (Rep (AllocM fromrep torep))
_ = () -> AllocM fromrep torep ()
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  mkLetNamesM :: [VName]
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (Stm (Rep (AllocM fromrep torep)))
mkLetNamesM [VName]
names Exp (Rep (AllocM fromrep torep))
e = do
    Space
def_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
    [ExpHint]
hints <- Exp torep -> AllocM fromrep torep [ExpHint]
forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
Exp (Rep (AllocM fromrep torep))
e
    Pat LParamMem
pat <- Space
-> [VName]
-> Exp (Rep (AllocM fromrep torep))
-> [ExpHint]
-> AllocM fromrep torep (Pat LParamMem)
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
def_space [VName]
names Exp (Rep (AllocM fromrep torep))
e [ExpHint]
hints
    Stm torep -> AllocM fromrep torep (Stm torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm torep -> AllocM fromrep torep (Stm torep))
-> Stm torep -> AllocM fromrep torep (Stm torep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec torep)
-> StmAux (ExpDec torep) -> Exp torep -> Stm torep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec torep)
Pat LParamMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) Exp torep
Exp (Rep (AllocM fromrep torep))
e

  mkBodyM :: Stms (Rep (AllocM fromrep torep))
-> Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
mkBodyM Stms (Rep (AllocM fromrep torep))
stms Result
res = Body (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep (AllocM fromrep torep))
 -> AllocM fromrep torep (Body (Rep (AllocM fromrep torep))))
-> Body (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ BodyDec torep -> Stms torep -> Result -> Body torep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms torep
Stms (Rep (AllocM fromrep torep))
stms Result
res

  addStms :: Stms (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
addStms = BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
-> AllocM fromrep torep ()
forall fromrep torep a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM (BuilderT
   torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
 -> AllocM fromrep torep ())
-> (Stms torep
    -> BuilderT
         torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ())
-> Stms torep
-> AllocM fromrep torep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms torep
-> BuilderT
     torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
Stms
  (Rep
     (BuilderT
        torep (ReaderT (AllocEnv fromrep torep) (State VNameSource))))
-> BuilderT
     torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
  collectStms :: forall a.
AllocM fromrep torep a
-> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep)))
collectStms (AllocM BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) = BuilderT
  torep
  (ReaderT (AllocEnv fromrep torep) (State VNameSource))
  (a, Stms (Rep (AllocM fromrep torep)))
-> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep)))
forall fromrep torep a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM (BuilderT
   torep
   (ReaderT (AllocEnv fromrep torep) (State VNameSource))
   (a, Stms (Rep (AllocM fromrep torep)))
 -> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep))))
-> BuilderT
     torep
     (ReaderT (AllocEnv fromrep torep) (State VNameSource))
     (a, Stms (Rep (AllocM fromrep torep)))
-> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> BuilderT
     torep
     (ReaderT (AllocEnv fromrep torep) (State VNameSource))
     (a,
      Stms
        (Rep
           (BuilderT
              torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)))))
forall a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> BuilderT
     torep
     (ReaderT (AllocEnv fromrep torep) (State VNameSource))
     (a,
      Stms
        (Rep
           (BuilderT
              torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m

expHints :: Exp torep -> AllocM fromrep torep [ExpHint]
expHints :: forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e = do
  Exp torep -> AllocM fromrep torep [ExpHint]
f <- (AllocEnv fromrep torep
 -> Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM
     fromrep torep (Exp torep -> AllocM fromrep torep [ExpHint])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
forall fromrep torep.
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints
  Exp torep -> AllocM fromrep torep [ExpHint]
f Exp torep
e

-- | The space in which we allocate memory if we have no other
-- preferences or constraints.
askDefaultSpace :: AllocM fromrep torep Space
askDefaultSpace :: forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace = (AllocEnv fromrep torep -> Space) -> AllocM fromrep torep Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> Space
forall fromrep torep. AllocEnv fromrep torep -> Space
allocSpace

runAllocM ::
  (MonadFreshNames m) =>
  Space ->
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  AllocM fromrep torep a ->
  m a
runAllocM :: forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) =
  ((a, Stms torep) -> a) -> m (a, Stms torep) -> m a
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms torep) -> a
forall a b. (a, b) -> a
fst (m (a, Stms torep) -> m a) -> m (a, Stms torep) -> m a
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((a, Stms torep), VNameSource))
-> m (a, Stms torep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Stms torep), VNameSource))
 -> m (a, Stms torep))
-> (VNameSource -> ((a, Stms torep), VNameSource))
-> m (a, Stms torep)
forall a b. (a -> b) -> a -> b
$ State VNameSource (a, Stms torep)
-> VNameSource -> ((a, Stms torep), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (a, Stms torep)
 -> VNameSource -> ((a, Stms torep), VNameSource))
-> State VNameSource (a, Stms torep)
-> VNameSource
-> ((a, Stms torep), VNameSource)
forall a b. (a -> b) -> a -> b
$ ReaderT
  (AllocEnv fromrep torep) (State VNameSource) (a, Stms torep)
-> AllocEnv fromrep torep -> State VNameSource (a, Stms torep)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> Scope torep
-> ReaderT
     (AllocEnv fromrep torep) (State VNameSource) (a, Stms torep)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m Scope torep
forall a. Monoid a => a
mempty) AllocEnv fromrep torep
env
  where
    env :: AllocEnv fromrep torep
env =
      AllocEnv
        { aggressiveReuse :: Bool
aggressiveReuse = Bool
False,
          allocSpace :: Space
allocSpace = Space
space,
          envConsts :: Set VName
envConsts = Set VName
forall a. Monoid a => a
mempty,
          allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp = Op fromrep -> AllocM fromrep torep (Op torep)
handleOp,
          envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints = Exp torep -> AllocM fromrep torep [ExpHint]
hints
        }

elemSize :: (Num a) => Type -> a
elemSize :: forall a. Num a => Type -> a
elemSize = PrimType -> a
forall a. Num a => PrimType -> a
primByteSize (PrimType -> a) -> (Type -> PrimType) -> Type -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType

arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp Type
t =
  TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName
 -> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(*) (Type -> TPrimExp Int64 VName
forall a. Num a => Type -> a
elemSize Type
t) ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)

arraySizeInBytesExpM :: (MonadBuilder m) => Type -> m (PrimExp VName)
arraySizeInBytesExpM :: forall (m :: * -> *). MonadBuilder m => Type -> m (PrimExp VName)
arraySizeInBytesExpM Type
t = do
  let dim_prod_i64 :: TPrimExp Int64 VName
dim_prod_i64 = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
      elm_size_i64 :: TPrimExp Int64 VName
elm_size_i64 = Type -> TPrimExp Int64 VName
forall a. Num a => Type -> a
elemSize Type
t
  PrimExp VName -> m (PrimExp VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp VName -> m (PrimExp VName))
-> PrimExp VName -> m (PrimExp VName)
forall a b. (a -> b) -> a -> b
$
    BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMax IntType
Int64) (PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> PrimExp VName) -> PrimValue -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
0) (PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName
dim_prod_i64 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elm_size_i64

arraySizeInBytes :: (MonadBuilder m) => Type -> m SubExp
arraySizeInBytes :: forall (m :: * -> *). MonadBuilder m => Type -> m SubExp
arraySizeInBytes = String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" (Exp (Rep m) -> m SubExp)
-> (Type -> m (Exp (Rep m))) -> Type -> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp (PrimExp VName -> m (Exp (Rep m)))
-> (Type -> m (PrimExp VName)) -> Type -> m (Exp (Rep m))
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Type -> m (PrimExp VName)
forall (m :: * -> *). MonadBuilder m => Type -> m (PrimExp VName)
arraySizeInBytesExpM

allocForArray' ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
  Type ->
  Space ->
  m VName
allocForArray' :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
t Space
space = do
  SubExp
size <- Type -> m SubExp
forall (m :: * -> *). MonadBuilder m => Type -> m SubExp
arraySizeInBytes Type
t
  String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner (Rep m)
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
size Space
space

-- | Allocate memory for a value of the given type.
allocForArray ::
  (Allocable fromrep torep inner) =>
  Type ->
  Space ->
  AllocM fromrep torep VName
allocForArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space = do
  Type -> Space -> AllocM fromrep torep VName
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
t Space
space

-- | Repair an expression that cannot be assigned an index function.
-- There is a simple remedy for this: normalise the input arrays and
-- try again.
repairExpression ::
  (Allocable fromrep torep inner) =>
  Exp torep ->
  AllocM fromrep torep (Exp torep)
repairExpression :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp torep -> AllocM fromrep torep (Exp torep)
repairExpression (BasicOp (Reshape ReshapeKind
k Shape
shape VName
v)) = do
  VName
v_mem <- (VName, IxFun (TPrimExp Int64 VName)) -> VName
forall a b. (a, b) -> a
fst ((VName, IxFun (TPrimExp Int64 VName)) -> VName)
-> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> AllocM fromrep torep VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
space <- VName -> AllocM fromrep torep Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem
  VName
v' <- (VName, VName) -> VName
forall a b. (a, b) -> b
snd ((VName, VName) -> VName)
-> AllocM fromrep torep (VName, VName)
-> AllocM fromrep torep VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) VName
v
  Exp torep -> AllocM fromrep torep (Exp torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp torep -> AllocM fromrep torep (Exp torep))
-> Exp torep -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp torep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp torep) -> BasicOp -> Exp torep
forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k Shape
shape VName
v'
repairExpression Exp torep
e =
  String -> AllocM fromrep torep (Exp torep)
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep torep (Exp torep))
-> String -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ String
"repairExpression:\n" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Exp torep -> String
forall a. Pretty a => a -> String
prettyString Exp torep
e

expReturns' ::
  (Allocable fromrep torep inner) =>
  Exp torep ->
  AllocM fromrep torep ([ExpReturns], Exp torep)
expReturns' :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp torep -> AllocM fromrep torep ([ExpReturns], Exp torep)
expReturns' Exp torep
e = do
  Maybe [ExpReturns]
maybe_rts <- Exp torep -> AllocM fromrep torep (Maybe [ExpReturns])
forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m (Maybe [ExpReturns])
expReturns Exp torep
e
  case Maybe [ExpReturns]
maybe_rts of
    Just [ExpReturns]
rts -> ([ExpReturns], Exp torep)
-> AllocM fromrep torep ([ExpReturns], Exp torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExpReturns]
rts, Exp torep
e)
    Maybe [ExpReturns]
Nothing -> do
      Exp torep
e' <- Exp torep -> AllocM fromrep torep (Exp torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp torep -> AllocM fromrep torep (Exp torep)
repairExpression Exp torep
e
      let bad :: [ExpReturns]
bad =
            String -> [ExpReturns]
forall a. HasCallStack => String -> a
error (String -> [ExpReturns])
-> ([String] -> String) -> [String] -> [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unlines ([String] -> [ExpReturns]) -> [String] -> [ExpReturns]
forall a b. (a -> b) -> a -> b
$
              [ String
"expReturns': impossible index transformation",
                Exp torep -> String
forall a. Pretty a => a -> String
prettyString Exp torep
e,
                Exp torep -> String
forall a. Pretty a => a -> String
prettyString Exp torep
e'
              ]
      [ExpReturns]
rts <- [ExpReturns] -> Maybe [ExpReturns] -> [ExpReturns]
forall a. a -> Maybe a -> a
fromMaybe [ExpReturns]
bad (Maybe [ExpReturns] -> [ExpReturns])
-> AllocM fromrep torep (Maybe [ExpReturns])
-> AllocM fromrep torep [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp torep -> AllocM fromrep torep (Maybe [ExpReturns])
forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m (Maybe [ExpReturns])
expReturns Exp torep
e'
      ([ExpReturns], Exp torep)
-> AllocM fromrep torep ([ExpReturns], Exp torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExpReturns]
rts, Exp torep
e')

allocsForStm ::
  (Allocable fromrep torep inner) =>
  [Ident] ->
  Exp torep ->
  AllocM fromrep torep (Stm torep)
allocsForStm :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm [Ident]
idents Exp torep
e = do
  Space
def_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  [ExpHint]
hints <- Exp torep -> AllocM fromrep torep [ExpHint]
forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e
  ([ExpReturns]
rts, Exp torep
e') <- Exp torep -> AllocM fromrep torep ([ExpReturns], Exp torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp torep -> AllocM fromrep torep ([ExpReturns], Exp torep)
expReturns' Exp torep
e
  [PatElem LParamMem]
pes <- Space
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> AllocM fromrep torep [PatElem LParamMem]
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
idents [ExpReturns]
rts [ExpHint]
hints
  ()
dec <- Pat (LetDec (Rep (AllocM fromrep torep)))
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (ExpDec (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
pes) Exp torep
Exp (Rep (AllocM fromrep torep))
e'
  Stm torep -> AllocM fromrep torep (Stm torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm torep -> AllocM fromrep torep (Stm torep))
-> Stm torep -> AllocM fromrep torep (Stm torep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec torep)
-> StmAux (ExpDec torep) -> Exp torep -> Stm torep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
pes) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()
dec) Exp torep
e'

patWithAllocations ::
  (MonadBuilder m, Mem (Rep m) inner) =>
  Space ->
  [VName] ->
  Exp (Rep m) ->
  [ExpHint] ->
  m (Pat LetDecMem)
patWithAllocations :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
def_space [VName]
names Exp (Rep m)
e [ExpHint]
hints = do
  [Type]
ts' <- [VName] -> [ExtType] -> [Type]
forall u.
[VName]
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> [TypeBase Shape u]
instantiateShapes' [VName]
names ([ExtType] -> [Type]) -> m [ExtType] -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp (Rep m) -> m [ExtType]
forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp (Rep m)
e
  let idents :: [Ident]
idents = (VName -> Type -> Ident) -> [VName] -> [Type] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
names [Type]
ts'
  [ExpReturns]
rts <- [ExpReturns] -> Maybe [ExpReturns] -> [ExpReturns]
forall a. a -> Maybe a -> a
fromMaybe (String -> [ExpReturns]
forall a. HasCallStack => String -> a
error String
"patWithAllocations: ill-typed") (Maybe [ExpReturns] -> [ExpReturns])
-> m (Maybe [ExpReturns]) -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp (Rep m) -> m (Maybe [ExpReturns])
forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m (Maybe [ExpReturns])
expReturns Exp (Rep m)
e
  [PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem LParamMem] -> Pat LParamMem)
-> m [PatElem LParamMem] -> m (Pat LParamMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
idents [ExpReturns]
rts [ExpHint]
hints

mkMissingIdents :: (MonadFreshNames m) => [Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents :: forall (m :: * -> *).
MonadFreshNames m =>
[Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
idents [ExpReturns]
rts =
  [Ident] -> [Ident]
forall a. [a] -> [a]
reverse ([Ident] -> [Ident]) -> m [Ident] -> m [Ident]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ExpReturns -> Maybe Ident -> m Ident)
-> [ExpReturns] -> [Maybe Ident] -> m [Ident]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ExpReturns -> Maybe Ident -> m Ident
forall {f :: * -> *} {d} {u} {ret}.
MonadFreshNames f =>
MemInfo d u ret -> Maybe Ident -> f Ident
f ([ExpReturns] -> [ExpReturns]
forall a. [a] -> [a]
reverse [ExpReturns]
rts) ((Ident -> Maybe Ident) -> [Ident] -> [Maybe Ident]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Maybe Ident
forall a. a -> Maybe a
Just ([Ident] -> [Ident]
forall a. [a] -> [a]
reverse [Ident]
idents) [Maybe Ident] -> [Maybe Ident] -> [Maybe Ident]
forall a. [a] -> [a] -> [a]
++ Maybe Ident -> [Maybe Ident]
forall a. a -> [a]
repeat Maybe Ident
forall a. Maybe a
Nothing)
  where
    f :: MemInfo d u ret -> Maybe Ident -> f Ident
f MemInfo d u ret
_ (Just Ident
ident) = Ident -> f Ident
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ident
ident
    f (MemMem Space
space) Maybe Ident
Nothing = String -> Type -> f Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext_mem" (Type -> f Ident) -> Type -> f Ident
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
    f MemInfo d u ret
_ Maybe Ident
Nothing = String -> Type -> f Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext" (Type -> f Ident) -> Type -> f Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

allocsForPat ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
  Space ->
  [Ident] ->
  [ExpReturns] ->
  [ExpHint] ->
  m [PatElem LetDecMem]
allocsForPat :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
some_idents [ExpReturns]
rts [ExpHint]
hints = do
  [Ident]
idents <- [Ident] -> [ExpReturns] -> m [Ident]
forall (m :: * -> *).
MonadFreshNames m =>
[Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
some_idents [ExpReturns]
rts

  [(Ident, ExpReturns, ExpHint)]
-> ((Ident, ExpReturns, ExpHint) -> m (PatElem LParamMem))
-> m [PatElem LParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Ident]
-> [ExpReturns] -> [ExpHint] -> [(Ident, ExpReturns, ExpHint)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Ident]
idents [ExpReturns]
rts [ExpHint]
hints) (((Ident, ExpReturns, ExpHint) -> m (PatElem LParamMem))
 -> m [PatElem LParamMem])
-> ((Ident, ExpReturns, ExpHint) -> m (PatElem LParamMem))
-> m [PatElem LParamMem]
forall a b. (a -> b) -> a -> b
$ \(Ident
ident, ExpReturns
rt, ExpHint
hint) -> do
    let ident_shape :: Shape
ident_shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Type -> Shape
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident
    case ExpReturns
rt of
      MemPrim PrimType
_ -> do
        LParamMem
summary <- Space -> Type -> ExpHint -> m LParamMem
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
def_space (Ident -> Type
identType Ident
ident) ExpHint
hint
        PatElem LParamMem -> m (PatElem LParamMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> PatElem LParamMem -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
      MemMem Space
space ->
        PatElem LParamMem -> m (PatElem LParamMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> PatElem LParamMem -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElem LParamMem) -> LParamMem -> PatElem LParamMem
forall a b. (a -> b) -> a -> b
$ Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
      MemArray PrimType
bt ShapeBase (Ext SubExp)
_ NoUniqueness
u (Just (ReturnsInBlock VName
mem ExtIxFun
extixfun)) -> do
        let ixfn :: IxFun (TPrimExp Int64 VName)
ixfn = [Ident] -> ExtIxFun -> IxFun (TPrimExp Int64 VName)
forall {f :: * -> *} {f :: * -> *}.
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents ExtIxFun
extixfun
        PatElem LParamMem -> m (PatElem LParamMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> (MemBind -> PatElem LParamMem)
-> MemBind
-> m (PatElem LParamMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElem LParamMem)
-> (MemBind -> LParamMem) -> MemBind -> PatElem LParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u (MemBind -> m (PatElem LParamMem))
-> MemBind -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfn
      MemArray PrimType
_ ShapeBase (Ext SubExp)
extshape NoUniqueness
_ Maybe MemReturn
Nothing
        | Just [SubExp]
_ <- ShapeBase (Ext SubExp) -> Maybe [SubExp]
forall {b}. ShapeBase (Ext b) -> Maybe [b]
knownShape ShapeBase (Ext SubExp)
extshape -> do
            LParamMem
summary <- Space -> Type -> ExpHint -> m LParamMem
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
def_space (Ident -> Type
identType Ident
ident) ExpHint
hint
            PatElem LParamMem -> m (PatElem LParamMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> PatElem LParamMem -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
      MemArray PrimType
bt ShapeBase (Ext SubExp)
_ NoUniqueness
u (Just (ReturnsNewBlock Space
_ Int
i ExtIxFun
extixfn)) -> do
        let ixfn :: IxFun (TPrimExp Int64 VName)
ixfn = [Ident] -> ExtIxFun -> IxFun (TPrimExp Int64 VName)
forall {f :: * -> *} {f :: * -> *}.
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents ExtIxFun
extixfn
        PatElem LParamMem -> m (PatElem LParamMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> (MemBind -> PatElem LParamMem)
-> MemBind
-> m (PatElem LParamMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElem LParamMem)
-> (MemBind -> LParamMem) -> MemBind -> PatElem LParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u (MemBind -> m (PatElem LParamMem))
-> MemBind -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$
          VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn ([Ident] -> Int -> VName
forall {a}. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i) IxFun (TPrimExp Int64 VName)
ixfn
      MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
        PatElem LParamMem -> m (PatElem LParamMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> PatElem LParamMem -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElem LParamMem) -> LParamMem -> PatElem LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
      ExpReturns
_ -> String -> m (PatElem LParamMem)
forall a. HasCallStack => String -> a
error String
"Impossible case reached in allocsForPat!"
  where
    knownShape :: ShapeBase (Ext b) -> Maybe [b]
knownShape = (Ext b -> Maybe b) -> [Ext b] -> Maybe [b]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Ext b -> Maybe b
forall {a}. Ext a -> Maybe a
known ([Ext b] -> Maybe [b])
-> (ShapeBase (Ext b) -> [Ext b]) -> ShapeBase (Ext b) -> Maybe [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase (Ext b) -> [Ext b]
forall d. ShapeBase d -> [d]
shapeDims
    known :: Ext a -> Maybe a
known (Free a
v) = a -> Maybe a
forall a. a -> Maybe a
Just a
v
    known Ext {} = Maybe a
forall a. Maybe a
Nothing

    getIdent :: [Ident] -> a -> VName
getIdent [Ident]
idents a
i =
      case a -> [Ident] -> Maybe Ident
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth a
i [Ident]
idents of
        Just Ident
ident -> Ident -> VName
identName Ident
ident
        Maybe Ident
Nothing ->
          String -> VName
forall a. HasCallStack => String -> a
error (String -> VName) -> String -> VName
forall a b. (a -> b) -> a -> b
$ String
"getIdent: Ext " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
i String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" but pattern has " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([Ident] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
idents) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" elements: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> [Ident] -> String
forall a. Pretty a => a -> String
prettyString [Ident]
idents

    instantiateExtIxFun :: [Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents = (f (Ext VName) -> f VName) -> f (f (Ext VName)) -> f (f VName)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((f (Ext VName) -> f VName) -> f (f (Ext VName)) -> f (f VName))
-> (f (Ext VName) -> f VName) -> f (f (Ext VName)) -> f (f VName)
forall a b. (a -> b) -> a -> b
$ (Ext VName -> VName) -> f (Ext VName) -> f VName
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> VName
inst
      where
        inst :: Ext VName -> VName
inst (Free VName
v) = VName
v
        inst (Ext Int
i) = [Ident] -> Int -> VName
forall {a}. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i

instantiateIxFun :: (Monad m) => ExtIxFun -> m IxFun
instantiateIxFun :: forall (m :: * -> *).
Monad m =>
ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
instantiateIxFun = (TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
-> ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IxFun a -> f (IxFun b)
traverse ((TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
 -> ExtIxFun -> m (IxFun (TPrimExp Int64 VName)))
-> (TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
-> ExtIxFun
-> m (IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ (Ext VName -> m VName)
-> TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> TPrimExp Int64 a -> f (TPrimExp Int64 b)
traverse Ext VName -> m VName
forall {f :: * -> *} {a}. Applicative f => Ext a -> f a
inst
  where
    inst :: Ext a -> f a
inst Ext {} = String -> f a
forall a. HasCallStack => String -> a
error String
"instantiateIxFun: not yet"
    inst (Free a
x) = a -> f a
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

summaryForBindage ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
  Space ->
  Type ->
  ExpHint ->
  m (MemBound NoUniqueness)
summaryForBindage :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
_ (Prim PrimType
bt) ExpHint
_ =
  LParamMem -> m LParamMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
summaryForBindage Space
_ (Mem Space
space) ExpHint
_ =
  LParamMem -> m LParamMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
summaryForBindage Space
_ (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) ExpHint
_ =
  LParamMem -> m LParamMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
summaryForBindage Space
def_space t :: Type
t@(Array PrimType
pt Shape
shape NoUniqueness
u) ExpHint
NoHint = do
  VName
m <- Type -> Space -> m VName
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
t Space
def_space
  LParamMem -> m LParamMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m (IxFun (TPrimExp Int64 VName) -> MemBind)
-> IxFun (TPrimExp Int64 VName) -> MemBind
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
summaryForBindage Space
_ t :: Type
t@(Array PrimType
pt Shape
_ NoUniqueness
_) (Hint IxFun (TPrimExp Int64 VName)
ixfun Space
space) = do
  SubExp
bytes <-
    String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" (Exp (Rep m) -> m SubExp)
-> (TPrimExp Int64 VName -> m (Exp (Rep m)))
-> TPrimExp Int64 VName
-> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp (PrimExp VName -> m (Exp (Rep m)))
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> m (Exp (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> m SubExp)
-> TPrimExp Int64 VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
      [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product
        [ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun,
          Int64 -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral (PrimType -> Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
pt :: Int64)
        ]
  VName
m <- String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner (Rep m)
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
bytes Space
space
  LParamMem -> m LParamMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int64 VName)
ixfun

allocInFParams ::
  (Allocable fromrep torep inner) =>
  [(FParam fromrep, Space)] ->
  ([FParam torep] -> AllocM fromrep torep a) ->
  AllocM fromrep torep a
allocInFParams :: forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams [(FParam fromrep, Space)]
params [FParam torep] -> AllocM fromrep torep a
m = do
  ([Param FParamMem]
valparams, ([Param FParamMem]
memparams, [Param FParamMem]
ctxparams)) <-
    WriterT
  ([Param FParamMem], [Param FParamMem])
  (AllocM fromrep torep)
  [Param FParamMem]
-> AllocM
     fromrep
     torep
     ([Param FParamMem], ([Param FParamMem], [Param FParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param FParamMem], [Param FParamMem])
   (AllocM fromrep torep)
   [Param FParamMem]
 -> AllocM
      fromrep
      torep
      ([Param FParamMem], ([Param FParamMem], [Param FParamMem])))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [Param FParamMem]
-> AllocM
     fromrep
     torep
     ([Param FParamMem], ([Param FParamMem], [Param FParamMem]))
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, Space)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> [(Param DeclType, Space)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [Param FParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Param DeclType
 -> Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> (Param DeclType, Space)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Param DeclType
-> Space
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam) [(Param DeclType, Space)]
[(FParam fromrep, Space)]
params
  let params' :: [Param FParamMem]
params' = [Param FParamMem]
memparams [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctxparams [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = [Param FParamMem] -> Scope torep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
params'
  Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall a.
Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [FParam torep] -> AllocM fromrep torep a
m [FParam torep]
[Param FParamMem]
params'

allocInFParam ::
  (Allocable fromrep torep inner) =>
  FParam fromrep ->
  Space ->
  WriterT
    ([FParam torep], [FParam torep])
    (AllocM fromrep torep)
    (FParam torep)
allocInFParam :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam FParam fromrep
param Space
pspace =
  case Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
FParam fromrep
param of
    Array PrimType
pt Shape
shape Uniqueness
u -> do
      let memname :: String
memname = VName -> String
baseString (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
FParam fromrep
param) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_mem"
          ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName
mem <- AllocM fromrep torep VName
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) VName
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep VName
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      VName)
-> AllocM fromrep torep VName
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) VName
forall a b. (a -> b) -> a -> b
$ String -> AllocM fromrep torep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
memname
      ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Attrs -> VName -> FParamMem -> Param FParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param (Param DeclType -> Attrs
forall dec. Param dec -> Attrs
paramAttrs Param DeclType
FParam fromrep
param) VName
mem (FParamMem -> Param FParamMem) -> FParamMem -> Param FParamMem
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
pspace], [])
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a.
a
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = PrimType -> Shape -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u (MemBind -> FParamMem) -> MemBind -> FParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun}
    Prim PrimType
pt ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a.
a
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt}
    Mem Space
space ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a.
a
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}
    Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a.
a
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = VName -> Shape -> [Type] -> Uniqueness -> FParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u}

ensureRowMajorArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensureRowMajorArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray Maybe Space
space_ok VName
v = do
  (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- VName -> AllocM fromrep torep Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  let space :: Space
space = Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok
  if [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun (TPrimExp Int64 VName) -> Int
forall num. IntegralExp num => IxFun num -> Int
IxFun.rank IxFun (TPrimExp Int64 VName)
ixfun
    Bool -> Bool -> Bool
&& Bool -> (Space -> Bool) -> Maybe Space -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then (VName, VName) -> AllocM fromrep torep (VName, VName)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
    else Space -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

ensureArrayIn ::
  (Allocable fromrep torep inner) =>
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
_ (Constant PrimValue
v) =
  String
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a. HasCallStack => String -> a
error (String
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> String
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
$ String
"ensureArrayIn: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimValue -> String
forall a. Pretty a => a -> String
prettyString PrimValue
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" cannot be an array."
ensureArrayIn Space
space (Var VName
v) = do
  (VName
mem', VName
v') <- AllocM fromrep torep (VName, VName)
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([SubExp], [SubExp]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
 -> WriterT
      ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall a b. (a -> b) -> a -> b
$ Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) VName
v
  (VName
_, IxFun (TPrimExp Int64 VName)
ixfun) <- AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
     ([SubExp], [SubExp])
     (AllocM fromrep torep)
     (VName, IxFun (TPrimExp Int64 VName))
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([SubExp], [SubExp]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
 -> WriterT
      ([SubExp], [SubExp])
      (AllocM fromrep torep)
      (VName, IxFun (TPrimExp Int64 VName)))
-> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
     ([SubExp], [SubExp])
     (AllocM fromrep torep)
     (VName, IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v'
  [SubExp]
ctx <- AllocM fromrep torep [SubExp]
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([SubExp], [SubExp]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep [SubExp]
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp])
-> AllocM fromrep torep [SubExp]
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> AllocM fromrep torep SubExp)
-> [TPrimExp Int64 VName] -> AllocM fromrep torep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String
-> Exp (Rep (AllocM fromrep torep)) -> AllocM fromrep torep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ixfun_arg" (Exp torep -> AllocM fromrep torep SubExp)
-> (TPrimExp Int64 VName -> AllocM fromrep torep (Exp torep))
-> TPrimExp Int64 VName
-> AllocM fromrep torep SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName -> AllocM fromrep torep (Exp torep)
TPrimExp Int64 VName
-> AllocM fromrep torep (Exp (Rep (AllocM fromrep torep)))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp) (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList IxFun (TPrimExp Int64 VName)
ixfun)
  ([SubExp], [SubExp])
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
mem'], [SubExp]
ctx)
  SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a.
a -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'

allocInMergeParams ::
  (Allocable fromrep torep inner) =>
  [(FParam fromrep, SubExp)] ->
  ( [(FParam torep, SubExp)] ->
    ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])) ->
    AllocM fromrep torep a
  ) ->
  AllocM fromrep torep a
allocInMergeParams :: forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
merge [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m = do
  (([Param FParamMem]
valparams, [SubExp]
valargs, [SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps), ([Param FParamMem]
mem_params, [Param FParamMem]
ctx_params)) <-
    WriterT
  ([Param FParamMem], [Param FParamMem])
  (AllocM fromrep torep)
  ([Param FParamMem], [SubExp],
   [SubExp
    -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
-> AllocM
     fromrep
     torep
     (([Param FParamMem], [SubExp],
       [SubExp
        -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]),
      ([Param FParamMem], [Param FParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param FParamMem], [Param FParamMem])
   (AllocM fromrep torep)
   ([Param FParamMem], [SubExp],
    [SubExp
     -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
 -> AllocM
      fromrep
      torep
      (([Param FParamMem], [SubExp],
        [SubExp
         -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]),
       ([Param FParamMem], [Param FParamMem])))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     ([Param FParamMem], [SubExp],
      [SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
-> AllocM
     fromrep
     torep
     (([Param FParamMem], [SubExp],
       [SubExp
        -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]),
      ([Param FParamMem], [Param FParamMem]))
forall a b. (a -> b) -> a -> b
$ [(Param FParamMem, SubExp,
  SubExp
  -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
-> ([Param FParamMem], [SubExp],
    [SubExp
     -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Param FParamMem, SubExp,
   SubExp
   -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
 -> ([Param FParamMem], [SubExp],
     [SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [(Param FParamMem, SubExp,
       SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     ([Param FParamMem], [SubExp],
      [SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param DeclType, SubExp)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem, SubExp,
       SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp))
-> [(Param DeclType, SubExp)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [(Param FParamMem, SubExp,
       SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
(Param DeclType, SubExp)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam [(Param DeclType, SubExp)]
[(FParam fromrep, SubExp)]
merge
  let mergeparams' :: [Param FParamMem]
mergeparams' = [Param FParamMem]
mem_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = [Param FParamMem] -> Scope torep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
mergeparams'

      mk_loop_res :: [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
ses = do
        ([SubExp]
ses', ([SubExp]
memargs, [SubExp]
ctxargs)) <-
          WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
-> AllocM fromrep torep ([SubExp], ([SubExp], [SubExp]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
 -> AllocM fromrep torep ([SubExp], ([SubExp], [SubExp])))
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
-> AllocM fromrep torep ([SubExp], ([SubExp], [SubExp]))
forall a b. (a -> b) -> a -> b
$ ((SubExp
  -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
 -> SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> [SubExp
    -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
-> [SubExp]
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
($) [SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps [SubExp]
ses
        ([SubExp], [SubExp]) -> AllocM fromrep torep ([SubExp], [SubExp])
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
memargs [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
ctxargs, [SubExp]
ses')

  ([SubExp]
valctx_args, [SubExp]
valargs') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
valargs
  let merge' :: [(Param FParamMem, SubExp)]
merge' =
        [Param FParamMem] -> [SubExp] -> [(Param FParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Param FParamMem]
mem_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams) ([SubExp]
valctx_args [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
valargs')
  Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall a.
Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m [(FParam torep, SubExp)]
[(Param FParamMem, SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res
  where
    param_names :: Names
param_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
[(FParam fromrep, SubExp)]
merge
    anyIsLoopParam :: Names -> Bool
anyIsLoopParam Names
names = Names
names Names -> Names -> Bool
`namesIntersect` Names
param_names

    scalarRes :: DeclType
-> Space -> IxFun (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun (Var VName
res) = do
      -- Try really hard to avoid copying needlessly, but the result
      -- _must_ be in ScalarSpace and have the right index function.
      (VName
res_mem, IxFun (TPrimExp Int64 VName)
res_ixfun) <- m (VName, IxFun (TPrimExp Int64 VName))
-> t m (VName, IxFun (TPrimExp Int64 VName))
forall (m :: * -> *) a. Monad m => m a -> t m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, IxFun (TPrimExp Int64 VName))
 -> t m (VName, IxFun (TPrimExp Int64 VName)))
-> m (VName, IxFun (TPrimExp Int64 VName))
-> t m (VName, IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ VName -> m (VName, IxFun (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
res
      Space
res_mem_space <- m Space -> t m Space
forall (m :: * -> *) a. Monad m => m a -> t m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Space -> t m Space) -> m Space -> t m Space
forall a b. (a -> b) -> a -> b
$ VName -> m Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
res_mem
      (VName
res_mem', VName
res') <-
        if (Space
res_mem_space, IxFun (TPrimExp Int64 VName)
res_ixfun) (Space, IxFun (TPrimExp Int64 VName))
-> (Space, IxFun (TPrimExp Int64 VName)) -> Bool
forall a. Eq a => a -> a -> Bool
== (Space
v_mem_space, IxFun (TPrimExp Int64 VName)
v_ixfun)
          then (VName, VName) -> t m (VName, VName)
forall a. a -> t m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
res_mem, VName
res)
          else m (VName, VName) -> t m (VName, VName)
forall (m :: * -> *) a. Monad m => m a -> t m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, VName) -> t m (VName, VName))
-> m (VName, VName) -> t m (VName, VName)
forall a b. (a -> b) -> a -> b
$ Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m),
 LetDec (Rep m) ~ LParamMem) =>
Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
arrayWithIxFun Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun (DeclType -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl DeclType
param_t) VName
res
      ([SubExp], [a]) -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
res_mem'], [])
      SubExp -> t m SubExp
forall a. a -> t m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> t m SubExp) -> SubExp -> t m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
res'
    scalarRes DeclType
_ Space
_ IxFun (TPrimExp Int64 VName)
_ SubExp
se = SubExp -> t m SubExp
forall a. a -> t m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

    allocInMergeParam ::
      (Allocable fromrep torep inner) =>
      (Param DeclType, SubExp) ->
      WriterT
        ([FParam torep], [FParam torep])
        (AllocM fromrep torep)
        ( FParam torep,
          SubExp,
          SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
        )
    allocInMergeParam :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, Var VName
v)
      | param_t :: DeclType
param_t@(Array PrimType
pt Shape
shape Uniqueness
u) <- Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
mergeparam = do
          (VName
v_mem, IxFun (TPrimExp Int64 VName)
v_ixfun) <- AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, IxFun (TPrimExp Int64 VName))
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (VName, IxFun (TPrimExp Int64 VName)))
-> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
          Space
v_mem_space <- AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      Space)
-> AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem

          -- Loop-invariant array parameters that are in scalar space
          -- are special - we do not wish to existentialise their index
          -- function at all (but the memory block is still existential).
          case Space
v_mem_space of
            ScalarSpace {} ->
              if Names -> Bool
anyIsLoopParam (Shape -> Names
forall a. FreeIn a => a -> Names
freeIn Shape
shape)
                then do
                  -- Arrays with loop-variant shape cannot be in scalar
                  -- space, so copy them elsewhere and try again.
                  Space
space <- AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
                  (VName
_, VName
v') <- AllocM fromrep torep (VName, VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, VName)
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, VName)
forall a b. (a -> b) -> a -> b
$ Space -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v
                  (Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, VName -> SubExp
Var VName
v')
                else do
                  Param FParamMem
p <- String
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" (FParamMem
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space
                  ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem
p], [])

                  (Param FParamMem, SubExp,
 SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall a.
a
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
                    ( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = PrimType -> Shape -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u (MemBind -> FParamMem) -> MemBind -> FParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
p) IxFun (TPrimExp Int64 VName)
v_ixfun},
                      VName -> SubExp
Var VName
v,
                      DeclType
-> Space
-> IxFun (TPrimExp Int64 VName)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall {m :: * -> *} {inner :: * -> *} {a}
       {t :: (* -> *) -> * -> *}.
(OpC (Rep m) ~ MemOp inner, LParamInfo (Rep m) ~ LParamMem,
 BranchType (Rep m) ~ BranchTypeMem, LetDec (Rep m) ~ LParamMem,
 RetType (Rep m) ~ RetTypeMem, FParamInfo (Rep m) ~ FParamMem,
 RephraseOp inner, MonadWriter ([SubExp], [a]) (t m),
 MonadBuilder m, MonadTrans t, OpReturns (inner (Rep m))) =>
DeclType
-> Space -> IxFun (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun
                    )
            Space
_ -> do
              (VName
v_mem', VName
v') <- AllocM fromrep torep (VName, VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, VName)
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, VName)
forall a b. (a -> b) -> a -> b
$ Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray Maybe Space
forall a. Maybe a
Nothing VName
v
              (VName
_, IxFun (TPrimExp Int64 VName)
v_ixfun') <- AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, IxFun (TPrimExp Int64 VName))
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (VName, IxFun (TPrimExp Int64 VName)))
-> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v'
              Space
v_mem_space' <- AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      Space)
-> AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem'

              [Param FParamMem]
ctx_params <-
                Int
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [Param FParamMem]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (IxFun (TPrimExp Int64 VName) -> Int
forall a. IxFun a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length IxFun (TPrimExp Int64 VName)
v_ixfun') (WriterT
   ([Param FParamMem], [Param FParamMem])
   (AllocM fromrep torep)
   (Param FParamMem)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [Param FParamMem]
forall a b. (a -> b) -> a -> b
$
                  String
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ctx_param_ext" (PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)

              IxFun (TPrimExp Int64 VName)
param_ixfun <-
                ExtIxFun
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (IxFun (TPrimExp Int64 VName))
forall (m :: * -> *).
Monad m =>
ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
instantiateIxFun (ExtIxFun
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (IxFun (TPrimExp Int64 VName)))
-> ExtIxFun
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$
                  Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun
                    ( [(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, TPrimExp Int64 (Ext VName))]
 -> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> ([TPrimExp Int64 (Ext VName)]
    -> [(Ext VName, TPrimExp Int64 (Ext VName))])
-> [TPrimExp Int64 (Ext VName)]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Ext VName]
-> [TPrimExp Int64 (Ext VName)]
-> [(Ext VName, TPrimExp Int64 (Ext VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Int -> Ext VName) -> [Int] -> [Ext VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Ext VName
forall a. Int -> Ext a
Ext [Int
0 ..]) ([TPrimExp Int64 (Ext VName)]
 -> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> [TPrimExp Int64 (Ext VName)]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$
                        (Param FParamMem -> TPrimExp Int64 (Ext VName))
-> [Param FParamMem] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Ext VName -> TPrimExp Int64 (Ext VName))
-> (Param FParamMem -> Ext VName)
-> Param FParamMem
-> TPrimExp Int64 (Ext VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> (Param FParamMem -> VName) -> Param FParamMem -> Ext VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName) [Param FParamMem]
ctx_params
                    )
                    (IxFun (TPrimExp Int64 VName) -> ExtIxFun
forall a b.
IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b))
IxFun.existentialize IxFun (TPrimExp Int64 VName)
v_ixfun')

              Param FParamMem
mem_param <- String
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" (FParamMem
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space'
              ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem
mem_param], [Param FParamMem]
ctx_params)
              (Param FParamMem, SubExp,
 SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall a.
a
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
                ( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = PrimType -> Shape -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u (MemBind -> FParamMem) -> MemBind -> FParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
mem_param) IxFun (TPrimExp Int64 VName)
param_ixfun},
                  VName -> SubExp
Var VName
v',
                  Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
v_mem_space'
                )
    allocInMergeParam (Param DeclType
mergeparam, SubExp
se) = Param (FParamInfo fromrep)
-> SubExp
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall {torep} {fromrep} {torep} {fromrep} {inner :: * -> *}
       {inner :: * -> *} {b}.
(RetType torep ~ RetTypeMem, RetType fromrep ~ DeclExtType,
 RetType torep ~ RetTypeMem, RetType fromrep ~ DeclExtType,
 FParamInfo torep ~ FParamMem, FParamInfo torep ~ FParamMem,
 FParamInfo fromrep ~ DeclType, FParamInfo fromrep ~ DeclType,
 BodyDec torep ~ (), BodyDec torep ~ (), BodyDec fromrep ~ (),
 BodyDec fromrep ~ (), LetDec torep ~ LParamMem,
 LetDec torep ~ LParamMem, ExpDec torep ~ (), ExpDec torep ~ (),
 LParamInfo fromrep ~ Type, LParamInfo fromrep ~ Type,
 LParamInfo torep ~ LParamMem, LParamInfo torep ~ LParamMem,
 BranchType fromrep ~ ExtType, BranchType torep ~ BranchTypeMem,
 BranchType fromrep ~ ExtType, BranchType torep ~ BranchTypeMem,
 OpC torep ~ MemOp inner, OpC torep ~ MemOp inner,
 PrettyRep fromrep, PrettyRep fromrep, OpReturns (inner torep),
 OpReturns (inner torep), RephraseOp inner, RephraseOp inner,
 SizeSubst (inner torep), SizeSubst (inner torep), BuilderOps torep,
 BuilderOps torep) =>
Param (FParamInfo fromrep)
-> b
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param DeclType
Param (FParamInfo fromrep)
mergeparam SubExp
se (Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem, SubExp,
       SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp))
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace

    doDefault :: Param (FParamInfo fromrep)
-> b
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param (FParamInfo fromrep)
mergeparam b
se Space
space = do
      Param (FParamInfo torep)
mergeparam' <- Param (FParamInfo fromrep)
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep))
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam Param (FParamInfo fromrep)
mergeparam Space
space
      (Param (FParamInfo torep), b,
 SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall a.
a
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (FParamInfo torep)
mergeparam', b
se, Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg (Param (FParamInfo fromrep) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (FParamInfo fromrep)
mergeparam) Space
space)

arrayWithIxFun ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m), LetDec (Rep m) ~ LetDecMem) =>
  Space ->
  IxFun ->
  Type ->
  VName ->
  m (VName, VName)
arrayWithIxFun :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m),
 LetDec (Rep m) ~ LParamMem) =>
Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
arrayWithIxFun Space
space IxFun (TPrimExp Int64 VName)
ixfun Type
v_t VName
v = do
  let Array PrimType
pt Shape
shape NoUniqueness
u = Type
v_t
  VName
mem <- Type -> Space -> m VName
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
v_t Space
space
  VName
v_copy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_scalcopy"
  let pe :: PatElem LParamMem
pe = VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem VName
v_copy (LParamMem -> PatElem LParamMem) -> LParamMem -> PatElem LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun
  Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem
pe]) (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
  (VName, VName) -> m (VName, VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v_copy)

ensureDirectArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensureDirectArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v = do
  (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- VName -> AllocM fromrep torep Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  if IxFun (TPrimExp Int64 VName) -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun (TPrimExp Int64 VName)
ixfun Bool -> Bool -> Bool
&& Bool -> (Space -> Bool) -> Maybe Space -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then (VName, VName) -> AllocM fromrep torep (VName, VName)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
    else Space -> AllocM fromrep torep (VName, VName)
needCopy (Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok)
  where
    needCopy :: Space -> AllocM fromrep torep (VName, VName)
needCopy Space
space =
      -- We need to do a new allocation, copy 'v', and make a new
      -- binding for the size of the memory block.
      Space -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

allocPermArray ::
  (Allocable fromrep torep inner) =>
  Space ->
  [Int] ->
  String ->
  VName ->
  AllocM fromrep torep (VName, VName)
allocPermArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray Space
space [Int]
perm String
s VName
v = do
  Type
t <- VName -> AllocM fromrep torep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  case Type
t of
    Array PrimType
pt Shape
shape NoUniqueness
u -> do
      VName
mem <- Type -> Space -> AllocM fromrep torep VName
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space
      VName
v' <- String -> AllocM fromrep torep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> AllocM fromrep torep VName)
-> String -> AllocM fromrep torep VName
forall a b. (a -> b) -> a -> b
$ String
s String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_desired_form"
      let info :: LParamMem
info =
            PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem)
-> (IxFun (TPrimExp Int64 VName) -> MemBind)
-> IxFun (TPrimExp Int64 VName)
-> LParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (IxFun (TPrimExp Int64 VName) -> LParamMem)
-> IxFun (TPrimExp Int64 VName) -> LParamMem
forall a b. (a -> b) -> a -> b
$
              IxFun (TPrimExp Int64 VName)
-> [Int] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) [Int]
perm
          pat :: Pat LParamMem
pat = [PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem VName
v' LParamMem
info]
      Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ())
-> Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (AllocM fromrep torep)))
-> StmAux (ExpDec (Rep (AllocM fromrep torep)))
-> Exp (Rep (AllocM fromrep torep))
-> Stm (Rep (AllocM fromrep torep))
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Rep (AllocM fromrep torep)))
Pat LParamMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp (Rep (AllocM fromrep torep))
 -> Stm (Rep (AllocM fromrep torep)))
-> Exp (Rep (AllocM fromrep torep))
-> Stm (Rep (AllocM fromrep torep))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (AllocM fromrep torep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (AllocM fromrep torep)))
-> BasicOp -> Exp (Rep (AllocM fromrep torep))
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
v
      (VName, VName) -> AllocM fromrep torep (VName, VName)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v')
    Type
_ ->
      String -> AllocM fromrep torep (VName, VName)
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep torep (VName, VName))
-> String -> AllocM fromrep torep (VName, VName)
forall a b. (a -> b) -> a -> b
$ String
"allocPermArray: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
prettyString Type
t

ensurePermArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  [Int] ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensurePermArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space
-> [Int] -> VName -> AllocM fromrep torep (VName, VName)
ensurePermArray Maybe Space
space_ok [Int]
perm VName
v = do
  (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- VName -> AllocM fromrep torep Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  if [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape IxFun (TPrimExp Int64 VName)
ixfun)
    Bool -> Bool -> Bool
&& Bool -> (Space -> Bool) -> Maybe Space -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then (VName, VName) -> AllocM fromrep torep (VName, VName)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
    else Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray (Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok) [Int]
perm (VName -> String
baseString VName
v) VName
v

allocLinearArray ::
  (Allocable fromrep torep inner) =>
  Space ->
  String ->
  VName ->
  AllocM fromrep torep (VName, VName)
allocLinearArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space String
s VName
v = do
  Type
t <- VName -> AllocM fromrep torep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  let perm :: [Int]
perm = [Int
0 .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
  Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray Space
space [Int]
perm String
s VName
v

funcallArgs ::
  (Allocable fromrep torep inner) =>
  [(SubExp, Diet)] ->
  AllocM fromrep torep [(SubExp, Diet)]
funcallArgs :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args = do
  ([(SubExp, Diet)]
valargs, ([SubExp]
ctx_args, [SubExp]
mem_and_size_args)) <- WriterT
  ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
-> AllocM fromrep torep ([(SubExp, Diet)], ([SubExp], [SubExp]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
 -> AllocM fromrep torep ([(SubExp, Diet)], ([SubExp], [SubExp])))
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
-> AllocM fromrep torep ([(SubExp, Diet)], ([SubExp], [SubExp]))
forall a b. (a -> b) -> a -> b
$
    [(SubExp, Diet)]
-> ((SubExp, Diet)
    -> WriterT
         ([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet))
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(SubExp, Diet)]
args (((SubExp, Diet)
  -> WriterT
       ([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet))
 -> WriterT
      ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)])
-> ((SubExp, Diet)
    -> WriterT
         ([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet))
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ \(SubExp
arg, Diet
d) -> do
      Type
t <- AllocM fromrep torep Type
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Type
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([SubExp], [SubExp]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Type
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Type)
-> AllocM fromrep torep Type
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Type
forall a b. (a -> b) -> a -> b
$ SubExp -> AllocM fromrep torep Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
arg
      Space
space <- AllocM fromrep torep Space
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Space
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([SubExp], [SubExp]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
      SubExp
arg' <- Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Type
t Space
space SubExp
arg
      (SubExp, Diet)
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet)
forall a.
a -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
arg', Diet
d)
  [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)])
-> [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ (SubExp -> (SubExp, Diet)) -> [SubExp] -> [(SubExp, Diet)]
forall a b. (a -> b) -> [a] -> [b]
map (,Diet
Observe) ([SubExp]
ctx_args [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
mem_and_size_args) [(SubExp, Diet)] -> [(SubExp, Diet)] -> [(SubExp, Diet)]
forall a. Semigroup a => a -> a -> a
<> [(SubExp, Diet)]
valargs

linearFuncallArg ::
  (Allocable fromrep torep inner) =>
  Type ->
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Array {} Space
space (Var VName
v) = do
  (VName
mem, VName
arg') <- AllocM fromrep torep (VName, VName)
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([SubExp], [SubExp]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
 -> WriterT
      ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall a b. (a -> b) -> a -> b
$ Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) VName
v
  ([SubExp], [SubExp])
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
mem], [])
  SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a.
a -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arg'
linearFuncallArg Type
_ Space
_ SubExp
arg =
  SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a.
a -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
arg

shiftRetAls :: Int -> Int -> RetAls -> RetAls
shiftRetAls :: Int -> Int -> RetAls -> RetAls
shiftRetAls Int
a Int
b (RetAls [Int]
is [Int]
js) =
  [Int] -> [Int] -> RetAls
RetAls ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a) [Int]
is) ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
b) [Int]
js)

explicitAllocationsGeneric ::
  (Allocable fromrep torep inner) =>
  Space ->
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Pass fromrep torep
explicitAllocationsGeneric :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Pass fromrep torep
explicitAllocationsGeneric Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints =
  String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"explicit allocations" String
"Transform program to explicit memory representation" ((Prog fromrep -> PassM (Prog torep)) -> Pass fromrep torep)
-> (Prog fromrep -> PassM (Prog torep)) -> Pass fromrep torep
forall a b. (a -> b) -> a -> b
$
    (Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms fromrep -> PassM (Stms torep)
onStms Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun
  where
    onStms :: Stms fromrep -> PassM (Stms torep)
onStms Stms fromrep
stms =
      Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (Stms torep)
-> PassM (Stms torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (Stms torep) -> PassM (Stms torep))
-> AllocM fromrep torep (Stms torep) -> PassM (Stms torep)
forall a b. (a -> b) -> a -> b
$ AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
 -> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ Stms fromrep -> AllocM fromrep torep () -> AllocM fromrep torep ()
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ () -> AllocM fromrep torep ()
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    allocInFun :: Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun Stms torep
consts (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [(RetType fromrep, RetAls)]
rettype [FParam fromrep]
params Body fromrep
fbody) =
      Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (FunDef torep)
-> PassM (FunDef torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep))
-> (AllocM fromrep torep (FunDef torep)
    -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
-> PassM (FunDef torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms torep
-> AllocM fromrep torep (FunDef torep)
-> AllocM fromrep torep (FunDef torep)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms torep
consts (AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep))
-> AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep)
forall a b. (a -> b) -> a -> b
$
        [(FParam fromrep, Space)]
-> ([Param (FParamInfo torep)]
    -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams ((Param DeclType -> (Param DeclType, Space))
-> [Param DeclType] -> [(Param DeclType, Space)]
forall a b. (a -> b) -> [a] -> [b]
map (,Space
space) [Param DeclType]
[FParam fromrep]
params) (([Param (FParamInfo torep)]
  -> AllocM fromrep torep (FunDef torep))
 -> AllocM fromrep torep (FunDef torep))
-> ([Param (FParamInfo torep)]
    -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
forall a b. (a -> b) -> a -> b
$ \[Param (FParamInfo torep)]
params' -> do
          (Body torep
fbody', [RetTypeMem]
mem_rets) <-
            [Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody (((DeclExtType, RetAls) -> Maybe Space)
-> [(DeclExtType, RetAls)] -> [Maybe Space]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe Space -> (DeclExtType, RetAls) -> Maybe Space
forall a b. a -> b -> a
const (Maybe Space -> (DeclExtType, RetAls) -> Maybe Space)
-> Maybe Space -> (DeclExtType, RetAls) -> Maybe Space
forall a b. (a -> b) -> a -> b
$ Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) [(DeclExtType, RetAls)]
[(RetType fromrep, RetAls)]
rettype) Body fromrep
fbody
          let num_extra_params :: Int
num_extra_params = [Param FParamMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param (FParamInfo torep)]
[Param FParamMem]
params' Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Param DeclType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param DeclType]
[FParam fromrep]
params
              num_extra_rets :: Int
num_extra_rets = [RetTypeMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
mem_rets
              rettype' :: [(RetTypeMem, RetAls)]
rettype' =
                (RetTypeMem -> (RetTypeMem, RetAls))
-> [RetTypeMem] -> [(RetTypeMem, RetAls)]
forall a b. (a -> b) -> [a] -> [b]
map (,[Int] -> [Int] -> RetAls
RetAls [Int]
forall a. Monoid a => a
mempty [Int]
forall a. Monoid a => a
mempty) [RetTypeMem]
mem_rets
                  [(RetTypeMem, RetAls)]
-> [(RetTypeMem, RetAls)] -> [(RetTypeMem, RetAls)]
forall a. [a] -> [a] -> [a]
++ [RetTypeMem] -> [RetAls] -> [(RetTypeMem, RetAls)]
forall a b. [a] -> [b] -> [(a, b)]
zip
                    (Space -> Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Space
space ([RetTypeMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
mem_rets) (((DeclExtType, RetAls) -> DeclExtType)
-> [(DeclExtType, RetAls)] -> [DeclExtType]
forall a b. (a -> b) -> [a] -> [b]
map (DeclExtType, RetAls) -> DeclExtType
forall a b. (a, b) -> a
fst [(DeclExtType, RetAls)]
[(RetType fromrep, RetAls)]
rettype))
                    (((DeclExtType, RetAls) -> RetAls)
-> [(DeclExtType, RetAls)] -> [RetAls]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> RetAls -> RetAls
shiftRetAls Int
num_extra_params Int
num_extra_rets (RetAls -> RetAls)
-> ((DeclExtType, RetAls) -> RetAls)
-> (DeclExtType, RetAls)
-> RetAls
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DeclExtType, RetAls) -> RetAls
forall a b. (a, b) -> b
snd) [(DeclExtType, RetAls)]
[(RetType fromrep, RetAls)]
rettype)
          FunDef torep -> AllocM fromrep torep (FunDef torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef torep -> AllocM fromrep torep (FunDef torep))
-> FunDef torep -> AllocM fromrep torep (FunDef torep)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [(RetType torep, RetAls)]
-> [Param (FParamInfo torep)]
-> Body torep
-> FunDef torep
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [(RetType rep, RetAls)]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [(RetType torep, RetAls)]
[(RetTypeMem, RetAls)]
rettype' [Param (FParamInfo torep)]
params' Body torep
fbody'

explicitAllocationsInStmsGeneric ::
  ( MonadFreshNames m,
    HasScope torep m,
    Allocable fromrep torep inner
  ) =>
  Space ->
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Stms fromrep ->
  m (Stms torep)
explicitAllocationsInStmsGeneric :: forall (m :: * -> *) torep fromrep (inner :: * -> *).
(MonadFreshNames m, HasScope torep m,
 Allocable fromrep torep inner) =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Stms fromrep
-> m (Stms torep)
explicitAllocationsInStmsGeneric Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints Stms fromrep
stms = do
  Scope torep
scope <- m (Scope torep)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (Stms torep)
-> m (Stms torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (Stms torep) -> m (Stms torep))
-> AllocM fromrep torep (Stms torep) -> m (Stms torep)
forall a b. (a -> b) -> a -> b
$
    Scope torep
-> AllocM fromrep torep (Stms torep)
-> AllocM fromrep torep (Stms torep)
forall a.
Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
scope (AllocM fromrep torep (Stms torep)
 -> AllocM fromrep torep (Stms torep))
-> AllocM fromrep torep (Stms torep)
-> AllocM fromrep torep (Stms torep)
forall a b. (a -> b) -> a -> b
$
      AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
 -> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$
        Stms fromrep -> AllocM fromrep torep () -> AllocM fromrep torep ()
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$
          () -> AllocM fromrep torep ()
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

memoryInDeclExtType :: Space -> Int -> [DeclExtType] -> [FunReturns]
memoryInDeclExtType :: Space -> Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Space
space Int
k [DeclExtType]
dets = State Int [RetTypeMem] -> Int -> [RetTypeMem]
forall s a. State s a -> s -> a
evalState ((DeclExtType -> StateT Int Identity RetTypeMem)
-> [DeclExtType] -> State Int [RetTypeMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DeclExtType -> StateT Int Identity RetTypeMem
addMem [DeclExtType]
dets) Int
0
  where
    addMem :: DeclExtType -> StateT Int Identity RetTypeMem
addMem (Prim PrimType
t) = RetTypeMem -> StateT Int Identity RetTypeMem
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RetTypeMem -> StateT Int Identity RetTypeMem)
-> RetTypeMem -> StateT Int Identity RetTypeMem
forall a b. (a -> b) -> a -> b
$ PrimType -> RetTypeMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    addMem Mem {} = String -> StateT Int Identity RetTypeMem
forall a. HasCallStack => String -> a
error String
"memoryInDeclExtType: too much memory"
    addMem (Array PrimType
pt ShapeBase (Ext SubExp)
shape Uniqueness
u) = do
      Int
i <- StateT Int Identity Int
forall s (m :: * -> *). MonadState s m => m s
get StateT Int Identity Int
-> StateT Int Identity () -> StateT Int Identity Int
forall a b.
StateT Int Identity a
-> StateT Int Identity b -> StateT Int Identity a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* (Int -> Int) -> StateT Int Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      let shape' :: ShapeBase (Ext SubExp)
shape' = (Ext SubExp -> Ext SubExp)
-> ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> ShapeBase a -> ShapeBase b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext SubExp -> Ext SubExp
shift ShapeBase (Ext SubExp)
shape
      RetTypeMem -> StateT Int Identity RetTypeMem
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RetTypeMem -> StateT Int Identity RetTypeMem)
-> (ExtIxFun -> RetTypeMem)
-> ExtIxFun
-> StateT Int Identity RetTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType
-> ShapeBase (Ext SubExp) -> Uniqueness -> MemReturn -> RetTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' Uniqueness
u (MemReturn -> RetTypeMem)
-> (ExtIxFun -> MemReturn) -> ExtIxFun -> RetTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i (ExtIxFun -> StateT Int Identity RetTypeMem)
-> ExtIxFun -> StateT Int Identity RetTypeMem
forall a b. (a -> b) -> a -> b
$
        [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 (Ext VName)] -> ExtIxFun)
-> [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall a b. (a -> b) -> a -> b
$
          (Ext SubExp -> TPrimExp Int64 (Ext VName))
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert ([Ext SubExp] -> [TPrimExp Int64 (Ext VName)])
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> a -> b
$
            ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape'
    addMem (Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u) = RetTypeMem -> StateT Int Identity RetTypeMem
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RetTypeMem -> StateT Int Identity RetTypeMem)
-> RetTypeMem -> StateT Int Identity RetTypeMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> Uniqueness -> RetTypeMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Ext VName -> TPrimExp Int64 (Ext VName))
-> Ext VName -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> a -> b
$ Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i
    convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    shift :: Ext SubExp -> Ext SubExp
shift (Ext Int
i) = Int -> Ext SubExp
forall a. Int -> Ext a
Ext (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
    shift (Free SubExp
x) = SubExp -> Ext SubExp
forall a. a -> Ext a
Free SubExp
x

bodyReturnMemCtx ::
  (Allocable fromrep torep inner) =>
  SubExpRes ->
  AllocM fromrep torep [(SubExpRes, MemInfo ExtSize u MemReturn)]
bodyReturnMemCtx :: forall fromrep torep (inner :: * -> *) u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx (SubExpRes Certs
_ Constant {}) =
  [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
bodyReturnMemCtx (SubExpRes Certs
_ (Var VName
v)) = do
  LParamMem
info <- VName -> AllocM fromrep torep LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
v
  case LParamMem
info of
    MemPrim {} -> [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    MemAcc {} -> [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    MemMem {} -> [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [] -- should not happen
    MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_) -> do
      LParamMem
mem_info <- VName -> AllocM fromrep torep LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
mem
      case LParamMem
mem_info of
        MemMem Space
space ->
          [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(SubExp -> SubExpRes
subExpRes (SubExp -> SubExpRes) -> SubExp -> SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
mem, Space -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)]
        LParamMem
_ -> String
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. HasCallStack => String -> a
error (String
 -> AllocM
      fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)])
-> String
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a b. (a -> b) -> a -> b
$ String
"bodyReturnMemCtx: not a memory block: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
prettyString VName
mem

allocInFunBody ::
  (Allocable fromrep torep inner) =>
  [Maybe Space] ->
  Body fromrep ->
  AllocM fromrep torep (Body torep, [FunReturns])
allocInFunBody :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody [Maybe Space]
space_oks (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
  AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem])
AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM
     fromrep torep (Body (Rep (AllocM fromrep torep)), [RetTypeMem])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody (AllocM fromrep torep (Result, [RetTypeMem])
 -> AllocM fromrep torep (Body torep, [RetTypeMem]))
-> (AllocM fromrep torep (Result, [RetTypeMem])
    -> AllocM fromrep torep (Result, [RetTypeMem]))
-> AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Result, [RetTypeMem])
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep (Result, [RetTypeMem])
 -> AllocM fromrep torep (Body torep, [RetTypeMem]))
-> AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem])
forall a b. (a -> b) -> a -> b
$ do
    Result
res' <- (Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes)
-> [Maybe Space] -> Result -> AllocM fromrep torep Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect [Maybe Space]
space_oks' Result
res
    (Result
mem_ctx_res, [RetTypeMem]
mem_ctx_rets) <- [(SubExpRes, RetTypeMem)] -> (Result, [RetTypeMem])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExpRes, RetTypeMem)] -> (Result, [RetTypeMem]))
-> ([[(SubExpRes, RetTypeMem)]] -> [(SubExpRes, RetTypeMem)])
-> [[(SubExpRes, RetTypeMem)]]
-> (Result, [RetTypeMem])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[(SubExpRes, RetTypeMem)]] -> [(SubExpRes, RetTypeMem)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(SubExpRes, RetTypeMem)]] -> (Result, [RetTypeMem]))
-> AllocM fromrep torep [[(SubExpRes, RetTypeMem)]]
-> AllocM fromrep torep (Result, [RetTypeMem])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> AllocM fromrep torep [(SubExpRes, RetTypeMem)])
-> Result -> AllocM fromrep torep [[(SubExpRes, RetTypeMem)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExpRes -> AllocM fromrep torep [(SubExpRes, RetTypeMem)]
forall fromrep torep (inner :: * -> *) u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx Result
res'
    (Result, [RetTypeMem])
-> AllocM fromrep torep (Result, [RetTypeMem])
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
mem_ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
res', [RetTypeMem]
mem_ctx_rets)
  where
    num_vals :: Int
num_vals = [Maybe Space] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Space]
space_oks
    space_oks' :: [Maybe Space]
space_oks' = Int -> Maybe Space -> [Maybe Space]
forall a. Int -> a -> [a]
replicate (Result -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_vals) Maybe Space
forall a. Maybe a
Nothing [Maybe Space] -> [Maybe Space] -> [Maybe Space]
forall a. [a] -> [a] -> [a]
++ [Maybe Space]
space_oks

ensureDirect ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  SubExpRes ->
  AllocM fromrep torep SubExpRes
ensureDirect :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect Maybe Space
space_ok (SubExpRes Certs
cs SubExp
se) = do
  LParamMem
se_info <- SubExp -> AllocM fromrep torep LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
  Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (SubExp -> SubExpRes)
-> AllocM fromrep torep SubExp -> AllocM fromrep torep SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case (LParamMem
se_info, SubExp
se) of
    (MemArray {}, Var VName
v) -> do
      (VName
_, VName
v') <- Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v
      SubExp -> AllocM fromrep torep SubExp
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> AllocM fromrep torep SubExp)
-> SubExp -> AllocM fromrep torep SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'
    (LParamMem, SubExp)
_ ->
      SubExp -> AllocM fromrep torep SubExp
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

allocInStms ::
  (Allocable fromrep torep inner) =>
  Stms fromrep ->
  AllocM fromrep torep a ->
  AllocM fromrep torep a
allocInStms :: forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
origstms AllocM fromrep torep a
m = [Stm fromrep] -> AllocM fromrep torep a
allocInStms' ([Stm fromrep] -> AllocM fromrep torep a)
-> [Stm fromrep] -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ Stms fromrep -> [Stm fromrep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms fromrep
origstms
  where
    allocInStms' :: [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [] = AllocM fromrep torep a
m
    allocInStms' (Stm fromrep
stm : [Stm fromrep]
stms) = do
      Seq (Stm torep)
allocstms <- AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
 -> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec fromrep)
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing (Stm fromrep -> StmAux (ExpDec fromrep)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm fromrep
stm) (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ Stm fromrep -> AllocM fromrep torep ()
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm Stm fromrep
stm
      Stms (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Seq (Stm torep)
Stms (Rep (AllocM fromrep torep))
allocstms
      let stms_consts :: Set VName
stms_consts = (Stm torep -> Set VName) -> Seq (Stm torep) -> Set VName
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm torep -> Set VName
forall rep. SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts Seq (Stm torep)
allocstms
          f :: AllocEnv fromrep torep -> AllocEnv fromrep torep
f AllocEnv fromrep torep
env = AllocEnv fromrep torep
env {envConsts :: Set VName
envConsts = Set VName
stms_consts Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> AllocEnv fromrep torep -> Set VName
forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts AllocEnv fromrep torep
env}
      (AllocEnv fromrep torep -> AllocEnv fromrep torep)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a.
(AllocEnv fromrep torep -> AllocEnv fromrep torep)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv fromrep torep -> AllocEnv fromrep torep
f (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [Stm fromrep]
stms

allocInStm ::
  (Allocable fromrep torep inner) =>
  Stm fromrep ->
  AllocM fromrep torep ()
allocInStm :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm (Let (Pat [PatElem (LetDec fromrep)]
pes) StmAux (ExpDec fromrep)
_ Exp fromrep
e) =
  Stm torep -> AllocM fromrep torep ()
Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm torep -> AllocM fromrep torep ())
-> AllocM fromrep torep (Stm torep) -> AllocM fromrep torep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm ((PatElem (LetDec fromrep) -> Ident)
-> [PatElem (LetDec fromrep)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElem (LetDec fromrep) -> Ident
forall dec. Typed dec => PatElem dec -> Ident
patElemIdent [PatElem (LetDec fromrep)]
pes) (Exp torep -> AllocM fromrep torep (Stm torep))
-> AllocM fromrep torep (Exp torep)
-> AllocM fromrep torep (Stm torep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp fromrep -> AllocM fromrep torep (Exp torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp Exp fromrep
e

allocInLambda ::
  (Allocable fromrep torep inner) =>
  [LParam torep] ->
  Body fromrep ->
  AllocM fromrep torep (Lambda torep)
allocInLambda :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
params Body fromrep
body =
  [LParam (Rep (AllocM fromrep torep))]
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [LParam torep]
[LParam (Rep (AllocM fromrep torep))]
params (AllocM fromrep torep Result
 -> AllocM fromrep torep (Lambda torep))
-> (AllocM fromrep torep Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms (Body fromrep -> Stms fromrep
forall rep. Body rep -> Stms rep
bodyStms Body fromrep
body) (AllocM fromrep torep Result
 -> AllocM fromrep torep (Lambda torep))
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda torep)
forall a b. (a -> b) -> a -> b
$ Result -> AllocM fromrep torep Result
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ Body fromrep -> Result
forall rep. Body rep -> Result
bodyResult Body fromrep
body

data MemReq
  = MemReq Space Rank
  | NeedsNormalisation Space
  deriving (MemReq -> MemReq -> Bool
(MemReq -> MemReq -> Bool)
-> (MemReq -> MemReq -> Bool) -> Eq MemReq
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MemReq -> MemReq -> Bool
== :: MemReq -> MemReq -> Bool
$c/= :: MemReq -> MemReq -> Bool
/= :: MemReq -> MemReq -> Bool
Eq, Int -> MemReq -> String -> String
[MemReq] -> String -> String
MemReq -> String
(Int -> MemReq -> String -> String)
-> (MemReq -> String)
-> ([MemReq] -> String -> String)
-> Show MemReq
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
$cshowsPrec :: Int -> MemReq -> String -> String
showsPrec :: Int -> MemReq -> String -> String
$cshow :: MemReq -> String
show :: MemReq -> String
$cshowList :: [MemReq] -> String -> String
showList :: [MemReq] -> String -> String
Show)

combMemReqs :: MemReq -> MemReq -> MemReq
combMemReqs :: MemReq -> MemReq -> MemReq
combMemReqs x :: MemReq
x@NeedsNormalisation {} MemReq
_ = MemReq
x
combMemReqs MemReq
_ y :: MemReq
y@NeedsNormalisation {} = MemReq
y
combMemReqs x :: MemReq
x@(MemReq Space
x_space Rank
_) y :: MemReq
y@MemReq {} =
  if MemReq
x MemReq -> MemReq -> Bool
forall a. Eq a => a -> a -> Bool
== MemReq
y then MemReq
x else Space -> MemReq
NeedsNormalisation Space
x_space

type MemReqType = MemInfo (Ext SubExp) NoUniqueness MemReq

combMemReqTypes :: MemReqType -> MemReqType -> MemReqType
combMemReqTypes :: MemReqType -> MemReqType -> MemReqType
combMemReqTypes (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u MemReq
x) (MemArray PrimType
_ ShapeBase (Ext SubExp)
_ NoUniqueness
_ MemReq
y) =
  PrimType
-> ShapeBase (Ext SubExp) -> NoUniqueness -> MemReq -> MemReqType
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u (MemReq -> MemReqType) -> MemReq -> MemReqType
forall a b. (a -> b) -> a -> b
$ MemReq -> MemReq -> MemReq
combMemReqs MemReq
x MemReq
y
combMemReqTypes MemReqType
x MemReqType
_ = MemReqType
x

contextRets :: MemReqType -> [MemInfo d u r]
contextRets :: forall d u r. MemReqType -> [MemInfo d u r]
contextRets (MemArray PrimType
_ ShapeBase (Ext SubExp)
shape NoUniqueness
_ (MemReq Space
space (Rank Int
base_rank))) =
  -- Memory + offset + base_rank + (stride,size)*rank.
  Space -> MemInfo d u r
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    MemInfo d u r -> [MemInfo d u r] -> [MemInfo d u r]
forall a. a -> [a] -> [a]
: PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64
    MemInfo d u r -> [MemInfo d u r] -> [MemInfo d u r]
forall a. a -> [a] -> [a]
: Int -> MemInfo d u r -> [MemInfo d u r]
forall a. Int -> a -> [a]
replicate Int
base_rank (PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
    [MemInfo d u r] -> [MemInfo d u r] -> [MemInfo d u r]
forall a. [a] -> [a] -> [a]
++ Int -> MemInfo d u r -> [MemInfo d u r]
forall a. Int -> a -> [a]
replicate (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* ShapeBase (Ext SubExp) -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) (PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
contextRets (MemArray PrimType
_ ShapeBase (Ext SubExp)
shape NoUniqueness
_ (NeedsNormalisation Space
space)) =
  -- Memory + offset + (base,stride,size)*rank.
  Space -> MemInfo d u r
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    MemInfo d u r -> [MemInfo d u r] -> [MemInfo d u r]
forall a. a -> [a] -> [a]
: PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64
    MemInfo d u r -> [MemInfo d u r] -> [MemInfo d u r]
forall a. a -> [a] -> [a]
: Int -> MemInfo d u r -> [MemInfo d u r]
forall a. Int -> a -> [a]
replicate (Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* ShapeBase (Ext SubExp) -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) (PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
contextRets MemReqType
_ = []

-- Add memory information to the body, but do not return memory/ixfun
-- information.  Instead, return restrictions on what the index
-- function should look like.  We will then (crudely) unify these
-- restrictions across all bodies.
allocInMatchBody ::
  (Allocable fromrep torep inner) =>
  [ExtType] ->
  Body fromrep ->
  AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [ExtType]
rets (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
  AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Body torep, [MemReqType])
AllocM fromrep torep (Result, [MemReqType])
-> AllocM
     fromrep torep (Body (Rep (AllocM fromrep torep)), [MemReqType])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody (AllocM fromrep torep (Result, [MemReqType])
 -> AllocM fromrep torep (Body torep, [MemReqType]))
-> (AllocM fromrep torep (Result, [MemReqType])
    -> AllocM fromrep torep (Result, [MemReqType]))
-> AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Body torep, [MemReqType])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Result, [MemReqType])
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep (Result, [MemReqType])
 -> AllocM fromrep torep (Body torep, [MemReqType]))
-> AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Body torep, [MemReqType])
forall a b. (a -> b) -> a -> b
$ do
    [MemReqType]
restrictions <- (ExtType -> SubExp -> AllocM fromrep torep MemReqType)
-> [ExtType] -> [SubExp] -> AllocM fromrep torep [MemReqType]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ExtType -> SubExp -> AllocM fromrep torep MemReqType
forall {rep} {inner :: * -> *} {m :: * -> *} {d}.
(OpC rep ~ MemOp inner, RetType rep ~ RetTypeMem,
 FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem,
 BranchType rep ~ BranchTypeMem, RephraseOp inner, Monad m,
 HasLetDecMem (LetDec rep), ASTRep rep, OpReturns (inner rep),
 HasScope rep m, Show d) =>
TypeBase (ShapeBase d) NoUniqueness
-> SubExp -> m (MemInfo d NoUniqueness MemReq)
restriction [ExtType]
rets ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res)
    (Result, [MemReqType])
-> AllocM fromrep torep (Result, [MemReqType])
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
res, [MemReqType]
restrictions)
  where
    restriction :: TypeBase (ShapeBase d) NoUniqueness
-> SubExp -> m (MemInfo d NoUniqueness MemReq)
restriction TypeBase (ShapeBase d) NoUniqueness
t SubExp
se = do
      LParamMem
v_info <- SubExp -> m LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
      case (TypeBase (ShapeBase d) NoUniqueness
t, LParamMem
v_info) of
        (Array PrimType
pt ShapeBase d
shape NoUniqueness
u, MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun)) -> do
          Space
space <- VName -> m Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
          MemInfo d NoUniqueness MemReq -> m (MemInfo d NoUniqueness MemReq)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d NoUniqueness MemReq
 -> m (MemInfo d NoUniqueness MemReq))
-> (MemReq -> MemInfo d NoUniqueness MemReq)
-> MemReq
-> m (MemInfo d NoUniqueness MemReq)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType
-> ShapeBase d
-> NoUniqueness
-> MemReq
-> MemInfo d NoUniqueness MemReq
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
shape NoUniqueness
u (MemReq -> m (MemInfo d NoUniqueness MemReq))
-> MemReq -> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$
            Space -> Rank -> MemReq
MemReq Space
space (Int -> Rank
Rank (Int -> Rank) -> Int -> Rank
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TPrimExp Int64 VName] -> Int) -> [TPrimExp Int64 VName] -> Int
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun)
        (TypeBase (ShapeBase d) NoUniqueness
_, MemMem Space
space) -> MemInfo d NoUniqueness MemReq -> m (MemInfo d NoUniqueness MemReq)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d NoUniqueness MemReq
 -> m (MemInfo d NoUniqueness MemReq))
-> MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo d NoUniqueness MemReq
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
        (TypeBase (ShapeBase d) NoUniqueness
_, MemPrim PrimType
pt) -> MemInfo d NoUniqueness MemReq -> m (MemInfo d NoUniqueness MemReq)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d NoUniqueness MemReq
 -> m (MemInfo d NoUniqueness MemReq))
-> MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo d NoUniqueness MemReq
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
        (TypeBase (ShapeBase d) NoUniqueness
_, MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) -> MemInfo d NoUniqueness MemReq -> m (MemInfo d NoUniqueness MemReq)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d NoUniqueness MemReq
 -> m (MemInfo d NoUniqueness MemReq))
-> MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$ VName
-> Shape -> [Type] -> NoUniqueness -> MemInfo d NoUniqueness MemReq
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
        (TypeBase (ShapeBase d) NoUniqueness, LParamMem)
_ -> String -> m (MemInfo d NoUniqueness MemReq)
forall a. HasCallStack => String -> a
error (String -> m (MemInfo d NoUniqueness MemReq))
-> String -> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$ String
"allocInMatchBody: mismatch: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (TypeBase (ShapeBase d) NoUniqueness, LParamMem) -> String
forall a. Show a => a -> String
show (TypeBase (ShapeBase d) NoUniqueness
t, LParamMem
v_info)

mkBranchRet :: [MemReqType] -> [BranchTypeMem]
mkBranchRet :: [MemReqType] -> [BranchTypeMem]
mkBranchRet [MemReqType]
reqs =
  let ([BranchTypeMem]
ctx_rets, [BranchTypeMem]
res_rets) = (([BranchTypeMem], [BranchTypeMem])
 -> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem]))
-> ([BranchTypeMem], [BranchTypeMem])
-> [(MemReqType, Int)]
-> ([BranchTypeMem], [BranchTypeMem])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([BranchTypeMem], [BranchTypeMem])
-> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem])
helper ([], []) ([(MemReqType, Int)] -> ([BranchTypeMem], [BranchTypeMem]))
-> [(MemReqType, Int)] -> ([BranchTypeMem], [BranchTypeMem])
forall a b. (a -> b) -> a -> b
$ [MemReqType] -> [Int] -> [(MemReqType, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [MemReqType]
reqs [Int]
offsets
   in [BranchTypeMem]
ctx_rets [BranchTypeMem] -> [BranchTypeMem] -> [BranchTypeMem]
forall a. [a] -> [a] -> [a]
++ [BranchTypeMem]
res_rets
  where
    numCtxNeeded :: MemReqType -> Int
numCtxNeeded = [MemInfo Any Any Any] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([MemInfo Any Any Any] -> Int)
-> (MemReqType -> [MemInfo Any Any Any]) -> MemReqType -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemReqType -> [MemInfo Any Any Any]
forall d u r. MemReqType -> [MemInfo d u r]
contextRets

    offsets :: [Int]
offsets = (Int -> Int -> Int) -> Int -> [Int] -> [Int]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0 ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (MemReqType -> Int) -> [MemReqType] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map MemReqType -> Int
numCtxNeeded [MemReqType]
reqs
    num_new_ctx :: Int
num_new_ctx = [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
offsets

    helper :: ([BranchTypeMem], [BranchTypeMem])
-> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem])
helper ([BranchTypeMem]
ctx_rets_acc, [BranchTypeMem]
res_rets_acc) (MemReqType
req, Int
ctx_offset) =
      ( [BranchTypeMem]
ctx_rets_acc [BranchTypeMem] -> [BranchTypeMem] -> [BranchTypeMem]
forall a. [a] -> [a] -> [a]
++ MemReqType -> [BranchTypeMem]
forall d u r. MemReqType -> [MemInfo d u r]
contextRets MemReqType
req,
        [BranchTypeMem]
res_rets_acc [BranchTypeMem] -> [BranchTypeMem] -> [BranchTypeMem]
forall a. [a] -> [a] -> [a]
++ [Int -> MemReqType -> BranchTypeMem
inspect Int
ctx_offset MemReqType
req]
      )

    arrayInfo :: Int -> MemReq -> (Space, Int)
arrayInfo Int
rank (NeedsNormalisation Space
space) =
      (Space
space, Int
rank)
    arrayInfo Int
_ (MemReq Space
space (Rank Int
base_rank)) =
      (Space
space, Int
base_rank)

    inspect :: Int -> MemReqType -> BranchTypeMem
inspect Int
ctx_offset (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u MemReq
req) =
      let shape' :: ShapeBase (Ext SubExp)
shape' = (Ext SubExp -> Ext SubExp)
-> ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> ShapeBase a -> ShapeBase b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Ext SubExp -> Ext SubExp
forall a. Int -> Ext a -> Ext a
adjustExt Int
num_new_ctx) ShapeBase (Ext SubExp)
shape
          (Space
space, Int
base_rank) = Int -> MemReq -> (Space, Int)
arrayInfo (ShapeBase (Ext SubExp) -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) MemReq
req
       in PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> MemReturn
-> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' NoUniqueness
u (MemReturn -> BranchTypeMem)
-> (ExtIxFun -> MemReturn) -> ExtIxFun -> BranchTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
ctx_offset (ExtIxFun -> BranchTypeMem) -> ExtIxFun -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$
            Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext SubExp -> TPrimExp Int64 (Ext VName))
-> IxFun (Ext SubExp) -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> Int -> IxFun (Ext SubExp)
forall a. Int -> Int -> Int -> IxFun (Ext a)
IxFun.mkExistential Int
base_rank (ShapeBase (Ext SubExp) -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) (Int
ctx_offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    inspect Int
_ (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = VName -> Shape -> [Type] -> NoUniqueness -> BranchTypeMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
    inspect Int
_ (MemPrim PrimType
pt) = PrimType -> BranchTypeMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
    inspect Int
_ (MemMem Space
space) = Space -> BranchTypeMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i)
    convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    adjustExt :: Int -> Ext a -> Ext a
    adjustExt :: forall a. Int -> Ext a -> Ext a
adjustExt Int
_ (Free a
v) = a -> Ext a
forall a. a -> Ext a
Free a
v
    adjustExt Int
k (Ext Int
i) = Int -> Ext a
forall a. Int -> Ext a
Ext (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)

addCtxToMatchBody ::
  (Allocable fromrep torep inner) =>
  [MemReqType] ->
  Body torep ->
  AllocM fromrep torep (Body torep)
addCtxToMatchBody :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs Body torep
body = AllocM fromrep torep Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (AllocM fromrep torep Result
 -> AllocM fromrep torep (Body (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ do
  Result
res <- (MemReqType -> SubExpRes -> AllocM fromrep torep SubExpRes)
-> [MemReqType] -> Result -> AllocM fromrep torep Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM MemReqType -> SubExpRes -> AllocM fromrep torep SubExpRes
forall {torep} {fromrep} {inner :: * -> *} {d} {u}.
(BranchType torep ~ BranchTypeMem, BranchType fromrep ~ ExtType,
 LParamInfo torep ~ LParamMem, LParamInfo fromrep ~ Type,
 LetDec torep ~ LParamMem, BodyDec fromrep ~ (), BodyDec torep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo torep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType torep ~ RetTypeMem,
 ExpDec torep ~ (), OpC torep ~ MemOp inner, PrettyRep fromrep,
 OpReturns (inner torep), RephraseOp inner, SizeSubst (inner torep),
 BuilderOps torep, ArrayShape (ShapeBase d), Ord d) =>
MemInfo d u MemReq -> SubExpRes -> AllocM fromrep torep SubExpRes
normaliseIfNeeded [MemReqType]
reqs (Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Body (Rep (AllocM fromrep torep)) -> AllocM fromrep torep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body torep
Body (Rep (AllocM fromrep torep))
body
  Result
ctx <- [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Result)
-> AllocM fromrep torep [Result] -> AllocM fromrep torep Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep [Result]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExpRes -> AllocM fromrep torep Result
forall {f :: * -> *} {inner :: * -> *}.
(OpC (Rep f) ~ MemOp inner, LParamInfo (Rep f) ~ LParamMem,
 RetType (Rep f) ~ RetTypeMem, FParamInfo (Rep f) ~ FParamMem,
 BranchType (Rep f) ~ BranchTypeMem, MonadBuilder f,
 RephraseOp inner, HasLetDecMem (LetDec (Rep f)),
 OpReturns (inner (Rep f))) =>
SubExpRes -> f Result
resCtx Result
res
  Result -> AllocM fromrep torep Result
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ Result
ctx Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
res
  where
    normaliseIfNeeded :: MemInfo d u MemReq -> SubExpRes -> AllocM fromrep torep SubExpRes
normaliseIfNeeded (MemArray PrimType
_ ShapeBase d
shape u
_ (NeedsNormalisation Space
space)) (SubExpRes Certs
cs (Var VName
v)) =
      Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (SubExp -> SubExpRes)
-> ((VName, VName) -> SubExp) -> (VName, VName) -> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> ((VName, VName) -> VName) -> (VName, VName) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, VName) -> VName
forall a b. (a, b) -> b
snd
        ((VName, VName) -> SubExpRes)
-> AllocM fromrep torep (VName, VName)
-> AllocM fromrep torep SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Space
-> [Int] -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space
-> [Int] -> VName -> AllocM fromrep torep (VName, VName)
ensurePermArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) [Int
0 .. ShapeBase d -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase d
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] VName
v
    normaliseIfNeeded MemInfo d u MemReq
_ SubExpRes
res =
      SubExpRes -> AllocM fromrep torep SubExpRes
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExpRes
res

    resCtx :: SubExpRes -> f Result
resCtx (SubExpRes Certs
_ Constant {}) =
      Result -> f Result
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    resCtx (SubExpRes Certs
_ (Var VName
v)) = do
      LParamMem
info <- VName -> f LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
v
      case LParamMem
info of
        MemPrim {} -> Result -> f Result
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        MemAcc {} -> Result -> f Result
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        MemMem {} -> Result -> f Result
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [] -- should not happen
        MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun) -> do
          [SubExp]
ixfun_exts <- (TPrimExp Int64 VName -> f SubExp)
-> [TPrimExp Int64 VName] -> f [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep f) -> f SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ixfun_ext" (Exp (Rep f) -> f SubExp)
-> (TPrimExp Int64 VName -> f (Exp (Rep f)))
-> TPrimExp Int64 VName
-> f SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName -> f (Exp (Rep f))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp) ([TPrimExp Int64 VName] -> f [SubExp])
-> [TPrimExp Int64 VName] -> f [SubExp]
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList IxFun (TPrimExp Int64 VName)
ixfun
          Result -> f Result
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> f Result) -> Result -> f Result
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExpRes
subExpRes (VName -> SubExp
Var VName
mem) SubExpRes -> Result -> Result
forall a. a -> [a] -> [a]
: [SubExp] -> Result
subExpsRes [SubExp]
ixfun_exts

-- Do a a simple form of invariance analysis to simplify a Match.  It
-- is unfortunate that we have to do it here, but functions such as
-- scalarRes will look carefully at the index functions before the
-- simplifier has a chance to run.  In a perfect world we would
-- simplify away those copies afterwards. XXX; this should be fixed by
-- a more general copy-removal pass. See
-- Futhark.Optimise.EntryPointMem for a very specialised version of
-- the idea, but which could perhaps be generalised.
simplifyMatch ::
  (Mem rep inner) =>
  [Case (Body rep)] ->
  Body rep ->
  [BranchTypeMem] ->
  ( [Case (Body rep)],
    Body rep,
    [BranchTypeMem]
  )
simplifyMatch :: forall rep (inner :: * -> *).
Mem rep inner =>
[Case (Body rep)]
-> Body rep
-> [BranchTypeMem]
-> ([Case (Body rep)], Body rep, [BranchTypeMem])
simplifyMatch [Case (Body rep)]
cases Body rep
defbody [BranchTypeMem]
ts =
  let case_reses :: [Result]
case_reses = (Case (Body rep) -> Result) -> [Case (Body rep)] -> [Result]
forall a b. (a -> b) -> [a] -> [b]
map (Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result)
-> (Case (Body rep) -> Body rep) -> Case (Body rep) -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
      defbody_res :: Result
defbody_res = Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
defbody
      ([(Int, SubExp)]
ctx_fixes, [(Result, SubExpRes, BranchTypeMem)]
variant) =
        [Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)]
-> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)]
 -> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)]))
-> ([(Int, Result, SubExpRes, BranchTypeMem)]
    -> [Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)])
-> [(Int, Result, SubExpRes, BranchTypeMem)]
-> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Result, SubExpRes, BranchTypeMem)
 -> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem))
-> [(Int, Result, SubExpRes, BranchTypeMem)]
-> [Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
branchInvariant ([(Int, Result, SubExpRes, BranchTypeMem)]
 -> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)]))
-> [(Int, Result, SubExpRes, BranchTypeMem)]
-> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)])
forall a b. (a -> b) -> a -> b
$
          [Int]
-> [Result]
-> Result
-> [BranchTypeMem]
-> [(Int, Result, SubExpRes, BranchTypeMem)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Int
0 ..] ([Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
case_reses) Result
defbody_res [BranchTypeMem]
ts
      ([Result]
cases_reses, Result
defbody_reses, [BranchTypeMem]
ts') = [(Result, SubExpRes, BranchTypeMem)]
-> ([Result], Result, [BranchTypeMem])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Result, SubExpRes, BranchTypeMem)]
variant
   in ( (Case (Body rep) -> Result -> Case (Body rep))
-> [Case (Body rep)] -> [Result] -> [Case (Body rep)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Case (Body rep) -> Result -> Case (Body rep)
forall {f :: * -> *} {rep}.
Functor f =>
f (Body rep) -> Result -> f (Body rep)
onCase [Case (Body rep)]
cases ([Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
cases_reses),
        Body rep -> Result -> Body rep
forall {rep}. Body rep -> Result -> Body rep
onBody Body rep
defbody Result
defbody_reses,
        ((Int, SubExp) -> [BranchTypeMem] -> [BranchTypeMem])
-> [BranchTypeMem] -> [(Int, SubExp)] -> [BranchTypeMem]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchTypeMem] -> [BranchTypeMem])
-> (Int, SubExp) -> [BranchTypeMem] -> [BranchTypeMem]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchTypeMem] -> [BranchTypeMem]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [BranchTypeMem]
ts' [(Int, SubExp)]
ctx_fixes
      )
  where
    bound_in_branches :: Names
bound_in_branches =
      [VName] -> Names
namesFromList ([VName] -> Names)
-> (Seq (Stm rep) -> [VName]) -> Seq (Stm rep) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm rep -> [VName]) -> Seq (Stm rep) -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName])
-> (Stm rep -> Pat (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) (Seq (Stm rep) -> Names) -> Seq (Stm rep) -> Names
forall a b. (a -> b) -> a -> b
$
        (Case (Body rep) -> Seq (Stm rep))
-> [Case (Body rep)] -> Seq (Stm rep)
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Seq (Stm rep))
-> (Case (Body rep) -> Body rep)
-> Case (Body rep)
-> Seq (Stm rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases Seq (Stm rep) -> Seq (Stm rep) -> Seq (Stm rep)
forall a. Semigroup a => a -> a -> a
<> Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
defbody

    onCase :: f (Body rep) -> Result -> f (Body rep)
onCase f (Body rep)
c Result
res = (Body rep -> Body rep) -> f (Body rep) -> f (Body rep)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Body rep -> Result -> Body rep
forall {rep}. Body rep -> Result -> Body rep
`onBody` Result
res) f (Body rep)
c
    onBody :: Body rep -> Result -> Body rep
onBody Body rep
body Result
res = Body rep
body {bodyResult :: Result
bodyResult = Result
res}

    branchInvariant :: (Int, Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
branchInvariant (Int
i, Result
case_reses, SubExpRes
defres, BranchTypeMem
t)
      -- If even one branch has a variant result, then we give up.
      | Names -> Names -> Bool
namesIntersect Names
bound_in_branches (Names -> Bool) -> Names -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ SubExpRes
defres SubExpRes -> Result -> Result
forall a. a -> [a] -> [a]
: Result
case_reses =
          (Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
forall a b. b -> Either a b
Right (Result
case_reses, SubExpRes
defres, BranchTypeMem
t)
      -- Do all branches return the same value?
      | (SubExpRes -> Bool) -> Result -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
defres) (SubExp -> Bool) -> (SubExpRes -> SubExp) -> SubExpRes -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
case_reses =
          (Int, SubExp)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
forall a b. a -> Either a b
Left (Int
i, SubExpRes -> SubExp
resSubExp SubExpRes
defres)
      | Bool
otherwise =
          (Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
forall a b. b -> Either a b
Right (Result
case_reses, SubExpRes
defres, BranchTypeMem
t)

allocInExp ::
  (Allocable fromrep torep inner) =>
  Exp fromrep ->
  AllocM fromrep torep (Exp torep)
allocInExp :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp (Loop [(FParam fromrep, SubExp)]
merge LoopForm fromrep
form (Body () Stms fromrep
bodystms Result
bodyres)) =
  [(FParam fromrep, SubExp)]
-> ([(Param (FParamInfo torep), SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep (Exp torep))
-> AllocM fromrep torep (Exp torep)
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
merge (([(Param (FParamInfo torep), SubExp)]
  -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
  -> AllocM fromrep torep (Exp torep))
 -> AllocM fromrep torep (Exp torep))
-> ([(Param (FParamInfo torep), SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep (Exp torep))
-> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ \[(Param (FParamInfo torep), SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val -> do
    LoopForm torep
form' <- LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
allocInLoopForm LoopForm fromrep
form
    Scope torep
-> AllocM fromrep torep (Exp torep)
-> AllocM fromrep torep (Exp torep)
forall a.
Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm torep -> Scope torep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm torep
form') (AllocM fromrep torep (Exp torep)
 -> AllocM fromrep torep (Exp torep))
-> AllocM fromrep torep (Exp torep)
-> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ do
      Body torep
body' <-
        AllocM fromrep torep Result -> AllocM fromrep torep (Body torep)
AllocM fromrep torep Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (AllocM fromrep torep Result -> AllocM fromrep torep (Body torep))
-> (AllocM fromrep torep Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Body torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
bodystms (AllocM fromrep torep Result -> AllocM fromrep torep (Body torep))
-> AllocM fromrep torep Result -> AllocM fromrep torep (Body torep)
forall a b. (a -> b) -> a -> b
$ do
          ([SubExp]
valctx, [SubExp]
valres') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
bodyres
          Result -> AllocM fromrep torep Result
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
valctx Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> (Certs -> SubExp -> SubExpRes) -> [Certs] -> [SubExp] -> Result
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Certs -> SubExp -> SubExpRes
SubExpRes ((SubExpRes -> Certs) -> Result -> [Certs]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts Result
bodyres) [SubExp]
valres'
      Exp torep -> AllocM fromrep torep (Exp torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp torep -> AllocM fromrep torep (Exp torep))
-> Exp torep -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo torep), SubExp)]
-> LoopForm torep -> Body torep -> Exp torep
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
Loop [(Param (FParamInfo torep), SubExp)]
merge' LoopForm torep
form' Body torep
body'
allocInExp (Apply Name
fname [(SubExp, Diet)]
args [(RetType fromrep, RetAls)]
rettype (Safety, SrcLoc, [SrcLoc])
loc) = do
  [(SubExp, Diet)]
args' <- [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args
  Space
space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  -- We assume that every array is going to be in its own memory.
  let num_extra_args :: Int
num_extra_args = [(SubExp, Diet)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SubExp, Diet)]
args' Int -> Int -> Int
forall a. Num a => a -> a -> a
- [(SubExp, Diet)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SubExp, Diet)]
args
      rettype' :: [(RetTypeMem, RetAls)]
rettype' =
        Space -> [(RetTypeMem, RetAls)]
mems Space
space
          [(RetTypeMem, RetAls)]
-> [(RetTypeMem, RetAls)] -> [(RetTypeMem, RetAls)]
forall a. [a] -> [a] -> [a]
++ [RetTypeMem] -> [RetAls] -> [(RetTypeMem, RetAls)]
forall a b. [a] -> [b] -> [(a, b)]
zip
            (Space -> Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Space
space Int
num_arrays (((DeclExtType, RetAls) -> DeclExtType)
-> [(DeclExtType, RetAls)] -> [DeclExtType]
forall a b. (a -> b) -> [a] -> [b]
map (DeclExtType, RetAls) -> DeclExtType
forall a b. (a, b) -> a
fst [(DeclExtType, RetAls)]
[(RetType fromrep, RetAls)]
rettype))
            (((DeclExtType, RetAls) -> RetAls)
-> [(DeclExtType, RetAls)] -> [RetAls]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> RetAls -> RetAls
shiftRetAls Int
num_extra_args Int
num_arrays (RetAls -> RetAls)
-> ((DeclExtType, RetAls) -> RetAls)
-> (DeclExtType, RetAls)
-> RetAls
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DeclExtType, RetAls) -> RetAls
forall a b. (a, b) -> b
snd) [(DeclExtType, RetAls)]
[(RetType fromrep, RetAls)]
rettype)
  Exp torep -> AllocM fromrep torep (Exp torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp torep -> AllocM fromrep torep (Exp torep))
-> Exp torep -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ Name
-> [(SubExp, Diet)]
-> [(RetType torep, RetAls)]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp torep
forall rep.
Name
-> [(SubExp, Diet)]
-> [(RetType rep, RetAls)]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp rep
Apply Name
fname [(SubExp, Diet)]
args' [(RetType torep, RetAls)]
[(RetTypeMem, RetAls)]
rettype' (Safety, SrcLoc, [SrcLoc])
loc
  where
    mems :: Space -> [(RetTypeMem, RetAls)]
mems Space
space = Int -> (RetTypeMem, RetAls) -> [(RetTypeMem, RetAls)]
forall a. Int -> a -> [a]
replicate Int
num_arrays (Space -> RetTypeMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space, [Int] -> [Int] -> RetAls
RetAls [Int]
forall a. Monoid a => a
mempty [Int]
forall a. Monoid a => a
mempty)
    num_arrays :: Int
num_arrays = [(DeclExtType, RetAls)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([(DeclExtType, RetAls)] -> Int) -> [(DeclExtType, RetAls)] -> Int
forall a b. (a -> b) -> a -> b
$ ((DeclExtType, RetAls) -> Bool)
-> [(DeclExtType, RetAls)] -> [(DeclExtType, RetAls)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (Int -> Bool)
-> ((DeclExtType, RetAls) -> Int) -> (DeclExtType, RetAls) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DeclExtType -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (DeclExtType -> Int)
-> ((DeclExtType, RetAls) -> DeclExtType)
-> (DeclExtType, RetAls)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DeclExtType -> DeclExtType
forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf (DeclExtType -> DeclExtType)
-> ((DeclExtType, RetAls) -> DeclExtType)
-> (DeclExtType, RetAls)
-> DeclExtType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DeclExtType, RetAls) -> DeclExtType
forall a b. (a, b) -> a
fst) [(DeclExtType, RetAls)]
[(RetType fromrep, RetAls)]
rettype
allocInExp (Match [SubExp]
ses [Case (Body fromrep)]
cases Body fromrep
defbody (MatchDec [BranchType fromrep]
rets MatchSort
ifsort)) = do
  (Body torep
defbody', [MemReqType]
def_reqs) <- [ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [ExtType]
[BranchType fromrep]
rets Body fromrep
defbody
  ([Case (Body torep)]
cases', [[MemReqType]]
cases_reqs) <- (Case (Body fromrep)
 -> AllocM fromrep torep (Case (Body torep), [MemReqType]))
-> [Case (Body fromrep)]
-> AllocM fromrep torep ([Case (Body torep)], [[MemReqType]])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM Case (Body fromrep)
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
onCase [Case (Body fromrep)]
cases
  let reqs :: [MemReqType]
reqs = (MemReqType -> [MemReqType] -> MemReqType)
-> [MemReqType] -> [[MemReqType]] -> [MemReqType]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((MemReqType -> MemReqType -> MemReqType)
-> MemReqType -> [MemReqType] -> MemReqType
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl MemReqType -> MemReqType -> MemReqType
combMemReqTypes) [MemReqType]
def_reqs ([[MemReqType]] -> [[MemReqType]]
forall a. [[a]] -> [[a]]
transpose [[MemReqType]]
cases_reqs)
  Body torep
defbody'' <- [MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs Body torep
defbody'
  [Case (Body torep)]
cases'' <- (Case (Body torep) -> AllocM fromrep torep (Case (Body torep)))
-> [Case (Body torep)] -> AllocM fromrep torep [Case (Body torep)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body torep -> AllocM fromrep torep (Body torep))
-> Case (Body torep) -> AllocM fromrep torep (Case (Body torep))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse ((Body torep -> AllocM fromrep torep (Body torep))
 -> Case (Body torep) -> AllocM fromrep torep (Case (Body torep)))
-> (Body torep -> AllocM fromrep torep (Body torep))
-> Case (Body torep)
-> AllocM fromrep torep (Case (Body torep))
forall a b. (a -> b) -> a -> b
$ [MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs) [Case (Body torep)]
cases'
  let ([Case (Body torep)]
cases''', Body torep
defbody''', [BranchTypeMem]
rets') =
        [Case (Body torep)]
-> Body torep
-> [BranchTypeMem]
-> ([Case (Body torep)], Body torep, [BranchTypeMem])
forall rep (inner :: * -> *).
Mem rep inner =>
[Case (Body rep)]
-> Body rep
-> [BranchTypeMem]
-> ([Case (Body rep)], Body rep, [BranchTypeMem])
simplifyMatch [Case (Body torep)]
cases'' Body torep
defbody'' ([BranchTypeMem]
 -> ([Case (Body torep)], Body torep, [BranchTypeMem]))
-> [BranchTypeMem]
-> ([Case (Body torep)], Body torep, [BranchTypeMem])
forall a b. (a -> b) -> a -> b
$ [MemReqType] -> [BranchTypeMem]
mkBranchRet [MemReqType]
reqs
  Exp torep -> AllocM fromrep torep (Exp torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp torep -> AllocM fromrep torep (Exp torep))
-> Exp torep -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body torep)]
-> Body torep
-> MatchDec (BranchType torep)
-> Exp torep
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses [Case (Body torep)]
cases''' Body torep
defbody''' (MatchDec (BranchType torep) -> Exp torep)
-> MatchDec (BranchType torep) -> Exp torep
forall a b. (a -> b) -> a -> b
$ [BranchTypeMem] -> MatchSort -> MatchDec BranchTypeMem
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchTypeMem]
rets' MatchSort
ifsort
  where
    onCase :: Case (Body fromrep)
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
onCase (Case [Maybe PrimValue]
vs Body fromrep
body) = (Body torep -> Case (Body torep))
-> (Body torep, [MemReqType]) -> (Case (Body torep), [MemReqType])
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ([Maybe PrimValue] -> Body torep -> Case (Body torep)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs) ((Body torep, [MemReqType]) -> (Case (Body torep), [MemReqType]))
-> AllocM fromrep torep (Body torep, [MemReqType])
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [ExtType]
[BranchType fromrep]
rets Body fromrep
body
allocInExp (WithAcc [WithAccInput fromrep]
inputs Lambda fromrep
bodylam) =
  [WithAccInput torep] -> Lambda torep -> Exp torep
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc ([WithAccInput torep] -> Lambda torep -> Exp torep)
-> AllocM fromrep torep [WithAccInput torep]
-> AllocM fromrep torep (Lambda torep -> Exp torep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithAccInput fromrep -> AllocM fromrep torep (WithAccInput torep))
-> [WithAccInput fromrep]
-> AllocM fromrep torep [WithAccInput torep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM WithAccInput fromrep -> AllocM fromrep torep (WithAccInput torep)
forall {rep} {fromrep} {inner :: * -> *} {t :: * -> *} {a} {b}.
(BodyDec rep ~ (), BodyDec fromrep ~ (), ExpDec rep ~ (),
 LetDec rep ~ LParamMem, RetType rep ~ RetTypeMem,
 RetType fromrep ~ DeclExtType, FParamInfo rep ~ FParamMem,
 FParamInfo fromrep ~ DeclType, LParamInfo fromrep ~ Type,
 LParamInfo rep ~ LParamMem, BranchType rep ~ BranchTypeMem,
 BranchType fromrep ~ ExtType, OpC rep ~ MemOp inner, Traversable t,
 ArrayShape a, BuilderOps rep, RephraseOp inner,
 OpReturns (inner rep), SizeSubst (inner rep), PrettyRep fromrep) =>
(a, [VName], t (Lambda fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput [WithAccInput fromrep]
inputs AllocM fromrep torep (Lambda torep -> Exp torep)
-> AllocM fromrep torep (Lambda torep)
-> AllocM fromrep torep (Exp torep)
forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda fromrep -> AllocM fromrep torep (Lambda torep)
forall {torep} {fromrep} {inner :: * -> *}.
(RetType torep ~ RetTypeMem, RetType fromrep ~ DeclExtType,
 FParamInfo torep ~ FParamMem, FParamInfo fromrep ~ DeclType,
 BodyDec torep ~ (), BodyDec fromrep ~ (), ExpDec torep ~ (),
 LetDec torep ~ LParamMem, LParamInfo fromrep ~ Type,
 LParamInfo torep ~ LParamMem, BranchType fromrep ~ ExtType,
 BranchType torep ~ BranchTypeMem, OpC torep ~ MemOp inner,
 PrettyRep fromrep, OpReturns (inner torep), RephraseOp inner,
 SizeSubst (inner torep), BuilderOps torep) =>
Lambda fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
bodylam
  where
    onLambda :: Lambda fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
lam = do
      [Param LParamMem]
params <- [Param Type]
-> (Param Type -> AllocM fromrep torep (Param LParamMem))
-> AllocM fromrep torep [Param LParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda fromrep -> [LParam fromrep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda fromrep
lam) ((Param Type -> AllocM fromrep torep (Param LParamMem))
 -> AllocM fromrep torep [Param LParamMem])
-> (Param Type -> AllocM fromrep torep (Param LParamMem))
-> AllocM fromrep torep [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ \(Param Attrs
attrs VName
pv Type
t) ->
        case Type
t of
          Prim PrimType
Unit -> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param LParamMem -> AllocM fromrep torep (Param LParamMem))
-> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ Attrs -> VName -> LParamMem -> Param LParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv (LParamMem -> Param LParamMem) -> LParamMem -> Param LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
Unit
          Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u -> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param LParamMem -> AllocM fromrep torep (Param LParamMem))
-> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ Attrs -> VName -> LParamMem -> Param LParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv (LParamMem -> Param LParamMem) -> LParamMem -> Param LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
          Type
_ -> String -> AllocM fromrep torep (Param LParamMem)
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep torep (Param LParamMem))
-> String -> AllocM fromrep torep (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ String
"Unexpected WithAcc lambda param: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Param Type -> String
forall a. Pretty a => a -> String
prettyString (Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv Type
t)
      [LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
[Param LParamMem]
params (Lambda fromrep -> Body fromrep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda fromrep
lam)

    onInput :: (a, [VName], t (Lambda fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput (a
shape, [VName]
arrs, t (Lambda fromrep, b)
op) =
      (a
shape,[VName]
arrs,) (t (Lambda rep, b) -> (a, [VName], t (Lambda rep, b)))
-> AllocM fromrep rep (t (Lambda rep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Lambda fromrep, b) -> AllocM fromrep rep (Lambda rep, b))
-> t (Lambda fromrep, b) -> AllocM fromrep rep (t (Lambda rep, b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b)
traverse (a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
forall {rep} {fromrep} {inner :: * -> *} {a} {b}.
(ExpDec rep ~ (), RetType rep ~ RetTypeMem,
 RetType fromrep ~ DeclExtType, FParamInfo rep ~ FParamMem,
 FParamInfo fromrep ~ DeclType, LParamInfo fromrep ~ Type,
 LParamInfo rep ~ LParamMem, BranchType rep ~ BranchTypeMem,
 BranchType fromrep ~ ExtType, LetDec rep ~ LParamMem,
 BodyDec fromrep ~ (), BodyDec rep ~ (), OpC rep ~ MemOp inner,
 ArrayShape a, BuilderOps rep, PrettyRep fromrep, RephraseOp inner,
 OpReturns (inner rep), SizeSubst (inner rep)) =>
a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
shape [VName]
arrs) t (Lambda fromrep, b)
op

    onOp :: a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
accshape [VName]
arrs (Lambda fromrep
lam, b
nes) = do
      let num_vs :: Int
num_vs = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda fromrep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda fromrep
lam)
          num_is :: Int
num_is = a -> Int
forall a. ArrayShape a => a -> Int
shapeRank a
accshape
          ([Param Type]
i_params, [Param Type]
x_params, [Param Type]
y_params) =
            Int
-> Int
-> [Param Type]
-> ([Param Type], [Param Type], [Param Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 Int
num_is Int
num_vs ([Param Type] -> ([Param Type], [Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda fromrep -> [LParam fromrep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda fromrep
lam
          i_params' :: [Param LParamMem]
i_params' = (Param Type -> Param LParamMem)
-> [Param Type] -> [Param LParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (\(Param Attrs
attrs VName
v Type
_) -> Attrs -> VName -> LParamMem -> Param LParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
v (LParamMem -> Param LParamMem) -> LParamMem -> Param LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64) [Param Type]
i_params
          is :: [DimIndex SubExp]
is = (Param LParamMem -> DimIndex SubExp)
-> [Param LParamMem] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (Param LParamMem -> SubExp)
-> Param LParamMem
-> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (Param LParamMem -> VName) -> Param LParamMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param LParamMem -> VName
forall dec. Param dec -> VName
paramName) [Param LParamMem]
i_params'
      [Param LParamMem]
x_params' <- (Param Type -> VName -> AllocM fromrep rep (Param LParamMem))
-> [Param Type] -> [VName] -> AllocM fromrep rep [Param LParamMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([DimIndex SubExp]
-> Param Type -> VName -> AllocM fromrep rep (Param LParamMem)
forall {rep} {inner :: * -> *} {f :: * -> *} {u}.
(OpC rep ~ MemOp inner, BranchType rep ~ BranchTypeMem,
 LParamInfo rep ~ LParamMem, FParamInfo rep ~ FParamMem,
 RetType rep ~ RetTypeMem, Monad f, HasLetDecMem (LetDec rep),
 ASTRep rep, OpReturns (inner rep), RephraseOp inner,
 HasScope rep f, Pretty u) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
is) [Param Type]
x_params [VName]
arrs
      [Param LParamMem]
y_params' <- (Param Type -> VName -> AllocM fromrep rep (Param LParamMem))
-> [Param Type] -> [VName] -> AllocM fromrep rep [Param LParamMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([DimIndex SubExp]
-> Param Type -> VName -> AllocM fromrep rep (Param LParamMem)
forall {rep} {fromrep} {inner :: * -> *} {u}.
(ExpDec rep ~ (), BranchType rep ~ BranchTypeMem,
 BranchType fromrep ~ ExtType, LParamInfo rep ~ LParamMem,
 LParamInfo fromrep ~ Type, OpC rep ~ MemOp inner,
 LetDec rep ~ LParamMem, BodyDec fromrep ~ (), BodyDec rep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo rep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType rep ~ RetTypeMem,
 PrettyRep fromrep, OpReturns (inner rep), RephraseOp inner,
 SizeSubst (inner rep), BuilderOps rep, Pretty u) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
is) [Param Type]
y_params [VName]
arrs
      Lambda rep
lam' <-
        [LParam rep] -> Body fromrep -> AllocM fromrep rep (Lambda rep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda
          ([Param LParamMem]
i_params' [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
x_params' [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
y_params')
          (Lambda fromrep -> Body fromrep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda fromrep
lam)
      (Lambda rep, b) -> AllocM fromrep rep (Lambda rep, b)
forall a. a -> AllocM fromrep rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep
lam', b
nes)

    mkP :: Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is =
      Attrs
-> VName
-> MemInfo SubExp u MemBind
-> Param (MemInfo SubExp u MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind))
-> (Slice (TPrimExp Int64 VName) -> MemInfo SubExp u MemBind)
-> Slice (TPrimExp Int64 VName)
-> Param (MemInfo SubExp u MemBind)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> u -> MemBind -> MemInfo SubExp u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape u
u (MemBind -> MemInfo SubExp u MemBind)
-> (Slice (TPrimExp Int64 VName) -> MemBind)
-> Slice (TPrimExp Int64 VName)
-> MemInfo SubExp u MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (IxFun (TPrimExp Int64 VName) -> MemBind)
-> (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName)
-> MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
ixfun (Slice (TPrimExp Int64 VName) -> Param (MemInfo SubExp u MemBind))
-> Slice (TPrimExp Int64 VName) -> Param (MemInfo SubExp u MemBind)
forall a b. (a -> b) -> a -> b
$
        (SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 (Slice SubExp -> Slice (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
          [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
            [DimIndex SubExp]
is [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape)

    onXParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
      Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> f (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> MemInfo SubExp u MemBind
-> Param (MemInfo SubExp u MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (PrimType -> MemInfo SubExp u MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
    onXParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- VName -> f (VName, IxFun (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
arr
      Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> f (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
forall {u}.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is
    onXParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      String -> f (Param (MemInfo SubExp u MemBind))
forall a. HasCallStack => String -> a
error (String -> f (Param (MemInfo SubExp u MemBind)))
-> String -> f (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Param (TypeBase Shape u) -> String
forall a. Pretty a => a -> String
prettyString Param (TypeBase Shape u)
p

    onYParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
      Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a. a -> AllocM fromrep rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> MemInfo SubExp u MemBind
-> Param (MemInfo SubExp u MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind))
-> MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo SubExp u MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    onYParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      Type
arr_t <- VName -> AllocM fromrep rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      Space
space <- AllocM fromrep rep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
      VName
mem <- Type -> Space -> AllocM fromrep rep VName
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
arr_t Space
space
      let base_dims :: [TPrimExp Int64 VName]
base_dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
          ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
base_dims
      Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a. a -> AllocM fromrep rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
forall {u}.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is
    onYParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      String -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind)))
-> String -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Param (TypeBase Shape u) -> String
forall a. Pretty a => a -> String
prettyString Param (TypeBase Shape u)
p
allocInExp Exp fromrep
e = Mapper fromrep torep (AllocM fromrep torep)
-> Exp fromrep -> AllocM fromrep torep (Exp torep)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper fromrep torep (AllocM fromrep torep)
forall {fromrep} {trep}. Mapper fromrep trep (AllocM fromrep trep)
alloc Exp fromrep
e
  where
    alloc :: Mapper fromrep trep (AllocM fromrep trep)
alloc =
      Mapper Any Any (AllocM fromrep trep)
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope trep -> Body fromrep -> AllocM fromrep trep (Body trep)
mapOnBody = String
-> Scope trep -> Body fromrep -> AllocM fromrep trep (Body trep)
forall a. HasCallStack => String -> a
error String
"Unhandled Body in ExplicitAllocations",
          mapOnRetType :: RetType fromrep -> AllocM fromrep trep (RetType trep)
mapOnRetType = String -> RetType fromrep -> AllocM fromrep trep (RetType trep)
forall a. HasCallStack => String -> a
error String
"Unhandled RetType in ExplicitAllocations",
          mapOnBranchType :: BranchType fromrep -> AllocM fromrep trep (BranchType trep)
mapOnBranchType = String
-> BranchType fromrep -> AllocM fromrep trep (BranchType trep)
forall a. HasCallStack => String -> a
error String
"Unhandled BranchType in ExplicitAllocations",
          mapOnFParam :: FParam fromrep -> AllocM fromrep trep (FParam trep)
mapOnFParam = String -> FParam fromrep -> AllocM fromrep trep (FParam trep)
forall a. HasCallStack => String -> a
error String
"Unhandled FParam in ExplicitAllocations",
          mapOnLParam :: LParam fromrep -> AllocM fromrep trep (LParam trep)
mapOnLParam = String -> LParam fromrep -> AllocM fromrep trep (LParam trep)
forall a. HasCallStack => String -> a
error String
"Unhandled LParam in ExplicitAllocations",
          mapOnOp :: Op fromrep -> AllocM fromrep trep (Op trep)
mapOnOp = \Op fromrep
op -> do
            Op fromrep -> AllocM fromrep trep (Op trep)
handle <- (AllocEnv fromrep trep
 -> Op fromrep -> AllocM fromrep trep (Op trep))
-> AllocM
     fromrep trep (Op fromrep -> AllocM fromrep trep (Op trep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep trep
-> Op fromrep -> AllocM fromrep trep (Op trep)
forall fromrep torep.
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp
            Op fromrep -> AllocM fromrep trep (Op trep)
handle Op fromrep
op
        }

allocInLoopForm ::
  (Allocable fromrep torep inner) =>
  LoopForm fromrep ->
  AllocM fromrep torep (LoopForm torep)
allocInLoopForm :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
allocInLoopForm (WhileLoop VName
v) = LoopForm torep -> AllocM fromrep torep (LoopForm torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LoopForm torep -> AllocM fromrep torep (LoopForm torep))
-> LoopForm torep -> AllocM fromrep torep (LoopForm torep)
forall a b. (a -> b) -> a -> b
$ VName -> LoopForm torep
forall rep. VName -> LoopForm rep
WhileLoop VName
v
allocInLoopForm (ForLoop VName
i IntType
it SubExp
n [(LParam fromrep, VName)]
loopvars) =
  VName
-> IntType -> SubExp -> [(LParam torep, VName)] -> LoopForm torep
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
n ([(Param LParamMem, VName)] -> LoopForm torep)
-> AllocM fromrep torep [(Param LParamMem, VName)]
-> AllocM fromrep torep (LoopForm torep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName)
 -> AllocM fromrep torep (Param LParamMem, VName))
-> [(Param Type, VName)]
-> AllocM fromrep torep [(Param LParamMem, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Param Type, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
allocInLoopVar [(Param Type, VName)]
[(LParam fromrep, VName)]
loopvars
  where
    allocInLoopVar :: (Param Type, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
allocInLoopVar (Param Type
p, VName
a) = do
      (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
a
      case Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p of
        Array PrimType
pt Shape
shape NoUniqueness
u -> do
          [TPrimExp Int64 VName]
dims <- (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> (Type -> [SubExp]) -> Type -> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [TPrimExp Int64 VName])
-> AllocM fromrep torep Type
-> AllocM fromrep torep [TPrimExp Int64 VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> AllocM fromrep torep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
          let ixfun' :: IxFun (TPrimExp Int64 VName)
ixfun' = IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName]
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum [TPrimExp Int64 VName]
dims [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i]
          (Param LParamMem, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun'}, VName
a)
        Prim PrimType
bt ->
          (Param LParamMem, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt}, VName
a)
        Mem Space
space ->
          (Param LParamMem, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}, VName
a)
        Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
          (Param LParamMem, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u}, VName
a)

class SizeSubst op where
  opIsConst :: op -> Bool
  opIsConst = Bool -> op -> Bool
forall a b. a -> b -> a
const Bool
False

instance SizeSubst (NoOp rep)

instance (SizeSubst (op rep)) => SizeSubst (MemOp op rep) where
  opIsConst :: MemOp op rep -> Bool
opIsConst (Inner op rep
op) = op rep -> Bool
forall op. SizeSubst op => op -> Bool
opIsConst op rep
op
  opIsConst MemOp op rep
_ = Bool
False

stmConsts :: (SizeSubst (Op rep)) => Stm rep -> S.Set VName
stmConsts :: forall rep. SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Op Op rep
op))
  | Op rep -> Bool
forall op. SizeSubst op => op -> Bool
opIsConst Op rep
op = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat
stmConsts Stm rep
_ = Set VName
forall a. Monoid a => a
mempty

mkLetNamesB' ::
  ( LetDec (Rep m) ~ LetDecMem,
    Mem (Rep m) inner,
    MonadBuilder m,
    ExpDec (Rep m) ~ ()
  ) =>
  Space ->
  ExpDec (Rep m) ->
  [VName] ->
  Exp (Rep m) ->
  m (Stm (Rep m))
mkLetNamesB' :: forall (m :: * -> *) (inner :: * -> *).
(LetDec (Rep m) ~ LParamMem, Mem (Rep m) inner, MonadBuilder m,
 ExpDec (Rep m) ~ ()) =>
Space
-> ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' Space
space ExpDec (Rep m)
dec [VName]
names Exp (Rep m)
e = do
  Pat LParamMem
pat <- Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
space [VName]
names Exp (Rep m)
e [ExpHint]
nohints
  Stm (Rep m) -> m (Stm (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm (Rep m) -> m (Stm (Rep m))) -> Stm (Rep m) -> m (Stm (Rep m))
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep m))
-> StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Rep m))
Pat LParamMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()
ExpDec (Rep m)
dec) Exp (Rep m)
e
  where
    nohints :: [ExpHint]
nohints = (VName -> ExpHint) -> [VName] -> [ExpHint]
forall a b. (a -> b) -> [a] -> [b]
map (ExpHint -> VName -> ExpHint
forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names

mkLetNamesB'' ::
  ( Mem rep inner,
    LetDec rep ~ LetDecMem,
    OpReturns (inner (Engine.Wise rep)),
    ExpDec rep ~ (),
    Rep m ~ Engine.Wise rep,
    HasScope (Engine.Wise rep) m,
    MonadBuilder m,
    AliasedOp (inner (Engine.Wise rep)),
    RephraseOp (MemOp inner),
    Engine.CanBeWise inner
  ) =>
  Space ->
  [VName] ->
  Exp (Engine.Wise rep) ->
  m (Stm (Engine.Wise rep))
mkLetNamesB'' :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, LetDec rep ~ LParamMem,
 OpReturns (inner (Wise rep)), ExpDec rep ~ (), Rep m ~ Wise rep,
 HasScope (Wise rep) m, MonadBuilder m,
 AliasedOp (inner (Wise rep)), RephraseOp (MemOp inner),
 CanBeWise inner) =>
Space -> [VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB'' Space
space [VName]
names Exp (Wise rep)
e = do
  Pat LParamMem
pat <- Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
space [VName]
names Exp (Rep m)
Exp (Wise rep)
e [ExpHint]
nohints
  let pat' :: Pat (LetDec (Wise rep))
pat' = Pat (LetDec rep) -> Exp (Wise rep) -> Pat (LetDec (Wise rep))
forall rep.
Informing rep =>
Pat (LetDec rep) -> Exp (Wise rep) -> Pat (LetDec (Wise rep))
Engine.addWisdomToPat Pat (LetDec rep)
Pat LParamMem
pat Exp (Wise rep)
e
      dec :: ExpDec (Wise rep)
dec = Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
forall rep.
Informing rep =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (LetDec (Wise rep))
pat' () Exp (Wise rep)
e
  Stm (Wise rep) -> m (Stm (Wise rep))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm (Wise rep) -> m (Stm (Wise rep)))
-> Stm (Wise rep) -> m (Stm (Wise rep))
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Wise rep))
-> StmAux (ExpDec (Wise rep)) -> Exp (Wise rep) -> Stm (Wise rep)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Wise rep))
pat' ((ExpWisdom, ()) -> StmAux (ExpWisdom, ())
forall dec. dec -> StmAux dec
defAux (ExpWisdom, ())
ExpDec (Wise rep)
dec) Exp (Wise rep)
e
  where
    nohints :: [ExpHint]
nohints = (VName -> ExpHint) -> [VName] -> [ExpHint]
forall a b. (a -> b) -> [a] -> [b]
map (ExpHint -> VName -> ExpHint
forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names

simplifyMemOp ::
  (Engine.SimplifiableRep rep) =>
  ( inner (Engine.Wise rep) ->
    Engine.SimpleM rep (inner (Engine.Wise rep), Stms (Engine.Wise rep))
  ) ->
  MemOp inner (Engine.Wise rep) ->
  Engine.SimpleM rep (MemOp inner (Engine.Wise rep), Stms (Engine.Wise rep))
simplifyMemOp :: forall rep (inner :: * -> *).
SimplifiableRep rep =>
(inner (Wise rep)
 -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> MemOp inner (Wise rep)
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
simplifyMemOp inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
_ (Alloc SubExp
size Space
space) =
  (,) (MemOp inner (Wise rep)
 -> Stms (Wise rep) -> (MemOp inner (Wise rep), Stms (Wise rep)))
-> SimpleM rep (MemOp inner (Wise rep))
-> SimpleM
     rep (Stms (Wise rep) -> (MemOp inner (Wise rep), Stms (Wise rep)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Space -> MemOp inner (Wise rep)
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc (SubExp -> Space -> MemOp inner (Wise rep))
-> SimpleM rep SubExp
-> SimpleM rep (Space -> MemOp inner (Wise rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
size SimpleM rep (Space -> MemOp inner (Wise rep))
-> SimpleM rep Space -> SimpleM rep (MemOp inner (Wise rep))
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> SimpleM rep Space
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space) SimpleM
  rep (Stms (Wise rep) -> (MemOp inner (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Stms (Wise rep))
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms (Wise rep)
forall a. Monoid a => a
mempty
simplifyMemOp inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
onInner (Inner inner (Wise rep)
k) = do
  (inner (Wise rep)
k', Stms (Wise rep)
hoisted) <- inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
onInner inner (Wise rep)
k
  (MemOp inner (Wise rep), Stms (Wise rep))
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (inner (Wise rep) -> MemOp inner (Wise rep)
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner inner (Wise rep)
k', Stms (Wise rep)
hoisted)

simplifiable ::
  ( Engine.SimplifiableRep rep,
    LetDec rep ~ LetDecMem,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    Mem (Engine.Wise rep) inner,
    Engine.CanBeWise inner,
    RephraseOp inner,
    IsOp (inner rep),
    OpReturns (inner (Engine.Wise rep)),
    AliasedOp (inner (Engine.Wise rep)),
    IndexOp (inner (Engine.Wise rep))
  ) =>
  (inner (Engine.Wise rep) -> UT.UsageTable) ->
  ( inner (Engine.Wise rep) ->
    Engine.SimpleM rep (inner (Engine.Wise rep), Stms (Engine.Wise rep))
  ) ->
  SimpleOps rep
simplifiable :: forall rep (inner :: * -> *).
(SimplifiableRep rep, LetDec rep ~ LParamMem, ExpDec rep ~ (),
 BodyDec rep ~ (), Mem (Wise rep) inner, CanBeWise inner,
 RephraseOp inner, IsOp (inner rep), OpReturns (inner (Wise rep)),
 AliasedOp (inner (Wise rep)), IndexOp (inner (Wise rep))) =>
(inner (Wise rep) -> UsageTable)
-> (inner (Wise rep)
    -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> SimpleOps rep
simplifiable inner (Wise rep) -> UsageTable
innerUsage inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
simplifyInnerOp =
  (SymbolTable (Wise rep)
 -> Pat (LetDec (Wise rep))
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (BuilderT (Wise rep) (State VNameSource))
-> (Op (Wise rep) -> UsageTable)
-> SimplifyOp rep (Op (Wise rep))
-> SimpleOps rep
forall {k} (rep :: k).
(SymbolTable (Wise rep)
 -> Pat (LetDec (Wise rep))
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (Builder (Wise rep))
-> (Op (Wise rep) -> UsageTable)
-> SimplifyOp rep (Op (Wise rep))
-> SimpleOps rep
SimpleOps SymbolTable (Wise rep)
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> SimpleM rep (ExpWisdom, ExpDec rep)
SymbolTable (Wise rep)
-> Pat (LetDec (Wise rep))
-> Exp (Wise rep)
-> SimpleM rep (ExpDec (Wise rep))
forall {rep} {f :: * -> *} {p}.
(ExpDec rep ~ (), Applicative f, ASTRep rep,
 AliasedOp (OpC rep (Wise rep)), CanBeWise (OpC rep)) =>
p
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> f (ExpWisdom, ExpDec rep)
mkExpDecS' SymbolTable (Wise rep)
-> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
forall {rep} {f :: * -> *} {p}.
(BodyDec rep ~ (), Applicative f, ASTRep rep,
 AliasedOp (OpC rep (Wise rep)), CanBeWise (OpC rep)) =>
p -> Stms (Wise rep) -> Result -> f (Body (Wise rep))
mkBodyS' Protect (BuilderT (Wise rep) (State VNameSource))
SubExp
-> Pat (LetDec (Rep (BuilderT (Wise rep) (State VNameSource))))
-> MemOp inner (Wise rep)
-> Maybe (BuilderT (Wise rep) (State VNameSource) ())
forall {m :: * -> *} {d} {u} {ret} {inner :: * -> *}
       {inner :: * -> *} {rep}.
(BranchType (Rep m) ~ MemInfo d u ret, OpC (Rep m) ~ MemOp inner,
 MonadBuilder m, IsBodyType (MemInfo d u ret), IsOp (inner (Rep m)),
 RephraseOp inner, Ord d, Ord ret, Ord u, Show d, Show ret, Show u,
 Substitute d, Substitute ret, FreeIn d, FreeIn ret, Pretty d,
 Pretty u, Pretty ret, Pretty (TypeBase (ShapeBase d) u),
 Pretty (ShapeBase d)) =>
SubExp -> Pat (LetDec (Rep m)) -> MemOp inner rep -> Maybe (m ())
protectOp Op (Wise rep) -> UsageTable
MemOp inner (Wise rep) -> UsageTable
opUsage ((inner (Wise rep)
 -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> MemOp inner (Wise rep)
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
forall rep (inner :: * -> *).
SimplifiableRep rep =>
(inner (Wise rep)
 -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> MemOp inner (Wise rep)
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
simplifyMemOp inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
simplifyInnerOp)
  where
    mkExpDecS' :: p
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> f (ExpWisdom, ExpDec rep)
mkExpDecS' p
_ Pat (VarWisdom, LetDec rep)
pat Exp (Wise rep)
e =
      (ExpWisdom, ExpDec rep) -> f (ExpWisdom, ExpDec rep)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((ExpWisdom, ExpDec rep) -> f (ExpWisdom, ExpDec rep))
-> (ExpWisdom, ExpDec rep) -> f (ExpWisdom, ExpDec rep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
forall rep.
Informing rep =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (VarWisdom, LetDec rep)
Pat (LetDec (Wise rep))
pat () Exp (Wise rep)
e

    mkBodyS' :: p -> Stms (Wise rep) -> Result -> f (Body (Wise rep))
mkBodyS' p
_ Stms (Wise rep)
stms Result
res = Body (Wise rep) -> f (Body (Wise rep))
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Wise rep) -> f (Body (Wise rep)))
-> Body (Wise rep) -> f (Body (Wise rep))
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
forall rep.
Informing rep =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody () Stms (Wise rep)
stms Result
res

    protectOp :: SubExp -> Pat (LetDec (Rep m)) -> MemOp inner rep -> Maybe (m ())
protectOp SubExp
taken Pat (LetDec (Rep m))
pat (Alloc SubExp
size Space
space) = m () -> Maybe (m ())
forall a. a -> Maybe a
Just (m () -> Maybe (m ())) -> m () -> Maybe (m ())
forall a b. (a -> b) -> a -> b
$ do
      Body (Rep m)
tbody <- [SubExp] -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
size]
      Body (Rep m)
fbody <- [SubExp] -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
      SubExp
size' <-
        String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"hoisted_alloc_size" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          [SubExp]
-> [Case (Body (Rep m))]
-> Body (Rep m)
-> MatchDec (BranchType (Rep m))
-> Exp (Rep m)
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp
taken] [[Maybe PrimValue] -> Body (Rep m) -> Case (Body (Rep m))
forall body. [Maybe PrimValue] -> body -> Case body
Case [PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just (PrimValue -> Maybe PrimValue) -> PrimValue -> Maybe PrimValue
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] Body (Rep m)
tbody] Body (Rep m)
fbody (MatchDec (BranchType (Rep m)) -> Exp (Rep m))
-> MatchDec (BranchType (Rep m)) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
            [MemInfo d u ret] -> MatchSort -> MatchDec (MemInfo d u ret)
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64] MatchSort
MatchFallback
      Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner (Rep m)
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
size' Space
space
    protectOp SubExp
_ Pat (LetDec (Rep m))
_ MemOp inner rep
_ = Maybe (m ())
forall a. Maybe a
Nothing

    opUsage :: MemOp inner (Wise rep) -> UsageTable
opUsage (Alloc (Var VName
size) Space
_) =
      VName -> UsageTable
UT.sizeUsage VName
size
    opUsage (Alloc SubExp
_ Space
_) =
      UsageTable
forall a. Monoid a => a
mempty
    opUsage (Inner inner (Wise rep)
inner) =
      inner (Wise rep) -> UsageTable
innerUsage inner (Wise rep)
inner

data ExpHint
  = NoHint
  | Hint IxFun Space

defaultExpHints :: (ASTRep rep, HasScope rep m) => Exp rep -> m [ExpHint]
defaultExpHints :: forall rep (m :: * -> *).
(ASTRep rep, HasScope rep m) =>
Exp rep -> m [ExpHint]
defaultExpHints Exp rep
e = (ExtType -> ExpHint) -> [ExtType] -> [ExpHint]
forall a b. (a -> b) -> [a] -> [b]
map (ExpHint -> ExtType -> ExpHint
forall a b. a -> b -> a
const ExpHint
NoHint) ([ExtType] -> [ExpHint]) -> m [ExtType] -> m [ExpHint]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp rep -> m [ExtType]
forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp rep
e