{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
module Futhark.Representation.AST.Syntax.Core
(
module Language.Futhark.Core
, module Futhark.Representation.Primitive
, Uniqueness(..)
, NoUniqueness(..)
, ShapeBase(..)
, Shape
, Ext(..)
, ExtSize
, ExtShape
, Rank(..)
, ArrayShape(..)
, Space (..)
, SpaceId
, TypeBase(..)
, Type
, ExtType
, DeclType
, DeclExtType
, Diet(..)
, ErrorMsg (..)
, ErrorMsgPart (..)
, PrimValue(..)
, Ident (..)
, Certificates(..)
, SubExp(..)
, ParamT (..)
, Param
, DimIndex (..)
, Slice
, dimFix
, sliceIndices
, sliceDims
, unitSlice
, fixSlice
, PatElemT (..)
, Names
) where
import Control.Monad.State
import Data.Maybe
import Data.Monoid ((<>))
import Data.String
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Data.Traversable
import Language.Futhark.Core
import Futhark.Representation.Primitive
newtype ShapeBase d = Shape { shapeDims :: [d] }
deriving (Eq, Ord, Show)
type Shape = ShapeBase SubExp
data Ext a = Ext Int
| Free a
deriving (Eq, Ord, Show)
type ExtSize = Ext SubExp
type ExtShape = ShapeBase ExtSize
newtype Rank = Rank Int
deriving (Show, Eq, Ord)
class (Monoid a, Eq a, Ord a) => ArrayShape a where
shapeRank :: a -> Int
stripDims :: Int -> a -> a
subShapeOf :: a -> a -> Bool
instance Semigroup (ShapeBase d) where
Shape l1 <> Shape l2 = Shape $ l1 `mappend` l2
instance Monoid (ShapeBase d) where
mempty = Shape mempty
instance Functor ShapeBase where
fmap f = Shape . map f . shapeDims
instance ArrayShape (ShapeBase SubExp) where
shapeRank (Shape l) = length l
stripDims n (Shape dims) = Shape $ drop n dims
subShapeOf = (==)
instance ArrayShape (ShapeBase ExtSize) where
shapeRank (Shape l) = length l
stripDims n (Shape dims) = Shape $ drop n dims
subShapeOf (Shape ds1) (Shape ds2) =
length ds1 == length ds2 &&
evalState (and <$> zipWithM subDimOf ds1 ds2) M.empty
where subDimOf (Free se1) (Free se2) = return $ se1 == se2
subDimOf (Ext _) (Free _) = return False
subDimOf (Free _) (Ext _) = return True
subDimOf (Ext x) (Ext y) = do
extmap <- get
case M.lookup y extmap of
Just ywas | ywas == x -> return True
| otherwise -> return False
Nothing -> do put $ M.insert y x extmap
return True
instance Semigroup Rank where
Rank x <> Rank y = Rank $ x + y
instance Monoid Rank where
mempty = Rank 0
instance ArrayShape Rank where
shapeRank (Rank x) = x
stripDims n (Rank x) = Rank $ x - n
subShapeOf = (==)
data Space = DefaultSpace
| Space SpaceId
deriving (Show, Eq, Ord)
type SpaceId = String
data NoUniqueness = NoUniqueness
deriving (Eq, Ord, Show)
data TypeBase shape u = Prim PrimType
| Array PrimType shape u
| Mem SubExp Space
deriving (Show, Eq, Ord)
type Type = TypeBase Shape NoUniqueness
type ExtType = TypeBase ExtShape NoUniqueness
type DeclType = TypeBase Shape Uniqueness
type DeclExtType = TypeBase ExtShape Uniqueness
data Diet = Consume
| Observe
deriving (Eq, Ord, Show)
data Ident = Ident { identName :: VName
, identType :: Type
}
deriving (Show)
instance Eq Ident where
x == y = identName x == identName y
instance Ord Ident where
x `compare` y = identName x `compare` identName y
newtype Certificates = Certificates { unCertificates :: [VName] }
deriving (Eq, Ord, Show)
instance Semigroup Certificates where
Certificates x <> Certificates y = Certificates (x <> y)
instance Monoid Certificates where
mempty = Certificates mempty
data SubExp = Constant PrimValue
| Var VName
deriving (Show, Eq, Ord)
data ParamT attr = Param
{ paramName :: VName
, paramAttr :: attr
}
deriving (Ord, Show, Eq)
type Param = ParamT
instance Foldable ParamT where
foldMap = foldMapDefault
instance Functor ParamT where
fmap = fmapDefault
instance Traversable ParamT where
traverse f (Param name attr) = Param name <$> f attr
data DimIndex d = DimFix
d
| DimSlice d d d
deriving (Eq, Ord, Show)
instance Functor DimIndex where
fmap f (DimFix i) = DimFix $ f i
fmap f (DimSlice i j s) = DimSlice (f i) (f j) (f s)
instance Foldable DimIndex where
foldMap f (DimFix d) = f d
foldMap f (DimSlice i j s) = f i <> f j <> f s
instance Traversable DimIndex where
traverse f (DimFix d) = DimFix <$> f d
traverse f (DimSlice i j s) = DimSlice <$> f i <*> f j <*> f s
type Slice d = [DimIndex d]
dimFix :: DimIndex d -> Maybe d
dimFix (DimFix d) = Just d
dimFix _ = Nothing
sliceIndices :: Slice d -> Maybe [d]
sliceIndices = mapM dimFix
sliceDims :: Slice d -> [d]
sliceDims = mapMaybe dimSlice
where dimSlice (DimSlice _ d _) = Just d
dimSlice DimFix{} = Nothing
unitSlice :: Num d => d -> d -> DimIndex d
unitSlice offset n = DimSlice offset n 1
fixSlice :: Num d => Slice d -> [d] -> [d]
fixSlice (DimFix j:mis') is' =
j : fixSlice mis' is'
fixSlice (DimSlice orig_k _ orig_s:mis') (i:is') =
(orig_k+i*orig_s) : fixSlice mis' is'
fixSlice _ _ = []
data PatElemT attr = PatElem { patElemName :: VName
, patElemAttr :: attr
}
deriving (Ord, Show, Eq)
instance Functor PatElemT where
fmap f (PatElem name attr) = PatElem name (f attr)
type Names = S.Set VName
newtype ErrorMsg a = ErrorMsg [ErrorMsgPart a]
deriving (Eq, Ord, Show)
instance IsString (ErrorMsg a) where
fromString = ErrorMsg . pure . fromString
data ErrorMsgPart a = ErrorString String
| ErrorInt32 a
deriving (Eq, Ord, Show)
instance IsString (ErrorMsgPart a) where
fromString = ErrorString
instance Functor ErrorMsg where
fmap f (ErrorMsg parts) = ErrorMsg $ map (fmap f) parts
instance Foldable ErrorMsg where
foldMap f (ErrorMsg parts) = foldMap (foldMap f) parts
instance Traversable ErrorMsg where
traverse f (ErrorMsg parts) = ErrorMsg <$> traverse (traverse f) parts
instance Functor ErrorMsgPart where
fmap _ (ErrorString s) = ErrorString s
fmap f (ErrorInt32 a) = ErrorInt32 $ f a
instance Foldable ErrorMsgPart where
foldMap _ ErrorString{} = mempty
foldMap f (ErrorInt32 a) = f a
instance Traversable ErrorMsgPart where
traverse _ (ErrorString s) = pure $ ErrorString s
traverse f (ErrorInt32 a) = ErrorInt32 <$> f a