module Network.Wai.Handler.WarpTLS (
TLSSettings
, certFile
, keyFile
, onInsecure
, tlsLogging
, tlsAllowedVersions
, tlsCiphers
, defaultTlsSettings
, tlsSettings
, tlsSettingsMemory
, OnInsecure (..)
, runTLS
, runTLSSocket
, WarpTLSException (..)
) where
import qualified Network.TLS as TLS
import Network.Wai.Handler.Warp
import Network.Wai (Application)
import Network.Socket (Socket, sClose, withSocketsDo, SockAddr)
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Control.Exception (bracket, finally, handle, fromException, try, IOException)
import qualified Network.TLS.Extra as TLSExtra
import qualified Data.ByteString as B
import Data.Streaming.Network (bindPortTCP, acceptSafe, safeRecv)
import Control.Applicative ((<$>))
import qualified Data.IORef as I
import Control.Exception (Exception, throwIO)
import Data.Typeable (Typeable)
import Data.Default.Class (def)
import qualified Crypto.Random.AESCtr
import Network.Wai.Handler.Warp.Buffer (allocateBuffer, bufferSize, freeBuffer)
import Network.Socket.ByteString (sendAll)
import Control.Monad (unless, void)
import Data.ByteString.Lazy.Internal (defaultChunkSize)
import qualified System.IO as IO
import System.IO.Error (isEOFError)
data TLSSettings = TLSSettings {
certFile :: FilePath
, keyFile :: FilePath
, certMemory :: Maybe S.ByteString
, keyMemory :: Maybe S.ByteString
, onInsecure :: OnInsecure
, tlsLogging :: TLS.Logging
, tlsAllowedVersions :: [TLS.Version]
, tlsCiphers :: [TLS.Cipher]
}
defaultTlsSettings :: TLSSettings
defaultTlsSettings = TLSSettings {
certFile = "certificate.pem"
, keyFile = "key.pem"
, certMemory = Nothing
, keyMemory = Nothing
, onInsecure = DenyInsecure "This server only accepts secure HTTPS connections."
, tlsLogging = def
, tlsAllowedVersions = [TLS.TLS10,TLS.TLS11,TLS.TLS12]
, tlsCiphers = ciphers
}
ciphers :: [TLS.Cipher]
ciphers =
[ TLSExtra.cipher_AES128_SHA1
, TLSExtra.cipher_AES256_SHA1
, TLSExtra.cipher_RC4_128_MD5
, TLSExtra.cipher_RC4_128_SHA1
]
data OnInsecure = DenyInsecure L.ByteString
| AllowInsecure
tlsSettings :: FilePath
-> FilePath
-> TLSSettings
tlsSettings cert key = defaultTlsSettings {
certFile = cert
, keyFile = key
}
tlsSettingsMemory
:: S.ByteString
-> S.ByteString
-> TLSSettings
tlsSettingsMemory cert key = defaultTlsSettings
{ certMemory = Just cert
, keyMemory = Just key
}
runTLS :: TLSSettings -> Settings -> Application -> IO ()
runTLS tset set app = withSocketsDo $
bracket
(bindPortTCP (getPort set) (getHost set))
sClose
(\sock -> runTLSSocket tset set sock app)
runTLSSocket :: TLSSettings -> Settings -> Socket -> Application -> IO ()
runTLSSocket tlsset@TLSSettings{..} set sock app = do
credential <- case (certMemory, keyMemory) of
(Nothing, Nothing) -> either error id <$> TLS.credentialLoadX509 certFile keyFile
(mcert, mkey) -> do
cert <- maybe (S.readFile certFile) return mcert
key <- maybe (S.readFile keyFile) return mkey
either error return $ TLS.credentialLoadX509FromMemory cert key
runTLSSocket' tlsset set credential sock app
runTLSSocket' :: TLSSettings -> Settings -> TLS.Credential -> Socket -> Application -> IO ()
runTLSSocket' tlsset@TLSSettings{..} set credential sock app =
runSettingsConnectionMakerSecure set get app
where
get = getter tlsset sock params
params = def {
TLS.serverWantClientCert = False
, TLS.serverSupported = def {
TLS.supportedVersions = tlsAllowedVersions
, TLS.supportedCiphers = tlsCiphers
}
, TLS.serverShared = def {
TLS.sharedCredentials = TLS.Credentials [credential]
}
}
getter :: TLS.TLSParams params => TLSSettings -> Socket -> params -> IO (IO (Connection, Bool), SockAddr)
getter tlsset@TLSSettings{..} sock params = do
(s, sa) <- acceptSafe sock
return (mkConn tlsset s params, sa)
mkConn :: TLS.TLSParams params => TLSSettings -> Socket -> params -> IO (Connection, Bool)
mkConn tlsset s params = do
firstBS <- safeRecv s 4096
cachedRef <- I.newIORef firstBS
if not (B.null firstBS) && B.head firstBS == 0x16 then
httpOverTls tlsset s cachedRef params
else
plainHTTP tlsset s cachedRef
httpOverTls :: TLS.TLSParams params => TLSSettings -> Socket -> I.IORef B.ByteString -> params -> IO (Connection, Bool)
httpOverTls TLSSettings{..} s cachedRef params = do
gen <- Crypto.Random.AESCtr.makeSystem
ctx <- TLS.contextNew backend params gen
TLS.contextHookSetLogging ctx tlsLogging
TLS.handshake ctx
readBuf <- allocateBuffer bufferSize
writeBuf <- allocateBuffer bufferSize
return (conn ctx readBuf writeBuf, True)
where
backend = TLS.Backend {
TLS.backendFlush = return ()
, TLS.backendClose = sClose s
, TLS.backendSend = sendAll s
, TLS.backendRecv = getNext cachedRef s
}
conn ctx readBuf writeBuf = Connection {
connSendMany = TLS.sendData ctx . L.fromChunks
, connSendAll = TLS.sendData ctx . L.fromChunks . return
, connSendFile = sendfile
, connClose = close
, connRecv = recv
, connSendFileOverride = NotOverride
, connReadBuffer = readBuf
, connWriteBuffer = writeBuf
, connBufferSize = bufferSize
}
where
sendfile fp offset len tickle headers = do
TLS.sendData ctx $ L.fromChunks headers
IO.withBinaryFile fp IO.ReadMode $ \h -> do
IO.hSeek h IO.AbsoluteSeek offset
loop h $ fromIntegral len
where
loop _ remaining | remaining <= 0 = return ()
loop h remaining = do
bs <- B.hGetSome h defaultChunkSize
unless (B.null bs) $ do
let x = B.take remaining bs
TLS.sendData ctx $ L.fromChunks [x]
tickle
loop h $ remaining B.length x
close = freeBuffer readBuf `finally`
freeBuffer writeBuf `finally`
void (tryIO $ TLS.bye ctx) `finally`
TLS.contextClose ctx
recv = handle onEOF go
where
onEOF e
| Just TLS.Error_EOF <- fromException e = return B.empty
| Just ioe <- fromException e, isEOFError ioe = return B.empty | otherwise = throwIO e
go = do
x <- TLS.recvData ctx
if B.null x then
go
else
return x
tryIO :: IO a -> IO (Either IOException a)
tryIO = try
plainHTTP :: TLSSettings -> Socket -> I.IORef B.ByteString -> IO (Connection, Bool)
plainHTTP TLSSettings{..} s cachedRef = case onInsecure of
AllowInsecure -> do
conn' <- socketConnection s
return (conn' { connRecv = getNext cachedRef s 4096 }, False)
DenyInsecure lbs -> do
sendAll s "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n"
mapM_ (sendAll s) $ L.toChunks lbs
sClose s
throwIO InsecureConnectionDenied
getNext :: I.IORef B.ByteString -> Socket -> Int -> IO B.ByteString
getNext cachedRef s size = do
cached <- I.readIORef cachedRef
loop cached
where
loop bs | B.length bs >= size = do
let (x, y) = B.splitAt size bs
I.writeIORef cachedRef y
return x
loop bs1 = do
bs2 <- safeRecv s 4096
if B.null bs2 then do
I.writeIORef cachedRef B.empty
return bs1
else
loop $ B.append bs1 bs2
data WarpTLSException = InsecureConnectionDenied
deriving (Show, Typeable)
instance Exception WarpTLSException