{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PatternSynonyms #-}

module Data.LLVM.BitCode.BitString
  (
    BitString
  , emptyBitString
  , toBitString
  , showBitString
  , fromBitString
  , bitStringValue
  , take, drop
  , joinBitString
  , NumBits, NumBytes, pattern Bits', pattern Bytes'
  , bitCount, bitCount#
  , bitsToBytes, bytesToBits
  , addBitCounts
  , subtractBitCounts
  )
where

import Data.Bits ( bit, bitSizeMaybe, Bits )
import GHC.Exts
import Numeric ( showIntAtBase, showHex )

import Prelude hiding (take,drop,splitAt)

----------------------------------------------------------------------
-- Define some convenience newtypes to clarify whether the count of bits or count
-- of bytes is being referenced, and to convert between the two.

newtype NumBits = NumBits Int deriving (Int -> NumBits -> ShowS
[NumBits] -> ShowS
NumBits -> String
(Int -> NumBits -> ShowS)
-> (NumBits -> String) -> ([NumBits] -> ShowS) -> Show NumBits
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NumBits -> ShowS
showsPrec :: Int -> NumBits -> ShowS
$cshow :: NumBits -> String
show :: NumBits -> String
$cshowList :: [NumBits] -> ShowS
showList :: [NumBits] -> ShowS
Show, NumBits -> NumBits -> Bool
(NumBits -> NumBits -> Bool)
-> (NumBits -> NumBits -> Bool) -> Eq NumBits
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NumBits -> NumBits -> Bool
== :: NumBits -> NumBits -> Bool
$c/= :: NumBits -> NumBits -> Bool
/= :: NumBits -> NumBits -> Bool
Eq, Eq NumBits
Eq NumBits =>
(NumBits -> NumBits -> Ordering)
-> (NumBits -> NumBits -> Bool)
-> (NumBits -> NumBits -> Bool)
-> (NumBits -> NumBits -> Bool)
-> (NumBits -> NumBits -> Bool)
-> (NumBits -> NumBits -> NumBits)
-> (NumBits -> NumBits -> NumBits)
-> Ord NumBits
NumBits -> NumBits -> Bool
NumBits -> NumBits -> Ordering
NumBits -> NumBits -> NumBits
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: NumBits -> NumBits -> Ordering
compare :: NumBits -> NumBits -> Ordering
$c< :: NumBits -> NumBits -> Bool
< :: NumBits -> NumBits -> Bool
$c<= :: NumBits -> NumBits -> Bool
<= :: NumBits -> NumBits -> Bool
$c> :: NumBits -> NumBits -> Bool
> :: NumBits -> NumBits -> Bool
$c>= :: NumBits -> NumBits -> Bool
>= :: NumBits -> NumBits -> Bool
$cmax :: NumBits -> NumBits -> NumBits
max :: NumBits -> NumBits -> NumBits
$cmin :: NumBits -> NumBits -> NumBits
min :: NumBits -> NumBits -> NumBits
Ord)
newtype NumBytes = NumBytes Int deriving (Int -> NumBytes -> ShowS
[NumBytes] -> ShowS
NumBytes -> String
(Int -> NumBytes -> ShowS)
-> (NumBytes -> String) -> ([NumBytes] -> ShowS) -> Show NumBytes
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NumBytes -> ShowS
showsPrec :: Int -> NumBytes -> ShowS
$cshow :: NumBytes -> String
show :: NumBytes -> String
$cshowList :: [NumBytes] -> ShowS
showList :: [NumBytes] -> ShowS
Show, NumBytes -> NumBytes -> Bool
(NumBytes -> NumBytes -> Bool)
-> (NumBytes -> NumBytes -> Bool) -> Eq NumBytes
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NumBytes -> NumBytes -> Bool
== :: NumBytes -> NumBytes -> Bool
$c/= :: NumBytes -> NumBytes -> Bool
/= :: NumBytes -> NumBytes -> Bool
Eq, Eq NumBytes
Eq NumBytes =>
(NumBytes -> NumBytes -> Ordering)
-> (NumBytes -> NumBytes -> Bool)
-> (NumBytes -> NumBytes -> Bool)
-> (NumBytes -> NumBytes -> Bool)
-> (NumBytes -> NumBytes -> Bool)
-> (NumBytes -> NumBytes -> NumBytes)
-> (NumBytes -> NumBytes -> NumBytes)
-> Ord NumBytes
NumBytes -> NumBytes -> Bool
NumBytes -> NumBytes -> Ordering
NumBytes -> NumBytes -> NumBytes
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: NumBytes -> NumBytes -> Ordering
compare :: NumBytes -> NumBytes -> Ordering
$c< :: NumBytes -> NumBytes -> Bool
< :: NumBytes -> NumBytes -> Bool
$c<= :: NumBytes -> NumBytes -> Bool
<= :: NumBytes -> NumBytes -> Bool
$c> :: NumBytes -> NumBytes -> Bool
> :: NumBytes -> NumBytes -> Bool
$c>= :: NumBytes -> NumBytes -> Bool
>= :: NumBytes -> NumBytes -> Bool
$cmax :: NumBytes -> NumBytes -> NumBytes
max :: NumBytes -> NumBytes -> NumBytes
$cmin :: NumBytes -> NumBytes -> NumBytes
min :: NumBytes -> NumBytes -> NumBytes
Ord)

pattern Bits' :: Int -> NumBits
pattern $mBits' :: forall {r}. NumBits -> (Int -> r) -> ((# #) -> r) -> r
$bBits' :: Int -> NumBits
Bits' n = NumBits n
{-# COMPLETE Bits' #-}

pattern Bytes' :: Int -> NumBytes
pattern $mBytes' :: forall {r}. NumBytes -> (Int -> r) -> ((# #) -> r) -> r
$bBytes' :: Int -> NumBytes
Bytes' n = NumBytes n
{-# COMPLETE Bytes' #-}

bitCount :: NumBits -> Int
bitCount :: NumBits -> Int
bitCount (NumBits Int
n) = Int
n

bitCount# :: NumBits -> Int#
bitCount# :: NumBits -> Int#
bitCount# (NumBits (I# Int#
n#)) = Int#
n#

{-# INLINE addBitCounts #-}
addBitCounts :: NumBits -> NumBits -> NumBits
addBitCounts :: NumBits -> NumBits -> NumBits
addBitCounts (NumBits (I# Int#
a#)) (NumBits (I# Int#
b#)) = Int -> NumBits
NumBits (Int# -> Int
I# (Int#
a# Int# -> Int# -> Int#
+# Int#
b#))

{-# INLINE subtractBitCounts #-}
subtractBitCounts :: NumBits -> NumBits -> NumBits
subtractBitCounts :: NumBits -> NumBits -> NumBits
subtractBitCounts (NumBits (I# Int#
a#)) (NumBits (I# Int#
b#)) = Int -> NumBits
NumBits (Int# -> Int
I# (Int#
a# Int# -> Int# -> Int#
-# Int#
b#))

{-# INLINE bytesToBits #-}
bitsToBytes :: NumBits -> (NumBytes, NumBits)
bitsToBytes :: NumBits -> (NumBytes, NumBits)
bitsToBytes (NumBits (I# Int#
n#)) = ( Int -> NumBytes
NumBytes (Int# -> Int
I# (Int#
n# Int# -> Int# -> Int#
`uncheckedIShiftRL#` Int#
3#))
                                , Int -> NumBits
NumBits (Int# -> Int
I# (Int#
n# Int# -> Int# -> Int#
`andI#` Int#
7#))
                                )

{-# INLINE bitsToBytes #-}
bytesToBits :: NumBytes -> NumBits
bytesToBits :: NumBytes -> NumBits
bytesToBits (NumBytes (I# Int#
n#)) = Int -> NumBits
NumBits (Int# -> Int
I# (Int#
n# Int# -> Int# -> Int#
`uncheckedIShiftL#` Int#
3#))

----------------------------------------------------------------------

data BitString = BitString
  { BitString -> NumBits
bsLength :: !NumBits
  , BitString -> Int
bsData   :: !Int
    -- Note: the bsData was originally an Integer, which allows an essentially
    -- unlimited size value.  However, this adds some overhead to various
    -- computations, and since LLVM Bitcode is unlikely to ever represent values
    -- greater than the native size (64 bits) as discrete values.  By changing
    -- this to @Int@, the use of unboxed calculations is enabled for better
    -- performance.
    --
    -- The use of Int is potentially unsound because GHC only guarantees it's a
    -- signed integer of at least 32-bits.  However current implementations in
    -- all environments where it's reasonable to use this parser have a 64-bit
    -- Int implementation.  This can be verified via:
    --
    --  > import Data.Bits
    --  > bitSizeMaybe (maxBound :: Int) >= Just 64
    --
    -- There's no good location here to automate this check (perhaps
    -- GetBits.hs:runGetBits?), which is why it isn't verified at runtime.
  } deriving (Int -> BitString -> ShowS
[BitString] -> ShowS
BitString -> String
(Int -> BitString -> ShowS)
-> (BitString -> String)
-> ([BitString] -> ShowS)
-> Show BitString
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BitString -> ShowS
showsPrec :: Int -> BitString -> ShowS
$cshow :: BitString -> String
show :: BitString -> String
$cshowList :: [BitString] -> ShowS
showList :: [BitString] -> ShowS
Show, BitString -> BitString -> Bool
(BitString -> BitString -> Bool)
-> (BitString -> BitString -> Bool) -> Eq BitString
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BitString -> BitString -> Bool
== :: BitString -> BitString -> Bool
$c/= :: BitString -> BitString -> Bool
/= :: BitString -> BitString -> Bool
Eq)

-- | Create an empty BitString

emptyBitString :: BitString
emptyBitString :: BitString
emptyBitString = NumBits -> Int -> BitString
BitString (Int -> NumBits
NumBits Int
0) Int
0


-- | Join two BitString representations together to form a single larger
-- BitString.  The first BitString is the \"lower\" value portion of the resulting
-- BitString.

joinBitString :: BitString -> BitString -> BitString
joinBitString :: BitString -> BitString -> BitString
joinBitString (BitString (Bits' (I# Int#
szA#)) (I# Int#
a#))
              (BitString (Bits' (I# Int#
szB#)) (I# Int#
b#)) =
  BitString { bsLength :: NumBits
bsLength = Int -> NumBits
NumBits (Int# -> Int
I# (Int#
szA# Int# -> Int# -> Int#
+# Int#
szB#))
            , bsData :: Int
bsData = Int# -> Int
I# (Int#
a# Int# -> Int# -> Int#
`orI#` (Int#
b# Int# -> Int# -> Int#
`uncheckedIShiftL#` Int#
szA#))
            }


-- | Given a number of bits to take, and an @Integer@, create a @BitString@.

toBitString :: NumBits -> Int -> BitString
toBitString :: NumBits -> Int -> BitString
toBitString len :: NumBits
len@(Bits' (I# Int#
len#)) (I# Int#
val#) =
  let !mask# :: Int#
mask# = (Int#
1# Int# -> Int# -> Int#
`uncheckedIShiftL#` Int#
len#) Int# -> Int# -> Int#
-# Int#
1#
  in NumBits -> Int -> BitString
BitString NumBits
len (Int# -> Int
I# (Int#
val# Int# -> Int# -> Int#
`andI#` Int#
mask#))


-- | Extract the referenced Integer value from a BitString

bitStringValue :: BitString -> Int
bitStringValue :: BitString -> Int
bitStringValue = BitString -> Int
bsData


-- | Extract a target (Num) value of the desired type from a BitString (using
-- fromInteger to perform the target type conversion).

fromBitString :: (Num a, Bits a) => BitString -> a
fromBitString :: forall a. (Num a, Bits a) => BitString -> a
fromBitString (BitString NumBits
l Int
i) =
  case a -> Maybe Int
forall a. Bits a => a -> Maybe Int
bitSizeMaybe a
x of
    Maybe Int
Nothing -> a
x
    Just Int
n
      -- Verify that the bitstring size is less than the target size, or if it is
      -- greater, that the extra upper bits are all zero.
      | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= NumBits -> Int
bitCount NumBits
l Bool -> Bool -> Bool
|| (Integer
ival Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> Integer
forall a. Bits a => Int -> a
bit Int
n) -> a
x
      | Bool
otherwise -> String -> a
forall a. HasCallStack => String -> a
error ([String] -> String
unwords
           [ String
"Data.LLVM.BitCode.BitString.fromBitString: bitstring value of length", NumBits -> String
forall a. Show a => a -> String
show NumBits
l
           , String
"(mask=0x" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> ShowS
forall a. Integral a => a -> ShowS
showHex Int
i String
")"
           , String
"could not be parsed into type with only", Int -> String
forall a. Show a => a -> String
show Int
n, String
"bits"
           ])
 where
 x :: a
x    = Integer -> a
forall a. Num a => Integer -> a
fromInteger Integer
ival  -- use Num to convert the Integer to the target type
 ival :: Integer
ival = Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
i  -- convert input to an Integer for ^^


showBitString :: BitString -> ShowS
showBitString :: BitString -> ShowS
showBitString BitString
bs = String -> ShowS
showString String
padding ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
bin
  where
  bin :: String
bin     = Int -> (Int -> Char) -> Int -> ShowS
forall a. Integral a => a -> (Int -> Char) -> a -> ShowS
showIntAtBase Int
2 Int -> Char
forall {a}. (Eq a, Num a) => a -> Char
fmt (BitString -> Int
bsData BitString
bs) String
""
  padding :: String
padding = Int -> Char -> String
forall a. Int -> a -> [a]
replicate (NumBits -> Int
bitCount (BitString -> NumBits
bsLength BitString
bs) Int -> Int -> Int
forall a. Num a => a -> a -> a
- String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
bin) Char
'0'
  fmt :: a -> Char
fmt a
0   = Char
'0'
  fmt a
1   = Char
'1'
  fmt a
_   = String -> Char
forall a. HasCallStack => String -> a
error String
"invalid binary digit value"


-- | Extract a smaller BitString with the specified number of bits from the
-- \"start\" of a larger BitString.
take :: NumBits -> BitString -> BitString
take :: NumBits -> BitString -> BitString
take NumBits
n bs :: BitString
bs@(BitString NumBits
l Int
i)
  | NumBits
n NumBits -> NumBits -> Bool
forall a. Ord a => a -> a -> Bool
>= NumBits
l    = BitString
bs
  | Bool
otherwise = NumBits -> Int -> BitString
toBitString NumBits
n Int
i


-- | Remove the specified number of bits from the beginning of a BitString and
-- return the remaining as a smaller BitString.

drop :: NumBits -> BitString -> BitString
drop :: NumBits -> BitString -> BitString
drop !NumBits
n !(BitString NumBits
l Int
i)
  | NumBits
n NumBits -> NumBits -> Bool
forall a. Ord a => a -> a -> Bool
>= NumBits
l    = BitString
emptyBitString
  | Bool
otherwise =
      let !(I# Int#
n#) = NumBits -> Int
bitCount NumBits
n
          !(I# Int#
l#) = NumBits -> Int
bitCount NumBits
l
          !(I# Int#
i#) = Int
i
      in NumBits -> Int -> BitString
BitString (Int -> NumBits
NumBits (Int# -> Int
I# (Int#
l# Int# -> Int# -> Int#
-# Int#
n#))) (Int# -> Int
I# (Int#
i# Int# -> Int# -> Int#
`uncheckedIShiftRL#` Int#
n#))