{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE MultiParamTypeClasses      #-}

module XChaCha20.Implementation where

import           Raaz.Core
import           Raaz.Primitive.ChaCha20.Internal


import qualified Implementation as Base

name :: String
name :: String
name = String
"x" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
Base.name

description :: String
description :: String
description = String
Base.description String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" This is the XChaCha variant."

type Prim                    = XChaCha20
data Internals               = XChaCha20Mem
  { Internals -> MemoryCell (Key Prim)
copyOfKey         :: MemoryCell (Key ChaCha20)
  , Internals -> Internals
chacha20Internals :: Base.Internals
  }

type BufferAlignment         = Base.BufferAlignment
type BufferPtr               = AlignedBlockPtr BufferAlignment Prim

instance Memory Internals where
  memoryAlloc :: Alloc Internals
memoryAlloc     = MemoryCell (Key Prim) -> Internals -> Internals
XChaCha20Mem (MemoryCell (Key Prim) -> Internals -> Internals)
-> TwistRF AllocField (BYTES Int) (MemoryCell (Key Prim))
-> TwistRF AllocField (BYTES Int) (Internals -> Internals)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TwistRF AllocField (BYTES Int) (MemoryCell (Key Prim))
forall m. Memory m => Alloc m
memoryAlloc TwistRF AllocField (BYTES Int) (Internals -> Internals)
-> TwistRF AllocField (BYTES Int) Internals -> Alloc Internals
forall a b.
TwistRF AllocField (BYTES Int) (a -> b)
-> TwistRF AllocField (BYTES Int) a
-> TwistRF AllocField (BYTES Int) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) Internals
forall m. Memory m => Alloc m
memoryAlloc
  unsafeToPointer :: Internals -> Ptr Word8
unsafeToPointer = MemoryCell (Key Prim) -> Ptr Word8
forall m. Memory m => m -> Ptr Word8
unsafeToPointer (MemoryCell (Key Prim) -> Ptr Word8)
-> (Internals -> MemoryCell (Key Prim)) -> Internals -> Ptr Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Internals -> MemoryCell (Key Prim)
copyOfKey


instance Initialisable Internals (Key XChaCha20) where
  initialise :: Key XChaCha20 -> Internals -> IO ()
initialise Key XChaCha20
xkey = Key XChaCha20 -> MemoryCell (Key Prim) -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise Key XChaCha20
xkey (MemoryCell (Key Prim) -> IO ())
-> (Internals -> MemoryCell (Key Prim)) -> Internals -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Internals -> MemoryCell (Key Prim)
copyOfKey

instance WriteAccessible Internals where
  writeAccess :: Internals -> [Access]
writeAccess = MemoryCell (Key Prim) -> [Access]
forall mem. WriteAccessible mem => mem -> [Access]
writeAccess (MemoryCell (Key Prim) -> [Access])
-> (Internals -> MemoryCell (Key Prim)) -> Internals -> [Access]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Internals -> MemoryCell (Key Prim)
copyOfKey
  afterWriteAdjustment :: Internals -> IO ()
afterWriteAdjustment = MemoryCell (Key Prim) -> IO ()
forall mem. WriteAccessible mem => mem -> IO ()
afterWriteAdjustment (MemoryCell (Key Prim) -> IO ())
-> (Internals -> MemoryCell (Key Prim)) -> Internals -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Internals -> MemoryCell (Key Prim)
copyOfKey

instance Initialisable Internals (Nounce XChaCha20) where
  initialise :: Nounce XChaCha20 -> Internals -> IO ()
initialise Nounce XChaCha20
xnounce Internals
imem = do
    let dest :: Dest Internals
dest = Internals -> Dest Internals
forall a. a -> Dest a
destination (Internals -> Dest Internals) -> Internals -> Dest Internals
forall a b. (a -> b) -> a -> b
$ Internals -> Internals
chacha20Internals Internals
imem
        src :: Src (MemoryCell (Key Prim))
src  = MemoryCell (Key Prim) -> Src (MemoryCell (Key Prim))
forall a. a -> Src a
source (MemoryCell (Key Prim) -> Src (MemoryCell (Key Prim)))
-> MemoryCell (Key Prim) -> Src (MemoryCell (Key Prim))
forall a b. (a -> b) -> a -> b
$ Internals -> MemoryCell (Key Prim)
copyOfKey Internals
imem
      in Dest Internals -> Src (MemoryCell (Key Prim)) -> IO ()
Base.copyKey Dest Internals
dest Src (MemoryCell (Key Prim))
src
    Nounce XChaCha20 -> Internals -> IO ()
Base.xchacha20Setup Nounce XChaCha20
xnounce (Internals -> IO ()) -> Internals -> IO ()
forall a b. (a -> b) -> a -> b
$ Internals -> Internals
chacha20Internals Internals
imem

instance Initialisable Internals (BlockCount XChaCha20) where
  initialise :: BlockCount XChaCha20 -> Internals -> IO ()
initialise BlockCount XChaCha20
bcount = BlockCount Prim -> Internals -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise BlockCount Prim
bcountP (Internals -> IO ())
-> (Internals -> Internals) -> Internals -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Internals -> Internals
chacha20Internals
    where bcountP :: BlockCount ChaCha20
          bcountP :: BlockCount Prim
bcountP = Int -> BlockCount Prim
forall a. Enum a => Int -> a
toEnum (Int -> BlockCount Prim) -> Int -> BlockCount Prim
forall a b. (a -> b) -> a -> b
$ BlockCount XChaCha20 -> Int
forall a. Enum a => a -> Int
fromEnum BlockCount XChaCha20
bcount

instance Extractable Internals (BlockCount XChaCha20) where
  extract :: Internals -> IO (BlockCount XChaCha20)
extract = (BlockCount Prim -> BlockCount XChaCha20)
-> IO (BlockCount Prim) -> IO (BlockCount XChaCha20)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap BlockCount Prim -> BlockCount XChaCha20
coerce (IO (BlockCount Prim) -> IO (BlockCount XChaCha20))
-> (Internals -> IO (BlockCount Prim))
-> Internals
-> IO (BlockCount XChaCha20)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Internals -> IO (BlockCount Prim)
forall m v. Extractable m v => m -> IO v
extract (Internals -> IO (BlockCount Prim))
-> (Internals -> Internals) -> Internals -> IO (BlockCount Prim)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Internals -> Internals
chacha20Internals
    where coerce :: BlockCount ChaCha20 -> BlockCount XChaCha20
          coerce :: BlockCount Prim -> BlockCount XChaCha20
coerce = Int -> BlockCount XChaCha20
forall a. Enum a => Int -> a
toEnum (Int -> BlockCount XChaCha20)
-> (BlockCount Prim -> Int)
-> BlockCount Prim
-> BlockCount XChaCha20
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockCount Prim -> Int
forall a. Enum a => a -> Int
fromEnum

additionalBlocks :: BlockCount XChaCha20
additionalBlocks :: BlockCount XChaCha20
additionalBlocks = BlockCount Prim -> BlockCount XChaCha20
coerce BlockCount Prim
Base.additionalBlocks
    where coerce :: BlockCount Base.Prim -> BlockCount XChaCha20
          coerce :: BlockCount Prim -> BlockCount XChaCha20
coerce = Int -> BlockCount XChaCha20
forall a. Enum a => Int -> a
toEnum (Int -> BlockCount XChaCha20)
-> (BlockCount Prim -> Int)
-> BlockCount Prim
-> BlockCount XChaCha20
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockCount Prim -> Int
forall a. Enum a => a -> Int
fromEnum


processBlocks :: BufferPtr
              -> BlockCount Prim
              -> Internals
              -> IO ()
processBlocks :: BufferPtr -> BlockCount XChaCha20 -> Internals -> IO ()
processBlocks BufferPtr
buf BlockCount XChaCha20
bcount =
  BufferPtr -> BlockCount Prim -> Internals -> IO ()
Base.processBlocks (AlignedPtr BufferAlignment (Tuple 16 WORD)
-> AlignedPtr BufferAlignment (Tuple 16 WORD)
forall a b.
AlignedPtr BufferAlignment a -> AlignedPtr BufferAlignment b
forall (ptr :: * -> *) a b. Pointer ptr => ptr a -> ptr b
castPointer AlignedPtr BufferAlignment (Tuple 16 WORD)
BufferPtr
buf) (BlockCount XChaCha20 -> BlockCount Prim
coerce BlockCount XChaCha20
bcount) (Internals -> IO ())
-> (Internals -> Internals) -> Internals -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Internals -> Internals
chacha20Internals
  where coerce :: BlockCount XChaCha20 -> BlockCount Base.Prim
        coerce :: BlockCount XChaCha20 -> BlockCount Prim
coerce = Int -> BlockCount Prim
forall a. Enum a => Int -> a
toEnum (Int -> BlockCount Prim)
-> (BlockCount XChaCha20 -> Int)
-> BlockCount XChaCha20
-> BlockCount Prim
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockCount XChaCha20 -> Int
forall a. Enum a => a -> Int
fromEnum

-- | Process the last bytes.
processLast :: BufferPtr
            -> BYTES Int
            -> Internals
            -> IO ()
processLast :: BufferPtr -> BYTES Int -> Internals -> IO ()
processLast BufferPtr
buf BYTES Int
nbytes = BufferPtr -> BYTES Int -> Internals -> IO ()
Base.processLast (AlignedPtr BufferAlignment (Tuple 16 WORD)
-> AlignedPtr BufferAlignment (Tuple 16 WORD)
forall a b.
AlignedPtr BufferAlignment a -> AlignedPtr BufferAlignment b
forall (ptr :: * -> *) a b. Pointer ptr => ptr a -> ptr b
castPointer AlignedPtr BufferAlignment (Tuple 16 WORD)
BufferPtr
buf) BYTES Int
nbytes (Internals -> IO ())
-> (Internals -> Internals) -> Internals -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Internals -> Internals
chacha20Internals