module Data.BitVector
(
BitVector
, BV
, size, width
, nat, uint, int
, bitVec
, ones, zeros
, isNat
, isPos
, (==.), (/=.)
, (<.), (<=.), (>.), (>=.)
, slt, sle, sgt, sge
, (@.), index
, (@@), extract
, (!.)
, least, most
, msb, lsb, msb1, lsb1
, signumI
, sdiv, srem, smod
, lg2
, (#), cat
, zeroExtend, signExtend
, foldl, foldl_
, foldr, foldr_
, reverse, reverse_
, replicate, replicate_
, and, and_
, or, or_
, split
, group, group_
, join
, module Data.Bits
, not, not_
, nand, nor, xnor
, (<<.), shl, (>>.), shr, ashr
, (<<<.), rol, (>>>.), ror
, fromBool
, fromBits
, toBits
, showBin
, showOct
, showHex
, maxNat
, integerWidth
) where
import Control.Exception ( assert )
import Data.Bits
import Data.Bool ( Bool(..), otherwise, (&&))
import qualified Data.Bool as Bool
import Data.Data ( Data )
import qualified Data.List as List
( foldr, foldl1'
, length
, map
, maximum
)
import Data.Ord
import Data.Typeable ( Typeable )
import Prelude
( Char
, Eq(..)
, Enum(..), Num(..)
, Integral(..), Int, Integer
, Maybe(..)
, Real(..)
, Show(..), String
, const
, error
, flip, fromIntegral
, maxBound
, snd
, undefined
, ($), (.), (^), (++)
)
data BV
= BV {
size :: !Int
, nat :: !Integer
}
deriving (Data,Typeable)
type BitVector = BV
width :: BV -> Int
width = size
uint :: BV -> Integer
uint = nat
int :: BV -> Integer
int u | msb u = nat(u)
| otherwise = nat u
instance Show BV where
show (BV n a) = "[" ++ show n ++ "]" ++ show a
bitVec :: Integral a => Int -> a -> BV
bitVec n a | n < 0 = error "Data.BitVector.bitVec: negative size"
| a >= 0 = BV n $ fromIntegral a
| otherwise = negate $ BV n $ fromIntegral (a)
ones :: Int -> BV
ones n | n < 0 = error "Data.BitVector.ones: negative size"
| otherwise = BV n (2^n 1)
zeros :: Int -> BV
zeros n | n < 0 = error "Data.BitVector.zeros: negative size"
| otherwise = BV n 0
isNat :: BV -> Bool
isNat = Bool.not . msb
isPos :: BV -> Bool
isPos a = int(a) > 0
infix 4 ==., /=., <., <=., >., >=.
infix 4 `slt`, `sle`, `sgt`, `sge`
instance Eq BV where
(BV _ a) == (BV _ b) = a == b
instance Ord BV where
compare = comparing nat
(==.) :: BV -> BV -> Bool
(BV n a) ==. (BV m b) = n == m && a == b
(/=.) :: BV -> BV -> Bool
u /=. v = Bool.not $ u ==. v
(<.) :: BV -> BV -> Bool
(BV n a) <. (BV m b) = n == m && a < b
(<=.) :: BV -> BV -> Bool
(BV n a) <=. (BV m b) = n == m && a <= b
(>.) :: BV -> BV -> Bool
(BV n a) >. (BV m b) = n == m && a > b
(>=.) :: BV -> BV -> Bool
(BV n a) >=. (BV m b) = n == m && a >= b
slt :: BV -> BV -> Bool
u@BV{size=n} `slt` v@BV{size=m} = n == m && int u < int v
sle :: BV -> BV -> Bool
u@BV{size=n} `sle` v@BV{size=m} = n == m && int u <= int v
sgt :: BV -> BV -> Bool
u@BV{size=n} `sgt` v@BV{size=m} = n == m && int u > int v
sge :: BV -> BV -> Bool
u@BV{size=n} `sge` v@BV{size=m} = n == m && int u >= int v
infixl 9 @., @@, !.
(@.) :: Integral ix => BV -> ix -> Bool
(BV n a) @. i | 0 <= i' && i' < n = testBit a i'
| otherwise = error "Data.BitVector.(@.): index of out bounds"
where i' = fromIntegral i
index :: Integral ix => ix -> BV -> Bool
index = flip (@.)
(@@) :: Integral ix => BV -> (ix,ix) -> BV
(BV _ a) @@ (j,i) | 0 <= i && i <= j = BV m $ (a `shiftR` i') `mod` 2^m
| otherwise = error "Data.BitVector.(@@): invalid range"
where i' = fromIntegral i
m = fromIntegral $ j i + 1
extract :: Integral ix => ix -> ix -> BV -> BV
extract j i = (@@ (j,i))
(!.) :: Integral ix => BV -> ix -> Bool
(BV n a) !. i | 0 <= i' && i' < n = testBit a (ni'1)
| otherwise = error "Data.BitVector.(!.): index out of bounds"
where i' = fromIntegral i
least :: Integral ix => ix -> BV -> BV
least m (BV _ a) | m' < 1 = error "Data.BitVector.least: non-positive index"
| otherwise = BV m' $ a `mod` 2^m
where m' = fromIntegral m
most :: Integral ix => ix -> BV -> BV
most m (BV n a) | m' < 1 = error "Data.BitVector.most: non-positive index"
| m' > n = error "Data.BitVector.most: index out of bounds"
| otherwise = BV m' $ a `shiftR` (nm')
where m' = fromIntegral m
msb :: BV -> Bool
msb = (!. (0::Int))
lsb :: BV -> Bool
lsb = (@. (0::Int))
msb1 :: BV -> Int
msb1 (BV _ 0) = error "Data.BitVector.msb1: zero bit-vector"
msb1 (BV n a) = go (n1)
where go i | testBit a i = i
| otherwise = go (i1)
lsb1 :: BV -> Int
lsb1 (BV _ 0) = error "Data.BitVector.lsb1: zero bit-vector"
lsb1 (BV _ a) = go 0
where go i | testBit a i = i
| otherwise = go (i+1)
instance Num BV where
(BV n1 a) + (BV n2 b) = BV n $ (a + b) `mod` 2^n
where n = max n1 n2
(BV n1 a) * (BV n2 b) = BV n $ (a * b) `mod` 2^n
where n = max n1 n2
negate (BV n a) = BV n $ 2^n a
abs u | msb u = negate u
| otherwise = u
signum u = bitVec 2 $ signum $ int u
fromInteger i = bitVec (integerWidth i) i
signumI :: Integral a => BV -> a
signumI = fromInteger . signum . int
instance Real BV where
toRational = toRational . nat
instance Enum BV where
toEnum = fromIntegral
fromEnum (BV _ a) = assert (a < max_int) $ fromIntegral a
where max_int = toInteger (maxBound::Int)
instance Integral BV where
quotRem (BV n1 a) (BV n2 b) = (BV n q,BV n r)
where n = max n1 n2
(q,r) = quotRem a b
divMod = quotRem
toInteger = nat
sdiv :: BV -> BV -> BV
sdiv u@(BV n1 _) v@(BV n2 _) = bitVec n q
where n = max n1 n2
q = int u `quot` int v
srem :: BV -> BV -> BV
srem u@(BV n1 _) v@(BV n2 _) = bitVec n r
where n = max n1 n2
r = int u `rem` int v
smod :: BV -> BV -> BV
smod u@(BV n1 _) v@(BV n2 _) = bitVec n r
where n = max n1 n2
r = int u `mod` int v
lg2 :: BV -> BV
lg2 (BV _ 0) = error "Data.BitVector.lg2: zero bit-vector"
lg2 (BV n 1) = BV n 0
lg2 (BV n a) = BV n $ toInteger $ integerWidth (a1)
infixr 5 #
(#), cat :: BV -> BV -> BV
(BV n a) # (BV m b) = BV (n + m) ((a `shiftL` m) + b)
cat = (#)
zeroExtend :: Integral size => size -> BV -> BV
zeroExtend d (BV n a) = BV (n+d') a
where d' = fromIntegral d
signExtend :: Integral size => size -> BV -> BV
signExtend d (BV n a)
| testBit a (n1) = BV (n+d') $ (maxNat d `shiftL` n) + a
| otherwise = BV (n+d') a
where d' = fromIntegral d
foldl, foldl_ :: (a -> Bool -> a) -> a -> BV -> a
foldl f e (BV n a) = go (n1) e
where go i !x | i >= 0 = let !b = testBit a i in go (i1) $ f x b
| otherwise = x
foldl_ = foldl
foldr, foldr_ :: (Bool -> a -> a) -> a -> BV -> a
foldr f e (BV n a) = go (n1) e
where go i !x | i >= 0 = let !b = testBit a i in f b (go (i1) x)
| otherwise = x
foldr_ = foldr
reverse, reverse_ :: BV -> BV
reverse bv@(BV n _) = BV n $ snd $ foldl go (1,0) bv
where go (v,acc) b | b = (v',acc+v)
| otherwise = (v',acc)
where v' = 2*v
reverse_ = reverse
replicate, replicate_ :: Integral size => size -> BV -> BV
replicate 0 _ = error "Data.BitVector.replicate: cannot replicate 0-times"
replicate n u = go (n1) u
where go 0 !acc = acc
go k !acc = go (k1) (u # acc)
replicate_ = replicate
and, and_ :: [BV] -> BV
and [] = error "Data.BitVector.and: empty list"
and ws = BV n' $ List.foldl1' (.&.) $ List.map nat ws
where n' = List.maximum $ List.map size ws
and_ = and
or, or_ :: [BV] -> BV
or [] = error "Data.BitVector.or: empty list"
or ws = BV n' $ List.foldl1' (.|.) $ List.map nat ws
where n' = List.maximum $ List.map size ws
or_ = or
split :: Integral times => times -> BV -> [BV]
split k (BV n a) | k > 0 = List.map (BV s) $ splitInteger s k' a
| otherwise = error "Data.BitVector.split: non-positive splits"
where k' = fromIntegral k
(q,r) = divMod n k'
s = q + signum r
group, group_ :: Integral size => size -> BV -> [BV]
group s (BV n a) | s > 0 = List.map (BV s') $ splitInteger s' k a
| otherwise = error "Data.BitVector.group: non-positive size"
where s' = fromIntegral s
(q,r) = divMod n s'
k = q + signum r
group_ = group
splitInteger :: (Integral size, Integral times) =>
size -> times -> Integer -> [Integer]
splitInteger n = go []
where n' = fromIntegral n
go acc 0 _ = acc
go acc k a = go (v:acc) (k1) a'
where v = a `mod` 2^n
a' = a `shiftR` n'
join :: [BV] -> BV
join = List.foldl1' (#)
infixl 8 <<., `shl`, >>., `shr`, `ashr`, <<<., `rol`, >>>., `ror`
instance Bits BV where
(BV n1 a) .&. (BV n2 b) = BV n $ a .&. b
where n = max n1 n2
(BV n1 a) .|. (BV n2 b) = BV n $ a .|. b
where n = max n1 n2
(BV n1 a) `xor` (BV n2 b) = BV n $ a `xor` b
where n = max n1 n2
complement (BV n a) = BV n $ 2^n 1 a
#if MIN_VERSION_base(4,7,0)
zeroBits = BV 1 0
#endif
bit i = BV (i+1) (2^i)
testBit (BV n a) i | i < n = testBit a i
| otherwise = False
bitSize = undefined
#if MIN_VERSION_base(4,7,0)
bitSizeMaybe = const Nothing
#endif
isSigned = const False
shiftL (BV n a) k
| k > n = BV n 0
| otherwise = BV n $ shiftL a k `mod` 2^n
shiftR (BV n a) k
| k > n = BV n 0
| otherwise = BV n $ shiftR a k
rotateL bv 0 = bv
rotateL (BV n a) k
| k == n = BV n a
| k > n = rotateL (BV n a) (k `mod` n)
| otherwise = BV n $ h + l
where s = n k
l = a `shiftR` s
h = (a `shiftL` k) `mod` 2^n
rotateR bv 0 = bv
rotateR (BV n a) k
| k == n = BV n a
| k > n = rotateR (BV n a) (k `mod` n)
| otherwise = BV n $ h + l
where s = n k
l = a `shiftR` k
h = (a `shiftL` s) `mod` 2^n
popCount (BV _ a) = assert (a >= 0) $ popCount a
not, not_ :: BV -> BV
not = complement
not_ = not
nand :: BV -> BV -> BV
nand u v = not $ u .&. v
nor :: BV -> BV -> BV
nor u v = not $ u .|. v
xnor :: BV -> BV -> BV
xnor u v = not $ u `xor` v
(<<.), shl :: BV -> BV -> BV
bv@BV{size=n} <<. (BV _ k)
| k >= fromIntegral n = BV n 0
| otherwise = bv `shiftL` (fromIntegral k)
shl = (<<.)
(>>.), shr :: BV -> BV -> BV
bv@BV{size=n} >>. (BV _ k)
| k >= fromIntegral n = BV n 0
| otherwise = bv `shiftR` (fromIntegral k)
shr = (>>.)
ashr :: BV -> BV -> BV
ashr u v | msb u = not ((not u) >>. v)
| otherwise = u >>. v
(<<<.), rol :: BV -> BV -> BV
bv@BV{size=n} <<<. (BV _ k)
| k >= n' = bv `rotateL` (fromIntegral $ k `mod` n')
| otherwise = bv `rotateL` (fromIntegral k)
where n' = fromIntegral n
rol = (<<<.)
(>>>.), ror :: BV -> BV -> BV
bv@BV{size=n} >>>. (BV _ k)
| k >= n' = bv `rotateR` (fromIntegral $ k `mod` n')
| otherwise = bv `rotateR` (fromIntegral k)
where n' = fromIntegral n
ror = (>>>.)
fromBool :: Bool -> BV
fromBool False = BV 1 0
fromBool True = BV 1 1
fromBits :: [Bool] -> BV
fromBits bs = BV n $ snd $ List.foldr go (1,0) bs
where n = List.length bs
go b (!v,!acc) | b = (v',acc+v)
| otherwise = (v',acc)
where v' = 2*v
toBits :: BV -> [Bool]
toBits (BV n a) = List.map (testBit a) [n1,n2..0]
showBin :: BV -> String
showBin = ("0b" ++) . List.map showBit . toBits
where showBit True = '1'
showBit False = '0'
hexChar :: Integral a => a -> Char
hexChar 0 = '0'
hexChar 1 = '1'
hexChar 2 = '2'
hexChar 3 = '3'
hexChar 4 = '4'
hexChar 5 = '5'
hexChar 6 = '6'
hexChar 7 = '7'
hexChar 8 = '8'
hexChar 9 = '9'
hexChar 10 = 'a'
hexChar 11 = 'b'
hexChar 12 = 'c'
hexChar 13 = 'd'
hexChar 14 = 'e'
hexChar 15 = 'f'
hexChar _ = error "Data.BitVector.hexChar: invalid input"
showOct :: BV -> String
showOct = ("0o" ++) . List.map (hexChar . nat) . group (3::Int)
showHex :: BV -> String
showHex = ("0x" ++) . List.map (hexChar . nat) . group (4::Int)
maxNat :: (Integral a, Integral b) => a -> b
maxNat n = 2^n 1
integerWidth :: Integer -> Int
integerWidth !n
| n >= 0 = go 1 1
| otherwise = 1 + integerWidth (abs n)
where go !k !k_max | k_max >= n = k
| otherwise = go (k+1) (2*k_max+1)