{-# LANGUAGE OverloadedStrings, ViewPatterns #-}

module Crypto.Encoding.PHKDF where

import Data.Monoid((<>))
import Data.Bits(Bits, (.&.))
import Data.ByteString(ByteString)
import Data.Foldable(Foldable)
import qualified Data.ByteString as B
import Crypto.Encoding.SHA3.TupleHash

import Debug.Trace

-- FIXME: several functions in here have opportunites for optimization

cycleByteStringToList :: ByteString -> Int -> [ByteString]
cycleByteStringToList :: ByteString -> Int -> [ByteString]
cycleByteStringToList ByteString
str Int
outBytes =
    if Int
outBytes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
    then []
    else if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
         then [ Int -> Word8 -> ByteString
B.replicate Int
outBytes Word8
0 ]
         else Int -> ByteString -> [ByteString]
forall a. Int -> a -> [a]
replicate Int
q ByteString
str [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [Int -> ByteString -> ByteString
B.take Int
r ByteString
str]
  where
    n :: Int
n = ByteString -> Int
B.length ByteString
str
    (Int
q,Int
r) = Int
outBytes Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
n

cycleByteStringWithNullToList :: ByteString -> Int -> [ByteString]
cycleByteStringWithNullToList :: ByteString -> Int -> [ByteString]
cycleByteStringWithNullToList ByteString
str Int
outBytes = [ByteString]
out
  where
    out :: [ByteString]
out = ByteString -> Int -> [ByteString]
cycleByteStringToList (ByteString
str ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\x00") Int
outBytes

cycleByteString :: ByteString -> Int -> ByteString
cycleByteString :: ByteString -> Int -> ByteString
cycleByteString ByteString
str Int
outBytes = [ByteString] -> ByteString
B.concat (ByteString -> Int -> [ByteString]
cycleByteStringToList ByteString
str Int
outBytes)

cycleByteStringWithNull :: ByteString -> Int -> ByteString
cycleByteStringWithNull :: ByteString -> Int -> ByteString
cycleByteStringWithNull ByteString
str Int
outBytes =
    [ByteString] -> ByteString
B.concat (ByteString -> Int -> [ByteString]
cycleByteStringWithNullToList ByteString
str Int
outBytes)

extendTagToList :: ByteString -> [ByteString]
extendTagToList :: ByteString -> [ByteString]
extendTagToList ByteString
tag = if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
19 then [ByteString
tag] else [ByteString]
tag'
  where
    n :: Int
n = ByteString -> Int
B.length ByteString
tag
    x :: Int
x = (Int
18 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
64
    tag' :: [ByteString]
tag' = ByteString -> Int -> [ByteString]
cycleByteStringWithNullToList ByteString
tag (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
x)
         [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [Word8 -> ByteString
B.singleton (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x)]

extendTag :: ByteString -> ByteString
extendTag :: ByteString -> ByteString
extendTag = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> [ByteString]
extendTagToList

trimExtTag :: ByteString -> Maybe ByteString
trimExtTag :: ByteString -> Maybe ByteString
trimExtTag ByteString
extTag
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
19 = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
extTag
  | ByteString
extTag ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> ByteString
extendTag ByteString
tag = Maybe ByteString
forall a. Maybe a
Nothing
  | Bool
otherwise = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
tag
  where
    n :: Int
n = ByteString -> Int
B.length ByteString
extTag
    x :: Word8
x = HasCallStack => ByteString -> Word8
ByteString -> Word8
B.last ByteString
extTag
    tag :: ByteString
tag = Int -> ByteString -> ByteString
B.take (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ByteString
extTag

{--

FIXME: as written, this only works on signed arithmetic, unless the modulus @a@
is a power of 2, such as 64

-- | @addWhileLt a b c@ is equivalent to  @while (b < c) { b += a }; return b@
addWhileLt :: Integral a => a -> a -> a -> a
addWhileLt a b c
   | b >= c = b
   | otherwise = c + ((b - c) `mod` a)

--}

-- | @add64WhileLt b c@ is equivalent to  @while (b < c) { b += 64 }; return b@

add64WhileLt :: (Ord a, Num a, Bits a) => a -> a -> a
add64WhileLt :: forall a. (Ord a, Num a, Bits a) => a -> a -> a
add64WhileLt a
b a
c
   | a
b a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
c = a
b
   | Bool
otherwise = a
c a -> a -> a
forall a. Num a => a -> a -> a
+ ((a
b a -> a -> a
forall a. Num a => a -> a -> a
- a
c) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
63)

add64WhileLt' :: (Ord a, Num a, Bits a, Show a) => a -> a -> a
add64WhileLt' :: forall a. (Ord a, Num a, Bits a, Show a) => a -> a -> a
add64WhileLt' a
b a
c
   | a
b a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
c = a
b
   | Bool
otherwise = let d :: a
d = a
c a -> a -> a
forall a. Num a => a -> a -> a
+ ((a
b a -> a -> a
forall a. Num a => a -> a -> a
- a
c) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
63)
                  in [Char] -> a -> a
forall a. [Char] -> a -> a
trace (a -> [Char]
forall a. Show a => a -> [Char]
show a
b [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" -> " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ a -> [Char]
forall a. Show a => a -> [Char]
show a
d) a
d

usernamePadding :: Foldable f => f ByteString -> ByteString -> ByteString -> ByteString
usernamePadding :: forall (f :: * -> *).
Foldable f =>
f ByteString -> ByteString -> ByteString -> ByteString
usernamePadding f ByteString
headerExtract ByteString
fillerTag ByteString
domainTag
  =  ByteString -> Int -> ByteString
cycleByteStringWithNull ByteString
fillerTag (Int
aInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
32)
  ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> Int -> ByteString
cycleByteStringWithNull ByteString
domainTag    Int
32
  where
    al :: Int
al = f ByteString -> Int
forall (f :: * -> *). Foldable f => f ByteString -> Int
encodedVectorByteLength f ByteString
headerExtract
    a :: Int
a  = Int -> Int -> Int
forall a. (Ord a, Num a, Bits a) => a -> a -> a
add64WhileLt (Int
157 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
al) Int
32

passwordPaddingBytes :: Foldable f => Int -> f ByteString -> f ByteString -> ByteString -> ByteString -> ByteString -> ByteString
passwordPaddingBytes :: forall (f :: * -> *).
Foldable f =>
Int
-> f ByteString
-> f ByteString
-> ByteString
-> ByteString
-> ByteString
-> ByteString
passwordPaddingBytes Int
bytes f ByteString
headerUsername f ByteString
headerLongTag ByteString
fillerTag ByteString
domainTag ByteString
password
  =  ByteString -> Int -> ByteString
cycleByteStringWithNull ByteString
fillerTag (Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
32)
  ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> Int -> ByteString
cycleByteStringWithNull ByteString
domainTag    Int
32
  where
    al :: Int
al = f ByteString -> Int
forall (f :: * -> *). Foldable f => f ByteString -> Int
encodedVectorByteLength f ByteString
headerLongTag
    a :: Int
a  = Int -> Int -> Int
forall a. (Ord a, Num a, Bits a) => a -> a -> a
add64WhileLt (Int
bytes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
al) Int
3240
    bl :: Int
bl = f ByteString -> Int
forall (f :: * -> *). Foldable f => f ByteString -> Int
encodedVectorByteLength f ByteString
headerUsername
    b :: Int
b  = Int -> Int -> Int
forall a. (Ord a, Num a, Bits a) => a -> a -> a
add64WhileLt (Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
bl) Int
136
    cl :: Int
cl = ByteString -> Int
encodedByteLength ByteString
password
    c :: Int
c  = Int -> Int -> Int
forall a. (Ord a, Num a, Bits a) => a -> a -> a
add64WhileLt (Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
cl) Int
32

passwordPadding :: Foldable f => f ByteString -> f ByteString -> ByteString -> ByteString -> ByteString -> ByteString
passwordPadding :: forall (f :: * -> *).
Foldable f =>
f ByteString
-> f ByteString
-> ByteString
-> ByteString
-> ByteString
-> ByteString
passwordPadding = Int
-> f ByteString
-> f ByteString
-> ByteString
-> ByteString
-> ByteString
-> ByteString
forall (f :: * -> *).
Foldable f =>
Int
-> f ByteString
-> f ByteString
-> ByteString
-> ByteString
-> ByteString
-> ByteString
passwordPaddingBytes Int
8413

credentialsPadding :: Foldable f => f ByteString -> ByteString -> ByteString -> ByteString
credentialsPadding :: forall (f :: * -> *).
Foldable f =>
f ByteString -> ByteString -> ByteString -> ByteString
credentialsPadding f ByteString
credentials ByteString
fillerTag ByteString
domainTag
  =  ByteString -> Int -> ByteString
cycleByteStringWithNull ByteString
fillerTag (Int
aInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
29)
  ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> Int -> ByteString
cycleByteStringWithNull ByteString
domainTag    Int
29
  where
    al :: Int
al = f ByteString -> Int
forall (f :: * -> *). Foldable f => f ByteString -> Int
encodedVectorByteLength f ByteString
credentials
    a :: Int
a  = Int -> Int -> Int
forall a. (Ord a, Num a, Bits a) => a -> a -> a
add64WhileLt (Int
122 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
al) Int
32