{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE Trustworthy #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Futhark.CodeGen.Backends.SimpleRep
( tupleField,
funName,
defaultMemBlockType,
intTypeToCType,
primTypeToCType,
signedPrimTypeToCType,
arrayName,
opaqueName,
externalValueType,
cproduct,
csum,
cIntOps,
cFloat32Ops,
cFloat32Funs,
cFloat64Ops,
cFloat64Funs,
cFloatConvOps,
storageSize,
storeValueHeader,
loadValueHeader,
)
where
import Data.Bits (shiftR, xor)
import Data.Char (isAlphaNum, isDigit, ord)
import Futhark.CodeGen.ImpCode
import Futhark.Util (zEncodeString)
import Futhark.Util.Pretty (prettyOneLine)
import qualified Language.C.Quote.C as C
import qualified Language.C.Syntax as C
import Text.Printf
intTypeToCType :: IntType -> C.Type
intTypeToCType :: IntType -> Type
intTypeToCType IntType
Int8 = [C.cty|typename int8_t|]
intTypeToCType IntType
Int16 = [C.cty|typename int16_t|]
intTypeToCType IntType
Int32 = [C.cty|typename int32_t|]
intTypeToCType IntType
Int64 = [C.cty|typename int64_t|]
uintTypeToCType :: IntType -> C.Type
uintTypeToCType :: IntType -> Type
uintTypeToCType IntType
Int8 = [C.cty|typename uint8_t|]
uintTypeToCType IntType
Int16 = [C.cty|typename uint16_t|]
uintTypeToCType IntType
Int32 = [C.cty|typename uint32_t|]
uintTypeToCType IntType
Int64 = [C.cty|typename uint64_t|]
floatTypeToCType :: FloatType -> C.Type
floatTypeToCType :: FloatType -> Type
floatTypeToCType FloatType
Float32 = [C.cty|float|]
floatTypeToCType FloatType
Float64 = [C.cty|double|]
primTypeToCType :: PrimType -> C.Type
primTypeToCType :: PrimType -> Type
primTypeToCType (IntType IntType
t) = IntType -> Type
intTypeToCType IntType
t
primTypeToCType (FloatType FloatType
t) = FloatType -> Type
floatTypeToCType FloatType
t
primTypeToCType PrimType
Bool = [C.cty|typename bool|]
primTypeToCType PrimType
Unit = [C.cty|typename bool|]
signedPrimTypeToCType :: Signedness -> PrimType -> C.Type
signedPrimTypeToCType :: Signedness -> PrimType -> Type
signedPrimTypeToCType Signedness
TypeUnsigned (IntType IntType
t) = IntType -> Type
uintTypeToCType IntType
t
signedPrimTypeToCType Signedness
TypeDirect (IntType IntType
t) = IntType -> Type
intTypeToCType IntType
t
signedPrimTypeToCType Signedness
_ PrimType
t = PrimType -> Type
primTypeToCType PrimType
t
tupleField :: Int -> String
tupleField :: Int -> String
tupleField Int
i = String
"v" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
funName :: Name -> String
funName :: Name -> String
funName = (String
"futrts_" String -> String -> String
forall a. [a] -> [a] -> [a]
++) (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
zEncodeString (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameToString
funName' :: String -> String
funName' :: String -> String
funName' = Name -> String
funName (Name -> String) -> (String -> Name) -> String -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Name
nameFromString
defaultMemBlockType :: C.Type
defaultMemBlockType :: Type
defaultMemBlockType = [C.cty|char*|]
arrayName :: PrimType -> Signedness -> Int -> String
arrayName :: PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank =
Bool -> PrimType -> String
prettySigned (Signedness
signed Signedness -> Signedness -> Bool
forall a. Eq a => a -> a -> Bool
== Signedness
TypeUnsigned) PrimType
pt String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
rank String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"d"
opaqueName :: String -> [ValueDesc] -> String
opaqueName :: String -> [ValueDesc] -> String
opaqueName String
s [ValueDesc]
_
| Bool
valid = String
"opaque_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s
where
valid :: Bool
valid =
String -> Char
forall a. [a] -> a
head String
s Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'_'
Bool -> Bool -> Bool
&& Bool -> Bool
not (Char -> Bool
isDigit (Char -> Bool) -> Char -> Bool
forall a b. (a -> b) -> a -> b
$ String -> Char
forall a. [a] -> a
head String
s)
Bool -> Bool -> Bool
&& (Char -> Bool) -> String -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Char -> Bool
ok String
s
ok :: Char -> Bool
ok Char
c = Char -> Bool
isAlphaNum Char
c Bool -> Bool -> Bool
|| Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'_'
opaqueName String
s [ValueDesc]
vds = String
"opaque_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
hash ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Bits a => a -> a -> a
xor [Int
0 ..] ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Char -> Int) -> String -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Char -> Int
ord (String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ (ValueDesc -> String) -> [ValueDesc] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ValueDesc -> String
p [ValueDesc]
vds))
where
p :: ValueDesc -> String
p (ScalarValue PrimType
pt Signedness
signed VName
_) =
(PrimType, Signedness) -> String
forall a. Show a => a -> String
show (PrimType
pt, Signedness
signed)
p (ArrayValue VName
_ Space
space PrimType
pt Signedness
signed [DimSize]
dims) =
(Space, PrimType, Signedness, Int) -> String
forall a. Show a => a -> String
show (Space
space, PrimType
pt, Signedness
signed, [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
dims)
hash :: [Int] -> String
hash =
String -> Word32 -> String
forall r. PrintfType r => String -> r
printf String
"%x" (Word32 -> String) -> ([Int] -> Word32) -> [Int] -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word32 -> Word32 -> Word32) -> Word32 -> [Word32] -> Word32
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
xor Word32
0
([Word32] -> Word32) -> ([Int] -> [Word32]) -> [Int] -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Word32) -> [Int] -> [Word32]
forall a b. (a -> b) -> [a] -> [b]
map
( Word32 -> Word32
iter (Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
0x45d9f3b)
(Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word32
iter
(Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
0x45d9f3b)
(Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word32
iter
(Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral
)
iter :: Word32 -> Word32
iter Word32
x = ((Word32
x :: Word32) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` Word32
x
externalValueType :: ExternalValue -> C.Type
externalValueType :: ExternalValue -> Type
externalValueType (OpaqueValue String
desc [ValueDesc]
vds) =
[C.cty|struct $id:("futhark_" ++ opaqueName desc vds)*|]
externalValueType (TransparentValue (ArrayValue VName
_ Space
_ PrimType
pt Signedness
signed [DimSize]
shape)) =
[C.cty|struct $id:("futhark_" ++ arrayName pt signed (length shape))*|]
externalValueType (TransparentValue (ScalarValue PrimType
pt Signedness
signed VName
_)) =
Signedness -> PrimType -> Type
signedPrimTypeToCType Signedness
signed PrimType
pt
cproduct :: [C.Exp] -> C.Exp
cproduct :: [Exp] -> Exp
cproduct [] = [C.cexp|1|]
cproduct (Exp
e : [Exp]
es) = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
mult Exp
e [Exp]
es
where
mult :: a -> a -> Exp
mult a
x a
y = [C.cexp|$exp:x * $exp:y|]
csum :: [C.Exp] -> C.Exp
csum :: [Exp] -> Exp
csum [] = [C.cexp|0|]
csum (Exp
e : [Exp]
es) = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
mult Exp
e [Exp]
es
where
mult :: a -> a -> Exp
mult a
x a
y = [C.cexp|$exp:x + $exp:y|]
instance C.ToIdent Name where
toIdent :: Name -> SrcLoc -> Id
toIdent = String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String -> SrcLoc -> Id)
-> (Name -> String) -> Name -> SrcLoc -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
zEncodeString (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameToString
instance C.ToIdent VName where
toIdent :: VName -> SrcLoc -> Id
toIdent = String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String -> SrcLoc -> Id)
-> (VName -> String) -> VName -> SrcLoc -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
zEncodeString (String -> String) -> (VName -> String) -> VName -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty
instance C.ToExp VName where
toExp :: VName -> SrcLoc -> Exp
toExp VName
v SrcLoc
_ = [C.cexp|$id:v|]
instance C.ToExp IntValue where
toExp :: IntValue -> SrcLoc -> Exp
toExp (Int8Value Int8
v) = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Int8
v
toExp (Int16Value Int16
v) = Int16 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Int16
v
toExp (Int32Value Int32
v) = Int32 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Int32
v
toExp (Int64Value Int64
v) = Int64 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Int64
v
instance C.ToExp FloatValue where
toExp :: FloatValue -> SrcLoc -> Exp
toExp (Float32Value Float
v) = Float -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Float
v
toExp (Float64Value Double
v) = Double -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Double
v
instance C.ToExp PrimValue where
toExp :: PrimValue -> SrcLoc -> Exp
toExp (IntValue IntValue
v) = IntValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp IntValue
v
toExp (FloatValue FloatValue
v) = FloatValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp FloatValue
v
toExp (BoolValue Bool
True) = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
1 :: Int8)
toExp (BoolValue Bool
False) = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
0 :: Int8)
toExp PrimValue
UnitValue = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
1 :: Int8)
instance C.ToExp SubExp where
toExp :: DimSize -> SrcLoc -> Exp
toExp (Var VName
v) = VName -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp VName
v
toExp (Constant PrimValue
c) = PrimValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp PrimValue
c
cIntOps :: [C.Definition]
cIntOps :: [Definition]
cIntOps =
((IntType -> Definition) -> [Definition])
-> [IntType -> Definition] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((IntType -> Definition) -> [IntType] -> [Definition]
forall a b. (a -> b) -> [a] -> [b]
`map` [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]) [IntType -> Definition]
ops
[Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cIntPrimFuns
where
ops :: [IntType -> Definition]
ops =
[ IntType -> Definition
mkAdd,
IntType -> Definition
mkSub,
IntType -> Definition
mkMul,
IntType -> Definition
mkUDiv,
IntType -> Definition
mkUDivUp,
IntType -> Definition
mkUMod,
IntType -> Definition
mkUDivSafe,
IntType -> Definition
mkUDivUpSafe,
IntType -> Definition
mkUModSafe,
IntType -> Definition
mkSDiv,
IntType -> Definition
mkSDivUp,
IntType -> Definition
mkSMod,
IntType -> Definition
mkSDivSafe,
IntType -> Definition
mkSDivUpSafe,
IntType -> Definition
mkSModSafe,
IntType -> Definition
mkSQuot,
IntType -> Definition
mkSRem,
IntType -> Definition
mkSQuotSafe,
IntType -> Definition
mkSRemSafe,
IntType -> Definition
mkSMin,
IntType -> Definition
mkUMin,
IntType -> Definition
mkSMax,
IntType -> Definition
mkUMax,
IntType -> Definition
mkShl,
IntType -> Definition
mkLShr,
IntType -> Definition
mkAShr,
IntType -> Definition
mkAnd,
IntType -> Definition
mkOr,
IntType -> Definition
mkXor,
IntType -> Definition
mkUlt,
IntType -> Definition
mkUle,
IntType -> Definition
mkSlt,
IntType -> Definition
mkSle,
IntType -> Definition
mkPow,
IntType -> Definition
mkIToB,
IntType -> Definition
mkBToI
]
[IntType -> Definition]
-> [IntType -> Definition] -> [IntType -> Definition]
forall a. [a] -> [a] -> [a]
++ (IntType -> IntType -> Definition)
-> [IntType] -> [IntType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map IntType -> IntType -> Definition
mkSExt [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]
[IntType -> Definition]
-> [IntType -> Definition] -> [IntType -> Definition]
forall a. [a] -> [a] -> [a]
++ (IntType -> IntType -> Definition)
-> [IntType] -> [IntType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map IntType -> IntType -> Definition
mkZExt [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]
taggedI :: String -> IntType -> String
taggedI String
s IntType
Int8 = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"8"
taggedI String
s IntType
Int16 = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"16"
taggedI String
s IntType
Int32 = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"32"
taggedI String
s IntType
Int64 = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"64"
mkAdd :: IntType -> Definition
mkAdd = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"add" [C.cexp|x + y|]
mkSub :: IntType -> Definition
mkSub = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"sub" [C.cexp|x - y|]
mkMul :: IntType -> Definition
mkMul = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"mul" [C.cexp|x * y|]
mkUDiv :: IntType -> Definition
mkUDiv = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"udiv" [C.cexp|x / y|]
mkUDivUp :: IntType -> Definition
mkUDivUp = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"udiv_up" [C.cexp|(x+y-1) / y|]
mkUMod :: IntType -> Definition
mkUMod = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"umod" [C.cexp|x % y|]
mkUDivSafe :: IntType -> Definition
mkUDivSafe = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"udiv_safe" [C.cexp|y == 0 ? 0 : x / y|]
mkUDivUpSafe :: IntType -> Definition
mkUDivUpSafe = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"udiv_up_safe" [C.cexp|y == 0 ? 0 : (x+y-1) / y|]
mkUModSafe :: IntType -> Definition
mkUModSafe = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"umod_safe" [C.cexp|y == 0 ? 0 : x % y|]
mkUMax :: IntType -> Definition
mkUMax = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"umax" [C.cexp|x < y ? y : x|]
mkUMin :: IntType -> Definition
mkUMin = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"umin" [C.cexp|x < y ? x : y|]
mkSDiv :: IntType -> Definition
mkSDiv IntType
t =
let ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
in [C.cedecl|static inline $ty:ct $id:(taggedI "sdiv" t)($ty:ct x, $ty:ct y) {
$ty:ct q = x / y;
$ty:ct r = x % y;
return q -
(((r != 0) && ((r < 0) != (y < 0))) ? 1 : 0);
}|]
mkSDivUp :: IntType -> Definition
mkSDivUp IntType
t =
String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"sdiv_up" [C.cexp|$id:(taggedI "sdiv" t)(x+y-1,y)|] IntType
t
mkSMod :: IntType -> Definition
mkSMod IntType
t =
let ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
in [C.cedecl|static inline $ty:ct $id:(taggedI "smod" t)($ty:ct x, $ty:ct y) {
$ty:ct r = x % y;
return r +
((r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0)) ? 0 : y);
}|]
mkSDivSafe :: IntType -> Definition
mkSDivSafe IntType
t =
String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"sdiv_safe" [C.cexp|y == 0 ? 0 : $id:(taggedI "sdiv" t)(x,y)|] IntType
t
mkSDivUpSafe :: IntType -> Definition
mkSDivUpSafe IntType
t =
String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"sdiv_up_safe" [C.cexp|$id:(taggedI "sdiv_safe" t)(x+y-1,y)|] IntType
t
mkSModSafe :: IntType -> Definition
mkSModSafe IntType
t =
String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"smod_safe" [C.cexp|y == 0 ? 0 : $id:(taggedI "smod" t)(x,y)|] IntType
t
mkSQuot :: IntType -> Definition
mkSQuot = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"squot" [C.cexp|x / y|]
mkSRem :: IntType -> Definition
mkSRem = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"srem" [C.cexp|x % y|]
mkSQuotSafe :: IntType -> Definition
mkSQuotSafe = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"squot_safe" [C.cexp|y == 0 ? 0 : x / y|]
mkSRemSafe :: IntType -> Definition
mkSRemSafe = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"srem_safe" [C.cexp|y == 0 ? 0 : x % y|]
mkSMax :: IntType -> Definition
mkSMax = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"smax" [C.cexp|x < y ? y : x|]
mkSMin :: IntType -> Definition
mkSMin = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"smin" [C.cexp|x < y ? x : y|]
mkShl :: IntType -> Definition
mkShl = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"shl" [C.cexp|x << y|]
mkLShr :: IntType -> Definition
mkLShr = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"lshr" [C.cexp|x >> y|]
mkAShr :: IntType -> Definition
mkAShr = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleIntOp String
"ashr" [C.cexp|x >> y|]
mkAnd :: IntType -> Definition
mkAnd = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"and" [C.cexp|x & y|]
mkOr :: IntType -> Definition
mkOr = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"or" [C.cexp|x | y|]
mkXor :: IntType -> Definition
mkXor = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
simpleUintOp String
"xor" [C.cexp|x ^ y|]
mkUlt :: IntType -> Definition
mkUlt = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
uintCmpOp String
"ult" [C.cexp|x < y|]
mkUle :: IntType -> Definition
mkUle = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
uintCmpOp String
"ule" [C.cexp|x <= y|]
mkSlt :: IntType -> Definition
mkSlt = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
intCmpOp String
"slt" [C.cexp|x < y|]
mkSle :: IntType -> Definition
mkSle = String -> Exp -> IntType -> Definition
forall {a}. ToExp a => String -> a -> IntType -> Definition
intCmpOp String
"sle" [C.cexp|x <= y|]
macro :: String -> a -> Definition
macro String
name a
rhs =
[C.cedecl|$esc:("#define " ++ name ++ "(x) (" ++ prettyOneLine rhs ++ ")")|]
mkPow :: IntType -> Definition
mkPow IntType
t =
let ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
in [C.cedecl|static inline $ty:ct $id:(taggedI "pow" t)($ty:ct x, $ty:ct y) {
$ty:ct res = 1, rem = y;
while (rem != 0) {
if (rem & 1) {
res *= x;
}
rem >>= 1;
x *= x;
}
return res;
}|]
mkSExt :: IntType -> IntType -> Definition
mkSExt IntType
from_t IntType
to_t = String -> Exp -> Definition
forall {a}. Pretty a => String -> a -> Definition
macro String
name [C.cexp|($ty:to_ct)(($ty:from_ct)x)|]
where
name :: String
name = String
"sext_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
from_t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
to_t
from_ct :: Type
from_ct = IntType -> Type
intTypeToCType IntType
from_t
to_ct :: Type
to_ct = IntType -> Type
intTypeToCType IntType
to_t
mkZExt :: IntType -> IntType -> Definition
mkZExt IntType
from_t IntType
to_t = String -> Exp -> Definition
forall {a}. Pretty a => String -> a -> Definition
macro String
name [C.cexp|($ty:to_ct)(($ty:from_ct)x)|]
where
name :: String
name = String
"zext_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
from_t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
to_t
from_ct :: Type
from_ct = IntType -> Type
uintTypeToCType IntType
from_t
to_ct :: Type
to_ct = IntType -> Type
intTypeToCType IntType
to_t
mkBToI :: IntType -> Definition
mkBToI IntType
to_t =
[C.cedecl|static inline $ty:to_ct
$id:name($ty:from_ct x) { return x; } |]
where
name :: String
name = String
"btoi_bool_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
to_t
from_ct :: Type
from_ct = PrimType -> Type
primTypeToCType PrimType
Bool
to_ct :: Type
to_ct = IntType -> Type
intTypeToCType IntType
to_t
mkIToB :: IntType -> Definition
mkIToB IntType
from_t =
[C.cedecl|static inline $ty:to_ct
$id:name($ty:from_ct x) { return x; } |]
where
name :: String
name = String
"itob_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
from_t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_bool"
to_ct :: Type
to_ct = PrimType -> Type
primTypeToCType PrimType
Bool
from_ct :: Type
from_ct = IntType -> Type
intTypeToCType IntType
from_t
simpleUintOp :: String -> a -> IntType -> Definition
simpleUintOp String
s a
e IntType
t =
[C.cedecl|static inline $ty:ct $id:(taggedI s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
where
ct :: Type
ct = IntType -> Type
uintTypeToCType IntType
t
simpleIntOp :: String -> a -> IntType -> Definition
simpleIntOp String
s a
e IntType
t =
[C.cedecl|static inline $ty:ct $id:(taggedI s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
where
ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
intCmpOp :: String -> a -> IntType -> Definition
intCmpOp String
s a
e IntType
t =
[C.cedecl|static inline typename bool $id:(taggedI s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
where
ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
uintCmpOp :: String -> a -> IntType -> Definition
uintCmpOp String
s a
e IntType
t =
[C.cedecl|static inline typename bool $id:(taggedI s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
where
ct :: Type
ct = IntType -> Type
uintTypeToCType IntType
t
cIntPrimFuns :: [C.Definition]
cIntPrimFuns :: [Definition]
cIntPrimFuns =
[C.cunit|
$esc:("#if defined(__OPENCL_VERSION__)")
static typename int32_t $id:(funName' "popc8") (typename int8_t x) {
return popcount(x);
}
static typename int32_t $id:(funName' "popc16") (typename int16_t x) {
return popcount(x);
}
static typename int32_t $id:(funName' "popc32") (typename int32_t x) {
return popcount(x);
}
static typename int32_t $id:(funName' "popc64") (typename int64_t x) {
return popcount(x);
}
$esc:("#elif defined(__CUDA_ARCH__)")
static typename int32_t $id:(funName' "popc8") (typename int8_t x) {
return __popc(zext_i8_i32(x));
}
static typename int32_t $id:(funName' "popc16") (typename int16_t x) {
return __popc(zext_i16_i32(x));
}
static typename int32_t $id:(funName' "popc32") (typename int32_t x) {
return __popc(x);
}
static typename int32_t $id:(funName' "popc64") (typename int64_t x) {
return __popcll(x);
}
$esc:("#else")
static typename int32_t $id:(funName' "popc8") (typename int8_t x) {
int c = 0;
for (; x; ++c) {
x &= x - 1;
}
return c;
}
static typename int32_t $id:(funName' "popc16") (typename int16_t x) {
int c = 0;
for (; x; ++c) {
x &= x - 1;
}
return c;
}
static typename int32_t $id:(funName' "popc32") (typename int32_t x) {
int c = 0;
for (; x; ++c) {
x &= x - 1;
}
return c;
}
static typename int32_t $id:(funName' "popc64") (typename int64_t x) {
int c = 0;
for (; x; ++c) {
x &= x - 1;
}
return c;
}
$esc:("#endif")
$esc:("#if defined(__OPENCL_VERSION__)")
static typename uint8_t $id:(funName' "mul_hi8") (typename uint8_t a, typename uint8_t b) {
return mul_hi(a, b);
}
static typename uint16_t $id:(funName' "mul_hi16") (typename uint16_t a, typename uint16_t b) {
return mul_hi(a, b);
}
static typename uint32_t $id:(funName' "mul_hi32") (typename uint32_t a, typename uint32_t b) {
return mul_hi(a, b);
}
static typename uint64_t $id:(funName' "mul_hi64") (typename uint64_t a, typename uint64_t b) {
return mul_hi(a, b);
}
$esc:("#elif defined(__CUDA_ARCH__)")
static typename uint8_t $id:(funName' "mul_hi8") (typename uint8_t a, typename uint8_t b) {
typename uint16_t aa = a;
typename uint16_t bb = b;
return (aa * bb) >> 8;
}
static typename uint16_t $id:(funName' "mul_hi16") (typename uint16_t a, typename uint16_t b) {
typename uint32_t aa = a;
typename uint32_t bb = b;
return (aa * bb) >> 16;
}
static typename uint32_t $id:(funName' "mul_hi32") (typename uint32_t a, typename uint32_t b) {
return mulhi(a, b);
}
static typename uint64_t $id:(funName' "mul_hi64") (typename uint64_t a, typename uint64_t b) {
return mul64hi(a, b);
}
$esc:("#else")
static typename uint8_t $id:(funName' "mul_hi8") (typename uint8_t a, typename uint8_t b) {
typename uint16_t aa = a;
typename uint16_t bb = b;
return (aa * bb) >> 8;
}
static typename uint16_t $id:(funName' "mul_hi16") (typename uint16_t a, typename uint16_t b) {
typename uint32_t aa = a;
typename uint32_t bb = b;
return (aa * bb) >> 16;
}
static typename uint32_t $id:(funName' "mul_hi32") (typename uint32_t a, typename uint32_t b) {
typename uint64_t aa = a;
typename uint64_t bb = b;
return (aa * bb) >> 32;
}
static typename uint64_t $id:(funName' "mul_hi64") (typename uint64_t a, typename uint64_t b) {
typename __uint128_t aa = a;
typename __uint128_t bb = b;
return (aa * bb) >> 64;
}
$esc:("#endif")
$esc:("#if defined(__OPENCL_VERSION__)")
static typename uint8_t $id:(funName' "mad_hi8") (typename uint8_t a, typename uint8_t b, typename uint8_t c) {
return mad_hi(a, b, c);
}
static typename uint16_t $id:(funName' "mad_hi16") (typename uint16_t a, typename uint16_t b, typename uint16_t c) {
return mad_hi(a, b, c);
}
static typename uint32_t $id:(funName' "mad_hi32") (typename uint32_t a, typename uint32_t b, typename uint32_t c) {
return mad_hi(a, b, c);
}
static typename uint64_t $id:(funName' "mad_hi64") (typename uint64_t a, typename uint64_t b, typename uint64_t c) {
return mad_hi(a, b, c);
}
$esc:("#else")
static typename uint8_t $id:(funName' "mad_hi8") (typename uint8_t a, typename uint8_t b, typename uint8_t c) {
return futrts_mul_hi8(a, b) + c;
}
static typename uint16_t $id:(funName' "mad_hi16") (typename uint16_t a, typename uint16_t b, typename uint16_t c) {
return futrts_mul_hi16(a, b) + c;
}
static typename uint32_t $id:(funName' "mad_hi32") (typename uint32_t a, typename uint32_t b, typename uint32_t c) {
return futrts_mul_hi32(a, b) + c;
}
static typename uint64_t $id:(funName' "mad_hi64") (typename uint64_t a, typename uint64_t b, typename uint64_t c) {
return futrts_mul_hi64(a, b) + c;
}
$esc:("#endif")
$esc:("#if defined(__OPENCL_VERSION__)")
static typename int32_t $id:(funName' "clz8") (typename int8_t x) {
return clz(x);
}
static typename int32_t $id:(funName' "clz16") (typename int16_t x) {
return clz(x);
}
static typename int32_t $id:(funName' "clz32") (typename int32_t x) {
return clz(x);
}
static typename int32_t $id:(funName' "clz64") (typename int64_t x) {
return clz(x);
}
$esc:("#elif defined(__CUDA_ARCH__)")
static typename int32_t $id:(funName' "clz8") (typename int8_t x) {
return __clz(zext_i8_i32(x))-24;
}
static typename int32_t $id:(funName' "clz16") (typename int16_t x) {
return __clz(zext_i16_i32(x))-16;
}
static typename int32_t $id:(funName' "clz32") (typename int32_t x) {
return __clz(x);
}
static typename int32_t $id:(funName' "clz64") (typename int64_t x) {
return __clzll(x);
}
$esc:("#else")
static typename int32_t $id:(funName' "clz8") (typename int8_t x) {
int n = 0;
int bits = sizeof(x) * 8;
for (int i = 0; i < bits; i++) {
if (x < 0) break;
n++;
x <<= 1;
}
return n;
}
static typename int32_t $id:(funName' "clz16") (typename int16_t x) {
int n = 0;
int bits = sizeof(x) * 8;
for (int i = 0; i < bits; i++) {
if (x < 0) break;
n++;
x <<= 1;
}
return n;
}
static typename int32_t $id:(funName' "clz32") (typename int32_t x) {
int n = 0;
int bits = sizeof(x) * 8;
for (int i = 0; i < bits; i++) {
if (x < 0) break;
n++;
x <<= 1;
}
return n;
}
static typename int32_t $id:(funName' "clz64") (typename int64_t x) {
int n = 0;
int bits = sizeof(x) * 8;
for (int i = 0; i < bits; i++) {
if (x < 0) break;
n++;
x <<= 1;
}
return n;
}
$esc:("#endif")
$esc:("#if defined(__OPENCL_VERSION__)")
// OpenCL has ctz, but only from version 2.0, which we cannot assume we are using.
static typename int32_t $id:(funName' "ctz8") (typename int8_t x) {
int i = 0;
for (; i < 8 && (x&1)==0; i++, x>>=1);
return i;
}
static typename int32_t $id:(funName' "ctz16") (typename int16_t x) {
int i = 0;
for (; i < 16 && (x&1)==0; i++, x>>=1);
return i;
}
static typename int32_t $id:(funName' "ctz32") (typename int32_t x) {
int i = 0;
for (; i < 32 && (x&1)==0; i++, x>>=1);
return i;
}
static typename int32_t $id:(funName' "ctz64") (typename int64_t x) {
int i = 0;
for (; i < 64 && (x&1)==0; i++, x>>=1);
return i;
}
$esc:("#elif defined(__CUDA_ARCH__)")
static typename int32_t $id:(funName' "ctz8") (typename int8_t x) {
int y = __ffs(x);
return y == 0 ? 8 : y-1;
}
static typename int32_t $id:(funName' "ctz16") (typename int16_t x) {
int y = __ffs(x);
return y == 0 ? 16 : y-1;
}
static typename int32_t $id:(funName' "ctz32") (typename int32_t x) {
int y = __ffs(x);
return y == 0 ? 32 : y-1;
}
static typename int32_t $id:(funName' "ctz64") (typename int64_t x) {
int y = __ffsll(x);
return y == 0 ? 64 : y-1;
}
$esc:("#else")
// FIXME: assumes GCC or clang.
static typename int32_t $id:(funName' "ctz8") (typename int8_t x) {
return x == 0 ? 8 : __builtin_ctz((typename uint32_t)x);
}
static typename int32_t $id:(funName' "ctz16") (typename int16_t x) {
return x == 0 ? 16 : __builtin_ctz((typename uint32_t)x);
}
static typename int32_t $id:(funName' "ctz32") (typename int32_t x) {
return x == 0 ? 32 : __builtin_ctz(x);
}
static typename int32_t $id:(funName' "ctz64") (typename int64_t x) {
return x == 0 ? 64 : __builtin_ctzll(x);
}
$esc:("#endif")
|]
cFloat32Ops :: [C.Definition]
cFloat64Ops :: [C.Definition]
cFloatConvOps :: [C.Definition]
([Definition]
cFloat32Ops, [Definition]
cFloat64Ops, [Definition]
cFloatConvOps) =
( ((FloatType -> Definition) -> Definition)
-> [FloatType -> Definition] -> [Definition]
forall a b. (a -> b) -> [a] -> [b]
map ((FloatType -> Definition) -> FloatType -> Definition
forall a b. (a -> b) -> a -> b
$ FloatType
Float32) [FloatType -> Definition]
mkOps,
((FloatType -> Definition) -> Definition)
-> [FloatType -> Definition] -> [Definition]
forall a b. (a -> b) -> [a] -> [b]
map ((FloatType -> Definition) -> FloatType -> Definition
forall a b. (a -> b) -> a -> b
$ FloatType
Float64) [FloatType -> Definition]
mkOps,
[ String -> FloatType -> FloatType -> Definition
mkFPConvFF String
"fpconv" FloatType
from FloatType
to
| FloatType
from <- [FloatType
forall a. Bounded a => a
minBound .. FloatType
forall a. Bounded a => a
maxBound],
FloatType
to <- [FloatType
forall a. Bounded a => a
minBound .. FloatType
forall a. Bounded a => a
maxBound]
]
)
where
taggedF :: String -> FloatType -> String
taggedF String
s FloatType
Float32 = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"32"
taggedF String
s FloatType
Float64 = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"64"
convOp :: String -> a -> a -> String
convOp String
s a
from a
to = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Pretty a => a -> String
pretty a
from String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Pretty a => a -> String
pretty a
to
mkOps :: [FloatType -> Definition]
mkOps =
[FloatType -> Definition
mkFDiv, FloatType -> Definition
mkFAdd, FloatType -> Definition
mkFSub, FloatType -> Definition
mkFMul, FloatType -> Definition
mkFMin, FloatType -> Definition
mkFMax, FloatType -> Definition
mkPow, FloatType -> Definition
mkCmpLt, FloatType -> Definition
mkCmpLe]
[FloatType -> Definition]
-> [FloatType -> Definition] -> [FloatType -> Definition]
forall a. [a] -> [a] -> [a]
++ (IntType -> FloatType -> Definition)
-> [IntType] -> [FloatType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map (String -> IntType -> FloatType -> Definition
mkFPConvIF String
"sitofp") [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]
[FloatType -> Definition]
-> [FloatType -> Definition] -> [FloatType -> Definition]
forall a. [a] -> [a] -> [a]
++ (IntType -> FloatType -> Definition)
-> [IntType] -> [FloatType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map (String -> IntType -> FloatType -> Definition
mkFPConvUF String
"uitofp") [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]
[FloatType -> Definition]
-> [FloatType -> Definition] -> [FloatType -> Definition]
forall a. [a] -> [a] -> [a]
++ (IntType -> FloatType -> Definition)
-> [IntType] -> [FloatType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map ((FloatType -> IntType -> Definition)
-> IntType -> FloatType -> Definition
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((FloatType -> IntType -> Definition)
-> IntType -> FloatType -> Definition)
-> (FloatType -> IntType -> Definition)
-> IntType
-> FloatType
-> Definition
forall a b. (a -> b) -> a -> b
$ String -> FloatType -> IntType -> Definition
mkFPConvFI String
"fptosi") [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]
[FloatType -> Definition]
-> [FloatType -> Definition] -> [FloatType -> Definition]
forall a. [a] -> [a] -> [a]
++ (IntType -> FloatType -> Definition)
-> [IntType] -> [FloatType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map ((FloatType -> IntType -> Definition)
-> IntType -> FloatType -> Definition
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((FloatType -> IntType -> Definition)
-> IntType -> FloatType -> Definition)
-> (FloatType -> IntType -> Definition)
-> IntType
-> FloatType
-> Definition
forall a b. (a -> b) -> a -> b
$ String -> FloatType -> IntType -> Definition
mkFPConvFU String
"fptoui") [IntType
forall a. Bounded a => a
minBound .. IntType
forall a. Bounded a => a
maxBound]
mkFDiv :: FloatType -> Definition
mkFDiv = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
simpleFloatOp String
"fdiv" [C.cexp|x / y|]
mkFAdd :: FloatType -> Definition
mkFAdd = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
simpleFloatOp String
"fadd" [C.cexp|x + y|]
mkFSub :: FloatType -> Definition
mkFSub = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
simpleFloatOp String
"fsub" [C.cexp|x - y|]
mkFMul :: FloatType -> Definition
mkFMul = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
simpleFloatOp String
"fmul" [C.cexp|x * y|]
mkFMin :: FloatType -> Definition
mkFMin = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
simpleFloatOp String
"fmin" [C.cexp|fmin(x, y)|]
mkFMax :: FloatType -> Definition
mkFMax = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
simpleFloatOp String
"fmax" [C.cexp|fmax(x, y)|]
mkCmpLt :: FloatType -> Definition
mkCmpLt = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
floatCmpOp String
"cmplt" [C.cexp|x < y|]
mkCmpLe :: FloatType -> Definition
mkCmpLe = String -> Exp -> FloatType -> Definition
forall {a}. ToExp a => String -> a -> FloatType -> Definition
floatCmpOp String
"cmple" [C.cexp|x <= y|]
mkPow :: FloatType -> Definition
mkPow FloatType
Float32 =
[C.cedecl|static inline float fpow32(float x, float y) { return pow(x, y); }|]
mkPow FloatType
Float64 =
[C.cedecl|static inline double fpow64(double x, double y) { return pow(x, y); }|]
mkFPConv :: (a -> Type) -> (a -> Type) -> String -> a -> a -> Definition
mkFPConv a -> Type
from_f a -> Type
to_f String
s a
from_t a
to_t =
[C.cedecl|static inline $ty:to_ct
$id:(convOp s from_t to_t)($ty:from_ct x) { return ($ty:to_ct)x;} |]
where
from_ct :: Type
from_ct = a -> Type
from_f a
from_t
to_ct :: Type
to_ct = a -> Type
to_f a
to_t
mkFPConvFF :: String -> FloatType -> FloatType -> Definition
mkFPConvFF = (FloatType -> Type)
-> (FloatType -> Type)
-> String
-> FloatType
-> FloatType
-> Definition
forall {a} {a}.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> String -> a -> a -> Definition
mkFPConv FloatType -> Type
floatTypeToCType FloatType -> Type
floatTypeToCType
mkFPConvFI :: String -> FloatType -> IntType -> Definition
mkFPConvFI = (FloatType -> Type)
-> (IntType -> Type)
-> String
-> FloatType
-> IntType
-> Definition
forall {a} {a}.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> String -> a -> a -> Definition
mkFPConv FloatType -> Type
floatTypeToCType IntType -> Type
intTypeToCType
mkFPConvIF :: String -> IntType -> FloatType -> Definition
mkFPConvIF = (IntType -> Type)
-> (FloatType -> Type)
-> String
-> IntType
-> FloatType
-> Definition
forall {a} {a}.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> String -> a -> a -> Definition
mkFPConv IntType -> Type
intTypeToCType FloatType -> Type
floatTypeToCType
mkFPConvFU :: String -> FloatType -> IntType -> Definition
mkFPConvFU = (FloatType -> Type)
-> (IntType -> Type)
-> String
-> FloatType
-> IntType
-> Definition
forall {a} {a}.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> String -> a -> a -> Definition
mkFPConv FloatType -> Type
floatTypeToCType IntType -> Type
uintTypeToCType
mkFPConvUF :: String -> IntType -> FloatType -> Definition
mkFPConvUF = (IntType -> Type)
-> (FloatType -> Type)
-> String
-> IntType
-> FloatType
-> Definition
forall {a} {a}.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> String -> a -> a -> Definition
mkFPConv IntType -> Type
uintTypeToCType FloatType -> Type
floatTypeToCType
simpleFloatOp :: String -> a -> FloatType -> Definition
simpleFloatOp String
s a
e FloatType
t =
[C.cedecl|static inline $ty:ct $id:(taggedF s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
where
ct :: Type
ct = FloatType -> Type
floatTypeToCType FloatType
t
floatCmpOp :: String -> a -> FloatType -> Definition
floatCmpOp String
s a
e FloatType
t =
[C.cedecl|static inline typename bool $id:(taggedF s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
where
ct :: Type
ct = FloatType -> Type
floatTypeToCType FloatType
t
cFloat32Funs :: [C.Definition]
cFloat32Funs :: [Definition]
cFloat32Funs =
[C.cunit|
static inline typename bool $id:(funName' "isnan32")(float x) {
return isnan(x);
}
static inline typename bool $id:(funName' "isinf32")(float x) {
return isinf(x);
}
$esc:("#ifdef __OPENCL_VERSION__")
static inline float $id:(funName' "log32")(float x) {
return log(x);
}
static inline float $id:(funName' "log2_32")(float x) {
return log2(x);
}
static inline float $id:(funName' "log10_32")(float x) {
return log10(x);
}
static inline float $id:(funName' "sqrt32")(float x) {
return sqrt(x);
}
static inline float $id:(funName' "exp32")(float x) {
return exp(x);
}
static inline float $id:(funName' "cos32")(float x) {
return cos(x);
}
static inline float $id:(funName' "sin32")(float x) {
return sin(x);
}
static inline float $id:(funName' "tan32")(float x) {
return tan(x);
}
static inline float $id:(funName' "acos32")(float x) {
return acos(x);
}
static inline float $id:(funName' "asin32")(float x) {
return asin(x);
}
static inline float $id:(funName' "atan32")(float x) {
return atan(x);
}
static inline float $id:(funName' "cosh32")(float x) {
return cosh(x);
}
static inline float $id:(funName' "sinh32")(float x) {
return sinh(x);
}
static inline float $id:(funName' "tanh32")(float x) {
return tanh(x);
}
static inline float $id:(funName' "acosh32")(float x) {
return acosh(x);
}
static inline float $id:(funName' "asinh32")(float x) {
return asinh(x);
}
static inline float $id:(funName' "atanh32")(float x) {
return atanh(x);
}
static inline float $id:(funName' "atan2_32")(float x, float y) {
return atan2(x,y);
}
static inline float $id:(funName' "hypot32")(float x, float y) {
return hypot(x,y);
}
static inline float $id:(funName' "gamma32")(float x) {
return tgamma(x);
}
static inline float $id:(funName' "lgamma32")(float x) {
return lgamma(x);
}
static inline float fmod32(float x, float y) {
return fmod(x, y);
}
static inline float $id:(funName' "round32")(float x) {
return rint(x);
}
static inline float $id:(funName' "floor32")(float x) {
return floor(x);
}
static inline float $id:(funName' "ceil32")(float x) {
return ceil(x);
}
static inline float $id:(funName' "lerp32")(float v0, float v1, float t) {
return mix(v0, v1, t);
}
static inline float $id:(funName' "mad32")(float a, float b, float c) {
return mad(a,b,c);
}
static inline float $id:(funName' "fma32")(float a, float b, float c) {
return fma(a,b,c);
}
$esc:("#else")
static inline float $id:(funName' "log32")(float x) {
return logf(x);
}
static inline float $id:(funName' "log2_32")(float x) {
return log2f(x);
}
static inline float $id:(funName' "log10_32")(float x) {
return log10f(x);
}
static inline float $id:(funName' "sqrt32")(float x) {
return sqrtf(x);
}
static inline float $id:(funName' "exp32")(float x) {
return expf(x);
}
static inline float $id:(funName' "cos32")(float x) {
return cosf(x);
}
static inline float $id:(funName' "sin32")(float x) {
return sinf(x);
}
static inline float $id:(funName' "tan32")(float x) {
return tanf(x);
}
static inline float $id:(funName' "acos32")(float x) {
return acosf(x);
}
static inline float $id:(funName' "asin32")(float x) {
return asinf(x);
}
static inline float $id:(funName' "atan32")(float x) {
return atanf(x);
}
static inline float $id:(funName' "cosh32")(float x) {
return coshf(x);
}
static inline float $id:(funName' "sinh32")(float x) {
return sinhf(x);
}
static inline float $id:(funName' "tanh32")(float x) {
return tanhf(x);
}
static inline float $id:(funName' "acosh32")(float x) {
return acoshf(x);
}
static inline float $id:(funName' "asinh32")(float x) {
return asinhf(x);
}
static inline float $id:(funName' "atanh32")(float x) {
return atanhf(x);
}
static inline float $id:(funName' "atan2_32")(float x, float y) {
return atan2f(x,y);
}
static inline float $id:(funName' "hypot32")(float x, float y) {
return hypotf(x,y);
}
static inline float $id:(funName' "gamma32")(float x) {
return tgammaf(x);
}
static inline float $id:(funName' "lgamma32")(float x) {
return lgammaf(x);
}
static inline float fmod32(float x, float y) {
return fmodf(x, y);
}
static inline float $id:(funName' "round32")(float x) {
return rintf(x);
}
static inline float $id:(funName' "floor32")(float x) {
return floorf(x);
}
static inline float $id:(funName' "ceil32")(float x) {
return ceilf(x);
}
static inline float $id:(funName' "lerp32")(float v0, float v1, float t) {
return v0 + (v1-v0)*t;
}
static inline float $id:(funName' "mad32")(float a, float b, float c) {
return a*b+c;
}
static inline float $id:(funName' "fma32")(float a, float b, float c) {
return fmaf(a,b,c);
}
$esc:("#endif")
static inline typename int32_t $id:(funName' "to_bits32")(float x) {
union {
float f;
typename int32_t t;
} p;
p.f = x;
return p.t;
}
static inline float $id:(funName' "from_bits32")(typename int32_t x) {
union {
typename int32_t f;
float t;
} p;
p.f = x;
return p.t;
}
static inline float fsignum32(float x) {
return $id:(funName' "isnan32")(x) ? x : ((x > 0) - (x < 0));
}
|]
cFloat64Funs :: [C.Definition]
cFloat64Funs :: [Definition]
cFloat64Funs =
[C.cunit|
static inline double $id:(funName' "log64")(double x) {
return log(x);
}
static inline double $id:(funName' "log2_64")(double x) {
return log2(x);
}
static inline double $id:(funName' "log10_64")(double x) {
return log10(x);
}
static inline double $id:(funName' "sqrt64")(double x) {
return sqrt(x);
}
static inline double $id:(funName' "exp64")(double x) {
return exp(x);
}
static inline double $id:(funName' "cos64")(double x) {
return cos(x);
}
static inline double $id:(funName' "sin64")(double x) {
return sin(x);
}
static inline double $id:(funName' "tan64")(double x) {
return tan(x);
}
static inline double $id:(funName' "acos64")(double x) {
return acos(x);
}
static inline double $id:(funName' "asin64")(double x) {
return asin(x);
}
static inline double $id:(funName' "atan64")(double x) {
return atan(x);
}
static inline double $id:(funName' "cosh64")(double x) {
return cosh(x);
}
static inline double $id:(funName' "sinh64")(double x) {
return sinh(x);
}
static inline double $id:(funName' "tanh64")(double x) {
return tanh(x);
}
static inline double $id:(funName' "acosh64")(double x) {
return acosh(x);
}
static inline double $id:(funName' "asinh64")(double x) {
return asinh(x);
}
static inline double $id:(funName' "atanh64")(double x) {
return atanh(x);
}
static inline double $id:(funName' "atan2_64")(double x, double y) {
return atan2(x,y);
}
static inline double $id:(funName' "hypot64")(double x, double y) {
return hypot(x,y);
}
static inline double $id:(funName' "gamma64")(double x) {
return tgamma(x);
}
static inline double $id:(funName' "lgamma64")(double x) {
return lgamma(x);
}
static inline double $id:(funName' "fma64")(double a, double b, double c) {
return fma(a,b,c);
}
static inline double $id:(funName' "round64")(double x) {
return rint(x);
}
static inline double $id:(funName' "ceil64")(double x) {
return ceil(x);
}
static inline double $id:(funName' "floor64")(double x) {
return floor(x);
}
static inline typename bool $id:(funName' "isnan64")(double x) {
return isnan(x);
}
static inline typename bool $id:(funName' "isinf64")(double x) {
return isinf(x);
}
static inline typename int64_t $id:(funName' "to_bits64")(double x) {
union {
double f;
typename int64_t t;
} p;
p.f = x;
return p.t;
}
static inline double $id:(funName' "from_bits64")(typename int64_t x) {
union {
typename int64_t f;
double t;
} p;
p.f = x;
return p.t;
}
static inline double fmod64(double x, double y) {
return fmod(x, y);
}
static inline double fsignum64(double x) {
return $id:(funName' "isnan64")(x) ? x : ((x > 0) - (x < 0));
}
$esc:("#ifdef __OPENCL_VERSION__")
static inline double $id:(funName' "lerp64")(double v0, double v1, double t) {
return mix(v0, v1, t);
}
static inline double $id:(funName' "mad64")(double a, double b, double c) {
return mad(a,b,c);
}
$esc:("#else")
static inline double $id:(funName' "lerp64")(double v0, double v1, double t) {
return v0 + (v1-v0)*t;
}
static inline double $id:(funName' "mad64")(double a, double b, double c) {
return a*b+c;
}
$esc:("#endif")
|]
storageSize :: PrimType -> Int -> C.Exp -> C.Exp
storageSize :: PrimType -> Int -> Exp -> Exp
storageSize PrimType
pt Int
rank Exp
shape =
[C.cexp|$int:header_size +
$int:rank * sizeof(typename int64_t) +
$exp:(cproduct dims) * $int:pt_size|]
where
header_size, pt_size :: Int
header_size :: Int
header_size = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
4
pt_size :: Int
pt_size = PrimType -> Int
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
dims :: [Exp]
dims = [[C.cexp|$exp:shape[$int:i]|] | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
typeStr :: Signedness -> PrimType -> String
typeStr :: Signedness -> PrimType -> String
typeStr Signedness
sign PrimType
pt =
case (Signedness
sign, PrimType
pt) of
(Signedness
_, PrimType
Bool) -> String
"bool"
(Signedness
_, PrimType
Unit) -> String
"bool"
(Signedness
_, FloatType FloatType
Float32) -> String
" f32"
(Signedness
_, FloatType FloatType
Float64) -> String
" f64"
(Signedness
TypeDirect, IntType IntType
Int8) -> String
" i8"
(Signedness
TypeDirect, IntType IntType
Int16) -> String
" i16"
(Signedness
TypeDirect, IntType IntType
Int32) -> String
" i32"
(Signedness
TypeDirect, IntType IntType
Int64) -> String
" i64"
(Signedness
TypeUnsigned, IntType IntType
Int8) -> String
" u8"
(Signedness
TypeUnsigned, IntType IntType
Int16) -> String
" u16"
(Signedness
TypeUnsigned, IntType IntType
Int32) -> String
" u32"
(Signedness
TypeUnsigned, IntType IntType
Int64) -> String
" u64"
storeValueHeader :: Signedness -> PrimType -> Int -> C.Exp -> C.Exp -> [C.Stm]
Signedness
sign PrimType
pt Int
rank Exp
shape Exp
dest =
[C.cstms|
*$exp:dest++ = 'b';
*$exp:dest++ = 2;
*$exp:dest++ = $int:rank;
memcpy($exp:dest, $string:(typeStr sign pt), 4);
$exp:dest += 4;
$stms:copy_shape
|]
where
copy_shape :: [Stm]
copy_shape
| Int
rank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = []
| Bool
otherwise =
[C.cstms|
memcpy($exp:dest, $exp:shape, $int:rank*sizeof(typename int64_t));
$exp:dest += $int:rank*sizeof(typename int64_t);|]
loadValueHeader :: Signedness -> PrimType -> Int -> C.Exp -> C.Exp -> [C.Stm]
Signedness
sign PrimType
pt Int
rank Exp
shape Exp
src =
[C.cstms|
err |= (*$exp:src++ != 'b');
err |= (*$exp:src++ != 2);
err |= (*$exp:src++ != $exp:rank);
err |= (memcmp($exp:src, $string:(typeStr sign pt), 4) != 0);
$exp:src += 4;
if (err == 0) {
$stms:load_shape
$exp:src += $int:rank*sizeof(typename int64_t);
}|]
where
load_shape :: [Stm]
load_shape
| Int
rank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = []
| Bool
otherwise = [C.cstms|memcpy($exp:shape, src, $int:rank*sizeof(typename int64_t));|]