{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies, FlexibleInstances, FlexibleContexts, MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Representation.ExplicitMemory
(
ExplicitMemory
, InKernel
, MemOp (..)
, MemInfo (..)
, MemBound
, MemBind (..)
, MemReturn (..)
, IxFun
, ExtIxFun
, isStaticIxFun
, ExpReturns
, BodyReturns
, FunReturns
, noUniquenessReturns
, bodyReturnsToExpReturns
, ExplicitMemorish
, expReturns
, extReturns
, sliceInfo
, lookupMemInfo
, subExpMemInfo
, lookupMemSize
, lookupArraySummary
, fullyLinear
, ixFunMatchesInnerShape
, existentialiseIxFun
, module Futhark.Representation.AST.Attributes
, module Futhark.Representation.AST.Traversals
, module Futhark.Representation.AST.Pretty
, module Futhark.Representation.AST.Syntax
, module Futhark.Representation.Kernels.Kernel
, module Futhark.Representation.Kernels.KernelExp
, module Futhark.Analysis.PrimExp.Convert
)
where
import Data.Maybe
import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Except
import qualified Data.Map.Strict as M
import Data.Foldable (traverse_)
import Data.List
import Data.Monoid ((<>))
import Futhark.Analysis.Metrics
import Futhark.Representation.AST.Syntax
import Futhark.Representation.Kernels.Kernel
import Futhark.Representation.Kernels.KernelExp
import Futhark.Representation.AST.Attributes
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.AST.Traversals
import Futhark.Representation.AST.Pretty
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.TypeCheck as TC
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.PrimExp.Simplify
import Futhark.Util
import Futhark.Util.IntegralExp
import qualified Futhark.Util.Pretty as PP
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Lore
import Futhark.Representation.Aliases
(Aliases, removeScopeAliases, removeExpAliases, removePatternAliases)
import Futhark.Representation.AST.Attributes.Ranges
import Futhark.Analysis.Usage
import qualified Futhark.Analysis.SymbolTable as ST
data ExplicitMemory
data InKernel
type ExplicitMemorish lore = (SameScope lore ExplicitMemory,
RetType lore ~ FunReturns,
BranchType lore ~ BodyReturns,
CanBeAliased (Op lore),
Attributes lore, Annotations lore,
TC.Checkable lore,
OpReturns lore)
instance IsRetType FunReturns where
primRetType = MemPrim
applyRetType = applyFunReturns
instance IsBodyType BodyReturns where
primBodyType = MemPrim
data MemOp inner = Alloc SubExp Space
| Inner inner
deriving (Eq, Ord, Show)
instance FreeIn inner => FreeIn (MemOp inner) where
freeIn (Alloc size _) = freeIn size
freeIn (Inner k) = freeIn k
instance TypedOp inner => TypedOp (MemOp inner) where
opType (Alloc size space) = pure [Mem size space]
opType (Inner k) = opType k
instance AliasedOp inner => AliasedOp (MemOp inner) where
opAliases Alloc{} = [mempty]
opAliases (Inner k) = opAliases k
consumedInOp Alloc{} = mempty
consumedInOp (Inner k) = consumedInOp k
instance CanBeAliased inner => CanBeAliased (MemOp inner) where
type OpWithAliases (MemOp inner) = MemOp (OpWithAliases inner)
removeOpAliases (Alloc se space) = Alloc se space
removeOpAliases (Inner k) = Inner $ removeOpAliases k
addOpAliases (Alloc se space) = Alloc se space
addOpAliases (Inner k) = Inner $ addOpAliases k
instance RangedOp inner => RangedOp (MemOp inner) where
opRanges (Alloc _ _) =
[unknownRange]
opRanges (Inner k) =
opRanges k
instance CanBeRanged inner => CanBeRanged (MemOp inner) where
type OpWithRanges (MemOp inner) = MemOp (OpWithRanges inner)
removeOpRanges (Alloc size space) = Alloc size space
removeOpRanges (Inner k) = Inner $ removeOpRanges k
addOpRanges (Alloc size space) = Alloc size space
addOpRanges (Inner k) = Inner $ addOpRanges k
instance Rename inner => Rename (MemOp inner) where
rename (Alloc size space) = Alloc <$> rename size <*> pure space
rename (Inner k) = Inner <$> rename k
instance Substitute inner => Substitute (MemOp inner) where
substituteNames subst (Alloc size space) = Alloc (substituteNames subst size) space
substituteNames subst (Inner k) = Inner $ substituteNames subst k
instance PP.Pretty inner => PP.Pretty (MemOp inner) where
ppr (Alloc e DefaultSpace) = PP.text "alloc" <> PP.apply [PP.ppr e]
ppr (Alloc e (Space sp)) = PP.text "alloc" <> PP.apply [PP.ppr e, PP.text sp]
ppr (Inner k) = PP.ppr k
instance OpMetrics inner => OpMetrics (MemOp inner) where
opMetrics Alloc{} = seen "Alloc"
opMetrics (Inner k) = opMetrics k
instance IsOp inner => IsOp (MemOp inner) where
safeOp Alloc{} = True
safeOp (Inner k) = safeOp k
cheapOp (Inner k) = cheapOp k
cheapOp Alloc{} = True
instance UsageInOp inner => UsageInOp (MemOp inner) where
usageInOp Alloc {} = mempty
usageInOp (Inner k) = usageInOp k
instance CanBeWise inner => CanBeWise (MemOp inner) where
type OpWithWisdom (MemOp inner) = MemOp (OpWithWisdom inner)
removeOpWisdom (Alloc size space) = Alloc size space
removeOpWisdom (Inner k) = Inner $ removeOpWisdom k
instance ST.IndexOp inner => ST.IndexOp (MemOp inner) where
indexOp vtable k (Inner op) is = ST.indexOp vtable k op is
indexOp _ _ _ _ = Nothing
instance Annotations ExplicitMemory where
type LetAttr ExplicitMemory = MemInfo SubExp NoUniqueness MemBind
type FParamAttr ExplicitMemory = MemInfo SubExp Uniqueness MemBind
type LParamAttr ExplicitMemory = MemInfo SubExp NoUniqueness MemBind
type RetType ExplicitMemory = FunReturns
type BranchType ExplicitMemory = BodyReturns
type Op ExplicitMemory = MemOp (Kernel InKernel)
instance Annotations InKernel where
type LetAttr InKernel = MemInfo SubExp NoUniqueness MemBind
type FParamAttr InKernel = MemInfo SubExp Uniqueness MemBind
type LParamAttr InKernel = MemInfo SubExp NoUniqueness MemBind
type RetType InKernel = FunReturns
type BranchType InKernel = BodyReturns
type Op InKernel = MemOp (KernelExp InKernel)
type IxFun = IxFun.IxFun (PrimExp VName)
type ExtIxFun = IxFun.IxFun (PrimExp (Ext VName))
data MemInfo d u ret = MemPrim PrimType
| MemMem d Space
| MemArray PrimType (ShapeBase d) u ret
deriving (Eq, Show, Ord)
type MemBound u = MemInfo SubExp u MemBind
instance FixExt ret => DeclExtTyped (MemInfo ExtSize Uniqueness ret) where
declExtTypeOf (MemPrim pt) = Prim pt
declExtTypeOf (MemMem (Free size) space) = Mem size space
declExtTypeOf (MemMem Ext{} space) = Mem (intConst Int32 0) space
declExtTypeOf (MemArray pt shape u _) = Array pt shape u
instance FixExt ret => ExtTyped (MemInfo ExtSize NoUniqueness ret) where
extTypeOf (MemPrim pt) = Prim pt
extTypeOf (MemMem (Free size) space) = Mem size space
extTypeOf (MemMem Ext{} space) = Mem (intConst Int32 0) space
extTypeOf (MemArray pt shape u _) = Array pt shape u
instance FixExt ret => FixExt (MemInfo ExtSize u ret) where
fixExt _ _ (MemPrim pt) = MemPrim pt
fixExt i se (MemMem size space) = MemMem (fixExt i se size) space
fixExt i se (MemArray pt shape u ret) =
MemArray pt (fixExt i se shape) u (fixExt i se ret)
instance Typed (MemInfo SubExp Uniqueness ret) where
typeOf = fromDecl . declTypeOf
instance Typed (MemInfo SubExp NoUniqueness ret) where
typeOf (MemPrim pt) = Prim pt
typeOf (MemMem size space) = Mem size space
typeOf (MemArray bt shape u _) = Array bt shape u
instance DeclTyped (MemInfo SubExp Uniqueness ret) where
declTypeOf (MemPrim bt) = Prim bt
declTypeOf (MemMem size space) = Mem size space
declTypeOf (MemArray bt shape u _) = Array bt shape u
instance (FreeIn d, FreeIn ret) => FreeIn (MemInfo d u ret) where
freeIn (MemArray _ shape _ ret) = freeIn shape <> freeIn ret
freeIn (MemMem size _) = freeIn size
freeIn (MemPrim _) = mempty
instance (Substitute d, Substitute ret) => Substitute (MemInfo d u ret) where
substituteNames subst (MemArray bt shape u ret) =
MemArray bt
(substituteNames subst shape) u
(substituteNames subst ret)
substituteNames substs (MemMem size space) =
MemMem (substituteNames substs size) space
substituteNames _ (MemPrim bt) =
MemPrim bt
instance (Substitute d, Substitute ret) => Rename (MemInfo d u ret) where
rename = substituteRename
simplifyIxFun :: Engine.SimplifiableLore lore =>
IxFun -> Engine.SimpleM lore IxFun
simplifyIxFun = traverse simplifyPrimExp
simplifyExtIxFun :: Engine.SimplifiableLore lore =>
ExtIxFun -> Engine.SimpleM lore ExtIxFun
simplifyExtIxFun = traverse simplifyExtPrimExp
isStaticIxFun :: ExtIxFun -> Maybe IxFun
isStaticIxFun = traverse $ traverse inst
where inst Ext{} = Nothing
inst (Free x) = Just x
instance (Engine.Simplifiable d, Engine.Simplifiable ret) =>
Engine.Simplifiable (MemInfo d u ret) where
simplify (MemPrim bt) =
return $ MemPrim bt
simplify (MemMem size space) =
MemMem <$> Engine.simplify size <*> pure space
simplify (MemArray bt shape u ret) =
MemArray bt <$> Engine.simplify shape <*> pure u <*> Engine.simplify ret
instance (PP.Pretty (TypeBase (ShapeBase d) u),
PP.Pretty d, PP.Pretty u, PP.Pretty ret) => PP.Pretty (MemInfo d u ret) where
ppr (MemPrim bt) = PP.ppr bt
ppr (MemMem s DefaultSpace) =
PP.text "mem" <> PP.parens (PP.ppr s)
ppr (MemMem s (Space sp)) =
PP.text "mem" <> PP.parens (PP.ppr s) <> PP.text "@" <> PP.text sp
ppr (MemArray bt shape u ret) =
PP.ppr (Array bt shape u) <> PP.text "@" <> PP.ppr ret
instance PP.Pretty (Param (MemInfo SubExp Uniqueness ret)) where
ppr = PP.ppr . fmap declTypeOf
instance PP.Pretty (Param (MemInfo SubExp NoUniqueness ret)) where
ppr = PP.ppr . fmap typeOf
instance PP.Pretty (PatElemT (MemInfo SubExp NoUniqueness ret)) where
ppr = PP.ppr . fmap typeOf
data MemBind = ArrayIn VName IxFun
deriving (Show)
instance Eq MemBind where
_ == _ = True
instance Ord MemBind where
_ `compare` _ = EQ
instance Rename MemBind where
rename = substituteRename
instance Substitute MemBind where
substituteNames substs (ArrayIn ident ixfun) =
ArrayIn (substituteNames substs ident) (substituteNames substs ixfun)
instance PP.Pretty MemBind where
ppr (ArrayIn mem ixfun) =
PP.text "@" <> PP.ppr mem <> PP.text "->" <> PP.ppr ixfun
instance FreeIn MemBind where
freeIn (ArrayIn mem ixfun) = freeIn mem <> freeIn ixfun
data MemReturn = ReturnsInBlock VName ExtIxFun
| ReturnsNewBlock Space Int ExtSize ExtIxFun
deriving (Show)
instance Eq MemReturn where
_ == _ = True
instance Ord MemReturn where
_ `compare` _ = EQ
instance Rename MemReturn where
rename = substituteRename
instance Substitute MemReturn where
substituteNames substs (ReturnsInBlock ident ixfun) =
ReturnsInBlock (substituteNames substs ident) (substituteNames substs ixfun)
substituteNames substs (ReturnsNewBlock space i size ixfun) =
ReturnsNewBlock space i (substituteNames substs size) (substituteNames substs ixfun)
instance FixExt MemReturn where
fixExt i (Var v) (ReturnsNewBlock _ j _ ixfun)
| j == i = ReturnsInBlock v $ fixExtIxFun i
(primExpFromSubExp int32 (Var v)) ixfun
fixExt i se (ReturnsNewBlock space j size ixfun) =
ReturnsNewBlock space j' (fixExt i se size)
(fixExtIxFun i (primExpFromSubExp int32 se) ixfun)
where j' | i < j = j-1
| otherwise = j
fixExt i se (ReturnsInBlock mem ixfun) =
ReturnsInBlock mem (fixExtIxFun i (primExpFromSubExp int32 se) ixfun)
fixExtIxFun :: Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun i e = fmap $ replaceInPrimExp update
where update (Ext j) t | j > i = LeafExp (Ext $ j - 1) t
| j == i = fmap Free e
| otherwise = LeafExp (Ext j) t
update (Free x) t = LeafExp (Free x) t
leafExp :: Int -> PrimExp (Ext a)
leafExp i = LeafExp (Ext i) int32
existentialiseIxFun :: [VName] -> IxFun -> ExtIxFun
existentialiseIxFun ctx = IxFun.substituteInIxFun ctx' . fmap (fmap Free)
where ctx' = M.map leafExp $ M.fromList $ zip (map Free ctx) [0..]
instance PP.Pretty MemReturn where
ppr (ReturnsInBlock v ixfun) =
PP.parens $ PP.text (pretty v) <> PP.text "->" <> PP.ppr ixfun
ppr (ReturnsNewBlock space i size ixfun) =
PP.text ("?" ++ show i) <> space' <> PP.parens (PP.ppr size)
<> PP.text "->" <> PP.ppr ixfun
where space' = case space of DefaultSpace -> mempty
Space s -> PP.text $ "@" ++ s
instance FreeIn MemReturn where
freeIn (ReturnsInBlock v ixfun) = freeIn v <> freeIn ixfun
freeIn _ = mempty
instance Engine.Simplifiable MemReturn where
simplify (ReturnsNewBlock space i size ixfun) =
ReturnsNewBlock space i <$> Engine.simplify size <*> simplifyExtIxFun ixfun
simplify (ReturnsInBlock v ixfun) =
ReturnsInBlock <$> Engine.simplify v <*> simplifyExtIxFun ixfun
instance Engine.Simplifiable MemBind where
simplify (ArrayIn mem ixfun) =
ArrayIn <$> Engine.simplify mem <*> simplifyIxFun ixfun
instance Engine.Simplifiable [FunReturns] where
simplify = mapM Engine.simplify
type ExpReturns = MemInfo ExtSize NoUniqueness (Maybe MemReturn)
type BodyReturns = MemInfo ExtSize NoUniqueness MemReturn
type FunReturns = MemInfo ExtSize Uniqueness MemReturn
maybeReturns :: MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns (MemArray bt shape u ret) =
MemArray bt shape u $ Just ret
maybeReturns (MemPrim bt) =
MemPrim bt
maybeReturns (MemMem size space) =
MemMem size space
noUniquenessReturns :: MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemArray bt shape _ r) =
MemArray bt shape NoUniqueness r
noUniquenessReturns (MemPrim bt) =
MemPrim bt
noUniquenessReturns (MemMem size space) =
MemMem size space
funReturnsToExpReturns :: FunReturns -> ExpReturns
funReturnsToExpReturns = noUniquenessReturns . maybeReturns
bodyReturnsToExpReturns :: BodyReturns -> ExpReturns
bodyReturnsToExpReturns = noUniquenessReturns . maybeReturns
instance TC.CheckableOp ExplicitMemory where
checkOp (Alloc size _) = TC.require [Prim int64] size
checkOp (Inner k) = TC.subCheck $ typeCheckKernel k
instance TC.CheckableOp InKernel where
checkOp (Alloc size _) = TC.require [Prim int64] size
checkOp (Inner k) = TC.subCheck $ typeCheckKernelExp k
instance TC.Checkable ExplicitMemory where
checkFParamLore = checkMemInfo
checkLParamLore = checkMemInfo
checkLetBoundLore = checkMemInfo
checkRetType = mapM_ TC.checkExtType . retTypeValues
primFParam name t = return $ Param name (MemPrim t)
matchPattern = matchPatternToExp
matchReturnType = matchFunctionReturnType
matchBranchType = matchBranchReturnType
instance TC.Checkable InKernel where
checkFParamLore = checkMemInfo
checkLParamLore = checkMemInfo
checkLetBoundLore = checkMemInfo
checkRetType = mapM_ TC.checkExtType . retTypeValues
primFParam name t = return $ Param name (MemPrim t)
matchPattern = matchPatternToExp
matchReturnType = matchFunctionReturnType
matchBranchType = matchBranchReturnType
matchFunctionReturnType :: ExplicitMemorish lore =>
[FunReturns] -> Result -> TC.TypeM lore ()
matchFunctionReturnType rettype result = do
TC.matchExtReturnType (fromDecl <$> ts) result
scope <- askScope
result_ts <- runReaderT (mapM subExpMemInfo result) $ removeScopeAliases scope
matchReturnType rettype result result_ts
mapM_ checkResultSubExp result
where ts = map declExtTypeOf rettype
checkResultSubExp Constant{} =
return ()
checkResultSubExp (Var v) = do
attr <- varMemInfo v
case attr of
MemPrim _ -> return ()
MemMem{} -> return ()
MemArray _ _ _ (ArrayIn _ ixfun)
| IxFun.isLinear ixfun ->
return ()
| otherwise ->
TC.bad $ TC.TypeError $
"Array " ++ pretty v ++
" returned by function, but has nontrivial index function " ++
pretty ixfun ++ " " ++ show ixfun
matchBranchReturnType :: ExplicitMemorish lore =>
[BodyReturns]
-> Body (Aliases lore)
-> TC.TypeM lore ()
matchBranchReturnType rettype (Body _ stms res) = do
scope <- askScope
ts <- runReaderT (mapM subExpMemInfo res) $ removeScopeAliases (scope <> scopeOf stms)
matchReturnType rettype res ts
getExtMaps :: [(VName,Int)] -> (M.Map (Ext VName) (PrimExp (Ext VName)),
M.Map (Ext VName) (PrimExp (Ext VName)))
getExtMaps ctx_lst_ids =
(M.map leafExp $ M.mapKeys Free $ M.fromListWith (flip const) ctx_lst_ids,
M.fromList $
mapMaybe (traverse (fmap (\i -> LeafExp (Ext i) int32) .
(`lookup` ctx_lst_ids)) .
uncurry (flip (,)) . fmap Ext) ctx_lst_ids)
matchReturnType :: PP.Pretty u =>
[MemInfo ExtSize u MemReturn]
-> [SubExp]
-> [MemInfo SubExp NoUniqueness MemBind]
-> TC.TypeM lore ()
matchReturnType rettype res ts = do
let (ctx_ts, val_ts) = splitFromEnd (length rettype) ts
(ctx_res, _val_res) = splitFromEnd (length rettype) res
getId :: (SubExp,Int) -> Maybe (VName,Int)
getId (Var ii, i) = Just (ii,i)
getId (Constant _, _) = Nothing
(ctx_map_ids, ctx_map_exts) =
getExtMaps $ mapMaybe getId $ zip ctx_res [0..length ctx_res - 1]
existentialiseIxFun0 :: IxFun -> ExtIxFun
existentialiseIxFun0 = IxFun.substituteInIxFun ctx_map_ids . fmap (fmap Free)
getCt :: (Int,SubExp) -> Maybe (Ext VName, PrimExp (Ext VName))
getCt (_, Var _) = Nothing
getCt (i, Constant c) = Just (Ext i, ValueExp c)
ctx_map_cts = M.fromList $ mapMaybe getCt $
zip [0..length ctx_res - 1] ctx_res
substConstsInExtIndFun :: ExtIxFun -> ExtIxFun
substConstsInExtIndFun = IxFun.substituteInIxFun (ctx_map_cts<>ctx_map_exts)
fetchCtx i = case maybeNth i $ zip ctx_res ctx_ts of
Nothing -> throwError $ "Cannot find context variable " ++
show i ++ " in context results: " ++ pretty ctx_res
Just (se, t) -> return (se, t)
checkReturn (MemPrim x) (MemPrim y)
| x == y = return ()
checkReturn (MemMem x _) (MemMem y _) =
checkDim x y
checkReturn (MemArray x_pt x_shape _ x_ret)
(MemArray y_pt y_shape _ y_ret)
| x_pt == y_pt, shapeRank x_shape == shapeRank y_shape = do
zipWithM_ checkDim (shapeDims x_shape) (shapeDims y_shape)
checkMemReturn x_ret y_ret
checkReturn x y =
throwError $ unwords ["Expected ", pretty x, " but got ", pretty y]
checkDim (Free x) y
| x == y = return ()
| otherwise = throwError $ unwords ["Expected dim", pretty x,
"but got", pretty y]
checkDim (Ext i) y = do
(x, _) <- fetchCtx i
unless (x == y) $
throwError $ unwords ["Expected ext dim", pretty i, "=>", pretty x,
"but got", pretty y]
checkMemReturn (ReturnsInBlock x_mem x_ixfun) (ArrayIn y_mem y_ixfun)
| x_mem == y_mem = do
let x_ixfun' = substConstsInExtIndFun x_ixfun
y_ixfun' = existentialiseIxFun0 y_ixfun
unless (x_ixfun' == y_ixfun') $
throwError $ unwords ["Index function unification fails1!",
"\nixfun of body result: ", pretty y_ixfun',
"\nixfun of return type: ", pretty x_ixfun',
"\nand context elements: ", pretty ctx_res]
checkMemReturn (ReturnsNewBlock x_space x_ext x_mem_size x_ixfun)
(ArrayIn y_mem y_ixfun) = do
(x_mem, x_mem_type) <- fetchCtx x_ext
let x_ixfun' = substConstsInExtIndFun x_ixfun
y_ixfun' = existentialiseIxFun0 y_ixfun
unless (x_ixfun' == y_ixfun') $
throwError $ unwords ["Index function unification fails2!",
"\nixfun of body result: ", pretty y_ixfun',
"\nixfun of return type: ", pretty x_ixfun',
"\nand context elements: ", pretty ctx_res]
case x_mem_type of
MemMem y_mem_size y_space -> do
unless (x_mem == Var y_mem) $
throwError $ unwords ["Expected memory", pretty x_ext, "=>", pretty x_mem,
"but got", pretty y_mem]
unless (x_space == y_space) $
throwError $ unwords ["Expected memory", pretty y_mem, "in space", pretty x_space,
"but actually in space", pretty y_space]
checkDim x_mem_size y_mem_size
t ->
throwError $ unwords ["Expected memory", pretty x_ext, "=>", pretty x_mem,
"but but has type", pretty t]
checkMemReturn x y =
throwError $ unwords ["Expected array in", pretty x,
"but array returned in", pretty y]
bad :: String -> TC.TypeM lore a
bad s = TC.bad $ TC.TypeError $
unlines [ "Return type"
, " " ++ prettyTuple rettype
, "cannot match returns of results"
, " " ++ prettyTuple ts
, s
]
either bad return =<< runExceptT (zipWithM_ checkReturn rettype val_ts)
matchPatternToExp :: (ExplicitMemorish lore) =>
Pattern (Aliases lore)
-> Exp (Aliases lore)
-> TC.TypeM lore ()
matchPatternToExp pat e = do
scope <- asksScope removeScopeAliases
rt <- runReaderT (expReturns $ removeExpAliases e) scope
let (ctxs, vals) = bodyReturnsFromPattern $ removePatternAliases pat
(ctx_ids, _ctx_ts) = unzip ctxs
(_val_ids, val_ts) = unzip vals
(ctx_map_ids, ctx_map_exts) =
getExtMaps $ zip ctx_ids [0..length ctx_ids - 1]
unless (length val_ts == length rt &&
and (zipWith (matches ctx_map_ids ctx_map_exts) val_ts rt)) $
TC.bad $ TC.TypeError $ "Expression type:\n " ++ prettyTuple rt ++
"\ncannot match pattern type:\n " ++ prettyTuple val_ts ++
"\nwith context elements: " ++ pretty ctx_ids
where matches _ _ (MemPrim x) (MemPrim y) = x == y
matches _ _ (MemMem x_size x_space) (MemMem y_size y_space) =
x_size == y_size && x_space == y_space
matches ctxids ctxexts (MemArray x_pt x_shape _ x_ret) (MemArray y_pt y_shape _ y_ret) =
x_pt == y_pt && x_shape == y_shape &&
case (x_ret, y_ret) of
(ReturnsInBlock x_mem x_ixfun, Just (ReturnsInBlock y_mem y_ixfun)) ->
let x_ixfun' = IxFun.substituteInIxFun ctxids x_ixfun
y_ixfun' = IxFun.substituteInIxFun ctxexts y_ixfun
in x_mem == y_mem && x_ixfun' == y_ixfun'
(ReturnsInBlock _ x_ixfun,
Just (ReturnsNewBlock _ _ _ y_ixfun)) ->
let x_ixfun' = IxFun.substituteInIxFun ctxids x_ixfun
y_ixfun' = IxFun.substituteInIxFun ctxexts y_ixfun
in x_ixfun' == y_ixfun'
(ReturnsNewBlock x_space x_i x_size x_ixfun,
Just (ReturnsNewBlock y_space y_i y_size y_ixfun)) ->
let x_ixfun' = IxFun.substituteInIxFun ctxids x_ixfun
y_ixfun' = IxFun.substituteInIxFun ctxexts y_ixfun
in x_space == y_space && x_i == y_i &&
x_size == y_size && x_ixfun' == y_ixfun'
(_, Nothing) -> True
_ -> False
matches _ _ _ _ = False
varMemInfo :: ExplicitMemorish lore =>
VName -> TC.TypeM lore (MemInfo SubExp NoUniqueness MemBind)
varMemInfo name = do
attr <- TC.lookupVar name
case attr of
LetInfo (_, summary) -> return summary
FParamInfo summary -> return $ noUniquenessReturns summary
LParamInfo summary -> return summary
IndexInfo it -> return $ MemPrim $ IntType it
nameInfoToMemInfo :: ExplicitMemorish lore => NameInfo lore -> MemBound NoUniqueness
nameInfoToMemInfo info =
case info of
FParamInfo summary -> noUniquenessReturns summary
LParamInfo summary -> summary
LetInfo summary -> summary
IndexInfo it -> MemPrim $ IntType it
lookupMemInfo :: (HasScope lore m, ExplicitMemorish lore) =>
VName -> m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo = fmap nameInfoToMemInfo . lookupInfo
subExpMemInfo :: (HasScope lore m, Monad m, ExplicitMemorish lore) =>
SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
subExpMemInfo (Var v) = lookupMemInfo v
subExpMemInfo (Constant v) = return $ MemPrim $ primValueType v
lookupArraySummary :: (ExplicitMemorish lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun.IxFun (PrimExp VName))
lookupArraySummary name = do
summary <- lookupMemInfo name
case summary of
MemArray _ _ _ (ArrayIn mem ixfun) ->
return (mem, ixfun)
_ ->
fail $ "Variable " ++ pretty name ++ " does not look like an array."
lookupMemSize :: (HasScope lore m, Monad m) =>
VName -> m SubExp
lookupMemSize v = do
t <- lookupType v
case t of Mem size _ -> return size
_ -> fail $ "lookupMemSize: " ++ pretty v ++ " is not a memory block."
checkMemInfo :: TC.Checkable lore =>
VName -> MemInfo SubExp u MemBind
-> TC.TypeM lore ()
checkMemInfo _ (MemPrim _) = return ()
checkMemInfo _ (MemMem size _) =
TC.require [Prim int64] size
checkMemInfo name (MemArray _ shape _ (ArrayIn v ixfun)) = do
t <- lookupType v
case t of
Mem{} ->
return ()
_ ->
TC.bad $ TC.TypeError $
"Variable " ++ pretty v ++
" used as memory block, but is of type " ++
pretty t ++ "."
TC.context ("in index function " ++ pretty ixfun) $ do
traverse_ (TC.requirePrimExp int32) ixfun
let ixfun_rank = IxFun.rank ixfun
ident_rank = shapeRank shape
unless (ixfun_rank == ident_rank) $
TC.bad $ TC.TypeError $
"Arity of index function (" ++ pretty ixfun_rank ++
") does not match rank of array " ++ pretty name ++
" (" ++ show ident_rank ++ ")"
instance Attributes ExplicitMemory where
expTypesFromPattern = return . map snd . snd . bodyReturnsFromPattern
instance Attributes InKernel where
expTypesFromPattern = return . map snd . snd . bodyReturnsFromPattern
bodyReturnsFromPattern :: PatternT (MemBound NoUniqueness)
-> ([(VName,BodyReturns)], [(VName,BodyReturns)])
bodyReturnsFromPattern pat =
(map asReturns $ patternContextElements pat,
map asReturns $ patternValueElements pat)
where ctx = patternContextElements pat
ext (Var v)
| Just (i, _) <- find ((==v) . patElemName . snd) $ zip [0..] ctx =
Ext i
ext se = Free se
asReturns pe =
(patElemName pe,
case patElemAttr pe of
MemPrim pt -> MemPrim pt
MemMem size space -> MemMem (ext size) space
MemArray pt shape u (ArrayIn mem ixfun) ->
MemArray pt (Shape $ map ext $ shapeDims shape) u $
case find ((==mem) . patElemName . snd) $ zip [0..] ctx of
Just (i, PatElem _ (MemMem size space)) ->
ReturnsNewBlock space i (ext size) $
existentialiseIxFun (map patElemName ctx) ixfun
_ -> ReturnsInBlock mem $ existentialiseIxFun [] ixfun
)
instance (PP.Pretty u, PP.Pretty r) => PrettyAnnot (PatElemT (MemInfo SubExp u r)) where
ppAnnot = bindeeAnnot patElemName patElemAttr
instance (PP.Pretty u, PP.Pretty r) => PrettyAnnot (ParamT (MemInfo SubExp u r)) where
ppAnnot = bindeeAnnot paramName paramAttr
instance PrettyLore ExplicitMemory where
instance PrettyLore InKernel where
bindeeAnnot :: (PP.Pretty u, PP.Pretty r) =>
(a -> VName) -> (a -> MemInfo SubExp u r)
-> a -> Maybe PP.Doc
bindeeAnnot bindeeName bindeeLore bindee =
case bindeeLore bindee of
attr@MemArray{} ->
Just $
PP.text "-- " <>
PP.oneLine (PP.ppr (bindeeName bindee) <>
PP.text " : " <>
PP.ppr attr)
MemMem {} ->
Nothing
MemPrim _ ->
Nothing
extReturns :: [ExtType] -> [ExpReturns]
extReturns ts =
evalState (mapM addAttr ts) 0
where addAttr (Prim bt) =
return $ MemPrim bt
addAttr (Mem size space) =
return $ MemMem (Free size) space
addAttr t@(Array bt shape u)
| existential t = do
i <- get <* modify (+2)
return $ MemArray bt shape u $ Just $
ReturnsNewBlock DefaultSpace (i+1) (Ext i) $
IxFun.iota $ map convert $ shapeDims shape
| otherwise =
return $ MemArray bt shape u Nothing
convert (Ext i) = LeafExp (Ext i) int32
convert (Free v) = Free <$> primExpFromSubExp int32 v
arrayVarReturns :: (HasScope lore m, Monad m, ExplicitMemorish lore) =>
VName
-> m (PrimType, Shape, VName, IxFun.IxFun (PrimExp VName))
arrayVarReturns v = do
summary <- lookupMemInfo v
case summary of
MemArray et shape _ (ArrayIn mem ixfun) ->
return (et, Shape $ shapeDims shape, mem, ixfun)
_ ->
fail $ "arrayVarReturns: " ++ pretty v ++ " is not an array."
varReturns :: (HasScope lore m, Monad m, ExplicitMemorish lore) =>
VName -> m ExpReturns
varReturns v = do
summary <- lookupMemInfo v
case summary of
MemPrim bt ->
return $ MemPrim bt
MemArray et shape _ (ArrayIn mem ixfun) ->
return $ MemArray et (fmap Free shape) NoUniqueness $
Just $ ReturnsInBlock mem $ existentialiseIxFun [] ixfun
MemMem size space ->
return $ MemMem (Free size) space
expReturns :: (Monad m, HasScope lore m,
ExplicitMemorish lore) =>
Exp lore -> m [ExpReturns]
expReturns (BasicOp (SubExp (Var v))) =
pure <$> varReturns v
expReturns (BasicOp (Opaque (Var v))) =
pure <$> varReturns v
expReturns (BasicOp (Repeat outer_shapes inner_shape v)) = do
t <- repeatDims outer_shapes inner_shape <$> lookupType v
(et, _, mem, ixfun) <- arrayVarReturns v
let outer_shapes' = map (map (primExpFromSubExp int32) . shapeDims) outer_shapes
inner_shape' = map (primExpFromSubExp int32) $ shapeDims inner_shape
return [MemArray et (Shape $ map Free $ arrayDims t) NoUniqueness $
Just $ ReturnsInBlock mem $ existentialiseIxFun [] $
IxFun.repeat ixfun outer_shapes' inner_shape']
expReturns (BasicOp (Reshape newshape v)) = do
(et, _, mem, ixfun) <- arrayVarReturns v
return [MemArray et (Shape $ map (Free . newDim) newshape) NoUniqueness $
Just $ ReturnsInBlock mem $ existentialiseIxFun [] $
IxFun.reshape ixfun $ map (fmap $ primExpFromSubExp int32) newshape]
expReturns (BasicOp (Rearrange perm v)) = do
(et, Shape dims, mem, ixfun) <- arrayVarReturns v
let ixfun' = IxFun.permute ixfun perm
dims' = rearrangeShape perm dims
return [MemArray et (Shape $ map Free dims') NoUniqueness $
Just $ ReturnsInBlock mem $ existentialiseIxFun [] ixfun']
expReturns (BasicOp (Rotate offsets v)) = do
(et, Shape dims, mem, ixfun) <- arrayVarReturns v
let offsets' = map (primExpFromSubExp int32) offsets
ixfun' = IxFun.rotate ixfun offsets'
return [MemArray et (Shape $ map Free dims) NoUniqueness $
Just $ ReturnsInBlock mem $ existentialiseIxFun [] ixfun']
expReturns (BasicOp (Index v slice)) = do
info <- sliceInfo v slice
case info of
MemArray et shape u (ArrayIn mem ixfun) ->
return [MemArray et (fmap Free shape) u $
Just $ ReturnsInBlock mem $ existentialiseIxFun [] ixfun]
MemPrim pt -> return [MemPrim pt]
MemMem d space -> return [MemMem (Free d) space]
expReturns (BasicOp (Update v _ _)) =
pure <$> varReturns v
expReturns (BasicOp op) =
extReturns . staticShapes <$> primOpType op
expReturns (DoLoop ctx val _ _) =
zipWithM typeWithAttr
(loopExtType (map (paramIdent . fst) ctx) (map (paramIdent . fst) val)) $ map fst val
where typeWithAttr t p =
case (t, paramAttr p) of
(Array bt shape u, MemArray _ _ _ (ArrayIn mem ixfun))
| Just (i, mem_p) <- isMergeVar mem,
Mem mem_size space <- paramType mem_p ->
let ext_size
| Just (j, _) <- isMergeVar =<< subExpVar mem_size = Ext j
| otherwise = Free mem_size
in return $ MemArray bt shape u $ Just $ ReturnsNewBlock space i ext_size ixfun'
| otherwise ->
return (MemArray bt shape u $
Just $ ReturnsInBlock mem ixfun')
where ixfun' = existentialiseIxFun (map paramName mergevars) ixfun
(Array{}, _) ->
fail "expReturns: Array return type but not array merge variable."
(Prim bt, _) ->
return $ MemPrim bt
(Mem{}, _) ->
fail "expReturns: loop returns memory block explicitly."
isMergeVar v = find ((==v) . paramName . snd) $ zip [0..] mergevars
mergevars = map fst $ ctx ++ val
expReturns (Apply _ _ ret _) =
return $ map funReturnsToExpReturns ret
expReturns (If _ _ _ (IfAttr ret _)) =
return $ map bodyReturnsToExpReturns ret
expReturns (Op op) =
opReturns op
sliceInfo :: (Monad m, HasScope lore m, ExplicitMemorish lore) =>
VName
-> Slice SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
sliceInfo v slice = do
(et, _, mem, ixfun) <- arrayVarReturns v
case sliceDims slice of
[] -> return $ MemPrim et
dims ->
return $ MemArray et (Shape dims) NoUniqueness $
ArrayIn mem $ IxFun.slice ixfun
(map (fmap (primExpFromSubExp int32)) slice)
class TypedOp (Op lore) => OpReturns lore where
opReturns :: (Monad m, HasScope lore m) =>
Op lore -> m [ExpReturns]
opReturns op = extReturns <$> opType op
instance OpReturns ExplicitMemory where
opReturns (Alloc size space) =
return [MemMem (Free size) space]
opReturns (Inner k@(Kernel _ _ _ body)) =
zipWithM correct (kernelBodyResult body) =<< (extReturns <$> opType k)
where correct (WriteReturn _ arr _) _ = varReturns arr
correct (KernelInPlaceReturn arr) _ =
extendedScope (varReturns arr)
(castScope $ scopeOf $ kernelBodyStms body)
correct _ ret = return ret
opReturns k =
extReturns <$> opType k
instance OpReturns InKernel where
opReturns (Alloc size space) =
return [MemMem (Free size) space]
opReturns (Inner (GroupStream _ _ lam _ _)) =
forM (groupStreamAccParams lam) $ \param ->
case paramAttr param of
MemPrim bt ->
return $ MemPrim bt
MemArray et shape _ (ArrayIn mem ixfun) ->
return $ MemArray et (Shape $ map Free $ shapeDims shape) NoUniqueness $
Just $ ReturnsInBlock mem $ existentialiseIxFun [] ixfun
MemMem size space ->
return $ MemMem (Free size) space
opReturns (Inner (GroupScan _ _ input)) =
mapM varReturns arrs
where arrs = map snd input
opReturns (Inner (GroupGenReduce _ dests _ _ _ _)) =
mapM varReturns dests
opReturns (Inner (Barrier res)) = mapM f res
where f (Var v) = varReturns v
f (Constant v) = return $ MemPrim $ primValueType v
opReturns (Inner (Combine (CombineSpace scatter cspace) ts _ _)) =
(++) <$> mapM varReturns as <*>
pure (extReturns $ staticShapes $ map (`arrayOfShape` shape) $ drop (sum ns*2) ts)
where (_, ns, as) = unzip3 scatter
shape = Shape $ map snd cspace
opReturns k =
extReturns <$> opType k
applyFunReturns :: Typed attr =>
[FunReturns]
-> [Param attr]
-> [(SubExp,Type)]
-> Maybe [FunReturns]
applyFunReturns rets params args
| Just _ <- applyRetType rettype params args =
Just $ map correctDims rets
| otherwise =
Nothing
where rettype = map declExtTypeOf rets
parammap :: M.Map VName (SubExp, Type)
parammap = M.fromList $
zip (map paramName params) args
substSubExp (Var v)
| Just (se,_) <- M.lookup v parammap = se
substSubExp se = se
correctDims (MemPrim t) =
MemPrim t
correctDims (MemMem (Free se) space) =
MemMem (Free $ substSubExp se) space
correctDims (MemMem (Ext d) space) =
MemMem (Ext d) space
correctDims (MemArray et shape u memsummary) =
MemArray et (correctShape shape) u $
correctSummary memsummary
correctShape = Shape . map correctDim . shapeDims
correctDim (Ext i) = Ext i
correctDim (Free se) = Free $ substSubExp se
correctSummary (ReturnsNewBlock space i size ixfun) =
ReturnsNewBlock space i size ixfun
correctSummary (ReturnsInBlock mem ixfun) =
ReturnsInBlock mem' ixfun
where mem' = case M.lookup mem parammap of
Just (Var v, _) -> v
_ -> mem
fullyLinear :: (Eq num, IntegralExp num) =>
ShapeBase num -> IxFun.IxFun num -> Bool
fullyLinear shape ixfun =
IxFun.isLinear ixfun && ixFunMatchesInnerShape shape ixfun
ixFunMatchesInnerShape :: (Eq num, IntegralExp num) =>
ShapeBase num -> IxFun.IxFun num -> Bool
ixFunMatchesInnerShape shape ixfun =
drop 1 (IxFun.shape ixfun) == drop 1 (shapeDims shape)