{-# LANGUAGE GeneralizedNewtypeDeriving, FlexibleContexts, FlexibleInstances, MultiParamTypeClasses, TypeFamilies, ScopedTypeVariables #-}
{-# LANGUAGE DefaultSignatures #-}
module Futhark.TypeCheck
(
checkProg
, TypeError (..)
, ErrorCase (..)
, TypeM
, bad
, context
, message
, Checkable (..)
, CheckableOp (..)
, lookupVar
, lookupAliases
, Occurences
, UsageMap
, usageMap
, collectOccurences
, subCheck
, require
, requireI
, requirePrimExp
, checkSubExp
, checkExp
, checkStms
, checkStm
, checkType
, checkExtType
, matchExtPattern
, matchExtReturnType
, matchExtBranchType
, argType
, argAliases
, noArgAliases
, checkArg
, checkSOACArrayArgs
, checkLambda
, checkFun'
, checkLambdaParams
, checkBody
, checkLambdaBody
, consume
, consumeOnlyParams
, binding
)
where
import Control.Parallel.Strategies
import Control.Monad.Reader
import Control.Monad.Writer
import Control.Monad.State
import Control.Monad.RWS.Strict
import Data.List
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import Futhark.Analysis.PrimExp
import Futhark.Construct (instantiateShapes)
import Futhark.Representation.Aliases
import Futhark.Analysis.Alias
import Futhark.Util
import Futhark.Util.Pretty (Pretty, prettyDoc, indent, ppr, text, (<+>), align)
data ErrorCase lore =
TypeError String
| UnexpectedType (Exp lore) Type [Type]
| ReturnTypeError Name [ExtType] [ExtType]
| DupDefinitionError Name
| DupParamError Name VName
| DupPatternError VName
| InvalidPatternError (Pattern (Aliases lore)) [ExtType] (Maybe String)
| UnknownVariableError VName
| UnknownFunctionError Name
| ParameterMismatch (Maybe Name) [Type] [Type]
| SlicingError Int Int
| BadAnnotation String Type Type
| ReturnAliased Name VName
| UniqueReturnAliased Name
| NotAnArray VName Type
| PermutationError [Int] Int (Maybe VName)
instance Checkable lore => Show (ErrorCase lore) where
show (TypeError msg) =
"Type error:\n" ++ msg
show (UnexpectedType e _ []) =
"Type of expression\n" ++
prettyDoc 160 (indent 2 $ ppr e) ++
"\ncannot have any type - possibly a bug in the type checker."
show (UnexpectedType e t ts) =
"Type of expression\n" ++
prettyDoc 160 (indent 2 $ ppr e) ++
"\nmust be one of " ++ intercalate ", " (map pretty ts) ++ ", but is " ++
pretty t ++ "."
show (ReturnTypeError fname rettype bodytype) =
"Declaration of function " ++ nameToString fname ++
" declares return type\n " ++ prettyTuple rettype ++
"\nBut body has type\n " ++ prettyTuple bodytype
show (DupDefinitionError name) =
"Duplicate definition of function " ++ nameToString name ++ ""
show (DupParamError funname paramname) =
"Parameter " ++ pretty paramname ++
" mentioned multiple times in argument list of function " ++
nameToString funname ++ "."
show (DupPatternError name) =
"Variable " ++ pretty name ++ " bound twice in pattern."
show (InvalidPatternError pat t desc) =
"Pattern " ++ pretty pat ++
" cannot match value of type " ++ prettyTuple t ++ end
where end = case desc of Nothing -> "."
Just desc' -> ":\n" ++ desc'
show (UnknownVariableError name) =
"Use of unknown variable " ++ pretty name ++ "."
show (UnknownFunctionError fname) =
"Call of unknown function " ++ nameToString fname ++ "."
show (ParameterMismatch fname expected got) =
"In call of " ++ fname' ++ ":\n" ++
"expecting " ++ show nexpected ++ " argument(s) of type(s) " ++
expected' ++ ", but got " ++ show ngot ++
" arguments of types " ++ intercalate ", " (map pretty got) ++ "."
where (nexpected, expected') =
(length expected, intercalate ", " $ map pretty expected)
ngot = length got
fname' = maybe "anonymous function" (("function "++) . nameToString) fname
show (SlicingError dims got) =
show got ++ " indices given, but type of indexee has " ++ show dims ++ " dimension(s)."
show (BadAnnotation desc expected got) =
"Annotation of \"" ++ desc ++ "\" type of expression is " ++ pretty expected ++
", but derived to be " ++ pretty got ++ "."
show (ReturnAliased fname name) =
"Unique return value of function " ++ nameToString fname ++
" is aliased to " ++ pretty name ++ ", which is not consumed."
show (UniqueReturnAliased fname) =
"A unique tuple element of return value of function " ++
nameToString fname ++ " is aliased to some other tuple component."
show (NotAnArray e t) =
"The expression " ++ pretty e ++
" is expected to be an array, but is " ++ pretty t ++ "."
show (PermutationError perm rank name) =
"The permutation (" ++ intercalate ", " (map show perm) ++
") is not valid for array " ++ name' ++ "of rank " ++ show rank ++ "."
where name' = maybe "" ((++" ") . pretty) name
data TypeError lore = Error [String] (ErrorCase lore)
instance Checkable lore => Show (TypeError lore) where
show (Error [] err) =
show err
show (Error msgs err) =
intercalate "\n" msgs ++ "\n" ++ show err
type FunBinding lore = ([RetType (Aliases lore)], [FParam (Aliases lore)])
type VarBinding lore = NameInfo (Aliases lore)
data Usage = Consumed
| Observed
deriving (Eq, Ord, Show)
data Occurence = Occurence { observed :: Names
, consumed :: Names
}
deriving (Eq, Show)
observation :: Names -> Occurence
observation = flip Occurence S.empty
consumption :: Names -> Occurence
consumption = Occurence S.empty
nullOccurence :: Occurence -> Bool
nullOccurence occ = S.null (observed occ) && S.null (consumed occ)
type Occurences = [Occurence]
type UsageMap = M.Map VName [Usage]
usageMap :: Occurences -> UsageMap
usageMap = foldl comb M.empty
where comb m (Occurence obs cons) =
let m' = S.foldl' (ins Observed) m obs
in S.foldl' (ins Consumed) m' cons
ins v m k = M.insertWith (++) k [v] m
allConsumed :: Occurences -> Names
allConsumed = S.unions . map consumed
seqOccurences :: Occurences -> Occurences -> Occurences
seqOccurences occurs1 occurs2 =
filter (not . nullOccurence) (map filt occurs1) ++ occurs2
where filt occ =
occ { observed = observed occ `S.difference` postcons }
postcons = allConsumed occurs2
altOccurences :: Occurences -> Occurences -> Occurences
altOccurences occurs1 occurs2 =
filter (not . nullOccurence) (map filt occurs1) ++ occurs2
where filt occ =
occ { consumed = consumed occ `S.difference` postcons
, observed = observed occ `S.difference` postcons }
postcons = allConsumed occurs2
unOccur :: Names -> Occurences -> Occurences
unOccur to_be_removed = filter (not . nullOccurence) . map unOccur'
where unOccur' occ =
occ { observed = observed occ `S.difference` to_be_removed
, consumed = consumed occ `S.difference` to_be_removed
}
data Consumption = ConsumptionError String
| Consumption Occurences
deriving (Show)
instance Semigroup Consumption where
ConsumptionError e <> _ = ConsumptionError e
_ <> ConsumptionError e = ConsumptionError e
Consumption o1 <> Consumption o2
| v:_ <- S.toList $ consumed_in_o1 `S.intersection` used_in_o2 =
ConsumptionError $ "Variable " <> pretty v <> " referenced after being consumed."
| otherwise =
Consumption $ o1 `seqOccurences` o2
where consumed_in_o1 = mconcat $ map consumed o1
used_in_o2 = mconcat $ map consumed o2 <> map observed o2
instance Monoid Consumption where
mempty = Consumption mempty
data Env lore =
Env { envVtable :: M.Map VName (VarBinding lore)
, envFtable :: M.Map Name (FunBinding lore)
, envContext :: [String]
}
newtype TypeM lore a = TypeM (RWST
(Env lore)
Consumption
Names
(Either (TypeError lore))
a)
deriving (Monad, Functor, Applicative,
MonadReader (Env lore),
MonadWriter Consumption,
MonadState Names)
instance Checkable lore =>
HasScope (Aliases lore) (TypeM lore) where
lookupType = fmap typeOf . lookupVar
askScope = asks $ M.fromList . mapMaybe varType . M.toList . envVtable
where varType (name, attr) = Just (name, attr)
runTypeM :: Env lore -> TypeM lore a
-> Either (TypeError lore) (a, Consumption)
runTypeM env (TypeM m) = evalRWST m env mempty
bad :: ErrorCase lore -> TypeM lore a
bad e = do
messages <- asks envContext
TypeM $ lift $ Left $ Error (reverse messages) e
context :: String
-> TypeM lore a
-> TypeM lore a
context s = local $ \env -> env { envContext = s : envContext env}
message :: Pretty a =>
String -> a -> String
message s x = prettyDoc 80 $
text s <+> align (ppr x)
bound :: VName -> TypeM lore ()
bound name = do already_seen <- gets $ S.member name
when already_seen $
bad $ TypeError $ "Name " ++ pretty name ++ " bound twice"
modify $ S.insert name
occur :: Occurences -> TypeM lore ()
occur = tell . Consumption . filter (not . nullOccurence)
observe :: Checkable lore =>
VName -> TypeM lore ()
observe name = do
attr <- lookupVar name
unless (primType $ typeOf attr) $
occur [observation $ S.insert name $ aliases attr]
consume :: Checkable lore => Names -> TypeM lore ()
consume als = do
scope <- askScope
let isArray = maybe False ((>0) . arrayRank . typeOf) . (`M.lookup` scope)
occur [consumption $ S.filter isArray als]
collectOccurences :: TypeM lore a -> TypeM lore (a, Occurences)
collectOccurences m = pass $ do
(x, c) <- listen m
o <- checkConsumption c
return ((x, o), const mempty)
checkConsumption :: Consumption -> TypeM lore Occurences
checkConsumption (ConsumptionError e) = bad $ TypeError e
checkConsumption (Consumption os) = return os
alternative :: TypeM lore a -> TypeM lore b -> TypeM lore (a,b)
alternative m1 m2 = pass $ do
(x, c1) <- listen m1
(y, c2) <- listen m2
os1 <- checkConsumption c1
os2 <- checkConsumption c2
let usage = Consumption $ os1 `altOccurences` os2
return ((x, y), const usage)
consumeOnlyParams :: [(VName, Names)] -> TypeM lore a -> TypeM lore a
consumeOnlyParams consumable m = do
(x, os) <- collectOccurences m
tell . Consumption =<< mapM inspect os
return x
where inspect o = do
new_consumed <- mconcat <$> mapM wasConsumed (S.toList $ consumed o)
return o { consumed = new_consumed }
wasConsumed v
| Just als <- lookup v consumable = return als
| otherwise =
bad $ TypeError $
unlines [pretty v ++ " was invalidly consumed.",
what ++ " can be consumed here."]
what | null consumable = "Nothing"
| otherwise = "Only " ++ intercalate ", " (map (pretty . fst) consumable)
expandAliases :: Names -> Env lore -> Names
expandAliases names env = names `S.union` aliasesOfAliases
where aliasesOfAliases = mconcat . map look . S.toList $ names
look k = case M.lookup k $ envVtable env of
Just (LetInfo (als, _)) -> unNames als
_ -> mempty
binding :: Checkable lore =>
Scope (Aliases lore)
-> TypeM lore a
-> TypeM lore a
binding bnds = check . local (`bindVars` bnds)
where bindVars = M.foldlWithKey' bindVar
boundnames = M.keys bnds
boundnameset = S.fromList boundnames
bindVar env name (LetInfo (Names' als, attr)) =
let als' | primType (typeOf attr) = mempty
| otherwise = expandAliases als env
inedges = S.toList als'
update (LetInfo (Names' thesenames, thisattr)) =
LetInfo (Names' $ S.insert name thesenames, thisattr)
update b = b
in env { envVtable =
M.insert name (LetInfo (Names' als', attr)) $
adjustSeveral update inedges $
envVtable env
}
bindVar env name attr =
env { envVtable = M.insert name attr $ envVtable env }
adjustSeveral f = flip $ foldl $ flip $ M.adjust f
check m = do
mapM_ bound $ M.keys bnds
(a, os) <- collectOccurences m
tell $ Consumption $ unOccur boundnameset os
return a
lookupVar :: VName -> TypeM lore (NameInfo (Aliases lore))
lookupVar name = do
bnd <- asks $ M.lookup name . envVtable
case bnd of
Nothing -> bad $ UnknownVariableError name
Just attr -> return attr
lookupAliases :: Checkable lore => VName -> TypeM lore Names
lookupAliases name = do
info <- lookupVar name
return $ if primType $ typeOf info
then mempty
else S.insert name $ aliases info
aliases :: NameInfo (Aliases lore) -> Names
aliases (LetInfo (als, _)) = unNames als
aliases _ = mempty
subExpAliasesM :: Checkable lore => SubExp -> TypeM lore Names
subExpAliasesM Constant{} = return mempty
subExpAliasesM (Var v) = lookupAliases v
lookupFun :: Checkable lore =>
Name
-> [SubExp]
-> TypeM lore ([RetType lore], [DeclType])
lookupFun fname args = do
bnd <- asks $ M.lookup fname . envFtable
case bnd of
Nothing -> bad $ UnknownFunctionError fname
Just (ftype, params) -> do
argts <- mapM subExpType args
case applyRetType ftype params $ zip args argts of
Nothing ->
bad $ ParameterMismatch (Just fname) (map paramType params) argts
Just rt ->
return (rt, map paramDeclType params)
checkAnnotation :: String -> Type -> Type
-> TypeM lore ()
checkAnnotation desc t1 t2
| t2 == t1 = return ()
| otherwise = bad $ BadAnnotation desc t1 t2
require :: Checkable lore => [Type] -> SubExp -> TypeM lore ()
require ts se = do
t <- checkSubExp se
unless (t `elem` ts) $
bad $ UnexpectedType (BasicOp $ SubExp se) t ts
requireI :: Checkable lore => [Type] -> VName -> TypeM lore ()
requireI ts ident = require ts $ Var ident
checkArrIdent :: Checkable lore =>
VName -> TypeM lore Type
checkArrIdent v = do
t <- lookupType v
case t of
Array{} -> return t
_ -> bad $ NotAnArray v t
checkProg :: Checkable lore =>
Prog lore -> Either (TypeError lore) ()
checkProg prog = do
let typeenv = Env { envVtable = M.empty
, envFtable = mempty
, envContext = []
}
let onFunction ftable fun =
fmap fst $ runTypeM typeenv $
local (\env -> env { envFtable = ftable }) $
checkFun fun
(ftable, _) <- runTypeM typeenv buildFtable
sequence_ $ parMap rpar (onFunction ftable) $ progFunctions prog'
where
prog' = aliasAnalysis prog
buildFtable = do table <- initialFtable prog'
foldM expand table $ progFunctions prog'
expand ftable (FunDef _ name ret params _)
| M.member name ftable =
bad $ DupDefinitionError name
| otherwise =
return $ M.insert name (ret,params) ftable
initialFtable :: Checkable lore =>
Prog (Aliases lore) -> TypeM lore (M.Map Name (FunBinding lore))
initialFtable _ = fmap M.fromList $ mapM addBuiltin $ M.toList builtInFunctions
where addBuiltin (fname, (t, ts)) = do
ps <- mapM (primFParam name) ts
return (fname, ([primRetType t], ps))
name = VName (nameFromString "x") 0
checkFun :: Checkable lore =>
FunDef (Aliases lore) -> TypeM lore ()
checkFun (FunDef _ fname rettype params body) =
context ("In function " ++ nameToString fname) $
checkFun' (fname,
retTypeValues rettype,
funParamsToNameInfos params,
body) consumable $ do
checkFunParams params
checkRetType rettype
checkFunBody rettype body
where consumable = [ (paramName param, mempty)
| param <- params
, unique $ paramDeclType param
]
funParamsToNameInfos :: [FParam lore]
-> [(VName, NameInfo (Aliases lore))]
funParamsToNameInfos = map nameTypeAndLore
where nameTypeAndLore fparam = (paramName fparam,
FParamInfo $ paramAttr fparam)
checkFunParams :: Checkable lore =>
[FParam lore] -> TypeM lore ()
checkFunParams = mapM_ $ \param ->
context ("In function parameter " ++ pretty param) $
checkFParamLore (paramName param) (paramAttr param)
checkLambdaParams :: Checkable lore =>
[LParam lore] -> TypeM lore ()
checkLambdaParams = mapM_ $ \param ->
context ("In lambda parameter " ++ pretty param) $
checkLParamLore (paramName param) (paramAttr param)
checkFun' :: Checkable lore =>
(Name,
[DeclExtType],
[(VName, NameInfo (Aliases lore))],
BodyT (Aliases lore))
-> [(VName, Names)]
-> TypeM lore ()
-> TypeM lore ()
checkFun' (fname, rettype, params, body) consumable check = do
checkNoDuplicateParams
binding (M.fromList params) $
consumeOnlyParams consumable $ do
check
scope <- askScope
let isArray = maybe False ((>0) . arrayRank . typeOf) . (`M.lookup` scope)
checkReturnAlias $ map (S.filter isArray) $ bodyAliases body
where param_names = map fst params
checkNoDuplicateParams = foldM_ expand [] param_names
expand seen pname
| Just _ <- find (==pname) seen =
bad $ DupParamError fname pname
| otherwise =
return $ pname : seen
checkReturnAlias =
foldM_ checkReturnAlias' S.empty . returnAliasing rettype
checkReturnAlias' seen (Unique, names)
| any (`S.member` S.map snd seen) $ S.toList names =
bad $ UniqueReturnAliased fname
| otherwise = do
consume names
return $ seen `S.union` tag Unique names
checkReturnAlias' seen (Nonunique, names)
| any (`S.member` seen) $ S.toList $ tag Unique names =
bad $ UniqueReturnAliased fname
| otherwise = return $ seen `S.union` tag Nonunique names
tag u = S.map $ \name -> (u, name)
returnAliasing expected got =
reverse $
zip (reverse (map uniqueness expected) ++ repeat Nonunique) $
reverse got
subCheck :: forall lore newlore a.
(Checkable newlore,
RetType lore ~ RetType newlore,
LetAttr lore ~ LetAttr newlore,
FParamAttr lore ~ FParamAttr newlore,
LParamAttr lore ~ LParamAttr newlore) =>
TypeM newlore a ->
TypeM lore a
subCheck m = do
typeenv <- asks newEnv
case runTypeM typeenv m of
Left err -> bad $ TypeError $ show err
Right (x, cons) -> tell cons >> return x
where newEnv :: Env lore -> Env newlore
newEnv (Env vtable ftable ctx) =
Env (M.map coerceVar vtable) ftable ctx
coerceVar (LetInfo x) = LetInfo x
coerceVar (FParamInfo x) = FParamInfo x
coerceVar (LParamInfo x) = LParamInfo x
coerceVar (IndexInfo it) = IndexInfo it
checkSubExp :: Checkable lore => SubExp -> TypeM lore Type
checkSubExp (Constant val) =
return $ Prim $ primValueType val
checkSubExp (Var ident) = context ("In subexp " ++ pretty ident) $ do
observe ident
lookupType ident
checkStms :: Checkable lore =>
Stms (Aliases lore) -> TypeM lore a
-> TypeM lore a
checkStms origbnds m = delve $ stmsToList origbnds
where delve (stm@(Let pat _ e):bnds) = do
context ("In expression of statement " ++ pretty pat) $
checkExp e
checkStm stm $
delve bnds
delve [] =
m
checkResult :: Checkable lore =>
Result -> TypeM lore ()
checkResult = mapM_ checkSubExp
checkFunBody :: Checkable lore =>
[RetType lore]
-> Body (Aliases lore)
-> TypeM lore ()
checkFunBody rt (Body (_,lore) bnds res) = do
checkStms bnds $ do
context "When checking body result" $ checkResult res
context "When matching declared return type to result of body" $
matchReturnType rt res
checkBodyLore lore
checkLambdaBody :: Checkable lore =>
[Type] -> Body (Aliases lore) -> TypeM lore ()
checkLambdaBody ret (Body (_,lore) bnds res) = do
checkStms bnds $ checkLambdaResult ret res
checkBodyLore lore
checkLambdaResult :: Checkable lore =>
[Type] -> Result -> TypeM lore ()
checkLambdaResult ts es
| length ts /= length es =
bad $ TypeError $
"Lambda has return type " ++ prettyTuple ts ++
" describing " ++ show (length ts) ++ " values, but body returns " ++
show (length es) ++ " values: " ++ prettyTuple es
| otherwise = forM_ (zip ts es) $ \(t, e) -> do
et <- checkSubExp e
unless (et == t) $
bad $ TypeError $
"Subexpression " ++ pretty e ++ " has type " ++ pretty et ++
" but expected " ++ pretty t
checkBody :: Checkable lore =>
Body (Aliases lore) -> TypeM lore ()
checkBody (Body (_,lore) bnds res) = do
checkStms bnds $ checkResult res
checkBodyLore lore
checkBasicOp :: Checkable lore =>
BasicOp (Aliases lore) -> TypeM lore ()
checkBasicOp (SubExp es) =
void $ checkSubExp es
checkBasicOp (Opaque es) =
void $ checkSubExp es
checkBasicOp (ArrayLit [] _) =
return ()
checkBasicOp (ArrayLit (e:es') t) = do
let check elemt eleme = do
elemet <- checkSubExp eleme
unless (elemet == elemt) $
bad $ TypeError $ pretty elemet ++
" is not of expected type " ++ pretty elemt ++ "."
et <- checkSubExp e
checkAnnotation "array-element" t et
mapM_ (check et) es'
checkBasicOp (UnOp op e) = require [Prim $ unOpType op] e
checkBasicOp (BinOp op e1 e2) = checkBinOpArgs (binOpType op) e1 e2
checkBasicOp (CmpOp op e1 e2) = checkCmpOp op e1 e2
checkBasicOp (ConvOp op e) = require [Prim $ fst $ convOpType op] e
checkBasicOp (Index ident idxes) = do
vt <- lookupType ident
observe ident
when (arrayRank vt /= length idxes) $
bad $ SlicingError (arrayRank vt) (length idxes)
mapM_ checkDimIndex idxes
checkBasicOp (Update src idxes se) = do
src_t <- checkArrIdent src
when (arrayRank src_t /= length idxes) $
bad $ SlicingError (arrayRank src_t) (length idxes)
se_aliases <- subExpAliasesM se
when (src `S.member` se_aliases) $
bad $ TypeError "The target of an Update must not alias the value to be written."
mapM_ checkDimIndex idxes
require [Prim (elemType src_t) `arrayOfShape` Shape (sliceDims idxes)] se
consume =<< lookupAliases src
checkBasicOp (Iota e x s et) = do
require [Prim int32] e
require [Prim $ IntType et] x
require [Prim $ IntType et] s
checkBasicOp (Replicate (Shape dims) valexp) = do
mapM_ (require [Prim int32]) dims
void $ checkSubExp valexp
checkBasicOp (Repeat shapes innershape v) = do
v_t <- lookupType v
mapM_ (mapM_ (require [Prim int32]) . shapeDims) $ innershape : shapes
unless (length shapes == arrayRank v_t) $
bad $ TypeError "Incorrect number of shapes in repeat."
checkBasicOp (Scratch _ shape) =
mapM_ checkSubExp shape
checkBasicOp (Reshape newshape arrexp) = do
rank <- arrayRank <$> checkArrIdent arrexp
mapM_ (require [Prim int32] . newDim) newshape
zipWithM_ (checkDimChange rank) newshape [0..]
where checkDimChange _ (DimNew _) _ =
return ()
checkDimChange rank (DimCoercion se) i
| i >= rank =
bad $ TypeError $
"Asked to coerce dimension " ++ show i ++ " to " ++ pretty se ++
", but array " ++ pretty arrexp ++ " has only " ++ pretty rank ++ " dimensions"
| otherwise =
return ()
checkBasicOp (Rearrange perm arr) = do
arrt <- lookupType arr
let rank = arrayRank arrt
when (length perm /= rank || sort perm /= [0..rank-1]) $
bad $ PermutationError perm rank $ Just arr
checkBasicOp (Rotate rots arr) = do
arrt <- lookupType arr
let rank = arrayRank arrt
mapM_ (require [Prim int32]) rots
when (length rots /= rank) $
bad $ TypeError $ "Cannot rotate " ++ show (length rots) ++
" dimensions of " ++ show rank ++ "-dimensional array."
checkBasicOp (Concat i arr1exp arr2exps ressize) = do
arr1t <- checkArrIdent arr1exp
arr2ts <- mapM checkArrIdent arr2exps
let success = all (== (dropAt i 1 $ arrayDims arr1t)) $
map (dropAt i 1 . arrayDims) arr2ts
unless success $
bad $ TypeError $
"Types of arguments to concat do not match. Got " ++
pretty arr1t ++ " and " ++ intercalate ", " (map pretty arr2ts)
require [Prim int32] ressize
checkBasicOp (Copy e) =
void $ checkArrIdent e
checkBasicOp (Manifest perm arr) =
checkBasicOp $ Rearrange perm arr
checkBasicOp (Assert e _ _) =
require [Prim Bool] e
checkExp :: Checkable lore =>
Exp (Aliases lore) -> TypeM lore ()
checkExp (BasicOp op) = checkBasicOp op
checkExp (If e1 e2 e3 info) = do
require [Prim Bool] e1
_ <- checkBody e2 `alternative` checkBody e3
context "in true branch" $ matchBranchType (ifReturns info) e2
context "in false branch" $ matchBranchType (ifReturns info) e3
checkExp (Apply fname args rettype_annot _) = do
(rettype_derived, paramtypes) <- lookupFun fname $ map fst args
argflows <- mapM (checkArg . fst) args
when (rettype_derived /= rettype_annot) $
bad $ TypeError $ "Expected apply result type " ++ pretty rettype_derived
++ " but annotation is " ++ pretty rettype_annot
checkFuncall (Just fname) paramtypes argflows
checkExp (DoLoop ctxmerge valmerge form loopbody) = do
let merge = ctxmerge ++ valmerge
(mergepat, mergeexps) = unzip merge
mergeargs <- mapM checkArg mergeexps
binding (scopeOf form) $ do
case form of
ForLoop loopvar it boundexp loopvars -> do
iparam <- primFParam loopvar $ IntType it
let funparams = iparam : mergepat
paramts = map paramDeclType funparams
forM_ loopvars $ \(p,a) -> do
a_t <- lookupType a
observe a
case peelArray 1 a_t of
Just a_t_r -> do
checkLParamLore (paramName p) $ paramAttr p
unless (a_t_r `subtypeOf` typeOf (paramAttr p)) $
bad $ TypeError $ "Loop parameter " ++ pretty p ++
" not valid for element of " ++ pretty a ++ ", which has row type " ++ pretty a_t_r
_ -> bad $ TypeError $ "Cannot loop over " ++ pretty a ++
" of type " ++ pretty a_t
boundarg <- checkArg boundexp
checkFuncall Nothing paramts $ boundarg : mergeargs
WhileLoop cond -> do
case find ((==cond) . paramName . fst) merge of
Just (condparam,_) ->
unless (paramType condparam == Prim Bool) $
bad $ TypeError $
"Conditional '" ++ pretty cond ++ "' of while-loop is not boolean, but " ++
pretty (paramType condparam) ++ "."
Nothing ->
bad $ TypeError $
"Conditional '" ++ pretty cond ++ "' of while-loop is not a merge varible."
let funparams = mergepat
paramts = map paramDeclType funparams
checkFuncall Nothing paramts mergeargs
let rettype = map paramDeclType mergepat
consumable = [ (paramName param, mempty)
| param <- mergepat,
unique $ paramDeclType param
]
context "Inside the loop body" $
checkFun' (nameFromString "<loop body>",
staticShapes rettype,
funParamsToNameInfos mergepat,
loopbody) consumable $ do
checkFunParams mergepat
checkBody loopbody
let rettype_ext = existentialiseExtTypes (map paramName mergepat) $
staticShapes $ map fromDecl rettype
bodyt <- extendedScope (traverse subExpType $ bodyResult loopbody) $
scopeOf $ bodyStms loopbody
case instantiateShapes (`maybeNth` bodyResult loopbody) rettype_ext of
Nothing -> bad $ ReturnTypeError (nameFromString "<loop body>")
(staticShapes $ map fromDecl rettype) (staticShapes bodyt)
Just rettype' ->
unless (bodyt `subtypesOf` rettype') $
bad $ ReturnTypeError (nameFromString "<loop body>")
(staticShapes rettype') (staticShapes bodyt)
checkExp (Op op) = checkOp op
checkSOACArrayArgs :: Checkable lore =>
SubExp -> [VName] -> TypeM lore [Arg]
checkSOACArrayArgs width vs =
forM vs $ \v -> do
(vt, v') <- checkSOACArrayArg v
let argSize = arraySize 0 vt
unless (argSize == width) $
bad $ TypeError $
"SOAC argument " ++ pretty v ++ " has outer size " ++
pretty argSize ++ ", but width of SOAC is " ++
pretty width
return v'
where checkSOACArrayArg ident = do
(t, als) <- checkArg $ Var ident
case peelArray 1 t of
Nothing -> bad $ TypeError $
"SOAC argument " ++ pretty ident ++ " is not an array"
Just rt -> return (t, (rt, als))
checkType :: Checkable lore =>
TypeBase Shape u -> TypeM lore ()
checkType = mapM_ checkSubExp . arrayDims
checkExtType :: Checkable lore =>
TypeBase ExtShape u
-> TypeM lore ()
checkExtType = mapM_ checkExtDim . shapeDims . arrayShape
where checkExtDim (Free se) = void $ checkSubExp se
checkExtDim (Ext _) = return ()
checkCmpOp :: Checkable lore =>
CmpOp -> SubExp -> SubExp
-> TypeM lore ()
checkCmpOp (CmpEq t) x y = do
require [Prim t] x
require [Prim t] y
checkCmpOp (CmpUlt t) x y = checkBinOpArgs (IntType t) x y
checkCmpOp (CmpUle t) x y = checkBinOpArgs (IntType t) x y
checkCmpOp (CmpSlt t) x y = checkBinOpArgs (IntType t) x y
checkCmpOp (CmpSle t) x y = checkBinOpArgs (IntType t) x y
checkCmpOp (FCmpLt t) x y = checkBinOpArgs (FloatType t) x y
checkCmpOp (FCmpLe t) x y = checkBinOpArgs (FloatType t) x y
checkCmpOp CmpLlt x y = checkBinOpArgs Bool x y
checkCmpOp CmpLle x y = checkBinOpArgs Bool x y
checkBinOpArgs :: Checkable lore =>
PrimType -> SubExp -> SubExp -> TypeM lore ()
checkBinOpArgs t e1 e2 = do
require [Prim t] e1
require [Prim t] e2
checkPatElem :: Checkable lore =>
PatElemT (LetAttr lore) -> TypeM lore ()
checkPatElem (PatElem name attr) = checkLetBoundLore name attr
checkDimIndex :: Checkable lore =>
DimIndex SubExp -> TypeM lore ()
checkDimIndex (DimFix i) = require [Prim int32] i
checkDimIndex (DimSlice i n s) = mapM_ (require [Prim int32]) [i,n,s]
checkStm :: Checkable lore =>
Stm (Aliases lore)
-> TypeM lore a
-> TypeM lore a
checkStm stm@(Let pat (StmAux (Certificates cs) (_,attr)) e) m = do
mapM_ (requireI [Prim Cert]) cs
checkExpLore attr
context ("When matching\n" ++ message " " pat ++ "\nwith\n" ++ message " " e) $
matchPattern pat e
binding (scopeOf stm) $ do
mapM_ checkPatElem (patternElements $ removePatternAliases pat)
m
matchExtPattern :: Checkable lore =>
Pattern (Aliases lore) -> [ExtType] -> TypeM lore ()
matchExtPattern pat ts =
unless (expExtTypesFromPattern pat == ts) $
bad $ InvalidPatternError pat ts Nothing
matchExtReturnType :: Checkable lore =>
[ExtType] -> Result -> TypeM lore ()
matchExtReturnType rettype res = do
ts <- mapM subExpType res
matchExtReturns rettype res ts
matchExtBranchType :: Checkable lore =>
[ExtType] -> Body (Aliases lore) -> TypeM lore ()
matchExtBranchType rettype (Body _ stms res) = do
ts <- extendedScope (traverse subExpType res) stmscope
matchExtReturns rettype res ts
where stmscope = scopeOf stms
matchExtReturns :: [ExtType] -> Result -> [Type] -> TypeM lore ()
matchExtReturns rettype res ts = do
let problem :: TypeM lore a
problem = bad $ TypeError $ unlines [ "Type annotation is"
, " " ++ prettyTuple rettype
, "But result returns type"
, " " ++ prettyTuple ts ]
let (ctx_res, val_res) = splitFromEnd (length rettype) res
(ctx_ts, val_ts) = splitFromEnd (length rettype) ts
unless (length val_res == length rettype) problem
let ctx_vals = zip ctx_res ctx_ts
instantiateExt i = case maybeNth i ctx_vals of
Just (se, Prim (IntType Int32)) -> return se
_ -> problem
rettype' <- instantiateShapes instantiateExt rettype
unless (rettype' == val_ts) problem
validApply :: ArrayShape shape =>
[TypeBase shape Uniqueness]
-> [TypeBase shape NoUniqueness]
-> Bool
validApply expected got =
length got == length expected &&
and (zipWith subtypeOf
(map rankShaped got)
(map (fromDecl . rankShaped) expected))
type Arg = (Type, Names)
argType :: Arg -> Type
argType (t, _) = t
argAliases :: Arg -> Names
argAliases (_, als) = als
noArgAliases :: Arg -> Arg
noArgAliases (t, _) = (t, mempty)
checkArg :: Checkable lore =>
SubExp -> TypeM lore Arg
checkArg arg = do argt <- checkSubExp arg
als <- subExpAliasesM arg
return (argt, als)
checkFuncall :: Maybe Name
-> [DeclType] -> [Arg]
-> TypeM lore ()
checkFuncall fname paramts args = do
let argts = map argType args
unless (validApply paramts argts) $
bad $ ParameterMismatch fname
(map fromDecl paramts) $
map argType args
forM_ (zip (map diet paramts) args) $ \(d, (_, als)) ->
occur [consumption (consumeArg als d)]
where consumeArg als Consume = als
consumeArg _ Observe = mempty
checkLambda :: Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
checkLambda (Lambda params body rettype) args = do
let fname = nameFromString "<anonymous>"
if length params == length args then do
checkFuncall Nothing
(map ((`toDecl` Nonunique) . paramType) params) args
let consumable = zip (map paramName params) (map argAliases args)
checkFun' (fname,
staticShapes $ map (`toDecl` Nonunique) rettype,
[ (paramName param,
LParamInfo $ paramAttr param)
| param <- params ],
body) consumable $ do
checkLambdaParams params
mapM_ checkType rettype
checkLambdaBody rettype body
else bad $ TypeError $ "Anonymous function defined with " ++ show (length params) ++ " parameters, but expected to take " ++ show (length args) ++ " arguments."
checkPrimExp :: Checkable lore => PrimExp VName -> TypeM lore ()
checkPrimExp ValueExp{} = return ()
checkPrimExp (LeafExp v pt) = requireI [Prim pt] v
checkPrimExp (BinOpExp op x y) = do requirePrimExp (binOpType op) x
requirePrimExp (binOpType op) y
checkPrimExp (CmpOpExp op x y) = do requirePrimExp (cmpOpType op) x
requirePrimExp (cmpOpType op) y
checkPrimExp (UnOpExp op x) = requirePrimExp (unOpType op) x
checkPrimExp (ConvOpExp op x) = requirePrimExp (fst $ convOpType op) x
checkPrimExp (FunExp h args t) = do
(h_ts, h_ret, _) <- maybe (bad $ TypeError $ "Unknown function: " ++ h)
return $ M.lookup h primFuns
when (length h_ts /= length args) $
bad $ TypeError $ "Function expects " ++ show (length h_ts) ++
" parameters, but given " ++ show (length args) ++ " arguments."
when (h_ret /= t) $
bad $ TypeError $ "Function return annotation is " ++ pretty t ++
", but expected " ++ pretty h_ret
zipWithM_ requirePrimExp h_ts args
requirePrimExp :: Checkable lore => PrimType -> PrimExp VName -> TypeM lore ()
requirePrimExp t e = context ("in PrimExp " ++ pretty e) $ do
checkPrimExp e
unless (primExpType e == t) $ bad $ TypeError $
pretty e ++ " must have type " ++ pretty t
class Attributes lore => CheckableOp lore where
checkOp :: OpWithAliases (Op lore) -> TypeM lore ()
class (Attributes lore, CanBeAliased (Op lore), CheckableOp lore) => Checkable lore where
checkExpLore :: ExpAttr lore -> TypeM lore ()
checkBodyLore :: BodyAttr lore -> TypeM lore ()
checkFParamLore :: VName -> FParamAttr lore -> TypeM lore ()
checkLParamLore :: VName -> LParamAttr lore -> TypeM lore ()
checkLetBoundLore :: VName -> LetAttr lore -> TypeM lore ()
checkRetType :: [RetType lore] -> TypeM lore ()
matchPattern :: Pattern (Aliases lore) -> Exp (Aliases lore) -> TypeM lore ()
primFParam :: VName -> PrimType -> TypeM lore (FParam (Aliases lore))
matchReturnType :: [RetType lore] -> Result -> TypeM lore ()
matchBranchType :: [BranchType lore] -> Body (Aliases lore) -> TypeM lore ()
default checkExpLore :: ExpAttr lore ~ () => ExpAttr lore -> TypeM lore ()
checkExpLore = return
default checkBodyLore :: BodyAttr lore ~ () => BodyAttr lore -> TypeM lore ()
checkBodyLore = return
default checkFParamLore :: FParamAttr lore ~ DeclType => VName -> FParamAttr lore -> TypeM lore ()
checkFParamLore _ = checkType
default checkLParamLore :: LParamAttr lore ~ Type => VName -> LParamAttr lore -> TypeM lore ()
checkLParamLore _ = checkType
default checkLetBoundLore :: LetAttr lore ~ Type => VName -> LetAttr lore -> TypeM lore ()
checkLetBoundLore _ = checkType
default checkRetType :: RetType lore ~ DeclExtType => [RetType lore] -> TypeM lore ()
checkRetType = mapM_ checkExtType . retTypeValues
default matchPattern :: Pattern (Aliases lore) -> Exp (Aliases lore) -> TypeM lore ()
matchPattern pat = matchExtPattern pat <=< expExtType
default primFParam :: FParamAttr lore ~ DeclType => VName -> PrimType -> TypeM lore (FParam (Aliases lore))
primFParam name t = return $ Param name (Prim t)
default matchReturnType :: RetType lore ~ DeclExtType => [RetType lore] -> Result -> TypeM lore ()
matchReturnType = matchExtReturnType . map fromDecl
default matchBranchType :: BranchType lore ~ ExtType => [BranchType lore] -> Body (Aliases lore) -> TypeM lore ()
matchBranchType = matchExtBranchType