{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE OverloadedStrings, ScopedTypeVariables, BangPatterns #-}
module Network.TLS.Core
(
sendPacket
, recvPacket
, bye
, handshake
, getNegotiatedProtocol
, getClientSNI
, sendData
, recvData
, recvData'
, updateKey
, KeyUpdateRequest(..)
, requestCertificate
) where
import Network.TLS.Cipher
import Network.TLS.Context
import Network.TLS.Crypto
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.State (getSession)
import Network.TLS.Parameters
import Network.TLS.IO
import Network.TLS.Session
import Network.TLS.Handshake
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Common13
import Network.TLS.Handshake.Process
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.PostHandshake
import Network.TLS.KeySchedule
import Network.TLS.Types (Role(..), HostName)
import Network.TLS.Util (catchException, mapChunks_)
import Network.TLS.Extension
import qualified Network.TLS.State as S
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as C8
import qualified Data.ByteString.Lazy as L
import qualified Control.Exception as E
import Control.Monad.State.Strict
bye :: MonadIO m => Context -> m ()
bye ctx = liftIO $ do
eof <- ctxEOF ctx
tls13 <- tls13orLater ctx
unless eof $ withWriteLock ctx $
if tls13 then
sendPacket13 ctx $ Alert13 [(AlertLevel_Warning, CloseNotify)]
else
sendPacket ctx $ Alert [(AlertLevel_Warning, CloseNotify)]
getNegotiatedProtocol :: MonadIO m => Context -> m (Maybe B.ByteString)
getNegotiatedProtocol ctx = liftIO $ usingState_ ctx S.getNegotiatedProtocol
getClientSNI :: MonadIO m => Context -> m (Maybe HostName)
getClientSNI ctx = liftIO $ usingState_ ctx S.getClientSNI
sendData :: MonadIO m => Context -> L.ByteString -> m ()
sendData ctx dataToSend = liftIO $ do
tls13 <- tls13orLater ctx
let sendP
| tls13 = sendPacket13 ctx . AppData13
| otherwise = sendPacket ctx . AppData
withWriteLock ctx $ do
checkValid ctx
mapM_ (mapChunks_ 16384 sendP) (L.toChunks dataToSend)
recvData :: MonadIO m => Context -> m B.ByteString
recvData ctx = liftIO $ do
tls13 <- tls13orLater ctx
withReadLock ctx $ do
checkValid ctx
if tls13 then recvData13 ctx else recvData1 ctx
recvData1 :: Context -> IO B.ByteString
recvData1 ctx = do
pkt <- recvPacket ctx
either (onError terminate) process pkt
where process (Handshake [ch@ClientHello{}]) =
handshakeWith ctx ch >> recvData1 ctx
process (Handshake [hr@HelloRequest]) =
handshakeWith ctx hr >> recvData1 ctx
process (Alert [(AlertLevel_Warning, CloseNotify)]) = tryBye ctx >> setEOF ctx >> return B.empty
process (Alert [(AlertLevel_Fatal, desc)]) = do
setEOF ctx
E.throwIO (Terminated True ("received fatal error: " ++ show desc) (Error_Protocol ("remote side fatal error", True, desc)))
process (AppData "") = recvData1 ctx
process (AppData x) = return x
process p = let reason = "unexpected message " ++ show p in
terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
terminate = terminateWithWriteLock ctx (sendPacket ctx . Alert)
recvData13 :: Context -> IO B.ByteString
recvData13 ctx = do
pkt <- recvPacket13 ctx
either (onError terminate) process pkt
where process (Alert13 [(AlertLevel_Warning, CloseNotify)]) = tryBye ctx >> setEOF ctx >> return B.empty
process (Alert13 [(AlertLevel_Fatal, desc)]) = do
setEOF ctx
E.throwIO (Terminated True ("received fatal error: " ++ show desc) (Error_Protocol ("remote side fatal error", True, desc)))
process (Handshake13 hs) = do
loopHandshake13 hs
recvData13 ctx
process (AppData13 "") = recvData13 ctx
process (AppData13 x) = do
let chunkLen = C8.length x
established <- ctxEstablished ctx
case established of
EarlyDataAllowed maxSize
| chunkLen <= maxSize -> do
setEstablished ctx $ EarlyDataAllowed (maxSize - chunkLen)
return x
| otherwise ->
let reason = "early data overflow" in
terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
EarlyDataNotAllowed n
| n > 0 -> do
setEstablished ctx $ EarlyDataNotAllowed (n - 1)
recvData13 ctx
| otherwise ->
let reason = "early data deprotect overflow" in
terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
Established -> return x
NotEstablished -> throwCore $ Error_Protocol ("data at not-established", True, UnexpectedMessage)
process ChangeCipherSpec13 = recvData13 ctx
process p = let reason = "unexpected message " ++ show p in
terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
loopHandshake13 [] = return ()
loopHandshake13 (ClientHello13{}:_) = do
let reason = "Client hello is not allowed"
terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
loopHandshake13 (NewSessionTicket13 life add nonce label exts:hs) = do
role <- usingState_ ctx S.isClientContext
unless (role == ClientRole) $
let reason = "Session ticket is allowed for client only"
in terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
withWriteLock ctx $ do
ResuptionSecret resumptionMasterSecret <- usingHState ctx getTLS13Secret
(usedHash, usedCipher, _) <- getTxState ctx
let hashSize = hashDigestSize usedHash
psk = hkdfExpandLabel usedHash resumptionMasterSecret "resumption" nonce hashSize
maxSize = case extensionLookup extensionID_EarlyData exts >>= extensionDecode MsgTNewSessionTicket of
Just (EarlyDataIndication (Just ms)) -> fromIntegral $ safeNonNegative32 ms
_ -> 0
life7d = min life 604800
tinfo <- createTLS13TicketInfo life7d (Right add) Nothing
sdata <- getSessionData13 ctx usedCipher tinfo maxSize psk
let !label' = B.copy label
sessionEstablish (sharedSessionManager $ ctxShared ctx) label' sdata
loopHandshake13 hs
loopHandshake13 (KeyUpdate13 mode:hs) = do
checkAlignment hs
established <- ctxEstablished ctx
if established == Established then do
keyUpdate ctx getRxState setRxState
when (mode == UpdateRequested) $ withWriteLock ctx $ do
sendPacket13 ctx $ Handshake13 [KeyUpdate13 UpdateNotRequested]
keyUpdate ctx getTxState setTxState
loopHandshake13 hs
else do
let reason = "received key update before established"
terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
loopHandshake13 (h@CertRequest13{}:hs) =
postHandshakeAuthWith ctx h >> loopHandshake13 hs
loopHandshake13 (h@Certificate13{}:hs) =
postHandshakeAuthWith ctx h >> loopHandshake13 hs
loopHandshake13 (h:hs) = do
mPendingAction <- popPendingAction ctx
case mPendingAction of
Nothing -> let reason = "unexpected handshake message " ++ show h in
terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
Just action -> do
withWriteLock ctx $ handleException ctx $
case action of
PendingAction needAligned pa -> do
when needAligned $ checkAlignment hs
processHandshake13 ctx h >> pa h
PendingActionHash needAligned pa -> do
when needAligned $ checkAlignment hs
d <- transcriptHash ctx
processHandshake13 ctx h
pa d h
loopHandshake13 hs
terminate = terminateWithWriteLock ctx (sendPacket13 ctx . Alert13)
checkAlignment hs = do
complete <- isRecvComplete ctx
unless (complete && null hs) $
let reason = "received message not aligned with record boundary"
in terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
tryBye :: Context -> IO ()
tryBye ctx = catchException (bye ctx) (\_ -> return ())
onError :: Monad m => (TLSError -> AlertLevel -> AlertDescription -> String -> m B.ByteString)
-> TLSError -> m B.ByteString
onError _ Error_EOF =
return B.empty
onError terminate err@(Error_Protocol (reason,fatal,desc)) =
terminate err (if fatal then AlertLevel_Fatal else AlertLevel_Warning) desc reason
onError terminate err =
terminate err AlertLevel_Fatal InternalError (show err)
terminateWithWriteLock :: Context -> ([(AlertLevel, AlertDescription)] -> IO ())
-> TLSError -> AlertLevel -> AlertDescription -> String -> IO a
terminateWithWriteLock ctx send err level desc reason = do
session <- usingState_ ctx getSession
withWriteLock ctx $ do
case session of
Session Nothing -> return ()
Session (Just sid) -> sessionInvalidate (sharedSessionManager $ ctxShared ctx) sid
catchException (send [(level, desc)]) (\_ -> return ())
setEOF ctx
E.throwIO (Terminated False reason err)
{-# DEPRECATED recvData' "use recvData that returns strict bytestring" #-}
recvData' :: MonadIO m => Context -> m L.ByteString
recvData' ctx = L.fromChunks . (:[]) <$> recvData ctx
keyUpdate :: Context
-> (Context -> IO (Hash,Cipher,C8.ByteString))
-> (Context -> Hash -> Cipher -> C8.ByteString -> IO ())
-> IO ()
keyUpdate ctx getState setState = do
(usedHash, usedCipher, applicationTrafficSecretN) <- getState ctx
let applicationTrafficSecretN1 = hkdfExpandLabel usedHash applicationTrafficSecretN "traffic upd" "" $ hashDigestSize usedHash
setState ctx usedHash usedCipher applicationTrafficSecretN1
data KeyUpdateRequest = OneWay
| TwoWay
deriving (Eq, Show)
updateKey :: MonadIO m => Context -> KeyUpdateRequest -> m Bool
updateKey ctx way = liftIO $ do
tls13 <- tls13orLater ctx
when tls13 $ do
let req = case way of
OneWay -> UpdateNotRequested
TwoWay -> UpdateRequested
withWriteLock ctx $ do
sendPacket13 ctx $ Handshake13 [KeyUpdate13 req]
keyUpdate ctx getTxState setTxState
return tls13