{-# LANGUAGE TemplateHaskell #-}

{-# OPTIONS_GHC -fwarn-incomplete-patterns #-}

-- | Primitive Feldspar expressions

module Feldspar.Primitive.Representation where



import Data.Array
import Data.Bits (Bits (..))
import Data.Complex
import Data.Int
import Data.Typeable
import Data.Word

import Data.Constraint (Dict (..))

import Language.Embedded.Expression
import Language.Embedded.Imperative.CMD (IArr (..))

import Language.Syntactic
import Language.Syntactic.TH
import Language.Syntactic.Functional



--------------------------------------------------------------------------------
-- * Types
--------------------------------------------------------------------------------

type Length = Word32
type Index  = Word32

-- | Representation of primitive supported types
data PrimTypeRep a
  where
    BoolT          :: PrimTypeRep Bool
    Int8T          :: PrimTypeRep Int8
    Int16T         :: PrimTypeRep Int16
    Int32T         :: PrimTypeRep Int32
    Int64T         :: PrimTypeRep Int64
    Word8T         :: PrimTypeRep Word8
    Word16T        :: PrimTypeRep Word16
    Word32T        :: PrimTypeRep Word32
    Word64T        :: PrimTypeRep Word64
    FloatT         :: PrimTypeRep Float
    DoubleT        :: PrimTypeRep Double
    ComplexFloatT  :: PrimTypeRep (Complex Float)
    ComplexDoubleT :: PrimTypeRep (Complex Double)

data IntTypeRep a
  where
    Int8Type  :: IntTypeRep Int8
    Int16Type :: IntTypeRep Int16
    Int32Type :: IntTypeRep Int32
    Int64Type :: IntTypeRep Int64

data WordTypeRep a
  where
    Word8Type  :: WordTypeRep Word8
    Word16Type :: WordTypeRep Word16
    Word32Type :: WordTypeRep Word32
    Word64Type :: WordTypeRep Word64

data IntWordTypeRep a
  where
    IntType  :: IntTypeRep a -> IntWordTypeRep a
    WordType :: WordTypeRep a -> IntWordTypeRep a

data FloatingTypeRep a
  where
    FloatType  :: FloatingTypeRep Float
    DoubleType :: FloatingTypeRep Double

data ComplexTypeRep a
  where
    ComplexFloatType  :: ComplexTypeRep (Complex Float)
    ComplexDoubleType :: ComplexTypeRep (Complex Double)

-- | A different view of 'PrimTypeRep' that allows matching on similar types
data PrimTypeView a
  where
    PrimTypeBool     :: PrimTypeView Bool
    PrimTypeIntWord  :: IntWordTypeRep a -> PrimTypeView a
    PrimTypeFloating :: FloatingTypeRep a -> PrimTypeView a
    PrimTypeComplex  :: ComplexTypeRep a -> PrimTypeView a

deriving instance Show (PrimTypeRep a)
deriving instance Show (IntTypeRep a)
deriving instance Show (WordTypeRep a)
deriving instance Show (IntWordTypeRep a)
deriving instance Show (FloatingTypeRep a)
deriving instance Show (ComplexTypeRep a)
deriving instance Show (PrimTypeView a)

viewPrimTypeRep :: PrimTypeRep a -> PrimTypeView a
viewPrimTypeRep :: PrimTypeRep a -> PrimTypeView a
viewPrimTypeRep PrimTypeRep a
BoolT          = PrimTypeView a
PrimTypeView Bool
PrimTypeBool
viewPrimTypeRep PrimTypeRep a
Int8T          = IntWordTypeRep Int8 -> PrimTypeView Int8
forall a. IntWordTypeRep a -> PrimTypeView a
PrimTypeIntWord (IntWordTypeRep Int8 -> PrimTypeView Int8)
-> IntWordTypeRep Int8 -> PrimTypeView Int8
forall a b. (a -> b) -> a -> b
$ IntTypeRep Int8 -> IntWordTypeRep Int8
forall a. IntTypeRep a -> IntWordTypeRep a
IntType (IntTypeRep Int8 -> IntWordTypeRep Int8)
-> IntTypeRep Int8 -> IntWordTypeRep Int8
forall a b. (a -> b) -> a -> b
$ IntTypeRep Int8
Int8Type
viewPrimTypeRep PrimTypeRep a
Int16T         = IntWordTypeRep Int16 -> PrimTypeView Int16
forall a. IntWordTypeRep a -> PrimTypeView a
PrimTypeIntWord (IntWordTypeRep Int16 -> PrimTypeView Int16)
-> IntWordTypeRep Int16 -> PrimTypeView Int16
forall a b. (a -> b) -> a -> b
$ IntTypeRep Int16 -> IntWordTypeRep Int16
forall a. IntTypeRep a -> IntWordTypeRep a
IntType (IntTypeRep Int16 -> IntWordTypeRep Int16)
-> IntTypeRep Int16 -> IntWordTypeRep Int16
forall a b. (a -> b) -> a -> b
$ IntTypeRep Int16
Int16Type
viewPrimTypeRep PrimTypeRep a
Int32T         = IntWordTypeRep Int32 -> PrimTypeView Int32
forall a. IntWordTypeRep a -> PrimTypeView a
PrimTypeIntWord (IntWordTypeRep Int32 -> PrimTypeView Int32)
-> IntWordTypeRep Int32 -> PrimTypeView Int32
forall a b. (a -> b) -> a -> b
$ IntTypeRep Int32 -> IntWordTypeRep Int32
forall a. IntTypeRep a -> IntWordTypeRep a
IntType (IntTypeRep Int32 -> IntWordTypeRep Int32)
-> IntTypeRep Int32 -> IntWordTypeRep Int32
forall a b. (a -> b) -> a -> b
$ IntTypeRep Int32
Int32Type
viewPrimTypeRep PrimTypeRep a
Int64T         = IntWordTypeRep Int64 -> PrimTypeView Int64
forall a. IntWordTypeRep a -> PrimTypeView a
PrimTypeIntWord (IntWordTypeRep Int64 -> PrimTypeView Int64)
-> IntWordTypeRep Int64 -> PrimTypeView Int64
forall a b. (a -> b) -> a -> b
$ IntTypeRep Int64 -> IntWordTypeRep Int64
forall a. IntTypeRep a -> IntWordTypeRep a
IntType (IntTypeRep Int64 -> IntWordTypeRep Int64)
-> IntTypeRep Int64 -> IntWordTypeRep Int64
forall a b. (a -> b) -> a -> b
$ IntTypeRep Int64
Int64Type
viewPrimTypeRep PrimTypeRep a
Word8T         = IntWordTypeRep Word8 -> PrimTypeView Word8
forall a. IntWordTypeRep a -> PrimTypeView a
PrimTypeIntWord (IntWordTypeRep Word8 -> PrimTypeView Word8)
-> IntWordTypeRep Word8 -> PrimTypeView Word8
forall a b. (a -> b) -> a -> b
$ WordTypeRep Word8 -> IntWordTypeRep Word8
forall a. WordTypeRep a -> IntWordTypeRep a
WordType (WordTypeRep Word8 -> IntWordTypeRep Word8)
-> WordTypeRep Word8 -> IntWordTypeRep Word8
forall a b. (a -> b) -> a -> b
$ WordTypeRep Word8
Word8Type
viewPrimTypeRep PrimTypeRep a
Word16T        = IntWordTypeRep Word16 -> PrimTypeView Word16
forall a. IntWordTypeRep a -> PrimTypeView a
PrimTypeIntWord (IntWordTypeRep Word16 -> PrimTypeView Word16)
-> IntWordTypeRep Word16 -> PrimTypeView Word16
forall a b. (a -> b) -> a -> b
$ WordTypeRep Word16 -> IntWordTypeRep Word16
forall a. WordTypeRep a -> IntWordTypeRep a
WordType (WordTypeRep Word16 -> IntWordTypeRep Word16)
-> WordTypeRep Word16 -> IntWordTypeRep Word16
forall a b. (a -> b) -> a -> b
$ WordTypeRep Word16
Word16Type
viewPrimTypeRep PrimTypeRep a
Word32T        = IntWordTypeRep Word32 -> PrimTypeView Word32
forall a. IntWordTypeRep a -> PrimTypeView a
PrimTypeIntWord (IntWordTypeRep Word32 -> PrimTypeView Word32)
-> IntWordTypeRep Word32 -> PrimTypeView Word32
forall a b. (a -> b) -> a -> b
$ WordTypeRep Word32 -> IntWordTypeRep Word32
forall a. WordTypeRep a -> IntWordTypeRep a
WordType (WordTypeRep Word32 -> IntWordTypeRep Word32)
-> WordTypeRep Word32 -> IntWordTypeRep Word32
forall a b. (a -> b) -> a -> b
$ WordTypeRep Word32
Word32Type
viewPrimTypeRep PrimTypeRep a
Word64T        = IntWordTypeRep Word64 -> PrimTypeView Word64
forall a. IntWordTypeRep a -> PrimTypeView a
PrimTypeIntWord (IntWordTypeRep Word64 -> PrimTypeView Word64)
-> IntWordTypeRep Word64 -> PrimTypeView Word64
forall a b. (a -> b) -> a -> b
$ WordTypeRep Word64 -> IntWordTypeRep Word64
forall a. WordTypeRep a -> IntWordTypeRep a
WordType (WordTypeRep Word64 -> IntWordTypeRep Word64)
-> WordTypeRep Word64 -> IntWordTypeRep Word64
forall a b. (a -> b) -> a -> b
$ WordTypeRep Word64
Word64Type
viewPrimTypeRep PrimTypeRep a
FloatT         = FloatingTypeRep Float -> PrimTypeView Float
forall a. FloatingTypeRep a -> PrimTypeView a
PrimTypeFloating FloatingTypeRep Float
FloatType
viewPrimTypeRep PrimTypeRep a
DoubleT        = FloatingTypeRep Double -> PrimTypeView Double
forall a. FloatingTypeRep a -> PrimTypeView a
PrimTypeFloating FloatingTypeRep Double
DoubleType
viewPrimTypeRep PrimTypeRep a
ComplexFloatT  = ComplexTypeRep (Complex Float) -> PrimTypeView (Complex Float)
forall a. ComplexTypeRep a -> PrimTypeView a
PrimTypeComplex ComplexTypeRep (Complex Float)
ComplexFloatType
viewPrimTypeRep PrimTypeRep a
ComplexDoubleT = ComplexTypeRep (Complex Double) -> PrimTypeView (Complex Double)
forall a. ComplexTypeRep a -> PrimTypeView a
PrimTypeComplex ComplexTypeRep (Complex Double)
ComplexDoubleType

unviewPrimTypeRep :: PrimTypeView a -> PrimTypeRep a
unviewPrimTypeRep :: PrimTypeView a -> PrimTypeRep a
unviewPrimTypeRep PrimTypeView a
PrimTypeBool                            = PrimTypeRep a
PrimTypeRep Bool
BoolT
unviewPrimTypeRep (PrimTypeIntWord (IntType IntTypeRep a
Int8Type))    = PrimTypeRep a
PrimTypeRep Int8
Int8T
unviewPrimTypeRep (PrimTypeIntWord (IntType IntTypeRep a
Int16Type))   = PrimTypeRep a
PrimTypeRep Int16
Int16T
unviewPrimTypeRep (PrimTypeIntWord (IntType IntTypeRep a
Int32Type))   = PrimTypeRep a
PrimTypeRep Int32
Int32T
unviewPrimTypeRep (PrimTypeIntWord (IntType IntTypeRep a
Int64Type))   = PrimTypeRep a
PrimTypeRep Int64
Int64T
unviewPrimTypeRep (PrimTypeIntWord (WordType WordTypeRep a
Word8Type))  = PrimTypeRep a
PrimTypeRep Word8
Word8T
unviewPrimTypeRep (PrimTypeIntWord (WordType WordTypeRep a
Word16Type)) = PrimTypeRep a
PrimTypeRep Word16
Word16T
unviewPrimTypeRep (PrimTypeIntWord (WordType WordTypeRep a
Word32Type)) = PrimTypeRep a
PrimTypeRep Word32
Word32T
unviewPrimTypeRep (PrimTypeIntWord (WordType WordTypeRep a
Word64Type)) = PrimTypeRep a
PrimTypeRep Word64
Word64T
unviewPrimTypeRep (PrimTypeFloating FloatingTypeRep a
FloatType)            = PrimTypeRep a
PrimTypeRep Float
FloatT
unviewPrimTypeRep (PrimTypeFloating FloatingTypeRep a
DoubleType)           = PrimTypeRep a
PrimTypeRep Double
DoubleT
unviewPrimTypeRep (PrimTypeComplex ComplexTypeRep a
ComplexFloatType)      = PrimTypeRep a
PrimTypeRep (Complex Float)
ComplexFloatT
unviewPrimTypeRep (PrimTypeComplex ComplexTypeRep a
ComplexDoubleType)     = PrimTypeRep a
PrimTypeRep (Complex Double)
ComplexDoubleT

primTypeIntWidth :: PrimTypeRep a -> Maybe Int
primTypeIntWidth :: PrimTypeRep a -> Maybe Int
primTypeIntWidth PrimTypeRep a
Int8T   = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
8
primTypeIntWidth PrimTypeRep a
Int16T  = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
16
primTypeIntWidth PrimTypeRep a
Int32T  = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
32
primTypeIntWidth PrimTypeRep a
Int64T  = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
64
primTypeIntWidth PrimTypeRep a
Word8T  = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
8
primTypeIntWidth PrimTypeRep a
Word16T = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
16
primTypeIntWidth PrimTypeRep a
Word32T = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
32
primTypeIntWidth PrimTypeRep a
Word64T = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
64
primTypeIntWidth PrimTypeRep a
_       = Maybe Int
forall a. Maybe a
Nothing

-- | Primitive supported types
class (Eq a, Show a, Typeable a) => PrimType' a
  where
    -- | Reify a primitive type
    primTypeRep :: PrimTypeRep a

instance PrimType' Bool             where primTypeRep :: PrimTypeRep Bool
primTypeRep = PrimTypeRep Bool
BoolT
instance PrimType' Int8             where primTypeRep :: PrimTypeRep Int8
primTypeRep = PrimTypeRep Int8
Int8T
instance PrimType' Int16            where primTypeRep :: PrimTypeRep Int16
primTypeRep = PrimTypeRep Int16
Int16T
instance PrimType' Int32            where primTypeRep :: PrimTypeRep Int32
primTypeRep = PrimTypeRep Int32
Int32T
instance PrimType' Int64            where primTypeRep :: PrimTypeRep Int64
primTypeRep = PrimTypeRep Int64
Int64T
instance PrimType' Word8            where primTypeRep :: PrimTypeRep Word8
primTypeRep = PrimTypeRep Word8
Word8T
instance PrimType' Word16           where primTypeRep :: PrimTypeRep Word16
primTypeRep = PrimTypeRep Word16
Word16T
instance PrimType' Word32           where primTypeRep :: PrimTypeRep Word32
primTypeRep = PrimTypeRep Word32
Word32T
instance PrimType' Word64           where primTypeRep :: PrimTypeRep Word64
primTypeRep = PrimTypeRep Word64
Word64T
instance PrimType' Float            where primTypeRep :: PrimTypeRep Float
primTypeRep = PrimTypeRep Float
FloatT
instance PrimType' Double           where primTypeRep :: PrimTypeRep Double
primTypeRep = PrimTypeRep Double
DoubleT
instance PrimType' (Complex Float)  where primTypeRep :: PrimTypeRep (Complex Float)
primTypeRep = PrimTypeRep (Complex Float)
ComplexFloatT
instance PrimType' (Complex Double) where primTypeRep :: PrimTypeRep (Complex Double)
primTypeRep = PrimTypeRep (Complex Double)
ComplexDoubleT

-- | Convenience function; like 'primTypeRep' but with an extra argument to
-- constrain the type parameter. The extra argument is ignored.
primTypeOf :: PrimType' a => a -> PrimTypeRep a
primTypeOf :: a -> PrimTypeRep a
primTypeOf a
_ = PrimTypeRep a
forall a. PrimType' a => PrimTypeRep a
primTypeRep

-- | Check whether two type representations are equal
primTypeEq :: PrimTypeRep a -> PrimTypeRep b -> Maybe (Dict (a ~ b))
primTypeEq :: PrimTypeRep a -> PrimTypeRep b -> Maybe (Dict (a ~ b))
primTypeEq PrimTypeRep a
BoolT          PrimTypeRep b
BoolT          = Dict (a ~ b) -> Maybe (Dict (a ~ b))
forall a. a -> Maybe a
Just Dict (a ~ b)
forall (a :: Constraint). a => Dict a
Dict
primTypeEq PrimTypeRep a
Int8T          PrimTypeRep b
Int8T          = Dict (a ~ b) -> Maybe (Dict (a ~ b))
forall a. a -> Maybe a
Just Dict (a ~ b)
forall (a :: Constraint). a => Dict a
Dict
primTypeEq PrimTypeRep a
Int16T         PrimTypeRep b
Int16T         = Dict (a ~ b) -> Maybe (Dict (a ~ b))
forall a. a -> Maybe a
Just Dict (a ~ b)
forall (a :: Constraint). a => Dict a
Dict
primTypeEq PrimTypeRep a
Int32T         PrimTypeRep b
Int32T         = Dict (a ~ b) -> Maybe (Dict (a ~ b))
forall a. a -> Maybe a
Just Dict (a ~ b)
forall (a :: Constraint). a => Dict a
Dict
primTypeEq PrimTypeRep a
Int64T         PrimTypeRep b
Int64T         = Dict (a ~ b) -> Maybe (Dict (a ~ b))
forall a. a -> Maybe a
Just Dict (a ~ b)
forall (a :: Constraint). a => Dict a
Dict
primTypeEq PrimTypeRep a
Word8T         PrimTypeRep b
Word8T         = Dict (a ~ b) -> Maybe (Dict (a ~ b))
forall a. a -> Maybe a
Just Dict (a ~ b)
forall (a :: Constraint). a => Dict a
Dict
primTypeEq PrimTypeRep a
Word16T        PrimTypeRep b
Word16T        = Dict (a ~ b) -> Maybe (Dict (a ~ b))
forall a. a -> Maybe a
Just Dict (a ~ b)
forall (a :: Constraint). a => Dict a
Dict
primTypeEq PrimTypeRep a
Word32T        PrimTypeRep b
Word32T        = Dict (a ~ b) -> Maybe (Dict (a ~ b))
forall a. a -> Maybe a
Just Dict (a ~ b)
forall (a :: Constraint). a => Dict a
Dict
primTypeEq PrimTypeRep a
Word64T        PrimTypeRep b
Word64T        = Dict (a ~ b) -> Maybe (Dict (a ~ b))
forall a. a -> Maybe a
Just Dict (a ~ b)
forall (a :: Constraint). a => Dict a
Dict
primTypeEq PrimTypeRep a
FloatT         PrimTypeRep b
FloatT         = Dict (a ~ b) -> Maybe (Dict (a ~ b))
forall a. a -> Maybe a
Just Dict (a ~ b)
forall (a :: Constraint). a => Dict a
Dict
primTypeEq PrimTypeRep a
DoubleT        PrimTypeRep b
DoubleT        = Dict (a ~ b) -> Maybe (Dict (a ~ b))
forall a. a -> Maybe a
Just Dict (a ~ b)
forall (a :: Constraint). a => Dict a
Dict
primTypeEq PrimTypeRep a
ComplexFloatT  PrimTypeRep b
ComplexFloatT  = Dict (a ~ b) -> Maybe (Dict (a ~ b))
forall a. a -> Maybe a
Just Dict (a ~ b)
forall (a :: Constraint). a => Dict a
Dict
primTypeEq PrimTypeRep a
ComplexDoubleT PrimTypeRep b
ComplexDoubleT = Dict (a ~ b) -> Maybe (Dict (a ~ b))
forall a. a -> Maybe a
Just Dict (a ~ b)
forall (a :: Constraint). a => Dict a
Dict
primTypeEq PrimTypeRep a
_ PrimTypeRep b
_ = Maybe (Dict (a ~ b))
forall a. Maybe a
Nothing

-- | Reflect a 'PrimTypeRep' to a 'PrimType'' constraint
witPrimType :: PrimTypeRep a -> Dict (PrimType' a)
witPrimType :: PrimTypeRep a -> Dict (PrimType' a)
witPrimType PrimTypeRep a
BoolT          = Dict (PrimType' a)
forall (a :: Constraint). a => Dict a
Dict
witPrimType PrimTypeRep a
Int8T          = Dict (PrimType' a)
forall (a :: Constraint). a => Dict a
Dict
witPrimType PrimTypeRep a
Int16T         = Dict (PrimType' a)
forall (a :: Constraint). a => Dict a
Dict
witPrimType PrimTypeRep a
Int32T         = Dict (PrimType' a)
forall (a :: Constraint). a => Dict a
Dict
witPrimType PrimTypeRep a
Int64T         = Dict (PrimType' a)
forall (a :: Constraint). a => Dict a
Dict
witPrimType PrimTypeRep a
Word8T         = Dict (PrimType' a)
forall (a :: Constraint). a => Dict a
Dict
witPrimType PrimTypeRep a
Word16T        = Dict (PrimType' a)
forall (a :: Constraint). a => Dict a
Dict
witPrimType PrimTypeRep a
Word32T        = Dict (PrimType' a)
forall (a :: Constraint). a => Dict a
Dict
witPrimType PrimTypeRep a
Word64T        = Dict (PrimType' a)
forall (a :: Constraint). a => Dict a
Dict
witPrimType PrimTypeRep a
FloatT         = Dict (PrimType' a)
forall (a :: Constraint). a => Dict a
Dict
witPrimType PrimTypeRep a
DoubleT        = Dict (PrimType' a)
forall (a :: Constraint). a => Dict a
Dict
witPrimType PrimTypeRep a
ComplexFloatT  = Dict (PrimType' a)
forall (a :: Constraint). a => Dict a
Dict
witPrimType PrimTypeRep a
ComplexDoubleT = Dict (PrimType' a)
forall (a :: Constraint). a => Dict a
Dict



--------------------------------------------------------------------------------
-- * Expressions
--------------------------------------------------------------------------------

-- | Primitive operations
data Primitive sig
  where
    FreeVar :: PrimType' a => String -> Primitive (Full a)
    Lit     :: (Eq a, Show a) => a -> Primitive (Full a)

    Add  :: (Num a, PrimType' a) => Primitive (a :-> a :-> Full a)
    Sub  :: (Num a, PrimType' a) => Primitive (a :-> a :-> Full a)
    Mul  :: (Num a, PrimType' a) => Primitive (a :-> a :-> Full a)
    Neg  :: (Num a, PrimType' a) => Primitive (a :-> Full a)
    Abs  :: (Num a, PrimType' a) => Primitive (a :-> Full a)
    Sign :: (Num a, PrimType' a) => Primitive (a :-> Full a)

    Quot :: (Integral a, PrimType' a)   => Primitive (a :-> a :-> Full a)
    Rem  :: (Integral a, PrimType' a)   => Primitive (a :-> a :-> Full a)
    Div  :: (Integral a, PrimType' a)   => Primitive (a :-> a :-> Full a)
    Mod  :: (Integral a, PrimType' a)   => Primitive (a :-> a :-> Full a)
    FDiv :: (Fractional a, PrimType' a) => Primitive (a :-> a :-> Full a)

    Pi    :: (Floating a, PrimType' a) => Primitive (Full a)
    Exp   :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Log   :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Sqrt  :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Pow   :: (Floating a, PrimType' a) => Primitive (a :-> a :-> Full a)
    Sin   :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Cos   :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Tan   :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Asin  :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Acos  :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Atan  :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Sinh  :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Cosh  :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Tanh  :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Asinh :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Acosh :: (Floating a, PrimType' a) => Primitive (a :-> Full a)
    Atanh :: (Floating a, PrimType' a) => Primitive (a :-> Full a)

    Complex   :: (Num a, PrimType' a, PrimType' (Complex a))       => Primitive (a :-> a :-> Full (Complex a))
    Polar     :: (Floating a, PrimType' a, PrimType' (Complex a))  => Primitive (a :-> a :-> Full (Complex a))
    Real      :: (PrimType' a, PrimType' (Complex a))              => Primitive (Complex a :-> Full a)
    Imag      :: (PrimType' a, PrimType' (Complex a))              => Primitive (Complex a :-> Full a)
    Magnitude :: (RealFloat a, PrimType' a, PrimType' (Complex a)) => Primitive (Complex a :-> Full a)
    Phase     :: (RealFloat a, PrimType' a, PrimType' (Complex a)) => Primitive (Complex a :-> Full a)
    Conjugate :: (Num a, PrimType' (Complex a))                    => Primitive (Complex a :-> Full (Complex a))

    I2N   :: (Integral a, Num b, PrimType' a, PrimType' b)      => Primitive (a :-> Full b)
    I2B   :: (Integral a, PrimType' a)                          => Primitive (a :-> Full Bool)
    B2I   :: (Integral a, PrimType' a)                          => Primitive (Bool :-> Full a)
    Round :: (RealFrac a, Num b, PrimType' a, PrimType' b) => Primitive (a :-> Full b)

    Not :: Primitive (Bool :-> Full Bool)
    And :: Primitive (Bool :-> Bool :-> Full Bool)
    Or  :: Primitive (Bool :-> Bool :-> Full Bool)
    Eq  :: (Eq a, PrimType' a)  => Primitive (a :-> a :-> Full Bool)
    NEq :: (Eq a, PrimType' a)  => Primitive (a :-> a :-> Full Bool)
    Lt  :: (Ord a, PrimType' a) => Primitive (a :-> a :-> Full Bool)
    Gt  :: (Ord a, PrimType' a) => Primitive (a :-> a :-> Full Bool)
    Le  :: (Ord a, PrimType' a) => Primitive (a :-> a :-> Full Bool)
    Ge  :: (Ord a, PrimType' a) => Primitive (a :-> a :-> Full Bool)

    BitAnd   :: (Bits a, PrimType' a) => Primitive (a :-> a :-> Full a)
    BitOr    :: (Bits a, PrimType' a) => Primitive (a :-> a :-> Full a)
    BitXor   :: (Bits a, PrimType' a) => Primitive (a :-> a :-> Full a)
    BitCompl :: (Bits a, PrimType' a) => Primitive (a :-> Full a)
    ShiftL   :: (Bits a, PrimType' a, Integral b, PrimType' b) => Primitive (a :-> b :-> Full a)
    ShiftR   :: (Bits a, PrimType' a, Integral b, PrimType' b) => Primitive (a :-> b :-> Full a)

    ArrIx :: PrimType' a => IArr Index a -> Primitive (Index :-> Full a)

    Cond :: Primitive (Bool :-> a :-> a :-> Full a)

deriving instance Show (Primitive a)

-- The `PrimType'` constraints on certain symbols require an explanation: The
-- constraints are actually not needed for anything in the modules in
-- `Feldspar.Primitive.*`, but they are needed by `Feldspar.Run.Compile`. They
-- guarantee to the compiler that these symbols don't operate on tuples.
--
-- It would seem more consistent to have a `PrimType'` constraint on all
-- polymorphic symbols. However, this would prevent using some symbols for
-- non-primitive types in `Feldspar.Representation`. For example, `Lit` and
-- `Cond` are used `Feldspar.Representation`, and there they can also be used
-- for tuple types. The current design was chosen because it "just works".

deriveSymbol ''Primitive

instance Render Primitive
  where
    renderSym :: Primitive sig -> String
renderSym (FreeVar String
v) = String
v
    renderSym (Lit a
a)     = a -> String
forall a. Show a => a -> String
show a
a
    renderSym (ArrIx (IArrComp String
arr)) = String
"ArrIx " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
arr
    renderSym (ArrIx IArr Word32 a
_)              = String
"ArrIx ..."
    renderSym Primitive sig
s = Primitive sig -> String
forall a. Show a => a -> String
show Primitive sig
s

    renderArgs :: [String] -> Primitive sig -> String
renderArgs = [String] -> Primitive sig -> String
forall (sym :: * -> *) a. Render sym => [String] -> sym a -> String
renderArgsSmart

instance StringTree Primitive

instance Eval Primitive
  where
    evalSym :: Primitive sig -> Denotation sig
evalSym (FreeVar String
v) = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"evaluating free variable " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
forall a. Show a => a -> String
show String
v
    evalSym (Lit a
a)     = a
Denotation sig
a
    evalSym Primitive sig
Add         = Denotation sig
forall a. Num a => a -> a -> a
(+)
    evalSym Primitive sig
Sub         = (-)
    evalSym Primitive sig
Mul         = Denotation sig
forall a. Num a => a -> a -> a
(*)
    evalSym Primitive sig
Neg         = Denotation sig
forall a. Num a => a -> a
negate
    evalSym Primitive sig
Abs         = Denotation sig
forall a. Num a => a -> a
abs
    evalSym Primitive sig
Sign        = Denotation sig
forall a. Num a => a -> a
signum
    evalSym Primitive sig
Quot        = Denotation sig
forall a. Integral a => a -> a -> a
quot
    evalSym Primitive sig
Rem         = Denotation sig
forall a. Integral a => a -> a -> a
rem
    evalSym Primitive sig
Div         = Denotation sig
forall a. Integral a => a -> a -> a
div
    evalSym Primitive sig
Mod         = Denotation sig
forall a. Integral a => a -> a -> a
mod
    evalSym Primitive sig
FDiv        = Denotation sig
forall a. Fractional a => a -> a -> a
(/)
    evalSym Primitive sig
Pi          = Denotation sig
forall a. Floating a => a
pi
    evalSym Primitive sig
Exp         = Denotation sig
forall a. Floating a => a -> a
exp
    evalSym Primitive sig
Log         = Denotation sig
forall a. Floating a => a -> a
log
    evalSym Primitive sig
Sqrt        = Denotation sig
forall a. Floating a => a -> a
sqrt
    evalSym Primitive sig
Pow         = Denotation sig
forall a. Floating a => a -> a -> a
(**)
    evalSym Primitive sig
Sin         = Denotation sig
forall a. Floating a => a -> a
sin
    evalSym Primitive sig
Cos         = Denotation sig
forall a. Floating a => a -> a
cos
    evalSym Primitive sig
Tan         = Denotation sig
forall a. Floating a => a -> a
tan
    evalSym Primitive sig
Asin        = Denotation sig
forall a. Floating a => a -> a
asin
    evalSym Primitive sig
Acos        = Denotation sig
forall a. Floating a => a -> a
acos
    evalSym Primitive sig
Atan        = Denotation sig
forall a. Floating a => a -> a
atan
    evalSym Primitive sig
Sinh        = Denotation sig
forall a. Floating a => a -> a
sinh
    evalSym Primitive sig
Cosh        = Denotation sig
forall a. Floating a => a -> a
cosh
    evalSym Primitive sig
Tanh        = Denotation sig
forall a. Floating a => a -> a
tanh
    evalSym Primitive sig
Asinh       = Denotation sig
forall a. Floating a => a -> a
asinh
    evalSym Primitive sig
Acosh       = Denotation sig
forall a. Floating a => a -> a
acosh
    evalSym Primitive sig
Atanh       = Denotation sig
forall a. Floating a => a -> a
atanh
    evalSym Primitive sig
Complex     = Denotation sig
forall a. a -> a -> Complex a
(:+)
    evalSym Primitive sig
Polar       = Denotation sig
forall a. Floating a => a -> a -> Complex a
mkPolar
    evalSym Primitive sig
Real        = Denotation sig
forall a. Complex a -> a
realPart
    evalSym Primitive sig
Imag        = Denotation sig
forall a. Complex a -> a
imagPart
    evalSym Primitive sig
Magnitude   = Denotation sig
forall a. RealFloat a => Complex a -> a
magnitude
    evalSym Primitive sig
Phase       = Denotation sig
forall a. RealFloat a => Complex a -> a
phase
    evalSym Primitive sig
Conjugate   = Denotation sig
forall a. Num a => Complex a -> Complex a
conjugate
    evalSym Primitive sig
I2N         = Integer -> b
forall a. Num a => Integer -> a
fromInteger (Integer -> b) -> (a -> Integer) -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Integer
forall a. Integral a => a -> Integer
toInteger
    evalSym Primitive sig
I2B         = (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/=a
0)
    evalSym Primitive sig
B2I         = \Bool
a -> if Bool
a then a
1 else a
0
    evalSym Primitive sig
Round       = Integer -> b
forall a. Num a => Integer -> a
fromInteger (Integer -> b) -> (a -> Integer) -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
round
    evalSym Primitive sig
Not         = Denotation sig
Bool -> Bool
not
    evalSym Primitive sig
And         = Denotation sig
Bool -> Bool -> Bool
(&&)
    evalSym Primitive sig
Or          = Denotation sig
Bool -> Bool -> Bool
(||)
    evalSym Primitive sig
Eq          = Denotation sig
forall a. Eq a => a -> a -> Bool
(==)
    evalSym Primitive sig
NEq         = Denotation sig
forall a. Eq a => a -> a -> Bool
(/=)
    evalSym Primitive sig
Lt          = Denotation sig
forall a. Ord a => a -> a -> Bool
(<)
    evalSym Primitive sig
Gt          = Denotation sig
forall a. Ord a => a -> a -> Bool
(>)
    evalSym Primitive sig
Le          = Denotation sig
forall a. Ord a => a -> a -> Bool
(<=)
    evalSym Primitive sig
Ge          = Denotation sig
forall a. Ord a => a -> a -> Bool
(>=)
    evalSym Primitive sig
BitAnd      = Denotation sig
forall a. Bits a => a -> a -> a
(.&.)
    evalSym Primitive sig
BitOr       = Denotation sig
forall a. Bits a => a -> a -> a
(.|.)
    evalSym Primitive sig
BitXor      = Denotation sig
forall a. Bits a => a -> a -> a
xor
    evalSym Primitive sig
BitCompl    = Denotation sig
forall a. Bits a => a -> a
complement
    evalSym Primitive sig
ShiftL      = \a
a -> a -> Int -> a
forall a. Bits a => a -> Int -> a
shiftL a
a (Int -> a) -> (b -> Int) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral
    evalSym Primitive sig
ShiftR      = \a
a -> a -> Int -> a
forall a. Bits a => a -> Int -> a
shiftR a
a (Int -> a) -> (b -> Int) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral
    evalSym Primitive sig
Cond        = \Bool
c a
t a
f -> if Bool
c then a
t else a
f
    evalSym (ArrIx (IArrRun Array Word32 a
arr)) = \Word32
i ->
        if Word32
iWord32 -> Word32 -> Bool
forall a. Ord a => a -> a -> Bool
<Word32
l Bool -> Bool -> Bool
|| Word32
iWord32 -> Word32 -> Bool
forall a. Ord a => a -> a -> Bool
>Word32
h
          then String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"ArrIx: index "
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show (Word32 -> Integer
forall a. Integral a => a -> Integer
toInteger Word32
i)
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" out of bounds "
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Integer, Integer) -> String
forall a. Show a => a -> String
show (Word32 -> Integer
forall a. Integral a => a -> Integer
toInteger Word32
l, Word32 -> Integer
forall a. Integral a => a -> Integer
toInteger Word32
h)
          else Array Word32 a
arrArray Word32 a -> Word32 -> a
forall i e. Ix i => Array i e -> i -> e
!Word32
i
      where
        (Word32
l,Word32
h) = Array Word32 a -> (Word32, Word32)
forall i e. Array i e -> (i, i)
bounds Array Word32 a
arr
    evalSym (ArrIx (IArrComp String
arr)) = String -> Word32 -> a
forall a. HasCallStack => String -> a
error (String -> Word32 -> a) -> String -> Word32 -> a
forall a b. (a -> b) -> a -> b
$ String
"evaluating symbolic array " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
arr

-- | Assumes no occurrences of 'FreeVar' and concrete representation of arrays
instance EvalEnv Primitive env

instance Equality Primitive
  where
    equal :: Primitive a -> Primitive b -> Bool
equal Primitive a
s1 Primitive b
s2 = Primitive a -> String
forall a. Show a => a -> String
show Primitive a
s1 String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== Primitive b -> String
forall a. Show a => a -> String
show Primitive b
s2
      -- NOTE: It is very important not to use `renderSym` here, because it will
      -- render all concrete arrays equal.

      -- This method uses string comparison. It is probably slightly more
      -- efficient to pattern match directly on the constructors. Unfortunately
      -- `deriveEquality ''Primitive` doesn't work, so it gets quite tedious to
      -- write it with pattern matching.

type PrimDomain = Primitive :&: PrimTypeRep

-- | Primitive expressions
newtype Prim a = Prim { Prim a -> ASTF PrimDomain a
unPrim :: ASTF PrimDomain a }

instance Syntactic (Prim a)
  where
    type Domain (Prim a)   = PrimDomain
    type Internal (Prim a) = a
    desugar :: Prim a -> ASTF (Domain (Prim a)) (Internal (Prim a))
desugar = Prim a -> ASTF (Domain (Prim a)) (Internal (Prim a))
forall a. Prim a -> ASTF PrimDomain a
unPrim
    sugar :: ASTF (Domain (Prim a)) (Internal (Prim a)) -> Prim a
sugar   = ASTF (Domain (Prim a)) (Internal (Prim a)) -> Prim a
forall a. ASTF PrimDomain a -> Prim a
Prim

-- | Evaluate a closed expression
evalPrim :: Prim a -> a
evalPrim :: Prim a -> a
evalPrim = AST PrimDomain (Full a) -> a
forall sig. AST PrimDomain sig -> Denotation sig
go (AST PrimDomain (Full a) -> a)
-> (Prim a -> AST PrimDomain (Full a)) -> Prim a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prim a -> AST PrimDomain (Full a)
forall a. Prim a -> ASTF PrimDomain a
unPrim
  where
    go :: AST PrimDomain sig -> Denotation sig
    go :: AST PrimDomain sig -> Denotation sig
go (Sym (Primitive sig
s :&: PrimTypeRep (DenResult sig)
_)) = Primitive sig -> Denotation sig
forall (s :: * -> *) sig. Eval s => s sig -> Denotation sig
evalSym Primitive sig
s
    go (AST PrimDomain (a :-> sig)
f :$ AST PrimDomain (Full a)
a) = AST PrimDomain (a :-> sig) -> Denotation (a :-> sig)
forall sig. AST PrimDomain sig -> Denotation sig
go AST PrimDomain (a :-> sig)
f (a -> Denotation sig) -> a -> Denotation sig
forall a b. (a -> b) -> a -> b
$ AST PrimDomain (Full a) -> Denotation (Full a)
forall sig. AST PrimDomain sig -> Denotation sig
go AST PrimDomain (Full a)
a

sugarSymPrim
    :: ( Signature sig
       , fi  ~ SmartFun dom sig
       , sig ~ SmartSig fi
       , dom ~ SmartSym fi
       , dom ~ PrimDomain
       , SyntacticN f fi
       , sub :<: Primitive
       , PrimType' (DenResult sig)
       )
    => sub sig -> f
sugarSymPrim :: sub sig -> f
sugarSymPrim = PrimTypeRep (DenResult sig) -> sub sig -> f
forall sig fi (sup :: * -> *) (info :: * -> *) f (sub :: * -> *).
(Signature sig, fi ~ SmartFun (sup :&: info) sig,
 sig ~ SmartSig fi, (sup :&: info) ~ SmartSym fi, SyntacticN f fi,
 sub :<: sup) =>
info (DenResult sig) -> sub sig -> f
sugarSymDecor PrimTypeRep (DenResult sig)
forall a. PrimType' a => PrimTypeRep a
primTypeRep

instance FreeExp Prim
  where
    type FreePred Prim = PrimType'
    constExp :: a -> Prim a
constExp = Primitive (Full a) -> Prim a
forall sig fi (dom :: * -> *) f (sub :: * -> *).
(Signature sig, fi ~ SmartFun dom sig, sig ~ SmartSig fi,
 dom ~ SmartSym fi, dom ~ PrimDomain, SyntacticN f fi,
 sub :<: Primitive, PrimType' (DenResult sig)) =>
sub sig -> f
sugarSymPrim (Primitive (Full a) -> Prim a)
-> (a -> Primitive (Full a)) -> a -> Prim a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Primitive (Full a)
forall a. (Eq a, Show a) => a -> Primitive (Full a)
Lit
    varExp :: String -> Prim a
varExp   = Primitive (Full a) -> Prim a
forall sig fi (dom :: * -> *) f (sub :: * -> *).
(Signature sig, fi ~ SmartFun dom sig, sig ~ SmartSig fi,
 dom ~ SmartSym fi, dom ~ PrimDomain, SyntacticN f fi,
 sub :<: Primitive, PrimType' (DenResult sig)) =>
sub sig -> f
sugarSymPrim (Primitive (Full a) -> Prim a)
-> (String -> Primitive (Full a)) -> String -> Prim a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Primitive (Full a)
forall a. PrimType' a => String -> Primitive (Full a)
FreeVar

instance EvalExp Prim
  where
    evalExp :: Prim a -> a
evalExp = Prim a -> a
forall a. Prim a -> a
evalPrim



--------------------------------------------------------------------------------
-- * Interface
--------------------------------------------------------------------------------

instance (Num a, PrimType' a) => Num (Prim a)
  where
    fromInteger :: Integer -> Prim a
fromInteger = a -> Prim a
forall (exp :: * -> *) a.
(FreeExp exp, FreePred exp a) =>
a -> exp a
constExp (a -> Prim a) -> (Integer -> a) -> Integer -> Prim a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger
    + :: Prim a -> Prim a -> Prim a
(+)         = Primitive (a :-> (a :-> Full a)) -> Prim a -> Prim a -> Prim a
forall sig fi (dom :: * -> *) f (sub :: * -> *).
(Signature sig, fi ~ SmartFun dom sig, sig ~ SmartSig fi,
 dom ~ SmartSym fi, dom ~ PrimDomain, SyntacticN f fi,
 sub :<: Primitive, PrimType' (DenResult sig)) =>
sub sig -> f
sugarSymPrim Primitive (a :-> (a :-> Full a))
forall a. (Num a, PrimType' a) => Primitive (a :-> (a :-> Full a))
Add
    (-)         = Primitive (a :-> (a :-> Full a)) -> Prim a -> Prim a -> Prim a
forall sig fi (dom :: * -> *) f (sub :: * -> *).
(Signature sig, fi ~ SmartFun dom sig, sig ~ SmartSig fi,
 dom ~ SmartSym fi, dom ~ PrimDomain, SyntacticN f fi,
 sub :<: Primitive, PrimType' (DenResult sig)) =>
sub sig -> f
sugarSymPrim Primitive (a :-> (a :-> Full a))
forall a. (Num a, PrimType' a) => Primitive (a :-> (a :-> Full a))
Sub
    * :: Prim a -> Prim a -> Prim a
(*)         = Primitive (a :-> (a :-> Full a)) -> Prim a -> Prim a -> Prim a
forall sig fi (dom :: * -> *) f (sub :: * -> *).
(Signature sig, fi ~ SmartFun dom sig, sig ~ SmartSig fi,
 dom ~ SmartSym fi, dom ~ PrimDomain, SyntacticN f fi,
 sub :<: Primitive, PrimType' (DenResult sig)) =>
sub sig -> f
sugarSymPrim Primitive (a :-> (a :-> Full a))
forall a. (Num a, PrimType' a) => Primitive (a :-> (a :-> Full a))
Mul
    negate :: Prim a -> Prim a
negate      = Primitive (a :-> Full a) -> Prim a -> Prim a
forall sig fi (dom :: * -> *) f (sub :: * -> *).
(Signature sig, fi ~ SmartFun dom sig, sig ~ SmartSig fi,
 dom ~ SmartSym fi, dom ~ PrimDomain, SyntacticN f fi,
 sub :<: Primitive, PrimType' (DenResult sig)) =>
sub sig -> f
sugarSymPrim Primitive (a :-> Full a)
forall a. (Num a, PrimType' a) => Primitive (a :-> Full a)
Neg
    abs :: Prim a -> Prim a
abs         = Primitive (a :-> Full a) -> Prim a -> Prim a
forall sig fi (dom :: * -> *) f (sub :: * -> *).
(Signature sig, fi ~ SmartFun dom sig, sig ~ SmartSig fi,
 dom ~ SmartSym fi, dom ~ PrimDomain, SyntacticN f fi,
 sub :<: Primitive, PrimType' (DenResult sig)) =>
sub sig -> f
sugarSymPrim Primitive (a :-> Full a)
forall a. (Num a, PrimType' a) => Primitive (a :-> Full a)
Abs
    signum :: Prim a -> Prim a
signum      = Primitive (a :-> Full a) -> Prim a -> Prim a
forall sig fi (dom :: * -> *) f (sub :: * -> *).
(Signature sig, fi ~ SmartFun dom sig, sig ~ SmartSig fi,
 dom ~ SmartSym fi, dom ~ PrimDomain, SyntacticN f fi,
 sub :<: Primitive, PrimType' (DenResult sig)) =>
sub sig -> f
sugarSymPrim Primitive (a :-> Full a)
forall a. (Num a, PrimType' a) => Primitive (a :-> Full a)
Sign