{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | A generic transformation for adding memory allocations to a
-- Futhark program.  Specialised by specific representations in
-- submodules.
module Futhark.Pass.ExplicitAllocations
  ( explicitAllocationsGeneric,
    explicitAllocationsInStmsGeneric,
    ExpHint (..),
    defaultExpHints,
    Allocable,
    Allocator (..),
    AllocM,
    AllocEnv (..),
    SizeSubst (..),
    allocInStms,
    allocForArray,
    simplifiable,
    arraySizeInBytesExp,
    mkLetNamesB',
    mkLetNamesB'',

    -- * Module re-exports

    --
    -- These are highly likely to be needed by any downstream
    -- users.
    module Control.Monad.Reader,
    module Futhark.MonadFreshNames,
    module Futhark.Pass,
    module Futhark.Tools,
  )
where

import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.List (foldl', partition, sort, zip4)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.IR.Mem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Engine (SimpleOps (..))
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Lore (mkWiseBody)
import Futhark.Pass
import Futhark.Tools
import Futhark.Util (splitAt3, splitFromEnd, takeLast)

data AllocStm
  = SizeComputation VName (PrimExp VName)
  | Allocation VName SubExp Space
  | ArrayCopy VName VName
  deriving (AllocStm -> AllocStm -> Bool
(AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> Bool) -> Eq AllocStm
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AllocStm -> AllocStm -> Bool
$c/= :: AllocStm -> AllocStm -> Bool
== :: AllocStm -> AllocStm -> Bool
$c== :: AllocStm -> AllocStm -> Bool
Eq, Eq AllocStm
Eq AllocStm
-> (AllocStm -> AllocStm -> Ordering)
-> (AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> AllocStm)
-> (AllocStm -> AllocStm -> AllocStm)
-> Ord AllocStm
AllocStm -> AllocStm -> Bool
AllocStm -> AllocStm -> Ordering
AllocStm -> AllocStm -> AllocStm
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: AllocStm -> AllocStm -> AllocStm
$cmin :: AllocStm -> AllocStm -> AllocStm
max :: AllocStm -> AllocStm -> AllocStm
$cmax :: AllocStm -> AllocStm -> AllocStm
>= :: AllocStm -> AllocStm -> Bool
$c>= :: AllocStm -> AllocStm -> Bool
> :: AllocStm -> AllocStm -> Bool
$c> :: AllocStm -> AllocStm -> Bool
<= :: AllocStm -> AllocStm -> Bool
$c<= :: AllocStm -> AllocStm -> Bool
< :: AllocStm -> AllocStm -> Bool
$c< :: AllocStm -> AllocStm -> Bool
compare :: AllocStm -> AllocStm -> Ordering
$ccompare :: AllocStm -> AllocStm -> Ordering
Ord, Int -> AllocStm -> ShowS
[AllocStm] -> ShowS
AllocStm -> String
(Int -> AllocStm -> ShowS)
-> (AllocStm -> String) -> ([AllocStm] -> ShowS) -> Show AllocStm
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AllocStm] -> ShowS
$cshowList :: [AllocStm] -> ShowS
show :: AllocStm -> String
$cshow :: AllocStm -> String
showsPrec :: Int -> AllocStm -> ShowS
$cshowsPrec :: Int -> AllocStm -> ShowS
Show)

bindAllocStm ::
  (MonadBinder m, Op (Lore m) ~ MemOp inner) =>
  AllocStm ->
  m ()
bindAllocStm :: forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ MemOp inner) =>
AllocStm -> m ()
bindAllocStm (SizeComputation VName
name PrimExp VName
pe) =
  [VName] -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
name] (ExpT (Lore m) -> m ()) -> m (ExpT (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimExp VName -> m (ExpT (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
coerceIntPrimExp IntType
Int64 PrimExp VName
pe)
bindAllocStm (Allocation VName
name SubExp
size Space
space) =
  [VName] -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
name] (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op (Lore m) -> ExpT (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> ExpT (Lore m)) -> Op (Lore m) -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size Space
space
bindAllocStm (ArrayCopy VName
name VName
src) =
  [VName] -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
name] (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
src

class
  (MonadFreshNames m, LocalScope lore m, Mem lore) =>
  Allocator lore m
  where
  addAllocStm :: AllocStm -> m ()
  askDefaultSpace :: m Space

  default addAllocStm ::
    ( Allocable fromlore lore,
      m ~ AllocM fromlore lore
    ) =>
    AllocStm ->
    m ()
  addAllocStm (SizeComputation VName
name PrimExp VName
se) =
    [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
name] (ExpT lore -> m ()) -> m (ExpT lore) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimExp VName -> m (Exp (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
coerceIntPrimExp IntType
Int64 PrimExp VName
se)
  addAllocStm (Allocation VName
name SubExp
size Space
space) =
    [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> Op lore
forall op. AllocOp op => SubExp -> Space -> op
allocOp SubExp
size Space
space
  addAllocStm (ArrayCopy VName
name VName
src) =
    [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
src

  -- | The subexpression giving the number of elements we should
  -- allocate space for.  See 'ChunkMap' comment.
  dimAllocationSize :: SubExp -> m SubExp
  default dimAllocationSize ::
    m ~ AllocM fromlore lore =>
    SubExp ->
    m SubExp
  dimAllocationSize (Var VName
v) =
    -- It is important to recurse here, as the substitution may itself
    -- be a chunk size.
    m SubExp -> (SubExp -> m SubExp) -> Maybe SubExp -> m SubExp
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v) SubExp -> m SubExp
forall lore (m :: * -> *). Allocator lore m => SubExp -> m SubExp
dimAllocationSize (Maybe SubExp -> m SubExp) -> m (Maybe SubExp) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (AllocEnv fromlore lore -> Maybe SubExp) -> m (Maybe SubExp)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName SubExp -> Maybe SubExp)
-> (AllocEnv fromlore lore -> Map VName SubExp)
-> AllocEnv fromlore lore
-> Maybe SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AllocEnv fromlore lore -> Map VName SubExp
forall fromlore tolore.
AllocEnv fromlore tolore -> Map VName SubExp
chunkMap)
  dimAllocationSize SubExp
size =
    SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
size

  -- | Get those names that are known to be constants at run-time.
  askConsts :: m (S.Set VName)

  expHints :: Exp lore -> m [ExpHint]
  expHints = ExpT lore -> m [ExpHint]
forall (m :: * -> *) lore.
(Monad m, ASTLore lore) =>
Exp lore -> m [ExpHint]
defaultExpHints

allocateMemory ::
  Allocator lore m =>
  String ->
  SubExp ->
  Space ->
  m VName
allocateMemory :: forall lore (m :: * -> *).
Allocator lore m =>
String -> SubExp -> Space -> m VName
allocateMemory String
desc SubExp
size Space
space = do
  VName
v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  AllocStm -> m ()
forall lore (m :: * -> *). Allocator lore m => AllocStm -> m ()
addAllocStm (AllocStm -> m ()) -> AllocStm -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp -> Space -> AllocStm
Allocation VName
v SubExp
size Space
space
  VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v

computeSize ::
  Allocator lore m =>
  String ->
  PrimExp VName ->
  m SubExp
computeSize :: forall lore (m :: * -> *).
Allocator lore m =>
String -> PrimExp VName -> m SubExp
computeSize String
desc PrimExp VName
se = do
  VName
v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  AllocStm -> m ()
forall lore (m :: * -> *). Allocator lore m => AllocStm -> m ()
addAllocStm (AllocStm -> m ()) -> AllocStm -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName -> AllocStm
SizeComputation VName
v PrimExp VName
se
  SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v

type Allocable fromlore tolore =
  ( PrettyLore fromlore,
    PrettyLore tolore,
    Mem tolore,
    FParamInfo fromlore ~ DeclType,
    LParamInfo fromlore ~ Type,
    BranchType fromlore ~ ExtType,
    RetType fromlore ~ DeclExtType,
    BodyDec fromlore ~ (),
    BodyDec tolore ~ (),
    ExpDec tolore ~ (),
    SizeSubst (Op tolore),
    BinderOps tolore
  )

-- | A mapping from chunk names to their maximum size.  XXX FIXME
-- HACK: This is part of a hack to add loop-invariant allocations to
-- reduce kernels, because memory expansion does not use range
-- analysis yet (it should).
type ChunkMap = M.Map VName SubExp

data AllocEnv fromlore tolore = AllocEnv
  { forall fromlore tolore.
AllocEnv fromlore tolore -> Map VName SubExp
chunkMap :: ChunkMap,
    -- | Aggressively try to reuse memory in do-loops -
    -- should be True inside kernels, False outside.
    forall fromlore tolore. AllocEnv fromlore tolore -> Bool
aggressiveReuse :: Bool,
    -- | When allocating memory, put it in this memory space.
    -- This is primarily used to ensure that group-wide
    -- statements store their results in local memory.
    forall fromlore tolore. AllocEnv fromlore tolore -> Space
allocSpace :: Space,
    -- | The set of names that are known to be constants at
    -- kernel compile time.
    forall fromlore tolore. AllocEnv fromlore tolore -> Set VName
envConsts :: S.Set VName,
    forall fromlore tolore.
AllocEnv fromlore tolore
-> Op fromlore -> AllocM fromlore tolore (Op tolore)
allocInOp :: Op fromlore -> AllocM fromlore tolore (Op tolore),
    forall fromlore tolore.
AllocEnv fromlore tolore
-> Exp tolore -> AllocM fromlore tolore [ExpHint]
envExpHints :: Exp tolore -> AllocM fromlore tolore [ExpHint]
  }

-- | Monad for adding allocations to an entire program.
newtype AllocM fromlore tolore a
  = AllocM (BinderT tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a)
  deriving
    ( Functor (AllocM fromlore tolore)
Functor (AllocM fromlore tolore)
-> (forall a. a -> AllocM fromlore tolore a)
-> (forall a b.
    AllocM fromlore tolore (a -> b)
    -> AllocM fromlore tolore a -> AllocM fromlore tolore b)
-> (forall a b c.
    (a -> b -> c)
    -> AllocM fromlore tolore a
    -> AllocM fromlore tolore b
    -> AllocM fromlore tolore c)
-> (forall a b.
    AllocM fromlore tolore a
    -> AllocM fromlore tolore b -> AllocM fromlore tolore b)
-> (forall a b.
    AllocM fromlore tolore a
    -> AllocM fromlore tolore b -> AllocM fromlore tolore a)
-> Applicative (AllocM fromlore tolore)
forall a. a -> AllocM fromlore tolore a
forall {fromlore} {tolore}. Functor (AllocM fromlore tolore)
forall a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore a
forall a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
forall a b.
AllocM fromlore tolore (a -> b)
-> AllocM fromlore tolore a -> AllocM fromlore tolore b
forall fromlore tolore a. a -> AllocM fromlore tolore a
forall a b c.
(a -> b -> c)
-> AllocM fromlore tolore a
-> AllocM fromlore tolore b
-> AllocM fromlore tolore c
forall fromlore tolore a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore a
forall fromlore tolore a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
forall fromlore tolore a b.
AllocM fromlore tolore (a -> b)
-> AllocM fromlore tolore a -> AllocM fromlore tolore b
forall fromlore tolore a b c.
(a -> b -> c)
-> AllocM fromlore tolore a
-> AllocM fromlore tolore b
-> AllocM fromlore tolore c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore a
$c<* :: forall fromlore tolore a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore a
*> :: forall a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
$c*> :: forall fromlore tolore a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
liftA2 :: forall a b c.
(a -> b -> c)
-> AllocM fromlore tolore a
-> AllocM fromlore tolore b
-> AllocM fromlore tolore c
$cliftA2 :: forall fromlore tolore a b c.
(a -> b -> c)
-> AllocM fromlore tolore a
-> AllocM fromlore tolore b
-> AllocM fromlore tolore c
<*> :: forall a b.
AllocM fromlore tolore (a -> b)
-> AllocM fromlore tolore a -> AllocM fromlore tolore b
$c<*> :: forall fromlore tolore a b.
AllocM fromlore tolore (a -> b)
-> AllocM fromlore tolore a -> AllocM fromlore tolore b
pure :: forall a. a -> AllocM fromlore tolore a
$cpure :: forall fromlore tolore a. a -> AllocM fromlore tolore a
Applicative,
      (forall a b.
 (a -> b) -> AllocM fromlore tolore a -> AllocM fromlore tolore b)
-> (forall a b.
    a -> AllocM fromlore tolore b -> AllocM fromlore tolore a)
-> Functor (AllocM fromlore tolore)
forall a b.
a -> AllocM fromlore tolore b -> AllocM fromlore tolore a
forall a b.
(a -> b) -> AllocM fromlore tolore a -> AllocM fromlore tolore b
forall fromlore tolore a b.
a -> AllocM fromlore tolore b -> AllocM fromlore tolore a
forall fromlore tolore a b.
(a -> b) -> AllocM fromlore tolore a -> AllocM fromlore tolore b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b.
a -> AllocM fromlore tolore b -> AllocM fromlore tolore a
$c<$ :: forall fromlore tolore a b.
a -> AllocM fromlore tolore b -> AllocM fromlore tolore a
fmap :: forall a b.
(a -> b) -> AllocM fromlore tolore a -> AllocM fromlore tolore b
$cfmap :: forall fromlore tolore a b.
(a -> b) -> AllocM fromlore tolore a -> AllocM fromlore tolore b
Functor,
      Applicative (AllocM fromlore tolore)
Applicative (AllocM fromlore tolore)
-> (forall a b.
    AllocM fromlore tolore a
    -> (a -> AllocM fromlore tolore b) -> AllocM fromlore tolore b)
-> (forall a b.
    AllocM fromlore tolore a
    -> AllocM fromlore tolore b -> AllocM fromlore tolore b)
-> (forall a. a -> AllocM fromlore tolore a)
-> Monad (AllocM fromlore tolore)
forall a. a -> AllocM fromlore tolore a
forall fromlore tolore. Applicative (AllocM fromlore tolore)
forall a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
forall a b.
AllocM fromlore tolore a
-> (a -> AllocM fromlore tolore b) -> AllocM fromlore tolore b
forall fromlore tolore a. a -> AllocM fromlore tolore a
forall fromlore tolore a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
forall fromlore tolore a b.
AllocM fromlore tolore a
-> (a -> AllocM fromlore tolore b) -> AllocM fromlore tolore b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> AllocM fromlore tolore a
$creturn :: forall fromlore tolore a. a -> AllocM fromlore tolore a
>> :: forall a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
$c>> :: forall fromlore tolore a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
>>= :: forall a b.
AllocM fromlore tolore a
-> (a -> AllocM fromlore tolore b) -> AllocM fromlore tolore b
$c>>= :: forall fromlore tolore a b.
AllocM fromlore tolore a
-> (a -> AllocM fromlore tolore b) -> AllocM fromlore tolore b
Monad,
      Monad (AllocM fromlore tolore)
Applicative (AllocM fromlore tolore)
AllocM fromlore tolore VNameSource
Applicative (AllocM fromlore tolore)
-> Monad (AllocM fromlore tolore)
-> AllocM fromlore tolore VNameSource
-> (VNameSource -> AllocM fromlore tolore ())
-> MonadFreshNames (AllocM fromlore tolore)
VNameSource -> AllocM fromlore tolore ()
forall fromlore tolore. Monad (AllocM fromlore tolore)
forall fromlore tolore. Applicative (AllocM fromlore tolore)
forall fromlore tolore. AllocM fromlore tolore VNameSource
forall fromlore tolore. VNameSource -> AllocM fromlore tolore ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> AllocM fromlore tolore ()
$cputNameSource :: forall fromlore tolore. VNameSource -> AllocM fromlore tolore ()
getNameSource :: AllocM fromlore tolore VNameSource
$cgetNameSource :: forall fromlore tolore. AllocM fromlore tolore VNameSource
MonadFreshNames,
      HasScope tolore,
      LocalScope tolore,
      MonadReader (AllocEnv fromlore tolore)
    )

instance
  (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
  MonadBinder (AllocM fromlore tolore)
  where
  type Lore (AllocM fromlore tolore) = tolore

  mkExpDecM :: Pattern (Lore (AllocM fromlore tolore))
-> Exp (Lore (AllocM fromlore tolore))
-> AllocM fromlore tolore (ExpDec (Lore (AllocM fromlore tolore)))
mkExpDecM Pattern (Lore (AllocM fromlore tolore))
_ Exp (Lore (AllocM fromlore tolore))
_ = () -> AllocM fromlore tolore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  mkLetNamesM :: [VName]
-> Exp (Lore (AllocM fromlore tolore))
-> AllocM fromlore tolore (Stm (Lore (AllocM fromlore tolore)))
mkLetNamesM [VName]
names Exp (Lore (AllocM fromlore tolore))
e = do
    PatternT (LetDec tolore)
pat <- [VName]
-> Exp tolore -> AllocM fromlore tolore (PatternT (LetDec tolore))
forall lore (m :: * -> *).
(Allocator lore m, ExpDec lore ~ ()) =>
[VName] -> Exp lore -> m (Pattern lore)
patternWithAllocations [VName]
names Exp tolore
Exp (Lore (AllocM fromlore tolore))
e
    Stm tolore -> AllocM fromlore tolore (Stm tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm tolore -> AllocM fromlore tolore (Stm tolore))
-> Stm tolore -> AllocM fromlore tolore (Stm tolore)
forall a b. (a -> b) -> a -> b
$ PatternT (LetDec tolore)
-> StmAux (ExpDec tolore) -> Exp tolore -> Stm tolore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT (LetDec tolore)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) Exp tolore
Exp (Lore (AllocM fromlore tolore))
e

  mkBodyM :: Stms (Lore (AllocM fromlore tolore))
-> Result
-> AllocM fromlore tolore (Body (Lore (AllocM fromlore tolore)))
mkBodyM Stms (Lore (AllocM fromlore tolore))
bnds Result
res = BodyT tolore -> AllocM fromlore tolore (BodyT tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT tolore -> AllocM fromlore tolore (BodyT tolore))
-> BodyT tolore -> AllocM fromlore tolore (BodyT tolore)
forall a b. (a -> b) -> a -> b
$ BodyDec tolore -> Stms tolore -> Result -> BodyT tolore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () Stms tolore
Stms (Lore (AllocM fromlore tolore))
bnds Result
res

  addStms :: Stms (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ()
addStms = BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) ()
-> AllocM fromlore tolore ()
forall fromlore tolore a.
BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
-> AllocM fromlore tolore a
AllocM (BinderT
   tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) ()
 -> AllocM fromlore tolore ())
-> (Stms tolore
    -> BinderT
         tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) ())
-> Stms tolore
-> AllocM fromlore tolore ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms tolore
-> BinderT
     tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms
  collectStms :: forall a.
AllocM fromlore tolore a
-> AllocM fromlore tolore (a, Stms (Lore (AllocM fromlore tolore)))
collectStms (AllocM BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
m) = BinderT
  tolore
  (ReaderT (AllocEnv fromlore tolore) (State VNameSource))
  (a, Stms tolore)
-> AllocM fromlore tolore (a, Stms tolore)
forall fromlore tolore a.
BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
-> AllocM fromlore tolore a
AllocM (BinderT
   tolore
   (ReaderT (AllocEnv fromlore tolore) (State VNameSource))
   (a, Stms tolore)
 -> AllocM fromlore tolore (a, Stms tolore))
-> BinderT
     tolore
     (ReaderT (AllocEnv fromlore tolore) (State VNameSource))
     (a, Stms tolore)
-> AllocM fromlore tolore (a, Stms tolore)
forall a b. (a -> b) -> a -> b
$ BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
-> BinderT
     tolore
     (ReaderT (AllocEnv fromlore tolore) (State VNameSource))
     (a,
      Stms
        (Lore
           (BinderT
              tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)))))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
m

instance
  (Allocable fromlore tolore) =>
  Allocator tolore (AllocM fromlore tolore)
  where
  expHints :: Exp tolore -> AllocM fromlore tolore [ExpHint]
expHints Exp tolore
e = do
    Exp tolore -> AllocM fromlore tolore [ExpHint]
f <- (AllocEnv fromlore tolore
 -> Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocM
     fromlore tolore (Exp tolore -> AllocM fromlore tolore [ExpHint])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromlore tolore
-> Exp tolore -> AllocM fromlore tolore [ExpHint]
forall fromlore tolore.
AllocEnv fromlore tolore
-> Exp tolore -> AllocM fromlore tolore [ExpHint]
envExpHints
    Exp tolore -> AllocM fromlore tolore [ExpHint]
f Exp tolore
e
  askDefaultSpace :: AllocM fromlore tolore Space
askDefaultSpace = (AllocEnv fromlore tolore -> Space) -> AllocM fromlore tolore Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromlore tolore -> Space
forall fromlore tolore. AllocEnv fromlore tolore -> Space
allocSpace

  askConsts :: AllocM fromlore tolore (Set VName)
askConsts = (AllocEnv fromlore tolore -> Set VName)
-> AllocM fromlore tolore (Set VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromlore tolore -> Set VName
forall fromlore tolore. AllocEnv fromlore tolore -> Set VName
envConsts

runAllocM ::
  MonadFreshNames m =>
  (Op fromlore -> AllocM fromlore tolore (Op tolore)) ->
  (Exp tolore -> AllocM fromlore tolore [ExpHint]) ->
  AllocM fromlore tolore a ->
  m a
runAllocM :: forall (m :: * -> *) fromlore tolore a.
MonadFreshNames m =>
(Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocM fromlore tolore a
-> m a
runAllocM Op fromlore -> AllocM fromlore tolore (Op tolore)
handleOp Exp tolore -> AllocM fromlore tolore [ExpHint]
hints (AllocM BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
m) =
  ((a, Stms tolore) -> a) -> m (a, Stms tolore) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms tolore) -> a
forall a b. (a, b) -> a
fst (m (a, Stms tolore) -> m a) -> m (a, Stms tolore) -> m a
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((a, Stms tolore), VNameSource))
-> m (a, Stms tolore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Stms tolore), VNameSource))
 -> m (a, Stms tolore))
-> (VNameSource -> ((a, Stms tolore), VNameSource))
-> m (a, Stms tolore)
forall a b. (a -> b) -> a -> b
$ State VNameSource (a, Stms tolore)
-> VNameSource -> ((a, Stms tolore), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (a, Stms tolore)
 -> VNameSource -> ((a, Stms tolore), VNameSource))
-> State VNameSource (a, Stms tolore)
-> VNameSource
-> ((a, Stms tolore), VNameSource)
forall a b. (a -> b) -> a -> b
$ ReaderT
  (AllocEnv fromlore tolore) (State VNameSource) (a, Stms tolore)
-> AllocEnv fromlore tolore -> State VNameSource (a, Stms tolore)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
-> Scope tolore
-> ReaderT
     (AllocEnv fromlore tolore) (State VNameSource) (a, Stms tolore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
m Scope tolore
forall a. Monoid a => a
mempty) AllocEnv fromlore tolore
env
  where
    env :: AllocEnv fromlore tolore
env =
      AllocEnv :: forall fromlore tolore.
Map VName SubExp
-> Bool
-> Space
-> Set VName
-> (Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocEnv fromlore tolore
AllocEnv
        { chunkMap :: Map VName SubExp
chunkMap = Map VName SubExp
forall a. Monoid a => a
mempty,
          aggressiveReuse :: Bool
aggressiveReuse = Bool
False,
          allocSpace :: Space
allocSpace = Space
DefaultSpace,
          envConsts :: Set VName
envConsts = Set VName
forall a. Monoid a => a
mempty,
          allocInOp :: Op fromlore -> AllocM fromlore tolore (Op tolore)
allocInOp = Op fromlore -> AllocM fromlore tolore (Op tolore)
handleOp,
          envExpHints :: Exp tolore -> AllocM fromlore tolore [ExpHint]
envExpHints = Exp tolore -> AllocM fromlore tolore [ExpHint]
hints
        }

-- | Monad for adding allocations to a single pattern.
newtype PatAllocM lore a
  = PatAllocM
      ( RWS
          (Scope lore)
          [AllocStm]
          VNameSource
          a
      )
  deriving
    ( Functor (PatAllocM lore)
Functor (PatAllocM lore)
-> (forall a. a -> PatAllocM lore a)
-> (forall a b.
    PatAllocM lore (a -> b) -> PatAllocM lore a -> PatAllocM lore b)
-> (forall a b c.
    (a -> b -> c)
    -> PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore c)
-> (forall a b.
    PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b)
-> (forall a b.
    PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore a)
-> Applicative (PatAllocM lore)
forall {lore}. Functor (PatAllocM lore)
forall a. a -> PatAllocM lore a
forall lore a. a -> PatAllocM lore a
forall a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore a
forall a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
forall a b.
PatAllocM lore (a -> b) -> PatAllocM lore a -> PatAllocM lore b
forall lore a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore a
forall lore a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
forall lore a b.
PatAllocM lore (a -> b) -> PatAllocM lore a -> PatAllocM lore b
forall a b c.
(a -> b -> c)
-> PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore c
forall lore a b c.
(a -> b -> c)
-> PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore a
$c<* :: forall lore a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore a
*> :: forall a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
$c*> :: forall lore a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
liftA2 :: forall a b c.
(a -> b -> c)
-> PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore c
$cliftA2 :: forall lore a b c.
(a -> b -> c)
-> PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore c
<*> :: forall a b.
PatAllocM lore (a -> b) -> PatAllocM lore a -> PatAllocM lore b
$c<*> :: forall lore a b.
PatAllocM lore (a -> b) -> PatAllocM lore a -> PatAllocM lore b
pure :: forall a. a -> PatAllocM lore a
$cpure :: forall lore a. a -> PatAllocM lore a
Applicative,
      (forall a b. (a -> b) -> PatAllocM lore a -> PatAllocM lore b)
-> (forall a b. a -> PatAllocM lore b -> PatAllocM lore a)
-> Functor (PatAllocM lore)
forall a b. a -> PatAllocM lore b -> PatAllocM lore a
forall a b. (a -> b) -> PatAllocM lore a -> PatAllocM lore b
forall lore a b. a -> PatAllocM lore b -> PatAllocM lore a
forall lore a b. (a -> b) -> PatAllocM lore a -> PatAllocM lore b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> PatAllocM lore b -> PatAllocM lore a
$c<$ :: forall lore a b. a -> PatAllocM lore b -> PatAllocM lore a
fmap :: forall a b. (a -> b) -> PatAllocM lore a -> PatAllocM lore b
$cfmap :: forall lore a b. (a -> b) -> PatAllocM lore a -> PatAllocM lore b
Functor,
      Applicative (PatAllocM lore)
Applicative (PatAllocM lore)
-> (forall a b.
    PatAllocM lore a -> (a -> PatAllocM lore b) -> PatAllocM lore b)
-> (forall a b.
    PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b)
-> (forall a. a -> PatAllocM lore a)
-> Monad (PatAllocM lore)
forall lore. Applicative (PatAllocM lore)
forall a. a -> PatAllocM lore a
forall lore a. a -> PatAllocM lore a
forall a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
forall a b.
PatAllocM lore a -> (a -> PatAllocM lore b) -> PatAllocM lore b
forall lore a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
forall lore a b.
PatAllocM lore a -> (a -> PatAllocM lore b) -> PatAllocM lore b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> PatAllocM lore a
$creturn :: forall lore a. a -> PatAllocM lore a
>> :: forall a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
$c>> :: forall lore a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
>>= :: forall a b.
PatAllocM lore a -> (a -> PatAllocM lore b) -> PatAllocM lore b
$c>>= :: forall lore a b.
PatAllocM lore a -> (a -> PatAllocM lore b) -> PatAllocM lore b
Monad,
      HasScope lore,
      LocalScope lore,
      MonadWriter [AllocStm],
      Monad (PatAllocM lore)
Applicative (PatAllocM lore)
PatAllocM lore VNameSource
Applicative (PatAllocM lore)
-> Monad (PatAllocM lore)
-> PatAllocM lore VNameSource
-> (VNameSource -> PatAllocM lore ())
-> MonadFreshNames (PatAllocM lore)
VNameSource -> PatAllocM lore ()
forall lore. Monad (PatAllocM lore)
forall lore. Applicative (PatAllocM lore)
forall lore. PatAllocM lore VNameSource
forall lore. VNameSource -> PatAllocM lore ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> PatAllocM lore ()
$cputNameSource :: forall lore. VNameSource -> PatAllocM lore ()
getNameSource :: PatAllocM lore VNameSource
$cgetNameSource :: forall lore. PatAllocM lore VNameSource
MonadFreshNames
    )

instance Mem lore => Allocator lore (PatAllocM lore) where
  addAllocStm :: AllocStm -> PatAllocM lore ()
addAllocStm = [AllocStm] -> PatAllocM lore ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([AllocStm] -> PatAllocM lore ())
-> (AllocStm -> [AllocStm]) -> AllocStm -> PatAllocM lore ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AllocStm -> [AllocStm]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  dimAllocationSize :: SubExp -> PatAllocM lore SubExp
dimAllocationSize = SubExp -> PatAllocM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
  askDefaultSpace :: PatAllocM lore Space
askDefaultSpace = Space -> PatAllocM lore Space
forall (m :: * -> *) a. Monad m => a -> m a
return Space
DefaultSpace
  askConsts :: PatAllocM lore (Set VName)
askConsts = Set VName -> PatAllocM lore (Set VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Set VName
forall a. Monoid a => a
mempty

runPatAllocM ::
  MonadFreshNames m =>
  PatAllocM lore a ->
  Scope lore ->
  m (a, [AllocStm])
runPatAllocM :: forall (m :: * -> *) lore a.
MonadFreshNames m =>
PatAllocM lore a -> Scope lore -> m (a, [AllocStm])
runPatAllocM (PatAllocM RWS (Scope lore) [AllocStm] VNameSource a
m) Scope lore
mems =
  (VNameSource -> ((a, [AllocStm]), VNameSource))
-> m (a, [AllocStm])
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, [AllocStm]), VNameSource))
 -> m (a, [AllocStm]))
-> (VNameSource -> ((a, [AllocStm]), VNameSource))
-> m (a, [AllocStm])
forall a b. (a -> b) -> a -> b
$ (a, VNameSource, [AllocStm]) -> ((a, [AllocStm]), VNameSource)
forall {a} {b} {b}. (a, b, b) -> ((a, b), b)
frob ((a, VNameSource, [AllocStm]) -> ((a, [AllocStm]), VNameSource))
-> (VNameSource -> (a, VNameSource, [AllocStm]))
-> VNameSource
-> ((a, [AllocStm]), VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RWS (Scope lore) [AllocStm] VNameSource a
-> Scope lore -> VNameSource -> (a, VNameSource, [AllocStm])
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (Scope lore) [AllocStm] VNameSource a
m Scope lore
mems
  where
    frob :: (a, b, b) -> ((a, b), b)
frob (a
a, b
s, b
w) = ((a
a, b
w), b
s)

elemSize :: Num a => Type -> a
elemSize :: forall a. Num a => Type -> a
elemSize = PrimType -> a
forall a. Num a => PrimType -> a
primByteSize (PrimType -> a) -> (Type -> PrimType) -> Type -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType

arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp Type
t =
  TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName
 -> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(*) (Type -> TPrimExp Int64 VName
forall a. Num a => Type -> a
elemSize Type
t) ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t)

arraySizeInBytesExpM :: Allocator lore m => Type -> m (PrimExp VName)
arraySizeInBytesExpM :: forall lore (m :: * -> *).
Allocator lore m =>
Type -> m (PrimExp VName)
arraySizeInBytesExpM Type
t = do
  Result
dims <- (SubExp -> m SubExp) -> Result -> m Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m SubExp
forall lore (m :: * -> *). Allocator lore m => SubExp -> m SubExp
dimAllocationSize (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t)
  let dim_prod :: TPrimExp Int64 VName
dim_prod = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 Result
dims
      elm_size :: TPrimExp Int64 VName
elm_size = Type -> TPrimExp Int64 VName
forall a. Num a => Type -> a
elemSize Type
t
  PrimExp VName -> m (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> m (PrimExp VName))
-> PrimExp VName -> m (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
dim_prod TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elm_size

arraySizeInBytes :: Allocator lore m => Type -> m SubExp
arraySizeInBytes :: forall lore (m :: * -> *). Allocator lore m => Type -> m SubExp
arraySizeInBytes = String -> PrimExp VName -> m SubExp
forall lore (m :: * -> *).
Allocator lore m =>
String -> PrimExp VName -> m SubExp
computeSize String
"bytes" (PrimExp VName -> m SubExp)
-> (Type -> m (PrimExp VName)) -> Type -> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Type -> m (PrimExp VName)
forall lore (m :: * -> *).
Allocator lore m =>
Type -> m (PrimExp VName)
arraySizeInBytesExpM

-- | Allocate memory for a value of the given type.
allocForArray ::
  Allocator lore m =>
  Type ->
  Space ->
  m VName
allocForArray :: forall lore (m :: * -> *).
Allocator lore m =>
Type -> Space -> m VName
allocForArray Type
t Space
space = do
  SubExp
size <- Type -> m SubExp
forall lore (m :: * -> *). Allocator lore m => Type -> m SubExp
arraySizeInBytes Type
t
  String -> SubExp -> Space -> m VName
forall lore (m :: * -> *).
Allocator lore m =>
String -> SubExp -> Space -> m VName
allocateMemory String
"mem" SubExp
size Space
space

allocsForStm ::
  (Allocator lore m, ExpDec lore ~ ()) =>
  [Ident] ->
  [Ident] ->
  Exp lore ->
  m (Stm lore)
allocsForStm :: forall lore (m :: * -> *).
(Allocator lore m, ExpDec lore ~ ()) =>
[Ident] -> [Ident] -> Exp lore -> m (Stm lore)
allocsForStm [Ident]
sizeidents [Ident]
validents Exp lore
e = do
  [ExpReturns]
rts <- Exp lore -> m [ExpReturns]
forall (m :: * -> *) lore.
(Monad m, LocalScope lore m, Mem lore) =>
Exp lore -> m [ExpReturns]
expReturns Exp lore
e
  [ExpHint]
hints <- Exp lore -> m [ExpHint]
forall lore (m :: * -> *).
Allocator lore m =>
ExpT lore -> m [ExpHint]
expHints Exp lore
e
  ([PatElemT LParamMem]
ctxElems, [PatElemT LParamMem]
valElems) <- [Ident]
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m ([PatElem lore], [PatElem lore])
forall lore (m :: * -> *).
Allocator lore m =>
[Ident]
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m ([PatElem lore], [PatElem lore])
allocsForPattern [Ident]
sizeidents [Ident]
validents [ExpReturns]
rts [ExpHint]
hints
  Stm lore -> m (Stm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm lore -> m (Stm lore)) -> Stm lore -> m (Stm lore)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT LParamMem] -> [PatElemT LParamMem] -> PatternT LParamMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT LParamMem]
ctxElems [PatElemT LParamMem]
valElems) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) Exp lore
e

patternWithAllocations ::
  (Allocator lore m, ExpDec lore ~ ()) =>
  [VName] ->
  Exp lore ->
  m (Pattern lore)
patternWithAllocations :: forall lore (m :: * -> *).
(Allocator lore m, ExpDec lore ~ ()) =>
[VName] -> Exp lore -> m (Pattern lore)
patternWithAllocations [VName]
names Exp lore
e = do
  ([Type]
ts', [Ident]
sizes) <- [ExtType] -> m ([Type], [Ident])
forall (m :: * -> *) u.
MonadFreshNames m =>
[TypeBase ExtShape u] -> m ([TypeBase Shape u], [Ident])
instantiateShapes' ([ExtType] -> m ([Type], [Ident]))
-> m [ExtType] -> m ([Type], [Ident])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp lore -> m [ExtType]
forall lore (m :: * -> *).
(HasScope lore m, TypedOp (Op lore)) =>
Exp lore -> m [ExtType]
expExtType Exp lore
e
  let identForBindage :: VName -> Type -> f Ident
identForBindage VName
name Type
t =
        Ident -> f Ident
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ident -> f Ident) -> Ident -> f Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
name Type
t
  [Ident]
vals <- [m Ident] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [VName -> Type -> m Ident
forall {f :: * -> *}. Applicative f => VName -> Type -> f Ident
identForBindage VName
name Type
t | (VName
name, Type
t) <- [VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [Type]
ts']
  Stm lore -> PatternT LParamMem
forall lore. Stm lore -> Pattern lore
stmPattern (Stm lore -> PatternT LParamMem)
-> m (Stm lore) -> m (PatternT LParamMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Ident] -> [Ident] -> Exp lore -> m (Stm lore)
forall lore (m :: * -> *).
(Allocator lore m, ExpDec lore ~ ()) =>
[Ident] -> [Ident] -> Exp lore -> m (Stm lore)
allocsForStm [Ident]
sizes [Ident]
vals Exp lore
e

allocsForPattern ::
  Allocator lore m =>
  [Ident] ->
  [Ident] ->
  [ExpReturns] ->
  [ExpHint] ->
  m
    ( [PatElem lore],
      [PatElem lore]
    )
allocsForPattern :: forall lore (m :: * -> *).
Allocator lore m =>
[Ident]
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m ([PatElem lore], [PatElem lore])
allocsForPattern [Ident]
sizeidents [Ident]
validents [ExpReturns]
rts [ExpHint]
hints = do
  let sizes' :: [PatElemT LParamMem]
sizes' = [VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
size (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64 | VName
size <- (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
sizeidents]
  ([PatElemT LParamMem]
vals, ([PatElemT LParamMem]
exts, [PatElemT LParamMem]
mems)) <-
    WriterT
  ([PatElemT LParamMem], [PatElemT LParamMem]) m [PatElemT LParamMem]
-> m ([PatElemT LParamMem],
      ([PatElemT LParamMem], [PatElemT LParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([PatElemT LParamMem], [PatElemT LParamMem]) m [PatElemT LParamMem]
 -> m ([PatElemT LParamMem],
       ([PatElemT LParamMem], [PatElemT LParamMem])))
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m [PatElemT LParamMem]
-> m ([PatElemT LParamMem],
      ([PatElemT LParamMem], [PatElemT LParamMem]))
forall a b. (a -> b) -> a -> b
$
      [(Ident, ExpReturns, ExpHint)]
-> ((Ident, ExpReturns, ExpHint)
    -> WriterT
         ([PatElemT LParamMem], [PatElemT LParamMem])
         m
         (PatElemT LParamMem))
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m [PatElemT LParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Ident]
-> [ExpReturns] -> [ExpHint] -> [(Ident, ExpReturns, ExpHint)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Ident]
validents [ExpReturns]
rts [ExpHint]
hints) (((Ident, ExpReturns, ExpHint)
  -> WriterT
       ([PatElemT LParamMem], [PatElemT LParamMem])
       m
       (PatElemT LParamMem))
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      [PatElemT LParamMem])
-> ((Ident, ExpReturns, ExpHint)
    -> WriterT
         ([PatElemT LParamMem], [PatElemT LParamMem])
         m
         (PatElemT LParamMem))
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m [PatElemT LParamMem]
forall a b. (a -> b) -> a -> b
$ \(Ident
ident, ExpReturns
rt, ExpHint
hint) -> do
        let ident_shape :: Shape
ident_shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Type -> Shape
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident
        case ExpReturns
rt of
          MemPrim PrimType
_ -> do
            LParamMem
summary <- m LParamMem
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m LParamMem
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem]) m LParamMem)
-> m LParamMem
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m LParamMem
forall a b. (a -> b) -> a -> b
$ Type -> ExpHint -> m LParamMem
forall lore (m :: * -> *).
Allocator lore m =>
Type -> ExpHint -> m LParamMem
summaryForBindage (Ident -> Type
identType Ident
ident) ExpHint
hint
            PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      (PatElemT LParamMem))
-> PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
          MemMem Space
space ->
            PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      (PatElemT LParamMem))
-> PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$
              VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$
                Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
          MemArray PrimType
bt ExtShape
_ NoUniqueness
u (Just (ReturnsInBlock VName
mem ExtIxFun
extixfun)) -> do
            ([PatElemT LParamMem]
patels, IxFun
ixfn) <- Ident
-> ExtIxFun
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem])
     m
     ([PatElemT LParamMem], IxFun)
forall (m :: * -> *) d u ret.
MonadFreshNames m =>
Ident -> ExtIxFun -> m ([PatElemT (MemInfo d u ret)], IxFun)
instantiateExtIxFun Ident
ident ExtIxFun
extixfun
            ([PatElemT LParamMem], [PatElemT LParamMem])
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([PatElemT LParamMem]
patels, [])

            PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      (PatElemT LParamMem))
-> PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$
              VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$
                PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$
                  VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfn
          MemArray PrimType
_ ExtShape
extshape NoUniqueness
_ Maybe MemReturn
Nothing
            | Just Result
_ <- ExtShape -> Maybe Result
forall {b}. ShapeBase (Ext b) -> Maybe [b]
knownShape ExtShape
extshape -> do
              LParamMem
summary <- m LParamMem
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m LParamMem
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem]) m LParamMem)
-> m LParamMem
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m LParamMem
forall a b. (a -> b) -> a -> b
$ Type -> ExpHint -> m LParamMem
forall lore (m :: * -> *).
Allocator lore m =>
Type -> ExpHint -> m LParamMem
summaryForBindage (Ident -> Type
identType Ident
ident) ExpHint
hint
              PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      (PatElemT LParamMem))
-> PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
          MemArray PrimType
bt ExtShape
_ NoUniqueness
u (Just (ReturnsNewBlock Space
space Int
_ ExtIxFun
extixfn)) -> do
            -- treat existential index function first
            ([PatElemT LParamMem]
patels, IxFun
ixfn) <- Ident
-> ExtIxFun
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem])
     m
     ([PatElemT LParamMem], IxFun)
forall (m :: * -> *) d u ret.
MonadFreshNames m =>
Ident -> ExtIxFun -> m ([PatElemT (MemInfo d u ret)], IxFun)
instantiateExtIxFun Ident
ident ExtIxFun
extixfn
            ([PatElemT LParamMem], [PatElemT LParamMem])
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([PatElemT LParamMem]
patels, [])

            Ident
memid <- m Ident
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m Ident
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Ident
 -> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m Ident)
-> m Ident
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m Ident
forall a b. (a -> b) -> a -> b
$ Ident -> Space -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
Ident -> Space -> m Ident
mkMemIdent Ident
ident Space
space
            ([PatElemT LParamMem], [PatElemT LParamMem])
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
memid) (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$ Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space])
            PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      (PatElemT LParamMem))
-> PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$
              VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$
                PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$
                  VName -> IxFun -> MemBind
ArrayIn (Ident -> VName
identName Ident
memid) IxFun
ixfn
          MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
            PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      (PatElemT LParamMem))
-> PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
          ExpReturns
_ -> String
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a. HasCallStack => String -> a
error String
"Impossible case reached in allocsForPattern!"

  ([PatElemT LParamMem], [PatElemT LParamMem])
-> m ([PatElemT LParamMem], [PatElemT LParamMem])
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( [PatElemT LParamMem]
sizes' [PatElemT LParamMem]
-> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Semigroup a => a -> a -> a
<> [PatElemT LParamMem]
exts [PatElemT LParamMem]
-> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Semigroup a => a -> a -> a
<> [PatElemT LParamMem]
mems,
      [PatElemT LParamMem]
vals
    )
  where
    knownShape :: ShapeBase (Ext b) -> Maybe [b]
knownShape = (Ext b -> Maybe b) -> [Ext b] -> Maybe [b]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext b -> Maybe b
forall {a}. Ext a -> Maybe a
known ([Ext b] -> Maybe [b])
-> (ShapeBase (Ext b) -> [Ext b]) -> ShapeBase (Ext b) -> Maybe [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase (Ext b) -> [Ext b]
forall d. ShapeBase d -> [d]
shapeDims
    known :: Ext a -> Maybe a
known (Free a
v) = a -> Maybe a
forall a. a -> Maybe a
Just a
v
    known Ext {} = Maybe a
forall a. Maybe a
Nothing

    mkMemIdent :: (MonadFreshNames m) => Ident -> Space -> m Ident
    mkMemIdent :: forall (m :: * -> *).
MonadFreshNames m =>
Ident -> Space -> m Ident
mkMemIdent Ident
ident Space
space = do
      let memname :: String
memname = VName -> String
baseString (Ident -> VName
identName Ident
ident) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_mem"
      String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
memname (Type -> m Ident) -> Type -> m Ident
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space

    instantiateExtIxFun ::
      MonadFreshNames m =>
      Ident ->
      ExtIxFun ->
      m ([PatElemT (MemInfo d u ret)], IxFun)
    instantiateExtIxFun :: forall (m :: * -> *) d u ret.
MonadFreshNames m =>
Ident -> ExtIxFun -> m ([PatElemT (MemInfo d u ret)], IxFun)
instantiateExtIxFun Ident
idd ExtIxFun
ext_ixfn = do
      let isAndPtps :: [(Int, PrimType)]
isAndPtps =
            Set (Int, PrimType) -> [(Int, PrimType)]
forall a. Set a -> [a]
S.toList (Set (Int, PrimType) -> [(Int, PrimType)])
-> Set (Int, PrimType) -> [(Int, PrimType)]
forall a b. (a -> b) -> a -> b
$
              ((Ext VName, PrimType) -> Set (Int, PrimType))
-> Set (Ext VName, PrimType) -> Set (Int, PrimType)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Ext VName, PrimType) -> Set (Int, PrimType)
forall a. (Ext a, PrimType) -> Set (Int, PrimType)
onlyExts (Set (Ext VName, PrimType) -> Set (Int, PrimType))
-> Set (Ext VName, PrimType) -> Set (Int, PrimType)
forall a b. (a -> b) -> a -> b
$
                (TPrimExp Int64 (Ext VName) -> Set (Ext VName, PrimType))
-> ExtIxFun -> Set (Ext VName, PrimType)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (PrimExp (Ext VName) -> Set (Ext VName, PrimType)
forall a. Ord a => PrimExp a -> Set (a, PrimType)
leafExpTypes (PrimExp (Ext VName) -> Set (Ext VName, PrimType))
-> (TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName))
-> TPrimExp Int64 (Ext VName)
-> Set (Ext VName, PrimType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName)
forall t v. TPrimExp t v -> PrimExp v
untyped) ExtIxFun
ext_ixfn

      -- Find the existentials that reuse the sizeidents, and
      -- those that need new pattern elements.  Assumes that the
      -- Exts form a contiguous interval of integers.
      let ([(Int, PrimType)]
size_exts, [(Int, PrimType)]
new_exts) =
            ((Int, PrimType) -> Bool)
-> [(Int, PrimType)] -> ([(Int, PrimType)], [(Int, PrimType)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
span ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< [Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
sizeidents) (Int -> Bool)
-> ((Int, PrimType) -> Int) -> (Int, PrimType) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, PrimType) -> Int
forall a b. (a, b) -> a
fst) ([(Int, PrimType)] -> ([(Int, PrimType)], [(Int, PrimType)]))
-> [(Int, PrimType)] -> ([(Int, PrimType)], [(Int, PrimType)])
forall a b. (a -> b) -> a -> b
$ [(Int, PrimType)] -> [(Int, PrimType)]
forall a. Ord a => [a] -> [a]
sort [(Int, PrimType)]
isAndPtps
      ([(Ext VName, PrimExp (Ext VName))]
new_substs, [PatElemT (MemInfo d u ret)]
patels) <-
        ([((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))]
 -> ([(Ext VName, PrimExp (Ext VName))],
     [PatElemT (MemInfo d u ret)]))
-> m [((Ext VName, PrimExp (Ext VName)),
       PatElemT (MemInfo d u ret))]
-> m ([(Ext VName, PrimExp (Ext VName))],
      [PatElemT (MemInfo d u ret)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))]
-> ([(Ext VName, PrimExp (Ext VName))],
    [PatElemT (MemInfo d u ret)])
forall a b. [(a, b)] -> ([a], [b])
unzip (m [((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))]
 -> m ([(Ext VName, PrimExp (Ext VName))],
       [PatElemT (MemInfo d u ret)]))
-> m [((Ext VName, PrimExp (Ext VName)),
       PatElemT (MemInfo d u ret))]
-> m ([(Ext VName, PrimExp (Ext VName))],
      [PatElemT (MemInfo d u ret)])
forall a b. (a -> b) -> a -> b
$
          [(Int, PrimType)]
-> ((Int, PrimType)
    -> m ((Ext VName, PrimExp (Ext VName)),
          PatElemT (MemInfo d u ret)))
-> m [((Ext VName, PrimExp (Ext VName)),
       PatElemT (MemInfo d u ret))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Int, PrimType)]
new_exts (((Int, PrimType)
  -> m ((Ext VName, PrimExp (Ext VName)),
        PatElemT (MemInfo d u ret)))
 -> m [((Ext VName, PrimExp (Ext VName)),
        PatElemT (MemInfo d u ret))])
-> ((Int, PrimType)
    -> m ((Ext VName, PrimExp (Ext VName)),
          PatElemT (MemInfo d u ret)))
-> m [((Ext VName, PrimExp (Ext VName)),
       PatElemT (MemInfo d u ret))]
forall a b. (a -> b) -> a -> b
$ \(Int
i, PrimType
t) -> do
            VName
v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (Ident -> VName
identName Ident
idd) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_ixfn"
            ((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))
-> m ((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))
forall (m :: * -> *) a. Monad m => a -> m a
return
              ( (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i, Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> Ext VName
forall a. a -> Ext a
Free VName
v) PrimType
t),
                VName -> MemInfo d u ret -> PatElemT (MemInfo d u ret)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
v (MemInfo d u ret -> PatElemT (MemInfo d u ret))
-> MemInfo d u ret -> PatElemT (MemInfo d u ret)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
              )
      let size_substs :: [(Ext VName, PrimExp (Ext VName))]
size_substs =
            ((Int, PrimType) -> Ident -> (Ext VName, PrimExp (Ext VName)))
-> [(Int, PrimType)]
-> [Ident]
-> [(Ext VName, PrimExp (Ext VName))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              ( \(Int
i, PrimType
t) Ident
ident ->
                  (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i, Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> Ext VName
forall a. a -> Ext a
Free (Ident -> VName
identName Ident
ident)) PrimType
t)
              )
              [(Int, PrimType)]
size_exts
              [Ident]
sizeidents
          substs :: Map (Ext VName) (PrimExp (Ext VName))
substs = [(Ext VName, PrimExp (Ext VName))]
-> Map (Ext VName) (PrimExp (Ext VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, PrimExp (Ext VName))]
 -> Map (Ext VName) (PrimExp (Ext VName)))
-> [(Ext VName, PrimExp (Ext VName))]
-> Map (Ext VName) (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ [(Ext VName, PrimExp (Ext VName))]
new_substs [(Ext VName, PrimExp (Ext VName))]
-> [(Ext VName, PrimExp (Ext VName))]
-> [(Ext VName, PrimExp (Ext VName))]
forall a. Semigroup a => a -> a -> a
<> [(Ext VName, PrimExp (Ext VName))]
size_substs
      IxFun
ixfn <- ExtIxFun -> m IxFun
forall (m :: * -> *). Monad m => ExtIxFun -> m IxFun
instantiateIxFun (ExtIxFun -> m IxFun) -> ExtIxFun -> m IxFun
forall a b. (a -> b) -> a -> b
$ Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun ((PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> Map (Ext VName) (PrimExp (Ext VName))
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 Map (Ext VName) (PrimExp (Ext VName))
substs) ExtIxFun
ext_ixfn

      ([PatElemT (MemInfo d u ret)], IxFun)
-> m ([PatElemT (MemInfo d u ret)], IxFun)
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT (MemInfo d u ret)]
patels, IxFun
ixfn)

onlyExts :: (Ext a, PrimType) -> S.Set (Int, PrimType)
onlyExts :: forall a. (Ext a, PrimType) -> Set (Int, PrimType)
onlyExts (Free a
_, PrimType
_) = Set (Int, PrimType)
forall a. Set a
S.empty
onlyExts (Ext Int
i, PrimType
t) = (Int, PrimType) -> Set (Int, PrimType)
forall a. a -> Set a
S.singleton (Int
i, PrimType
t)

instantiateIxFun :: Monad m => ExtIxFun -> m IxFun
instantiateIxFun :: forall (m :: * -> *). Monad m => ExtIxFun -> m IxFun
instantiateIxFun = (TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
-> ExtIxFun -> m IxFun
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
 -> ExtIxFun -> m IxFun)
-> (TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
-> ExtIxFun
-> m IxFun
forall a b. (a -> b) -> a -> b
$ (Ext VName -> m VName)
-> TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> m VName
forall {m :: * -> *} {a}. Monad m => Ext a -> m a
inst
  where
    inst :: Ext a -> m a
inst Ext {} = String -> m a
forall a. HasCallStack => String -> a
error String
"instantiateIxFun: not yet"
    inst (Free a
x) = a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

summaryForBindage ::
  Allocator lore m =>
  Type ->
  ExpHint ->
  m (MemBound NoUniqueness)
summaryForBindage :: forall lore (m :: * -> *).
Allocator lore m =>
Type -> ExpHint -> m LParamMem
summaryForBindage (Prim PrimType
bt) ExpHint
_ =
  LParamMem -> m LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
summaryForBindage (Mem Space
space) ExpHint
_ =
  LParamMem -> m LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
summaryForBindage (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) ExpHint
_ =
  LParamMem -> m LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
summaryForBindage t :: Type
t@(Array PrimType
pt Shape
shape NoUniqueness
u) ExpHint
NoHint = do
  VName
m <- Type -> Space -> m VName
forall lore (m :: * -> *).
Allocator lore m =>
Type -> Space -> m VName
allocForArray Type
t (Space -> m VName) -> m Space -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m Space
forall lore (m :: * -> *). Allocator lore m => m Space
askDefaultSpace
  LParamMem -> m LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> VName -> Type -> LParamMem
forall u. PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIxFun PrimType
pt Shape
shape NoUniqueness
u VName
m Type
t
summaryForBindage t :: Type
t@(Array PrimType
pt Shape
_ NoUniqueness
_) (Hint IxFun
ixfun Space
space) = do
  SubExp
bytes <-
    String -> PrimExp VName -> m SubExp
forall lore (m :: * -> *).
Allocator lore m =>
String -> PrimExp VName -> m SubExp
computeSize String
"bytes" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
        [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product
          [ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ IxFun -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun
ixfun,
            Int64 -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral (PrimType -> Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
pt :: Int64)
          ]
  VName
m <- String -> SubExp -> Space -> m VName
forall lore (m :: * -> *).
Allocator lore m =>
String -> SubExp -> Space -> m VName
allocateMemory String
"mem" SubExp
bytes Space
space
  LParamMem -> m LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
m IxFun
ixfun

lookupMemSpace :: (HasScope lore m, Monad m) => VName -> m Space
lookupMemSpace :: forall lore (m :: * -> *).
(HasScope lore m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v = do
  Type
t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
  case Type
t of
    Mem Space
space -> Space -> m Space
forall (m :: * -> *) a. Monad m => a -> m a
return Space
space
    Type
_ -> String -> m Space
forall a. HasCallStack => String -> a
error (String -> m Space) -> String -> m Space
forall a b. (a -> b) -> a -> b
$ String
"lookupMemSpace: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not a memory block."

directIxFun :: PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIxFun :: forall u. PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIxFun PrimType
bt Shape
shape u
u VName
mem Type
t =
  let ixf :: IxFun
ixf = [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun)
-> [TPrimExp Int64 VName] -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Result -> [TPrimExp Int64 VName])
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t
   in PrimType -> Shape -> u -> MemBind -> MemInfo SubExp u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape u
u (MemBind -> MemInfo SubExp u MemBind)
-> MemBind -> MemInfo SubExp u MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixf

allocInFParams ::
  (Allocable fromlore tolore) =>
  [(FParam fromlore, Space)] ->
  ([FParam tolore] -> AllocM fromlore tolore a) ->
  AllocM fromlore tolore a
allocInFParams :: forall fromlore tolore a.
Allocable fromlore tolore =>
[(FParam fromlore, Space)]
-> ([FParam tolore] -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInFParams [(FParam fromlore, Space)]
params [FParam tolore] -> AllocM fromlore tolore a
m = do
  ([Param FParamMem]
valparams, ([Param FParamMem]
ctxparams, [Param FParamMem]
memparams)) <-
    WriterT
  ([Param FParamMem], [Param FParamMem])
  (AllocM fromlore tolore)
  [Param FParamMem]
-> AllocM
     fromlore
     tolore
     ([Param FParamMem], ([Param FParamMem], [Param FParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param FParamMem], [Param FParamMem])
   (AllocM fromlore tolore)
   [Param FParamMem]
 -> AllocM
      fromlore
      tolore
      ([Param FParamMem], ([Param FParamMem], [Param FParamMem])))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     [Param FParamMem]
-> AllocM
     fromlore
     tolore
     ([Param FParamMem], ([Param FParamMem], [Param FParamMem]))
forall a b. (a -> b) -> a -> b
$ ((FParam fromlore, Space)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromlore tolore)
      (Param FParamMem))
-> [(FParam fromlore, Space)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     [Param FParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParam fromlore
 -> Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromlore tolore)
      (Param FParamMem))
-> (FParam fromlore, Space)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     (Param FParamMem)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry FParam fromlore
-> Space
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     (Param FParamMem)
forall fromlore tolore.
Allocable fromlore tolore =>
FParam fromlore
-> Space
-> WriterT
     ([FParam tolore], [FParam tolore])
     (AllocM fromlore tolore)
     (FParam tolore)
allocInFParam) [(FParam fromlore, Space)]
params
  let params' :: [Param FParamMem]
params' = [Param FParamMem]
ctxparams [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
memparams [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope tolore
summary = [Param FParamMem] -> Scope tolore
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param FParamMem]
params'
  Scope tolore
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope tolore
summary (AllocM fromlore tolore a -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall a b. (a -> b) -> a -> b
$ [FParam tolore] -> AllocM fromlore tolore a
m [FParam tolore]
[Param FParamMem]
params'

allocInFParam ::
  (Allocable fromlore tolore) =>
  FParam fromlore ->
  Space ->
  WriterT
    ([FParam tolore], [FParam tolore])
    (AllocM fromlore tolore)
    (FParam tolore)
allocInFParam :: forall fromlore tolore.
Allocable fromlore tolore =>
FParam fromlore
-> Space
-> WriterT
     ([FParam tolore], [FParam tolore])
     (AllocM fromlore tolore)
     (FParam tolore)
allocInFParam FParam fromlore
param Space
pspace =
  case Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
FParam fromlore
param of
    Array PrimType
pt Shape
shape Uniqueness
u -> do
      let memname :: String
memname = VName -> String
baseString (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
FParam fromlore
param) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_mem"
          ixfun :: IxFun
ixfun = [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun)
-> [TPrimExp Int64 VName] -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Result -> [TPrimExp Int64 VName])
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName
mem <- AllocM fromlore tolore VName
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromlore tolore VName
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromlore tolore)
      VName)
-> AllocM fromlore tolore VName
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     VName
forall a b. (a -> b) -> a -> b
$ String -> AllocM fromlore tolore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
memname
      ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromlore tolore) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [VName -> FParamMem -> Param FParamMem
forall dec. VName -> dec -> Param dec
Param VName
mem (FParamMem -> Param FParamMem) -> FParamMem -> Param FParamMem
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
pspace])
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromlore
param {paramDec :: FParamMem
paramDec = PrimType -> Shape -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u (MemBind -> FParamMem) -> MemBind -> FParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun}
    Prim PrimType
pt ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromlore
param {paramDec :: FParamMem
paramDec = PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt}
    Mem Space
space ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromlore
param {paramDec :: FParamMem
paramDec = Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}
    Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromlore
param {paramDec :: FParamMem
paramDec = VName -> Shape -> [Type] -> Uniqueness -> FParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u}

allocInMergeParams ::
  ( Allocable fromlore tolore,
    Allocator tolore (AllocM fromlore tolore)
  ) =>
  [(FParam fromlore, SubExp)] ->
  ( [FParam tolore] ->
    [FParam tolore] ->
    ([SubExp] -> AllocM fromlore tolore ([SubExp], [SubExp])) ->
    AllocM fromlore tolore a
  ) ->
  AllocM fromlore tolore a
allocInMergeParams :: forall fromlore tolore a.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[(FParam fromlore, SubExp)]
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInMergeParams [(FParam fromlore, SubExp)]
merge [FParam tolore]
-> [FParam tolore]
-> (Result -> AllocM fromlore tolore (Result, Result))
-> AllocM fromlore tolore a
m = do
  (([Param FParamMem]
valparams, [SubExp
 -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp]
handle_loop_subexps), ([Param FParamMem]
ctx_params, [Param FParamMem]
mem_params)) <-
    WriterT
  ([Param FParamMem], [Param FParamMem])
  (AllocM fromlore tolore)
  ([Param FParamMem],
   [SubExp
    -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp])
-> AllocM
     fromlore
     tolore
     (([Param FParamMem],
       [SubExp
        -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp]),
      ([Param FParamMem], [Param FParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param FParamMem], [Param FParamMem])
   (AllocM fromlore tolore)
   ([Param FParamMem],
    [SubExp
     -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp])
 -> AllocM
      fromlore
      tolore
      (([Param FParamMem],
        [SubExp
         -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp]),
       ([Param FParamMem], [Param FParamMem])))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     ([Param FParamMem],
      [SubExp
       -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp])
-> AllocM
     fromlore
     tolore
     (([Param FParamMem],
       [SubExp
        -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp]),
      ([Param FParamMem], [Param FParamMem]))
forall a b. (a -> b) -> a -> b
$ [(Param FParamMem,
  SubExp
  -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)]
-> ([Param FParamMem],
    [SubExp
     -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param FParamMem,
   SubExp
   -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)]
 -> ([Param FParamMem],
     [SubExp
      -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp]))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     [(Param FParamMem,
       SubExp
       -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     ([Param FParamMem],
      [SubExp
       -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param DeclType, SubExp)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromlore tolore)
      (Param FParamMem,
       SubExp
       -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp))
-> [(Param DeclType, SubExp)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     [(Param FParamMem,
       SubExp
       -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param DeclType, SubExp)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     (Param FParamMem,
      SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam tolore], [FParam tolore])
     (AllocM fromlore tolore)
     (FParam tolore,
      SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
allocInMergeParam [(Param DeclType, SubExp)]
[(FParam fromlore, SubExp)]
merge
  let mergeparams' :: [Param FParamMem]
mergeparams' = [Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
mem_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope tolore
summary = [Param FParamMem] -> Scope tolore
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param FParamMem]
mergeparams'

      mk_loop_res :: Result -> AllocM fromlore tolore (Result, Result)
mk_loop_res Result
ses = do
        (Result
valargs, (Result
ctxargs, Result
memargs)) <-
          WriterT (Result, Result) (AllocM fromlore tolore) Result
-> AllocM fromlore tolore (Result, (Result, Result))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT (Result, Result) (AllocM fromlore tolore) Result
 -> AllocM fromlore tolore (Result, (Result, Result)))
-> WriterT (Result, Result) (AllocM fromlore tolore) Result
-> AllocM fromlore tolore (Result, (Result, Result))
forall a b. (a -> b) -> a -> b
$ ((SubExp
  -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
 -> SubExp
 -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
-> [SubExp
    -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp]
-> Result
-> WriterT (Result, Result) (AllocM fromlore tolore) Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExp
 -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
-> SubExp
-> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
forall a b. (a -> b) -> a -> b
($) [SubExp
 -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp]
handle_loop_subexps Result
ses
        (Result, Result) -> AllocM fromlore tolore (Result, Result)
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
ctxargs Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
memargs, Result
valargs)

  Scope tolore
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope tolore
summary (AllocM fromlore tolore a -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall a b. (a -> b) -> a -> b
$ [FParam tolore]
-> [FParam tolore]
-> (Result -> AllocM fromlore tolore (Result, Result))
-> AllocM fromlore tolore a
m ([Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
mem_params) [FParam tolore]
[Param FParamMem]
valparams Result -> AllocM fromlore tolore (Result, Result)
mk_loop_res
  where
    allocInMergeParam ::
      (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
      (Param DeclType, SubExp) ->
      WriterT
        ([FParam tolore], [FParam tolore])
        (AllocM fromlore tolore)
        (FParam tolore, SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromlore tolore) SubExp)
    allocInMergeParam :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam tolore], [FParam tolore])
     (AllocM fromlore tolore)
     (FParam tolore,
      SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
allocInMergeParam (Param DeclType
mergeparam, Var VName
v)
      | Array PrimType
pt Shape
shape Uniqueness
u <- Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
mergeparam = do
        (VName
mem', IxFun
_) <- AllocM fromlore tolore (VName, IxFun)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     (VName, IxFun)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromlore tolore (VName, IxFun)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromlore tolore)
      (VName, IxFun))
-> AllocM fromlore tolore (VName, IxFun)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     (VName, IxFun)
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromlore tolore (VName, IxFun)
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
        Space
mem_space <- AllocM fromlore tolore Space
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromlore tolore Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromlore tolore)
      Space)
-> AllocM fromlore tolore Space
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     Space
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromlore tolore Space
forall lore (m :: * -> *).
(HasScope lore m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem'

        (SubExp
_, ExtIxFun
ext_ixfun, [TPrimExp Int64 VName]
substs, VName
_) <- AllocM
  fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM
   fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromlore tolore)
      (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName))
-> AllocM
     fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall a b. (a -> b) -> a -> b
$ Space
-> VName
-> AllocM
     fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Space
-> VName
-> AllocM
     fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
existentializeArray Space
mem_space VName
v

        ([Param FParamMem]
ctx_params, [TPrimExp Int64 (Ext VName)]
param_ixfun_substs) <-
          [(Param FParamMem, TPrimExp Int64 (Ext VName))]
-> ([Param FParamMem], [TPrimExp Int64 (Ext VName)])
forall a b. [(a, b)] -> ([a], [b])
unzip
            ([(Param FParamMem, TPrimExp Int64 (Ext VName))]
 -> ([Param FParamMem], [TPrimExp Int64 (Ext VName)]))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     [(Param FParamMem, TPrimExp Int64 (Ext VName))]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     ([Param FParamMem], [TPrimExp Int64 (Ext VName)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TPrimExp Int64 VName
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromlore tolore)
      (Param FParamMem, TPrimExp Int64 (Ext VName)))
-> [TPrimExp Int64 VName]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     [(Param FParamMem, TPrimExp Int64 (Ext VName))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
              ( \TPrimExp Int64 VName
e -> do
                  let e_t :: PrimType
e_t = PrimExp VName -> PrimType
forall v. PrimExp v -> PrimType
primExpType (PrimExp VName -> PrimType) -> PrimExp VName -> PrimType
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
e
                  VName
vname <- AllocM fromlore tolore VName
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromlore tolore VName
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromlore tolore)
      VName)
-> AllocM fromlore tolore VName
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     VName
forall a b. (a -> b) -> a -> b
$ String -> AllocM fromlore tolore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ctx_param_ext"
                  (Param FParamMem, TPrimExp Int64 (Ext VName))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     (Param FParamMem, TPrimExp Int64 (Ext VName))
forall (m :: * -> *) a. Monad m => a -> m a
return
                    ( VName -> FParamMem -> Param FParamMem
forall dec. VName -> dec -> Param dec
Param VName
vname (FParamMem -> Param FParamMem) -> FParamMem -> Param FParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
e_t,
                      (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vname
                    )
              )
              [TPrimExp Int64 VName]
substs

        ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromlore tolore) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem]
ctx_params, [])

        IxFun
param_ixfun <-
          ExtIxFun
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     IxFun
forall (m :: * -> *). Monad m => ExtIxFun -> m IxFun
instantiateIxFun (ExtIxFun
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromlore tolore)
      IxFun)
-> ExtIxFun
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     IxFun
forall a b. (a -> b) -> a -> b
$
            Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun
              ([(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, TPrimExp Int64 (Ext VName))]
 -> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> [(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$ [Ext VName]
-> [TPrimExp Int64 (Ext VName)]
-> [(Ext VName, TPrimExp Int64 (Ext VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Int -> Ext VName) -> [Int] -> [Ext VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Ext VName
forall a. Int -> Ext a
Ext [Int
0 ..]) [TPrimExp Int64 (Ext VName)]
param_ixfun_substs)
              ExtIxFun
ext_ixfun

        VName
mem_name <- String
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"mem_param"
        ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromlore tolore) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [VName -> FParamMem -> Param FParamMem
forall dec. VName -> dec -> Param dec
Param VName
mem_name (FParamMem -> Param FParamMem) -> FParamMem -> Param FParamMem
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
mem_space])

        (Param FParamMem,
 SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromlore tolore)
     (Param FParamMem,
      SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = PrimType -> Shape -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u (MemBind -> FParamMem) -> MemBind -> FParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem_name IxFun
param_ixfun},
            Space
-> SubExp
-> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Space
-> SubExp
-> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
ensureArrayIn Space
mem_space
          )
    allocInMergeParam (Param DeclType
mergeparam, SubExp
_) = Param DeclType
-> Space
-> WriterT
     ([Param (FParamInfo tolore)], [Param (FParamInfo tolore)])
     (AllocM fromlore tolore)
     (Param (FParamInfo tolore),
      SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
forall {tolore} {fromlore} {tolore} {fromlore}.
(PrettyLore fromlore, PrettyLore fromlore, AllocOp (Op tolore),
 AllocOp (Op tolore), OpReturns tolore, OpReturns tolore,
 SizeSubst (Op tolore), SizeSubst (Op tolore), BinderOps tolore,
 BinderOps tolore, LetDec tolore ~ LParamMem, BodyDec tolore ~ (),
 LParamInfo fromlore ~ Type, LetDec tolore ~ LParamMem,
 BodyDec tolore ~ (), LParamInfo fromlore ~ Type,
 BranchType fromlore ~ ExtType, ExpDec tolore ~ (),
 RetType tolore ~ RetTypeMem, BranchType fromlore ~ ExtType,
 ExpDec tolore ~ (), RetType tolore ~ RetTypeMem,
 LParamInfo tolore ~ LParamMem, BodyDec fromlore ~ (),
 FParamInfo fromlore ~ DeclType, LParamInfo tolore ~ LParamMem,
 BodyDec fromlore ~ (), FParamInfo fromlore ~ DeclType,
 RetType fromlore ~ DeclExtType, FParamInfo tolore ~ FParamMem,
 BranchType tolore ~ BranchTypeMem, RetType fromlore ~ DeclExtType,
 FParamInfo tolore ~ FParamMem,
 BranchType tolore ~ BranchTypeMem) =>
Param DeclType
-> Space
-> WriterT
     ([Param (FParamInfo tolore)], [Param (FParamInfo tolore)])
     (AllocM fromlore tolore)
     (Param (FParamInfo tolore),
      SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
doDefault Param DeclType
mergeparam (Space
 -> WriterT
      ([Param (FParamInfo tolore)], [Param (FParamInfo tolore)])
      (AllocM fromlore tolore)
      (Param FParamMem,
       SubExp
       -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp))
-> WriterT
     ([Param (FParamInfo tolore)], [Param (FParamInfo tolore)])
     (AllocM fromlore tolore)
     Space
-> WriterT
     ([Param (FParamInfo tolore)], [Param (FParamInfo tolore)])
     (AllocM fromlore tolore)
     (Param FParamMem,
      SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< AllocM fromlore tolore Space
-> WriterT
     ([Param (FParamInfo tolore)], [Param (FParamInfo tolore)])
     (AllocM fromlore tolore)
     Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromlore tolore Space
forall lore (m :: * -> *). Allocator lore m => m Space
askDefaultSpace

    doDefault :: Param DeclType
-> Space
-> WriterT
     ([Param (FParamInfo tolore)], [Param (FParamInfo tolore)])
     (AllocM fromlore tolore)
     (Param (FParamInfo tolore),
      SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
doDefault Param DeclType
mergeparam Space
space = do
      Param (FParamInfo tolore)
mergeparam' <- FParam fromlore
-> Space
-> WriterT
     ([Param (FParamInfo tolore)], [Param (FParamInfo tolore)])
     (AllocM fromlore tolore)
     (Param (FParamInfo tolore))
forall fromlore tolore.
Allocable fromlore tolore =>
FParam fromlore
-> Space
-> WriterT
     ([FParam tolore], [FParam tolore])
     (AllocM fromlore tolore)
     (FParam tolore)
allocInFParam Param DeclType
FParam fromlore
mergeparam Space
space
      (Param (FParamInfo tolore),
 SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
-> WriterT
     ([Param (FParamInfo tolore)], [Param (FParamInfo tolore)])
     (AllocM fromlore tolore)
     (Param (FParamInfo tolore),
      SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (FParamInfo tolore)
mergeparam', Type
-> Space
-> SubExp
-> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Type
-> Space
-> SubExp
-> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
linearFuncallArg (Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
mergeparam) Space
space)

-- Returns the existentialized index function, the list of substituted values and the memory location.
existentializeArray ::
  (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
  Space ->
  VName ->
  AllocM fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
existentializeArray :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Space
-> VName
-> AllocM
     fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
existentializeArray ScalarSpace {} VName
v = do
  (VName
mem', IxFun
ixfun) <- VName -> AllocM fromlore tolore (VName, IxFun)
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
  (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> AllocM
     fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> SubExp
Var VName
v, (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free) IxFun
ixfun, [TPrimExp Int64 VName]
forall a. Monoid a => a
mempty, VName
mem')
existentializeArray Space
space VName
v = do
  (VName
mem', IxFun
ixfun) <- VName -> AllocM fromlore tolore (VName, IxFun)
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
  Space
sp <- VName -> AllocM fromlore tolore Space
forall lore (m :: * -> *).
(HasScope lore m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem'

  let (Maybe ExtIxFun
ext_ixfun', [TPrimExp Int64 VName]
substs') = State [TPrimExp Int64 VName] (Maybe ExtIxFun)
-> [TPrimExp Int64 VName]
-> (Maybe ExtIxFun, [TPrimExp Int64 VName])
forall s a. State s a -> s -> (a, s)
runState (IxFun -> State [TPrimExp Int64 VName] (Maybe ExtIxFun)
forall t v.
(IntExp t, Eq v, Pretty v) =>
IxFun (TPrimExp t v)
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
IxFun.existentialize IxFun
ixfun) []

  case (Maybe ExtIxFun
ext_ixfun', Space
sp Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
space) of
    (Just ExtIxFun
x, Bool
True) -> (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> AllocM
     fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> SubExp
Var VName
v, ExtIxFun
x, [TPrimExp Int64 VName]
substs', VName
mem')
    (Maybe ExtIxFun, Bool)
_ -> do
      (VName
mem, SubExp
subexp) <- Space -> String -> VName -> AllocM fromlore tolore (VName, SubExp)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Space -> String -> VName -> AllocM fromlore tolore (VName, SubExp)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v
      IxFun
ixfun' <- Maybe IxFun -> IxFun
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe IxFun -> IxFun)
-> AllocM fromlore tolore (Maybe IxFun)
-> AllocM fromlore tolore IxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> AllocM fromlore tolore (Maybe IxFun)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
SubExp -> AllocM fromlore tolore (Maybe IxFun)
subExpIxFun SubExp
subexp
      let (Maybe ExtIxFun
ext_ixfun, [TPrimExp Int64 VName]
substs) = State [TPrimExp Int64 VName] (Maybe ExtIxFun)
-> [TPrimExp Int64 VName]
-> (Maybe ExtIxFun, [TPrimExp Int64 VName])
forall s a. State s a -> s -> (a, s)
runState (IxFun -> State [TPrimExp Int64 VName] (Maybe ExtIxFun)
forall t v.
(IntExp t, Eq v, Pretty v) =>
IxFun (TPrimExp t v)
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
IxFun.existentialize IxFun
ixfun') []
      (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> AllocM
     fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
subexp, Maybe ExtIxFun -> ExtIxFun
forall a. HasCallStack => Maybe a -> a
fromJust Maybe ExtIxFun
ext_ixfun, [TPrimExp Int64 VName]
substs, VName
mem)

ensureArrayIn ::
  ( Allocable fromlore tolore,
    Allocator tolore (AllocM fromlore tolore)
  ) =>
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromlore tolore) SubExp
ensureArrayIn :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Space
-> SubExp
-> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
ensureArrayIn Space
_ (Constant PrimValue
v) =
  String -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
forall a. HasCallStack => String -> a
error (String
 -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp)
-> String
-> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
forall a b. (a -> b) -> a -> b
$ String
"ensureArrayIn: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" cannot be an array."
ensureArrayIn Space
space (Var VName
v) = do
  (SubExp
sub_exp, ExtIxFun
_, [TPrimExp Int64 VName]
substs, VName
mem) <- AllocM
  fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> WriterT
     (Result, Result)
     (AllocM fromlore tolore)
     (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM
   fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
 -> WriterT
      (Result, Result)
      (AllocM fromlore tolore)
      (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName))
-> AllocM
     fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> WriterT
     (Result, Result)
     (AllocM fromlore tolore)
     (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall a b. (a -> b) -> a -> b
$ Space
-> VName
-> AllocM
     fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Space
-> VName
-> AllocM
     fromlore tolore (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
existentializeArray Space
space VName
v
  (Result
ctx_vals, [PrimExp (Ext VName)]
_) <-
    [(SubExp, PrimExp (Ext VName))] -> (Result, [PrimExp (Ext VName)])
forall a b. [(a, b)] -> ([a], [b])
unzip
      ([(SubExp, PrimExp (Ext VName))]
 -> (Result, [PrimExp (Ext VName)]))
-> WriterT
     (Result, Result)
     (AllocM fromlore tolore)
     [(SubExp, PrimExp (Ext VName))]
-> WriterT
     (Result, Result)
     (AllocM fromlore tolore)
     (Result, [PrimExp (Ext VName)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TPrimExp Int64 VName
 -> WriterT
      (Result, Result)
      (AllocM fromlore tolore)
      (SubExp, PrimExp (Ext VName)))
-> [TPrimExp Int64 VName]
-> WriterT
     (Result, Result)
     (AllocM fromlore tolore)
     [(SubExp, PrimExp (Ext VName))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
        ( \TPrimExp Int64 VName
s -> do
            VName
vname <- AllocM fromlore tolore VName
-> WriterT (Result, Result) (AllocM fromlore tolore) VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromlore tolore VName
 -> WriterT (Result, Result) (AllocM fromlore tolore) VName)
-> AllocM fromlore tolore VName
-> WriterT (Result, Result) (AllocM fromlore tolore) VName
forall a b. (a -> b) -> a -> b
$ String
-> Exp (Lore (AllocM fromlore tolore))
-> AllocM fromlore tolore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"ctx_val" (ExpT tolore -> AllocM fromlore tolore VName)
-> AllocM fromlore tolore (ExpT tolore)
-> AllocM fromlore tolore VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> AllocM fromlore tolore (Exp (Lore (AllocM fromlore tolore)))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp TPrimExp Int64 VName
s
            (SubExp, PrimExp (Ext VName))
-> WriterT
     (Result, Result)
     (AllocM fromlore tolore)
     (SubExp, PrimExp (Ext VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> SubExp
Var VName
vname, (VName -> Ext VName) -> PrimExp VName -> PrimExp (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free (PrimExp VName -> PrimExp (Ext VName))
-> PrimExp VName -> PrimExp (Ext VName)
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vname)
        )
        [TPrimExp Int64 VName]
substs

  (Result, Result)
-> WriterT (Result, Result) (AllocM fromlore tolore) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Result
ctx_vals, [VName -> SubExp
Var VName
mem])

  SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
sub_exp

ensureDirectArray ::
  ( Allocable fromlore tolore,
    Allocator tolore (AllocM fromlore tolore)
  ) =>
  Maybe Space ->
  VName ->
  AllocM fromlore tolore (VName, SubExp)
ensureDirectArray :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Maybe Space -> VName -> AllocM fromlore tolore (VName, SubExp)
ensureDirectArray Maybe Space
space_ok VName
v = do
  (VName
mem, IxFun
ixfun) <- VName -> AllocM fromlore tolore (VName, IxFun)
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
  Space
mem_space <- VName -> AllocM fromlore tolore Space
forall lore (m :: * -> *).
(HasScope lore m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- AllocM fromlore tolore Space
forall lore (m :: * -> *). Allocator lore m => m Space
askDefaultSpace
  if IxFun -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun
ixfun Bool -> Bool -> Bool
&& Bool -> (Space -> Bool) -> Maybe Space -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then (VName, SubExp) -> AllocM fromlore tolore (VName, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, VName -> SubExp
Var VName
v)
    else Space -> AllocM fromlore tolore (VName, SubExp)
needCopy (Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok)
  where
    needCopy :: Space -> AllocM fromlore tolore (VName, SubExp)
needCopy Space
space =
      -- We need to do a new allocation, copy 'v', and make a new
      -- binding for the size of the memory block.
      Space -> String -> VName -> AllocM fromlore tolore (VName, SubExp)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Space -> String -> VName -> AllocM fromlore tolore (VName, SubExp)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

allocLinearArray ::
  (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
  Space ->
  String ->
  VName ->
  AllocM fromlore tolore (VName, SubExp)
allocLinearArray :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Space -> String -> VName -> AllocM fromlore tolore (VName, SubExp)
allocLinearArray Space
space String
s VName
v = do
  Type
t <- VName -> AllocM fromlore tolore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
  case Type
t of
    Array PrimType
pt Shape
shape NoUniqueness
u -> do
      VName
mem <- Type -> Space -> AllocM fromlore tolore VName
forall lore (m :: * -> *).
Allocator lore m =>
Type -> Space -> m VName
allocForArray Type
t Space
space
      Ident
v' <- String -> Type -> AllocM fromlore tolore Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent (String
s String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_linear") Type
t
      let ixfun :: LParamMem
ixfun = PrimType -> Shape -> NoUniqueness -> VName -> Type -> LParamMem
forall u. PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIxFun PrimType
pt Shape
shape NoUniqueness
u VName
mem Type
t
          pat :: PatternT LParamMem
pat = [PatElemT LParamMem] -> [PatElemT LParamMem] -> PatternT LParamMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
v') LParamMem
ixfun]
      Stm (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ())
-> Stm (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ()
forall a b. (a -> b) -> a -> b
$ Pattern tolore
-> StmAux (ExpDec tolore) -> Exp tolore -> Stm tolore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern tolore
PatternT LParamMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp tolore -> Stm tolore) -> Exp tolore -> Stm tolore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp tolore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp tolore) -> BasicOp -> Exp tolore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
      (VName, SubExp) -> AllocM fromlore tolore (VName, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v')
    Type
_ ->
      String -> AllocM fromlore tolore (VName, SubExp)
forall a. HasCallStack => String -> a
error (String -> AllocM fromlore tolore (VName, SubExp))
-> String -> AllocM fromlore tolore (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ String
"allocLinearArray: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
t

funcallArgs ::
  ( Allocable fromlore tolore,
    Allocator tolore (AllocM fromlore tolore)
  ) =>
  [(SubExp, Diet)] ->
  AllocM fromlore tolore [(SubExp, Diet)]
funcallArgs :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[(SubExp, Diet)] -> AllocM fromlore tolore [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args = do
  ([(SubExp, Diet)]
valargs, (Result
ctx_args, Result
mem_and_size_args)) <- WriterT (Result, Result) (AllocM fromlore tolore) [(SubExp, Diet)]
-> AllocM fromlore tolore ([(SubExp, Diet)], (Result, Result))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT (Result, Result) (AllocM fromlore tolore) [(SubExp, Diet)]
 -> AllocM fromlore tolore ([(SubExp, Diet)], (Result, Result)))
-> WriterT
     (Result, Result) (AllocM fromlore tolore) [(SubExp, Diet)]
-> AllocM fromlore tolore ([(SubExp, Diet)], (Result, Result))
forall a b. (a -> b) -> a -> b
$
    [(SubExp, Diet)]
-> ((SubExp, Diet)
    -> WriterT
         (Result, Result) (AllocM fromlore tolore) (SubExp, Diet))
-> WriterT
     (Result, Result) (AllocM fromlore tolore) [(SubExp, Diet)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(SubExp, Diet)]
args (((SubExp, Diet)
  -> WriterT
       (Result, Result) (AllocM fromlore tolore) (SubExp, Diet))
 -> WriterT
      (Result, Result) (AllocM fromlore tolore) [(SubExp, Diet)])
-> ((SubExp, Diet)
    -> WriterT
         (Result, Result) (AllocM fromlore tolore) (SubExp, Diet))
-> WriterT
     (Result, Result) (AllocM fromlore tolore) [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ \(SubExp
arg, Diet
d) -> do
      Type
t <- AllocM fromlore tolore Type
-> WriterT (Result, Result) (AllocM fromlore tolore) Type
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromlore tolore Type
 -> WriterT (Result, Result) (AllocM fromlore tolore) Type)
-> AllocM fromlore tolore Type
-> WriterT (Result, Result) (AllocM fromlore tolore) Type
forall a b. (a -> b) -> a -> b
$ SubExp -> AllocM fromlore tolore Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
arg
      Space
space <- AllocM fromlore tolore Space
-> WriterT (Result, Result) (AllocM fromlore tolore) Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromlore tolore Space
forall lore (m :: * -> *). Allocator lore m => m Space
askDefaultSpace
      SubExp
arg' <- Type
-> Space
-> SubExp
-> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Type
-> Space
-> SubExp
-> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
linearFuncallArg Type
t Space
space SubExp
arg
      (SubExp, Diet)
-> WriterT (Result, Result) (AllocM fromlore tolore) (SubExp, Diet)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
arg', Diet
d)
  [(SubExp, Diet)] -> AllocM fromlore tolore [(SubExp, Diet)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(SubExp, Diet)] -> AllocM fromlore tolore [(SubExp, Diet)])
-> [(SubExp, Diet)] -> AllocM fromlore tolore [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ (SubExp -> (SubExp, Diet)) -> Result -> [(SubExp, Diet)]
forall a b. (a -> b) -> [a] -> [b]
map (,Diet
Observe) (Result
ctx_args Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
mem_and_size_args) [(SubExp, Diet)] -> [(SubExp, Diet)] -> [(SubExp, Diet)]
forall a. Semigroup a => a -> a -> a
<> [(SubExp, Diet)]
valargs

linearFuncallArg ::
  ( Allocable fromlore tolore,
    Allocator tolore (AllocM fromlore tolore)
  ) =>
  Type ->
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromlore tolore) SubExp
linearFuncallArg :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Type
-> Space
-> SubExp
-> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
linearFuncallArg Array {} Space
space (Var VName
v) = do
  (VName
mem, SubExp
arg') <- AllocM fromlore tolore (VName, SubExp)
-> WriterT
     (Result, Result) (AllocM fromlore tolore) (VName, SubExp)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromlore tolore (VName, SubExp)
 -> WriterT
      (Result, Result) (AllocM fromlore tolore) (VName, SubExp))
-> AllocM fromlore tolore (VName, SubExp)
-> WriterT
     (Result, Result) (AllocM fromlore tolore) (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ Maybe Space -> VName -> AllocM fromlore tolore (VName, SubExp)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Maybe Space -> VName -> AllocM fromlore tolore (VName, SubExp)
ensureDirectArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) VName
v
  (Result, Result)
-> WriterT (Result, Result) (AllocM fromlore tolore) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [VName -> SubExp
Var VName
mem])
  SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
arg'
linearFuncallArg Type
_ Space
_ SubExp
arg =
  SubExp -> WriterT (Result, Result) (AllocM fromlore tolore) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
arg

explicitAllocationsGeneric ::
  ( Allocable fromlore tolore,
    Allocator tolore (AllocM fromlore tolore)
  ) =>
  (Op fromlore -> AllocM fromlore tolore (Op tolore)) ->
  (Exp tolore -> AllocM fromlore tolore [ExpHint]) ->
  Pass fromlore tolore
explicitAllocationsGeneric :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
(Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> Pass fromlore tolore
explicitAllocationsGeneric Op fromlore -> AllocM fromlore tolore (Op tolore)
handleOp Exp tolore -> AllocM fromlore tolore [ExpHint]
hints =
  String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"explicit allocations" String
"Transform program to explicit memory representation" ((Prog fromlore -> PassM (Prog tolore)) -> Pass fromlore tolore)
-> (Prog fromlore -> PassM (Prog tolore)) -> Pass fromlore tolore
forall a b. (a -> b) -> a -> b
$
    (Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts Stms fromlore -> PassM (Stms tolore)
onStms Stms tolore -> FunDef fromlore -> PassM (FunDef tolore)
allocInFun
  where
    onStms :: Stms fromlore -> PassM (Stms tolore)
onStms Stms fromlore
stms =
      (Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocM fromlore tolore (Stms tolore)
-> PassM (Stms tolore)
forall (m :: * -> *) fromlore tolore a.
MonadFreshNames m =>
(Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocM fromlore tolore a
-> m a
runAllocM Op fromlore -> AllocM fromlore tolore (Op tolore)
handleOp Exp tolore -> AllocM fromlore tolore [ExpHint]
hints (AllocM fromlore tolore (Stms tolore) -> PassM (Stms tolore))
-> AllocM fromlore tolore (Stms tolore) -> PassM (Stms tolore)
forall a b. (a -> b) -> a -> b
$ AllocM fromlore tolore ()
-> AllocM fromlore tolore (Stms (Lore (AllocM fromlore tolore)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (AllocM fromlore tolore ()
 -> AllocM fromlore tolore (Stms (Lore (AllocM fromlore tolore))))
-> AllocM fromlore tolore ()
-> AllocM fromlore tolore (Stms (Lore (AllocM fromlore tolore)))
forall a b. (a -> b) -> a -> b
$ Stms fromlore
-> AllocM fromlore tolore () -> AllocM fromlore tolore ()
forall fromlore tolore a.
Allocable fromlore tolore =>
Stms fromlore
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
allocInStms Stms fromlore
stms (AllocM fromlore tolore () -> AllocM fromlore tolore ())
-> AllocM fromlore tolore () -> AllocM fromlore tolore ()
forall a b. (a -> b) -> a -> b
$ () -> AllocM fromlore tolore ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    allocInFun :: Stms tolore -> FunDef fromlore -> PassM (FunDef tolore)
allocInFun Stms tolore
consts (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType fromlore]
rettype [FParam fromlore]
params BodyT fromlore
fbody) =
      (Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocM fromlore tolore (FunDef tolore)
-> PassM (FunDef tolore)
forall (m :: * -> *) fromlore tolore a.
MonadFreshNames m =>
(Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocM fromlore tolore a
-> m a
runAllocM Op fromlore -> AllocM fromlore tolore (Op tolore)
handleOp Exp tolore -> AllocM fromlore tolore [ExpHint]
hints (AllocM fromlore tolore (FunDef tolore) -> PassM (FunDef tolore))
-> AllocM fromlore tolore (FunDef tolore) -> PassM (FunDef tolore)
forall a b. (a -> b) -> a -> b
$
        Stms tolore
-> AllocM fromlore tolore (FunDef tolore)
-> AllocM fromlore tolore (FunDef tolore)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms tolore
consts (AllocM fromlore tolore (FunDef tolore)
 -> AllocM fromlore tolore (FunDef tolore))
-> AllocM fromlore tolore (FunDef tolore)
-> AllocM fromlore tolore (FunDef tolore)
forall a b. (a -> b) -> a -> b
$
          [(FParam fromlore, Space)]
-> ([FParam tolore] -> AllocM fromlore tolore (FunDef tolore))
-> AllocM fromlore tolore (FunDef tolore)
forall fromlore tolore a.
Allocable fromlore tolore =>
[(FParam fromlore, Space)]
-> ([FParam tolore] -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInFParams ([Param DeclType] -> [Space] -> [(Param DeclType, Space)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
[FParam fromlore]
params ([Space] -> [(Param DeclType, Space)])
-> [Space] -> [(Param DeclType, Space)]
forall a b. (a -> b) -> a -> b
$ Space -> [Space]
forall a. a -> [a]
repeat Space
DefaultSpace) (([FParam tolore] -> AllocM fromlore tolore (FunDef tolore))
 -> AllocM fromlore tolore (FunDef tolore))
-> ([FParam tolore] -> AllocM fromlore tolore (FunDef tolore))
-> AllocM fromlore tolore (FunDef tolore)
forall a b. (a -> b) -> a -> b
$ \[FParam tolore]
params' -> do
            Body tolore
fbody' <-
              [Maybe Space]
-> BodyT fromlore -> AllocM fromlore tolore (Body tolore)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[Maybe Space]
-> Body fromlore -> AllocM fromlore tolore (Body tolore)
allocInFunBody
                ((DeclExtType -> Maybe Space) -> [DeclExtType] -> [Maybe Space]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe Space -> DeclExtType -> Maybe Space
forall a b. a -> b -> a
const (Maybe Space -> DeclExtType -> Maybe Space)
-> Maybe Space -> DeclExtType -> Maybe Space
forall a b. (a -> b) -> a -> b
$ Space -> Maybe Space
forall a. a -> Maybe a
Just Space
DefaultSpace) [DeclExtType]
[RetType fromlore]
rettype)
                BodyT fromlore
fbody
            FunDef tolore -> AllocM fromlore tolore (FunDef tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef tolore -> AllocM fromlore tolore (FunDef tolore))
-> FunDef tolore -> AllocM fromlore tolore (FunDef tolore)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType tolore]
-> [FParam tolore]
-> Body tolore
-> FunDef tolore
forall lore.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname ([DeclExtType] -> [RetTypeMem]
memoryInDeclExtType [DeclExtType]
[RetType fromlore]
rettype) [FParam tolore]
params' Body tolore
fbody'

explicitAllocationsInStmsGeneric ::
  ( MonadFreshNames m,
    HasScope tolore m,
    Allocable fromlore tolore
  ) =>
  (Op fromlore -> AllocM fromlore tolore (Op tolore)) ->
  (Exp tolore -> AllocM fromlore tolore [ExpHint]) ->
  Stms fromlore ->
  m (Stms tolore)
explicitAllocationsInStmsGeneric :: forall (m :: * -> *) tolore fromlore.
(MonadFreshNames m, HasScope tolore m,
 Allocable fromlore tolore) =>
(Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> Stms fromlore
-> m (Stms tolore)
explicitAllocationsInStmsGeneric Op fromlore -> AllocM fromlore tolore (Op tolore)
handleOp Exp tolore -> AllocM fromlore tolore [ExpHint]
hints Stms fromlore
stms = do
  Scope tolore
scope <- m (Scope tolore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  (Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocM fromlore tolore (Stms tolore)
-> m (Stms tolore)
forall (m :: * -> *) fromlore tolore a.
MonadFreshNames m =>
(Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocM fromlore tolore a
-> m a
runAllocM Op fromlore -> AllocM fromlore tolore (Op tolore)
handleOp Exp tolore -> AllocM fromlore tolore [ExpHint]
hints (AllocM fromlore tolore (Stms tolore) -> m (Stms tolore))
-> AllocM fromlore tolore (Stms tolore) -> m (Stms tolore)
forall a b. (a -> b) -> a -> b
$
    Scope tolore
-> AllocM fromlore tolore (Stms tolore)
-> AllocM fromlore tolore (Stms tolore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope tolore
scope (AllocM fromlore tolore (Stms tolore)
 -> AllocM fromlore tolore (Stms tolore))
-> AllocM fromlore tolore (Stms tolore)
-> AllocM fromlore tolore (Stms tolore)
forall a b. (a -> b) -> a -> b
$ AllocM fromlore tolore ()
-> AllocM fromlore tolore (Stms (Lore (AllocM fromlore tolore)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (AllocM fromlore tolore ()
 -> AllocM fromlore tolore (Stms (Lore (AllocM fromlore tolore))))
-> AllocM fromlore tolore ()
-> AllocM fromlore tolore (Stms (Lore (AllocM fromlore tolore)))
forall a b. (a -> b) -> a -> b
$ Stms fromlore
-> AllocM fromlore tolore () -> AllocM fromlore tolore ()
forall fromlore tolore a.
Allocable fromlore tolore =>
Stms fromlore
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
allocInStms Stms fromlore
stms (AllocM fromlore tolore () -> AllocM fromlore tolore ())
-> AllocM fromlore tolore () -> AllocM fromlore tolore ()
forall a b. (a -> b) -> a -> b
$ () -> AllocM fromlore tolore ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

memoryInDeclExtType :: [DeclExtType] -> [FunReturns]
memoryInDeclExtType :: [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType [DeclExtType]
dets = State Int [RetTypeMem] -> Int -> [RetTypeMem]
forall s a. State s a -> s -> a
evalState ((DeclExtType -> StateT Int Identity RetTypeMem)
-> [DeclExtType] -> State Int [RetTypeMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DeclExtType -> StateT Int Identity RetTypeMem
forall {m :: * -> *} {u}.
MonadState Int m =>
TypeBase ExtShape u -> m (MemInfo (Ext SubExp) u MemReturn)
addMem [DeclExtType]
dets) (Int -> [RetTypeMem]) -> Int -> [RetTypeMem]
forall a b. (a -> b) -> a -> b
$ [DeclExtType] -> Int
forall u. [TypeBase ExtShape u] -> Int
startOfFreeIDRange [DeclExtType]
dets
  where
    addMem :: TypeBase ExtShape u -> m (MemInfo (Ext SubExp) u MemReturn)
addMem (Prim PrimType
t) = MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo (Ext SubExp) u MemReturn
 -> m (MemInfo (Ext SubExp) u MemReturn))
-> MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    addMem Mem {} = String -> m (MemInfo (Ext SubExp) u MemReturn)
forall a. HasCallStack => String -> a
error String
"memoryInDeclExtType: too much memory"
    addMem (Array PrimType
pt ExtShape
shape u
u) = do
      Int
i <- m Int
forall s (m :: * -> *). MonadState s m => m s
get m Int -> m () -> m Int
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* (Int -> Int) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo (Ext SubExp) u MemReturn
 -> m (MemInfo (Ext SubExp) u MemReturn))
-> MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall a b. (a -> b) -> a -> b
$
        PrimType
-> ExtShape -> u -> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shape u
u (MemReturn -> MemInfo (Ext SubExp) u MemReturn)
-> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall a b. (a -> b) -> a -> b
$
          Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
DefaultSpace Int
i (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
            [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 (Ext VName)] -> ExtIxFun)
-> [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> TPrimExp Int64 (Ext VName))
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert ([Ext SubExp] -> [TPrimExp Int64 (Ext VName)])
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> a -> b
$ ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ExtShape
shape
    addMem (Acc VName
acc Shape
ispace [Type]
ts u
u) = MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo (Ext SubExp) u MemReturn
 -> m (MemInfo (Ext SubExp) u MemReturn))
-> MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> u -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Ext VName -> TPrimExp Int64 (Ext VName))
-> Ext VName -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> a -> b
$ Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i
    convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

startOfFreeIDRange :: [TypeBase ExtShape u] -> Int
startOfFreeIDRange :: forall u. [TypeBase ExtShape u] -> Int
startOfFreeIDRange = Set Int -> Int
forall a. Set a -> Int
S.size (Set Int -> Int)
-> ([TypeBase ExtShape u] -> Set Int)
-> [TypeBase ExtShape u]
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TypeBase ExtShape u] -> Set Int
forall u. [TypeBase ExtShape u] -> Set Int
shapeContext

bodyReturnMemCtx ::
  (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
  SubExp ->
  AllocM fromlore tolore [SubExp]
bodyReturnMemCtx :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
SubExp -> AllocM fromlore tolore Result
bodyReturnMemCtx Constant {} =
  Result -> AllocM fromlore tolore Result
forall (m :: * -> *) a. Monad m => a -> m a
return []
bodyReturnMemCtx (Var VName
v) = do
  LParamMem
info <- VName -> AllocM fromlore tolore LParamMem
forall lore (m :: * -> *).
(HasScope lore m, Mem lore) =>
VName -> m LParamMem
lookupMemInfo VName
v
  case LParamMem
info of
    MemPrim {} -> Result -> AllocM fromlore tolore Result
forall (m :: * -> *) a. Monad m => a -> m a
return []
    MemAcc {} -> Result -> AllocM fromlore tolore Result
forall (m :: * -> *) a. Monad m => a -> m a
return []
    MemMem {} -> Result -> AllocM fromlore tolore Result
forall (m :: * -> *) a. Monad m => a -> m a
return [] -- should not happen
    MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> Result -> AllocM fromlore tolore Result
forall (m :: * -> *) a. Monad m => a -> m a
return [VName -> SubExp
Var VName
mem]

allocInFunBody ::
  (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
  [Maybe Space] ->
  Body fromlore ->
  AllocM fromlore tolore (Body tolore)
allocInFunBody :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[Maybe Space]
-> Body fromlore -> AllocM fromlore tolore (Body tolore)
allocInFunBody [Maybe Space]
space_oks (Body BodyDec fromlore
_ Stms fromlore
bnds Result
res) =
  AllocM fromlore tolore Result
-> AllocM fromlore tolore (BodyT tolore)
forall (m :: * -> *).
MonadBinder m =>
m Result -> m (Body (Lore m))
buildBody_ (AllocM fromlore tolore Result
 -> AllocM fromlore tolore (BodyT tolore))
-> (AllocM fromlore tolore Result -> AllocM fromlore tolore Result)
-> AllocM fromlore tolore Result
-> AllocM fromlore tolore (BodyT tolore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromlore
-> AllocM fromlore tolore Result -> AllocM fromlore tolore Result
forall fromlore tolore a.
Allocable fromlore tolore =>
Stms fromlore
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
allocInStms Stms fromlore
bnds (AllocM fromlore tolore Result
 -> AllocM fromlore tolore (BodyT tolore))
-> AllocM fromlore tolore Result
-> AllocM fromlore tolore (BodyT tolore)
forall a b. (a -> b) -> a -> b
$ do
    Result
res' <- (Maybe Space -> SubExp -> AllocM fromlore tolore SubExp)
-> [Maybe Space] -> Result -> AllocM fromlore tolore Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Maybe Space -> SubExp -> AllocM fromlore tolore SubExp
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Maybe Space -> SubExp -> AllocM fromlore tolore SubExp
ensureDirect [Maybe Space]
space_oks' Result
res
    let (Result
ctx_res, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_vals Result
res'
    Result
mem_ctx_res <- [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Result)
-> AllocM fromlore tolore [Result] -> AllocM fromlore tolore Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> AllocM fromlore tolore Result)
-> Result -> AllocM fromlore tolore [Result]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> AllocM fromlore tolore Result
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
SubExp -> AllocM fromlore tolore Result
bodyReturnMemCtx Result
val_res
    Result -> AllocM fromlore tolore Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromlore tolore Result)
-> Result -> AllocM fromlore tolore Result
forall a b. (a -> b) -> a -> b
$ Result
ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
mem_ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
val_res
  where
    num_vals :: Int
num_vals = [Maybe Space] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Space]
space_oks
    space_oks' :: [Maybe Space]
space_oks' = Int -> Maybe Space -> [Maybe Space]
forall a. Int -> a -> [a]
replicate (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_vals) Maybe Space
forall a. Maybe a
Nothing [Maybe Space] -> [Maybe Space] -> [Maybe Space]
forall a. [a] -> [a] -> [a]
++ [Maybe Space]
space_oks

ensureDirect ::
  (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
  Maybe Space ->
  SubExp ->
  AllocM fromlore tolore SubExp
ensureDirect :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Maybe Space -> SubExp -> AllocM fromlore tolore SubExp
ensureDirect Maybe Space
space_ok SubExp
se = do
  LParamMem
se_info <- SubExp -> AllocM fromlore tolore LParamMem
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
  case (LParamMem
se_info, SubExp
se) of
    (MemArray {}, Var VName
v) -> do
      (VName
_, SubExp
v') <- Maybe Space -> VName -> AllocM fromlore tolore (VName, SubExp)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Maybe Space -> VName -> AllocM fromlore tolore (VName, SubExp)
ensureDirectArray Maybe Space
space_ok VName
v
      SubExp -> AllocM fromlore tolore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
v'
    (LParamMem, SubExp)
_ ->
      SubExp -> AllocM fromlore tolore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se

allocInStms ::
  (Allocable fromlore tolore) =>
  Stms fromlore ->
  AllocM fromlore tolore a ->
  AllocM fromlore tolore a
allocInStms :: forall fromlore tolore a.
Allocable fromlore tolore =>
Stms fromlore
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
allocInStms Stms fromlore
origstms AllocM fromlore tolore a
m = [Stm fromlore] -> AllocM fromlore tolore a
allocInStms' ([Stm fromlore] -> AllocM fromlore tolore a)
-> [Stm fromlore] -> AllocM fromlore tolore a
forall a b. (a -> b) -> a -> b
$ Stms fromlore -> [Stm fromlore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms fromlore
origstms
  where
    allocInStms' :: [Stm fromlore] -> AllocM fromlore tolore a
allocInStms' [] = AllocM fromlore tolore a
m
    allocInStms' (Stm fromlore
stm : [Stm fromlore]
stms) = do
      Seq (Stm tolore)
allocstms <- AllocM fromlore tolore ()
-> AllocM fromlore tolore (Stms (Lore (AllocM fromlore tolore)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (AllocM fromlore tolore ()
 -> AllocM fromlore tolore (Stms (Lore (AllocM fromlore tolore))))
-> AllocM fromlore tolore ()
-> AllocM fromlore tolore (Stms (Lore (AllocM fromlore tolore)))
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec fromlore)
-> AllocM fromlore tolore () -> AllocM fromlore tolore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing (Stm fromlore -> StmAux (ExpDec fromlore)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm fromlore
stm) (AllocM fromlore tolore () -> AllocM fromlore tolore ())
-> AllocM fromlore tolore () -> AllocM fromlore tolore ()
forall a b. (a -> b) -> a -> b
$ Stm fromlore -> AllocM fromlore tolore ()
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Stm fromlore -> AllocM fromlore tolore ()
allocInStm Stm fromlore
stm
      Stms (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Seq (Stm tolore)
Stms (Lore (AllocM fromlore tolore))
allocstms
      let stms_substs :: Map VName SubExp
stms_substs = (Stm tolore -> Map VName SubExp)
-> Seq (Stm tolore) -> Map VName SubExp
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm tolore -> Map VName SubExp
forall lore. SizeSubst (Op lore) => Stm lore -> Map VName SubExp
sizeSubst Seq (Stm tolore)
allocstms
          stms_consts :: Set VName
stms_consts = (Stm tolore -> Set VName) -> Seq (Stm tolore) -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm tolore -> Set VName
forall lore. SizeSubst (Op lore) => Stm lore -> Set VName
stmConsts Seq (Stm tolore)
allocstms
          f :: AllocEnv fromlore tolore -> AllocEnv fromlore tolore
f AllocEnv fromlore tolore
env =
            AllocEnv fromlore tolore
env
              { chunkMap :: Map VName SubExp
chunkMap = Map VName SubExp
stms_substs Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> AllocEnv fromlore tolore -> Map VName SubExp
forall fromlore tolore.
AllocEnv fromlore tolore -> Map VName SubExp
chunkMap AllocEnv fromlore tolore
env,
                envConsts :: Set VName
envConsts = Set VName
stms_consts Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> AllocEnv fromlore tolore -> Set VName
forall fromlore tolore. AllocEnv fromlore tolore -> Set VName
envConsts AllocEnv fromlore tolore
env
              }
      (AllocEnv fromlore tolore -> AllocEnv fromlore tolore)
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv fromlore tolore -> AllocEnv fromlore tolore
f (AllocM fromlore tolore a -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall a b. (a -> b) -> a -> b
$ [Stm fromlore] -> AllocM fromlore tolore a
allocInStms' [Stm fromlore]
stms

allocInStm ::
  (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
  Stm fromlore ->
  AllocM fromlore tolore ()
allocInStm :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Stm fromlore -> AllocM fromlore tolore ()
allocInStm (Let (Pattern [PatElemT (LetDec fromlore)]
sizeElems [PatElemT (LetDec fromlore)]
valElems) StmAux (ExpDec fromlore)
_ Exp fromlore
e) = do
  Exp tolore
e' <- Exp fromlore -> AllocM fromlore tolore (Exp tolore)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Exp fromlore -> AllocM fromlore tolore (Exp tolore)
allocInExp Exp fromlore
e
  let sizeidents :: [Ident]
sizeidents = (PatElemT (LetDec fromlore) -> Ident)
-> [PatElemT (LetDec fromlore)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (LetDec fromlore) -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent [PatElemT (LetDec fromlore)]
sizeElems
      validents :: [Ident]
validents = (PatElemT (LetDec fromlore) -> Ident)
-> [PatElemT (LetDec fromlore)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (LetDec fromlore) -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent [PatElemT (LetDec fromlore)]
valElems
  Stm tolore
bnd <- [Ident]
-> [Ident] -> Exp tolore -> AllocM fromlore tolore (Stm tolore)
forall lore (m :: * -> *).
(Allocator lore m, ExpDec lore ~ ()) =>
[Ident] -> [Ident] -> Exp lore -> m (Stm lore)
allocsForStm [Ident]
sizeidents [Ident]
validents Exp tolore
e'
  Stm (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm tolore
Stm (Lore (AllocM fromlore tolore))
bnd

allocInLambda ::
  Allocable fromlore tolore =>
  [LParam tolore] ->
  Body fromlore ->
  AllocM fromlore tolore (Lambda tolore)
allocInLambda :: forall fromlore tolore.
Allocable fromlore tolore =>
[LParam tolore]
-> Body fromlore -> AllocM fromlore tolore (Lambda tolore)
allocInLambda [LParam tolore]
params Body fromlore
body =
  [LParam (Lore (AllocM fromlore tolore))]
-> AllocM fromlore tolore Result
-> AllocM fromlore tolore (Lambda (Lore (AllocM fromlore tolore)))
forall (m :: * -> *).
MonadBinder m =>
[LParam (Lore m)] -> m Result -> m (Lambda (Lore m))
mkLambda [LParam tolore]
[LParam (Lore (AllocM fromlore tolore))]
params (AllocM fromlore tolore Result
 -> AllocM fromlore tolore (LambdaT tolore))
-> (AllocM fromlore tolore Result -> AllocM fromlore tolore Result)
-> AllocM fromlore tolore Result
-> AllocM fromlore tolore (LambdaT tolore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromlore
-> AllocM fromlore tolore Result -> AllocM fromlore tolore Result
forall fromlore tolore a.
Allocable fromlore tolore =>
Stms fromlore
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
allocInStms (Body fromlore -> Stms fromlore
forall lore. BodyT lore -> Stms lore
bodyStms Body fromlore
body) (AllocM fromlore tolore Result
 -> AllocM fromlore tolore (LambdaT tolore))
-> AllocM fromlore tolore Result
-> AllocM fromlore tolore (LambdaT tolore)
forall a b. (a -> b) -> a -> b
$
    Result -> AllocM fromlore tolore Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromlore tolore Result)
-> Result -> AllocM fromlore tolore Result
forall a b. (a -> b) -> a -> b
$ Body fromlore -> Result
forall lore. BodyT lore -> Result
bodyResult Body fromlore
body

allocInExp ::
  (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
  Exp fromlore ->
  AllocM fromlore tolore (Exp tolore)
allocInExp :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Exp fromlore -> AllocM fromlore tolore (Exp tolore)
allocInExp (DoLoop [(FParam fromlore, SubExp)]
ctx [(FParam fromlore, SubExp)]
val LoopForm fromlore
form (Body () Stms fromlore
bodybnds Result
bodyres)) =
  [(FParam fromlore, SubExp)]
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore (ExpT tolore))
-> AllocM fromlore tolore (ExpT tolore)
forall fromlore tolore a.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[(FParam fromlore, SubExp)]
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInMergeParams [(FParam fromlore, SubExp)]
ctx (([FParam tolore]
  -> [FParam tolore]
  -> (Result -> AllocM fromlore tolore (Result, Result))
  -> AllocM fromlore tolore (ExpT tolore))
 -> AllocM fromlore tolore (ExpT tolore))
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore (ExpT tolore))
-> AllocM fromlore tolore (ExpT tolore)
forall a b. (a -> b) -> a -> b
$ \[FParam tolore]
_ [FParam tolore]
ctxparams' Result -> AllocM fromlore tolore (Result, Result)
_ ->
    [(FParam fromlore, SubExp)]
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore (ExpT tolore))
-> AllocM fromlore tolore (ExpT tolore)
forall fromlore tolore a.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[(FParam fromlore, SubExp)]
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInMergeParams [(FParam fromlore, SubExp)]
val (([FParam tolore]
  -> [FParam tolore]
  -> (Result -> AllocM fromlore tolore (Result, Result))
  -> AllocM fromlore tolore (ExpT tolore))
 -> AllocM fromlore tolore (ExpT tolore))
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore (ExpT tolore))
-> AllocM fromlore tolore (ExpT tolore)
forall a b. (a -> b) -> a -> b
$
      \[FParam tolore]
new_ctx_params [FParam tolore]
valparams' Result -> AllocM fromlore tolore (Result, Result)
mk_loop_val -> do
        LoopForm tolore
form' <- LoopForm fromlore -> AllocM fromlore tolore (LoopForm tolore)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
LoopForm fromlore -> AllocM fromlore tolore (LoopForm tolore)
allocInLoopForm LoopForm fromlore
form
        Scope tolore
-> AllocM fromlore tolore (ExpT tolore)
-> AllocM fromlore tolore (ExpT tolore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm tolore -> Scope tolore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm tolore
form') (AllocM fromlore tolore (ExpT tolore)
 -> AllocM fromlore tolore (ExpT tolore))
-> AllocM fromlore tolore (ExpT tolore)
-> AllocM fromlore tolore (ExpT tolore)
forall a b. (a -> b) -> a -> b
$ do
          (Result
valinit_ctx, Result
valinit') <- Result -> AllocM fromlore tolore (Result, Result)
mk_loop_val Result
valinit
          BodyT tolore
body' <-
            AllocM fromlore tolore Result
-> AllocM fromlore tolore (BodyT tolore)
forall (m :: * -> *).
MonadBinder m =>
m Result -> m (Body (Lore m))
buildBody_ (AllocM fromlore tolore Result
 -> AllocM fromlore tolore (BodyT tolore))
-> (AllocM fromlore tolore Result -> AllocM fromlore tolore Result)
-> AllocM fromlore tolore Result
-> AllocM fromlore tolore (BodyT tolore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromlore
-> AllocM fromlore tolore Result -> AllocM fromlore tolore Result
forall fromlore tolore a.
Allocable fromlore tolore =>
Stms fromlore
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
allocInStms Stms fromlore
bodybnds (AllocM fromlore tolore Result
 -> AllocM fromlore tolore (BodyT tolore))
-> AllocM fromlore tolore Result
-> AllocM fromlore tolore (BodyT tolore)
forall a b. (a -> b) -> a -> b
$ do
              (Result
val_ses, Result
valres') <- Result -> AllocM fromlore tolore (Result, Result)
mk_loop_val Result
valres
              Result -> AllocM fromlore tolore Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromlore tolore Result)
-> Result -> AllocM fromlore tolore Result
forall a b. (a -> b) -> a -> b
$ Result
ctxres Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
val_ses Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
valres'
          ExpT tolore -> AllocM fromlore tolore (ExpT tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT tolore -> AllocM fromlore tolore (ExpT tolore))
-> ExpT tolore -> AllocM fromlore tolore (ExpT tolore)
forall a b. (a -> b) -> a -> b
$
            [(FParam tolore, SubExp)]
-> [(FParam tolore, SubExp)]
-> LoopForm tolore
-> BodyT tolore
-> ExpT tolore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop
              ([Param FParamMem] -> Result -> [(Param FParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([FParam tolore]
[Param FParamMem]
ctxparams' [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. [a] -> [a] -> [a]
++ [FParam tolore]
[Param FParamMem]
new_ctx_params) (Result
ctxinit Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
valinit_ctx))
              ([Param FParamMem] -> Result -> [(Param FParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FParam tolore]
[Param FParamMem]
valparams' Result
valinit')
              LoopForm tolore
form'
              BodyT tolore
body'
  where
    ([Param DeclType]
_ctxparams, Result
ctxinit) = [(Param DeclType, SubExp)] -> ([Param DeclType], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam fromlore, SubExp)]
ctx
    ([Param DeclType]
_valparams, Result
valinit) = [(Param DeclType, SubExp)] -> ([Param DeclType], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam fromlore, SubExp)]
val
    (Result
ctxres, Result
valres) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Param DeclType, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Param DeclType, SubExp)]
[(FParam fromlore, SubExp)]
ctx) Result
bodyres
allocInExp (Apply Name
fname [(SubExp, Diet)]
args [RetType fromlore]
rettype (Safety, SrcLoc, [SrcLoc])
loc) = do
  [(SubExp, Diet)]
args' <- [(SubExp, Diet)] -> AllocM fromlore tolore [(SubExp, Diet)]
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[(SubExp, Diet)] -> AllocM fromlore tolore [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args
  ExpT tolore -> AllocM fromlore tolore (ExpT tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT tolore -> AllocM fromlore tolore (ExpT tolore))
-> ExpT tolore -> AllocM fromlore tolore (ExpT tolore)
forall a b. (a -> b) -> a -> b
$ Name
-> [(SubExp, Diet)]
-> [RetType tolore]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT tolore
forall lore.
Name
-> [(SubExp, Diet)]
-> [RetType lore]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT lore
Apply Name
fname [(SubExp, Diet)]
args' ([DeclExtType] -> [RetTypeMem]
memoryInDeclExtType [DeclExtType]
[RetType fromlore]
rettype) (Safety, SrcLoc, [SrcLoc])
loc
allocInExp (If SubExp
cond BodyT fromlore
tbranch0 BodyT fromlore
fbranch0 (IfDec [BranchType fromlore]
rets IfSort
ifsort)) = do
  let num_rets :: Int
num_rets = [ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
[BranchType fromlore]
rets
  -- switch to the explicit-mem rep, but do nothing about results
  (BodyT tolore
tbranch, [Maybe IxFun]
tm_ixfs) <- Int
-> BodyT fromlore
-> AllocM fromlore tolore (BodyT tolore, [Maybe IxFun])
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Int
-> Body fromlore
-> AllocM fromlore tolore (Body tolore, [Maybe IxFun])
allocInIfBody Int
num_rets BodyT fromlore
tbranch0
  (BodyT tolore
fbranch, [Maybe IxFun]
fm_ixfs) <- Int
-> BodyT fromlore
-> AllocM fromlore tolore (BodyT tolore, [Maybe IxFun])
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Int
-> Body fromlore
-> AllocM fromlore tolore (Body tolore, [Maybe IxFun])
allocInIfBody Int
num_rets BodyT fromlore
fbranch0
  [Maybe Space]
tspaces <- Int -> BodyT tolore -> AllocM fromlore tolore [Maybe Space]
forall tolore (m :: * -> *).
(Mem tolore, LocalScope tolore m) =>
Int -> Body tolore -> m [Maybe Space]
mkSpaceOks Int
num_rets BodyT tolore
tbranch
  [Maybe Space]
fspaces <- Int -> BodyT tolore -> AllocM fromlore tolore [Maybe Space]
forall tolore (m :: * -> *).
(Mem tolore, LocalScope tolore m) =>
Int -> Body tolore -> m [Maybe Space]
mkSpaceOks Int
num_rets BodyT tolore
fbranch
  -- try to generalize (antiunify) the index functions of the then and else bodies
  let sp_substs :: [(Maybe Space,
  Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))]
sp_substs = ((Maybe Space, Maybe IxFun)
 -> (Maybe Space, Maybe IxFun)
 -> (Maybe Space,
     Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])))
-> [(Maybe Space, Maybe IxFun)]
-> [(Maybe Space, Maybe IxFun)]
-> [(Maybe Space,
     Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Maybe Space, Maybe IxFun)
-> (Maybe Space, Maybe IxFun)
-> (Maybe Space,
    Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))
generalize ([Maybe Space] -> [Maybe IxFun] -> [(Maybe Space, Maybe IxFun)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe Space]
tspaces [Maybe IxFun]
tm_ixfs) ([Maybe Space] -> [Maybe IxFun] -> [(Maybe Space, Maybe IxFun)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe Space]
fspaces [Maybe IxFun]
fm_ixfs)
      ([Maybe Space]
spaces, [Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
subs) = [(Maybe Space,
  Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))]
-> ([Maybe Space],
    [Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Maybe Space,
  Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))]
sp_substs
      tsubs :: [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
tsubs = (Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
 -> Maybe (ExtIxFun, [TPrimExp Int64 VName]))
-> [Maybe
      (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
forall a b. (a -> b) -> [a] -> [b]
map (((TPrimExp Int64 VName, TPrimExp Int64 VName)
 -> TPrimExp Int64 VName)
-> Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
-> Maybe (ExtIxFun, [TPrimExp Int64 VName])
forall a.
((a, a) -> a)
-> Maybe (ExtIxFun, [(a, a)]) -> Maybe (ExtIxFun, [a])
selectSub (TPrimExp Int64 VName, TPrimExp Int64 VName)
-> TPrimExp Int64 VName
forall a b. (a, b) -> a
fst) [Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
subs
      fsubs :: [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
fsubs = (Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
 -> Maybe (ExtIxFun, [TPrimExp Int64 VName]))
-> [Maybe
      (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
forall a b. (a -> b) -> [a] -> [b]
map (((TPrimExp Int64 VName, TPrimExp Int64 VName)
 -> TPrimExp Int64 VName)
-> Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
-> Maybe (ExtIxFun, [TPrimExp Int64 VName])
forall a.
((a, a) -> a)
-> Maybe (ExtIxFun, [(a, a)]) -> Maybe (ExtIxFun, [a])
selectSub (TPrimExp Int64 VName, TPrimExp Int64 VName)
-> TPrimExp Int64 VName
forall a b. (a, b) -> b
snd) [Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
subs
  (BodyT tolore
tbranch', [BranchTypeMem]
trets) <- [ExtType]
-> BodyT tolore
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromlore tolore (BodyT tolore, [BranchTypeMem])
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[ExtType]
-> Body tolore
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromlore tolore (Body tolore, [BranchTypeMem])
addResCtxInIfBody [ExtType]
[BranchType fromlore]
rets BodyT tolore
tbranch [Maybe Space]
spaces [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
tsubs
  (BodyT tolore
fbranch', [BranchTypeMem]
frets) <- [ExtType]
-> BodyT tolore
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromlore tolore (BodyT tolore, [BranchTypeMem])
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[ExtType]
-> Body tolore
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromlore tolore (Body tolore, [BranchTypeMem])
addResCtxInIfBody [ExtType]
[BranchType fromlore]
rets BodyT tolore
fbranch [Maybe Space]
spaces [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
fsubs
  if [BranchTypeMem]
frets [BranchTypeMem] -> [BranchTypeMem] -> Bool
forall a. Eq a => a -> a -> Bool
/= [BranchTypeMem]
trets
    then String -> AllocM fromlore tolore (ExpT tolore)
forall a. HasCallStack => String -> a
error String
"In allocInExp, IF case: antiunification of then/else produce different ExtInFn!"
    else do
      -- above is a sanity check; implementation continues on else branch
      let res_then :: Result
res_then = BodyT tolore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT tolore
tbranch'
          res_else :: Result
res_else = BodyT tolore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT tolore
fbranch'
          size_ext :: Int
size_ext = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res_then Int -> Int -> Int
forall a. Num a => a -> a -> a
- [BranchTypeMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchTypeMem]
trets
          ([(SubExp, SubExp, Int)]
ind_ses0, [(SubExp, SubExp, Int)]
r_then_else) =
            ((SubExp, SubExp, Int) -> Bool)
-> [(SubExp, SubExp, Int)]
-> ([(SubExp, SubExp, Int)], [(SubExp, SubExp, Int)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (\(SubExp
r_then, SubExp
r_else, Int
_) -> SubExp
r_then SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
r_else) ([(SubExp, SubExp, Int)]
 -> ([(SubExp, SubExp, Int)], [(SubExp, SubExp, Int)]))
-> [(SubExp, SubExp, Int)]
-> ([(SubExp, SubExp, Int)], [(SubExp, SubExp, Int)])
forall a b. (a -> b) -> a -> b
$
              Result -> Result -> [Int] -> [(SubExp, SubExp, Int)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
res_then Result
res_else [Int
0 .. Int
size_ext Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
          (Result
r_then_ext, Result
r_else_ext, [Int]
_) = [(SubExp, SubExp, Int)] -> (Result, Result, [Int])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, Int)]
r_then_else
          ind_ses :: [(Int, SubExp)]
ind_ses =
            ((SubExp, SubExp, Int) -> Int -> (Int, SubExp))
-> [(SubExp, SubExp, Int)] -> [Int] -> [(Int, SubExp)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              (\(SubExp
se, SubExp
_, Int
i) Int
k -> (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k, SubExp
se))
              [(SubExp, SubExp, Int)]
ind_ses0
              [Int
0 .. [(SubExp, SubExp, Int)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SubExp, SubExp, Int)]
ind_ses0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
          rets'' :: [BranchTypeMem]
rets'' = ([BranchTypeMem] -> (Int, SubExp) -> [BranchTypeMem])
-> [BranchTypeMem] -> [(Int, SubExp)] -> [BranchTypeMem]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\[BranchTypeMem]
acc (Int
i, SubExp
se) -> Int -> SubExp -> [BranchTypeMem] -> [BranchTypeMem]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se [BranchTypeMem]
acc) [BranchTypeMem]
trets [(Int, SubExp)]
ind_ses
          tbranch'' :: BodyT tolore
tbranch'' = BodyT tolore
tbranch' {bodyResult :: Result
bodyResult = Result
r_then_ext Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
size_ext Result
res_then}
          fbranch'' :: BodyT tolore
fbranch'' = BodyT tolore
fbranch' {bodyResult :: Result
bodyResult = Result
r_else_ext Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
size_ext Result
res_else}
          res_if_expr :: ExpT tolore
res_if_expr = SubExp
-> BodyT tolore
-> BodyT tolore
-> IfDec (BranchType tolore)
-> ExpT tolore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond BodyT tolore
tbranch'' BodyT tolore
fbranch'' (IfDec (BranchType tolore) -> ExpT tolore)
-> IfDec (BranchType tolore) -> ExpT tolore
forall a b. (a -> b) -> a -> b
$ [BranchTypeMem] -> IfSort -> IfDec BranchTypeMem
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchTypeMem]
rets'' IfSort
ifsort
      ExpT tolore -> AllocM fromlore tolore (ExpT tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return ExpT tolore
res_if_expr
  where
    generalize ::
      (Maybe Space, Maybe IxFun) ->
      (Maybe Space, Maybe IxFun) ->
      (Maybe Space, Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))
    generalize :: (Maybe Space, Maybe IxFun)
-> (Maybe Space, Maybe IxFun)
-> (Maybe Space,
    Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))
generalize (Just Space
sp1, Just IxFun
ixf1) (Just Space
sp2, Just IxFun
ixf2) =
      if Space
sp1 Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
/= Space
sp2
        then (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
sp1, Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
forall a. Maybe a
Nothing)
        else case IxFun (PrimExp VName)
-> IxFun (PrimExp VName)
-> Maybe
     (IxFun (PrimExp (Ext VName)), [(PrimExp VName, PrimExp VName)])
forall v.
Eq v =>
IxFun (PrimExp v)
-> IxFun (PrimExp v)
-> Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
IxFun.leastGeneralGeneralization ((TPrimExp Int64 VName -> PrimExp VName)
-> IxFun -> IxFun (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped IxFun
ixf1) ((TPrimExp Int64 VName -> PrimExp VName)
-> IxFun -> IxFun (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped IxFun
ixf2) of
          Just (IxFun (PrimExp (Ext VName))
ixf, [(PrimExp VName, PrimExp VName)]
m) ->
            ( Space -> Maybe Space
forall a. a -> Maybe a
Just Space
sp1,
              (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
-> Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
forall a. a -> Maybe a
Just
                ( (PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> IxFun (PrimExp (Ext VName)) -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall t v. PrimExp v -> TPrimExp t v
TPrimExp IxFun (PrimExp (Ext VName))
ixf,
                  [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [(TPrimExp Int64 VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((PrimExp VName, PrimExp VName) -> TPrimExp Int64 VName)
-> [(PrimExp VName, PrimExp VName)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> ((PrimExp VName, PrimExp VName) -> PrimExp VName)
-> (PrimExp VName, PrimExp VName)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimExp VName, PrimExp VName) -> PrimExp VName
forall a b. (a, b) -> a
fst) [(PrimExp VName, PrimExp VName)]
m) (((PrimExp VName, PrimExp VName) -> TPrimExp Int64 VName)
-> [(PrimExp VName, PrimExp VName)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> ((PrimExp VName, PrimExp VName) -> PrimExp VName)
-> (PrimExp VName, PrimExp VName)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimExp VName, PrimExp VName) -> PrimExp VName
forall a b. (a, b) -> b
snd) [(PrimExp VName, PrimExp VName)]
m)
                )
            )
          Maybe
  (IxFun (PrimExp (Ext VName)), [(PrimExp VName, PrimExp VName)])
Nothing -> (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
sp1, Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
forall a. Maybe a
Nothing)
    generalize (Maybe Space
mbsp1, Maybe IxFun
_) (Maybe Space, Maybe IxFun)
_ = (Maybe Space
mbsp1, Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
forall a. Maybe a
Nothing)

    selectSub ::
      ((a, a) -> a) ->
      Maybe (ExtIxFun, [(a, a)]) ->
      Maybe (ExtIxFun, [a])
    selectSub :: forall a.
((a, a) -> a)
-> Maybe (ExtIxFun, [(a, a)]) -> Maybe (ExtIxFun, [a])
selectSub (a, a) -> a
f (Just (ExtIxFun
ixfn, [(a, a)]
m)) = (ExtIxFun, [a]) -> Maybe (ExtIxFun, [a])
forall a. a -> Maybe a
Just (ExtIxFun
ixfn, ((a, a) -> a) -> [(a, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a, a) -> a
f [(a, a)]
m)
    selectSub (a, a) -> a
_ Maybe (ExtIxFun, [(a, a)])
Nothing = Maybe (ExtIxFun, [a])
forall a. Maybe a
Nothing
    allocInIfBody ::
      (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
      Int ->
      Body fromlore ->
      AllocM fromlore tolore (Body tolore, [Maybe IxFun])
    allocInIfBody :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Int
-> Body fromlore
-> AllocM fromlore tolore (Body tolore, [Maybe IxFun])
allocInIfBody Int
num_vals (Body BodyDec fromlore
_ Stms fromlore
bnds Result
res) =
      AllocM fromlore tolore (Result, [Maybe IxFun])
-> AllocM fromlore tolore (BodyT tolore, [Maybe IxFun])
forall (m :: * -> *) a.
MonadBinder m =>
m (Result, a) -> m (Body (Lore m), a)
buildBody (AllocM fromlore tolore (Result, [Maybe IxFun])
 -> AllocM fromlore tolore (BodyT tolore, [Maybe IxFun]))
-> (AllocM fromlore tolore (Result, [Maybe IxFun])
    -> AllocM fromlore tolore (Result, [Maybe IxFun]))
-> AllocM fromlore tolore (Result, [Maybe IxFun])
-> AllocM fromlore tolore (BodyT tolore, [Maybe IxFun])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromlore
-> AllocM fromlore tolore (Result, [Maybe IxFun])
-> AllocM fromlore tolore (Result, [Maybe IxFun])
forall fromlore tolore a.
Allocable fromlore tolore =>
Stms fromlore
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
allocInStms Stms fromlore
bnds (AllocM fromlore tolore (Result, [Maybe IxFun])
 -> AllocM fromlore tolore (BodyT tolore, [Maybe IxFun]))
-> AllocM fromlore tolore (Result, [Maybe IxFun])
-> AllocM fromlore tolore (BodyT tolore, [Maybe IxFun])
forall a b. (a -> b) -> a -> b
$ do
        let (Result
_, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_vals Result
res
        [Maybe IxFun]
mem_ixfs <- (SubExp -> AllocM fromlore tolore (Maybe IxFun))
-> Result -> AllocM fromlore tolore [Maybe IxFun]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> AllocM fromlore tolore (Maybe IxFun)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
SubExp -> AllocM fromlore tolore (Maybe IxFun)
subExpIxFun Result
val_res
        (Result, [Maybe IxFun])
-> AllocM fromlore tolore (Result, [Maybe IxFun])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
res, [Maybe IxFun]
mem_ixfs)
allocInExp (WithAcc [(Shape, [VName], Maybe (Lambda fromlore, Result))]
inputs Lambda fromlore
bodylam) =
  [(Shape, [VName], Maybe (Lambda tolore, Result))]
-> Lambda tolore -> ExpT tolore
forall lore.
[(Shape, [VName], Maybe (Lambda lore, Result))]
-> Lambda lore -> ExpT lore
WithAcc ([(Shape, [VName], Maybe (Lambda tolore, Result))]
 -> Lambda tolore -> ExpT tolore)
-> AllocM
     fromlore tolore [(Shape, [VName], Maybe (Lambda tolore, Result))]
-> AllocM fromlore tolore (Lambda tolore -> ExpT tolore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Shape, [VName], Maybe (Lambda fromlore, Result))
 -> AllocM
      fromlore tolore (Shape, [VName], Maybe (Lambda tolore, Result)))
-> [(Shape, [VName], Maybe (Lambda fromlore, Result))]
-> AllocM
     fromlore tolore [(Shape, [VName], Maybe (Lambda tolore, Result))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Shape, [VName], Maybe (Lambda fromlore, Result))
-> AllocM
     fromlore tolore (Shape, [VName], Maybe (Lambda tolore, Result))
forall {t :: * -> *} {a} {lore} {fromlore} {b}.
(Traversable t, ArrayShape a, AllocOp (Op lore),
 SizeSubst (Op lore), OpReturns lore, BinderOps lore,
 PrettyLore fromlore, LParamInfo fromlore ~ Type,
 RetType lore ~ RetTypeMem, BodyDec fromlore ~ (),
 FParamInfo fromlore ~ DeclType, LParamInfo lore ~ LParamMem,
 FParamInfo lore ~ FParamMem, BodyDec lore ~ (), ExpDec lore ~ (),
 LetDec lore ~ LParamMem, RetType fromlore ~ DeclExtType,
 BranchType fromlore ~ ExtType, BranchType lore ~ BranchTypeMem) =>
(a, [VName], t (LambdaT fromlore, b))
-> AllocM fromlore lore (a, [VName], t (Lambda lore, b))
onInput [(Shape, [VName], Maybe (Lambda fromlore, Result))]
inputs AllocM fromlore tolore (Lambda tolore -> ExpT tolore)
-> AllocM fromlore tolore (Lambda tolore)
-> AllocM fromlore tolore (ExpT tolore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda fromlore -> AllocM fromlore tolore (Lambda tolore)
forall {fromlore} {tolore}.
(PrettyLore fromlore, AllocOp (Op tolore), OpReturns tolore,
 SizeSubst (Op tolore), BinderOps tolore, LetDec tolore ~ LParamMem,
 BodyDec fromlore ~ (), LParamInfo tolore ~ LParamMem,
 LParamInfo fromlore ~ Type, BodyDec tolore ~ (),
 RetType tolore ~ RetTypeMem, FParamInfo tolore ~ FParamMem,
 RetType fromlore ~ DeclExtType, FParamInfo fromlore ~ DeclType,
 BranchType fromlore ~ ExtType, ExpDec tolore ~ (),
 BranchType tolore ~ BranchTypeMem) =>
LambdaT fromlore -> AllocM fromlore tolore (Lambda tolore)
onLambda Lambda fromlore
bodylam
  where
    onLambda :: LambdaT fromlore -> AllocM fromlore tolore (Lambda tolore)
onLambda LambdaT fromlore
lam = do
      [Param LParamMem]
params <- [Param Type]
-> (Param Type -> AllocM fromlore tolore (Param LParamMem))
-> AllocM fromlore tolore [Param LParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (LambdaT fromlore -> [LParam fromlore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT fromlore
lam) ((Param Type -> AllocM fromlore tolore (Param LParamMem))
 -> AllocM fromlore tolore [Param LParamMem])
-> (Param Type -> AllocM fromlore tolore (Param LParamMem))
-> AllocM fromlore tolore [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ \(Param VName
pv Type
t) ->
        case Type
t of
          Prim PrimType
Unit -> Param LParamMem -> AllocM fromlore tolore (Param LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param LParamMem -> AllocM fromlore tolore (Param LParamMem))
-> Param LParamMem -> AllocM fromlore tolore (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> Param LParamMem
forall dec. VName -> dec -> Param dec
Param VName
pv (LParamMem -> Param LParamMem) -> LParamMem -> Param LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
Unit
          Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u -> Param LParamMem -> AllocM fromlore tolore (Param LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param LParamMem -> AllocM fromlore tolore (Param LParamMem))
-> Param LParamMem -> AllocM fromlore tolore (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> Param LParamMem
forall dec. VName -> dec -> Param dec
Param VName
pv (LParamMem -> Param LParamMem) -> LParamMem -> Param LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
          Type
_ -> String -> AllocM fromlore tolore (Param LParamMem)
forall a. HasCallStack => String -> a
error (String -> AllocM fromlore tolore (Param LParamMem))
-> String -> AllocM fromlore tolore (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ String
"Unexpected WithAcc lambda param: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Param Type -> String
forall a. Pretty a => a -> String
pretty (VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
pv Type
t)
      [LParam tolore]
-> Body fromlore -> AllocM fromlore tolore (Lambda tolore)
forall fromlore tolore.
Allocable fromlore tolore =>
[LParam tolore]
-> Body fromlore -> AllocM fromlore tolore (Lambda tolore)
allocInLambda [LParam tolore]
[Param LParamMem]
params (LambdaT fromlore -> Body fromlore
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT fromlore
lam)

    onInput :: (a, [VName], t (LambdaT fromlore, b))
-> AllocM fromlore lore (a, [VName], t (Lambda lore, b))
onInput (a
shape, [VName]
arrs, t (LambdaT fromlore, b)
op) =
      (a
shape,[VName]
arrs,) (t (Lambda lore, b) -> (a, [VName], t (Lambda lore, b)))
-> AllocM fromlore lore (t (Lambda lore, b))
-> AllocM fromlore lore (a, [VName], t (Lambda lore, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((LambdaT fromlore, b) -> AllocM fromlore lore (Lambda lore, b))
-> t (LambdaT fromlore, b)
-> AllocM fromlore lore (t (Lambda lore, b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (a
-> [VName]
-> (LambdaT fromlore, b)
-> AllocM fromlore lore (Lambda lore, b)
forall {a} {lore} {fromlore} {b}.
(ArrayShape a, AllocOp (Op lore), SizeSubst (Op lore),
 OpReturns lore, PrettyLore fromlore, BinderOps lore,
 LetDec lore ~ LParamMem, BodyDec fromlore ~ (),
 FParamInfo fromlore ~ DeclType, RetType fromlore ~ DeclExtType,
 ExpDec lore ~ (), RetType lore ~ RetTypeMem,
 LParamInfo lore ~ LParamMem, BranchType fromlore ~ ExtType,
 BodyDec lore ~ (), LParamInfo fromlore ~ Type,
 FParamInfo lore ~ FParamMem, BranchType lore ~ BranchTypeMem) =>
a
-> [VName]
-> (LambdaT fromlore, b)
-> AllocM fromlore lore (Lambda lore, b)
onOp a
shape [VName]
arrs) t (LambdaT fromlore, b)
op

    onOp :: a
-> [VName]
-> (LambdaT fromlore, b)
-> AllocM fromlore lore (Lambda lore, b)
onOp a
accshape [VName]
arrs (LambdaT fromlore
lam, b
nes) = do
      let num_vs :: Int
num_vs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT fromlore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT fromlore
lam)
          num_is :: Int
num_is = a -> Int
forall a. ArrayShape a => a -> Int
shapeRank a
accshape
          ([Param Type]
i_params, [Param Type]
x_params, [Param Type]
y_params) =
            Int
-> Int
-> [Param Type]
-> ([Param Type], [Param Type], [Param Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 Int
num_is Int
num_vs ([Param Type] -> ([Param Type], [Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ LambdaT fromlore -> [LParam fromlore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT fromlore
lam
          i_params' :: [Param LParamMem]
i_params' = (Param Type -> Param LParamMem)
-> [Param Type] -> [Param LParamMem]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> LParamMem -> Param LParamMem
forall dec. VName -> dec -> Param dec
`Param` PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64) (VName -> Param LParamMem)
-> (Param Type -> VName) -> Param Type -> Param LParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
i_params
          is :: [DimIndex SubExp]
is = (Param LParamMem -> DimIndex SubExp)
-> [Param LParamMem] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (Param LParamMem -> SubExp)
-> Param LParamMem
-> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (Param LParamMem -> VName) -> Param LParamMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param LParamMem -> VName
forall dec. Param dec -> VName
paramName) [Param LParamMem]
i_params'
      [Param LParamMem]
x_params' <- (Param Type -> VName -> AllocM fromlore lore (Param LParamMem))
-> [Param Type]
-> [VName]
-> AllocM fromlore lore [Param LParamMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([DimIndex SubExp]
-> Param Type -> VName -> AllocM fromlore lore (Param LParamMem)
forall {m :: * -> *} {lore} {u}.
(Monad m, AllocOp (Op lore), ASTLore lore, OpReturns lore,
 HasScope lore m, Pretty u, LetDec lore ~ LParamMem,
 LParamInfo lore ~ LParamMem, RetType lore ~ RetTypeMem,
 FParamInfo lore ~ FParamMem, BranchType lore ~ BranchTypeMem) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> m (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
is) [Param Type]
x_params [VName]
arrs
      [Param LParamMem]
y_params' <- (Param Type -> VName -> AllocM fromlore lore (Param LParamMem))
-> [Param Type]
-> [VName]
-> AllocM fromlore lore [Param LParamMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([DimIndex SubExp]
-> Param Type -> VName -> AllocM fromlore lore (Param LParamMem)
forall {m :: * -> *} {lore} {u}.
(Allocator lore m, Pretty u) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> m (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
is) [Param Type]
y_params [VName]
arrs
      Lambda lore
lam' <-
        [LParam lore]
-> Body fromlore -> AllocM fromlore lore (Lambda lore)
forall fromlore tolore.
Allocable fromlore tolore =>
[LParam tolore]
-> Body fromlore -> AllocM fromlore tolore (Lambda tolore)
allocInLambda
          ([Param LParamMem]
i_params' [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
x_params' [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
y_params')
          (LambdaT fromlore -> Body fromlore
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT fromlore
lam)
      (Lambda lore, b) -> AllocM fromlore lore (Lambda lore, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore
lam', b
nes)

    mkP :: VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP VName
p PrimType
pt Shape
shape u
u VName
mem IxFun
ixfun [DimIndex SubExp]
is =
      VName
-> MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind)
forall dec. VName -> dec -> Param dec
Param VName
p (MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind))
-> (Slice (TPrimExp Int64 VName) -> MemInfo SubExp u MemBind)
-> Slice (TPrimExp Int64 VName)
-> Param (MemInfo SubExp u MemBind)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> u -> MemBind -> MemInfo SubExp u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape u
u (MemBind -> MemInfo SubExp u MemBind)
-> (Slice (TPrimExp Int64 VName) -> MemBind)
-> Slice (TPrimExp Int64 VName)
-> MemInfo SubExp u MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind)
-> (Slice (TPrimExp Int64 VName) -> IxFun)
-> Slice (TPrimExp Int64 VName)
-> MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfun (Slice (TPrimExp Int64 VName) -> Param (MemInfo SubExp u MemBind))
-> Slice (TPrimExp Int64 VName) -> Param (MemInfo SubExp u MemBind)
forall a b. (a -> b) -> a -> b
$
        (DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> [DimIndex SubExp] -> Slice (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) ([DimIndex SubExp] -> Slice (TPrimExp Int64 VName))
-> [DimIndex SubExp] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
is [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimIndex SubExp) -> Result -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape)

    onXParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> m (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
_ (Param VName
p (Prim PrimType
t)) VName
_ =
      Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo SubExp u MemBind)
 -> m (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind)
forall dec. VName -> dec -> Param dec
Param VName
p (PrimType -> MemInfo SubExp u MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
    onXParam [DimIndex SubExp]
is (Param VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      (VName
mem, IxFun
ixfun) <- VName -> m (VName, IxFun)
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
arr
      Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo SubExp u MemBind)
 -> m (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
forall {u}.
VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP VName
p PrimType
pt Shape
shape u
u VName
mem IxFun
ixfun [DimIndex SubExp]
is
    onXParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      String -> m (Param (MemInfo SubExp u MemBind))
forall a. HasCallStack => String -> a
error (String -> m (Param (MemInfo SubExp u MemBind)))
-> String -> m (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Param (TypeBase Shape u) -> String
forall a. Pretty a => a -> String
pretty Param (TypeBase Shape u)
p

    onYParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> m (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
_ (Param VName
p (Prim PrimType
t)) VName
_ =
      Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo SubExp u MemBind)
 -> m (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind)
forall dec. VName -> dec -> Param dec
Param VName
p (PrimType -> MemInfo SubExp u MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
    onYParam [DimIndex SubExp]
is (Param VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
      VName
mem <- Type -> Space -> m VName
forall lore (m :: * -> *).
Allocator lore m =>
Type -> Space -> m VName
allocForArray Type
arr_t Space
DefaultSpace
      let base_dims :: [TPrimExp Int64 VName]
base_dims = (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Result -> [TPrimExp Int64 VName])
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t
          ixfun :: IxFun
ixfun = [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
base_dims
      Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> m (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
forall {u}.
VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP VName
p PrimType
pt Shape
shape u
u VName
mem IxFun
ixfun [DimIndex SubExp]
is
    onYParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      String -> m (Param (MemInfo SubExp u MemBind))
forall a. HasCallStack => String -> a
error (String -> m (Param (MemInfo SubExp u MemBind)))
-> String -> m (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Param (TypeBase Shape u) -> String
forall a. Pretty a => a -> String
pretty Param (TypeBase Shape u)
p
allocInExp ExpT fromlore
e = Mapper fromlore tolore (AllocM fromlore tolore)
-> ExpT fromlore -> AllocM fromlore tolore (ExpT tolore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper fromlore tolore (AllocM fromlore tolore)
alloc ExpT fromlore
e
  where
    alloc :: Mapper fromlore tolore (AllocM fromlore tolore)
alloc =
      Mapper Any Any (AllocM fromlore tolore)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope tolore
-> BodyT fromlore -> AllocM fromlore tolore (BodyT tolore)
mapOnBody = String
-> Scope tolore
-> BodyT fromlore
-> AllocM fromlore tolore (BodyT tolore)
forall a. HasCallStack => String -> a
error String
"Unhandled Body in ExplicitAllocations",
          mapOnRetType :: RetType fromlore -> AllocM fromlore tolore (RetType tolore)
mapOnRetType = String
-> RetType fromlore -> AllocM fromlore tolore (RetType tolore)
forall a. HasCallStack => String -> a
error String
"Unhandled RetType in ExplicitAllocations",
          mapOnBranchType :: BranchType fromlore -> AllocM fromlore tolore (BranchType tolore)
mapOnBranchType = String
-> BranchType fromlore
-> AllocM fromlore tolore (BranchType tolore)
forall a. HasCallStack => String -> a
error String
"Unhandled BranchType in ExplicitAllocations",
          mapOnFParam :: FParam fromlore -> AllocM fromlore tolore (FParam tolore)
mapOnFParam = String -> FParam fromlore -> AllocM fromlore tolore (FParam tolore)
forall a. HasCallStack => String -> a
error String
"Unhandled FParam in ExplicitAllocations",
          mapOnLParam :: LParam fromlore -> AllocM fromlore tolore (LParam tolore)
mapOnLParam = String -> LParam fromlore -> AllocM fromlore tolore (LParam tolore)
forall a. HasCallStack => String -> a
error String
"Unhandled LParam in ExplicitAllocations",
          mapOnOp :: Op fromlore -> AllocM fromlore tolore (Op tolore)
mapOnOp = \Op fromlore
op -> do
            Op fromlore -> AllocM fromlore tolore (Op tolore)
handle <- (AllocEnv fromlore tolore
 -> Op fromlore -> AllocM fromlore tolore (Op tolore))
-> AllocM
     fromlore tolore (Op fromlore -> AllocM fromlore tolore (Op tolore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromlore tolore
-> Op fromlore -> AllocM fromlore tolore (Op tolore)
forall fromlore tolore.
AllocEnv fromlore tolore
-> Op fromlore -> AllocM fromlore tolore (Op tolore)
allocInOp
            Op fromlore -> AllocM fromlore tolore (Op tolore)
handle Op fromlore
op
        }

subExpIxFun ::
  (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
  SubExp ->
  AllocM fromlore tolore (Maybe IxFun)
subExpIxFun :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
SubExp -> AllocM fromlore tolore (Maybe IxFun)
subExpIxFun Constant {} = Maybe IxFun -> AllocM fromlore tolore (Maybe IxFun)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe IxFun
forall a. Maybe a
Nothing
subExpIxFun (Var VName
v) = do
  LParamMem
info <- VName -> AllocM fromlore tolore LParamMem
forall lore (m :: * -> *).
(HasScope lore m, Mem lore) =>
VName -> m LParamMem
lookupMemInfo VName
v
  case LParamMem
info of
    MemArray PrimType
_ptp Shape
_shp NoUniqueness
_u (ArrayIn VName
_ IxFun
ixf) -> Maybe IxFun -> AllocM fromlore tolore (Maybe IxFun)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe IxFun -> AllocM fromlore tolore (Maybe IxFun))
-> Maybe IxFun -> AllocM fromlore tolore (Maybe IxFun)
forall a b. (a -> b) -> a -> b
$ IxFun -> Maybe IxFun
forall a. a -> Maybe a
Just IxFun
ixf
    LParamMem
_ -> Maybe IxFun -> AllocM fromlore tolore (Maybe IxFun)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe IxFun
forall a. Maybe a
Nothing

addResCtxInIfBody ::
  (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
  [ExtType] ->
  Body tolore ->
  [Maybe Space] ->
  [Maybe (ExtIxFun, [TPrimExp Int64 VName])] ->
  AllocM fromlore tolore (Body tolore, [BodyReturns])
addResCtxInIfBody :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[ExtType]
-> Body tolore
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromlore tolore (Body tolore, [BranchTypeMem])
addResCtxInIfBody [ExtType]
ifrets (Body BodyDec tolore
_ Stms tolore
bnds Result
res) [Maybe Space]
spaces [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
substs = do
  let num_vals :: Int
num_vals = [ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
ifrets
      (Result
ctx_res, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_vals Result
res
  ((Result
res', [BranchTypeMem]
bodyrets'), Stms tolore
all_body_stms) <- AllocM fromlore tolore (Result, [BranchTypeMem])
-> AllocM
     fromlore
     tolore
     ((Result, [BranchTypeMem]), Stms (Lore (AllocM fromlore tolore)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (AllocM fromlore tolore (Result, [BranchTypeMem])
 -> AllocM
      fromlore
      tolore
      ((Result, [BranchTypeMem]), Stms (Lore (AllocM fromlore tolore))))
-> AllocM fromlore tolore (Result, [BranchTypeMem])
-> AllocM
     fromlore
     tolore
     ((Result, [BranchTypeMem]), Stms (Lore (AllocM fromlore tolore)))
forall a b. (a -> b) -> a -> b
$ do
    (Stm tolore -> AllocM fromlore tolore ())
-> Stms tolore -> AllocM fromlore tolore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm tolore -> AllocM fromlore tolore ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stms tolore
bnds
    (Result
val_res', Result
ext_ses_res, Result
mem_ctx_res, [BranchTypeMem]
bodyrets, Int
total_existentials) <-
      ((Result, Result, Result, [BranchTypeMem], Int)
 -> (ExtType, SubExp, Maybe (ExtIxFun, [TPrimExp Int64 VName]),
     Maybe Space)
 -> AllocM
      fromlore tolore (Result, Result, Result, [BranchTypeMem], Int))
-> (Result, Result, Result, [BranchTypeMem], Int)
-> [(ExtType, SubExp, Maybe (ExtIxFun, [TPrimExp Int64 VName]),
     Maybe Space)]
-> AllocM
     fromlore tolore (Result, Result, Result, [BranchTypeMem], Int)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Result, Result, Result, [BranchTypeMem], Int)
-> (ExtType, SubExp, Maybe (ExtIxFun, [TPrimExp Int64 VName]),
    Maybe Space)
-> AllocM
     fromlore tolore (Result, Result, Result, [BranchTypeMem], Int)
forall {tolore} {fromlore} {a} {u}.
(PrettyLore fromlore, AllocOp (Op tolore), OpReturns tolore,
 SizeSubst (Op tolore), BinderOps tolore, ToExp a,
 LetDec tolore ~ LParamMem, BodyDec tolore ~ (),
 LParamInfo fromlore ~ Type, BranchType fromlore ~ ExtType,
 ExpDec tolore ~ (), RetType tolore ~ RetTypeMem,
 LParamInfo tolore ~ LParamMem, BodyDec fromlore ~ (),
 FParamInfo fromlore ~ DeclType, RetType fromlore ~ DeclExtType,
 FParamInfo tolore ~ FParamMem,
 BranchType tolore ~ BranchTypeMem) =>
(Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
-> (TypeBase ExtShape u, SubExp, Maybe (ExtIxFun, [a]),
    Maybe Space)
-> AllocM
     fromlore
     tolore
     (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
helper ([], [], [], [], Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
ctx_res) ([ExtType]
-> Result
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> [Maybe Space]
-> [(ExtType, SubExp, Maybe (ExtIxFun, [TPrimExp Int64 VName]),
     Maybe Space)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [ExtType]
ifrets Result
val_res [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
substs [Maybe Space]
spaces)
    (Result, [BranchTypeMem])
-> AllocM fromlore tolore (Result, [BranchTypeMem])
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( Result
ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
ext_ses_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
mem_ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
val_res',
        -- We need to adjust the ReturnsNewBlock existentials, because they
        -- should always be numbered _after_ all other existentials in the
        -- return values.
        [BranchTypeMem] -> [BranchTypeMem]
forall a. [a] -> [a]
reverse ([BranchTypeMem] -> [BranchTypeMem])
-> [BranchTypeMem] -> [BranchTypeMem]
forall a b. (a -> b) -> a -> b
$ ([BranchTypeMem], Int) -> [BranchTypeMem]
forall a b. (a, b) -> a
fst (([BranchTypeMem], Int) -> [BranchTypeMem])
-> ([BranchTypeMem], Int) -> [BranchTypeMem]
forall a b. (a -> b) -> a -> b
$ (([BranchTypeMem], Int) -> BranchTypeMem -> ([BranchTypeMem], Int))
-> ([BranchTypeMem], Int)
-> [BranchTypeMem]
-> ([BranchTypeMem], Int)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([BranchTypeMem], Int) -> BranchTypeMem -> ([BranchTypeMem], Int)
adjustNewBlockExistential ([], Int
total_existentials) [BranchTypeMem]
bodyrets
      )
  BodyT tolore
body' <- Stms (Lore (AllocM fromlore tolore))
-> Result
-> AllocM fromlore tolore (Body (Lore (AllocM fromlore tolore)))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> Result -> m (Body (Lore m))
mkBodyM Stms tolore
Stms (Lore (AllocM fromlore tolore))
all_body_stms Result
res'
  (BodyT tolore, [BranchTypeMem])
-> AllocM fromlore tolore (BodyT tolore, [BranchTypeMem])
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT tolore
body', [BranchTypeMem]
bodyrets')
  where
    helper :: (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
-> (TypeBase ExtShape u, SubExp, Maybe (ExtIxFun, [a]),
    Maybe Space)
-> AllocM
     fromlore
     tolore
     (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
helper (Result
res_acc, Result
ext_acc, Result
ctx_acc, [MemInfo (Ext SubExp) u MemReturn]
br_acc, Int
k) (TypeBase ExtShape u
ifr, SubExp
r, Maybe (ExtIxFun, [a])
mbixfsub, Maybe Space
sp) =
      case Maybe (ExtIxFun, [a])
mbixfsub of
        Maybe (ExtIxFun, [a])
Nothing -> do
          -- does NOT generalize/antiunify; ensure direct
          SubExp
r' <- Maybe Space -> SubExp -> AllocM fromlore tolore SubExp
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Maybe Space -> SubExp -> AllocM fromlore tolore SubExp
ensureDirect Maybe Space
sp SubExp
r
          Result
mem_ctx_r <- SubExp -> AllocM fromlore tolore Result
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
SubExp -> AllocM fromlore tolore Result
bodyReturnMemCtx SubExp
r'
          let body_ret :: MemInfo (Ext SubExp) u MemReturn
body_ret = TypeBase ExtShape u
-> Maybe Space -> MemInfo (Ext SubExp) u MemReturn
forall {u}.
TypeBase ExtShape u
-> Maybe Space -> MemInfo (Ext SubExp) u MemReturn
inspect TypeBase ExtShape u
ifr Maybe Space
sp
          (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
-> AllocM
     fromlore
     tolore
     (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( Result
res_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp
r'],
              Result
ext_acc,
              Result
ctx_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
mem_ctx_r,
              [MemInfo (Ext SubExp) u MemReturn]
br_acc [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
forall a. [a] -> [a] -> [a]
++ [MemInfo (Ext SubExp) u MemReturn
body_ret],
              Int
k
            )
        Just (ExtIxFun
ixfn, [a]
m) -> do
          -- generalizes
          let i :: Int
i = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
m
          Result
ext_ses <- (a -> AllocM fromlore tolore SubExp)
-> [a] -> AllocM fromlore tolore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> a -> AllocM fromlore tolore SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"ixfn_exist") [a]
m
          Result
mem_ctx_r <- SubExp -> AllocM fromlore tolore Result
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
SubExp -> AllocM fromlore tolore Result
bodyReturnMemCtx SubExp
r
          let sp' :: Space
sp' = Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
DefaultSpace Maybe Space
sp
              ixfn' :: ExtIxFun
ixfn' = (TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName)
forall t. Int -> TPrimExp t (Ext VName) -> TPrimExp t (Ext VName)
adjustExtPE Int
k) ExtIxFun
ixfn
              exttp :: MemInfo (Ext SubExp) u MemReturn
exttp = case TypeBase ExtShape u
ifr of
                Array PrimType
pt ExtShape
shp' u
u ->
                  PrimType
-> ExtShape -> u -> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shp' u
u (MemReturn -> MemInfo (Ext SubExp) u MemReturn)
-> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall a b. (a -> b) -> a -> b
$
                    Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
sp' Int
0 ExtIxFun
ixfn'
                TypeBase ExtShape u
_ -> String -> MemInfo (Ext SubExp) u MemReturn
forall a. HasCallStack => String -> a
error String
"Impossible case reached in addResCtxInIfBody"
          (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
-> AllocM
     fromlore
     tolore
     (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( Result
res_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp
r],
              Result
ext_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
ext_ses,
              Result
ctx_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
mem_ctx_r,
              [MemInfo (Ext SubExp) u MemReturn]
br_acc [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
forall a. [a] -> [a] -> [a]
++ [MemInfo (Ext SubExp) u MemReturn
exttp],
              Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i
            )

    adjustNewBlockExistential :: ([BodyReturns], Int) -> BodyReturns -> ([BodyReturns], Int)
    adjustNewBlockExistential :: ([BranchTypeMem], Int) -> BranchTypeMem -> ([BranchTypeMem], Int)
adjustNewBlockExistential ([BranchTypeMem]
acc, Int
k) (MemArray PrimType
pt ExtShape
shp NoUniqueness
u (ReturnsNewBlock Space
space Int
_ ExtIxFun
ixfun)) =
      (PrimType -> ExtShape -> NoUniqueness -> MemReturn -> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shp NoUniqueness
u (Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
k ExtIxFun
ixfun) BranchTypeMem -> [BranchTypeMem] -> [BranchTypeMem]
forall a. a -> [a] -> [a]
: [BranchTypeMem]
acc, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    adjustNewBlockExistential ([BranchTypeMem]
acc, Int
k) BranchTypeMem
x = (BranchTypeMem
x BranchTypeMem -> [BranchTypeMem] -> [BranchTypeMem]
forall a. a -> [a] -> [a]
: [BranchTypeMem]
acc, Int
k)

    inspect :: TypeBase ExtShape u
-> Maybe Space -> MemInfo (Ext SubExp) u MemReturn
inspect (Array PrimType
pt ExtShape
shape u
u) Maybe Space
space =
      let space' :: Space
space' = Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
DefaultSpace Maybe Space
space
          bodyret :: MemInfo (Ext SubExp) u MemReturn
bodyret =
            PrimType
-> ExtShape -> u -> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shape u
u (MemReturn -> MemInfo (Ext SubExp) u MemReturn)
-> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall a b. (a -> b) -> a -> b
$
              Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space' Int
0 (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
                [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 (Ext VName)] -> ExtIxFun)
-> [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> TPrimExp Int64 (Ext VName))
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert ([Ext SubExp] -> [TPrimExp Int64 (Ext VName)])
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> a -> b
$ ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ExtShape
shape
       in MemInfo (Ext SubExp) u MemReturn
bodyret
    inspect (Acc VName
acc Shape
ispace [Type]
ts u
u) Maybe Space
_ = VName -> Shape -> [Type] -> u -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u
    inspect (Prim PrimType
pt) Maybe Space
_ = PrimType -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
    inspect (Mem Space
space) Maybe Space
_ = Space -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i)
    convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    adjustExtV :: Int -> Ext VName -> Ext VName
    adjustExtV :: Int -> Ext VName -> Ext VName
adjustExtV Int
_ (Free VName
v) = VName -> Ext VName
forall a. a -> Ext a
Free VName
v
    adjustExtV Int
k (Ext Int
i) = Int -> Ext VName
forall a. Int -> Ext a
Ext (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)

    adjustExtPE :: Int -> TPrimExp t (Ext VName) -> TPrimExp t (Ext VName)
    adjustExtPE :: forall t. Int -> TPrimExp t (Ext VName) -> TPrimExp t (Ext VName)
adjustExtPE Int
k = (Ext VName -> Ext VName)
-> TPrimExp t (Ext VName) -> TPrimExp t (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Ext VName -> Ext VName
adjustExtV Int
k)

mkSpaceOks ::
  (Mem tolore, LocalScope tolore m) =>
  Int ->
  Body tolore ->
  m [Maybe Space]
mkSpaceOks :: forall tolore (m :: * -> *).
(Mem tolore, LocalScope tolore m) =>
Int -> Body tolore -> m [Maybe Space]
mkSpaceOks Int
num_vals (Body BodyDec tolore
_ Stms tolore
stms Result
res) =
  Stms tolore -> m [Maybe Space] -> m [Maybe Space]
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms tolore
stms (m [Maybe Space] -> m [Maybe Space])
-> m [Maybe Space] -> m [Maybe Space]
forall a b. (a -> b) -> a -> b
$
    (SubExp -> m (Maybe Space)) -> Result -> m [Maybe Space]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m (Maybe Space)
forall {lore} {m :: * -> *}.
(HasScope lore m, AllocOp (Op lore), Monad m, ASTLore lore,
 OpReturns lore, LParamInfo lore ~ LParamMem,
 LetDec lore ~ LParamMem, BranchType lore ~ BranchTypeMem,
 FParamInfo lore ~ FParamMem, RetType lore ~ RetTypeMem) =>
SubExp -> m (Maybe Space)
mkSpaceOK (Result -> m [Maybe Space]) -> Result -> m [Maybe Space]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast Int
num_vals Result
res
  where
    mkSpaceOK :: SubExp -> m (Maybe Space)
mkSpaceOK (Var VName
v) = do
      LParamMem
v_info <- VName -> m LParamMem
forall lore (m :: * -> *).
(HasScope lore m, Mem lore) =>
VName -> m LParamMem
lookupMemInfo VName
v
      case LParamMem
v_info of
        MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
          LParamMem
mem_info <- VName -> m LParamMem
forall lore (m :: * -> *).
(HasScope lore m, Mem lore) =>
VName -> m LParamMem
lookupMemInfo VName
mem
          case LParamMem
mem_info of
            MemMem Space
space -> Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Space -> m (Maybe Space)) -> Maybe Space -> m (Maybe Space)
forall a b. (a -> b) -> a -> b
$ Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space
            LParamMem
_ -> Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Space
forall a. Maybe a
Nothing
        LParamMem
_ -> Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Space
forall a. Maybe a
Nothing
    mkSpaceOK SubExp
_ = Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Space
forall a. Maybe a
Nothing

allocInLoopForm ::
  ( Allocable fromlore tolore,
    Allocator tolore (AllocM fromlore tolore)
  ) =>
  LoopForm fromlore ->
  AllocM fromlore tolore (LoopForm tolore)
allocInLoopForm :: forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
LoopForm fromlore -> AllocM fromlore tolore (LoopForm tolore)
allocInLoopForm (WhileLoop VName
v) = LoopForm tolore -> AllocM fromlore tolore (LoopForm tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopForm tolore -> AllocM fromlore tolore (LoopForm tolore))
-> LoopForm tolore -> AllocM fromlore tolore (LoopForm tolore)
forall a b. (a -> b) -> a -> b
$ VName -> LoopForm tolore
forall lore. VName -> LoopForm lore
WhileLoop VName
v
allocInLoopForm (ForLoop VName
i IntType
it SubExp
n [(LParam fromlore, VName)]
loopvars) =
  VName
-> IntType -> SubExp -> [(LParam tolore, VName)] -> LoopForm tolore
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
n ([(Param LParamMem, VName)] -> LoopForm tolore)
-> AllocM fromlore tolore [(Param LParamMem, VName)]
-> AllocM fromlore tolore (LoopForm tolore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName)
 -> AllocM fromlore tolore (Param LParamMem, VName))
-> [(Param Type, VName)]
-> AllocM fromlore tolore [(Param LParamMem, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param Type, VName)
-> AllocM fromlore tolore (Param LParamMem, VName)
allocInLoopVar [(Param Type, VName)]
[(LParam fromlore, VName)]
loopvars
  where
    allocInLoopVar :: (Param Type, VName)
-> AllocM fromlore tolore (Param LParamMem, VName)
allocInLoopVar (Param Type
p, VName
a) = do
      (VName
mem, IxFun
ixfun) <- VName -> AllocM fromlore tolore (VName, IxFun)
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
a
      case Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p of
        Array PrimType
pt Shape
shape NoUniqueness
u -> do
          [TPrimExp Int64 VName]
dims <- (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Result -> [TPrimExp Int64 VName])
-> (Type -> Result) -> Type -> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims (Type -> [TPrimExp Int64 VName])
-> AllocM fromlore tolore Type
-> AllocM fromlore tolore [TPrimExp Int64 VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> AllocM fromlore tolore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
a
          let ixfun' :: IxFun
ixfun' =
                IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfun (Slice (TPrimExp Int64 VName) -> IxFun)
-> Slice (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
                  [TPrimExp Int64 VName]
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum [TPrimExp Int64 VName]
dims [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i]
          (Param LParamMem, VName)
-> AllocM fromlore tolore (Param LParamMem, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p {paramDec :: LParamMem
paramDec = PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun'}, VName
a)
        Prim PrimType
bt ->
          (Param LParamMem, VName)
-> AllocM fromlore tolore (Param LParamMem, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p {paramDec :: LParamMem
paramDec = PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt}, VName
a)
        Mem Space
space ->
          (Param LParamMem, VName)
-> AllocM fromlore tolore (Param LParamMem, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p {paramDec :: LParamMem
paramDec = Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}, VName
a)
        Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
          (Param LParamMem, VName)
-> AllocM fromlore tolore (Param LParamMem, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p {paramDec :: LParamMem
paramDec = VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u}, VName
a)

class SizeSubst op where
  opSizeSubst :: PatternT dec -> op -> ChunkMap
  opIsConst :: op -> Bool
  opIsConst = Bool -> op -> Bool
forall a b. a -> b -> a
const Bool
False

instance SizeSubst () where
  opSizeSubst :: forall dec. PatternT dec -> () -> Map VName SubExp
opSizeSubst PatternT dec
_ ()
_ = Map VName SubExp
forall a. Monoid a => a
mempty

instance SizeSubst op => SizeSubst (MemOp op) where
  opSizeSubst :: forall dec. PatternT dec -> MemOp op -> Map VName SubExp
opSizeSubst PatternT dec
pat (Inner op
op) = PatternT dec -> op -> Map VName SubExp
forall op dec.
SizeSubst op =>
PatternT dec -> op -> Map VName SubExp
opSizeSubst PatternT dec
pat op
op
  opSizeSubst PatternT dec
_ MemOp op
_ = Map VName SubExp
forall a. Monoid a => a
mempty

  opIsConst :: MemOp op -> Bool
opIsConst (Inner op
op) = op -> Bool
forall op. SizeSubst op => op -> Bool
opIsConst op
op
  opIsConst MemOp op
_ = Bool
False

sizeSubst :: SizeSubst (Op lore) => Stm lore -> ChunkMap
sizeSubst :: forall lore. SizeSubst (Op lore) => Stm lore -> Map VName SubExp
sizeSubst (Let Pattern lore
pat StmAux (ExpDec lore)
_ (Op Op lore
op)) = Pattern lore -> Op lore -> Map VName SubExp
forall op dec.
SizeSubst op =>
PatternT dec -> op -> Map VName SubExp
opSizeSubst Pattern lore
pat Op lore
op
sizeSubst Stm lore
_ = Map VName SubExp
forall a. Monoid a => a
mempty

stmConsts :: SizeSubst (Op lore) => Stm lore -> S.Set VName
stmConsts :: forall lore. SizeSubst (Op lore) => Stm lore -> Set VName
stmConsts (Let Pattern lore
pat StmAux (ExpDec lore)
_ (Op Op lore
op))
  | Op lore -> Bool
forall op. SizeSubst op => op -> Bool
opIsConst Op lore
op = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat
stmConsts Stm lore
_ = Set VName
forall a. Monoid a => a
mempty

mkLetNamesB' ::
  ( Op (Lore m) ~ MemOp inner,
    MonadBinder m,
    ExpDec (Lore m) ~ (),
    Allocator (Lore m) (PatAllocM (Lore m))
  ) =>
  ExpDec (Lore m) ->
  [VName] ->
  Exp (Lore m) ->
  m (Stm (Lore m))
mkLetNamesB' :: forall (m :: * -> *) inner.
(Op (Lore m) ~ MemOp inner, MonadBinder m, ExpDec (Lore m) ~ (),
 Allocator (Lore m) (PatAllocM (Lore m))) =>
ExpDec (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesB' ExpDec (Lore m)
dec [VName]
names Exp (Lore m)
e = do
  Scope (Lore m)
scope <- m (Scope (Lore m))
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  PatternT (LetDec (Lore m))
pat <- Scope (Lore m)
-> [VName] -> Exp (Lore m) -> m (PatternT (LetDec (Lore m)))
forall (m :: * -> *) lore inner.
(MonadBinder m, ExpDec lore ~ (), Op (Lore m) ~ MemOp inner,
 Allocator lore (PatAllocM lore)) =>
Scope lore -> [VName] -> Exp lore -> m (Pattern lore)
bindPatternWithAllocations Scope (Lore m)
scope [VName]
names Exp (Lore m)
e
  Stm (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Lore m) -> m (Stm (Lore m)))
-> Stm (Lore m) -> m (Stm (Lore m))
forall a b. (a -> b) -> a -> b
$ PatternT (LetDec (Lore m))
-> StmAux (ExpDec (Lore m)) -> Exp (Lore m) -> Stm (Lore m)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT (LetDec (Lore m))
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()
ExpDec (Lore m)
dec) Exp (Lore m)
e

mkLetNamesB'' ::
  ( Op (Lore m) ~ MemOp inner,
    ExpDec lore ~ (),
    HasScope (Engine.Wise lore) m,
    Allocator lore (PatAllocM lore),
    MonadBinder m,
    Engine.CanBeWise (Op lore)
  ) =>
  [VName] ->
  Exp (Engine.Wise lore) ->
  m (Stm (Engine.Wise lore))
mkLetNamesB'' :: forall (m :: * -> *) inner lore.
(Op (Lore m) ~ MemOp inner, ExpDec lore ~ (),
 HasScope (Wise lore) m, Allocator lore (PatAllocM lore),
 MonadBinder m, CanBeWise (Op lore)) =>
[VName] -> Exp (Wise lore) -> m (Stm (Wise lore))
mkLetNamesB'' [VName]
names Exp (Wise lore)
e = do
  Scope lore
scope <- Scope (Wise lore) -> Scope lore
forall lore. Scope (Wise lore) -> Scope lore
Engine.removeScopeWisdom (Scope (Wise lore) -> Scope lore)
-> m (Scope (Wise lore)) -> m (Scope lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Scope (Wise lore))
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  (PatternT LParamMem
pat, [AllocStm]
prestms) <- PatAllocM lore (PatternT LParamMem)
-> Scope lore -> m (PatternT LParamMem, [AllocStm])
forall (m :: * -> *) lore a.
MonadFreshNames m =>
PatAllocM lore a -> Scope lore -> m (a, [AllocStm])
runPatAllocM ([VName] -> Exp lore -> PatAllocM lore (Pattern lore)
forall lore (m :: * -> *).
(Allocator lore m, ExpDec lore ~ ()) =>
[VName] -> Exp lore -> m (Pattern lore)
patternWithAllocations [VName]
names (Exp lore -> PatAllocM lore (Pattern lore))
-> Exp lore -> PatAllocM lore (Pattern lore)
forall a b. (a -> b) -> a -> b
$ Exp (Wise lore) -> Exp lore
forall lore. CanBeWise (Op lore) => Exp (Wise lore) -> Exp lore
Engine.removeExpWisdom Exp (Wise lore)
e) Scope lore
scope
  (AllocStm -> m ()) -> [AllocStm] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ AllocStm -> m ()
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ MemOp inner) =>
AllocStm -> m ()
bindAllocStm [AllocStm]
prestms
  let pat' :: Pattern (Wise lore)
pat' = Pattern lore -> Exp (Wise lore) -> Pattern (Wise lore)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
Pattern lore -> Exp (Wise lore) -> Pattern (Wise lore)
Engine.addWisdomToPattern Pattern lore
PatternT LParamMem
pat Exp (Wise lore)
e
      dec :: ExpDec (Wise lore)
dec = Pattern (Wise lore)
-> ExpDec lore -> Exp (Wise lore) -> ExpDec (Wise lore)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
Pattern (Wise lore)
-> ExpDec lore -> Exp (Wise lore) -> ExpDec (Wise lore)
Engine.mkWiseExpDec Pattern (Wise lore)
pat' () Exp (Wise lore)
e
  Stm (Wise lore) -> m (Stm (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Wise lore) -> m (Stm (Wise lore)))
-> Stm (Wise lore) -> m (Stm (Wise lore))
forall a b. (a -> b) -> a -> b
$ Pattern (Wise lore)
-> StmAux (ExpDec (Wise lore))
-> Exp (Wise lore)
-> Stm (Wise lore)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern (Wise lore)
pat' ((ExpWisdom, ()) -> StmAux (ExpWisdom, ())
forall dec. dec -> StmAux dec
defAux (ExpWisdom, ())
ExpDec (Wise lore)
dec) Exp (Wise lore)
e

simplifiable ::
  ( Engine.SimplifiableLore lore,
    ExpDec lore ~ (),
    BodyDec lore ~ (),
    Op lore ~ MemOp inner,
    Allocator lore (PatAllocM lore)
  ) =>
  (Engine.OpWithWisdom inner -> UT.UsageTable) ->
  (inner -> Engine.SimpleM lore (Engine.OpWithWisdom inner, Stms (Engine.Wise lore))) ->
  SimpleOps lore
simplifiable :: forall lore inner.
(SimplifiableLore lore, ExpDec lore ~ (), BodyDec lore ~ (),
 Op lore ~ MemOp inner, Allocator lore (PatAllocM lore)) =>
(OpWithWisdom inner -> UsageTable)
-> (inner -> SimpleM lore (OpWithWisdom inner, Stms (Wise lore)))
-> SimpleOps lore
simplifiable OpWithWisdom inner -> UsageTable
innerUsage inner -> SimpleM lore (OpWithWisdom inner, Stms (Wise lore))
simplifyInnerOp =
  (SymbolTable (Wise lore)
 -> Pattern (Wise lore)
 -> Exp (Wise lore)
 -> SimpleM lore (ExpDec (Wise lore)))
-> (SymbolTable (Wise lore)
    -> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore)))
-> Protect (BinderT (Wise lore) (State VNameSource))
-> (Op (Wise lore) -> UsageTable)
-> SimplifyOp lore (Op lore)
-> SimpleOps lore
forall lore.
(SymbolTable (Wise lore)
 -> Pattern (Wise lore)
 -> Exp (Wise lore)
 -> SimpleM lore (ExpDec (Wise lore)))
-> (SymbolTable (Wise lore)
    -> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore)))
-> Protect (Binder (Wise lore))
-> (Op (Wise lore) -> UsageTable)
-> SimplifyOp lore (Op lore)
-> SimpleOps lore
SimpleOps SymbolTable (Wise lore)
-> Pattern (Wise lore)
-> Exp (Wise lore)
-> SimpleM lore (ExpDec (Wise lore))
forall {m :: * -> *} {lore} {p}.
(Monad m, ASTLore lore, CanBeWise (Op lore), ExpDec lore ~ ()) =>
p
-> PatternT (VarWisdom, LetDec lore)
-> Exp (Wise lore)
-> m (ExpWisdom, ExpDec lore)
mkExpDecS' SymbolTable (Wise lore)
-> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall {m :: * -> *} {lore} {p}.
(Monad m, ASTLore lore, CanBeWise (Op lore), BodyDec lore ~ ()) =>
p -> Stms (Wise lore) -> Result -> m (Body (Wise lore))
mkBodyS' Protect (BinderT (Wise lore) (State VNameSource))
forall {m :: * -> *} {d} {u} {ret} {inner} {inner}.
(MonadBinder m, BranchType (Lore m) ~ MemInfo d u ret,
 Op (Lore m) ~ MemOp inner) =>
SubExp -> PatternT (LetDec (Lore m)) -> MemOp inner -> Maybe (m ())
protectOp Op (Wise lore) -> UsageTable
MemOp (OpWithWisdom inner) -> UsageTable
opUsage SimplifyOp lore (Op lore)
MemOp inner
-> SimpleM lore (MemOp (OpWithWisdom inner), Stms (Wise lore))
simplifyOp
  where
    mkExpDecS' :: p
-> PatternT (VarWisdom, LetDec lore)
-> Exp (Wise lore)
-> m (ExpWisdom, ExpDec lore)
mkExpDecS' p
_ PatternT (VarWisdom, LetDec lore)
pat Exp (Wise lore)
e =
      (ExpWisdom, ExpDec lore) -> m (ExpWisdom, ExpDec lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ExpDec lore) -> m (ExpWisdom, ExpDec lore))
-> (ExpWisdom, ExpDec lore) -> m (ExpWisdom, ExpDec lore)
forall a b. (a -> b) -> a -> b
$ Pattern (Wise lore)
-> ExpDec lore -> Exp (Wise lore) -> ExpDec (Wise lore)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
Pattern (Wise lore)
-> ExpDec lore -> Exp (Wise lore) -> ExpDec (Wise lore)
Engine.mkWiseExpDec PatternT (VarWisdom, LetDec lore)
Pattern (Wise lore)
pat () Exp (Wise lore)
e

    mkBodyS' :: p -> Stms (Wise lore) -> Result -> m (Body (Wise lore))
mkBodyS' p
_ Stms (Wise lore)
bnds Result
res = Body (Wise lore) -> m (Body (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise lore) -> m (Body (Wise lore)))
-> Body (Wise lore) -> m (Body (Wise lore))
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> Stms (Wise lore) -> Result -> Body (Wise lore)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore -> Stms (Wise lore) -> Result -> Body (Wise lore)
mkWiseBody () Stms (Wise lore)
bnds Result
res

    protectOp :: SubExp -> PatternT (LetDec (Lore m)) -> MemOp inner -> Maybe (m ())
protectOp SubExp
taken PatternT (LetDec (Lore m))
pat (Alloc SubExp
size Space
space) = m () -> Maybe (m ())
forall a. a -> Maybe a
Just (m () -> Maybe (m ())) -> m () -> Maybe (m ())
forall a b. (a -> b) -> a -> b
$ do
      BodyT (Lore m)
tbody <- Result -> m (BodyT (Lore m))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM [SubExp
size]
      BodyT (Lore m)
fbody <- Result -> m (BodyT (Lore m))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
      SubExp
size' <-
        String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"hoisted_alloc_size" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfDec (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
taken BodyT (Lore m)
tbody BodyT (Lore m)
fbody (IfDec (BranchType (Lore m)) -> Exp (Lore m))
-> IfDec (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [MemInfo d u ret] -> IfSort -> IfDec (MemInfo d u ret)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64] IfSort
IfFallback
      PatternT (LetDec (Lore m)) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT (LetDec (Lore m))
pat (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op (Lore m) -> Exp (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> Exp (Lore m)) -> Op (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size' Space
space
    protectOp SubExp
_ PatternT (LetDec (Lore m))
_ MemOp inner
_ = Maybe (m ())
forall a. Maybe a
Nothing

    opUsage :: MemOp (OpWithWisdom inner) -> UsageTable
opUsage (Alloc (Var VName
size) Space
_) =
      VName -> UsageTable
UT.sizeUsage VName
size
    opUsage (Alloc SubExp
_ Space
_) =
      UsageTable
forall a. Monoid a => a
mempty
    opUsage (Inner OpWithWisdom inner
inner) =
      OpWithWisdom inner -> UsageTable
innerUsage OpWithWisdom inner
inner

    simplifyOp :: MemOp inner
-> SimpleM lore (MemOp (OpWithWisdom inner), Stms (Wise lore))
simplifyOp (Alloc SubExp
size Space
space) =
      (,) (MemOp (OpWithWisdom inner)
 -> Stms (Wise lore)
 -> (MemOp (OpWithWisdom inner), Stms (Wise lore)))
-> SimpleM lore (MemOp (OpWithWisdom inner))
-> SimpleM
     lore
     (Stms (Wise lore)
      -> (MemOp (OpWithWisdom inner), Stms (Wise lore)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Space -> MemOp (OpWithWisdom inner)
forall inner. SubExp -> Space -> MemOp inner
Alloc (SubExp -> Space -> MemOp (OpWithWisdom inner))
-> SimpleM lore SubExp
-> SimpleM lore (Space -> MemOp (OpWithWisdom inner))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
size SimpleM lore (Space -> MemOp (OpWithWisdom inner))
-> SimpleM lore Space -> SimpleM lore (MemOp (OpWithWisdom inner))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> SimpleM lore Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space) SimpleM
  lore
  (Stms (Wise lore)
   -> (MemOp (OpWithWisdom inner), Stms (Wise lore)))
-> SimpleM lore (Stms (Wise lore))
-> SimpleM lore (MemOp (OpWithWisdom inner), Stms (Wise lore))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise lore) -> SimpleM lore (Stms (Wise lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms (Wise lore)
forall a. Monoid a => a
mempty
    simplifyOp (Inner inner
k) = do
      (OpWithWisdom inner
k', Stms (Wise lore)
hoisted) <- inner -> SimpleM lore (OpWithWisdom inner, Stms (Wise lore))
simplifyInnerOp inner
k
      (MemOp (OpWithWisdom inner), Stms (Wise lore))
-> SimpleM lore (MemOp (OpWithWisdom inner), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (OpWithWisdom inner -> MemOp (OpWithWisdom inner)
forall inner. inner -> MemOp inner
Inner OpWithWisdom inner
k', Stms (Wise lore)
hoisted)

bindPatternWithAllocations ::
  ( MonadBinder m,
    ExpDec lore ~ (),
    Op (Lore m) ~ MemOp inner,
    Allocator lore (PatAllocM lore)
  ) =>
  Scope lore ->
  [VName] ->
  Exp lore ->
  m (Pattern lore)
bindPatternWithAllocations :: forall (m :: * -> *) lore inner.
(MonadBinder m, ExpDec lore ~ (), Op (Lore m) ~ MemOp inner,
 Allocator lore (PatAllocM lore)) =>
Scope lore -> [VName] -> Exp lore -> m (Pattern lore)
bindPatternWithAllocations Scope lore
types [VName]
names Exp lore
e = do
  (PatternT LParamMem
pat, [AllocStm]
prebnds) <- PatAllocM lore (PatternT LParamMem)
-> Scope lore -> m (PatternT LParamMem, [AllocStm])
forall (m :: * -> *) lore a.
MonadFreshNames m =>
PatAllocM lore a -> Scope lore -> m (a, [AllocStm])
runPatAllocM ([VName] -> Exp lore -> PatAllocM lore (Pattern lore)
forall lore (m :: * -> *).
(Allocator lore m, ExpDec lore ~ ()) =>
[VName] -> Exp lore -> m (Pattern lore)
patternWithAllocations [VName]
names Exp lore
e) Scope lore
types
  (AllocStm -> m ()) -> [AllocStm] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ AllocStm -> m ()
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ MemOp inner) =>
AllocStm -> m ()
bindAllocStm [AllocStm]
prebnds
  PatternT LParamMem -> m (PatternT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return PatternT LParamMem
pat

data ExpHint
  = NoHint
  | Hint IxFun Space

defaultExpHints :: (Monad m, ASTLore lore) => Exp lore -> m [ExpHint]
defaultExpHints :: forall (m :: * -> *) lore.
(Monad m, ASTLore lore) =>
Exp lore -> m [ExpHint]
defaultExpHints Exp lore
e = [ExpHint] -> m [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpHint] -> m [ExpHint]) -> [ExpHint] -> m [ExpHint]
forall a b. (a -> b) -> a -> b
$ Int -> ExpHint -> [ExpHint]
forall a. Int -> a -> [a]
replicate (Exp lore -> Int
forall lore.
(Decorations lore, TypedOp (Op lore)) =>
Exp lore -> Int
expExtTypeSize Exp lore
e) ExpHint
NoHint