{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric  #-}
{-# LANGUAGE Trustworthy    #-}
{-|
  Module        : Crypto.NewHope.Poly
  Description   : Polynomials and related operations.
  Copyright     : © Jeremy Bornstein 2019
  License       : Apache 2.0
  Maintainer    : jeremy@bornstein.org
  Stability     : experimental
  Portability   : portable

  Polynomials and related operations.

-}

module Crypto.NewHope.Poly where

import           Control.DeepSeq
import           Control.Monad.State         (join)
import           Data.Bits
import qualified Data.ByteString             as BS
import           Data.Int
import qualified Data.Vector.Unboxed         as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import           Data.Word
import           GHC.Generics                (Generic)
import           Prelude                     hiding (length)

import           Crypto.NewHope.FIPS202
import           Crypto.NewHope.Internals (N (N1024, N512))
import qualified Crypto.NewHope.Internals as Internals
import qualified Crypto.NewHope.NTT       as NTT
import           Crypto.NewHope.Precomp
import           Crypto.NewHope.Reduce    (montgomeryReduce)
import           MiscUtils


-- | Our Poly vectors are always of length N.
newtype Poly = Poly (VU.Vector Word16) deriving (Eq, Show, Generic, NFData)


-- | Bytes taken by regular encoded polynomials.
polyBytes :: N -> Int
polyBytes v = (14 * Internals.value v) `div` 8

-- | Bytes taken by a Poly-derived msg.
polyMsgBytes :: Int
polyMsgBytes = Internals.symBytes

-- | Bytes taken by a compressed Poly.
polyCompressedBytes :: N -> Int
polyCompressedBytes v = (3 * Internals.value v) `div` 8


-- | The length of this Poly
length :: Poly -> Int
length (Poly v) = VU.length v


-- | The parameter set to which this Poly belongs
getN :: Poly -> N
getN p = case length p of
           512  -> N512
           1024 -> N1024
           _    -> error "Unexpected Poly length"


-- | Fully reduces an integer modulo q in constant time.  Returns
-- integer in {0,...,q-1} congruent to x modulo q.
coeffFreeze :: Word16 -> Word16
coeffFreeze x = r'
  where
    q = fromIntegral Internals.q
    r = x `mod` q
    m = r - q
    c = fromIntegral (fromIntegral m :: Int16) :: Int16
    c' = shiftR c 15
    c'word = fromIntegral c' :: Word16
    r' = m `xor` ((r `xor` m) .&. c'word)


-- | computes |(x mod q) - Q/2|
flipabs :: Word16 -> Word16
flipabs x = fromIntegral $ xor m (r' + m)
  where
    q  :: Int
    q  = fromIntegral Internals.q
    r  = fromIntegral $ coeffFreeze x
    r' = r - (q `div` 2)
    m  = shiftR r' 15


-- | Deserialize
fromByteString :: BS.ByteString -> Poly
fromByteString a
    | not bytesOK = error $ "Invalid number (" ++ show bytes ++ ") of serialized bytes for Poly."
    | otherwise   = Poly result
  where
    bytes     = BS.length a
    bytes512  = polyBytes N512
    bytes1024 = polyBytes N1024
    bytesOK   = (bytes == bytes512) || (bytes == bytes1024)

    result   = VU.fromList $ fmap fromIntegral joined
    joined   = join folded
    folded   = Prelude.foldr go [] as
    as       = Prelude.take (bytes `div` 4) $ VU.fromList <$> chunk 7 (fromIntegral <$> BS.unpack a :: [Word16])

    go b c = [i0, i1, i2, i3] : c
      where
        b0 = b VU.! 0
        b1 = b VU.! 1
        b2 = b VU.! 2
        b3 = b VU.! 3
        b4 = b VU.! 4
        b5 = b VU.! 5
        b6 = b VU.! 6

        i0 =                        b0   .|. shiftL (b1 .&. 0x3f)  8
        i1 = shiftR b1 6 .|. shiftL b2 2 .|. shiftL (b3 .&. 0x0f) 10
        i2 = shiftR b3 4 .|. shiftL b4 4 .|. shiftL (b5 .&. 0x03) 12
        i3 = shiftR b5 2 .|. shiftL b6 6


-- | Serialize
toByteString :: Poly -> BS.ByteString
toByteString (Poly v) = results
  where
    results      = foldr go BS.empty inputVectors
    inputVectors = chunk 4 v

    go a = BS.append newItems
      where
        newItems = BS.pack $ fmap fromIntegral [i0, i1, i2, i3, i4, i5, i6]

        t0 = coeffFreeze $ a VU.! 0
        t1 = coeffFreeze $ a VU.! 1
        t2 = coeffFreeze $ a VU.! 2
        t3 = coeffFreeze $ a VU.! 3

        i0 =  t0 .&. 0xff
        i1 = shiftR t0  8 .|. shiftL t1 6
        i2 = shiftR t1  2
        i3 = shiftR t1 10 .|. shiftL t2 4
        i4 = shiftR t2  4
        i5 = shiftR t2 12 .|. shiftL t3 2
        i6 = shiftR t3  6


-- | Compression + serialization
compress :: Poly -> BS.ByteString
compress (Poly pData) = result
  where
    ts    = VU.map t pData
    input = chunk 8 ts

    t :: Word16 -> Word32
    t n = fromIntegral $ div (shiftL n' 3 + (q `div` 2)) q .&. 0x07
      where
        n' = fromIntegral $ coeffFreeze n
        q  = Internals.q

    result = BS.pack $ join $ fmap process input
      where
        process :: VU.Vector Word32 -> [Word8]
        process i = [ fromIntegral $       i0    .|. shiftL i1 3 .|. shiftL i2  6
                    , fromIntegral $ shiftR i2 2 .|. shiftL i3 1 .|. shiftL i4  4 .|. shiftL i5 7
                    , fromIntegral $ shiftR i5 1 .|. shiftL i6 2 .|. shiftL i7  5
                    ]
          where
            i0 = i VU.! 0
            i1 = i VU.! 1
            i2 = i VU.! 2
            i3 = i VU.! 3
            i4 = i VU.! 4
            i5 = i VU.! 5
            i6 = i VU.! 6
            i7 = i VU.! 7


-- | De-serialization and subsequent decompression of a polynomial;
-- approximate inverse of compress
decompress :: BS.ByteString -> Poly
decompress input = Poly $ VU.fromList result
  where
    inputChunks = chunk 3 $ VU.fromList (fromIntegral <$> BS.unpack input)
    process :: VU.Vector Word16 -> [Word16]
    process a = [        a0   .&. 7
                , shiftR a0 3 .&. 7
                , shiftR a0 6 .|. (shiftL a1 2 .&. 4)
                , shiftR a1 1 .&. 7
                , shiftR a1 4 .&. 7
                , shiftR a1 7 .|. (shiftL a2 1 .&. 6)
                , shiftR a2 2 .&. 7
                , shiftR a2 5
                ]
      where
        a0 = a VU.! 0
        a1 = a VU.! 1
        a2 = a VU.! 2

    finalize :: Word16 -> Word16
    finalize x = fromIntegral $ shiftR ((fromIntegral x :: Word32) * fromIntegral Internals.q + 4) 3
    result = fmap finalize $ join $ fmap process inputChunks


-- | Restore/convert from (32-byte) message
fromMsg :: N -> BS.ByteString -> Poly
fromMsg n msg = Poly vector'
  where
    msg' = VU.fromList $ BS.unpack msg
    empty = VU.replicate 256 0
    vector'
        | n == N512  = vector VU.++ vector
        | n == N1024 = vector VU.++ vector VU.++ vector VU.++ vector
        | otherwise  = error "Invalid N"
    vector = foldr go empty [0..31]
      where
        go i b = foldr go' b [0..7]
          where
            go' j = VU.modify (\v -> VUM.write v base value)
              where
                base = 8 * i + j
                mask = - ((fromIntegral (msg' VU.! i) `shiftR` j) .&. 1)
                value = mask .&. (fromIntegral Internals.q `div` 2)


-- | Convert polynomial to (32-byte) message
toMsg :: Poly -> BS.ByteString
toMsg p@(Poly x) = BS.pack result
  where
    result = foldr (.|.) 0 <$> chunked
    chunked = chunk 8 ts

    ts = t <$> [0..255]
      where
        n = getN p

        offsets
            | n == N512  = [0, 256]
            | n == N1024 = [0, 256, 512, 768]
            | otherwise  = error "Invalid vector size"

        tExtra :: Num a => a
        tExtra = fromIntegral $ if n == N1024
          then Internals.q
          else Internals.q `div` 2

        t :: Int -> Word8
        t i = fromIntegral shifted
          where
            offsets' = (+i) <$> offsets
            values   = (x VU.!) <$> offsets'
            values'  = flipabs <$> values
            summed   = sum values' - tExtra
            shifted  = shiftL (shiftR summed 15) (i .&. 7)


-- | Sample a polynomial deterministically from a seed, with output
-- polynomial looking uniformly random
uniform :: N -> Internals.Seed -> Poly
uniform n seed = Poly vector
  where
    Internals.Seed seed' = seed
    size = Internals.value n
    vector = let empty = VU.replicate size (0 :: Word16)
                 go :: Int -> VU.Vector Word16 -> VU.Vector Word16
                 go i victor = victor'
                   where
                     (_, victor') = let (buf, _) = let extseed = BS.snoc seed' (fromIntegral i)
                                                       staite  = shake128Absorb extseed
                                                   in shake128SqueezeBlocks staite 1
                                        bufBS    = VU.fromList $ BS.unpack buf
                                    in go' bufBS 0 0 victor

                     go' :: VU.Vector Word8 -> Int -> Int -> VU.Vector Word16 -> (Int, VU.Vector Word16)
                     go' buf ctr j vactor = if j' < shake128Rate && ctr' < 64
                                            then go' buf ctr' j' vactor'
                                            else (ctr', vactor')
                       where
                         val = let b0 = fromIntegral $ buf VU.! j
                                   b1 = fromIntegral $ buf VU.! (j + 1)
                               in b0 .|. shiftL b1 8 :: Word16
                         moveCounter = val < 5 * fromIntegral Internals.q
                         vactor' = if moveCounter
                                   then VU.modify (\v ->  VUM.write v (i * 64 + ctr) val) vactor
                                   else vactor
                         ctr' = if moveCounter
                                then ctr + 1
                                else ctr
                         j' = j + 2
               in foldr go empty [0..size `div` 64 - 1]


-- | The Hamming weight of a byte (the number of 1s)
hw :: Word8 -> Word8
hw a = sum [shiftR a i .&. 1  | i <- [0..7]]


-- | Sample a polynomial deterministically from a seed and a nonce,
-- with output polynomial close to centered binomial distribution with
-- parameter k=8
sample :: N -> Internals.Seed -> Word8 -> Poly
sample n seed nonce = Poly $ foldr go empty [0..size `div` 64 - 1]
  where
    size = Internals.value n
    empty = VU.replicate size 0
    seed' = let Internals.Seed seedData = seed
            in BS.snoc seedData nonce

    go i vector = foldr go' vector [0..63]
      where
        extseed = BS.snoc seed' $ fromIntegral i
        buf     = shake256 extseed 128

        go' j victor = victor'
          where
            a = fromIntegral.hw $ BS.index buf (2 * j)
            b = fromIntegral.hw $ BS.index buf (2 * j + 1)

            index   = 64 * i + j
            value   = a + fromIntegral Internals.q - b
            victor' = VU.modify (\v -> VUM.write v index value) victor


-- | Multiply two polynomials pointwise (i.e., coefficient-wise).
mulPointwise :: Poly -> Poly -> Poly
mulPointwise (Poly a) (Poly b) = Poly $ VU.zipWith go a b
-- NOTE: we don't check that these are the same length, which is fine given existing code.
-- NOTE: this is not commutative.
  where
    go c d = value
      where
        t     = montgomeryReduce (3186 * fromIntegral d) -- t is now in Montgomery domain
        value = montgomeryReduce (fromIntegral c * fromIntegral t) -- back in normal domain


-- | Add two polynomials
add :: Poly -> Poly -> Poly
add (Poly a) (Poly b) = Poly $ VU.zipWith go a b
-- NOTE: we don't check that these are the same length, which is fine given existing code.
-- NOTE: this is not commutative. should it be?
  where
    go c d = (c + d) `mod` fromIntegral Internals.q


-- | Subtract two polynomials
sub :: Poly -> Poly -> Poly
sub (Poly a) (Poly b) = Poly $ VU.zipWith go a b
  where
    q      = fromIntegral Internals.q
    go c d = (c + (3 * q) - d) `mod` q


-- | Forward NTT transform of a polynomial
-- input is assumed to have coefficients in bitreversed order output
-- has coefficients in normal order
ntt :: Poly  -- ^ input polynomial, in bitreversed order
    -> Poly  -- ^ transformed polynomial, in normal order
ntt p@(Poly r) = Poly result
  where
    n          = getN p
    multiplied = NTT.mulCoefficients r $ ψBitrevMontgomery n
    result     = NTT.ntt multiplied $ ωBitrevMontgomery n


-- | Inverse NTT transform of a polynomial
--
-- Output has coefficients in normal order
invntt :: Poly  -- ^ input, with coefficients in normal order
       -> Poly  -- ^ output, with coefficients in normal order
invntt p@(Poly r) = Poly result
  where
    n   = getN p
    r'     = NTT.bitrev r
    r''    = NTT.ntt r' $ ωInvBitrevMontgomery n
    result = NTT.mulCoefficients r'' $ ψInvMontgomery n