{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE Trustworthy #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

-- | Simple C runtime representation.
module Futhark.CodeGen.Backends.SimpleRep
  ( tupleField,
    funName,
    defaultMemBlockType,
    intTypeToCType,
    primTypeToCType,
    signedPrimTypeToCType,
    arrayName,
    opaqueName,
    externalValueType,
    cproduct,
    csum,

    -- * Primitive value operations
    cIntOps,
    cFloat32Ops,
    cFloat32Funs,
    cFloat64Ops,
    cFloat64Funs,
    cFloatConvOps,

    -- * Storing/restoring values in byte sequences
    storageSize,
    storeValueHeader,
    loadValueHeader,
  )
where

import Data.Bits (shiftR, xor)
import Data.Char (isAlphaNum, isDigit, ord)
import Futhark.CodeGen.ImpCode
import Futhark.Util (zEncodeString)
import Futhark.Util.Pretty (prettyOneLine)
import qualified Language.C.Quote.C as C
import qualified Language.C.Syntax as C
import Text.Printf

-- | The C type corresponding to a signed integer type.
intTypeToCType :: IntType -> C.Type
intTypeToCType :: IntType -> Type
intTypeToCType IntType
Int8 = [C.cty|typename int8_t|]
intTypeToCType IntType
Int16 = [C.cty|typename int16_t|]
intTypeToCType IntType
Int32 = [C.cty|typename int32_t|]
intTypeToCType IntType
Int64 = [C.cty|typename int64_t|]

-- | The C type corresponding to an unsigned integer type.
uintTypeToCType :: IntType -> C.Type
uintTypeToCType :: IntType -> Type
uintTypeToCType IntType
Int8 = [C.cty|typename uint8_t|]
uintTypeToCType IntType
Int16 = [C.cty|typename uint16_t|]
uintTypeToCType IntType
Int32 = [C.cty|typename uint32_t|]
uintTypeToCType IntType
Int64 = [C.cty|typename uint64_t|]

-- | The C type corresponding to a float type.
floatTypeToCType :: FloatType -> C.Type
floatTypeToCType :: FloatType -> Type
floatTypeToCType FloatType
Float32 = [C.cty|float|]
floatTypeToCType FloatType
Float64 = [C.cty|double|]

-- | The C type corresponding to a primitive type.  Integers are
-- assumed to be unsigned.
primTypeToCType :: PrimType -> C.Type
primTypeToCType :: PrimType -> Type
primTypeToCType (IntType IntType
t) = IntType -> Type
intTypeToCType IntType
t
primTypeToCType (FloatType FloatType
t) = FloatType -> Type
floatTypeToCType FloatType
t
primTypeToCType PrimType
Bool = [C.cty|typename bool|]
primTypeToCType PrimType
Unit = [C.cty|typename bool|]

-- | The C type corresponding to a primitive type.  Integers are
-- assumed to have the specified sign.
signedPrimTypeToCType :: Signedness -> PrimType -> C.Type
signedPrimTypeToCType :: Signedness -> PrimType -> Type
signedPrimTypeToCType Signedness
TypeUnsigned (IntType IntType
t) = IntType -> Type
uintTypeToCType IntType
t
signedPrimTypeToCType Signedness
TypeDirect (IntType IntType
t) = IntType -> Type
intTypeToCType IntType
t
signedPrimTypeToCType Signedness
_ PrimType
t = PrimType -> Type
primTypeToCType PrimType
t

-- | @tupleField i@ is the name of field number @i@ in a tuple.
tupleField :: Int -> String
tupleField :: Int -> String
tupleField Int
i = String
"v" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i

-- | @funName f@ is the name of the C function corresponding to
-- the Futhark function @f@.
funName :: Name -> String
funName :: Name -> String
funName = (String
"futrts_" String -> String -> String
forall a. [a] -> [a] -> [a]
++) (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
zEncodeString (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameToString

funName' :: String -> String
funName' :: String -> String
funName' = Name -> String
funName (Name -> String) -> (String -> Name) -> String -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Name
nameFromString

-- | The type of memory blocks in the default memory space.
defaultMemBlockType :: C.Type
defaultMemBlockType :: Type
defaultMemBlockType = [C.cty|char*|]

-- | The name of exposed array type structs.
arrayName :: PrimType -> Signedness -> Int -> String
arrayName :: PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank =
  Bool -> PrimType -> String
prettySigned (Signedness
signed Signedness -> Signedness -> Bool
forall a. Eq a => a -> a -> Bool
== Signedness
TypeUnsigned) PrimType
pt String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
rank String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"d"

-- | The name of exposed opaque types.
opaqueName :: String -> [ValueDesc] -> String
opaqueName :: String -> [ValueDesc] -> String
opaqueName String
s [ValueDesc]
_
  | Bool
valid = String
"opaque_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s
  where
    valid :: Bool
valid =
      String -> Char
forall a. [a] -> a
head String
s Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'_'
        Bool -> Bool -> Bool
&& Bool -> Bool
not (Char -> Bool
isDigit (Char -> Bool) -> Char -> Bool
forall a b. (a -> b) -> a -> b
$ String -> Char
forall a. [a] -> a
head String
s)
        Bool -> Bool -> Bool
&& (Char -> Bool) -> String -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Char -> Bool
ok String
s
    ok :: Char -> Bool
ok Char
c = Char -> Bool
isAlphaNum Char
c Bool -> Bool -> Bool
|| Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'_'
opaqueName String
s [ValueDesc]
vds = String
"opaque_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
hash ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Bits a => a -> a -> a
xor [Int
0 ..] ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Char -> Int) -> String -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Char -> Int
ord (String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ (ValueDesc -> String) -> [ValueDesc] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ValueDesc -> String
p [ValueDesc]
vds))
  where
    p :: ValueDesc -> String
p (ScalarValue PrimType
pt Signedness
signed VName
_) =
      (PrimType, Signedness) -> String
forall a. Show a => a -> String
show (PrimType
pt, Signedness
signed)
    p (ArrayValue VName
_ Space
space PrimType
pt Signedness
signed [DimSize]
dims) =
      (Space, PrimType, Signedness, Int) -> String
forall a. Show a => a -> String
show (Space
space, PrimType
pt, Signedness
signed, [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
dims)

    -- FIXME: a stupid hash algorithm; may have collisions.
    hash :: [Int] -> String
hash =
      String -> Word32 -> String
forall r. PrintfType r => String -> r
printf String
"%x" (Word32 -> String) -> ([Int] -> Word32) -> [Int] -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word32 -> Word32 -> Word32) -> Word32 -> [Word32] -> Word32
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
xor Word32
0
        ([Word32] -> Word32) -> ([Int] -> [Word32]) -> [Int] -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Word32) -> [Int] -> [Word32]
forall a b. (a -> b) -> [a] -> [b]
map
          ( Word32 -> Word32
iter (Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
0x45d9f3b)
              (Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word32
iter
              (Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
0x45d9f3b)
              (Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word32
iter
              (Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral
          )
    iter :: Word32 -> Word32
iter Word32
x = ((Word32
x :: Word32) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` Word32
x

-- | The type used to expose a Futhark value in the C API.  A pointer
-- in the case of arrays and opaques.
externalValueType :: ExternalValue -> C.Type
externalValueType :: ExternalValue -> Type
externalValueType (OpaqueValue String
desc [ValueDesc]
vds) =
  [C.cty|struct $id:("futhark_" ++ opaqueName desc vds)*|]
externalValueType (TransparentValue (ArrayValue VName
_ Space
_ PrimType
pt Signedness
signed [DimSize]
shape)) =
  [C.cty|struct $id:("futhark_" ++ arrayName pt signed (length shape))*|]
externalValueType (TransparentValue (ScalarValue PrimType
pt Signedness
signed VName
_)) =
  Signedness -> PrimType -> Type
signedPrimTypeToCType Signedness
signed PrimType
pt

-- | Return an expression multiplying together the given expressions.
-- If an empty list is given, the expression @1@ is returned.
cproduct :: [C.Exp] -> C.Exp
cproduct :: [Exp] -> Exp
cproduct [] = [C.cexp|1|]
cproduct (Exp
e : [Exp]
es) = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
mult Exp
e [Exp]
es
  where
    mult :: a -> a -> Exp
mult a
x a
y = [C.cexp|$exp:x * $exp:y|]

-- | Return an expression summing the given expressions.
-- If an empty list is given, the expression @0@ is returned.
csum :: [C.Exp] -> C.Exp
csum :: [Exp] -> Exp
csum [] = [C.cexp|0|]
csum (Exp
e : [Exp]
es) = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
mult Exp
e [Exp]
es
  where
    mult :: a -> a -> Exp
mult a
x a
y = [C.cexp|$exp:x + $exp:y|]

instance C.ToIdent Name where
  toIdent :: Name -> SrcLoc -> Id
toIdent = String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String -> SrcLoc -> Id)
-> (Name -> String) -> Name -> SrcLoc -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
zEncodeString (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameToString

instance C.ToIdent VName where
  toIdent :: VName -> SrcLoc -> Id
toIdent = String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String -> SrcLoc -> Id)
-> (VName -> String) -> VName -> SrcLoc -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
zEncodeString (String -> String) -> (VName -> String) -> VName -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty

instance C.ToExp VName where
  toExp :: VName -> SrcLoc -> Exp
toExp VName
v SrcLoc
_ = [C.cexp|$id:v|]

instance C.ToExp IntValue where
  toExp :: IntValue -> SrcLoc -> Exp
toExp (Int8Value Int8
v) = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Int8
v
  toExp (Int16Value Int16
v) = Int16 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Int16
v
  toExp (Int32Value Int32
v) = Int32 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Int32
v
  toExp (Int64Value Int64
v) = Int64 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Int64
v

instance C.ToExp FloatValue where
  toExp :: FloatValue -> SrcLoc -> Exp
toExp (Float32Value Float
v) = Float -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Float
v
  toExp (Float64Value Double
v) = Double -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Double
v

instance C.ToExp PrimValue where
  toExp :: PrimValue -> SrcLoc -> Exp
toExp (IntValue IntValue
v) = IntValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp IntValue
v
  toExp (FloatValue FloatValue
v) = FloatValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp FloatValue
v
  toExp (BoolValue Bool
True) = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
1 :: Int8)
  toExp (BoolValue Bool
False) = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
0 :: Int8)
  toExp PrimValue
UnitValue = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
1 :: Int8)

instance C.ToExp SubExp where
  toExp :: DimSize -> SrcLoc -> Exp
toExp (Var VName
v) = VName -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp VName
v
  toExp (Constant PrimValue
c) = PrimValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp PrimValue
c

cIntOps :: [C.Definition]
cIntOps :: [Definition]
cIntOps =
  ((IntType -> Definition) -> [Definition])
-> [IntType -> Definition] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((IntType -> Definition) -> [IntType] -> [Definition]
forall a b. (a -> b) -> [a] -> [b]
`map` [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]) [IntType -> Definition]
ops
    [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cIntPrimFuns
  where
    ops :: [IntType -> Definition]
ops =
      [ IntType -> Definition
mkAdd,
        IntType -> Definition
mkSub,
        IntType -> Definition
mkMul,
        IntType -> Definition
mkUDiv,
        IntType -> Definition
mkUDivUp,
        IntType -> Definition
mkUMod,
        IntType -> Definition
mkUDivSafe,
        IntType -> Definition
mkUDivUpSafe,
        IntType -> Definition
mkUModSafe,
        IntType -> Definition
mkSDiv,
        IntType -> Definition
mkSDivUp,
        IntType -> Definition
mkSMod,
        IntType -> Definition
mkSDivSafe,
        IntType -> Definition
mkSDivUpSafe,
        IntType -> Definition
mkSModSafe,
        IntType -> Definition
mkSQuot,
        IntType -> Definition
mkSRem,
        IntType -> Definition
mkSQuotSafe,
        IntType -> Definition
mkSRemSafe,
        IntType -> Definition
mkSMin,
        IntType -> Definition
mkUMin,
        IntType -> Definition
mkSMax,
        IntType -> Definition
mkUMax,
        IntType -> Definition
mkShl,
        IntType -> Definition
mkLShr,
        IntType -> Definition
mkAShr,
        IntType -> Definition
mkAnd,
        IntType -> Definition
mkOr,
        IntType -> Definition
mkXor,
        IntType -> Definition
mkUlt,
        IntType -> Definition
mkUle,
        IntType -> Definition
mkSlt,
        IntType -> Definition
mkSle,
        IntType -> Definition
mkPow,
        IntType -> Definition
mkIToB,
        IntType -> Definition
mkBToI
      ]
        [IntType -> Definition]
-> [IntType -> Definition] -> [IntType -> Definition]
forall a. [a] -> [a] -> [a]
++ (IntType -> IntType -> Definition)
-> [IntType] -> [IntType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map IntType -> IntType -> Definition
mkSExt [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]
        [IntType -> Definition]
-> [IntType -> Definition] -> [IntType -> Definition]
forall a. [a] -> [a] -> [a]
++ (IntType -> IntType -> Definition)
-> [IntType] -> [IntType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map IntType -> IntType -> Definition
mkZExt [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]

    taggedI :: String -> IntType -> String
taggedI String
s IntType
Int8 = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"8"
    taggedI String
s IntType
Int16 = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"16"
    taggedI String
s IntType
Int32 = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"32"
    taggedI String
s IntType
Int64 = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"64"

    -- Use unsigned types for add/sub/mul so we can do
    -- well-defined overflow.
    mkAdd :: IntType -> Definition
mkAdd = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"add" [C.cexp|x + y|]
    mkSub :: IntType -> Definition
mkSub = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"sub" [C.cexp|x - y|]
    mkMul :: IntType -> Definition
mkMul = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"mul" [C.cexp|x * y|]
    mkUDiv :: IntType -> Definition
mkUDiv = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"udiv" [C.cexp|x / y|]
    mkUDivUp :: IntType -> Definition
mkUDivUp = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"udiv_up" [C.cexp|(x+y-1) / y|]
    mkUMod :: IntType -> Definition
mkUMod = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"umod" [C.cexp|x % y|]
    mkUDivSafe :: IntType -> Definition
mkUDivSafe = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"udiv_safe" [C.cexp|y == 0 ? 0 : x / y|]
    mkUDivUpSafe :: IntType -> Definition
mkUDivUpSafe = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"udiv_up_safe" [C.cexp|y == 0 ? 0 : (x+y-1) / y|]
    mkUModSafe :: IntType -> Definition
mkUModSafe = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"umod_safe" [C.cexp|y == 0 ? 0 : x % y|]
    mkUMax :: IntType -> Definition
mkUMax = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"umax" [C.cexp|x < y ? y : x|]
    mkUMin :: IntType -> Definition
mkUMin = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"umin" [C.cexp|x < y ? x : y|]

    mkSDiv :: IntType -> Definition
mkSDiv IntType
t =
      let ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
       in [C.cedecl|static inline $ty:ct $id:(taggedI "sdiv" t)($ty:ct x, $ty:ct y) {
                         $ty:ct q = x / y;
                         $ty:ct r = x % y;
                         return q -
                           (((r != 0) && ((r < 0) != (y < 0))) ? 1 : 0);
             }|]
    mkSDivUp :: IntType -> Definition
mkSDivUp IntType
t =
      String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"sdiv_up" [C.cexp|$id:(taggedI "sdiv" t)(x+y-1,y)|] IntType
t
    mkSMod :: IntType -> Definition
mkSMod IntType
t =
      let ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
       in [C.cedecl|static inline $ty:ct $id:(taggedI "smod" t)($ty:ct x, $ty:ct y) {
                         $ty:ct r = x % y;
                         return r +
                           ((r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0)) ? 0 : y);
              }|]
    mkSDivSafe :: IntType -> Definition
mkSDivSafe IntType
t =
      String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"sdiv_safe" [C.cexp|y == 0 ? 0 : $id:(taggedI "sdiv" t)(x,y)|] IntType
t
    mkSDivUpSafe :: IntType -> Definition
mkSDivUpSafe IntType
t =
      String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"sdiv_up_safe" [C.cexp|$id:(taggedI "sdiv_safe" t)(x+y-1,y)|] IntType
t
    mkSModSafe :: IntType -> Definition
mkSModSafe IntType
t =
      String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"smod_safe" [C.cexp|y == 0 ? 0 : $id:(taggedI "smod" t)(x,y)|] IntType
t

    mkSQuot :: IntType -> Definition
mkSQuot = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"squot" [C.cexp|x / y|]
    mkSRem :: IntType -> Definition
mkSRem = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"srem" [C.cexp|x % y|]
    mkSQuotSafe :: IntType -> Definition
mkSQuotSafe = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"squot_safe" [C.cexp|y == 0 ? 0 : x / y|]
    mkSRemSafe :: IntType -> Definition
mkSRemSafe = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"srem_safe" [C.cexp|y == 0 ? 0 : x % y|]
    mkSMax :: IntType -> Definition
mkSMax = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"smax" [C.cexp|x < y ? y : x|]
    mkSMin :: IntType -> Definition
mkSMin = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"smin" [C.cexp|x < y ? x : y|]
    mkShl :: IntType -> Definition
mkShl = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"shl" [C.cexp|x << y|]
    mkLShr :: IntType -> Definition
mkLShr = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"lshr" [C.cexp|x >> y|]
    mkAShr :: IntType -> Definition
mkAShr = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"ashr" [C.cexp|x >> y|]
    mkAnd :: IntType -> Definition
mkAnd = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"and" [C.cexp|x & y|]
    mkOr :: IntType -> Definition
mkOr = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"or" [C.cexp|x | y|]
    mkXor :: IntType -> Definition
mkXor = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"xor" [C.cexp|x ^ y|]
    mkUlt :: IntType -> Definition
mkUlt = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
uintCmpOp String
"ult" [C.cexp|x < y|]
    mkUle :: IntType -> Definition
mkUle = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
uintCmpOp String
"ule" [C.cexp|x <= y|]
    mkSlt :: IntType -> Definition
mkSlt = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
intCmpOp String
"slt" [C.cexp|x < y|]
    mkSle :: IntType -> Definition
mkSle = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
intCmpOp String
"sle" [C.cexp|x <= y|]

    -- We define some operations as macros rather than functions,
    -- because this allows us to use them as constant expressions
    -- in things like array sizes and static initialisers.
    macro :: String -> a -> Definition
macro String
name a
rhs =
      [C.cedecl|$esc:("#define " ++ name ++ "(x) (" ++ prettyOneLine rhs ++ ")")|]

    mkPow :: IntType -> Definition
mkPow IntType
t =
      let ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
       in [C.cedecl|static inline $ty:ct $id:(taggedI "pow" t)($ty:ct x, $ty:ct y) {
                         $ty:ct res = 1, rem = y;
                         while (rem != 0) {
                           if (rem & 1) {
                             res *= x;
                           }
                           rem >>= 1;
                           x *= x;
                         }
                         return res;
              }|]

    mkSExt :: IntType -> IntType -> Definition
mkSExt IntType
from_t IntType
to_t = String -> Exp -> Definition
forall {a}. Pretty a => String -> a -> Definition
macro String
name [C.cexp|($ty:to_ct)(($ty:from_ct)x)|]
      where
        name :: String
name = String
"sext_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
from_t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
to_t
        from_ct :: Type
from_ct = IntType -> Type
intTypeToCType IntType
from_t
        to_ct :: Type
to_ct = IntType -> Type
intTypeToCType IntType
to_t

    mkZExt :: IntType -> IntType -> Definition
mkZExt IntType
from_t IntType
to_t = String -> Exp -> Definition
forall {a}. Pretty a => String -> a -> Definition
macro String
name [C.cexp|($ty:to_ct)(($ty:from_ct)x)|]
      where
        name :: String
name = String
"zext_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
from_t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
to_t
        from_ct :: Type
from_ct = IntType -> Type
uintTypeToCType IntType
from_t
        to_ct :: Type
to_ct = IntType -> Type
intTypeToCType IntType
to_t

    mkBToI :: IntType -> Definition
mkBToI IntType
to_t =
      [C.cedecl|static inline $ty:to_ct
                    $id:name($ty:from_ct x) { return x; } |]
      where
        name :: String
name = String
"btoi_bool_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
to_t
        from_ct :: Type
from_ct = PrimType -> Type
primTypeToCType PrimType
Bool
        to_ct :: Type
to_ct = IntType -> Type
intTypeToCType IntType
to_t

    mkIToB :: IntType -> Definition
mkIToB IntType
from_t =
      [C.cedecl|static inline $ty:to_ct
                    $id:name($ty:from_ct x) { return x; } |]
      where
        name :: String
name = String
"itob_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
from_t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_bool"
        to_ct :: Type
to_ct = PrimType -> Type
primTypeToCType PrimType
Bool
        from_ct :: Type
from_ct = IntType -> Type
intTypeToCType IntType
from_t

    simpleUintOp :: String -> a -> IntType -> Definition
simpleUintOp String
s a
e IntType
t =
      [C.cedecl|static inline $ty:ct $id:(taggedI s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
      where
        ct :: Type
ct = IntType -> Type
uintTypeToCType IntType
t
    simpleIntOp :: String -> a -> IntType -> Definition
simpleIntOp String
s a
e IntType
t =
      [C.cedecl|static inline $ty:ct $id:(taggedI s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
      where
        ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
    intCmpOp :: String -> a -> IntType -> Definition
intCmpOp String
s a
e IntType
t =
      [C.cedecl|static inline typename bool $id:(taggedI s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
      where
        ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
    uintCmpOp :: String -> a -> IntType -> Definition
uintCmpOp String
s a
e IntType
t =
      [C.cedecl|static inline typename bool $id:(taggedI s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
      where
        ct :: Type
ct = IntType -> Type
uintTypeToCType IntType
t

cIntPrimFuns :: [C.Definition]
cIntPrimFuns :: [Definition]
cIntPrimFuns =
  [C.cunit|
$esc:("#if defined(__OPENCL_VERSION__)")
   static typename int32_t $id:(funName' "popc8") (typename int8_t x) {
      return popcount(x);
   }
   static typename int32_t $id:(funName' "popc16") (typename int16_t x) {
      return popcount(x);
   }
   static typename int32_t $id:(funName' "popc32") (typename int32_t x) {
      return popcount(x);
   }
   static typename int32_t $id:(funName' "popc64") (typename int64_t x) {
      return popcount(x);
   }
$esc:("#elif defined(__CUDA_ARCH__)")
   static typename int32_t $id:(funName' "popc8") (typename int8_t x) {
      return __popc(zext_i8_i32(x));
   }
   static typename int32_t $id:(funName' "popc16") (typename int16_t x) {
      return __popc(zext_i16_i32(x));
   }
   static typename int32_t $id:(funName' "popc32") (typename int32_t x) {
      return __popc(x);
   }
   static typename int32_t $id:(funName' "popc64") (typename int64_t x) {
      return __popcll(x);
   }
$esc:("#else")
   static typename int32_t $id:(funName' "popc8") (typename int8_t x) {
     int c = 0;
     for (; x; ++c) {
       x &= x - 1;
     }
     return c;
    }
   static typename int32_t $id:(funName' "popc16") (typename int16_t x) {
     int c = 0;
     for (; x; ++c) {
       x &= x - 1;
     }
     return c;
   }
   static typename int32_t $id:(funName' "popc32") (typename int32_t x) {
     int c = 0;
     for (; x; ++c) {
       x &= x - 1;
     }
     return c;
   }
   static typename int32_t $id:(funName' "popc64") (typename int64_t x) {
     int c = 0;
     for (; x; ++c) {
       x &= x - 1;
     }
     return c;
   }
$esc:("#endif")

$esc:("#if defined(__OPENCL_VERSION__)")
   static typename uint8_t $id:(funName' "mul_hi8") (typename uint8_t a, typename uint8_t b) {
      return mul_hi(a, b);
   }
   static typename uint16_t $id:(funName' "mul_hi16") (typename uint16_t a, typename uint16_t b) {
      return mul_hi(a, b);
   }
   static typename uint32_t $id:(funName' "mul_hi32") (typename uint32_t a, typename uint32_t b) {
      return mul_hi(a, b);
   }
   static typename uint64_t $id:(funName' "mul_hi64") (typename uint64_t a, typename uint64_t b) {
      return mul_hi(a, b);
   }
$esc:("#elif defined(__CUDA_ARCH__)")
   static typename uint8_t $id:(funName' "mul_hi8") (typename uint8_t a, typename uint8_t b) {
     typename uint16_t aa = a;
     typename uint16_t bb = b;
     return (aa * bb) >> 8;
   }
   static typename uint16_t $id:(funName' "mul_hi16") (typename uint16_t a, typename uint16_t b) {
     typename uint32_t aa = a;
     typename uint32_t bb = b;
     return (aa * bb) >> 16;
   }
   static typename uint32_t $id:(funName' "mul_hi32") (typename uint32_t a, typename uint32_t b) {
      return mulhi(a, b);
   }
   static typename uint64_t $id:(funName' "mul_hi64") (typename uint64_t a, typename uint64_t b) {
      return mul64hi(a, b);
   }
$esc:("#else")
   static typename uint8_t $id:(funName' "mul_hi8") (typename uint8_t a, typename uint8_t b) {
     typename uint16_t aa = a;
     typename uint16_t bb = b;
     return (aa * bb) >> 8;
    }
   static typename uint16_t $id:(funName' "mul_hi16") (typename uint16_t a, typename uint16_t b) {
     typename uint32_t aa = a;
     typename uint32_t bb = b;
     return (aa * bb) >> 16;
    }
   static typename uint32_t $id:(funName' "mul_hi32") (typename uint32_t a, typename uint32_t b) {
     typename uint64_t aa = a;
     typename uint64_t bb = b;
     return (aa * bb) >> 32;
    }
   static typename uint64_t $id:(funName' "mul_hi64") (typename uint64_t a, typename uint64_t b) {
     typename __uint128_t aa = a;
     typename __uint128_t bb = b;
     return (aa * bb) >> 64;
    }
$esc:("#endif")

$esc:("#if defined(__OPENCL_VERSION__)")
   static typename uint8_t $id:(funName' "mad_hi8") (typename uint8_t a, typename uint8_t b, typename uint8_t c) {
      return mad_hi(a, b, c);
   }
   static typename uint16_t $id:(funName' "mad_hi16") (typename uint16_t a, typename uint16_t b, typename uint16_t c) {
      return mad_hi(a, b, c);
   }
   static typename uint32_t $id:(funName' "mad_hi32") (typename uint32_t a, typename uint32_t b, typename uint32_t c) {
      return mad_hi(a, b, c);
   }
   static typename uint64_t $id:(funName' "mad_hi64") (typename uint64_t a, typename uint64_t b, typename uint64_t c) {
      return mad_hi(a, b, c);
   }
$esc:("#else")
   static typename uint8_t $id:(funName' "mad_hi8") (typename uint8_t a, typename uint8_t b, typename uint8_t c) {
     return futrts_mul_hi8(a, b) + c;
    }
   static typename uint16_t $id:(funName' "mad_hi16") (typename uint16_t a, typename uint16_t b, typename uint16_t c) {
     return futrts_mul_hi16(a, b) + c;
    }
   static typename uint32_t $id:(funName' "mad_hi32") (typename uint32_t a, typename uint32_t b, typename uint32_t c) {
     return futrts_mul_hi32(a, b) + c;
    }
   static typename uint64_t $id:(funName' "mad_hi64") (typename uint64_t a, typename uint64_t b, typename uint64_t c) {
     return futrts_mul_hi64(a, b) + c;
    }
$esc:("#endif")


$esc:("#if defined(__OPENCL_VERSION__)")
   static typename int32_t $id:(funName' "clz8") (typename int8_t x) {
      return clz(x);
   }
   static typename int32_t $id:(funName' "clz16") (typename int16_t x) {
      return clz(x);
   }
   static typename int32_t $id:(funName' "clz32") (typename int32_t x) {
      return clz(x);
   }
   static typename int32_t $id:(funName' "clz64") (typename int64_t x) {
      return clz(x);
   }
$esc:("#elif defined(__CUDA_ARCH__)")
   static typename int32_t $id:(funName' "clz8") (typename int8_t x) {
      return __clz(zext_i8_i32(x))-24;
   }
   static typename int32_t $id:(funName' "clz16") (typename int16_t x) {
      return __clz(zext_i16_i32(x))-16;
   }
   static typename int32_t $id:(funName' "clz32") (typename int32_t x) {
      return __clz(x);
   }
   static typename int32_t $id:(funName' "clz64") (typename int64_t x) {
      return __clzll(x);
   }
$esc:("#else")
   static typename int32_t $id:(funName' "clz8") (typename int8_t x) {
    int n = 0;
    int bits = sizeof(x) * 8;
    for (int i = 0; i < bits; i++) {
        if (x < 0) break;
        n++;
        x <<= 1;
    }
    return n;
   }
   static typename int32_t $id:(funName' "clz16") (typename int16_t x) {
    int n = 0;
    int bits = sizeof(x) * 8;
    for (int i = 0; i < bits; i++) {
        if (x < 0) break;
        n++;
        x <<= 1;
    }
    return n;
   }
   static typename int32_t $id:(funName' "clz32") (typename int32_t x) {
    int n = 0;
    int bits = sizeof(x) * 8;
    for (int i = 0; i < bits; i++) {
        if (x < 0) break;
        n++;
        x <<= 1;
    }
    return n;
   }
   static typename int32_t $id:(funName' "clz64") (typename int64_t x) {
    int n = 0;
    int bits = sizeof(x) * 8;
    for (int i = 0; i < bits; i++) {
        if (x < 0) break;
        n++;
        x <<= 1;
    }
    return n;
   }
$esc:("#endif")

$esc:("#if defined(__OPENCL_VERSION__)")
   // OpenCL has ctz, but only from version 2.0, which we cannot assume we are using.
   static typename int32_t $id:(funName' "ctz8") (typename int8_t x) {
      int i = 0;
      for (; i < 8 && (x&1)==0; i++, x>>=1);
      return i;
   }
   static typename int32_t $id:(funName' "ctz16") (typename int16_t x) {
      int i = 0;
      for (; i < 16 && (x&1)==0; i++, x>>=1);
      return i;
   }
   static typename int32_t $id:(funName' "ctz32") (typename int32_t x) {
      int i = 0;
      for (; i < 32 && (x&1)==0; i++, x>>=1);
      return i;
   }
   static typename int32_t $id:(funName' "ctz64") (typename int64_t x) {
      int i = 0;
      for (; i < 64 && (x&1)==0; i++, x>>=1);
      return i;
   }
$esc:("#elif defined(__CUDA_ARCH__)")
   static typename int32_t $id:(funName' "ctz8") (typename int8_t x) {
     int y = __ffs(x);
     return y == 0 ? 8 : y-1;
   }
   static typename int32_t $id:(funName' "ctz16") (typename int16_t x) {
     int y = __ffs(x);
     return y == 0 ? 16 : y-1;
   }
   static typename int32_t $id:(funName' "ctz32") (typename int32_t x) {
     int y = __ffs(x);
     return y == 0 ? 32 : y-1;
   }
   static typename int32_t $id:(funName' "ctz64") (typename int64_t x) {
     int y = __ffsll(x);
     return y == 0 ? 64 : y-1;
   }
$esc:("#else")
// FIXME: assumes GCC or clang.
   static typename int32_t $id:(funName' "ctz8") (typename int8_t x) {
     return x == 0 ? 8 : __builtin_ctz((typename uint32_t)x);
   }
   static typename int32_t $id:(funName' "ctz16") (typename int16_t x) {
     return x == 0 ? 16 : __builtin_ctz((typename uint32_t)x);
   }
   static typename int32_t $id:(funName' "ctz32") (typename int32_t x) {
     return x == 0 ? 32 :  __builtin_ctz(x);
   }
   static typename int32_t $id:(funName' "ctz64") (typename int64_t x) {
     return x == 0 ? 64 : __builtin_ctzll(x);
   }
$esc:("#endif")
                |]

cFloat32Ops :: [C.Definition]
cFloat64Ops :: [C.Definition]
cFloatConvOps :: [C.Definition]
([Definition]
cFloat32Ops, [Definition]
cFloat64Ops, [Definition]
cFloatConvOps) =
  ( ((FloatType -> Definition) -> Definition)
-> [FloatType -> Definition] -> [Definition]
forall a b. (a -> b) -> [a] -> [b]
map ((FloatType -> Definition) -> FloatType -> Definition
forall a b. (a -> b) -> a -> b
$ FloatType
Float32) [FloatType -> Definition]
mkOps,
    ((FloatType -> Definition) -> Definition)
-> [FloatType -> Definition] -> [Definition]
forall a b. (a -> b) -> [a] -> [b]
map ((FloatType -> Definition) -> FloatType -> Definition
forall a b. (a -> b) -> a -> b
$ FloatType
Float64) [FloatType -> Definition]
mkOps,
    [ String -> FloatType -> FloatType -> Definition
mkFPConvFF String
"fpconv" FloatType
from FloatType
to
      | FloatType
from <- [FloatType
forall a. Bounded a => a
minBound .. FloatType
forall a. Bounded a => a
maxBound],
        FloatType
to <- [FloatType
forall a. Bounded a => a
minBound .. FloatType
forall a. Bounded a => a
maxBound]
    ]
  )
  where
    taggedF :: String -> FloatType -> String
taggedF String
s FloatType
Float32 = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"32"
    taggedF String
s FloatType
Float64 = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"64"
    convOp :: String -> a -> a -> String
convOp String
s a
from a
to = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Pretty a => a -> String
pretty a
from String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Pretty a => a -> String
pretty a
to

    mkOps :: [FloatType -> Definition]
mkOps =
      [FloatType -> Definition
mkFDiv, FloatType -> Definition
mkFAdd, FloatType -> Definition
mkFSub, FloatType -> Definition
mkFMul, FloatType -> Definition
mkFMin, FloatType -> Definition
mkFMax, FloatType -> Definition
mkPow, FloatType -> Definition
mkCmpLt, FloatType -> Definition
mkCmpLe]
        [FloatType -> Definition]
-> [FloatType -> Definition] -> [FloatType -> Definition]
forall a. [a] -> [a] -> [a]
++ (IntType -> FloatType -> Definition)
-> [IntType] -> [FloatType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map (String -> IntType -> FloatType -> Definition
mkFPConvIF String
"sitofp") [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]
        [FloatType -> Definition]
-> [FloatType -> Definition] -> [FloatType -> Definition]
forall a. [a] -> [a] -> [a]
++ (IntType -> FloatType -> Definition)
-> [IntType] -> [FloatType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map (String -> IntType -> FloatType -> Definition
mkFPConvUF String
"uitofp") [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]
        [FloatType -> Definition]
-> [FloatType -> Definition] -> [FloatType -> Definition]
forall a. [a] -> [a] -> [a]
++ (IntType -> FloatType -> Definition)
-> [IntType] -> [FloatType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map ((FloatType -> IntType -> Definition)
-> IntType -> FloatType -> Definition
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((FloatType -> IntType -> Definition)
 -> IntType -> FloatType -> Definition)
-> (FloatType -> IntType -> Definition)
-> IntType
-> FloatType
-> Definition
forall a b. (a -> b) -> a -> b
$ String -> FloatType -> IntType -> Definition
mkFPConvFI String
"fptosi") [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]
        [FloatType -> Definition]
-> [FloatType -> Definition] -> [FloatType -> Definition]
forall a. [a] -> [a] -> [a]
++ (IntType -> FloatType -> Definition)
-> [IntType] -> [FloatType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map ((FloatType -> IntType -> Definition)
-> IntType -> FloatType -> Definition
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((FloatType -> IntType -> Definition)
 -> IntType -> FloatType -> Definition)
-> (FloatType -> IntType -> Definition)
-> IntType
-> FloatType
-> Definition
forall a b. (a -> b) -> a -> b
$ String -> FloatType -> IntType -> Definition
mkFPConvFU String
"fptoui") [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]

    mkFDiv :: FloatType -> Definition
mkFDiv = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
simpleFloatOp String
"fdiv" [C.cexp|x / y|]
    mkFAdd :: FloatType -> Definition
mkFAdd = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
simpleFloatOp String
"fadd" [C.cexp|x + y|]
    mkFSub :: FloatType -> Definition
mkFSub = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
simpleFloatOp String
"fsub" [C.cexp|x - y|]
    mkFMul :: FloatType -> Definition
mkFMul = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
simpleFloatOp String
"fmul" [C.cexp|x * y|]
    mkFMin :: FloatType -> Definition
mkFMin = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
simpleFloatOp String
"fmin" [C.cexp|fmin(x, y)|]
    mkFMax :: FloatType -> Definition
mkFMax = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
simpleFloatOp String
"fmax" [C.cexp|fmax(x, y)|]
    mkCmpLt :: FloatType -> Definition
mkCmpLt = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
floatCmpOp String
"cmplt" [C.cexp|x < y|]
    mkCmpLe :: FloatType -> Definition
mkCmpLe = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
floatCmpOp String
"cmple" [C.cexp|x <= y|]

    mkPow :: FloatType -> Definition
mkPow FloatType
Float32 =
      [C.cedecl|static inline float fpow32(float x, float y) { return pow(x, y); }|]
    mkPow FloatType
Float64 =
      [C.cedecl|static inline double fpow64(double x, double y) { return pow(x, y); }|]

    mkFPConv :: (a -> Type) -> (a -> Type) -> String -> a -> a -> Definition
mkFPConv a -> Type
from_f a -> Type
to_f String
s a
from_t a
to_t =
      [C.cedecl|static inline $ty:to_ct
                    $id:(convOp s from_t to_t)($ty:from_ct x) { return ($ty:to_ct)x;} |]
      where
        from_ct :: Type
from_ct = a -> Type
from_f a
from_t
        to_ct :: Type
to_ct = a -> Type
to_f a
to_t

    mkFPConvFF :: String -> FloatType -> FloatType -> Definition
mkFPConvFF = (FloatType -> Type)
-> (FloatType -> Type)
-> String
-> FloatType
-> FloatType
-> Definition
forall {a} {a}.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> String -> a -> a -> Definition
mkFPConv FloatType -> Type
floatTypeToCType FloatType -> Type
floatTypeToCType
    mkFPConvFI :: String -> FloatType -> IntType -> Definition
mkFPConvFI = (FloatType -> Type)
-> (IntType -> Type)
-> String
-> FloatType
-> IntType
-> Definition
forall {a} {a}.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> String -> a -> a -> Definition
mkFPConv FloatType -> Type
floatTypeToCType IntType -> Type
intTypeToCType
    mkFPConvIF :: String -> IntType -> FloatType -> Definition
mkFPConvIF = (IntType -> Type)
-> (FloatType -> Type)
-> String
-> IntType
-> FloatType
-> Definition
forall {a} {a}.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> String -> a -> a -> Definition
mkFPConv IntType -> Type
intTypeToCType FloatType -> Type
floatTypeToCType
    mkFPConvFU :: String -> FloatType -> IntType -> Definition
mkFPConvFU = (FloatType -> Type)
-> (IntType -> Type)
-> String
-> FloatType
-> IntType
-> Definition
forall {a} {a}.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> String -> a -> a -> Definition
mkFPConv FloatType -> Type
floatTypeToCType IntType -> Type
uintTypeToCType
    mkFPConvUF :: String -> IntType -> FloatType -> Definition
mkFPConvUF = (IntType -> Type)
-> (FloatType -> Type)
-> String
-> IntType
-> FloatType
-> Definition
forall {a} {a}.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> String -> a -> a -> Definition
mkFPConv IntType -> Type
uintTypeToCType FloatType -> Type
floatTypeToCType

    simpleFloatOp :: String -> a -> FloatType -> Definition
simpleFloatOp String
s a
e FloatType
t =
      [C.cedecl|static inline $ty:ct $id:(taggedF s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
      where
        ct :: Type
ct = FloatType -> Type
floatTypeToCType FloatType
t
    floatCmpOp :: String -> a -> FloatType -> Definition
floatCmpOp String
s a
e FloatType
t =
      [C.cedecl|static inline typename bool $id:(taggedF s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
      where
        ct :: Type
ct = FloatType -> Type
floatTypeToCType FloatType
t

cFloat32Funs :: [C.Definition]
cFloat32Funs :: [Definition]
cFloat32Funs =
  [C.cunit|
    static inline typename bool $id:(funName' "isnan32")(float x) {
      return isnan(x);
    }

    static inline typename bool $id:(funName' "isinf32")(float x) {
      return isinf(x);
    }

$esc:("#ifdef __OPENCL_VERSION__")
    static inline float $id:(funName' "log32")(float x) {
      return log(x);
    }

    static inline float $id:(funName' "log2_32")(float x) {
      return log2(x);
    }

    static inline float $id:(funName' "log10_32")(float x) {
      return log10(x);
    }

    static inline float $id:(funName' "sqrt32")(float x) {
      return sqrt(x);
    }

    static inline float $id:(funName' "exp32")(float x) {
      return exp(x);
    }

    static inline float $id:(funName' "cos32")(float x) {
      return cos(x);
    }

    static inline float $id:(funName' "sin32")(float x) {
      return sin(x);
    }

    static inline float $id:(funName' "tan32")(float x) {
      return tan(x);
    }

    static inline float $id:(funName' "acos32")(float x) {
      return acos(x);
    }

    static inline float $id:(funName' "asin32")(float x) {
      return asin(x);
    }

    static inline float $id:(funName' "atan32")(float x) {
      return atan(x);
    }

    static inline float $id:(funName' "cosh32")(float x) {
      return cosh(x);
    }

    static inline float $id:(funName' "sinh32")(float x) {
      return sinh(x);
    }

    static inline float $id:(funName' "tanh32")(float x) {
      return tanh(x);
    }

    static inline float $id:(funName' "acosh32")(float x) {
      return acosh(x);
    }

    static inline float $id:(funName' "asinh32")(float x) {
      return asinh(x);
    }

    static inline float $id:(funName' "atanh32")(float x) {
      return atanh(x);
    }

    static inline float $id:(funName' "atan2_32")(float x, float y) {
      return atan2(x,y);
    }

    static inline float $id:(funName' "hypot32")(float x, float y) {
      return hypot(x,y);
    }

    static inline float $id:(funName' "gamma32")(float x) {
      return tgamma(x);
    }

    static inline float $id:(funName' "lgamma32")(float x) {
      return lgamma(x);
    }

    static inline float fmod32(float x, float y) {
      return fmod(x, y);
    }
    static inline float $id:(funName' "round32")(float x) {
      return rint(x);
    }
    static inline float $id:(funName' "floor32")(float x) {
      return floor(x);
    }
    static inline float $id:(funName' "ceil32")(float x) {
      return ceil(x);
    }
    static inline float $id:(funName' "lerp32")(float v0, float v1, float t) {
      return mix(v0, v1, t);
    }
    static inline float $id:(funName' "mad32")(float a, float b, float c) {
      return mad(a,b,c);
    }
    static inline float $id:(funName' "fma32")(float a, float b, float c) {
      return fma(a,b,c);
    }
$esc:("#else")
    static inline float $id:(funName' "log32")(float x) {
      return logf(x);
    }

    static inline float $id:(funName' "log2_32")(float x) {
      return log2f(x);
    }

    static inline float $id:(funName' "log10_32")(float x) {
      return log10f(x);
    }

    static inline float $id:(funName' "sqrt32")(float x) {
      return sqrtf(x);
    }

    static inline float $id:(funName' "exp32")(float x) {
      return expf(x);
    }

    static inline float $id:(funName' "cos32")(float x) {
      return cosf(x);
    }

    static inline float $id:(funName' "sin32")(float x) {
      return sinf(x);
    }

    static inline float $id:(funName' "tan32")(float x) {
      return tanf(x);
    }

    static inline float $id:(funName' "acos32")(float x) {
      return acosf(x);
    }

    static inline float $id:(funName' "asin32")(float x) {
      return asinf(x);
    }

    static inline float $id:(funName' "atan32")(float x) {
      return atanf(x);
    }

    static inline float $id:(funName' "cosh32")(float x) {
      return coshf(x);
    }

    static inline float $id:(funName' "sinh32")(float x) {
      return sinhf(x);
    }

    static inline float $id:(funName' "tanh32")(float x) {
      return tanhf(x);
    }

    static inline float $id:(funName' "acosh32")(float x) {
      return acoshf(x);
    }

    static inline float $id:(funName' "asinh32")(float x) {
      return asinhf(x);
    }

    static inline float $id:(funName' "atanh32")(float x) {
      return atanhf(x);
    }

    static inline float $id:(funName' "atan2_32")(float x, float y) {
      return atan2f(x,y);
    }

    static inline float $id:(funName' "hypot32")(float x, float y) {
      return hypotf(x,y);
    }

    static inline float $id:(funName' "gamma32")(float x) {
      return tgammaf(x);
    }

    static inline float $id:(funName' "lgamma32")(float x) {
      return lgammaf(x);
    }

    static inline float fmod32(float x, float y) {
      return fmodf(x, y);
    }
    static inline float $id:(funName' "round32")(float x) {
      return rintf(x);
    }
    static inline float $id:(funName' "floor32")(float x) {
      return floorf(x);
    }
    static inline float $id:(funName' "ceil32")(float x) {
      return ceilf(x);
    }
    static inline float $id:(funName' "lerp32")(float v0, float v1, float t) {
      return v0 + (v1-v0)*t;
    }
    static inline float $id:(funName' "mad32")(float a, float b, float c) {
      return a*b+c;
    }
    static inline float $id:(funName' "fma32")(float a, float b, float c) {
      return fmaf(a,b,c);
    }
$esc:("#endif")
    static inline typename int32_t $id:(funName' "to_bits32")(float x) {
      union {
        float f;
        typename int32_t t;
      } p;
      p.f = x;
      return p.t;
    }

    static inline float $id:(funName' "from_bits32")(typename int32_t x) {
      union {
        typename int32_t f;
        float t;
      } p;
      p.f = x;
      return p.t;
    }

    static inline float fsignum32(float x) {
      return $id:(funName' "isnan32")(x) ? x : ((x > 0) - (x < 0));
    }
|]

cFloat64Funs :: [C.Definition]
cFloat64Funs :: [Definition]
cFloat64Funs =
  [C.cunit|
    static inline double $id:(funName' "log64")(double x) {
      return log(x);
    }

    static inline double $id:(funName' "log2_64")(double x) {
      return log2(x);
    }

    static inline double $id:(funName' "log10_64")(double x) {
      return log10(x);
    }

    static inline double $id:(funName' "sqrt64")(double x) {
      return sqrt(x);
    }

    static inline double $id:(funName' "exp64")(double x) {
      return exp(x);
    }

    static inline double $id:(funName' "cos64")(double x) {
      return cos(x);
    }

    static inline double $id:(funName' "sin64")(double x) {
      return sin(x);
    }

    static inline double $id:(funName' "tan64")(double x) {
      return tan(x);
    }

    static inline double $id:(funName' "acos64")(double x) {
      return acos(x);
    }

    static inline double $id:(funName' "asin64")(double x) {
      return asin(x);
    }

    static inline double $id:(funName' "atan64")(double x) {
      return atan(x);
    }

    static inline double $id:(funName' "cosh64")(double x) {
      return cosh(x);
    }

    static inline double $id:(funName' "sinh64")(double x) {
      return sinh(x);
    }

    static inline double $id:(funName' "tanh64")(double x) {
      return tanh(x);
    }

    static inline double $id:(funName' "acosh64")(double x) {
      return acosh(x);
    }

    static inline double $id:(funName' "asinh64")(double x) {
      return asinh(x);
    }

    static inline double $id:(funName' "atanh64")(double x) {
      return atanh(x);
    }

    static inline double $id:(funName' "atan2_64")(double x, double y) {
      return atan2(x,y);
    }

    static inline double $id:(funName' "hypot64")(double x, double y) {
      return hypot(x,y);
    }

    static inline double $id:(funName' "gamma64")(double x) {
      return tgamma(x);
    }

    static inline double $id:(funName' "lgamma64")(double x) {
      return lgamma(x);
    }

    static inline double $id:(funName' "fma64")(double a, double b, double c) {
      return fma(a,b,c);
    }

    static inline double $id:(funName' "round64")(double x) {
      return rint(x);
    }

    static inline double $id:(funName' "ceil64")(double x) {
      return ceil(x);
    }

    static inline double $id:(funName' "floor64")(double x) {
      return floor(x);
    }

    static inline typename bool $id:(funName' "isnan64")(double x) {
      return isnan(x);
    }

    static inline typename bool $id:(funName' "isinf64")(double x) {
      return isinf(x);
    }

    static inline typename int64_t $id:(funName' "to_bits64")(double x) {
      union {
        double f;
        typename int64_t t;
      } p;
      p.f = x;
      return p.t;
    }

    static inline double $id:(funName' "from_bits64")(typename int64_t x) {
      union {
        typename int64_t f;
        double t;
      } p;
      p.f = x;
      return p.t;
    }

    static inline double fmod64(double x, double y) {
      return fmod(x, y);
    }

    static inline double fsignum64(double x) {
      return $id:(funName' "isnan64")(x) ? x : ((x > 0) - (x < 0));
    }

$esc:("#ifdef __OPENCL_VERSION__")
    static inline double $id:(funName' "lerp64")(double v0, double v1, double t) {
      return mix(v0, v1, t);
    }
    static inline double $id:(funName' "mad64")(double a, double b, double c) {
      return mad(a,b,c);
    }
$esc:("#else")
    static inline double $id:(funName' "lerp64")(double v0, double v1, double t) {
      return v0 + (v1-v0)*t;
    }
    static inline double $id:(funName' "mad64")(double a, double b, double c) {
      return a*b+c;
    }
$esc:("#endif")
|]

storageSize :: PrimType -> Int -> C.Exp -> C.Exp
storageSize :: PrimType -> Int -> Exp -> Exp
storageSize PrimType
pt Int
rank Exp
shape =
  [C.cexp|$int:header_size +
          $int:rank * sizeof(typename int64_t) +
          $exp:(cproduct dims) * $int:pt_size|]
  where
    header_size, pt_size :: Int
    header_size :: Int
header_size = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
4 -- 'b' <version> <num_dims> <type>
    pt_size :: Int
pt_size = PrimType -> Int
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
    dims :: [Exp]
dims = [[C.cexp|$exp:shape[$int:i]|] | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]

typeStr :: Signedness -> PrimType -> String
typeStr :: Signedness -> PrimType -> String
typeStr Signedness
sign PrimType
pt =
  case (Signedness
sign, PrimType
pt) of
    (Signedness
_, PrimType
Bool) -> String
"bool"
    (Signedness
_, PrimType
Unit) -> String
"bool"
    (Signedness
_, FloatType FloatType
Float32) -> String
" f32"
    (Signedness
_, FloatType FloatType
Float64) -> String
" f64"
    (Signedness
TypeDirect, IntType IntType
Int8) -> String
"  i8"
    (Signedness
TypeDirect, IntType IntType
Int16) -> String
" i16"
    (Signedness
TypeDirect, IntType IntType
Int32) -> String
" i32"
    (Signedness
TypeDirect, IntType IntType
Int64) -> String
" i64"
    (Signedness
TypeUnsigned, IntType IntType
Int8) -> String
"  u8"
    (Signedness
TypeUnsigned, IntType IntType
Int16) -> String
" u16"
    (Signedness
TypeUnsigned, IntType IntType
Int32) -> String
" u32"
    (Signedness
TypeUnsigned, IntType IntType
Int64) -> String
" u64"

storeValueHeader :: Signedness -> PrimType -> Int -> C.Exp -> C.Exp -> [C.Stm]
storeValueHeader :: Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
storeValueHeader Signedness
sign PrimType
pt Int
rank Exp
shape Exp
dest =
  [C.cstms|
          *$exp:dest++ = 'b';
          *$exp:dest++ = 2;
          *$exp:dest++ = $int:rank;
          memcpy($exp:dest, $string:(typeStr sign pt), 4);
          $exp:dest += 4;
          $stms:copy_shape
          |]
  where
    copy_shape :: [Stm]
copy_shape
      | Int
rank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = []
      | Bool
otherwise =
        [C.cstms|
                memcpy($exp:dest, $exp:shape, $int:rank*sizeof(typename int64_t));
                $exp:dest += $int:rank*sizeof(typename int64_t);|]

loadValueHeader :: Signedness -> PrimType -> Int -> C.Exp -> C.Exp -> [C.Stm]
loadValueHeader :: Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
loadValueHeader Signedness
sign PrimType
pt Int
rank Exp
shape Exp
src =
  [C.cstms|
     err |= (*$exp:src++ != 'b');
     err |= (*$exp:src++ != 2);
     err |= (*$exp:src++ != $exp:rank);
     err |= (memcmp($exp:src, $string:(typeStr sign pt), 4) != 0);
     $exp:src += 4;
     if (err == 0) {
       $stms:load_shape
       $exp:src += $int:rank*sizeof(typename int64_t);
     }|]
  where
    load_shape :: [Stm]
load_shape
      | Int
rank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = []
      | Bool
otherwise = [C.cstms|memcpy($exp:shape, src, $int:rank*sizeof(typename int64_t));|]