{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.IR.Mem
( LetDecMem,
FParamMem,
LParamMem,
RetTypeMem,
BranchTypeMem,
MemOp (..),
MemInfo (..),
MemBound,
MemBind (..),
MemReturn (..),
IxFun,
ExtIxFun,
isStaticIxFun,
ExpReturns,
BodyReturns,
FunReturns,
noUniquenessReturns,
bodyReturnsToExpReturns,
Mem,
AllocOp (..),
OpReturns (..),
varReturns,
expReturns,
extReturns,
lookupMemInfo,
subExpMemInfo,
lookupArraySummary,
existentialiseIxFun,
matchBranchReturnType,
matchPatternToExp,
matchFunctionReturnType,
matchLoopResultMem,
bodyReturnsFromPattern,
checkMemInfo,
module Futhark.IR.Prop,
module Futhark.IR.Traversals,
module Futhark.IR.Pretty,
module Futhark.IR.Syntax,
module Futhark.Analysis.PrimExp.Convert,
)
where
import Control.Category
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Foldable (toList, traverse_)
import Data.List (elemIndex, find)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.PrimExp.Simplify
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.IR.Aliases
( Aliases,
removeExpAliases,
removePatternAliases,
removeScopeAliases,
)
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.IR.Pretty
import Futhark.IR.Prop
import Futhark.IR.Prop.Aliases
import Futhark.IR.Syntax
import Futhark.IR.Traversals
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Lore
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.TypeCheck as TC
import Futhark.Util
import Futhark.Util.Pretty (indent, ppr, text, (<+>), (</>))
import qualified Futhark.Util.Pretty as PP
import Prelude hiding (id, (.))
type LetDecMem = MemInfo SubExp NoUniqueness MemBind
type FParamMem = MemInfo SubExp Uniqueness MemBind
type LParamMem = MemInfo SubExp NoUniqueness MemBind
type RetTypeMem = FunReturns
type BranchTypeMem = BodyReturns
class AllocOp op where
allocOp :: SubExp -> Space -> op
type Mem lore =
( AllocOp (Op lore),
FParamInfo lore ~ FParamMem,
LParamInfo lore ~ LParamMem,
LetDec lore ~ LetDecMem,
RetType lore ~ RetTypeMem,
BranchType lore ~ BranchTypeMem,
ASTLore lore,
Decorations lore,
OpReturns lore
)
instance IsRetType FunReturns where
primRetType :: PrimType -> FunReturns
primRetType = PrimType -> FunReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim
applyRetType :: [FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
applyRetType = [FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
forall dec.
Typed dec =>
[FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
applyFunReturns
instance IsBodyType BodyReturns where
primBodyType :: PrimType -> BodyReturns
primBodyType = PrimType -> BodyReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim
data MemOp inner
=
Alloc SubExp Space
| Inner inner
deriving (MemOp inner -> MemOp inner -> Bool
(MemOp inner -> MemOp inner -> Bool)
-> (MemOp inner -> MemOp inner -> Bool) -> Eq (MemOp inner)
forall inner. Eq inner => MemOp inner -> MemOp inner -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemOp inner -> MemOp inner -> Bool
$c/= :: forall inner. Eq inner => MemOp inner -> MemOp inner -> Bool
== :: MemOp inner -> MemOp inner -> Bool
$c== :: forall inner. Eq inner => MemOp inner -> MemOp inner -> Bool
Eq, Eq (MemOp inner)
Eq (MemOp inner)
-> (MemOp inner -> MemOp inner -> Ordering)
-> (MemOp inner -> MemOp inner -> Bool)
-> (MemOp inner -> MemOp inner -> Bool)
-> (MemOp inner -> MemOp inner -> Bool)
-> (MemOp inner -> MemOp inner -> Bool)
-> (MemOp inner -> MemOp inner -> MemOp inner)
-> (MemOp inner -> MemOp inner -> MemOp inner)
-> Ord (MemOp inner)
MemOp inner -> MemOp inner -> Bool
MemOp inner -> MemOp inner -> Ordering
MemOp inner -> MemOp inner -> MemOp inner
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall inner. Ord inner => Eq (MemOp inner)
forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
forall inner. Ord inner => MemOp inner -> MemOp inner -> Ordering
forall inner.
Ord inner =>
MemOp inner -> MemOp inner -> MemOp inner
min :: MemOp inner -> MemOp inner -> MemOp inner
$cmin :: forall inner.
Ord inner =>
MemOp inner -> MemOp inner -> MemOp inner
max :: MemOp inner -> MemOp inner -> MemOp inner
$cmax :: forall inner.
Ord inner =>
MemOp inner -> MemOp inner -> MemOp inner
>= :: MemOp inner -> MemOp inner -> Bool
$c>= :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
> :: MemOp inner -> MemOp inner -> Bool
$c> :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
<= :: MemOp inner -> MemOp inner -> Bool
$c<= :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
< :: MemOp inner -> MemOp inner -> Bool
$c< :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
compare :: MemOp inner -> MemOp inner -> Ordering
$ccompare :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Ordering
$cp1Ord :: forall inner. Ord inner => Eq (MemOp inner)
Ord, Int -> MemOp inner -> ShowS
[MemOp inner] -> ShowS
MemOp inner -> String
(Int -> MemOp inner -> ShowS)
-> (MemOp inner -> String)
-> ([MemOp inner] -> ShowS)
-> Show (MemOp inner)
forall inner. Show inner => Int -> MemOp inner -> ShowS
forall inner. Show inner => [MemOp inner] -> ShowS
forall inner. Show inner => MemOp inner -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemOp inner] -> ShowS
$cshowList :: forall inner. Show inner => [MemOp inner] -> ShowS
show :: MemOp inner -> String
$cshow :: forall inner. Show inner => MemOp inner -> String
showsPrec :: Int -> MemOp inner -> ShowS
$cshowsPrec :: forall inner. Show inner => Int -> MemOp inner -> ShowS
Show)
instance AllocOp (MemOp inner) where
allocOp :: SubExp -> Space -> MemOp inner
allocOp = SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc
instance FreeIn inner => FreeIn (MemOp inner) where
freeIn' :: MemOp inner -> FV
freeIn' (Alloc SubExp
size Space
_) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
size
freeIn' (Inner inner
k) = inner -> FV
forall a. FreeIn a => a -> FV
freeIn' inner
k
instance TypedOp inner => TypedOp (MemOp inner) where
opType :: MemOp inner -> m [ExtType]
opType (Alloc SubExp
_ Space
space) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Space -> ExtType
forall shape u. Space -> TypeBase shape u
Mem Space
space]
opType (Inner inner
k) = inner -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType inner
k
instance AliasedOp inner => AliasedOp (MemOp inner) where
opAliases :: MemOp inner -> [Names]
opAliases Alloc {} = [Names
forall a. Monoid a => a
mempty]
opAliases (Inner inner
k) = inner -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases inner
k
consumedInOp :: MemOp inner -> Names
consumedInOp Alloc {} = Names
forall a. Monoid a => a
mempty
consumedInOp (Inner inner
k) = inner -> Names
forall op. AliasedOp op => op -> Names
consumedInOp inner
k
instance CanBeAliased inner => CanBeAliased (MemOp inner) where
type OpWithAliases (MemOp inner) = MemOp (OpWithAliases inner)
removeOpAliases :: OpWithAliases (MemOp inner) -> MemOp inner
removeOpAliases (Alloc se space) = SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
se Space
space
removeOpAliases (Inner k) = inner -> MemOp inner
forall inner. inner -> MemOp inner
Inner (inner -> MemOp inner) -> inner -> MemOp inner
forall a b. (a -> b) -> a -> b
$ OpWithAliases inner -> inner
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases inner
k
addOpAliases :: AliasTable -> MemOp inner -> OpWithAliases (MemOp inner)
addOpAliases AliasTable
_ (Alloc SubExp
se Space
space) = SubExp -> Space -> MemOp (OpWithAliases inner)
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
se Space
space
addOpAliases AliasTable
aliases (Inner inner
k) = OpWithAliases inner -> MemOp (OpWithAliases inner)
forall inner. inner -> MemOp inner
Inner (OpWithAliases inner -> MemOp (OpWithAliases inner))
-> OpWithAliases inner -> MemOp (OpWithAliases inner)
forall a b. (a -> b) -> a -> b
$ AliasTable -> inner -> OpWithAliases inner
forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases inner
k
instance Rename inner => Rename (MemOp inner) where
rename :: MemOp inner -> RenameM (MemOp inner)
rename (Alloc SubExp
size Space
space) = SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc (SubExp -> Space -> MemOp inner)
-> RenameM SubExp -> RenameM (Space -> MemOp inner)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
size RenameM (Space -> MemOp inner)
-> RenameM Space -> RenameM (MemOp inner)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> RenameM Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space
rename (Inner inner
k) = inner -> MemOp inner
forall inner. inner -> MemOp inner
Inner (inner -> MemOp inner) -> RenameM inner -> RenameM (MemOp inner)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> inner -> RenameM inner
forall a. Rename a => a -> RenameM a
rename inner
k
instance Substitute inner => Substitute (MemOp inner) where
substituteNames :: Map VName VName -> MemOp inner -> MemOp inner
substituteNames Map VName VName
subst (Alloc SubExp
size Space
space) = SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
size) Space
space
substituteNames Map VName VName
subst (Inner inner
k) = inner -> MemOp inner
forall inner. inner -> MemOp inner
Inner (inner -> MemOp inner) -> inner -> MemOp inner
forall a b. (a -> b) -> a -> b
$ Map VName VName -> inner -> inner
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst inner
k
instance PP.Pretty inner => PP.Pretty (MemOp inner) where
ppr :: MemOp inner -> Doc
ppr (Alloc SubExp
e Space
DefaultSpace) = String -> Doc
PP.text String
"alloc" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.apply [SubExp -> Doc
forall a. Pretty a => a -> Doc
PP.ppr SubExp
e]
ppr (Alloc SubExp
e Space
s) = String -> Doc
PP.text String
"alloc" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.apply [SubExp -> Doc
forall a. Pretty a => a -> Doc
PP.ppr SubExp
e, Space -> Doc
forall a. Pretty a => a -> Doc
PP.ppr Space
s]
ppr (Inner inner
k) = inner -> Doc
forall a. Pretty a => a -> Doc
PP.ppr inner
k
instance OpMetrics inner => OpMetrics (MemOp inner) where
opMetrics :: MemOp inner -> MetricsM ()
opMetrics Alloc {} = Text -> MetricsM ()
seen Text
"Alloc"
opMetrics (Inner inner
k) = inner -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics inner
k
instance IsOp inner => IsOp (MemOp inner) where
safeOp :: MemOp inner -> Bool
safeOp (Alloc (Constant (IntValue (Int64Value Int64
k))) Space
_) = Int64
k Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
0
safeOp Alloc {} = Bool
False
safeOp (Inner inner
k) = inner -> Bool
forall op. IsOp op => op -> Bool
safeOp inner
k
cheapOp :: MemOp inner -> Bool
cheapOp (Inner inner
k) = inner -> Bool
forall op. IsOp op => op -> Bool
cheapOp inner
k
cheapOp Alloc {} = Bool
True
instance CanBeWise inner => CanBeWise (MemOp inner) where
type OpWithWisdom (MemOp inner) = MemOp (OpWithWisdom inner)
removeOpWisdom :: OpWithWisdom (MemOp inner) -> MemOp inner
removeOpWisdom (Alloc size space) = SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size Space
space
removeOpWisdom (Inner k) = inner -> MemOp inner
forall inner. inner -> MemOp inner
Inner (inner -> MemOp inner) -> inner -> MemOp inner
forall a b. (a -> b) -> a -> b
$ OpWithWisdom inner -> inner
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom inner
k
instance ST.IndexOp inner => ST.IndexOp (MemOp inner) where
indexOp :: SymbolTable lore
-> Int -> MemOp inner -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k (Inner inner
op) [TPrimExp Int64 VName]
is = SymbolTable lore
-> Int -> inner -> [TPrimExp Int64 VName] -> Maybe Indexed
forall op lore.
(IndexOp op, ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable lore
vtable Int
k inner
op [TPrimExp Int64 VName]
is
indexOp SymbolTable lore
_ Int
_ MemOp inner
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing
type IxFun = IxFun.IxFun (TPrimExp Int64 VName)
type ExtIxFun = IxFun.IxFun (TPrimExp Int64 (Ext VName))
data MemInfo d u ret
=
MemPrim PrimType
|
MemMem Space
|
MemArray PrimType (ShapeBase d) u ret
deriving (MemInfo d u ret -> MemInfo d u ret -> Bool
(MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> Eq (MemInfo d u ret)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall d u ret.
(Eq d, Eq u, Eq ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
/= :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c/= :: forall d u ret.
(Eq d, Eq u, Eq ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
== :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c== :: forall d u ret.
(Eq d, Eq u, Eq ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
Eq, Int -> MemInfo d u ret -> ShowS
[MemInfo d u ret] -> ShowS
MemInfo d u ret -> String
(Int -> MemInfo d u ret -> ShowS)
-> (MemInfo d u ret -> String)
-> ([MemInfo d u ret] -> ShowS)
-> Show (MemInfo d u ret)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall d u ret.
(Show d, Show u, Show ret) =>
Int -> MemInfo d u ret -> ShowS
forall d u ret.
(Show d, Show u, Show ret) =>
[MemInfo d u ret] -> ShowS
forall d u ret.
(Show d, Show u, Show ret) =>
MemInfo d u ret -> String
showList :: [MemInfo d u ret] -> ShowS
$cshowList :: forall d u ret.
(Show d, Show u, Show ret) =>
[MemInfo d u ret] -> ShowS
show :: MemInfo d u ret -> String
$cshow :: forall d u ret.
(Show d, Show u, Show ret) =>
MemInfo d u ret -> String
showsPrec :: Int -> MemInfo d u ret -> ShowS
$cshowsPrec :: forall d u ret.
(Show d, Show u, Show ret) =>
Int -> MemInfo d u ret -> ShowS
Show, Eq (MemInfo d u ret)
Eq (MemInfo d u ret)
-> (MemInfo d u ret -> MemInfo d u ret -> Ordering)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret)
-> (MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret)
-> Ord (MemInfo d u ret)
MemInfo d u ret -> MemInfo d u ret -> Bool
MemInfo d u ret -> MemInfo d u ret -> Ordering
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall d u ret. (Ord d, Ord u, Ord ret) => Eq (MemInfo d u ret)
forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Ordering
forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
min :: MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
$cmin :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
max :: MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
$cmax :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
>= :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c>= :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
> :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c> :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
<= :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c<= :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
< :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c< :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
compare :: MemInfo d u ret -> MemInfo d u ret -> Ordering
$ccompare :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Ordering
$cp1Ord :: forall d u ret. (Ord d, Ord u, Ord ret) => Eq (MemInfo d u ret)
Ord)
type MemBound u = MemInfo SubExp u MemBind
instance FixExt ret => DeclExtTyped (MemInfo ExtSize Uniqueness ret) where
declExtTypeOf :: MemInfo ExtSize Uniqueness ret -> DeclExtType
declExtTypeOf (MemPrim PrimType
pt) = PrimType -> DeclExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
declExtTypeOf (MemMem Space
space) = Space -> DeclExtType
forall shape u. Space -> TypeBase shape u
Mem Space
space
declExtTypeOf (MemArray PrimType
pt ShapeBase ExtSize
shape Uniqueness
u ret
_) = PrimType -> ShapeBase ExtSize -> Uniqueness -> DeclExtType
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase ExtSize
shape Uniqueness
u
instance FixExt ret => ExtTyped (MemInfo ExtSize NoUniqueness ret) where
extTypeOf :: MemInfo ExtSize NoUniqueness ret -> ExtType
extTypeOf (MemPrim PrimType
pt) = PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
extTypeOf (MemMem Space
space) = Space -> ExtType
forall shape u. Space -> TypeBase shape u
Mem Space
space
extTypeOf (MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u ret
_) = PrimType -> ShapeBase ExtSize -> NoUniqueness -> ExtType
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u
instance FixExt ret => FixExt (MemInfo ExtSize u ret) where
fixExt :: Int -> SubExp -> MemInfo ExtSize u ret -> MemInfo ExtSize u ret
fixExt Int
_ SubExp
_ (MemPrim PrimType
pt) = PrimType -> MemInfo ExtSize u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
fixExt Int
_ SubExp
_ (MemMem Space
space) = Space -> MemInfo ExtSize u ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
fixExt Int
i SubExp
se (MemArray PrimType
pt ShapeBase ExtSize
shape u
u ret
ret) =
PrimType -> ShapeBase ExtSize -> u -> ret -> MemInfo ExtSize u ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (Int -> SubExp -> ShapeBase ExtSize -> ShapeBase ExtSize
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se ShapeBase ExtSize
shape) u
u (Int -> SubExp -> ret -> ret
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se ret
ret)
instance Typed (MemInfo SubExp Uniqueness ret) where
typeOf :: MemInfo SubExp Uniqueness ret -> Type
typeOf = TypeBase Shape Uniqueness -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl (TypeBase Shape Uniqueness -> Type)
-> (MemInfo SubExp Uniqueness ret -> TypeBase Shape Uniqueness)
-> MemInfo SubExp Uniqueness ret
-> Type
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. MemInfo SubExp Uniqueness ret -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf
instance Typed (MemInfo SubExp NoUniqueness ret) where
typeOf :: MemInfo SubExp NoUniqueness ret -> Type
typeOf (MemPrim PrimType
pt) = PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
typeOf (MemMem Space
space) = Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
typeOf (MemArray PrimType
bt Shape
shape NoUniqueness
u ret
_) = PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt Shape
shape NoUniqueness
u
instance DeclTyped (MemInfo SubExp Uniqueness ret) where
declTypeOf :: MemInfo SubExp Uniqueness ret -> TypeBase Shape Uniqueness
declTypeOf (MemPrim PrimType
bt) = PrimType -> TypeBase Shape Uniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
declTypeOf (MemMem Space
space) = Space -> TypeBase Shape Uniqueness
forall shape u. Space -> TypeBase shape u
Mem Space
space
declTypeOf (MemArray PrimType
bt Shape
shape Uniqueness
u ret
_) = PrimType -> Shape -> Uniqueness -> TypeBase Shape Uniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt Shape
shape Uniqueness
u
instance (FreeIn d, FreeIn ret) => FreeIn (MemInfo d u ret) where
freeIn' :: MemInfo d u ret -> FV
freeIn' (MemArray PrimType
_ ShapeBase d
shape u
_ ret
ret) = ShapeBase d -> FV
forall a. FreeIn a => a -> FV
freeIn' ShapeBase d
shape FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ret -> FV
forall a. FreeIn a => a -> FV
freeIn' ret
ret
freeIn' (MemMem Space
s) = Space -> FV
forall a. FreeIn a => a -> FV
freeIn' Space
s
freeIn' MemPrim {} = FV
forall a. Monoid a => a
mempty
instance (Substitute d, Substitute ret) => Substitute (MemInfo d u ret) where
substituteNames :: Map VName VName -> MemInfo d u ret -> MemInfo d u ret
substituteNames Map VName VName
subst (MemArray PrimType
bt ShapeBase d
shape u
u ret
ret) =
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray
PrimType
bt
(Map VName VName -> ShapeBase d -> ShapeBase d
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ShapeBase d
shape)
u
u
(Map VName VName -> ret -> ret
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ret
ret)
substituteNames Map VName VName
_ (MemMem Space
space) =
Space -> MemInfo d u ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
substituteNames Map VName VName
_ (MemPrim PrimType
bt) =
PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
instance (Substitute d, Substitute ret) => Rename (MemInfo d u ret) where
rename :: MemInfo d u ret -> RenameM (MemInfo d u ret)
rename = MemInfo d u ret -> RenameM (MemInfo d u ret)
forall a. Substitute a => a -> RenameM a
substituteRename
simplifyIxFun ::
Engine.SimplifiableLore lore =>
IxFun ->
Engine.SimpleM lore IxFun
simplifyIxFun :: IxFun -> SimpleM lore IxFun
simplifyIxFun = (TPrimExp Int64 VName -> SimpleM lore (TPrimExp Int64 VName))
-> IxFun -> SimpleM lore IxFun
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((TPrimExp Int64 VName -> SimpleM lore (TPrimExp Int64 VName))
-> IxFun -> SimpleM lore IxFun)
-> (TPrimExp Int64 VName -> SimpleM lore (TPrimExp Int64 VName))
-> IxFun
-> SimpleM lore IxFun
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> TPrimExp Int64 VName)
-> SimpleM lore (PrimExp VName)
-> SimpleM lore (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (SimpleM lore (PrimExp VName)
-> SimpleM lore (TPrimExp Int64 VName))
-> (TPrimExp Int64 VName -> SimpleM lore (PrimExp VName))
-> TPrimExp Int64 VName
-> SimpleM lore (TPrimExp Int64 VName)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PrimExp VName -> SimpleM lore (PrimExp VName)
forall lore.
SimplifiableLore lore =>
PrimExp VName -> SimpleM lore (PrimExp VName)
simplifyPrimExp (PrimExp VName -> SimpleM lore (PrimExp VName))
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> SimpleM lore (PrimExp VName)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped
simplifyExtIxFun ::
Engine.SimplifiableLore lore =>
ExtIxFun ->
Engine.SimpleM lore ExtIxFun
simplifyExtIxFun :: ExtIxFun -> SimpleM lore ExtIxFun
simplifyExtIxFun = (TPrimExp Int64 (Ext VName)
-> SimpleM lore (TPrimExp Int64 (Ext VName)))
-> ExtIxFun -> SimpleM lore ExtIxFun
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((TPrimExp Int64 (Ext VName)
-> SimpleM lore (TPrimExp Int64 (Ext VName)))
-> ExtIxFun -> SimpleM lore ExtIxFun)
-> (TPrimExp Int64 (Ext VName)
-> SimpleM lore (TPrimExp Int64 (Ext VName)))
-> ExtIxFun
-> SimpleM lore ExtIxFun
forall a b. (a -> b) -> a -> b
$ (PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> SimpleM lore (PrimExp (Ext VName))
-> SimpleM lore (TPrimExp Int64 (Ext VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (SimpleM lore (PrimExp (Ext VName))
-> SimpleM lore (TPrimExp Int64 (Ext VName)))
-> (TPrimExp Int64 (Ext VName)
-> SimpleM lore (PrimExp (Ext VName)))
-> TPrimExp Int64 (Ext VName)
-> SimpleM lore (TPrimExp Int64 (Ext VName))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall lore.
SimplifiableLore lore =>
PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
simplifyExtPrimExp (PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName)))
-> (TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName))
-> TPrimExp Int64 (Ext VName)
-> SimpleM lore (PrimExp (Ext VName))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName)
forall t v. TPrimExp t v -> PrimExp v
untyped
isStaticIxFun :: ExtIxFun -> Maybe IxFun
isStaticIxFun :: ExtIxFun -> Maybe IxFun
isStaticIxFun = (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtIxFun -> Maybe IxFun
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtIxFun -> Maybe IxFun)
-> (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtIxFun
-> Maybe IxFun
forall a b. (a -> b) -> a -> b
$ (Ext VName -> Maybe VName)
-> TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> Maybe VName
forall a. Ext a -> Maybe a
inst
where
inst :: Ext a -> Maybe a
inst Ext {} = Maybe a
forall a. Maybe a
Nothing
inst (Free a
x) = a -> Maybe a
forall a. a -> Maybe a
Just a
x
instance
(Engine.Simplifiable d, Engine.Simplifiable ret) =>
Engine.Simplifiable (MemInfo d u ret)
where
simplify :: MemInfo d u ret -> SimpleM lore (MemInfo d u ret)
simplify (MemPrim PrimType
bt) =
MemInfo d u ret -> SimpleM lore (MemInfo d u ret)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo d u ret -> SimpleM lore (MemInfo d u ret))
-> MemInfo d u ret -> SimpleM lore (MemInfo d u ret)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
simplify (MemMem Space
space) =
MemInfo d u ret -> SimpleM lore (MemInfo d u ret)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d u ret -> SimpleM lore (MemInfo d u ret))
-> MemInfo d u ret -> SimpleM lore (MemInfo d u ret)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo d u ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
simplify (MemArray PrimType
bt ShapeBase d
shape u
u ret
ret) =
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt (ShapeBase d -> u -> ret -> MemInfo d u ret)
-> SimpleM lore (ShapeBase d)
-> SimpleM lore (u -> ret -> MemInfo d u ret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ShapeBase d -> SimpleM lore (ShapeBase d)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify ShapeBase d
shape SimpleM lore (u -> ret -> MemInfo d u ret)
-> SimpleM lore u -> SimpleM lore (ret -> MemInfo d u ret)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> SimpleM lore u
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u SimpleM lore (ret -> MemInfo d u ret)
-> SimpleM lore ret -> SimpleM lore (MemInfo d u ret)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ret -> SimpleM lore ret
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify ret
ret
instance
( PP.Pretty (TypeBase (ShapeBase d) u),
PP.Pretty d,
PP.Pretty u,
PP.Pretty ret
) =>
PP.Pretty (MemInfo d u ret)
where
ppr :: MemInfo d u ret -> Doc
ppr (MemPrim PrimType
bt) = PrimType -> Doc
forall a. Pretty a => a -> Doc
PP.ppr PrimType
bt
ppr (MemMem Space
DefaultSpace) = String -> Doc
PP.text String
"mem"
ppr (MemMem Space
s) = String -> Doc
PP.text String
"mem" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Space -> Doc
forall a. Pretty a => a -> Doc
PP.ppr Space
s
ppr (MemArray PrimType
bt ShapeBase d
shape u
u ret
ret) =
TypeBase (ShapeBase d) u -> Doc
forall a. Pretty a => a -> Doc
PP.ppr (PrimType -> ShapeBase d -> u -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt ShapeBase d
shape u
u) Doc -> Doc -> Doc
<+> String -> Doc
PP.text String
"@" Doc -> Doc -> Doc
<+> ret -> Doc
forall a. Pretty a => a -> Doc
PP.ppr ret
ret
data MemBind
=
ArrayIn VName IxFun
deriving (Int -> MemBind -> ShowS
[MemBind] -> ShowS
MemBind -> String
(Int -> MemBind -> ShowS)
-> (MemBind -> String) -> ([MemBind] -> ShowS) -> Show MemBind
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemBind] -> ShowS
$cshowList :: [MemBind] -> ShowS
show :: MemBind -> String
$cshow :: MemBind -> String
showsPrec :: Int -> MemBind -> ShowS
$cshowsPrec :: Int -> MemBind -> ShowS
Show)
instance Eq MemBind where
MemBind
_ == :: MemBind -> MemBind -> Bool
== MemBind
_ = Bool
True
instance Ord MemBind where
MemBind
_ compare :: MemBind -> MemBind -> Ordering
`compare` MemBind
_ = Ordering
EQ
instance Rename MemBind where
rename :: MemBind -> RenameM MemBind
rename = MemBind -> RenameM MemBind
forall a. Substitute a => a -> RenameM a
substituteRename
instance Substitute MemBind where
substituteNames :: Map VName VName -> MemBind -> MemBind
substituteNames Map VName VName
substs (ArrayIn VName
ident IxFun
ixfun) =
VName -> IxFun -> MemBind
ArrayIn (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
ident) (Map VName VName -> IxFun -> IxFun
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs IxFun
ixfun)
instance PP.Pretty MemBind where
ppr :: MemBind -> Doc
ppr (ArrayIn VName
mem IxFun
ixfun) =
VName -> Doc
forall a. Pretty a => a -> Doc
PP.ppr VName
mem Doc -> Doc -> Doc
<+> Doc
"->" Doc -> Doc -> Doc
PP.</> IxFun -> Doc
forall a. Pretty a => a -> Doc
PP.ppr IxFun
ixfun
instance FreeIn MemBind where
freeIn' :: MemBind -> FV
freeIn' (ArrayIn VName
mem IxFun
ixfun) = VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
mem FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> IxFun -> FV
forall a. FreeIn a => a -> FV
freeIn' IxFun
ixfun
data MemReturn
=
ReturnsInBlock VName ExtIxFun
|
ReturnsNewBlock Space Int ExtIxFun
deriving (Int -> MemReturn -> ShowS
[MemReturn] -> ShowS
MemReturn -> String
(Int -> MemReturn -> ShowS)
-> (MemReturn -> String)
-> ([MemReturn] -> ShowS)
-> Show MemReturn
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemReturn] -> ShowS
$cshowList :: [MemReturn] -> ShowS
show :: MemReturn -> String
$cshow :: MemReturn -> String
showsPrec :: Int -> MemReturn -> ShowS
$cshowsPrec :: Int -> MemReturn -> ShowS
Show)
instance Eq MemReturn where
MemReturn
_ == :: MemReturn -> MemReturn -> Bool
== MemReturn
_ = Bool
True
instance Ord MemReturn where
MemReturn
_ compare :: MemReturn -> MemReturn -> Ordering
`compare` MemReturn
_ = Ordering
EQ
instance Rename MemReturn where
rename :: MemReturn -> RenameM MemReturn
rename = MemReturn -> RenameM MemReturn
forall a. Substitute a => a -> RenameM a
substituteRename
instance Substitute MemReturn where
substituteNames :: Map VName VName -> MemReturn -> MemReturn
substituteNames Map VName VName
substs (ReturnsInBlock VName
ident ExtIxFun
ixfun) =
VName -> ExtIxFun -> MemReturn
ReturnsInBlock (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
ident) (Map VName VName -> ExtIxFun -> ExtIxFun
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ExtIxFun
ixfun)
substituteNames Map VName VName
substs (ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun) =
Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i (Map VName VName -> ExtIxFun -> ExtIxFun
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ExtIxFun
ixfun)
instance FixExt MemReturn where
fixExt :: Int -> SubExp -> MemReturn -> MemReturn
fixExt Int
i (Var VName
v) (ReturnsNewBlock Space
_ Int
j ExtIxFun
ixfun)
| Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i =
VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
v (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun
Int
i
(PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 (VName -> SubExp
Var VName
v))
ExtIxFun
ixfun
fixExt Int
i SubExp
se (ReturnsNewBlock Space
space Int
j ExtIxFun
ixfun) =
Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock
Space
space
Int
j'
(Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun Int
i (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 SubExp
se) ExtIxFun
ixfun)
where
j' :: Int
j'
| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
j = Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1
| Bool
otherwise = Int
j
fixExt Int
i SubExp
se (ReturnsInBlock VName
mem ExtIxFun
ixfun) =
VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun Int
i (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 SubExp
se) ExtIxFun
ixfun)
fixExtIxFun :: Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun :: Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun Int
i PrimExp VName
e = (TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun)
-> (TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName))
-> ExtIxFun
-> ExtIxFun
forall a b. (a -> b) -> a -> b
$ PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> (TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName))
-> TPrimExp Int64 (Ext VName)
-> TPrimExp Int64 (Ext VName)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Ext VName -> PrimType -> PrimExp (Ext VName))
-> PrimExp (Ext VName) -> PrimExp (Ext VName)
forall a b. (a -> PrimType -> PrimExp b) -> PrimExp a -> PrimExp b
replaceInPrimExp Ext VName -> PrimType -> PrimExp (Ext VName)
update (PrimExp (Ext VName) -> PrimExp (Ext VName))
-> (TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName))
-> TPrimExp Int64 (Ext VName)
-> PrimExp (Ext VName)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName)
forall t v. TPrimExp t v -> PrimExp v
untyped
where
update :: Ext VName -> PrimType -> PrimExp (Ext VName)
update (Ext Int
j) PrimType
t
| Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
i = Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext (Int -> Ext VName) -> Int -> Ext VName
forall a b. (a -> b) -> a -> b
$ Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) PrimType
t
| Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i = (VName -> Ext VName) -> PrimExp VName -> PrimExp (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free PrimExp VName
e
| Bool
otherwise = Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
j) PrimType
t
update (Free VName
x) PrimType
t = Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> Ext VName
forall a. a -> Ext a
Free VName
x) PrimType
t
leafExp :: Int -> TPrimExp Int64 (Ext a)
leafExp :: Int -> TPrimExp Int64 (Ext a)
leafExp Int
i = PrimExp (Ext a) -> TPrimExp Int64 (Ext a)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp (Ext a) -> TPrimExp Int64 (Ext a))
-> PrimExp (Ext a) -> TPrimExp Int64 (Ext a)
forall a b. (a -> b) -> a -> b
$ Ext a -> PrimType -> PrimExp (Ext a)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext a
forall a. Int -> Ext a
Ext Int
i) PrimType
int64
existentialiseIxFun :: [VName] -> IxFun -> ExtIxFun
existentialiseIxFun :: [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [VName]
ctx = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx' (ExtIxFun -> ExtIxFun) -> (IxFun -> ExtIxFun) -> IxFun -> ExtIxFun
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free)
where
ctx' :: Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx' = (Int -> TPrimExp Int64 (Ext VName))
-> Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Int -> TPrimExp Int64 (Ext VName)
forall a. Int -> TPrimExp Int64 (Ext a)
leafExp (Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$ [(Ext VName, Int)] -> Map (Ext VName) Int
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, Int)] -> Map (Ext VName) Int)
-> [(Ext VName, Int)] -> Map (Ext VName) Int
forall a b. (a -> b) -> a -> b
$ [Ext VName] -> [Int] -> [(Ext VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((VName -> Ext VName) -> [VName] -> [Ext VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Ext VName
forall a. a -> Ext a
Free [VName]
ctx) [Int
0 ..]
instance PP.Pretty MemReturn where
ppr :: MemReturn -> Doc
ppr (ReturnsInBlock VName
v ExtIxFun
ixfun) =
Doc -> Doc
PP.parens (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v Doc -> Doc -> Doc
<+> Doc
"->" Doc -> Doc -> Doc
PP.</> ExtIxFun -> Doc
forall a. Pretty a => a -> Doc
PP.ppr ExtIxFun
ixfun
ppr (ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun) =
Doc
"?" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Int -> Doc
forall a. Pretty a => a -> Doc
ppr Int
i Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Space -> Doc
forall a. Pretty a => a -> Doc
PP.ppr Space
space Doc -> Doc -> Doc
<+> Doc
"->" Doc -> Doc -> Doc
PP.</> ExtIxFun -> Doc
forall a. Pretty a => a -> Doc
PP.ppr ExtIxFun
ixfun
instance FreeIn MemReturn where
freeIn' :: MemReturn -> FV
freeIn' (ReturnsInBlock VName
v ExtIxFun
ixfun) = VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ExtIxFun -> FV
forall a. FreeIn a => a -> FV
freeIn' ExtIxFun
ixfun
freeIn' (ReturnsNewBlock Space
space Int
_ ExtIxFun
ixfun) = Space -> FV
forall a. FreeIn a => a -> FV
freeIn' Space
space FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ExtIxFun -> FV
forall a. FreeIn a => a -> FV
freeIn' ExtIxFun
ixfun
instance Engine.Simplifiable MemReturn where
simplify :: MemReturn -> SimpleM lore MemReturn
simplify (ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun) =
Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i (ExtIxFun -> MemReturn)
-> SimpleM lore ExtIxFun -> SimpleM lore MemReturn
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtIxFun -> SimpleM lore ExtIxFun
forall lore.
SimplifiableLore lore =>
ExtIxFun -> SimpleM lore ExtIxFun
simplifyExtIxFun ExtIxFun
ixfun
simplify (ReturnsInBlock VName
v ExtIxFun
ixfun) =
VName -> ExtIxFun -> MemReturn
ReturnsInBlock (VName -> ExtIxFun -> MemReturn)
-> SimpleM lore VName -> SimpleM lore (ExtIxFun -> MemReturn)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
v SimpleM lore (ExtIxFun -> MemReturn)
-> SimpleM lore ExtIxFun -> SimpleM lore MemReturn
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ExtIxFun -> SimpleM lore ExtIxFun
forall lore.
SimplifiableLore lore =>
ExtIxFun -> SimpleM lore ExtIxFun
simplifyExtIxFun ExtIxFun
ixfun
instance Engine.Simplifiable MemBind where
simplify :: MemBind -> SimpleM lore MemBind
simplify (ArrayIn VName
mem IxFun
ixfun) =
VName -> IxFun -> MemBind
ArrayIn (VName -> IxFun -> MemBind)
-> SimpleM lore VName -> SimpleM lore (IxFun -> MemBind)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify VName
mem SimpleM lore (IxFun -> MemBind)
-> SimpleM lore IxFun -> SimpleM lore MemBind
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IxFun -> SimpleM lore IxFun
forall lore. SimplifiableLore lore => IxFun -> SimpleM lore IxFun
simplifyIxFun IxFun
ixfun
instance Engine.Simplifiable [FunReturns] where
simplify :: [FunReturns] -> SimpleM lore [FunReturns]
simplify = (FunReturns -> SimpleM lore FunReturns)
-> [FunReturns] -> SimpleM lore [FunReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FunReturns -> SimpleM lore FunReturns
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify
type ExpReturns = MemInfo ExtSize NoUniqueness (Maybe MemReturn)
type BodyReturns = MemInfo ExtSize NoUniqueness MemReturn
type FunReturns = MemInfo ExtSize Uniqueness MemReturn
maybeReturns :: MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns :: MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns (MemArray PrimType
bt ShapeBase d
shape u
u r
ret) =
PrimType -> ShapeBase d -> u -> Maybe r -> MemInfo d u (Maybe r)
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase d
shape u
u (Maybe r -> MemInfo d u (Maybe r))
-> Maybe r -> MemInfo d u (Maybe r)
forall a b. (a -> b) -> a -> b
$ r -> Maybe r
forall a. a -> Maybe a
Just r
ret
maybeReturns (MemPrim PrimType
bt) =
PrimType -> MemInfo d u (Maybe r)
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
maybeReturns (MemMem Space
space) =
Space -> MemInfo d u (Maybe r)
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
noUniquenessReturns :: MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns :: MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemArray PrimType
bt ShapeBase d
shape u
_ r
r) =
PrimType
-> ShapeBase d -> NoUniqueness -> r -> MemInfo d NoUniqueness r
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase d
shape NoUniqueness
NoUniqueness r
r
noUniquenessReturns (MemPrim PrimType
bt) =
PrimType -> MemInfo d NoUniqueness r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
noUniquenessReturns (MemMem Space
space) =
Space -> MemInfo d NoUniqueness r
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
funReturnsToExpReturns :: FunReturns -> ExpReturns
funReturnsToExpReturns :: FunReturns -> ExpReturns
funReturnsToExpReturns = MemInfo ExtSize Uniqueness (Maybe MemReturn) -> ExpReturns
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo ExtSize Uniqueness (Maybe MemReturn) -> ExpReturns)
-> (FunReturns -> MemInfo ExtSize Uniqueness (Maybe MemReturn))
-> FunReturns
-> ExpReturns
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. FunReturns -> MemInfo ExtSize Uniqueness (Maybe MemReturn)
forall d u r. MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns
bodyReturnsToExpReturns :: BodyReturns -> ExpReturns
bodyReturnsToExpReturns :: BodyReturns -> ExpReturns
bodyReturnsToExpReturns = ExpReturns -> ExpReturns
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (ExpReturns -> ExpReturns)
-> (BodyReturns -> ExpReturns) -> BodyReturns -> ExpReturns
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. BodyReturns -> ExpReturns
forall d u r. MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns
matchRetTypeToResult ::
(Mem lore, TC.Checkable lore) =>
[FunReturns] ->
Result ->
TC.TypeM lore ()
matchRetTypeToResult :: [FunReturns] -> Result -> TypeM lore ()
matchRetTypeToResult [FunReturns]
rettype Result
result = do
Scope (Aliases lore)
scope <- TypeM lore (Scope (Aliases lore))
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
[MemInfo SubExp NoUniqueness MemBind]
result_ts <- ReaderT
(Scope lore) (TypeM lore) [MemInfo SubExp NoUniqueness MemBind]
-> Scope lore -> TypeM lore [MemInfo SubExp NoUniqueness MemBind]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((SubExp
-> ReaderT
(Scope lore) (TypeM lore) (MemInfo SubExp NoUniqueness MemBind))
-> Result
-> ReaderT
(Scope lore) (TypeM lore) [MemInfo SubExp NoUniqueness MemBind]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp
-> ReaderT
(Scope lore) (TypeM lore) (MemInfo SubExp NoUniqueness MemBind)
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
subExpMemInfo Result
result) (Scope lore -> TypeM lore [MemInfo SubExp NoUniqueness MemBind])
-> Scope lore -> TypeM lore [MemInfo SubExp NoUniqueness MemBind]
forall a b. (a -> b) -> a -> b
$ Scope (Aliases lore) -> Scope lore
forall lore. Scope (Aliases lore) -> Scope lore
removeScopeAliases Scope (Aliases lore)
scope
[FunReturns]
-> Result -> [MemInfo SubExp NoUniqueness MemBind] -> TypeM lore ()
forall u lore.
Pretty u =>
[MemInfo ExtSize u MemReturn]
-> Result -> [MemInfo SubExp NoUniqueness MemBind] -> TypeM lore ()
matchReturnType [FunReturns]
rettype Result
result [MemInfo SubExp NoUniqueness MemBind]
result_ts
matchFunctionReturnType ::
(Mem lore, TC.Checkable lore) =>
[FunReturns] ->
Result ->
TC.TypeM lore ()
matchFunctionReturnType :: [FunReturns] -> Result -> TypeM lore ()
matchFunctionReturnType [FunReturns]
rettype Result
result = do
[FunReturns] -> Result -> TypeM lore ()
forall lore.
(Mem lore, Checkable lore) =>
[FunReturns] -> Result -> TypeM lore ()
matchRetTypeToResult [FunReturns]
rettype Result
result
(SubExp -> TypeM lore ()) -> Result -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> TypeM lore ()
forall lore.
(AllocOp (Op lore), ASTLore lore, OpReturns lore,
LetDec lore ~ MemInfo SubExp NoUniqueness MemBind,
LParamInfo lore ~ MemInfo SubExp NoUniqueness MemBind,
RetType lore ~ FunReturns, FParamInfo lore ~ FParamMem,
BranchType lore ~ BodyReturns) =>
SubExp -> TypeM lore ()
checkResultSubExp Result
result
where
checkResultSubExp :: SubExp -> TypeM lore ()
checkResultSubExp Constant {} =
() -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkResultSubExp (Var VName
v) = do
MemInfo SubExp NoUniqueness MemBind
dec <- VName -> TypeM lore (MemInfo SubExp NoUniqueness MemBind)
forall lore.
Mem lore =>
VName -> TypeM lore (MemInfo SubExp NoUniqueness MemBind)
varMemInfo VName
v
case MemInfo SubExp NoUniqueness MemBind
dec of
MemPrim PrimType
_ -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
MemMem {} -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
_ IxFun
ixfun)
| IxFun -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isLinear IxFun
ixfun ->
() -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise ->
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Array " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" returned by function, but has nontrivial index function "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ IxFun -> String
forall a. Pretty a => a -> String
pretty IxFun
ixfun
matchLoopResultMem ::
(Mem lore, TC.Checkable lore) =>
[FParam (Aliases lore)] ->
[FParam (Aliases lore)] ->
[SubExp] ->
TC.TypeM lore ()
matchLoopResultMem :: [FParam (Aliases lore)]
-> [FParam (Aliases lore)] -> Result -> TypeM lore ()
matchLoopResultMem [FParam (Aliases lore)]
ctx [FParam (Aliases lore)]
val = [FunReturns] -> Result -> TypeM lore ()
forall lore.
(Mem lore, Checkable lore) =>
[FunReturns] -> Result -> TypeM lore ()
matchRetTypeToResult [FunReturns]
rettype
where
ctx_names :: [VName]
ctx_names = (Param FParamMem -> VName) -> [Param FParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param FParamMem -> VName
forall dec. Param dec -> VName
paramName [FParam (Aliases lore)]
[Param FParamMem]
ctx
rettype :: [FunReturns]
rettype = (Param FParamMem -> FunReturns)
-> [Param FParamMem] -> [FunReturns]
forall a b. (a -> b) -> [a] -> [b]
map (FParamMem -> FunReturns
toRet (FParamMem -> FunReturns)
-> (Param FParamMem -> FParamMem) -> Param FParamMem -> FunReturns
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec) [FParam (Aliases lore)]
[Param FParamMem]
val
toExtV :: VName -> Ext VName
toExtV VName
v
| Just Int
i <- VName
v VName -> [VName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [VName]
ctx_names = Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i
| Bool
otherwise = VName -> Ext VName
forall a. a -> Ext a
Free VName
v
toExtSE :: SubExp -> ExtSize
toExtSE (Var VName
v) = VName -> SubExp
Var (VName -> SubExp) -> Ext VName -> ExtSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Ext VName
toExtV VName
v
toExtSE (Constant PrimValue
v) = SubExp -> ExtSize
forall a. a -> Ext a
Free (SubExp -> ExtSize) -> SubExp -> ExtSize
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
toRet :: FParamMem -> FunReturns
toRet (MemPrim PrimType
t) =
PrimType -> FunReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
toRet (MemMem Space
space) =
Space -> FunReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
toRet (MemArray PrimType
pt Shape
shape Uniqueness
u (ArrayIn VName
mem IxFun
ixfun))
| Just Int
i <- VName
mem VName -> [VName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [VName]
ctx_names,
Param VName
_ (MemMem Space
space) : [Param FParamMem]
_ <- Int -> [Param FParamMem] -> [Param FParamMem]
forall a. Int -> [a] -> [a]
drop Int
i [FParam (Aliases lore)]
[Param FParamMem]
ctx =
PrimType
-> ShapeBase ExtSize -> Uniqueness -> MemReturn -> FunReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape' Uniqueness
u (MemReturn -> FunReturns) -> MemReturn -> FunReturns
forall a b. (a -> b) -> a -> b
$ Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun'
| Bool
otherwise =
PrimType
-> ShapeBase ExtSize -> Uniqueness -> MemReturn -> FunReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape' Uniqueness
u (MemReturn -> FunReturns) -> MemReturn -> FunReturns
forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem ExtIxFun
ixfun'
where
shape' :: ShapeBase ExtSize
shape' = (SubExp -> ExtSize) -> Shape -> ShapeBase ExtSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ExtSize
toExtSE Shape
shape
ixfun' :: ExtIxFun
ixfun' = [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [VName]
ctx_names IxFun
ixfun
matchBranchReturnType ::
(Mem lore, TC.Checkable lore) =>
[BodyReturns] ->
Body (Aliases lore) ->
TC.TypeM lore ()
matchBranchReturnType :: [BodyReturns] -> Body (Aliases lore) -> TypeM lore ()
matchBranchReturnType [BodyReturns]
rettype (Body BodyDec (Aliases lore)
_ Stms (Aliases lore)
stms Result
res) = do
Scope (Aliases lore)
scope <- TypeM lore (Scope (Aliases lore))
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
[MemInfo SubExp NoUniqueness MemBind]
ts <- ReaderT
(Scope lore) (TypeM lore) [MemInfo SubExp NoUniqueness MemBind]
-> Scope lore -> TypeM lore [MemInfo SubExp NoUniqueness MemBind]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((SubExp
-> ReaderT
(Scope lore) (TypeM lore) (MemInfo SubExp NoUniqueness MemBind))
-> Result
-> ReaderT
(Scope lore) (TypeM lore) [MemInfo SubExp NoUniqueness MemBind]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp
-> ReaderT
(Scope lore) (TypeM lore) (MemInfo SubExp NoUniqueness MemBind)
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
subExpMemInfo Result
res) (Scope lore -> TypeM lore [MemInfo SubExp NoUniqueness MemBind])
-> Scope lore -> TypeM lore [MemInfo SubExp NoUniqueness MemBind]
forall a b. (a -> b) -> a -> b
$ Scope (Aliases lore) -> Scope lore
forall lore. Scope (Aliases lore) -> Scope lore
removeScopeAliases (Scope (Aliases lore)
scope Scope (Aliases lore)
-> Scope (Aliases lore) -> Scope (Aliases lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases lore) -> Scope (Aliases lore)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms (Aliases lore)
stms)
[BodyReturns]
-> Result -> [MemInfo SubExp NoUniqueness MemBind] -> TypeM lore ()
forall u lore.
Pretty u =>
[MemInfo ExtSize u MemReturn]
-> Result -> [MemInfo SubExp NoUniqueness MemBind] -> TypeM lore ()
matchReturnType [BodyReturns]
rettype Result
res [MemInfo SubExp NoUniqueness MemBind]
ts
getExtMaps ::
[(VName, Int)] ->
( M.Map (Ext VName) (TPrimExp Int64 (Ext VName)),
M.Map (Ext VName) (TPrimExp Int64 (Ext VName))
)
getExtMaps :: [(VName, Int)]
-> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
Map (Ext VName) (TPrimExp Int64 (Ext VName)))
getExtMaps [(VName, Int)]
ctx_lst_ids =
( (Int -> TPrimExp Int64 (Ext VName))
-> Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Int -> TPrimExp Int64 (Ext VName)
forall a. Int -> TPrimExp Int64 (Ext a)
leafExp (Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$ (VName -> Ext VName) -> Map VName Int -> Map (Ext VName) Int
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
M.mapKeys VName -> Ext VName
forall a. a -> Ext a
Free (Map VName Int -> Map (Ext VName) Int)
-> Map VName Int -> Map (Ext VName) Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [(VName, Int)] -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
M.fromListWith ((Int -> Int) -> Int -> Int -> Int
forall a b. a -> b -> a
const Int -> Int
forall k (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id) [(VName, Int)]
ctx_lst_ids,
[(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> [(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$
((VName, Int) -> Maybe (Ext VName, TPrimExp Int64 (Ext VName)))
-> [(VName, Int)] -> [(Ext VName, TPrimExp Int64 (Ext VName))]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
( (VName -> Maybe (TPrimExp Int64 (Ext VName)))
-> (Ext VName, VName)
-> Maybe (Ext VName, TPrimExp Int64 (Ext VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
( (Int -> TPrimExp Int64 (Ext VName))
-> Maybe Int -> Maybe (TPrimExp Int64 (Ext VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
i -> PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> a -> b
$ Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i) PrimType
int64)
(Maybe Int -> Maybe (TPrimExp Int64 (Ext VName)))
-> (VName -> Maybe Int)
-> VName
-> Maybe (TPrimExp Int64 (Ext VName))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName -> [(VName, Int)] -> Maybe Int
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` [(VName, Int)]
ctx_lst_ids)
)
((Ext VName, VName)
-> Maybe (Ext VName, TPrimExp Int64 (Ext VName)))
-> ((VName, Int) -> (Ext VName, VName))
-> (VName, Int)
-> Maybe (Ext VName, TPrimExp Int64 (Ext VName))
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName -> Ext VName -> (Ext VName, VName))
-> (VName, Ext VName) -> (Ext VName, VName)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Ext VName -> VName -> (Ext VName, VName))
-> VName -> Ext VName -> (Ext VName, VName)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,))
((VName, Ext VName) -> (Ext VName, VName))
-> ((VName, Int) -> (VName, Ext VName))
-> (VName, Int)
-> (Ext VName, VName)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int -> Ext VName) -> (VName, Int) -> (VName, Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Ext VName
forall a. Int -> Ext a
Ext
)
[(VName, Int)]
ctx_lst_ids
)
matchReturnType ::
PP.Pretty u =>
[MemInfo ExtSize u MemReturn] ->
[SubExp] ->
[MemInfo SubExp NoUniqueness MemBind] ->
TC.TypeM lore ()
matchReturnType :: [MemInfo ExtSize u MemReturn]
-> Result -> [MemInfo SubExp NoUniqueness MemBind] -> TypeM lore ()
matchReturnType [MemInfo ExtSize u MemReturn]
rettype Result
res [MemInfo SubExp NoUniqueness MemBind]
ts = do
let ([MemInfo SubExp NoUniqueness MemBind]
ctx_ts, [MemInfo SubExp NoUniqueness MemBind]
val_ts) = Int
-> [MemInfo SubExp NoUniqueness MemBind]
-> ([MemInfo SubExp NoUniqueness MemBind],
[MemInfo SubExp NoUniqueness MemBind])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([MemInfo ExtSize u MemReturn] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [MemInfo ExtSize u MemReturn]
rettype) [MemInfo SubExp NoUniqueness MemBind]
ts
(Result
ctx_res, Result
_val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([MemInfo ExtSize u MemReturn] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [MemInfo ExtSize u MemReturn]
rettype) Result
res
existentialiseIxFun0 :: IxFun -> ExtIxFun
existentialiseIxFun0 :: IxFun -> ExtIxFun
existentialiseIxFun0 = (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun -> ExtIxFun)
-> (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun
-> ExtIxFun
forall a b. (a -> b) -> a -> b
$ (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free
fetchCtx :: Int
-> ExceptT
String (TypeM lore) (SubExp, MemInfo SubExp NoUniqueness MemBind)
fetchCtx Int
i = case Int
-> [(SubExp, MemInfo SubExp NoUniqueness MemBind)]
-> Maybe (SubExp, MemInfo SubExp NoUniqueness MemBind)
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i ([(SubExp, MemInfo SubExp NoUniqueness MemBind)]
-> Maybe (SubExp, MemInfo SubExp NoUniqueness MemBind))
-> [(SubExp, MemInfo SubExp NoUniqueness MemBind)]
-> Maybe (SubExp, MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Result
-> [MemInfo SubExp NoUniqueness MemBind]
-> [(SubExp, MemInfo SubExp NoUniqueness MemBind)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
ctx_res [MemInfo SubExp NoUniqueness MemBind]
ctx_ts of
Maybe (SubExp, MemInfo SubExp NoUniqueness MemBind)
Nothing ->
String
-> ExceptT
String (TypeM lore) (SubExp, MemInfo SubExp NoUniqueness MemBind)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String
-> ExceptT
String (TypeM lore) (SubExp, MemInfo SubExp NoUniqueness MemBind))
-> String
-> ExceptT
String (TypeM lore) (SubExp, MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$
String
"Cannot find context variable "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" in context results: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Result -> String
forall a. Pretty a => a -> String
pretty Result
ctx_res
Just (SubExp
se, MemInfo SubExp NoUniqueness MemBind
t) -> (SubExp, MemInfo SubExp NoUniqueness MemBind)
-> ExceptT
String (TypeM lore) (SubExp, MemInfo SubExp NoUniqueness MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
se, MemInfo SubExp NoUniqueness MemBind
t)
checkReturn :: MemInfo ExtSize u MemReturn
-> MemInfo SubExp NoUniqueness MemBind
-> ExceptT String (TypeM lore) ()
checkReturn (MemPrim PrimType
x) (MemPrim PrimType
y)
| PrimType
x PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
y = () -> ExceptT String (TypeM lore) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkReturn (MemMem Space
x) (MemMem Space
y)
| Space
x Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
y = () -> ExceptT String (TypeM lore) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkReturn
(MemArray PrimType
x_pt ShapeBase ExtSize
x_shape u
_ MemReturn
x_ret)
(MemArray PrimType
y_pt Shape
y_shape NoUniqueness
_ MemBind
y_ret)
| PrimType
x_pt PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
y_pt,
ShapeBase ExtSize -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase ExtSize
x_shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
y_shape = do
(ExtSize -> SubExp -> ExceptT String (TypeM lore) ())
-> [ExtSize] -> Result -> ExceptT String (TypeM lore) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ExtSize -> SubExp -> ExceptT String (TypeM lore) ()
checkDim (ShapeBase ExtSize -> [ExtSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase ExtSize
x_shape) (Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
y_shape)
MemReturn -> MemBind -> ExceptT String (TypeM lore) ()
checkMemReturn MemReturn
x_ret MemBind
y_ret
checkReturn MemInfo ExtSize u MemReturn
x MemInfo SubExp NoUniqueness MemBind
y =
String -> ExceptT String (TypeM lore) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM lore) ())
-> String -> ExceptT String (TypeM lore) ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords [String
"Expected", MemInfo ExtSize u MemReturn -> String
forall a. Pretty a => a -> String
pretty MemInfo ExtSize u MemReturn
x, String
"but got", MemInfo SubExp NoUniqueness MemBind -> String
forall a. Pretty a => a -> String
pretty MemInfo SubExp NoUniqueness MemBind
y]
checkDim :: ExtSize -> SubExp -> ExceptT String (TypeM lore) ()
checkDim (Free SubExp
x) SubExp
y
| SubExp
x SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y = () -> ExceptT String (TypeM lore) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise =
String -> ExceptT String (TypeM lore) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM lore) ())
-> String -> ExceptT String (TypeM lore) ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords
[ String
"Expected dim",
SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
x,
String
"but got",
SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
y
]
checkDim (Ext Int
i) SubExp
y = do
(SubExp
x, MemInfo SubExp NoUniqueness MemBind
_) <- Int
-> ExceptT
String (TypeM lore) (SubExp, MemInfo SubExp NoUniqueness MemBind)
fetchCtx Int
i
Bool
-> ExceptT String (TypeM lore) () -> ExceptT String (TypeM lore) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (SubExp
x SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y) (ExceptT String (TypeM lore) () -> ExceptT String (TypeM lore) ())
-> ExceptT String (TypeM lore) () -> ExceptT String (TypeM lore) ()
forall a b. (a -> b) -> a -> b
$
String -> ExceptT String (TypeM lore) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM lore) ())
-> String -> ExceptT String (TypeM lore) ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords
[ String
"Expected ext dim",
Int -> String
forall a. Pretty a => a -> String
pretty Int
i,
String
"=>",
SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
x,
String
"but got",
SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
y
]
extsInMemInfo :: MemInfo ExtSize u MemReturn -> S.Set Int
extsInMemInfo :: MemInfo ExtSize u MemReturn -> Set Int
extsInMemInfo (MemArray PrimType
_ ShapeBase ExtSize
shp u
_ MemReturn
ret) =
ShapeBase ExtSize -> Set Int
extInShape ShapeBase ExtSize
shp Set Int -> Set Int -> Set Int
forall a. Semigroup a => a -> a -> a
<> MemReturn -> Set Int
extInMemReturn MemReturn
ret
extsInMemInfo MemInfo ExtSize u MemReturn
_ = Set Int
forall a. Set a
S.empty
checkMemReturn :: MemReturn -> MemBind -> ExceptT String (TypeM lore) ()
checkMemReturn (ReturnsInBlock VName
x_mem ExtIxFun
x_ixfun) (ArrayIn VName
y_mem IxFun
y_ixfun)
| VName
x_mem VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y_mem =
Bool
-> ExceptT String (TypeM lore) () -> ExceptT String (TypeM lore) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ExtIxFun -> ExtIxFun -> Bool
forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun (ExtIxFun -> Bool) -> ExtIxFun -> Bool
forall a b. (a -> b) -> a -> b
$ IxFun -> ExtIxFun
existentialiseIxFun0 IxFun
y_ixfun) (ExceptT String (TypeM lore) () -> ExceptT String (TypeM lore) ())
-> ExceptT String (TypeM lore) () -> ExceptT String (TypeM lore) ()
forall a b. (a -> b) -> a -> b
$
String -> ExceptT String (TypeM lore) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM lore) ())
-> String -> ExceptT String (TypeM lore) ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords
[ String
"Index function unification failed (ReturnsInBlock)",
String
"\nixfun of body result: ",
IxFun -> String
forall a. Pretty a => a -> String
pretty IxFun
y_ixfun,
String
"\nixfun of return type: ",
ExtIxFun -> String
forall a. Pretty a => a -> String
pretty ExtIxFun
x_ixfun,
String
"\nand context elements: ",
Result -> String
forall a. Pretty a => a -> String
pretty Result
ctx_res
]
checkMemReturn
(ReturnsNewBlock Space
x_space Int
x_ext ExtIxFun
x_ixfun)
(ArrayIn VName
y_mem IxFun
y_ixfun) = do
(SubExp
x_mem, MemInfo SubExp NoUniqueness MemBind
x_mem_type) <- Int
-> ExceptT
String (TypeM lore) (SubExp, MemInfo SubExp NoUniqueness MemBind)
fetchCtx Int
x_ext
Bool
-> ExceptT String (TypeM lore) () -> ExceptT String (TypeM lore) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ExtIxFun -> ExtIxFun -> Bool
forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun (ExtIxFun -> Bool) -> ExtIxFun -> Bool
forall a b. (a -> b) -> a -> b
$ IxFun -> ExtIxFun
existentialiseIxFun0 IxFun
y_ixfun) (ExceptT String (TypeM lore) () -> ExceptT String (TypeM lore) ())
-> ExceptT String (TypeM lore) () -> ExceptT String (TypeM lore) ()
forall a b. (a -> b) -> a -> b
$
String -> ExceptT String (TypeM lore) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM lore) ())
-> String -> ExceptT String (TypeM lore) ()
forall a b. (a -> b) -> a -> b
$
Doc -> String
forall a. Pretty a => a -> String
pretty (Doc -> String) -> Doc -> String
forall a b. (a -> b) -> a -> b
$
Doc
"Index function unification failed (ReturnsNewBlock)"
Doc -> Doc -> Doc
</> Doc
"Ixfun of body result:"
Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (IxFun -> Doc
forall a. Pretty a => a -> Doc
ppr IxFun
y_ixfun)
Doc -> Doc -> Doc
</> Doc
"Ixfun of return type:"
Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (ExtIxFun -> Doc
forall a. Pretty a => a -> Doc
ppr ExtIxFun
x_ixfun)
Doc -> Doc -> Doc
</> Doc
"Context elements: "
Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (Result -> Doc
forall a. Pretty a => a -> Doc
ppr Result
ctx_res)
case MemInfo SubExp NoUniqueness MemBind
x_mem_type of
MemMem Space
y_space ->
Bool
-> ExceptT String (TypeM lore) () -> ExceptT String (TypeM lore) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Space
x_space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
y_space) (ExceptT String (TypeM lore) () -> ExceptT String (TypeM lore) ())
-> ExceptT String (TypeM lore) () -> ExceptT String (TypeM lore) ()
forall a b. (a -> b) -> a -> b
$
String -> ExceptT String (TypeM lore) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM lore) ())
-> String -> ExceptT String (TypeM lore) ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords
[ String
"Expected memory",
VName -> String
forall a. Pretty a => a -> String
pretty VName
y_mem,
String
"in space",
Space -> String
forall a. Pretty a => a -> String
pretty Space
x_space,
String
"but actually in space",
Space -> String
forall a. Pretty a => a -> String
pretty Space
y_space
]
MemInfo SubExp NoUniqueness MemBind
t ->
String -> ExceptT String (TypeM lore) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM lore) ())
-> String -> ExceptT String (TypeM lore) ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords
[ String
"Expected memory",
Int -> String
forall a. Pretty a => a -> String
pretty Int
x_ext,
String
"=>",
SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
x_mem,
String
"but but has type",
MemInfo SubExp NoUniqueness MemBind -> String
forall a. Pretty a => a -> String
pretty MemInfo SubExp NoUniqueness MemBind
t
]
checkMemReturn MemReturn
x MemBind
y =
String -> ExceptT String (TypeM lore) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM lore) ())
-> String -> ExceptT String (TypeM lore) ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords
[ String
"Expected array in",
MemReturn -> String
forall a. Pretty a => a -> String
pretty MemReturn
x,
String
"but array returned in",
MemBind -> String
forall a. Pretty a => a -> String
pretty MemBind
y
]
bad :: String -> TC.TypeM lore a
bad :: String -> TypeM lore a
bad String
s =
ErrorCase lore -> TypeM lore a
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore a) -> ErrorCase lore -> TypeM lore a
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
Doc -> String
forall a. Pretty a => a -> String
pretty (Doc -> String) -> Doc -> String
forall a b. (a -> b) -> a -> b
$
Doc
"Return type"
Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 ([MemInfo ExtSize u MemReturn] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [MemInfo ExtSize u MemReturn]
rettype)
Doc -> Doc -> Doc
</> Doc
"cannot match returns of results"
Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 ([MemInfo SubExp NoUniqueness MemBind] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [MemInfo SubExp NoUniqueness MemBind]
ts)
Doc -> Doc -> Doc
</> String -> Doc
text String
s
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Set Int -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Set Int] -> Set Int
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions ([Set Int] -> Set Int) -> [Set Int] -> Set Int
forall a b. (a -> b) -> a -> b
$ (MemInfo ExtSize u MemReturn -> Set Int)
-> [MemInfo ExtSize u MemReturn] -> [Set Int]
forall a b. (a -> b) -> [a] -> [b]
map MemInfo ExtSize u MemReturn -> Set Int
forall u. MemInfo ExtSize u MemReturn -> Set Int
extsInMemInfo [MemInfo ExtSize u MemReturn]
rettype) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
ctx_res) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Too many context parameters for the number of "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"existentials in the return type! type:\n "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [MemInfo ExtSize u MemReturn] -> String
forall a. Pretty a => [a] -> String
prettyTuple [MemInfo ExtSize u MemReturn]
rettype
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\ncannot match context parameters:\n "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Result -> String
forall a. Pretty a => [a] -> String
prettyTuple Result
ctx_res
(String -> TypeM lore ())
-> (() -> TypeM lore ()) -> Either String () -> TypeM lore ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> TypeM lore ()
forall lore a. String -> TypeM lore a
bad () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String () -> TypeM lore ())
-> TypeM lore (Either String ()) -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExceptT String (TypeM lore) () -> TypeM lore (Either String ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ((MemInfo ExtSize u MemReturn
-> MemInfo SubExp NoUniqueness MemBind
-> ExceptT String (TypeM lore) ())
-> [MemInfo ExtSize u MemReturn]
-> [MemInfo SubExp NoUniqueness MemBind]
-> ExceptT String (TypeM lore) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ MemInfo ExtSize u MemReturn
-> MemInfo SubExp NoUniqueness MemBind
-> ExceptT String (TypeM lore) ()
checkReturn [MemInfo ExtSize u MemReturn]
rettype [MemInfo SubExp NoUniqueness MemBind]
val_ts)
matchPatternToExp ::
(Mem lore, TC.Checkable lore) =>
Pattern (Aliases lore) ->
Exp (Aliases lore) ->
TC.TypeM lore ()
matchPatternToExp :: Pattern (Aliases lore) -> Exp (Aliases lore) -> TypeM lore ()
matchPatternToExp Pattern (Aliases lore)
pat Exp (Aliases lore)
e = do
Scope lore
scope <- (Scope (Aliases lore) -> Scope lore) -> TypeM lore (Scope lore)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope (Aliases lore) -> Scope lore
forall lore. Scope (Aliases lore) -> Scope lore
removeScopeAliases
[ExpReturns]
rt <- ReaderT (Scope lore) (TypeM lore) [ExpReturns]
-> Scope lore -> TypeM lore [ExpReturns]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Exp lore -> ReaderT (Scope lore) (TypeM lore) [ExpReturns]
forall (m :: * -> *) lore.
(Monad m, HasScope lore m, Mem lore) =>
Exp lore -> m [ExpReturns]
expReturns (Exp lore -> ReaderT (Scope lore) (TypeM lore) [ExpReturns])
-> Exp lore -> ReaderT (Scope lore) (TypeM lore) [ExpReturns]
forall a b. (a -> b) -> a -> b
$ Exp (Aliases lore) -> Exp lore
forall lore.
CanBeAliased (Op lore) =>
Exp (Aliases lore) -> Exp lore
removeExpAliases Exp (Aliases lore)
e) Scope lore
scope
let ([(VName, BodyReturns)]
ctxs, [(VName, BodyReturns)]
vals) = PatternT (MemInfo SubExp NoUniqueness MemBind)
-> ([(VName, BodyReturns)], [(VName, BodyReturns)])
bodyReturnsFromPattern (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> ([(VName, BodyReturns)], [(VName, BodyReturns)]))
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
-> ([(VName, BodyReturns)], [(VName, BodyReturns)])
forall a b. (a -> b) -> a -> b
$ PatternT (AliasDec, MemInfo SubExp NoUniqueness MemBind)
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall a. PatternT (AliasDec, a) -> PatternT a
removePatternAliases PatternT (AliasDec, MemInfo SubExp NoUniqueness MemBind)
Pattern (Aliases lore)
pat
([VName]
ctx_ids, [BodyReturns]
_ctx_ts) = [(VName, BodyReturns)] -> ([VName], [BodyReturns])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, BodyReturns)]
ctxs
([VName]
_val_ids, [BodyReturns]
val_ts) = [(VName, BodyReturns)] -> ([VName], [BodyReturns])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, BodyReturns)]
vals
(Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_ids, Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_exts) =
[(VName, Int)]
-> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
Map (Ext VName) (TPrimExp Int64 (Ext VName)))
getExtMaps ([(VName, Int)]
-> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
Map (Ext VName) (TPrimExp Int64 (Ext VName))))
-> [(VName, Int)]
-> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
Map (Ext VName) (TPrimExp Int64 (Ext VName)))
forall a b. (a -> b) -> a -> b
$ [VName] -> [Int] -> [(VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ctx_ids [Int
0 .. [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ctx_ids Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
let rt_exts :: Set Int
rt_exts = (ExpReturns -> Set Int) -> [ExpReturns] -> Set Int
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ExpReturns -> Set Int
extInExpReturns [ExpReturns]
rt
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
( [BodyReturns] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BodyReturns]
val_ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [ExpReturns] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExpReturns]
rt
Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((BodyReturns -> ExpReturns -> Bool)
-> [BodyReturns] -> [ExpReturns] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> BodyReturns
-> ExpReturns
-> Bool
forall d u u.
Eq d =>
Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> MemInfo d u MemReturn
-> MemInfo d u (Maybe MemReturn)
-> Bool
matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_ids Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_exts) [BodyReturns]
val_ts [ExpReturns]
rt)
Bool -> Bool -> Bool
&& Map (Ext VName) (TPrimExp Int64 (Ext VName)) -> Set (Ext VName)
forall k a. Map k a -> Set k
M.keysSet Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_exts Set (Ext VName) -> Set (Ext VName) -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`S.isSubsetOf` (Int -> Ext VName) -> Set Int -> Set (Ext VName)
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map Int -> Ext VName
forall a. Int -> Ext a
Ext Set Int
rt_exts
)
(TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Expression type:\n " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [ExpReturns] -> String
forall a. Pretty a => [a] -> String
prettyTuple [ExpReturns]
rt
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\ncannot match pattern type:\n "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [BodyReturns] -> String
forall a. Pretty a => [a] -> String
prettyTuple [BodyReturns]
val_ts
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nwith context elements: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [VName] -> String
forall a. Pretty a => a -> String
pretty [VName]
ctx_ids
where
matches :: Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> MemInfo d u MemReturn
-> MemInfo d u (Maybe MemReturn)
-> Bool
matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ (MemPrim PrimType
x) (MemPrim PrimType
y) = PrimType
x PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
y
matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ (MemMem Space
x_space) (MemMem Space
y_space) =
Space
x_space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
y_space
matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts (MemArray PrimType
x_pt ShapeBase d
x_shape u
_ MemReturn
x_ret) (MemArray PrimType
y_pt ShapeBase d
y_shape u
_ Maybe MemReturn
y_ret) =
PrimType
x_pt PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
y_pt Bool -> Bool -> Bool
&& ShapeBase d
x_shape ShapeBase d -> ShapeBase d -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeBase d
y_shape
Bool -> Bool -> Bool
&& case (MemReturn
x_ret, Maybe MemReturn
y_ret) of
(ReturnsInBlock VName
_ ExtIxFun
x_ixfun, Just (ReturnsInBlock VName
_ ExtIxFun
y_ixfun)) ->
let x_ixfun' :: ExtIxFun
x_ixfun' = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids ExtIxFun
x_ixfun
y_ixfun' :: ExtIxFun
y_ixfun' = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts ExtIxFun
y_ixfun
in ExtIxFun -> ExtIxFun -> Bool
forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun' ExtIxFun
y_ixfun'
( ReturnsInBlock VName
_ ExtIxFun
x_ixfun,
Just (ReturnsNewBlock Space
_ Int
_ ExtIxFun
y_ixfun)
) ->
let x_ixfun' :: ExtIxFun
x_ixfun' = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids ExtIxFun
x_ixfun
y_ixfun' :: ExtIxFun
y_ixfun' = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts ExtIxFun
y_ixfun
in ExtIxFun -> ExtIxFun -> Bool
forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun' ExtIxFun
y_ixfun'
( ReturnsNewBlock Space
_ Int
x_i ExtIxFun
x_ixfun,
Just (ReturnsNewBlock Space
_ Int
y_i ExtIxFun
y_ixfun)
) ->
let x_ixfun' :: ExtIxFun
x_ixfun' = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids ExtIxFun
x_ixfun
y_ixfun' :: ExtIxFun
y_ixfun' = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts ExtIxFun
y_ixfun
in Int
x_i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
y_i Bool -> Bool -> Bool
&& ExtIxFun -> ExtIxFun -> Bool
forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun' ExtIxFun
y_ixfun'
(MemReturn
_, Maybe MemReturn
Nothing) -> Bool
True
(MemReturn, Maybe MemReturn)
_ -> Bool
False
matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ MemInfo d u MemReturn
_ MemInfo d u (Maybe MemReturn)
_ = Bool
False
extInExpReturns :: ExpReturns -> S.Set Int
extInExpReturns :: ExpReturns -> Set Int
extInExpReturns (MemArray PrimType
_ ShapeBase ExtSize
shape NoUniqueness
_ Maybe MemReturn
mem_return) =
ShapeBase ExtSize -> Set Int
extInShape ShapeBase ExtSize
shape Set Int -> Set Int -> Set Int
forall a. Semigroup a => a -> a -> a
<> Set Int -> (MemReturn -> Set Int) -> Maybe MemReturn -> Set Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Set Int
forall a. Set a
S.empty MemReturn -> Set Int
extInMemReturn Maybe MemReturn
mem_return
extInExpReturns ExpReturns
_ = Set Int
forall a. Monoid a => a
mempty
extInShape :: ShapeBase (Ext SubExp) -> S.Set Int
extInShape :: ShapeBase ExtSize -> Set Int
extInShape ShapeBase ExtSize
shape = [Int] -> Set Int
forall a. Ord a => [a] -> Set a
S.fromList ([Int] -> Set Int) -> [Int] -> Set Int
forall a b. (a -> b) -> a -> b
$ (ExtSize -> Maybe Int) -> [ExtSize] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ExtSize -> Maybe Int
forall a. Ext a -> Maybe Int
isExt ([ExtSize] -> [Int]) -> [ExtSize] -> [Int]
forall a b. (a -> b) -> a -> b
$ ShapeBase ExtSize -> [ExtSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase ExtSize
shape
extInMemReturn :: MemReturn -> S.Set Int
extInMemReturn :: MemReturn -> Set Int
extInMemReturn (ReturnsInBlock VName
_ ExtIxFun
extixfn) = ExtIxFun -> Set Int
extInIxFn ExtIxFun
extixfn
extInMemReturn (ReturnsNewBlock Space
_ Int
i ExtIxFun
extixfn) =
Int -> Set Int
forall a. a -> Set a
S.singleton Int
i Set Int -> Set Int -> Set Int
forall a. Semigroup a => a -> a -> a
<> ExtIxFun -> Set Int
extInIxFn ExtIxFun
extixfn
extInIxFn :: ExtIxFun -> S.Set Int
extInIxFn :: ExtIxFun -> Set Int
extInIxFn ExtIxFun
ixfun = [Int] -> Set Int
forall a. Ord a => [a] -> Set a
S.fromList ([Int] -> Set Int) -> [Int] -> Set Int
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 (Ext VName) -> [Int]) -> ExtIxFun -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Ext VName -> Maybe Int) -> [Ext VName] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Ext VName -> Maybe Int
forall a. Ext a -> Maybe Int
isExt ([Ext VName] -> [Int])
-> (TPrimExp Int64 (Ext VName) -> [Ext VName])
-> TPrimExp Int64 (Ext VName)
-> [Int]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 (Ext VName) -> [Ext VName]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList) ExtIxFun
ixfun
varMemInfo ::
Mem lore =>
VName ->
TC.TypeM lore (MemInfo SubExp NoUniqueness MemBind)
varMemInfo :: VName -> TypeM lore (MemInfo SubExp NoUniqueness MemBind)
varMemInfo VName
name = do
NameInfo (Aliases lore)
dec <- VName -> TypeM lore (NameInfo (Aliases lore))
forall lore. VName -> TypeM lore (NameInfo (Aliases lore))
TC.lookupVar VName
name
case NameInfo (Aliases lore)
dec of
LetName (_, summary) -> MemInfo SubExp NoUniqueness MemBind
-> TypeM lore (MemInfo SubExp NoUniqueness MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return MemInfo SubExp NoUniqueness MemBind
summary
FParamName FParamInfo (Aliases lore)
summary -> MemInfo SubExp NoUniqueness MemBind
-> TypeM lore (MemInfo SubExp NoUniqueness MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo SubExp NoUniqueness MemBind
-> TypeM lore (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> TypeM lore (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ FParamMem -> MemInfo SubExp NoUniqueness MemBind
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo (Aliases lore)
FParamMem
summary
LParamName LParamInfo (Aliases lore)
summary -> MemInfo SubExp NoUniqueness MemBind
-> TypeM lore (MemInfo SubExp NoUniqueness MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return LParamInfo (Aliases lore)
MemInfo SubExp NoUniqueness MemBind
summary
IndexName IntType
it -> MemInfo SubExp NoUniqueness MemBind
-> TypeM lore (MemInfo SubExp NoUniqueness MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo SubExp NoUniqueness MemBind
-> TypeM lore (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> TypeM lore (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> MemInfo SubExp NoUniqueness MemBind)
-> PrimType -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
nameInfoToMemInfo :: Mem lore => NameInfo lore -> MemBound NoUniqueness
nameInfoToMemInfo :: NameInfo lore -> MemInfo SubExp NoUniqueness MemBind
nameInfoToMemInfo NameInfo lore
info =
case NameInfo lore
info of
FParamName FParamInfo lore
summary -> FParamMem -> MemInfo SubExp NoUniqueness MemBind
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo lore
FParamMem
summary
LParamName LParamInfo lore
summary -> LParamInfo lore
MemInfo SubExp NoUniqueness MemBind
summary
LetName LetDec lore
summary -> LetDec lore
MemInfo SubExp NoUniqueness MemBind
summary
IndexName IntType
it -> PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> MemInfo SubExp NoUniqueness MemBind)
-> PrimType -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
lookupMemInfo ::
(HasScope lore m, Mem lore) =>
VName ->
m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo :: VName -> m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo = (NameInfo lore -> MemInfo SubExp NoUniqueness MemBind)
-> m (NameInfo lore) -> m (MemInfo SubExp NoUniqueness MemBind)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap NameInfo lore -> MemInfo SubExp NoUniqueness MemBind
forall lore.
Mem lore =>
NameInfo lore -> MemInfo SubExp NoUniqueness MemBind
nameInfoToMemInfo (m (NameInfo lore) -> m (MemInfo SubExp NoUniqueness MemBind))
-> (VName -> m (NameInfo lore))
-> VName
-> m (MemInfo SubExp NoUniqueness MemBind)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> m (NameInfo lore)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (NameInfo lore)
lookupInfo
subExpMemInfo ::
(HasScope lore m, Monad m, Mem lore) =>
SubExp ->
m (MemInfo SubExp NoUniqueness MemBind)
subExpMemInfo :: SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
subExpMemInfo (Var VName
v) = VName -> m (MemInfo SubExp NoUniqueness MemBind)
forall lore (m :: * -> *).
(HasScope lore m, Mem lore) =>
VName -> m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo VName
v
subExpMemInfo (Constant PrimValue
v) = MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> MemInfo SubExp NoUniqueness MemBind)
-> PrimType -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
lookupArraySummary ::
(Mem lore, HasScope lore m, Monad m) =>
VName ->
m (VName, IxFun.IxFun (TPrimExp Int64 VName))
lookupArraySummary :: VName -> m (VName, IxFun)
lookupArraySummary VName
name = do
MemInfo SubExp NoUniqueness MemBind
summary <- VName -> m (MemInfo SubExp NoUniqueness MemBind)
forall lore (m :: * -> *).
(HasScope lore m, Mem lore) =>
VName -> m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo VName
name
case MemInfo SubExp NoUniqueness MemBind
summary of
MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
(VName, IxFun) -> m (VName, IxFun)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, IxFun
ixfun)
MemInfo SubExp NoUniqueness MemBind
_ ->
String -> m (VName, IxFun)
forall a. HasCallStack => String -> a
error (String -> m (VName, IxFun)) -> String -> m (VName, IxFun)
forall a b. (a -> b) -> a -> b
$ String
"Variable " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" does not look like an array."
checkMemInfo ::
TC.Checkable lore =>
VName ->
MemInfo SubExp u MemBind ->
TC.TypeM lore ()
checkMemInfo :: VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo VName
_ (MemPrim PrimType
_) = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkMemInfo VName
_ (MemMem (ScalarSpace Result
d PrimType
_)) = (SubExp -> TypeM lore ()) -> Result -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Result
d
checkMemInfo VName
_ (MemMem Space
_) = () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkMemInfo VName
name (MemArray PrimType
_ Shape
shape u
_ (ArrayIn VName
v IxFun
ixfun)) = do
Type
t <- VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
case Type
t of
Mem {} ->
() -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Type
_ ->
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Variable " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" used as memory block, but is of type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
t
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"."
String -> TypeM lore () -> TypeM lore ()
forall lore a. String -> TypeM lore a -> TypeM lore a
TC.context (String
"in index function " String -> ShowS
forall a. [a] -> [a] -> [a]
++ IxFun -> String
forall a. Pretty a => a -> String
pretty IxFun
ixfun) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ do
(TPrimExp Int64 VName -> TypeM lore ()) -> IxFun -> TypeM lore ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (PrimType -> PrimExp VName -> TypeM lore ()
forall lore.
Checkable lore =>
PrimType -> PrimExp VName -> TypeM lore ()
TC.requirePrimExp PrimType
int64 (PrimExp VName -> TypeM lore ())
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> TypeM lore ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped) IxFun
ixfun
let ixfun_rank :: Int
ixfun_rank = IxFun -> Int
forall num. IntegralExp num => IxFun num -> Int
IxFun.rank IxFun
ixfun
ident_rank :: Int
ident_rank = Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape
Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
ixfun_rank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
ident_rank) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
String
"Arity of index function (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
ixfun_rank
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") does not match rank of array "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" ("
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
ident_rank
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
bodyReturnsFromPattern ::
PatternT (MemBound NoUniqueness) ->
([(VName, BodyReturns)], [(VName, BodyReturns)])
bodyReturnsFromPattern :: PatternT (MemInfo SubExp NoUniqueness MemBind)
-> ([(VName, BodyReturns)], [(VName, BodyReturns)])
bodyReturnsFromPattern PatternT (MemInfo SubExp NoUniqueness MemBind)
pat =
( (PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> (VName, BodyReturns))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, BodyReturns)]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> (VName, BodyReturns)
asReturns ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, BodyReturns)])
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, BodyReturns)]
forall a b. (a -> b) -> a -> b
$ PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT (MemInfo SubExp NoUniqueness MemBind)
pat,
(PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> (VName, BodyReturns))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, BodyReturns)]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> (VName, BodyReturns)
asReturns ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, BodyReturns)])
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(VName, BodyReturns)]
forall a b. (a -> b) -> a -> b
$ PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT (MemInfo SubExp NoUniqueness MemBind)
pat
)
where
ctx :: [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctx = PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT (MemInfo SubExp NoUniqueness MemBind)
pat
ext :: SubExp -> ExtSize
ext (Var VName
v)
| Just (Int
i, PatElemT (MemInfo SubExp NoUniqueness MemBind)
_) <- ((Int, PatElemT (MemInfo SubExp NoUniqueness MemBind)) -> Bool)
-> [(Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> Maybe (Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Int, PatElemT (MemInfo SubExp NoUniqueness MemBind)) -> VName)
-> (Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> ((Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> (Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> VName
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a, b) -> b
snd) ([(Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> Maybe (Int, PatElemT (MemInfo SubExp NoUniqueness MemBind)))
-> [(Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> Maybe (Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ [Int]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctx =
Int -> ExtSize
forall a. Int -> Ext a
Ext Int
i
ext SubExp
se = SubExp -> ExtSize
forall a. a -> Ext a
Free SubExp
se
asReturns :: PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> (VName, BodyReturns)
asReturns PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe =
( PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe,
case PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> MemInfo SubExp NoUniqueness MemBind
forall dec. PatElemT dec -> dec
patElemDec PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe of
MemPrim PrimType
pt -> PrimType -> BodyReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
MemMem Space
space -> Space -> BodyReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
MemArray PrimType
pt Shape
shape NoUniqueness
u (ArrayIn VName
mem IxFun
ixfun) ->
PrimType
-> ShapeBase ExtSize -> NoUniqueness -> MemReturn -> BodyReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ([ExtSize] -> ShapeBase ExtSize
forall d. [d] -> ShapeBase d
Shape ([ExtSize] -> ShapeBase ExtSize) -> [ExtSize] -> ShapeBase ExtSize
forall a b. (a -> b) -> a -> b
$ (SubExp -> ExtSize) -> Result -> [ExtSize]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> ExtSize
ext (Result -> [ExtSize]) -> Result -> [ExtSize]
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) NoUniqueness
u (MemReturn -> BodyReturns) -> MemReturn -> BodyReturns
forall a b. (a -> b) -> a -> b
$
case ((Int, PatElemT (MemInfo SubExp NoUniqueness MemBind)) -> Bool)
-> [(Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> Maybe (Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
mem) (VName -> Bool)
-> ((Int, PatElemT (MemInfo SubExp NoUniqueness MemBind)) -> VName)
-> (Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> ((Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> (Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> VName
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a, b) -> b
snd) ([(Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> Maybe (Int, PatElemT (MemInfo SubExp NoUniqueness MemBind)))
-> [(Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))]
-> Maybe (Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ [Int]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [(Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctx of
Just (Int
i, PatElem VName
_ (MemMem Space
space)) ->
Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
[VName] -> IxFun -> ExtIxFun
existentialiseIxFun ((PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctx) IxFun
ixfun
Maybe (Int, PatElemT (MemInfo SubExp NoUniqueness MemBind))
_ -> VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$ [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun
)
extReturns :: [ExtType] -> [ExpReturns]
extReturns :: [ExtType] -> [ExpReturns]
extReturns [ExtType]
ts =
State Int [ExpReturns] -> Int -> [ExpReturns]
forall s a. State s a -> s -> a
evalState ((ExtType -> StateT Int Identity ExpReturns)
-> [ExtType] -> State Int [ExpReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ExtType -> StateT Int Identity ExpReturns
forall (m :: * -> *). MonadState Int m => ExtType -> m ExpReturns
addDec [ExtType]
ts) Int
0
where
addDec :: ExtType -> m ExpReturns
addDec (Prim PrimType
bt) =
ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
addDec (Mem Space
space) =
ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
addDec t :: ExtType
t@(Array PrimType
bt ShapeBase ExtSize
shape NoUniqueness
u)
| ExtType -> Bool
existential ExtType
t = do
Int
i <- m Int
forall s (m :: * -> *). MonadState s m => m s
get m Int -> m () -> m Int
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* (Int -> Int) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$
PrimType
-> ShapeBase ExtSize
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase ExtSize
shape NoUniqueness
u (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$
Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
DefaultSpace Int
i (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
Shape (TPrimExp Int64 (Ext VName)) -> ExtIxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 (Ext VName)) -> ExtIxFun)
-> Shape (TPrimExp Int64 (Ext VName)) -> ExtIxFun
forall a b. (a -> b) -> a -> b
$ (ExtSize -> TPrimExp Int64 (Ext VName))
-> [ExtSize] -> Shape (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> [a] -> [b]
map ExtSize -> TPrimExp Int64 (Ext VName)
convert ([ExtSize] -> Shape (TPrimExp Int64 (Ext VName)))
-> [ExtSize] -> Shape (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$ ShapeBase ExtSize -> [ExtSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase ExtSize
shape
| Bool
otherwise =
ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase ExtSize
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase ExtSize
shape NoUniqueness
u Maybe MemReturn
forall a. Maybe a
Nothing
convert :: ExtSize -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i)
convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v
arrayVarReturns ::
(HasScope lore m, Monad m, Mem lore) =>
VName ->
m (PrimType, Shape, VName, IxFun)
arrayVarReturns :: VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v = do
MemInfo SubExp NoUniqueness MemBind
summary <- VName -> m (MemInfo SubExp NoUniqueness MemBind)
forall lore (m :: * -> *).
(HasScope lore m, Mem lore) =>
VName -> m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo VName
v
case MemInfo SubExp NoUniqueness MemBind
summary of
MemArray PrimType
et Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
(PrimType, Shape, VName, IxFun)
-> m (PrimType, Shape, VName, IxFun)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimType
et, Result -> Shape
forall d. [d] -> ShapeBase d
Shape (Result -> Shape) -> Result -> Shape
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape, VName
mem, IxFun
ixfun)
MemInfo SubExp NoUniqueness MemBind
_ ->
String -> m (PrimType, Shape, VName, IxFun)
forall a. HasCallStack => String -> a
error (String -> m (PrimType, Shape, VName, IxFun))
-> String -> m (PrimType, Shape, VName, IxFun)
forall a b. (a -> b) -> a -> b
$ String
"arrayVarReturns: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not an array."
varReturns ::
(HasScope lore m, Monad m, Mem lore) =>
VName ->
m ExpReturns
varReturns :: VName -> m ExpReturns
varReturns VName
v = do
MemInfo SubExp NoUniqueness MemBind
summary <- VName -> m (MemInfo SubExp NoUniqueness MemBind)
forall lore (m :: * -> *).
(HasScope lore m, Mem lore) =>
VName -> m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo VName
v
case MemInfo SubExp NoUniqueness MemBind
summary of
MemPrim PrimType
bt ->
ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
MemArray PrimType
et Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$
PrimType
-> ShapeBase ExtSize
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ((SubExp -> ExtSize) -> Shape -> ShapeBase ExtSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ExtSize
forall a. a -> Ext a
Free Shape
shape) NoUniqueness
NoUniqueness (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$ [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun
MemMem Space
space ->
ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
expReturns ::
( Monad m,
HasScope lore m,
Mem lore
) =>
Exp lore ->
m [ExpReturns]
expReturns :: Exp lore -> m [ExpReturns]
expReturns (BasicOp (SubExp (Var VName
v))) =
ExpReturns -> [ExpReturns]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> [ExpReturns]) -> m ExpReturns -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m ExpReturns
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
VName -> m ExpReturns
varReturns VName
v
expReturns (BasicOp (Opaque (Var VName
v))) =
ExpReturns -> [ExpReturns]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> [ExpReturns]) -> m ExpReturns -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m ExpReturns
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
VName -> m ExpReturns
varReturns VName
v
expReturns (BasicOp (Reshape ShapeChange SubExp
newshape VName
v)) = do
(PrimType
et, Shape
_, VName
mem, IxFun
ixfun) <- VName -> m (PrimType, Shape, VName, IxFun)
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v
[ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return
[ PrimType
-> ShapeBase ExtSize
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ([ExtSize] -> ShapeBase ExtSize
forall d. [d] -> ShapeBase d
Shape ([ExtSize] -> ShapeBase ExtSize) -> [ExtSize] -> ShapeBase ExtSize
forall a b. (a -> b) -> a -> b
$ (DimChange SubExp -> ExtSize) -> ShapeChange SubExp -> [ExtSize]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ExtSize
forall a. a -> Ext a
Free (SubExp -> ExtSize)
-> (DimChange SubExp -> SubExp) -> DimChange SubExp -> ExtSize
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. DimChange SubExp -> SubExp
forall d. DimChange d -> d
newDim) ShapeChange SubExp
newshape) NoUniqueness
NoUniqueness (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$
VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
[VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] (IxFun -> ExtIxFun) -> IxFun -> ExtIxFun
forall a b. (a -> b) -> a -> b
$
IxFun -> ShapeChange (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
IxFun.reshape IxFun
ixfun (ShapeChange (TPrimExp Int64 VName) -> IxFun)
-> ShapeChange (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (DimChange SubExp -> DimChange (TPrimExp Int64 VName))
-> ShapeChange SubExp -> ShapeChange (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimChange SubExp -> DimChange (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) ShapeChange SubExp
newshape
]
expReturns (BasicOp (Rearrange [Int]
perm VName
v)) = do
(PrimType
et, Shape Result
dims, VName
mem, IxFun
ixfun) <- VName -> m (PrimType, Shape, VName, IxFun)
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v
let ixfun' :: IxFun
ixfun' = IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun
ixfun [Int]
perm
dims' :: Result
dims' = [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm Result
dims
[ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return
[ PrimType
-> ShapeBase ExtSize
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ([ExtSize] -> ShapeBase ExtSize
forall d. [d] -> ShapeBase d
Shape ([ExtSize] -> ShapeBase ExtSize) -> [ExtSize] -> ShapeBase ExtSize
forall a b. (a -> b) -> a -> b
$ (SubExp -> ExtSize) -> Result -> [ExtSize]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> ExtSize
forall a. a -> Ext a
Free Result
dims') NoUniqueness
NoUniqueness (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$ [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun'
]
expReturns (BasicOp (Rotate Result
offsets VName
v)) = do
(PrimType
et, Shape Result
dims, VName
mem, IxFun
ixfun) <- VName -> m (PrimType, Shape, VName, IxFun)
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v
let offsets' :: [TPrimExp Int64 VName]
offsets' = (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 Result
offsets
ixfun' :: IxFun
ixfun' = IxFun -> [TPrimExp Int64 VName] -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> IxFun num
IxFun.rotate IxFun
ixfun [TPrimExp Int64 VName]
offsets'
[ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return
[ PrimType
-> ShapeBase ExtSize
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ([ExtSize] -> ShapeBase ExtSize
forall d. [d] -> ShapeBase d
Shape ([ExtSize] -> ShapeBase ExtSize) -> [ExtSize] -> ShapeBase ExtSize
forall a b. (a -> b) -> a -> b
$ (SubExp -> ExtSize) -> Result -> [ExtSize]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> ExtSize
forall a. a -> Ext a
Free Result
dims) NoUniqueness
NoUniqueness (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$ [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun'
]
expReturns (BasicOp (Index VName
v Slice SubExp
slice)) = do
MemInfo SubExp NoUniqueness MemBind
info <- VName -> Slice SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
forall (m :: * -> *) lore.
(Monad m, HasScope lore m, Mem lore) =>
VName -> Slice SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
sliceInfo VName
v Slice SubExp
slice
case MemInfo SubExp NoUniqueness MemBind
info of
MemArray PrimType
et Shape
shape NoUniqueness
u (ArrayIn VName
mem IxFun
ixfun) ->
[ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return
[ PrimType
-> ShapeBase ExtSize
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ((SubExp -> ExtSize) -> Shape -> ShapeBase ExtSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ExtSize
forall a. a -> Ext a
Free Shape
shape) NoUniqueness
u (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$ [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun
]
MemPrim PrimType
pt -> [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt]
MemMem Space
space -> [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
expReturns (BasicOp (Update VName
v Slice SubExp
_ SubExp
_)) =
ExpReturns -> [ExpReturns]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> [ExpReturns]) -> m ExpReturns -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m ExpReturns
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
VName -> m ExpReturns
varReturns VName
v
expReturns (BasicOp BasicOp
op) =
[ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns])
-> ([Type] -> [ExtType]) -> [Type] -> [ExpReturns]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase (ShapeBase ExtSize) u]
staticShapes ([Type] -> [ExpReturns]) -> m [Type] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BasicOp -> m [Type]
forall lore (m :: * -> *). HasScope lore m => BasicOp -> m [Type]
primOpType BasicOp
op
expReturns e :: Exp lore
e@(DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
_ BodyT lore
_) = do
[ExtType]
t <- Exp lore -> m [ExtType]
forall lore (m :: * -> *).
(HasScope lore m, TypedOp (Op lore)) =>
Exp lore -> m [ExtType]
expExtType Exp lore
e
(ExtType -> Param FParamMem -> m ExpReturns)
-> [ExtType] -> [Param FParamMem] -> m [ExpReturns]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ExtType -> Param FParamMem -> m ExpReturns
typeWithDec [ExtType]
t ([Param FParamMem] -> m [ExpReturns])
-> [Param FParamMem] -> m [ExpReturns]
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
val
where
typeWithDec :: ExtType -> Param FParamMem -> m ExpReturns
typeWithDec ExtType
t Param FParamMem
p =
case (ExtType
t, Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec Param FParamMem
p) of
( Array PrimType
bt ShapeBase ExtSize
shape NoUniqueness
u,
MemArray PrimType
_ Shape
_ Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun)
)
| Just (Int
i, Param FParamMem
mem_p) <- VName -> Maybe (Int, Param FParamMem)
isMergeVar VName
mem,
Mem Space
space <- Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
mem_p ->
ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase ExtSize
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase ExtSize
shape NoUniqueness
u (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$ MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun'
| Bool
otherwise ->
ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return
( PrimType
-> ShapeBase ExtSize
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase ExtSize
shape NoUniqueness
u (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem ExtIxFun
ixfun'
)
where
ixfun' :: ExtIxFun
ixfun' = [VName] -> IxFun -> ExtIxFun
existentialiseIxFun ((Param FParamMem -> VName) -> [Param FParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param FParamMem -> VName
forall dec. Param dec -> VName
paramName [Param FParamMem]
mergevars) IxFun
ixfun
(Array {}, FParamMem
_) ->
String -> m ExpReturns
forall a. HasCallStack => String -> a
error String
"expReturns: Array return type but not array merge variable."
(Prim PrimType
bt, FParamMem
_) ->
ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
(Mem {}, FParamMem
_) ->
String -> m ExpReturns
forall a. HasCallStack => String -> a
error String
"expReturns: loop returns memory block explicitly."
isMergeVar :: VName -> Maybe (Int, Param FParamMem)
isMergeVar VName
v = ((Int, Param FParamMem) -> Bool)
-> [(Int, Param FParamMem)] -> Maybe (Int, Param FParamMem)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Int, Param FParamMem) -> VName)
-> (Int, Param FParamMem)
-> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Int, Param FParamMem) -> Param FParamMem)
-> (Int, Param FParamMem)
-> VName
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int, Param FParamMem) -> Param FParamMem
forall a b. (a, b) -> b
snd) ([(Int, Param FParamMem)] -> Maybe (Int, Param FParamMem))
-> [(Int, Param FParamMem)] -> Maybe (Int, Param FParamMem)
forall a b. (a -> b) -> a -> b
$ [Int] -> [Param FParamMem] -> [(Int, Param FParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [Param FParamMem]
mergevars
mergevars :: [Param FParamMem]
mergevars = ((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst ([(Param FParamMem, SubExp)] -> [Param FParamMem])
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
ctx [(Param FParamMem, SubExp)]
-> [(Param FParamMem, SubExp)] -> [(Param FParamMem, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
val
expReturns (Apply Name
_ [(SubExp, Diet)]
_ [RetType lore]
ret (Safety, SrcLoc, [SrcLoc])
_) =
[ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpReturns] -> m [ExpReturns]) -> [ExpReturns] -> m [ExpReturns]
forall a b. (a -> b) -> a -> b
$ (FunReturns -> ExpReturns) -> [FunReturns] -> [ExpReturns]
forall a b. (a -> b) -> [a] -> [b]
map FunReturns -> ExpReturns
funReturnsToExpReturns [RetType lore]
[FunReturns]
ret
expReturns (If SubExp
_ BodyT lore
_ BodyT lore
_ (IfDec [BranchType lore]
ret IfSort
_)) =
[ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpReturns] -> m [ExpReturns]) -> [ExpReturns] -> m [ExpReturns]
forall a b. (a -> b) -> a -> b
$ (BodyReturns -> ExpReturns) -> [BodyReturns] -> [ExpReturns]
forall a b. (a -> b) -> [a] -> [b]
map BodyReturns -> ExpReturns
bodyReturnsToExpReturns [BranchType lore]
[BodyReturns]
ret
expReturns (Op Op lore
op) =
Op lore -> m [ExpReturns]
forall lore (m :: * -> *).
(OpReturns lore, Monad m, HasScope lore m) =>
Op lore -> m [ExpReturns]
opReturns Op lore
op
sliceInfo ::
(Monad m, HasScope lore m, Mem lore) =>
VName ->
Slice SubExp ->
m (MemInfo SubExp NoUniqueness MemBind)
sliceInfo :: VName -> Slice SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
sliceInfo VName
v Slice SubExp
slice = do
(PrimType
et, Shape
_, VName
mem, IxFun
ixfun) <- VName -> m (PrimType, Shape, VName, IxFun)
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Mem lore) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v
case Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice of
[] -> MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
et
Result
dims ->
MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$
PrimType
-> Shape
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et (Result -> Shape
forall d. [d] -> ShapeBase d
Shape Result
dims) NoUniqueness
NoUniqueness (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice
IxFun
ixfun
((DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Int64 VName
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64)) Slice SubExp
slice)
class TypedOp (Op lore) => OpReturns lore where
opReturns ::
(Monad m, HasScope lore m) =>
Op lore ->
m [ExpReturns]
opReturns Op lore
op = [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Op lore -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType Op lore
op
applyFunReturns ::
Typed dec =>
[FunReturns] ->
[Param dec] ->
[(SubExp, Type)] ->
Maybe [FunReturns]
applyFunReturns :: [FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
applyFunReturns [FunReturns]
rets [Param dec]
params [(SubExp, Type)]
args
| Just [DeclExtType]
_ <- [DeclExtType]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [DeclExtType]
forall rt dec.
(IsRetType rt, Typed dec) =>
[rt] -> [Param dec] -> [(SubExp, Type)] -> Maybe [rt]
applyRetType [DeclExtType]
rettype [Param dec]
params [(SubExp, Type)]
args =
[FunReturns] -> Maybe [FunReturns]
forall a. a -> Maybe a
Just ([FunReturns] -> Maybe [FunReturns])
-> [FunReturns] -> Maybe [FunReturns]
forall a b. (a -> b) -> a -> b
$ (FunReturns -> FunReturns) -> [FunReturns] -> [FunReturns]
forall a b. (a -> b) -> [a] -> [b]
map FunReturns -> FunReturns
forall u.
MemInfo ExtSize u MemReturn -> MemInfo ExtSize u MemReturn
correctDims [FunReturns]
rets
| Bool
otherwise =
Maybe [FunReturns]
forall a. Maybe a
Nothing
where
rettype :: [DeclExtType]
rettype = (FunReturns -> DeclExtType) -> [FunReturns] -> [DeclExtType]
forall a b. (a -> b) -> [a] -> [b]
map FunReturns -> DeclExtType
forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf [FunReturns]
rets
parammap :: M.Map VName (SubExp, Type)
parammap :: Map VName (SubExp, Type)
parammap =
[(VName, (SubExp, Type))] -> Map VName (SubExp, Type)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, (SubExp, Type))] -> Map VName (SubExp, Type))
-> [(VName, (SubExp, Type))] -> Map VName (SubExp, Type)
forall a b. (a -> b) -> a -> b
$
[VName] -> [(SubExp, Type)] -> [(VName, (SubExp, Type))]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
paramName [Param dec]
params) [(SubExp, Type)]
args
substSubExp :: SubExp -> SubExp
substSubExp (Var VName
v)
| Just (SubExp
se, Type
_) <- VName -> Map VName (SubExp, Type) -> Maybe (SubExp, Type)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (SubExp, Type)
parammap = SubExp
se
substSubExp SubExp
se = SubExp
se
correctDims :: MemInfo ExtSize u MemReturn -> MemInfo ExtSize u MemReturn
correctDims (MemPrim PrimType
t) =
PrimType -> MemInfo ExtSize u MemReturn
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
correctDims (MemMem Space
space) =
Space -> MemInfo ExtSize u MemReturn
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
correctDims (MemArray PrimType
et ShapeBase ExtSize
shape u
u MemReturn
memsummary) =
PrimType
-> ShapeBase ExtSize
-> u
-> MemReturn
-> MemInfo ExtSize u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et (ShapeBase ExtSize -> ShapeBase ExtSize
correctShape ShapeBase ExtSize
shape) u
u (MemReturn -> MemInfo ExtSize u MemReturn)
-> MemReturn -> MemInfo ExtSize u MemReturn
forall a b. (a -> b) -> a -> b
$
MemReturn -> MemReturn
correctSummary MemReturn
memsummary
correctShape :: ShapeBase ExtSize -> ShapeBase ExtSize
correctShape = [ExtSize] -> ShapeBase ExtSize
forall d. [d] -> ShapeBase d
Shape ([ExtSize] -> ShapeBase ExtSize)
-> (ShapeBase ExtSize -> [ExtSize])
-> ShapeBase ExtSize
-> ShapeBase ExtSize
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (ExtSize -> ExtSize) -> [ExtSize] -> [ExtSize]
forall a b. (a -> b) -> [a] -> [b]
map ExtSize -> ExtSize
correctDim ([ExtSize] -> [ExtSize])
-> (ShapeBase ExtSize -> [ExtSize])
-> ShapeBase ExtSize
-> [ExtSize]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ShapeBase ExtSize -> [ExtSize]
forall d. ShapeBase d -> [d]
shapeDims
correctDim :: ExtSize -> ExtSize
correctDim (Ext Int
i) = Int -> ExtSize
forall a. Int -> Ext a
Ext Int
i
correctDim (Free SubExp
se) = SubExp -> ExtSize
forall a. a -> Ext a
Free (SubExp -> ExtSize) -> SubExp -> ExtSize
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp
substSubExp SubExp
se
correctSummary :: MemReturn -> MemReturn
correctSummary (ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun) =
Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun
correctSummary (ReturnsInBlock VName
mem ExtIxFun
ixfun) =
VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem' ExtIxFun
ixfun
where
mem' :: VName
mem' = case VName -> Map VName (SubExp, Type) -> Maybe (SubExp, Type)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mem Map VName (SubExp, Type)
parammap of
Just (Var VName
v, Type
_) -> VName
v
Maybe (SubExp, Type)
_ -> VName
mem