module Network.Wai.Handler.WarpTLS (
TLSSettings
, defaultTlsSettings
, tlsSettings
, tlsSettingsMemory
, tlsSettingsChain
, tlsSettingsChainMemory
, certFile
, keyFile
, tlsLogging
, tlsAllowedVersions
, tlsCiphers
, tlsWantClientCert
, tlsServerHooks
, onInsecure
, OnInsecure (..)
, runTLS
, runTLSSocket
, runHTTP2TLS
, runHTTP2TLSSocket
, WarpTLSException (..)
) where
#if __GLASGOW_HASKELL__ < 709
import Control.Applicative ((<$>))
#endif
import Control.Applicative ((<|>))
import Control.Exception (Exception, throwIO, bracket, finally, handle, fromException, try, IOException, onException, SomeException(..))
import qualified Control.Exception as E
import Control.Monad (void)
import qualified Crypto.Random.AESCtr
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.Default.Class (def)
import qualified Data.IORef as I
import Data.Streaming.Network (bindPortTCP, safeRecv)
import Data.Typeable (Typeable)
import Network.Socket (Socket, sClose, withSocketsDo, SockAddr, accept)
import Network.Socket.ByteString (sendAll)
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra as TLSExtra
import Network.Wai (Application)
import Network.Wai.HTTP2 (HTTP2Application)
import Network.Wai.Handler.Warp
import Network.Wai.Handler.Warp.Internal
import System.IO.Error (isEOFError)
data TLSSettings = TLSSettings {
certFile :: FilePath
, chainCertFiles :: [FilePath]
, keyFile :: FilePath
, certMemory :: Maybe S.ByteString
, chainCertsMemory :: [S.ByteString]
, keyMemory :: Maybe S.ByteString
, onInsecure :: OnInsecure
, tlsLogging :: TLS.Logging
, tlsAllowedVersions :: [TLS.Version]
, tlsCiphers :: [TLS.Cipher]
, tlsWantClientCert :: Bool
, tlsServerHooks :: TLS.ServerHooks
}
defaultTlsSettings :: TLSSettings
defaultTlsSettings = TLSSettings {
certFile = "certificate.pem"
, chainCertFiles = []
, keyFile = "key.pem"
, certMemory = Nothing
, chainCertsMemory = []
, keyMemory = Nothing
, onInsecure = DenyInsecure "This server only accepts secure HTTPS connections."
, tlsLogging = def
, tlsAllowedVersions = [TLS.TLS12,TLS.TLS11,TLS.TLS10]
, tlsCiphers = ciphers
, tlsWantClientCert = False
, tlsServerHooks = def
}
ciphers :: [TLS.Cipher]
ciphers =
[ TLSExtra.cipher_ECDHE_RSA_AES128GCM_SHA256
, TLSExtra.cipher_ECDHE_RSA_AES128CBC_SHA256
, TLSExtra.cipher_ECDHE_RSA_AES128CBC_SHA
, TLSExtra.cipher_DHE_RSA_AES128GCM_SHA256
, TLSExtra.cipher_DHE_RSA_AES256_SHA256
, TLSExtra.cipher_DHE_RSA_AES128_SHA256
, TLSExtra.cipher_DHE_RSA_AES256_SHA1
, TLSExtra.cipher_DHE_RSA_AES128_SHA1
, TLSExtra.cipher_DHE_DSS_AES128_SHA1
, TLSExtra.cipher_DHE_DSS_AES256_SHA1
, TLSExtra.cipher_AES128_SHA1
, TLSExtra.cipher_AES256_SHA1
]
data OnInsecure = DenyInsecure L.ByteString
| AllowInsecure
deriving (Show)
tlsSettings :: FilePath
-> FilePath
-> TLSSettings
tlsSettings cert key = defaultTlsSettings {
certFile = cert
, keyFile = key
}
tlsSettingsChain
:: FilePath
-> [FilePath]
-> FilePath
-> TLSSettings
tlsSettingsChain cert chainCerts key = defaultTlsSettings {
certFile = cert
, chainCertFiles = chainCerts
, keyFile = key
}
tlsSettingsMemory
:: S.ByteString
-> S.ByteString
-> TLSSettings
tlsSettingsMemory cert key = defaultTlsSettings
{ certMemory = Just cert
, keyMemory = Just key
}
tlsSettingsChainMemory
:: S.ByteString
-> [S.ByteString]
-> S.ByteString
-> TLSSettings
tlsSettingsChainMemory cert chainCerts key = defaultTlsSettings
{ certMemory = Just cert
, chainCertsMemory = chainCerts
, keyMemory = Just key
}
(.:) :: (c -> d) -> (a -> b -> c) -> a -> b -> d
f .: g = curry $ f . uncurry g
runTLS :: TLSSettings -> Settings -> Application -> IO ()
runTLS tset set = runServeTLS tset set . serveDefault
runHTTP2TLS :: TLSSettings -> Settings -> HTTP2Application -> Application -> IO ()
runHTTP2TLS tset set = runServeTLS tset set .: serveHTTP2
runServeTLS :: TLSSettings -> Settings -> ServeConnection -> IO ()
runServeTLS tset set serve = withSocketsDo $
bracket
(bindPortTCP (getPort set) (getHost set))
sClose
(\sock -> runServeTLSSocket tset set sock serve)
runTLSSocket :: TLSSettings -> Settings -> Socket -> Application -> IO ()
runTLSSocket tset set sock = runServeTLSSocket tset set sock . serveDefault
runHTTP2TLSSocket :: TLSSettings -> Settings -> Socket -> HTTP2Application -> Application -> IO ()
runHTTP2TLSSocket tset set sock = runServeTLSSocket tset set sock .: serveHTTP2
runServeTLSSocket :: TLSSettings -> Settings -> Socket -> ServeConnection -> IO ()
runServeTLSSocket tlsset@TLSSettings{..} set sock serve = do
credential <- case (certMemory, keyMemory) of
(Nothing, Nothing) ->
either error id <$>
TLS.credentialLoadX509Chain certFile chainCertFiles keyFile
(mcert, mkey) -> do
cert <- maybe (S.readFile certFile) return mcert
key <- maybe (S.readFile keyFile) return mkey
either error return $
TLS.credentialLoadX509ChainFromMemory cert chainCertsMemory key
runServeTLSSocket' tlsset set credential sock serve
runServeTLSSocket' :: TLSSettings -> Settings -> TLS.Credential -> Socket -> ServeConnection -> IO ()
runServeTLSSocket' tlsset@TLSSettings{..} set credential sock serve =
runServeSettingsConnectionMakerSecure set get serve
where
get = getter tlsset sock params
params = def {
TLS.serverWantClientCert = tlsWantClientCert
, TLS.serverCACertificates = []
, TLS.serverDHEParams = Nothing
, TLS.serverHooks = hooks
, TLS.serverShared = shared
, TLS.serverSupported = supported
}
hooks = tlsServerHooks {
TLS.onALPNClientSuggest = TLS.onALPNClientSuggest tlsServerHooks <|>
(if settingsHTTP2Enabled set then Just alpn else Nothing)
}
shared = def {
TLS.sharedCredentials = TLS.Credentials [credential]
}
supported = def {
TLS.supportedVersions = tlsAllowedVersions
, TLS.supportedCiphers = tlsCiphers
, TLS.supportedCompressions = [TLS.nullCompression]
, TLS.supportedHashSignatures = [
(TLS.HashSHA256, TLS.SignatureRSA)
, (TLS.HashSHA224, TLS.SignatureRSA)
, (TLS.HashSHA1, TLS.SignatureRSA)
, (TLS.HashSHA1, TLS.SignatureDSS)
]
, TLS.supportedSecureRenegotiation = True
, TLS.supportedClientInitiatedRenegotiation = False
, TLS.supportedSession = True
, TLS.supportedFallbackScsv = True
}
alpn :: [S.ByteString] -> IO S.ByteString
alpn xs
| "h2" `elem` xs = return "h2"
| "h2-16" `elem` xs = return "h2-16"
| "h2-15" `elem` xs = return "h2-15"
| "h2-14" `elem` xs = return "h2-14"
| otherwise = return "http/1.1"
getter :: TLS.TLSParams params => TLSSettings -> Socket -> params -> IO (IO (Connection, Transport), SockAddr)
getter tlsset@TLSSettings{..} sock params = do
(s, sa) <- accept sock
return (mkConn tlsset s params, sa)
mkConn :: TLS.TLSParams params => TLSSettings -> Socket -> params -> IO (Connection, Transport)
mkConn tlsset s params = do
firstBS <- safeRecv s 4096
(if not (S.null firstBS) && S.head firstBS == 0x16 then
httpOverTls tlsset s firstBS params
else
plainHTTP tlsset s firstBS) `onException` sClose s
httpOverTls :: TLS.TLSParams params => TLSSettings -> Socket -> S.ByteString -> params -> IO (Connection, Transport)
httpOverTls TLSSettings{..} s bs0 params = do
recvN <- makePlainReceiveN s bs0
#if MIN_VERSION_tls(1,3,0)
ctx <- TLS.contextNew (backend recvN) params
#else
gen <- Crypto.Random.AESCtr.makeSystem
ctx <- TLS.contextNew (backend recvN) params gen
#endif
TLS.contextHookSetLogging ctx tlsLogging
TLS.handshake ctx
writeBuf <- allocateBuffer bufferSize
ref <- I.newIORef ""
tls <- getTLSinfo ctx
return (conn ctx writeBuf ref, tls)
where
backend recvN = TLS.Backend {
TLS.backendFlush = return ()
, TLS.backendClose = sClose s
, TLS.backendSend = sendAll' s
, TLS.backendRecv = recvN
}
sendAll' sock bs = sendAll sock bs `E.catch` \(SomeException _) ->
throwIO ConnectionClosedByPeer
conn ctx writeBuf ref = Connection {
connSendMany = TLS.sendData ctx . L.fromChunks
, connSendAll = sendall
, connSendFile = sendfile
, connClose = close
, connRecv = recv ref
, connRecvBuf = recvBuf ref
, connWriteBuffer = writeBuf
, connBufferSize = bufferSize
}
where
sendall = TLS.sendData ctx . L.fromChunks . return
sendfile fid offset len hook headers =
readSendFile writeBuf bufferSize sendall fid offset len hook headers
close = freeBuffer writeBuf `finally`
void (tryIO $ TLS.bye ctx) `finally`
TLS.contextClose ctx
recv cref = do
cached <- I.readIORef cref
if cached /= "" then do
I.writeIORef cref ""
return cached
else
recv'
recv' = handle onEOF go
where
onEOF e
| Just TLS.Error_EOF <- fromException e = return S.empty
| Just ioe <- fromException e, isEOFError ioe = return S.empty | otherwise = throwIO e
go = do
x <- TLS.recvData ctx
if S.null x then
go
else
return x
recvBuf cref buf siz = do
cached <- I.readIORef cref
(ret, leftover) <- fill cached buf siz recv'
I.writeIORef cref leftover
return ret
fill :: S.ByteString -> Buffer -> BufSize -> Recv -> IO (Bool,S.ByteString)
fill bs0 buf0 siz0 recv
| siz0 <= len0 = do
let (bs, leftover) = S.splitAt siz0 bs0
void $ copy buf0 bs
return (True, leftover)
| otherwise = do
buf <- copy buf0 bs0
loop buf (siz0 len0)
where
len0 = S.length bs0
loop _ 0 = return (True, "")
loop buf siz = do
bs <- recv
let len = S.length bs
if len == 0 then return (False, "")
else if (len <= siz) then do
buf' <- copy buf bs
loop buf' (siz len)
else do
let (bs1,bs2) = S.splitAt siz bs
void $ copy buf bs1
return (True, bs2)
getTLSinfo :: TLS.Context -> IO Transport
getTLSinfo ctx = do
proto <- TLS.getNegotiatedProtocol ctx
minfo <- TLS.contextGetInformation ctx
case minfo of
Nothing -> return TCP
Just TLS.Information{..} -> do
let (major, minor) = case infoVersion of
TLS.SSL2 -> (2,0)
TLS.SSL3 -> (3,0)
TLS.TLS10 -> (3,1)
TLS.TLS11 -> (3,2)
TLS.TLS12 -> (3,3)
return TLS {
tlsMajorVersion = major
, tlsMinorVersion = minor
, tlsNegotiatedProtocol = proto
, tlsChiperID = TLS.cipherID infoCipher
}
tryIO :: IO a -> IO (Either IOException a)
tryIO = try
plainHTTP :: TLSSettings -> Socket -> S.ByteString -> IO (Connection, Transport)
plainHTTP TLSSettings{..} s bs0 = case onInsecure of
AllowInsecure -> do
conn' <- socketConnection s
cachedRef <- I.newIORef bs0
let conn'' = conn'
{ connRecv = recvPlain cachedRef (connRecv conn')
}
return (conn'', TCP)
DenyInsecure lbs -> do
sendAll s "HTTP/1.1 426 Upgrade Required\
\r\nUpgrade: TLS/1.0, HTTP/1.1\
\r\nConnection: Upgrade\
\r\nContent-Type: text/plain\r\n\r\n"
mapM_ (sendAll s) $ L.toChunks lbs
sClose s
throwIO InsecureConnectionDenied
recvPlain :: I.IORef S.ByteString -> IO S.ByteString -> IO S.ByteString
recvPlain ref fallback = do
bs <- I.readIORef ref
if S.null bs
then fallback
else do
I.writeIORef ref S.empty
return bs
data WarpTLSException = InsecureConnectionDenied
deriving (Show, Typeable)
instance Exception WarpTLSException