{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnliftedFFITypes #-}

module Hash.Md5
  ( Context (..)

    -- * Context Reuse
  , context
  , reinitialize
  , update
  , finalize

    -- * One Shot
  , boundedBuilder
  ) where

import Control.Monad.ST (ST)
import Data.Bytes.Builder.Bounded as BB
import Data.Bytes.Builder.Bounded.Unsafe as BBU
import Data.Bytes.Types (Bytes (Bytes))
import Data.Primitive (ByteArray (..), MutableByteArray (..), newByteArray)
import GHC.Exts (unsafeCoerce#)
import GHC.IO (unsafeIOToST)

import Hash.Md5.Internal

newtype Context s = Context (MutableByteArray s)

-- | Create a new context. The context is initialized.
context :: ST s (Context s)
context :: forall s. ST s (Context s)
context = do
  MutableByteArray s
b <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
88
  Context s -> ST s ()
forall s. Context s -> ST s ()
reinitialize (MutableByteArray s -> Context s
forall s. MutableByteArray s -> Context s
Context MutableByteArray s
b)
  Context s -> ST s (Context s)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MutableByteArray s -> Context s
forall s. MutableByteArray s -> Context s
Context MutableByteArray s
b)

-- | Reset the context so that it may be used to hash another byte sequence.
reinitialize :: Context s -> ST s ()
reinitialize :: forall s. Context s -> ST s ()
reinitialize (Context (MutableByteArray MutableByteArray# s
ctx)) =
  IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (MutableByteArray# RealWorld -> IO ()
c_md5_init (MutableByteArray# s -> MutableByteArray# RealWorld
forall a b. a -> b
unsafeCoerce# MutableByteArray# s
ctx))

finalize ::
  Context s ->
  -- | Destination, implied length is 16
  MutableByteArray s ->
  -- | Destination offset
  Int ->
  ST s ()
finalize :: forall s. Context s -> MutableByteArray s -> Int -> ST s ()
finalize (Context (MutableByteArray MutableByteArray# s
ctx)) (MutableByteArray MutableByteArray# s
x) !Int
a =
  IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (MutableByteArray# RealWorld
-> MutableByteArray# RealWorld -> Int -> IO ()
c_md5_finalize (MutableByteArray# s -> MutableByteArray# RealWorld
forall a b. a -> b
unsafeCoerce# MutableByteArray# s
ctx) (MutableByteArray# s -> MutableByteArray# RealWorld
forall a b. a -> b
unsafeCoerce# MutableByteArray# s
x) Int
a)

update ::
  Context s ->
  Bytes ->
  ST s ()
update :: forall s. Context s -> Bytes -> ST s ()
update (Context (MutableByteArray MutableByteArray# s
ctx)) (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) =
  IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (MutableByteArray# RealWorld -> ByteArray# -> Int -> Int -> IO ()
c_md5_update_unsafe (MutableByteArray# s -> MutableByteArray# RealWorld
forall a b. a -> b
unsafeCoerce# MutableByteArray# s
ctx) ByteArray#
arr Int
off Int
len)

performHash :: MutableByteArray s -> Int -> ByteArray -> Int -> Int -> ST s ()
performHash :: forall s.
MutableByteArray s -> Int -> ByteArray -> Int -> Int -> ST s ()
performHash (MutableByteArray MutableByteArray# s
x) !Int
a (ByteArray ByteArray#
y) !Int
b !Int
c =
  IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> IO () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
    MutableByteArray MutableByteArray# RealWorld
ctx <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
88
    MutableByteArray# RealWorld -> IO ()
c_md5_init MutableByteArray# RealWorld
ctx
    MutableByteArray# RealWorld -> ByteArray# -> Int -> Int -> IO ()
c_md5_update_unsafe MutableByteArray# RealWorld
ctx ByteArray#
y Int
b Int
c
    MutableByteArray# RealWorld
-> MutableByteArray# RealWorld -> Int -> IO ()
c_md5_finalize MutableByteArray# RealWorld
ctx (MutableByteArray# s -> MutableByteArray# RealWorld
forall a b. a -> b
unsafeCoerce# MutableByteArray# s
x) Int
a

boundedBuilder :: Bytes -> BB.Builder 16
boundedBuilder :: Bytes -> Builder 16
boundedBuilder (Bytes ByteArray
arr Int
off Int
len) =
  (forall s. MutableByteArray s -> Int -> ST s Int) -> Builder 16
forall (n :: Nat).
(forall s. MutableByteArray s -> Int -> ST s Int) -> Builder n
BBU.construct
    ( \MutableByteArray s
buf Int
ix -> do
        MutableByteArray s -> Int -> ByteArray -> Int -> Int -> ST s ()
forall s.
MutableByteArray s -> Int -> ByteArray -> Int -> Int -> ST s ()
performHash MutableByteArray s
buf Int
ix ByteArray
arr Int
off Int
len
        Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
16)
    )