{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- SPDX-FileCopyrightText: 2020 Serokell
--
-- SPDX-License-Identifier: MPL-2.0

-- | Internals of @crypto_stream@.
module NaCl.Stream.Internal
  ( Key
  , toKey

  , Nonce
  , toNonce

  , MaxStreamSize
  , generate

  , xor
  ) where

import Prelude hiding (length)

import Data.ByteArray (ByteArray, ByteArrayAccess, allocRet, length, withByteArray)
import Data.ByteArray.Sized (ByteArrayN, SizedByteArray, sizedByteArray)
import Data.Proxy (Proxy (Proxy))
import GHC.TypeLits (type (<=), natVal)

import qualified Data.ByteArray.Sized as Sized (allocRet)
import qualified Libsodium as Na


-- | Encryption key that can be used for Stream.
--
-- This type is parametrised by the actual data type that contains
-- bytes. This can be, for example, a @ByteString@, but, since this
-- is a secret key, it is better to use @ScrubbedBytes@.
type Key a = SizedByteArray Na.CRYPTO_STREAM_KEYBYTES a

-- | Make a 'Key' from an arbitrary byte array.
--
-- This function returns @Just@ if and only if the byte array has
-- the right length to be used as a key with a Stream.
toKey :: ByteArrayAccess ba => ba -> Maybe (Key ba)
toKey :: ba -> Maybe (Key ba)
toKey = ba -> Maybe (Key ba)
forall (n :: Nat) ba.
(KnownNat n, ByteArrayAccess ba) =>
ba -> Maybe (SizedByteArray n ba)
sizedByteArray


-- | Nonce that can be used for Stream.
--
-- This type is parametrised by the actual data type that contains
-- bytes. This can be, for example, a @ByteString@.
type Nonce a = SizedByteArray Na.CRYPTO_STREAM_NONCEBYTES a

-- | Make a 'Nonce' from an arbitrary byte array.
--
-- This function returns @Just@ if and only if the byte array has
-- the right length to be used as a nonce with a Stream.
toNonce :: ByteArrayAccess ba => ba -> Maybe (Nonce ba)
toNonce :: ba -> Maybe (Nonce ba)
toNonce = ba -> Maybe (Nonce ba)
forall (n :: Nat) ba.
(KnownNat n, ByteArrayAccess ba) =>
ba -> Maybe (SizedByteArray n ba)
sizedByteArray


-- | The maximum size of the stream produced by 'generate'.
type MaxStreamSize = 18446744073709551615  -- = 2^64 - 1 (internal 64-bit counter)

-- | Generate a stream of pseudo-random bytes.
generate
  ::  forall key nonce n ct.
      ( ByteArrayAccess key, ByteArrayAccess nonce
      , ByteArrayN n ct
      , n <= MaxStreamSize
      )
  => Key key  -- ^ Secret key
  -> Nonce nonce  -- ^ Nonce
  -> IO ct
generate :: Key key -> Nonce nonce -> IO ct
generate Key key
key Nonce nonce
nonce = do
    (CInt
_ret, ct
ct) <-
      Proxy n -> (Ptr CUChar -> IO CInt) -> IO (CInt, ct)
forall (n :: Nat) c p a.
ByteArrayN n c =>
Proxy n -> (Ptr p -> IO a) -> IO (a, c)
Sized.allocRet (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n) ((Ptr CUChar -> IO CInt) -> IO (CInt, ct))
-> (Ptr CUChar -> IO CInt) -> IO (CInt, ct)
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
ctPtr ->
      Key key -> (Ptr CUChar -> IO CInt) -> IO CInt
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Key key
key ((Ptr CUChar -> IO CInt) -> IO CInt)
-> (Ptr CUChar -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
keyPtr ->
      Nonce nonce -> (Ptr CUChar -> IO CInt) -> IO CInt
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Nonce nonce
nonce ((Ptr CUChar -> IO CInt) -> IO CInt)
-> (Ptr CUChar -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
noncePtr ->
        Ptr CUChar
-> (Any ::: CULLong) -> Ptr CUChar -> Ptr CUChar -> IO CInt
forall k1 k2 k3 k4 (c :: k1) (clen :: k2) (n :: k3) (k5 :: k4).
Ptr CUChar
-> (Any ::: CULLong) -> Ptr CUChar -> Ptr CUChar -> IO CInt
Na.crypto_stream Ptr CUChar
ctPtr
          (Integer -> Any ::: CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Any ::: CULLong) -> Integer -> Any ::: CULLong
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n))
          Ptr CUChar
noncePtr
          Ptr CUChar
keyPtr
    -- _ret can be only 0, so we don’t check it
    ct -> IO ct
forall (f :: * -> *) a. Applicative f => a -> f a
pure ct
ct


-- | Encrypt/decrypt a message.
xor
  ::  ( ByteArrayAccess key, ByteArrayAccess nonce
      , ByteArrayAccess pt, ByteArray ct
      )
  => Key key  -- ^ Secret key
  -> Nonce nonce  -- ^ Nonce
  -> pt -- ^ Input (plain/cipher) text
  -> IO ct
xor :: Key key -> Nonce nonce -> pt -> IO ct
xor Key key
key Nonce nonce
nonce pt
msg = do
    (CInt
_ret, ct
ct) <-
      Int -> (Ptr CUChar -> IO CInt) -> IO (CInt, ct)
forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
allocRet Int
clen ((Ptr CUChar -> IO CInt) -> IO (CInt, ct))
-> (Ptr CUChar -> IO CInt) -> IO (CInt, ct)
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
ctPtr ->
      Key key -> (Ptr CUChar -> IO CInt) -> IO CInt
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Key key
key ((Ptr CUChar -> IO CInt) -> IO CInt)
-> (Ptr CUChar -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
keyPtr ->
      Nonce nonce -> (Ptr CUChar -> IO CInt) -> IO CInt
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Nonce nonce
nonce ((Ptr CUChar -> IO CInt) -> IO CInt)
-> (Ptr CUChar -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
noncePtr ->
      pt -> (Ptr CUChar -> IO CInt) -> IO CInt
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray pt
msg ((Ptr CUChar -> IO CInt) -> IO CInt)
-> (Ptr CUChar -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
msgPtr -> do
        Ptr CUChar
-> Ptr CUChar
-> (Any ::: CULLong)
-> Ptr CUChar
-> Ptr CUChar
-> IO CInt
forall k1 k2 k3 k4 k5 (c :: k1) (m :: k2) (mlen :: k3) (n :: k4)
       (k6 :: k5).
Ptr CUChar
-> Ptr CUChar
-> (Any ::: CULLong)
-> Ptr CUChar
-> Ptr CUChar
-> IO CInt
Na.crypto_stream_xor Ptr CUChar
ctPtr
          Ptr CUChar
msgPtr (Int -> Any ::: CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Any ::: CULLong) -> Int -> Any ::: CULLong
forall a b. (a -> b) -> a -> b
$ pt -> Int
forall ba. ByteArrayAccess ba => ba -> Int
length pt
msg)
          Ptr CUChar
noncePtr
          Ptr CUChar
keyPtr
    -- _ret can be only 0, so we don’t check it
    ct -> IO ct
forall (f :: * -> *) a. Applicative f => a -> f a
pure ct
ct
  where
    clen :: Int
    clen :: Int
clen = pt -> Int
forall ba. ByteArrayAccess ba => ba -> Int
length pt
msg