{-# LANGUAGE OverloadedStrings, BangPatterns, ScopedTypeVariables, ViewPatterns #-}

module Crypto.Encoding.SHA3.TupleHash
  ( leftEncodeZero
  , leftEncodeInteger
  , leftEncodeIntegerFromBytes
  , leftEncode
  , leftEncodeFromBytes
  , encodeString
  , encodedByteLength
  , encodedVectorByteLength
  , bareEncodeZero
  , bareEncodeInteger
  , bareEncodeIntegerFromBytes
  , bareEncode
  , bareEncodeFromBytes
  , lengthOfBareEncode
  , lengthOfBareEncodeFromBytes
  , lengthOfBareEncodeInteger
  , lengthOfBareEncodeIntegerFromBytes
  , lengthOfLeftEncode
  , lengthOfLeftEncodeFromBytes
  , lengthOfLeftEncodeInteger
  , lengthOfLeftEncodeIntegerFromBytes
  ) where

import Data.Monoid((<>))
import Data.ByteString(ByteString)
import qualified Data.ByteString as B
import Data.Bits
import Data.List(foldl')
import Data.Word
import Math.NumberTheory.Logarithms(integerLog2)

downFrom :: (Num a, Enum a) => a -> [a]
downFrom :: forall a. (Num a, Enum a) => a -> [a]
downFrom a
x = [a
xa -> a -> a
forall a. Num a => a -> a -> a
-a
1,a
xa -> a -> a
forall a. Num a => a -> a -> a
-a
2..a
0]

leftEncodeZero :: ByteString
leftEncodeZero :: ByteString
leftEncodeZero = ByteString
"\x01\x00"

-- I don't like the current interface in that:
--   leftEncodeInteger returns Nothing when provided a negative result, but
--   leftEncode (on FiniteBits) returns leftEncodeZero when provided the same
--   If there were a way in Haskell to constrain the latter bit to unsigned Words, I'd do it
--   Then the more generic version that works on finite signed ints would return (Maybe ByteString)
--   The current interface is a bit of a compromise

leftEncodeInteger :: Integer -> Maybe ByteString
leftEncodeInteger :: Integer -> Maybe ByteString
leftEncodeInteger Integer
n =
  case Integer -> Integer -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Integer
n Integer
0 of
    Ordering
LT -> Maybe ByteString
forall a. Maybe a
Nothing
    Ordering
EQ -> ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
leftEncodeZero
    -- Note that the bit length is integerLog2 plus one
    -- Round up to the nearest byte by adding 7, then divide by 8
    -- simplifying, we can divide integerLog2 by eight, and then add one
    Ordering
GT -> case Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftR (Integer -> Int
integerLog2 Integer
n) Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 of
           Int
nSigBytes
             | Int
nSigBytes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
255 -> Maybe ByteString
forall a. Maybe a
Nothing
             | Bool
otherwise -> ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ([Word8] -> ByteString
B.pack (Int -> [Word8]
forall {a}. Num a => Int -> [a]
go (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nSigBytes)))
  where
    go :: Int -> [a]
go Int
nSigBytes = Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nSigBytes a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (Int -> a) -> [Int] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Int -> a
forall {b}. Num b => Int -> b
getByte (Int -> [Int]
forall a. (Num a, Enum a) => a -> [a]
downFrom Int
nSigBytes)
    -- FIXME: using shiftR here results in a quadratic algorithm
    getByte :: Int -> b
getByte Int
ix = Integer -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
shiftR Integer
n (Int
8Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
ix) Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
0xFF)

leftEncodeIntegerFromBytes :: Integer -> Maybe ByteString
leftEncodeIntegerFromBytes :: Integer -> Maybe ByteString
leftEncodeIntegerFromBytes Integer
n = Integer -> Maybe ByteString
leftEncodeInteger (Integer
8Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
n)

leftEncode :: forall b. (Integral b, FiniteBits b) => b -> ByteString
leftEncode :: forall b. (Integral b, FiniteBits b) => b -> ByteString
leftEncode b
n
  | b
n b -> b -> Bool
forall a. Ord a => a -> a -> Bool
<= b
0 = ByteString
leftEncodeZero
  | Bool
otherwise = [Word8] -> ByteString
B.pack [Word8]
output
  where
    wordLen :: Int
wordLen = b -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize b
n
    zeros :: Int
zeros = b -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros b
n
    nSigBits :: Int
nSigBits = Int
wordLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
zeros
    nSigBytes :: Int
nSigBytes = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
1 (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftR (Int
nSigBits Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) Int
3)
    getByte :: Int -> Word8
    getByte :: Int -> Word8
getByte Int
ix = b -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (b -> Int -> b
forall a. Bits a => a -> Int -> a
shiftR b
n (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
ix) b -> b -> b
forall a. Bits a => a -> a -> a
.&. b
255)
    output :: [Word8]
output = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nSigBytes Word8 -> [Word8] -> [Word8]
forall a. a -> [a] -> [a]
: (Int -> Word8) -> [Int] -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Word8
getByte (Int -> [Int]
forall a. (Num a, Enum a) => a -> [a]
downFrom Int
nSigBytes)

leftEncodeFromBytes :: (Integral b, FiniteBits b) => b -> ByteString
leftEncodeFromBytes :: forall b. (Integral b, FiniteBits b) => b -> ByteString
leftEncodeFromBytes b
n
  | b
n b -> b -> Bool
forall a. Ord a => a -> a -> Bool
<= b
0 = ByteString
leftEncodeZero
  | Bool
otherwise = [Word8] -> ByteString
B.pack [Word8]
output
  where
    wordLen :: Int
wordLen = b -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize b
n
    zeros :: Int
zeros = b -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros b
n
    nSigBits :: Int
nSigBits = Int
wordLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
zeros Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3
    nSigBytes :: Int
nSigBytes = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
1 (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftR (Int
nSigBits Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) Int
3)
    getByte :: Int -> b
getByte Int
ix = b -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral (b -> Int -> b
forall a. Bits a => a -> Int -> a
shift b
n (Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
ix) b -> b -> b
forall a. Bits a => a -> a -> a
.&. b
0xFF)
    output :: [Word8]
output = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nSigBytes Word8 -> [Word8] -> [Word8]
forall a. a -> [a] -> [a]
: (Int -> Word8) -> [Int] -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Word8
forall {b}. Num b => Int -> b
getByte (Int -> [Int]
forall a. (Num a, Enum a) => a -> [a]
downFrom Int
nSigBytes)

{--
-- FIXME: this doesn't work for x0 > 7
leftEncodeBytesMinusBits :: Word -> Word8 -> ByteString
leftEncodeBytesMinusBits n0 x0 = B.pack (fromIntegral nSigBytes : output)
  where
    n = n0 - fromIntegral (fromEnum (x /= 0))
    wordLen = finiteBitSize n
    zeros = countLeadingZeros n
    nSigBits = wordLen - zeros + 3
    nSigBytes = fromIntegral (max 1 (shiftR (nSigBits + 7) 3)) :: Word8
    getByte ix = fromIntegral byte .|. mx
      where byte = shift n (3 - 8 * ix) .&. 0xFF
            mx   = (-x) * fromIntegral (fromEnum (ix /= 0))
    output = map getByte (nSigBytes `downTo` 0)

-- rightEncode :: Int -> ByteString
--}

encodeString :: ByteString -> ByteString
encodeString :: ByteString -> ByteString
encodeString ByteString
bytes
    | Word
byteLen Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
<= Word
0 = ByteString
leftEncodeZero
    | Bool
otherwise = Word -> ByteString
forall b. (Integral b, FiniteBits b) => b -> ByteString
leftEncodeFromBytes Word
byteLen ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
bytes
  where
    byteLen :: Word
byteLen = Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
bytes) :: Word

encodedByteLength :: ByteString -> Int
encodedByteLength :: ByteString -> Int
encodedByteLength (ByteString -> Int
B.length -> Int
n) = Int -> Int
forall b. (Integral b, FiniteBits b) => b -> Int
lengthOfLeftEncode Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n

encodedVectorByteLength :: Foldable f => f ByteString -> Int
encodedVectorByteLength :: forall (f :: * -> *). Foldable f => f ByteString -> Int
encodedVectorByteLength = (Int -> ByteString -> Int) -> Int -> f ByteString -> Int
forall b a. (b -> a -> b) -> b -> f a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Int
a ByteString
x -> Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ByteString -> Int
encodedByteLength ByteString
x) Int
0

{--
encodeBitString :: Word8 -> ByteString -> [ByteString]
encodeBitString truncBits bytes
   | byteLen <= 0 = [ "\x01\x00" ]
   | otherwise = [ leftEncodeBytesMinusBits byteLen truncBits


     leftEncode bitL, take (byteLength - 1) bytes,
--}

bareEncodeZero :: ByteString
bareEncodeZero :: ByteString
bareEncodeZero = ByteString
"\x00"

bareEncodeInteger :: Integer -> Maybe ByteString
bareEncodeInteger :: Integer -> Maybe ByteString
bareEncodeInteger Integer
n =
  case Integer -> Integer -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Integer
n Integer
0 of
    Ordering
LT -> Maybe ByteString
forall a. Maybe a
Nothing
    Ordering
EQ -> ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
leftEncodeZero
    Ordering
GT -> let nSigBytes :: Int
nSigBytes = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftR (Integer -> Int
integerLog2 Integer
n) Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
           in ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ([Word8] -> ByteString
B.pack (Int -> [Word8]
forall {a}. Num a => Int -> [a]
go (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nSigBytes)))
  where
    go :: Int -> [a]
go Int
nSigBytes = Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nSigBytes a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (Int -> a) -> [Int] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Int -> a
forall {b}. Num b => Int -> b
getByte (Int -> [Int]
forall a. (Num a, Enum a) => a -> [a]
downFrom Int
nSigBytes)
    -- FIXME: using shiftR here results in a quadratic algorithm
    getByte :: Int -> b
getByte Int
ix = Integer -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
shiftR Integer
n (Int
8Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
ix) Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
0xFF)

bareEncodeIntegerFromBytes :: Integer -> Maybe ByteString
bareEncodeIntegerFromBytes :: Integer -> Maybe ByteString
bareEncodeIntegerFromBytes = (ByteString -> ByteString) -> Maybe ByteString -> Maybe ByteString
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> ByteString -> ByteString
B.drop Int
1) (Maybe ByteString -> Maybe ByteString)
-> (Integer -> Maybe ByteString) -> Integer -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Maybe ByteString
leftEncodeIntegerFromBytes

bareEncode :: (Integral b, FiniteBits b) => b -> ByteString
bareEncode :: forall b. (Integral b, FiniteBits b) => b -> ByteString
bareEncode = Int -> ByteString -> ByteString
B.drop Int
1 (ByteString -> ByteString) -> (b -> ByteString) -> b -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> ByteString
forall b. (Integral b, FiniteBits b) => b -> ByteString
leftEncode

bareEncodeFromBytes :: (Integral b, FiniteBits b) => b -> ByteString
bareEncodeFromBytes :: forall b. (Integral b, FiniteBits b) => b -> ByteString
bareEncodeFromBytes = Int -> ByteString -> ByteString
B.drop Int
1 (ByteString -> ByteString) -> (b -> ByteString) -> b -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> ByteString
forall b. (Integral b, FiniteBits b) => b -> ByteString
leftEncodeFromBytes

lengthOfBareEncode :: (Integral b, FiniteBits b) => b -> Int
lengthOfBareEncode :: forall b. (Integral b, FiniteBits b) => b -> Int
lengthOfBareEncode b
n
    | b
n b -> b -> Bool
forall a. Ord a => a -> a -> Bool
<= b
0 = Int
1
    | Bool
otherwise = Int
nSigBytes
  where
    wordLen :: Int
wordLen = b -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize b
n
    zeros :: Int
zeros = b -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros b
n
    nSigBits :: Int
nSigBits = Int
wordLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
zeros Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3
    nSigBytes :: Int
nSigBytes = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
1 (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftR (Int
nSigBits Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) Int
3)

lengthOfBareEncodeFromBytes :: (Integral b, FiniteBits b) => b -> Int
lengthOfBareEncodeFromBytes :: forall b. (Integral b, FiniteBits b) => b -> Int
lengthOfBareEncodeFromBytes b
n
    | b
n b -> b -> Bool
forall a. Ord a => a -> a -> Bool
<= b
0 = Int
1
    | Bool
otherwise = Int
nSigBytes
  where
    wordLen :: Int
wordLen = b -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize b
n
    zeros :: Int
zeros = b -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros b
n
    nSigBits :: Int
nSigBits = Int
wordLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
zeros Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3
    nSigBytes :: Int
nSigBytes = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
1 (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftR (Int
nSigBits Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) Int
3)

lengthOfBareEncodeInteger :: Integer -> Maybe Int
lengthOfBareEncodeInteger :: Integer -> Maybe Int
lengthOfBareEncodeInteger Integer
n =
  case Integer -> Integer -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Integer
n Integer
0 of
    Ordering
LT -> Maybe Int
forall a. Maybe a
Nothing
    Ordering
EQ -> Int -> Maybe Int
forall a. a -> Maybe a
Just Int
1
    Ordering
GT -> Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftR (Integer -> Int
integerLog2 Integer
n) Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

lengthOfBareEncodeIntegerFromBytes :: Integer -> Maybe Int
lengthOfBareEncodeIntegerFromBytes :: Integer -> Maybe Int
lengthOfBareEncodeIntegerFromBytes Integer
n =
  case Integer -> Integer -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Integer
n Integer
0 of
    Ordering
LT -> Maybe Int
forall a. Maybe a
Nothing
    Ordering
EQ -> Int -> Maybe Int
forall a. a -> Maybe a
Just Int
1
    Ordering
GT -> Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftR (Integer -> Int
integerLog2 Integer
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3) Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

lengthOfLeftEncode :: (Integral b, FiniteBits b) => b -> Int
lengthOfLeftEncode :: forall b. (Integral b, FiniteBits b) => b -> Int
lengthOfLeftEncode = (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int -> Int) -> (b -> Int) -> b -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Int
forall b. (Integral b, FiniteBits b) => b -> Int
lengthOfBareEncode

lengthOfLeftEncodeFromBytes :: (Integral b, FiniteBits b) => b -> Int
lengthOfLeftEncodeFromBytes :: forall b. (Integral b, FiniteBits b) => b -> Int
lengthOfLeftEncodeFromBytes = (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int -> Int) -> (b -> Int) -> b -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Int
forall b. (Integral b, FiniteBits b) => b -> Int
lengthOfBareEncodeFromBytes

lengthOfLeftEncodeInteger :: Integer -> Maybe Int
lengthOfLeftEncodeInteger :: Integer -> Maybe Int
lengthOfLeftEncodeInteger Integer
n =
  case Integer -> Maybe Int
lengthOfBareEncodeInteger Integer
n of
    Maybe Int
Nothing -> Maybe Int
forall a. Maybe a
Nothing
    Just Int
nSigBytes
       | Int
nSigBytes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
255 -> Maybe Int
forall a. Maybe a
Nothing
       | Bool
otherwise       -> Int -> Maybe Int
forall a. a -> Maybe a
Just (Int
nSigBytesInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)

lengthOfLeftEncodeIntegerFromBytes :: Integer -> Maybe Int
lengthOfLeftEncodeIntegerFromBytes :: Integer -> Maybe Int
lengthOfLeftEncodeIntegerFromBytes Integer
n =
  case Integer -> Maybe Int
lengthOfBareEncodeIntegerFromBytes Integer
n of
    Maybe Int
Nothing -> Maybe Int
forall a. Maybe a
Nothing
    Just Int
nSigBytes
       | Int
nSigBytes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
255 -> Maybe Int
forall a. Maybe a
Nothing
       | Bool
otherwise       -> Int -> Maybe Int
forall a. a -> Maybe a
Just (Int
nSigBytesInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)