{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE NamedFieldPuns #-}

{-|
Module: Crypto.Spake2
Description: Implementation of SPAKE2 key exchange protocol

Say that you and someone else share a secret password, and you want to use
this password to arrange some secure channel of communication. You want:

 * to know that the other party also knows the secret password (maybe
   they're an imposter!)
 * the password to be secure against offline dictionary attacks
 * probably some other things

SPAKE2 is an algorithm for agreeing on a key exchange that meets these
criteria. See [Simple Password-Based Encrypted Key Exchange
Protocols](http://www.di.ens.fr/~pointche/Documents/Papers/2005_rsa.pdf) by
Michel Abdalla and David Pointcheval for more details.

== How it works

=== Preliminaries

Before exchanging, two nodes need to agree on the following, out-of-band:

In general:

* hash algorithm, \(H\)
* group to use, \(G\)
* arbitrary members of group to use for blinding
* a means of converting this password to a scalar of group

For a specific exchange:

* whether the connection is symmetric or asymmetric
* the IDs of the respective sides
* a shared, secret password in bytes

#protocol#

=== Protocol

==== How we map the password to a scalar

Use HKDF expansion (see 'expandData') to expand the password by 16 bytes,
using an empty salt, and "SPAKE2 pw" as the info.

Then, use a group-specific mapping from bytes to scalars.
Since scalars are normally isomorphic to integers,
this will normally be a matter of converting the bytes to an integer
using standard deserialization
and then turning the integer into a scalar.

==== How we exchange information

See 'Crypto.Spake2.Math' for details on the mathematics of the exchange.

==== How python-spake2 works

- Message to other side is prepended with a single character, @A@, @B@, or
  @S@, to indicate which side it came from
- The hash function for generating the session key has a few interesting properties:
    - uses SHA256 for hashing
    - does not include password or IDs directly, but rather uses /their/ SHA256
      digests as inputs to the hash
    - for the symmetric version, it sorts \(X^{\star}\) and \(Y^{\star}\),
      because neither side knows which is which
- By default, the ID of either side is the empty bytestring

== Open questions

* how does endianness come into play?
* what is Shallue-Woestijne-Ulas and why is it relevant?

== References

* [Javascript implementation](https://github.com/bitwiseshiftleft/sjcl/pull/273/), includes long, possibly relevant discussion
* [Python implementation](https://github.com/warner/python-spake2)
* [SPAKE2 random elements](http://www.lothar.com/blog/54-spake2-random-elements/) - blog post by warner about choosing \(M\) and \(N\)
* [Simple Password-Based Encrypted Key Exchange Protocols](http://www.di.ens.fr/~pointche/Documents/Papers/2005_rsa.pdf) by Michel Abdalla and David Pointcheval
* [draft-irtf-cfrg-spake2-03](https://tools.ietf.org/html/draft-irtf-cfrg-spake2-03) - expired IRTF draft for SPAKE2

-}

module Crypto.Spake2
  ( Password
  , makePassword
  -- * The SPAKE2 protocol
  , Protocol
  , makeAsymmetricProtocol
  , makeSymmetricProtocol
  , spake2Exchange
  , startSpake2
  , Math.computeOutboundMessage
  , Math.generateKeyMaterial
  , extractElement
  , MessageError
  , formatError
  , elementToMessage
  , createSessionKey
  , SideID(..)
  , WhichSide(..)
  ) where

import Protolude hiding (group)

import Crypto.Error (CryptoError, CryptoFailable(..))
import Crypto.Hash (HashAlgorithm, hashWith)
import Crypto.Random.Types (MonadRandom(..))
import Data.ByteArray (ByteArrayAccess)
import qualified Data.ByteArray as ByteArray
import qualified Data.ByteString as ByteString

import Crypto.Spake2.Group (AbelianGroup(..), Group(..), decodeScalar, scalarSizeBytes)
import qualified Crypto.Spake2.Math as Math
import Crypto.Spake2.Util (expandData)


-- | Shared secret password used to negotiate the connection.
--
-- Constructor deliberately not exported,
-- so that once a 'Password' has been created, the actual password cannot be retrieved by other modules.
--
-- Construct with 'makePassword'.
newtype Password = Password ByteString deriving (Password -> Password -> Bool
(Password -> Password -> Bool)
-> (Password -> Password -> Bool) -> Eq Password
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Password -> Password -> Bool
$c/= :: Password -> Password -> Bool
== :: Password -> Password -> Bool
$c== :: Password -> Password -> Bool
Eq, Eq Password
Eq Password
-> (Password -> Password -> Ordering)
-> (Password -> Password -> Bool)
-> (Password -> Password -> Bool)
-> (Password -> Password -> Bool)
-> (Password -> Password -> Bool)
-> (Password -> Password -> Password)
-> (Password -> Password -> Password)
-> Ord Password
Password -> Password -> Bool
Password -> Password -> Ordering
Password -> Password -> Password
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Password -> Password -> Password
$cmin :: Password -> Password -> Password
max :: Password -> Password -> Password
$cmax :: Password -> Password -> Password
>= :: Password -> Password -> Bool
$c>= :: Password -> Password -> Bool
> :: Password -> Password -> Bool
$c> :: Password -> Password -> Bool
<= :: Password -> Password -> Bool
$c<= :: Password -> Password -> Bool
< :: Password -> Password -> Bool
$c< :: Password -> Password -> Bool
compare :: Password -> Password -> Ordering
$ccompare :: Password -> Password -> Ordering
$cp1Ord :: Eq Password
Ord)

-- | Construct a password.
makePassword :: ByteString -> Password
makePassword :: ByteString -> Password
makePassword = ByteString -> Password
Password

-- | Bytes that identify a side of the protocol
newtype SideID = SideID { SideID -> ByteString
unSideID :: ByteString } deriving (SideID -> SideID -> Bool
(SideID -> SideID -> Bool)
-> (SideID -> SideID -> Bool) -> Eq SideID
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SideID -> SideID -> Bool
$c/= :: SideID -> SideID -> Bool
== :: SideID -> SideID -> Bool
$c== :: SideID -> SideID -> Bool
Eq, Eq SideID
Eq SideID
-> (SideID -> SideID -> Ordering)
-> (SideID -> SideID -> Bool)
-> (SideID -> SideID -> Bool)
-> (SideID -> SideID -> Bool)
-> (SideID -> SideID -> Bool)
-> (SideID -> SideID -> SideID)
-> (SideID -> SideID -> SideID)
-> Ord SideID
SideID -> SideID -> Bool
SideID -> SideID -> Ordering
SideID -> SideID -> SideID
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SideID -> SideID -> SideID
$cmin :: SideID -> SideID -> SideID
max :: SideID -> SideID -> SideID
$cmax :: SideID -> SideID -> SideID
>= :: SideID -> SideID -> Bool
$c>= :: SideID -> SideID -> Bool
> :: SideID -> SideID -> Bool
$c> :: SideID -> SideID -> Bool
<= :: SideID -> SideID -> Bool
$c<= :: SideID -> SideID -> Bool
< :: SideID -> SideID -> Bool
$c< :: SideID -> SideID -> Bool
compare :: SideID -> SideID -> Ordering
$ccompare :: SideID -> SideID -> Ordering
$cp1Ord :: Eq SideID
Ord, Int -> SideID -> ShowS
[SideID] -> ShowS
SideID -> String
(Int -> SideID -> ShowS)
-> (SideID -> String) -> ([SideID] -> ShowS) -> Show SideID
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SideID] -> ShowS
$cshowList :: [SideID] -> ShowS
show :: SideID -> String
$cshow :: SideID -> String
showsPrec :: Int -> SideID -> ShowS
$cshowsPrec :: Int -> SideID -> ShowS
Show)

-- | Convert a user-supplied password into a scalar on a group.
passwordToScalar :: AbelianGroup group => group -> Password -> Scalar group
passwordToScalar :: group -> Password -> Scalar group
passwordToScalar group
group Password
password =
  group -> ByteString -> Scalar group
forall bytes group.
(ByteArrayAccess bytes, AbelianGroup group) =>
group -> bytes -> Scalar group
decodeScalar group
group ByteString
oversized
  where
    oversized :: ByteString
oversized = Password -> Int -> ByteString
forall output. ByteArray output => Password -> Int -> output
expandPassword Password
password (group -> Int
forall group. AbelianGroup group => group -> Int
scalarSizeBytes group
group Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
16) :: ByteString
    expandPassword :: Password -> Int -> output
expandPassword (Password ByteString
bytes) = ByteString -> ByteString -> Int -> output
forall input output.
(ByteArrayAccess input, ByteArray output) =>
ByteString -> input -> Int -> output
expandData ByteString
info ByteString
bytes
    -- This needs to be exactly "SPAKE2 pw"
    -- See <https://github.com/bitwiseshiftleft/sjcl/pull/273/#issuecomment-185251593>
    info :: ByteString
info = ByteString
"SPAKE2 pw"

-- | Turn an element into a message from this side of the protocol.
elementToMessage :: Group group => Protocol group hashAlgorithm -> Element group -> ByteString
elementToMessage :: Protocol group hashAlgorithm -> Element group -> ByteString
elementToMessage Protocol group hashAlgorithm
protocol Element group
element = ByteString
prefix ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> group -> Element group -> ByteString
forall group bytes.
(Group group, ByteArray bytes) =>
group -> Element group -> bytes
encodeElement (Protocol group hashAlgorithm -> group
forall group hashAlgorithm. Protocol group hashAlgorithm -> group
group Protocol group hashAlgorithm
protocol) Element group
element
  where
    prefix :: ByteString
prefix =
      case Protocol group hashAlgorithm -> Relation group
forall group hashAlgorithm.
Protocol group hashAlgorithm -> Relation group
relation Protocol group hashAlgorithm
protocol of
        Symmetric Side group
_ -> ByteString
"S"
        Asymmetric{us :: forall group. Relation group -> WhichSide
us=WhichSide
SideA} -> ByteString
"A"
        Asymmetric{us :: forall group. Relation group -> WhichSide
us=WhichSide
SideB} -> ByteString
"B"

-- | An error that occurs when interpreting messages from the other side of the exchange.
data MessageError e
  = EmptyMessage -- ^ We received an empty bytestring.
  | UnexpectedPrefix Word8 Word8
    -- ^ The bytestring had an unexpected prefix.
    -- We expect the prefix to be @A@ if the other side is side A,
    -- @B@ if they are side B,
    -- or @S@ if the connection is symmetric.
    -- First argument is received prefix, second is expected.
  | BadCrypto CryptoError ByteString
    -- ^ Message could not be decoded to an element of the group.
    -- This can indicate either an error in serialization logic,
    -- or in mathematics.
  | UnknownError e
    -- ^ An error arising from the "receive" action in 'spake2Exchange'.
    -- Since 0.4.0
  deriving (MessageError e -> MessageError e -> Bool
(MessageError e -> MessageError e -> Bool)
-> (MessageError e -> MessageError e -> Bool)
-> Eq (MessageError e)
forall e. Eq e => MessageError e -> MessageError e -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MessageError e -> MessageError e -> Bool
$c/= :: forall e. Eq e => MessageError e -> MessageError e -> Bool
== :: MessageError e -> MessageError e -> Bool
$c== :: forall e. Eq e => MessageError e -> MessageError e -> Bool
Eq, Int -> MessageError e -> ShowS
[MessageError e] -> ShowS
MessageError e -> String
(Int -> MessageError e -> ShowS)
-> (MessageError e -> String)
-> ([MessageError e] -> ShowS)
-> Show (MessageError e)
forall e. Show e => Int -> MessageError e -> ShowS
forall e. Show e => [MessageError e] -> ShowS
forall e. Show e => MessageError e -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MessageError e] -> ShowS
$cshowList :: forall e. Show e => [MessageError e] -> ShowS
show :: MessageError e -> String
$cshow :: forall e. Show e => MessageError e -> String
showsPrec :: Int -> MessageError e -> ShowS
$cshowsPrec :: forall e. Show e => Int -> MessageError e -> ShowS
Show)

-- | Turn a 'MessageError' into human-readable text.
formatError :: Show e => MessageError e -> Text
formatError :: MessageError e -> Text
formatError MessageError e
EmptyMessage = Text
"Other side sent us an empty message"
formatError (UnexpectedPrefix Word8
got Word8
expected) = Text
"Other side claims to be " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Char -> Text
forall a b. (Show a, ConvertText String b) => a -> b
show (Int -> Char
chr (Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
got)) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", expected " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Char -> Text
forall a b. (Show a, ConvertText String b) => a -> b
show (Int -> Char
chr (Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
expected))
formatError (BadCrypto CryptoError
err ByteString
message) = Text
"Could not decode message (" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ByteString -> Text
forall a b. (Show a, ConvertText String b) => a -> b
show ByteString
message Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
") to element: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> CryptoError -> Text
forall a b. (Show a, ConvertText String b) => a -> b
show CryptoError
err
formatError (UnknownError e
err) = Text
"Error receiving message from other side: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> e -> Text
forall a b. (Show a, ConvertText String b) => a -> b
show e
err

-- | Extract an element on the group from an incoming message.
--
-- Returns a 'MessageError' if we cannot decode the message,
-- or the other side does not appear to be the expected other side.
--
-- TODO: Need to protect against reflection attack at some point.
extractElement :: Group group => Protocol group hashAlgorithm -> ByteString -> Either (MessageError error) (Element group)
extractElement :: Protocol group hashAlgorithm
-> ByteString -> Either (MessageError error) (Element group)
extractElement Protocol group hashAlgorithm
protocol ByteString
message =
  case ByteString -> Maybe (Word8, ByteString)
ByteString.uncons ByteString
message of
    Maybe (Word8, ByteString)
Nothing -> MessageError error -> Either (MessageError error) (Element group)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError MessageError error
forall e. MessageError e
EmptyMessage
    Just (Word8
prefix, ByteString
msg)
      | Word8
prefix Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Relation group -> Word8
forall a. Relation a -> Word8
theirPrefix (Protocol group hashAlgorithm -> Relation group
forall group hashAlgorithm.
Protocol group hashAlgorithm -> Relation group
relation Protocol group hashAlgorithm
protocol) -> MessageError error -> Either (MessageError error) (Element group)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (MessageError error -> Either (MessageError error) (Element group))
-> MessageError error
-> Either (MessageError error) (Element group)
forall a b. (a -> b) -> a -> b
$ Word8 -> Word8 -> MessageError error
forall e. Word8 -> Word8 -> MessageError e
UnexpectedPrefix Word8
prefix (Relation group -> Word8
forall a. Relation a -> Word8
theirPrefix (Protocol group hashAlgorithm -> Relation group
forall group hashAlgorithm.
Protocol group hashAlgorithm -> Relation group
relation Protocol group hashAlgorithm
protocol))
      | Bool
otherwise ->
        case group -> ByteString -> CryptoFailable (Element group)
forall group bytes.
(Group group, ByteArray bytes) =>
group -> bytes -> CryptoFailable (Element group)
decodeElement (Protocol group hashAlgorithm -> group
forall group hashAlgorithm. Protocol group hashAlgorithm -> group
group Protocol group hashAlgorithm
protocol) ByteString
msg of
          CryptoFailed CryptoError
err -> MessageError error -> Either (MessageError error) (Element group)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (CryptoError -> ByteString -> MessageError error
forall e. CryptoError -> ByteString -> MessageError e
BadCrypto CryptoError
err ByteString
msg)
          CryptoPassed Element group
element -> Element group -> Either (MessageError error) (Element group)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Element group
element


-- | One side of the SPAKE2 protocol.
data Side group
  = Side
  { Side group -> SideID
sideID :: SideID -- ^ Bytes identifying this side
  , Side group -> Element group
blind :: Element group -- ^ Arbitrarily chosen element in the group
                           -- used by this side to blind outgoing messages.
  }

-- | Which side we are.
data WhichSide = SideA | SideB deriving (WhichSide -> WhichSide -> Bool
(WhichSide -> WhichSide -> Bool)
-> (WhichSide -> WhichSide -> Bool) -> Eq WhichSide
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: WhichSide -> WhichSide -> Bool
$c/= :: WhichSide -> WhichSide -> Bool
== :: WhichSide -> WhichSide -> Bool
$c== :: WhichSide -> WhichSide -> Bool
Eq, Eq WhichSide
Eq WhichSide
-> (WhichSide -> WhichSide -> Ordering)
-> (WhichSide -> WhichSide -> Bool)
-> (WhichSide -> WhichSide -> Bool)
-> (WhichSide -> WhichSide -> Bool)
-> (WhichSide -> WhichSide -> Bool)
-> (WhichSide -> WhichSide -> WhichSide)
-> (WhichSide -> WhichSide -> WhichSide)
-> Ord WhichSide
WhichSide -> WhichSide -> Bool
WhichSide -> WhichSide -> Ordering
WhichSide -> WhichSide -> WhichSide
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: WhichSide -> WhichSide -> WhichSide
$cmin :: WhichSide -> WhichSide -> WhichSide
max :: WhichSide -> WhichSide -> WhichSide
$cmax :: WhichSide -> WhichSide -> WhichSide
>= :: WhichSide -> WhichSide -> Bool
$c>= :: WhichSide -> WhichSide -> Bool
> :: WhichSide -> WhichSide -> Bool
$c> :: WhichSide -> WhichSide -> Bool
<= :: WhichSide -> WhichSide -> Bool
$c<= :: WhichSide -> WhichSide -> Bool
< :: WhichSide -> WhichSide -> Bool
$c< :: WhichSide -> WhichSide -> Bool
compare :: WhichSide -> WhichSide -> Ordering
$ccompare :: WhichSide -> WhichSide -> Ordering
$cp1Ord :: Eq WhichSide
Ord, Int -> WhichSide -> ShowS
[WhichSide] -> ShowS
WhichSide -> String
(Int -> WhichSide -> ShowS)
-> (WhichSide -> String)
-> ([WhichSide] -> ShowS)
-> Show WhichSide
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WhichSide] -> ShowS
$cshowList :: [WhichSide] -> ShowS
show :: WhichSide -> String
$cshow :: WhichSide -> String
showsPrec :: Int -> WhichSide -> ShowS
$cshowsPrec :: Int -> WhichSide -> ShowS
Show, WhichSide
WhichSide -> WhichSide -> Bounded WhichSide
forall a. a -> a -> Bounded a
maxBound :: WhichSide
$cmaxBound :: WhichSide
minBound :: WhichSide
$cminBound :: WhichSide
Bounded, Int -> WhichSide
WhichSide -> Int
WhichSide -> [WhichSide]
WhichSide -> WhichSide
WhichSide -> WhichSide -> [WhichSide]
WhichSide -> WhichSide -> WhichSide -> [WhichSide]
(WhichSide -> WhichSide)
-> (WhichSide -> WhichSide)
-> (Int -> WhichSide)
-> (WhichSide -> Int)
-> (WhichSide -> [WhichSide])
-> (WhichSide -> WhichSide -> [WhichSide])
-> (WhichSide -> WhichSide -> [WhichSide])
-> (WhichSide -> WhichSide -> WhichSide -> [WhichSide])
-> Enum WhichSide
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: WhichSide -> WhichSide -> WhichSide -> [WhichSide]
$cenumFromThenTo :: WhichSide -> WhichSide -> WhichSide -> [WhichSide]
enumFromTo :: WhichSide -> WhichSide -> [WhichSide]
$cenumFromTo :: WhichSide -> WhichSide -> [WhichSide]
enumFromThen :: WhichSide -> WhichSide -> [WhichSide]
$cenumFromThen :: WhichSide -> WhichSide -> [WhichSide]
enumFrom :: WhichSide -> [WhichSide]
$cenumFrom :: WhichSide -> [WhichSide]
fromEnum :: WhichSide -> Int
$cfromEnum :: WhichSide -> Int
toEnum :: Int -> WhichSide
$ctoEnum :: Int -> WhichSide
pred :: WhichSide -> WhichSide
$cpred :: WhichSide -> WhichSide
succ :: WhichSide -> WhichSide
$csucc :: WhichSide -> WhichSide
Enum)

-- | Relation between two sides in SPAKE2.
-- Can be either symmetric (both sides are the same), or asymmetric.
data Relation group
  = Asymmetric
  { Relation group -> Side group
sideA :: Side group -- ^ Side A. Both sides need to agree who side A is.
  , Relation group -> Side group
sideB :: Side group -- ^ Side B. Both sides need to agree who side B is.
  , Relation group -> WhichSide
us :: WhichSide -- ^ Which side we are
  }
  | Symmetric
  { Relation group -> Side group
bothSides :: Side group -- ^ Description used by both sides.
  }

theirPrefix :: Relation a -> Word8
theirPrefix :: Relation a -> Word8
theirPrefix Relation a
relation =
  Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
ord (Char -> Word8) -> Char -> Word8
forall a b. (a -> b) -> a -> b
$ case Relation a
relation of
                         Asymmetric{us :: forall group. Relation group -> WhichSide
us=WhichSide
SideA} -> Char
'B'
                         Asymmetric{us :: forall group. Relation group -> WhichSide
us=WhichSide
SideB} -> Char
'A'
                         Symmetric{} -> Char
'S'

-- | Everything required for the SPAKE2 protocol.
--
-- Both sides must agree on these values for the protocol to work.
-- This /mostly/ means value equality, except for 'Relation.us',
-- where each side must have complementary values.
--
-- Construct with 'makeAsymmetricProtocol' or 'makeSymmetricProtocol'.
data Protocol group hashAlgorithm
  = Protocol
  { Protocol group hashAlgorithm -> group
group :: group -- ^ The group to use for encryption
  , Protocol group hashAlgorithm -> hashAlgorithm
hashAlgorithm :: hashAlgorithm -- ^ Hash algorithm used for generating the session key
  , Protocol group hashAlgorithm -> Relation group
relation :: Relation group  -- ^ How the two sides relate to each other
  }

-- | Construct an asymmetric SPAKE2 protocol.
makeAsymmetricProtocol :: hashAlgorithm -> group -> Element group -> Element group -> SideID -> SideID -> WhichSide -> Protocol group hashAlgorithm
makeAsymmetricProtocol :: hashAlgorithm
-> group
-> Element group
-> Element group
-> SideID
-> SideID
-> WhichSide
-> Protocol group hashAlgorithm
makeAsymmetricProtocol hashAlgorithm
hashAlgorithm group
group Element group
blindA Element group
blindB SideID
sideA SideID
sideB WhichSide
whichSide =
  Protocol :: forall group hashAlgorithm.
group
-> hashAlgorithm -> Relation group -> Protocol group hashAlgorithm
Protocol
  { group :: group
group = group
group
  , hashAlgorithm :: hashAlgorithm
hashAlgorithm = hashAlgorithm
hashAlgorithm
  , relation :: Relation group
relation = Asymmetric :: forall group.
Side group -> Side group -> WhichSide -> Relation group
Asymmetric
               { sideA :: Side group
sideA = Side :: forall group. SideID -> Element group -> Side group
Side { sideID :: SideID
sideID = SideID
sideA, blind :: Element group
blind = Element group
blindA }
               , sideB :: Side group
sideB = Side :: forall group. SideID -> Element group -> Side group
Side { sideID :: SideID
sideID = SideID
sideB, blind :: Element group
blind = Element group
blindB }
               , us :: WhichSide
us = WhichSide
whichSide
               }
  }

-- | Construct a symmetric SPAKE2 protocol.
makeSymmetricProtocol :: hashAlgorithm -> group -> Element group -> SideID -> Protocol group hashAlgorithm
makeSymmetricProtocol :: hashAlgorithm
-> group -> Element group -> SideID -> Protocol group hashAlgorithm
makeSymmetricProtocol hashAlgorithm
hashAlgorithm group
group Element group
blind SideID
id =
  Protocol :: forall group hashAlgorithm.
group
-> hashAlgorithm -> Relation group -> Protocol group hashAlgorithm
Protocol
  { group :: group
group = group
group
  , hashAlgorithm :: hashAlgorithm
hashAlgorithm = hashAlgorithm
hashAlgorithm
  , relation :: Relation group
relation = Side group -> Relation group
forall group. Side group -> Relation group
Symmetric Side :: forall group. SideID -> Element group -> Side group
Side { sideID :: SideID
sideID = SideID
id, blind :: Element group
blind = Element group
blind }
  }

-- | Get the parameters for the mathematical part of SPAKE2 from the protocol specification.
getParams :: Protocol group hashAlgorithm -> Math.Params group
getParams :: Protocol group hashAlgorithm -> Params group
getParams Protocol{group
group :: group
group :: forall group hashAlgorithm. Protocol group hashAlgorithm -> group
group, Relation group
relation :: Relation group
relation :: forall group hashAlgorithm.
Protocol group hashAlgorithm -> Relation group
relation} =
  case Relation group
relation of
    Symmetric{Side group
bothSides :: Side group
bothSides :: forall group. Relation group -> Side group
bothSides} -> Side group -> Side group -> Params group
mkParams Side group
bothSides Side group
bothSides
    Asymmetric{Side group
sideA :: Side group
sideA :: forall group. Relation group -> Side group
sideA, Side group
sideB :: Side group
sideB :: forall group. Relation group -> Side group
sideB, WhichSide
us :: WhichSide
us :: forall group. Relation group -> WhichSide
us} ->
      case WhichSide
us of
        WhichSide
SideA -> Side group -> Side group -> Params group
mkParams Side group
sideA Side group
sideB
        WhichSide
SideB -> Side group -> Side group -> Params group
mkParams Side group
sideB Side group
sideA

  where
    mkParams :: Side group -> Side group -> Params group
mkParams Side group
ours Side group
theirs =
      Params :: forall group.
group -> Element group -> Element group -> Params group
Math.Params
      { group :: group
Math.group = group
group
      , ourBlind :: Element group
Math.ourBlind = Side group -> Element group
forall group. Side group -> Element group
blind Side group
ours
      , theirBlind :: Element group
Math.theirBlind = Side group -> Element group
forall group. Side group -> Element group
blind Side group
theirs
      }

-- | Perform an entire SPAKE2 exchange.
--
-- Given a SPAKE2 protocol that has all of the parameters for this exchange,
-- generate a one-off message from this side and receive a one off message
-- from the other.
--
-- Once we are done, return a key shared between both sides for a single
-- session.
--
-- Note: as per the SPAKE2 definition, the session key is not guaranteed
-- to actually /work/. If the other side has failed to authenticate, you will
-- still get a session key. Therefore, you must exchange some other message
-- that has been encrypted using this key in order to confirm that the session
-- key is indeed shared.
--
-- Note: the "send" and "receive" actions are performed 'concurrently'. If you
-- have ordering requirements, consider using a 'TVar' or 'MVar' to coordinate,
-- or implementing your own equivalent of 'spake2Exchange'.
--
-- If the message received from the other side cannot be parsed, return a
-- 'MessageError'.
--
-- Since 0.4.0.
spake2Exchange
  :: (AbelianGroup group, HashAlgorithm hashAlgorithm)
  => Protocol group hashAlgorithm
  -- ^ A 'Protocol' with all the parameters for the exchange. These parameters
  -- must be shared by both sides. Construct with 'makeAsymmetricProtocol' or
  -- 'makeSymmetricProtocol'.
  -> Password
  -- ^ The password shared between both sides. Construct with 'makePassword'.
  -> (ByteString -> IO ())
  -- ^ An action to send a message. The 'ByteString' parameter is this side's
  -- SPAKE2 element, encoded using the group encoding, prefixed according to
  -- the parameters in the 'Protocol'.
  -> IO (Either error ByteString)
  -- ^ An action to receive a message. The 'ByteString' generated ought to be
  -- the protocol-prefixed, group-encoded version of the other side's SPAKE2
  -- element.
  -> IO (Either (MessageError error) ByteString)
  -- ^ Either the shared session key or an error indicating we couldn't parse
  -- the other side's message.
spake2Exchange :: Protocol group hashAlgorithm
-> Password
-> (ByteString -> IO ())
-> IO (Either error ByteString)
-> IO (Either (MessageError error) ByteString)
spake2Exchange Protocol group hashAlgorithm
protocol Password
password ByteString -> IO ()
send IO (Either error ByteString)
receive = do
  Spake2Exchange group
exchange <- Protocol group hashAlgorithm
-> Password -> IO (Spake2Exchange group)
forall (randomly :: * -> *) group hashAlgorithm.
(MonadRandom randomly, AbelianGroup group) =>
Protocol group hashAlgorithm
-> Password -> randomly (Spake2Exchange group)
startSpake2 Protocol group hashAlgorithm
protocol Password
password
  let outboundElement :: Element group
outboundElement = Spake2Exchange group -> Element group
forall group.
AbelianGroup group =>
Spake2Exchange group -> Element group
Math.computeOutboundMessage Spake2Exchange group
exchange
  let outboundMessage :: ByteString
outboundMessage = Protocol group hashAlgorithm -> Element group -> ByteString
forall group hashAlgorithm.
Group group =>
Protocol group hashAlgorithm -> Element group -> ByteString
elementToMessage Protocol group hashAlgorithm
protocol Element group
outboundElement
  (()
_, Either error ByteString
inboundMessage) <- IO ()
-> IO (Either error ByteString) -> IO ((), Either error ByteString)
forall a b. IO a -> IO b -> IO (a, b)
concurrently (ByteString -> IO ()
send ByteString
outboundMessage) IO (Either error ByteString)
receive
  Either (MessageError error) ByteString
-> IO (Either (MessageError error) ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (MessageError error) ByteString
 -> IO (Either (MessageError error) ByteString))
-> Either (MessageError error) ByteString
-> IO (Either (MessageError error) ByteString)
forall a b. (a -> b) -> a -> b
$ do
    ByteString
inboundMessage' <- (error -> MessageError error)
-> Either error ByteString
-> Either (MessageError error) ByteString
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first error -> MessageError error
forall e. e -> MessageError e
UnknownError Either error ByteString
inboundMessage
    Element group
inboundElement <- Protocol group hashAlgorithm
-> ByteString -> Either (MessageError error) (Element group)
forall group hashAlgorithm error.
Group group =>
Protocol group hashAlgorithm
-> ByteString -> Either (MessageError error) (Element group)
extractElement Protocol group hashAlgorithm
protocol ByteString
inboundMessage'
    let keyMaterial :: Element group
keyMaterial = Spake2Exchange group -> Element group -> Element group
forall group.
AbelianGroup group =>
Spake2Exchange group -> Element group -> Element group
Math.generateKeyMaterial Spake2Exchange group
exchange Element group
inboundElement
    ByteString -> Either (MessageError error) ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Protocol group hashAlgorithm
-> Element group
-> Element group
-> Element group
-> Password
-> ByteString
forall group hashAlgorithm.
(Group group, HashAlgorithm hashAlgorithm) =>
Protocol group hashAlgorithm
-> Element group
-> Element group
-> Element group
-> Password
-> ByteString
createSessionKey Protocol group hashAlgorithm
protocol Element group
inboundElement Element group
outboundElement Element group
keyMaterial Password
password)

-- | Commence a SPAKE2 exchange.
startSpake2
  :: (MonadRandom randomly, AbelianGroup group)
  => Protocol group hashAlgorithm
  -> Password
  -> randomly (Math.Spake2Exchange group)
startSpake2 :: Protocol group hashAlgorithm
-> Password -> randomly (Spake2Exchange group)
startSpake2 Protocol group hashAlgorithm
protocol Password
password =
  Spake2 group -> randomly (Spake2Exchange group)
forall group (randomly :: * -> *).
(AbelianGroup group, MonadRandom randomly) =>
Spake2 group -> randomly (Spake2Exchange group)
Math.startSpake2 Spake2 :: forall group. Params group -> Scalar group -> Spake2 group
Math.Spake2 { params :: Params group
Math.params = Protocol group hashAlgorithm -> Params group
forall group hashAlgorithm.
Protocol group hashAlgorithm -> Params group
getParams Protocol group hashAlgorithm
protocol
                               , password :: Scalar group
Math.password = group -> Password -> Scalar group
forall group.
AbelianGroup group =>
group -> Password -> Scalar group
passwordToScalar (Protocol group hashAlgorithm -> group
forall group hashAlgorithm. Protocol group hashAlgorithm -> group
group Protocol group hashAlgorithm
protocol) Password
password
                               }

-- | Create a session key based on the output of SPAKE2.
--
-- \[SK \leftarrow H(A, B, X^{\star}, Y^{\star}, K, pw)\]
--
-- Including \(pw\) in the session key is what makes this SPAKE2, not SPAKE1.
--
-- __Note__: In spake2 0.3 and earlier, The \(X^{\star}\) and \(Y^{\star}\)
-- were expected to be from side A and side B respectively. Since spake2 0.4,
-- they are the outbound and inbound elements respectively. This fixes an
-- interoperability concern with the Python library, and reduces the burden on
-- the caller. Apologies for the possibly breaking change to any users of
-- older versions of spake2.
createSessionKey
  :: (Group group, HashAlgorithm hashAlgorithm)
  => Protocol group hashAlgorithm  -- ^ The protocol used for this exchange
  -> Element group  -- ^ The outbound message, generated by this, \(X^{\star}\), or either side if symmetric
  -> Element group  -- ^ The inbound message, generated by the other side, \(Y^{\star}\), or either side if symmetric
  -> Element group  -- ^ The calculated key material, \(K\)
  -> Password  -- ^ The shared secret password
  -> ByteString  -- ^ A session key to use for further communication
createSessionKey :: Protocol group hashAlgorithm
-> Element group
-> Element group
-> Element group
-> Password
-> ByteString
createSessionKey Protocol{group
group :: group
group :: forall group hashAlgorithm. Protocol group hashAlgorithm -> group
group, hashAlgorithm
hashAlgorithm :: hashAlgorithm
hashAlgorithm :: forall group hashAlgorithm.
Protocol group hashAlgorithm -> hashAlgorithm
hashAlgorithm, Relation group
relation :: Relation group
relation :: forall group hashAlgorithm.
Protocol group hashAlgorithm -> Relation group
relation} Element group
outbound Element group
inbound Element group
k (Password ByteString
password) =
  ByteString -> ByteString
forall input. ByteArrayAccess input => input -> ByteString
hashDigest ByteString
transcript

  where
    -- The protocol expects that when we include the hash of various
    -- components (e.g. the password) as input for the session key hash,
    -- that we use the *byte* representation of these elements.
    hashDigest :: ByteArrayAccess input => input -> ByteString
    hashDigest :: input -> ByteString
hashDigest input
thing = Digest hashAlgorithm -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
ByteArray.convert (hashAlgorithm -> input -> Digest hashAlgorithm
forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith hashAlgorithm
hashAlgorithm input
thing)

    transcript :: ByteString
transcript =
      case Relation group
relation of
        Asymmetric{Side group
sideA :: Side group
sideA :: forall group. Relation group -> Side group
sideA, Side group
sideB :: Side group
sideB :: forall group. Relation group -> Side group
sideB, WhichSide
us :: WhichSide
us :: forall group. Relation group -> WhichSide
us} ->
          let (Element group
x, Element group
y) = case WhichSide
us of
                         WhichSide
SideA -> (Element group
inbound, Element group
outbound)
                         WhichSide
SideB -> (Element group
outbound, Element group
inbound)
          in [ByteString] -> ByteString
forall a. Monoid a => [a] -> a
mconcat [ ByteString -> ByteString
forall input. ByteArrayAccess input => input -> ByteString
hashDigest ByteString
password
                     , ByteString -> ByteString
forall input. ByteArrayAccess input => input -> ByteString
hashDigest (SideID -> ByteString
unSideID (Side group -> SideID
forall group. Side group -> SideID
sideID Side group
sideA))
                     , ByteString -> ByteString
forall input. ByteArrayAccess input => input -> ByteString
hashDigest (SideID -> ByteString
unSideID (Side group -> SideID
forall group. Side group -> SideID
sideID Side group
sideB))
                     , group -> Element group -> ByteString
forall group bytes.
(Group group, ByteArray bytes) =>
group -> Element group -> bytes
encodeElement group
group Element group
x
                     , group -> Element group -> ByteString
forall group bytes.
(Group group, ByteArray bytes) =>
group -> Element group -> bytes
encodeElement group
group Element group
y
                     , group -> Element group -> ByteString
forall group bytes.
(Group group, ByteArray bytes) =>
group -> Element group -> bytes
encodeElement group
group Element group
k
                     ]
        Symmetric{Side group
bothSides :: Side group
bothSides :: forall group. Relation group -> Side group
bothSides} ->
          [ByteString] -> ByteString
forall a. Monoid a => [a] -> a
mconcat [ ByteString -> ByteString
forall input. ByteArrayAccess input => input -> ByteString
hashDigest ByteString
password
                  , ByteString -> ByteString
forall input. ByteArrayAccess input => input -> ByteString
hashDigest (SideID -> ByteString
unSideID (Side group -> SideID
forall group. Side group -> SideID
sideID Side group
bothSides))
                  , ByteString
symmetricElements
                  , group -> Element group -> ByteString
forall group bytes.
(Group group, ByteArray bytes) =>
group -> Element group -> bytes
encodeElement group
group Element group
k
                  ]

    symmetricElements :: ByteString
symmetricElements =
      let [ ByteString
firstMessage, ByteString
secondMessage ] = [ByteString] -> [ByteString]
forall a. Ord a => [a] -> [a]
sort [ group -> Element group -> ByteString
forall group bytes.
(Group group, ByteArray bytes) =>
group -> Element group -> bytes
encodeElement group
group Element group
inbound, group -> Element group -> ByteString
forall group bytes.
(Group group, ByteArray bytes) =>
group -> Element group -> bytes
encodeElement group
group Element group
outbound ]
      in ByteString
firstMessage ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
secondMessage