{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE OverloadedStrings, ScopedTypeVariables, BangPatterns #-}
module Network.TLS.Core
    (
    
      sendPacket
    , recvPacket
    
    , bye
    , handshake
    
    , getNegotiatedProtocol
    
    , getClientSNI
    
    , sendData
    , recvData
    , recvData'
    , updateKey
    , KeyUpdateRequest(..)
    , requestCertificate
    ) where
import Network.TLS.Cipher
import Network.TLS.Context
import Network.TLS.Crypto
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.State (getSession)
import Network.TLS.Parameters
import Network.TLS.IO
import Network.TLS.Session
import Network.TLS.Handshake
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Common13
import Network.TLS.Handshake.Process
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.PostHandshake
import Network.TLS.KeySchedule
import Network.TLS.Types (Role(..), HostName)
import Network.TLS.Util (catchException, mapChunks_)
import Network.TLS.Extension
import qualified Network.TLS.State as S
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as C8
import qualified Data.ByteString.Lazy as L
import qualified Control.Exception as E
import Control.Monad.State.Strict
bye :: MonadIO m => Context -> m ()
bye ctx = liftIO $ do
    
    
    
    eof <- ctxEOF ctx
    tls13 <- tls13orLater ctx
    unless eof $ withWriteLock ctx $
        if tls13 then
            sendPacket13 ctx $ Alert13 [(AlertLevel_Warning, CloseNotify)]
          else
            sendPacket ctx $ Alert [(AlertLevel_Warning, CloseNotify)]
getNegotiatedProtocol :: MonadIO m => Context -> m (Maybe B.ByteString)
getNegotiatedProtocol ctx = liftIO $ usingState_ ctx S.getNegotiatedProtocol
getClientSNI :: MonadIO m => Context -> m (Maybe HostName)
getClientSNI ctx = liftIO $ usingState_ ctx S.getClientSNI
sendData :: MonadIO m => Context -> L.ByteString -> m ()
sendData ctx dataToSend = liftIO $ do
    tls13 <- tls13orLater ctx
    let sendP
          | tls13     = sendPacket13 ctx . AppData13
          | otherwise = sendPacket ctx . AppData
    withWriteLock ctx $ do
        checkValid ctx
        
        
        
        mapM_ (mapChunks_ 16384 sendP) (L.toChunks dataToSend)
recvData :: MonadIO m => Context -> m B.ByteString
recvData ctx = liftIO $ do
    tls13 <- tls13orLater ctx
    withReadLock ctx $ do
        checkValid ctx
        
        
        
        
        
        
        
        if tls13 then recvData13 ctx else recvData1 ctx
recvData1 :: Context -> IO B.ByteString
recvData1 ctx = do
    pkt <- recvPacket ctx
    either (onError terminate) process pkt
  where process (Handshake [ch@ClientHello{}]) =
            handshakeWith ctx ch >> recvData1 ctx
        process (Handshake [hr@HelloRequest]) =
            handshakeWith ctx hr >> recvData1 ctx
        process (Alert [(AlertLevel_Warning, CloseNotify)]) = tryBye ctx >> setEOF ctx >> return B.empty
        process (Alert [(AlertLevel_Fatal, desc)]) = do
            setEOF ctx
            E.throwIO (Terminated True ("received fatal error: " ++ show desc) (Error_Protocol ("remote side fatal error", True, desc)))
        
        process (AppData "") = recvData1 ctx
        process (AppData x)  = return x
        process p            = let reason = "unexpected message " ++ show p in
                               terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
        terminate = terminateWithWriteLock ctx (sendPacket ctx . Alert)
recvData13 :: Context -> IO B.ByteString
recvData13 ctx = do
    pkt <- recvPacket13 ctx
    either (onError terminate) process pkt
  where process (Alert13 [(AlertLevel_Warning, CloseNotify)]) = tryBye ctx >> setEOF ctx >> return B.empty
        process (Alert13 [(AlertLevel_Fatal, desc)]) = do
            setEOF ctx
            E.throwIO (Terminated True ("received fatal error: " ++ show desc) (Error_Protocol ("remote side fatal error", True, desc)))
        process (Handshake13 hs) = do
            loopHandshake13 hs
            recvData13 ctx
        
        process (AppData13 "") = recvData13 ctx
        process (AppData13 x) = do
            let chunkLen = C8.length x
            established <- ctxEstablished ctx
            case established of
              EarlyDataAllowed maxSize
                | chunkLen <= maxSize -> do
                    setEstablished ctx $ EarlyDataAllowed (maxSize - chunkLen)
                    return x
                | otherwise ->
                    let reason = "early data overflow" in
                    terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
              EarlyDataNotAllowed n
                | n > 0     -> do
                    setEstablished ctx $ EarlyDataNotAllowed (n - 1)
                    recvData13 ctx 
                | otherwise ->
                    let reason = "early data deprotect overflow" in
                    terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
              Established         -> return x
              NotEstablished      -> throwCore $ Error_Protocol ("data at not-established", True, UnexpectedMessage)
        process ChangeCipherSpec13 = recvData13 ctx
        process p             = let reason = "unexpected message " ++ show p in
                                terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
        loopHandshake13 [] = return ()
        loopHandshake13 (ClientHello13{}:_) = do
            let reason = "Client hello is not allowed"
            terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
        
        
        loopHandshake13 (NewSessionTicket13 life add nonce label exts:hs) = do
            role <- usingState_ ctx S.isClientContext
            unless (role == ClientRole) $
                let reason = "Session ticket is allowed for client only"
                 in terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
            
            
            
            withWriteLock ctx $ do
                Just resumptionMasterSecret <- usingHState ctx getTLS13ResumptionSecret
                (_, usedCipher, _) <- getTxState ctx
                let choice = makeCipherChoice TLS13 usedCipher
                    psk = derivePSK choice resumptionMasterSecret nonce
                    maxSize = case extensionLookup extensionID_EarlyData exts >>= extensionDecode MsgTNewSessionTicket of
                        Just (EarlyDataIndication (Just ms)) -> fromIntegral $ safeNonNegative32 ms
                        _                                    -> 0
                    life7d = min life 604800  
                tinfo <- createTLS13TicketInfo life7d (Right add) Nothing
                sdata <- getSessionData13 ctx usedCipher tinfo maxSize psk
                let !label' = B.copy label
                sessionEstablish (sharedSessionManager $ ctxShared ctx) label' sdata
                
            loopHandshake13 hs
        loopHandshake13 (KeyUpdate13 mode:hs) = do
            checkAlignment hs
            established <- ctxEstablished ctx
            
            
            
            
            if established == Established then do
                keyUpdate ctx getRxState setRxState
                
                
                
                when (mode == UpdateRequested) $ withWriteLock ctx $ do
                    sendPacket13 ctx $ Handshake13 [KeyUpdate13 UpdateNotRequested]
                    keyUpdate ctx getTxState setTxState
                loopHandshake13 hs
              else do
                let reason = "received key update before established"
                terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
        loopHandshake13 (h@CertRequest13{}:hs) =
            postHandshakeAuthWith ctx h >> loopHandshake13 hs
        loopHandshake13 (h@Certificate13{}:hs) =
            postHandshakeAuthWith ctx h >> loopHandshake13 hs
        loopHandshake13 (h:hs) = do
            mPendingAction <- popPendingAction ctx
            case mPendingAction of
                Nothing -> let reason = "unexpected handshake message " ++ show h in
                           terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
                Just action -> do
                    
                    
                    withWriteLock ctx $ handleException ctx $
                        case action of
                            PendingAction needAligned pa -> do
                                when needAligned $ checkAlignment hs
                                processHandshake13 ctx h >> pa h
                            PendingActionHash needAligned pa -> do
                                when needAligned $ checkAlignment hs
                                d <- transcriptHash ctx
                                processHandshake13 ctx h
                                pa d h
                    loopHandshake13 hs
        terminate = terminateWithWriteLock ctx (sendPacket13 ctx . Alert13)
        checkAlignment hs = do
            complete <- isRecvComplete ctx
            unless (complete && null hs) $
                let reason = "received message not aligned with record boundary"
                 in terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
tryBye :: Context -> IO ()
tryBye ctx = catchException (bye ctx) (\_ -> return ())
onError :: Monad m => (TLSError -> AlertLevel -> AlertDescription -> String -> m B.ByteString)
                   -> TLSError -> m B.ByteString
onError _ Error_EOF = 
            return B.empty
onError terminate err@(Error_Protocol (reason,fatal,desc)) =
    terminate err (if fatal then AlertLevel_Fatal else AlertLevel_Warning) desc reason
onError terminate err =
    terminate err AlertLevel_Fatal InternalError (show err)
terminateWithWriteLock :: Context -> ([(AlertLevel, AlertDescription)] -> IO ())
                       -> TLSError -> AlertLevel -> AlertDescription -> String -> IO a
terminateWithWriteLock ctx send err level desc reason = do
    session <- usingState_ ctx getSession
    
    
    withWriteLock ctx $ do
        case session of
            Session Nothing    -> return ()
            Session (Just sid) -> sessionInvalidate (sharedSessionManager $ ctxShared ctx) sid
        catchException (send [(level, desc)]) (\_ -> return ())
    setEOF ctx
    E.throwIO (Terminated False reason err)
{-# DEPRECATED recvData' "use recvData that returns strict bytestring" #-}
recvData' :: MonadIO m => Context -> m L.ByteString
recvData' ctx = L.fromChunks . (:[]) <$> recvData ctx
keyUpdate :: Context
          -> (Context -> IO (Hash,Cipher,C8.ByteString))
          -> (Context -> Hash -> Cipher -> C8.ByteString -> IO ())
          -> IO ()
keyUpdate ctx getState setState = do
    (usedHash, usedCipher, applicationSecretN) <- getState ctx
    let applicationSecretN1 = hkdfExpandLabel usedHash applicationSecretN "traffic upd" "" $ hashDigestSize usedHash
    setState ctx usedHash usedCipher applicationSecretN1
data KeyUpdateRequest = OneWay 
                      | TwoWay 
                      deriving (Eq, Show)
updateKey :: MonadIO m => Context -> KeyUpdateRequest -> m Bool
updateKey ctx way = liftIO $ do
    tls13 <- tls13orLater ctx
    when tls13 $ do
        let req = case way of
                OneWay -> UpdateNotRequested
                TwoWay -> UpdateRequested
        
        
        withWriteLock ctx $ do
            sendPacket13 ctx $ Handshake13 [KeyUpdate13 req]
            keyUpdate ctx getTxState setTxState
    return tls13