{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE ViewPatterns, PackageImports, RecordWildCards,
    NamedFieldPuns, ScopedTypeVariables, BangPatterns, DeriveDataTypeable #-}
module Network.Hermes.Core(
  withHermes, CoreContext, TrustLevel(..), HermesID, HermesException(..)
  -- * Context management
  ,newContext, restoreContext, restoreContext', snapshotContext, addAuthority, setKeySignature
  ,myHermesID, setTimeout, timeout, setTrustLimit, snapshotContext'
  -- * Listeners
  ,startListener
   -- * Peer management
  ,connect, setHermesID
  -- * Communication
  ,send, send', recv, recv', NoTag(..), acceptType, refuseType
) where

import Prelude hiding(catch)

import Control.Arrow
import Control.Applicative
import Control.Monad
import Control.Monad.Tools
import "monads-tf" Control.Monad.State
import Control.Concurrent.STM
import Control.Exception(throwIO, throw, onException, block, unblock, catch, IOException(..))
import Data.Typeable

import System.Log.Logger

import Data.Maybe
import Data.Word
import qualified Data.Set as S
import Data.Map(Map)
import qualified Data.Map as M
import qualified Network
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import System.IO(Handle, hFlush, hClose)
import qualified System.Timeout

import qualified Data.Serialize
import Data.Serialize(encode,decode,Serialize)
import Data.Serialize.Put
import Data.Serialize.Get
import Network.Hermes.Protocol
import Network.Hermes.Types
import Network.Hermes.Misc
import Network.Hermes.MChan
import qualified Network.Hermes.Net as N
import Codec.Crypto.AES.Random
import Codec.Crypto.RSA as RSA
import Codec.Crypto.AES.IO as AES
import Codec.Digest.SHA

-- * Misc

hashKey :: PublicKey -> HermesID
hashKey = byteStringToInteger . hash SHA256 . encode

-- | All use of hermes must be wrapped with this (on windows)
withHermes :: IO a -> IO a
withHermes = Network.withSocketsDo

-- | Like Data.ByteString.hGet, except it throws EOF instead of returning partial data
hGet :: Handle -> Int -> IO B.ByteString
hGet h i = do
  bs <- B.hGet h i
  unless (B.length bs == i) $ throwIO EOF
  return bs


-- * Sending and receiving

-- Wire messages are always strict bytestrings. The API uses Data.Serialize
-- to encode arbitrary values as bytestrings, and tags them with their
-- type.
--
-- All integers are unsigned and little-endian.
--
-- Each clause (marked with a - ) is separately encrypted
--
-- Wire message format:
--
-- - Header:
-- * Length of the message as a 32-bit integer
-- * Length of the tag as a 32-bit integer
-- * Message type index as a 32-bit integer
-- * Tag type index as a 32-bit integer
--
-- - 16-byte truncated HMAC of the header
--
-- - 32-byte HMAC of the tag and message
-- - Variable length tag
-- - The message itself

cryptSend :: Connection
             -> Int          -- ^ The tag's type index
             -> B.ByteString -- ^ The tag
             -> Int          -- ^ The message's type index
             -> B.ByteString -- ^ The message
             -> IO ()
cryptSend Connection{..} tagIndex tag messageIndex message = do
  let header      = runPut $ do
        putWord32le $ fromIntegral $ B.length message
        putWord32le $ fromIntegral $ B.length tag
        putWord32le $ fromIntegral messageIndex
        putWord32le $ fromIntegral tagIndex
      headerHMAC  = B.take 16 $ hmac SHA256 aesKey header
      payloadHMAC = hmac SHA256 aesKey (BL.fromChunks [tag,message])
  debugM "hermes.core" $ "  message index: " ++ show messageIndex
  debugM "hermes.core" $ "  tag index    : " ++ show tagIndex
  debugM "hermes.core" $ "  header HMAC  : " ++ showBSasHex headerHMAC
  debugM "hermes.core" $ "  payload HMAC : " ++ showBSasHex payloadHMAC
  debugM "hermes.core" $ "  tag          : " ++ showBSasHex tag
  debugM "hermes.core" $ "  message      : " ++ showBSasHex message
  forM_ [header,headerHMAC,payloadHMAC,tag,message] (\bs -> AES.crypt aesctx bs >>= B.hPut handle)
  hFlush handle

cryptRecv :: Handle
             -> B.ByteString -- ^ The HMAC secret key
             -> AESCtx
             -> IO (Int,B.ByteString,Int,B.ByteString) -- ^ (Tag type index, tag, message index, message)
cryptRecv h key ctx = do
  header <- AES.crypt ctx =<< hGet h 16
  expectedHeaderHMAC <- AES.crypt ctx =<< hGet h 16
  let (messageLength,tagLength,messageIndex,tagIndex) =
        flip runGet' header $ liftM4 (,,,) getWord32le getWord32le getWord32le getWord32le
      headerHMAC = B.take 16 $ hmac SHA256 key header
  debugM "hermes.core" $ "Receiving message, type index " ++ show (messageIndex,tagIndex)
  debugM "hermes.core" $ "  header HMAC : " ++ showBSasHex headerHMAC
  unless (headerHMAC == expectedHeaderHMAC) $ throwIO MessageError
  expectedPayloadHMAC <- AES.crypt ctx =<< hGet h 32
  tag <- AES.crypt ctx =<< hGet h (fromIntegral tagLength)
  message <- AES.crypt ctx =<< hGet h (fromIntegral messageLength)
  let payloadHMAC = hmac SHA256 key (BL.fromChunks [tag,message])
  debugM "hermes.core" $ "  payload HMAC: " ++ showBSasHex payloadHMAC
  unless (payloadHMAC == expectedPayloadHMAC) $ throwIO MessageError
  debugM "hermes.core" $ "  tag         : " ++ showBSasHex tag
  debugM "hermes.core" $ "  message     : " ++ showBSasHex message
  return (fromIntegral tagIndex, tag, fromIntegral messageIndex, message)

-- | Optimizes sends by converting textual type representation to
-- indexed.
baseSend :: CoreContext
            -> HermesID
            -> Type         -- ^ The tag's type
            -> B.ByteString -- ^ The tag
            -> Type         -- ^ The message's type
            -> B.ByteString -- ^ The message
            -> IO ()
baseSend ctx uuid tagType tag msgType msg = trySend True
  where
    trySend firstTry = flip catch (handler firstTry) $ do
      -- OPTIMIZE: Use Data.Typeable's tag int
      -- OPTIMIZE: Don't encode tag/message if they're already (strict) bytestrings
      infoM "hermes.core" $ "Sending message, type " ++ show (tagType,msgType)
      if uuid == myHermesID ctx
        then do insertMessage ctx msgType tagType tag uuid msg
        else do withConnection ctx uuid $ \conn -> do
                  let fetchIndex typeString = do
                        (typeIndex,sendType) <- atomically $ do
                          maybeIndex <- M.lookup typeString <$> readTVar (typeMap conn)
                          case maybeIndex of
                            Just index -> return (index,False)
                            Nothing -> do
                              index <- succ <$> readTVar (typeMax conn)
                              modifyTVar (typeMap conn) (M.insert typeString index)
                              writeTVar (typeMax conn) index
                              return (index,True)
                        when sendType $ cryptSend conn 0 B.empty 0 (encode (typeIndex,typeString))
                        return typeIndex
                      fetchIndex :: Type -> IO Int
                  tagIndex <- fetchIndex tagType
                  messageIndex <- fetchIndex msgType
                  cryptSend conn tagIndex tag messageIndex msg
                  atomically $ modifyTVar (peerFailures ctx) (M.insert uuid 0)
    handler :: Bool -> IOException -> IO ()
    handler firstTry e = do
      noticeM "hermes.core" $ "IO error while sending: " ++ show e ++ if firstTry then ", retrying" else ""
      killConnection ctx uuid
      if firstTry then trySend False else return ()
      
-- | Kills the connection to a peer, increasing its failure count in
-- the process
killConnection :: CoreContext -> HermesID -> IO ()
killConnection ctx uuid = do
  h <- atomically $ do
    mvar <- M.lookup uuid <$> readTVar (peerConnections ctx)
    var <- case mvar of
      Nothing -> return Nothing
      Just v -> Just . handle <$> takeTMVar v
    modifyTVar (peerConnections ctx) (M.delete uuid)
    return $ var
  maybe (return ()) hClose h

-- | Sends a message. The type representation is included, so
-- a modicum of type safety is provided, and recv will only attempt to
-- decode and return a message of the matching (not necessarily
-- correct!) type. There is, of course, a possibility of parse errors
-- if application versions differ.
--
-- This function blocks until the entire message has been sent, or a
-- timeout is reached.  It will retry once if the connection fails.
--
-- This function is equivalent to using NoTag for send's tag.
send :: (Serialize msg, Typeable msg) => CoreContext -> HermesID -> msg -> IO ()
send cc uuid msg = send' cc uuid msg NoTag

data NoTag = NoTag
           deriving(Typeable)

instance Data.Serialize.Serialize NoTag where
         get = return NoTag 
         put _ = return ()

-- | Alternately, you may provide an arbitrary tag to match on, in
-- which case recv' will only return a message with an equal tag.
send' :: (Serialize msg, Typeable msg, Serialize tag, Typeable tag)
         => CoreContext -> HermesID -> msg -> tag -> IO ()
send' ctx uuid msg tag = baseSend ctx uuid (showType tag) (encode tag) (showType msg) (encode msg)

-- | Receives a message. This function blocks until a message of the
-- appropriate type has been received.
--
-- Multiple calls to recv are allowed; messages are only returned once.
recv :: forall msg. (Serialize msg, Typeable msg) => CoreContext -> IO (HermesID,msg)
recv ctx = recv' ctx NoTag

-- | You may also specify an arbitrary tag to match on, in which case
-- we'll block until we see a message of the appropriate type, with a
-- tag with the correct type and value.
--
-- Messages with a type/tag that has not been requested via
-- recv' or acceptType are automatically dropped.
--
-- Once recv has been called once, any further messages of the same
-- type are indefinitely queued.
recv' :: forall msg tag. (Serialize msg, Typeable msg, Serialize tag, Typeable tag)
         => CoreContext
         -> tag
         -> IO (HermesID,msg)
recv' ctx tag = do
  let tagType = showType tag
      messageType = showType (undefined :: msg)
      key = (messageType,tagType,encode tag)
  infoM "hermes.core" $ "Requesting message of type " ++ show (messageType,showType tag,encode tag)
  acceptType ctx (undefined :: msg) tag
  msg <- atomically $ readMChan (messageBox ctx) key
  case msg of
    Nothing -> throwIO RecvCancelled
    Just msg' -> do
      infoM "hermes.core" $ "Message of type " ++ show (tagType,messageType) ++ " returned"
      return $ second decode' msg'

-- | If you wish to queue messages without immediately calling recv, use this.
--
-- acceptType is idempotent.
acceptType :: forall tag msg. (Typeable msg, Serialize tag, Typeable tag)
              => CoreContext
              -> msg -- ^ The message type to accept. Only the type is used, so undefined is fine.
              -> tag
              -> IO ()
acceptType CoreContext{messageBox} (showType -> messageType) tag = do
  let key = (messageType,showType tag,encode tag)
  debugM "hermes.core" $ "Accepting key: " ++ show key
  atomically $ ensureMChan messageBox key

-- | If you wish to *stop* queueing messages of a given type, use this.
--
-- Calling refuseType will cause all recv calls to this type/tag
-- combination to throw RecvCancelled.
--
-- refuseType is idempotent.
refuseType :: forall tag msg. (Typeable msg, Serialize tag, Typeable tag)
              => CoreContext
              -> msg -- ^ The message type to accept. Only the type is used, so undefined is fine.
              -> tag
              -> IO ()
refuseType CoreContext{messageBox} (showType -> messageType) tag = do
  let key = (messageType,showType tag,encode tag)
  debugM "hermes.core" $ "Refusing key: " ++ show key
  atomically $ deleteMChan messageBox key


-- * Context management

restoreContext' :: CoreContextSnapshot -> STM CoreContext
restoreContext' CoreContextSnapshot{..} = do
  let myKey        = myKeySnap
      myPrivateKey = myPrivateKeySnap
      myHermesID       = myHermesIDSnap
  myKeySignature  <- newTVar myKeySignatureSnap
  authorities     <- newTVar authoritiesSnap
  listeners       <- newTVar S.empty
  listenerKillers <- newTVar M.empty
  peerAddress     <- newTVar peerAddressSnap
  peerKeys        <- newTVar peerKeysSnap
  peerConnections <- newTVar M.empty
  peerFailures    <- newTVar peerFailuresSnap
  trustLimit      <- newTVar trustLimitSnap
  messageBox      <- newMChan
  timeLimit       <- newTVar timeLimitSnap
  return $ CoreContext {..}

-- | Restores a context from snapshot
restoreContext :: B.ByteString -> IO CoreContext
restoreContext (decode' -> snapshot) = do
  infoM "hermes.core" "Restoring context from snapshot"
  atomically $ restoreContext' snapshot



-- | Snapshots a context for storage  
--
-- Messages that have been received but not yet processed are
-- discarded, as are the listeners, and obviously connections. Other
-- details - keys, authorities, signatures, etc. - are all saved.
snapshotContext :: CoreContext -> STM CoreContextSnapshot
snapshotContext ctx = do
 let myKeySnap        = myKey ctx
     myPrivateKeySnap = myPrivateKey ctx
     myHermesIDSnap       = myHermesID ctx
 myKeySignatureSnap <- readTVar (myKeySignature ctx)
 authoritiesSnap    <- readTVar (authorities ctx)
 peerAddressSnap    <- readTVar (peerAddress ctx)
 peerKeysSnap       <- readTVar (peerKeys ctx)
 trustLimitSnap     <- readTVar (trustLimit ctx)
 timeLimitSnap      <- readTVar (timeLimit ctx)
 peerFailuresSnap   <- readTVar (peerFailures ctx)
 return $ CoreContextSnapshot {..}

snapshotContext' :: CoreContext -> IO B.ByteString
snapshotContext' ctx = encode <$> atomically (snapshotContext ctx)

-- | Set key/addressing information for some HermesID. You should probably
-- | not call this function directly.
setHermesID :: CoreContext -> HermesID -> Maybe Address -> Maybe PeerKey -> IO ()
setHermesID ctx uuid address key = atomically $ do
  when (isJust address) $ modifyTVar (peerAddress ctx) (M.insert uuid (fromJust address))
  when (isJust key) $ modifyTVar (peerKeys ctx) (M.insert uuid (fromJust key))

-- | Set the desired trust limit, which will take effect on next connection.
setTrustLimit :: CoreContext -> TrustLevel -> IO ()
setTrustLimit ctx = atomically . writeTVar (trustLimit ctx)

-- | Set the desired time-limit for all operations. Defaults to 30 seconds.
setTimeout :: CoreContext -> Double -> IO ()
setTimeout ctx = atomically . writeTVar (timeLimit ctx)

-- | As for System.Timeout.timeout, but reads the context for the
-- | timeout and throws Timeout instead of returning Nothing
timeout :: CoreContext -> IO a -> IO a
timeout ctx act = do
  limit <- atomically $ readTVar (timeLimit ctx)
  ret <- System.Timeout.timeout (round $ limit * 1000000) act
  maybe (throwIO Timeout) return ret

-- | Add an authority key that we're supposed to trust
addAuthority :: CoreContext -> PublicKey -> IO ()
addAuthority ctx authority = atomically $ modifyTVar (authorities ctx) (authority :)

-- | Set our key signature
setKeySignature :: CoreContext -> Signature -> IO ()
setKeySignature ctx sig = atomically $ writeTVar (myKeySignature ctx) (Just sig)

-- | Creates an empty context, with a new, randomly generated key.
-- The default trust level is Indirect, but no authorities are set.
newContext :: IO CoreContext
newContext = do
  infoM "hermes.core" "Creating new context"
  aesGen <- newAESGen
  let (myKey,myPrivateKey,_) = RSA.generateKeyPair aesGen rsaKeySize
      myHermesID                 = hashKey myKey
  myKeySignature  <- newTVarIO Nothing
  authorities     <- newTVarIO []
  listeners       <- newTVarIO S.empty
  listenerKillers <- newTVarIO M.empty
  peerAddress     <- newTVarIO M.empty
  peerKeys        <- newTVarIO M.empty
  peerConnections <- newTVarIO M.empty
  peerFailures    <- newTVarIO M.empty
  trustLimit      <- newTVarIO Indirect
  timeLimit       <- newTVarIO 30
  messageBox      <- newMChanIO
  return CoreContext {..}

-- | Connects to a given address without first knowing who will be
-- | answering. The answerer's HermesID is returned, assuming the
-- | connection is properly established.
--
-- Typically used for bootstrapping.
connect :: CoreContext -> Address -> IO HermesID
connect ctx address = block $ do
  infoM "hermes.core" $ "Connecting to " ++ show address
  (conn, uuid) <- unblock $ negotiate ctx address Nothing
  closeIt <- atomically $ do
    modifyTVar (peerAddress ctx) (M.insert uuid address)
    -- We store this connection iff we don't already have one to this HermesID
    ifM (M.member uuid <$> readTVar (peerConnections ctx))
      (return True)
      (do box <- newTMVar conn
          modifyTVar (peerConnections ctx) (M.insert uuid box)
          return False)
  when closeIt (hClose (handle conn))
  return uuid
    

-- * Connection negotiation & listeners



startListener :: CoreContext
                -> Address -- ^ The local address we bind to on this system
                -> Maybe Address -- ^ If present, the address other systems see
                -> IO ()
startListener ctx localAddress remoteAddressMaybe = do
  infoM "hermes.core" $ "Listener started on address " ++ show localAddress
  let remoteAddress = maybe localAddress id remoteAddressMaybe
      address = ListenerAddress {..}
  ok <- atomically $ do
    set <- readTVar (listeners ctx)
    if S.member address set
      then return False
      else writeTVar (listeners ctx) (S.insert address set) >> return True
  unless ok $ throwM ListenerAlreadyExists
  killer <- N.streamServer localAddress (handleConnection ctx)
  atomically $ do
    False <- M.member address <$> readTVar (listenerKillers ctx)
    modifyTVar (listenerKillers ctx) (M.insert address killer)

-- | Handle an incoming connection
handleConnection :: CoreContext -> Handle -> Address -> IO ()
handleConnection ctx h address = traplogging "hermes.core" CRITICAL "trap: handleConnection" $ do
  -- Handle exceptions
  flip catch (\e -> case e of
                 EOF -> infoM "hermes.core.handleConnection" $ "EOF on " ++ show address
             ) $ do
    infoM "hermes.core" $ "Incoming connection from " ++ show address
    -- Exchange protocol version
    exchangeVersions h
    -- Exchange HermesIDs. We go last.
    AHermesID theirHermesID <- rawRecv h
    rawSend h (AHermesID $ myHermesID ctx)
    -- Exchange keys, if required. We go last.
    answerKeyQuery ctx h
    -- Then we get to ask client
    -- for a key.  If we've gotten this far, then both sides have the
    -- other's key.  The connecting client then creates a session key,
    -- and seals it for the server's (i.e. our) eyes only.  Note that
    -- there's no provision in the protocol for informing the other side
    -- of errors. That's probably a TODO, but we have to handle
    -- disconnections at any point regardless; doing so would be purely
    -- for courtesy.
    theirKey <- ensureKey ctx h theirHermesID
    -- We send the client a challenge that must be included with the
    -- session key, to prevent replay attacks
    challenge <- prandBytes 16
    debugM "hermes.core" $ "Sending challenge: " ++ showBSasHex challenge
    do gen <- newAESGen
       rawSend h $ AChallenge $ fst $ rsaEncrypt gen (peerKey theirKey) challenge
    ASessionSetup setupBS <- rawRecv h
    let setupMsg@SessionSetup{..} = decode' $ rsaDecrypt (myPrivateKey ctx) setupBS
    infoM "hermes.core" $ "Receiving session setup: " ++ show setupMsg
    unless (challenge == setupChallenge) (throwM $ AuthError "Challenge mismatch")
    -- Well, they must be who they say they are.
    -- Opportunistically (re)insert the client's address in our address database
    when (isJust clientAddress) $
      atomically $ modifyTVar (peerAddress ctx) (M.insert theirHermesID (fromJust clientAddress))
    -- Create the session decryption context
    aesctx <- AES.newCtx AES.CTR setupKey setupIV AES.Decrypt
    -- And a map for type indexes
    indexMap <- newTVarIO (M.empty :: Map Int Type)
    -- Enter a loop, receiving and forwarding messages.
    infoM "hermes.core" $ "Connection setup for " ++ show address ++ " complete"
    forever $ do
      -- Receive a new message.
      (tagIndex,tag,messageIndex,message) <- cryptRecv h setupKey aesctx
      case messageIndex of
        0 -> do
          let decoded = decode' message
          infoM "hermes.core" $ "New type registered: " ++ show decoded
          atomically $ modifyTVar indexMap (uncurry M.insert decoded)
        _ -> do
          let getType typeIndex = atomically $ fromJust . M.lookup typeIndex <$> readTVar indexMap
          tagType <- getType tagIndex
          messageType <- getType messageIndex
          infoM "hermes.core" $ "Received message of type " ++ show (tagType,messageType)
          insertMessage ctx messageType tagType tag theirHermesID message
          
insertMessage ctx messageType tagType tag theirHermesID message = do
  accepted <- atomically $ writeMChan (messageBox ctx) (messageType,tagType,tag) (theirHermesID,message)
  unless accepted $ do
    warningM "hermes.core" $ "Message discarded: type, tag type, tag value: " ++
      show (messageType,tagType,tag)
    -- Return the reject message, assuming this isn't a reject message - no loops!
    unless (messageType == showType RejectedMessage) $
      send' ctx theirHermesID RejectedMessage (tagType,tag)

-- | Negotiate a connection, client side. Returns the HermesID of the
-- | remote host, along with an connection/encryption context.
negotiate :: CoreContext
            -> Address -- ^ Address to connect to
            -> Maybe HermesID -- ^ Expected HermesID of the remote host, if any
            -> IO (Connection,HermesID) -- ^ (Connection,actual HermesID)
negotiate ctx address expectedHermesID = block $ do
  infoM "hermes.core" $ "Negotiating connection to " ++ show address ++ ", HermesID " ++ show expectedHermesID
  h <- N.connectStream address
  flip onException (hClose h) $ unblock $ do
    -- Exchange protocol version
    exchangeVersions h
    -- Exchange HermesIDs. We go first.
    rawSend h $ AHermesID $ myHermesID ctx
    AHermesID theirHermesID <- rawRecv h
    unless (maybe True (== theirHermesID) expectedHermesID) $ do
      -- Address is apparently wrong. Delete it from our database.
      -- But check to make sure it hasn't been updated in the meantime
      atomically $ whenM ((== Just address) . M.lookup (fromJust expectedHermesID)
                          <$> readTVar (peerAddress ctx))
        (modifyTVar (peerAddress ctx) (M.delete theirHermesID))
      throwM $ AddressUnknown theirHermesID
    -- Exchange keys, if required. Client (that is, us) goes first.
    theirKey <- ensureKey ctx h theirHermesID
    answerKeyQuery ctx h
    -- Both sides have the other's key - see handleConnection for details.
    -- Generate the session key and send it, with the challenge
    AChallenge challengeEncrypted <- rawRecv h
    let setupChallenge = rsaDecrypt (myPrivateKey ctx) challengeEncrypted
    setupKey <- prandBytes (aesKeySize `div` 8)
    setupIV <- prandBytes 16
    -- TODO: Do something smarter than picking one listener at random, and abstract
    clientAddress <- atomically $ listToMaybe . map remoteAddress . S.toList <$> readTVar (listeners ctx)
    let setupMsg = SessionSetup {..}
    debugM "hermes.core" $ "Sending session setup: " ++ show setupMsg
    do g <- newAESGen
       rawSend h $ ASessionSetup $ fst $ rsaEncrypt g (peerKey theirKey) (encode setupMsg)
    -- Finally, create and return the session context
    aesctx <- AES.newCtx AES.CTR setupKey setupIV AES.Encrypt
    infoM "hermes.core" $ "Negotiation complete"
    typeMap <- newTVarIO M.empty
    typeMax <- newTVarIO 0
    return (Connection { aesctx = aesctx, handle = h, aesKey = setupKey, typeMap, typeMax }, theirHermesID)

-- | Locks a peer connection, negotiating a connection to said peer if necessary
withConnection :: CoreContext -> HermesID -> (Connection -> IO a) -> IO a
withConnection ctx theirHermesID act = do
  address <- atomically $ maybe (throw $ AddressUnknown theirHermesID) id . M.lookup theirHermesID <$> readTVar (peerAddress ctx)
  withTMVar
    (peerConnections ctx)
    theirHermesID
    (fst <$> negotiate ctx address (Just theirHermesID))
    (get >>= liftIO . act)

-- | Looks up a TMVar in a TVar Map. If doesn't exist, it is created
-- | and some code is executed to fill it. Once it does exist, another
-- | function is called with it held. Only after this function is
-- | complete is the value put in the TMVar. If an exception occurs,
-- | the TMVar is deleted from the map again.
withTMVar :: Ord a =>
            TVar (Map a (TMVar b)) -- ^ The map
            -> a -- ^ A key for the map
            -> IO b -- ^ Function used to fill the TMVar
            -> StateT b IO r -- ^ Code to run once the TMVar is full
            -> IO r
withTMVar tvar key filler act = block $ do
  (needFill,var) <- atomically $ do
    exists <- M.member key <$> readTVar tvar
    if exists
      then do var <- fromJust . M.lookup key <$> readTVar tvar
              return (False,var)
      else do placeholder <- newEmptyTMVar
              modifyTVar tvar (M.insert key placeholder)
              return (True,placeholder)
  if needFill
    then flip onException (atomically $ modifyTVar tvar (M.delete key)) $ unblock $ do
    fill <- filler
    (ret, fill') <- runStateT act fill
    atomically $ putTMVar var fill'
    return ret
    else unblock $ runTMVar var act

-- | Checks if we have a trusted key for the given HermesID, and asks the
-- | given handle to provide one if we don't - then checks it.
ensureKey :: MonadIO m => CoreContext -> Handle -> HermesID -> m PeerKey
ensureKey ctx h uuid = do
  maybeTheirKey <- getKey ctx uuid -- See if we need a new key
  if isNothing maybeTheirKey
    then do 
    theirKey <- requestKey ctx h uuid
    limit <- liftIO $ atomically $ readTVar (trustLimit ctx)
    when (trust theirKey < limit)
      (throwM $ AuthError $ "Key insufficiently trusted: " ++ show (trust theirKey))
    liftIO $ atomically $ modifyTVar (peerKeys ctx) (M.insert uuid theirKey)
    return theirKey
    else do
    rawSend h $ AKeyQuery KeyOK
    return (fromJust maybeTheirKey)


-- | Fetches and verifies a remote key. Does not test trust level
-- | against the Context limit.
requestKey :: MonadIO m => CoreContext -> Handle -> HermesID -> m PeerKey
requestKey ctx h uuid = liftIO $ do
  rawSend h $ AKeyQuery RequestKey
  AKeyReply reply <- rawRecv h
  let key = keyReplyKey reply
      keyHermesID = hashKey key
      keyBS = encode key
  unless (keyHermesID == uuid) (throwM $ AuthError "key HermesID mismatch")
  case keyReplySig reply of
    Nothing -> return $ PeerKey { peerKey = key, trust = None, signature = Nothing }
    Just sig -> do
      authorityKeys <- atomically $ readTVar (authorities ctx)
      let ok = or $ map (\authority -> rsaVerify authority keyBS sig) authorityKeys
      unless ok (throwM $ AuthError "signature not verifiable")
      return $ PeerKey { peerKey = key, trust = Indirect, signature = Just sig }

-- | Returns Nothing if the key doesn't exist, or if it is insufficiently trusted.
getKey :: MonadIO m => CoreContext -> HermesID -> m (Maybe PeerKey)
getKey ctx uuid = liftIO $ atomically $ do
  maybeKey <- M.lookup uuid <$> readTVar (peerKeys ctx)
  limit <- readTVar (trustLimit ctx)
  case maybeKey of
    Just key -> return $ if trust key < limit then Nothing else Just key
    Nothing -> return Nothing

answerKeyQuery :: MonadIO m => CoreContext -> Handle -> m ()
answerKeyQuery ctx h = answerKeyQuery' =<< rawRecv h
  where
    answerKeyQuery' (AKeyQuery KeyOK) = return ()
    answerKeyQuery' (AKeyQuery RequestKey) = do
      let keyReplyKey = myKey ctx
      keyReplySig <- liftIO $ atomically $ readTVar (myKeySignature ctx)
      rawSend h $ AKeyReply $ KeyReply{..}
    answerKeyQuery' _ = throwM $ AuthError "answerKeyQuery: Unexpected reply"


exchangeVersions :: MonadIO m => Handle -> m ()
exchangeVersions h = liftIO $ do
  B.hPut h magicString
  B.hPut h (encode protocolVersion)
  hFlush h
  theirString <- hGet h (B.length magicString)
  unless (magicString == theirString) (throwIO WrongProtocol)
  theirVersion <- decode' <$> hGet h 4
  unless (theirVersion == protocolVersion) (throwIO $ ProtocolVersionMismatch protocolVersion theirVersion)

rawSend :: (MonadIO m) => Handle -> AnyMessage -> m ()
rawSend h (encode -> msg) = liftIO $ do
  B.hPut h $ encode (fromIntegral $ B.length msg :: Word32)
  B.hPut h msg
  hFlush h

rawRecv :: (MonadIO m) => Handle -> m AnyMessage
rawRecv h = liftIO $ do
  size <- decode' <$> hGet h 4 :: IO Word32
  decode' <$> hGet h (fromIntegral size)