-- |
-- Module      : Crypto.PubKey.ECC.DH
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- Elliptic curve Diffie Hellman
--
module Crypto.PubKey.ECC.DH
    (
      Curve
    , PublicPoint
    , PrivateNumber
    , SharedKey(..)
    , generatePrivate
    , calculatePublic
    , getShared
    ) where

import Crypto.Number.Generate (generateMax)
import Crypto.Number.Serialize (i2ospOf_)
import Crypto.PubKey.ECC.Prim (pointMul)
import Crypto.Random.Types
import Crypto.PubKey.DH (SharedKey(..))
import Crypto.PubKey.ECC.Types (PublicPoint, PrivateNumber, Curve, Point(..), curveSizeBits)
import Crypto.PubKey.ECC.Types (ecc_n, ecc_g, common_curve)

-- | Generating a private number d.
generatePrivate :: MonadRandom m => Curve -> m PrivateNumber
generatePrivate :: forall (m :: * -> *). MonadRandom m => Curve -> m PrivateNumber
generatePrivate Curve
curve = forall (m :: * -> *).
MonadRandom m =>
PrivateNumber -> m PrivateNumber
generateMax PrivateNumber
n
  where
    n :: PrivateNumber
n = CurveCommon -> PrivateNumber
ecc_n forall a b. (a -> b) -> a -> b
$ Curve -> CurveCommon
common_curve Curve
curve

-- | Generating a public point Q.
calculatePublic :: Curve -> PrivateNumber -> PublicPoint
calculatePublic :: Curve -> PrivateNumber -> PublicPoint
calculatePublic Curve
curve PrivateNumber
d = PublicPoint
q
  where
    g :: PublicPoint
g = CurveCommon -> PublicPoint
ecc_g forall a b. (a -> b) -> a -> b
$ Curve -> CurveCommon
common_curve Curve
curve
    q :: PublicPoint
q = Curve -> PrivateNumber -> PublicPoint -> PublicPoint
pointMul Curve
curve PrivateNumber
d PublicPoint
g

-- | Generating a shared key using our private number and
--   the other party public point.
getShared :: Curve -> PrivateNumber -> PublicPoint -> SharedKey
getShared :: Curve -> PrivateNumber -> PublicPoint -> SharedKey
getShared Curve
curve PrivateNumber
db PublicPoint
qa = ScrubbedBytes -> SharedKey
SharedKey forall a b. (a -> b) -> a -> b
$ forall ba. ByteArray ba => Int -> PrivateNumber -> ba
i2ospOf_ ((Int
nbBits forall a. Num a => a -> a -> a
+ Int
7) forall a. Integral a => a -> a -> a
`div` Int
8) PrivateNumber
x
  where
    Point PrivateNumber
x PrivateNumber
_ = Curve -> PrivateNumber -> PublicPoint -> PublicPoint
pointMul Curve
curve PrivateNumber
db PublicPoint
qa
    nbBits :: Int
nbBits    = Curve -> Int
curveSizeBits Curve
curve