{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE PatternGuards #-}
module Network.Wai.Handler.WarpTLS (
TLSSettings
, defaultTlsSettings
, tlsSettings
, tlsSettingsMemory
, tlsSettingsChain
, tlsSettingsChainMemory
, certFile
, keyFile
, tlsLogging
, tlsAllowedVersions
, tlsCiphers
, tlsWantClientCert
, tlsServerHooks
, tlsServerDHEParams
, tlsSessionManagerConfig
, onInsecure
, OnInsecure (..)
, runTLS
, runTLSSocket
, WarpTLSException (..)
, DH.Params
, DH.generateParams
) where
import Control.Applicative ((<|>))
import Control.Exception (Exception, throwIO, bracket, finally, handle, fromException, try, IOException, onException, SomeException(..), handleJust)
import qualified Control.Exception as E
import Control.Monad (void, guard)
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.Default.Class (def)
import qualified Data.IORef as I
import Data.Streaming.Network (bindPortTCP, safeRecv)
import Data.Typeable (Typeable)
import Network.Socket (Socket, close, withSocketsDo, SockAddr, accept)
import Network.Socket.ByteString (sendAll)
import qualified Network.TLS as TLS
import qualified Crypto.PubKey.DH as DH
import qualified Network.TLS.Extra as TLSExtra
import qualified Network.TLS.SessionManager as SM
import Network.Wai (Application)
import Network.Wai.Handler.Warp
import Network.Wai.Handler.Warp.Internal
import System.IO.Error (isEOFError)
data TLSSettings = TLSSettings {
certFile :: FilePath
, chainCertFiles :: [FilePath]
, keyFile :: FilePath
, certMemory :: Maybe S.ByteString
, chainCertsMemory :: [S.ByteString]
, keyMemory :: Maybe S.ByteString
, onInsecure :: OnInsecure
, tlsLogging :: TLS.Logging
, tlsAllowedVersions :: [TLS.Version]
, tlsCiphers :: [TLS.Cipher]
, tlsWantClientCert :: Bool
, tlsServerHooks :: TLS.ServerHooks
, tlsServerDHEParams :: Maybe DH.Params
, tlsSessionManagerConfig :: Maybe SM.Config
}
defaultTlsSettings :: TLSSettings
defaultTlsSettings = TLSSettings {
certFile = "certificate.pem"
, chainCertFiles = []
, keyFile = "key.pem"
, certMemory = Nothing
, chainCertsMemory = []
, keyMemory = Nothing
, onInsecure = DenyInsecure "This server only accepts secure HTTPS connections."
, tlsLogging = def
, tlsAllowedVersions = [TLS.TLS12,TLS.TLS11,TLS.TLS10]
, tlsCiphers = ciphers
, tlsWantClientCert = False
, tlsServerHooks = def
, tlsServerDHEParams = Nothing
, tlsSessionManagerConfig = Nothing
}
ciphers :: [TLS.Cipher]
ciphers = TLSExtra.ciphersuite_strong
data OnInsecure = DenyInsecure L.ByteString
| AllowInsecure
deriving (Show)
tlsSettings :: FilePath
-> FilePath
-> TLSSettings
tlsSettings cert key = defaultTlsSettings {
certFile = cert
, keyFile = key
}
tlsSettingsChain
:: FilePath
-> [FilePath]
-> FilePath
-> TLSSettings
tlsSettingsChain cert chainCerts key = defaultTlsSettings {
certFile = cert
, chainCertFiles = chainCerts
, keyFile = key
}
tlsSettingsMemory
:: S.ByteString
-> S.ByteString
-> TLSSettings
tlsSettingsMemory cert key = defaultTlsSettings
{ certMemory = Just cert
, keyMemory = Just key
}
tlsSettingsChainMemory
:: S.ByteString
-> [S.ByteString]
-> S.ByteString
-> TLSSettings
tlsSettingsChainMemory cert chainCerts key = defaultTlsSettings
{ certMemory = Just cert
, chainCertsMemory = chainCerts
, keyMemory = Just key
}
runTLS :: TLSSettings -> Settings -> Application -> IO ()
runTLS tset set app = withSocketsDo $
bracket
(bindPortTCP (getPort set) (getHost set))
close
(\sock -> runTLSSocket tset set sock app)
runTLSSocket :: TLSSettings -> Settings -> Socket -> Application -> IO ()
runTLSSocket tlsset@TLSSettings{..} set sock app = do
credential <- case (certMemory, keyMemory) of
(Nothing, Nothing) ->
either error id <$>
TLS.credentialLoadX509Chain certFile chainCertFiles keyFile
(mcert, mkey) -> do
cert <- maybe (S.readFile certFile) return mcert
key <- maybe (S.readFile keyFile) return mkey
either error return $
TLS.credentialLoadX509ChainFromMemory cert chainCertsMemory key
mgr <- case tlsSessionManagerConfig of
Nothing -> return TLS.noSessionManager
Just config -> SM.newSessionManager config
runTLSSocket' tlsset set credential mgr sock app
runTLSSocket' :: TLSSettings -> Settings -> TLS.Credential -> TLS.SessionManager -> Socket -> Application -> IO ()
runTLSSocket' tlsset@TLSSettings{..} set credential mgr sock app =
runSettingsConnectionMakerSecure set get app
where
get = getter tlsset sock params
params = def {
TLS.serverWantClientCert = tlsWantClientCert
, TLS.serverCACertificates = []
, TLS.serverDHEParams = tlsServerDHEParams
, TLS.serverHooks = hooks
, TLS.serverShared = shared
, TLS.serverSupported = supported
}
hooks = tlsServerHooks {
TLS.onALPNClientSuggest = TLS.onALPNClientSuggest tlsServerHooks <|>
(if settingsHTTP2Enabled set then Just alpn else Nothing)
}
shared = def {
TLS.sharedCredentials = TLS.Credentials [credential]
, TLS.sharedSessionManager = mgr
}
supported = def {
TLS.supportedVersions = tlsAllowedVersions
, TLS.supportedCiphers = tlsCiphers
, TLS.supportedCompressions = [TLS.nullCompression]
, TLS.supportedSecureRenegotiation = True
, TLS.supportedClientInitiatedRenegotiation = False
, TLS.supportedSession = True
, TLS.supportedFallbackScsv = True
}
alpn :: [S.ByteString] -> IO S.ByteString
alpn xs
| "h2" `elem` xs = return "h2"
| otherwise = return "http/1.1"
getter :: TLS.TLSParams params => TLSSettings -> Socket -> params -> IO (IO (Connection, Transport), SockAddr)
getter tlsset@TLSSettings{..} sock params = do
#if WINDOWS
(s, sa) <- windowsThreadBlockHack $ accept sock
#else
(s, sa) <- accept sock
#endif
setSocketCloseOnExec s
return (mkConn tlsset s params, sa)
mkConn :: TLS.TLSParams params => TLSSettings -> Socket -> params -> IO (Connection, Transport)
mkConn tlsset s params = switch `onException` close s
where
switch = do
firstBS <- safeRecv s 4096
if not (S.null firstBS) && S.head firstBS == 0x16 then
httpOverTls tlsset s firstBS params
else
plainHTTP tlsset s firstBS
httpOverTls :: TLS.TLSParams params => TLSSettings -> Socket -> S.ByteString -> params -> IO (Connection, Transport)
httpOverTls TLSSettings{..} s bs0 params = do
recvN <- makePlainReceiveN s bs0
ctx <- TLS.contextNew (backend recvN) params
TLS.contextHookSetLogging ctx tlsLogging
TLS.handshake ctx
writeBuf <- allocateBuffer bufferSize
ref <- I.newIORef ""
tls <- getTLSinfo ctx
return (conn ctx writeBuf ref, tls)
where
backend recvN = TLS.Backend {
TLS.backendFlush = return ()
, TLS.backendClose = close s
, TLS.backendSend = sendAll' s
, TLS.backendRecv = recvN
}
sendAll' sock bs = sendAll sock bs `E.catch` \(SomeException _) ->
throwIO ConnectionClosedByPeer
conn ctx writeBuf ref = Connection {
connSendMany = TLS.sendData ctx . L.fromChunks
, connSendAll = sendall
, connSendFile = sendfile
, connClose = close'
, connFree = freeBuffer writeBuf
, connRecv = recv ref
, connRecvBuf = recvBuf ref
, connWriteBuffer = writeBuf
, connBufferSize = bufferSize
}
where
sendall = TLS.sendData ctx . L.fromChunks . return
sendfile fid offset len hook headers =
readSendFile writeBuf bufferSize sendall fid offset len hook headers
close' = void (tryIO sendBye) `finally`
TLS.contextClose ctx
sendBye =
handleJust
(\e -> guard (e == ConnectionClosedByPeer) >> return e)
(const (return ()))
(TLS.bye ctx)
recv cref = do
cached <- I.readIORef cref
if cached /= "" then do
I.writeIORef cref ""
return cached
else
recv'
recv' = handle onEOF go
where
onEOF e
| Just TLS.Error_EOF <- fromException e = return S.empty
| Just ioe <- fromException e, isEOFError ioe = return S.empty | otherwise = throwIO e
go = do
x <- TLS.recvData ctx
if S.null x then
go
else
return x
recvBuf cref buf siz = do
cached <- I.readIORef cref
(ret, leftover) <- fill cached buf siz recv'
I.writeIORef cref leftover
return ret
fill :: S.ByteString -> Buffer -> BufSize -> Recv -> IO (Bool,S.ByteString)
fill bs0 buf0 siz0 recv
| siz0 <= len0 = do
let (bs, leftover) = S.splitAt siz0 bs0
void $ copy buf0 bs
return (True, leftover)
| otherwise = do
buf <- copy buf0 bs0
loop buf (siz0 - len0)
where
len0 = S.length bs0
loop _ 0 = return (True, "")
loop buf siz = do
bs <- recv
let len = S.length bs
if len == 0 then return (False, "")
else if (len <= siz) then do
buf' <- copy buf bs
loop buf' (siz - len)
else do
let (bs1,bs2) = S.splitAt siz bs
void $ copy buf bs1
return (True, bs2)
getTLSinfo :: TLS.Context -> IO Transport
getTLSinfo ctx = do
proto <- TLS.getNegotiatedProtocol ctx
minfo <- TLS.contextGetInformation ctx
case minfo of
Nothing -> return TCP
Just TLS.Information{..} -> do
let (major, minor) = case infoVersion of
TLS.SSL2 -> (2,0)
TLS.SSL3 -> (3,0)
TLS.TLS10 -> (3,1)
TLS.TLS11 -> (3,2)
TLS.TLS12 -> (3,3)
return TLS {
tlsMajorVersion = major
, tlsMinorVersion = minor
, tlsNegotiatedProtocol = proto
, tlsChiperID = TLS.cipherID infoCipher
}
tryIO :: IO a -> IO (Either IOException a)
tryIO = try
plainHTTP :: TLSSettings -> Socket -> S.ByteString -> IO (Connection, Transport)
plainHTTP TLSSettings{..} s bs0 = case onInsecure of
AllowInsecure -> do
conn' <- socketConnection s
cachedRef <- I.newIORef bs0
let conn'' = conn'
{ connRecv = recvPlain cachedRef (connRecv conn')
}
return (conn'', TCP)
DenyInsecure lbs -> do
sendAll s "HTTP/1.1 426 Upgrade Required\
\r\nUpgrade: TLS/1.0, HTTP/1.1\
\r\nConnection: Upgrade\
\r\nContent-Type: text/plain\r\n\r\n"
mapM_ (sendAll s) $ L.toChunks lbs
close s
throwIO InsecureConnectionDenied
recvPlain :: I.IORef S.ByteString -> IO S.ByteString -> IO S.ByteString
recvPlain ref fallback = do
bs <- I.readIORef ref
if S.null bs
then fallback
else do
I.writeIORef ref S.empty
return bs
data WarpTLSException = InsecureConnectionDenied
deriving (Show, Typeable)
instance Exception WarpTLSException