{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.Pass.ExplicitAllocations
( explicitAllocationsGeneric,
explicitAllocationsInStmsGeneric,
ExpHint (..),
defaultExpHints,
Allocable,
AllocM,
AllocEnv (..),
SizeSubst (..),
allocInStms,
allocForArray,
simplifiable,
arraySizeInBytesExp,
mkLetNamesB',
mkLetNamesB'',
dimAllocationSize,
ChunkMap,
module Control.Monad.Reader,
module Futhark.MonadFreshNames,
module Futhark.Pass,
module Futhark.Tools,
)
where
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor (first)
import Data.Either (partitionEithers)
import Data.Foldable (toList)
import Data.List (foldl', transpose, zip4)
import qualified Data.List.NonEmpty as NE
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.IR.Mem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Engine (SimpleOps (..))
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Rep (mkWiseBody)
import Futhark.Pass
import Futhark.Tools
import Futhark.Util (maybeNth, splitAt3)
dimAllocationSize :: ChunkMap -> SubExp -> SubExp
dimAllocationSize :: ChunkMap -> SubExp -> SubExp
dimAllocationSize ChunkMap
chunkmap (Var VName
v) =
SubExp -> (SubExp -> SubExp) -> Maybe SubExp -> SubExp
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (VName -> SubExp
Var VName
v) (ChunkMap -> SubExp -> SubExp
dimAllocationSize ChunkMap
chunkmap) (Maybe SubExp -> SubExp) -> Maybe SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ VName -> ChunkMap -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v ChunkMap
chunkmap
dimAllocationSize ChunkMap
_ SubExp
size =
SubExp
size
type Allocable fromrep torep inner =
( PrettyRep fromrep,
PrettyRep torep,
Mem torep inner,
LetDec torep ~ LetDecMem,
FParamInfo fromrep ~ DeclType,
LParamInfo fromrep ~ Type,
BranchType fromrep ~ ExtType,
RetType fromrep ~ DeclExtType,
BodyDec fromrep ~ (),
BodyDec torep ~ (),
ExpDec torep ~ (),
SizeSubst inner,
BuilderOps torep
)
type ChunkMap = M.Map VName SubExp
data AllocEnv fromrep torep = AllocEnv
{ forall fromrep torep. AllocEnv fromrep torep -> ChunkMap
chunkMap :: ChunkMap,
forall fromrep torep. AllocEnv fromrep torep -> Bool
aggressiveReuse :: Bool,
forall fromrep torep. AllocEnv fromrep torep -> Space
allocSpace :: Space,
forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts :: S.Set VName,
forall fromrep torep.
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep),
forall fromrep torep.
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
}
newtype AllocM fromrep torep a
= AllocM (BuilderT torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a)
deriving
( Functor (AllocM fromrep torep)
Functor (AllocM fromrep torep)
-> (forall a. a -> AllocM fromrep torep a)
-> (forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b)
-> (forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c)
-> (forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b)
-> (forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a)
-> Applicative (AllocM fromrep torep)
forall a. a -> AllocM fromrep torep a
forall {fromrep} {torep}. Functor (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
$c<* :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
*> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c*> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
liftA2 :: forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
$cliftA2 :: forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
<*> :: forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
$c<*> :: forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
pure :: forall a. a -> AllocM fromrep torep a
$cpure :: forall fromrep torep a. a -> AllocM fromrep torep a
Applicative,
(forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b)
-> (forall a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a)
-> Functor (AllocM fromrep torep)
forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
$c<$ :: forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
fmap :: forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
$cfmap :: forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
Functor,
Applicative (AllocM fromrep torep)
Applicative (AllocM fromrep torep)
-> (forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b)
-> (forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b)
-> (forall a. a -> AllocM fromrep torep a)
-> Monad (AllocM fromrep torep)
forall a. a -> AllocM fromrep torep a
forall fromrep torep. Applicative (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> AllocM fromrep torep a
$creturn :: forall fromrep torep a. a -> AllocM fromrep torep a
>> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c>> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
>>= :: forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
$c>>= :: forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
Monad,
Monad (AllocM fromrep torep)
AllocM fromrep torep VNameSource
Monad (AllocM fromrep torep)
-> AllocM fromrep torep VNameSource
-> (VNameSource -> AllocM fromrep torep ())
-> MonadFreshNames (AllocM fromrep torep)
VNameSource -> AllocM fromrep torep ()
forall fromrep torep. Monad (AllocM fromrep torep)
forall fromrep torep. AllocM fromrep torep VNameSource
forall fromrep torep. VNameSource -> AllocM fromrep torep ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
putNameSource :: VNameSource -> AllocM fromrep torep ()
$cputNameSource :: forall fromrep torep. VNameSource -> AllocM fromrep torep ()
getNameSource :: AllocM fromrep torep VNameSource
$cgetNameSource :: forall fromrep torep. AllocM fromrep torep VNameSource
MonadFreshNames,
HasScope torep,
LocalScope torep,
MonadReader (AllocEnv fromrep torep)
)
instance (Allocable fromrep torep inner) => MonadBuilder (AllocM fromrep torep) where
type Rep (AllocM fromrep torep) = torep
mkExpDecM :: Pat (LetDec (Rep (AllocM fromrep torep)))
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (ExpDec (Rep (AllocM fromrep torep)))
mkExpDecM Pat (LetDec (Rep (AllocM fromrep torep)))
_ Exp (Rep (AllocM fromrep torep))
_ = () -> AllocM fromrep torep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
mkLetNamesM :: [VName]
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (Stm (Rep (AllocM fromrep torep)))
mkLetNamesM [VName]
names Exp (Rep (AllocM fromrep torep))
e = do
Space
def_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
ChunkMap
chunkmap <- (AllocEnv fromrep torep -> ChunkMap)
-> AllocM fromrep torep ChunkMap
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> ChunkMap
forall fromrep torep. AllocEnv fromrep torep -> ChunkMap
chunkMap
[ExpHint]
hints <- Exp torep -> AllocM fromrep torep [ExpHint]
forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
Exp (Rep (AllocM fromrep torep))
e
Pat LParamMem
pat <- Space
-> ChunkMap
-> [VName]
-> Exp (Rep (AllocM fromrep torep))
-> [ExpHint]
-> AllocM fromrep torep (Pat LParamMem)
forall (m :: * -> *) inner.
(MonadBuilder m, Mem (Rep m) inner) =>
Space
-> ChunkMap
-> [VName]
-> Exp (Rep m)
-> [ExpHint]
-> m (Pat LParamMem)
patWithAllocations Space
def_space ChunkMap
chunkmap [VName]
names Exp (Rep (AllocM fromrep torep))
e [ExpHint]
hints
Stm torep -> AllocM fromrep torep (Stm torep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm torep -> AllocM fromrep torep (Stm torep))
-> Stm torep -> AllocM fromrep torep (Stm torep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec torep)
-> StmAux (ExpDec torep) -> Exp torep -> Stm torep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec torep)
Pat LParamMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) Exp torep
Exp (Rep (AllocM fromrep torep))
e
mkBodyM :: Stms (Rep (AllocM fromrep torep))
-> Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
mkBodyM Stms (Rep (AllocM fromrep torep))
stms Result
res = Body torep -> AllocM fromrep torep (Body torep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body torep -> AllocM fromrep torep (Body torep))
-> Body torep -> AllocM fromrep torep (Body torep)
forall a b. (a -> b) -> a -> b
$ BodyDec torep -> Stms torep -> Result -> Body torep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms torep
Stms (Rep (AllocM fromrep torep))
stms Result
res
addStms :: Stms (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
addStms = BuilderT
torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
-> AllocM fromrep torep ()
forall fromrep torep a.
BuilderT
torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM (BuilderT
torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
-> AllocM fromrep torep ())
-> (Stms torep
-> BuilderT
torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ())
-> Stms torep
-> AllocM fromrep torep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms torep
-> BuilderT
torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
collectStms :: forall a.
AllocM fromrep torep a
-> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep)))
collectStms (AllocM BuilderT
torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) = BuilderT
torep
(ReaderT (AllocEnv fromrep torep) (State VNameSource))
(a, Stms torep)
-> AllocM fromrep torep (a, Stms torep)
forall fromrep torep a.
BuilderT
torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM (BuilderT
torep
(ReaderT (AllocEnv fromrep torep) (State VNameSource))
(a, Stms torep)
-> AllocM fromrep torep (a, Stms torep))
-> BuilderT
torep
(ReaderT (AllocEnv fromrep torep) (State VNameSource))
(a, Stms torep)
-> AllocM fromrep torep (a, Stms torep)
forall a b. (a -> b) -> a -> b
$ BuilderT
torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> BuilderT
torep
(ReaderT (AllocEnv fromrep torep) (State VNameSource))
(a,
Stms
(Rep
(BuilderT
torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT
torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m
expHints :: Exp torep -> AllocM fromrep torep [ExpHint]
expHints :: forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e = do
Exp torep -> AllocM fromrep torep [ExpHint]
f <- (AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM
fromrep torep (Exp torep -> AllocM fromrep torep [ExpHint])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
forall fromrep torep.
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints
Exp torep -> AllocM fromrep torep [ExpHint]
f Exp torep
e
askDefaultSpace :: AllocM fromrep torep Space
askDefaultSpace :: forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace = (AllocEnv fromrep torep -> Space) -> AllocM fromrep torep Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> Space
forall fromrep torep. AllocEnv fromrep torep -> Space
allocSpace
runAllocM ::
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep)) ->
(Exp torep -> AllocM fromrep torep [ExpHint]) ->
AllocM fromrep torep a ->
m a
runAllocM :: forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM BuilderT
torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) =
((a, Stms torep) -> a) -> m (a, Stms torep) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms torep) -> a
forall a b. (a, b) -> a
fst (m (a, Stms torep) -> m a) -> m (a, Stms torep) -> m a
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((a, Stms torep), VNameSource))
-> m (a, Stms torep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Stms torep), VNameSource))
-> m (a, Stms torep))
-> (VNameSource -> ((a, Stms torep), VNameSource))
-> m (a, Stms torep)
forall a b. (a -> b) -> a -> b
$ State VNameSource (a, Stms torep)
-> VNameSource -> ((a, Stms torep), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (a, Stms torep)
-> VNameSource -> ((a, Stms torep), VNameSource))
-> State VNameSource (a, Stms torep)
-> VNameSource
-> ((a, Stms torep), VNameSource)
forall a b. (a -> b) -> a -> b
$ ReaderT
(AllocEnv fromrep torep) (State VNameSource) (a, Stms torep)
-> AllocEnv fromrep torep -> State VNameSource (a, Stms torep)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (BuilderT
torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> Scope torep
-> ReaderT
(AllocEnv fromrep torep) (State VNameSource) (a, Stms torep)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT
torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m Scope torep
forall a. Monoid a => a
mempty) AllocEnv fromrep torep
env
where
env :: AllocEnv fromrep torep
env =
AllocEnv :: forall fromrep torep.
ChunkMap
-> Bool
-> Space
-> Set VName
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocEnv fromrep torep
AllocEnv
{ chunkMap :: ChunkMap
chunkMap = ChunkMap
forall a. Monoid a => a
mempty,
aggressiveReuse :: Bool
aggressiveReuse = Bool
False,
allocSpace :: Space
allocSpace = Space
DefaultSpace,
envConsts :: Set VName
envConsts = Set VName
forall a. Monoid a => a
mempty,
allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp = Op fromrep -> AllocM fromrep torep (Op torep)
handleOp,
envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints = Exp torep -> AllocM fromrep torep [ExpHint]
hints
}
elemSize :: Num a => Type -> a
elemSize :: forall a. Num a => Type -> a
elemSize = PrimType -> a
forall a. Num a => PrimType -> a
primByteSize (PrimType -> a) -> (Type -> PrimType) -> Type -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType
arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp Type
t =
TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(*) (Type -> TPrimExp Int64 VName
forall a. Num a => Type -> a
elemSize Type
t) ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
arraySizeInBytesExpM :: MonadBuilder m => ChunkMap -> Type -> m (PrimExp VName)
arraySizeInBytesExpM :: forall (m :: * -> *).
MonadBuilder m =>
ChunkMap -> Type -> m (PrimExp VName)
arraySizeInBytesExpM ChunkMap
chunkmap Type
t = do
let dim_prod_i64 :: TPrimExp Int64 VName
dim_prod_i64 = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (SubExp -> SubExp) -> SubExp -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ChunkMap -> SubExp -> SubExp
dimAllocationSize ChunkMap
chunkmap) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
elm_size_i64 :: TPrimExp Int64 VName
elm_size_i64 = Type -> TPrimExp Int64 VName
forall a. Num a => Type -> a
elemSize Type
t
PrimExp VName -> m (PrimExp VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp VName -> m (PrimExp VName))
-> PrimExp VName -> m (PrimExp VName)
forall a b. (a -> b) -> a -> b
$
BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMax IntType
Int64) (PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> PrimExp VName) -> PrimValue -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
0) (PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
dim_prod_i64 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elm_size_i64
arraySizeInBytes :: MonadBuilder m => ChunkMap -> Type -> m SubExp
arraySizeInBytes :: forall (m :: * -> *).
MonadBuilder m =>
ChunkMap -> Type -> m SubExp
arraySizeInBytes ChunkMap
chunkmap = String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" (Exp (Rep m) -> m SubExp)
-> (Type -> m (Exp (Rep m))) -> Type -> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> m (Exp (Rep m)))
-> (Type -> m (PrimExp VName)) -> Type -> m (Exp (Rep m))
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ChunkMap -> Type -> m (PrimExp VName)
forall (m :: * -> *).
MonadBuilder m =>
ChunkMap -> Type -> m (PrimExp VName)
arraySizeInBytesExpM ChunkMap
chunkmap
allocForArray' ::
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
ChunkMap ->
Type ->
Space ->
m VName
allocForArray' :: forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
ChunkMap -> Type -> Space -> m VName
allocForArray' ChunkMap
chunkmap Type
t Space
space = do
SubExp
size <- ChunkMap -> Type -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
ChunkMap -> Type -> m SubExp
arraySizeInBytes ChunkMap
chunkmap Type
t
String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size Space
space
allocForArray ::
Allocable fromrep torep inner =>
Type ->
Space ->
AllocM fromrep torep VName
allocForArray :: forall fromrep torep inner.
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space = do
ChunkMap
chunkmap <- (AllocEnv fromrep torep -> ChunkMap)
-> AllocM fromrep torep ChunkMap
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> ChunkMap
forall fromrep torep. AllocEnv fromrep torep -> ChunkMap
chunkMap
ChunkMap -> Type -> Space -> AllocM fromrep torep VName
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
ChunkMap -> Type -> Space -> m VName
allocForArray' ChunkMap
chunkmap Type
t Space
space
allocsForStm ::
(Allocable fromrep torep inner) =>
[Ident] ->
Exp torep ->
AllocM fromrep torep (Stm torep)
allocsForStm :: forall fromrep torep inner.
Allocable fromrep torep inner =>
[Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm [Ident]
idents Exp torep
e = do
Space
def_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
ChunkMap
chunkmap <- (AllocEnv fromrep torep -> ChunkMap)
-> AllocM fromrep torep ChunkMap
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> ChunkMap
forall fromrep torep. AllocEnv fromrep torep -> ChunkMap
chunkMap
[ExpHint]
hints <- Exp torep -> AllocM fromrep torep [ExpHint]
forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e
[ExpReturns]
rts <- Exp torep -> AllocM fromrep torep [ExpReturns]
forall rep (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp torep
e
[PatElem LParamMem]
pes <- Space
-> ChunkMap
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> AllocM fromrep torep [PatElem LParamMem]
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space
-> ChunkMap
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m [PatElem LParamMem]
allocsForPat Space
def_space ChunkMap
chunkmap [Ident]
idents [ExpReturns]
rts [ExpHint]
hints
()
dec <- Pat (LetDec (Rep (AllocM fromrep torep)))
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (ExpDec (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
pes) Exp torep
Exp (Rep (AllocM fromrep torep))
e
Stm torep -> AllocM fromrep torep (Stm torep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm torep -> AllocM fromrep torep (Stm torep))
-> Stm torep -> AllocM fromrep torep (Stm torep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec torep)
-> StmAux (ExpDec torep) -> Exp torep -> Stm torep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
pes) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()
dec) Exp torep
e
patWithAllocations ::
(MonadBuilder m, Mem (Rep m) inner) =>
Space ->
ChunkMap ->
[VName] ->
Exp (Rep m) ->
[ExpHint] ->
m (Pat LetDecMem)
patWithAllocations :: forall (m :: * -> *) inner.
(MonadBuilder m, Mem (Rep m) inner) =>
Space
-> ChunkMap
-> [VName]
-> Exp (Rep m)
-> [ExpHint]
-> m (Pat LParamMem)
patWithAllocations Space
def_space ChunkMap
chunkmap [VName]
names Exp (Rep m)
e [ExpHint]
hints = do
[Type]
ts' <- [VName] -> [ExtType] -> [Type]
forall u.
[VName]
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> [TypeBase Shape u]
instantiateShapes' [VName]
names ([ExtType] -> [Type]) -> m [ExtType] -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp (Rep m) -> m [ExtType]
forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp (Rep m)
e
let idents :: [Ident]
idents = (VName -> Type -> Ident) -> [VName] -> [Type] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
names [Type]
ts'
[ExpReturns]
rts <- Exp (Rep m) -> m [ExpReturns]
forall rep (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp (Rep m)
e
[PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem LParamMem] -> Pat LParamMem)
-> m [PatElem LParamMem] -> m (Pat LParamMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Space
-> ChunkMap
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m [PatElem LParamMem]
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space
-> ChunkMap
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m [PatElem LParamMem]
allocsForPat Space
def_space ChunkMap
chunkmap [Ident]
idents [ExpReturns]
rts [ExpHint]
hints
mkMissingIdents :: MonadFreshNames m => [Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents :: forall (m :: * -> *).
MonadFreshNames m =>
[Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
idents [ExpReturns]
rts =
[Ident] -> [Ident]
forall a. [a] -> [a]
reverse ([Ident] -> [Ident]) -> m [Ident] -> m [Ident]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ExpReturns -> Maybe Ident -> m Ident)
-> [ExpReturns] -> [Maybe Ident] -> m [Ident]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ExpReturns -> Maybe Ident -> m Ident
forall {f :: * -> *} {d} {u} {ret}.
MonadFreshNames f =>
MemInfo d u ret -> Maybe Ident -> f Ident
f ([ExpReturns] -> [ExpReturns]
forall a. [a] -> [a]
reverse [ExpReturns]
rts) ((Ident -> Maybe Ident) -> [Ident] -> [Maybe Ident]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Maybe Ident
forall a. a -> Maybe a
Just ([Ident] -> [Ident]
forall a. [a] -> [a]
reverse [Ident]
idents) [Maybe Ident] -> [Maybe Ident] -> [Maybe Ident]
forall a. [a] -> [a] -> [a]
++ Maybe Ident -> [Maybe Ident]
forall a. a -> [a]
repeat Maybe Ident
forall a. Maybe a
Nothing)
where
f :: MemInfo d u ret -> Maybe Ident -> f Ident
f MemInfo d u ret
_ (Just Ident
ident) = Ident -> f Ident
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ident
ident
f (MemMem Space
space) Maybe Ident
Nothing = String -> Type -> f Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext_mem" (Type -> f Ident) -> Type -> f Ident
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
f MemInfo d u ret
_ Maybe Ident
Nothing = String -> Type -> f Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext" (Type -> f Ident) -> Type -> f Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
allocsForPat ::
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space ->
ChunkMap ->
[Ident] ->
[ExpReturns] ->
[ExpHint] ->
m [PatElem LetDecMem]
allocsForPat :: forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space
-> ChunkMap
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m [PatElem LParamMem]
allocsForPat Space
def_space ChunkMap
chunkmap [Ident]
some_idents [ExpReturns]
rts [ExpHint]
hints = do
[Ident]
idents <- [Ident] -> [ExpReturns] -> m [Ident]
forall (m :: * -> *).
MonadFreshNames m =>
[Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
some_idents [ExpReturns]
rts
[(Ident, ExpReturns, ExpHint)]
-> ((Ident, ExpReturns, ExpHint) -> m (PatElem LParamMem))
-> m [PatElem LParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Ident]
-> [ExpReturns] -> [ExpHint] -> [(Ident, ExpReturns, ExpHint)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Ident]
idents [ExpReturns]
rts [ExpHint]
hints) (((Ident, ExpReturns, ExpHint) -> m (PatElem LParamMem))
-> m [PatElem LParamMem])
-> ((Ident, ExpReturns, ExpHint) -> m (PatElem LParamMem))
-> m [PatElem LParamMem]
forall a b. (a -> b) -> a -> b
$ \(Ident
ident, ExpReturns
rt, ExpHint
hint) -> do
let ident_shape :: Shape
ident_shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Type -> Shape
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident
case ExpReturns
rt of
MemPrim PrimType
_ -> do
LParamMem
summary <- Space -> ChunkMap -> Type -> ExpHint -> m LParamMem
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space -> ChunkMap -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
def_space ChunkMap
chunkmap (Ident -> Type
identType Ident
ident) ExpHint
hint
PatElem LParamMem -> m (PatElem LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> PatElem LParamMem -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
MemMem Space
space ->
PatElem LParamMem -> m (PatElem LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> PatElem LParamMem -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElem LParamMem) -> LParamMem -> PatElem LParamMem
forall a b. (a -> b) -> a -> b
$ Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
MemArray PrimType
bt ShapeBase (Ext SubExp)
_ NoUniqueness
u (Just (ReturnsInBlock VName
mem ExtIxFun
extixfun)) -> do
let ixfn :: IxFun (TPrimExp Int64 VName)
ixfn = [Ident] -> ExtIxFun -> IxFun (TPrimExp Int64 VName)
forall {f :: * -> *} {f :: * -> *}.
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents ExtIxFun
extixfun
PatElem LParamMem -> m (PatElem LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> (MemBind -> PatElem LParamMem)
-> MemBind
-> m (PatElem LParamMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElem LParamMem)
-> (MemBind -> LParamMem) -> MemBind -> PatElem LParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u (MemBind -> m (PatElem LParamMem))
-> MemBind -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfn
MemArray PrimType
_ ShapeBase (Ext SubExp)
extshape NoUniqueness
_ Maybe MemReturn
Nothing
| Just [SubExp]
_ <- ShapeBase (Ext SubExp) -> Maybe [SubExp]
forall {b}. ShapeBase (Ext b) -> Maybe [b]
knownShape ShapeBase (Ext SubExp)
extshape -> do
LParamMem
summary <- Space -> ChunkMap -> Type -> ExpHint -> m LParamMem
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space -> ChunkMap -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
def_space ChunkMap
chunkmap (Ident -> Type
identType Ident
ident) ExpHint
hint
PatElem LParamMem -> m (PatElem LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> PatElem LParamMem -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
MemArray PrimType
bt ShapeBase (Ext SubExp)
_ NoUniqueness
u (Just (ReturnsNewBlock Space
_ Int
i ExtIxFun
extixfn)) -> do
let ixfn :: IxFun (TPrimExp Int64 VName)
ixfn = [Ident] -> ExtIxFun -> IxFun (TPrimExp Int64 VName)
forall {f :: * -> *} {f :: * -> *}.
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents ExtIxFun
extixfn
PatElem LParamMem -> m (PatElem LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> (MemBind -> PatElem LParamMem)
-> MemBind
-> m (PatElem LParamMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElem LParamMem)
-> (MemBind -> LParamMem) -> MemBind -> PatElem LParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u (MemBind -> m (PatElem LParamMem))
-> MemBind -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$
VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn ([Ident] -> Int -> VName
forall {a}. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i) IxFun (TPrimExp Int64 VName)
ixfn
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
PatElem LParamMem -> m (PatElem LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem LParamMem -> m (PatElem LParamMem))
-> PatElem LParamMem -> m (PatElem LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElem LParamMem) -> LParamMem -> PatElem LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
ExpReturns
_ -> String -> m (PatElem LParamMem)
forall a. HasCallStack => String -> a
error String
"Impossible case reached in allocsForPat!"
where
knownShape :: ShapeBase (Ext b) -> Maybe [b]
knownShape = (Ext b -> Maybe b) -> [Ext b] -> Maybe [b]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext b -> Maybe b
forall {a}. Ext a -> Maybe a
known ([Ext b] -> Maybe [b])
-> (ShapeBase (Ext b) -> [Ext b]) -> ShapeBase (Ext b) -> Maybe [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase (Ext b) -> [Ext b]
forall d. ShapeBase d -> [d]
shapeDims
known :: Ext a -> Maybe a
known (Free a
v) = a -> Maybe a
forall a. a -> Maybe a
Just a
v
known Ext {} = Maybe a
forall a. Maybe a
Nothing
getIdent :: [Ident] -> a -> VName
getIdent [Ident]
idents a
i =
case a -> [Ident] -> Maybe Ident
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth a
i [Ident]
idents of
Just Ident
ident -> Ident -> VName
identName Ident
ident
Maybe Ident
Nothing ->
String -> VName
forall a. HasCallStack => String -> a
error (String -> VName) -> String -> VName
forall a b. (a -> b) -> a -> b
$ String
"getIdent: Ext " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
i String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" but pattern has " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
idents) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" elements: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> [Ident] -> String
forall a. Pretty a => a -> String
pretty [Ident]
idents
instantiateExtIxFun :: [Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents = (f (Ext VName) -> f VName) -> f (f (Ext VName)) -> f (f VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((f (Ext VName) -> f VName) -> f (f (Ext VName)) -> f (f VName))
-> (f (Ext VName) -> f VName) -> f (f (Ext VName)) -> f (f VName)
forall a b. (a -> b) -> a -> b
$ (Ext VName -> VName) -> f (Ext VName) -> f VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> VName
inst
where
inst :: Ext VName -> VName
inst (Free VName
v) = VName
v
inst (Ext Int
i) = [Ident] -> Int -> VName
forall {a}. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i
instantiateIxFun :: Monad m => ExtIxFun -> m IxFun
instantiateIxFun :: forall (m :: * -> *).
Monad m =>
ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
instantiateIxFun = (TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
-> ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
-> ExtIxFun -> m (IxFun (TPrimExp Int64 VName)))
-> (TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
-> ExtIxFun
-> m (IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ (Ext VName -> m VName)
-> TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> m VName
forall {f :: * -> *} {a}. Applicative f => Ext a -> f a
inst
where
inst :: Ext a -> f a
inst Ext {} = String -> f a
forall a. HasCallStack => String -> a
error String
"instantiateIxFun: not yet"
inst (Free a
x) = a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
summaryForBindage ::
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space ->
ChunkMap ->
Type ->
ExpHint ->
m (MemBound NoUniqueness)
summaryForBindage :: forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space -> ChunkMap -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
_ ChunkMap
_ (Prim PrimType
bt) ExpHint
_ =
LParamMem -> m LParamMem
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
summaryForBindage Space
_ ChunkMap
_ (Mem Space
space) ExpHint
_ =
LParamMem -> m LParamMem
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
summaryForBindage Space
_ ChunkMap
_ (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) ExpHint
_ =
LParamMem -> m LParamMem
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
summaryForBindage Space
def_space ChunkMap
chunkmap t :: Type
t@(Array PrimType
pt Shape
shape NoUniqueness
u) ExpHint
NoHint = do
VName
m <- ChunkMap -> Type -> Space -> m VName
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
ChunkMap -> Type -> Space -> m VName
allocForArray' ChunkMap
chunkmap Type
t Space
def_space
LParamMem -> m LParamMem
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m (IxFun (TPrimExp Int64 VName) -> MemBind)
-> IxFun (TPrimExp Int64 VName) -> MemBind
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
summaryForBindage Space
_ ChunkMap
_ t :: Type
t@(Array PrimType
pt Shape
_ NoUniqueness
_) (Hint IxFun (TPrimExp Int64 VName)
ixfun Space
space) = do
SubExp
bytes <-
String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" (Exp (Rep m) -> m SubExp)
-> (TPrimExp Int64 VName -> m (Exp (Rep m)))
-> TPrimExp Int64 VName
-> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> m (Exp (Rep m)))
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> m (Exp (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> m SubExp)
-> TPrimExp Int64 VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
[TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product
[ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun,
Int64 -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral (PrimType -> Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
pt :: Int64)
]
VName
m <- String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
bytes Space
space
LParamMem -> m LParamMem
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int64 VName)
ixfun
allocInFParams ::
(Allocable fromrep torep inner) =>
[(FParam fromrep, Space)] ->
([FParam torep] -> AllocM fromrep torep a) ->
AllocM fromrep torep a
allocInFParams :: forall fromrep torep inner a.
Allocable fromrep torep inner =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams [(FParam fromrep, Space)]
params [FParam torep] -> AllocM fromrep torep a
m = do
([Param FParamMem]
valparams, ([Param FParamMem]
memparams, [Param FParamMem]
ctxparams)) <-
WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
[Param FParamMem]
-> AllocM
fromrep
torep
([Param FParamMem], ([Param FParamMem], [Param FParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
[Param FParamMem]
-> AllocM
fromrep
torep
([Param FParamMem], ([Param FParamMem], [Param FParamMem])))
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
[Param FParamMem]
-> AllocM
fromrep
torep
([Param FParamMem], ([Param FParamMem], [Param FParamMem]))
forall a b. (a -> b) -> a -> b
$ ((FParam fromrep, Space)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem))
-> [(FParam fromrep, Space)]
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
[Param FParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParam fromrep
-> Space
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem))
-> (FParam fromrep, Space)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry FParam fromrep
-> Space
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
forall fromrep torep inner.
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
([FParam torep], [FParam torep])
(AllocM fromrep torep)
(FParam torep)
allocInFParam) [(FParam fromrep, Space)]
params
let params' :: [Param FParamMem]
params' = [Param FParamMem]
memparams [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctxparams [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
summary :: Scope torep
summary = [Param FParamMem] -> Scope torep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
params'
Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [FParam torep] -> AllocM fromrep torep a
m [FParam torep]
[Param FParamMem]
params'
allocInFParam ::
(Allocable fromrep torep inner) =>
FParam fromrep ->
Space ->
WriterT
([FParam torep], [FParam torep])
(AllocM fromrep torep)
(FParam torep)
allocInFParam :: forall fromrep torep inner.
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
([FParam torep], [FParam torep])
(AllocM fromrep torep)
(FParam torep)
allocInFParam FParam fromrep
param Space
pspace =
case Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
FParam fromrep
param of
Array PrimType
pt Shape
shape Uniqueness
u -> do
let memname :: String
memname = VName -> String
baseString (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
FParam fromrep
param) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_mem"
ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
VName
mem <- AllocM fromrep torep VName
-> WriterT
([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep VName
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
VName)
-> AllocM fromrep torep VName
-> WriterT
([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) VName
forall a b. (a -> b) -> a -> b
$ String -> AllocM fromrep torep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
memname
([Param FParamMem], [Param FParamMem])
-> WriterT
([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Attrs -> VName -> FParamMem -> Param FParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param (Param DeclType -> Attrs
forall dec. Param dec -> Attrs
paramAttrs Param DeclType
FParam fromrep
param) VName
mem (FParamMem -> Param FParamMem) -> FParamMem -> Param FParamMem
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
pspace], [])
Param FParamMem
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = PrimType -> Shape -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u (MemBind -> FParamMem) -> MemBind -> FParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun}
Prim PrimType
pt ->
Param FParamMem
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt}
Mem Space
space ->
Param FParamMem
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}
Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u ->
Param FParamMem
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = VName -> Shape -> [Type] -> Uniqueness -> FParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u}
ensureRowMajorArray ::
(Allocable fromrep torep inner) =>
Maybe Space ->
VName ->
AllocM fromrep torep (VName, VName)
ensureRowMajorArray :: forall fromrep torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray Maybe Space
space_ok VName
v = do
(VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
Space
mem_space <- VName -> AllocM fromrep torep Space
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
Space
default_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
let space :: Space
space = Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok
if IxFun (TPrimExp Int64 VName) -> Int
numLMADs IxFun (TPrimExp Int64 VName)
ixfun Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
Bool -> Bool -> Bool
&& IxFun (TPrimExp Int64 VName) -> [Int]
ixFunPerm IxFun (TPrimExp Int64 VName)
ixfun [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int
0 .. IxFun (TPrimExp Int64 VName) -> Int
forall num. IntegralExp num => IxFun num -> Int
IxFun.rank IxFun (TPrimExp Int64 VName)
ixfun Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
Bool -> Bool -> Bool
&& [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun (TPrimExp Int64 VName) -> Int
forall num. IntegralExp num => IxFun num -> Int
IxFun.rank IxFun (TPrimExp Int64 VName)
ixfun
Bool -> Bool -> Bool
&& Bool -> (Space -> Bool) -> Maybe Space -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
Bool -> Bool -> Bool
&& IxFun (TPrimExp Int64 VName) -> Bool
forall num. IxFun num -> Bool
IxFun.contiguous IxFun (TPrimExp Int64 VName)
ixfun
then (VName, VName) -> AllocM fromrep torep (VName, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
else Space -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v
ensureArrayIn ::
(Allocable fromrep torep inner) =>
Space ->
SubExp ->
WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn :: forall fromrep torep inner.
Allocable fromrep torep inner =>
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
_ (Constant PrimValue
v) =
String
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a. HasCallStack => String -> a
error (String
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> String
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
$ String
"ensureArrayIn: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" cannot be an array."
ensureArrayIn Space
space (Var VName
v) = do
(VName
mem', VName
v') <- AllocM fromrep torep (VName, VName)
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall a b. (a -> b) -> a -> b
$ Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) VName
v
(VName
_, IxFun (TPrimExp Int64 VName)
ixfun) <- AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
([SubExp], [SubExp])
(AllocM fromrep torep)
(VName, IxFun (TPrimExp Int64 VName))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
([SubExp], [SubExp])
(AllocM fromrep torep)
(VName, IxFun (TPrimExp Int64 VName)))
-> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
([SubExp], [SubExp])
(AllocM fromrep torep)
(VName, IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v'
[SubExp]
ctx <- AllocM fromrep torep [SubExp]
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep [SubExp]
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp])
-> AllocM fromrep torep [SubExp]
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> AllocM fromrep torep SubExp)
-> [TPrimExp Int64 VName] -> AllocM fromrep torep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Rep (AllocM fromrep torep)) -> AllocM fromrep torep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ixfun_arg" (Exp torep -> AllocM fromrep torep SubExp)
-> (TPrimExp Int64 VName -> AllocM fromrep torep (Exp torep))
-> TPrimExp Int64 VName
-> AllocM fromrep torep SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName -> AllocM fromrep torep (Exp torep)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp) (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList IxFun (TPrimExp Int64 VName)
ixfun)
([SubExp], [SubExp])
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
mem'], [SubExp]
ctx)
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'
allocInMergeParams ::
(Allocable fromrep torep inner) =>
[(FParam fromrep, SubExp)] ->
( [(FParam torep, SubExp)] ->
([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])) ->
AllocM fromrep torep a
) ->
AllocM fromrep torep a
allocInMergeParams :: forall fromrep torep inner a.
Allocable fromrep torep inner =>
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
merge [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m = do
(([Param FParamMem]
valparams, [SubExp]
valargs, [SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps), ([Param FParamMem]
mem_params, [Param FParamMem]
ctx_params)) <-
WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
([Param FParamMem], [SubExp],
[SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
-> AllocM
fromrep
torep
(([Param FParamMem], [SubExp],
[SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]),
([Param FParamMem], [Param FParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
([Param FParamMem], [SubExp],
[SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
-> AllocM
fromrep
torep
(([Param FParamMem], [SubExp],
[SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]),
([Param FParamMem], [Param FParamMem])))
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
([Param FParamMem], [SubExp],
[SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
-> AllocM
fromrep
torep
(([Param FParamMem], [SubExp],
[SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]),
([Param FParamMem], [Param FParamMem]))
forall a b. (a -> b) -> a -> b
$ [(Param FParamMem, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
-> ([Param FParamMem], [SubExp],
[SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Param FParamMem, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
-> ([Param FParamMem], [SubExp],
[SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]))
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
[(Param FParamMem, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
([Param FParamMem], [SubExp],
[SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param DeclType, SubExp)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp))
-> [(Param DeclType, SubExp)]
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
[(Param FParamMem, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param DeclType, SubExp)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall fromrep torep inner.
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
([FParam torep], [FParam torep])
(AllocM fromrep torep)
(FParam torep, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam [(Param DeclType, SubExp)]
[(FParam fromrep, SubExp)]
merge
let mergeparams' :: [Param FParamMem]
mergeparams' = [Param FParamMem]
mem_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
summary :: Scope torep
summary = [Param FParamMem] -> Scope torep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
mergeparams'
mk_loop_res :: [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
ses = do
([SubExp]
ses', ([SubExp]
memargs, [SubExp]
ctxargs)) <-
WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
-> AllocM fromrep torep ([SubExp], ([SubExp], [SubExp]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
-> AllocM fromrep torep ([SubExp], ([SubExp], [SubExp])))
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
-> AllocM fromrep torep ([SubExp], ([SubExp], [SubExp]))
forall a b. (a -> b) -> a -> b
$ ((SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> [SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
-> [SubExp]
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
($) [SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps [SubExp]
ses
([SubExp], [SubExp]) -> AllocM fromrep torep ([SubExp], [SubExp])
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
memargs [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
ctxargs, [SubExp]
ses')
([SubExp]
valctx_args, [SubExp]
valargs') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
valargs
let merge' :: [(Param FParamMem, SubExp)]
merge' =
[Param FParamMem] -> [SubExp] -> [(Param FParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Param FParamMem]
mem_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams) ([SubExp]
valctx_args [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
valargs')
Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m [(FParam torep, SubExp)]
[(Param FParamMem, SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res
where
param_names :: Names
param_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
[(FParam fromrep, SubExp)]
merge
anyIsLoopParam :: Names -> Bool
anyIsLoopParam Names
names = Names
names Names -> Names -> Bool
`namesIntersect` Names
param_names
scalarRes :: DeclType
-> Space -> IxFun (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun (Var VName
res) = do
(VName
res_mem, IxFun (TPrimExp Int64 VName)
res_ixfun) <- m (VName, IxFun (TPrimExp Int64 VName))
-> t m (VName, IxFun (TPrimExp Int64 VName))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, IxFun (TPrimExp Int64 VName))
-> t m (VName, IxFun (TPrimExp Int64 VName)))
-> m (VName, IxFun (TPrimExp Int64 VName))
-> t m (VName, IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ VName -> m (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
res
Space
res_mem_space <- m Space -> t m Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Space -> t m Space) -> m Space -> t m Space
forall a b. (a -> b) -> a -> b
$ VName -> m Space
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
res_mem
ChunkMap
chunkmap <- (AllocEnv fromrep torep -> ChunkMap) -> t m ChunkMap
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> ChunkMap
forall fromrep torep. AllocEnv fromrep torep -> ChunkMap
chunkMap
(VName
res_mem', VName
res') <-
if (Space
res_mem_space, IxFun (TPrimExp Int64 VName)
res_ixfun) (Space, IxFun (TPrimExp Int64 VName))
-> (Space, IxFun (TPrimExp Int64 VName)) -> Bool
forall a. Eq a => a -> a -> Bool
== (Space
v_mem_space, IxFun (TPrimExp Int64 VName)
v_ixfun)
then (VName, VName) -> t m (VName, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
res_mem, VName
res)
else m (VName, VName) -> t m (VName, VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, VName) -> t m (VName, VName))
-> m (VName, VName) -> t m (VName, VName)
forall a b. (a -> b) -> a -> b
$ ChunkMap
-> Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner,
LetDec (Rep m) ~ LParamMem) =>
ChunkMap
-> Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
arrayWithIxFun ChunkMap
chunkmap Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun (DeclType -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl DeclType
param_t) VName
res
([SubExp], [a]) -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
res_mem'], [])
SubExp -> t m SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> t m SubExp) -> SubExp -> t m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
res'
scalarRes DeclType
_ Space
_ IxFun (TPrimExp Int64 VName)
_ SubExp
se = SubExp -> t m SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
allocInMergeParam ::
(Allocable fromrep torep inner) =>
(Param DeclType, SubExp) ->
WriterT
([FParam torep], [FParam torep])
(AllocM fromrep torep)
( FParam torep,
SubExp,
SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
)
allocInMergeParam :: forall fromrep torep inner.
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
([FParam torep], [FParam torep])
(AllocM fromrep torep)
(FParam torep, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, Var VName
v)
| param_t :: DeclType
param_t@(Array PrimType
pt Shape
shape Uniqueness
u) <- Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
mergeparam = do
(VName
v_mem, IxFun (TPrimExp Int64 VName)
v_ixfun) <- AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(VName, IxFun (TPrimExp Int64 VName))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(VName, IxFun (TPrimExp Int64 VName)))
-> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(VName, IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
Space
v_mem_space <- AllocM fromrep torep Space
-> WriterT
([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Space
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
Space)
-> AllocM fromrep torep Space
-> WriterT
([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep Space
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem
case Space
v_mem_space of
ScalarSpace {} ->
if Names -> Bool
anyIsLoopParam (Shape -> Names
forall a. FreeIn a => a -> Names
freeIn Shape
shape)
then do
(VName
_, VName
v') <- AllocM fromrep torep (VName, VName)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(VName, VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(VName, VName)
forall a b. (a -> b) -> a -> b
$ Space -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
DefaultSpace (VName -> String
baseString VName
v) VName
v
(Param DeclType, SubExp)
-> WriterT
([Param (FParamInfo torep)], [Param (FParamInfo torep)])
(AllocM fromrep torep)
(Param (FParamInfo torep), SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall fromrep torep inner.
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
([FParam torep], [FParam torep])
(AllocM fromrep torep)
(FParam torep, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, VName -> SubExp
Var VName
v')
else do
Param FParamMem
p <- String
-> FParamMem
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" (FParamMem
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem))
-> FParamMem
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space
([Param FParamMem], [Param FParamMem])
-> WriterT
([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem
p], [])
(Param FParamMem, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = PrimType -> Shape -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u (MemBind -> FParamMem) -> MemBind -> FParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
p) IxFun (TPrimExp Int64 VName)
v_ixfun},
VName -> SubExp
Var VName
v,
DeclType
-> Space
-> IxFun (TPrimExp Int64 VName)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall {m :: * -> *} {t :: (* -> *) -> * -> *} {fromrep} {torep}
{a} {inner}.
(MonadTrans t, MonadReader (AllocEnv fromrep torep) (t m),
MonadBuilder m, MonadWriter ([SubExp], [a]) (t m),
HasLetDecMem (LetDec (Rep m)), OpReturns inner,
LetDec (Rep m) ~ LParamMem, RetType (Rep m) ~ RetTypeMem,
Op (Rep m) ~ MemOp inner, LParamInfo (Rep m) ~ LParamMem,
FParamInfo (Rep m) ~ FParamMem,
BranchType (Rep m) ~ BranchTypeMem) =>
DeclType
-> Space -> IxFun (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun
)
Space
_ -> do
(VName
v_mem', VName
v') <- AllocM fromrep torep (VName, VName)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(VName, VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(VName, VName)
forall a b. (a -> b) -> a -> b
$ Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray Maybe Space
forall a. Maybe a
Nothing VName
v
(VName
_, IxFun (TPrimExp Int64 VName)
v_ixfun') <- AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(VName, IxFun (TPrimExp Int64 VName))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(VName, IxFun (TPrimExp Int64 VName)))
-> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(VName, IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v'
Space
v_mem_space' <- AllocM fromrep torep Space
-> WriterT
([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Space
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
Space)
-> AllocM fromrep torep Space
-> WriterT
([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep Space
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem'
[Param FParamMem]
ctx_params <-
Int
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
[Param FParamMem]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (IxFun (TPrimExp Int64 VName) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length IxFun (TPrimExp Int64 VName)
v_ixfun') (WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
[Param FParamMem])
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
[Param FParamMem]
forall a b. (a -> b) -> a -> b
$
String
-> FParamMem
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ctx_param_ext" (PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
IxFun (TPrimExp Int64 VName)
param_ixfun <-
ExtIxFun
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(IxFun (TPrimExp Int64 VName))
forall (m :: * -> *).
Monad m =>
ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
instantiateIxFun (ExtIxFun
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(IxFun (TPrimExp Int64 VName)))
-> ExtIxFun
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$
Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun
( [(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> ([TPrimExp Int64 (Ext VName)]
-> [(Ext VName, TPrimExp Int64 (Ext VName))])
-> [TPrimExp Int64 (Ext VName)]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Ext VName]
-> [TPrimExp Int64 (Ext VName)]
-> [(Ext VName, TPrimExp Int64 (Ext VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Int -> Ext VName) -> [Int] -> [Ext VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Ext VName
forall a. Int -> Ext a
Ext [Int
0 ..]) ([TPrimExp Int64 (Ext VName)]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> [TPrimExp Int64 (Ext VName)]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$
(Param FParamMem -> TPrimExp Int64 (Ext VName))
-> [Param FParamMem] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Ext VName -> TPrimExp Int64 (Ext VName))
-> (Param FParamMem -> Ext VName)
-> Param FParamMem
-> TPrimExp Int64 (Ext VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> (Param FParamMem -> VName) -> Param FParamMem -> Ext VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName) [Param FParamMem]
ctx_params
)
(IxFun (TPrimExp Int64 VName) -> ExtIxFun
forall a b.
IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b))
IxFun.existentialize IxFun (TPrimExp Int64 VName)
v_ixfun')
Param FParamMem
mem_param <- String
-> FParamMem
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" (FParamMem
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem))
-> FParamMem
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem)
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space'
([Param FParamMem], [Param FParamMem])
-> WriterT
([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem
mem_param], [Param FParamMem]
ctx_params)
(Param FParamMem, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> WriterT
([Param FParamMem], [Param FParamMem])
(AllocM fromrep torep)
(Param FParamMem, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = PrimType -> Shape -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u (MemBind -> FParamMem) -> MemBind -> FParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
mem_param) IxFun (TPrimExp Int64 VName)
param_ixfun},
VName -> SubExp
Var VName
v',
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall fromrep torep inner.
Allocable fromrep torep inner =>
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
v_mem_space'
)
allocInMergeParam (Param DeclType
mergeparam, SubExp
se) = Param DeclType
-> SubExp
-> Space
-> WriterT
([Param (FParamInfo torep)], [Param (FParamInfo torep)])
(AllocM fromrep torep)
(Param (FParamInfo torep), SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall {torep} {fromrep} {fromrep} {torep} {inner} {inner} {b}.
(PrettyRep fromrep, PrettyRep fromrep, HasLetDecMem (LetDec torep),
HasLetDecMem (LetDec torep), OpReturns inner, OpReturns inner,
SizeSubst inner, SizeSubst inner, BuilderOps torep,
BuilderOps torep, FParamInfo torep ~ FParamMem,
LetDec torep ~ LParamMem, BodyDec fromrep ~ (),
LParamInfo fromrep ~ Type, RetType torep ~ RetTypeMem,
BodyDec fromrep ~ (), BranchType torep ~ BranchTypeMem,
BranchType fromrep ~ ExtType, FParamInfo torep ~ FParamMem,
FParamInfo fromrep ~ DeclType, ExpDec torep ~ (),
LParamInfo fromrep ~ Type, RetType torep ~ RetTypeMem,
BodyDec torep ~ (), BranchType torep ~ BranchTypeMem,
BranchType fromrep ~ ExtType, RetType fromrep ~ DeclExtType,
Op torep ~ MemOp inner, ExpDec torep ~ (),
LParamInfo torep ~ LParamMem, FParamInfo fromrep ~ DeclType,
BodyDec torep ~ (), LetDec torep ~ LParamMem,
LParamInfo torep ~ LParamMem, RetType fromrep ~ DeclExtType,
Op torep ~ MemOp inner) =>
Param DeclType
-> b
-> Space
-> WriterT
([Param (FParamInfo torep)], [Param (FParamInfo torep)])
(AllocM fromrep torep)
(Param (FParamInfo torep), b,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param DeclType
mergeparam SubExp
se (Space
-> WriterT
([Param (FParamInfo torep)], [Param (FParamInfo torep)])
(AllocM fromrep torep)
(Param FParamMem, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp))
-> WriterT
([Param (FParamInfo torep)], [Param (FParamInfo torep)])
(AllocM fromrep torep)
Space
-> WriterT
([Param (FParamInfo torep)], [Param (FParamInfo torep)])
(AllocM fromrep torep)
(Param FParamMem, SubExp,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< AllocM fromrep torep Space
-> WriterT
([Param (FParamInfo torep)], [Param (FParamInfo torep)])
(AllocM fromrep torep)
Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
doDefault :: Param DeclType
-> b
-> Space
-> WriterT
([Param (FParamInfo torep)], [Param (FParamInfo torep)])
(AllocM fromrep torep)
(Param (FParamInfo torep), b,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param DeclType
mergeparam b
se Space
space = do
Param (FParamInfo torep)
mergeparam' <- FParam fromrep
-> Space
-> WriterT
([Param (FParamInfo torep)], [Param (FParamInfo torep)])
(AllocM fromrep torep)
(Param (FParamInfo torep))
forall fromrep torep inner.
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
([FParam torep], [FParam torep])
(AllocM fromrep torep)
(FParam torep)
allocInFParam Param DeclType
FParam fromrep
mergeparam Space
space
(Param (FParamInfo torep), b,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> WriterT
([Param (FParamInfo torep)], [Param (FParamInfo torep)])
(AllocM fromrep torep)
(Param (FParamInfo torep), b,
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (FParamInfo torep)
mergeparam', b
se, Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall fromrep torep inner.
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg (Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
mergeparam) Space
space)
arrayWithIxFun ::
(MonadBuilder m, Op (Rep m) ~ MemOp inner, LetDec (Rep m) ~ LetDecMem) =>
ChunkMap ->
Space ->
IxFun ->
Type ->
VName ->
m (VName, VName)
arrayWithIxFun :: forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner,
LetDec (Rep m) ~ LParamMem) =>
ChunkMap
-> Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
arrayWithIxFun ChunkMap
chunkmap Space
space IxFun (TPrimExp Int64 VName)
ixfun Type
v_t VName
v = do
let Array PrimType
pt Shape
shape NoUniqueness
u = Type
v_t
VName
mem <- ChunkMap -> Type -> Space -> m VName
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
ChunkMap -> Type -> Space -> m VName
allocForArray' ChunkMap
chunkmap Type
v_t Space
space
VName
v_copy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_scalcopy"
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem VName
v_copy (LParamMem -> PatElem LParamMem) -> LParamMem -> PatElem LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun]) (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
(VName, VName) -> m (VName, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v_copy)
ensureDirectArray ::
(Allocable fromrep torep inner) =>
Maybe Space ->
VName ->
AllocM fromrep torep (VName, VName)
ensureDirectArray :: forall fromrep torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v = do
(VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
Space
mem_space <- VName -> AllocM fromrep torep Space
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
Space
default_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
if IxFun (TPrimExp Int64 VName) -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun (TPrimExp Int64 VName)
ixfun Bool -> Bool -> Bool
&& Bool -> (Space -> Bool) -> Maybe Space -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
then (VName, VName) -> AllocM fromrep torep (VName, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
else Space -> AllocM fromrep torep (VName, VName)
needCopy (Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok)
where
needCopy :: Space -> AllocM fromrep torep (VName, VName)
needCopy Space
space =
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v
allocPermArray ::
(Allocable fromrep torep inner) =>
Space ->
[Int] ->
String ->
VName ->
AllocM fromrep torep (VName, VName)
allocPermArray :: forall fromrep torep inner.
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray Space
space [Int]
perm String
s VName
v = do
Type
t <- VName -> AllocM fromrep torep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
case Type
t of
Array PrimType
pt Shape
shape NoUniqueness
u -> do
VName
mem <- Type -> Space -> AllocM fromrep torep VName
forall fromrep torep inner.
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space
VName
v' <- String -> AllocM fromrep torep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> AllocM fromrep torep VName)
-> String -> AllocM fromrep torep VName
forall a b. (a -> b) -> a -> b
$ String
s String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_desired_form"
let info :: LParamMem
info =
PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem)
-> (IxFun (TPrimExp Int64 VName) -> MemBind)
-> IxFun (TPrimExp Int64 VName)
-> LParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (IxFun (TPrimExp Int64 VName) -> LParamMem)
-> IxFun (TPrimExp Int64 VName) -> LParamMem
forall a b. (a -> b) -> a -> b
$
IxFun (TPrimExp Int64 VName)
-> [Int] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) [Int]
perm
pat :: Pat LParamMem
pat = [PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem VName
v' LParamMem
info]
Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ())
-> Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec torep)
-> StmAux (ExpDec torep) -> Exp torep -> Stm torep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec torep)
Pat LParamMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp torep -> Stm torep) -> Exp torep -> Stm torep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp torep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp torep) -> BasicOp -> Exp torep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
v
(VName, VName) -> AllocM fromrep torep (VName, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v')
Type
_ ->
String -> AllocM fromrep torep (VName, VName)
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep torep (VName, VName))
-> String -> AllocM fromrep torep (VName, VName)
forall a b. (a -> b) -> a -> b
$ String
"allocPermArray: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
t
allocLinearArray ::
(Allocable fromrep torep inner) =>
Space ->
String ->
VName ->
AllocM fromrep torep (VName, VName)
allocLinearArray :: forall fromrep torep inner.
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space String
s VName
v = do
Type
t <- VName -> AllocM fromrep torep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
let perm :: [Int]
perm = [Int
0 .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray Space
space [Int]
perm String
s VName
v
funcallArgs ::
(Allocable fromrep torep inner) =>
[(SubExp, Diet)] ->
AllocM fromrep torep [(SubExp, Diet)]
funcallArgs :: forall fromrep torep inner.
Allocable fromrep torep inner =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args = do
([(SubExp, Diet)]
valargs, ([SubExp]
ctx_args, [SubExp]
mem_and_size_args)) <- WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
-> AllocM fromrep torep ([(SubExp, Diet)], ([SubExp], [SubExp]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
-> AllocM fromrep torep ([(SubExp, Diet)], ([SubExp], [SubExp])))
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
-> AllocM fromrep torep ([(SubExp, Diet)], ([SubExp], [SubExp]))
forall a b. (a -> b) -> a -> b
$
[(SubExp, Diet)]
-> ((SubExp, Diet)
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet))
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(SubExp, Diet)]
args (((SubExp, Diet)
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet))
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)])
-> ((SubExp, Diet)
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet))
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ \(SubExp
arg, Diet
d) -> do
Type
t <- AllocM fromrep torep Type
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Type
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Type
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Type)
-> AllocM fromrep torep Type
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Type
forall a b. (a -> b) -> a -> b
$ SubExp -> AllocM fromrep torep Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
arg
Space
space <- AllocM fromrep torep Space
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
SubExp
arg' <- Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall fromrep torep inner.
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Type
t Space
space SubExp
arg
(SubExp, Diet)
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
arg', Diet
d)
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)])
-> [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ (SubExp -> (SubExp, Diet)) -> [SubExp] -> [(SubExp, Diet)]
forall a b. (a -> b) -> [a] -> [b]
map (,Diet
Observe) ([SubExp]
ctx_args [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
mem_and_size_args) [(SubExp, Diet)] -> [(SubExp, Diet)] -> [(SubExp, Diet)]
forall a. Semigroup a => a -> a -> a
<> [(SubExp, Diet)]
valargs
linearFuncallArg ::
(Allocable fromrep torep inner) =>
Type ->
Space ->
SubExp ->
WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg :: forall fromrep torep inner.
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Array {} Space
space (Var VName
v) = do
(VName
mem, VName
arg') <- AllocM fromrep torep (VName, VName)
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall a b. (a -> b) -> a -> b
$ Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) VName
v
([SubExp], [SubExp])
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
mem], [])
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arg'
linearFuncallArg Type
_ Space
_ SubExp
arg =
SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
arg
explicitAllocationsGeneric ::
(Allocable fromrep torep inner) =>
(Op fromrep -> AllocM fromrep torep (Op torep)) ->
(Exp torep -> AllocM fromrep torep [ExpHint]) ->
Pass fromrep torep
explicitAllocationsGeneric :: forall fromrep torep inner.
Allocable fromrep torep inner =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Pass fromrep torep
explicitAllocationsGeneric Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints =
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"explicit allocations" String
"Transform program to explicit memory representation" ((Prog fromrep -> PassM (Prog torep)) -> Pass fromrep torep)
-> (Prog fromrep -> PassM (Prog torep)) -> Pass fromrep torep
forall a b. (a -> b) -> a -> b
$
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms fromrep -> PassM (Stms torep)
onStms Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun
where
onStms :: Stms fromrep -> PassM (Stms torep)
onStms Stms fromrep
stms =
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (Stms torep)
-> PassM (Stms torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (Stms torep) -> PassM (Stms torep))
-> AllocM fromrep torep (Stms torep) -> PassM (Stms torep)
forall a b. (a -> b) -> a -> b
$ AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ Stms fromrep -> AllocM fromrep torep () -> AllocM fromrep torep ()
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ () -> AllocM fromrep torep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
allocInFun :: Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun Stms torep
consts (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType fromrep]
rettype [FParam fromrep]
params Body fromrep
fbody) =
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (FunDef torep)
-> PassM (FunDef torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep))
-> (AllocM fromrep torep (FunDef torep)
-> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
-> PassM (FunDef torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms torep
-> AllocM fromrep torep (FunDef torep)
-> AllocM fromrep torep (FunDef torep)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms torep
consts (AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep))
-> AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep)
forall a b. (a -> b) -> a -> b
$
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
forall fromrep torep inner a.
Allocable fromrep torep inner =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams ([Param DeclType] -> [Space] -> [(Param DeclType, Space)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
[FParam fromrep]
params ([Space] -> [(Param DeclType, Space)])
-> [Space] -> [(Param DeclType, Space)]
forall a b. (a -> b) -> a -> b
$ Space -> [Space]
forall a. a -> [a]
repeat Space
DefaultSpace) (([FParam torep] -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep))
-> ([FParam torep] -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
forall a b. (a -> b) -> a -> b
$ \[FParam torep]
params' -> do
(Body torep
fbody', [RetTypeMem]
mem_rets) <-
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
forall fromrep torep inner.
Allocable fromrep torep inner =>
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody ((DeclExtType -> Maybe Space) -> [DeclExtType] -> [Maybe Space]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe Space -> DeclExtType -> Maybe Space
forall a b. a -> b -> a
const (Maybe Space -> DeclExtType -> Maybe Space)
-> Maybe Space -> DeclExtType -> Maybe Space
forall a b. (a -> b) -> a -> b
$ Space -> Maybe Space
forall a. a -> Maybe a
Just Space
DefaultSpace) [DeclExtType]
[RetType fromrep]
rettype) Body fromrep
fbody
let rettype' :: [RetTypeMem]
rettype' = [RetTypeMem]
mem_rets [RetTypeMem] -> [RetTypeMem] -> [RetTypeMem]
forall a. [a] -> [a] -> [a]
++ Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType ([RetTypeMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
mem_rets) [DeclExtType]
[RetType fromrep]
rettype
FunDef torep -> AllocM fromrep torep (FunDef torep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef torep -> AllocM fromrep torep (FunDef torep))
-> FunDef torep -> AllocM fromrep torep (FunDef torep)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType torep]
-> [FParam torep]
-> Body torep
-> FunDef torep
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType torep]
[RetTypeMem]
rettype' [FParam torep]
params' Body torep
fbody'
explicitAllocationsInStmsGeneric ::
( MonadFreshNames m,
HasScope torep m,
Allocable fromrep torep inner
) =>
(Op fromrep -> AllocM fromrep torep (Op torep)) ->
(Exp torep -> AllocM fromrep torep [ExpHint]) ->
Stms fromrep ->
m (Stms torep)
explicitAllocationsInStmsGeneric :: forall (m :: * -> *) torep fromrep inner.
(MonadFreshNames m, HasScope torep m,
Allocable fromrep torep inner) =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Stms fromrep
-> m (Stms torep)
explicitAllocationsInStmsGeneric Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints Stms fromrep
stms = do
Scope torep
scope <- m (Scope torep)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (Stms torep)
-> m (Stms torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (Stms torep) -> m (Stms torep))
-> AllocM fromrep torep (Stms torep) -> m (Stms torep)
forall a b. (a -> b) -> a -> b
$
Scope torep
-> AllocM fromrep torep (Stms torep)
-> AllocM fromrep torep (Stms torep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
scope (AllocM fromrep torep (Stms torep)
-> AllocM fromrep torep (Stms torep))
-> AllocM fromrep torep (Stms torep)
-> AllocM fromrep torep (Stms torep)
forall a b. (a -> b) -> a -> b
$
AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$
Stms fromrep -> AllocM fromrep torep () -> AllocM fromrep torep ()
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$
() -> AllocM fromrep torep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
memoryInDeclExtType :: Int -> [DeclExtType] -> [FunReturns]
memoryInDeclExtType :: Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Int
k [DeclExtType]
dets = State Int [RetTypeMem] -> Int -> [RetTypeMem]
forall s a. State s a -> s -> a
evalState ((DeclExtType -> StateT Int Identity RetTypeMem)
-> [DeclExtType] -> State Int [RetTypeMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DeclExtType -> StateT Int Identity RetTypeMem
addMem [DeclExtType]
dets) Int
0
where
addMem :: DeclExtType -> StateT Int Identity RetTypeMem
addMem (Prim PrimType
t) = RetTypeMem -> StateT Int Identity RetTypeMem
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RetTypeMem -> StateT Int Identity RetTypeMem)
-> RetTypeMem -> StateT Int Identity RetTypeMem
forall a b. (a -> b) -> a -> b
$ PrimType -> RetTypeMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
addMem Mem {} = String -> StateT Int Identity RetTypeMem
forall a. HasCallStack => String -> a
error String
"memoryInDeclExtType: too much memory"
addMem (Array PrimType
pt ShapeBase (Ext SubExp)
shape Uniqueness
u) = do
Int
i <- StateT Int Identity Int
forall s (m :: * -> *). MonadState s m => m s
get StateT Int Identity Int
-> StateT Int Identity () -> StateT Int Identity Int
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* (Int -> Int) -> StateT Int Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
let shape' :: ShapeBase (Ext SubExp)
shape' = (Ext SubExp -> Ext SubExp)
-> ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext SubExp -> Ext SubExp
shift ShapeBase (Ext SubExp)
shape
RetTypeMem -> StateT Int Identity RetTypeMem
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RetTypeMem -> StateT Int Identity RetTypeMem)
-> (ExtIxFun -> RetTypeMem)
-> ExtIxFun
-> StateT Int Identity RetTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType
-> ShapeBase (Ext SubExp) -> Uniqueness -> MemReturn -> RetTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' Uniqueness
u (MemReturn -> RetTypeMem)
-> (ExtIxFun -> MemReturn) -> ExtIxFun -> RetTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
DefaultSpace Int
i (ExtIxFun -> StateT Int Identity RetTypeMem)
-> ExtIxFun -> StateT Int Identity RetTypeMem
forall a b. (a -> b) -> a -> b
$
[TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 (Ext VName)] -> ExtIxFun)
-> [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall a b. (a -> b) -> a -> b
$
(Ext SubExp -> TPrimExp Int64 (Ext VName))
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert ([Ext SubExp] -> [TPrimExp Int64 (Ext VName)])
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> a -> b
$
ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape'
addMem (Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u) = RetTypeMem -> StateT Int Identity RetTypeMem
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RetTypeMem -> StateT Int Identity RetTypeMem)
-> RetTypeMem -> StateT Int Identity RetTypeMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> Uniqueness -> RetTypeMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u
convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Ext VName -> TPrimExp Int64 (Ext VName))
-> Ext VName -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> a -> b
$ Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i
convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v
shift :: Ext SubExp -> Ext SubExp
shift (Ext Int
i) = Int -> Ext SubExp
forall a. Int -> Ext a
Ext (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
shift (Free SubExp
x) = SubExp -> Ext SubExp
forall a. a -> Ext a
Free SubExp
x
bodyReturnMemCtx ::
(Allocable fromrep torep inner) =>
SubExpRes ->
AllocM fromrep torep [(SubExpRes, MemInfo ExtSize u MemReturn)]
bodyReturnMemCtx :: forall fromrep torep inner u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx (SubExpRes Certs
_ Constant {}) =
[(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
bodyReturnMemCtx (SubExpRes Certs
_ (Var VName
v)) = do
LParamMem
info <- VName -> AllocM fromrep torep LParamMem
forall rep (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
v
case LParamMem
info of
MemPrim {} -> [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
MemAcc {} -> [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
MemMem {} -> [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_) -> do
LParamMem
mem_info <- VName -> AllocM fromrep torep LParamMem
forall rep (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
mem
case LParamMem
mem_info of
MemMem Space
space ->
[(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(SubExp -> SubExpRes
subExpRes (SubExp -> SubExpRes) -> SubExp -> SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
mem, Space -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)]
LParamMem
_ -> String
-> AllocM
fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. HasCallStack => String -> a
error (String
-> AllocM
fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)])
-> String
-> AllocM
fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a b. (a -> b) -> a -> b
$ String
"bodyReturnMemCtx: not a memory block: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
mem
allocInFunBody ::
(Allocable fromrep torep inner) =>
[Maybe Space] ->
Body fromrep ->
AllocM fromrep torep (Body torep, [FunReturns])
allocInFunBody :: forall fromrep torep inner.
Allocable fromrep torep inner =>
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody [Maybe Space]
space_oks (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody (AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem]))
-> (AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Result, [RetTypeMem]))
-> AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Result, [RetTypeMem])
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem]))
-> AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem])
forall a b. (a -> b) -> a -> b
$ do
Result
res' <- (Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes)
-> [Maybe Space] -> Result -> AllocM fromrep torep Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
forall fromrep torep inner.
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect [Maybe Space]
space_oks' Result
res
(Result
mem_ctx_res, [RetTypeMem]
mem_ctx_rets) <- [(SubExpRes, RetTypeMem)] -> (Result, [RetTypeMem])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExpRes, RetTypeMem)] -> (Result, [RetTypeMem]))
-> ([[(SubExpRes, RetTypeMem)]] -> [(SubExpRes, RetTypeMem)])
-> [[(SubExpRes, RetTypeMem)]]
-> (Result, [RetTypeMem])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[(SubExpRes, RetTypeMem)]] -> [(SubExpRes, RetTypeMem)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(SubExpRes, RetTypeMem)]] -> (Result, [RetTypeMem]))
-> AllocM fromrep torep [[(SubExpRes, RetTypeMem)]]
-> AllocM fromrep torep (Result, [RetTypeMem])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> AllocM fromrep torep [(SubExpRes, RetTypeMem)])
-> Result -> AllocM fromrep torep [[(SubExpRes, RetTypeMem)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> AllocM fromrep torep [(SubExpRes, RetTypeMem)]
forall fromrep torep inner u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx Result
res'
(Result, [RetTypeMem])
-> AllocM fromrep torep (Result, [RetTypeMem])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
mem_ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
res', [RetTypeMem]
mem_ctx_rets)
where
num_vals :: Int
num_vals = [Maybe Space] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Space]
space_oks
space_oks' :: [Maybe Space]
space_oks' = Int -> Maybe Space -> [Maybe Space]
forall a. Int -> a -> [a]
replicate (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_vals) Maybe Space
forall a. Maybe a
Nothing [Maybe Space] -> [Maybe Space] -> [Maybe Space]
forall a. [a] -> [a] -> [a]
++ [Maybe Space]
space_oks
ensureDirect ::
(Allocable fromrep torep inner) =>
Maybe Space ->
SubExpRes ->
AllocM fromrep torep SubExpRes
ensureDirect :: forall fromrep torep inner.
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect Maybe Space
space_ok (SubExpRes Certs
cs SubExp
se) = do
LParamMem
se_info <- SubExp -> AllocM fromrep torep LParamMem
forall rep (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (SubExp -> SubExpRes)
-> AllocM fromrep torep SubExp -> AllocM fromrep torep SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case (LParamMem
se_info, SubExp
se) of
(MemArray {}, Var VName
v) -> do
(VName
_, VName
v') <- Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v
SubExp -> AllocM fromrep torep SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> AllocM fromrep torep SubExp)
-> SubExp -> AllocM fromrep torep SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'
(LParamMem, SubExp)
_ ->
SubExp -> AllocM fromrep torep SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
allocInStms ::
(Allocable fromrep torep inner) =>
Stms fromrep ->
AllocM fromrep torep a ->
AllocM fromrep torep a
allocInStms :: forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
origstms AllocM fromrep torep a
m = [Stm fromrep] -> AllocM fromrep torep a
allocInStms' ([Stm fromrep] -> AllocM fromrep torep a)
-> [Stm fromrep] -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ Stms fromrep -> [Stm fromrep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms fromrep
origstms
where
allocInStms' :: [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [] = AllocM fromrep torep a
m
allocInStms' (Stm fromrep
stm : [Stm fromrep]
stms) = do
Seq (Stm torep)
allocstms <- AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec fromrep)
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing (Stm fromrep -> StmAux (ExpDec fromrep)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm fromrep
stm) (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ Stm fromrep -> AllocM fromrep torep ()
forall fromrep torep inner.
Allocable fromrep torep inner =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm Stm fromrep
stm
Stms (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Seq (Stm torep)
Stms (Rep (AllocM fromrep torep))
allocstms
let stms_substs :: ChunkMap
stms_substs = (Stm torep -> ChunkMap) -> Seq (Stm torep) -> ChunkMap
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm torep -> ChunkMap
forall rep. SizeSubst (Op rep) => Stm rep -> ChunkMap
sizeSubst Seq (Stm torep)
allocstms
stms_consts :: Set VName
stms_consts = (Stm torep -> Set VName) -> Seq (Stm torep) -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm torep -> Set VName
forall rep. SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts Seq (Stm torep)
allocstms
f :: AllocEnv fromrep torep -> AllocEnv fromrep torep
f AllocEnv fromrep torep
env =
AllocEnv fromrep torep
env
{ chunkMap :: ChunkMap
chunkMap = ChunkMap
stms_substs ChunkMap -> ChunkMap -> ChunkMap
forall a. Semigroup a => a -> a -> a
<> AllocEnv fromrep torep -> ChunkMap
forall fromrep torep. AllocEnv fromrep torep -> ChunkMap
chunkMap AllocEnv fromrep torep
env,
envConsts :: Set VName
envConsts = Set VName
stms_consts Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> AllocEnv fromrep torep -> Set VName
forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts AllocEnv fromrep torep
env
}
(AllocEnv fromrep torep -> AllocEnv fromrep torep)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv fromrep torep -> AllocEnv fromrep torep
f (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [Stm fromrep]
stms
allocInStm ::
(Allocable fromrep torep inner) =>
Stm fromrep ->
AllocM fromrep torep ()
allocInStm :: forall fromrep torep inner.
Allocable fromrep torep inner =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm (Let (Pat [PatElem (LetDec fromrep)]
pes) StmAux (ExpDec fromrep)
_ Exp fromrep
e) =
Stm torep -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm torep -> AllocM fromrep torep ())
-> AllocM fromrep torep (Stm torep) -> AllocM fromrep torep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
forall fromrep torep inner.
Allocable fromrep torep inner =>
[Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm ((PatElem (LetDec fromrep) -> Ident)
-> [PatElem (LetDec fromrep)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElem (LetDec fromrep) -> Ident
forall dec. Typed dec => PatElem dec -> Ident
patElemIdent [PatElem (LetDec fromrep)]
pes) (Exp torep -> AllocM fromrep torep (Stm torep))
-> AllocM fromrep torep (Exp torep)
-> AllocM fromrep torep (Stm torep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp fromrep -> AllocM fromrep torep (Exp torep)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp Exp fromrep
e
allocInLambda ::
Allocable fromrep torep inner =>
[LParam torep] ->
Body fromrep ->
AllocM fromrep torep (Lambda torep)
allocInLambda :: forall fromrep torep inner.
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
params Body fromrep
body =
[LParam (Rep (AllocM fromrep torep))]
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [LParam torep]
[LParam (Rep (AllocM fromrep torep))]
params (AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda torep))
-> (AllocM fromrep torep Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms (Body fromrep -> Stms fromrep
forall rep. Body rep -> Stms rep
bodyStms Body fromrep
body) (AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda torep))
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda torep)
forall a b. (a -> b) -> a -> b
$ Result -> AllocM fromrep torep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ Body fromrep -> Result
forall rep. Body rep -> Result
bodyResult Body fromrep
body
numLMADs :: IxFun -> Int
numLMADs :: IxFun (TPrimExp Int64 VName) -> Int
numLMADs = NonEmpty (LMAD (TPrimExp Int64 VName)) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (NonEmpty (LMAD (TPrimExp Int64 VName)) -> Int)
-> (IxFun (TPrimExp Int64 VName)
-> NonEmpty (LMAD (TPrimExp Int64 VName)))
-> IxFun (TPrimExp Int64 VName)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IxFun (TPrimExp Int64 VName)
-> NonEmpty (LMAD (TPrimExp Int64 VName))
forall num. IxFun num -> NonEmpty (LMAD num)
IxFun.ixfunLMADs
ixFunPerm :: IxFun -> [Int]
ixFunPerm :: IxFun (TPrimExp Int64 VName) -> [Int]
ixFunPerm = (LMADDim (TPrimExp Int64 VName) -> Int)
-> [LMADDim (TPrimExp Int64 VName)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (TPrimExp Int64 VName) -> Int
forall num. LMADDim num -> Int
IxFun.ldPerm ([LMADDim (TPrimExp Int64 VName)] -> [Int])
-> (IxFun (TPrimExp Int64 VName)
-> [LMADDim (TPrimExp Int64 VName)])
-> IxFun (TPrimExp Int64 VName)
-> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMAD (TPrimExp Int64 VName) -> [LMADDim (TPrimExp Int64 VName)]
forall num. LMAD num -> [LMADDim num]
IxFun.lmadDims (LMAD (TPrimExp Int64 VName) -> [LMADDim (TPrimExp Int64 VName)])
-> (IxFun (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName))
-> IxFun (TPrimExp Int64 VName)
-> [LMADDim (TPrimExp Int64 VName)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty (LMAD (TPrimExp Int64 VName))
-> LMAD (TPrimExp Int64 VName)
forall a. NonEmpty a -> a
NE.head (NonEmpty (LMAD (TPrimExp Int64 VName))
-> LMAD (TPrimExp Int64 VName))
-> (IxFun (TPrimExp Int64 VName)
-> NonEmpty (LMAD (TPrimExp Int64 VName)))
-> IxFun (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IxFun (TPrimExp Int64 VName)
-> NonEmpty (LMAD (TPrimExp Int64 VName))
forall num. IxFun num -> NonEmpty (LMAD num)
IxFun.ixfunLMADs
ixFunMon :: IxFun -> [IxFun.Monotonicity]
ixFunMon :: IxFun (TPrimExp Int64 VName) -> [Monotonicity]
ixFunMon = (LMADDim (TPrimExp Int64 VName) -> Monotonicity)
-> [LMADDim (TPrimExp Int64 VName)] -> [Monotonicity]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (TPrimExp Int64 VName) -> Monotonicity
forall num. LMADDim num -> Monotonicity
IxFun.ldMon ([LMADDim (TPrimExp Int64 VName)] -> [Monotonicity])
-> (IxFun (TPrimExp Int64 VName)
-> [LMADDim (TPrimExp Int64 VName)])
-> IxFun (TPrimExp Int64 VName)
-> [Monotonicity]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMAD (TPrimExp Int64 VName) -> [LMADDim (TPrimExp Int64 VName)]
forall num. LMAD num -> [LMADDim num]
IxFun.lmadDims (LMAD (TPrimExp Int64 VName) -> [LMADDim (TPrimExp Int64 VName)])
-> (IxFun (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName))
-> IxFun (TPrimExp Int64 VName)
-> [LMADDim (TPrimExp Int64 VName)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty (LMAD (TPrimExp Int64 VName))
-> LMAD (TPrimExp Int64 VName)
forall a. NonEmpty a -> a
NE.head (NonEmpty (LMAD (TPrimExp Int64 VName))
-> LMAD (TPrimExp Int64 VName))
-> (IxFun (TPrimExp Int64 VName)
-> NonEmpty (LMAD (TPrimExp Int64 VName)))
-> IxFun (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IxFun (TPrimExp Int64 VName)
-> NonEmpty (LMAD (TPrimExp Int64 VName))
forall num. IxFun num -> NonEmpty (LMAD num)
IxFun.ixfunLMADs
data MemReq
= MemReq Space [Int] [IxFun.Monotonicity] Rank Bool
| NeedsLinearisation Space
deriving (MemReq -> MemReq -> Bool
(MemReq -> MemReq -> Bool)
-> (MemReq -> MemReq -> Bool) -> Eq MemReq
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemReq -> MemReq -> Bool
$c/= :: MemReq -> MemReq -> Bool
== :: MemReq -> MemReq -> Bool
$c== :: MemReq -> MemReq -> Bool
Eq, Int -> MemReq -> String -> String
[MemReq] -> String -> String
MemReq -> String
(Int -> MemReq -> String -> String)
-> (MemReq -> String)
-> ([MemReq] -> String -> String)
-> Show MemReq
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [MemReq] -> String -> String
$cshowList :: [MemReq] -> String -> String
show :: MemReq -> String
$cshow :: MemReq -> String
showsPrec :: Int -> MemReq -> String -> String
$cshowsPrec :: Int -> MemReq -> String -> String
Show)
combMemReqs :: MemReq -> MemReq -> MemReq
combMemReqs :: MemReq -> MemReq -> MemReq
combMemReqs x :: MemReq
x@NeedsLinearisation {} MemReq
_ = MemReq
x
combMemReqs MemReq
_ y :: MemReq
y@NeedsLinearisation {} = MemReq
y
combMemReqs x :: MemReq
x@(MemReq Space
x_space [Int]
_ [Monotonicity]
_ Rank
_ Bool
_) y :: MemReq
y@MemReq {} =
if MemReq
x MemReq -> MemReq -> Bool
forall a. Eq a => a -> a -> Bool
== MemReq
y then MemReq
x else Space -> MemReq
NeedsLinearisation Space
x_space
type MemReqType = MemInfo (Ext SubExp) NoUniqueness MemReq
combMemReqTypes :: MemReqType -> MemReqType -> MemReqType
combMemReqTypes :: MemReqType -> MemReqType -> MemReqType
combMemReqTypes (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u MemReq
x) (MemArray PrimType
_ ShapeBase (Ext SubExp)
_ NoUniqueness
_ MemReq
y) =
PrimType
-> ShapeBase (Ext SubExp) -> NoUniqueness -> MemReq -> MemReqType
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u (MemReq -> MemReqType) -> MemReq -> MemReqType
forall a b. (a -> b) -> a -> b
$ MemReq -> MemReq -> MemReq
combMemReqs MemReq
x MemReq
y
combMemReqTypes MemReqType
x MemReqType
_ = MemReqType
x
contextRets :: MemReqType -> [MemInfo d u r]
contextRets :: forall d u r. MemReqType -> [MemInfo d u r]
contextRets (MemArray PrimType
_ ShapeBase (Ext SubExp)
shape NoUniqueness
_ (MemReq Space
space [Int]
_ [Monotonicity]
_ (Rank Int
base_rank) Bool
_)) =
Space -> MemInfo d u r
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
MemInfo d u r -> [MemInfo d u r] -> [MemInfo d u r]
forall a. a -> [a] -> [a]
: PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64
MemInfo d u r -> [MemInfo d u r] -> [MemInfo d u r]
forall a. a -> [a] -> [a]
: Int -> MemInfo d u r -> [MemInfo d u r]
forall a. Int -> a -> [a]
replicate Int
base_rank (PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
[MemInfo d u r] -> [MemInfo d u r] -> [MemInfo d u r]
forall a. [a] -> [a] -> [a]
++ Int -> MemInfo d u r -> [MemInfo d u r]
forall a. Int -> a -> [a]
replicate (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* ShapeBase (Ext SubExp) -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) (PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
contextRets (MemArray PrimType
_ ShapeBase (Ext SubExp)
shape NoUniqueness
_ (NeedsLinearisation Space
space)) =
Space -> MemInfo d u r
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
MemInfo d u r -> [MemInfo d u r] -> [MemInfo d u r]
forall a. a -> [a] -> [a]
: PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64
MemInfo d u r -> [MemInfo d u r] -> [MemInfo d u r]
forall a. a -> [a] -> [a]
: Int -> MemInfo d u r -> [MemInfo d u r]
forall a. Int -> a -> [a]
replicate (Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* ShapeBase (Ext SubExp) -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) (PrimType -> MemInfo d u r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
contextRets MemReqType
_ = []
allocInMatchBody ::
(Allocable fromrep torep inner) =>
[ExtType] ->
Body fromrep ->
AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody :: forall fromrep torep inner.
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [ExtType]
rets (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Body torep, [MemReqType])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody (AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Body torep, [MemReqType]))
-> (AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Result, [MemReqType]))
-> AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Body torep, [MemReqType])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Result, [MemReqType])
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Body torep, [MemReqType]))
-> AllocM fromrep torep (Result, [MemReqType])
-> AllocM fromrep torep (Body torep, [MemReqType])
forall a b. (a -> b) -> a -> b
$ do
[MemReqType]
restrictions <- (ExtType -> SubExp -> AllocM fromrep torep MemReqType)
-> [ExtType] -> [SubExp] -> AllocM fromrep torep [MemReqType]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ExtType -> SubExp -> AllocM fromrep torep MemReqType
forall {rep} {m :: * -> *} {inner} {d}.
(HasScope rep m, OpReturns inner, ASTRep rep,
HasLetDecMem (LetDec rep), Monad m, Show d,
FParamInfo rep ~ FParamMem, RetType rep ~ RetTypeMem,
LParamInfo rep ~ LParamMem, Op rep ~ MemOp inner,
BranchType rep ~ BranchTypeMem) =>
TypeBase (ShapeBase d) NoUniqueness
-> SubExp -> m (MemInfo d NoUniqueness MemReq)
restriction [ExtType]
rets ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res)
(Result, [MemReqType])
-> AllocM fromrep torep (Result, [MemReqType])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
res, [MemReqType]
restrictions)
where
restriction :: TypeBase (ShapeBase d) NoUniqueness
-> SubExp -> m (MemInfo d NoUniqueness MemReq)
restriction TypeBase (ShapeBase d) NoUniqueness
t SubExp
se = do
LParamMem
v_info <- SubExp -> m LParamMem
forall rep (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
case (TypeBase (ShapeBase d) NoUniqueness
t, LParamMem
v_info) of
(Array PrimType
pt ShapeBase d
shape NoUniqueness
u, MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun)) -> do
Space
space <- VName -> m Space
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
MemInfo d NoUniqueness MemReq -> m (MemInfo d NoUniqueness MemReq)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq))
-> (MemReq -> MemInfo d NoUniqueness MemReq)
-> MemReq
-> m (MemInfo d NoUniqueness MemReq)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType
-> ShapeBase d
-> NoUniqueness
-> MemReq
-> MemInfo d NoUniqueness MemReq
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
shape NoUniqueness
u (MemReq -> m (MemInfo d NoUniqueness MemReq))
-> MemReq -> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$
if IxFun (TPrimExp Int64 VName) -> Int
numLMADs IxFun (TPrimExp Int64 VName)
ixfun Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
then
Space -> [Int] -> [Monotonicity] -> Rank -> Bool -> MemReq
MemReq
Space
space
(IxFun (TPrimExp Int64 VName) -> [Int]
ixFunPerm IxFun (TPrimExp Int64 VName)
ixfun)
(IxFun (TPrimExp Int64 VName) -> [Monotonicity]
ixFunMon IxFun (TPrimExp Int64 VName)
ixfun)
(Int -> Rank
Rank (Int -> Rank) -> Int -> Rank
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TPrimExp Int64 VName] -> Int) -> [TPrimExp Int64 VName] -> Int
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun)
(IxFun (TPrimExp Int64 VName) -> Bool
forall num. IxFun num -> Bool
IxFun.contiguous IxFun (TPrimExp Int64 VName)
ixfun)
else Space -> MemReq
NeedsLinearisation Space
space
(TypeBase (ShapeBase d) NoUniqueness
_, MemMem Space
space) -> MemInfo d NoUniqueness MemReq -> m (MemInfo d NoUniqueness MemReq)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq))
-> MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo d NoUniqueness MemReq
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
(TypeBase (ShapeBase d) NoUniqueness
_, MemPrim PrimType
pt) -> MemInfo d NoUniqueness MemReq -> m (MemInfo d NoUniqueness MemReq)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq))
-> MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo d NoUniqueness MemReq
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
(TypeBase (ShapeBase d) NoUniqueness
_, MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) -> MemInfo d NoUniqueness MemReq -> m (MemInfo d NoUniqueness MemReq)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq))
-> MemInfo d NoUniqueness MemReq
-> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$ VName
-> Shape -> [Type] -> NoUniqueness -> MemInfo d NoUniqueness MemReq
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
(TypeBase (ShapeBase d) NoUniqueness, LParamMem)
_ -> String -> m (MemInfo d NoUniqueness MemReq)
forall a. HasCallStack => String -> a
error (String -> m (MemInfo d NoUniqueness MemReq))
-> String -> m (MemInfo d NoUniqueness MemReq)
forall a b. (a -> b) -> a -> b
$ String
"allocInMatchBody: mismatch: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (TypeBase (ShapeBase d) NoUniqueness, LParamMem) -> String
forall a. Show a => a -> String
show (TypeBase (ShapeBase d) NoUniqueness
t, LParamMem
v_info)
mkBranchRet :: [MemReqType] -> [BranchTypeMem]
mkBranchRet :: [MemReqType] -> [BranchTypeMem]
mkBranchRet [MemReqType]
reqs =
let ([BranchTypeMem]
ctx_rets, [BranchTypeMem]
res_rets) = (([BranchTypeMem], [BranchTypeMem])
-> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem]))
-> ([BranchTypeMem], [BranchTypeMem])
-> [(MemReqType, Int)]
-> ([BranchTypeMem], [BranchTypeMem])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([BranchTypeMem], [BranchTypeMem])
-> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem])
helper ([], []) ([(MemReqType, Int)] -> ([BranchTypeMem], [BranchTypeMem]))
-> [(MemReqType, Int)] -> ([BranchTypeMem], [BranchTypeMem])
forall a b. (a -> b) -> a -> b
$ [MemReqType] -> [Int] -> [(MemReqType, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [MemReqType]
reqs [Int]
offsets
in [BranchTypeMem]
ctx_rets [BranchTypeMem] -> [BranchTypeMem] -> [BranchTypeMem]
forall a. [a] -> [a] -> [a]
++ [BranchTypeMem]
res_rets
where
numCtxNeeded :: MemReqType -> Int
numCtxNeeded = [MemInfo Any Any Any] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([MemInfo Any Any Any] -> Int)
-> (MemReqType -> [MemInfo Any Any Any]) -> MemReqType -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemReqType -> [MemInfo Any Any Any]
forall d u r. MemReqType -> [MemInfo d u r]
contextRets
offsets :: [Int]
offsets = (Int -> Int -> Int) -> Int -> [Int] -> [Int]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0 ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (MemReqType -> Int) -> [MemReqType] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map MemReqType -> Int
numCtxNeeded [MemReqType]
reqs
num_new_ctx :: Int
num_new_ctx = [Int] -> Int
forall a. [a] -> a
last [Int]
offsets
helper :: ([BranchTypeMem], [BranchTypeMem])
-> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem])
helper ([BranchTypeMem]
ctx_rets_acc, [BranchTypeMem]
res_rets_acc) (MemReqType
req, Int
ctx_offset) =
( [BranchTypeMem]
ctx_rets_acc [BranchTypeMem] -> [BranchTypeMem] -> [BranchTypeMem]
forall a. [a] -> [a] -> [a]
++ MemReqType -> [BranchTypeMem]
forall d u r. MemReqType -> [MemInfo d u r]
contextRets MemReqType
req,
[BranchTypeMem]
res_rets_acc [BranchTypeMem] -> [BranchTypeMem] -> [BranchTypeMem]
forall a. [a] -> [a] -> [a]
++ [Int -> MemReqType -> BranchTypeMem
inspect Int
ctx_offset MemReqType
req]
)
arrayInfo :: Int -> MemReq -> (Space, [Int], [Monotonicity], Int, Bool)
arrayInfo Int
rank (NeedsLinearisation Space
space) =
(Space
space, [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1], Monotonicity -> [Monotonicity]
forall a. a -> [a]
repeat Monotonicity
IxFun.Inc, Int
rank, Bool
True)
arrayInfo Int
_ (MemReq Space
space [Int]
perm [Monotonicity]
mon (Rank Int
base_rank) Bool
contig) =
(Space
space, [Int]
perm, [Monotonicity]
mon, Int
base_rank, Bool
contig)
inspect :: Int -> MemReqType -> BranchTypeMem
inspect Int
ctx_offset (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u MemReq
req) =
let shape' :: ShapeBase (Ext SubExp)
shape' = (Ext SubExp -> Ext SubExp)
-> ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Ext SubExp -> Ext SubExp
forall a. Int -> Ext a -> Ext a
adjustExt Int
num_new_ctx) ShapeBase (Ext SubExp)
shape
(Space
space, [Int]
perm, [Monotonicity]
mon, Int
base_rank, Bool
contig) = Int -> MemReq -> (Space, [Int], [Monotonicity], Int, Bool)
arrayInfo (ShapeBase (Ext SubExp) -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) MemReq
req
in PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> MemReturn
-> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' NoUniqueness
u (MemReturn -> BranchTypeMem)
-> (ExtIxFun -> MemReturn) -> ExtIxFun -> BranchTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
ctx_offset (ExtIxFun -> BranchTypeMem) -> ExtIxFun -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$
Ext SubExp -> TPrimExp Int64 (Ext VName)
convert
(Ext SubExp -> TPrimExp Int64 (Ext VName))
-> IxFun (Ext SubExp) -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> [(Int, Monotonicity)] -> Bool -> Int -> IxFun (Ext SubExp)
forall a.
Int -> [(Int, Monotonicity)] -> Bool -> Int -> IxFun (Ext a)
IxFun.mkExistential Int
base_rank ([Int] -> [Monotonicity] -> [(Int, Monotonicity)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
perm [Monotonicity]
mon) Bool
contig (Int
ctx_offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
inspect Int
_ (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = VName -> Shape -> [Type] -> NoUniqueness -> BranchTypeMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
inspect Int
_ (MemPrim PrimType
pt) = PrimType -> BranchTypeMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
inspect Int
_ (MemMem Space
space) = Space -> BranchTypeMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i)
convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v
adjustExt :: Int -> Ext a -> Ext a
adjustExt :: forall a. Int -> Ext a -> Ext a
adjustExt Int
_ (Free a
v) = a -> Ext a
forall a. a -> Ext a
Free a
v
adjustExt Int
k (Ext Int
i) = Int -> Ext a
forall a. Int -> Ext a
Ext (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
addCtxToMatchBody ::
(Allocable fromrep torep inner) =>
[MemReqType] ->
Body torep ->
AllocM fromrep torep (Body torep)
addCtxToMatchBody :: forall fromrep torep inner.
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs Body torep
body = AllocM fromrep torep Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (AllocM fromrep torep Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ do
Result
res <- (MemReqType -> SubExpRes -> AllocM fromrep torep SubExpRes)
-> [MemReqType] -> Result -> AllocM fromrep torep Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM MemReqType -> SubExpRes -> AllocM fromrep torep SubExpRes
forall {torep} {fromrep} {inner} {d} {u}.
(PrettyRep fromrep, HasLetDecMem (LetDec torep), OpReturns inner,
SizeSubst inner, BuilderOps torep, FParamInfo torep ~ FParamMem,
LetDec torep ~ LParamMem, ExpDec torep ~ (),
LParamInfo torep ~ LParamMem, FParamInfo fromrep ~ DeclType,
BodyDec torep ~ (), RetType torep ~ RetTypeMem,
LParamInfo fromrep ~ Type, BodyDec fromrep ~ (),
BranchType torep ~ BranchTypeMem, BranchType fromrep ~ ExtType,
RetType fromrep ~ DeclExtType, Op torep ~ MemOp inner) =>
MemInfo d u MemReq -> SubExpRes -> AllocM fromrep torep SubExpRes
linearIfNeeded [MemReqType]
reqs (Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Body (Rep (AllocM fromrep torep)) -> AllocM fromrep torep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body torep
Body (Rep (AllocM fromrep torep))
body
Result
ctx <- [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Result)
-> AllocM fromrep torep [Result] -> AllocM fromrep torep Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep [Result]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> AllocM fromrep torep Result
forall {f :: * -> *} {inner}.
(MonadBuilder f, OpReturns inner, HasLetDecMem (LetDec (Rep f)),
Op (Rep f) ~ MemOp inner, LParamInfo (Rep f) ~ LParamMem,
FParamInfo (Rep f) ~ FParamMem, BranchType (Rep f) ~ BranchTypeMem,
RetType (Rep f) ~ RetTypeMem) =>
SubExpRes -> f Result
resCtx Result
res
Result -> AllocM fromrep torep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ Result
ctx Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
res
where
linearIfNeeded :: MemInfo d u MemReq -> SubExpRes -> AllocM fromrep torep SubExpRes
linearIfNeeded (MemArray PrimType
_ ShapeBase d
_ u
_ (NeedsLinearisation Space
space)) (SubExpRes Certs
cs (Var VName
v)) =
Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (SubExp -> SubExpRes)
-> ((VName, VName) -> SubExp) -> (VName, VName) -> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> ((VName, VName) -> VName) -> (VName, VName) -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, VName) -> VName
forall a b. (a, b) -> b
snd ((VName, VName) -> SubExpRes)
-> AllocM fromrep torep (VName, VName)
-> AllocM fromrep torep SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) VName
v
linearIfNeeded MemInfo d u MemReq
_ SubExpRes
res =
SubExpRes -> AllocM fromrep torep SubExpRes
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExpRes
res
resCtx :: SubExpRes -> f Result
resCtx (SubExpRes Certs
_ Constant {}) =
Result -> f Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
resCtx (SubExpRes Certs
_ (Var VName
v)) = do
LParamMem
info <- VName -> f LParamMem
forall rep (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
v
case LParamMem
info of
MemPrim {} -> Result -> f Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
MemAcc {} -> Result -> f Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
MemMem {} -> Result -> f Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun) -> do
[SubExp]
ixfun_exts <- (TPrimExp Int64 VName -> f SubExp)
-> [TPrimExp Int64 VName] -> f [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp (Rep f) -> f SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ixfun_ext" (Exp (Rep f) -> f SubExp)
-> (TPrimExp Int64 VName -> f (Exp (Rep f)))
-> TPrimExp Int64 VName
-> f SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName -> f (Exp (Rep f))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp) ([TPrimExp Int64 VName] -> f [SubExp])
-> [TPrimExp Int64 VName] -> f [SubExp]
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList IxFun (TPrimExp Int64 VName)
ixfun
Result -> f Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> f Result) -> Result -> f Result
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExpRes
subExpRes (VName -> SubExp
Var VName
mem) SubExpRes -> Result -> Result
forall a. a -> [a] -> [a]
: [SubExp] -> Result
subExpsRes [SubExp]
ixfun_exts
simplifyMatch ::
Mem rep inner =>
[Case (Body rep)] ->
Body rep ->
[BranchTypeMem] ->
( [Case (Body rep)],
Body rep,
[BranchTypeMem]
)
simplifyMatch :: forall rep inner.
Mem rep inner =>
[Case (Body rep)]
-> Body rep
-> [BranchTypeMem]
-> ([Case (Body rep)], Body rep, [BranchTypeMem])
simplifyMatch [Case (Body rep)]
cases Body rep
defbody [BranchTypeMem]
ts =
let case_reses :: [Result]
case_reses = (Case (Body rep) -> Result) -> [Case (Body rep)] -> [Result]
forall a b. (a -> b) -> [a] -> [b]
map (Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result)
-> (Case (Body rep) -> Body rep) -> Case (Body rep) -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
defbody_res :: Result
defbody_res = Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
defbody
([(Int, SubExp)]
ctx_fixes, [(Result, SubExpRes, BranchTypeMem)]
variant) =
[Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)]
-> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)]
-> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)]))
-> ([(Int, Result, SubExpRes, BranchTypeMem)]
-> [Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)])
-> [(Int, Result, SubExpRes, BranchTypeMem)]
-> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem))
-> [(Int, Result, SubExpRes, BranchTypeMem)]
-> [Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
branchInvariant ([(Int, Result, SubExpRes, BranchTypeMem)]
-> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)]))
-> [(Int, Result, SubExpRes, BranchTypeMem)]
-> ([(Int, SubExp)], [(Result, SubExpRes, BranchTypeMem)])
forall a b. (a -> b) -> a -> b
$
[Int]
-> [Result]
-> Result
-> [BranchTypeMem]
-> [(Int, Result, SubExpRes, BranchTypeMem)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Int
0 ..] ([Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
case_reses) Result
defbody_res [BranchTypeMem]
ts
([Result]
cases_reses, Result
defbody_reses, [BranchTypeMem]
ts') = [(Result, SubExpRes, BranchTypeMem)]
-> ([Result], Result, [BranchTypeMem])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Result, SubExpRes, BranchTypeMem)]
variant
in ( (Case (Body rep) -> Result -> Case (Body rep))
-> [Case (Body rep)] -> [Result] -> [Case (Body rep)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Case (Body rep) -> Result -> Case (Body rep)
forall {f :: * -> *} {rep}.
Functor f =>
f (Body rep) -> Result -> f (Body rep)
onCase [Case (Body rep)]
cases ([Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
cases_reses),
Body rep -> Result -> Body rep
forall {rep}. Body rep -> Result -> Body rep
onBody Body rep
defbody Result
defbody_reses,
((Int, SubExp) -> [BranchTypeMem] -> [BranchTypeMem])
-> [BranchTypeMem] -> [(Int, SubExp)] -> [BranchTypeMem]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchTypeMem] -> [BranchTypeMem])
-> (Int, SubExp) -> [BranchTypeMem] -> [BranchTypeMem]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchTypeMem] -> [BranchTypeMem]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [BranchTypeMem]
ts' [(Int, SubExp)]
ctx_fixes
)
where
bound_in_branches :: Names
bound_in_branches =
[VName] -> Names
namesFromList ([VName] -> Names) -> (Stms rep -> [VName]) -> Stms rep -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm rep -> [VName]) -> Stms rep -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName])
-> (Stm rep -> Pat (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) (Stms rep -> Names) -> Stms rep -> Names
forall a b. (a -> b) -> a -> b
$
(Case (Body rep) -> Stms rep) -> [Case (Body rep)] -> Stms rep
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Stms rep)
-> (Case (Body rep) -> Body rep) -> Case (Body rep) -> Stms rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
defbody
onCase :: f (Body rep) -> Result -> f (Body rep)
onCase f (Body rep)
c Result
res = (Body rep -> Body rep) -> f (Body rep) -> f (Body rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Body rep -> Result -> Body rep
forall {rep}. Body rep -> Result -> Body rep
`onBody` Result
res) f (Body rep)
c
onBody :: Body rep -> Result -> Body rep
onBody Body rep
body Result
res = Body rep
body {bodyResult :: Result
bodyResult = Result
res}
branchInvariant :: (Int, Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
branchInvariant (Int
i, Result
case_reses, SubExpRes
defres, BranchTypeMem
t)
| Names -> Names -> Bool
namesIntersect Names
bound_in_branches (Names -> Bool) -> Names -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ SubExpRes
defres SubExpRes -> Result -> Result
forall a. a -> [a] -> [a]
: Result
case_reses =
(Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
forall a b. b -> Either a b
Right (Result
case_reses, SubExpRes
defres, BranchTypeMem
t)
| (SubExpRes -> Bool) -> Result -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
defres) (SubExp -> Bool) -> (SubExpRes -> SubExp) -> SubExpRes -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
case_reses =
(Int, SubExp)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
forall a b. a -> Either a b
Left (Int
i, SubExpRes -> SubExp
resSubExp SubExpRes
defres)
| Bool
otherwise =
(Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
forall a b. b -> Either a b
Right (Result
case_reses, SubExpRes
defres, BranchTypeMem
t)
allocInExp ::
(Allocable fromrep torep inner) =>
Exp fromrep ->
AllocM fromrep torep (Exp torep)
allocInExp :: forall fromrep torep inner.
Allocable fromrep torep inner =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp (DoLoop [(FParam fromrep, SubExp)]
merge LoopForm fromrep
form (Body () Stms fromrep
bodystms Result
bodyres)) =
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep (Exp torep))
-> AllocM fromrep torep (Exp torep)
forall fromrep torep inner a.
Allocable fromrep torep inner =>
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
merge (([(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep (Exp torep))
-> AllocM fromrep torep (Exp torep))
-> ([(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep (Exp torep))
-> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ \[(FParam torep, SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val -> do
LoopForm torep
form' <- LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
forall fromrep torep inner.
Allocable fromrep torep inner =>
LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
allocInLoopForm LoopForm fromrep
form
Scope torep
-> AllocM fromrep torep (Exp torep)
-> AllocM fromrep torep (Exp torep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm torep -> Scope torep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm torep
form') (AllocM fromrep torep (Exp torep)
-> AllocM fromrep torep (Exp torep))
-> AllocM fromrep torep (Exp torep)
-> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ do
Body torep
body' <-
AllocM fromrep torep Result -> AllocM fromrep torep (Body torep)
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (AllocM fromrep torep Result -> AllocM fromrep torep (Body torep))
-> (AllocM fromrep torep Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Body torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
bodystms (AllocM fromrep torep Result -> AllocM fromrep torep (Body torep))
-> AllocM fromrep torep Result -> AllocM fromrep torep (Body torep)
forall a b. (a -> b) -> a -> b
$ do
([SubExp]
valctx, [SubExp]
valres') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
bodyres
Result -> AllocM fromrep torep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
valctx Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> (Certs -> SubExp -> SubExpRes) -> [Certs] -> [SubExp] -> Result
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Certs -> SubExp -> SubExpRes
SubExpRes ((SubExpRes -> Certs) -> Result -> [Certs]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts Result
bodyres) [SubExp]
valres'
Exp torep -> AllocM fromrep torep (Exp torep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp torep -> AllocM fromrep torep (Exp torep))
-> Exp torep -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ [(FParam torep, SubExp)]
-> LoopForm torep -> Body torep -> Exp torep
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam torep, SubExp)]
merge' LoopForm torep
form' Body torep
body'
allocInExp (Apply Name
fname [(SubExp, Diet)]
args [RetType fromrep]
rettype (Safety, SrcLoc, [SrcLoc])
loc) = do
[(SubExp, Diet)]
args' <- [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall fromrep torep inner.
Allocable fromrep torep inner =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args
Exp torep -> AllocM fromrep torep (Exp torep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp torep -> AllocM fromrep torep (Exp torep))
-> Exp torep -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ Name
-> [(SubExp, Diet)]
-> [RetType torep]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp torep
forall rep.
Name
-> [(SubExp, Diet)]
-> [RetType rep]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp rep
Apply Name
fname [(SubExp, Diet)]
args' ([RetTypeMem]
mems [RetTypeMem] -> [RetTypeMem] -> [RetTypeMem]
forall a. [a] -> [a] -> [a]
++ Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Int
num_arrays [DeclExtType]
[RetType fromrep]
rettype) (Safety, SrcLoc, [SrcLoc])
loc
where
mems :: [RetTypeMem]
mems = Int -> RetTypeMem -> [RetTypeMem]
forall a. Int -> a -> [a]
replicate Int
num_arrays (Space -> RetTypeMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
DefaultSpace)
num_arrays :: Int
num_arrays = [DeclExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([DeclExtType] -> Int) -> [DeclExtType] -> Int
forall a b. (a -> b) -> a -> b
$ (DeclExtType -> Bool) -> [DeclExtType] -> [DeclExtType]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (Int -> Bool) -> (DeclExtType -> Int) -> DeclExtType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DeclExtType -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (DeclExtType -> Int)
-> (DeclExtType -> DeclExtType) -> DeclExtType -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DeclExtType -> DeclExtType
forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf) [DeclExtType]
[RetType fromrep]
rettype
allocInExp (Match [SubExp]
ses [Case (Body fromrep)]
cases Body fromrep
defbody (MatchDec [BranchType fromrep]
rets MatchSort
ifsort)) = do
(Body torep
defbody', [MemReqType]
def_reqs) <- [ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
forall fromrep torep inner.
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [ExtType]
[BranchType fromrep]
rets Body fromrep
defbody
([Case (Body torep)]
cases', [[MemReqType]]
cases_reqs) <- [(Case (Body torep), [MemReqType])]
-> ([Case (Body torep)], [[MemReqType]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Case (Body torep), [MemReqType])]
-> ([Case (Body torep)], [[MemReqType]]))
-> AllocM fromrep torep [(Case (Body torep), [MemReqType])]
-> AllocM fromrep torep ([Case (Body torep)], [[MemReqType]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Case (Body fromrep)
-> AllocM fromrep torep (Case (Body torep), [MemReqType]))
-> [Case (Body fromrep)]
-> AllocM fromrep torep [(Case (Body torep), [MemReqType])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Case (Body fromrep)
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
onCase [Case (Body fromrep)]
cases
let reqs :: [MemReqType]
reqs = (MemReqType -> [MemReqType] -> MemReqType)
-> [MemReqType] -> [[MemReqType]] -> [MemReqType]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((MemReqType -> MemReqType -> MemReqType)
-> MemReqType -> [MemReqType] -> MemReqType
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl MemReqType -> MemReqType -> MemReqType
combMemReqTypes) [MemReqType]
def_reqs ([[MemReqType]] -> [[MemReqType]]
forall a. [[a]] -> [[a]]
transpose [[MemReqType]]
cases_reqs)
Body torep
defbody'' <- [MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
forall fromrep torep inner.
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs Body torep
defbody'
[Case (Body torep)]
cases'' <- (Case (Body torep) -> AllocM fromrep torep (Case (Body torep)))
-> [Case (Body torep)] -> AllocM fromrep torep [Case (Body torep)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Body torep -> AllocM fromrep torep (Body torep))
-> Case (Body torep) -> AllocM fromrep torep (Case (Body torep))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Body torep -> AllocM fromrep torep (Body torep))
-> Case (Body torep) -> AllocM fromrep torep (Case (Body torep)))
-> (Body torep -> AllocM fromrep torep (Body torep))
-> Case (Body torep)
-> AllocM fromrep torep (Case (Body torep))
forall a b. (a -> b) -> a -> b
$ [MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
forall fromrep torep inner.
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs) [Case (Body torep)]
cases'
let ([Case (Body torep)]
cases''', Body torep
defbody''', [BranchTypeMem]
rets') =
[Case (Body torep)]
-> Body torep
-> [BranchTypeMem]
-> ([Case (Body torep)], Body torep, [BranchTypeMem])
forall rep inner.
Mem rep inner =>
[Case (Body rep)]
-> Body rep
-> [BranchTypeMem]
-> ([Case (Body rep)], Body rep, [BranchTypeMem])
simplifyMatch [Case (Body torep)]
cases'' Body torep
defbody'' ([BranchTypeMem]
-> ([Case (Body torep)], Body torep, [BranchTypeMem]))
-> [BranchTypeMem]
-> ([Case (Body torep)], Body torep, [BranchTypeMem])
forall a b. (a -> b) -> a -> b
$ [MemReqType] -> [BranchTypeMem]
mkBranchRet [MemReqType]
reqs
Exp torep -> AllocM fromrep torep (Exp torep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp torep -> AllocM fromrep torep (Exp torep))
-> Exp torep -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body torep)]
-> Body torep
-> MatchDec (BranchType torep)
-> Exp torep
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses [Case (Body torep)]
cases''' Body torep
defbody''' (MatchDec (BranchType torep) -> Exp torep)
-> MatchDec (BranchType torep) -> Exp torep
forall a b. (a -> b) -> a -> b
$ [BranchTypeMem] -> MatchSort -> MatchDec BranchTypeMem
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchTypeMem]
rets' MatchSort
ifsort
where
onCase :: Case (Body fromrep)
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
onCase (Case [Maybe PrimValue]
vs Body fromrep
body) = (Body torep -> Case (Body torep))
-> (Body torep, [MemReqType]) -> (Case (Body torep), [MemReqType])
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ([Maybe PrimValue] -> Body torep -> Case (Body torep)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs) ((Body torep, [MemReqType]) -> (Case (Body torep), [MemReqType]))
-> AllocM fromrep torep (Body torep, [MemReqType])
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
forall fromrep torep inner.
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [ExtType]
[BranchType fromrep]
rets Body fromrep
body
allocInExp (WithAcc [WithAccInput fromrep]
inputs Lambda fromrep
bodylam) =
[WithAccInput torep] -> Lambda torep -> Exp torep
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc ([WithAccInput torep] -> Lambda torep -> Exp torep)
-> AllocM fromrep torep [WithAccInput torep]
-> AllocM fromrep torep (Lambda torep -> Exp torep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithAccInput fromrep -> AllocM fromrep torep (WithAccInput torep))
-> [WithAccInput fromrep]
-> AllocM fromrep torep [WithAccInput torep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM WithAccInput fromrep -> AllocM fromrep torep (WithAccInput torep)
forall {t :: * -> *} {a} {rep} {inner} {fromrep} {b}.
(Traversable t, ArrayShape a, HasLetDecMem (LetDec rep),
BuilderOps rep, OpReturns inner, SizeSubst inner,
PrettyRep fromrep, BranchType rep ~ BranchTypeMem,
BodyDec fromrep ~ (), RetType rep ~ RetTypeMem,
RetType fromrep ~ DeclExtType, FParamInfo rep ~ FParamMem,
BranchType fromrep ~ ExtType, LParamInfo rep ~ LParamMem,
ExpDec rep ~ (), BodyDec rep ~ (), Op rep ~ MemOp inner,
FParamInfo fromrep ~ DeclType, LParamInfo fromrep ~ Type,
LetDec rep ~ LParamMem) =>
(a, [VName], t (Lambda fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput [WithAccInput fromrep]
inputs AllocM fromrep torep (Lambda torep -> Exp torep)
-> AllocM fromrep torep (Lambda torep)
-> AllocM fromrep torep (Exp torep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda fromrep -> AllocM fromrep torep (Lambda torep)
forall {fromrep} {torep} {inner}.
(PrettyRep fromrep, HasLetDecMem (LetDec torep), OpReturns inner,
SizeSubst inner, BuilderOps torep, ExpDec torep ~ (),
LParamInfo fromrep ~ Type, BodyDec torep ~ (),
FParamInfo torep ~ FParamMem, LetDec torep ~ LParamMem,
BodyDec fromrep ~ (), RetType torep ~ RetTypeMem,
FParamInfo fromrep ~ DeclType, RetType fromrep ~ DeclExtType,
BranchType torep ~ BranchTypeMem, LParamInfo torep ~ LParamMem,
BranchType fromrep ~ ExtType, Op torep ~ MemOp inner) =>
Lambda fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
bodylam
where
onLambda :: Lambda fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
lam = do
[Param LParamMem]
params <- [Param Type]
-> (Param Type -> AllocM fromrep torep (Param LParamMem))
-> AllocM fromrep torep [Param LParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda fromrep -> [LParam fromrep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda fromrep
lam) ((Param Type -> AllocM fromrep torep (Param LParamMem))
-> AllocM fromrep torep [Param LParamMem])
-> (Param Type -> AllocM fromrep torep (Param LParamMem))
-> AllocM fromrep torep [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ \(Param Attrs
attrs VName
pv Type
t) ->
case Type
t of
Prim PrimType
Unit -> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param LParamMem -> AllocM fromrep torep (Param LParamMem))
-> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ Attrs -> VName -> LParamMem -> Param LParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv (LParamMem -> Param LParamMem) -> LParamMem -> Param LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
Unit
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u -> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param LParamMem -> AllocM fromrep torep (Param LParamMem))
-> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ Attrs -> VName -> LParamMem -> Param LParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv (LParamMem -> Param LParamMem) -> LParamMem -> Param LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
Type
_ -> String -> AllocM fromrep torep (Param LParamMem)
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep torep (Param LParamMem))
-> String -> AllocM fromrep torep (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ String
"Unexpected WithAcc lambda param: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Param Type -> String
forall a. Pretty a => a -> String
pretty (Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv Type
t)
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
forall fromrep torep inner.
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
[Param LParamMem]
params (Lambda fromrep -> Body fromrep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda fromrep
lam)
onInput :: (a, [VName], t (Lambda fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput (a
shape, [VName]
arrs, t (Lambda fromrep, b)
op) =
(a
shape,[VName]
arrs,) (t (Lambda rep, b) -> (a, [VName], t (Lambda rep, b)))
-> AllocM fromrep rep (t (Lambda rep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Lambda fromrep, b) -> AllocM fromrep rep (Lambda rep, b))
-> t (Lambda fromrep, b) -> AllocM fromrep rep (t (Lambda rep, b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
forall {a} {rep} {fromrep} {inner} {b}.
(ArrayShape a, HasLetDecMem (LetDec rep), BuilderOps rep,
PrettyRep fromrep, OpReturns inner, SizeSubst inner,
BodyDec fromrep ~ (), FParamInfo rep ~ FParamMem, ExpDec rep ~ (),
FParamInfo fromrep ~ DeclType, LParamInfo fromrep ~ Type,
BodyDec rep ~ (), BranchType fromrep ~ ExtType,
RetType rep ~ RetTypeMem, BranchType rep ~ BranchTypeMem,
RetType fromrep ~ DeclExtType, LParamInfo rep ~ LParamMem,
Op rep ~ MemOp inner, LetDec rep ~ LParamMem) =>
a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
shape [VName]
arrs) t (Lambda fromrep, b)
op
onOp :: a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
accshape [VName]
arrs (Lambda fromrep
lam, b
nes) = do
let num_vs :: Int
num_vs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda fromrep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda fromrep
lam)
num_is :: Int
num_is = a -> Int
forall a. ArrayShape a => a -> Int
shapeRank a
accshape
([Param Type]
i_params, [Param Type]
x_params, [Param Type]
y_params) =
Int
-> Int
-> [Param Type]
-> ([Param Type], [Param Type], [Param Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 Int
num_is Int
num_vs ([Param Type] -> ([Param Type], [Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda fromrep -> [LParam fromrep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda fromrep
lam
i_params' :: [Param LParamMem]
i_params' = (Param Type -> Param LParamMem)
-> [Param Type] -> [Param LParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (\(Param Attrs
attrs VName
v Type
_) -> Attrs -> VName -> LParamMem -> Param LParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
v (LParamMem -> Param LParamMem) -> LParamMem -> Param LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64) [Param Type]
i_params
is :: [DimIndex SubExp]
is = (Param LParamMem -> DimIndex SubExp)
-> [Param LParamMem] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (Param LParamMem -> SubExp)
-> Param LParamMem
-> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (Param LParamMem -> VName) -> Param LParamMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param LParamMem -> VName
forall dec. Param dec -> VName
paramName) [Param LParamMem]
i_params'
[Param LParamMem]
x_params' <- (Param Type -> VName -> AllocM fromrep rep (Param LParamMem))
-> [Param Type] -> [VName] -> AllocM fromrep rep [Param LParamMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([DimIndex SubExp]
-> Param Type -> VName -> AllocM fromrep rep (Param LParamMem)
forall {f :: * -> *} {rep} {inner} {u}.
(Monad f, HasLetDecMem (LetDec rep), ASTRep rep, OpReturns inner,
HasScope rep f, Pretty u, FParamInfo rep ~ FParamMem,
LParamInfo rep ~ LParamMem, RetType rep ~ RetTypeMem,
BranchType rep ~ BranchTypeMem, Op rep ~ MemOp inner) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
is) [Param Type]
x_params [VName]
arrs
[Param LParamMem]
y_params' <- (Param Type -> VName -> AllocM fromrep rep (Param LParamMem))
-> [Param Type] -> [VName] -> AllocM fromrep rep [Param LParamMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([DimIndex SubExp]
-> Param Type -> VName -> AllocM fromrep rep (Param LParamMem)
forall {rep} {fromrep} {inner} {u}.
(PrettyRep fromrep, HasLetDecMem (LetDec rep), OpReturns inner,
SizeSubst inner, BuilderOps rep, Pretty u,
FParamInfo rep ~ FParamMem, LetDec rep ~ LParamMem,
ExpDec rep ~ (), LParamInfo rep ~ LParamMem,
FParamInfo fromrep ~ DeclType, BodyDec rep ~ (),
RetType rep ~ RetTypeMem, LParamInfo fromrep ~ Type,
BodyDec fromrep ~ (), BranchType rep ~ BranchTypeMem,
BranchType fromrep ~ ExtType, RetType fromrep ~ DeclExtType,
Op rep ~ MemOp inner) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
is) [Param Type]
y_params [VName]
arrs
Lambda rep
lam' <-
[LParam rep] -> Body fromrep -> AllocM fromrep rep (Lambda rep)
forall fromrep torep inner.
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda
([Param LParamMem]
i_params' [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
x_params' [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
y_params')
(Lambda fromrep -> Body fromrep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda fromrep
lam)
(Lambda rep, b) -> AllocM fromrep rep (Lambda rep, b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep
lam', b
nes)
mkP :: Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is =
Attrs
-> VName
-> MemInfo SubExp u MemBind
-> Param (MemInfo SubExp u MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind))
-> (Slice (TPrimExp Int64 VName) -> MemInfo SubExp u MemBind)
-> Slice (TPrimExp Int64 VName)
-> Param (MemInfo SubExp u MemBind)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> u -> MemBind -> MemInfo SubExp u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape u
u (MemBind -> MemInfo SubExp u MemBind)
-> (Slice (TPrimExp Int64 VName) -> MemBind)
-> Slice (TPrimExp Int64 VName)
-> MemInfo SubExp u MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (IxFun (TPrimExp Int64 VName) -> MemBind)
-> (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName)
-> MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
ixfun (Slice (TPrimExp Int64 VName) -> Param (MemInfo SubExp u MemBind))
-> Slice (TPrimExp Int64 VName) -> Param (MemInfo SubExp u MemBind)
forall a b. (a -> b) -> a -> b
$
(SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 (Slice SubExp -> Slice (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
[DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
[DimIndex SubExp]
is [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape)
onXParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> MemInfo SubExp u MemBind
-> Param (MemInfo SubExp u MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (PrimType -> MemInfo SubExp u MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
onXParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
(VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- VName -> f (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
arr
Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
forall {u}.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is
onXParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
String -> f (Param (MemInfo SubExp u MemBind))
forall a. HasCallStack => String -> a
error (String -> f (Param (MemInfo SubExp u MemBind)))
-> String -> f (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Param (TypeBase Shape u) -> String
forall a. Pretty a => a -> String
pretty Param (TypeBase Shape u)
p
onYParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> MemInfo SubExp u MemBind
-> Param (MemInfo SubExp u MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind))
-> MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo SubExp u MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
onYParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
Type
arr_t <- VName -> AllocM fromrep rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
VName
mem <- Type -> Space -> AllocM fromrep rep VName
forall fromrep torep inner.
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
arr_t Space
DefaultSpace
let base_dims :: [TPrimExp Int64 VName]
base_dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
base_dims
Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
forall {u}.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is
onYParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
String -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind)))
-> String -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Param (TypeBase Shape u) -> String
forall a. Pretty a => a -> String
pretty Param (TypeBase Shape u)
p
allocInExp Exp fromrep
e = Mapper fromrep torep (AllocM fromrep torep)
-> Exp fromrep -> AllocM fromrep torep (Exp torep)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper fromrep torep (AllocM fromrep torep)
alloc Exp fromrep
e
where
alloc :: Mapper fromrep torep (AllocM fromrep torep)
alloc =
Mapper Any Any (AllocM fromrep torep)
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
{ mapOnBody :: Scope torep -> Body fromrep -> AllocM fromrep torep (Body torep)
mapOnBody = String
-> Scope torep -> Body fromrep -> AllocM fromrep torep (Body torep)
forall a. HasCallStack => String -> a
error String
"Unhandled Body in ExplicitAllocations",
mapOnRetType :: RetType fromrep -> AllocM fromrep torep (RetType torep)
mapOnRetType = String -> RetType fromrep -> AllocM fromrep torep (RetType torep)
forall a. HasCallStack => String -> a
error String
"Unhandled RetType in ExplicitAllocations",
mapOnBranchType :: BranchType fromrep -> AllocM fromrep torep (BranchType torep)
mapOnBranchType = String
-> BranchType fromrep -> AllocM fromrep torep (BranchType torep)
forall a. HasCallStack => String -> a
error String
"Unhandled BranchType in ExplicitAllocations",
mapOnFParam :: FParam fromrep -> AllocM fromrep torep (FParam torep)
mapOnFParam = String -> FParam fromrep -> AllocM fromrep torep (FParam torep)
forall a. HasCallStack => String -> a
error String
"Unhandled FParam in ExplicitAllocations",
mapOnLParam :: LParam fromrep -> AllocM fromrep torep (LParam torep)
mapOnLParam = String -> LParam fromrep -> AllocM fromrep torep (LParam torep)
forall a. HasCallStack => String -> a
error String
"Unhandled LParam in ExplicitAllocations",
mapOnOp :: Op fromrep -> AllocM fromrep torep (Op torep)
mapOnOp = \Op fromrep
op -> do
Op fromrep -> AllocM fromrep torep (Op torep)
handle <- (AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep))
-> AllocM
fromrep torep (Op fromrep -> AllocM fromrep torep (Op torep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
forall fromrep torep.
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp
Op fromrep -> AllocM fromrep torep (Op torep)
handle Op fromrep
op
}
allocInLoopForm ::
(Allocable fromrep torep inner) =>
LoopForm fromrep ->
AllocM fromrep torep (LoopForm torep)
allocInLoopForm :: forall fromrep torep inner.
Allocable fromrep torep inner =>
LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
allocInLoopForm (WhileLoop VName
v) = LoopForm torep -> AllocM fromrep torep (LoopForm torep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LoopForm torep -> AllocM fromrep torep (LoopForm torep))
-> LoopForm torep -> AllocM fromrep torep (LoopForm torep)
forall a b. (a -> b) -> a -> b
$ VName -> LoopForm torep
forall rep. VName -> LoopForm rep
WhileLoop VName
v
allocInLoopForm (ForLoop VName
i IntType
it SubExp
n [(LParam fromrep, VName)]
loopvars) =
VName
-> IntType -> SubExp -> [(LParam torep, VName)] -> LoopForm torep
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
n ([(Param LParamMem, VName)] -> LoopForm torep)
-> AllocM fromrep torep [(Param LParamMem, VName)]
-> AllocM fromrep torep (LoopForm torep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName)
-> AllocM fromrep torep (Param LParamMem, VName))
-> [(Param Type, VName)]
-> AllocM fromrep torep [(Param LParamMem, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param Type, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
allocInLoopVar [(Param Type, VName)]
[(LParam fromrep, VName)]
loopvars
where
allocInLoopVar :: (Param Type, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
allocInLoopVar (Param Type
p, VName
a) = do
(VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
a
case Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p of
Array PrimType
pt Shape
shape NoUniqueness
u -> do
[TPrimExp Int64 VName]
dims <- (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> (Type -> [SubExp]) -> Type -> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [TPrimExp Int64 VName])
-> AllocM fromrep torep Type
-> AllocM fromrep torep [TPrimExp Int64 VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> AllocM fromrep torep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
let ixfun' :: IxFun (TPrimExp Int64 VName)
ixfun' = IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName]
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum [TPrimExp Int64 VName]
dims [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i]
(Param LParamMem, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun'}, VName
a)
Prim PrimType
bt ->
(Param LParamMem, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt}, VName
a)
Mem Space
space ->
(Param LParamMem, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}, VName
a)
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
(Param LParamMem, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u}, VName
a)
class SizeSubst op where
opSizeSubst :: Pat dec -> op -> ChunkMap
opIsConst :: op -> Bool
opIsConst = Bool -> op -> Bool
forall a b. a -> b -> a
const Bool
False
instance SizeSubst () where
opSizeSubst :: forall dec. Pat dec -> () -> ChunkMap
opSizeSubst Pat dec
_ ()
_ = ChunkMap
forall a. Monoid a => a
mempty
instance SizeSubst op => SizeSubst (MemOp op) where
opSizeSubst :: forall dec. Pat dec -> MemOp op -> ChunkMap
opSizeSubst Pat dec
pat (Inner op
op) = Pat dec -> op -> ChunkMap
forall op dec. SizeSubst op => Pat dec -> op -> ChunkMap
opSizeSubst Pat dec
pat op
op
opSizeSubst Pat dec
_ MemOp op
_ = ChunkMap
forall a. Monoid a => a
mempty
opIsConst :: MemOp op -> Bool
opIsConst (Inner op
op) = op -> Bool
forall op. SizeSubst op => op -> Bool
opIsConst op
op
opIsConst MemOp op
_ = Bool
False
sizeSubst :: SizeSubst (Op rep) => Stm rep -> ChunkMap
sizeSubst :: forall rep. SizeSubst (Op rep) => Stm rep -> ChunkMap
sizeSubst (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Op Op rep
op)) = Pat (LetDec rep) -> Op rep -> ChunkMap
forall op dec. SizeSubst op => Pat dec -> op -> ChunkMap
opSizeSubst Pat (LetDec rep)
pat Op rep
op
sizeSubst Stm rep
_ = ChunkMap
forall a. Monoid a => a
mempty
stmConsts :: SizeSubst (Op rep) => Stm rep -> S.Set VName
stmConsts :: forall rep. SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Op Op rep
op))
| Op rep -> Bool
forall op. SizeSubst op => op -> Bool
opIsConst Op rep
op = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat
stmConsts Stm rep
_ = Set VName
forall a. Monoid a => a
mempty
mkLetNamesB' ::
( LetDec (Rep m) ~ LetDecMem,
Mem (Rep m) inner,
MonadBuilder m,
ExpDec (Rep m) ~ ()
) =>
ExpDec (Rep m) ->
[VName] ->
Exp (Rep m) ->
m (Stm (Rep m))
mkLetNamesB' :: forall (m :: * -> *) inner.
(LetDec (Rep m) ~ LParamMem, Mem (Rep m) inner, MonadBuilder m,
ExpDec (Rep m) ~ ()) =>
ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' ExpDec (Rep m)
dec [VName]
names Exp (Rep m)
e = do
Pat LParamMem
pat <- Space
-> ChunkMap
-> [VName]
-> Exp (Rep m)
-> [ExpHint]
-> m (Pat LParamMem)
forall (m :: * -> *) inner.
(MonadBuilder m, Mem (Rep m) inner) =>
Space
-> ChunkMap
-> [VName]
-> Exp (Rep m)
-> [ExpHint]
-> m (Pat LParamMem)
patWithAllocations Space
DefaultSpace ChunkMap
forall a. Monoid a => a
mempty [VName]
names Exp (Rep m)
e [ExpHint]
nohints
Stm (Rep m) -> m (Stm (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm (Rep m) -> m (Stm (Rep m))) -> Stm (Rep m) -> m (Stm (Rep m))
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep m))
-> StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Rep m))
Pat LParamMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()
ExpDec (Rep m)
dec) Exp (Rep m)
e
where
nohints :: [ExpHint]
nohints = (VName -> ExpHint) -> [VName] -> [ExpHint]
forall a b. (a -> b) -> [a] -> [b]
map (ExpHint -> VName -> ExpHint
forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names
mkLetNamesB'' ::
( Mem rep inner,
LetDec rep ~ LetDecMem,
OpReturns (Engine.OpWithWisdom inner),
ExpDec rep ~ (),
Rep m ~ Engine.Wise rep,
HasScope (Engine.Wise rep) m,
MonadBuilder m,
Engine.CanBeWise inner
) =>
[VName] ->
Exp (Engine.Wise rep) ->
m (Stm (Engine.Wise rep))
mkLetNamesB'' :: forall rep inner (m :: * -> *).
(Mem rep inner, LetDec rep ~ LParamMem,
OpReturns (OpWithWisdom inner), ExpDec rep ~ (), Rep m ~ Wise rep,
HasScope (Wise rep) m, MonadBuilder m, CanBeWise inner) =>
[VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB'' [VName]
names Exp (Wise rep)
e = do
Pat LParamMem
pat <- Space
-> ChunkMap
-> [VName]
-> Exp (Rep m)
-> [ExpHint]
-> m (Pat LParamMem)
forall (m :: * -> *) inner.
(MonadBuilder m, Mem (Rep m) inner) =>
Space
-> ChunkMap
-> [VName]
-> Exp (Rep m)
-> [ExpHint]
-> m (Pat LParamMem)
patWithAllocations Space
DefaultSpace ChunkMap
forall a. Monoid a => a
mempty [VName]
names Exp (Rep m)
Exp (Wise rep)
e [ExpHint]
nohints
let pat' :: Pat (LetDec (Wise rep))
pat' = Pat (LetDec rep) -> Exp (Wise rep) -> Pat (LetDec (Wise rep))
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pat (LetDec rep) -> Exp (Wise rep) -> Pat (LetDec (Wise rep))
Engine.addWisdomToPat Pat (LetDec rep)
Pat LParamMem
pat Exp (Wise rep)
e
dec :: ExpDec (Wise rep)
dec = Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (LetDec (Wise rep))
pat' () Exp (Wise rep)
e
Stm (Wise rep) -> m (Stm (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm (Wise rep) -> m (Stm (Wise rep)))
-> Stm (Wise rep) -> m (Stm (Wise rep))
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Wise rep))
-> StmAux (ExpDec (Wise rep)) -> Exp (Wise rep) -> Stm (Wise rep)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Wise rep))
pat' ((ExpWisdom, ()) -> StmAux (ExpWisdom, ())
forall dec. dec -> StmAux dec
defAux (ExpWisdom, ())
ExpDec (Wise rep)
dec) Exp (Wise rep)
e
where
nohints :: [ExpHint]
nohints = (VName -> ExpHint) -> [VName] -> [ExpHint]
forall a b. (a -> b) -> [a] -> [b]
map (ExpHint -> VName -> ExpHint
forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names
simplifiable ::
( Engine.SimplifiableRep rep,
ExpDec rep ~ (),
BodyDec rep ~ (),
LetDec rep ~ LetDecMem,
OpReturns (Engine.OpWithWisdom inner),
Mem rep inner
) =>
(Engine.OpWithWisdom inner -> UT.UsageTable) ->
(Engine.OpWithWisdom inner -> Engine.SimpleM rep (Engine.OpWithWisdom inner, Stms (Engine.Wise rep))) ->
SimpleOps rep
simplifiable :: forall rep inner.
(SimplifiableRep rep, ExpDec rep ~ (), BodyDec rep ~ (),
LetDec rep ~ LParamMem, OpReturns (OpWithWisdom inner),
Mem rep inner) =>
(OpWithWisdom inner -> UsageTable)
-> (OpWithWisdom inner
-> SimpleM rep (OpWithWisdom inner, Stms (Wise rep)))
-> SimpleOps rep
simplifiable OpWithWisdom inner -> UsageTable
innerUsage OpWithWisdom inner
-> SimpleM rep (OpWithWisdom inner, Stms (Wise rep))
simplifyInnerOp =
(SymbolTable (Wise rep)
-> Pat (LetDec (Wise rep))
-> Exp (Wise rep)
-> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
-> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (BuilderT (Wise rep) (State VNameSource))
-> (Op (Wise rep) -> UsageTable)
-> (Pat (LetDec rep)
-> Exp (Wise rep) -> SimpleM rep (Pat (LetDec rep)))
-> SimplifyOp rep (Op (Wise rep))
-> SimpleOps rep
forall rep.
(SymbolTable (Wise rep)
-> Pat (LetDec (Wise rep))
-> Exp (Wise rep)
-> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
-> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (Builder (Wise rep))
-> (Op (Wise rep) -> UsageTable)
-> (Pat (LetDec rep)
-> Exp (Wise rep) -> SimpleM rep (Pat (LetDec rep)))
-> SimplifyOp rep (Op (Wise rep))
-> SimpleOps rep
SimpleOps SymbolTable (Wise rep)
-> Pat (LetDec (Wise rep))
-> Exp (Wise rep)
-> SimpleM rep (ExpDec (Wise rep))
forall {f :: * -> *} {rep} {p}.
(Applicative f, ASTRep rep, CanBeWise (Op rep), ExpDec rep ~ ()) =>
p
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> f (ExpWisdom, ExpDec rep)
mkExpDecS' SymbolTable (Wise rep)
-> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
forall {f :: * -> *} {rep} {p}.
(Applicative f, ASTRep rep, CanBeWise (Op rep),
BodyDec rep ~ ()) =>
p -> Stms (Wise rep) -> Result -> f (Body (Wise rep))
mkBodyS' Protect (BuilderT (Wise rep) (State VNameSource))
forall {m :: * -> *} {d} {u} {ret} {inner} {inner}.
(MonadBuilder m, BranchType (Rep m) ~ MemInfo d u ret,
Op (Rep m) ~ MemOp inner) =>
SubExp -> Pat (LetDec (Rep m)) -> MemOp inner -> Maybe (m ())
protectOp Op (Wise rep) -> UsageTable
MemOp (OpWithWisdom inner) -> UsageTable
opUsage Pat (LetDec rep)
-> Exp (Wise rep) -> SimpleM rep (Pat (LetDec rep))
forall {rep} {inner} {d} {a}.
(BuilderOps (Wise rep), CanBeWise (Op rep),
TraverseOpStms (Wise rep), ASTRep rep, OpReturns inner,
HasLetDecMem (LetDec rep), Simplifiable (LetDec rep),
Simplifiable d, Simplifiable (FParamInfo rep),
Simplifiable (LParamInfo rep), Simplifiable (RetType rep),
Simplifiable (BranchType rep), IndexOp (OpWithWisdom (Op rep)),
BranchType rep ~ BranchTypeMem, LParamInfo rep ~ LParamMem,
FParamInfo rep ~ FParamMem, RetType rep ~ RetTypeMem,
OpWithWisdom (Op rep) ~ MemOp inner) =>
Pat (MemInfo d a MemBind)
-> Exp (Wise rep) -> SimpleM rep (Pat (MemInfo d a MemBind))
simplifyPat SimplifyOp rep (Op (Wise rep))
MemOp (OpWithWisdom inner)
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
simplifyOp
where
mkExpDecS' :: p
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> f (ExpWisdom, ExpDec rep)
mkExpDecS' p
_ Pat (VarWisdom, LetDec rep)
pat Exp (Wise rep)
e =
(ExpWisdom, ExpDec rep) -> f (ExpWisdom, ExpDec rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((ExpWisdom, ExpDec rep) -> f (ExpWisdom, ExpDec rep))
-> (ExpWisdom, ExpDec rep) -> f (ExpWisdom, ExpDec rep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (VarWisdom, LetDec rep)
Pat (LetDec (Wise rep))
pat () Exp (Wise rep)
e
mkBodyS' :: p -> Stms (Wise rep) -> Result -> f (Body (Wise rep))
mkBodyS' p
_ Stms (Wise rep)
stms Result
res = Body (Wise rep) -> f (Body (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Wise rep) -> f (Body (Wise rep)))
-> Body (Wise rep) -> f (Body (Wise rep))
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody () Stms (Wise rep)
stms Result
res
protectOp :: SubExp -> Pat (LetDec (Rep m)) -> MemOp inner -> Maybe (m ())
protectOp SubExp
taken Pat (LetDec (Rep m))
pat (Alloc SubExp
size Space
space) = m () -> Maybe (m ())
forall a. a -> Maybe a
Just (m () -> Maybe (m ())) -> m () -> Maybe (m ())
forall a b. (a -> b) -> a -> b
$ do
Body (Rep m)
tbody <- [SubExp] -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
size]
Body (Rep m)
fbody <- [SubExp] -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
SubExp
size' <-
String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"hoisted_alloc_size" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
[SubExp]
-> [Case (Body (Rep m))]
-> Body (Rep m)
-> MatchDec (BranchType (Rep m))
-> Exp (Rep m)
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp
taken] [[Maybe PrimValue] -> Body (Rep m) -> Case (Body (Rep m))
forall body. [Maybe PrimValue] -> body -> Case body
Case [PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just (PrimValue -> Maybe PrimValue) -> PrimValue -> Maybe PrimValue
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] Body (Rep m)
tbody] Body (Rep m)
fbody (MatchDec (BranchType (Rep m)) -> Exp (Rep m))
-> MatchDec (BranchType (Rep m)) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
[MemInfo d u ret] -> MatchSort -> MatchDec (MemInfo d u ret)
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64] MatchSort
MatchFallback
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size' Space
space
protectOp SubExp
_ Pat (LetDec (Rep m))
_ MemOp inner
_ = Maybe (m ())
forall a. Maybe a
Nothing
opUsage :: MemOp (OpWithWisdom inner) -> UsageTable
opUsage (Alloc (Var VName
size) Space
_) =
VName -> UsageTable
UT.sizeUsage VName
size
opUsage (Alloc SubExp
_ Space
_) =
UsageTable
forall a. Monoid a => a
mempty
opUsage (Inner OpWithWisdom inner
inner) =
OpWithWisdom inner -> UsageTable
innerUsage OpWithWisdom inner
inner
simplifyOp :: MemOp (OpWithWisdom inner)
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
simplifyOp (Alloc SubExp
size Space
space) =
(,) (MemOp (OpWithWisdom inner)
-> Stms (Wise rep)
-> (MemOp (OpWithWisdom inner), Stms (Wise rep)))
-> SimpleM rep (MemOp (OpWithWisdom inner))
-> SimpleM
rep
(Stms (Wise rep) -> (MemOp (OpWithWisdom inner), Stms (Wise rep)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Space -> MemOp (OpWithWisdom inner)
forall inner. SubExp -> Space -> MemOp inner
Alloc (SubExp -> Space -> MemOp (OpWithWisdom inner))
-> SimpleM rep SubExp
-> SimpleM rep (Space -> MemOp (OpWithWisdom inner))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
size SimpleM rep (Space -> MemOp (OpWithWisdom inner))
-> SimpleM rep Space -> SimpleM rep (MemOp (OpWithWisdom inner))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> SimpleM rep Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space) SimpleM
rep
(Stms (Wise rep) -> (MemOp (OpWithWisdom inner), Stms (Wise rep)))
-> SimpleM rep (Stms (Wise rep))
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms (Wise rep)
forall a. Monoid a => a
mempty
simplifyOp (Inner OpWithWisdom inner
k) = do
(OpWithWisdom inner
k', Stms (Wise rep)
hoisted) <- OpWithWisdom inner
-> SimpleM rep (OpWithWisdom inner, Stms (Wise rep))
simplifyInnerOp OpWithWisdom inner
k
(MemOp (OpWithWisdom inner), Stms (Wise rep))
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpWithWisdom inner -> MemOp (OpWithWisdom inner)
forall inner. inner -> MemOp inner
Inner OpWithWisdom inner
k', Stms (Wise rep)
hoisted)
simplifyPat :: Pat (MemInfo d a MemBind)
-> Exp (Wise rep) -> SimpleM rep (Pat (MemInfo d a MemBind))
simplifyPat (Pat [PatElem (MemInfo d a MemBind)]
pes) Exp (Wise rep)
e = do
[ExpReturns]
rets <- Exp (Wise rep) -> SimpleM rep [ExpReturns]
forall rep (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp (Wise rep)
e
[PatElem (MemInfo d a MemBind)] -> Pat (MemInfo d a MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (MemInfo d a MemBind)] -> Pat (MemInfo d a MemBind))
-> SimpleM rep [PatElem (MemInfo d a MemBind)]
-> SimpleM rep (Pat (MemInfo d a MemBind))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem (MemInfo d a MemBind)
-> ExpReturns -> SimpleM rep (PatElem (MemInfo d a MemBind)))
-> [PatElem (MemInfo d a MemBind)]
-> [ExpReturns]
-> SimpleM rep [PatElem (MemInfo d a MemBind)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem (MemInfo d a MemBind)
-> ExpReturns -> SimpleM rep (PatElem (MemInfo d a MemBind))
update [PatElem (MemInfo d a MemBind)]
pes [ExpReturns]
rets
where
names :: [VName]
names = (PatElem (MemInfo d a MemBind) -> VName)
-> [PatElem (MemInfo d a MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem (MemInfo d a MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem (MemInfo d a MemBind)]
pes
update :: PatElem (MemInfo d a MemBind)
-> ExpReturns -> SimpleM rep (PatElem (MemInfo d a MemBind))
update
(PatElem VName
pe_v (MemArray PrimType
pt ShapeBase d
shape a
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_)))
(MemArray PrimType
_ ShapeBase (Ext SubExp)
_ NoUniqueness
_ (Just (ReturnsInBlock VName
_ ExtIxFun
ixfun)))
| Just IxFun (TPrimExp Int64 VName)
ixfun' <- (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Ext VName -> Maybe VName)
-> TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> Maybe VName
inst) ExtIxFun
ixfun =
VName -> MemInfo d a MemBind -> PatElem (MemInfo d a MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
pe_v
(MemInfo d a MemBind -> PatElem (MemInfo d a MemBind))
-> SimpleM rep (MemInfo d a MemBind)
-> SimpleM rep (PatElem (MemInfo d a MemBind))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( PrimType -> ShapeBase d -> a -> MemBind -> MemInfo d a MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt
(ShapeBase d -> a -> MemBind -> MemInfo d a MemBind)
-> SimpleM rep (ShapeBase d)
-> SimpleM rep (a -> MemBind -> MemInfo d a MemBind)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ShapeBase d -> SimpleM rep (ShapeBase d)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase d
shape
SimpleM rep (a -> MemBind -> MemInfo d a MemBind)
-> SimpleM rep a -> SimpleM rep (MemBind -> MemInfo d a MemBind)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
u
SimpleM rep (MemBind -> MemInfo d a MemBind)
-> SimpleM rep MemBind -> SimpleM rep (MemInfo d a MemBind)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (VName -> IxFun (TPrimExp Int64 VName) -> MemBind)
-> SimpleM rep VName
-> SimpleM rep (IxFun (TPrimExp Int64 VName) -> MemBind)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
mem SimpleM rep (IxFun (TPrimExp Int64 VName) -> MemBind)
-> SimpleM rep (IxFun (TPrimExp Int64 VName))
-> SimpleM rep MemBind
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IxFun (TPrimExp Int64 VName)
-> SimpleM rep (IxFun (TPrimExp Int64 VName))
forall (f :: * -> *) a. Applicative f => a -> f a
pure IxFun (TPrimExp Int64 VName)
ixfun')
)
where
inst :: Ext VName -> Maybe VName
inst (Ext Int
i) = Int -> [VName] -> Maybe VName
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i [VName]
names
inst (Free VName
v) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
update PatElem (MemInfo d a MemBind)
pe ExpReturns
_ = (MemInfo d a MemBind -> SimpleM rep (MemInfo d a MemBind))
-> PatElem (MemInfo d a MemBind)
-> SimpleM rep (PatElem (MemInfo d a MemBind))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse MemInfo d a MemBind -> SimpleM rep (MemInfo d a MemBind)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify PatElem (MemInfo d a MemBind)
pe
data ExpHint
= NoHint
| Hint IxFun Space
defaultExpHints :: (Monad m, ASTRep rep) => Exp rep -> m [ExpHint]
defaultExpHints :: forall (m :: * -> *) rep.
(Monad m, ASTRep rep) =>
Exp rep -> m [ExpHint]
defaultExpHints Exp rep
e = [ExpHint] -> m [ExpHint]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExpHint] -> m [ExpHint]) -> [ExpHint] -> m [ExpHint]
forall a b. (a -> b) -> a -> b
$ Int -> ExpHint -> [ExpHint]
forall a. Int -> a -> [a]
replicate (Exp rep -> Int
forall rep. (RepTypes rep, TypedOp (Op rep)) => Exp rep -> Int
expExtTypeSize Exp rep
e) ExpHint
NoHint