module Crypto.PHKDF.Primitives.Subtle
  ( PhkdfCtx(..)
  , phkdfCtx_unsafeFeed
  , PhkdfSlowCtx(..)
  , phkdfSlowCtx_lift
  , PhkdfGen(..)
  ) where

import           Prelude hiding (null)
import qualified Crypto.Hash.SHA256 as SHA256
import           Crypto.PHKDF.HMAC (HmacKey)
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import           Data.Foldable(foldl', null)
import           Data.Word

-- I should be using the counter inside the sha256 ctx, but this is a Proof of Concept

-- TODO: should phkdfCtx_length count bytes, or bits? Double-check how SHA256 internal counter
-- works. Decide how this should work. Then export it from Primitives module.
-- For truly bulletproof code, we probably need to be returning (Maybe Ctx), so that we don't
-- overflow SHA256's internal counter. This would be a bit of a conceptual problem with the
-- cryptohash-style interface I'm mimicking, not to mention the cryptohash implementation I
-- am depending upon.

-- note that there's an offset error w.r.t the sha256 internal counter and phkdfCtx_length, but
-- it's always 64 bytes.  As the internals of this module only care about the internal counter
-- modulo 64, this doesn't matter.  However we should probably export the SHA256 counter itself

data PhkdfCtx = PhkdfCtx
  { PhkdfCtx -> Word64
phkdfCtx_byteLen :: !Word64
  , PhkdfCtx -> Ctx
phkdfCtx_state :: !SHA256.Ctx
  , PhkdfCtx -> HmacKey
phkdfCtx_hmacKey :: !HmacKey
  }

data P = P !Word64 !SHA256.Ctx

phkdfCtx_unsafeFeed :: Foldable f => f ByteString -> PhkdfCtx -> PhkdfCtx
phkdfCtx_unsafeFeed :: forall (f :: * -> *).
Foldable f =>
f ByteString -> PhkdfCtx -> PhkdfCtx
phkdfCtx_unsafeFeed f ByteString
strs PhkdfCtx
ctx0 =
  if f ByteString -> Bool
forall a. f a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null f ByteString
strs then PhkdfCtx
ctx0
  else PhkdfCtx
ctx0 {
    phkdfCtx_byteLen = byteLen',
    phkdfCtx_state = state'
  }
  where
    delta :: P -> ByteString -> P
delta (P Word64
len Ctx
ctx) ByteString
str = Word64 -> Ctx -> P
P (Word64
len Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
str))) (Ctx -> ByteString -> Ctx
SHA256.update Ctx
ctx ByteString
str)

    p0 :: P
p0 = Word64 -> Ctx -> P
P (PhkdfCtx -> Word64
phkdfCtx_byteLen PhkdfCtx
ctx0) (PhkdfCtx -> Ctx
phkdfCtx_state PhkdfCtx
ctx0)

    P Word64
byteLen' Ctx
state' = (P -> ByteString -> P) -> P -> f ByteString -> P
forall b a. (b -> a -> b) -> b -> f a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' P -> ByteString -> P
delta P
p0 f ByteString
strs

data PhkdfSlowCtx = PhkdfSlowCtx
  { PhkdfSlowCtx -> PhkdfCtx
phkdfSlowCtx_phkdfCtx :: !PhkdfCtx
  , PhkdfSlowCtx -> Word32
phkdfSlowCtx_counter :: !Word32
  , PhkdfSlowCtx -> ByteString
phkdfSlowCtx_tag :: !ByteString
  }

phkdfSlowCtx_lift :: (PhkdfCtx -> PhkdfCtx) -> PhkdfSlowCtx -> PhkdfSlowCtx
phkdfSlowCtx_lift :: (PhkdfCtx -> PhkdfCtx) -> PhkdfSlowCtx -> PhkdfSlowCtx
phkdfSlowCtx_lift PhkdfCtx -> PhkdfCtx
f PhkdfSlowCtx
ctx = PhkdfSlowCtx
ctx {
    phkdfSlowCtx_phkdfCtx = f (phkdfSlowCtx_phkdfCtx ctx)
  }

data PhkdfGen = PhkdfGen
  { PhkdfGen -> HmacKey
phkdfGen_hmacKey :: !HmacKey
  , PhkdfGen -> ByteString
phkdfGen_extTag :: !ByteString
  , PhkdfGen -> Word32
phkdfGen_counter :: !Word32
  , PhkdfGen -> ByteString
phkdfGen_state :: !ByteString
  , PhkdfGen -> Maybe Ctx
phkdfGen_initCtx :: !(Maybe SHA256.Ctx)
  }