{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveDataTypeable #-} module CodeGen.Types.Parsed where import CodeGen.Prelude import CodeGen.Types.CLI import Control.Monad import Data.Data import Data.Typeable import qualified Data.HashSet as HS -- ---------------------------------------- -- Parsed types -- ---------------------------------------- data Parsable = Ptr Parsable | TenType TenType -- | NNType NNType | CType CType deriving (Eq, Show, Generic, Hashable) data CType = CBool | CVoid | CPtrdiff | CFloat | CDouble | CLong | CUInt64 | CUInt32 | CUInt16 | CUInt8 | CInt64 | CInt32 | CInt16 | CInt8 | CInt -- must come _after_ all the other int types | CSize | CChar | CShort deriving (Eq, Show, Generic, Hashable, Bounded, Enum) newtype TenType = Pair { unTenType :: (RawTenType, LibType) } deriving (Eq, Show, Generic, Hashable) data RawTenType = Tensor | ByteTensor | CharTensor | ShortTensor | IntTensor | LongTensor | FloatTensor | DoubleTensor | HalfTensor | Storage | ByteStorage | CharStorage | ShortStorage | IntStorage | LongStorage | FloatStorage | DoubleStorage | HalfStorage | DescBuff | Generator | Allocator | File | Half | State -- NN types | IndexTensor | IntegerTensor -- real and accreal are parameterized occasionally differ by library, but I don't think this exists to date. | Real | AccReal -- FIXME: I don't think we need to enable THCThreadLocal, yet. But we would need to include a -- wrapper of ThreadId from the pthread package: https://hackage.haskell.org/package/pthread-0.2.0 -- | ThreadLocal -- THC-specific -- FIXME: while we can add this to codegen now, we need access to cudaStream_t in cuda_runtime_api | Stream -- THC-specific deriving (Eq, Show, Generic, Hashable, Bounded, Enum) isConcreteCudaPrefixed :: TenType -> Bool isConcreteCudaPrefixed (Pair (t, lib)) = (lib == THC || lib == THCUNN) && t `HS.member` HS.fromList [ ByteTensor , CharTensor , ShortTensor , IntTensor , LongTensor , FloatTensor , DoubleTensor , HalfTensor -- , IndexTensor ] allTenTypes :: [TenType] allTenTypes = Pair <$> ((,) <$> [minBound..maxBound] <*> [minBound..maxBound]) -- data NNType -- = IndexTensor -- | IntegerTensor -- deriving (Eq, Show, Generic, Hashable, Bounded, Enum) data Arg = Arg { argType :: Parsable , argName :: Text } deriving (Eq, Show, Generic, Hashable) data Function = Function { funPrefix :: Maybe (LibType, Text) , funName :: Text , funArgs :: [Arg] , funReturn :: Parsable } deriving (Eq, Show, Generic, Hashable) type Parser = Parsec Void String