{-# LANGUAGE OverloadedStrings, BangPatterns, ScopedTypeVariables, ViewPatterns #-}

-------------------------------------------------------------------------------
-- |
-- Module:      Crypto.Encoding.SHA3.TupleHash
-- Copyright:   (c) 2024 Auth Global
-- License:     Apache2
--
-- See NIST Special Publication 800-185: SHA-3 Derived Functions: cSHAKE, KMAC,
-- TupleHash and ParallelHash. <https://www.nist.gov/publications/sha-3-derived-functions-cshake-kmac-tuplehash-and-parallelhash>
-- Note that this module does not implement TupleHash itself, though it could
-- be used to implement TupleHash given an implementation of Keccak/SHA-3. This
-- only implements the data-encoding portions of TupleHash, and does not
-- implement any kind of cryptographic state machine.
--
-------------------------------------------------------------------------------

module Crypto.Encoding.SHA3.TupleHash
  ( leftEncodeZero
  , leftEncodeInteger
  , leftEncodeIntegerFromBytes
  , leftEncode
  , leftEncodeFromBytes
  , encodeString
  , encodedByteLength
  , encodedVectorByteLength
  , bareEncodeZero
  , bareEncodeInteger
  , bareEncodeIntegerFromBytes
  , bareEncode
  , bareEncodeFromBytes
  , lengthOfBareEncode
  , lengthOfBareEncodeFromBytes
  , lengthOfBareEncodeInteger
  , lengthOfBareEncodeIntegerFromBytes
  , lengthOfLeftEncode
  , lengthOfLeftEncodeFromBytes
  , lengthOfLeftEncodeInteger
  , lengthOfLeftEncodeIntegerFromBytes
  ) where

import Data.Monoid((<>))
import Data.ByteString(ByteString)
import qualified Data.ByteString as B
import Data.Bits
import Data.List(foldl')
import Data.Word
import Math.NumberTheory.Logarithms(integerLog2)

downFrom :: (Num a, Enum a) => a -> [a]
downFrom x = [x-1,x-2..0]

leftEncodeZero :: ByteString
leftEncodeZero = "\x01\x00"

-- I don't like the current interface in that:
--   leftEncodeInteger returns Nothing when provided a negative result, but
--   leftEncode (on FiniteBits) returns leftEncodeZero when provided the same
--   If there were a way in Haskell to constrain the latter bit to unsigned Words, I'd do it
--   Then the more generic version that works on finite signed ints would return (Maybe ByteString)
--   The current interface is a bit of a compromise

leftEncodeInteger :: Integer -> Maybe ByteString
leftEncodeInteger n =
  case compare n 0 of
    LT -> Nothing
    EQ -> Just leftEncodeZero
    -- Note that the bit length is integerLog2 plus one
    -- Round up to the nearest byte by adding 7, then divide by 8
    -- simplifying, we can divide integerLog2 by eight, and then add one
    GT -> case shiftR (integerLog2 n) 3 + 1 of
           nSigBytes
             | nSigBytes > 255 -> Nothing
             | otherwise -> Just (B.pack (go (fromIntegral nSigBytes)))
  where
    go nSigBytes = fromIntegral nSigBytes : map getByte (downFrom nSigBytes)
    -- FIXME: using shiftR here results in a quadratic algorithm
    getByte ix = fromIntegral (shiftR n (8*ix) .&. 0xFF)

leftEncodeIntegerFromBytes :: Integer -> Maybe ByteString
leftEncodeIntegerFromBytes n = leftEncodeInteger (8*n)

leftEncode :: forall b. (Integral b, FiniteBits b) => b -> ByteString
leftEncode n
  | n <= 0 = leftEncodeZero
  | otherwise = B.pack output
  where
    wordLen = finiteBitSize n
    zeros = countLeadingZeros n
    nSigBits = wordLen - zeros
    nSigBytes = max 1 (shiftR (nSigBits + 7) 3)
    getByte :: Int -> Word8
    getByte ix = fromIntegral (shiftR n (8 * ix) .&. 255)
    output = fromIntegral nSigBytes : map getByte (downFrom nSigBytes)

leftEncodeFromBytes :: (Integral b, FiniteBits b) => b -> ByteString
leftEncodeFromBytes n
  | n <= 0 = leftEncodeZero
  | otherwise = B.pack output
  where
    wordLen = finiteBitSize n
    zeros = countLeadingZeros n
    nSigBits = wordLen - zeros + 3
    nSigBytes = max 1 (shiftR (nSigBits + 7) 3)
    getByte ix = fromIntegral (shift n (3 - 8 * ix) .&. 0xFF)
    output = fromIntegral nSigBytes : map getByte (downFrom nSigBytes)

{--
-- FIXME: this doesn't work for x0 > 7
leftEncodeBytesMinusBits :: Word -> Word8 -> ByteString
leftEncodeBytesMinusBits n0 x0 = B.pack (fromIntegral nSigBytes : output)
  where
    n = n0 - fromIntegral (fromEnum (x /= 0))
    wordLen = finiteBitSize n
    zeros = countLeadingZeros n
    nSigBits = wordLen - zeros + 3
    nSigBytes = fromIntegral (max 1 (shiftR (nSigBits + 7) 3)) :: Word8
    getByte ix = fromIntegral byte .|. mx
      where byte = shift n (3 - 8 * ix) .&. 0xFF
            mx   = (-x) * fromIntegral (fromEnum (ix /= 0))
    output = map getByte (nSigBytes `downTo` 0)

-- rightEncode :: Int -> ByteString
--}

encodeString :: ByteString -> ByteString
encodeString bytes
    | byteLen <= 0 = leftEncodeZero
    | otherwise = leftEncodeFromBytes byteLen <> bytes
  where
    byteLen = fromIntegral (B.length bytes) :: Word

encodedByteLength :: ByteString -> Int
encodedByteLength (B.length -> n) = lengthOfLeftEncodeFromBytes n + n

encodedVectorByteLength :: Foldable f => f ByteString -> Int
encodedVectorByteLength = foldl' (\a x -> a + encodedByteLength x) 0

{--
encodeBitString :: Word8 -> ByteString -> [ByteString]
encodeBitString truncBits bytes
   | byteLen <= 0 = [ "\x01\x00" ]
   | otherwise = [ leftEncodeBytesMinusBits byteLen truncBits


     leftEncode bitL, take (byteLength - 1) bytes,
--}

bareEncodeZero :: ByteString
bareEncodeZero = "\x00"

bareEncodeInteger :: Integer -> Maybe ByteString
bareEncodeInteger n =
  case compare n 0 of
    LT -> Nothing
    EQ -> Just bareEncodeZero
    GT -> let nSigBytes = shiftR (integerLog2 n) 3 + 1
           in Just (B.pack (go (fromIntegral nSigBytes)))
  where
    go nSigBytes = map getByte (downFrom nSigBytes)
    -- FIXME: using shiftR here results in a quadratic algorithm
    getByte ix = fromIntegral (shiftR n (8*ix) .&. 0xFF)

bareEncodeIntegerFromBytes :: Integer -> Maybe ByteString
bareEncodeIntegerFromBytes = fmap (B.drop 1) . leftEncodeIntegerFromBytes

bareEncode :: (Integral b, FiniteBits b) => b -> ByteString
bareEncode = B.drop 1 . leftEncode

bareEncodeFromBytes :: (Integral b, FiniteBits b) => b -> ByteString
bareEncodeFromBytes = B.drop 1 . leftEncodeFromBytes

lengthOfBareEncode :: (Integral b, FiniteBits b) => b -> Int
lengthOfBareEncode n
    | n <= 0 = 1
    | otherwise = nSigBytes
  where
    wordLen = finiteBitSize n
    zeros = countLeadingZeros n
    nSigBits = wordLen - zeros
    nSigBytes = max 1 (shiftR (nSigBits + 7) 3)

lengthOfBareEncodeFromBytes :: (Integral b, FiniteBits b) => b -> Int
lengthOfBareEncodeFromBytes n
    | n <= 0 = 1
    | otherwise = nSigBytes
  where
    wordLen = finiteBitSize n
    zeros = countLeadingZeros n
    nSigBits = wordLen - zeros + 3
    nSigBytes = max 1 (shiftR (nSigBits + 7) 3)

lengthOfBareEncodeInteger :: Integer -> Maybe Int
lengthOfBareEncodeInteger n =
  case compare n 0 of
    LT -> Nothing
    EQ -> Just 1
    GT -> Just (shiftR (integerLog2 n) 3 + 1)

lengthOfBareEncodeIntegerFromBytes :: Integer -> Maybe Int
lengthOfBareEncodeIntegerFromBytes n =
  case compare n 0 of
    LT -> Nothing
    EQ -> Just 1
    GT -> Just (shiftR (integerLog2 n + 3) 3 + 1)

lengthOfLeftEncode :: (Integral b, FiniteBits b) => b -> Int
lengthOfLeftEncode = (+1) . lengthOfBareEncode

lengthOfLeftEncodeFromBytes :: (Integral b, FiniteBits b) => b -> Int
lengthOfLeftEncodeFromBytes = (+1) . lengthOfBareEncodeFromBytes

lengthOfLeftEncodeInteger :: Integer -> Maybe Int
lengthOfLeftEncodeInteger n =
  case lengthOfBareEncodeInteger n of
    Nothing -> Nothing
    Just nSigBytes
       | nSigBytes > 255 -> Nothing
       | otherwise       -> Just (nSigBytes+1)

lengthOfLeftEncodeIntegerFromBytes :: Integer -> Maybe Int
lengthOfLeftEncodeIntegerFromBytes n =
  case lengthOfBareEncodeIntegerFromBytes n of
    Nothing -> Nothing
    Just nSigBytes
       | nSigBytes > 255 -> Nothing
       | otherwise       -> Just (nSigBytes+1)