{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE Trustworthy #-}
module Crypto.NewHope.FIPS202 where
import Data.Bits
import qualified Data.ByteString as BS
import qualified Data.Map as Map
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import Data.Word
import Prelude hiding (round)
import StringUtils
type KeccakVector = VU.Vector Word64
keccakStateLength :: Int
keccakStateLength = 25
shake128Rate :: Int
shake128Rate = 168
shake256Rate :: Int
shake256Rate = 136
keccakF1600StatePermuteRoundsCount :: Int
keccakF1600StatePermuteRoundsCount = 24
keccakEmpty :: KeccakVector
keccakEmpty = VU.replicate keccakStateLength 0
keccakFRoundConstants :: KeccakVector
keccakFRoundConstants = VU.fromList [ 0x0000000000000001, 0x0000000000008082, 0x800000000000808a, 0x8000000080008000,
0x000000000000808b, 0x0000000080000001, 0x8000000080008081, 0x8000000000008009,
0x000000000000008a, 0x0000000000000088, 0x0000000080008009, 0x000000008000000a,
0x000000008000808b, 0x800000000000008b, 0x8000000000008089, 0x8000000000008003,
0x8000000000008002, 0x8000000000000080, 0x000000000000800a, 0x800000008000000a,
0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008]
rol :: Word64 -> Int -> Word64
rol a offset = shift a offset `xor` shift a (offset - 64)
load64 :: BS.ByteString
-> Int
-> Word64
load64 input offset
| BS.length trimmed < 8 = error ("TOO SHORT: \"" ++ (byteStringToHexString trimmed :: String) ++ "\"")
| otherwise = foldl go a as
where
trimmed = BS.take 8 (BS.drop offset input)
a : as = reverse $ fromIntegral <$> BS.unpack trimmed
go d c = c .|. shift d 8
load64' :: VU.Vector Word8
-> Int
-> Word64
load64' input offset = foldl go a as
where
a : as = reverse $ fmap fromIntegral $ VU.toList $ VU.take 8 (VU.drop offset input)
go d c = c .|. shift d 8
store64 :: Word64 -> BS.ByteString
store64 value = BS.pack $ fmap fromIntegral [ value .&. 0xFF,
shift value $ -8 .&. 0xFF,
shift value $ -16 .&. 0xFF,
shift value $ -24 .&. 0xFF,
shift value $ -32 .&. 0xFF,
shift value $ -40 .&. 0xFF,
shift value $ -48 .&. 0xFF,
shift value $ -56 .&. 0xFF]
keccakF1600StatePermute :: KeccakVector -> KeccakVector
keccakF1600StatePermute state' = runRounds state' 0
where
runRounds state round = if round >= keccakF1600StatePermuteRoundsCount - 2
then result
else runRounds result (round + 2)
where
aba = state VU.! 0
abe = state VU.! 1
abi = state VU.! 2
abo = state VU.! 3
abu = state VU.! 4
aga = state VU.! 5
age = state VU.! 6
agi = state VU.! 7
ago = state VU.! 8
agu = state VU.! 9
aka = state VU.! 10
ake = state VU.! 11
aki = state VU.! 12
ako = state VU.! 13
aku = state VU.! 14
ama = state VU.! 15
ame = state VU.! 16
ami = state VU.! 17
amo = state VU.! 18
amu = state VU.! 19
asa = state VU.! 20
ase = state VU.! 21
asi = state VU.! 22
aso = state VU.! 23
asu = state VU.! 24
bca = aba `xor` aga `xor` aka `xor` ama `xor` asa
bce = abe `xor` age `xor` ake `xor` ame `xor` ase
bci = abi `xor` agi `xor` aki `xor` ami `xor` asi
bco = abo `xor` ago `xor` ako `xor` amo `xor` aso
bcu = abu `xor` agu `xor` aku `xor` amu `xor` asu
da = bcu `xor` rol bce 1
de = bca `xor` rol bci 1
di = bce `xor` rol bco 1
do1 = bci `xor` rol bcu 1
du = bco `xor` rol bca 1
aba2 = aba `xor` da
bca2 = aba2
age2 = age `xor` de
bce2 = rol age2 44
aki2 = aki `xor` di
bci2 = rol aki2 43
amo2 = amo `xor` do1
bco2 = rol amo2 21
asu2 = asu `xor` du
bcu2 = rol asu2 14
eba4 = bca2 `xor` complement bce2 .&. bci2
eba5 = eba4 `xor` (keccakFRoundConstants VU.! round)
ebe = bce2 `xor` complement bci2 .&. bco2
ebi = bci2 `xor` complement bco2 .&. bcu2
ebo = bco2 `xor` complement bcu2 .&. bca2
ebu = bcu2 `xor` complement bca2 .&. bce2
abo2 = abo `xor` do1
bca3 = rol abo2 28
agu2 = agu `xor` du
bce3 = rol agu2 20
aka2 = aka `xor` da
bci3 = rol aka2 3
ame2 = ame `xor` de
bco3 = rol ame2 45
asi2 = asi `xor` di
bcu3 = rol asi2 61
ega = bca3 `xor` complement bce3 .&. bci3
ege = bce3 `xor` complement bci3 .&. bco3
egi = bci3 `xor` complement bco3 .&. bcu3
ego = bco3 `xor` complement bcu3 .&. bca3
egu = bcu3 `xor` complement bca3 .&. bce3
abe2 = abe `xor` de
bca4 = rol abe2 1
agi2 = agi `xor` di
bce4 = rol agi2 6
ako2 = ako `xor` do1
bci4 = rol ako2 25
amu2 = amu `xor` du
bco4 = rol amu2 8
asa2 = asa `xor` da
bcu4 = rol asa2 18
eka = bca4 `xor` complement bce4 .&. bci4
eke = bce4 `xor` complement bci4 .&. bco4
eki = bci4 `xor` complement bco4 .&. bcu4
eko = bco4 `xor` complement bcu4 .&. bca4
eku = bcu4 `xor` complement bca4 .&. bce4
abu2 = abu `xor` du
bca5 = rol abu2 27
aga2 = aga `xor` da
bce5 = rol aga2 36
ake2 = ake `xor` de
bci5 = rol ake2 10
ami2 = ami `xor` di
bco5 = rol ami2 15
aso2 = aso `xor` do1
bcu5 = rol aso2 56
ema = bca5 `xor` complement bce5 .&. bci5
eme = bce5 `xor` complement bci5 .&. bco5
emi = bci5 `xor` complement bco5 .&. bcu5
emo = bco5 `xor` complement bcu5 .&. bca5
emu = bcu5 `xor` complement bca5 .&. bce5
abi2 = abi `xor` di
bca6 = rol abi2 62
ago2 = ago `xor` do1
bce6 = rol ago2 55
aku2 = aku `xor` du
bci6 = rol aku2 39
ama2 = ama `xor` da
bco6 = rol ama2 41
ase2 = ase `xor` de
bcu6 = rol ase2 2
esa = bca6 `xor` complement bce6 .&. bci6
ese = bce6 `xor` complement bci6 .&. bco6
esi = bci6 `xor` complement bco6 .&. bcu6
eso = bco6 `xor` complement bcu6 .&. bca6
esu = bcu6 `xor` complement bca6 .&. bce6
bca7 = eba5 `xor` ega `xor` eka `xor` ema `xor` esa
bce7 = ebe `xor` ege `xor` eke `xor` eme `xor` ese
bci7 = ebi `xor` egi `xor` eki `xor` emi `xor` esi
bco7 = ebo `xor` ego `xor` eko `xor` emo `xor` eso
bcu7 = ebu `xor` egu `xor` eku `xor` emu `xor` esu
da2 = bcu7 `xor` rol bce7 1
de2 = bca7 `xor` rol bci7 1
di2 = bce7 `xor` rol bco7 1
do2 = bci7 `xor` rol bcu7 1
du2 = bco7 `xor` rol bca7 1
eba6 = eba5 `xor` da2
bca8 = eba6
ege2 = ege `xor` de2
bce8 = rol ege2 44
eki2 = eki `xor` di2
bci8 = rol eki2 43
emo2 = emo `xor` do2
bco8 = rol emo2 21
esu2 = esu `xor` du2
bcu8 = rol esu2 14
aba3 = bca8 `xor` complement bce8 .&. bci8
aba4 = aba3 `xor` (keccakFRoundConstants VU.! (round + 1))
abe3 = bce8 `xor` complement bci8 .&. bco8
abi3 = bci8 `xor` complement bco8 .&. bcu8
abo3 = bco8 `xor` complement bcu8 .&. bca8
abu3 = bcu8 `xor` complement bca8 .&. bce8
ebo2 = ebo `xor` do2
bca9 = rol ebo2 28
egu2 = egu `xor` du2
bce9 = rol egu2 20
eka2 = eka `xor` da2
bci9 = rol eka2 3
eme2 = eme `xor` de2
bco9 = rol eme2 45
esi2 = esi `xor` di2
bcu9 = rol esi2 61
aga3 = bca9 `xor` complement bce9 .&. bci9
age3 = bce9 `xor` complement bci9 .&. bco9
agi3 = bci9 `xor` complement bco9 .&. bcu9
ago3 = bco9 `xor` complement bcu9 .&. bca9
agu3 = bcu9 `xor` complement bca9 .&. bce9
ebe2 = ebe `xor` de2
bcaA = rol ebe2 1
egi2 = egi `xor` di2
bceA = rol egi2 6
eko2 = eko `xor` do2
bciA = rol eko2 25
emu2 = emu `xor` du2
bcoA = rol emu2 8
esa2 = esa `xor` da2
bcuA = rol esa2 18
aka3 = bcaA `xor` complement bceA .&. bciA
ake3 = bceA `xor` complement bciA .&. bcoA
aki3 = bciA `xor` complement bcoA .&. bcuA
ako3 = bcoA `xor` complement bcuA .&. bcaA
aku3 = bcuA `xor` complement bcaA .&. bceA
ebu2 = ebu `xor` du2
bcaB = rol ebu2 27
ega2 = ega `xor` da2
bceB = rol ega2 36
eke2 = eke `xor` de2
bciB = rol eke2 10
emi2 = emi `xor` di2
bcoB = rol emi2 15
eso2 = eso `xor` do2
bcuB = rol eso2 56
ama3 = bcaB `xor` complement bceB .&. bciB
ame3 = bceB `xor` complement bciB .&. bcoB
ami3 = bciB `xor` complement bcoB .&. bcuB
amo3 = bcoB `xor` complement bcuB .&. bcaB
amu3 = bcuB `xor` complement bcaB .&. bceB
ebi2 = ebi `xor` di2
bcaC = rol ebi2 62
ego2 = ego `xor` do2
bceC = rol ego2 55
eku2 = eku `xor` du2
bciC = rol eku2 39
ema2 = ema `xor` da2
bcoC = rol ema2 41
ese2 = ese `xor` de2
bcuC = rol ese2 2
asa3 = bcaC `xor` complement bceC .&. bciC
ase3 = bceC `xor` complement bciC .&. bcoC
asi3 = bciC `xor` complement bcoC .&. bcuC
aso3 = bcoC `xor` complement bcuC .&. bcaC
asu3 = bcuC `xor` complement bcaC .&. bceC
result = VU.fromList [ aba4, abe3, abi3, abo3, abu3, aga3, age3, agi3,
ago3, agu3, aka3, ake3, aki3, ako3, aku3, ama3,
ame3, ami3, amo3, amu3, asa3, ase3, asi3, aso3, asu3]
inputLoadVectors :: Int -> BS.ByteString -> Map.Map Int KeccakVector
inputLoadVectors rate input = do
let mlen = BS.length input
let inputOffsets = [0, rate .. mlen - rate]
Map.fromList $ fmap vectorAt inputOffsets
where
vectorAt offset = (offset, VU.fromList $ take keccakStateLength $ fmap valueAt [0 .. rate `div` 8 - 1] ++ repeat 0)
where
valueAt i = load64 input (offset + i * 8)
unifiedInputLoadVectors :: KeccakVector -> Int -> BS.ByteString -> KeccakVector
unifiedInputLoadVectors start rate input = go start inputs
where
inputs = Map.elems $ inputLoadVectors rate input
go start' [] = start'
go start' (input' : inputs') = go nextStart inputs'
where
nextStart = keccakF1600StatePermute $ VU.zipWith xor start' input'
keccakAbsorb :: Int -> BS.ByteString -> Word8 -> KeccakVector
keccakAbsorb rate input domainSep = s'
where
s = unifiedInputLoadVectors keccakEmpty rate input
s' = let loop ss i = ss VU.// [(i, (ss VU.! i) `xor` load64' t (8 * i))]
in foldl loop s [0 .. rate `div` 8 - 1]
t :: VU.Vector Word8
t = VU.modify (\ v -> VUM.write v lastIndex (xor 128 (basic VU.! lastIndex))) basic
where
inputOffset = (inputLength `div` rate) * rate
input' = BS.drop inputOffset input
inputLength = BS.length input
inputLength' = BS.length input'
lastIndex = rate - 1
basic = VU.concat [VU.fromList (BS.unpack input')
, VU.replicate 1 domainSep
, VU.replicate (rate - inputLength' - 1) 0
]
shake128Absorb :: BS.ByteString -> KeccakVector
shake128Absorb seed = keccakAbsorb shake128Rate seed 0x1f
shake128SqueezeBlocks :: KeccakVector
-> Int
-> (BS.ByteString, KeccakVector)
shake128SqueezeBlocks = flip keccakSqueezeblocks shake128Rate
keccakSqueezeblocks :: KeccakVector -> Int -> Int -> (BS.ByteString, KeccakVector)
keccakSqueezeblocks state rate blocks
| blocks <= 0 = (BS.empty, state)
| otherwise = (BS.append output nextOutput, nextState)
where
state' = keccakF1600StatePermute state
toEncode = VU.take (rate `div` 8) state'
output = VU.foldl toLSB BS.empty toEncode
toLSB out inp = BS.append out (store64 inp)
(nextOutput, nextState) = keccakSqueezeblocks state' rate (blocks - 1)
shake256 :: BS.ByteString
-> Int
-> BS.ByteString
shake256 input outputLength = do
let state = keccakAbsorb shake256Rate input 0x1F
let nblocks = outputLength `div` shake256Rate
let (output, state') = keccakSqueezeblocks state shake256Rate nblocks
let extraBytes = outputLength `mod` shake256Rate
let (extraOutput, _) = keccakSqueezeblocks state' shake256Rate 1
if extraBytes == 0
then output
else BS.append output (BS.take extraBytes extraOutput)