-- |
-- Module      : Crypto.PubKey.ECC.P256
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- P256 support
--
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# OPTIONS_GHC -fno-warn-unused-binds #-}
module Crypto.PubKey.ECC.P256
    ( Scalar
    , Point
    -- * Point arithmetic
    , pointBase
    , pointAdd
    , pointNegate
    , pointMul
    , pointDh
    , pointsMulVarTime
    , pointIsValid
    , pointIsAtInfinity
    , toPoint
    , pointX
    , pointToIntegers
    , pointFromIntegers
    , pointToBinary
    , pointFromBinary
    , unsafePointFromBinary
    -- * Scalar arithmetic
    , scalarGenerate
    , scalarZero
    , scalarN
    , scalarIsZero
    , scalarAdd
    , scalarSub
    , scalarMul
    , scalarInv
    , scalarInvSafe
    , scalarCmp
    , scalarFromBinary
    , scalarToBinary
    , scalarFromInteger
    , scalarToInteger
    ) where

import           Data.Word
import           Foreign.Ptr
import           Foreign.C.Types

import           Crypto.Internal.Compat
import           Crypto.Internal.Imports
import           Crypto.Internal.ByteArray
import qualified Crypto.Internal.ByteArray as B
import           Data.Memory.PtrMethods (memSet)
import           Crypto.Error
import           Crypto.Random
import           Crypto.Number.Serialize.Internal (os2ip, i2ospOf)
import qualified Crypto.Number.Serialize as S (os2ip, i2ospOf)

-- | A P256 scalar
newtype Scalar = Scalar ScrubbedBytes
    deriving (Int -> Scalar -> ShowS
[Scalar] -> ShowS
Scalar -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Scalar] -> ShowS
$cshowList :: [Scalar] -> ShowS
show :: Scalar -> String
$cshow :: Scalar -> String
showsPrec :: Int -> Scalar -> ShowS
$cshowsPrec :: Int -> Scalar -> ShowS
Show,Scalar -> Scalar -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Scalar -> Scalar -> Bool
$c/= :: Scalar -> Scalar -> Bool
== :: Scalar -> Scalar -> Bool
$c== :: Scalar -> Scalar -> Bool
Eq,Scalar -> Int
forall p. Scalar -> Ptr p -> IO ()
forall ba.
(ba -> Int)
-> (forall p a. ba -> (Ptr p -> IO a) -> IO a)
-> (forall p. ba -> Ptr p -> IO ())
-> ByteArrayAccess ba
forall p a. Scalar -> (Ptr p -> IO a) -> IO a
copyByteArrayToPtr :: forall p. Scalar -> Ptr p -> IO ()
$ccopyByteArrayToPtr :: forall p. Scalar -> Ptr p -> IO ()
withByteArray :: forall p a. Scalar -> (Ptr p -> IO a) -> IO a
$cwithByteArray :: forall p a. Scalar -> (Ptr p -> IO a) -> IO a
length :: Scalar -> Int
$clength :: Scalar -> Int
ByteArrayAccess,Scalar -> ()
forall a. (a -> ()) -> NFData a
rnf :: Scalar -> ()
$crnf :: Scalar -> ()
NFData)

-- | A P256 point
newtype Point = Point Bytes
    deriving (Int -> Point -> ShowS
[Point] -> ShowS
Point -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Point] -> ShowS
$cshowList :: [Point] -> ShowS
show :: Point -> String
$cshow :: Point -> String
showsPrec :: Int -> Point -> ShowS
$cshowsPrec :: Int -> Point -> ShowS
Show,Point -> Point -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Point -> Point -> Bool
$c/= :: Point -> Point -> Bool
== :: Point -> Point -> Bool
$c== :: Point -> Point -> Bool
Eq,Point -> ()
forall a. (a -> ()) -> NFData a
rnf :: Point -> ()
$crnf :: Point -> ()
NFData)

scalarSize :: Int
scalarSize :: Int
scalarSize = Int
32

pointSize :: Int
pointSize :: Int
pointSize = Int
64

type P256Digit  = Word32

data P256Scalar
data P256Y
data P256X

order :: Integer
order :: Integer
order = Integer
0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551

------------------------------------------------------------------------
-- Point methods
------------------------------------------------------------------------

-- | Get the base point for the P256 Curve
pointBase :: Point
pointBase :: Point
pointBase =
    case Integer -> CryptoFailable Scalar
scalarFromInteger Integer
1 of
        CryptoPassed Scalar
s  -> Scalar -> Point
toPoint Scalar
s
        CryptoFailed CryptoError
_ -> forall a. HasCallStack => String -> a
error String
"pointBase: assumption failed"

-- | Lift to curve a scalar
--
-- Using the curve generator as base point compute:
--
-- > scalar * G
--
toPoint :: Scalar -> Point
toPoint :: Scalar -> Point
toPoint Scalar
s
    | Scalar -> Bool
scalarIsZero Scalar
s = forall a. HasCallStack => String -> a
error String
"cannot create point from zero"
    | Bool
otherwise      =
        (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py -> forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
s forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
p ->
            Ptr P256Scalar -> Ptr P256X -> Ptr P256Y -> IO ()
ccrypton_p256_basepoint_mul Ptr P256Scalar
p Ptr P256X
px Ptr P256Y
py

-- | Add a point to another point
pointAdd :: Point -> Point -> Point
pointAdd :: Point -> Point -> Point
pointAdd Point
a Point
b = (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint forall a b. (a -> b) -> a -> b
$ \Ptr P256X
dx Ptr P256Y
dy ->
    forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
a forall a b. (a -> b) -> a -> b
$ \Ptr P256X
ax Ptr P256Y
ay -> forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
b forall a b. (a -> b) -> a -> b
$ \Ptr P256X
bx Ptr P256Y
by ->
        Ptr P256X
-> Ptr P256Y
-> Ptr P256X
-> Ptr P256Y
-> Ptr P256X
-> Ptr P256Y
-> IO ()
ccrypton_p256e_point_add Ptr P256X
ax Ptr P256Y
ay Ptr P256X
bx Ptr P256Y
by Ptr P256X
dx Ptr P256Y
dy

-- | Negate a point
pointNegate :: Point -> Point
pointNegate :: Point -> Point
pointNegate Point
a = (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint forall a b. (a -> b) -> a -> b
$ \Ptr P256X
dx Ptr P256Y
dy ->
    forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
a forall a b. (a -> b) -> a -> b
$ \Ptr P256X
ax Ptr P256Y
ay ->
        Ptr P256X -> Ptr P256Y -> Ptr P256X -> Ptr P256Y -> IO ()
ccrypton_p256e_point_negate Ptr P256X
ax Ptr P256Y
ay Ptr P256X
dx Ptr P256Y
dy

-- | Multiply a point by a scalar
--
-- warning: variable time
pointMul :: Scalar -> Point -> Point
pointMul :: Scalar -> Point -> Point
pointMul Scalar
scalar Point
p = (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint forall a b. (a -> b) -> a -> b
$ \Ptr P256X
dx Ptr P256Y
dy ->
    forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
scalar forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
n -> forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py ->
        Ptr P256Scalar
-> Ptr P256X -> Ptr P256Y -> Ptr P256X -> Ptr P256Y -> IO ()
ccrypton_p256e_point_mul Ptr P256Scalar
n Ptr P256X
px Ptr P256Y
py Ptr P256X
dx Ptr P256Y
dy

-- | Similar to 'pointMul', serializing the x coordinate as binary.
-- When scalar is multiple of point order the result is all zero.
pointDh :: ByteArray binary => Scalar -> Point -> binary
pointDh :: forall binary. ByteArray binary => Scalar -> Point -> binary
pointDh Scalar
scalar Point
p =
    forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.unsafeCreate Int
scalarSize forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst -> forall a. (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withTempPoint forall a b. (a -> b) -> a -> b
$ \Ptr P256X
dx Ptr P256Y
dy -> do
        forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
scalar forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
n -> forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py ->
            Ptr P256Scalar
-> Ptr P256X -> Ptr P256Y -> Ptr P256X -> Ptr P256Y -> IO ()
ccrypton_p256e_point_mul Ptr P256Scalar
n Ptr P256X
px Ptr P256Y
py Ptr P256X
dx Ptr P256Y
dy
        Ptr P256Scalar -> Ptr Word8 -> IO ()
ccrypton_p256_to_bin (forall a b. Ptr a -> Ptr b
castPtr Ptr P256X
dx) Ptr Word8
dst

-- | multiply the point @p with @n2 and add a lifted to curve value @n1
--
-- > n1 * G + n2 * p
--
-- warning: variable time
pointsMulVarTime :: Scalar -> Scalar -> Point -> Point
pointsMulVarTime :: Scalar -> Scalar -> Point -> Point
pointsMulVarTime Scalar
n1 Scalar
n2 Point
p = (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint forall a b. (a -> b) -> a -> b
$ \Ptr P256X
dx Ptr P256Y
dy ->
    forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
n1 forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pn1 -> forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
n2 forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pn2 -> forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py ->
        Ptr P256Scalar
-> Ptr P256Scalar
-> Ptr P256X
-> Ptr P256Y
-> Ptr P256X
-> Ptr P256Y
-> IO ()
ccrypton_p256_points_mul_vartime Ptr P256Scalar
pn1 Ptr P256Scalar
pn2 Ptr P256X
px Ptr P256Y
py Ptr P256X
dx Ptr P256Y
dy

-- | Check if a 'Point' is valid
pointIsValid :: Point -> Bool
pointIsValid :: Point -> Bool
pointIsValid Point
p = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py -> do
    CInt
r <- Ptr P256X -> Ptr P256Y -> IO CInt
ccrypton_p256_is_valid_point Ptr P256X
px Ptr P256Y
py
    forall (m :: * -> *) a. Monad m => a -> m a
return (CInt
r forall a. Eq a => a -> a -> Bool
/= CInt
0)

-- | Check if a 'Point' is the point at infinity
pointIsAtInfinity :: Point -> Bool
pointIsAtInfinity :: Point -> Bool
pointIsAtInfinity (Point Bytes
b) = forall ba. ByteArrayAccess ba => ba -> Bool
constAllZero Bytes
b

-- | Return the x coordinate as a 'Scalar' if the point is not at infinity
pointX :: Point -> Maybe Scalar
pointX :: Point -> Maybe Scalar
pointX Point
p
    | Point -> Bool
pointIsAtInfinity Point
p = forall a. Maybe a
Nothing
    | Bool
otherwise           = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
        (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
d    ->
        forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p         forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
_ ->
            Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
ccrypton_p256_mod Ptr P256Scalar
ccrypton_SECP256r1_n (forall a b. Ptr a -> Ptr b
castPtr Ptr P256X
px) (forall a b. Ptr a -> Ptr b
castPtr Ptr P256Scalar
d)

-- | Convert a point to (x,y) Integers
pointToIntegers :: Point -> (Integer, Integer)
pointToIntegers :: Point -> (Integer, Integer)
pointToIntegers Point
p = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py ->
    forall a. Int -> (Ptr Word8 -> IO a) -> IO a
allocTemp Int
32 (Ptr P256Scalar
-> Ptr P256Scalar -> Ptr Word8 -> IO (Integer, Integer)
serialize (forall a b. Ptr a -> Ptr b
castPtr Ptr P256X
px) (forall a b. Ptr a -> Ptr b
castPtr Ptr P256Y
py))
  where
    serialize :: Ptr P256Scalar
-> Ptr P256Scalar -> Ptr Word8 -> IO (Integer, Integer)
serialize Ptr P256Scalar
px Ptr P256Scalar
py Ptr Word8
temp = do
        Ptr P256Scalar -> Ptr Word8 -> IO ()
ccrypton_p256_to_bin Ptr P256Scalar
px Ptr Word8
temp
        Integer
x <- Ptr Word8 -> Int -> IO Integer
os2ip Ptr Word8
temp Int
scalarSize
        Ptr P256Scalar -> Ptr Word8 -> IO ()
ccrypton_p256_to_bin Ptr P256Scalar
py Ptr Word8
temp
        Integer
y <- Ptr Word8 -> Int -> IO Integer
os2ip Ptr Word8
temp Int
scalarSize
        forall (m :: * -> *) a. Monad m => a -> m a
return (Integer
x,Integer
y)

-- | Convert from (x,y) Integers to a point
pointFromIntegers :: (Integer, Integer) -> Point
pointFromIntegers :: (Integer, Integer) -> Point
pointFromIntegers (Integer
x,Integer
y) = (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint forall a b. (a -> b) -> a -> b
$ \Ptr P256X
dx Ptr P256Y
dy ->
    forall a. Int -> (Ptr Word8 -> IO a) -> IO a
allocTemp Int
scalarSize (\Ptr Word8
temp -> Ptr Word8 -> Ptr P256Scalar -> Integer -> IO ()
fill Ptr Word8
temp (forall a b. Ptr a -> Ptr b
castPtr Ptr P256X
dx) Integer
x forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr Word8 -> Ptr P256Scalar -> Integer -> IO ()
fill Ptr Word8
temp (forall a b. Ptr a -> Ptr b
castPtr Ptr P256Y
dy) Integer
y)
  where
    -- put @n to @temp in big endian format, then from @temp to @dest in p256 scalar format
    fill :: Ptr Word8 -> Ptr P256Scalar -> Integer -> IO ()
    fill :: Ptr Word8 -> Ptr P256Scalar -> Integer -> IO ()
fill Ptr Word8
temp Ptr P256Scalar
dest Integer
n = do
        -- write the integer in big endian format to temp
        Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
temp Word8
0 Int
scalarSize
        Int
e <- Integer -> Ptr Word8 -> Int -> IO Int
i2ospOf Integer
n Ptr Word8
temp Int
scalarSize
        if Int
e forall a. Eq a => a -> a -> Bool
== Int
0
            then forall a. HasCallStack => String -> a
error String
"pointFromIntegers: filling failed"
            else forall (m :: * -> *) a. Monad m => a -> m a
return ()
        -- then fill dest with the P256 scalar from temp
        Ptr Word8 -> Ptr P256Scalar -> IO ()
ccrypton_p256_from_bin Ptr Word8
temp Ptr P256Scalar
dest

-- | Convert a point to a binary representation
pointToBinary :: ByteArray ba => Point -> ba
pointToBinary :: forall ba. ByteArray ba => Point -> ba
pointToBinary Point
p = forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.unsafeCreate Int
pointSize forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst -> forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py -> do
    Ptr P256Scalar -> Ptr Word8 -> IO ()
ccrypton_p256_to_bin (forall a b. Ptr a -> Ptr b
castPtr Ptr P256X
px) Ptr Word8
dst
    Ptr P256Scalar -> Ptr Word8 -> IO ()
ccrypton_p256_to_bin (forall a b. Ptr a -> Ptr b
castPtr Ptr P256Y
py) (Ptr Word8
dst forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
32)

-- | Convert from binary to a valid point
pointFromBinary :: ByteArrayAccess ba => ba -> CryptoFailable Point
pointFromBinary :: forall ba. ByteArrayAccess ba => ba -> CryptoFailable Point
pointFromBinary ba
ba = forall ba. ByteArrayAccess ba => ba -> CryptoFailable Point
unsafePointFromBinary ba
ba forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Point -> CryptoFailable Point
validatePoint
  where
    validatePoint :: Point -> CryptoFailable Point
    validatePoint :: Point -> CryptoFailable Point
validatePoint Point
p
        | Point -> Bool
pointIsValid Point
p = forall a. a -> CryptoFailable a
CryptoPassed Point
p
        | Bool
otherwise      = forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_PointCoordinatesInvalid

-- | Convert from binary to a point, possibly invalid
unsafePointFromBinary :: ByteArrayAccess ba => ba -> CryptoFailable Point
unsafePointFromBinary :: forall ba. ByteArrayAccess ba => ba -> CryptoFailable Point
unsafePointFromBinary ba
ba
    | forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
ba forall a. Eq a => a -> a -> Bool
/= Int
pointSize = forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_PublicKeySizeInvalid
    | Bool
otherwise                =
        forall a. a -> CryptoFailable a
CryptoPassed forall a b. (a -> b) -> a -> b
$ (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py -> forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
ba forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src -> do
            Ptr Word8 -> Ptr P256Scalar -> IO ()
ccrypton_p256_from_bin Ptr Word8
src                        (forall a b. Ptr a -> Ptr b
castPtr Ptr P256X
px)
            Ptr Word8 -> Ptr P256Scalar -> IO ()
ccrypton_p256_from_bin (Ptr Word8
src forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
scalarSize) (forall a b. Ptr a -> Ptr b
castPtr Ptr P256Y
py)

------------------------------------------------------------------------
-- Scalar methods
------------------------------------------------------------------------

-- | Generate a randomly generated new scalar
scalarGenerate :: MonadRandom randomly => randomly Scalar
scalarGenerate :: forall (randomly :: * -> *).
MonadRandom randomly =>
randomly Scalar
scalarGenerate = forall {a}. CryptoFailable a -> a
unwrap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ba. ByteArrayAccess ba => ba -> CryptoFailable Scalar
scalarFromBinary forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> ScrubbedBytes
witness forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
32
  where
    unwrap :: CryptoFailable a -> a
unwrap (CryptoFailed CryptoError
_) = forall a. HasCallStack => String -> a
error String
"scalarGenerate: assumption failed"
    unwrap (CryptoPassed a
s) = a
s
    witness :: ScrubbedBytes -> ScrubbedBytes
    witness :: ScrubbedBytes -> ScrubbedBytes
witness = forall a. a -> a
id

-- | The scalar representing 0
scalarZero :: Scalar
scalarZero :: Scalar
scalarZero = (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
d -> Ptr P256Scalar -> IO ()
ccrypton_p256_init Ptr P256Scalar
d

-- | The scalar representing the curve order
scalarN :: Scalar
scalarN :: Scalar
scalarN = forall {a}. CryptoFailable a -> a
throwCryptoError (Integer -> CryptoFailable Scalar
scalarFromInteger Integer
order)

-- | Check if the scalar is 0
scalarIsZero :: Scalar -> Bool
scalarIsZero :: Scalar -> Bool
scalarIsZero Scalar
s = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
s forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
d -> do
    CInt
result <- Ptr P256Scalar -> IO CInt
ccrypton_p256_is_zero Ptr P256Scalar
d
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ CInt
result forall a. Eq a => a -> a -> Bool
/= CInt
0

-- | Perform addition between two scalars
--
-- > a + b
scalarAdd :: Scalar -> Scalar -> Scalar
scalarAdd :: Scalar -> Scalar -> Scalar
scalarAdd Scalar
a Scalar
b =
    (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
d -> forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
a forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pa -> forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
b forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pb ->
        Ptr P256Scalar
-> Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
ccrypton_p256e_modadd Ptr P256Scalar
ccrypton_SECP256r1_n Ptr P256Scalar
pa Ptr P256Scalar
pb Ptr P256Scalar
d

-- | Perform subtraction between two scalars
--
-- > a - b
scalarSub :: Scalar -> Scalar -> Scalar
scalarSub :: Scalar -> Scalar -> Scalar
scalarSub Scalar
a Scalar
b =
    (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
d -> forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
a forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pa -> forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
b forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pb ->
        Ptr P256Scalar
-> Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
ccrypton_p256e_modsub Ptr P256Scalar
ccrypton_SECP256r1_n Ptr P256Scalar
pa Ptr P256Scalar
pb Ptr P256Scalar
d

-- | Perform multiplication between two scalars
--
-- > a * b
scalarMul :: Scalar -> Scalar -> Scalar
scalarMul :: Scalar -> Scalar -> Scalar
scalarMul Scalar
a Scalar
b =
    (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
d -> forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
a forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pa -> forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
b forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pb ->
         Ptr P256Scalar
-> Ptr P256Scalar
-> P256Digit
-> Ptr P256Scalar
-> Ptr P256Scalar
-> IO ()
ccrypton_p256_modmul Ptr P256Scalar
ccrypton_SECP256r1_n Ptr P256Scalar
pa P256Digit
0 Ptr P256Scalar
pb Ptr P256Scalar
d

-- | Give the inverse of the scalar
--
-- > 1 / a
--
-- warning: variable time
scalarInv :: Scalar -> Scalar
scalarInv :: Scalar -> Scalar
scalarInv Scalar
a =
    (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
b -> forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
a forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pa ->
        Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
ccrypton_p256_modinv_vartime Ptr P256Scalar
ccrypton_SECP256r1_n Ptr P256Scalar
pa Ptr P256Scalar
b

-- | Give the inverse of the scalar using safe exponentiation
--
-- > 1 / a
scalarInvSafe :: Scalar -> Scalar
scalarInvSafe :: Scalar -> Scalar
scalarInvSafe Scalar
a =
    (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
b -> forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
a forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pa ->
        Ptr P256Scalar -> Ptr P256Scalar -> IO ()
ccrypton_p256e_scalar_invert Ptr P256Scalar
pa Ptr P256Scalar
b

-- | Compare 2 Scalar
scalarCmp :: Scalar -> Scalar -> Ordering
scalarCmp :: Scalar -> Scalar -> Ordering
scalarCmp Scalar
a Scalar
b = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$
    forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
a forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pa -> forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
b forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pb -> do
        CInt
v <- Ptr P256Scalar -> Ptr P256Scalar -> IO CInt
ccrypton_p256_cmp Ptr P256Scalar
pa Ptr P256Scalar
pb
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> Ordering
compare CInt
v CInt
0

-- | convert a scalar from binary
scalarFromBinary :: ByteArrayAccess ba => ba -> CryptoFailable Scalar
scalarFromBinary :: forall ba. ByteArrayAccess ba => ba -> CryptoFailable Scalar
scalarFromBinary ba
ba
    | forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
ba forall a. Eq a => a -> a -> Bool
/= Int
scalarSize = forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_SecretKeySizeInvalid
    | Bool
otherwise                 =
        forall a. a -> CryptoFailable a
CryptoPassed forall a b. (a -> b) -> a -> b
$ (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
p -> forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
ba forall a b. (a -> b) -> a -> b
$ \Ptr Word8
b ->
            Ptr Word8 -> Ptr P256Scalar -> IO ()
ccrypton_p256_from_bin Ptr Word8
b Ptr P256Scalar
p
{-# NOINLINE scalarFromBinary #-}

-- | convert a scalar to binary
scalarToBinary :: ByteArray ba => Scalar -> ba
scalarToBinary :: forall ba. ByteArray ba => Scalar -> ba
scalarToBinary Scalar
s = forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.unsafeCreate Int
scalarSize forall a b. (a -> b) -> a -> b
$ \Ptr Word8
b -> forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
s forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
p ->
    Ptr P256Scalar -> Ptr Word8 -> IO ()
ccrypton_p256_to_bin Ptr P256Scalar
p Ptr Word8
b
{-# NOINLINE scalarToBinary #-}

-- | Convert from an Integer to a P256 Scalar
scalarFromInteger :: Integer -> CryptoFailable Scalar
scalarFromInteger :: Integer -> CryptoFailable Scalar
scalarFromInteger Integer
i =
    forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_SecretKeySizeInvalid) forall ba. ByteArrayAccess ba => ba -> CryptoFailable Scalar
scalarFromBinary (forall ba. ByteArray ba => Int -> Integer -> Maybe ba
S.i2ospOf Int
32 Integer
i :: Maybe Bytes)

-- | Convert from a P256 Scalar to an Integer
scalarToInteger :: Scalar -> Integer
scalarToInteger :: Scalar -> Integer
scalarToInteger Scalar
s = forall ba. ByteArrayAccess ba => ba -> Integer
S.os2ip (forall ba. ByteArray ba => Scalar -> ba
scalarToBinary Scalar
s :: Bytes)

------------------------------------------------------------------------
-- Memory Helpers
------------------------------------------------------------------------
withNewPoint :: (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint :: (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint Ptr P256X -> Ptr P256Y -> IO ()
f = Bytes -> Point
Point forall a b. (a -> b) -> a -> b
$ forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.unsafeCreate Int
pointSize forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px -> Ptr P256X -> Ptr P256Y -> IO ()
f Ptr P256X
px (Ptr P256X -> Ptr P256Y
pxToPy Ptr P256X
px)
{-# NOINLINE withNewPoint #-}

withPoint :: Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint :: forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint (Point Bytes
d) Ptr P256X -> Ptr P256Y -> IO a
f = forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
d forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px -> Ptr P256X -> Ptr P256Y -> IO a
f Ptr P256X
px (Ptr P256X -> Ptr P256Y
pxToPy Ptr P256X
px)

pxToPy :: Ptr P256X -> Ptr P256Y
pxToPy :: Ptr P256X -> Ptr P256Y
pxToPy Ptr P256X
px = forall a b. Ptr a -> Ptr b
castPtr (Ptr P256X
px forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
scalarSize)

withNewScalarFreeze :: (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze :: (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze Ptr P256Scalar -> IO ()
f = ScrubbedBytes -> Scalar
Scalar forall a b. (a -> b) -> a -> b
$ forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
scalarSize Ptr P256Scalar -> IO ()
f
{-# NOINLINE withNewScalarFreeze #-}

withTempPoint :: (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withTempPoint :: forall a. (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withTempPoint Ptr P256X -> Ptr P256Y -> IO a
f = forall a. Int -> (Ptr Word8 -> IO a) -> IO a
allocTempScrubbed Int
pointSize (\Ptr Word8
p -> let px :: Ptr b
px = forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
p in Ptr P256X -> Ptr P256Y -> IO a
f forall {b}. Ptr b
px (Ptr P256X -> Ptr P256Y
pxToPy forall {b}. Ptr b
px))

withScalar :: Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar :: forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar (Scalar ScrubbedBytes
d) Ptr P256Scalar -> IO a
f = forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
d Ptr P256Scalar -> IO a
f

allocTemp :: Int -> (Ptr Word8 -> IO a) -> IO a
allocTemp :: forall a. Int -> (Ptr Word8 -> IO a) -> IO a
allocTemp Int
n Ptr Word8 -> IO a
f = forall a. (a, Bytes) -> a
ignoreSnd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
B.allocRet Int
n Ptr Word8 -> IO a
f
  where
    ignoreSnd :: (a, Bytes) -> a
    ignoreSnd :: forall a. (a, Bytes) -> a
ignoreSnd = forall a b. (a, b) -> a
fst

allocTempScrubbed :: Int -> (Ptr Word8 -> IO a) -> IO a
allocTempScrubbed :: forall a. Int -> (Ptr Word8 -> IO a) -> IO a
allocTempScrubbed Int
n Ptr Word8 -> IO a
f = forall a. (a, ScrubbedBytes) -> a
ignoreSnd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
B.allocRet Int
n Ptr Word8 -> IO a
f
  where
    ignoreSnd :: (a, ScrubbedBytes) -> a
    ignoreSnd :: forall a. (a, ScrubbedBytes) -> a
ignoreSnd = forall a b. (a, b) -> a
fst

------------------------------------------------------------------------
-- Foreign bindings
------------------------------------------------------------------------
foreign import ccall "&crypton_SECP256r1_n"
    ccrypton_SECP256r1_n :: Ptr P256Scalar
foreign import ccall "&crypton_SECP256r1_p"
    ccrypton_SECP256r1_p :: Ptr P256Scalar
foreign import ccall "&crypton_SECP256r1_b"
    ccrypton_SECP256r1_b :: Ptr P256Scalar

foreign import ccall "crypton_p256_init"
    ccrypton_p256_init :: Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256_is_zero"
    ccrypton_p256_is_zero :: Ptr P256Scalar -> IO CInt
foreign import ccall "crypton_p256_clear"
    ccrypton_p256_clear :: Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256e_modadd"
    ccrypton_p256e_modadd :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256_add_d"
    ccrypton_p256_add_d :: Ptr P256Scalar -> P256Digit -> Ptr P256Scalar -> IO CInt
foreign import ccall "crypton_p256e_modsub"
    ccrypton_p256e_modsub :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256_cmp"
    ccrypton_p256_cmp :: Ptr P256Scalar -> Ptr P256Scalar -> IO CInt
foreign import ccall "crypton_p256_mod"
    ccrypton_p256_mod :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256_modmul"
    ccrypton_p256_modmul :: Ptr P256Scalar -> Ptr P256Scalar -> P256Digit -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256e_scalar_invert"
    ccrypton_p256e_scalar_invert :: Ptr P256Scalar -> Ptr P256Scalar -> IO ()
--foreign import ccall "crypton_p256_modinv"
--    ccrypton_p256_modinv :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256_modinv_vartime"
    ccrypton_p256_modinv_vartime :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256_base_point_mul"
    ccrypton_p256_basepoint_mul :: Ptr P256Scalar
                                   -> Ptr P256X -> Ptr P256Y
                                   -> IO ()

foreign import ccall "crypton_p256e_point_add"
    ccrypton_p256e_point_add :: Ptr P256X -> Ptr P256Y
                                -> Ptr P256X -> Ptr P256Y
                                -> Ptr P256X -> Ptr P256Y
                                -> IO ()

foreign import ccall "crypton_p256e_point_negate"
    ccrypton_p256e_point_negate :: Ptr P256X -> Ptr P256Y
                                   -> Ptr P256X -> Ptr P256Y
                                   -> IO ()

-- compute (out_x,out_y) = n * (in_x,in_y)
foreign import ccall "crypton_p256e_point_mul"
    ccrypton_p256e_point_mul :: Ptr P256Scalar -- n
                                -> Ptr P256X -> Ptr P256Y -- in_{x,y}
                                -> Ptr P256X -> Ptr P256Y -- out_{x,y}
                                -> IO ()

-- compute (out_x,out,y) = n1 * G + n2 * (in_x,in_y)
foreign import ccall "crypton_p256_points_mul_vartime"
    ccrypton_p256_points_mul_vartime :: Ptr P256Scalar -- n1
                                        -> Ptr P256Scalar -- n2
                                        -> Ptr P256X -> Ptr P256Y -- in_{x,y}
                                        -> Ptr P256X -> Ptr P256Y -- out_{x,y}
                                        -> IO ()
foreign import ccall "crypton_p256_is_valid_point"
    ccrypton_p256_is_valid_point :: Ptr P256X -> Ptr P256Y -> IO CInt

foreign import ccall "crypton_p256_to_bin"
    ccrypton_p256_to_bin :: Ptr P256Scalar -> Ptr Word8 -> IO ()

foreign import ccall "crypton_p256_from_bin"
    ccrypton_p256_from_bin :: Ptr Word8 -> Ptr P256Scalar -> IO ()