{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module BLAKE3.IO
  ( 
    init
  , update
  , finalize
  , Hasher
  , allocRetHasher
  , Digest
  , allocRetDigest
  
  , Raw.HasherInternal
  , copyHasher
  , withHasherInternal
  
  , Key
  , key
  , allocRetKey
  , initKeyed
  
  , Context
  , context
  , initDerive
  
  , Raw.HASHER_ALIGNMENT
  , Raw.HASHER_SIZE
  , Raw.KEY_LEN
  , Raw.BLOCK_SIZE
  , Raw.DEFAULT_DIGEST_LEN
  )
  where
import Control.Monad (guard)
import qualified Data.ByteArray as BA
import qualified Data.ByteArray.Encoding as BA
import Data.Foldable
import qualified Data.Memory.PtrMethods as BA
import Data.Proxy
import Data.String
import Data.Word
import Foreign.Ptr
import Foreign.Storable
import GHC.TypeLits
import Prelude hiding (init)
import qualified BLAKE3.Raw as Raw
newtype Hasher = Hasher BA.ScrubbedBytes
  
allocRetHasher
  :: forall a
  .  (Ptr Raw.HasherInternal -> IO a)  
  -> IO (a, Hasher)
allocRetHasher g = do
  let size = fromIntegral (natVal (Proxy @Raw.HASHER_SIZE))
  (a, bs) <- BA.allocRet size g
  pure (a, Hasher bs)
withHasherInternal
  :: Hasher
  -> (Ptr Raw.HasherInternal -> IO a) 
  -> IO a
withHasherInternal (Hasher x) = BA.withByteArray x
copyHasher :: Hasher -> IO Hasher 
copyHasher (Hasher x) = fmap Hasher $ BA.copy x (const (pure ()))
newtype Digest (len :: Nat) = Digest BA.ScrubbedBytes
  deriving newtype ( Eq 
                   , BA.ByteArrayAccess)
instance Show (Digest len) where
  show (Digest x) = showBase16 x
allocRetDigest
  :: forall len a
  .  KnownNat len
  => (Ptr Word8 -> IO a)  
  -> IO (a, Digest len)
allocRetDigest g = do
  let size = fromIntegral (natVal (Proxy @len))
  (a, bs) <- BA.allocRet size g
  pure (a, Digest bs)
newtype Key = Key BA.ScrubbedBytes
  deriving newtype ( Eq 
                   , BA.ByteArrayAccess)
instance Show Key where
  show (Key x) = showBase16 x
keyLen :: Int
keyLen = fromIntegral (natVal (Proxy @Raw.KEY_LEN))
key
  :: BA.ByteArrayAccess bin
  => bin 
  -> Maybe Key 
key bin | BA.length bin == keyLen = Just (Key (BA.convert bin))
        | otherwise = Nothing
allocRetKey
  :: forall a
  . (Ptr Word8 -> IO a) 
  -> IO (a, Key)
allocRetKey g = do
  (a, bs) <- BA.allocRet keyLen g
  pure (a, Key bs)
newtype Context = Context BA.Bytes 
  deriving newtype (Eq)
instance Show Context where
  show (Context x) = showBase16 (BA.takeView x (BA.length x - 1))
instance IsString Context where
  fromString s = case traverse charToWord8 s of
      Nothing -> error "Not a valid String for Context"
      Just w8s -> Context $! BA.pack (w8s <> [0])
    where
      charToWord8 :: Char -> Maybe Word8
      charToWord8 c = do
        let i = fromEnum c
        guard (i > 0 && i < 256)
        pure (fromIntegral i)
context
  :: BA.ByteArrayAccess bin
  => bin 
  -> Maybe Context
context src
  | BA.any (0 ==) src = Nothing
  | otherwise = Just $ Context $
      let srcLen = BA.length src
          dstLen = srcLen + 1
      in BA.allocAndFreeze dstLen $ \pdst ->
         BA.withByteArray src $ \psrc -> do
           BA.memCopy pdst psrc srcLen
           pokeByteOff pdst srcLen (0 :: Word8)
init
  :: Ptr Raw.HasherInternal 
  -> IO ()
init = Raw.init
initKeyed
  :: Ptr Raw.HasherInternal 
  -> Key
  -> IO () 
initKeyed ph key0 =
  BA.withByteArray key0 $ \pkey ->
  Raw.init_keyed ph pkey
initDerive
  :: Ptr Raw.HasherInternal 
  -> Context
  -> IO ()
initDerive ph (Context ctx) =
  BA.withByteArray ctx $ \pc ->
  Raw.init_derive_key ph pc
update
  :: forall bin
  .  BA.ByteArrayAccess bin
  => Ptr Raw.HasherInternal 
  -> [bin]
  -> IO () 
update ph bins =
  for_ bins $ \bin ->
  BA.withByteArray bin $ \pbin ->
  Raw.update ph pbin (fromIntegral (BA.length bin))
finalize
  :: forall len
  .  KnownNat len
  => Ptr Raw.HasherInternal 
  -> IO (Digest len) 
finalize ph =
  fmap snd $ allocRetDigest $ \pd ->
  Raw.finalize ph pd (fromIntegral (natVal (Proxy @len)))
showBase16 :: BA.ByteArrayAccess x => x -> String
showBase16 = fmap (toEnum . fromIntegral)
           . BA.unpack @BA.ScrubbedBytes
           . BA.convertToBase BA.Base16