{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE MagicHash           #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE UnboxedTuples       #-}
-- |
-- Module      : Data.Array.Accelerate.Internal.Num2
-- Copyright   : [2016..2020] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

#include "MachDeps.h"

module Data.Array.Accelerate.Internal.Num2 ( Num2(..) )
  where

import Data.Bits
import Data.Int
import Data.Word
import Prelude

#if UNBOXED_TUPLES
import GHC.Prim                                                     ( plusWord2#, timesWord2# )
#if WORD_SIZE_IN_BITS == 32
import GHC.Word                                                     ( Word32(..) )
#endif
#if WORD_SIZE_IN_BITS == 64
import GHC.Word                                                     ( Word64(..) )
#endif
#endif


-- | Addition and multiplication with carry
--
class Num2 w where
  type Signed   w
  type Unsigned w
  --
  signed        :: w -> Signed w
  unsigned      :: w -> Unsigned w
  addWithCarry  :: w -> w -> (w, Unsigned w)
  mulWithCarry  :: w -> w -> (w, Unsigned w)


-- Base
-- ----

instance Num2 Int8 where
  type Signed   Int8 = Int8
  type Unsigned Int8 = Word8
  --
  signed :: Int8 -> Signed Int8
signed       = Int8 -> Signed Int8
forall a. a -> a
id
  unsigned :: Int8 -> Unsigned Int8
unsigned     = Int8 -> Unsigned Int8
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  addWithCarry :: Int8 -> Int8 -> (Int8, Unsigned Int8)
addWithCarry = (Int16 -> Int16 -> Int16) -> Int8 -> Int8 -> (Int8, Unsigned Int8)
forall w ww.
(FiniteBits w, Bits ww, Integral w, Integral ww,
 Integral (Unsigned w)) =>
(ww -> ww -> ww) -> w -> w -> (w, Unsigned w)
defaultUnwrapped (Int16 -> Int16 -> Int16
forall a. Num a => a -> a -> a
(+) :: Int16 -> Int16 -> Int16)
  mulWithCarry :: Int8 -> Int8 -> (Int8, Unsigned Int8)
mulWithCarry = (Int16 -> Int16 -> Int16) -> Int8 -> Int8 -> (Int8, Unsigned Int8)
forall w ww.
(FiniteBits w, Bits ww, Integral w, Integral ww,
 Integral (Unsigned w)) =>
(ww -> ww -> ww) -> w -> w -> (w, Unsigned w)
defaultUnwrapped (Int16 -> Int16 -> Int16
forall a. Num a => a -> a -> a
(*) :: Int16 -> Int16 -> Int16)

instance Num2 Word8 where
  type Signed   Word8 = Int8
  type Unsigned Word8 = Word8
  --
  signed :: Word8 -> Signed Word8
signed       = Word8 -> Signed Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unsigned :: Word8 -> Unsigned Word8
unsigned     = Word8 -> Unsigned Word8
forall a. a -> a
id
  addWithCarry :: Word8 -> Word8 -> (Word8, Unsigned Word8)
addWithCarry = (Word16 -> Word16 -> Word16)
-> Word8 -> Word8 -> (Word8, Unsigned Word8)
forall w ww.
(FiniteBits w, Bits ww, Integral w, Integral ww,
 Integral (Unsigned w)) =>
(ww -> ww -> ww) -> w -> w -> (w, Unsigned w)
defaultUnwrapped (Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
(+) :: Word16 -> Word16 -> Word16)
  mulWithCarry :: Word8 -> Word8 -> (Word8, Unsigned Word8)
mulWithCarry = (Word16 -> Word16 -> Word16)
-> Word8 -> Word8 -> (Word8, Unsigned Word8)
forall w ww.
(FiniteBits w, Bits ww, Integral w, Integral ww,
 Integral (Unsigned w)) =>
(ww -> ww -> ww) -> w -> w -> (w, Unsigned w)
defaultUnwrapped (Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
(*) :: Word16 -> Word16 -> Word16)

instance Num2 Int16 where
  type Signed   Int16 = Int16
  type Unsigned Int16 = Word16
  --
  signed :: Int16 -> Signed Int16
signed       = Int16 -> Signed Int16
forall a. a -> a
id
  unsigned :: Int16 -> Unsigned Int16
unsigned     = Int16 -> Unsigned Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  addWithCarry :: Int16 -> Int16 -> (Int16, Unsigned Int16)
addWithCarry = (Int32 -> Int32 -> Int32)
-> Int16 -> Int16 -> (Int16, Unsigned Int16)
forall w ww.
(FiniteBits w, Bits ww, Integral w, Integral ww,
 Integral (Unsigned w)) =>
(ww -> ww -> ww) -> w -> w -> (w, Unsigned w)
defaultUnwrapped (Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
(+) :: Int32 -> Int32 -> Int32)
  mulWithCarry :: Int16 -> Int16 -> (Int16, Unsigned Int16)
mulWithCarry = (Int32 -> Int32 -> Int32)
-> Int16 -> Int16 -> (Int16, Unsigned Int16)
forall w ww.
(FiniteBits w, Bits ww, Integral w, Integral ww,
 Integral (Unsigned w)) =>
(ww -> ww -> ww) -> w -> w -> (w, Unsigned w)
defaultUnwrapped (Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
(*) :: Int32 -> Int32 -> Int32)

instance Num2 Word16 where
  type Signed   Word16 = Int16
  type Unsigned Word16 = Word16
  --
  signed :: Word16 -> Signed Word16
signed       = Word16 -> Signed Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unsigned :: Word16 -> Unsigned Word16
unsigned     = Word16 -> Unsigned Word16
forall a. a -> a
id
  addWithCarry :: Word16 -> Word16 -> (Word16, Unsigned Word16)
addWithCarry = (Word32 -> Word32 -> Word32)
-> Word16 -> Word16 -> (Word16, Unsigned Word16)
forall w ww.
(FiniteBits w, Bits ww, Integral w, Integral ww,
 Integral (Unsigned w)) =>
(ww -> ww -> ww) -> w -> w -> (w, Unsigned w)
defaultUnwrapped (Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
(+) :: Word32 -> Word32 -> Word32)
  mulWithCarry :: Word16 -> Word16 -> (Word16, Unsigned Word16)
mulWithCarry = (Word32 -> Word32 -> Word32)
-> Word16 -> Word16 -> (Word16, Unsigned Word16)
forall w ww.
(FiniteBits w, Bits ww, Integral w, Integral ww,
 Integral (Unsigned w)) =>
(ww -> ww -> ww) -> w -> w -> (w, Unsigned w)
defaultUnwrapped (Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
(*) :: Word32 -> Word32 -> Word32)

instance Num2 Int32 where
  type Signed   Int32 = Int32
  type Unsigned Int32 = Word32
  --
  signed :: Int32 -> Signed Int32
signed       = Int32 -> Signed Int32
forall a. a -> a
id
  unsigned :: Int32 -> Unsigned Int32
unsigned     = Int32 -> Unsigned Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  addWithCarry :: Int32 -> Int32 -> (Int32, Unsigned Int32)
addWithCarry = (Int64 -> Int64 -> Int64)
-> Int32 -> Int32 -> (Int32, Unsigned Int32)
forall w ww.
(FiniteBits w, Bits ww, Integral w, Integral ww,
 Integral (Unsigned w)) =>
(ww -> ww -> ww) -> w -> w -> (w, Unsigned w)
defaultUnwrapped (Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
(+) :: Int64 -> Int64 -> Int64)
  mulWithCarry :: Int32 -> Int32 -> (Int32, Unsigned Int32)
mulWithCarry = (Int64 -> Int64 -> Int64)
-> Int32 -> Int32 -> (Int32, Unsigned Int32)
forall w ww.
(FiniteBits w, Bits ww, Integral w, Integral ww,
 Integral (Unsigned w)) =>
(ww -> ww -> ww) -> w -> w -> (w, Unsigned w)
defaultUnwrapped (Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
(*) :: Int64 -> Int64 -> Int64)

instance Num2 Word32 where
  type Signed   Word32 = Int32
  type Unsigned Word32 = Word32
  --
  signed :: Word32 -> Signed Word32
signed       = Word32 -> Signed Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unsigned :: Word32 -> Unsigned Word32
unsigned     = Word32 -> Unsigned Word32
forall a. a -> a
id
#if UNBOXED_TUPLES && WORD_SIZE_IN_BITS == 32
  addWithCarry (W32# x#) (W32# y#) = case plusWord2#  x# y# of (# hi#, lo# #) -> (W32# hi#, W32# lo#)
  mulWithCarry (W32# x#) (W32# y#) = case timesWord2# x# y# of (# hi#, lo# #) -> (W32# hi#, W32# lo#)
#else
  addWithCarry :: Word32 -> Word32 -> (Word32, Unsigned Word32)
addWithCarry = (Word64 -> Word64 -> Word64)
-> Word32 -> Word32 -> (Word32, Unsigned Word32)
forall w ww.
(FiniteBits w, Bits ww, Integral w, Integral ww,
 Integral (Unsigned w)) =>
(ww -> ww -> ww) -> w -> w -> (w, Unsigned w)
defaultUnwrapped (Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
(+) :: Word64 -> Word64 -> Word64)
  mulWithCarry :: Word32 -> Word32 -> (Word32, Unsigned Word32)
mulWithCarry = (Word64 -> Word64 -> Word64)
-> Word32 -> Word32 -> (Word32, Unsigned Word32)
forall w ww.
(FiniteBits w, Bits ww, Integral w, Integral ww,
 Integral (Unsigned w)) =>
(ww -> ww -> ww) -> w -> w -> (w, Unsigned w)
defaultUnwrapped (Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
(*) :: Word64 -> Word64 -> Word64)
#endif

instance Num2 Int64 where
  type Signed   Int64 = Int64
  type Unsigned Int64 = Word64
  --
  signed :: Int64 -> Signed Int64
signed       = Int64 -> Signed Int64
forall a. a -> a
id
  unsigned :: Int64 -> Unsigned Int64
unsigned     = Int64 -> Unsigned Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  addWithCarry :: Int64 -> Int64 -> (Int64, Unsigned Int64)
addWithCarry Int64
x Int64
y = Int64
Signed Word64
hi Int64 -> (Int64, Word64) -> (Int64, Word64)
`seq` Word64
lo Word64 -> (Int64, Word64) -> (Int64, Word64)
`seq` (Int64
Signed Word64
hi,Word64
lo)
    where
      extX :: Word64
extX      = if Int64
x Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0 then Word64
forall a. Bounded a => a
maxBound else Word64
0
      extY :: Word64
extY      = if Int64
y Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0 then Word64
forall a. Bounded a => a
maxBound else Word64
0
      (Word64
hi',Word64
lo)  = Int64 -> Unsigned Int64
forall w. Num2 w => w -> Unsigned w
unsigned Int64
x Word64 -> Word64 -> (Word64, Unsigned Word64)
forall w. Num2 w => w -> w -> (w, Unsigned w)
`addWithCarry` Int64 -> Unsigned Int64
forall w. Num2 w => w -> Unsigned w
unsigned Int64
y
      hi :: Signed Word64
hi        = Word64 -> Signed Word64
forall w. Num2 w => w -> Signed w
signed (Word64
hi' Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
extX Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
extY)

  mulWithCarry :: Int64 -> Int64 -> (Int64, Unsigned Int64)
mulWithCarry Int64
x Int64
y = Int64
hi Int64 -> (Int64, Word64) -> (Int64, Word64)
`seq` Word64
lo Word64 -> (Int64, Word64) -> (Int64, Word64)
`seq` (Int64
hi,Word64
lo)
    where
      extX :: Int64
extX      = if Int64
x Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0 then Int64 -> Int64
forall a. Num a => a -> a
negate Int64
y else Int64
0
      extY :: Int64
extY      = if Int64
y Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0 then Int64 -> Int64
forall a. Num a => a -> a
negate Int64
x else Int64
0
      (Word64
hi',Word64
lo)  = Int64 -> Unsigned Int64
forall w. Num2 w => w -> Unsigned w
unsigned Int64
x Word64 -> Word64 -> (Word64, Unsigned Word64)
forall w. Num2 w => w -> w -> (w, Unsigned w)
`mulWithCarry` Int64 -> Unsigned Int64
forall w. Num2 w => w -> Unsigned w
unsigned Int64
y
      hi :: Int64
hi        = Word64 -> Signed Word64
forall w. Num2 w => w -> Signed w
signed Word64
hi' Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
extX Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
extY

instance Num2 Word64 where
  type Signed   Word64 = Int64
  type Unsigned Word64 = Word64
  --
  signed :: Word64 -> Signed Word64
signed       = Word64 -> Signed Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unsigned :: Word64 -> Unsigned Word64
unsigned     = Word64 -> Unsigned Word64
forall a. a -> a
id
#if UNBOXED_TUPLES && WORD_SIZE_IN_BITS == 64
  addWithCarry :: Word64 -> Word64 -> (Word64, Unsigned Word64)
addWithCarry (W64# Word#
x#) (W64# Word#
y#) = case Word# -> Word# -> (# Word#, Word# #)
plusWord2#  Word#
x# Word#
y# of (# Word#
hi#, Word#
lo# #) -> (Word# -> Word64
W64# Word#
hi#, Word# -> Word64
W64# Word#
lo#)
  mulWithCarry :: Word64 -> Word64 -> (Word64, Unsigned Word64)
mulWithCarry (W64# Word#
x#) (W64# Word#
y#) = case Word# -> Word# -> (# Word#, Word# #)
timesWord2# Word#
x# Word#
y# of (# Word#
hi#, Word#
lo# #) -> (Word# -> Word64
W64# Word#
hi#, Word# -> Word64
W64# Word#
lo#)
#else
  addWithCarry x y = (hi,lo)
    where
      !lo             = x + y
      !hi | lo < x    = 1
          | otherwise = 0
  --
  mulWithCarry x y = (hi,lo)
    where
      xHi         = shiftR x 32
      yHi         = shiftR y 32
      xLo         = x .&. 0xFFFFFFFF
      yLo         = y .&. 0xFFFFFFFF
      hi0         = xHi * yHi
      lo0         = xLo * yLo
      p1          = xHi * yLo
      p2          = xLo * yHi
      (uHi1, uLo) = addWithCarry (fromIntegral p1) (fromIntegral p2)
      (uHi2, lo') = addWithCarry (fromIntegral (shiftR lo0 32)) uLo
      !hi         = hi0 + fromIntegral (uHi1::Word32) + fromIntegral uHi2 + shiftR p1 32 + shiftR p2 32
      !lo         = shiftL (fromIntegral lo') 32 .|. (lo0 .&. 0xFFFFFFFF)
#endif

{-# INLINE defaultUnwrapped #-}
defaultUnwrapped
    :: (FiniteBits w, Bits ww, Integral w, Integral ww, Integral (Unsigned w))
    => (ww -> ww -> ww)
    -> w
    -> w
    -> (w, Unsigned w)
defaultUnwrapped :: (ww -> ww -> ww) -> w -> w -> (w, Unsigned w)
defaultUnwrapped ww -> ww -> ww
op w
x w
y = (w
hi, Unsigned w
lo)
  where
    !r :: ww
r  = w -> ww
forall a b. (Integral a, Num b) => a -> b
fromIntegral w
x ww -> ww -> ww
`op` w -> ww
forall a b. (Integral a, Num b) => a -> b
fromIntegral w
y
    !lo :: Unsigned w
lo = ww -> Unsigned w
forall a b. (Integral a, Num b) => a -> b
fromIntegral ww
r
    !hi :: w
hi = ww -> w
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ww -> Int -> ww
forall a. Bits a => a -> Int -> a
shiftR ww
r (w -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize w
x))