{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE ConstraintKinds            #-}

-- | Some utility functions useful for all Sha hashes.
module Raaz.Hash.Sha.Util
       ( shaImplementation
       , length64Write
       , length128Write
       , Compressor
       ) where

import Data.Monoid                  ( (<>)      )
import Data.Word
import Foreign.Ptr                  ( Ptr       )
import Foreign.Storable             ( Storable  )

import Raaz.Core
import Raaz.Core.Transfer
import Raaz.Hash.Internal

-- | The utilities in this module can be used on primitives which
-- satisfies the following constraint.
type IsSha h    = (Primitive h, Storable h, Memory (HashMemory h))

-- | All actions here are in the following monad
type ShaMonad h = MT (HashMemory h)

-- | The Writes used in this module.
type ShaWrite h = WriteM (ShaMonad h)
--
-- The message in the sha1 family of hashes pads the message, the last
-- few bytes of which are used to store the message length. Hashes
-- like sha1, sha256 etc writes the message lengths in 64-bits while
-- sha512 uses lengths in 128 bits. The generic writes `length64Write`
-- and `length128Write` are write actions that support this.

-- | Type that captures length writes.
type LengthWrite h = BITS Word64 -> ShaWrite h

-- | The length encoding that uses 64-bits.
length64Write :: LengthWrite h
length64Write :: LengthWrite h
length64Write (BITS Word64
w) = BE Word64 -> WriteM (ShaMonad h)
forall (m :: * -> *) a. (MonadIO m, EndianStore a) => a -> WriteM m
write (BE Word64 -> WriteM (ShaMonad h))
-> BE Word64 -> WriteM (ShaMonad h)
forall a b. (a -> b) -> a -> b
$ Word64 -> BE Word64
forall w. w -> BE w
bigEndian Word64
w

-- | The length encoding that uses 128-bits.
length128Write :: LengthWrite h
length128Write :: LengthWrite h
length128Write BITS Word64
w = Word64 -> WriteM (ShaMonad h)
forall (m :: * -> *) a. (MonadIO m, Storable a) => a -> WriteM m
writeStorable (Word64
0 :: Word64) WriteM (ShaMonad h) -> WriteM (ShaMonad h) -> WriteM (ShaMonad h)
forall a. Semigroup a => a -> a -> a
<> LengthWrite h
forall h. LengthWrite h
length64Write BITS Word64
w


-- | The type alias for the raw compressor function. The compressor function
-- does not need to know the length of the message so far and hence
-- this is not supposed to update lengths.
type Compressor h = Pointer -- ^ The buffer to compress
                  -> Int    -- ^ The number of blocks to compress
                  -> Ptr h  -- ^ The cell memory containing the hash
                  -> IO ()

-- | Action on a Buffer
type ShaBufferAction bufSize h = Pointer       -- ^ The data buffer
                               -> bufSize      -- ^ Total data present
                               -> ShaMonad h ()

-- | Lifts the raw compressor to a buffer action. This function does not update
-- the lengths.
liftCompressor          :: IsSha h => Compressor h -> ShaBufferAction (BLOCKS h) h
liftCompressor :: Compressor h -> ShaBufferAction (BLOCKS h) h
liftCompressor Compressor h
comp Pointer
ptr = (HashMemory h -> MemoryCell h)
-> MT (MemoryCell h) () -> MT (HashMemory h) ()
forall (mT :: * -> * -> *) mem submem a.
MemoryThread mT =>
(mem -> submem) -> mT submem a -> mT mem a
onSubMemory HashMemory h -> MemoryCell h
forall h. HashMemory h -> MemoryCell h
hashCell (MT (MemoryCell h) () -> MT (HashMemory h) ())
-> (BLOCKS h -> MT (MemoryCell h) ())
-> BLOCKS h
-> MT (HashMemory h) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ptr h -> IO ()) -> MT (MemoryCell h) ()
forall (mT :: * -> * -> *) a b.
(MemoryThread mT, Storable a) =>
(Ptr a -> IO b) -> mT (MemoryCell a) b
withCellPointer ((Ptr h -> IO ()) -> MT (MemoryCell h) ())
-> (BLOCKS h -> Ptr h -> IO ()) -> BLOCKS h -> MT (MemoryCell h) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Compressor h
comp Pointer
ptr (Int -> Ptr h -> IO ())
-> (BLOCKS h -> Int) -> BLOCKS h -> Ptr h -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BLOCKS h -> Int
forall a. Enum a => a -> Int
fromEnum


-- | The combinator `shaBlocks` on an input compressor @comp@ gives a buffer action
-- that process blocks of data.
shaBlocks :: Primitive h
          => ShaBufferAction (BLOCKS h) h -- ^ the compressor function
          -> ShaBufferAction (BLOCKS h) h
shaBlocks :: ShaBufferAction (BLOCKS h) h -> ShaBufferAction (BLOCKS h) h
shaBlocks ShaBufferAction (BLOCKS h) h
comp Pointer
ptr BLOCKS h
nblocks =
  ShaBufferAction (BLOCKS h) h
comp Pointer
ptr BLOCKS h
nblocks MT (HashMemory h) ()
-> MT (HashMemory h) () -> MT (HashMemory h) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> BLOCKS h -> MT (HashMemory h) ()
forall u h. LengthUnit u => u -> MT (HashMemory h) ()
updateLength BLOCKS h
nblocks

-- | The combinator `shaFinal` on an input compressor @comp@ gives
-- buffer action for the final chunk of data.
shaFinal :: (Primitive h, Storable h)
         => ShaBufferAction (BLOCKS h) h   -- ^ the raw compressor
         -> LengthWrite h                  -- ^ the length writer
         -> ShaBufferAction (BYTES Int) h
shaFinal :: ShaBufferAction (BLOCKS h) h
-> LengthWrite h -> ShaBufferAction (BYTES Int) h
shaFinal ShaBufferAction (BLOCKS h) h
comp LengthWrite h
lenW Pointer
ptr BYTES Int
msgLen = do
  BYTES Int -> ShaMonad h ()
forall u h. LengthUnit u => u -> MT (HashMemory h) ()
updateLength BYTES Int
msgLen
  BITS Word64
totalBits <- MT (HashMemory h) (BITS Word64)
forall h. MT (HashMemory h) (BITS Word64)
extractLength
  let pad :: ShaWrite h
pad      = h -> BYTES Int -> ShaWrite h -> ShaWrite h
forall h. IsSha h => h -> BYTES Int -> ShaWrite h -> ShaWrite h
shaPad h
forall a. HasCallStack => a
undefined BYTES Int
msgLen (ShaWrite h -> ShaWrite h) -> ShaWrite h -> ShaWrite h
forall a b. (a -> b) -> a -> b
$ LengthWrite h
lenW BITS Word64
totalBits
      blocks :: BLOCKS h
blocks   = BYTES Int -> BLOCKS h
forall src dest. (LengthUnit src, LengthUnit dest) => src -> dest
atMost (BYTES Int -> BLOCKS h) -> BYTES Int -> BLOCKS h
forall a b. (a -> b) -> a -> b
$ ShaWrite h -> BYTES Int
forall (m :: * -> *). WriteM m -> BYTES Int
bytesToWrite ShaWrite h
pad
      in ShaWrite h -> Pointer -> ShaMonad h ()
forall (m :: * -> *). WriteM m -> Pointer -> m ()
unsafeWrite ShaWrite h
pad Pointer
ptr ShaMonad h () -> ShaMonad h () -> ShaMonad h ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ShaBufferAction (BLOCKS h) h
comp Pointer
ptr BLOCKS h
blocks


-- | Padding is message followed by a single bit 1 and a glue of zeros
-- followed by the length so that the message is aligned to the block boundary.
shaPad :: IsSha h
       => h
       -> BYTES Int -- Message length
       -> ShaWrite h
       -> ShaWrite h
shaPad :: h -> BYTES Int -> ShaWrite h -> ShaWrite h
shaPad h
h BYTES Int
msgLen = Word8 -> BLOCKS h -> ShaWrite h -> ShaWrite h -> ShaWrite h
forall n (m :: * -> *).
(LengthUnit n, MonadIO m) =>
Word8 -> n -> WriteM m -> WriteM m -> WriteM m
glueWrites Word8
0 BLOCKS h
boundary ShaWrite h
hdr
  where skipMessage :: ShaWrite h
skipMessage = BYTES Int -> ShaWrite h
forall u (m :: * -> *). (LengthUnit u, Monad m) => u -> WriteM m
skipWrite BYTES Int
msgLen
        oneBit :: ShaWrite h
oneBit      = Word8 -> ShaWrite h
forall (m :: * -> *) a. (MonadIO m, Storable a) => a -> WriteM m
writeStorable (Word8
0x80 :: Word8)
        hdr :: ShaWrite h
hdr         = ShaWrite h
skipMessage ShaWrite h -> ShaWrite h -> ShaWrite h
forall a. Semigroup a => a -> a -> a
<> ShaWrite h
oneBit
        boundary :: BLOCKS h
boundary    = Int -> h -> BLOCKS h
forall p. Int -> p -> BLOCKS p
blocksOf Int
1 h
h



-- | Creates an implementation for a sha hash given the compressor and
-- the length writer.
shaImplementation :: IsSha h
                  => String                   -- ^ Name
                  -> String                   -- ^ Description
                  -> Compressor  h
                  -> LengthWrite h
                  -> HashI h (HashMemory h)
shaImplementation :: String
-> String
-> Compressor h
-> LengthWrite h
-> HashI h (HashMemory h)
shaImplementation String
nam String
des Compressor h
comp LengthWrite h
lenW
  = HashI :: forall h m.
String
-> String
-> (Pointer -> BLOCKS h -> MT m ())
-> (Pointer -> BYTES Int -> MT m ())
-> Alignment
-> HashI h m
HashI { hashIName :: String
hashIName               = String
nam
          , hashIDescription :: String
hashIDescription        = String
des
          , compress :: Pointer -> BLOCKS h -> MT (HashMemory h) ()
compress                = (Pointer -> BLOCKS h -> MT (HashMemory h) ())
-> Pointer -> BLOCKS h -> MT (HashMemory h) ()
forall h.
Primitive h =>
ShaBufferAction (BLOCKS h) h -> ShaBufferAction (BLOCKS h) h
shaBlocks Pointer -> BLOCKS h -> MT (HashMemory h) ()
shaComp
          , compressFinal :: Pointer -> BYTES Int -> MT (HashMemory h) ()
compressFinal           = (Pointer -> BLOCKS h -> MT (HashMemory h) ())
-> LengthWrite h -> Pointer -> BYTES Int -> MT (HashMemory h) ()
forall h.
(Primitive h, Storable h) =>
ShaBufferAction (BLOCKS h) h
-> LengthWrite h -> ShaBufferAction (BYTES Int) h
shaFinal  Pointer -> BLOCKS h -> MT (HashMemory h) ()
shaComp LengthWrite h
lenW
          , compressStartAlignment :: Alignment
compressStartAlignment  = Alignment
wordAlignment
          }
  where shaComp :: Pointer -> BLOCKS h -> MT (HashMemory h) ()
shaComp = Compressor h -> Pointer -> BLOCKS h -> MT (HashMemory h) ()
forall h. IsSha h => Compressor h -> ShaBufferAction (BLOCKS h) h
liftCompressor Compressor h
comp

{-# INLINE shaImplementation #-}