{-# LANGUAGE DataKinds         #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-- |
--
-- Module      : Raaz.Primitive.HashMemory
-- Description : Memory elements for typical hashes.
-- Copyright   : (c) Piyush P Kurur, 2019
-- License     : Apache-2.0 OR BSD-3-Clause
-- Maintainer  : Piyush P Kurur <ppk@iitpkd.ac.in>
-- Stability   : experimental
--
module Raaz.Primitive.HashMemory
       ( HashMemory128, HashMemory64
       , hashCellPointer, hashCell128Pointer
       , lengthCellPointer, uLengthCellPointer, lLengthCellPointer
       , getLength, getULength, getLLength
       , updateLength, updateLength128
       ) where


import Foreign.Storable           ( Storable(..)  )
import Raaz.Core

-- | Similar to `HashMemory128` but keeps track of length as a 64-bit quantity.
data HashMemory64 h = HashMemory64 { forall h. HashMemory64 h -> MemoryCell h
hashCell    :: MemoryCell h
                                   , forall h. HashMemory64 h -> MemoryCell (BYTES Word64)
lengthCell  :: MemoryCell (BYTES Word64)
                                   }

-- | Memory element that keeps track of a hash and the total bytes
-- processed (as a 128 bit quantity). Such a memory element is useful
-- for building the memory element for cryptographic hashes.

data HashMemory128 h = HashMemory128 { forall h. HashMemory128 h -> MemoryCell h
hashCell128 :: MemoryCell h
                                     , forall h. HashMemory128 h -> MemoryCell (BYTES Word64)
uLengthCell :: MemoryCell (BYTES Word64)
                                     , forall h. HashMemory128 h -> MemoryCell (BYTES Word64)
lLengthCell :: MemoryCell (BYTES Word64)
                                     }



-- | Get the length.
getLength :: HashMemory64 h -> IO (BYTES Word64)
getLength :: forall h. HashMemory64 h -> IO (BYTES Word64)
getLength = MemoryCell (BYTES Word64) -> IO (BYTES Word64)
forall m v. Extractable m v => m -> IO v
extract (MemoryCell (BYTES Word64) -> IO (BYTES Word64))
-> (HashMemory64 h -> MemoryCell (BYTES Word64))
-> HashMemory64 h
-> IO (BYTES Word64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMemory64 h -> MemoryCell (BYTES Word64)
forall h. HashMemory64 h -> MemoryCell (BYTES Word64)
lengthCell

-- | Get the higher order 64-bits.
getULength :: HashMemory128 h -> IO (BYTES Word64)
getULength :: forall h. HashMemory128 h -> IO (BYTES Word64)
getULength = MemoryCell (BYTES Word64) -> IO (BYTES Word64)
forall m v. Extractable m v => m -> IO v
extract (MemoryCell (BYTES Word64) -> IO (BYTES Word64))
-> (HashMemory128 h -> MemoryCell (BYTES Word64))
-> HashMemory128 h
-> IO (BYTES Word64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMemory128 h -> MemoryCell (BYTES Word64)
forall h. HashMemory128 h -> MemoryCell (BYTES Word64)
uLengthCell

-- | Get the lower order 64-bits
getLLength :: HashMemory128 h -> IO (BYTES Word64)
getLLength :: forall h. HashMemory128 h -> IO (BYTES Word64)
getLLength =  MemoryCell (BYTES Word64) -> IO (BYTES Word64)
forall m v. Extractable m v => m -> IO v
extract (MemoryCell (BYTES Word64) -> IO (BYTES Word64))
-> (HashMemory128 h -> MemoryCell (BYTES Word64))
-> HashMemory128 h
-> IO (BYTES Word64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMemory128 h -> MemoryCell (BYTES Word64)
forall h. HashMemory128 h -> MemoryCell (BYTES Word64)
lLengthCell


-- | Get the pointer to the hash.
hashCellPointer :: Storable h
                => HashMemory64 h
                -> Ptr h
hashCellPointer :: forall h. Storable h => HashMemory64 h -> Ptr h
hashCellPointer = MemoryCell h -> Ptr h
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer (MemoryCell h -> Ptr h)
-> (HashMemory64 h -> MemoryCell h) -> HashMemory64 h -> Ptr h
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMemory64 h -> MemoryCell h
forall h. HashMemory64 h -> MemoryCell h
hashCell
-- | Get the pointer to the array which stores the digest
hashCell128Pointer :: Storable h
                  => HashMemory128 h
                  -> Ptr h
hashCell128Pointer :: forall h. Storable h => HashMemory128 h -> Ptr h
hashCell128Pointer = MemoryCell h -> Ptr h
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer (MemoryCell h -> Ptr h)
-> (HashMemory128 h -> MemoryCell h) -> HashMemory128 h -> Ptr h
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMemory128 h -> MemoryCell h
forall h. HashMemory128 h -> MemoryCell h
hashCell128



-- | Get the pointer to upper half of the length bytes.
lengthCellPointer :: Storable h
                   => HashMemory64 h
                   -> Ptr (BYTES Word64)
lengthCellPointer :: forall h. Storable h => HashMemory64 h -> Ptr (BYTES Word64)
lengthCellPointer = MemoryCell (BYTES Word64) -> Ptr (BYTES Word64)
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer (MemoryCell (BYTES Word64) -> Ptr (BYTES Word64))
-> (HashMemory64 h -> MemoryCell (BYTES Word64))
-> HashMemory64 h
-> Ptr (BYTES Word64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMemory64 h -> MemoryCell (BYTES Word64)
forall h. HashMemory64 h -> MemoryCell (BYTES Word64)
lengthCell

-- | Get the pointer to upper half of the length bytes.
uLengthCellPointer :: Storable h
                   => HashMemory128 h
                   -> Ptr (BYTES Word64)
uLengthCellPointer :: forall h. Storable h => HashMemory128 h -> Ptr (BYTES Word64)
uLengthCellPointer = MemoryCell (BYTES Word64) -> Ptr (BYTES Word64)
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer (MemoryCell (BYTES Word64) -> Ptr (BYTES Word64))
-> (HashMemory128 h -> MemoryCell (BYTES Word64))
-> HashMemory128 h
-> Ptr (BYTES Word64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMemory128 h -> MemoryCell (BYTES Word64)
forall h. HashMemory128 h -> MemoryCell (BYTES Word64)
uLengthCell

-- | Get the pointer to the lower half of the length bytes.
lLengthCellPointer :: Storable h
                   => HashMemory128 h
                   -> Ptr (BYTES Word64)
lLengthCellPointer :: forall h. Storable h => HashMemory128 h -> Ptr (BYTES Word64)
lLengthCellPointer = MemoryCell (BYTES Word64) -> Ptr (BYTES Word64)
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer (MemoryCell (BYTES Word64) -> Ptr (BYTES Word64))
-> (HashMemory128 h -> MemoryCell (BYTES Word64))
-> HashMemory128 h
-> Ptr (BYTES Word64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMemory128 h -> MemoryCell (BYTES Word64)
forall h. HashMemory128 h -> MemoryCell (BYTES Word64)
lLengthCell


-- | Update the 128 bit length stored in the hash memory.
updateLength128 :: LengthUnit len
                => len
                -> HashMemory128 h
                -> IO ()
updateLength128 :: forall len h. LengthUnit len => len -> HashMemory128 h -> IO ()
updateLength128 len
len HashMemory128 h
hmem =
  do BYTES Word64
l <- HashMemory128 h -> IO (BYTES Word64)
forall h. HashMemory128 h -> IO (BYTES Word64)
getLLength HashMemory128 h
hmem
     BYTES Word64 -> MemoryCell (BYTES Word64) -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise (BYTES Word64
l BYTES Word64 -> BYTES Word64 -> BYTES Word64
forall a. Num a => a -> a -> a
+ BYTES Word64
lenBytes) (MemoryCell (BYTES Word64) -> IO ())
-> MemoryCell (BYTES Word64) -> IO ()
forall a b. (a -> b) -> a -> b
$ HashMemory128 h -> MemoryCell (BYTES Word64)
forall h. HashMemory128 h -> MemoryCell (BYTES Word64)
lLengthCell HashMemory128 h
hmem
     Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (BYTES Word64
l BYTES Word64 -> BYTES Word64 -> Bool
forall a. Ord a => a -> a -> Bool
> BYTES Word64
forall a. Bounded a => a
maxBound BYTES Word64 -> BYTES Word64 -> BYTES Word64
forall a. Num a => a -> a -> a
- BYTES Word64
lenBytes) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
       (BYTES Word64 -> BYTES Word64)
-> MemoryCell (BYTES Word64) -> IO ()
forall mem a b.
(Initialisable mem a, Extractable mem b) =>
(b -> a) -> mem -> IO ()
modifyMem (BYTES Word64 -> BYTES Word64 -> BYTES Word64
forall a. Num a => a -> a -> a
+(BYTES Word64
1 :: BYTES Word64)) (MemoryCell (BYTES Word64) -> IO ())
-> MemoryCell (BYTES Word64) -> IO ()
forall a b. (a -> b) -> a -> b
$ HashMemory128 h -> MemoryCell (BYTES Word64)
forall h. HashMemory128 h -> MemoryCell (BYTES Word64)
uLengthCell HashMemory128 h
hmem
  where lenBytes :: BYTES Word64
lenBytes = BYTES Int -> BYTES Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (BYTES Int -> BYTES Word64) -> BYTES Int -> BYTES Word64
forall a b. (a -> b) -> a -> b
$ len -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes len
len

-- | Update the 64-bit length stored in the hash memory.
updateLength :: LengthUnit len
             => len
             -> HashMemory64 h
             -> IO ()
updateLength :: forall len h. LengthUnit len => len -> HashMemory64 h -> IO ()
updateLength len
len = (BYTES Word64 -> BYTES Word64)
-> MemoryCell (BYTES Word64) -> IO ()
forall mem a b.
(Initialisable mem a, Extractable mem b) =>
(b -> a) -> mem -> IO ()
modifyMem (BYTES Word64 -> BYTES Word64 -> BYTES Word64
forall a. Num a => a -> a -> a
+BYTES Word64
lenBytes) (MemoryCell (BYTES Word64) -> IO ())
-> (HashMemory64 h -> MemoryCell (BYTES Word64))
-> HashMemory64 h
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMemory64 h -> MemoryCell (BYTES Word64)
forall h. HashMemory64 h -> MemoryCell (BYTES Word64)
lengthCell
  where lenBytes :: BYTES Word64
lenBytes = BYTES Int -> BYTES Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (BYTES Int -> BYTES Word64) -> BYTES Int -> BYTES Word64
forall a b. (a -> b) -> a -> b
$ len -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes len
len :: BYTES Word64

instance Storable h  => Memory (HashMemory128 h) where
  memoryAlloc :: Alloc (HashMemory128 h)
memoryAlloc     = MemoryCell h
-> MemoryCell (BYTES Word64)
-> MemoryCell (BYTES Word64)
-> HashMemory128 h
forall h.
MemoryCell h
-> MemoryCell (BYTES Word64)
-> MemoryCell (BYTES Word64)
-> HashMemory128 h
HashMemory128 (MemoryCell h
 -> MemoryCell (BYTES Word64)
 -> MemoryCell (BYTES Word64)
 -> HashMemory128 h)
-> TwistRF AllocField (BYTES Int) (MemoryCell h)
-> TwistRF
     AllocField
     (BYTES Int)
     (MemoryCell (BYTES Word64)
      -> MemoryCell (BYTES Word64) -> HashMemory128 h)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TwistRF AllocField (BYTES Int) (MemoryCell h)
forall m. Memory m => Alloc m
memoryAlloc TwistRF
  AllocField
  (BYTES Int)
  (MemoryCell (BYTES Word64)
   -> MemoryCell (BYTES Word64) -> HashMemory128 h)
-> TwistRF AllocField (BYTES Int) (MemoryCell (BYTES Word64))
-> TwistRF
     AllocField
     (BYTES Int)
     (MemoryCell (BYTES Word64) -> HashMemory128 h)
forall a b.
TwistRF AllocField (BYTES Int) (a -> b)
-> TwistRF AllocField (BYTES Int) a
-> TwistRF AllocField (BYTES Int) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) (MemoryCell (BYTES Word64))
forall m. Memory m => Alloc m
memoryAlloc TwistRF
  AllocField
  (BYTES Int)
  (MemoryCell (BYTES Word64) -> HashMemory128 h)
-> TwistRF AllocField (BYTES Int) (MemoryCell (BYTES Word64))
-> Alloc (HashMemory128 h)
forall a b.
TwistRF AllocField (BYTES Int) (a -> b)
-> TwistRF AllocField (BYTES Int) a
-> TwistRF AllocField (BYTES Int) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) (MemoryCell (BYTES Word64))
forall m. Memory m => Alloc m
memoryAlloc
  unsafeToPointer :: HashMemory128 h -> Ptr Word8
unsafeToPointer = MemoryCell h -> Ptr Word8
forall m. Memory m => m -> Ptr Word8
unsafeToPointer (MemoryCell h -> Ptr Word8)
-> (HashMemory128 h -> MemoryCell h)
-> HashMemory128 h
-> Ptr Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMemory128 h -> MemoryCell h
forall h. HashMemory128 h -> MemoryCell h
hashCell128

instance Storable h  => Memory (HashMemory64 h) where
  memoryAlloc :: Alloc (HashMemory64 h)
memoryAlloc     = MemoryCell h -> MemoryCell (BYTES Word64) -> HashMemory64 h
forall h.
MemoryCell h -> MemoryCell (BYTES Word64) -> HashMemory64 h
HashMemory64 (MemoryCell h -> MemoryCell (BYTES Word64) -> HashMemory64 h)
-> TwistRF AllocField (BYTES Int) (MemoryCell h)
-> TwistRF
     AllocField
     (BYTES Int)
     (MemoryCell (BYTES Word64) -> HashMemory64 h)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TwistRF AllocField (BYTES Int) (MemoryCell h)
forall m. Memory m => Alloc m
memoryAlloc TwistRF
  AllocField
  (BYTES Int)
  (MemoryCell (BYTES Word64) -> HashMemory64 h)
-> TwistRF AllocField (BYTES Int) (MemoryCell (BYTES Word64))
-> Alloc (HashMemory64 h)
forall a b.
TwistRF AllocField (BYTES Int) (a -> b)
-> TwistRF AllocField (BYTES Int) a
-> TwistRF AllocField (BYTES Int) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) (MemoryCell (BYTES Word64))
forall m. Memory m => Alloc m
memoryAlloc
  unsafeToPointer :: HashMemory64 h -> Ptr Word8
unsafeToPointer = MemoryCell h -> Ptr Word8
forall m. Memory m => m -> Ptr Word8
unsafeToPointer (MemoryCell h -> Ptr Word8)
-> (HashMemory64 h -> MemoryCell h) -> HashMemory64 h -> Ptr Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMemory64 h -> MemoryCell h
forall h. HashMemory64 h -> MemoryCell h
hashCell

instance Storable h => Initialisable (HashMemory128 h) h where
  initialise :: h -> HashMemory128 h -> IO ()
initialise h
h HashMemory128 h
hmem = do h -> MemoryCell h -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise h
h (MemoryCell h -> IO ()) -> MemoryCell h -> IO ()
forall a b. (a -> b) -> a -> b
$ HashMemory128 h -> MemoryCell h
forall h. HashMemory128 h -> MemoryCell h
hashCell128 HashMemory128 h
hmem
                         BYTES Word64 -> MemoryCell (BYTES Word64) -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise (BYTES Word64
0 :: BYTES Word64) (MemoryCell (BYTES Word64) -> IO ())
-> MemoryCell (BYTES Word64) -> IO ()
forall a b. (a -> b) -> a -> b
$ HashMemory128 h -> MemoryCell (BYTES Word64)
forall h. HashMemory128 h -> MemoryCell (BYTES Word64)
uLengthCell HashMemory128 h
hmem
                         BYTES Word64 -> MemoryCell (BYTES Word64) -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise (BYTES Word64
0 :: BYTES Word64) (MemoryCell (BYTES Word64) -> IO ())
-> MemoryCell (BYTES Word64) -> IO ()
forall a b. (a -> b) -> a -> b
$ HashMemory128 h -> MemoryCell (BYTES Word64)
forall h. HashMemory128 h -> MemoryCell (BYTES Word64)
lLengthCell HashMemory128 h
hmem


instance Storable h => Initialisable (HashMemory64 h) h where
  initialise :: h -> HashMemory64 h -> IO ()
initialise h
h HashMemory64 h
hmem = do h -> MemoryCell h -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise h
h (MemoryCell h -> IO ()) -> MemoryCell h -> IO ()
forall a b. (a -> b) -> a -> b
$ HashMemory64 h -> MemoryCell h
forall h. HashMemory64 h -> MemoryCell h
hashCell HashMemory64 h
hmem
                         BYTES Word64 -> MemoryCell (BYTES Word64) -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise (BYTES Word64
0 :: BYTES Word64) (MemoryCell (BYTES Word64) -> IO ())
-> MemoryCell (BYTES Word64) -> IO ()
forall a b. (a -> b) -> a -> b
$ HashMemory64 h -> MemoryCell (BYTES Word64)
forall h. HashMemory64 h -> MemoryCell (BYTES Word64)
lengthCell HashMemory64 h
hmem


instance Storable h => Extractable (HashMemory128 h) h where
  extract :: HashMemory128 h -> IO h
extract = MemoryCell h -> IO h
forall m v. Extractable m v => m -> IO v
extract (MemoryCell h -> IO h)
-> (HashMemory128 h -> MemoryCell h) -> HashMemory128 h -> IO h
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMemory128 h -> MemoryCell h
forall h. HashMemory128 h -> MemoryCell h
hashCell128

instance Storable h => Extractable (HashMemory64 h) h where
  extract :: HashMemory64 h -> IO h
extract = MemoryCell h -> IO h
forall m v. Extractable m v => m -> IO v
extract (MemoryCell h -> IO h)
-> (HashMemory64 h -> MemoryCell h) -> HashMemory64 h -> IO h
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMemory64 h -> MemoryCell h
forall h. HashMemory64 h -> MemoryCell h
hashCell