{-# LANGUAGE FlexibleContexts, FlexibleInstances, TypeSynonymInstances #-}
module Futhark.Representation.AST.Attributes.Types
(
rankShaped
, arrayRank
, arrayShape
, modifyArrayShape
, setArrayShape
, existential
, uniqueness
, setUniqueness
, unique
, staticShapes
, staticShapes1
, primType
, arrayOf
, arrayOfRow
, arrayOfShape
, setOuterSize
, setDimSize
, setOuterDim
, setDim
, setArrayDims
, setArrayExtDims
, peelArray
, stripArray
, arrayDims
, arrayExtDims
, shapeSize
, arraySize
, arraysSize
, rowType
, elemType
, transposeType
, rearrangeType
, diet
, subtypeOf
, subtypesOf
, toDecl
, fromDecl
, extractShapeContext
, shapeContext
, shapeContextSize
, hasStaticShape
, hasStaticShapes
, generaliseExtTypes
, existentialiseExtTypes
, shapeMapping
, shapeMapping'
, shapeExtMapping
, int8, int16, int32, int64
, float32, float64
, Typed (..)
, DeclTyped (..)
, ExtTyped (..)
, DeclExtTyped (..)
, SetType (..)
, FixExt (..)
)
where
import Control.Monad.State
import Data.Maybe
import Data.Monoid ((<>))
import Data.List (elemIndex)
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Futhark.Representation.AST.Syntax.Core
import Futhark.Representation.AST.Attributes.Constants
import Futhark.Representation.AST.Attributes.Rearrange
rankShaped :: ArrayShape shape => TypeBase shape u -> TypeBase Rank u
rankShaped (Array et sz u) = Array et (Rank $ shapeRank sz) u
rankShaped (Prim et) = Prim et
rankShaped (Mem size space) = Mem size space
arrayRank :: ArrayShape shape => TypeBase shape u -> Int
arrayRank = shapeRank . arrayShape
arrayShape :: ArrayShape shape => TypeBase shape u -> shape
arrayShape (Array _ ds _) = ds
arrayShape _ = mempty
modifyArrayShape :: ArrayShape newshape =>
(oldshape -> newshape)
-> TypeBase oldshape u
-> TypeBase newshape u
modifyArrayShape f (Array t ds u)
| shapeRank ds' == 0 = Prim t
| otherwise = Array t (f ds) u
where ds' = f ds
modifyArrayShape _ (Prim t) = Prim t
modifyArrayShape _ (Mem size space) = Mem size space
setArrayShape :: ArrayShape newshape =>
TypeBase oldshape u
-> newshape
-> TypeBase newshape u
setArrayShape t ds = modifyArrayShape (const ds) t
existential :: ExtType -> Bool
existential = any ext . shapeDims . arrayShape
where ext (Ext _) = True
ext (Free _) = False
uniqueness :: TypeBase shape Uniqueness -> Uniqueness
uniqueness (Array _ _ u) = u
uniqueness _ = Nonunique
unique :: TypeBase shape Uniqueness -> Bool
unique = (==Unique) . uniqueness
setUniqueness :: TypeBase shape Uniqueness
-> Uniqueness
-> TypeBase shape Uniqueness
setUniqueness (Array et dims _) u = Array et dims u
setUniqueness t _ = t
staticShapes :: [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes = map staticShapes1
staticShapes1 :: TypeBase Shape u -> TypeBase ExtShape u
staticShapes1 (Prim bt) =
Prim bt
staticShapes1 (Array bt (Shape shape) u) =
Array bt (Shape $ map Free shape) u
staticShapes1 (Mem size space) =
Mem size space
arrayOf :: ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (Array et size1 _) size2 u =
Array et (size2 <> size1) u
arrayOf (Prim et) s _
| 0 <- shapeRank s = Prim et
arrayOf (Prim et) size u =
Array et size u
arrayOf Mem{} _ _ =
error "arrayOf Mem"
arrayOfRow :: ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d
-> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow t size = arrayOf t (Shape [size]) NoUniqueness
arrayOfShape :: Type -> Shape -> Type
arrayOfShape t shape = arrayOf t shape NoUniqueness
setArrayDims :: TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
setArrayDims t dims = t `setArrayShape` Shape dims
setArrayExtDims :: TypeBase oldshape u -> [ExtSize] -> TypeBase ExtShape u
setArrayExtDims t dims = t `setArrayShape` Shape dims
setOuterSize :: ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
setOuterSize = setDimSize 0
setDimSize :: ArrayShape (ShapeBase d) =>
Int -> TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
setDimSize i t e = t `setArrayShape` setDim i (arrayShape t) e
setOuterDim :: ShapeBase d -> d -> ShapeBase d
setOuterDim = setDim 0
setDim :: Int -> ShapeBase d -> d -> ShapeBase d
setDim i (Shape ds) e = Shape $ take i ds ++ e : drop (i+1) ds
peelArray :: ArrayShape shape =>
Int -> TypeBase shape u -> Maybe (TypeBase shape u)
peelArray 0 t = Just t
peelArray n (Array et shape u)
| shapeRank shape == n = Just $ Prim et
| shapeRank shape > n = Just $ Array et (stripDims n shape) u
peelArray _ _ = Nothing
stripArray :: ArrayShape shape => Int -> TypeBase shape u -> TypeBase shape u
stripArray n (Array et shape u)
| n < shapeRank shape = Array et (stripDims n shape) u
| otherwise = Prim et
stripArray _ t = t
shapeSize :: Int -> Shape -> SubExp
shapeSize i shape = case drop i $ shapeDims shape of
e : _ -> e
[] -> constant (0 :: Int32)
arrayDims :: TypeBase Shape u -> [SubExp]
arrayDims = shapeDims . arrayShape
arrayExtDims :: TypeBase ExtShape u -> [ExtSize]
arrayExtDims = shapeDims . arrayShape
arraySize :: Int -> TypeBase Shape u -> SubExp
arraySize i = shapeSize i . arrayShape
arraysSize :: Int -> [TypeBase Shape u] -> SubExp
arraysSize _ [] = constant (0 :: Int32)
arraysSize i (t:_) = arraySize i t
rowType :: ArrayShape shape => TypeBase shape u -> TypeBase shape u
rowType = stripArray 1
primType :: TypeBase shape u -> Bool
primType Array{} = False
primType Mem{} = False
primType _ = True
elemType :: TypeBase shape u -> PrimType
elemType (Array t _ _) = t
elemType (Prim t) = t
elemType Mem{} = error "elemType Mem"
transposeType :: Type -> Type
transposeType = rearrangeType [1,0]
rearrangeType :: [Int] -> Type -> Type
rearrangeType perm t =
t `setArrayShape` Shape (rearrangeShape perm' $ arrayDims t)
where perm' = perm ++ [length perm .. arrayRank t - 1]
diet :: TypeBase shape Uniqueness -> Diet
diet (Prim _) = Observe
diet (Array _ _ Unique) = Consume
diet (Array _ _ Nonunique) = Observe
diet Mem{} = Observe
subtypeOf :: (Ord u, ArrayShape shape) =>
TypeBase shape u
-> TypeBase shape u
-> Bool
subtypeOf (Array t1 shape1 u1) (Array t2 shape2 u2) =
u2 <= u1 &&
t1 == t2 &&
shape1 `subShapeOf` shape2
subtypeOf (Prim t1) (Prim t2) = t1 == t2
subtypeOf (Mem _ space1) (Mem _ space2) = space1 == space2
subtypeOf _ _ = False
subtypesOf :: (Ord u, ArrayShape shape) =>
[TypeBase shape u]
-> [TypeBase shape u]
-> Bool
subtypesOf xs ys = length xs == length ys &&
and (zipWith subtypeOf xs ys)
toDecl :: TypeBase shape NoUniqueness
-> Uniqueness
-> TypeBase shape Uniqueness
toDecl (Prim bt) _ = Prim bt
toDecl (Array et shape _) u = Array et shape u
toDecl (Mem size space) _ = Mem size space
fromDecl :: TypeBase shape Uniqueness
-> TypeBase shape NoUniqueness
fromDecl (Prim bt) = Prim bt
fromDecl (Array et shape _) = Array et shape NoUniqueness
fromDecl (Mem size space) = Mem size space
extractShapeContext :: [TypeBase ExtShape u] -> [[a]] -> [a]
extractShapeContext ts shapes =
evalState (concat <$> zipWithM extract ts shapes) S.empty
where extract t shape =
catMaybes <$> zipWithM extract' (shapeDims $ arrayShape t) shape
extract' (Ext x) v = do
seen <- gets $ S.member x
if seen then return Nothing
else do modify $ S.insert x
return $ Just v
extract' (Free _) _ = return Nothing
shapeContext :: [TypeBase ExtShape u] -> S.Set Int
shapeContext = S.fromList
. concatMap (mapMaybe ext . shapeDims . arrayShape)
where ext (Ext x) = Just x
ext (Free _) = Nothing
shapeContextSize :: [ExtType] -> Int
shapeContextSize = S.size . shapeContext
hasStaticShape :: ExtType -> Maybe Type
hasStaticShape (Prim bt) =
Just $ Prim bt
hasStaticShape (Mem size space) =
Just $ Mem size space
hasStaticShape (Array bt (Shape shape) u) =
Array bt <$> (Shape <$> mapM isFree shape) <*> pure u
where isFree (Free s) = Just s
isFree (Ext _) = Nothing
hasStaticShapes :: [ExtType] -> Maybe [Type]
hasStaticShapes = mapM hasStaticShape
generaliseExtTypes :: [TypeBase ExtShape u]
-> [TypeBase ExtShape u]
-> [TypeBase ExtShape u]
generaliseExtTypes rt1 rt2 =
evalState (zipWithM unifyExtShapes rt1 rt2) (0, M.empty)
where unifyExtShapes t1 t2 =
setArrayShape t1 . Shape <$>
zipWithM unifyExtDims
(shapeDims $ arrayShape t1)
(shapeDims $ arrayShape t2)
unifyExtDims (Free se1) (Free se2)
| se1 == se2 = return $ Free se1
| otherwise = do (n,m) <- get
put (n + 1, m)
return $ Ext n
unifyExtDims (Ext x) (Ext y)
| x == y = Ext <$> (maybe (new x) return =<<
gets (M.lookup x . snd))
unifyExtDims (Ext x) _ = Ext <$> new x
unifyExtDims _ (Ext x) = Ext <$> new x
new x = do (n,m) <- get
put (n + 1, M.insert x n m)
return n
existentialiseExtTypes :: [VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes inaccessible = map makeBoundShapesFree
where makeBoundShapesFree =
modifyArrayShape $ fmap checkDim
checkDim (Free (Var v))
| Just i <- v `elemIndex` inaccessible =
Ext i
checkDim d = d
shapeMapping :: [TypeBase Shape u0] -> [TypeBase Shape u1] -> M.Map VName SubExp
shapeMapping ts = shapeMapping' ts . map arrayDims
shapeMapping' :: [TypeBase Shape u] -> [[a]] -> M.Map VName a
shapeMapping' = dimMapping arrayDims id match
where match Constant{} _ = M.empty
match (Var v) dim = M.singleton v dim
shapeExtMapping :: [TypeBase ExtShape u] -> [TypeBase Shape u1] -> M.Map Int SubExp
shapeExtMapping = dimMapping arrayExtDims arrayDims match
where match Free{} _ = mempty
match (Ext i) dim = M.singleton i dim
dimMapping :: Monoid res =>
(t1 -> [dim1]) -> (t2 -> [dim2]) -> (dim1 -> dim2 -> res)
-> [t1] -> [t2]
-> res
dimMapping getDims1 getDims2 f ts1 ts2 =
mconcat $ concat $ zipWith (zipWith f) (map getDims1 ts1) (map getDims2 ts2)
int8 :: PrimType
int8 = IntType Int8
int16 :: PrimType
int16 = IntType Int16
int32 :: PrimType
int32 = IntType Int32
int64 :: PrimType
int64 = IntType Int64
float32 :: PrimType
float32 = FloatType Float32
float64 :: PrimType
float64 = FloatType Float64
class Typed t where
typeOf :: t -> Type
instance Typed Type where
typeOf = id
instance Typed DeclType where
typeOf = fromDecl
instance Typed Ident where
typeOf = identType
instance Typed attr => Typed (Param attr) where
typeOf = typeOf . paramAttr
instance Typed attr => Typed (PatElemT attr) where
typeOf = typeOf . patElemAttr
instance Typed b => Typed (a,b) where
typeOf = typeOf . snd
class DeclTyped t where
declTypeOf :: t -> DeclType
instance DeclTyped DeclType where
declTypeOf = id
instance DeclTyped attr => DeclTyped (Param attr) where
declTypeOf = declTypeOf . paramAttr
class FixExt t => ExtTyped t where
extTypeOf :: t -> ExtType
instance ExtTyped ExtType where
extTypeOf = id
class FixExt t => DeclExtTyped t where
declExtTypeOf :: t -> DeclExtType
instance DeclExtTyped DeclExtType where
declExtTypeOf = id
class Typed a => SetType a where
setType :: a -> Type -> a
instance SetType Type where
setType _ t = t
instance SetType b => SetType (a, b) where
setType (a, b) t = (a, setType b t)
instance SetType attr => SetType (PatElemT attr) where
setType (PatElem name attr) t =
PatElem name $ setType attr t
class FixExt t where
fixExt :: Int -> SubExp -> t -> t
instance (FixExt shape, ArrayShape shape) => FixExt (TypeBase shape u) where
fixExt i se = modifyArrayShape $ fixExt i se
instance FixExt d => FixExt (ShapeBase d) where
fixExt i se = fmap $ fixExt i se
instance FixExt a => FixExt [a] where
fixExt i se = fmap $ fixExt i se
instance FixExt ExtSize where
fixExt i se (Ext j) | j > i = Ext $ j - 1
| j == i = Free se
| otherwise = Ext j
fixExt _ _ (Free x) = Free x
instance FixExt () where
fixExt _ _ () = ()