module Cryptography.WringTwistree.Compress
  ( blockSize
  , relPrimes
  , lfsr
  , backCrc
  , compress
  , compress2
  , compress3
  ) where

{- This module is used in Twistree.
 - It compresses two or three 32-byte blocks into one, using three s-boxes in
 - an order specified by the sboxalt argument.
 -}

import Data.Bits
import Data.Word
import Data.List (mapAccumR)
import Data.Array.Unboxed
import Cryptography.WringTwistree.Mix3
import Cryptography.WringTwistree.RotBitcount
import Cryptography.WringTwistree.Sboxes
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as MV
import Control.Monad.ST
import Control.Monad
import Debug.Trace

blockSize :: Integral a => a
blockSize :: forall a. Integral a => a
blockSize = a
32
twistPrime :: Integral a => a
twistPrime :: forall a. Integral a => a
twistPrime = a
37
-- blockSize must be a multiple of 4. Blocks in the process of compression
-- can be any size from blockSize to 3*blockSize in steps of 4. twistPrime is
-- the smallest prime greater than blockSize, which is relatively prime to all
-- block sizes during compression.

relPrimes :: UArray Word16 Word16
-- 3/4 of this is waste. The numbers are Word16, because the last number is
-- 19, and the program will multiply 37 by 19, which doesn't fit in Word8.
relPrimes :: UArray Word16 Word16
relPrimes = (Word16, Word16) -> [Word16] -> UArray Word16 Word16
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
(i, i) -> [e] -> a i e
listArray (Word16
forall a. Integral a => a
blockSize,Word16
3Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
*Word16
forall a. Integral a => a
blockSize)
  ((Integer -> Word16) -> [Integer] -> [Word16]
forall a b. (a -> b) -> [a] -> [b]
map (Integer -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Word16) -> (Integer -> Integer) -> Integer -> Word16
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Integer
findMaxOrder (Integer -> Integer) -> (Integer -> Integer) -> Integer -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
3))
       [Integer
forall a. Integral a => a
blockSize..Integer
3Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
forall a. Integral a => a
blockSize])

lfsr1 :: Word32 -> Word32
lfsr1 :: Word32 -> Word32
lfsr1 Word32
n = Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
xor ((Word32
n Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
1) Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
0x84802140) (Word32
n Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
.>>. Int
1)

lfsr :: UArray Word32 Word32
lfsr :: UArray Word32 Word32
lfsr = (Word32, Word32) -> [Word32] -> UArray Word32 Word32
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
(i, i) -> [e] -> a i e
listArray (Word32
0,Word32
255) ((Integer -> Word32) -> [Integer] -> [Word32]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
n -> ((Word32 -> Word32) -> Word32 -> [Word32]
forall a. (a -> a) -> a -> [a]
iterate Word32 -> Word32
lfsr1 (Integer -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
n)) [Word32] -> Int -> Word32
forall a. HasCallStack => [a] -> Int -> a
!! Int
8) [Integer
0..Integer
255])

backCrc1 :: Word32 -> Word32 -> Word32
backCrc1 :: Word32 -> Word32 -> Word32
backCrc1 Word32
a Word32
b = (Word32
a Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
.>>. Int
8) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` (UArray Word32 Word32
lfsr UArray Word32 Word32 -> Word32 -> Word32
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
! (Word32
a Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
255)) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` Word32
b

backCrcM :: Word32 -> Word8 -> (Word32,Word8)
backCrcM :: Word32 -> Word8 -> (Word32, Word8)
backCrcM Word32
a Word8
b = (Word32
c,(Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
c)) where
  c :: Word32
c = Word32 -> Word32 -> Word32
backCrc1 Word32
a (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b)

backCrc :: [Word8] -> [Word8]
backCrc :: [Word8] -> [Word8]
backCrc [Word8]
bytes = (Word32, [Word8]) -> [Word8]
forall a b. (a, b) -> b
snd ((Word32, [Word8]) -> [Word8]) -> (Word32, [Word8]) -> [Word8]
forall a b. (a -> b) -> a -> b
$ (Word32 -> Word8 -> (Word32, Word8))
-> Word32 -> [Word8] -> (Word32, [Word8])
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumR Word32 -> Word8 -> (Word32, Word8)
backCrcM Word32
0xdeadc0de [Word8]
bytes

-- Original purely functional version, modified to use vectors

roundCompressFun :: SBox -> V.Vector Word8 -> Int -> V.Vector Word8
roundCompressFun :: SBox -> SBox -> Int -> SBox
roundCompressFun SBox
sbox SBox
buf Int
sboxalt = SBox
i4 where
  len :: Int
len = SBox -> Int
forall a. Unbox a => Vector a -> Int
V.length SBox
buf
  rprime :: Word16
rprime = UArray Word16 Word16
relPrimes UArray Word16 Word16 -> Word16 -> Word16
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
! (Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
  i1 :: SBox
i1 = SBox -> Int -> SBox
mix3Parts SBox
buf (Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
rprime)
  i2 :: SBox
i2 = Int -> [Word8] -> SBox
forall a. Unbox a => Int -> [a] -> Vector a
V.fromListN Int
len ([Word8] -> SBox) -> [Word8] -> SBox
forall a b. (a -> b) -> a -> b
$ (Int -> Word8) -> [Int] -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map (SBox
sbox SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.!) ([Int] -> [Word8]) -> [Int] -> [Word8]
forall a b. (a -> b) -> a -> b
$ (Word8 -> Word8 -> Int) -> [Word8] -> [Word8] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Word8 -> Word8 -> Int
forall a b. (Integral a, Integral b) => a -> b -> Int
sboxInx (Int -> [Word8] -> [Word8]
forall a. Int -> [a] -> [a]
drop Int
sboxalt [Word8]
cycle3) (SBox -> [Word8]
forall a. Unbox a => Vector a -> [a]
V.toList SBox
i1)
  i3 :: SBox
i3 = SBox -> Int -> SBox
rotBitcount SBox
i2 Int
forall a. Integral a => a
twistPrime
  i4 :: SBox
i4 = Int -> [Word8] -> SBox
forall a. Unbox a => Int -> [a] -> Vector a
V.fromListN (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
4) ([Word8] -> SBox) -> [Word8] -> SBox
forall a b. (a -> b) -> a -> b
$ [Word8] -> [Word8]
backCrc (SBox -> [Word8]
forall a. Unbox a => Vector a -> [a]
V.toList SBox
i3)

compressFun :: V.Vector Word8 -> V.Vector Word8 -> Int -> V.Vector Word8
compressFun :: SBox -> SBox -> Int -> SBox
compressFun SBox
sbox SBox
buf Int
sboxalt
  | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
forall a. Integral a => a
blockSize = SBox
buf
  | Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
forall a. Integral a => a
twistPrime Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = [Char] -> SBox
forall a. HasCallStack => [Char] -> a
error [Char]
"bad block size"
  | Bool
otherwise = SBox -> SBox -> Int -> SBox
compressFun SBox
sbox (SBox -> SBox -> Int -> SBox
roundCompressFun SBox
sbox SBox
buf Int
sboxalt) Int
sboxalt
  where len :: Int
len = SBox -> Int
forall a. Unbox a => Vector a -> Int
V.length SBox
buf

-- ST monad version modifies memory in place

roundCompressST ::
  SBox ->
  MV.MVector s Word8 ->
  Int ->
  ST s (MV.MVector s Word8)
roundCompressST :: forall s. SBox -> MVector s Word8 -> Int -> ST s (MVector s Word8)
roundCompressST SBox
sbox MVector s Word8
buf Int
sboxalt = do
  let len :: Int
len = MVector s Word8 -> Int
forall a s. Unbox a => MVector s a -> Int
MV.length MVector s Word8
buf
  let rprime :: Word16
rprime = UArray Word16 Word16
relPrimes UArray Word16 Word16 -> Word16 -> Word16
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
! (Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
  MVector s Word8
tmp <- Int -> ST s (MVector (PrimState (ST s)) Word8)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
MV.new Int
len
  MVector s Word8 -> Int -> ST s ()
forall s. MVector s Word8 -> Int -> ST s ()
mix3Parts' MVector s Word8
buf (Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
rprime)
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    Word8
a <- MVector (PrimState (ST s)) Word8 -> Int -> ST s Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState (ST s)) Word8
buf Int
i
    MVector (PrimState (ST s)) Word8 -> Int -> Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word8
MVector (PrimState (ST s)) Word8
tmp Int
i (SBox
sbox SBox -> Int -> Word8
forall a. Unbox a => Vector a -> Int -> a
V.! (Int -> Word8 -> Int
forall a b. (Integral a, Integral b) => a -> b -> Int
sboxInx ((Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
sboxalt) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
3) Word8
a))
  MVector s Word8 -> Int -> MVector s Word8 -> ST s ()
forall s. MVector s Word8 -> Int -> MVector s Word8 -> ST s ()
rotBitcount' MVector s Word8
tmp Int
forall a. Integral a => a
twistPrime MVector s Word8
buf
  MVector s Word32
crcVec <- Int -> ST s (MVector (PrimState (ST s)) Word32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
MV.new (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
  MVector (PrimState (ST s)) Word32 -> Int -> Word32 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word32
MVector (PrimState (ST s)) Word32
crcVec Int
len Word32
0xdeadc0de
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int] -> [Int]
forall a. [a] -> [a]
reverse [Int
0..Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    Word8
a <- MVector (PrimState (ST s)) Word8 -> Int -> ST s Word8
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word8
MVector (PrimState (ST s)) Word8
buf Int
i
    Word32
c <- MVector (PrimState (ST s)) Word32 -> Int -> ST s Word32
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s Word32
MVector (PrimState (ST s)) Word32
crcVec (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
    let (Word32
c',Word8
a') = Word32 -> Word8 -> (Word32, Word8)
backCrcM Word32
c Word8
a
    MVector (PrimState (ST s)) Word8 -> Int -> Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word8
MVector (PrimState (ST s)) Word8
buf Int
i Word8
a'
    MVector (PrimState (ST s)) Word32 -> Int -> Word32 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s Word32
MVector (PrimState (ST s)) Word32
crcVec Int
i Word32
c'
  MVector s Word8 -> ST s (MVector s Word8)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> MVector s Word8 -> MVector s Word8
forall a s. Unbox a => Int -> MVector s a -> MVector s a
MV.take (Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
4) MVector s Word8
buf)

compressST :: V.Vector Word8 -> V.Vector Word8 -> Int -> V.Vector Word8
compressST :: SBox -> SBox -> Int -> SBox
compressST SBox
sbox SBox
buf Int
sboxalt = (forall s. ST s (MVector s Word8)) -> SBox
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
V.create ((forall s. ST s (MVector s Word8)) -> SBox)
-> (forall s. ST s (MVector s Word8)) -> SBox
forall a b. (a -> b) -> a -> b
$ do
  let len :: Int
len = SBox -> Int
forall a. Unbox a => Vector a -> Int
V.length SBox
buf
  let nr :: Int
nr = (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
forall a. Integral a => a
blockSize) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
  let rounds :: [Int]
rounds = [Int
0 .. Int
nr]
  MVector s Word8
buf <- SBox -> ST s (MVector (PrimState (ST s)) Word8)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
V.thaw SBox
buf
  MVector s Word8
res <- (MVector s Word8 -> Int -> ST s (MVector s Word8))
-> MVector s Word8 -> [Int] -> ST s (MVector s Word8)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\MVector s Word8
b Int
r -> SBox -> MVector s Word8 -> Int -> ST s (MVector s Word8)
forall s. SBox -> MVector s Word8 -> Int -> ST s (MVector s Word8)
roundCompressST SBox
sbox MVector s Word8
b Int
sboxalt) MVector s Word8
buf [Int]
rounds
  MVector s Word8 -> ST s (MVector s Word8)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s Word8
res

-- | Compresses two or three 32-byte blocks into one. Exported for cryptanalysis.
compress :: SBox -> SBox -> Int -> SBox
compress = SBox -> SBox -> Int -> SBox
compressST

{-
compress2 takes 100x operations, compress3 takes 264x operations.
Twistree does twice as many compress2 calls as compress3 calls.
So it spends 100 ms in compress2 for every 132 ms in compress3, or 43% and 57%.
Profiling shows 42.5% and 56.0%, with the rest being blockize.
Hashing 1 MiB takes 12.5 s on my box, using two threads, one for the 2-tree and
one for the 3-tree. The 3-tree takes longer, so that's 16384 compress3 calls
(ignoring the few compress2 calls in the 3-tree) in 12.5 s, or 763 µs for compress3
and 289 µs for compress2.
-}

compress2 ::
  SBox ->
  V.Vector Word8 ->
  V.Vector Word8 ->
  Int ->
  V.Vector Word8
compress2 :: SBox -> SBox -> SBox -> Int -> SBox
compress2 SBox
sbox SBox
buf0 SBox
buf1 Int
sboxalt = SBox -> SBox -> Int -> SBox
compress SBox
sbox SBox
buf Int
sboxalt where
  buf :: SBox
buf = SBox
buf0 SBox -> SBox -> SBox
forall a. Semigroup a => a -> a -> a
<> SBox
buf1

compress3 ::
  SBox ->
  V.Vector Word8 ->
  V.Vector Word8 ->
  V.Vector Word8 ->
  Int ->
  V.Vector Word8
compress3 :: SBox -> SBox -> SBox -> SBox -> Int -> SBox
compress3 SBox
sbox SBox
buf0 SBox
buf1 SBox
buf2 Int
sboxalt = SBox -> SBox -> Int -> SBox
compress SBox
sbox SBox
buf Int
sboxalt where
  buf :: SBox
buf = SBox
buf0 SBox -> SBox -> SBox
forall a. Semigroup a => a -> a -> a
<> SBox
buf1 SBox -> SBox -> SBox
forall a. Semigroup a => a -> a -> a
<> SBox
buf2