-- SPDX-FileCopyrightText: 2020 Serokell
--
-- SPDX-License-Identifier: MPL-2.0

-- | Scalar multiplication in a group.
--
-- This is @crypto_scalarmult_*@ from NaCl.
--
-- Note that this primitive is designed to only make the /Computational Diffie–Hellman/
-- problem hard. It makes no promises about other assumptions, therefore it is
-- the user’s responsibility to hash the output if required for the security
-- of the specific application.
module NaCl.Scalarmult
  ( Point (..)
  , toPoint
  , Scalar (..)
  , toScalar

  , mult
  , multBase
  ) where

import Data.ByteArray (ByteArray, ByteArrayAccess, withByteArray)
import Data.ByteArray.Sized (ByteArrayN, SizedByteArray, allocRet, sizedByteArray)
import Data.Proxy (Proxy (Proxy))
import System.IO.Unsafe (unsafePerformIO)

import qualified Libsodium as Na


-- | Point in the group.
--
-- This type is parametrised by the actual data type that contains
-- bytes. This can be, for example, a @ByteString@.
newtype Point a = Point (SizedByteArray Na.CRYPTO_SCALARMULT_BYTES a)
  deriving
    ( Point a -> Int
Point a -> Ptr p -> IO ()
Point a -> (Ptr p -> IO a) -> IO a
(Point a -> Int)
-> (forall p a. Point a -> (Ptr p -> IO a) -> IO a)
-> (forall p. Point a -> Ptr p -> IO ())
-> ByteArrayAccess (Point a)
forall a. ByteArrayAccess a => Point a -> Int
forall a p. ByteArrayAccess a => Point a -> Ptr p -> IO ()
forall a p a.
ByteArrayAccess a =>
Point a -> (Ptr p -> IO a) -> IO a
forall p. Point a -> Ptr p -> IO ()
forall ba.
(ba -> Int)
-> (forall p a. ba -> (Ptr p -> IO a) -> IO a)
-> (forall p. ba -> Ptr p -> IO ())
-> ByteArrayAccess ba
forall p a. Point a -> (Ptr p -> IO a) -> IO a
copyByteArrayToPtr :: Point a -> Ptr p -> IO ()
$ccopyByteArrayToPtr :: forall a p. ByteArrayAccess a => Point a -> Ptr p -> IO ()
withByteArray :: Point a -> (Ptr p -> IO a) -> IO a
$cwithByteArray :: forall a p a.
ByteArrayAccess a =>
Point a -> (Ptr p -> IO a) -> IO a
length :: Point a -> Int
$clength :: forall a. ByteArrayAccess a => Point a -> Int
ByteArrayAccess, ByteArrayN Na.CRYPTO_SCALARMULT_BYTES
    , Point a -> Point a -> Bool
(Point a -> Point a -> Bool)
-> (Point a -> Point a -> Bool) -> Eq (Point a)
forall a. Eq a => Point a -> Point a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Point a -> Point a -> Bool
$c/= :: forall a. Eq a => Point a -> Point a -> Bool
== :: Point a -> Point a -> Bool
$c== :: forall a. Eq a => Point a -> Point a -> Bool
Eq, Eq (Point a)
Eq (Point a)
-> (Point a -> Point a -> Ordering)
-> (Point a -> Point a -> Bool)
-> (Point a -> Point a -> Bool)
-> (Point a -> Point a -> Bool)
-> (Point a -> Point a -> Bool)
-> (Point a -> Point a -> Point a)
-> (Point a -> Point a -> Point a)
-> Ord (Point a)
Point a -> Point a -> Bool
Point a -> Point a -> Ordering
Point a -> Point a -> Point a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Point a)
forall a. Ord a => Point a -> Point a -> Bool
forall a. Ord a => Point a -> Point a -> Ordering
forall a. Ord a => Point a -> Point a -> Point a
min :: Point a -> Point a -> Point a
$cmin :: forall a. Ord a => Point a -> Point a -> Point a
max :: Point a -> Point a -> Point a
$cmax :: forall a. Ord a => Point a -> Point a -> Point a
>= :: Point a -> Point a -> Bool
$c>= :: forall a. Ord a => Point a -> Point a -> Bool
> :: Point a -> Point a -> Bool
$c> :: forall a. Ord a => Point a -> Point a -> Bool
<= :: Point a -> Point a -> Bool
$c<= :: forall a. Ord a => Point a -> Point a -> Bool
< :: Point a -> Point a -> Bool
$c< :: forall a. Ord a => Point a -> Point a -> Bool
compare :: Point a -> Point a -> Ordering
$ccompare :: forall a. Ord a => Point a -> Point a -> Ordering
$cp1Ord :: forall a. Ord a => Eq (Point a)
Ord, Int -> Point a -> ShowS
[Point a] -> ShowS
Point a -> String
(Int -> Point a -> ShowS)
-> (Point a -> String) -> ([Point a] -> ShowS) -> Show (Point a)
forall a. Show a => Int -> Point a -> ShowS
forall a. Show a => [Point a] -> ShowS
forall a. Show a => Point a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Point a] -> ShowS
$cshowList :: forall a. Show a => [Point a] -> ShowS
show :: Point a -> String
$cshow :: forall a. Show a => Point a -> String
showsPrec :: Int -> Point a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Point a -> ShowS
Show
    )

-- | Convert bytes to a group point.
toPoint :: ByteArrayAccess bytes => bytes -> Maybe (Point bytes)
toPoint :: bytes -> Maybe (Point bytes)
toPoint = (SizedByteArray CRYPTO_SCALARMULT_BYTES bytes -> Point bytes)
-> Maybe (SizedByteArray CRYPTO_SCALARMULT_BYTES bytes)
-> Maybe (Point bytes)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SizedByteArray CRYPTO_SCALARMULT_BYTES bytes -> Point bytes
forall a. SizedByteArray CRYPTO_SCALARMULT_BYTES a -> Point a
Point (Maybe (SizedByteArray CRYPTO_SCALARMULT_BYTES bytes)
 -> Maybe (Point bytes))
-> (bytes -> Maybe (SizedByteArray CRYPTO_SCALARMULT_BYTES bytes))
-> bytes
-> Maybe (Point bytes)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. bytes -> Maybe (SizedByteArray CRYPTO_SCALARMULT_BYTES bytes)
forall (n :: Nat) ba.
(KnownNat n, ByteArrayAccess ba) =>
ba -> Maybe (SizedByteArray n ba)
sizedByteArray

-- | Scalar that can be used for group multiplication.
--
-- This type is parametrised by the actual data type that contains
-- bytes. This can be, for example, a @ByteString@.
newtype Scalar a = Scalar (SizedByteArray Na.CRYPTO_SCALARMULT_SCALARBYTES a)
  deriving
    ( Scalar a -> Int
Scalar a -> Ptr p -> IO ()
Scalar a -> (Ptr p -> IO a) -> IO a
(Scalar a -> Int)
-> (forall p a. Scalar a -> (Ptr p -> IO a) -> IO a)
-> (forall p. Scalar a -> Ptr p -> IO ())
-> ByteArrayAccess (Scalar a)
forall a. ByteArrayAccess a => Scalar a -> Int
forall a p. ByteArrayAccess a => Scalar a -> Ptr p -> IO ()
forall a p a.
ByteArrayAccess a =>
Scalar a -> (Ptr p -> IO a) -> IO a
forall p. Scalar a -> Ptr p -> IO ()
forall ba.
(ba -> Int)
-> (forall p a. ba -> (Ptr p -> IO a) -> IO a)
-> (forall p. ba -> Ptr p -> IO ())
-> ByteArrayAccess ba
forall p a. Scalar a -> (Ptr p -> IO a) -> IO a
copyByteArrayToPtr :: Scalar a -> Ptr p -> IO ()
$ccopyByteArrayToPtr :: forall a p. ByteArrayAccess a => Scalar a -> Ptr p -> IO ()
withByteArray :: Scalar a -> (Ptr p -> IO a) -> IO a
$cwithByteArray :: forall a p a.
ByteArrayAccess a =>
Scalar a -> (Ptr p -> IO a) -> IO a
length :: Scalar a -> Int
$clength :: forall a. ByteArrayAccess a => Scalar a -> Int
ByteArrayAccess, ByteArrayN Na.CRYPTO_SCALARMULT_SCALARBYTES
    , Scalar a -> Scalar a -> Bool
(Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Bool) -> Eq (Scalar a)
forall a. Eq a => Scalar a -> Scalar a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Scalar a -> Scalar a -> Bool
$c/= :: forall a. Eq a => Scalar a -> Scalar a -> Bool
== :: Scalar a -> Scalar a -> Bool
$c== :: forall a. Eq a => Scalar a -> Scalar a -> Bool
Eq, Eq (Scalar a)
Eq (Scalar a)
-> (Scalar a -> Scalar a -> Ordering)
-> (Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Bool)
-> (Scalar a -> Scalar a -> Scalar a)
-> (Scalar a -> Scalar a -> Scalar a)
-> Ord (Scalar a)
Scalar a -> Scalar a -> Bool
Scalar a -> Scalar a -> Ordering
Scalar a -> Scalar a -> Scalar a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Scalar a)
forall a. Ord a => Scalar a -> Scalar a -> Bool
forall a. Ord a => Scalar a -> Scalar a -> Ordering
forall a. Ord a => Scalar a -> Scalar a -> Scalar a
min :: Scalar a -> Scalar a -> Scalar a
$cmin :: forall a. Ord a => Scalar a -> Scalar a -> Scalar a
max :: Scalar a -> Scalar a -> Scalar a
$cmax :: forall a. Ord a => Scalar a -> Scalar a -> Scalar a
>= :: Scalar a -> Scalar a -> Bool
$c>= :: forall a. Ord a => Scalar a -> Scalar a -> Bool
> :: Scalar a -> Scalar a -> Bool
$c> :: forall a. Ord a => Scalar a -> Scalar a -> Bool
<= :: Scalar a -> Scalar a -> Bool
$c<= :: forall a. Ord a => Scalar a -> Scalar a -> Bool
< :: Scalar a -> Scalar a -> Bool
$c< :: forall a. Ord a => Scalar a -> Scalar a -> Bool
compare :: Scalar a -> Scalar a -> Ordering
$ccompare :: forall a. Ord a => Scalar a -> Scalar a -> Ordering
$cp1Ord :: forall a. Ord a => Eq (Scalar a)
Ord, Int -> Scalar a -> ShowS
[Scalar a] -> ShowS
Scalar a -> String
(Int -> Scalar a -> ShowS)
-> (Scalar a -> String) -> ([Scalar a] -> ShowS) -> Show (Scalar a)
forall a. Show a => Int -> Scalar a -> ShowS
forall a. Show a => [Scalar a] -> ShowS
forall a. Show a => Scalar a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Scalar a] -> ShowS
$cshowList :: forall a. Show a => [Scalar a] -> ShowS
show :: Scalar a -> String
$cshow :: forall a. Show a => Scalar a -> String
showsPrec :: Int -> Scalar a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Scalar a -> ShowS
Show
    )

-- | Convert bytes to a scalar.
toScalar :: ByteArrayAccess bytes => bytes -> Maybe (Scalar bytes)
toScalar :: bytes -> Maybe (Scalar bytes)
toScalar = (SizedByteArray CRYPTO_SCALARMULT_BYTES bytes -> Scalar bytes)
-> Maybe (SizedByteArray CRYPTO_SCALARMULT_BYTES bytes)
-> Maybe (Scalar bytes)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SizedByteArray CRYPTO_SCALARMULT_BYTES bytes -> Scalar bytes
forall a. SizedByteArray CRYPTO_SCALARMULT_BYTES a -> Scalar a
Scalar (Maybe (SizedByteArray CRYPTO_SCALARMULT_BYTES bytes)
 -> Maybe (Scalar bytes))
-> (bytes -> Maybe (SizedByteArray CRYPTO_SCALARMULT_BYTES bytes))
-> bytes
-> Maybe (Scalar bytes)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. bytes -> Maybe (SizedByteArray CRYPTO_SCALARMULT_BYTES bytes)
forall (n :: Nat) ba.
(KnownNat n, ByteArrayAccess ba) =>
ba -> Maybe (SizedByteArray n ba)
sizedByteArray

-- | Multiply a group point by an integer.
--
-- Note that this function is slightly different from the corresponding function
-- in NaCl. Namely, unlike @crypto_scalarmult@ in NaCl, this one will return
-- @Nothing@ if:
--
-- * either the group point has a small order (1, 2, 4, or 8)
-- * or the result of the multiplication is the identity point.
--
-- This is how it is implemented in libsodium.
mult
  :: forall outBytes pointBytes scalarBytes.
     ( ByteArrayAccess pointBytes
     , ByteArrayAccess scalarBytes
     , ByteArray outBytes
     )
  => Point pointBytes  -- ^ Group point.
  -> Scalar scalarBytes  -- ^ Scalar.
  -> Maybe (Point outBytes)
mult :: Point pointBytes -> Scalar scalarBytes -> Maybe (Point outBytes)
mult Point pointBytes
point Scalar scalarBytes
scalar = IO (Maybe (Point outBytes)) -> Maybe (Point outBytes)
forall a. IO a -> a
unsafePerformIO (IO (Maybe (Point outBytes)) -> Maybe (Point outBytes))
-> IO (Maybe (Point outBytes)) -> Maybe (Point outBytes)
forall a b. (a -> b) -> a -> b
$ do
    (CInt
ret, Point outBytes
out) <-
      Proxy CRYPTO_SCALARMULT_BYTES
-> (Ptr CUChar -> IO CInt) -> IO (CInt, Point outBytes)
forall (n :: Nat) c p a.
ByteArrayN n c =>
Proxy n -> (Ptr p -> IO a) -> IO (a, c)
allocRet (Proxy CRYPTO_SCALARMULT_BYTES
forall k (t :: k). Proxy t
Proxy @Na.CRYPTO_SCALARMULT_BYTES) ((Ptr CUChar -> IO CInt) -> IO (CInt, Point outBytes))
-> (Ptr CUChar -> IO CInt) -> IO (CInt, Point outBytes)
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
outPtr ->
      Point pointBytes -> (Ptr CUChar -> IO CInt) -> IO CInt
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Point pointBytes
point ((Ptr CUChar -> IO CInt) -> IO CInt)
-> (Ptr CUChar -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
pointPtr ->
      Scalar scalarBytes -> (Ptr CUChar -> IO CInt) -> IO CInt
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Scalar scalarBytes
scalar ((Ptr CUChar -> IO CInt) -> IO CInt)
-> (Ptr CUChar -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
scalarPtr ->
        Ptr CUChar -> Ptr CUChar -> Ptr CUChar -> IO CInt
forall k1 k2 k3 (q :: k1) (n :: k2) (p :: k3).
Ptr CUChar -> Ptr CUChar -> Ptr CUChar -> IO CInt
Na.crypto_scalarmult Ptr CUChar
outPtr Ptr CUChar
scalarPtr Ptr CUChar
pointPtr
    if CInt
ret CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
0
    then Maybe (Point outBytes) -> IO (Maybe (Point outBytes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Point outBytes) -> IO (Maybe (Point outBytes)))
-> Maybe (Point outBytes) -> IO (Maybe (Point outBytes))
forall a b. (a -> b) -> a -> b
$ Point outBytes -> Maybe (Point outBytes)
forall a. a -> Maybe a
Just Point outBytes
out
    else Maybe (Point outBytes) -> IO (Maybe (Point outBytes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Point outBytes)
forall a. Maybe a
Nothing

-- | Multiply the standard group point by an integer.
multBase
  :: forall outBytes scalarBytes.
     ( ByteArrayAccess scalarBytes
     , ByteArray outBytes
     )
  => Scalar scalarBytes  -- ^ Scalar.
  -> Point outBytes
multBase :: Scalar scalarBytes -> Point outBytes
multBase Scalar scalarBytes
scalar = IO (Point outBytes) -> Point outBytes
forall a. IO a -> a
unsafePerformIO (IO (Point outBytes) -> Point outBytes)
-> IO (Point outBytes) -> Point outBytes
forall a b. (a -> b) -> a -> b
$ do
    (CInt
_ret, Point outBytes
out) <-
      Proxy CRYPTO_SCALARMULT_BYTES
-> (Ptr CUChar -> IO CInt) -> IO (CInt, Point outBytes)
forall (n :: Nat) c p a.
ByteArrayN n c =>
Proxy n -> (Ptr p -> IO a) -> IO (a, c)
allocRet (Proxy CRYPTO_SCALARMULT_BYTES
forall k (t :: k). Proxy t
Proxy @Na.CRYPTO_SCALARMULT_BYTES) ((Ptr CUChar -> IO CInt) -> IO (CInt, Point outBytes))
-> (Ptr CUChar -> IO CInt) -> IO (CInt, Point outBytes)
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
outPtr ->
      Scalar scalarBytes -> (Ptr CUChar -> IO CInt) -> IO CInt
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Scalar scalarBytes
scalar ((Ptr CUChar -> IO CInt) -> IO CInt)
-> (Ptr CUChar -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
scalarPtr ->
        Ptr CUChar -> Ptr CUChar -> IO CInt
forall k1 k2 (q :: k1) (n :: k2).
Ptr CUChar -> Ptr CUChar -> IO CInt
Na.crypto_scalarmult_base Ptr CUChar
outPtr Ptr CUChar
scalarPtr
    -- _ret can be only 0, so we don’t check it
    Point outBytes -> IO (Point outBytes)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Point outBytes
out