{-# LANGUAGE OverloadedStrings, TemplateHaskell, ScopedTypeVariables #-} {-# OPTIONS_HADDOCK hide #-} ---------------------------------------------------------------- -- | -- Module : Crypto.Noise.Internal.NoiseState -- Maintainer : John Galt -- Stability : experimental -- Portability : POSIX module Crypto.Noise.Internal.NoiseState where import Control.Monad.Catch.Pure import Control.Monad.Coroutine import Control.Monad.Coroutine.SuspensionFunctors import Control.Monad.State (MonadState(..), runStateT, get, put) import Control.Monad.Free.Church import Control.Lens import Data.ByteArray (ScrubbedBytes, convert, length, splitAt) import Data.ByteString (ByteString) import Data.Maybe (isJust) import Data.Monoid ((<>)) import Data.Proxy (Proxy(..)) import Prelude hiding (concat, splitAt, length) import Crypto.Noise.Cipher import Crypto.Noise.DH import Crypto.Noise.Hash import Crypto.Noise.Internal.CipherState import Crypto.Noise.Internal.SymmetricState import Crypto.Noise.Internal.Handshake import Crypto.Noise.Internal.HandshakePattern hiding (e, s, ee, se, es, ss) import Crypto.Noise.Internal.Types -- | Represents the complete state of a Noise conversation. data NoiseState c d h = NoiseState { _nsHandshakeState :: HandshakeState c d h , _nsHandshakeSuspension :: ScrubbedBytes -> Handshake c d h () , _nsSendingCipherState :: Maybe (CipherState c) , _nsReceivingCipherState :: Maybe (CipherState c) } $(makeLenses ''NoiseState) -- | Returns a default set of handshake options. The prologue is set to an -- empty string, PSK-mode is disabled, and all keys are set to 'Nothing'. defaultHandshakeOpts :: HandshakePattern -> HandshakeRole -> HandshakeOpts d defaultHandshakeOpts hp r = HandshakeOpts { _hoPattern = hp , _hoRole = r , _hoPrologue = "" , _hoPreSharedKey = Nothing , _hoLocalStatic = Nothing , _hoLocalSemiEphemeral = Nothing , _hoLocalEphemeral = Nothing , _hoRemoteStatic = Nothing , _hoRemoteSemiEphemeral = Nothing , _hoRemoteEphemeral = Nothing } mkHandshakeName :: forall c d h proxy. (Cipher c, DH d, Hash h) => ByteString -> Bool -> proxy (c, d, h) -> ScrubbedBytes mkHandshakeName hpn psk _ = p <> convert hpn <> "_" <> d <> "_" <> c <> "_" <> h where p = if psk then "NoisePSK_" else "Noise_" c = cipherName (Proxy :: Proxy c) d = dhName (Proxy :: Proxy d) h = hashName (Proxy :: Proxy h) handshakeState :: forall c d h. (Cipher c, DH d, Hash h) => HandshakeOpts d -> HandshakeState c d h handshakeState ho | not (validPSK (ho ^. hoPreSharedKey)) = error "pre-shared key must be 32 bytes in length" | otherwise = HandshakeState { _hsSymmetricState = ss'' , _hsOpts = ho , _hsMsgBuffer = mempty } where validPSK = maybe True (\psk -> length psk == 32) ss = symmetricState $ mkHandshakeName (ho ^. hoPattern ^. hpName) (isJust (ho ^. hoPreSharedKey)) (Proxy :: Proxy (c, d, h)) ss' = mixHash (ho ^. hoPrologue) ss ss'' = maybe ss' (`mixPSK` ss') $ ho ^. hoPreSharedKey runHandshake :: (MonadThrow m, Cipher c, Hash h) => ScrubbedBytes -> NoiseState c d h -> m (ScrubbedBytes, NoiseState c d h) runHandshake msg ns = reThrow . runCatch $ do ((res, ns''), hs') <- runStateT st $ ns ^. nsHandshakeState return (res, ns'' & nsHandshakeState .~ hs') where reThrow = either throwM return st = do x <- resume . runHandshake' . (ns ^. nsHandshakeSuspension) $ msg case x of Left (Request req resp) -> return (req, ns & nsHandshakeSuspension .~ (Handshake . resp)) Right _ -> do hs <- get let (cs1, cs2) = split (hs ^. hsSymmetricState) ns' = if hs ^. hsOpts . hoRole == InitiatorRole then ns & nsSendingCipherState .~ Just cs1 & nsReceivingCipherState .~ Just cs2 else ns & nsSendingCipherState .~ Just cs2 & nsReceivingCipherState .~ Just cs1 return (hs ^. hsMsgBuffer, ns') -- | Creates a 'NoiseState'. noiseState :: forall c d h. (Cipher c, DH d, Hash h) => HandshakeOpts d -> NoiseState c d h noiseState ho = NoiseState { _nsHandshakeState = hs'' , _nsHandshakeSuspension = suspension , _nsSendingCipherState = Nothing , _nsReceivingCipherState = Nothing } where hs = handshakeState ho :: HandshakeState c d h coroutine = iterM evalPattern $ ho ^. hoPattern . hpActions (suspension, hs'') = case runCatch (runStateT (resume (runHandshake' coroutine)) hs) of Left err -> error $ "handshake pattern interpreter threw exception: " <> show err Right result -> case result of (Left (Request _ resp), hs') -> (Handshake . resp, hs') _ -> error "handshake pattern interpreter ended pre-maturely" processPatternOp :: (Cipher c, DH d, Hash h) => HandshakeRole -> F TokenF () -> Handshake c d h () -> Handshake c d h () processPatternOp opRole t next = do hs <- get input <- Handshake <$> request $ hs ^. hsMsgBuffer hs' <- get if opRole == hs' ^. hsOpts . hoRole then do put $ hs' & hsMsgBuffer .~ mempty iterM (evalMsgToken opRole) t hs'' <- get let enc = encryptAndHash (convert input) $ hs'' ^. hsSymmetricState (ep, ss) <- either throwM return enc put $ hs'' & hsMsgBuffer %~ (flip mappend . convert) ep & hsSymmetricState .~ ss else do put $ hs' & hsMsgBuffer .~ input iterM (evalMsgToken opRole) t hs'' <- get let remaining = hs'' ^. hsMsgBuffer dec = decryptAndHash (cipherBytesToText (convert remaining)) $ hs'' ^. hsSymmetricState (dp, ss) <- either (const . throwM . HandshakeError $ "handshake payload failed to decrypt") return dec put $ hs'' & hsMsgBuffer .~ convert dp & hsSymmetricState .~ ss next evalPattern :: (Cipher c, DH d, Hash h) => HandshakePatternF (Handshake c d h ()) -> Handshake c d h () evalPattern (PreInitiator t next) = do iterM (evalPreMsgToken InitiatorRole) t next evalPattern (PreResponder t next) = do iterM (evalPreMsgToken ResponderRole) t next evalPattern (Initiator t next) = processPatternOp InitiatorRole t next evalPattern (Responder t next) = processPatternOp ResponderRole t next evalMsgToken :: forall c d h. (Cipher c, DH d, Hash h) => HandshakeRole -> TokenF (Handshake c d h ()) -> Handshake c d h () evalMsgToken opRole (E next) = do hs <- get if opRole == hs ^. hsOpts . hoRole then do (_, pk) <- getLocalEphemeral hs let pk' = dhPubToBytes pk ss = hs ^. hsSymmetricState ss' = mixHash pk' ss ss'' = if ss' ^. ssHasPSK then mixKey pk' ss' else ss' put $ hs & hsSymmetricState .~ ss'' & hsMsgBuffer %~ (flip mappend . convert) pk' else do let (b, rest) = splitAt (dhLength (Proxy :: Proxy d)) $ hs ^. hsMsgBuffer reBytes = convert b ss = hs ^. hsSymmetricState ss' = mixHash reBytes ss ss'' = if ss ^. ssHasPSK then mixKey reBytes ss' else ss' theirKey = dhBytesToPub reBytes theirKey' <- maybe (throwM . HandshakeError $ "invalid remote ephemeral key") return theirKey put $ hs & hsOpts . hoRemoteEphemeral .~ Just theirKey' & hsSymmetricState .~ ss'' & hsMsgBuffer .~ rest next evalMsgToken opRole (S next) = do hs <- get if opRole == hs ^. hsOpts. hoRole then do pk <- dhPubToBytes . snd <$> getLocalStatic hs let ss = hs ^. hsSymmetricState enc = encryptAndHash (convert pk) ss (ct, ss') <- either throwM return enc put $ hs & hsSymmetricState .~ ss' & hsMsgBuffer %~ (flip mappend . convert) ct else if isJust (hs ^. hsOpts ^. hoRemoteStatic) then throwM . InvalidHandshakeOptions $ "unable to overwrite remote static key" else do let hasKey = hs ^. hsSymmetricState . ssHasKey len = dhLength (Proxy :: Proxy d) -- The magic 16 here represents the length of the auth tag. d = if hasKey then len + 16 else len (b, rest) = splitAt d $ hs ^. hsMsgBuffer ct = cipherBytesToText . convert $ b ss = hs ^. hsSymmetricState dec = decryptAndHash ct ss (pt, ss') <- either (const . throwM . HandshakeError $ "failed to decrypt remote static key") return dec theirKey' <- maybe (throwM . HandshakeError $ "invalid remote static key provided") return $ dhBytesToPub pt put $ hs & hsOpts . hoRemoteStatic .~ Just theirKey' & hsSymmetricState .~ ss' & hsMsgBuffer .~ rest next evalMsgToken _ (Ee next) = do hs <- get ~(sk, _) <- getLocalEphemeral hs rpk <- getRemoteEphemeral hs let ss' = mixKey (dhPerform sk rpk) $ hs ^. hsSymmetricState put $ hs & hsSymmetricState .~ ss' next evalMsgToken _ (Es next) = do hs <- get let ss = hs ^. hsSymmetricState if hs ^. hsOpts . hoRole == InitiatorRole then do rpk <- getRemoteStatic hs ~(sk, _) <- getLocalEphemeral hs let dh = dhPerform sk rpk ss' = mixKey dh ss put $ hs & hsSymmetricState .~ ss' else do ~(sk, _) <- getLocalStatic hs rpk <- getRemoteEphemeral hs let dh = dhPerform sk rpk ss' = mixKey dh ss put $ hs & hsSymmetricState .~ ss' next evalMsgToken _ (Se next) = do hs <- get let ss = hs ^. hsSymmetricState if hs ^. hsOpts . hoRole == InitiatorRole then do ~(sk, _) <- getLocalStatic hs rpk <- getRemoteEphemeral hs let dh = dhPerform sk rpk ss' = mixKey dh ss put $ hs & hsSymmetricState .~ ss' else do rpk <- getRemoteStatic hs ~(sk, _) <- getLocalEphemeral hs let dh = dhPerform sk rpk ss' = mixKey dh ss put $ hs & hsSymmetricState .~ ss' next evalMsgToken _ (Ss next) = do hs <- get let ss = hs ^. hsSymmetricState ~(sk, _) <- getLocalStatic hs rpk <- getRemoteStatic hs let dh = dhPerform sk rpk ss' = mixKey dh ss put $ hs & hsSymmetricState .~ ss' next evalPreMsgToken :: (Cipher c, DH d, Hash h) => HandshakeRole -> TokenF (Handshake c d h ()) -> Handshake c d h () evalPreMsgToken opRole (E next) = do hs <- get let ss = hs ^. hsSymmetricState pk <- if opRole == hs ^. hsOpts . hoRole then snd <$> getLocalSemiEphemeral hs else getRemoteSemiEphemeral hs let ss' = mixHash (dhPubToBytes pk) ss put $ hs & hsSymmetricState .~ ss' next evalPreMsgToken opRole (S next) = do hs <- get let ss = hs ^. hsSymmetricState pk <- if opRole == hs ^. hsOpts . hoRole then snd <$> getLocalStatic hs else getRemoteStatic hs let ss' = mixHash (dhPubToBytes pk) ss put $ hs & hsSymmetricState .~ ss' next evalPreMsgToken _ _ = error "invalid pre-message pattern token" getLocalStatic :: HandshakeState c d h -> Handshake c d h (KeyPair d) getLocalStatic hs = maybe (throwM (InvalidHandshakeOptions "local static key not set")) return (hs ^. hsOpts ^. hoLocalStatic) getLocalSemiEphemeral :: HandshakeState c d h -> Handshake c d h (KeyPair d) getLocalSemiEphemeral hs = maybe (throwM (InvalidHandshakeOptions "local semi-ephemeral key not set")) return (hs ^. hsOpts ^. hoLocalSemiEphemeral) getLocalEphemeral :: HandshakeState c d h -> Handshake c d h (KeyPair d) getLocalEphemeral hs = maybe (throwM (InvalidHandshakeOptions "local ephemeral key not set")) return (hs ^. hsOpts ^. hoLocalEphemeral) getRemoteStatic :: HandshakeState c d h -> Handshake c d h (PublicKey d) getRemoteStatic hs = maybe (throwM (InvalidHandshakeOptions "remote static key not set")) return (hs ^. hsOpts ^. hoRemoteStatic) getRemoteSemiEphemeral :: HandshakeState c d h -> Handshake c d h (PublicKey d) getRemoteSemiEphemeral hs = maybe (throwM (InvalidHandshakeOptions "remote semi-ephemeral key not set")) return (hs ^. hsOpts ^. hoRemoteSemiEphemeral) getRemoteEphemeral :: HandshakeState c d h -> Handshake c d h (PublicKey d) getRemoteEphemeral hs = maybe (throwM (InvalidHandshakeOptions "remote ephemeral key not set")) return (hs ^. hsOpts ^. hoRemoteEphemeral)