{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}

-- | A generic transformation for adding memory allocations to a
-- Futhark program.  Specialised by specific representations in
-- submodules.
module Futhark.Pass.ExplicitAllocations
  ( explicitAllocationsGeneric,
    explicitAllocationsInStmsGeneric,
    ExpHint (..),
    defaultExpHints,
    askDefaultSpace,
    Allocable,
    AllocM,
    AllocEnv (..),
    SizeSubst (..),
    allocInStms,
    allocForArray,
    simplifiable,
    mkLetNamesB',
    mkLetNamesB'',

    -- * Module re-exports

    --
    -- These are highly likely to be needed by any downstream
    -- users.
    module Control.Monad.Reader,
    module Futhark.MonadFreshNames,
    module Futhark.Pass,
    module Futhark.Tools,
  )
where

import Control.Monad
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor (first)
import Data.Either (partitionEithers)
import Data.List (foldl', transpose, zip4)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Analysis.SymbolTable (IndexOp)
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR.Mem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.IR.Prop.Aliases (AliasedOp)
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Engine (SimpleOps (..))
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep (mkWiseBody)
import Futhark.Pass
import Futhark.Tools
import Futhark.Util (maybeNth, splitAt3)

type Allocable fromrep torep inner =
  ( PrettyRep fromrep,
    PrettyRep torep,
    Mem torep inner,
    LetDec torep ~ LetDecMem,
    FParamInfo fromrep ~ DeclType,
    LParamInfo fromrep ~ Type,
    BranchType fromrep ~ ExtType,
    RetType fromrep ~ DeclExtType,
    BodyDec fromrep ~ (),
    BodyDec torep ~ (),
    ExpDec torep ~ (),
    SizeSubst (inner torep),
    BuilderOps torep
  )

data AllocEnv fromrep torep = AllocEnv
  { -- | When allocating memory, put it in this memory space.
    -- This is primarily used to ensure that group-wide
    -- statements store their results in shared memory.
    forall fromrep torep. AllocEnv fromrep torep -> Space
allocSpace :: Space,
    -- | The set of names that are known to be constants at
    -- kernel compile time.
    forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts :: S.Set VName,
    forall fromrep torep.
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep),
    forall fromrep torep.
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
  }

-- | Monad for adding allocations to an entire program.
newtype AllocM fromrep torep a
  = AllocM (BuilderT torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a)
  deriving
    ( Functor (AllocM fromrep torep)
Functor (AllocM fromrep torep) =>
(forall a. a -> AllocM fromrep torep a)
-> (forall a b.
    AllocM fromrep torep (a -> b)
    -> AllocM fromrep torep a -> AllocM fromrep torep b)
-> (forall a b c.
    (a -> b -> c)
    -> AllocM fromrep torep a
    -> AllocM fromrep torep b
    -> AllocM fromrep torep c)
-> (forall a b.
    AllocM fromrep torep a
    -> AllocM fromrep torep b -> AllocM fromrep torep b)
-> (forall a b.
    AllocM fromrep torep a
    -> AllocM fromrep torep b -> AllocM fromrep torep a)
-> Applicative (AllocM fromrep torep)
forall a. a -> AllocM fromrep torep a
forall fromrep torep. Functor (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall fromrep torep a. a -> AllocM fromrep torep a
pure :: forall a. a -> AllocM fromrep torep a
$c<*> :: forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
<*> :: forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
$cliftA2 :: forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
liftA2 :: forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
$c*> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
*> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c<* :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
<* :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
Applicative,
      (forall a b.
 (a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b)
-> (forall a b.
    a -> AllocM fromrep torep b -> AllocM fromrep torep a)
-> Functor (AllocM fromrep torep)
forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
fmap :: forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
$c<$ :: forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
<$ :: forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
Functor,
      Applicative (AllocM fromrep torep)
Applicative (AllocM fromrep torep) =>
(forall a b.
 AllocM fromrep torep a
 -> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b)
-> (forall a b.
    AllocM fromrep torep a
    -> AllocM fromrep torep b -> AllocM fromrep torep b)
-> (forall a. a -> AllocM fromrep torep a)
-> Monad (AllocM fromrep torep)
forall a. a -> AllocM fromrep torep a
forall fromrep torep. Applicative (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
>>= :: forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
$c>> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
>> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$creturn :: forall fromrep torep a. a -> AllocM fromrep torep a
return :: forall a. a -> AllocM fromrep torep a
Monad,
      Monad (AllocM fromrep torep)
AllocM fromrep torep VNameSource
Monad (AllocM fromrep torep) =>
AllocM fromrep torep VNameSource
-> (VNameSource -> AllocM fromrep torep ())
-> MonadFreshNames (AllocM fromrep torep)
VNameSource -> AllocM fromrep torep ()
forall fromrep torep. Monad (AllocM fromrep torep)
forall fromrep torep. AllocM fromrep torep VNameSource
forall fromrep torep. VNameSource -> AllocM fromrep torep ()
forall (m :: * -> *).
Monad m =>
m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
$cgetNameSource :: forall fromrep torep. AllocM fromrep torep VNameSource
getNameSource :: AllocM fromrep torep VNameSource
$cputNameSource :: forall fromrep torep. VNameSource -> AllocM fromrep torep ()
putNameSource :: VNameSource -> AllocM fromrep torep ()
MonadFreshNames,
      HasScope torep,
      LocalScope torep,
      MonadReader (AllocEnv fromrep torep)
    )

instance (Allocable fromrep torep inner) => MonadBuilder (AllocM fromrep torep) where
  type Rep (AllocM fromrep torep) = torep

  mkExpDecM :: Pat (LetDec (Rep (AllocM fromrep torep)))
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (ExpDec (Rep (AllocM fromrep torep)))
mkExpDecM Pat (LetDec (Rep (AllocM fromrep torep)))
_ Exp (Rep (AllocM fromrep torep))
_ = () -> AllocM fromrep torep ()
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  mkLetNamesM :: [VName]
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (Stm (Rep (AllocM fromrep torep)))
mkLetNamesM [VName]
names Exp (Rep (AllocM fromrep torep))
e = do
    Space
def_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
    [ExpHint]
hints <- Exp torep -> AllocM fromrep torep [ExpHint]
forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
Exp (Rep (AllocM fromrep torep))
e
    Pat LParamMem
pat <- Space
-> [VName]
-> Exp (Rep (AllocM fromrep torep))
-> [ExpHint]
-> AllocM fromrep torep (Pat LParamMem)
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
def_space [VName]
names Exp (Rep (AllocM fromrep torep))
e [ExpHint]
hints
    Stm torep -> AllocM fromrep torep (Stm torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm torep -> AllocM fromrep torep (Stm torep))
-> Stm torep -> AllocM fromrep torep (Stm torep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec torep)
-> StmAux (ExpDec torep) -> Exp torep -> Stm torep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec torep)
Pat LParamMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) Exp torep
Exp (Rep (AllocM fromrep torep))
e

  mkBodyM :: Stms (Rep (AllocM fromrep torep))
-> Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
mkBodyM Stms (Rep (AllocM fromrep torep))
stms Result
res = Body (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep (AllocM fromrep torep))
 -> AllocM fromrep torep (Body (Rep (AllocM fromrep torep))))
-> Body (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ BodyDec torep -> Stms torep -> Result -> Body torep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms torep
Stms (Rep (AllocM fromrep torep))
stms Result
res

  addStms :: Stms (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
addStms = BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
-> AllocM fromrep torep ()
forall fromrep torep a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM (BuilderT
   torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
 -> AllocM fromrep torep ())
-> (Stms torep
    -> BuilderT
         torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ())
-> Stms torep
-> AllocM fromrep torep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms torep
-> BuilderT
     torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
Stms
  (Rep
     (BuilderT
        torep (ReaderT (AllocEnv fromrep torep) (State VNameSource))))
-> BuilderT
     torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
  collectStms :: forall a.
AllocM fromrep torep a
-> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep)))
collectStms (AllocM BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) = BuilderT
  torep
  (ReaderT (AllocEnv fromrep torep) (State VNameSource))
  (a, Stms (Rep (AllocM fromrep torep)))
-> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep)))
forall fromrep torep a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM (BuilderT
   torep
   (ReaderT (AllocEnv fromrep torep) (State VNameSource))
   (a, Stms (Rep (AllocM fromrep torep)))
 -> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep))))
-> BuilderT
     torep
     (ReaderT (AllocEnv fromrep torep) (State VNameSource))
     (a, Stms (Rep (AllocM fromrep torep)))
-> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> BuilderT
     torep
     (ReaderT (AllocEnv fromrep torep) (State VNameSource))
     (a,
      Stms
        (Rep
           (BuilderT
              torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)))))
forall a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> BuilderT
     torep
     (ReaderT (AllocEnv fromrep torep) (State VNameSource))
     (a,
      Stms
        (Rep
           (BuilderT
              torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m

expHints :: Exp torep -> AllocM fromrep torep [ExpHint]
expHints :: forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e = do
  Exp torep -> AllocM fromrep torep [ExpHint]
f <- (AllocEnv fromrep torep
 -> Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM
     fromrep torep (Exp torep -> AllocM fromrep torep [ExpHint])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
forall fromrep torep.
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints
  Exp torep -> AllocM fromrep torep [ExpHint]
f Exp torep
e

-- | The space in which we allocate memory if we have no other
-- preferences or constraints.
askDefaultSpace :: AllocM fromrep torep Space
askDefaultSpace :: forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace = (AllocEnv fromrep torep -> Space) -> AllocM fromrep torep Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> Space
forall fromrep torep. AllocEnv fromrep torep -> Space
allocSpace

runAllocM ::
  (MonadFreshNames m) =>
  Space ->
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  AllocM fromrep torep a ->
  m a
runAllocM :: forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) =
  ((a, Stms torep) -> a) -> m (a, Stms torep) -> m a
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms torep) -> a
forall a b. (a, b) -> a
fst (m (a, Stms torep) -> m a) -> m (a, Stms torep) -> m a
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((a, Stms torep), VNameSource))
-> m (a, Stms torep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Stms torep), VNameSource))
 -> m (a, Stms torep))
-> (VNameSource -> ((a, Stms torep), VNameSource))
-> m (a, Stms torep)
forall a b. (a -> b) -> a -> b
$ State VNameSource (a, Stms torep)
-> VNameSource -> ((a, Stms torep), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (a, Stms torep)
 -> VNameSource -> ((a, Stms torep), VNameSource))
-> State VNameSource (a, Stms torep)
-> VNameSource
-> ((a, Stms torep), VNameSource)
forall a b. (a -> b) -> a -> b
$ ReaderT
  (AllocEnv fromrep torep) (State VNameSource) (a, Stms torep)
-> AllocEnv fromrep torep -> State VNameSource (a, Stms torep)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> Scope torep
-> ReaderT
     (AllocEnv fromrep torep) (State VNameSource) (a, Stms torep)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m Scope torep
forall a. Monoid a => a
mempty) AllocEnv fromrep torep
env
  where
    env :: AllocEnv fromrep torep
env =
      AllocEnv
        { allocSpace :: Space
allocSpace = Space
space,
          envConsts :: Set VName
envConsts = Set VName
forall a. Monoid a => a
mempty,
          allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp = Op fromrep -> AllocM fromrep torep (Op torep)
handleOp,
          envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints = Exp torep -> AllocM fromrep torep [ExpHint]
hints
        }

elemSize :: (Num a) => Type -> a
elemSize :: forall a. Num a => Type -> a
elemSize = PrimType -> a
forall a. Num a => PrimType -> a
primByteSize (PrimType -> a) -> (Type -> PrimType) -> Type -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType

arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp Type
t =
  TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName
 -> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(*) (Type -> TPrimExp Int64 VName
forall a. Num a => Type -> a
elemSize Type
t) ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)

arraySizeInBytes :: (MonadBuilder m) => Type -> m SubExp
arraySizeInBytes :: forall (m :: * -> *). MonadBuilder m => Type -> m SubExp
arraySizeInBytes = String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" (Exp (Rep m) -> m SubExp)
-> (Type -> m (Exp (Rep m))) -> Type -> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp (PrimExp VName -> m (Exp (Rep m)))
-> (Type -> PrimExp VName) -> Type -> m (Exp (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> PrimExp VName
arraySizeInBytesExp

allocForArray' ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
  Type ->
  Space ->
  m VName
allocForArray' :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
t Space
space = do
  SubExp
size <- Type -> m SubExp
forall (m :: * -> *). MonadBuilder m => Type -> m SubExp
arraySizeInBytes Type
t
  String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner (Rep m)
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
size Space
space

-- | Allocate memory for a value of the given type.
allocForArray ::
  (Allocable fromrep torep inner) =>
  Type ->
  Space ->
  AllocM fromrep torep VName
allocForArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space = do
  Type -> Space -> AllocM fromrep torep VName
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
t Space
space

-- | Repair an expression that cannot be assigned an index function.
-- There is a simple remedy for this: normalise the input arrays and
-- try again.
repairExpression ::
  (Allocable fromrep torep inner) =>
  Exp torep ->
  AllocM fromrep torep (Exp torep)
repairExpression :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp torep -> AllocM fromrep torep (Exp torep)
repairExpression (BasicOp (Reshape ReshapeKind
k Shape
shape VName
v)) = do
  VName
v_mem <- (VName, LMAD (TPrimExp Int64 VName)) -> VName
forall a b. (a, b) -> a
fst ((VName, LMAD (TPrimExp Int64 VName)) -> VName)
-> AllocM fromrep torep (VName, LMAD (TPrimExp Int64 VName))
-> AllocM fromrep torep VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> AllocM fromrep torep (VName, LMAD (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, LMAD (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
space <- VName -> AllocM fromrep torep Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem
  VName
v' <- (VName, VName) -> VName
forall a b. (a, b) -> b
snd ((VName, VName) -> VName)
-> AllocM fromrep torep (VName, VName)
-> AllocM fromrep torep VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) VName
v
  Exp torep -> AllocM fromrep torep (Exp torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp torep -> AllocM fromrep torep (Exp torep))
-> Exp torep -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp torep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp torep) -> BasicOp -> Exp torep
forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k Shape
shape VName
v'
repairExpression Exp torep
e =
  String -> AllocM fromrep torep (Exp torep)
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep torep (Exp torep))
-> String -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ String
"repairExpression:\n" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Exp torep -> String
forall a. Pretty a => a -> String
prettyString Exp torep
e

expReturns' ::
  (Allocable fromrep torep inner) =>
  Exp torep ->
  AllocM fromrep torep ([ExpReturns], Exp torep)
expReturns' :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp torep -> AllocM fromrep torep ([ExpReturns], Exp torep)
expReturns' Exp torep
e = do
  Maybe [ExpReturns]
maybe_rts <- Exp torep -> AllocM fromrep torep (Maybe [ExpReturns])
forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m (Maybe [ExpReturns])
expReturns Exp torep
e
  case Maybe [ExpReturns]
maybe_rts of
    Just [ExpReturns]
rts -> ([ExpReturns], Exp torep)
-> AllocM fromrep torep ([ExpReturns], Exp torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExpReturns]
rts, Exp torep
e)
    Maybe [ExpReturns]
Nothing -> do
      Exp torep
e' <- Exp torep -> AllocM fromrep torep (Exp torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp torep -> AllocM fromrep torep (Exp torep)
repairExpression Exp torep
e
      let bad :: [ExpReturns]
bad =
            String -> [ExpReturns]
forall a. HasCallStack => String -> a
error (String -> [ExpReturns])
-> ([String] -> String) -> [String] -> [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unlines ([String] -> [ExpReturns]) -> [String] -> [ExpReturns]
forall a b. (a -> b) -> a -> b
$
              [ String
"expReturns': impossible index transformation",
                Exp torep -> String
forall a. Pretty a => a -> String
prettyString Exp torep
e,
                Exp torep -> String
forall a. Pretty a => a -> String
prettyString Exp torep
e'
              ]
      [ExpReturns]
rts <- [ExpReturns] -> Maybe [ExpReturns] -> [ExpReturns]
forall a. a -> Maybe a -> a
fromMaybe [ExpReturns]
bad (Maybe [ExpReturns] -> [ExpReturns])
-> AllocM fromrep torep (Maybe [ExpReturns])
-> AllocM fromrep torep [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp torep -> AllocM fromrep torep (Maybe [ExpReturns])
forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m (Maybe [ExpReturns])
expReturns Exp torep
e'
      ([ExpReturns], Exp torep)
-> AllocM fromrep torep ([ExpReturns], Exp torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExpReturns]
rts, Exp torep
e')

allocsForStm ::
  (Allocable fromrep torep inner) =>
  [Ident] ->
  Exp torep ->
  AllocM fromrep torep (Stm torep)
allocsForStm :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm [Ident]
idents Exp torep
e = do
  Space
def_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  [ExpHint]
hints <- Exp torep -> AllocM fromrep torep [ExpHint]
forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e
  ([ExpReturns]
rts, Exp torep
e') <- Exp torep -> AllocM fromrep torep ([ExpReturns], Exp torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp torep -> AllocM fromrep torep ([ExpReturns], Exp torep)
expReturns' Exp torep
e
  [PatElem LParamMem]
pes <- Space
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> AllocM fromrep torep [PatElem LParamMem]
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
idents [ExpReturns]
rts [ExpHint]
hints
  ()
dec <- Pat (LetDec (Rep (AllocM fromrep torep)))
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (ExpDec (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
pes) Exp torep
Exp (Rep (AllocM fromrep torep))
e'
  Stm torep -> AllocM fromrep torep (Stm torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm torep -> AllocM fromrep torep (Stm torep))
-> Stm torep -> AllocM fromrep torep (Stm torep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec torep)
-> StmAux (ExpDec torep) -> Exp torep -> Stm torep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
pes) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()
dec) Exp torep
e'

patWithAllocations ::
  (MonadBuilder m, Mem (Rep m) inner) =>
  Space ->
  [VName] ->
  Exp (Rep m) ->
  [ExpHint] ->
  m (Pat LetDecMem)
patWithAllocations :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
def_space [VName]
names Exp (Rep m)
e [ExpHint]
hints = do
  [Type]
ts' <- [VName] -> [ExtType] -> [Type]
forall u. [VName] -> [TypeBase ExtShape u] -> [TypeBase Shape u]
instantiateShapes' [VName]
names ([ExtType] -> [Type]) -> m [ExtType] -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp (Rep m) -> m [ExtType]
forall rep (m :: * -> *).
(HasScope rep m, TypedOp (OpC rep)) =>
Exp rep -> m [ExtType]
expExtType Exp (Rep m)
e
  let idents :: [Ident]
idents = (VName -> Type -> Ident) -> [VName] -> [Type] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
names [Type]
ts'
  [ExpReturns]
rts <- [ExpReturns] -> Maybe [ExpReturns] -> [ExpReturns]
forall a. a -> Maybe a -> a
fromMaybe (String -> [ExpReturns]
forall a. HasCallStack => String -> a
error String
"patWithAllocations: ill-typed") (Maybe [ExpReturns] -> [ExpReturns])
-> m (Maybe [ExpReturns]) -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp (Rep m) -> m (Maybe [ExpReturns])
forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m (Maybe [ExpReturns])
expReturns Exp (Rep m)
e
  [PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem LParamMem] -> Pat LParamMem)
-> m [PatElem LParamMem] -> m (Pat LParamMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
idents [ExpReturns]
rts [ExpHint]
hints

mkMissingIdents :: (MonadFreshNames m) => [Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents :: forall (m :: * -> *).
MonadFreshNames m =>
[Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
idents [ExpReturns]
rts =
  [Ident] -> [Ident]
forall a. [a] -> [a]
reverse ([Ident] -> [Ident]) -> m [Ident] -> m [Ident]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ExpReturns -> Maybe Ident -> m Ident)
-> [ExpReturns] -> [Maybe Ident] -> m [Ident]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ExpReturns -> Maybe Ident -> m Ident
forall {f :: * -> *} {d} {u} {ret}.
MonadFreshNames f =>
MemInfo d u ret -> Maybe Ident -> f Ident
f ([ExpReturns] -> [ExpReturns]
forall a. [a] -> [a]
reverse [ExpReturns]
rts) ((Ident -> Maybe Ident) -> [Ident] -> [Maybe Ident]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Maybe Ident
forall a. a -> Maybe a
Just ([Ident] -> [Ident]
forall a. [a] -> [a]
reverse [Ident]
idents) [Maybe Ident] -> [Maybe Ident] -> [Maybe Ident]
forall a. [a] -> [a] -> [a]
++ Maybe Ident -> [Maybe Ident]
forall a. a -> [a]
repeat Maybe Ident
forall a. Maybe a
Nothing)
  where
    f :: MemInfo d u ret -> Maybe Ident -> f Ident
f MemInfo d u ret
_ (Just Ident
ident) = Ident -> f Ident
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ident
ident
    f (MemMem Space
space) Maybe Ident
Nothing = String -> Type -> f Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext_mem" (Type -> f Ident) -> Type -> f Ident
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
    f MemInfo d u ret
_ Maybe Ident
Nothing = String -> Type -> f Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext" (Type -> f Ident) -> Type -> f Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

allocsForPat ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
  Space ->
  [Ident] ->
  [ExpReturns] ->
  [ExpHint] ->
  m [PatElem LetDecMem]
allocsForPat :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
some_idents [ExpReturns]
rts [ExpHint]
hints = do
  [Ident]
idents <- [Ident] -> [ExpReturns] -> m [Ident]
forall (m :: * -> *).
MonadFreshNames m =>
[Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
some_idents [ExpReturns]
rts

  [(Ident, ExpReturns, ExpHint)]
-> ((Ident, ExpReturns, ExpHint) -> m (PatElem LParamMem))
-> m [PatElem LParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Ident]
-> [ExpReturns] -> [ExpHint] -> [(Ident, ExpReturns, ExpHint)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Ident]
idents [ExpReturns]
rts [ExpHint]
hints) (((Ident, ExpReturns, ExpHint) -> m (PatElem LParamMem))
 -> m [PatElem LParamMem])
-> ((Ident, ExpReturns, ExpHint) -> m (PatElem LParamMem))
-> m [PatElem LParamMem]
forall a b. (a -> b) -> a -> b
$ \(Ident
ident, ExpReturns
rt, ExpHint
hint) -> do
    let ident_shape :: Shape
ident_shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Type -> Shape
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident
    case ExpReturns
rt of
      MemPrim PrimType
_ -> do
        LParamMem
summary <- Space -> Type -> ExpHint -> m LParamMem
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
def_space (Ident -> Type
identType Ident
ident) ExpHint
hint
        PatElem LParamMem -> m (PatElem LParamMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> PatElem LParamMem -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
      MemMem Space
space ->
        PatElem LParamMem -> m (PatElem LParamMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> PatElem LParamMem -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElem LParamMem) -> LParamMem -> PatElem LParamMem
forall a b. (a -> b) -> a -> b
$ Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
      MemArray PrimType
bt ExtShape
_ NoUniqueness
u (Just (ReturnsInBlock VName
mem ExtLMAD
extlmad)) -> do
        let ixfn :: LMAD (TPrimExp Int64 VName)
ixfn = [Ident] -> ExtLMAD -> LMAD (TPrimExp Int64 VName)
forall {f :: * -> *} {f :: * -> *}.
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtLMAD [Ident]
idents ExtLMAD
extlmad
        PatElem LParamMem -> m (PatElem LParamMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> (MemBind -> PatElem LParamMem)
-> MemBind
-> m (PatElem LParamMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElem LParamMem)
-> (MemBind -> LParamMem) -> MemBind -> PatElem LParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u (MemBind -> m (PatElem LParamMem))
-> MemBind -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LMAD (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem LMAD (TPrimExp Int64 VName)
ixfn
      MemArray PrimType
_ ExtShape
extshape NoUniqueness
_ Maybe MemReturn
Nothing
        | Just [SubExp]
_ <- ExtShape -> Maybe [SubExp]
forall {b}. ShapeBase (Ext b) -> Maybe [b]
knownShape ExtShape
extshape -> do
            LParamMem
summary <- Space -> Type -> ExpHint -> m LParamMem
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
def_space (Ident -> Type
identType Ident
ident) ExpHint
hint
            PatElem LParamMem -> m (PatElem LParamMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> PatElem LParamMem -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
      MemArray PrimType
bt ExtShape
_ NoUniqueness
u (Just (ReturnsNewBlock Space
_ Int
i ExtLMAD
extixfn)) -> do
        let ixfn :: LMAD (TPrimExp Int64 VName)
ixfn = [Ident] -> ExtLMAD -> LMAD (TPrimExp Int64 VName)
forall {f :: * -> *} {f :: * -> *}.
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtLMAD [Ident]
idents ExtLMAD
extixfn
        PatElem LParamMem -> m (PatElem LParamMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> (MemBind -> PatElem LParamMem)
-> MemBind
-> m (PatElem LParamMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElem LParamMem)
-> (MemBind -> LParamMem) -> MemBind -> PatElem LParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u (MemBind -> m (PatElem LParamMem))
-> MemBind -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$
          VName -> LMAD (TPrimExp Int64 VName) -> MemBind
ArrayIn ([Ident] -> Int -> VName
forall {a}. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i) LMAD (TPrimExp Int64 VName)
ixfn
      MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
        PatElem LParamMem -> m (PatElem LParamMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> PatElem LParamMem -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElem LParamMem) -> LParamMem -> PatElem LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
      ExpReturns
_ -> String -> m (PatElem LParamMem)
forall a. HasCallStack => String -> a
error String
"Impossible case reached in allocsForPat!"
  where
    knownShape :: ShapeBase (Ext b) -> Maybe [b]
knownShape = (Ext b -> Maybe b) -> [Ext b] -> Maybe [b]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Ext b -> Maybe b
forall {a}. Ext a -> Maybe a
known ([Ext b] -> Maybe [b])
-> (ShapeBase (Ext b) -> [Ext b]) -> ShapeBase (Ext b) -> Maybe [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase (Ext b) -> [Ext b]
forall d. ShapeBase d -> [d]
shapeDims
    known :: Ext a -> Maybe a
known (Free a
v) = a -> Maybe a
forall a. a -> Maybe a
Just a
v
    known Ext {} = Maybe a
forall a. Maybe a
Nothing

    getIdent :: [Ident] -> a -> VName
getIdent [Ident]
idents a
i =
      case a -> [Ident] -> Maybe Ident
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth a
i [Ident]
idents of
        Just Ident
ident -> Ident -> VName
identName Ident
ident
        Maybe Ident
Nothing ->
          String -> VName
forall a. HasCallStack => String -> a
error (String -> VName) -> String -> VName
forall a b. (a -> b) -> a -> b
$ String
"getIdent: Ext " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
i String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" but pattern has " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([Ident] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
idents) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" elements: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> [Ident] -> String
forall a. Pretty a => a -> String
prettyString [Ident]
idents

    instantiateExtLMAD :: [Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtLMAD [Ident]
idents = (f (Ext VName) -> f VName) -> f (f (Ext VName)) -> f (f VName)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((f (Ext VName) -> f VName) -> f (f (Ext VName)) -> f (f VName))
-> (f (Ext VName) -> f VName) -> f (f (Ext VName)) -> f (f VName)
forall a b. (a -> b) -> a -> b
$ (Ext VName -> VName) -> f (Ext VName) -> f VName
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> VName
inst
      where
        inst :: Ext VName -> VName
inst (Free VName
v) = VName
v
        inst (Ext Int
i) = [Ident] -> Int -> VName
forall {a}. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i

instantiateLMAD :: (Monad m) => ExtLMAD -> m LMAD
instantiateLMAD :: forall (m :: * -> *).
Monad m =>
ExtLMAD -> m (LMAD (TPrimExp Int64 VName))
instantiateLMAD = (TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
-> ExtLMAD -> m (LMAD (TPrimExp Int64 VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse ((TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
 -> ExtLMAD -> m (LMAD (TPrimExp Int64 VName)))
-> (TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
-> ExtLMAD
-> m (LMAD (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ (Ext VName -> m VName)
-> TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> TPrimExp Int64 a -> f (TPrimExp Int64 b)
traverse Ext VName -> m VName
forall {f :: * -> *} {a}. Applicative f => Ext a -> f a
inst
  where
    inst :: Ext a -> f a
inst Ext {} = String -> f a
forall a. HasCallStack => String -> a
error String
"instantiateLMAD: not yet"
    inst (Free a
x) = a -> f a
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

summaryForBindage ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
  Space ->
  Type ->
  ExpHint ->
  m (MemBound NoUniqueness)
summaryForBindage :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
_ (Prim PrimType
bt) ExpHint
_ =
  LParamMem -> m LParamMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
summaryForBindage Space
_ (Mem Space
space) ExpHint
_ =
  LParamMem -> m LParamMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
summaryForBindage Space
_ (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) ExpHint
_ =
  LParamMem -> m LParamMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
summaryForBindage Space
def_space t :: Type
t@(Array PrimType
pt Shape
shape NoUniqueness
u) ExpHint
NoHint = do
  VName
m <- Type -> Space -> m VName
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
t Space
def_space
  LParamMem -> m LParamMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> LMAD (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m (LMAD (TPrimExp Int64 VName) -> MemBind)
-> LMAD (TPrimExp Int64 VName) -> MemBind
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> LMAD (TPrimExp Int64 VName)
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 VName
0 ([TPrimExp Int64 VName] -> LMAD (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> LMAD (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
summaryForBindage Space
_ t :: Type
t@(Array PrimType
pt Shape
_ NoUniqueness
_) (Hint LMAD (TPrimExp Int64 VName)
lmad Space
space) = do
  SubExp
bytes <-
    String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" (Exp (Rep m) -> m SubExp)
-> (TPrimExp Int64 VName -> m (Exp (Rep m)))
-> TPrimExp Int64 VName
-> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp (PrimExp VName -> m (Exp (Rep m)))
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> m (Exp (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> m SubExp)
-> TPrimExp Int64 VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
      PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
pt TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* (TPrimExp Int64 VName
1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num.
Pretty num =>
LMAD (TPrimExp Int64 num) -> TPrimExp Int64 num
LMAD.range LMAD (TPrimExp Int64 VName)
lmad)
  VName
m <- String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner (Rep m)
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
bytes Space
space
  LParamMem -> m LParamMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> LMAD (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m LMAD (TPrimExp Int64 VName)
lmad

allocInFParams ::
  (Allocable fromrep torep inner) =>
  [(FParam fromrep, Space)] ->
  ([FParam torep] -> AllocM fromrep torep a) ->
  AllocM fromrep torep a
allocInFParams :: forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams [(FParam fromrep, Space)]
params [FParam torep] -> AllocM fromrep torep a
m = do
  ([Param FParamMem]
valparams, ([Param FParamMem]
memparams, [Param FParamMem]
ctxparams)) <-
    WriterT
  ([Param FParamMem], [Param FParamMem])
  (AllocM fromrep torep)
  [Param FParamMem]
-> AllocM
     fromrep
     torep
     ([Param FParamMem], ([Param FParamMem], [Param FParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param FParamMem], [Param FParamMem])
   (AllocM fromrep torep)
   [Param FParamMem]
 -> AllocM
      fromrep
      torep
      ([Param FParamMem], ([Param FParamMem], [Param FParamMem])))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [Param FParamMem]
-> AllocM
     fromrep
     torep
     ([Param FParamMem], ([Param FParamMem], [Param FParamMem]))
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, Space)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> [(Param DeclType, Space)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [Param FParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Param DeclType
 -> Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> (Param DeclType, Space)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Param DeclType
-> Space
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam) [(Param DeclType, Space)]
[(FParam fromrep, Space)]
params
  let params' :: [Param FParamMem]
params' = [Param FParamMem]
memparams [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctxparams [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = [Param FParamMem] -> Scope torep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
params'
  Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall a.
Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [FParam torep] -> AllocM fromrep torep a
m [FParam torep]
[Param FParamMem]
params'

allocInFParam ::
  (Allocable fromrep torep inner) =>
  FParam fromrep ->
  Space ->
  WriterT
    ([FParam torep], [FParam torep])
    (AllocM fromrep torep)
    (FParam torep)
allocInFParam :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam FParam fromrep
param Space
pspace =
  case Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
FParam fromrep
param of
    Array PrimType
pt Shape
shape Uniqueness
u -> do
      let memname :: String
memname = VName -> String
baseString (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
FParam fromrep
param) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_mem"
          lmad :: LMAD (TPrimExp Int64 VName)
lmad = TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> LMAD (TPrimExp Int64 VName)
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 VName
0 ([TPrimExp Int64 VName] -> LMAD (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> LMAD (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName
mem <- AllocM fromrep torep VName
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) VName
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep VName
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      VName)
-> AllocM fromrep torep VName
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) VName
forall a b. (a -> b) -> a -> b
$ String -> AllocM fromrep torep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
memname
      ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Attrs -> VName -> FParamMem -> Param FParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param (Param DeclType -> Attrs
forall dec. Param dec -> Attrs
paramAttrs Param DeclType
FParam fromrep
param) VName
mem (FParamMem -> Param FParamMem) -> FParamMem -> Param FParamMem
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
pspace], [])
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a.
a
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec = MemArray pt shape u $ ArrayIn mem lmad}
    Prim PrimType
pt ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a.
a
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec = MemPrim pt}
    Mem Space
space ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a.
a
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec = MemMem space}
    Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a.
a
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec = MemAcc acc ispace ts u}

ensureRowMajorArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensureRowMajorArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray Maybe Space
space_ok VName
v = do
  (VName
mem, LMAD (TPrimExp Int64 VName)
_) <- VName -> AllocM fromrep torep (VName, LMAD (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, LMAD (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- VName -> AllocM fromrep torep Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  let space :: Space
space = Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok
  if Bool -> (Space -> Bool) -> Maybe Space -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then (VName, VName) -> AllocM fromrep torep (VName, VName)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
    else Space -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

ensureArrayIn ::
  (Allocable fromrep torep inner) =>
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
_ (Constant PrimValue
v) =
  String
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a. HasCallStack => String -> a
error (String
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> String
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
$ String
"ensureArrayIn: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimValue -> String
forall a. Pretty a => a -> String
prettyString PrimValue
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" cannot be an array."
ensureArrayIn Space
space (Var VName
v) = do
  (VName
mem', VName
v') <- AllocM fromrep torep (VName, VName)
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([SubExp], [SubExp]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
 -> WriterT
      ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall a b. (a -> b) -> a -> b
$ Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) VName
v
  (VName
_, LMAD (TPrimExp Int64 VName)
lmad) <- AllocM fromrep torep (VName, LMAD (TPrimExp Int64 VName))
-> WriterT
     ([SubExp], [SubExp])
     (AllocM fromrep torep)
     (VName, LMAD (TPrimExp Int64 VName))
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([SubExp], [SubExp]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, LMAD (TPrimExp Int64 VName))
 -> WriterT
      ([SubExp], [SubExp])
      (AllocM fromrep torep)
      (VName, LMAD (TPrimExp Int64 VName)))
-> AllocM fromrep torep (VName, LMAD (TPrimExp Int64 VName))
-> WriterT
     ([SubExp], [SubExp])
     (AllocM fromrep torep)
     (VName, LMAD (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep (VName, LMAD (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, LMAD (TPrimExp Int64 VName))
lookupArraySummary VName
v'
  [SubExp]
ctx <- AllocM fromrep torep [SubExp]
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([SubExp], [SubExp]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep [SubExp]
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp])
-> AllocM fromrep torep [SubExp]
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> AllocM fromrep torep SubExp)
-> [TPrimExp Int64 VName] -> AllocM fromrep torep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String
-> Exp (Rep (AllocM fromrep torep)) -> AllocM fromrep torep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"lmad_arg" (Exp torep -> AllocM fromrep torep SubExp)
-> (TPrimExp Int64 VName -> AllocM fromrep torep (Exp torep))
-> TPrimExp Int64 VName
-> AllocM fromrep torep SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName -> AllocM fromrep torep (Exp torep)
TPrimExp Int64 VName
-> AllocM fromrep torep (Exp (Rep (AllocM fromrep torep)))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp) (LMAD (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.existentialized LMAD (TPrimExp Int64 VName)
lmad)
  ([SubExp], [SubExp])
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
mem'], [SubExp]
ctx)
  SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a.
a -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'

allocInLoopParams ::
  (Allocable fromrep torep inner) =>
  [(FParam fromrep, SubExp)] ->
  ( [(FParam torep, SubExp)] ->
    ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])) ->
    AllocM fromrep torep a
  ) ->
  AllocM fromrep torep a
allocInLoopParams :: forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInLoopParams [(FParam fromrep, SubExp)]
merge [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m = do
  (([Param FParamMem]
valparams, [SubExp]
valargs, [SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps), ([Param FParamMem]
mem_params, [Param FParamMem]
ctx_params)) <-
    WriterT
  ([Param FParamMem], [Param FParamMem])
  (AllocM fromrep torep)
  ([Param FParamMem], [SubExp],
   [SubExp
    -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
-> AllocM
     fromrep
     torep
     (([Param FParamMem], [SubExp],
       [SubExp
        -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]),
      ([Param FParamMem], [Param FParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param FParamMem], [Param FParamMem])
   (AllocM fromrep torep)
   ([Param FParamMem], [SubExp],
    [SubExp
     -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
 -> AllocM
      fromrep
      torep
      (([Param FParamMem], [SubExp],
        [SubExp
         -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]),
       ([Param FParamMem], [Param FParamMem])))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     ([Param FParamMem], [SubExp],
      [SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
-> AllocM
     fromrep
     torep
     (([Param FParamMem], [SubExp],
       [SubExp
        -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]),
      ([Param FParamMem], [Param FParamMem]))
forall a b. (a -> b) -> a -> b
$ [(Param FParamMem, SubExp,
  SubExp
  -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
-> ([Param FParamMem], [SubExp],
    [SubExp
     -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Param FParamMem, SubExp,
   SubExp
   -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
 -> ([Param FParamMem], [SubExp],
     [SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [(Param FParamMem, SubExp,
       SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     ([Param FParamMem], [SubExp],
      [SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param DeclType, SubExp)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem, SubExp,
       SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp))
-> [(Param DeclType, SubExp)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [(Param FParamMem, SubExp,
       SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
(Param DeclType, SubExp)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam [(Param DeclType, SubExp)]
[(FParam fromrep, SubExp)]
merge
  let mergeparams' :: [Param FParamMem]
mergeparams' = [Param FParamMem]
mem_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = [Param FParamMem] -> Scope torep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
mergeparams'

      mk_loop_res :: [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
ses = do
        ([SubExp]
ses', ([SubExp]
memargs, [SubExp]
ctxargs)) <-
          WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
-> AllocM fromrep torep ([SubExp], ([SubExp], [SubExp]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
 -> AllocM fromrep torep ([SubExp], ([SubExp], [SubExp])))
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
-> AllocM fromrep torep ([SubExp], ([SubExp], [SubExp]))
forall a b. (a -> b) -> a -> b
$ ((SubExp
  -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
 -> SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> [SubExp
    -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
-> [SubExp]
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
($) [SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps [SubExp]
ses
        ([SubExp], [SubExp]) -> AllocM fromrep torep ([SubExp], [SubExp])
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
memargs [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
ctxargs, [SubExp]
ses')

  ([SubExp]
valctx_args, [SubExp]
valargs') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
valargs
  let merge' :: [(Param FParamMem, SubExp)]
merge' =
        [Param FParamMem] -> [SubExp] -> [(Param FParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Param FParamMem]
mem_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams) ([SubExp]
valctx_args [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
valargs')
  Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall a.
Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m [(FParam torep, SubExp)]
[(Param FParamMem, SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res
  where
    param_names :: Names
param_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
[(FParam fromrep, SubExp)]
merge
    anyIsLoopParam :: Names -> Bool
anyIsLoopParam Names
names = Names
names Names -> Names -> Bool
`namesIntersect` Names
param_names

    scalarRes :: DeclType
-> Space -> LMAD (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space LMAD (TPrimExp Int64 VName)
v_lmad (Var VName
res) = do
      -- Try really hard to avoid copying needlessly, but the result
      -- _must_ be in ScalarSpace and have the right index function.
      (VName
res_mem, LMAD (TPrimExp Int64 VName)
res_lmad) <- m (VName, LMAD (TPrimExp Int64 VName))
-> t m (VName, LMAD (TPrimExp Int64 VName))
forall (m :: * -> *) a. Monad m => m a -> t m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, LMAD (TPrimExp Int64 VName))
 -> t m (VName, LMAD (TPrimExp Int64 VName)))
-> m (VName, LMAD (TPrimExp Int64 VName))
-> t m (VName, LMAD (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ VName -> m (VName, LMAD (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, LMAD (TPrimExp Int64 VName))
lookupArraySummary VName
res
      Space
res_mem_space <- m Space -> t m Space
forall (m :: * -> *) a. Monad m => m a -> t m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Space -> t m Space) -> m Space -> t m Space
forall a b. (a -> b) -> a -> b
$ VName -> m Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
res_mem
      (VName
res_mem', VName
res') <-
        if (Space
res_mem_space, LMAD (TPrimExp Int64 VName)
res_lmad) (Space, LMAD (TPrimExp Int64 VName))
-> (Space, LMAD (TPrimExp Int64 VName)) -> Bool
forall a. Eq a => a -> a -> Bool
== (Space
v_mem_space, LMAD (TPrimExp Int64 VName)
v_lmad)
          then (VName, VName) -> t m (VName, VName)
forall a. a -> t m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
res_mem, VName
res)
          else m (VName, VName) -> t m (VName, VName)
forall (m :: * -> *) a. Monad m => m a -> t m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, VName) -> t m (VName, VName))
-> m (VName, VName) -> t m (VName, VName)
forall a b. (a -> b) -> a -> b
$ Space
-> LMAD (TPrimExp Int64 VName) -> Type -> VName -> m (VName, VName)
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m),
 LetDec (Rep m) ~ LParamMem) =>
Space
-> LMAD (TPrimExp Int64 VName) -> Type -> VName -> m (VName, VName)
arrayWithLMAD Space
v_mem_space LMAD (TPrimExp Int64 VName)
v_lmad (DeclType -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl DeclType
param_t) VName
res
      ([SubExp], [a]) -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
res_mem'], [])
      SubExp -> t m SubExp
forall a. a -> t m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> t m SubExp) -> SubExp -> t m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
res'
    scalarRes DeclType
_ Space
_ LMAD (TPrimExp Int64 VName)
_ SubExp
se = SubExp -> t m SubExp
forall a. a -> t m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

    allocInMergeParam ::
      (Allocable fromrep torep inner) =>
      (Param DeclType, SubExp) ->
      WriterT
        ([FParam torep], [FParam torep])
        (AllocM fromrep torep)
        ( FParam torep,
          SubExp,
          SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
        )
    allocInMergeParam :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, Var VName
v)
      | param_t :: DeclType
param_t@(Array PrimType
pt Shape
shape Uniqueness
u) <- Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
mergeparam = do
          (VName
v_mem, LMAD (TPrimExp Int64 VName)
v_lmad) <- AllocM fromrep torep (VName, LMAD (TPrimExp Int64 VName))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, LMAD (TPrimExp Int64 VName))
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, LMAD (TPrimExp Int64 VName))
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (VName, LMAD (TPrimExp Int64 VName)))
-> AllocM fromrep torep (VName, LMAD (TPrimExp Int64 VName))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, LMAD (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep (VName, LMAD (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, LMAD (TPrimExp Int64 VName))
lookupArraySummary VName
v
          Space
v_mem_space <- AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      Space)
-> AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem

          -- Loop-invariant array parameters that are in scalar space
          -- are special - we do not wish to existentialise their index
          -- function at all (but the memory block is still existential).
          case Space
v_mem_space of
            ScalarSpace {} ->
              if Names -> Bool
anyIsLoopParam (Shape -> Names
forall a. FreeIn a => a -> Names
freeIn Shape
shape)
                then do
                  -- Arrays with loop-variant shape cannot be in scalar
                  -- space, so copy them elsewhere and try again.
                  Space
space <- AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
                  (VName
_, VName
v') <- AllocM fromrep torep (VName, VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, VName)
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, VName)
forall a b. (a -> b) -> a -> b
$ Space -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v
                  (Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, VName -> SubExp
Var VName
v')
                else do
                  Param FParamMem
p <- String
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" (FParamMem
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space
                  ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem
p], [])

                  (Param FParamMem, SubExp,
 SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall a.
a
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
                    ( Param DeclType
mergeparam {paramDec = MemArray pt shape u $ ArrayIn (paramName p) v_lmad},
                      VName -> SubExp
Var VName
v,
                      DeclType
-> Space
-> LMAD (TPrimExp Int64 VName)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall {m :: * -> *} {inner :: * -> *} {a}
       {t :: (* -> *) -> * -> *}.
(OpC (Rep m) ~ MemOp inner, FParamInfo (Rep m) ~ FParamMem,
 RetType (Rep m) ~ RetTypeMem, LetDec (Rep m) ~ LParamMem,
 BranchType (Rep m) ~ BranchTypeMem, LParamInfo (Rep m) ~ LParamMem,
 RephraseOp inner, MonadWriter ([SubExp], [a]) (t m),
 MonadBuilder m, MonadTrans t, OpReturns inner,
 Pretty (inner (Rep m)), Rename (inner (Rep m)),
 Show (inner (Rep m)), Ord (inner (Rep m)),
 Substitute (inner (Rep m)), FreeIn (inner (Rep m))) =>
DeclType
-> Space -> LMAD (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space LMAD (TPrimExp Int64 VName)
v_lmad
                    )
            Space
_ -> do
              (VName
v_mem', VName
v') <- AllocM fromrep torep (VName, VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, VName)
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, VName)
forall a b. (a -> b) -> a -> b
$ Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray Maybe Space
forall a. Maybe a
Nothing VName
v
              let lmad_ext :: ExtLMAD
lmad_ext =
                    Int -> LMAD (TPrimExp Int64 VName) -> ExtLMAD
forall a.
Int -> LMAD (TPrimExp Int64 a) -> LMAD (TPrimExp Int64 (Ext a))
LMAD.existentialize Int
0 (LMAD (TPrimExp Int64 VName) -> ExtLMAD)
-> LMAD (TPrimExp Int64 VName) -> ExtLMAD
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> LMAD (TPrimExp Int64 VName)
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 VName
0 ([TPrimExp Int64 VName] -> LMAD (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> LMAD (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape

              Space
v_mem_space' <- AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      Space)
-> AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem'

              [Param FParamMem]
ctx_params <-
                Int
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [Param FParamMem]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([TPrimExp Int64 (Ext VName)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ExtLMAD -> [TPrimExp Int64 (Ext VName)]
forall a. LMAD a -> [a]
LMAD.existentialized ExtLMAD
lmad_ext)) (WriterT
   ([Param FParamMem], [Param FParamMem])
   (AllocM fromrep torep)
   (Param FParamMem)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [Param FParamMem]
forall a b. (a -> b) -> a -> b
$
                  String
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ctx_param_ext" (PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)

              LMAD (TPrimExp Int64 VName)
param_lmad <-
                ExtLMAD
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (LMAD (TPrimExp Int64 VName))
forall (m :: * -> *).
Monad m =>
ExtLMAD -> m (LMAD (TPrimExp Int64 VName))
instantiateLMAD (ExtLMAD
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (LMAD (TPrimExp Int64 VName)))
-> ExtLMAD
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (LMAD (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$
                  Map (Ext VName) (TPrimExp Int64 (Ext VName)) -> ExtLMAD -> ExtLMAD
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
LMAD.substitute
                    ( [(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, TPrimExp Int64 (Ext VName))]
 -> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> ([TPrimExp Int64 (Ext VName)]
    -> [(Ext VName, TPrimExp Int64 (Ext VName))])
-> [TPrimExp Int64 (Ext VName)]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Ext VName]
-> [TPrimExp Int64 (Ext VName)]
-> [(Ext VName, TPrimExp Int64 (Ext VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Int -> Ext VName) -> [Int] -> [Ext VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Ext VName
forall a. Int -> Ext a
Ext [Int
0 ..]) ([TPrimExp Int64 (Ext VName)]
 -> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> [TPrimExp Int64 (Ext VName)]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$
                        (Param FParamMem -> TPrimExp Int64 (Ext VName))
-> [Param FParamMem] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Ext VName -> TPrimExp Int64 (Ext VName))
-> (Param FParamMem -> Ext VName)
-> Param FParamMem
-> TPrimExp Int64 (Ext VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> (Param FParamMem -> VName) -> Param FParamMem -> Ext VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName) [Param FParamMem]
ctx_params
                    )
                    ExtLMAD
lmad_ext

              Param FParamMem
mem_param <- String
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" (FParamMem
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space'
              ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem
mem_param], [Param FParamMem]
ctx_params)
              (Param FParamMem, SubExp,
 SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall a.
a
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
                ( Param DeclType
mergeparam {paramDec = MemArray pt shape u $ ArrayIn (paramName mem_param) param_lmad},
                  VName -> SubExp
Var VName
v',
                  Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
v_mem_space'
                )
    allocInMergeParam (Param DeclType
mergeparam, SubExp
se) = Param (FParamInfo fromrep)
-> SubExp
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall {torep} {fromrep} {fromrep} {torep} {inner :: * -> *}
       {inner :: * -> *} {b}.
(BranchType torep ~ BranchTypeMem, BranchType fromrep ~ ExtType,
 BranchType fromrep ~ ExtType, BranchType torep ~ BranchTypeMem,
 LParamInfo torep ~ LParamMem, LParamInfo fromrep ~ Type,
 LParamInfo torep ~ LParamMem, LParamInfo fromrep ~ Type,
 ExpDec torep ~ (), ExpDec torep ~ (), LetDec torep ~ LParamMem,
 LetDec torep ~ LParamMem, BodyDec fromrep ~ (), BodyDec torep ~ (),
 BodyDec fromrep ~ (), BodyDec torep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo torep ~ FParamMem,
 FParamInfo fromrep ~ DeclType, FParamInfo torep ~ FParamMem,
 RetType torep ~ RetTypeMem, RetType fromrep ~ DeclExtType,
 RetType fromrep ~ DeclExtType, RetType torep ~ RetTypeMem,
 OpC torep ~ MemOp inner, OpC torep ~ MemOp inner,
 PrettyRep fromrep, PrettyRep fromrep, OpReturns inner,
 OpReturns inner, RephraseOp inner, RephraseOp inner,
 Rename (inner torep), Rename (inner torep),
 Substitute (inner torep), Substitute (inner torep),
 FreeIn (inner torep), FreeIn (inner torep),
 SizeSubst (inner torep), SizeSubst (inner torep), BuilderOps torep,
 BuilderOps torep, Show (inner torep), Show (inner torep),
 Ord (inner torep), Ord (inner torep), Pretty (inner torep),
 Pretty (inner torep)) =>
Param (FParamInfo fromrep)
-> b
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param DeclType
Param (FParamInfo fromrep)
mergeparam SubExp
se (Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem, SubExp,
       SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp))
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([Param FParamMem], [Param FParamMem]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace

    doDefault :: Param (FParamInfo fromrep)
-> b
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param (FParamInfo fromrep)
mergeparam b
se Space
space = do
      Param (FParamInfo torep)
mergeparam' <- Param (FParamInfo fromrep)
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep))
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam Param (FParamInfo fromrep)
mergeparam Space
space
      (Param (FParamInfo torep), b,
 SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall a.
a
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (FParamInfo torep)
mergeparam', b
se, Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg (Param (FParamInfo fromrep) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (FParamInfo fromrep)
mergeparam) Space
space)

arrayWithLMAD ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m), LetDec (Rep m) ~ LetDecMem) =>
  Space ->
  LMAD ->
  Type ->
  VName ->
  m (VName, VName)
arrayWithLMAD :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m),
 LetDec (Rep m) ~ LParamMem) =>
Space
-> LMAD (TPrimExp Int64 VName) -> Type -> VName -> m (VName, VName)
arrayWithLMAD Space
space LMAD (TPrimExp Int64 VName)
lmad Type
v_t VName
v = do
  let Array PrimType
pt Shape
shape NoUniqueness
u = Type
v_t
  VName
mem <- Type -> Space -> m VName
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
v_t Space
space
  VName
v_copy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_scalcopy"
  let pe :: PatElem LParamMem
pe = VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem VName
v_copy (LParamMem -> PatElem LParamMem) -> LParamMem -> PatElem LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> LMAD (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem LMAD (TPrimExp Int64 VName)
lmad
  Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem
pe]) (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
  (VName, VName) -> m (VName, VName)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v_copy)

ensureDirectArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensureDirectArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v = do
  (VName
mem, LMAD (TPrimExp Int64 VName)
lmad) <- VName -> AllocM fromrep torep (VName, LMAD (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, LMAD (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- VName -> AllocM fromrep torep Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  if LMAD (TPrimExp Int64 VName) -> Bool
forall num. (Eq num, IntegralExp num) => LMAD num -> Bool
LMAD.isDirect LMAD (TPrimExp Int64 VName)
lmad Bool -> Bool -> Bool
&& Bool -> (Space -> Bool) -> Maybe Space -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then (VName, VName) -> AllocM fromrep torep (VName, VName)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
    else Space -> AllocM fromrep torep (VName, VName)
needCopy (Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok)
  where
    needCopy :: Space -> AllocM fromrep torep (VName, VName)
needCopy Space
space =
      -- We need to do a new allocation, copy 'v', and make a new
      -- binding for the size of the memory block.
      Space -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

allocPermArray ::
  (Allocable fromrep torep inner) =>
  Space ->
  [Int] ->
  String ->
  VName ->
  AllocM fromrep torep (VName, VName)
allocPermArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray Space
space [Int]
perm String
s VName
v = do
  Type
t <- VName -> AllocM fromrep torep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  case Type
t of
    Array PrimType
pt Shape
shape NoUniqueness
u -> do
      VName
mem <- Type -> Space -> AllocM fromrep torep VName
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space
      VName
v' <- String -> AllocM fromrep torep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> AllocM fromrep torep VName)
-> String -> AllocM fromrep torep VName
forall a b. (a -> b) -> a -> b
$ String
s String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_desired_form"
      let info :: LParamMem
info =
            PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem)
-> (LMAD (TPrimExp Int64 VName) -> MemBind)
-> LMAD (TPrimExp Int64 VName)
-> LParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> LMAD (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (LMAD (TPrimExp Int64 VName) -> LParamMem)
-> LMAD (TPrimExp Int64 VName) -> LParamMem
forall a b. (a -> b) -> a -> b
$
              LMAD (TPrimExp Int64 VName) -> [Int] -> LMAD (TPrimExp Int64 VName)
forall num. LMAD num -> [Int] -> LMAD num
LMAD.permute (TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> LMAD (TPrimExp Int64 VName)
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 VName
0 ([TPrimExp Int64 VName] -> LMAD (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> LMAD (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) [Int]
perm
          pat :: Pat LParamMem
pat = [PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem VName
v' LParamMem
info]
      Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ())
-> Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (AllocM fromrep torep)))
-> StmAux (ExpDec (Rep (AllocM fromrep torep)))
-> Exp (Rep (AllocM fromrep torep))
-> Stm (Rep (AllocM fromrep torep))
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Rep (AllocM fromrep torep)))
Pat LParamMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp (Rep (AllocM fromrep torep))
 -> Stm (Rep (AllocM fromrep torep)))
-> Exp (Rep (AllocM fromrep torep))
-> Stm (Rep (AllocM fromrep torep))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (AllocM fromrep torep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (AllocM fromrep torep)))
-> BasicOp -> Exp (Rep (AllocM fromrep torep))
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
v
      (VName, VName) -> AllocM fromrep torep (VName, VName)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v')
    Type
_ ->
      String -> AllocM fromrep torep (VName, VName)
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep torep (VName, VName))
-> String -> AllocM fromrep torep (VName, VName)
forall a b. (a -> b) -> a -> b
$ String
"allocPermArray: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
prettyString Type
t

ensurePermArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  [Int] ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensurePermArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space
-> [Int] -> VName -> AllocM fromrep torep (VName, VName)
ensurePermArray Maybe Space
space_ok [Int]
perm VName
v = do
  (VName
mem, LMAD (TPrimExp Int64 VName)
_) <- VName -> AllocM fromrep torep (VName, LMAD (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, LMAD (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- VName -> AllocM fromrep torep Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  if Bool -> (Space -> Bool) -> Maybe Space -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then (VName, VName) -> AllocM fromrep torep (VName, VName)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
    else Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray (Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok) [Int]
perm (VName -> String
baseString VName
v) VName
v

allocLinearArray ::
  (Allocable fromrep torep inner) =>
  Space ->
  String ->
  VName ->
  AllocM fromrep torep (VName, VName)
allocLinearArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space String
s VName
v = do
  Type
t <- VName -> AllocM fromrep torep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  let perm :: [Int]
perm = [Int
0 .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
  Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray Space
space [Int]
perm String
s VName
v

funcallArgs ::
  (Allocable fromrep torep inner) =>
  [(SubExp, Diet)] ->
  AllocM fromrep torep [(SubExp, Diet)]
funcallArgs :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args = do
  ([(SubExp, Diet)]
valargs, ([SubExp]
ctx_args, [SubExp]
mem_and_size_args)) <- WriterT
  ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
-> AllocM fromrep torep ([(SubExp, Diet)], ([SubExp], [SubExp]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
 -> AllocM fromrep torep ([(SubExp, Diet)], ([SubExp], [SubExp])))
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
-> AllocM fromrep torep ([(SubExp, Diet)], ([SubExp], [SubExp]))
forall a b. (a -> b) -> a -> b
$
    [(SubExp, Diet)]
-> ((SubExp, Diet)
    -> WriterT
         ([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet))
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(SubExp, Diet)]
args (((SubExp, Diet)
  -> WriterT
       ([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet))
 -> WriterT
      ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)])
-> ((SubExp, Diet)
    -> WriterT
         ([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet))
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ \(SubExp
arg, Diet
d) -> do
      Type
t <- AllocM fromrep torep Type
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Type
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([SubExp], [SubExp]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Type
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Type)
-> AllocM fromrep torep Type
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Type
forall a b. (a -> b) -> a -> b
$ SubExp -> AllocM fromrep torep Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
arg
      Space
space <- AllocM fromrep torep Space
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Space
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([SubExp], [SubExp]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
      SubExp
arg' <- Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Type
t Space
space SubExp
arg
      (SubExp, Diet)
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet)
forall a.
a -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
arg', Diet
d)
  [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)])
-> [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ (SubExp -> (SubExp, Diet)) -> [SubExp] -> [(SubExp, Diet)]
forall a b. (a -> b) -> [a] -> [b]
map (,Diet
Observe) ([SubExp]
ctx_args [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
mem_and_size_args) [(SubExp, Diet)] -> [(SubExp, Diet)] -> [(SubExp, Diet)]
forall a. Semigroup a => a -> a -> a
<> [(SubExp, Diet)]
valargs

linearFuncallArg ::
  (Allocable fromrep torep inner) =>
  Type ->
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Array {} Space
space (Var VName
v) = do
  (VName
mem, VName
arg') <- AllocM fromrep torep (VName, VName)
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall (m :: * -> *) a.
Monad m =>
m a -> WriterT ([SubExp], [SubExp]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
 -> WriterT
      ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall a b. (a -> b) -> a -> b
$ Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) VName
v
  ([SubExp], [SubExp])
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
mem], [])
  SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a.
a -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arg'
linearFuncallArg Type
_ Space
_ SubExp
arg =
  SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a.
a -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
arg

shiftRetAls :: Int -> Int -> RetAls -> RetAls
shiftRetAls :: Int -> Int -> RetAls -> RetAls
shiftRetAls Int
a Int
b (RetAls [Int]
is [Int]
js) =
  [Int] -> [Int] -> RetAls
RetAls ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a) [Int]
is) ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
b) [Int]
js)

explicitAllocationsGeneric ::
  (Allocable fromrep torep inner) =>
  Space ->
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Pass fromrep torep
explicitAllocationsGeneric :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Pass fromrep torep
explicitAllocationsGeneric Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints =
  String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"explicit allocations" String
"Transform program to explicit memory representation" ((Prog fromrep -> PassM (Prog torep)) -> Pass fromrep torep)
-> (Prog fromrep -> PassM (Prog torep)) -> Pass fromrep torep
forall a b. (a -> b) -> a -> b
$
    (Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms fromrep -> PassM (Stms torep)
onStms Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun
  where
    onStms :: Stms fromrep -> PassM (Stms torep)
onStms Stms fromrep
stms =
      Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (Stms torep)
-> PassM (Stms torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (Stms torep) -> PassM (Stms torep))
-> AllocM fromrep torep (Stms torep) -> PassM (Stms torep)
forall a b. (a -> b) -> a -> b
$ AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
 -> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ Stms fromrep -> AllocM fromrep torep () -> AllocM fromrep torep ()
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ () -> AllocM fromrep torep ()
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    allocInFun :: Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun Stms torep
consts (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [(RetType fromrep, RetAls)]
rettype [FParam fromrep]
params Body fromrep
fbody) =
      Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (FunDef torep)
-> PassM (FunDef torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep))
-> (AllocM fromrep torep (FunDef torep)
    -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
-> PassM (FunDef torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms torep
-> AllocM fromrep torep (FunDef torep)
-> AllocM fromrep torep (FunDef torep)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms torep
consts (AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep))
-> AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep)
forall a b. (a -> b) -> a -> b
$
        [(FParam fromrep, Space)]
-> ([Param (FParamInfo torep)]
    -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams ((Param DeclType -> (Param DeclType, Space))
-> [Param DeclType] -> [(Param DeclType, Space)]
forall a b. (a -> b) -> [a] -> [b]
map (,Space
space) [Param DeclType]
[FParam fromrep]
params) (([Param (FParamInfo torep)]
  -> AllocM fromrep torep (FunDef torep))
 -> AllocM fromrep torep (FunDef torep))
-> ([Param (FParamInfo torep)]
    -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
forall a b. (a -> b) -> a -> b
$ \[Param (FParamInfo torep)]
params' -> do
          (Body torep
fbody', [RetTypeMem]
mem_rets) <-
            [Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody (((DeclExtType, RetAls) -> Maybe Space)
-> [(DeclExtType, RetAls)] -> [Maybe Space]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe Space -> (DeclExtType, RetAls) -> Maybe Space
forall a b. a -> b -> a
const (Maybe Space -> (DeclExtType, RetAls) -> Maybe Space)
-> Maybe Space -> (DeclExtType, RetAls) -> Maybe Space
forall a b. (a -> b) -> a -> b
$ Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) [(DeclExtType, RetAls)]
[(RetType fromrep, RetAls)]
rettype) Body fromrep
fbody
          let num_extra_params :: Int
num_extra_params = [Param FParamMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param (FParamInfo torep)]
[Param FParamMem]
params' Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Param DeclType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param DeclType]
[FParam fromrep]
params
              num_extra_rets :: Int
num_extra_rets = [RetTypeMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
mem_rets
              rettype' :: [(RetTypeMem, RetAls)]
rettype' =
                (RetTypeMem -> (RetTypeMem, RetAls))
-> [RetTypeMem] -> [(RetTypeMem, RetAls)]
forall a b. (a -> b) -> [a] -> [b]
map (,[Int] -> [Int] -> RetAls
RetAls [Int]
forall a. Monoid a => a
mempty [Int]
forall a. Monoid a => a
mempty) [RetTypeMem]
mem_rets
                  [(RetTypeMem, RetAls)]
-> [(RetTypeMem, RetAls)] -> [(RetTypeMem, RetAls)]
forall a. [a] -> [a] -> [a]
++ [RetTypeMem] -> [RetAls] -> [(RetTypeMem, RetAls)]
forall a b. [a] -> [b] -> [(a, b)]
zip
                    (Space -> Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Space
space ([RetTypeMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
mem_rets) (((DeclExtType, RetAls) -> DeclExtType)
-> [(DeclExtType, RetAls)] -> [DeclExtType]
forall a b. (a -> b) -> [a] -> [b]
map (DeclExtType, RetAls) -> DeclExtType
forall a b. (a, b) -> a
fst [(DeclExtType, RetAls)]
[(RetType fromrep, RetAls)]
rettype))
                    (((DeclExtType, RetAls) -> RetAls)
-> [(DeclExtType, RetAls)] -> [RetAls]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> RetAls -> RetAls
shiftRetAls Int
num_extra_params Int
num_extra_rets (RetAls -> RetAls)
-> ((DeclExtType, RetAls) -> RetAls)
-> (DeclExtType, RetAls)
-> RetAls
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DeclExtType, RetAls) -> RetAls
forall a b. (a, b) -> b
snd) [(DeclExtType, RetAls)]
[(RetType fromrep, RetAls)]
rettype)
          FunDef torep -> AllocM fromrep torep (FunDef torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef torep -> AllocM fromrep torep (FunDef torep))
-> FunDef torep -> AllocM fromrep torep (FunDef torep)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [(RetType torep, RetAls)]
-> [Param (FParamInfo torep)]
-> Body torep
-> FunDef torep
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [(RetType rep, RetAls)]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [(RetType torep, RetAls)]
[(RetTypeMem, RetAls)]
rettype' [Param (FParamInfo torep)]
params' Body torep
fbody'

explicitAllocationsInStmsGeneric ::
  ( MonadFreshNames m,
    HasScope torep m,
    Allocable fromrep torep inner
  ) =>
  Space ->
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Stms fromrep ->
  m (Stms torep)
explicitAllocationsInStmsGeneric :: forall (m :: * -> *) torep fromrep (inner :: * -> *).
(MonadFreshNames m, HasScope torep m,
 Allocable fromrep torep inner) =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Stms fromrep
-> m (Stms torep)
explicitAllocationsInStmsGeneric Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints Stms fromrep
stms = do
  Scope torep
scope <- m (Scope torep)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (Stms torep)
-> m (Stms torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (Stms torep) -> m (Stms torep))
-> AllocM fromrep torep (Stms torep) -> m (Stms torep)
forall a b. (a -> b) -> a -> b
$
    Scope torep
-> AllocM fromrep torep (Stms torep)
-> AllocM fromrep torep (Stms torep)
forall a.
Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
scope (AllocM fromrep torep (Stms torep)
 -> AllocM fromrep torep (Stms torep))
-> AllocM fromrep torep (Stms torep)
-> AllocM fromrep torep (Stms torep)
forall a b. (a -> b) -> a -> b
$
      AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
 -> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$
        Stms fromrep -> AllocM fromrep torep () -> AllocM fromrep torep ()
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$
          () -> AllocM fromrep torep ()
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

memoryInDeclExtType :: Space -> Int -> [DeclExtType] -> [FunReturns]
memoryInDeclExtType :: Space -> Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Space
space Int
k [DeclExtType]
dets = State Int [RetTypeMem] -> Int -> [RetTypeMem]
forall s a. State s a -> s -> a
evalState ((DeclExtType -> StateT Int Identity RetTypeMem)
-> [DeclExtType] -> State Int [RetTypeMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DeclExtType -> StateT Int Identity RetTypeMem
addMem [DeclExtType]
dets) Int
0
  where
    addMem :: DeclExtType -> StateT Int Identity RetTypeMem
addMem (Prim PrimType
t) = RetTypeMem -> StateT Int Identity RetTypeMem
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RetTypeMem -> StateT Int Identity RetTypeMem)
-> RetTypeMem -> StateT Int Identity RetTypeMem
forall a b. (a -> b) -> a -> b
$ PrimType -> RetTypeMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    addMem Mem {} = String -> StateT Int Identity RetTypeMem
forall a. HasCallStack => String -> a
error String
"memoryInDeclExtType: too much memory"
    addMem (Array PrimType
pt ExtShape
shape Uniqueness
u) = do
      Int
i <- StateT Int Identity Int
forall s (m :: * -> *). MonadState s m => m s
get StateT Int Identity Int
-> StateT Int Identity () -> StateT Int Identity Int
forall a b.
StateT Int Identity a
-> StateT Int Identity b -> StateT Int Identity a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* (Int -> Int) -> StateT Int Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      let shape' :: ExtShape
shape' = (Ext SubExp -> Ext SubExp) -> ExtShape -> ExtShape
forall a b. (a -> b) -> ShapeBase a -> ShapeBase b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext SubExp -> Ext SubExp
shift ExtShape
shape
      RetTypeMem -> StateT Int Identity RetTypeMem
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RetTypeMem -> StateT Int Identity RetTypeMem)
-> (ExtLMAD -> RetTypeMem)
-> ExtLMAD
-> StateT Int Identity RetTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> ExtShape -> Uniqueness -> MemReturn -> RetTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shape' Uniqueness
u (MemReturn -> RetTypeMem)
-> (ExtLMAD -> MemReturn) -> ExtLMAD -> RetTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtLMAD -> MemReturn
ReturnsNewBlock Space
space Int
i (ExtLMAD -> StateT Int Identity RetTypeMem)
-> ExtLMAD -> StateT Int Identity RetTypeMem
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 (Ext VName)
-> [TPrimExp Int64 (Ext VName)] -> ExtLMAD
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 (Ext VName)
0 ((Ext SubExp -> TPrimExp Int64 (Ext VName))
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert ([Ext SubExp] -> [TPrimExp Int64 (Ext VName)])
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> a -> b
$ ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ExtShape
shape')
    addMem (Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u) = RetTypeMem -> StateT Int Identity RetTypeMem
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RetTypeMem -> StateT Int Identity RetTypeMem)
-> RetTypeMem -> StateT Int Identity RetTypeMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> Uniqueness -> RetTypeMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Ext VName -> TPrimExp Int64 (Ext VName))
-> Ext VName -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> a -> b
$ Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i
    convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    shift :: Ext SubExp -> Ext SubExp
shift (Ext Int
i) = Int -> Ext SubExp
forall a. Int -> Ext a
Ext (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
    shift (Free SubExp
x) = SubExp -> Ext SubExp
forall a. a -> Ext a
Free SubExp
x

bodyReturnMemCtx ::
  (Allocable fromrep torep inner) =>
  SubExpRes ->
  AllocM fromrep torep [(SubExpRes, MemInfo ExtSize u MemReturn)]
bodyReturnMemCtx :: forall fromrep torep (inner :: * -> *) u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx (SubExpRes Certs
_ Constant {}) =
  [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
bodyReturnMemCtx (SubExpRes Certs
_ (Var VName
v)) = do
  LParamMem
info <- VName -> AllocM fromrep torep LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
v
  case LParamMem
info of
    MemPrim {} -> [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    MemAcc {} -> [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    MemMem {} -> [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [] -- should not happen
    MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem LMAD (TPrimExp Int64 VName)
_) -> do
      LParamMem
mem_info <- VName -> AllocM fromrep torep LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
mem
      case LParamMem
mem_info of
        MemMem Space
space ->
          [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(SubExp -> SubExpRes
subExpRes (SubExp -> SubExpRes) -> SubExp -> SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
mem, Space -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)]
        LParamMem
_ -> String
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. HasCallStack => String -> a
error (String
 -> AllocM
      fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)])
-> String
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a b. (a -> b) -> a -> b
$ String
"bodyReturnMemCtx: not a memory block: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
prettyString VName
mem

allocInFunBody ::
  (Allocable fromrep torep inner) =>
  [Maybe Space] ->
  Body fromrep ->
  AllocM fromrep torep (Body torep, [FunReturns])
allocInFunBody :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody [Maybe Space]
space_oks (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
  AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem])
AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM
     fromrep torep (Body (Rep (AllocM fromrep torep)), [RetTypeMem])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody (AllocM fromrep torep (Result, [RetTypeMem])
 -> AllocM fromrep torep (Body torep, [RetTypeMem]))
-> (AllocM fromrep torep (Result, [RetTypeMem])
    -> AllocM fromrep torep (Result, [RetTypeMem]))
-> AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Result, [RetTypeMem])
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep (Result, [RetTypeMem])
 -> AllocM fromrep torep (Body torep, [RetTypeMem]))
-> AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem])
forall a b. (a -> b) -> a -> b
$ do
    Result
res' <- (Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes)
-> [Maybe Space] -> Result -> AllocM fromrep torep Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect [Maybe Space]
space_oks' Result
res
    (Result
mem_ctx_res, [RetTypeMem]
mem_ctx_rets) <- [(SubExpRes, RetTypeMem)] -> (Result, [RetTypeMem])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExpRes, RetTypeMem)] -> (Result, [RetTypeMem]))
-> ([[(SubExpRes, RetTypeMem)]] -> [(SubExpRes, RetTypeMem)])
-> [[(SubExpRes, RetTypeMem)]]
-> (Result, [RetTypeMem])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[(SubExpRes, RetTypeMem)]] -> [(SubExpRes, RetTypeMem)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(SubExpRes, RetTypeMem)]] -> (Result, [RetTypeMem]))
-> AllocM fromrep torep [[(SubExpRes, RetTypeMem)]]
-> AllocM fromrep torep (Result, [RetTypeMem])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> AllocM fromrep torep [(SubExpRes, RetTypeMem)])
-> Result -> AllocM fromrep torep [[(SubExpRes, RetTypeMem)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExpRes -> AllocM fromrep torep [(SubExpRes, RetTypeMem)]
forall fromrep torep (inner :: * -> *) u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx Result
res'
    (Result, [RetTypeMem])
-> AllocM fromrep torep (Result, [RetTypeMem])
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
mem_ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
res', [RetTypeMem]
mem_ctx_rets)
  where
    num_vals :: Int
num_vals = [Maybe Space] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Space]
space_oks
    space_oks' :: [Maybe Space]
space_oks' = Int -> Maybe Space -> [Maybe Space]
forall a. Int -> a -> [a]
replicate (Result -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_vals) Maybe Space
forall a. Maybe a
Nothing [Maybe Space] -> [Maybe Space] -> [Maybe Space]
forall a. [a] -> [a] -> [a]
++ [Maybe Space]
space_oks

ensureDirect ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  SubExpRes ->
  AllocM fromrep torep SubExpRes
ensureDirect :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect Maybe Space
space_ok (SubExpRes Certs
cs SubExp
se) = do
  LParamMem
se_info <- SubExp -> AllocM fromrep torep LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
  Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (SubExp -> SubExpRes)
-> AllocM fromrep torep SubExp -> AllocM fromrep torep SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case (LParamMem
se_info, SubExp
se) of
    (MemArray {}, Var VName
v) -> do
      (VName
_, VName
v') <- Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v
      SubExp -> AllocM fromrep torep SubExp
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> AllocM fromrep torep SubExp)
-> SubExp -> AllocM fromrep torep SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'
    (LParamMem, SubExp)
_ ->
      SubExp -> AllocM fromrep torep SubExp
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

allocInStms ::
  (Allocable fromrep torep inner) =>
  Stms fromrep ->
  AllocM fromrep torep a ->
  AllocM fromrep torep a
allocInStms :: forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
origstms AllocM fromrep torep a
m = [Stm fromrep] -> AllocM fromrep torep a
allocInStms' ([Stm fromrep] -> AllocM fromrep torep a)
-> [Stm fromrep] -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ Stms fromrep -> [Stm fromrep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms fromrep
origstms
  where
    allocInStms' :: [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [] = AllocM fromrep torep a
m
    allocInStms' (Stm fromrep
stm : [Stm fromrep]
stms) = do
      Seq (Stm torep)
allocstms <- AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
 -> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec fromrep)
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing (Stm fromrep -> StmAux (ExpDec fromrep)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm fromrep
stm) (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ Stm fromrep -> AllocM fromrep torep ()
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm Stm fromrep
stm
      Stms (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Seq (Stm torep)
Stms (Rep (AllocM fromrep torep))
allocstms
      let stms_consts :: Set VName
stms_consts = (Stm torep -> Set VName) -> Seq (Stm torep) -> Set VName
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm torep -> Set VName
forall rep. SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts Seq (Stm torep)
allocstms
          f :: AllocEnv fromrep torep -> AllocEnv fromrep torep
f AllocEnv fromrep torep
env = AllocEnv fromrep torep
env {envConsts = stms_consts <> envConsts env}
      (AllocEnv fromrep torep -> AllocEnv fromrep torep)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a.
(AllocEnv fromrep torep -> AllocEnv fromrep torep)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv fromrep torep -> AllocEnv fromrep torep
f (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [Stm fromrep]
stms

allocInStm ::
  (Allocable fromrep torep inner) =>
  Stm fromrep ->
  AllocM fromrep torep ()
allocInStm :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm (Let (Pat [PatElem (LetDec fromrep)]
pes) StmAux (ExpDec fromrep)
_ Exp fromrep
e) =
  Stm torep -> AllocM fromrep torep ()
Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm torep -> AllocM fromrep torep ())
-> AllocM fromrep torep (Stm torep) -> AllocM fromrep torep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm ((PatElem (LetDec fromrep) -> Ident)
-> [PatElem (LetDec fromrep)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElem (LetDec fromrep) -> Ident
forall dec. Typed dec => PatElem dec -> Ident
patElemIdent [PatElem (LetDec fromrep)]
pes) (Exp torep -> AllocM fromrep torep (Stm torep))
-> AllocM fromrep torep (Exp torep)
-> AllocM fromrep torep (Stm torep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp fromrep -> AllocM fromrep torep (Exp torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp Exp fromrep
e

allocInLambda ::
  (Allocable fromrep torep inner) =>
  [LParam torep] ->
  Body fromrep ->
  AllocM fromrep torep (Lambda torep)
allocInLambda :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
params Body fromrep
body =
  [LParam (Rep (AllocM fromrep torep))]
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [LParam torep]
[LParam (Rep (AllocM fromrep torep))]
params (AllocM fromrep torep Result
 -> AllocM fromrep torep (Lambda torep))
-> (AllocM fromrep torep Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms (Body fromrep -> Stms fromrep
forall rep. Body rep -> Stms rep
bodyStms Body fromrep
body) (AllocM fromrep torep Result
 -> AllocM fromrep torep (Lambda torep))
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda torep)
forall a b. (a -> b) -> a -> b
$ Result -> AllocM fromrep torep Result
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ Body fromrep -> Result
forall rep. Body rep -> Result
bodyResult Body fromrep
body

data MemReq
  = MemReq Space
  | NeedsNormalisation Space
  deriving (MemReq -> MemReq -> Bool
(MemReq -> MemReq -> Bool)
-> (MemReq -> MemReq -> Bool) -> Eq MemReq
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MemReq -> MemReq -> Bool
== :: MemReq -> MemReq -> Bool
$c/= :: MemReq -> MemReq -> Bool
/= :: MemReq -> MemReq -> Bool
Eq, Int -> MemReq -> String -> String
[MemReq] -> String -> String
MemReq -> String
(Int -> MemReq -> String -> String)
-> (MemReq -> String)
-> ([MemReq] -> String -> String)
-> Show MemReq
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
$cshowsPrec :: Int -> MemReq -> String -> String
showsPrec :: Int -> MemReq -> String -> String
$cshow :: MemReq -> String
show :: MemReq -> String
$cshowList :: [MemReq] -> String -> String
showList :: [MemReq] -> String -> String
Show)

combMemReqs :: MemReq -> MemReq -> MemReq
combMemReqs :: MemReq -> MemReq -> MemReq
combMemReqs x :: MemReq
x@NeedsNormalisation {} MemReq
_ = MemReq
x
combMemReqs MemReq
_ y :: MemReq
y@NeedsNormalisation {} = MemReq
y
combMemReqs x :: MemReq
x@(MemReq Space
x_space) y :: MemReq
y@MemReq {} =
  if MemReq
x MemReq -> MemReq -> Bool
forall a. Eq a => a -> a -> Bool
== MemReq
y then MemReq
x else Space -> MemReq
NeedsNormalisation Space
x_space

type MemReqType = MemInfo (Ext SubExp) NoUniqueness MemReq

combMemReqTypes :: MemReqType -> MemReqType -> MemReqType
combMemReqTypes :: MemReqType -> MemReqType -> MemReqType
combMemReqTypes (MemArray PrimType
pt ExtShape
shape NoUniqueness
u MemReq
x) (MemArray PrimType
_ ExtShape
_ NoUniqueness
_ MemReq
y) =
  PrimType -> ExtShape -> NoUniqueness -> MemReq -> MemReqType
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shape NoUniqueness
u (MemReq -> MemReqType) -> MemReq -> MemReqType
forall a b. (a -> b) -> a -> b
$ MemReq -> MemReq -> MemReq
combMemReqs MemReq
x MemReq
y
combMemReqTypes MemReqType
x MemReqType
_ = MemReqType
x

contextRets :: MemReqType -> [MemInfo d u r]
contextRets :: forall d u r. MemReqType -> [MemInfo d u r]
contextRets (MemArray PrimType
_ ExtShape
shape NoUniqueness
_ (MemReq Space
space)) =
  -- Memory + offset + stride*rank.
  [Space -> MemInfo d u r
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space, PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64]
    [MemInfo d u r] -> [MemInfo d u r] -> [MemInfo d u r]
forall a. [a] -> [a] -> [a]
++ Int -> MemInfo d u r -> [MemInfo d u r]
forall a. Int -> a -> [a]
replicate (ExtShape -> Int
forall a. ArrayShape a => a -> Int
shapeRank ExtShape
shape) (PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
contextRets (MemArray PrimType
_ ExtShape
shape NoUniqueness
_ (NeedsNormalisation Space
space)) =
  -- Memory + offset + stride*rank.
  [Space -> MemInfo d u r
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space, PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64]
    [MemInfo d u r] -> [MemInfo d u r] -> [MemInfo d u r]
forall a. [a] -> [a] -> [a]
++ Int -> MemInfo d u r -> [MemInfo d u r]
forall a. Int -> a -> [a]
replicate (ExtShape -> Int
forall a. ArrayShape a => a -> Int
shapeRank ExtShape
shape) (PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
contextRets MemReqType
_ = []

-- Add memory information to the body, but do not return memory/lmad
-- information.  Instead, return restrictions on what the index
-- function should look like.  We will then (crudely) unify these
-- restrictions across all bodies.
allocInMatchBody ::
  (Allocable fromrep torep inner) =>
  [ExtType] ->
  Body fromrep ->
  AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [ExtType]
rets (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
  AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Body torep, [MemReqType])
AllocM fromrep torep (Result, [MemReqType])
-> AllocM
     fromrep torep (Body (Rep (AllocM fromrep torep)), [MemReqType])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody (AllocM fromrep torep (Result, [MemReqType])
 -> AllocM fromrep torep (Body torep, [MemReqType]))
-> (AllocM fromrep torep (Result, [MemReqType])
    -> AllocM fromrep torep (Result, [MemReqType]))
-> AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Body torep, [MemReqType])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Result, [MemReqType])
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep (Result, [MemReqType])
 -> AllocM fromrep torep (Body torep, [MemReqType]))
-> AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Body torep, [MemReqType])
forall a b. (a -> b) -> a -> b
$ do
    [MemReqType]
restrictions <- (ExtType -> SubExp -> AllocM fromrep torep MemReqType)
-> [ExtType] -> [SubExp] -> AllocM fromrep torep [MemReqType]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ExtType -> SubExp -> AllocM fromrep torep MemReqType
forall {rep} {inner :: * -> *} {m :: * -> *} {d}.
(OpC rep ~ MemOp inner, BranchType rep ~ BranchTypeMem,
 LParamInfo rep ~ LParamMem, FParamInfo rep ~ FParamMem,
 RetType rep ~ RetTypeMem, RephraseOp inner, OpReturns inner,
 Monad m, HasLetDecMem (LetDec rep), ASTRep rep, HasScope rep m,
 Pretty (inner rep), Rename (inner rep), Show d, Show (inner rep),
 Ord (inner rep), Substitute (inner rep), FreeIn (inner rep)) =>
TypeBase (ShapeBase d) NoUniqueness
-> SubExp -> m (MemInfo d NoUniqueness MemReq)
restriction [ExtType]
rets ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res)
    (Result, [MemReqType])
-> AllocM fromrep torep (Result, [MemReqType])
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
res, [MemReqType]
restrictions)
  where
    restriction :: TypeBase (ShapeBase d) NoUniqueness
-> SubExp -> m (MemInfo d NoUniqueness MemReq)
restriction TypeBase (ShapeBase d) NoUniqueness
t SubExp
se = do
      LParamMem
v_info <- SubExp -> m LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
      case (TypeBase (ShapeBase d) NoUniqueness
t, LParamMem
v_info) of
        (Array PrimType
pt ShapeBase d
shape NoUniqueness
u, MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem LMAD (TPrimExp Int64 VName)
_)) -> do
          Space
space <- VName -> m Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
          MemInfo d NoUniqueness MemReq -> m (MemInfo d NoUniqueness MemReq)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d NoUniqueness MemReq
 -> m (MemInfo d NoUniqueness MemReq))
-> MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase d
-> NoUniqueness
-> MemReq
-> MemInfo d NoUniqueness MemReq
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
shape NoUniqueness
u (MemReq -> MemInfo d NoUniqueness MemReq)
-> MemReq -> MemInfo d NoUniqueness MemReq
forall a b. (a -> b) -> a -> b
$ Space -> MemReq
MemReq Space
space
        (TypeBase (ShapeBase d) NoUniqueness
_, MemMem Space
space) -> MemInfo d NoUniqueness MemReq -> m (MemInfo d NoUniqueness MemReq)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d NoUniqueness MemReq
 -> m (MemInfo d NoUniqueness MemReq))
-> MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo d NoUniqueness MemReq
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
        (TypeBase (ShapeBase d) NoUniqueness
_, MemPrim PrimType
pt) -> MemInfo d NoUniqueness MemReq -> m (MemInfo d NoUniqueness MemReq)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d NoUniqueness MemReq
 -> m (MemInfo d NoUniqueness MemReq))
-> MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo d NoUniqueness MemReq
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
        (TypeBase (ShapeBase d) NoUniqueness
_, MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) -> MemInfo d NoUniqueness MemReq -> m (MemInfo d NoUniqueness MemReq)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d NoUniqueness MemReq
 -> m (MemInfo d NoUniqueness MemReq))
-> MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$ VName
-> Shape -> [Type] -> NoUniqueness -> MemInfo d NoUniqueness MemReq
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
        (TypeBase (ShapeBase d) NoUniqueness, LParamMem)
_ -> String -> m (MemInfo d NoUniqueness MemReq)
forall a. HasCallStack => String -> a
error (String -> m (MemInfo d NoUniqueness MemReq))
-> String -> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$ String
"allocInMatchBody: mismatch: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (TypeBase (ShapeBase d) NoUniqueness, LParamMem) -> String
forall a. Show a => a -> String
show (TypeBase (ShapeBase d) NoUniqueness
t, LParamMem
v_info)

mkBranchRet :: [MemReqType] -> [BranchTypeMem]
mkBranchRet :: [MemReqType] -> [BranchTypeMem]
mkBranchRet [MemReqType]
reqs =
  let ([BranchTypeMem]
ctx_rets, [BranchTypeMem]
res_rets) = (([BranchTypeMem], [BranchTypeMem])
 -> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem]))
-> ([BranchTypeMem], [BranchTypeMem])
-> [(MemReqType, Int)]
-> ([BranchTypeMem], [BranchTypeMem])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([BranchTypeMem], [BranchTypeMem])
-> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem])
helper ([], []) ([(MemReqType, Int)] -> ([BranchTypeMem], [BranchTypeMem]))
-> [(MemReqType, Int)] -> ([BranchTypeMem], [BranchTypeMem])
forall a b. (a -> b) -> a -> b
$ [MemReqType] -> [Int] -> [(MemReqType, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [MemReqType]
reqs [Int]
offsets
   in [BranchTypeMem]
ctx_rets [BranchTypeMem] -> [BranchTypeMem] -> [BranchTypeMem]
forall a. [a] -> [a] -> [a]
++ [BranchTypeMem]
res_rets
  where
    numCtxNeeded :: MemReqType -> Int
numCtxNeeded = [MemInfo Any Any Any] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([MemInfo Any Any Any] -> Int)
-> (MemReqType -> [MemInfo Any Any Any]) -> MemReqType -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemReqType -> [MemInfo Any Any Any]
forall d u r. MemReqType -> [MemInfo d u r]
contextRets

    offsets :: [Int]
offsets = (Int -> Int -> Int) -> Int -> [Int] -> [Int]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0 ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (MemReqType -> Int) -> [MemReqType] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map MemReqType -> Int
numCtxNeeded [MemReqType]
reqs
    num_new_ctx :: Int
num_new_ctx = [Int] -> Int
forall a. HasCallStack => [a] -> a
last [Int]
offsets

    helper :: ([BranchTypeMem], [BranchTypeMem])
-> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem])
helper ([BranchTypeMem]
ctx_rets_acc, [BranchTypeMem]
res_rets_acc) (MemReqType
req, Int
ctx_offset) =
      ( [BranchTypeMem]
ctx_rets_acc [BranchTypeMem] -> [BranchTypeMem] -> [BranchTypeMem]
forall a. [a] -> [a] -> [a]
++ MemReqType -> [BranchTypeMem]
forall d u r. MemReqType -> [MemInfo d u r]
contextRets MemReqType
req,
        [BranchTypeMem]
res_rets_acc [BranchTypeMem] -> [BranchTypeMem] -> [BranchTypeMem]
forall a. [a] -> [a] -> [a]
++ [Int -> MemReqType -> BranchTypeMem
inspect Int
ctx_offset MemReqType
req]
      )

    arrayInfo :: MemReq -> Space
arrayInfo (NeedsNormalisation Space
space) =
      Space
space
    arrayInfo (MemReq Space
space) =
      Space
space

    inspect :: Int -> MemReqType -> BranchTypeMem
inspect Int
ctx_offset (MemArray PrimType
pt ExtShape
shape NoUniqueness
u MemReq
req) =
      let shape' :: ExtShape
shape' = (Ext SubExp -> Ext SubExp) -> ExtShape -> ExtShape
forall a b. (a -> b) -> ShapeBase a -> ShapeBase b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Ext SubExp -> Ext SubExp
forall a. Int -> Ext a -> Ext a
adjustExt Int
num_new_ctx) ExtShape
shape
          space :: Space
space = MemReq -> Space
arrayInfo MemReq
req
       in PrimType -> ExtShape -> NoUniqueness -> MemReturn -> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shape' NoUniqueness
u (MemReturn -> BranchTypeMem)
-> (ExtLMAD -> MemReturn) -> ExtLMAD -> BranchTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtLMAD -> MemReturn
ReturnsNewBlock Space
space Int
ctx_offset (ExtLMAD -> BranchTypeMem) -> ExtLMAD -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$
            Ext SubExp -> TPrimExp Int64 (Ext VName)
convert
              (Ext SubExp -> TPrimExp Int64 (Ext VName))
-> LMAD (Ext SubExp) -> ExtLMAD
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Ext SubExp] -> Int -> LMAD (Ext SubExp)
forall a. Shape (Ext a) -> Int -> LMAD (Ext a)
LMAD.mkExistential (ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ExtShape
shape') (Int
ctx_offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    inspect Int
_ (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = VName -> Shape -> [Type] -> NoUniqueness -> BranchTypeMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
    inspect Int
_ (MemPrim PrimType
pt) = PrimType -> BranchTypeMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
    inspect Int
_ (MemMem Space
space) = Space -> BranchTypeMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i)
    convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    adjustExt :: Int -> Ext a -> Ext a
    adjustExt :: forall a. Int -> Ext a -> Ext a
adjustExt Int
_ (Free a
v) = a -> Ext a
forall a. a -> Ext a
Free a
v
    adjustExt Int
k (Ext Int
i) = Int -> Ext a
forall a. Int -> Ext a
Ext (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)

addCtxToMatchBody ::
  (Allocable fromrep torep inner) =>
  [MemReqType] ->
  Body torep ->
  AllocM fromrep torep (Body torep)
addCtxToMatchBody :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs Body torep
body = AllocM fromrep torep Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (AllocM fromrep torep Result
 -> AllocM fromrep torep (Body (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ do
  Result
res <- (MemReqType -> SubExpRes -> AllocM fromrep torep SubExpRes)
-> [MemReqType] -> Result -> AllocM fromrep torep Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM MemReqType -> SubExpRes -> AllocM fromrep torep SubExpRes
forall {torep} {fromrep} {inner :: * -> *} {d} {u}.
(RetType torep ~ RetTypeMem, RetType fromrep ~ DeclExtType,
 FParamInfo torep ~ FParamMem, FParamInfo fromrep ~ DeclType,
 BodyDec torep ~ (), BodyDec fromrep ~ (), ExpDec torep ~ (),
 LParamInfo fromrep ~ Type, LParamInfo torep ~ LParamMem,
 BranchType fromrep ~ ExtType, BranchType torep ~ BranchTypeMem,
 LetDec torep ~ LParamMem, OpC torep ~ MemOp inner,
 PrettyRep fromrep, OpReturns inner, RephraseOp inner,
 Rename (inner torep), Substitute (inner torep),
 FreeIn (inner torep), SizeSubst (inner torep), BuilderOps torep,
 ArrayShape (ShapeBase d), Show (inner torep), Ord d,
 Ord (inner torep), Pretty (inner torep)) =>
MemInfo d u MemReq -> SubExpRes -> AllocM fromrep torep SubExpRes
normaliseIfNeeded [MemReqType]
reqs (Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Body (Rep (AllocM fromrep torep)) -> AllocM fromrep torep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body torep
Body (Rep (AllocM fromrep torep))
body
  Result
ctx <- [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Result)
-> AllocM fromrep torep [Result] -> AllocM fromrep torep Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep [Result]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExpRes -> AllocM fromrep torep Result
forall {f :: * -> *} {inner :: * -> *}.
(OpC (Rep f) ~ MemOp inner, FParamInfo (Rep f) ~ FParamMem,
 BranchType (Rep f) ~ BranchTypeMem, LParamInfo (Rep f) ~ LParamMem,
 RetType (Rep f) ~ RetTypeMem, OpReturns inner, MonadBuilder f,
 RephraseOp inner, HasLetDecMem (LetDec (Rep f)),
 Pretty (inner (Rep f)), Rename (inner (Rep f)),
 Substitute (inner (Rep f)), FreeIn (inner (Rep f)),
 Show (inner (Rep f)), Ord (inner (Rep f))) =>
SubExpRes -> f Result
resCtx Result
res
  Result -> AllocM fromrep torep Result
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ Result
ctx Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
res
  where
    normaliseIfNeeded :: MemInfo d u MemReq -> SubExpRes -> AllocM fromrep torep SubExpRes
normaliseIfNeeded (MemArray PrimType
_ ShapeBase d
shape u
_ (NeedsNormalisation Space
space)) (SubExpRes Certs
cs (Var VName
v)) =
      Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (SubExp -> SubExpRes)
-> ((VName, VName) -> SubExp) -> (VName, VName) -> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> ((VName, VName) -> VName) -> (VName, VName) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, VName) -> VName
forall a b. (a, b) -> b
snd
        ((VName, VName) -> SubExpRes)
-> AllocM fromrep torep (VName, VName)
-> AllocM fromrep torep SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Space
-> [Int] -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space
-> [Int] -> VName -> AllocM fromrep torep (VName, VName)
ensurePermArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) [Int
0 .. ShapeBase d -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase d
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] VName
v
    normaliseIfNeeded MemInfo d u MemReq
_ SubExpRes
res =
      SubExpRes -> AllocM fromrep torep SubExpRes
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExpRes
res

    resCtx :: SubExpRes -> f Result
resCtx (SubExpRes Certs
_ Constant {}) =
      Result -> f Result
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    resCtx (SubExpRes Certs
_ (Var VName
v)) = do
      LParamMem
info <- VName -> f LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
v
      case LParamMem
info of
        MemPrim {} -> Result -> f Result
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        MemAcc {} -> Result -> f Result
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        MemMem {} -> Result -> f Result
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [] -- should not happen
        MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem LMAD (TPrimExp Int64 VName)
lmad) -> do
          [SubExp]
lmad_exts <- (TPrimExp Int64 VName -> f SubExp)
-> [TPrimExp Int64 VName] -> f [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep f) -> f SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"lmad_ext" (Exp (Rep f) -> f SubExp)
-> (TPrimExp Int64 VName -> f (Exp (Rep f)))
-> TPrimExp Int64 VName
-> f SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName -> f (Exp (Rep f))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp) ([TPrimExp Int64 VName] -> f [SubExp])
-> [TPrimExp Int64 VName] -> f [SubExp]
forall a b. (a -> b) -> a -> b
$ LMAD (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.existentialized LMAD (TPrimExp Int64 VName)
lmad
          Result -> f Result
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> f Result) -> Result -> f Result
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExpRes
subExpRes (VName -> SubExp
Var VName
mem) SubExpRes -> Result -> Result
forall a. a -> [a] -> [a]
: [SubExp] -> Result
subExpsRes [SubExp]
lmad_exts

-- Do a a simple form of invariance analysis to simplify a Match.  It
-- is unfortunate that we have to do it here, but functions such as
-- scalarRes will look carefully at the index functions before the
-- simplifier has a chance to run.  In a perfect world we would
-- simplify away those copies afterwards. XXX; this should be fixed by
-- a more general copy-removal pass. See
-- Futhark.Optimise.EntryPointMem for a very specialised version of
-- the idea, but which could perhaps be generalised.
simplifyMatch ::
  (Mem rep inner) =>
  [Case (Body rep)] ->
  Body rep ->
  [BranchTypeMem] ->
  ( [Case (Body rep)],
    Body rep,
    [BranchTypeMem]
  )
simplifyMatch :: forall rep (inner :: * -> *).
Mem rep inner =>
[Case (Body rep)]
-> Body rep
-> [BranchTypeMem]
-> ([Case (Body rep)], Body rep, [BranchTypeMem])
simplifyMatch [Case (Body rep)]
cases Body rep
defbody [BranchTypeMem]
ts =
  let case_reses :: [Result]
case_reses = (Case (Body rep) -> Result) -> [Case (Body rep)] -> [Result]
forall a b. (a -> b) -> [a] -> [b]
map (Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result)
-> (Case (Body rep) -> Body rep) -> Case (Body rep) -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
      defbody_res :: Result
defbody_res = Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
defbody
      ([(Int, SubExp)]
ctx_fixes, [(Result, SubExpRes, BranchTypeMem)]
variant) =
        [Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)]
-> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)]
 -> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)]))
-> ([(Int, Result, SubExpRes, BranchTypeMem)]
    -> [Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)])
-> [(Int, Result, SubExpRes, BranchTypeMem)]
-> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Result, SubExpRes, BranchTypeMem)
 -> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem))
-> [(Int, Result, SubExpRes, BranchTypeMem)]
-> [Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
branchInvariant ([(Int, Result, SubExpRes, BranchTypeMem)]
 -> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)]))
-> [(Int, Result, SubExpRes, BranchTypeMem)]
-> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)])
forall a b. (a -> b) -> a -> b
$
          [Int]
-> [Result]
-> Result
-> [BranchTypeMem]
-> [(Int, Result, SubExpRes, BranchTypeMem)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Int
0 ..] ([Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
case_reses) Result
defbody_res [BranchTypeMem]
ts
      ([Result]
cases_reses, Result
defbody_reses, [BranchTypeMem]
ts') = [(Result, SubExpRes, BranchTypeMem)]
-> ([Result], Result, [BranchTypeMem])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Result, SubExpRes, BranchTypeMem)]
variant
   in ( (Case (Body rep) -> Result -> Case (Body rep))
-> [Case (Body rep)] -> [Result] -> [Case (Body rep)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Case (Body rep) -> Result -> Case (Body rep)
forall {f :: * -> *} {rep}.
Functor f =>
f (Body rep) -> Result -> f (Body rep)
onCase [Case (Body rep)]
cases ([Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
cases_reses),
        Body rep -> Result -> Body rep
forall {rep}. Body rep -> Result -> Body rep
onBody Body rep
defbody Result
defbody_reses,
        ((Int, SubExp) -> [BranchTypeMem] -> [BranchTypeMem])
-> [BranchTypeMem] -> [(Int, SubExp)] -> [BranchTypeMem]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchTypeMem] -> [BranchTypeMem])
-> (Int, SubExp) -> [BranchTypeMem] -> [BranchTypeMem]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchTypeMem] -> [BranchTypeMem]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [BranchTypeMem]
ts' [(Int, SubExp)]
ctx_fixes
      )
  where
    bound_in_branches :: Names
bound_in_branches =
      [VName] -> Names
namesFromList ([VName] -> Names)
-> (Seq (Stm rep) -> [VName]) -> Seq (Stm rep) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm rep -> [VName]) -> Seq (Stm rep) -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName])
-> (Stm rep -> Pat (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) (Seq (Stm rep) -> Names) -> Seq (Stm rep) -> Names
forall a b. (a -> b) -> a -> b
$
        (Case (Body rep) -> Seq (Stm rep))
-> [Case (Body rep)] -> Seq (Stm rep)
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Seq (Stm rep))
-> (Case (Body rep) -> Body rep)
-> Case (Body rep)
-> Seq (Stm rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases Seq (Stm rep) -> Seq (Stm rep) -> Seq (Stm rep)
forall a. Semigroup a => a -> a -> a
<> Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
defbody

    onCase :: f (Body rep) -> Result -> f (Body rep)
onCase f (Body rep)
c Result
res = (Body rep -> Body rep) -> f (Body rep) -> f (Body rep)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Body rep -> Result -> Body rep
forall {rep}. Body rep -> Result -> Body rep
`onBody` Result
res) f (Body rep)
c
    onBody :: Body rep -> Result -> Body rep
onBody Body rep
body Result
res = Body rep
body {bodyResult = res}

    branchInvariant :: (Int, Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
branchInvariant (Int
i, Result
case_reses, SubExpRes
defres, BranchTypeMem
t)
      -- If even one branch has a variant result, then we give up.
      | Names -> Names -> Bool
namesIntersect Names
bound_in_branches (Names -> Bool) -> Names -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ SubExpRes
defres SubExpRes -> Result -> Result
forall a. a -> [a] -> [a]
: Result
case_reses =
          (Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
forall a b. b -> Either a b
Right (Result
case_reses, SubExpRes
defres, BranchTypeMem
t)
      -- Do all branches return the same value?
      | (SubExpRes -> Bool) -> Result -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
defres) (SubExp -> Bool) -> (SubExpRes -> SubExp) -> SubExpRes -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
case_reses =
          (Int, SubExp)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
forall a b. a -> Either a b
Left (Int
i, SubExpRes -> SubExp
resSubExp SubExpRes
defres)
      | Bool
otherwise =
          (Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
forall a b. b -> Either a b
Right (Result
case_reses, SubExpRes
defres, BranchTypeMem
t)

allocInExp ::
  (Allocable fromrep torep inner) =>
  Exp fromrep ->
  AllocM fromrep torep (Exp torep)
allocInExp :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp (Loop [(FParam fromrep, SubExp)]
merge LoopForm
form (Body () Stms fromrep
bodystms Result
bodyres)) =
  [(FParam fromrep, SubExp)]
-> ([(Param (FParamInfo torep), SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep (Exp torep))
-> AllocM fromrep torep (Exp torep)
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInLoopParams [(FParam fromrep, SubExp)]
merge (([(Param (FParamInfo torep), SubExp)]
  -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
  -> AllocM fromrep torep (Exp torep))
 -> AllocM fromrep torep (Exp torep))
-> ([(Param (FParamInfo torep), SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep (Exp torep))
-> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ \[(Param (FParamInfo torep), SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val -> do
    Scope torep
-> AllocM fromrep torep (Exp torep)
-> AllocM fromrep torep (Exp torep)
forall a.
Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope torep
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form) (AllocM fromrep torep (Exp torep)
 -> AllocM fromrep torep (Exp torep))
-> AllocM fromrep torep (Exp torep)
-> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ do
      Body torep
body' <-
        AllocM fromrep torep Result -> AllocM fromrep torep (Body torep)
AllocM fromrep torep Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (AllocM fromrep torep Result -> AllocM fromrep torep (Body torep))
-> (AllocM fromrep torep Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Body torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
bodystms (AllocM fromrep torep Result -> AllocM fromrep torep (Body torep))
-> AllocM fromrep torep Result -> AllocM fromrep torep (Body torep)
forall a b. (a -> b) -> a -> b
$ do
          ([SubExp]
valctx, [SubExp]
valres') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
bodyres
          Result -> AllocM fromrep torep Result
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
valctx Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> (Certs -> SubExp -> SubExpRes) -> [Certs] -> [SubExp] -> Result
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Certs -> SubExp -> SubExpRes
SubExpRes ((SubExpRes -> Certs) -> Result -> [Certs]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts Result
bodyres) [SubExp]
valres'
      Exp torep -> AllocM fromrep torep (Exp torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp torep -> AllocM fromrep torep (Exp torep))
-> Exp torep -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo torep), SubExp)]
-> LoopForm -> Body torep -> Exp torep
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param (FParamInfo torep), SubExp)]
merge' LoopForm
form Body torep
body'
allocInExp (Apply Name
fname [(SubExp, Diet)]
args [(RetType fromrep, RetAls)]
rettype (Safety, SrcLoc, [SrcLoc])
loc) = do
  [(SubExp, Diet)]
args' <- [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args
  Space
space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  -- We assume that every array is going to be in its own memory.
  let num_extra_args :: Int
num_extra_args = [(SubExp, Diet)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SubExp, Diet)]
args' Int -> Int -> Int
forall a. Num a => a -> a -> a
- [(SubExp, Diet)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SubExp, Diet)]
args
      rettype' :: [(RetTypeMem, RetAls)]
rettype' =
        Space -> [(RetTypeMem, RetAls)]
mems Space
space
          [(RetTypeMem, RetAls)]
-> [(RetTypeMem, RetAls)] -> [(RetTypeMem, RetAls)]
forall a. [a] -> [a] -> [a]
++ [RetTypeMem] -> [RetAls] -> [(RetTypeMem, RetAls)]
forall a b. [a] -> [b] -> [(a, b)]
zip
            (Space -> Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Space
space Int
num_arrays (((DeclExtType, RetAls) -> DeclExtType)
-> [(DeclExtType, RetAls)] -> [DeclExtType]
forall a b. (a -> b) -> [a] -> [b]
map (DeclExtType, RetAls) -> DeclExtType
forall a b. (a, b) -> a
fst [(DeclExtType, RetAls)]
[(RetType fromrep, RetAls)]
rettype))
            (((DeclExtType, RetAls) -> RetAls)
-> [(DeclExtType, RetAls)] -> [RetAls]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> RetAls -> RetAls
shiftRetAls Int
num_extra_args Int
num_arrays (RetAls -> RetAls)
-> ((DeclExtType, RetAls) -> RetAls)
-> (DeclExtType, RetAls)
-> RetAls
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DeclExtType, RetAls) -> RetAls
forall a b. (a, b) -> b
snd) [(DeclExtType, RetAls)]
[(RetType fromrep, RetAls)]
rettype)
  Exp torep -> AllocM fromrep torep (Exp torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp torep -> AllocM fromrep torep (Exp torep))
-> Exp torep -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ Name
-> [(SubExp, Diet)]
-> [(RetType torep, RetAls)]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp torep
forall rep.
Name
-> [(SubExp, Diet)]
-> [(RetType rep, RetAls)]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp rep
Apply Name
fname [(SubExp, Diet)]
args' [(RetType torep, RetAls)]
[(RetTypeMem, RetAls)]
rettype' (Safety, SrcLoc, [SrcLoc])
loc
  where
    mems :: Space -> [(RetTypeMem, RetAls)]
mems Space
space = Int -> (RetTypeMem, RetAls) -> [(RetTypeMem, RetAls)]
forall a. Int -> a -> [a]
replicate Int
num_arrays (Space -> RetTypeMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space, [Int] -> [Int] -> RetAls
RetAls [Int]
forall a. Monoid a => a
mempty [Int]
forall a. Monoid a => a
mempty)
    num_arrays :: Int
num_arrays = [(DeclExtType, RetAls)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([(DeclExtType, RetAls)] -> Int) -> [(DeclExtType, RetAls)] -> Int
forall a b. (a -> b) -> a -> b
$ ((DeclExtType, RetAls) -> Bool)
-> [(DeclExtType, RetAls)] -> [(DeclExtType, RetAls)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (Int -> Bool)
-> ((DeclExtType, RetAls) -> Int) -> (DeclExtType, RetAls) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DeclExtType -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (DeclExtType -> Int)
-> ((DeclExtType, RetAls) -> DeclExtType)
-> (DeclExtType, RetAls)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DeclExtType -> DeclExtType
forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf (DeclExtType -> DeclExtType)
-> ((DeclExtType, RetAls) -> DeclExtType)
-> (DeclExtType, RetAls)
-> DeclExtType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DeclExtType, RetAls) -> DeclExtType
forall a b. (a, b) -> a
fst) [(DeclExtType, RetAls)]
[(RetType fromrep, RetAls)]
rettype
allocInExp (Match [SubExp]
ses [Case (Body fromrep)]
cases Body fromrep
defbody (MatchDec [BranchType fromrep]
rets MatchSort
ifsort)) = do
  (Body torep
defbody', [MemReqType]
def_reqs) <- [ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [ExtType]
[BranchType fromrep]
rets Body fromrep
defbody
  ([Case (Body torep)]
cases', [[MemReqType]]
cases_reqs) <- (Case (Body fromrep)
 -> AllocM fromrep torep (Case (Body torep), [MemReqType]))
-> [Case (Body fromrep)]
-> AllocM fromrep torep ([Case (Body torep)], [[MemReqType]])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM Case (Body fromrep)
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
onCase [Case (Body fromrep)]
cases
  let reqs :: [MemReqType]
reqs = (MemReqType -> [MemReqType] -> MemReqType)
-> [MemReqType] -> [[MemReqType]] -> [MemReqType]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((MemReqType -> MemReqType -> MemReqType)
-> MemReqType -> [MemReqType] -> MemReqType
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl MemReqType -> MemReqType -> MemReqType
combMemReqTypes) [MemReqType]
def_reqs ([[MemReqType]] -> [[MemReqType]]
forall a. [[a]] -> [[a]]
transpose [[MemReqType]]
cases_reqs)
  Body torep
defbody'' <- [MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs Body torep
defbody'
  [Case (Body torep)]
cases'' <- (Case (Body torep) -> AllocM fromrep torep (Case (Body torep)))
-> [Case (Body torep)] -> AllocM fromrep torep [Case (Body torep)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body torep -> AllocM fromrep torep (Body torep))
-> Case (Body torep) -> AllocM fromrep torep (Case (Body torep))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse ((Body torep -> AllocM fromrep torep (Body torep))
 -> Case (Body torep) -> AllocM fromrep torep (Case (Body torep)))
-> (Body torep -> AllocM fromrep torep (Body torep))
-> Case (Body torep)
-> AllocM fromrep torep (Case (Body torep))
forall a b. (a -> b) -> a -> b
$ [MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs) [Case (Body torep)]
cases'
  let ([Case (Body torep)]
cases''', Body torep
defbody''', [BranchTypeMem]
rets') =
        [Case (Body torep)]
-> Body torep
-> [BranchTypeMem]
-> ([Case (Body torep)], Body torep, [BranchTypeMem])
forall rep (inner :: * -> *).
Mem rep inner =>
[Case (Body rep)]
-> Body rep
-> [BranchTypeMem]
-> ([Case (Body rep)], Body rep, [BranchTypeMem])
simplifyMatch [Case (Body torep)]
cases'' Body torep
defbody'' ([BranchTypeMem]
 -> ([Case (Body torep)], Body torep, [BranchTypeMem]))
-> [BranchTypeMem]
-> ([Case (Body torep)], Body torep, [BranchTypeMem])
forall a b. (a -> b) -> a -> b
$ [MemReqType] -> [BranchTypeMem]
mkBranchRet [MemReqType]
reqs
  Exp torep -> AllocM fromrep torep (Exp torep)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp torep -> AllocM fromrep torep (Exp torep))
-> Exp torep -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body torep)]
-> Body torep
-> MatchDec (BranchType torep)
-> Exp torep
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses [Case (Body torep)]
cases''' Body torep
defbody''' (MatchDec (BranchType torep) -> Exp torep)
-> MatchDec (BranchType torep) -> Exp torep
forall a b. (a -> b) -> a -> b
$ [BranchTypeMem] -> MatchSort -> MatchDec BranchTypeMem
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchTypeMem]
rets' MatchSort
ifsort
  where
    onCase :: Case (Body fromrep)
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
onCase (Case [Maybe PrimValue]
vs Body fromrep
body) = (Body torep -> Case (Body torep))
-> (Body torep, [MemReqType]) -> (Case (Body torep), [MemReqType])
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ([Maybe PrimValue] -> Body torep -> Case (Body torep)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs) ((Body torep, [MemReqType]) -> (Case (Body torep), [MemReqType]))
-> AllocM fromrep torep (Body torep, [MemReqType])
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [ExtType]
[BranchType fromrep]
rets Body fromrep
body
allocInExp (WithAcc [WithAccInput fromrep]
inputs Lambda fromrep
bodylam) =
  [WithAccInput torep] -> Lambda torep -> Exp torep
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc ([WithAccInput torep] -> Lambda torep -> Exp torep)
-> AllocM fromrep torep [WithAccInput torep]
-> AllocM fromrep torep (Lambda torep -> Exp torep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithAccInput fromrep -> AllocM fromrep torep (WithAccInput torep))
-> [WithAccInput fromrep]
-> AllocM fromrep torep [WithAccInput torep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM WithAccInput fromrep -> AllocM fromrep torep (WithAccInput torep)
forall {rep} {fromrep} {inner :: * -> *} {t :: * -> *} {a} {b}.
(BodyDec rep ~ (), BodyDec fromrep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo rep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType rep ~ RetTypeMem,
 LParamInfo fromrep ~ Type, LParamInfo rep ~ LParamMem,
 BranchType fromrep ~ ExtType, BranchType rep ~ BranchTypeMem,
 ExpDec rep ~ (), LetDec rep ~ LParamMem, OpC rep ~ MemOp inner,
 Traversable t, ArrayShape a, BuilderOps rep, OpReturns inner,
 RephraseOp inner, Rename (inner rep), Substitute (inner rep),
 FreeIn (inner rep), SizeSubst (inner rep), Show (inner rep),
 Pretty (inner rep), PrettyRep fromrep, Ord (inner rep)) =>
(a, [VName], t (Lambda fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput [WithAccInput fromrep]
inputs AllocM fromrep torep (Lambda torep -> Exp torep)
-> AllocM fromrep torep (Lambda torep)
-> AllocM fromrep torep (Exp torep)
forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda fromrep -> AllocM fromrep torep (Lambda torep)
forall {torep} {fromrep} {inner :: * -> *}.
(BranchType torep ~ BranchTypeMem, BranchType fromrep ~ ExtType,
 LParamInfo torep ~ LParamMem, LParamInfo fromrep ~ Type,
 LetDec torep ~ LParamMem, ExpDec torep ~ (), BodyDec fromrep ~ (),
 BodyDec torep ~ (), FParamInfo fromrep ~ DeclType,
 FParamInfo torep ~ FParamMem, RetType fromrep ~ DeclExtType,
 RetType torep ~ RetTypeMem, OpC torep ~ MemOp inner,
 PrettyRep fromrep, OpReturns inner, RephraseOp inner,
 Rename (inner torep), Substitute (inner torep),
 FreeIn (inner torep), SizeSubst (inner torep), BuilderOps torep,
 Show (inner torep), Ord (inner torep), Pretty (inner torep)) =>
Lambda fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
bodylam
  where
    onLambda :: Lambda fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
lam = do
      [Param LParamMem]
params <- [Param Type]
-> (Param Type -> AllocM fromrep torep (Param LParamMem))
-> AllocM fromrep torep [Param LParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda fromrep -> [LParam fromrep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda fromrep
lam) ((Param Type -> AllocM fromrep torep (Param LParamMem))
 -> AllocM fromrep torep [Param LParamMem])
-> (Param Type -> AllocM fromrep torep (Param LParamMem))
-> AllocM fromrep torep [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ \(Param Attrs
attrs VName
pv Type
t) ->
        case Type
t of
          Prim PrimType
Unit -> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param LParamMem -> AllocM fromrep torep (Param LParamMem))
-> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ Attrs -> VName -> LParamMem -> Param LParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv (LParamMem -> Param LParamMem) -> LParamMem -> Param LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
Unit
          Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u -> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall a. a -> AllocM fromrep torep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param LParamMem -> AllocM fromrep torep (Param LParamMem))
-> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ Attrs -> VName -> LParamMem -> Param LParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv (LParamMem -> Param LParamMem) -> LParamMem -> Param LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
          Type
_ -> String -> AllocM fromrep torep (Param LParamMem)
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep torep (Param LParamMem))
-> String -> AllocM fromrep torep (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ String
"Unexpected WithAcc lambda param: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Param Type -> String
forall a. Pretty a => a -> String
prettyString (Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv Type
t)
      [LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
[Param LParamMem]
params (Lambda fromrep -> Body fromrep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda fromrep
lam)

    onInput :: (a, [VName], t (Lambda fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput (a
shape, [VName]
arrs, t (Lambda fromrep, b)
op) =
      (a
shape,[VName]
arrs,) (t (Lambda rep, b) -> (a, [VName], t (Lambda rep, b)))
-> AllocM fromrep rep (t (Lambda rep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Lambda fromrep, b) -> AllocM fromrep rep (Lambda rep, b))
-> t (Lambda fromrep, b) -> AllocM fromrep rep (t (Lambda rep, b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b)
traverse (a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
forall {rep} {fromrep} {inner :: * -> *} {a} {b}.
(ExpDec rep ~ (), LetDec rep ~ LParamMem,
 BranchType rep ~ BranchTypeMem, BranchType fromrep ~ ExtType,
 LParamInfo rep ~ LParamMem, LParamInfo fromrep ~ Type,
 RetType rep ~ RetTypeMem, RetType fromrep ~ DeclExtType,
 FParamInfo rep ~ FParamMem, FParamInfo fromrep ~ DeclType,
 BodyDec fromrep ~ (), BodyDec rep ~ (), OpC rep ~ MemOp inner,
 ArrayShape a, BuilderOps rep, PrettyRep fromrep, OpReturns inner,
 RephraseOp inner, Rename (inner rep), Substitute (inner rep),
 FreeIn (inner rep), SizeSubst (inner rep), Show (inner rep),
 Pretty (inner rep), Ord (inner rep)) =>
a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
shape [VName]
arrs) t (Lambda fromrep, b)
op

    onOp :: a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
accshape [VName]
arrs (Lambda fromrep
lam, b
nes) = do
      let num_vs :: Int
num_vs = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda fromrep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda fromrep
lam)
          num_is :: Int
num_is = a -> Int
forall a. ArrayShape a => a -> Int
shapeRank a
accshape
          ([Param Type]
i_params, [Param Type]
x_params, [Param Type]
y_params) =
            Int
-> Int
-> [Param Type]
-> ([Param Type], [Param Type], [Param Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 Int
num_is Int
num_vs ([Param Type] -> ([Param Type], [Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda fromrep -> [LParam fromrep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda fromrep
lam
          i_params' :: [Param LParamMem]
i_params' = (Param Type -> Param LParamMem)
-> [Param Type] -> [Param LParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (\(Param Attrs
attrs VName
v Type
_) -> Attrs -> VName -> LParamMem -> Param LParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
v (LParamMem -> Param LParamMem) -> LParamMem -> Param LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64) [Param Type]
i_params
          is :: [DimIndex SubExp]
is = (Param LParamMem -> DimIndex SubExp)
-> [Param LParamMem] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (Param LParamMem -> SubExp)
-> Param LParamMem
-> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (Param LParamMem -> VName) -> Param LParamMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param LParamMem -> VName
forall dec. Param dec -> VName
paramName) [Param LParamMem]
i_params'
      [Param LParamMem]
x_params' <- (Param Type -> VName -> AllocM fromrep rep (Param LParamMem))
-> [Param Type] -> [VName] -> AllocM fromrep rep [Param LParamMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([DimIndex SubExp]
-> Param Type -> VName -> AllocM fromrep rep (Param LParamMem)
forall {rep} {inner :: * -> *} {f :: * -> *} {u}.
(OpC rep ~ MemOp inner, RetType rep ~ RetTypeMem,
 FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem,
 BranchType rep ~ BranchTypeMem, Monad f, HasLetDecMem (LetDec rep),
 ASTRep rep, OpReturns inner, RephraseOp inner, HasScope rep f,
 Pretty u, Pretty (inner rep), Rename (inner rep), Show (inner rep),
 Ord (inner rep), Substitute (inner rep), FreeIn (inner rep)) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
is) [Param Type]
x_params [VName]
arrs
      [Param LParamMem]
y_params' <- (Param Type -> VName -> AllocM fromrep rep (Param LParamMem))
-> [Param Type] -> [VName] -> AllocM fromrep rep [Param LParamMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([DimIndex SubExp]
-> Param Type -> VName -> AllocM fromrep rep (Param LParamMem)
forall {rep} {fromrep} {inner :: * -> *} {u}.
(LetDec rep ~ LParamMem, RetType rep ~ RetTypeMem,
 RetType fromrep ~ DeclExtType, FParamInfo rep ~ FParamMem,
 FParamInfo fromrep ~ DeclType, BodyDec rep ~ (),
 BodyDec fromrep ~ (), OpC rep ~ MemOp inner, ExpDec rep ~ (),
 LParamInfo fromrep ~ Type, LParamInfo rep ~ LParamMem,
 BranchType fromrep ~ ExtType, BranchType rep ~ BranchTypeMem,
 PrettyRep fromrep, OpReturns inner, RephraseOp inner,
 Rename (inner rep), Substitute (inner rep), FreeIn (inner rep),
 SizeSubst (inner rep), BuilderOps rep, Show (inner rep),
 Ord (inner rep), Pretty u, Pretty (inner rep)) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
is) [Param Type]
y_params [VName]
arrs
      Lambda rep
lam' <-
        [LParam rep] -> Body fromrep -> AllocM fromrep rep (Lambda rep)
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda
          ([Param LParamMem]
i_params' [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
x_params' [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
y_params')
          (Lambda fromrep -> Body fromrep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda fromrep
lam)
      (Lambda rep, b) -> AllocM fromrep rep (Lambda rep, b)
forall a. a -> AllocM fromrep rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep
lam', b
nes)

    mkP :: Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> LMAD (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem LMAD (TPrimExp Int64 VName)
lmad [DimIndex SubExp]
is =
      Attrs
-> VName
-> MemInfo SubExp u MemBind
-> Param (MemInfo SubExp u MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind))
-> (Slice (TPrimExp Int64 VName) -> MemInfo SubExp u MemBind)
-> Slice (TPrimExp Int64 VName)
-> Param (MemInfo SubExp u MemBind)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> u -> MemBind -> MemInfo SubExp u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape u
u (MemBind -> MemInfo SubExp u MemBind)
-> (Slice (TPrimExp Int64 VName) -> MemBind)
-> Slice (TPrimExp Int64 VName)
-> MemInfo SubExp u MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> LMAD (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (LMAD (TPrimExp Int64 VName) -> MemBind)
-> (Slice (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName)
-> MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMAD (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
LMAD.slice LMAD (TPrimExp Int64 VName)
lmad (Slice (TPrimExp Int64 VName) -> Param (MemInfo SubExp u MemBind))
-> Slice (TPrimExp Int64 VName) -> Param (MemInfo SubExp u MemBind)
forall a b. (a -> b) -> a -> b
$
        (SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 (Slice SubExp -> Slice (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
          [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
            [DimIndex SubExp]
is [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape)

    onXParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
      Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> f (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> MemInfo SubExp u MemBind
-> Param (MemInfo SubExp u MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (PrimType -> MemInfo SubExp u MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
    onXParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      (VName
mem, LMAD (TPrimExp Int64 VName)
lmad) <- VName -> f (VName, LMAD (TPrimExp Int64 VName))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, LMAD (TPrimExp Int64 VName))
lookupArraySummary VName
arr
      Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> f (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> LMAD (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
forall {u}.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> LMAD (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem LMAD (TPrimExp Int64 VName)
lmad [DimIndex SubExp]
is
    onXParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      String -> f (Param (MemInfo SubExp u MemBind))
forall a. HasCallStack => String -> a
error (String -> f (Param (MemInfo SubExp u MemBind)))
-> String -> f (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Param (TypeBase Shape u) -> String
forall a. Pretty a => a -> String
prettyString Param (TypeBase Shape u)
p

    onYParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
      Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a. a -> AllocM fromrep rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> MemInfo SubExp u MemBind
-> Param (MemInfo SubExp u MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind))
-> MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo SubExp u MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    onYParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      Type
arr_t <- VName -> AllocM fromrep rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      Space
space <- AllocM fromrep rep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
      VName
mem <- Type -> Space -> AllocM fromrep rep VName
forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
arr_t Space
space
      let base_dims :: [TPrimExp Int64 VName]
base_dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
          lmad :: LMAD (TPrimExp Int64 VName)
lmad = TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> LMAD (TPrimExp Int64 VName)
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 VName
0 [TPrimExp Int64 VName]
base_dims
      Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a. a -> AllocM fromrep rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> LMAD (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
forall {u}.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> LMAD (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem LMAD (TPrimExp Int64 VName)
lmad [DimIndex SubExp]
is
    onYParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      String -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind)))
-> String -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Param (TypeBase Shape u) -> String
forall a. Pretty a => a -> String
prettyString Param (TypeBase Shape u)
p
allocInExp Exp fromrep
e = Mapper fromrep torep (AllocM fromrep torep)
-> Exp fromrep -> AllocM fromrep torep (Exp torep)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper fromrep torep (AllocM fromrep torep)
forall {fromrep} {trep}. Mapper fromrep trep (AllocM fromrep trep)
alloc Exp fromrep
e
  where
    alloc :: Mapper fromrep trep (AllocM fromrep trep)
alloc =
      Mapper Any Any (AllocM fromrep trep)
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody = error "Unhandled Body in ExplicitAllocations",
          mapOnRetType = error "Unhandled RetType in ExplicitAllocations",
          mapOnBranchType = error "Unhandled BranchType in ExplicitAllocations",
          mapOnFParam = error "Unhandled FParam in ExplicitAllocations",
          mapOnLParam = error "Unhandled LParam in ExplicitAllocations",
          mapOnOp = \Op fromrep
op -> do
            Op fromrep -> AllocM fromrep trep (Op trep)
handle <- (AllocEnv fromrep trep
 -> Op fromrep -> AllocM fromrep trep (Op trep))
-> AllocM
     fromrep trep (Op fromrep -> AllocM fromrep trep (Op trep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep trep
-> Op fromrep -> AllocM fromrep trep (Op trep)
forall fromrep torep.
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp
            Op fromrep -> AllocM fromrep trep (Op trep)
handle Op fromrep
op
        }

class SizeSubst op where
  opIsConst :: op -> Bool
  opIsConst = Bool -> op -> Bool
forall a b. a -> b -> a
const Bool
False

instance SizeSubst (NoOp rep)

instance (SizeSubst (op rep)) => SizeSubst (MemOp op rep) where
  opIsConst :: MemOp op rep -> Bool
opIsConst (Inner op rep
op) = op rep -> Bool
forall op. SizeSubst op => op -> Bool
opIsConst op rep
op
  opIsConst MemOp op rep
_ = Bool
False

stmConsts :: (SizeSubst (Op rep)) => Stm rep -> S.Set VName
stmConsts :: forall rep. SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Op Op rep
op))
  | Op rep -> Bool
forall op. SizeSubst op => op -> Bool
opIsConst Op rep
op = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat
stmConsts Stm rep
_ = Set VName
forall a. Monoid a => a
mempty

mkLetNamesB' ::
  ( LetDec (Rep m) ~ LetDecMem,
    Mem (Rep m) inner,
    MonadBuilder m,
    ExpDec (Rep m) ~ ()
  ) =>
  Space ->
  ExpDec (Rep m) ->
  [VName] ->
  Exp (Rep m) ->
  m (Stm (Rep m))
mkLetNamesB' :: forall (m :: * -> *) (inner :: * -> *).
(LetDec (Rep m) ~ LParamMem, Mem (Rep m) inner, MonadBuilder m,
 ExpDec (Rep m) ~ ()) =>
Space
-> ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' Space
space ExpDec (Rep m)
dec [VName]
names Exp (Rep m)
e = do
  Pat LParamMem
pat <- Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
space [VName]
names Exp (Rep m)
e [ExpHint]
nohints
  Stm (Rep m) -> m (Stm (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm (Rep m) -> m (Stm (Rep m))) -> Stm (Rep m) -> m (Stm (Rep m))
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep m))
-> StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Rep m))
Pat LParamMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()
ExpDec (Rep m)
dec) Exp (Rep m)
e
  where
    nohints :: [ExpHint]
nohints = (VName -> ExpHint) -> [VName] -> [ExpHint]
forall a b. (a -> b) -> [a] -> [b]
map (ExpHint -> VName -> ExpHint
forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names

mkLetNamesB'' ::
  ( Mem rep inner,
    LetDec rep ~ LetDecMem,
    OpReturns inner,
    ExpDec rep ~ (),
    Rep m ~ Engine.Wise rep,
    HasScope (Engine.Wise rep) m,
    MonadBuilder m,
    AliasedOp inner,
    RephraseOp (MemOp inner),
    Engine.CanBeWise inner,
    ASTConstraints (inner (Engine.Wise rep))
  ) =>
  Space ->
  [VName] ->
  Exp (Engine.Wise rep) ->
  m (Stm (Engine.Wise rep))
mkLetNamesB'' :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, LetDec rep ~ LParamMem, OpReturns inner,
 ExpDec rep ~ (), Rep m ~ Wise rep, HasScope (Wise rep) m,
 MonadBuilder m, AliasedOp inner, RephraseOp (MemOp inner),
 CanBeWise inner, ASTConstraints (inner (Wise rep))) =>
Space -> [VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB'' Space
space [VName]
names Exp (Wise rep)
e = do
  Pat LParamMem
pat <- Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
space [VName]
names Exp (Rep m)
Exp (Wise rep)
e [ExpHint]
nohints
  let pat' :: Pat (LetDec (Wise rep))
pat' = Pat (LetDec rep) -> Exp (Wise rep) -> Pat (LetDec (Wise rep))
forall rep.
Informing rep =>
Pat (LetDec rep) -> Exp (Wise rep) -> Pat (LetDec (Wise rep))
Engine.addWisdomToPat Pat (LetDec rep)
Pat LParamMem
pat Exp (Wise rep)
e
      dec :: ExpDec (Wise rep)
dec = Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
forall rep.
Informing rep =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (LetDec (Wise rep))
pat' () Exp (Wise rep)
e
  Stm (Wise rep) -> m (Stm (Wise rep))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm (Wise rep) -> m (Stm (Wise rep)))
-> Stm (Wise rep) -> m (Stm (Wise rep))
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Wise rep))
-> StmAux (ExpDec (Wise rep)) -> Exp (Wise rep) -> Stm (Wise rep)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Wise rep))
pat' ((ExpWisdom, ()) -> StmAux (ExpWisdom, ())
forall dec. dec -> StmAux dec
defAux (ExpWisdom, ())
ExpDec (Wise rep)
dec) Exp (Wise rep)
e
  where
    nohints :: [ExpHint]
nohints = (VName -> ExpHint) -> [VName] -> [ExpHint]
forall a b. (a -> b) -> [a] -> [b]
map (ExpHint -> VName -> ExpHint
forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names

simplifyMemOp ::
  (Engine.SimplifiableRep rep) =>
  ( inner (Engine.Wise rep) ->
    Engine.SimpleM rep (inner (Engine.Wise rep), Stms (Engine.Wise rep))
  ) ->
  MemOp inner (Engine.Wise rep) ->
  Engine.SimpleM rep (MemOp inner (Engine.Wise rep), Stms (Engine.Wise rep))
simplifyMemOp :: forall rep (inner :: * -> *).
SimplifiableRep rep =>
(inner (Wise rep)
 -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> MemOp inner (Wise rep)
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
simplifyMemOp inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
_ (Alloc SubExp
size Space
space) =
  (,) (MemOp inner (Wise rep)
 -> Stms (Wise rep) -> (MemOp inner (Wise rep), Stms (Wise rep)))
-> SimpleM rep (MemOp inner (Wise rep))
-> SimpleM
     rep (Stms (Wise rep) -> (MemOp inner (Wise rep), Stms (Wise rep)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Space -> MemOp inner (Wise rep)
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc (SubExp -> Space -> MemOp inner (Wise rep))
-> SimpleM rep SubExp
-> SimpleM rep (Space -> MemOp inner (Wise rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM rep SubExp
forall rep. SimplifiableRep rep => SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
size SimpleM rep (Space -> MemOp inner (Wise rep))
-> SimpleM rep Space -> SimpleM rep (MemOp inner (Wise rep))
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> SimpleM rep Space
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space) SimpleM
  rep (Stms (Wise rep) -> (MemOp inner (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Stms (Wise rep))
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms (Wise rep)
forall a. Monoid a => a
mempty
simplifyMemOp inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
onInner (Inner inner (Wise rep)
k) = do
  (inner (Wise rep)
k', Stms (Wise rep)
hoisted) <- inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
onInner inner (Wise rep)
k
  (MemOp inner (Wise rep), Stms (Wise rep))
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (inner (Wise rep) -> MemOp inner (Wise rep)
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner inner (Wise rep)
k', Stms (Wise rep)
hoisted)

simplifiable ::
  ( Engine.SimplifiableRep rep,
    LetDec rep ~ LetDecMem,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    Mem (Engine.Wise rep) inner,
    Engine.CanBeWise inner,
    RephraseOp inner,
    IsOp inner,
    OpReturns inner,
    AliasedOp inner,
    IndexOp (inner (Engine.Wise rep))
  ) =>
  (inner (Engine.Wise rep) -> UT.UsageTable) ->
  ( inner (Engine.Wise rep) ->
    Engine.SimpleM rep (inner (Engine.Wise rep), Stms (Engine.Wise rep))
  ) ->
  SimpleOps rep
simplifiable :: forall rep (inner :: * -> *).
(SimplifiableRep rep, LetDec rep ~ LParamMem, ExpDec rep ~ (),
 BodyDec rep ~ (), Mem (Wise rep) inner, CanBeWise inner,
 RephraseOp inner, IsOp inner, OpReturns inner, AliasedOp inner,
 IndexOp (inner (Wise rep))) =>
(inner (Wise rep) -> UsageTable)
-> (inner (Wise rep)
    -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> SimpleOps rep
simplifiable inner (Wise rep) -> UsageTable
innerUsage inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
simplifyInnerOp =
  (SymbolTable (Wise rep)
 -> Pat (LetDec (Wise rep))
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (BuilderT (Wise rep) (State VNameSource))
-> (Op (Wise rep) -> UsageTable)
-> SimplifyOp rep (Op (Wise rep))
-> SimpleOps rep
forall {k} (rep :: k).
(SymbolTable (Wise rep)
 -> Pat (LetDec (Wise rep))
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (Builder (Wise rep))
-> (Op (Wise rep) -> UsageTable)
-> SimplifyOp rep (Op (Wise rep))
-> SimpleOps rep
SimpleOps SymbolTable (Wise rep)
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> SimpleM rep (ExpWisdom, ExpDec rep)
SymbolTable (Wise rep)
-> Pat (LetDec (Wise rep))
-> Exp (Wise rep)
-> SimpleM rep (ExpDec (Wise rep))
forall {rep} {f :: * -> *} {p}.
(ExpDec rep ~ (), Applicative f, ASTRep rep, AliasedOp (OpC rep),
 CanBeWise (OpC rep), Ord (OpC rep (Wise rep)),
 Show (OpC rep (Wise rep)), Rename (OpC rep (Wise rep)),
 Substitute (OpC rep (Wise rep)), FreeIn (OpC rep (Wise rep)),
 Pretty (OpC rep (Wise rep))) =>
p
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> f (ExpWisdom, ExpDec rep)
mkExpDecS' SymbolTable (Wise rep)
-> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
forall {rep} {f :: * -> *} {p}.
(BodyDec rep ~ (), Applicative f, ASTRep rep, AliasedOp (OpC rep),
 CanBeWise (OpC rep), Ord (OpC rep (Wise rep)),
 Show (OpC rep (Wise rep)), Rename (OpC rep (Wise rep)),
 Substitute (OpC rep (Wise rep)), FreeIn (OpC rep (Wise rep)),
 Pretty (OpC rep (Wise rep))) =>
p -> Stms (Wise rep) -> Result -> f (Body (Wise rep))
mkBodyS' Protect (BuilderT (Wise rep) (State VNameSource))
SubExp
-> Pat (LetDec (Rep (BuilderT (Wise rep) (State VNameSource))))
-> MemOp inner (Wise rep)
-> Maybe (BuilderT (Wise rep) (State VNameSource) ())
forall {m :: * -> *} {d} {u} {ret} {inner :: * -> *}
       {inner :: * -> *} {rep}.
(BranchType (Rep m) ~ MemInfo d u ret, OpC (Rep m) ~ MemOp inner,
 MonadBuilder m, IsBodyType (MemInfo d u ret), Pretty d, Pretty u,
 Pretty ret, Pretty (inner (Rep m)),
 Pretty (TypeBase (ShapeBase d) u), Pretty (ShapeBase d),
 Rename (inner (Rep m)), Substitute d, Substitute ret,
 Substitute (inner (Rep m)), FreeIn d, FreeIn ret,
 FreeIn (inner (Rep m)), IsOp inner, RephraseOp inner, Show d,
 Show ret, Show u, Show (inner (Rep m)), Ord d, Ord ret, Ord u,
 Ord (inner (Rep m))) =>
SubExp -> Pat (LetDec (Rep m)) -> MemOp inner rep -> Maybe (m ())
protectOp Op (Wise rep) -> UsageTable
MemOp inner (Wise rep) -> UsageTable
opUsage ((inner (Wise rep)
 -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> MemOp inner (Wise rep)
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
forall rep (inner :: * -> *).
SimplifiableRep rep =>
(inner (Wise rep)
 -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> MemOp inner (Wise rep)
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
simplifyMemOp inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
simplifyInnerOp)
  where
    mkExpDecS' :: p
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> f (ExpWisdom, ExpDec rep)
mkExpDecS' p
_ Pat (VarWisdom, LetDec rep)
pat Exp (Wise rep)
e =
      (ExpWisdom, ExpDec rep) -> f (ExpWisdom, ExpDec rep)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((ExpWisdom, ExpDec rep) -> f (ExpWisdom, ExpDec rep))
-> (ExpWisdom, ExpDec rep) -> f (ExpWisdom, ExpDec rep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
forall rep.
Informing rep =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (VarWisdom, LetDec rep)
Pat (LetDec (Wise rep))
pat () Exp (Wise rep)
e

    mkBodyS' :: p -> Stms (Wise rep) -> Result -> f (Body (Wise rep))
mkBodyS' p
_ Stms (Wise rep)
stms Result
res = Body (Wise rep) -> f (Body (Wise rep))
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Wise rep) -> f (Body (Wise rep)))
-> Body (Wise rep) -> f (Body (Wise rep))
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
forall rep.
Informing rep =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody () Stms (Wise rep)
stms Result
res

    protectOp :: SubExp -> Pat (LetDec (Rep m)) -> MemOp inner rep -> Maybe (m ())
protectOp SubExp
taken Pat (LetDec (Rep m))
pat (Alloc SubExp
size Space
space) = m () -> Maybe (m ())
forall a. a -> Maybe a
Just (m () -> Maybe (m ())) -> m () -> Maybe (m ())
forall a b. (a -> b) -> a -> b
$ do
      Body (Rep m)
tbody <- [SubExp] -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
size]
      Body (Rep m)
fbody <- [SubExp] -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
      SubExp
size' <-
        String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"hoisted_alloc_size" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          [SubExp]
-> [Case (Body (Rep m))]
-> Body (Rep m)
-> MatchDec (BranchType (Rep m))
-> Exp (Rep m)
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp
taken] [[Maybe PrimValue] -> Body (Rep m) -> Case (Body (Rep m))
forall body. [Maybe PrimValue] -> body -> Case body
Case [PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just (PrimValue -> Maybe PrimValue) -> PrimValue -> Maybe PrimValue
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] Body (Rep m)
tbody] Body (Rep m)
fbody (MatchDec (BranchType (Rep m)) -> Exp (Rep m))
-> MatchDec (BranchType (Rep m)) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
            [MemInfo d u ret] -> MatchSort -> MatchDec (MemInfo d u ret)
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64] MatchSort
MatchFallback
      Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner (Rep m)
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
size' Space
space
    protectOp SubExp
_ Pat (LetDec (Rep m))
_ MemOp inner rep
_ = Maybe (m ())
forall a. Maybe a
Nothing

    opUsage :: MemOp inner (Wise rep) -> UsageTable
opUsage (Alloc (Var VName
size) Space
_) =
      VName -> UsageTable
UT.sizeUsage VName
size
    opUsage (Alloc SubExp
_ Space
_) =
      UsageTable
forall a. Monoid a => a
mempty
    opUsage (Inner inner (Wise rep)
inner) =
      inner (Wise rep) -> UsageTable
innerUsage inner (Wise rep)
inner

data ExpHint
  = NoHint
  | Hint LMAD Space

defaultExpHints :: (ASTRep rep, HasScope rep m) => Exp rep -> m [ExpHint]
defaultExpHints :: forall rep (m :: * -> *).
(ASTRep rep, HasScope rep m) =>
Exp rep -> m [ExpHint]
defaultExpHints Exp rep
e = (ExtType -> ExpHint) -> [ExtType] -> [ExpHint]
forall a b. (a -> b) -> [a] -> [b]
map (ExpHint -> ExtType -> ExpHint
forall a b. a -> b -> a
const ExpHint
NoHint) ([ExtType] -> [ExpHint]) -> m [ExtType] -> m [ExpHint]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp rep -> m [ExtType]
forall rep (m :: * -> *).
(HasScope rep m, TypedOp (OpC rep)) =>
Exp rep -> m [ExtType]
expExtType Exp rep
e