module Cryptography.WringTwistree.KeySchedule
  ( extendKey
  , mul65537
  , keySchedule
  , reschedule
  ) where

{- This module is used in both Wring and Twistree.
 - It is part of the keying algorithm, which turns a byte string
 - into three s-boxes. KeySchedule takes a ByteString and returns a
 - sequence of 96 Word16.
 -
 - To convert a String to a ByteString, put "- utf8-string" in your
 - package.yaml dependencies, import Data.ByteString.UTF8, and use
 - fromString.
 -}

import Data.Bits
import Data.Word
import Data.Foldable (toList,foldl')
import qualified Data.Sequence as Seq
import Data.Sequence ((><), (<|), (|>), Seq((:<|)), Seq((:|>)), update)
import qualified Data.ByteString as B

-- This sequence was used as the PRNG in an Apple implementation of Forth.
-- Its cycle length is 64697.
swap13mult :: [Word16]
swap13mult :: [Word16]
swap13mult = Word16
1 Word16 -> [Word16] -> [Word16]
forall a. a -> [a] -> [a]
: (Word16 -> Word16) -> [Word16] -> [Word16]
forall a b. (a -> b) -> [a] -> [b]
map ((Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`rotate` Int
8) (Word16 -> Word16) -> (Word16 -> Word16) -> Word16 -> Word16
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
* Word16
13)) [Word16]
swap13mult

extendKey_ :: [Word8] -> Int -> Int -> [Word16]
extendKey_ :: [Word8] -> Int -> Int -> [Word16]
extendKey_ [Word8]
str Int
i Int
n
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n = []
  | Bool
otherwise = ((Word16 -> Word16) -> [Word16] -> [Word16]
forall a b. (a -> b) -> [a] -> [b]
map (Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ (Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
256Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i))) ([Word16] -> [Word16]) -> [Word16] -> [Word16]
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word16) -> [Word8] -> [Word16]
forall a b. (a -> b) -> [a] -> [b]
map Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Word8]
str)
                [Word16] -> [Word16] -> [Word16]
forall a. [a] -> [a] -> [a]
++ ([Word8] -> Int -> Int -> [Word16]
extendKey_ [Word8]
str (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
n)

extendKey :: B.ByteString -> [Word16]
-- Extends the key, if it isn't empty, to be at least as long as 384 words.
extendKey :: ByteString -> [Word16]
extendKey ByteString
str = [Word8] -> Int -> Int -> [Word16]
extendKey_ (ByteString -> [Word8]
B.unpack ByteString
str) Int
0 Int
n where
  n :: Int
n = if (ByteString -> Int
B.length ByteString
str)Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0 then -((-Int
384) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` (ByteString -> Int
B.length ByteString
str)) else Int
0

-- | Multiplies two nonzero numbers mod 65537. Exported for testing.
mul65537 :: Word16 -> Word16 -> Word16
mul65537 :: Word16 -> Word16 -> Word16
mul65537 Word16
a Word16
b = Word64 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((((Word16 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
a)Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+Word64
1) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* 
                              ((Word16 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
b)Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+Word64
1))
                             Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`mod` (Word64
65537::Word64) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1)
-- 65537::Word32 gives the wrong answer for mul65537 65535 65535,
-- since 65536*65536 overflows a Word32.

alter :: Seq.Seq Word16 -> (Word16,Int) -> Seq.Seq Word16
-- subkey is 96 long. Alters the element at position inx.
alter :: Seq Word16 -> (Word16, Int) -> Seq Word16
alter Seq Word16
subkey (Word16
keyWord,Int
inx) = Int -> Word16 -> Seq Word16 -> Seq Word16
forall a. Int -> a -> Seq a -> Seq a
update Int
inx Word16
newval Seq Word16
subkey where
  i1 :: Word16
i1 = Word16 -> Word16 -> Word16
mul65537 (Seq Word16 -> Int -> Word16
forall a. Seq a -> Int -> a
Seq.index Seq Word16
subkey Int
inx) Word16
keyWord
  i2 :: Word16
i2 = Word16
i1 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ ((Seq Word16 -> Int -> Word16
forall a. Seq a -> Int -> a
Seq.index Seq Word16
subkey (Int -> Int -> Int
forall a. Integral a => a -> a -> a
mod (Int
inxInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
59) Int
96)) Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
`xor`
             (Seq Word16 -> Int -> Word16
forall a. Seq a -> Int -> a
Seq.index Seq Word16
subkey (Int -> Int -> Int
forall a. Integral a => a -> a -> a
mod (Int
inxInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
36) Int
96)) Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
`xor`
             (Seq Word16 -> Int -> Word16
forall a. Seq a -> Int -> a
Seq.index Seq Word16
subkey (Int -> Int -> Int
forall a. Integral a => a -> a -> a
mod (Int
inxInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
62) Int
96)))
  newval :: Word16
newval = Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
rotate Word16
i2 Int
8

keySchedule :: B.ByteString -> Seq.Seq Word16
keySchedule :: ByteString -> Seq Word16
keySchedule ByteString
key = (Seq Word16 -> (Word16, Int) -> Seq Word16)
-> Seq Word16 -> [(Word16, Int)] -> Seq Word16
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Seq Word16 -> (Word16, Int) -> Seq Word16
alter Seq Word16
initial [(Word16, Int)]
extended where
  extended :: [(Word16, Int)]
extended = [Word16] -> [Int] -> [(Word16, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip (ByteString -> [Word16]
extendKey ByteString
key) ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
96) [Int
0..])
  initial :: Seq Word16
initial = [Word16] -> Seq Word16
forall a. [a] -> Seq a
Seq.fromList (Int -> [Word16] -> [Word16]
forall a. Int -> [a] -> [a]
take Int
96 [Word16]
swap13mult)

reschedule :: Seq.Seq Word16 -> Seq.Seq Word16
reschedule :: Seq Word16 -> Seq Word16
reschedule Seq Word16
subkey = (Seq Word16 -> (Word16, Int) -> Seq Word16)
-> Seq Word16 -> [(Word16, Int)] -> Seq Word16
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Seq Word16 -> (Word16, Int) -> Seq Word16
alter Seq Word16
subkey ([Word16] -> [Int] -> [(Word16, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Word16 -> [Word16]
forall a. a -> [a]
repeat Word16
40504) [Int
0..Int
95])
-- 40505 is a primitive root near 65537/φ