{-# OPTIONS_HADDOCK not-home #-}
-- |
-- Description : PAKE exchange for Magic Wormhole
--
-- Once a peer is found, and both sides have mailboxes open, the peers need to
-- generate a shared 'ClientProtocol.SessionKey' based on their shared password.
--
-- 'pakeExchange' has the logic for doing this.
module MagicWormhole.Internal.Pake
  ( pakeExchange
  , PakeError(..)
  -- * Exported for testing
  , spakeBytesToMessageBody
  , messageBodyToSpakeBytes
  ) where

import Protolude

import Control.Monad (fail)
import Crypto.Hash (SHA256(..))
import qualified Crypto.Spake2 as Spake2
import Crypto.Spake2.Group (Group(arbitraryElement))
import Crypto.Spake2.Groups (Ed25519(..))
import qualified Data.Aeson as Aeson
import Data.Aeson (FromJSON, ToJSON, (.=), object, Value(..), (.:))
import Data.Aeson.Types (typeMismatch)
import Data.ByteArray.Encoding (convertToBase, convertFromBase, Base(Base16))

import qualified MagicWormhole.Internal.Messages as Messages
import qualified MagicWormhole.Internal.ClientProtocol as ClientProtocol

-- | Exchange SPAKE2 keys with a Magic Wormhole peer.
--
-- Throws an 'Error' if we cannot parse the incoming message.
pakeExchange
  :: ClientProtocol.Connection -- ^ A connection to a peer
  -> Spake2.Password -- ^ The shared password. Construct with 'Spake2.makePassword'.
  -> IO ClientProtocol.SessionKey  -- ^ A key that can be used for the remainder of the session
pakeExchange conn password = do
  let protocol = wormholeSpakeProtocol (ClientProtocol.appID conn)
  result <- Spake2.spake2Exchange protocol password sendPakeMessage (atomically receivePakeMessage)
  case result of
    Left err -> throwIO (Error err)
    Right key -> pure (ClientProtocol.SessionKey key)
  where
    sendPakeMessage = ClientProtocol.send conn Messages.PakePhase . spakeBytesToMessageBody
    receivePakeMessage  = do
      -- This is kind of a fun approach, but it means that everyone else has
      -- to promise that they *don't* consume pake messages.
      msg <- ClientProtocol.receive conn
      unless (Messages.phase msg == Messages.PakePhase) retry
      pure $ messageBodyToSpakeBytes (Messages.body msg)

-- | The message we send to negotiate the shared session key.
newtype Spake2Message = Spake2Message { spake2Bytes :: ByteString } deriving (Eq, Show)

instance ToJSON Spake2Message where
  toJSON (Spake2Message msg) = object [ "pake_v1" .= toS @ByteString @Text (convertToBase Base16 msg) ]

instance FromJSON Spake2Message where
  parseJSON (Object msg) = do
    hexKey <- toS @Text @ByteString <$> msg .: "pake_v1"
    case convertFromBase Base16 hexKey of
      Left err -> fail err
      Right key -> pure $ Spake2Message key
  parseJSON unknown = typeMismatch "Spake2Message" unknown

-- | Encode the bytes generated by the SPAKE2 algorithm into a Magic Wormhole
-- message body.
spakeBytesToMessageBody :: ByteString -> Messages.Body
spakeBytesToMessageBody = Messages.Body . toS . Aeson.encode . Spake2Message

-- | Decode a Magic Wormhole message body into bytes that can be used as input
-- into the SPAKE2 algorithm.
messageBodyToSpakeBytes :: Messages.Body -> Either Text ByteString
messageBodyToSpakeBytes (Messages.Body bodyBytes) =
  bimap toS spake2Bytes . Aeson.eitherDecode . toS $ bodyBytes

-- | Construct a SPAKE2 protocol compatible with Magic Wormhole.
wormholeSpakeProtocol :: Messages.AppID -> Spake2Protocol
wormholeSpakeProtocol (Messages.AppID appID') =
  Spake2.makeSymmetricProtocol SHA256 Ed25519 blind sideID
  where
    blind = arbitraryElement Ed25519 ("symmetric" :: ByteString)
    sideID = Spake2.SideID (toS appID')

-- | The version of the SPAKE2 protocol used by Magic Wormhole.
type Spake2Protocol = Spake2.Protocol Ed25519 SHA256

-- | An error that occured during 'pakeExchange'.
newtype PakeError = Error (Spake2.MessageError Text) deriving (Eq, Show, Typeable)
instance Exception PakeError