{-# LANGUAGE DataKinds                  #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE FlexibleInstances          #-}
-- | This module implements memory elements used for Poly1305
-- implementations.
module Poly1305.Memory
       ( Mem(..)
       , Element
       , elementToInteger
       , rKeyPtr
       , sKeyPtr
       , accumPtr
       ) where

import qualified Data.Vector.Unboxed as V
import           Foreign.Ptr                        ( castPtr )
import           Raaz.Core
import qualified Raaz.Core.Types.Internal        as TI
import           Raaz.Primitive.Poly1305.Internal

import           Raaz.Verse.Poly1305.C.Portable (verse_poly1305_c_portable_clamp)

-- | An element in the finite field GF(2¹³⁰ - 5) requires 130 bits
-- which is stored as three 64-bit word where the last word has only
-- 2-bits.
type Element = Tuple 3 Word64


-- | Convert the element to an integer.
elementToInteger :: Element -> Integer
elementToInteger :: Element -> Integer
elementToInteger = (Word64 -> Integer -> Integer)
-> Integer -> Vector Word64 -> Integer
forall a b. Unbox a => (a -> b -> b) -> b -> Vector a -> b
V.foldr Word64 -> Integer -> Integer
fld Integer
0 (Vector Word64 -> Integer)
-> (Element -> Vector Word64) -> Element -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Element -> Vector Word64
forall (dim :: Nat) a. Tuple dim a -> Vector a
unsafeToVector
  where fld :: Word64 -> Integer -> Integer
        fld :: Word64 -> Integer -> Integer
fld Word64
w Integer
i = Word64 -> Integer
forall a. Integral a => a -> Integer
toInteger Word64
w Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
i Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` Int
32


-- | The memory associated with Poly1305 stores the
data Mem = Mem { Mem -> MemoryCell Element
accCell :: MemoryCell Element
               , Mem -> MemoryCell R
rCell   :: MemoryCell R
               , Mem -> MemoryCell S
sCell   :: MemoryCell S
               }

-- | Clearing the accumulator.
clearAcc :: Mem -> IO ()
clearAcc :: Mem -> IO ()
clearAcc = Element -> MemoryCell Element -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise Element
zero (MemoryCell Element -> IO ())
-> (Mem -> MemoryCell Element) -> Mem -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mem -> MemoryCell Element
accCell
  where zero :: Element
        zero :: Element
zero = [Word64] -> Element
forall a (dim :: Nat).
(Unbox a, Dimension dim) =>
[a] -> Tuple dim a
unsafeFromList [Word64
0,Word64
0,Word64
0]

instance Memory Mem where
  memoryAlloc :: Alloc Mem
memoryAlloc     = MemoryCell Element -> MemoryCell R -> MemoryCell S -> Mem
Mem (MemoryCell Element -> MemoryCell R -> MemoryCell S -> Mem)
-> TwistRF AllocField (BYTES Int) (MemoryCell Element)
-> TwistRF
     AllocField (BYTES Int) (MemoryCell R -> MemoryCell S -> Mem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TwistRF AllocField (BYTES Int) (MemoryCell Element)
forall m. Memory m => Alloc m
memoryAlloc TwistRF
  AllocField (BYTES Int) (MemoryCell R -> MemoryCell S -> Mem)
-> TwistRF AllocField (BYTES Int) (MemoryCell R)
-> TwistRF AllocField (BYTES Int) (MemoryCell S -> Mem)
forall a b.
TwistRF AllocField (BYTES Int) (a -> b)
-> TwistRF AllocField (BYTES Int) a
-> TwistRF AllocField (BYTES Int) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) (MemoryCell R)
forall m. Memory m => Alloc m
memoryAlloc TwistRF AllocField (BYTES Int) (MemoryCell S -> Mem)
-> TwistRF AllocField (BYTES Int) (MemoryCell S) -> Alloc Mem
forall a b.
TwistRF AllocField (BYTES Int) (a -> b)
-> TwistRF AllocField (BYTES Int) a
-> TwistRF AllocField (BYTES Int) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) (MemoryCell S)
forall m. Memory m => Alloc m
memoryAlloc
  unsafeToPointer :: Mem -> Ptr Word8
unsafeToPointer = MemoryCell Element -> Ptr Word8
forall m. Memory m => m -> Ptr Word8
unsafeToPointer (MemoryCell Element -> Ptr Word8)
-> (Mem -> MemoryCell Element) -> Mem -> Ptr Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mem -> MemoryCell Element
accCell


-- | Get the pointer to the array holding the key fragment r.
rKeyPtr  :: Mem  -> Ptr (Tuple 2 Word64)
rKeyPtr :: Mem -> Ptr (Tuple 2 Word64)
rKeyPtr  = Ptr R -> Ptr (Tuple 2 Word64)
forall a b. Ptr a -> Ptr b
castPtr (Ptr R -> Ptr (Tuple 2 Word64))
-> (Mem -> Ptr R) -> Mem -> Ptr (Tuple 2 Word64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemoryCell R -> Ptr R
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer (MemoryCell R -> Ptr R) -> (Mem -> MemoryCell R) -> Mem -> Ptr R
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mem -> MemoryCell R
rCell

-- | Get the pointer to the array holding the key fragment s.
sKeyPtr  :: Mem -> Ptr (Tuple 2 Word64)
sKeyPtr :: Mem -> Ptr (Tuple 2 Word64)
sKeyPtr  = Ptr S -> Ptr (Tuple 2 Word64)
forall a b. Ptr a -> Ptr b
castPtr (Ptr S -> Ptr (Tuple 2 Word64))
-> (Mem -> Ptr S) -> Mem -> Ptr (Tuple 2 Word64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemoryCell S -> Ptr S
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer (MemoryCell S -> Ptr S) -> (Mem -> MemoryCell S) -> Mem -> Ptr S
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mem -> MemoryCell S
sCell

-- | Get the pointer to the accumulator array.
accumPtr :: Mem -> Ptr Element
accumPtr :: Mem -> Ptr Element
accumPtr = Ptr Element -> Ptr Element
forall a b. Ptr a -> Ptr b
castPtr (Ptr Element -> Ptr Element)
-> (Mem -> Ptr Element) -> Mem -> Ptr Element
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemoryCell Element -> Ptr Element
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer (MemoryCell Element -> Ptr Element)
-> (Mem -> MemoryCell Element) -> Mem -> Ptr Element
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  Mem -> MemoryCell Element
accCell

-- |  The clamping function on pointer
clampPtr :: Ptr (Tuple 2 Word64) -> IO ()
clampPtr :: Ptr (Tuple 2 Word64) -> IO ()
clampPtr = (Ptr (Tuple 2 Word64) -> Word64 -> IO ())
-> Word64 -> Ptr (Tuple 2 Word64) -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr (Tuple 2 Word64) -> Word64 -> IO ()
verse_poly1305_c_portable_clamp Word64
1

-- | The clamping operation
clamp :: Mem -> IO ()
clamp :: Mem -> IO ()
clamp =  Ptr (Tuple 2 Word64) -> IO ()
clampPtr (Ptr (Tuple 2 Word64) -> IO ())
-> (Mem -> Ptr (Tuple 2 Word64)) -> Mem -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mem -> Ptr (Tuple 2 Word64)
rKeyPtr

instance Initialisable Mem (Key Poly1305) where
  initialise :: Key Poly1305 -> Mem -> IO ()
initialise (Key R
r S
s) Mem
mem = do Mem -> IO ()
clearAcc Mem
mem
                                R -> MemoryCell R -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise R
r (MemoryCell R -> IO ()) -> MemoryCell R -> IO ()
forall a b. (a -> b) -> a -> b
$ Mem -> MemoryCell R
rCell Mem
mem
                                S -> MemoryCell S -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise S
s (MemoryCell S -> IO ()) -> MemoryCell S -> IO ()
forall a b. (a -> b) -> a -> b
$ Mem -> MemoryCell S
sCell Mem
mem
                                Mem -> IO ()
clamp Mem
mem

instance Extractable Mem Poly1305 where
    extract :: Mem -> IO Poly1305
extract = (Element -> Poly1305) -> IO Element -> IO Poly1305
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Element -> Poly1305
toPoly1305 (IO Element -> IO Poly1305)
-> (Mem -> IO Element) -> Mem -> IO Poly1305
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemoryCell Element -> IO Element
forall m v. Extractable m v => m -> IO v
extract (MemoryCell Element -> IO Element)
-> (Mem -> MemoryCell Element) -> Mem -> IO Element
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mem -> MemoryCell Element
accCell
      where toPoly1305 :: Element -> Poly1305
toPoly1305 = WORD -> Poly1305
Poly1305 (WORD -> Poly1305) -> (Element -> WORD) -> Element -> Poly1305
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word64 -> LE Word64) -> Tuple 2 Word64 -> WORD
forall a b (dim :: Nat).
(Unbox a, Unbox b) =>
(a -> b) -> Tuple dim a -> Tuple dim b
TI.map Word64 -> LE Word64
forall w. w -> LE w
littleEndian (Tuple 2 Word64 -> WORD)
-> (Element -> Tuple 2 Word64) -> Element -> WORD
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Element -> Tuple 2 Word64
project
            project :: Tuple 3 Word64 -> Tuple 2 Word64
            project :: Element -> Tuple 2 Word64
project = Element -> Tuple 2 Word64
forall a (dim0 :: Nat) (dim1 :: Nat).
(Unbox a, Dimension dim0) =>
Tuple dim1 a -> Tuple dim0 a
initial

instance WriteAccessible Mem where
  writeAccess :: Mem -> [Access]
writeAccess Mem
mem          = MemoryCell R -> [Access]
forall mem. WriteAccessible mem => mem -> [Access]
writeAccess (Mem -> MemoryCell R
rCell Mem
mem) [Access] -> [Access] -> [Access]
forall a. [a] -> [a] -> [a]
++ MemoryCell S -> [Access]
forall mem. WriteAccessible mem => mem -> [Access]
writeAccess (Mem -> MemoryCell S
sCell Mem
mem)
  afterWriteAdjustment :: Mem -> IO ()
afterWriteAdjustment Mem
mem = do
    Mem -> IO ()
clearAcc Mem
mem
    MemoryCell R -> IO ()
forall mem. WriteAccessible mem => mem -> IO ()
afterWriteAdjustment (MemoryCell R -> IO ()) -> MemoryCell R -> IO ()
forall a b. (a -> b) -> a -> b
$ Mem -> MemoryCell R
rCell Mem
mem
    MemoryCell S -> IO ()
forall mem. WriteAccessible mem => mem -> IO ()
afterWriteAdjustment (MemoryCell S -> IO ()) -> MemoryCell S -> IO ()
forall a b. (a -> b) -> a -> b
$ Mem -> MemoryCell S
sCell Mem
mem
    Mem -> IO ()
clamp Mem
mem