{-# LANGUAGE BangPatterns #-}
module Cryptography.WringTwistree.Mix3
  ( findMaxOrder
  , mix3Parts
  , mix3Parts'
  ) where

{- This module is used in Wring.
 - This module splits a buffer (an array of bytes) into three equal parts, with
 - 0, 1, or 2 bytes left over, and mixes them as follows:
 -
 - The mix function takes three bytes and flips each bit which is not the same
 - in all three bytes. This is a self-inverse, nonlinear operation.
 -
 - The 0th third is traversed forward, the 1st third is traversed backward,
 - and the 2nd third is traversed by steps close to 1/φ the length of a third.
 - Taking a 16-byte buffer as an example:
 - 00 0d 1a 27 34|41 4e 5b 68 75|82 8f 9c a9 b6|c3
 - <>            |            <>|<>            |
 -    <>         |         <>   |         <>   |
 -       <>      |      <>      |   <>         |
 -          <>   |   <>         |            <>|
 -             <>|<>            |      <>      |
 - f7 e8 cf de c9|bc b7 8e 8d 82|75 5a 61 4c 4f|c3
 -}

import Data.Bits
import Data.Word
import Data.Array.Unboxed
import Math.NumberTheory.ArithmeticFunctions
import GHC.Natural
import Math.NumberTheory.Primes
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as MV
import Control.Monad.ST
import Control.Monad

{-# INLINE mix #-}
mix :: Word8 -> Word8 -> Word8 -> Word8
mix :: Word8 -> Word8 -> Word8 -> Word8
mix Word8
a Word8
b Word8
c = Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor Word8
a Word8
mask
  where mask :: Word8
mask = (Word8
a Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word8
b Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word8
c) Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- (Word8
a Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
b Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
c)

fibonacci :: [Integer]
fibonacci = Integer
0 Integer -> [Integer] -> [Integer]
forall a. a -> [a] -> [a]
: Integer
1 Integer -> [Integer] -> [Integer]
forall a. a -> [a] -> [a]
: (Integer -> Integer -> Integer)
-> [Integer] -> [Integer] -> [Integer]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
(+) [Integer]
fibonacci ([Integer] -> [Integer]
forall a. HasCallStack => [a] -> [a]
tail [Integer]
fibonacci)

fiboPair :: Integer -> [Integer]
fiboPair :: Integer -> [Integer]
fiboPair Integer
n = Int -> [Integer] -> [Integer]
forall a. Int -> [a] -> [a]
take Int
2 ([Integer] -> [Integer]) -> [Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ (Integer -> Bool) -> [Integer] -> [Integer]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
n) [Integer]
fibonacci

searchDir :: Integer -> (Integer,Int)
-- fst=n/φ rounded to nearest. snd=+1 or -1, indicating search direction.
-- e.g. if n=89, returns (55,1). Search 55,56,54,57,53...
-- if n=144, returns (89,(-1)). Search 89,88,90,87,91...
searchDir :: Integer -> (Integer, Int)
searchDir Integer
n
  | Integer
rInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
2 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
den = (Integer
q,Int
1)
  | Bool
otherwise = (Integer
qInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1,(-Int
1))
  where [Integer
num,Integer
den] = Integer -> [Integer]
fiboPair (Integer
2Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
n)
        (Integer
q,Integer
r) = (Integer
nInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
num) Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
`divMod` Integer
den

isMaxOrder :: Integral a => a -> a -> [a] -> a -> Bool
-- isMaxOrder modl car fac n
-- where modl is the modulus, car is its Carmichael function,
-- fac is the set of prime factors of car (without multiplicities),
-- and n is the number being tested.
-- Returns true if n has maximum order, which implies it's a primitive root
-- if modulus has any primitive roots.
isMaxOrder :: forall a. Integral a => a -> a -> [a] -> a -> Bool
isMaxOrder a
modl a
car [a]
fac a
n = (Natural -> Natural -> Natural -> Natural
powModNatural Natural
nn Natural
ncar Natural
nmodl) Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
1 Bool -> Bool -> Bool
&& Bool
allnot1
  where nn :: Natural
nn = (a -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n) :: Natural
        ncar :: Natural
ncar = (a -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
car) :: Natural
        nmodl :: Natural
nmodl = (a -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
modl) :: Natural
        powns :: [Natural]
powns = (a -> Natural) -> [a] -> [Natural]
forall a b. (a -> b) -> [a] -> [b]
map ((\a
x -> Natural -> Natural -> Natural -> Natural
powModNatural Natural
nn (a -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
x) Natural
nmodl) (a -> Natural) -> (a -> a) -> a -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
car a -> a -> a
forall a. Integral a => a -> a -> a
`div`)) [a]
fac
        allnot1 :: Bool
allnot1 = (Bool -> Bool -> Bool) -> Bool -> [Bool] -> Bool
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Bool -> Bool -> Bool
(&&) Bool
True ((Natural -> Bool) -> [Natural] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
/= Natural
1) [Natural]
powns)

searchSeq :: [Integer]
searchSeq = (Integer -> Integer) -> [Integer] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
n -> if (Integer -> Bool
forall a. Integral a => a -> Bool
odd Integer
n) then (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) else (-Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
2)) [Integer
0..]

searchFrom :: (Integer,Int) -> [Integer]
searchFrom :: (Integer, Int) -> [Integer]
searchFrom (Integer
start,Int
dir) = (Integer -> Integer) -> [Integer] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
x -> Integer
xInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*(Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
dir)Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
start) [Integer]
searchSeq

findMaxOrder :: Integer -> Integer
-- n must be positive.
-- Returns the number of maximum multiplicative order mod n closest to n/φ.
-- n=1 is a special case, as (isMaxOrder 1 1 [] i) returns False for all i>=0.
findMaxOrder :: Integer -> Integer
findMaxOrder Integer
1 = Integer
1
findMaxOrder Integer
n = [Integer] -> Integer
forall a. HasCallStack => [a] -> a
head ([Integer] -> Integer) -> [Integer] -> Integer
forall a b. (a -> b) -> a -> b
$ (Integer -> Bool) -> [Integer] -> [Integer]
forall a. (a -> Bool) -> [a] -> [a]
filter (Integer -> Integer -> [Integer] -> Integer -> Bool
forall a. Integral a => a -> a -> [a] -> a -> Bool
isMaxOrder Integer
n Integer
car [Integer]
fac) ([Integer] -> [Integer]) -> [Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ (Integer, Int) -> [Integer]
searchFrom ((Integer, Int) -> [Integer]) -> (Integer, Int) -> [Integer]
forall a b. (a -> b) -> a -> b
$ Integer -> (Integer, Int)
searchDir Integer
n
  where car :: Integer
car = Integer -> Integer
forall n. (UniqueFactorisation n, Integral n) => n -> n
carmichael Integer
n
        fac :: [Integer]
fac = ((Prime Integer, Word) -> Integer)
-> [(Prime Integer, Word)] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (Prime Integer -> Integer
forall a. Prime a -> a
unPrime (Prime Integer -> Integer)
-> ((Prime Integer, Word) -> Prime Integer)
-> (Prime Integer, Word)
-> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Prime Integer, Word) -> Prime Integer
forall a b. (a, b) -> a
fst) ([(Prime Integer, Word)] -> [Integer])
-> [(Prime Integer, Word)] -> [Integer]
forall a b. (a -> b) -> a -> b
$ Integer -> [(Prime Integer, Word)]
forall a. UniqueFactorisation a => a -> [(Prime a, Word)]
factorise Integer
car

triplicate :: [(a,a,a)] -> [(a,a,a)]
triplicate :: forall a. [(a, a, a)] -> [(a, a, a)]
triplicate [] = []
triplicate ((a
a,a
b,a
c):[(a, a, a)]
xs) = (a
a,a
b,a
c)(a, a, a) -> [(a, a, a)] -> [(a, a, a)]
forall a. a -> [a] -> [a]
:(a
b,a
c,a
a)(a, a, a) -> [(a, a, a)] -> [(a, a, a)]
forall a. a -> [a] -> [a]
:(a
c,a
a,a
b)(a, a, a) -> [(a, a, a)] -> [(a, a, a)]
forall a. a -> [a] -> [a]
:[(a, a, a)] -> [(a, a, a)]
forall a. [(a, a, a)] -> [(a, a, a)]
triplicate [(a, a, a)]
xs

mixOrder :: Int -> Int -> [(Int,Int,Int)]
-- rprime is relatively prime to len `div` 3
mixOrder :: Int -> Int -> [(Int, Int, Int)]
mixOrder Int
len Int
rprime
  | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
3 = []
  | Bool
otherwise = [(Int, Int, Int)] -> [(Int, Int, Int)]
forall a. [(a, a, a)] -> [(a, a, a)]
triplicate [(Int, Int, Int)]
mixord
  where
    third :: Int
third = Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
3
    mixord :: [(Int, Int, Int)]
mixord = [Int] -> [Int] -> [Int] -> [(Int, Int, Int)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
      [Int
0..Int
thirdInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
      ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ((Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
thirdInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
-) [Int
0..])
      ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ((Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
third) Int -> Int -> Int
forall a. Num a => a -> a -> a
+) ((Int -> Int) -> Int -> [Int]
forall a. (a -> a) -> a -> [a]
iterate (\Int
x -> (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rprime) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
third) Int
0))

-- | Splits @buf@ into three equal parts, with 0-2 bytes left over,
-- and mixes the three parts. Exported for testing.
mix3Parts :: V.Vector Word8 -> Int -> V.Vector Word8
-- The index of buf must start at 0.
-- Compute rprime once (findMaxOrder (fromIntegral (div len 3)))
-- and pass it to mix3Parts on every round.
mix3Parts :: Vector Word8 -> Int -> Vector Word8
mix3Parts Vector Word8
buf Int
rprime = Vector Word8
buf Vector Word8 -> [(Int, Word8)] -> Vector Word8
forall a. Unbox a => Vector a -> [(Int, a)] -> Vector a
V.// [(Int, Word8)]
mixed
  where
    mixed :: [(Int, Word8)]
mixed = ((Int, Int, Int) -> (Int, Word8))
-> [(Int, Int, Int)] -> [(Int, Word8)]
forall a b. (a -> b) -> [a] -> [b]
map (\(Int
a,Int
b,Int
c) -> (Int
a, Word8 -> Word8 -> Word8 -> Word8
mix (Vector Word8
buf Vector Word8 -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! Int
a) (Vector Word8
buf Vector Word8 -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! Int
b) (Vector Word8
buf Vector Word8 -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! Int
c))) [(Int, Int, Int)]
order
    order :: [(Int, Int, Int)]
order = Int -> Int -> [(Int, Int, Int)]
mixOrder Int
len Int
rprime
    len :: Int
len = Vector Word8 -> Int
forall a. Unbox a => Vector a -> Int
V.length Vector Word8
buf

mix3Parts' :: MV.MVector s Word8 -> Int -> ST s ()
mix3Parts' :: forall s. MVector s Word8 -> Int -> ST s ()
mix3Parts' MVector s Word8
buf Int
rprime = do
    let third :: Int
third = MVector s Word8 -> Int
forall a s. Unbox a => MVector s a -> Int
MV.length MVector s Word8
buf Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
3
        go :: Int -> Int -> Int -> t -> f ()
go Int
_  Int
_  Int
_  t
0 = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        go !Int
a !Int
b !Int
c t
n = do
            Word8
x <- MVector (PrimState f) Word8 -> Int -> f Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState f) Word8
buf Int
a
            Word8
y <- MVector (PrimState f) Word8 -> Int -> f Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState f) Word8
buf Int
b
            Word8
z <- MVector (PrimState f) Word8 -> Int -> f Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState f) Word8
buf Int
c
            let mask :: Word8
mask = (Word8
x Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word8
y Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word8
z) Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- (Word8
x Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
y Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
z)
            MVector (PrimState f) Word8 -> Int -> Word8 -> f ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word8
MVector (PrimState f) Word8
buf Int
a (Word8
x Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` Word8
mask)
            MVector (PrimState f) Word8 -> Int -> Word8 -> f ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word8
MVector (PrimState f) Word8
buf Int
b (Word8
y Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` Word8
mask)
            MVector (PrimState f) Word8 -> Int -> Word8 -> f ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word8
MVector (PrimState f) Word8
buf Int
c (Word8
z Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` Word8
mask)
            let c' :: Int
c' = Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rprime
            Int -> Int -> Int -> t -> f ()
go (Int
aInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
bInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (if Int
c' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
3Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
third then Int
c' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
third else Int
c') (t
nt -> t -> t
forall a. Num a => a -> a -> a
-t
1)
    Int -> Int -> Int -> Int -> ST s ()
forall {f :: * -> *} {t}.
(PrimState f ~ s, Eq t, Num t, PrimMonad f) =>
Int -> Int -> Int -> t -> f ()
go Int
0 (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
third Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
third) Int
third