{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
module Language.Futhark.Attributes
(
Intrinsic(..)
, intrinsics
, maxIntrinsicTag
, namesToPrimTypes
, qualName
, qualify
, typeName
, valueType
, leadingOperator
, progImports
, decImports
, progModuleTypes
, identifierReference
, identifierReferences
, typeOf
, patIdentSet
, patternType
, patternStructType
, patternParam
, patternNoShapeAnnotations
, patternOrderZero
, patternDimNames
, uniqueness
, unique
, recordArrayElemUniqueness
, aliases
, diet
, arrayRank
, nestedDims
, returnType
, concreteType
, orderZero
, unfoldFunType
, foldFunType
, typeVars
, typeDimNames
, rank
, peelArray
, stripArray
, arrayOf
, arrayOfWithAliases
, toStructural
, toStruct
, fromStruct
, setAliases
, addAliases
, setUniqueness
, modifyShapeAnnotations
, setArrayShape
, removeShapeAnnotations
, vacuousShapeAnnotations
, typeToRecordArrayElem
, typeToRecordArrayElem'
, recordArrayElemToType
, tupleRecord
, isTupleRecord
, areTupleFields
, tupleFieldNames
, sortFields
, isTypeParam
, NoInfo(..)
, UncheckedType
, UncheckedTypeExp
, UncheckedArrayElemType
, UncheckedIdent
, UncheckedTypeDecl
, UncheckedDimIndex
, UncheckedExp
, UncheckedModExp
, UncheckedSigExp
, UncheckedTypeParam
, UncheckedPattern
, UncheckedValBind
, UncheckedDec
, UncheckedProg
)
where
import Control.Monad.Writer
import Data.Char
import Data.Foldable
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.List
import Data.Loc
import Data.Maybe
import Data.Ord
import Data.Bifunctor
import Data.Bifoldable
import Prelude
import Futhark.Util.Pretty
import Language.Futhark.Syntax
import qualified Futhark.Representation.Primitive as Primitive
arrayRank :: TypeBase dim as -> Int
arrayRank = shapeRank . arrayShape
arrayShape :: TypeBase dim as -> ShapeDecl dim
arrayShape (Array _ ds _) = ds
arrayShape _ = mempty
nestedDims :: TypeBase (DimDecl VName) as -> [DimDecl VName]
nestedDims t =
case t of Array a ds _ -> nub $ arrayNestedDims a <> shapeDims ds
Record fs -> nub $ fold $ fmap nestedDims fs
Prim{} -> mempty
TypeVar _ _ _ targs -> concatMap typeArgDims targs
Arrow _ v t1 t2 -> filter (notV v) $ nestedDims t1 <> nestedDims t2
where arrayNestedDims ArrayPrimElem{} =
mempty
arrayNestedDims (ArrayPolyElem _ targs _) =
concatMap typeArgDims targs
arrayNestedDims (ArrayRecordElem ts) =
fold (fmap recordArrayElemNestedDims ts)
recordArrayElemNestedDims (RecordArrayArrayElem a ds _) =
arrayNestedDims a <> shapeDims ds
recordArrayElemNestedDims (RecordArrayElem et) =
arrayNestedDims et
typeArgDims (TypeArgDim d _) = [d]
typeArgDims (TypeArgType at _) = nestedDims at
notV Nothing = const True
notV (Just v) = (/=NamedDim (qualName v))
setArrayShape :: TypeBase dim as -> ShapeDecl dim -> TypeBase dim as
setArrayShape (Array t _ u) ds = Array t ds u
setArrayShape t _ = t
removeShapeAnnotations :: TypeBase dim as -> TypeBase () as
removeShapeAnnotations = modifyShapeAnnotations $ const ()
vacuousShapeAnnotations :: TypeBase dim as -> TypeBase (DimDecl vn) as
vacuousShapeAnnotations = modifyShapeAnnotations $ const AnyDim
modifyShapeAnnotations :: (oldshape -> newshape)
-> TypeBase oldshape as
-> TypeBase newshape as
modifyShapeAnnotations f = bimap f id
uniqueness :: TypeBase shape as -> Uniqueness
uniqueness (Array _ _ u) = u
uniqueness (TypeVar _ u _ _) = u
uniqueness _ = Nonunique
recordArrayElemUniqueness :: RecordArrayElemTypeBase shape as -> Uniqueness
recordArrayElemUniqueness RecordArrayElem{} = Nonunique
recordArrayElemUniqueness (RecordArrayArrayElem _ _ u) = u
unique :: TypeBase shape as -> Bool
unique = (==Unique) . uniqueness
aliases :: Monoid as => TypeBase shape as -> as
aliases = bifoldMap (const mempty) id
diet :: TypeBase shape as -> Diet
diet (Record ets) = RecordDiet $ fmap diet ets
diet (Prim _) = Observe
diet TypeVar{} = Observe
diet (Arrow _ _ t1 t2) = FuncDiet (diet t1) (diet t2)
diet (Array _ _ Unique) = Consume
diet (Array _ _ Nonunique) = Observe
maskAliases :: Monoid as =>
TypeBase shape as
-> Diet
-> TypeBase shape as
maskAliases t Consume = t `setAliases` mempty
maskAliases t Observe = t
maskAliases (Record ets) (RecordDiet ds) =
Record $ M.intersectionWith maskAliases ets ds
maskAliases t FuncDiet{} = t
maskAliases _ _ = error "Invalid arguments passed to maskAliases."
toStructural :: TypeBase dim as
-> TypeBase () ()
toStructural = removeNames . removeShapeAnnotations
toStruct :: TypeBase dim as
-> TypeBase dim ()
toStruct t = t `setAliases` ()
fromStruct :: TypeBase dim as
-> TypeBase dim Names
fromStruct t = t `setAliases` S.empty
peelArray :: Int -> TypeBase dim as -> Maybe (TypeBase dim as)
peelArray 0 t = Just t
peelArray n (Array (ArrayPrimElem et _) shape _)
| shapeRank shape == n =
Just $ Prim et
peelArray n (Array (ArrayPolyElem et targs als) shape u)
| shapeRank shape == n =
Just $ TypeVar als u et targs
peelArray n (Array (ArrayRecordElem ts) shape u)
| shapeRank shape == n =
Just $ Record $ fmap asType ts
where asType (RecordArrayElem (ArrayPrimElem bt _)) = Prim bt
asType (RecordArrayElem (ArrayPolyElem bt targs als)) = TypeVar als u bt targs
asType (RecordArrayElem (ArrayRecordElem ts')) = Record $ fmap asType ts'
asType (RecordArrayArrayElem et e_shape _) = Array et e_shape u
peelArray n (Array et shape u) = do
shape' <- stripDims n shape
return $ Array et shape' u
peelArray _ _ = Nothing
removeNames :: TypeBase dim as
-> TypeBase () ()
removeNames = flip setAliases () . removeShapeAnnotations
arrayOf :: Monoid as =>
TypeBase dim as
-> ShapeDecl dim
-> Uniqueness
-> Maybe (TypeBase dim as)
arrayOf t = arrayOfWithAliases t mempty
arrayOfWithAliases :: Monoid as =>
TypeBase dim as
-> as
-> ShapeDecl dim
-> Uniqueness
-> Maybe (TypeBase dim as)
arrayOfWithAliases (Array et shape1 _) as shape2 u =
Just $ Array et (shape2 <> shape1) u `setAliases` as
arrayOfWithAliases (Prim et) as shape u =
Just $ Array (ArrayPrimElem et as) shape u
arrayOfWithAliases (TypeVar _ _ x targs) as shape u =
Just $ Array (ArrayPolyElem x targs as) shape u
arrayOfWithAliases (Record ts) as shape u = do
ts' <- traverse (typeToRecordArrayElem' as) ts
return $ Array (ArrayRecordElem ts') shape u
arrayOfWithAliases Arrow{} _ _ _ = Nothing
typeToRecordArrayElem :: Monoid as =>
TypeBase dim as
-> Maybe (RecordArrayElemTypeBase dim as)
typeToRecordArrayElem = typeToRecordArrayElem' mempty
typeToRecordArrayElem' :: Monoid as =>
as -> TypeBase dim as
-> Maybe (RecordArrayElemTypeBase dim as)
typeToRecordArrayElem' as (Prim bt) =
Just $ RecordArrayElem $ ArrayPrimElem bt as
typeToRecordArrayElem' as (TypeVar t_as _ bt targs) =
Just $ RecordArrayElem $ ArrayPolyElem bt targs (as <> t_as)
typeToRecordArrayElem' as (Record ts') =
RecordArrayElem . ArrayRecordElem <$>
traverse (typeToRecordArrayElem' as) ts'
typeToRecordArrayElem' _ (Array et shape u) =
Just $ RecordArrayArrayElem et shape u
typeToRecordArrayElem' _ Arrow{} = Nothing
recordArrayElemToType :: Monoid as =>
RecordArrayElemTypeBase dim as
-> (TypeBase dim as, as)
recordArrayElemToType (RecordArrayElem et) = arrayElemToType et
recordArrayElemToType (RecordArrayArrayElem et shape u) = (Array et shape u, mempty)
arrayElemToType :: Monoid as => ArrayElemTypeBase dim as -> (TypeBase dim as, as)
arrayElemToType (ArrayPrimElem bt als) = (Prim bt, als)
arrayElemToType (ArrayPolyElem bt targs als) = (TypeVar als Nonunique bt targs, als)
arrayElemToType (ArrayRecordElem ts) =
let ts' = fmap recordArrayElemToType ts
in (Record $ fmap fst ts', foldMap snd ts')
stripArray :: Monoid as => Int -> TypeBase dim as -> TypeBase dim as
stripArray n (Array et shape u)
| Just shape' <- stripDims n shape =
Array et shape' u
| otherwise = fst (arrayElemToType et) `setUniqueness` u
stripArray _ t = t
tupleRecord :: [TypeBase dim as] -> TypeBase dim as
tupleRecord = Record . M.fromList . zip tupleFieldNames
isTupleRecord :: TypeBase dim as -> Maybe [TypeBase dim as]
isTupleRecord (Record fs) = areTupleFields fs
isTupleRecord _ = Nothing
areTupleFields :: M.Map Name a -> Maybe [a]
areTupleFields fs =
let fs' = sortFields fs
in if and $ zipWith (==) (map fst fs') tupleFieldNames
then Just $ map snd fs'
else Nothing
tupleFieldNames :: [Name]
tupleFieldNames = map (nameFromString . show) [(1::Int)..]
sortFields :: M.Map Name a -> [(Name,a)]
sortFields l = map snd $ sortOn fst $ zip (map (fieldish . fst) l') l'
where l' = M.toList l
fieldish s = case reads $ nameToString s of
[(x, "")] -> Left (x::Int)
_ -> Right s
isTypeParam :: TypeParamBase vn -> Bool
isTypeParam TypeParamType{} = True
isTypeParam TypeParamDim{} = False
setUniqueness :: TypeBase dim as -> Uniqueness -> TypeBase dim as
setUniqueness (Array et shape _) u =
Array (setArrayElemUniqueness et u) shape u
setUniqueness (TypeVar als _ t targs) u =
TypeVar als u t targs
setUniqueness (Record ets) u =
Record $ fmap (`setUniqueness` u) ets
setUniqueness t _ = t
setArrayElemUniqueness :: ArrayElemTypeBase dim as
-> Uniqueness -> ArrayElemTypeBase dim as
setArrayElemUniqueness (ArrayPrimElem bt as) _ =
ArrayPrimElem bt as
setArrayElemUniqueness (ArrayPolyElem v args as) _ =
ArrayPolyElem v args as
setArrayElemUniqueness (ArrayRecordElem r) u =
ArrayRecordElem $ fmap set r
where set (RecordArrayElem et) =
RecordArrayElem $ setArrayElemUniqueness et u
set (RecordArrayArrayElem et shape e_u) =
RecordArrayArrayElem (setArrayElemUniqueness et u) shape e_u
setAliases :: TypeBase dim asf -> ast -> TypeBase dim ast
setAliases t = addAliases t . const
addAliases :: TypeBase dim asf -> (asf -> ast)
-> TypeBase dim ast
addAliases t f = bimap id f t
intValueType :: IntValue -> IntType
intValueType Int8Value{} = Int8
intValueType Int16Value{} = Int16
intValueType Int32Value{} = Int32
intValueType Int64Value{} = Int64
floatValueType :: FloatValue -> FloatType
floatValueType Float32Value{} = Float32
floatValueType Float64Value{} = Float64
primValueType :: PrimValue -> PrimType
primValueType (SignedValue v) = Signed $ intValueType v
primValueType (UnsignedValue v) = Unsigned $ intValueType v
primValueType (FloatValue v) = FloatType $ floatValueType v
primValueType BoolValue{} = Bool
valueType :: Value -> TypeBase () ()
valueType (PrimValue bv) = Prim $ primValueType bv
valueType (ArrayValue _ t) = t
rank :: Int -> ShapeDecl ()
rank n = ShapeDecl $ replicate n ()
typeOf :: ExpBase Info VName -> CompType
typeOf (Literal val _) = Prim $ primValueType val
typeOf (IntLit _ (Info t) _) = fromStruct t
typeOf (FloatLit _ (Info t) _) = fromStruct t
typeOf (Parens e _) = typeOf e
typeOf (QualParens _ e _) = typeOf e
typeOf (TupLit es _) = tupleRecord $ map typeOf es
typeOf (RecordLit fs _) =
Record $ M.unions $ reverse $ map record fs
where record (RecordFieldExplicit name e _) = M.singleton name $ typeOf e
record (RecordFieldImplicit name (Info t) _) =
M.singleton (baseName name) $ t `addAliases` S.insert name
typeOf (ArrayLit _ (Info t) _) = t
typeOf (Range _ _ _ (Info t) _) = t
typeOf (BinOp _ _ _ _ (Info t) _) = removeShapeAnnotations t
typeOf (Project _ _ (Info t) _) = t
typeOf (If _ _ _ (Info t) _) = t
typeOf (Var qn (Info t) _) = removeShapeAnnotations t `addAliases` S.insert (qualLeaf qn)
typeOf (Ascript e _ _) = typeOf e
typeOf (Apply _ _ _ (Info t) _) = removeShapeAnnotations t
typeOf (Negate e _) = typeOf e
typeOf (LetPat _ _ _ body _) = typeOf body
typeOf (LetFun _ _ body _) = typeOf body
typeOf (LetWith _ _ _ _ body _) = typeOf body
typeOf (Index _ _ (Info t) _) = t
typeOf (Update e _ _ _) = typeOf e `setAliases` mempty
typeOf (RecordUpdate _ _ _ (Info t) _) = removeShapeAnnotations t
typeOf (Zip _ _ _ (Info t) _) = t
typeOf (Unzip _ ts _) =
tupleRecord $ map unInfo ts
typeOf (Unsafe e _) = typeOf e
typeOf (Assert _ e _ _) = typeOf e
typeOf (Map _ _ (Info t) _) = t `setUniqueness` Unique
typeOf (Reduce _ _ _ arr _) =
stripArray 1 (typeOf arr) `setAliases` mempty
typeOf (GenReduce hist _ _ _ _ _) =
typeOf hist `setAliases` mempty `setUniqueness` Unique
typeOf (Scan _ _ arr _) = typeOf arr `setAliases` mempty `setUniqueness` Unique
typeOf (Filter _ arr _) = typeOf arr `setAliases` mempty `setUniqueness` Unique
typeOf (Partition _ _ arr _) =
tupleRecord [typeOf arr `setAliases` mempty `setUniqueness` Unique,
Array (ArrayPrimElem (Signed Int32) mempty) (rank 1) Unique]
typeOf (Stream _ lam _ _) =
rettype (typeOf lam) `setUniqueness` Unique
where rettype (Arrow _ _ _ t) = rettype t
rettype t = t
typeOf (DoLoop _ pat _ _ _ _) = patternType pat
typeOf (Lambda _ params _ _ (Info (als, t)) _) =
removeShapeAnnotations (foldr (uncurry (Arrow ()) . patternParam) t params)
`setAliases` als
typeOf (OpSection _ (Info t) _) =
removeShapeAnnotations t
typeOf (OpSectionLeft _ _ _ (_, Info pt2) (Info ret) _) =
removeShapeAnnotations $ foldFunType [fromStruct pt2] ret
typeOf (OpSectionRight _ _ _ (Info pt1, _) (Info ret) _) =
removeShapeAnnotations $ foldFunType [fromStruct pt1] ret
typeOf (ProjectSection _ (Info t) _) =
removeShapeAnnotations t
typeOf (IndexSection _ (Info t) _) =
removeShapeAnnotations t
foldFunType :: Monoid as => [TypeBase dim as] -> TypeBase dim as -> TypeBase dim as
foldFunType ps ret = foldr (Arrow mempty Nothing) ret ps
unfoldFunType :: TypeBase dim as -> ([TypeBase dim as], TypeBase dim as)
unfoldFunType (Arrow _ _ t1 t2) = let (ps, r) = unfoldFunType t2
in (t1 : ps, r)
unfoldFunType t = ([], t)
typeVars :: Monoid as => TypeBase dim as -> Names
typeVars t =
case t of
Prim{} -> mempty
TypeVar _ _ tn targs ->
mconcat $ typeVarFree tn : map typeArgFree targs
Arrow _ _ t1 t2 -> typeVars t1 <> typeVars t2
Record fields -> foldMap typeVars fields
Array ArrayPrimElem{} _ _ -> mempty
Array (ArrayPolyElem tn targs _) _ _ ->
mconcat $ typeVarFree tn : map typeArgFree targs
Array (ArrayRecordElem fields) _ _ ->
foldMap (typeVars . fst . recordArrayElemToType) fields
where typeVarFree = S.singleton . typeLeaf
typeArgFree (TypeArgType ta _) = typeVars ta
typeArgFree TypeArgDim{} = mempty
returnType :: TypeBase dim ()
-> [Diet]
-> [CompType]
-> TypeBase dim Names
returnType (Array et shape Unique) _ _ =
Array (bimap id (const mempty) et) shape Unique
returnType (Array et shape Nonunique) ds args =
Array (arrayElemReturnType et ds args) shape Nonunique
returnType (Record fs) ds args =
Record $ fmap (\et -> returnType et ds args) fs
returnType (Prim t) _ _ = Prim t
returnType (TypeVar () Unique t targs) _ _ =
TypeVar mempty Unique t $ map (bimap id (const mempty)) targs
returnType (TypeVar () Nonunique t targs) ds args =
TypeVar als Nonunique t $ map (\arg -> typeArgReturnType arg ds args) targs
where als = mconcat $ map aliases $ zipWith maskAliases args ds
returnType (Arrow _ v t1 t2) ds args =
Arrow als v (bimap id (const mempty) t1) (returnType t2 ds args)
where als = foldMap aliases $ zipWith maskAliases args ds
typeArgReturnType :: TypeArg shape () -> [Diet] -> [CompType]
-> TypeArg shape Names
typeArgReturnType (TypeArgDim v loc) _ _ =
TypeArgDim v loc
typeArgReturnType (TypeArgType t loc) ds args =
TypeArgType (returnType t ds args) loc
arrayElemReturnType :: ArrayElemTypeBase dim ()
-> [Diet]
-> [CompType]
-> ArrayElemTypeBase dim Names
arrayElemReturnType (ArrayPrimElem bt ()) ds args =
ArrayPrimElem bt als
where als = mconcat $ map aliases $ zipWith maskAliases args ds
arrayElemReturnType (ArrayPolyElem bt targs ()) ds args =
ArrayPolyElem bt (map (\arg -> typeArgReturnType arg ds args) targs) als
where als = mconcat $ map aliases $ zipWith maskAliases args ds
arrayElemReturnType (ArrayRecordElem et) ds args =
ArrayRecordElem $ fmap (\t -> recordArrayElemReturnType t ds args) et
recordArrayElemReturnType :: RecordArrayElemTypeBase dim ()
-> [Diet]
-> [CompType]
-> RecordArrayElemTypeBase dim Names
recordArrayElemReturnType (RecordArrayElem et) ds args =
RecordArrayElem $ arrayElemReturnType et ds args
recordArrayElemReturnType (RecordArrayArrayElem et shape u) ds args =
RecordArrayArrayElem (arrayElemReturnType et ds args) shape u
concreteType :: TypeBase f vn -> Bool
concreteType Prim{} = True
concreteType TypeVar{} = False
concreteType Arrow{} = False
concreteType (Record ts) = all concreteType ts
concreteType (Array at _ _) = concreteArrayType at
where concreteArrayType ArrayPrimElem{} = True
concreteArrayType ArrayPolyElem{} = False
concreteArrayType (ArrayRecordElem ts) = all concreteRecordArrayElem ts
concreteRecordArrayElem (RecordArrayElem et) = concreteArrayType et
concreteRecordArrayElem (RecordArrayArrayElem et _ _) = concreteArrayType et
orderZero :: TypeBase dim as -> Bool
orderZero (Prim _) = True
orderZero Array{} = True
orderZero (Record fs) = all orderZero $ M.elems fs
orderZero TypeVar{} = True
orderZero Arrow{} = False
patternDimNames :: PatternBase Info VName -> Names
patternDimNames (TuplePattern ps _) = foldMap patternDimNames ps
patternDimNames (RecordPattern fs _) = foldMap (patternDimNames . snd) fs
patternDimNames (PatternParens p _) = patternDimNames p
patternDimNames (Id _ (Info tp) _) = typeDimNames tp
patternDimNames (Wildcard (Info tp) _) = typeDimNames tp
patternDimNames (PatternAscription p (TypeDecl _ (Info t)) _) =
patternDimNames p <> typeDimNames t
typeDimNames :: TypeBase (DimDecl VName) als -> Names
typeDimNames = foldMap dimName . nestedDims
where dimName :: DimDecl VName -> Names
dimName (NamedDim qn) = S.singleton $ qualLeaf qn
dimName _ = mempty
patternOrderZero :: PatternBase Info vn -> Bool
patternOrderZero pat = case pat of
TuplePattern ps _ -> all patternOrderZero ps
RecordPattern fs _ -> all (patternOrderZero . snd) fs
PatternParens p _ -> patternOrderZero p
Id _ (Info t) _ -> orderZero t
Wildcard (Info t) _ -> orderZero t
PatternAscription p _ _ -> patternOrderZero p
patIdentSet :: (Functor f, Ord vn) => PatternBase f vn -> S.Set (IdentBase f vn)
patIdentSet (Id v t loc) = S.singleton $ Ident v (removeShapeAnnotations <$> t) loc
patIdentSet (PatternParens p _) = patIdentSet p
patIdentSet (TuplePattern pats _) = mconcat $ map patIdentSet pats
patIdentSet (RecordPattern fs _) = mconcat $ map (patIdentSet . snd) fs
patIdentSet Wildcard{} = mempty
patIdentSet (PatternAscription p _ _) = patIdentSet p
patternType :: PatternBase Info VName -> CompType
patternType (Wildcard (Info t) _) = removeShapeAnnotations t
patternType (PatternParens p _) = patternType p
patternType (Id _ (Info t) _) = removeShapeAnnotations t
patternType (TuplePattern pats _) = tupleRecord $ map patternType pats
patternType (RecordPattern fs _) = Record $ patternType <$> M.fromList fs
patternType (PatternAscription p _ _) = patternType p
patternStructType :: PatternBase Info VName -> StructType
patternStructType (PatternAscription p _ _) = patternStructType p
patternStructType (PatternParens p _) = patternStructType p
patternStructType (Id _ (Info t) _) = t `setAliases` ()
patternStructType (TuplePattern ps _) = tupleRecord $ map patternStructType ps
patternStructType (RecordPattern fs _) = Record $ patternStructType <$> M.fromList fs
patternStructType (Wildcard (Info t) _) = vacuousShapeAnnotations $ toStruct t
patternParam :: PatternBase Info VName -> (Maybe VName, StructType)
patternParam (PatternParens p _) =
patternParam p
patternParam (PatternAscription (Id v _ _) td _) =
(Just v, unInfo $ expandedType td)
patternParam p =
(Nothing, patternStructType p)
patternNoShapeAnnotations :: PatternBase Info VName -> PatternBase Info VName
patternNoShapeAnnotations (PatternAscription p (TypeDecl te (Info t)) loc) =
PatternAscription (patternNoShapeAnnotations p)
(TypeDecl te $ Info $ vacuousShapeAnnotations t) loc
patternNoShapeAnnotations (PatternParens p loc) =
PatternParens (patternNoShapeAnnotations p) loc
patternNoShapeAnnotations (Id v (Info t) loc) =
Id v (Info $ vacuousShapeAnnotations t) loc
patternNoShapeAnnotations (TuplePattern ps loc) =
TuplePattern (map patternNoShapeAnnotations ps) loc
patternNoShapeAnnotations (RecordPattern ps loc) =
RecordPattern (map (fmap patternNoShapeAnnotations) ps) loc
patternNoShapeAnnotations (Wildcard (Info t) loc) =
Wildcard (Info (vacuousShapeAnnotations t)) loc
namesToPrimTypes :: M.Map Name PrimType
namesToPrimTypes = M.fromList
[ (nameFromString $ pretty t, t) |
t <- Bool :
map Signed [minBound..maxBound] ++
map Unsigned [minBound..maxBound] ++
map FloatType [minBound..maxBound] ]
data Intrinsic = IntrinsicMonoFun [PrimType] PrimType
| IntrinsicOverloadedFun [PrimType] [Maybe PrimType] (Maybe PrimType)
| IntrinsicPolyFun [TypeParamBase VName] [TypeBase () ()] (TypeBase () ())
| IntrinsicType PrimType
| IntrinsicEquality
| IntrinsicOpaque
intrinsics :: M.Map VName Intrinsic
intrinsics = M.fromList $ zipWith namify [10..] $
map primFun (M.toList Primitive.primFuns) ++
[ ("~", IntrinsicOverloadedFun
(map Signed [minBound..maxBound] ++
map Unsigned [minBound..maxBound])
[Nothing] Nothing)
, ("!", IntrinsicMonoFun [Bool] Bool)] ++
[("opaque", IntrinsicOpaque)] ++
map unOpFun Primitive.allUnOps ++
map binOpFun Primitive.allBinOps ++
map cmpOpFun Primitive.allCmpOps ++
map convOpFun Primitive.allConvOps ++
map signFun Primitive.allIntTypes ++
map unsignFun Primitive.allIntTypes ++
map intrinsicType (map Signed [minBound..maxBound] ++
map Unsigned [minBound..maxBound] ++
map FloatType [minBound..maxBound] ++
[Bool]) ++
mapMaybe mkIntrinsicBinOp [minBound..maxBound] ++
[("flatten", IntrinsicPolyFun [tp_a]
[Array (ArrayPolyElem tv_a' [] ()) (rank 2) Nonunique] $
Array (ArrayPolyElem tv_a' [] ()) (rank 1) Nonunique),
("unflatten", IntrinsicPolyFun [tp_a]
[Prim $ Signed Int32,
Prim $ Signed Int32,
Array (ArrayPolyElem tv_a' [] ()) (rank 1) Nonunique] $
Array (ArrayPolyElem tv_a' [] ()) (rank 2) Nonunique),
("concat", IntrinsicPolyFun [tp_a]
[arr_a, arr_a] uarr_a),
("rotate", IntrinsicPolyFun [tp_a]
[Prim $ Signed Int32, arr_a] arr_a),
("transpose", IntrinsicPolyFun [tp_a] [arr_a] arr_a),
("cmp_threshold", IntrinsicPolyFun []
[Prim $ Signed Int32,
Array (ArrayPrimElem (Signed Int32) ()) (rank 1) Nonunique] $
Prim Bool),
("scatter", IntrinsicPolyFun [tp_a]
[Array (ArrayPolyElem tv_a' [] ()) (rank 1) Unique,
Array (ArrayPrimElem (Signed Int32) ()) (rank 1) Nonunique,
Array (ArrayPolyElem tv_a' [] ()) (rank 1) Nonunique] $
Array (ArrayPolyElem tv_a' [] ()) (rank 1) Unique),
("zip", IntrinsicPolyFun [tp_a, tp_b] [arr_a, arr_b] arr_a_b),
("unzip", IntrinsicPolyFun [tp_a, tp_b] [arr_a_b] t_arr_a_arr_b),
("gen_reduce", IntrinsicPolyFun [tp_a]
[uarr_a,
t_a `arr` (t_a `arr` t_a),
t_a,
Array (ArrayPrimElem (Signed Int32) ()) (rank 1) Nonunique,
arr_a]
uarr_a),
("map", IntrinsicPolyFun [tp_a, tp_b] [t_a `arr` t_b, arr_a] uarr_b),
("reduce", IntrinsicPolyFun [tp_a]
[t_a `arr` (t_a `arr` t_a), t_a, arr_a] t_a),
("reduce_comm", IntrinsicPolyFun [tp_a]
[t_a `arr` (t_a `arr` t_a), t_a, arr_a] t_a),
("scan", IntrinsicPolyFun [tp_a]
[t_a `arr` (t_a `arr` t_a), t_a, arr_a] uarr_a),
("partition",
IntrinsicPolyFun [tp_a]
[Prim (Signed Int32), t_a `arr` Prim (Signed Int32), arr_a] $
tupleRecord [uarr_a, Array (ArrayPrimElem (Signed Int32) ()) (rank 1) Unique]),
("stream_map",
IntrinsicPolyFun [tp_a, tp_b] [arr_a `arr` arr_b, arr_a] uarr_b),
("stream_map_per",
IntrinsicPolyFun [tp_a, tp_b] [arr_a `arr` arr_b, arr_a] uarr_b),
("stream_red",
IntrinsicPolyFun [tp_a, tp_b] [t_b `arr` (t_b `arr` t_b), arr_a `arr` t_b, arr_a] t_b),
("stream_red_per",
IntrinsicPolyFun [tp_a, tp_b] [t_b `arr` (t_b `arr` t_b), arr_a `arr` t_b, arr_a] t_b),
("trace", IntrinsicPolyFun [tp_a] [t_a] t_a),
("break", IntrinsicPolyFun [tp_a] [t_a] t_a)]
where tv_a = VName (nameFromString "a") 0
tv_a' = typeName tv_a
t_a = TypeVar () Nonunique tv_a' []
arr_a = Array (ArrayPolyElem tv_a' [] ()) (rank 1) Nonunique
uarr_a = Array (ArrayPolyElem tv_a' [] ()) (rank 1) Unique
tp_a = TypeParamType Unlifted tv_a noLoc
tv_b = VName (nameFromString "b") 1
tv_b' = typeName tv_b
t_b = TypeVar () Nonunique tv_b' []
arr_b = Array (ArrayPolyElem tv_b' [] ()) (rank 1) Nonunique
uarr_b = Array (ArrayPolyElem tv_b' [] ()) (rank 1) Unique
tp_b = TypeParamType Unlifted tv_b noLoc
arr_a_b = Array (ArrayRecordElem (M.fromList $ zip tupleFieldNames
[RecordArrayElem $ ArrayPolyElem tv_a' [] (),
RecordArrayElem $ ArrayPolyElem tv_b' [] ()]))
(rank 1) Nonunique
t_arr_a_arr_b = Record $ M.fromList $ zip tupleFieldNames [arr_a, arr_b]
arr = Arrow mempty Nothing
namify i (k,v) = (VName (nameFromString k) i, v)
primFun (name, (ts,t, _)) =
(name, IntrinsicMonoFun (map unPrim ts) $ unPrim t)
unOpFun bop = (pretty bop, IntrinsicMonoFun [t] t)
where t = unPrim $ Primitive.unOpType bop
binOpFun bop = (pretty bop, IntrinsicMonoFun [t, t] t)
where t = unPrim $ Primitive.binOpType bop
cmpOpFun bop = (pretty bop, IntrinsicMonoFun [t, t] Bool)
where t = unPrim $ Primitive.cmpOpType bop
convOpFun cop = (pretty cop, IntrinsicMonoFun [unPrim ft] $ unPrim tt)
where (ft, tt) = Primitive.convOpType cop
signFun t = ("sign_" ++ pretty t, IntrinsicMonoFun [Unsigned t] $ Signed t)
unsignFun t = ("unsign_" ++ pretty t, IntrinsicMonoFun [Signed t] $ Unsigned t)
unPrim (Primitive.IntType t) = Signed t
unPrim (Primitive.FloatType t) = FloatType t
unPrim Primitive.Bool = Bool
unPrim Primitive.Cert = Bool
intrinsicType t = (pretty t, IntrinsicType t)
anyIntType = map Signed [minBound..maxBound] ++
map Unsigned [minBound..maxBound]
anyNumberType = anyIntType ++
map FloatType [minBound..maxBound]
anyPrimType = Bool : anyNumberType
mkIntrinsicBinOp :: BinOp -> Maybe (String, Intrinsic)
mkIntrinsicBinOp op = do op' <- intrinsicBinOp op
return (pretty op, op')
binOp ts = Just $ IntrinsicOverloadedFun ts [Nothing, Nothing] Nothing
ordering = Just $ IntrinsicOverloadedFun anyPrimType [Nothing, Nothing] (Just Bool)
intrinsicBinOp Plus = binOp anyNumberType
intrinsicBinOp Minus = binOp anyNumberType
intrinsicBinOp Pow = binOp anyNumberType
intrinsicBinOp Times = binOp anyNumberType
intrinsicBinOp Divide = binOp anyNumberType
intrinsicBinOp Mod = binOp anyNumberType
intrinsicBinOp Quot = binOp anyIntType
intrinsicBinOp Rem = binOp anyIntType
intrinsicBinOp ShiftR = binOp anyIntType
intrinsicBinOp ShiftL = binOp anyIntType
intrinsicBinOp Band = binOp anyIntType
intrinsicBinOp Xor = binOp anyIntType
intrinsicBinOp Bor = binOp anyIntType
intrinsicBinOp LogAnd = Just $ IntrinsicMonoFun [Bool,Bool] Bool
intrinsicBinOp LogOr = Just $ IntrinsicMonoFun [Bool,Bool] Bool
intrinsicBinOp Equal = Just IntrinsicEquality
intrinsicBinOp NotEqual = Just IntrinsicEquality
intrinsicBinOp Less = ordering
intrinsicBinOp Leq = ordering
intrinsicBinOp Greater = ordering
intrinsicBinOp Geq = ordering
intrinsicBinOp _ = Nothing
maxIntrinsicTag :: Int
maxIntrinsicTag = maximum $ map baseTag $ M.keys intrinsics
qualName :: v -> QualName v
qualName = QualName []
qualify :: v -> QualName v -> QualName v
qualify k (QualName ks v) = QualName (k:ks) v
typeName :: VName -> TypeName
typeName = typeNameFromQualName . qualName
progImports :: ProgBase f vn -> [(String,SrcLoc)]
progImports = concatMap decImports . progDecs
decImports :: DecBase f vn -> [(String,SrcLoc)]
decImports (OpenDec x _ _) = modExpImports x
decImports (ModDec md) = modExpImports $ modExp md
decImports SigDec{} = []
decImports TypeDec{} = []
decImports ValDec{} = []
decImports (LocalDec d _) = decImports d
modExpImports :: ModExpBase f vn -> [(String,SrcLoc)]
modExpImports ModVar{} = []
modExpImports (ModParens p _) = modExpImports p
modExpImports (ModImport f _ loc) = [(f,loc)]
modExpImports (ModDecs ds _) = concatMap decImports ds
modExpImports (ModApply _ me _ _ _) = modExpImports me
modExpImports (ModAscript me _ _ _) = modExpImports me
modExpImports ModLambda{} = []
progModuleTypes :: Ord vn => ProgBase f vn -> S.Set vn
progModuleTypes = mconcat . map onDec . progDecs
where onDec (OpenDec x _ _) = onModExp x
onDec (ModDec md) =
maybe mempty (onSigExp . fst) (modSignature md) <> onModExp (modExp md)
onDec SigDec{} = mempty
onDec TypeDec{} = mempty
onDec ValDec{} = mempty
onDec (LocalDec _ _) = mempty
onModExp ModVar{} = mempty
onModExp (ModParens p _) = onModExp p
onModExp ModImport {} = mempty
onModExp (ModDecs ds _) = mconcat $ map onDec ds
onModExp (ModApply me1 me2 _ _ _) = onModExp me1 <> onModExp me2
onModExp (ModAscript me se _ _) = onModExp me <> onSigExp se
onModExp (ModLambda p r me _) =
onModParam p <> maybe mempty (onSigExp . fst) r <> onModExp me
onModParam = onSigExp . modParamType
onSigExp (SigVar v _) = S.singleton $ qualLeaf v
onSigExp (SigParens e _) = onSigExp e
onSigExp SigSpecs{} = mempty
onSigExp (SigWith e _ _) = onSigExp e
onSigExp (SigArrow _ e1 e2 _) = onSigExp e1 <> onSigExp e2
identifierReference :: String -> Maybe ((String, String, Maybe FilePath), String)
identifierReference ('`' : s)
| (identifier, '`' : '@' : s') <- break (=='`') s,
(namespace, s'') <- span isAlpha s',
not $ null namespace =
case s'' of
'@' : '"' : s'''
| (file, '"' : s'''') <- span (/= '"') s''' ->
Just ((identifier, namespace, Just file), s'''')
_ -> Just ((identifier, namespace, Nothing), s'')
identifierReference _ = Nothing
identifierReferences :: String -> [(String, String, Maybe FilePath)]
identifierReferences [] = []
identifierReferences s
| Just (ref, s') <- identifierReference s =
ref : identifierReferences s'
identifierReferences (_:s') =
identifierReferences s'
leadingOperator :: Name -> BinOp
leadingOperator s = maybe Backtick snd $ find ((`isPrefixOf` s') . fst) $
sortBy (flip $ comparing $ length . fst) $
zip (map pretty operators) operators
where s' = nameToString s
operators :: [BinOp]
operators = [minBound..maxBound::BinOp]
type UncheckedType = TypeBase (ShapeDecl Name) ()
type UncheckedTypeExp = TypeExp Name
type UncheckedArrayElemType = ArrayElemTypeBase (ShapeDecl Name) ()
type UncheckedTypeDecl = TypeDeclBase NoInfo Name
type UncheckedIdent = IdentBase NoInfo Name
type UncheckedDimIndex = DimIndexBase NoInfo Name
type UncheckedExp = ExpBase NoInfo Name
type UncheckedModExp = ModExpBase NoInfo Name
type UncheckedSigExp = SigExpBase NoInfo Name
type UncheckedTypeParam = TypeParamBase Name
type UncheckedPattern = PatternBase NoInfo Name
type UncheckedValBind = ValBindBase NoInfo Name
type UncheckedDec = DecBase NoInfo Name
type UncheckedProg = ProgBase NoInfo Name