{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module      :  Crypto.Ecdsa.Signature
-- Copyright   :  Aleksandr Krupenkin 2016-2021
-- License     :  Apache-2.0
--
-- Maintainer  :  mail@akru.me
-- Stability   :  experimental
-- Portability :  portable
--
-- Recoverable ECC signature support.
--

module Crypto.Ecdsa.Signature
    (
      sign
    , pack
    , unpack
    ) where

import           Control.Monad               (when)
import           Crypto.Hash                 (SHA256)
import           Crypto.Number.Generate      (generateBetween)
import           Crypto.Number.ModArithmetic (inverse)
import           Crypto.Number.Serialize     (i2osp, os2ip)
import           Crypto.PubKey.ECC.ECDSA     (PrivateKey (..))
import           Crypto.PubKey.ECC.Prim      (pointMul)
import           Crypto.PubKey.ECC.Types     (CurveCommon (ecc_g, ecc_n),
                                              Point (..), common_curve)
import           Crypto.Random               (MonadRandom, withDRG)
import           Crypto.Random.HmacDrbg      (HmacDrbg, initialize)
import           Data.Bits                   (xor, (.|.))
import           Data.ByteArray              (ByteArray, ByteArrayAccess, Bytes,
                                              convert, singleton, takeView,
                                              view)
import qualified Data.ByteArray              as BA (unpack)
import           Data.Word                   (Word8)

import           Crypto.Ecdsa.Utils          (exportKey)

-- | Sign arbitrary data by given private key.
--
-- /WARNING:/ Vulnerable to timing attacks.
sign :: ByteArrayAccess bin
     => PrivateKey
     -> bin
     -> (Integer, Integer, Word8)
sign :: PrivateKey -> bin -> (Integer, Integer, Word8)
sign PrivateKey
pk bin
bin = ((Integer, Integer, Word8), HmacDrbg SHA256)
-> (Integer, Integer, Word8)
forall a b. (a, b) -> a
fst (((Integer, Integer, Word8), HmacDrbg SHA256)
 -> (Integer, Integer, Word8))
-> ((Integer, Integer, Word8), HmacDrbg SHA256)
-> (Integer, Integer, Word8)
forall a b. (a -> b) -> a -> b
$ HmacDrbg SHA256
-> MonadPseudoRandom (HmacDrbg SHA256) (Integer, Integer, Word8)
-> ((Integer, Integer, Word8), HmacDrbg SHA256)
forall gen a. DRG gen => gen -> MonadPseudoRandom gen a -> (a, gen)
withDRG HmacDrbg SHA256
hmac_drbg (MonadPseudoRandom (HmacDrbg SHA256) (Integer, Integer, Word8)
 -> ((Integer, Integer, Word8), HmacDrbg SHA256))
-> MonadPseudoRandom (HmacDrbg SHA256) (Integer, Integer, Word8)
-> ((Integer, Integer, Word8), HmacDrbg SHA256)
forall a b. (a -> b) -> a -> b
$ PrivateKey
-> Integer
-> MonadPseudoRandom (HmacDrbg SHA256) (Integer, Integer, Word8)
forall (m :: * -> *).
MonadRandom m =>
PrivateKey -> Integer -> m (Integer, Integer, Word8)
ecsign PrivateKey
pk (Bytes -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip Bytes
truncated)
  where
    hmac_drbg :: HmacDrbg SHA256
    hmac_drbg :: HmacDrbg SHA256
hmac_drbg = Bytes -> HmacDrbg SHA256
forall seed a.
(ByteArray seed, HashAlgorithm a) =>
seed -> HmacDrbg a
initialize (Bytes -> HmacDrbg SHA256) -> Bytes -> HmacDrbg SHA256
forall a b. (a -> b) -> a -> b
$ PrivateKey -> Bytes
forall privateKey. ByteArray privateKey => PrivateKey -> privateKey
exportKey PrivateKey
pk Bytes -> Bytes -> Bytes
forall a. Semigroup a => a -> a -> a
<> Bytes
truncated
    truncated :: Bytes
truncated = View bin -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (View bin -> Bytes) -> View bin -> Bytes
forall a b. (a -> b) -> a -> b
$ bin -> Int -> View bin
forall bytes. ByteArrayAccess bytes => bytes -> Int -> View bytes
takeView bin
bin Int
32 :: Bytes

ecsign :: MonadRandom m
       => PrivateKey
       -> Integer
       -> m (Integer, Integer, Word8)
ecsign :: PrivateKey -> Integer -> m (Integer, Integer, Word8)
ecsign pk :: PrivateKey
pk@(PrivateKey Curve
curve Integer
d) Integer
z = do
    Integer
k <- Integer -> Integer -> m Integer
forall (m :: * -> *).
MonadRandom m =>
Integer -> Integer -> m Integer
generateBetween Integer
0 (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)
    case Integer -> Maybe (Integer, Integer, Word8)
forall c. (Bits c, Num c) => Integer -> Maybe (Integer, Integer, c)
trySign Integer
k of
        Maybe (Integer, Integer, Word8)
Nothing  -> PrivateKey -> Integer -> m (Integer, Integer, Word8)
forall (m :: * -> *).
MonadRandom m =>
PrivateKey -> Integer -> m (Integer, Integer, Word8)
ecsign PrivateKey
pk Integer
z
        Just (Integer, Integer, Word8)
rsv -> (Integer, Integer, Word8) -> m (Integer, Integer, Word8)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer, Integer, Word8)
rsv
  where
    n :: Integer
n = CurveCommon -> Integer
ecc_n (Curve -> CurveCommon
common_curve Curve
curve)
    g :: Point
g = CurveCommon -> Point
ecc_g (Curve -> CurveCommon
common_curve Curve
curve)
    recoveryParam :: a -> a -> a -> b
recoveryParam a
x a
y a
r = Int -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> b) -> Int -> b
forall a b. (a -> b) -> a -> b
$
        Bool -> Int
forall a. Enum a => a -> Int
fromEnum (a -> Bool
forall a. Integral a => a -> Bool
odd a
y) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. if a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
r then Int
2 else Int
0
    trySign :: Integer -> Maybe (Integer, Integer, c)
trySign Integer
k = do
        (Integer
kpX, Integer
kpY) <- case Curve -> Integer -> Point -> Point
pointMul Curve
curve Integer
k Point
g of
            Point
PointO    -> Maybe (Integer, Integer)
forall a. Maybe a
Nothing
            Point Integer
x Integer
y -> (Integer, Integer) -> Maybe (Integer, Integer)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer
x, Integer
y)
        let r :: Integer
r = Integer
kpX Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n
        Integer
kInv <- Integer -> Integer -> Maybe Integer
inverse Integer
k Integer
n
        let s :: Integer
s = Integer
kInv Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
z Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
d) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n
        Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Integer
r Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 Bool -> Bool -> Bool
|| Integer
s Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0) Maybe ()
forall a. Maybe a
Nothing
        -- Recovery param
        let v :: c
v = Integer -> Integer -> Integer -> c
forall a b a. (Integral a, Num b, Eq a) => a -> a -> a -> b
recoveryParam Integer
kpX Integer
kpY Integer
r
        -- Use complement of s if it > n / 2
        let (Integer
s', c
v') | Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
2 = (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
s, c
v c -> c -> c
forall a. Bits a => a -> a -> a
`xor` c
1)
                     | Bool
otherwise = (Integer
s, c
v)
        (Integer, Integer, c) -> Maybe (Integer, Integer, c)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer
r, Integer
s', c
v' c -> c -> c
forall a. Num a => a -> a -> a
+ c
27)

-- | Unpack recoverable signature from byte array.
--
-- Input array should have 65 byte length.
unpack :: ByteArrayAccess rsv => rsv -> (Integer, Integer, Word8)
unpack :: rsv -> (Integer, Integer, Word8)
unpack rsv
vrs = (Integer
r, Integer
s, Word8
v)
  where
    r :: Integer
r = View rsv -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip (rsv -> Int -> Int -> View rsv
forall bytes.
ByteArrayAccess bytes =>
bytes -> Int -> Int -> View bytes
view rsv
vrs Int
1 Int
33)
    s :: Integer
s = View rsv -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip (rsv -> Int -> Int -> View rsv
forall bytes.
ByteArrayAccess bytes =>
bytes -> Int -> Int -> View bytes
view rsv
vrs Int
33 Int
65)
    v :: Word8
v = [Word8] -> Word8
forall a. [a] -> a
head (rsv -> [Word8]
forall a. ByteArrayAccess a => a -> [Word8]
BA.unpack rsv
vrs)

-- | Pack recoverable signature as byte array (65 byte length).
pack :: ByteArray rsv => (Integer, Integer, Word8) -> rsv
pack :: (Integer, Integer, Word8) -> rsv
pack (Integer
r, Integer
s, Word8
v) = Integer -> rsv
forall ba. ByteArray ba => Integer -> ba
i2osp Integer
r rsv -> rsv -> rsv
forall a. Semigroup a => a -> a -> a
<> Integer -> rsv
forall ba. ByteArray ba => Integer -> ba
i2osp Integer
s rsv -> rsv -> rsv
forall a. Semigroup a => a -> a -> a
<> Word8 -> rsv
forall a. ByteArray a => Word8 -> a
singleton Word8
v