{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE Trustworthy #-}
module Crypto.NewHope.Poly where
import Control.DeepSeq
import Control.Monad.State (join)
import Data.Bits
import qualified Data.ByteString as BS
import Data.Int
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import Data.Word
import GHC.Generics (Generic)
import Prelude hiding (length)
import Crypto.NewHope.FIPS202
import Crypto.NewHope.Internals (N (N1024, N512))
import qualified Crypto.NewHope.Internals as Internals
import qualified Crypto.NewHope.NTT as NTT
import Crypto.NewHope.Precomp
import Crypto.NewHope.Reduce (montgomeryReduce)
import MiscUtils
newtype Poly = Poly (VU.Vector Word16) deriving (Eq, Show, Generic, NFData)
polyBytes :: N -> Int
polyBytes v = (14 * Internals.value v) `div` 8
polyMsgBytes :: Int
polyMsgBytes = Internals.symBytes
polyCompressedBytes :: N -> Int
polyCompressedBytes v = (3 * Internals.value v) `div` 8
length :: Poly -> Int
length (Poly v) = VU.length v
getN :: Poly -> N
getN p = case length p of
512 -> N512
1024 -> N1024
_ -> error "Unexpected Poly length"
coeffFreeze :: Word16 -> Word16
coeffFreeze x = r'
where
q = fromIntegral Internals.q
r = x `mod` q
m = r - q
c = fromIntegral (fromIntegral m :: Int16) :: Int16
c' = shiftR c 15
c'word = fromIntegral c' :: Word16
r' = m `xor` ((r `xor` m) .&. c'word)
flipabs :: Word16 -> Word16
flipabs x = fromIntegral $ xor m (r' + m)
where
q :: Int
q = fromIntegral Internals.q
r = fromIntegral $ coeffFreeze x
r' = r - (q `div` 2)
m = shiftR r' 15
fromByteString :: BS.ByteString -> Poly
fromByteString a
| not bytesOK = error $ "Invalid number (" ++ show bytes ++ ") of serialized bytes for Poly."
| otherwise = Poly result
where
bytes = BS.length a
bytes512 = polyBytes N512
bytes1024 = polyBytes N1024
bytesOK = (bytes == bytes512) || (bytes == bytes1024)
result = VU.fromList $ fmap fromIntegral joined
joined = join folded
folded = Prelude.foldr go [] as
as = Prelude.take (bytes `div` 4) $ VU.fromList <$> chunk 7 (fromIntegral <$> BS.unpack a :: [Word16])
go b c = [i0, i1, i2, i3] : c
where
b0 = b VU.! 0
b1 = b VU.! 1
b2 = b VU.! 2
b3 = b VU.! 3
b4 = b VU.! 4
b5 = b VU.! 5
b6 = b VU.! 6
i0 = b0 .|. shiftL (b1 .&. 0x3f) 8
i1 = shiftR b1 6 .|. shiftL b2 2 .|. shiftL (b3 .&. 0x0f) 10
i2 = shiftR b3 4 .|. shiftL b4 4 .|. shiftL (b5 .&. 0x03) 12
i3 = shiftR b5 2 .|. shiftL b6 6
toByteString :: Poly -> BS.ByteString
toByteString (Poly v) = results
where
results = foldr go BS.empty inputVectors
inputVectors = chunk 4 v
go a = BS.append newItems
where
newItems = BS.pack $ fmap fromIntegral [i0, i1, i2, i3, i4, i5, i6]
t0 = coeffFreeze $ a VU.! 0
t1 = coeffFreeze $ a VU.! 1
t2 = coeffFreeze $ a VU.! 2
t3 = coeffFreeze $ a VU.! 3
i0 = t0 .&. 0xff
i1 = shiftR t0 8 .|. shiftL t1 6
i2 = shiftR t1 2
i3 = shiftR t1 10 .|. shiftL t2 4
i4 = shiftR t2 4
i5 = shiftR t2 12 .|. shiftL t3 2
i6 = shiftR t3 6
compress :: Poly -> BS.ByteString
compress (Poly pData) = result
where
ts = VU.map t pData
input = chunk 8 ts
t :: Word16 -> Word32
t n = fromIntegral $ div (shiftL n' 3 + (q `div` 2)) q .&. 0x07
where
n' = fromIntegral $ coeffFreeze n
q = Internals.q
result = BS.pack $ join $ fmap process input
where
process :: VU.Vector Word32 -> [Word8]
process i = [ fromIntegral $ i0 .|. shiftL i1 3 .|. shiftL i2 6
, fromIntegral $ shiftR i2 2 .|. shiftL i3 1 .|. shiftL i4 4 .|. shiftL i5 7
, fromIntegral $ shiftR i5 1 .|. shiftL i6 2 .|. shiftL i7 5
]
where
i0 = i VU.! 0
i1 = i VU.! 1
i2 = i VU.! 2
i3 = i VU.! 3
i4 = i VU.! 4
i5 = i VU.! 5
i6 = i VU.! 6
i7 = i VU.! 7
decompress :: BS.ByteString -> Poly
decompress input = Poly $ VU.fromList result
where
inputChunks = chunk 3 $ VU.fromList (fromIntegral <$> BS.unpack input)
process :: VU.Vector Word16 -> [Word16]
process a = [ a0 .&. 7
, shiftR a0 3 .&. 7
, shiftR a0 6 .|. (shiftL a1 2 .&. 4)
, shiftR a1 1 .&. 7
, shiftR a1 4 .&. 7
, shiftR a1 7 .|. (shiftL a2 1 .&. 6)
, shiftR a2 2 .&. 7
, shiftR a2 5
]
where
a0 = a VU.! 0
a1 = a VU.! 1
a2 = a VU.! 2
finalize :: Word16 -> Word16
finalize x = fromIntegral $ shiftR ((fromIntegral x :: Word32) * fromIntegral Internals.q + 4) 3
result = fmap finalize $ join $ fmap process inputChunks
fromMsg :: N -> BS.ByteString -> Poly
fromMsg n msg = Poly vector'
where
msg' = VU.fromList $ BS.unpack msg
empty = VU.replicate 256 0
vector'
| n == N512 = vector VU.++ vector
| n == N1024 = vector VU.++ vector VU.++ vector VU.++ vector
| otherwise = error "Invalid N"
vector = foldr go empty [0..31]
where
go i b = foldr go' b [0..7]
where
go' j = VU.modify (\v -> VUM.write v base value)
where
base = 8 * i + j
mask = - ((fromIntegral (msg' VU.! i) `shiftR` j) .&. 1)
value = mask .&. (fromIntegral Internals.q `div` 2)
toMsg :: Poly -> BS.ByteString
toMsg p@(Poly x) = BS.pack result
where
result = foldr (.|.) 0 <$> chunked
chunked = chunk 8 ts
ts = t <$> [0..255]
where
n = getN p
offsets
| n == N512 = [0, 256]
| n == N1024 = [0, 256, 512, 768]
| otherwise = error "Invalid vector size"
tExtra :: Num a => a
tExtra = fromIntegral $ if n == N1024
then Internals.q
else Internals.q `div` 2
t :: Int -> Word8
t i = fromIntegral shifted
where
offsets' = (+i) <$> offsets
values = (x VU.!) <$> offsets'
values' = flipabs <$> values
summed = sum values' - tExtra
shifted = shiftL (shiftR summed 15) (i .&. 7)
uniform :: N -> Internals.Seed -> Poly
uniform n seed = Poly vector
where
Internals.Seed seed' = seed
size = Internals.value n
vector = let empty = VU.replicate size (0 :: Word16)
go :: Int -> VU.Vector Word16 -> VU.Vector Word16
go i victor = victor'
where
(_, victor') = let (buf, _) = let extseed = BS.snoc seed' (fromIntegral i)
staite = shake128Absorb extseed
in shake128SqueezeBlocks staite 1
bufBS = VU.fromList $ BS.unpack buf
in go' bufBS 0 0 victor
go' :: VU.Vector Word8 -> Int -> Int -> VU.Vector Word16 -> (Int, VU.Vector Word16)
go' buf ctr j vactor = if j' < shake128Rate && ctr' < 64
then go' buf ctr' j' vactor'
else (ctr', vactor')
where
val = let b0 = fromIntegral $ buf VU.! j
b1 = fromIntegral $ buf VU.! (j + 1)
in b0 .|. shiftL b1 8 :: Word16
moveCounter = val < 5 * fromIntegral Internals.q
vactor' = if moveCounter
then VU.modify (\v -> VUM.write v (i * 64 + ctr) val) vactor
else vactor
ctr' = if moveCounter
then ctr + 1
else ctr
j' = j + 2
in foldr go empty [0..size `div` 64 - 1]
hw :: Word8 -> Word8
hw a = sum [shiftR a i .&. 1 | i <- [0..7]]
sample :: N -> Internals.Seed -> Word8 -> Poly
sample n seed nonce = Poly $ foldr go empty [0..size `div` 64 - 1]
where
size = Internals.value n
empty = VU.replicate size 0
seed' = let Internals.Seed seedData = seed
in BS.snoc seedData nonce
go i vector = foldr go' vector [0..63]
where
extseed = BS.snoc seed' $ fromIntegral i
buf = shake256 extseed 128
go' j victor = victor'
where
a = fromIntegral.hw $ BS.index buf (2 * j)
b = fromIntegral.hw $ BS.index buf (2 * j + 1)
index = 64 * i + j
value = a + fromIntegral Internals.q - b
victor' = VU.modify (\v -> VUM.write v index value) victor
mulPointwise :: Poly -> Poly -> Poly
mulPointwise (Poly a) (Poly b) = Poly $ VU.zipWith go a b
where
go c d = value
where
t = montgomeryReduce (3186 * fromIntegral d)
value = montgomeryReduce (fromIntegral c * fromIntegral t)
add :: Poly -> Poly -> Poly
add (Poly a) (Poly b) = Poly $ VU.zipWith go a b
where
go c d = (c + d) `mod` fromIntegral Internals.q
sub :: Poly -> Poly -> Poly
sub (Poly a) (Poly b) = Poly $ VU.zipWith go a b
where
q = fromIntegral Internals.q
go c d = (c + (3 * q) - d) `mod` q
ntt :: Poly
-> Poly
ntt p@(Poly r) = Poly result
where
n = getN p
multiplied = NTT.mulCoefficients r $ ψBitrevMontgomery n
result = NTT.ntt multiplied $ ωBitrevMontgomery n
invntt :: Poly
-> Poly
invntt p@(Poly r) = Poly result
where
n = getN p
r' = NTT.bitrev r
r'' = NTT.ntt r' $ ωInvBitrevMontgomery n
result = NTT.mulCoefficients r'' $ ψInvMontgomery n