{-# LANGUAGE CPP                      #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE CApiFFI                  #-}
-- | Diffie-Hellman key exchange
module OpenSSL.DH
    ( DHP
    , DH
    , DHGen(..)
    , genDHParams
    , getDHLength
    , checkDHParams
    , genDH
    , getDHParams
    , getDHPublicKey
    , computeDHKey
    )
    where
import Data.Word (Word8)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Internal as BS
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>))
#endif
import Foreign.Ptr (Ptr, nullPtr)
#if MIN_VERSION_base(4,5,0)
import Foreign.C.Types (CInt(..))
#else
import Foreign.C.Types (CInt)
#endif
import Foreign.Marshal.Alloc (alloca)
import OpenSSL.BN
import OpenSSL.DH.Internal
import OpenSSL.Utils

data DHGen = DHGen2
           | DHGen5
           deriving (DHGen -> DHGen -> Bool
(DHGen -> DHGen -> Bool) -> (DHGen -> DHGen -> Bool) -> Eq DHGen
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DHGen -> DHGen -> Bool
$c/= :: DHGen -> DHGen -> Bool
== :: DHGen -> DHGen -> Bool
$c== :: DHGen -> DHGen -> Bool
Eq, Eq DHGen
Eq DHGen
-> (DHGen -> DHGen -> Ordering)
-> (DHGen -> DHGen -> Bool)
-> (DHGen -> DHGen -> Bool)
-> (DHGen -> DHGen -> Bool)
-> (DHGen -> DHGen -> Bool)
-> (DHGen -> DHGen -> DHGen)
-> (DHGen -> DHGen -> DHGen)
-> Ord DHGen
DHGen -> DHGen -> Bool
DHGen -> DHGen -> Ordering
DHGen -> DHGen -> DHGen
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: DHGen -> DHGen -> DHGen
$cmin :: DHGen -> DHGen -> DHGen
max :: DHGen -> DHGen -> DHGen
$cmax :: DHGen -> DHGen -> DHGen
>= :: DHGen -> DHGen -> Bool
$c>= :: DHGen -> DHGen -> Bool
> :: DHGen -> DHGen -> Bool
$c> :: DHGen -> DHGen -> Bool
<= :: DHGen -> DHGen -> Bool
$c<= :: DHGen -> DHGen -> Bool
< :: DHGen -> DHGen -> Bool
$c< :: DHGen -> DHGen -> Bool
compare :: DHGen -> DHGen -> Ordering
$ccompare :: DHGen -> DHGen -> Ordering
Ord, Int -> DHGen -> ShowS
[DHGen] -> ShowS
DHGen -> String
(Int -> DHGen -> ShowS)
-> (DHGen -> String) -> ([DHGen] -> ShowS) -> Show DHGen
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DHGen] -> ShowS
$cshowList :: [DHGen] -> ShowS
show :: DHGen -> String
$cshow :: DHGen -> String
showsPrec :: Int -> DHGen -> ShowS
$cshowsPrec :: Int -> DHGen -> ShowS
Show)

-- | @'genDHParams' gen n@ generates @n@-bit long DH parameters.
genDHParams :: DHGen -> Int -> IO DHP
genDHParams :: DHGen -> Int -> IO DHP
genDHParams DHGen
gen Int
len = do
    CInt -> CInt -> Ptr () -> Ptr () -> IO (Ptr DH_)
_DH_generate_parameters (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) CInt
gen' Ptr ()
forall a. Ptr a
nullPtr Ptr ()
forall a. Ptr a
nullPtr
      IO (Ptr DH_) -> (Ptr DH_ -> IO (Ptr DH_)) -> IO (Ptr DH_)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr DH_ -> IO (Ptr DH_)
forall a. Ptr a -> IO (Ptr a)
failIfNull
      IO (Ptr DH_) -> (Ptr DH_ -> IO DHP) -> IO DHP
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr DH_ -> IO DHP
wrapDHPPtr
  where gen' :: CInt
gen' = case DHGen
gen of
                 DHGen
DHGen2 -> CInt
2
                 DHGen
DHGen5 -> CInt
5

-- | Get DH parameters length (in bits).
getDHLength :: DHP -> IO Int
getDHLength :: DHP -> IO Int
getDHLength DHP
dh = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DHP -> (Ptr DH_ -> IO CInt) -> IO CInt
forall a. DHP -> (Ptr DH_ -> IO a) -> IO a
withDHPPtr DHP
dh Ptr DH_ -> IO CInt
_DH_length

-- | Check that DH parameters are coherent.
checkDHParams :: DHP -> IO Bool
checkDHParams :: DHP -> IO Bool
checkDHParams DHP
dh = (Ptr CInt -> IO Bool) -> IO Bool
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CInt -> IO Bool) -> IO Bool)
-> (Ptr CInt -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
pErr ->
                     DHP -> (Ptr DH_ -> IO Bool) -> IO Bool
forall a. DHP -> (Ptr DH_ -> IO a) -> IO a
withDHPPtr DHP
dh ((Ptr DH_ -> IO Bool) -> IO Bool)
-> (Ptr DH_ -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr DH_
dhPtr -> Ptr DH_ -> Ptr CInt -> IO Bool
_DH_check Ptr DH_
dhPtr Ptr CInt
pErr

-- | The first step of a key exchange. Public and private keys are generated.
genDH :: DHP -> IO DH
genDH :: DHP -> IO DH
genDH DHP
dh = do
  DHP
dh' <- DHP -> (Ptr DH_ -> IO (Ptr DH_)) -> IO (Ptr DH_)
forall a. DHP -> (Ptr DH_ -> IO a) -> IO a
withDHPPtr DHP
dh Ptr DH_ -> IO (Ptr DH_)
_DH_dup IO (Ptr DH_) -> (Ptr DH_ -> IO (Ptr DH_)) -> IO (Ptr DH_)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr DH_ -> IO (Ptr DH_)
forall a. Ptr a -> IO (Ptr a)
failIfNull IO (Ptr DH_) -> (Ptr DH_ -> IO DHP) -> IO DHP
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr DH_ -> IO DHP
wrapDHPPtr
  DHP -> (Ptr DH_ -> IO CInt) -> IO CInt
forall a. DHP -> (Ptr DH_ -> IO a) -> IO a
withDHPPtr DHP
dh' Ptr DH_ -> IO CInt
_DH_generate_key IO CInt -> (CInt -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (CInt -> Bool) -> CInt -> IO ()
forall a. (a -> Bool) -> a -> IO ()
failIf_ (CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
/= CInt
1)
  DH -> IO DH
forall (m :: * -> *) a. Monad m => a -> m a
return (DH -> IO DH) -> DH -> IO DH
forall a b. (a -> b) -> a -> b
$ DHP -> DH
asDH DHP
dh'

-- | Get parameters of a key exchange.
getDHParams :: DH -> DHP
getDHParams :: DH -> DHP
getDHParams = DH -> DHP
asDHP

-- | Get the public key.
getDHPublicKey :: DH -> IO Integer
getDHPublicKey :: DH -> IO Integer
getDHPublicKey DH
dh =
  DH -> (Ptr DH_ -> IO Integer) -> IO Integer
forall a. DH -> (Ptr DH_ -> IO a) -> IO a
withDHPtr DH
dh ((Ptr DH_ -> IO Integer) -> IO Integer)
-> (Ptr DH_ -> IO Integer) -> IO Integer
forall a b. (a -> b) -> a -> b
$ \Ptr DH_
dhPtr -> do
    Ptr BIGNUM
pKey <- Ptr DH_ -> IO (Ptr BIGNUM)
_DH_get_pub_key Ptr DH_
dhPtr
    BigNum -> IO Integer
bnToInteger (Ptr BIGNUM -> BigNum
wrapBN Ptr BIGNUM
pKey)

-- | Compute the shared key using the other party's public key.
computeDHKey :: DH -> Integer -> IO ByteString
computeDHKey :: DH -> Integer -> IO ByteString
computeDHKey DH
dh Integer
pubKey =
  DH -> (Ptr DH_ -> IO ByteString) -> IO ByteString
forall a. DH -> (Ptr DH_ -> IO a) -> IO a
withDHPtr DH
dh ((Ptr DH_ -> IO ByteString) -> IO ByteString)
-> (Ptr DH_ -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr DH_
dhPtr ->
    Integer -> (BigNum -> IO ByteString) -> IO ByteString
forall a. Integer -> (BigNum -> IO a) -> IO a
withBN Integer
pubKey ((BigNum -> IO ByteString) -> IO ByteString)
-> (BigNum -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \BigNum
bn -> do
      Int
size <- CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr DH_ -> IO CInt
_DH_size Ptr DH_
dhPtr
      Int -> (Ptr Word8 -> IO Int) -> IO ByteString
BS.createAndTrim Int
size ((Ptr Word8 -> IO Int) -> IO ByteString)
-> (Ptr Word8 -> IO Int) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
bsPtr ->
        CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Word8 -> Ptr BIGNUM -> Ptr DH_ -> IO CInt
_DH_compute_key Ptr Word8
bsPtr (BigNum -> Ptr BIGNUM
unwrapBN BigNum
bn) Ptr DH_
dhPtr
          IO Int -> (Int -> IO Int) -> IO Int
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Int -> Bool) -> Int -> IO Int
forall a. (a -> Bool) -> a -> IO a
failIf (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0)

foreign import capi "openssl/dh.h DH_generate_parameters"
  _DH_generate_parameters :: CInt -> CInt -> Ptr () -> Ptr () -> IO (Ptr DH_)
foreign import capi "openssl/dh.h DH_generate_key"
  _DH_generate_key :: Ptr DH_ -> IO CInt
foreign import capi "openssl/dh.h DH_compute_key"
  _DH_compute_key :: Ptr Word8 -> Ptr BIGNUM -> Ptr DH_ -> IO CInt
foreign import capi "openssl/dh.h DH_check"
  _DH_check :: Ptr DH_ -> Ptr CInt -> IO Bool
foreign import capi unsafe "openssl/dh.h DH_size"
  _DH_size :: Ptr DH_ -> IO CInt
foreign import capi unsafe "HsOpenSSL.h HsOpenSSL_DHparams_dup"
  _DH_dup :: Ptr DH_ -> IO (Ptr DH_)
foreign import capi unsafe "HsOpenSSL.h HsOpenSSL_DH_get_pub_key"
  _DH_get_pub_key :: Ptr DH_ -> IO (Ptr BIGNUM)
foreign import capi unsafe "HsOpenSSL.h HsOpenSSL_DH_length"
  _DH_length :: Ptr DH_ -> IO CInt