module Network.TLS.Handshake.Common
( HandshakeFailed(..)
, handshakeFailed
, errorToAlert
, unexpected
, newSession
, handshakeTerminate
, sendChangeCipherAndFinish
, recvChangeCipherAndFinish
, RecvState(..)
, runRecvState
, recvPacketHandshake
) where
import Network.TLS.Context
import Network.TLS.Session
import Network.TLS.Struct
import Network.TLS.IO
import Network.TLS.State hiding (getNegotiatedProtocol)
import Network.TLS.Receiving
import Network.TLS.Measurement
import Data.Maybe
import Data.Data
import Data.ByteString.Char8 ()
import Control.Monad.State
import Control.Exception (throwIO, Exception())
data HandshakeFailed = HandshakeFailed TLSError
deriving (Show,Eq,Typeable)
instance Exception HandshakeFailed
handshakeFailed :: TLSError -> IO ()
handshakeFailed err = throwIO $ HandshakeFailed err
errorToAlert :: TLSError -> Packet
errorToAlert (Error_Protocol (_, _, ad)) = Alert [(AlertLevel_Fatal, ad)]
errorToAlert _ = Alert [(AlertLevel_Fatal, InternalError)]
unexpected :: MonadIO m => String -> Maybe [Char] -> m a
unexpected msg expected = throwCore $ Error_Packet_unexpected msg (maybe "" (" expected: " ++) expected)
newSession :: MonadIO m => Context -> m Session
newSession ctx
| pUseSession $ ctxParams ctx = getStateRNG ctx 32 >>= return . Session . Just
| otherwise = return $ Session Nothing
handshakeTerminate :: MonadIO m => Context -> m ()
handshakeTerminate ctx = do
session <- usingState_ ctx getSession
case session of
Session (Just sessionId) -> do
sessionData <- usingState_ ctx getSessionData
withSessionManager (ctxParams ctx) (\s -> liftIO $ sessionEstablish s sessionId (fromJust sessionData))
_ -> return ()
usingState_ ctx endHandshake
updateMeasure ctx resetBytesCounters
setEstablished ctx True
return ()
sendChangeCipherAndFinish :: MonadIO m => Context -> Bool -> m ()
sendChangeCipherAndFinish ctx isClient = do
sendPacket ctx ChangeCipherSpec
when isClient $ do
suggest <- usingState_ ctx $ getServerNextProtocolSuggest
case (onNPNServerSuggest (ctxParams ctx), suggest) of
(Just io, Just protos) -> do proto <- liftIO $ io protos
sendPacket ctx (Handshake [HsNextProtocolNegotiation proto])
usingState_ ctx $ setNegotiatedProtocol proto
(Just _, Nothing) -> return ()
(Nothing, _) -> return ()
liftIO $ contextFlush ctx
cf <- usingState_ ctx $ getHandshakeDigest isClient
sendPacket ctx (Handshake [Finished cf])
liftIO $ contextFlush ctx
recvChangeCipherAndFinish :: MonadIO m => Context -> m ()
recvChangeCipherAndFinish ctx = runRecvState ctx (RecvStateNext expectChangeCipher)
where
expectChangeCipher ChangeCipherSpec = return $ RecvStateHandshake expectFinish
expectChangeCipher p = unexpected (show p) (Just "change cipher")
expectFinish (Finished _) = return RecvStateDone
expectFinish p = unexpected (show p) (Just "Handshake Finished")
data RecvState m =
RecvStateNext (Packet -> m (RecvState m))
| RecvStateHandshake (Handshake -> m (RecvState m))
| RecvStateDone
recvPacketHandshake :: MonadIO m => Context -> m [Handshake]
recvPacketHandshake ctx = do
pkts <- recvPacket ctx
case pkts of
Right (Handshake l) -> return l
Right x -> fail ("unexpected type received. expecting handshake and got: " ++ show x)
Left err -> throwCore err
runRecvState :: MonadIO m => Context -> RecvState m -> m ()
runRecvState _ (RecvStateDone) = return ()
runRecvState ctx (RecvStateNext f) = recvPacket ctx >>= either throwCore f >>= runRecvState ctx
runRecvState ctx iniState = recvPacketHandshake ctx >>= loop iniState >>= runRecvState ctx
where
loop :: MonadIO m => RecvState m -> [Handshake] -> m (RecvState m)
loop recvState [] = return recvState
loop (RecvStateHandshake f) (x:xs) = do
nstate <- f x
usingState_ ctx $ processHandshake x
loop nstate xs
loop _ _ = unexpected "spurious handshake" Nothing