module Data.Primitive.SIMD.FloatX8 (FloatX8) where
import Data.Primitive.SIMD.Class
import GHC.Types
import GHC.Prim
import GHC.Ptr
import GHC.ST
import Foreign.Storable
import Control.Monad.Primitive
import Data.Primitive.Types
import Data.Primitive.ByteArray
import Data.Primitive.Addr
import Data.Monoid
import Data.Typeable
import qualified Data.Vector.Primitive as PV
import qualified Data.Vector.Primitive.Mutable as PMV
import Data.Vector.Unboxed (Unbox)
import qualified Data.Vector.Unboxed as UV
import Data.Vector.Generic (Vector(..))
import Data.Vector.Generic.Mutable (MVector(..))
data FloatX8 = FloatX8 FloatX4# FloatX4# deriving Typeable
abs' :: Float -> Float
abs' (F# x) = F# (abs# x)
abs# :: Float# -> Float#
abs# x = case abs (F# x) of
F# y -> y
signum' :: Float -> Float
signum' (F# x) = F# (signum# x)
signum# :: Float# -> Float#
signum# x = case signum (F# x) of
F# y -> y
instance Eq FloatX8 where
a == b = case unpackFloatX8 a of
(x1, x2, x3, x4, x5, x6, x7, x8) -> case unpackFloatX8 b of
(y1, y2, y3, y4, y5, y6, y7, y8) -> x1 == y1 && x2 == y2 && x3 == y3 && x4 == y4 && x5 == y5 && x6 == y6 && x7 == y7 && x8 == y8
instance Ord FloatX8 where
a `compare` b = case unpackFloatX8 a of
(x1, x2, x3, x4, x5, x6, x7, x8) -> case unpackFloatX8 b of
(y1, y2, y3, y4, y5, y6, y7, y8) -> x1 `compare` y1 <> x2 `compare` y2 <> x3 `compare` y3 <> x4 `compare` y4 <> x5 `compare` y5 <> x6 `compare` y6 <> x7 `compare` y7 <> x8 `compare` y8
instance Show FloatX8 where
showsPrec _ a s = case unpackFloatX8 a of
(x1, x2, x3, x4, x5, x6, x7, x8) -> "FloatX8 (" ++ shows x1 (", " ++ shows x2 (", " ++ shows x3 (", " ++ shows x4 (", " ++ shows x5 (", " ++ shows x6 (", " ++ shows x7 (", " ++ shows x8 (")" ++ s))))))))
instance Num FloatX8 where
(+) = plusFloatX8
() = minusFloatX8
(*) = timesFloatX8
negate = negateFloatX8
abs = mapVector abs'
signum = mapVector signum'
fromInteger = broadcastVector . fromInteger
instance Fractional FloatX8 where
(/) = divideFloatX8
recip v = broadcastVector 1 / v
fromRational = broadcastVector . fromRational
instance Floating FloatX8 where
pi = broadcastVector pi
exp = mapVector exp
sqrt = mapVector sqrt
log = mapVector log
(**) = zipVector (**)
logBase = zipVector (**)
sin = mapVector sin
tan = mapVector tan
cos = mapVector cos
asin = mapVector asin
atan = mapVector atan
acos = mapVector acos
sinh = mapVector sinh
tanh = mapVector tanh
cosh = mapVector cosh
asinh = mapVector asinh
atanh = mapVector atanh
acosh = mapVector acosh
instance Storable FloatX8 where
sizeOf x = vectorSize x * elementSize x
alignment = sizeOf
peek (Ptr a) = readOffAddr (Addr a) 0
poke (Ptr a) = writeOffAddr (Addr a) 0
instance SIMDVector FloatX8 where
type Elem FloatX8 = Float
type ElemTuple FloatX8 = (Float, Float, Float, Float, Float, Float, Float, Float)
nullVector = broadcastVector 0
vectorSize _ = 8
elementSize _ = 4
broadcastVector = broadcastFloatX8
unsafeInsertVector = unsafeInsertFloatX8
packVector = packFloatX8
unpackVector = unpackFloatX8
mapVector = mapFloatX8
zipVector = zipFloatX8
foldVector = foldFloatX8
sumVector = sumFloatX8
instance Prim FloatX8 where
sizeOf# a = let !(I# x) = sizeOf a in x
alignment# a = let !(I# x) = alignment a in x
indexByteArray# ba i = indexFloatX8Array (ByteArray ba) (I# i)
readByteArray# mba i s = let (ST r) = readFloatX8Array (MutableByteArray mba) (I# i) in r s
writeByteArray# mba i v s = let (ST r) = writeFloatX8Array (MutableByteArray mba) (I# i) v in case r s of { (# s', _ #) -> s' }
setByteArray# mba off n v s = let (ST r) = setByteArrayGeneric (MutableByteArray mba) (I# off) (I# n) v in case r s of { (# s', _ #) -> s' }
indexOffAddr# addr i = indexFloatX8OffAddr (Addr addr) (I# i)
readOffAddr# addr i s = let (ST r) = readFloatX8OffAddr (Addr addr) (I# i) in r s
writeOffAddr# addr i v s = let (ST r) = writeFloatX8OffAddr (Addr addr) (I# i) v in case r s of { (# s', _ #) -> s' }
setOffAddr# addr off n v s = let (ST r) = setOffAddrGeneric (Addr addr) (I# off) (I# n) v in case r s of { (# s', _ #) -> s' }
newtype instance UV.Vector FloatX8 = V_FloatX8 (PV.Vector FloatX8)
newtype instance UV.MVector s FloatX8 = MV_FloatX8 (PMV.MVector s FloatX8)
instance Vector UV.Vector FloatX8 where
basicUnsafeFreeze (MV_FloatX8 v) = V_FloatX8 <$> PV.unsafeFreeze v
basicUnsafeThaw (V_FloatX8 v) = MV_FloatX8 <$> PV.unsafeThaw v
basicLength (V_FloatX8 v) = PV.length v
basicUnsafeSlice start len (V_FloatX8 v) = V_FloatX8(PV.unsafeSlice start len v)
basicUnsafeIndexM (V_FloatX8 v) = PV.unsafeIndexM v
basicUnsafeCopy (MV_FloatX8 m) (V_FloatX8 v) = PV.unsafeCopy m v
elemseq _ = seq
instance MVector UV.MVector FloatX8 where
basicLength (MV_FloatX8 v) = PMV.length v
basicUnsafeSlice start len (MV_FloatX8 v) = MV_FloatX8(PMV.unsafeSlice start len v)
basicOverlaps (MV_FloatX8 v) (MV_FloatX8 w) = PMV.overlaps v w
basicUnsafeNew len = MV_FloatX8 <$> PMV.unsafeNew len
#if MIN_VERSION_vector(0,11,0)
basicInitialize (MV_FloatX8 v) = basicInitialize v
#endif
basicUnsafeRead (MV_FloatX8 v) = PMV.unsafeRead v
basicUnsafeWrite (MV_FloatX8 v) = PMV.unsafeWrite v
instance Unbox FloatX8
broadcastFloatX8 :: Float -> FloatX8
broadcastFloatX8 (F# x) = case broadcastFloatX4# x of
v -> FloatX8 v v
packFloatX8 :: (Float, Float, Float, Float, Float, Float, Float, Float) -> FloatX8
packFloatX8 (F# x1, F# x2, F# x3, F# x4, F# x5, F# x6, F# x7, F# x8) = FloatX8 (packFloatX4# (# x1, x2, x3, x4 #)) (packFloatX4# (# x5, x6, x7, x8 #))
unpackFloatX8 :: FloatX8 -> (Float, Float, Float, Float, Float, Float, Float, Float)
unpackFloatX8 (FloatX8 m1 m2) = case unpackFloatX4# m1 of
(# x1, x2, x3, x4 #) -> case unpackFloatX4# m2 of
(# x5, x6, x7, x8 #) -> (F# x1, F# x2, F# x3, F# x4, F# x5, F# x6, F# x7, F# x8)
unsafeInsertFloatX8 :: FloatX8 -> Float -> Int -> FloatX8
unsafeInsertFloatX8 (FloatX8 m1 m2) (F# y) _i@(I# ip) | _i < 4 = FloatX8 (insertFloatX4# m1 y (ip -# 0#)) m2
| otherwise = FloatX8 m1 (insertFloatX4# m2 y (ip -# 4#))
mapFloatX8 :: (Float -> Float) -> FloatX8 -> FloatX8
mapFloatX8 f = mapFloatX8# (\ x -> case f (F# x) of { F# y -> y})
mapFloatX8# :: (Float# -> Float#) -> FloatX8 -> FloatX8
mapFloatX8# f = \ v -> case unpackFloatX8 v of
(F# x1, F# x2, F# x3, F# x4, F# x5, F# x6, F# x7, F# x8) -> packFloatX8 (F# (f x1), F# (f x2), F# (f x3), F# (f x4), F# (f x5), F# (f x6), F# (f x7), F# (f x8))
zipFloatX8 :: (Float -> Float -> Float) -> FloatX8 -> FloatX8 -> FloatX8
zipFloatX8 f = \ v1 v2 -> case unpackFloatX8 v1 of
(x1, x2, x3, x4, x5, x6, x7, x8) -> case unpackFloatX8 v2 of
(y1, y2, y3, y4, y5, y6, y7, y8) -> packFloatX8 (f x1 y1, f x2 y2, f x3 y3, f x4 y4, f x5 y5, f x6 y6, f x7 y7, f x8 y8)
foldFloatX8 :: (Float -> Float -> Float) -> FloatX8 -> Float
foldFloatX8 f' = \ v -> case unpackFloatX8 v of
(x1, x2, x3, x4, x5, x6, x7, x8) -> x1 `f` x2 `f` x3 `f` x4 `f` x5 `f` x6 `f` x7 `f` x8
where f !x !y = f' x y
sumFloatX8 :: FloatX8 -> Float
sumFloatX8 (FloatX8 x1 x2) = case unpackFloatX4# (plusFloatX4# x1 x2) of
(# y1, y2, y3, y4 #) -> F# y1 + F# y2 + F# y3 + F# y4
plusFloatX8 :: FloatX8 -> FloatX8 -> FloatX8
plusFloatX8 (FloatX8 m1_1 m2_1) (FloatX8 m1_2 m2_2) = FloatX8 (plusFloatX4# m1_1 m1_2) (plusFloatX4# m2_1 m2_2)
minusFloatX8 :: FloatX8 -> FloatX8 -> FloatX8
minusFloatX8 (FloatX8 m1_1 m2_1) (FloatX8 m1_2 m2_2) = FloatX8 (minusFloatX4# m1_1 m1_2) (minusFloatX4# m2_1 m2_2)
timesFloatX8 :: FloatX8 -> FloatX8 -> FloatX8
timesFloatX8 (FloatX8 m1_1 m2_1) (FloatX8 m1_2 m2_2) = FloatX8 (timesFloatX4# m1_1 m1_2) (timesFloatX4# m2_1 m2_2)
divideFloatX8 :: FloatX8 -> FloatX8 -> FloatX8
divideFloatX8 (FloatX8 m1_1 m2_1) (FloatX8 m1_2 m2_2) = FloatX8 (divideFloatX4# m1_1 m1_2) (divideFloatX4# m2_1 m2_2)
negateFloatX8 :: FloatX8 -> FloatX8
negateFloatX8 (FloatX8 m1_1 m2_1) = FloatX8 (negateFloatX4# m1_1) (negateFloatX4# m2_1)
indexFloatX8Array :: ByteArray -> Int -> FloatX8
indexFloatX8Array (ByteArray a) (I# i) = FloatX8 (indexFloatX4Array# a ((i *# 2#) +# 0#)) (indexFloatX4Array# a ((i *# 2#) +# 1#))
readFloatX8Array :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> m FloatX8
readFloatX8Array (MutableByteArray a) (I# i) = primitive (\ s0 -> case readFloatX4Array# a ((i *# 2#) +# 0#) s0 of
(# s1, m1 #) -> case readFloatX4Array# a ((i *# 2#) +# 1#) s1 of
(# s2, m2 #) -> (# s2, FloatX8 m1 m2 #))
writeFloatX8Array :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> FloatX8 -> m ()
writeFloatX8Array (MutableByteArray a) (I# i) (FloatX8 m1 m2) = primitive_ (writeFloatX4Array# a ((i *# 2#) +# 0#) m1) >> primitive_ (writeFloatX4Array# a ((i *# 2#) +# 1#) m2)
indexFloatX8OffAddr :: Addr -> Int -> FloatX8
indexFloatX8OffAddr (Addr a) (I# i) = FloatX8 (indexFloatX4OffAddr# (plusAddr# a ((i *# 32#) +# 0#)) 0#) (indexFloatX4OffAddr# (plusAddr# a ((i *# 32#) +# 16#)) 0#)
readFloatX8OffAddr :: PrimMonad m => Addr -> Int -> m FloatX8
readFloatX8OffAddr (Addr a) (I# i) = primitive (\ s0 -> case (\ addr i' -> readFloatX4OffAddr# (plusAddr# addr i') 0#) a ((i *# 32#) +# 0#) s0 of
(# s1, m1 #) -> case (\ addr i' -> readFloatX4OffAddr# (plusAddr# addr i') 0#) a ((i *# 32#) +# 16#) s1 of
(# s2, m2 #) -> (# s2, FloatX8 m1 m2 #))
writeFloatX8OffAddr :: PrimMonad m => Addr -> Int -> FloatX8 -> m ()
writeFloatX8OffAddr (Addr a) (I# i) (FloatX8 m1 m2) = primitive_ (writeFloatX4OffAddr# (plusAddr# a ((i *# 32#) +# 0#)) 0# m1) >> primitive_ (writeFloatX4OffAddr# (plusAddr# a ((i *# 32#) +# 16#)) 0# m2)