{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.Data.Bits (
Bits(..),
FiniteBits(..),
) where
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Language
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Classes.Eq
import Data.Array.Accelerate.Classes.Ord
import Data.Array.Accelerate.Classes.Num
import Data.Array.Accelerate.Classes.Integral ()
import Prelude ( ($), (.), undefined, otherwise )
import qualified Data.Bits as B
infixl 8 `shift`, `rotate`, `shiftL`, `shiftR`, `rotateL`, `rotateR`
infixl 7 .&.
infixl 6 `xor`
infixl 5 .|.
class Eq a => Bits a where
{-# MINIMAL (.&.), (.|.), xor, complement,
(shift | (shiftL, shiftR)),
(rotate | (rotateL, rotateR)),
isSigned, testBit, bit, popCount #-}
(.&.) :: Exp a -> Exp a -> Exp a
(.|.) :: Exp a -> Exp a -> Exp a
xor :: Exp a -> Exp a -> Exp a
complement :: Exp a -> Exp a
shift :: Exp a -> Exp Int -> Exp a
shift x i
= cond (i < 0) (x `shiftR` (-i))
$ cond (i > 0) (x `shiftL` i)
$ x
rotate :: Exp a -> Exp Int -> Exp a
rotate x i
= cond (i < 0) (x `rotateR` (-i))
$ cond (i > 0) (x `rotateL` i)
$ x
zeroBits :: Exp a
zeroBits = clearBit (bit 0) 0
bit :: Exp Int -> Exp a
setBit :: Exp a -> Exp Int -> Exp a
setBit x i = x .|. bit i
clearBit :: Exp a -> Exp Int -> Exp a
clearBit x i = x .&. complement (bit i)
complementBit :: Exp a -> Exp Int -> Exp a
complementBit x i = x `xor` bit i
testBit :: Exp a -> Exp Int -> Exp Bool
isSigned :: Exp a -> Exp Bool
shiftL :: Exp a -> Exp Int -> Exp a
shiftL x i = x `shift` i
unsafeShiftL :: Exp a -> Exp Int -> Exp a
unsafeShiftL = shiftL
shiftR :: Exp a -> Exp Int -> Exp a
shiftR x i = x `shift` (-i)
unsafeShiftR :: Exp a -> Exp Int -> Exp a
unsafeShiftR = shiftR
rotateL :: Exp a -> Exp Int -> Exp a
rotateL x i = x `rotate` i
rotateR :: Exp a -> Exp Int -> Exp a
rotateR x i = x `rotate` (-i)
popCount :: Exp a -> Exp Int
class Bits b => FiniteBits b where
finiteBitSize :: Exp b -> Exp Int
countLeadingZeros :: Exp b -> Exp Int
countTrailingZeros :: Exp b -> Exp Int
instance Bits Bool where
(.&.) = (&&)
(.|.) = (||)
xor = (/=)
complement = not
shift x i = cond (i == 0) x (constant False)
testBit x i = cond (i == 0) x (constant False)
rotate x _ = x
bit i = i == 0
isSigned = isSignedDefault
popCount = mkBoolToInt
instance Bits Int where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits Int8 where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits Int16 where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits Int32 where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits Int64 where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits Word where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits Word8 where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits Word16 where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits Word32 where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits Word64 where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits CInt where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits CUInt where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits CLong where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits CULong where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits CLLong where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits CULLong where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits CShort where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance Bits CUShort where
(.&.) = mkBAnd
(.|.) = mkBOr
xor = mkBXor
complement = mkBNot
bit = bitDefault
testBit = testBitDefault
shift = shiftDefault
shiftL = shiftLDefault
shiftR = shiftRDefault
unsafeShiftL = mkBShiftL
unsafeShiftR = mkBShiftR
rotate = rotateDefault
rotateL = rotateLDefault
rotateR = rotateRDefault
isSigned = isSignedDefault
popCount = mkPopCount
instance FiniteBits Bool where
finiteBitSize _ = constant (B.finiteBitSize (undefined::Bool))
countLeadingZeros x = cond x 0 1
countTrailingZeros x = cond x 0 1
instance FiniteBits Int where
finiteBitSize _ = constant (B.finiteBitSize (undefined::Int))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Int8 where
finiteBitSize _ = constant (B.finiteBitSize (undefined::Int8))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Int16 where
finiteBitSize _ = constant (B.finiteBitSize (undefined::Int16))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Int32 where
finiteBitSize _ = constant (B.finiteBitSize (undefined::Int32))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Int64 where
finiteBitSize _ = constant (B.finiteBitSize (undefined::Int64))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Word where
finiteBitSize _ = constant (B.finiteBitSize (undefined::Word))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Word8 where
finiteBitSize _ = constant (B.finiteBitSize (undefined::Word8))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Word16 where
finiteBitSize _ = constant (B.finiteBitSize (undefined::Word16))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Word32 where
finiteBitSize _ = constant (B.finiteBitSize (undefined::Word32))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Word64 where
finiteBitSize _ = constant (B.finiteBitSize (undefined::Word64))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CInt where
finiteBitSize _ = constant (B.finiteBitSize (undefined::CInt))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CUInt where
finiteBitSize _ = constant (B.finiteBitSize (undefined::CUInt))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CLong where
finiteBitSize _ = constant (B.finiteBitSize (undefined::CLong))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CULong where
finiteBitSize _ = constant (B.finiteBitSize (undefined::CULong))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CLLong where
finiteBitSize _ = constant (B.finiteBitSize (undefined::CLLong))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CULLong where
finiteBitSize _ = constant (B.finiteBitSize (undefined::CULLong))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CShort where
finiteBitSize _ = constant (B.finiteBitSize (undefined::CShort))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CUShort where
finiteBitSize _ = constant (B.finiteBitSize (undefined::CUShort))
countLeadingZeros = mkCountLeadingZeros
countTrailingZeros = mkCountTrailingZeros
bitDefault :: (IsIntegral t, Bits t) => Exp Int -> Exp t
bitDefault x = constant 1 `shiftL` x
testBitDefault :: (IsIntegral t, Bits t) => Exp t -> Exp Int -> Exp Bool
testBitDefault x i = (x .&. bit i) /= constant 0
shiftDefault :: (FiniteBits t, IsIntegral t, B.Bits t) => Exp t -> Exp Int -> Exp t
shiftDefault x i
= cond (i >= 0) (shiftLDefault x i)
(shiftRDefault x (-i))
shiftLDefault :: (FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t
shiftLDefault x i
= cond (i >= finiteBitSize x) (constant 0)
$ mkBShiftL x i
shiftRDefault :: forall t. (B.Bits t, FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t
shiftRDefault
| B.isSigned (undefined::t) = shiftRADefault
| otherwise = shiftRLDefault
shiftRADefault :: (FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t
shiftRADefault x i
= cond (i >= finiteBitSize x) (cond (mkLt x (constant 0)) (constant (-1)) (constant 0))
$ mkBShiftR x i
shiftRLDefault :: (FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t
shiftRLDefault x i
= cond (i >= finiteBitSize x) (constant 0)
$ mkBShiftR x i
rotateDefault :: forall t. (FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t
rotateDefault =
case (integralType :: IntegralType t) of
TypeInt{} -> rotateDefault' (undefined::Word)
TypeInt8{} -> rotateDefault' (undefined::Word8)
TypeInt16{} -> rotateDefault' (undefined::Word16)
TypeInt32{} -> rotateDefault' (undefined::Word32)
TypeInt64{} -> rotateDefault' (undefined::Word64)
TypeWord{} -> rotateDefault' (undefined::Word)
TypeWord8{} -> rotateDefault' (undefined::Word8)
TypeWord16{} -> rotateDefault' (undefined::Word16)
TypeWord32{} -> rotateDefault' (undefined::Word32)
TypeWord64{} -> rotateDefault' (undefined::Word64)
TypeCShort{} -> rotateDefault' (undefined::CUShort)
TypeCUShort{} -> rotateDefault' (undefined::CUShort)
TypeCInt{} -> rotateDefault' (undefined::CUInt)
TypeCUInt{} -> rotateDefault' (undefined::CUInt)
TypeCLong{} -> rotateDefault' (undefined::CULong)
TypeCULong{} -> rotateDefault' (undefined::CULong)
TypeCLLong{} -> rotateDefault' (undefined::CULLong)
TypeCULLong{} -> rotateDefault' (undefined::CULLong)
rotateDefault'
:: forall i w. (Elt w, FiniteBits i, IsIntegral i, IsIntegral w, BitSizeEq i w, BitSizeEq w i)
=> w
-> Exp i
-> Exp Int
-> Exp i
rotateDefault' _ x i
= cond (i' == 0) x
$ w2i ((x' `mkBShiftL` i') `mkBOr` (x' `mkBShiftR` (wsib - i')))
where
w2i = mkBitcast :: Exp w -> Exp i
i2w = mkBitcast :: Exp i -> Exp w
x' = i2w x
i' = i `mkBAnd` (wsib - 1)
wsib = finiteBitSize x
rotateLDefault :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t
rotateLDefault x i
= cond (i == 0) x
$ mkBRotateL x i
rotateRDefault :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t
rotateRDefault x i
= cond (i == 0) x
$ mkBRotateR x i
isSignedDefault :: forall b. B.Bits b => Exp b -> Exp Bool
isSignedDefault _ = constant (B.isSigned (undefined::b))
_popCountDefault :: forall a. (B.FiniteBits a, IsScalar a, Bits a, Num a) => Exp a -> Exp Int
_popCountDefault =
$( [e| case B.finiteBitSize (undefined::a) of
8 -> popCnt8 . mkUnsafeCoerce
16 -> popCnt16 . mkUnsafeCoerce
32 -> popCnt32 . mkUnsafeCoerce
64 -> popCnt64 . mkUnsafeCoerce
_ -> popCountKernighan |] )
popCountKernighan :: (Bits a, Num a) => Exp a -> Exp Int
popCountKernighan x = r
where
(r,_) = untup2
$ while (\(untup2 -> (_,v)) -> v /= 0)
(\(untup2 -> (c,v)) -> tup2 (c+1, v .&. (v-1)))
(tup2 (0,x))
popCnt8 :: Exp Word8 -> Exp Int
popCnt8 v1 = mkFromIntegral c
where
v2 = v1 - ((v1 `unsafeShiftR` 1) .&. 0x55)
v3 = (v2 .&. 0x33) + ((v2 `unsafeShiftR` 2) .&. 0x33)
v4 = (v3 + (v3 `unsafeShiftR` 4)) .&. 0x0F
c = v4 * 0x01
popCnt16 :: Exp Word16 -> Exp Int
popCnt16 v1 = mkFromIntegral c
where
v2 = v1 - ((v1 `unsafeShiftR` 1) .&. 0x5555)
v3 = (v2 .&. 0x3333) + ((v2 `unsafeShiftR` 2) .&. 0x3333)
v4 = (v3 + (v3 `unsafeShiftR` 4)) .&. 0x0F0F
c = (v4 * 0x0101) `unsafeShiftR` 8
popCnt32 :: Exp Word32 -> Exp Int
popCnt32 v1 = mkFromIntegral c
where
v2 = v1 - ((v1 `unsafeShiftR` 1) .&. 0x55555555)
v3 = (v2 .&. 0x33333333) + ((v2 `unsafeShiftR` 2) .&. 0x33333333)
v4 = (v3 + (v3 `unsafeShiftR` 4)) .&. 0x0F0F0F0F
c = (v4 * 0x01010101) `unsafeShiftR` 24
popCnt64 :: Exp Word64 -> Exp Int
popCnt64 v1 = mkFromIntegral c
where
v2 = v1 - ((v1 `unsafeShiftR` 1) .&. 0x5555555555555555)
v3 = (v2 .&. 0x3333333333333333) + ((v2 `unsafeShiftR` 2) .&. 0x3333333333333333)
v4 = (v3 + (v3 `unsafeShiftR` 4)) .&. 0X0F0F0F0F0F0F0F0F
c = (v4 * 0x0101010101010101) `unsafeShiftR` 56