{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE Trustworthy      #-}
{-|
  Module        : Crypto.NewHope.NTT
  Description   : Number-theoretic transforms for NewHope.
  Copyright     : © Jeremy Bornstein 2019
  License       : Apache 2.0
  Maintainer    : jeremy@bornstein.org
  Stability     : experimental
  Portability   : portable

  Number-theoretic transforms for NewHope.

  See: https://en.wikipedia.org/wiki/Discrete_Fourier_transform_(general)#Number-theoretic_transform

-}

module Crypto.NewHope.NTT where


import           Data.Bits
import qualified Data.Vector.Unboxed as VU
import           Data.Word

import qualified Crypto.NewHope.Internals as Internals
import           Crypto.NewHope.Reduce    (montgomeryReduce)


type NTTVector = VU.Vector Word16

-- | Bit-reverse a vector.  See https://en.wikipedia.org/wiki/Bit-reversal_permutation
bitrev :: NTTVector -> NTTVector
bitrev p = VU.imap go p
  where
    table = case len of
      512  -> bitrevTable512
      1024 -> bitrevTable1024
      _    -> error "unexpected vector length"
    len = VU.length p
    go i _ = p VU.! fromIntegral (table VU.! i)


-- | Performs pointwise (coefficient-wise) multiplication of two
-- polynomials.  The factors parameter contains coefficients assumed
-- to be in Montgomery representation.
mulCoefficients :: NTTVector  -- ^ first polynomial, in standard representation
                -> NTTVector  -- ^ second polynomial ("factors"), in Montgomery representation
                -> NTTVector
mulCoefficients = VU.zipWith go
  where
    go a b = montgomeryReduce (fromIntegral a * fromIntegral b)


getPreReduce :: Integral a => a -> a -> a -> Word32
getPreReduce a b w = w' * (a' + (3 :: Word32) * q - b')
  where
    a' = fromIntegral a :: Word32
    b' = fromIntegral b :: Word32
    q  = fromIntegral Internals.q :: Word32
    w' = fromIntegral w :: Word32


-- | Computes number-theoretic transform (NTT) of a polynomial.
-- Inputs are bit-reversed, output is in standard order.
ntt :: NTTVector  -- ^ Input polynomial, in bit-reversed order
    -> NTTVector  -- ^ Powers of root of unity omega, in Montgomery domain
    -> NTTVector  -- ^ Result polynomial, in normal order
ntt vec ω = foldl go vec [0, 2, 4, 6, 8]
  where
    vectorLength = VU.length vec

    modVec :: NTTVector -> Int -> Word16 -> Int -> Word16 -> NTTVector
    modVec διάνυσμα index value index2 value2 = διάνυσμα VU.// [(index, value), (index2, value2)]

    go :: NTTVector -> Int -> NTTVector
    go διάνυσμα i = if i >= 8 && vectorLength == 512
                    then vec' else vec''
      where
        vec'         = foldl (process distanceEven True) διάνυσμα [0..distanceEven-1]
        vec''        = foldl (process distanceOdd False) vec' [0..distanceOdd-1]
        distanceEven = shift 1 i
        distanceOdd  = shift distanceEven 1

        process :: Int -> Bool -> NTTVector -> Int -> NTTVector
        process distance isEven vic start = foldl jLoop vic joinedParams
          where
            increment  = 2 * distance
            second     = start + increment
            limit      = vectorLength - 2
            items      = [start, second..limit]
            joinedParams = zip items [0..]

            jLoop :: NTTVector -> (Int, Int) -> NTTVector
            jLoop vac (j, jTwiddle) = modVec vac j value0 j' value1
              where
                j'        = j + distance
                w         = ω VU.! jTwiddle
                temp      = vac VU.! j
                ajdist    = vac VU.! j'
                value0    = if isEven
                            then temp + ajdist
                            else (temp + ajdist) `mod` fromIntegral Internals.q
                preReduce = getPreReduce temp ajdist w
                value1    = montgomeryReduce preReduce


-- | Contains bit-reversed 9-bit indices to be used to re-order
-- polynomials before number theoratic transform
bitrevTable :: Int -> NTTVector
bitrevTable bits = case bits of
  512  -> bitrevTable512
  1024 -> bitrevTable1024
  _    -> error "Unsupported"

-- | The 512-item version of the vector
bitrevTable512 :: NTTVector
bitrevTable512 = VU.fromList [ 0,256,128,384,64,320,192,448,32,288,160,416,96,352,224,480,16,272,144,400,80,336,208,
                               464,48,304,176,432,112,368,240,496,8,264,136,392,72,328,200,456,40,296,168,424,104,
                               360,232,488,24,280,152,408,88,344,216,472,56,312,184,440,120,376,248,504,4,260,132,
                               388,68,324,196,452,36,292,164,420,100,356,228,484,20,276,148,404,84,340,212,468,52,
                               308,180,436,116,372,244,500,12,268,140,396,76,332,204,460,44,300,172,428,108,364,236,
                               492,28,284,156,412,92,348,220,476,60,316,188,444,124,380,252,508,2,258,130,386,66,
                               322,194,450,34,290,162,418,98,354,226,482,18,274,146,402,82,338,210,466,50,306,178,
                               434,114,370,242,498,10,266,138,394,74,330,202,458,42,298,170,426,106,362,234,490,26,
                               282,154,410,90,346,218,474,58,314,186,442,122,378,250,506,6,262,134,390,70,326,198,
                               454,38,294,166,422,102,358,230,486,22,278,150,406,86,342,214,470,54,310,182,438,118,
                               374,246,502,14,270,142,398,78,334,206,462,46,302,174,430,110,366,238,494,30,286,158,
                               414,94,350,222,478,62,318,190,446,126,382,254,510,1,257,129,385,65,321,193,449,33,
                               289,161,417,97,353,225,481,17,273,145,401,81,337,209,465,49,305,177,433,113,369,241,
                               497,9,265,137,393,73,329,201,457,41,297,169,425,105,361,233,489,25,281,153,409,89,
                               345,217,473,57,313,185,441,121,377,249,505,5,261,133,389,69,325,197,453,37,293,165,
                               421,101,357,229,485,21,277,149,405,85,341,213,469,53,309,181,437,117,373,245,501,13,
                               269,141,397,77,333,205,461,45,301,173,429,109,365,237,493,29,285,157,413,93,349,221,
                               477,61,317,189,445,125,381,253,509,3,259,131,387,67,323,195,451,35,291,163,419,99,
                               355,227,483,19,275,147,403,83,339,211,467,51,307,179,435,115,371,243,499,11,267,139,
                               395,75,331,203,459,43,299,171,427,107,363,235,491,27,283,155,411,91,347,219,475,59,
                               315,187,443,123,379,251,507,7,263,135,391,71,327,199,455,39,295,167,423,103,359,231,
                               487,23,279,151,407,87,343,215,471,55,311,183,439,119,375,247,503,15,271,143,399,79,
                               335,207,463,47,303,175,431,111,367,239,495,31,287,159,415,95,351,223,479,63,319,191,
                               447,127,383,255,511]


-- | The 1024-item version of the vector
bitrevTable1024 :: NTTVector
bitrevTable1024 = VU.fromList [ 0,512,256,768,128,640,384,896,64,576,320,832,192,704,448,960,32,544,288,800,160,672,
                                416,928,96,608,352,864,224,736,480,992,16,528,272,784,144,656,400,912,80,592,336,
                                848,208,720,464,976,48,560,304,816,176,688,432,944,112,624,368,880,240,752,496,1008,
                                8,520,264,776,136,648,392,904,72,584,328,840,200,712,456,968,40,552,296,808,168,680,
                                424,936,104,616,360,872,232,744,488,1000,24,536,280,792,152,664,408,920,88,600,344,
                                856,216,728,472,984,56,568,312,824,184,696,440,952,120,632,376,888,248,760,504,1016,
                                4,516,260,772,132,644,388,900,68,580,324,836,196,708,452,964,36,548,292,804,164,676,
                                420,932,100,612,356,868,228,740,484,996,20,532,276,788,148,660,404,916,84,596,340,
                                852,212,724,468,980,52,564,308,820,180,692,436,948,116,628,372,884,244,756,500,1012,
                                12,524,268,780,140,652,396,908,76,588,332,844,204,716,460,972,44,556,300,812,172,
                                684,428,940,108,620,364,876,236,748,492,1004,28,540,284,796,156,668,412,924,92,604,
                                348,860,220,732,476,988,60,572,316,828,188,700,444,956,124,636,380,892,252,764,508,
                                1020,2,514,258,770,130,642,386,898,66,578,322,834,194,706,450,962,34,546,290,802,
                                162,674,418,930,98,610,354,866,226,738,482,994,18,530,274,786,146,658,402,914,82,
                                594,338,850,210,722,466,978,50,562,306,818,178,690,434,946,114,626,370,882,242,754,
                                498,1010,10,522,266,778,138,650,394,906,74,586,330,842,202,714,458,970,42,554,298,
                                810,170,682,426,938,106,618,362,874,234,746,490,1002,26,538,282,794,154,666,410,922,
                                90,602,346,858,218,730,474,986,58,570,314,826,186,698,442,954,122,634,378,890,250,
                                762,506,1018,6,518,262,774,134,646,390,902,70,582,326,838,198,710,454,966,38,550,
                                294,806,166,678,422,934,102,614,358,870,230,742,486,998,22,534,278,790,150,662,406,
                                918,86,598,342,854,214,726,470,982,54,566,310,822,182,694,438,950,118,630,374,886,
                                246,758,502,1014,14,526,270,782,142,654,398,910,78,590,334,846,206,718,462,974,46,
                                558,302,814,174,686,430,942,110,622,366,878,238,750,494,1006,30,542,286,798,158,670,
                                414,926,94,606,350,862,222,734,478,990,62,574,318,830,190,702,446,958,126,638,382,
                                894,254,766,510,1022,1,513,257,769,129,641,385,897,65,577,321,833,193,705,449,961,
                                33,545,289,801,161,673,417,929,97,609,353,865,225,737,481,993,17,529,273,785,145,
                                657,401,913,81,593,337,849,209,721,465,977,49,561,305,817,177,689,433,945,113,625,
                                369,881,241,753,497,1009,9,521,265,777,137,649,393,905,73,585,329,841,201,713,457,
                                969,41,553,297,809,169,681,425,937,105,617,361,873,233,745,489,1001,25,537,281,793,
                                153,665,409,921,89,601,345,857,217,729,473,985,57,569,313,825,185,697,441,953,121,
                                633,377,889,249,761,505,1017,5,517,261,773,133,645,389,901,69,581,325,837,197,709,
                                453,965,37,549,293,805,165,677,421,933,101,613,357,869,229,741,485,997,21,533,277,
                                789,149,661,405,917,85,597,341,853,213,725,469,981,53,565,309,821,181,693,437,949,
                                117,629,373,885,245,757,501,1013,13,525,269,781,141,653,397,909,77,589,333,845,205,
                                717,461,973,45,557,301,813,173,685,429,941,109,621,365,877,237,749,493,1005,29,541,
                                285,797,157,669,413,925,93,605,349,861,221,733,477,989,61,573,317,829,189,701,445,
                                957,125,637,381,893,253,765,509,1021,3,515,259,771,131,643,387,899,67,579,323,835,
                                195,707,451,963,35,547,291,803,163,675,419,931,99,611,355,867,227,739,483,995,19,
                                531,275,787,147,659,403,915,83,595,339,851,211,723,467,979,51,563,307,819,179,691,
                                435,947,115,627,371,883,243,755,499,1011,11,523,267,779,139,651,395,907,75,587,331,
                                843,203,715,459,971,43,555,299,811,171,683,427,939,107,619,363,875,235,747,491,1003,
                                27,539,283,795,155,667,411,923,91,603,347,859,219,731,475,987,59,571,315,827,187,
                                699,443,955,123,635,379,891,251,763,507,1019,7,519,263,775,135,647,391,903,71,583,
                                327,839,199,711,455,967,39,551,295,807,167,679,423,935,103,615,359,871,231,743,487,
                                999,23,535,279,791,151,663,407,919,87,599,343,855,215,727,471,983,55,567,311,823,
                                183,695,439,951,119,631,375,887,247,759,503,1015,15,527,271,783,143,655,399,911,79,
                                591,335,847,207,719,463,975,47,559,303,815,175,687,431,943,111,623,367,879,239,751,
                                495,1007,31,543,287,799,159,671,415,927,95,607,351,863,223,735,479,991,63,575,319,
                                831,191,703,447,959,127,639,383,895,255,767,511,1023]