{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Analysis.Type (
AccType, arrayType, sizeOf,
accType, expType, delayedAccType, delayedExpType,
preAccType, preExpType
) where
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Trafo.Base
import Data.Array.Accelerate.Type
import qualified Foreign.Storable as F
arrayType :: forall sh e. Array sh e -> TupleType (EltRepr e)
arrayType (Array _ _) = eltType (undefined::e)
type AccType acc = forall aenv sh e. acc aenv (Array sh e) -> TupleType (EltRepr e)
accType :: AccType OpenAcc
accType (OpenAcc acc) = preAccType accType acc
delayedAccType :: AccType DelayedOpenAcc
delayedAccType (Manifest acc) = preAccType delayedAccType acc
delayedAccType (Delayed _ f _)
| Lam (Body e) <- f = delayedExpType e
| otherwise = error "my favourite place in the world is wherever you happen to be"
preAccType :: forall acc aenv sh e.
AccType acc
-> PreOpenAcc acc aenv (Array sh e)
-> TupleType (EltRepr e)
preAccType k pacc =
case pacc of
Alet _ acc -> k acc
Avar _ -> case arrays (undefined :: (Array sh e)) of
ArraysRarray -> eltType (undefined::e)
#if __GLASGOW_HASKELL__ < 800
_ -> error "When I get sad, I stop being sad and be AWESOME instead."
#endif
Apply _ _ -> case arrays (undefined :: Array sh e) of
ArraysRarray -> eltType (undefined::e)
#if __GLASGOW_HASKELL__ < 800
_ -> error "TRUE STORY."
#endif
Atuple _ -> case arrays (undefined :: Array sh e) of
ArraysRarray -> eltType (undefined::e)
#if __GLASGOW_HASKELL__ < 800
_ -> error "I made you a cookie, but I eated it."
#endif
Aprj _ _ -> case arrays (undefined :: Array sh e) of
ArraysRarray -> eltType (undefined::e)
#if __GLASGOW_HASKELL__ < 800
_ -> error "Hey look! even the leaves are falling for you."
#endif
Aforeign _ _ _ -> case arrays (undefined :: Array sh e) of
ArraysRarray -> eltType (undefined::e)
#if __GLASGOW_HASKELL__ < 800
_ -> error "Who on earth wrote all these weird error messages?"
#endif
Acond _ acc _ -> k acc
Awhile _ _ acc -> k acc
Use a -> arrayType a
Unit _ -> eltType (undefined::e)
Generate _ _ -> eltType (undefined::e)
Transform _ _ _ _ -> eltType (undefined::e)
Reshape _ acc -> k acc
Replicate _ _ acc -> k acc
Slice _ acc _ -> k acc
Map _ _ -> eltType (undefined::e)
ZipWith _ _ _ -> eltType (undefined::e)
Fold _ _ acc -> k acc
FoldSeg _ _ acc _ -> k acc
Fold1 _ acc -> k acc
Fold1Seg _ acc _ -> k acc
Scanl _ _ acc -> k acc
Scanl1 _ acc -> k acc
Scanr _ _ acc -> k acc
Scanr1 _ acc -> k acc
Permute _ _ _ acc -> k acc
Backpermute _ _ acc -> k acc
Stencil _ _ _ -> eltType (undefined::e)
Stencil2 _ _ _ _ _ -> eltType (undefined::e)
expType :: OpenExp env aenv t -> TupleType (EltRepr t)
expType = preExpType accType
delayedExpType :: DelayedOpenExp env aenv t -> TupleType (EltRepr t)
delayedExpType = preExpType delayedAccType
preExpType :: forall acc aenv env t.
AccType acc
-> PreOpenExp acc aenv env t
-> TupleType (EltRepr t)
preExpType k e =
case e of
Let _ _ -> eltType (undefined::t)
Var _ -> eltType (undefined::t)
Const _ -> eltType (undefined::t)
Undef -> eltType (undefined::t)
Tuple _ -> eltType (undefined::t)
Prj _ _ -> eltType (undefined::t)
IndexNil -> eltType (undefined::t)
IndexCons _ _ -> eltType (undefined::t)
IndexHead _ -> eltType (undefined::t)
IndexTail _ -> eltType (undefined::t)
IndexAny -> eltType (undefined::t)
IndexSlice _ _ _ -> eltType (undefined::t)
IndexFull _ _ _ -> eltType (undefined::t)
ToIndex _ _ -> eltType (undefined::t)
FromIndex _ _ -> eltType (undefined::t)
Cond _ t _ -> preExpType k t
While _ _ _ -> eltType (undefined::t)
PrimConst _ -> eltType (undefined::t)
PrimApp _ _ -> eltType (undefined::t)
Index acc _ -> k acc
LinearIndex acc _ -> k acc
Shape _ -> eltType (undefined::t)
ShapeSize _ -> eltType (undefined::t)
Intersect _ _ -> eltType (undefined::t)
Union _ _ -> eltType (undefined::t)
Foreign _ _ _ -> eltType (undefined::t)
Coerce _ -> eltType (undefined::t)
sizeOf :: TupleType a -> Int
sizeOf TypeRunit = 0
sizeOf (TypeRpair a b) = sizeOf a + sizeOf b
sizeOf (TypeRscalar t) = sizeOfScalarType t
sizeOfScalarType :: ScalarType t -> Int
sizeOfScalarType (SingleScalarType t) = sizeOfSingleType t
sizeOfScalarType (VectorScalarType t) = sizeOfVectorType t
sizeOfSingleType :: SingleType t -> Int
sizeOfSingleType (NumSingleType t) = sizeOfNumType t
sizeOfSingleType (NonNumSingleType t) = sizeOfNonNumType t
sizeOfVectorType :: VectorType t -> Int
sizeOfVectorType (Vector2Type t) = 2 * sizeOfSingleType t
sizeOfVectorType (Vector3Type t) = 3 * sizeOfSingleType t
sizeOfVectorType (Vector4Type t) = 4 * sizeOfSingleType t
sizeOfVectorType (Vector8Type t) = 8 * sizeOfSingleType t
sizeOfVectorType (Vector16Type t) = 16 * sizeOfSingleType t
sizeOfNumType :: forall t. NumType t -> Int
sizeOfNumType (IntegralNumType t) | IntegralDict <- integralDict t = F.sizeOf (undefined::t)
sizeOfNumType (FloatingNumType t) | FloatingDict <- floatingDict t = F.sizeOf (undefined::t)
sizeOfNonNumType :: forall t. NonNumType t -> Int
sizeOfNonNumType TypeBool{} = 1
sizeOfNonNumType t | NonNumDict <- nonNumDict t = F.sizeOf (undefined::t)