{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RankNTypes #-}
module Network.HTTP2.Client.RawConnection (
RawHttp2Connection (..)
, newRawHttp2Connection
, newRawHttp2ConnectionSocket
) where
import Control.Monad (forever, when)
import Control.Concurrent.Async (Async, async, cancel, pollSTM)
import Control.Concurrent.STM (STM, atomically, retry, throwSTM)
import Control.Concurrent.STM.TVar (TVar, modifyTVar', newTVarIO, readTVar, writeTVar)
import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import Data.ByteString.Lazy (fromChunks)
import Data.Monoid ((<>))
import qualified Network.HTTP2 as HTTP2
import Network.Socket hiding (recv)
import Network.Socket.ByteString
import qualified Network.TLS as TLS
data RawHttp2Connection = RawHttp2Connection {
_sendRaw :: [ByteString] -> IO ()
, _nextRaw :: Int -> IO ByteString
, _close :: IO ()
}
newRawHttp2Connection :: HostName
-> PortNumber
-> Maybe TLS.ClientParams
-> IO RawHttp2Connection
newRawHttp2Connection host port mparams = do
let hints = defaultHints { addrFlags = [AI_NUMERICSERV], addrSocketType = Stream }
addr:_ <- getAddrInfo (Just hints) (Just host) (Just $ show port)
skt <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
setSocketOption skt NoDelay 1
connect skt (addrAddress addr)
newRawHttp2ConnectionSocket skt mparams
newRawHttp2ConnectionSocket
:: Socket
-> Maybe TLS.ClientParams
-> IO RawHttp2Connection
newRawHttp2ConnectionSocket skt mparams = do
conn <- maybe (plainTextRaw skt) (tlsRaw skt) mparams
_sendRaw conn [HTTP2.connectionPreface]
return conn
plainTextRaw :: Socket -> IO RawHttp2Connection
plainTextRaw skt = do
(b,putRaw) <- startWriteWorker (sendMany skt)
(a,getRaw) <- startReadWorker (recv skt)
let doClose = cancel a >> cancel b >> close skt
return $ RawHttp2Connection (atomically . putRaw) (atomically . getRaw) doClose
tlsRaw :: Socket -> TLS.ClientParams -> IO RawHttp2Connection
tlsRaw skt params = do
tlsContext <- TLS.contextNew skt (modifyParams params)
TLS.handshake tlsContext
(b,putRaw) <- startWriteWorker (TLS.sendData tlsContext . fromChunks)
(a,getRaw) <- startReadWorker (const $ TLS.recvData tlsContext)
let doClose = cancel a >> cancel b >> TLS.bye tlsContext >> TLS.contextClose tlsContext
return $ RawHttp2Connection (atomically . putRaw) (atomically . getRaw) doClose
where
modifyParams prms = prms {
TLS.clientHooks = (TLS.clientHooks prms) {
TLS.onSuggestALPN = return $ Just [ "h2", "h2-17" ]
}
}
startWriteWorker
:: ([ByteString] -> IO ())
-> IO (Async (), [ByteString] -> STM ())
startWriteWorker sendChunks = do
outQ <- newTVarIO []
let putRaw chunks = modifyTVar' outQ (\xs -> xs ++ chunks)
b <- async $ writeWorkerLoop outQ sendChunks
return (b, putRaw)
writeWorkerLoop :: TVar [ByteString] -> ([ByteString] -> IO ()) -> IO ()
writeWorkerLoop outQ sendChunks = forever $ do
xs <- atomically $ do
chunks <- readTVar outQ
when (null chunks) retry
writeTVar outQ []
return chunks
sendChunks xs
startReadWorker
:: (Int -> IO ByteString)
-> IO (Async (), (Int -> STM ByteString))
startReadWorker get = do
buf <- newTVarIO ""
a <- async $ readWorkerLoop buf get
return $ (a, getRawWorker a buf)
readWorkerLoop :: TVar ByteString -> (Int -> IO ByteString) -> IO ()
readWorkerLoop buf next = forever $ do
dat <- next 4096
atomically $ modifyTVar' buf (\bs -> (bs <> dat))
getRawWorker :: Async () -> TVar ByteString -> Int -> STM ByteString
getRawWorker a buf amount = do
asyncStatus <- pollSTM a
case asyncStatus of
(Just (Left e)) -> throwSTM e
_ -> return ()
dat <- readTVar buf
if amount > ByteString.length dat
then retry
else do
writeTVar buf (ByteString.drop amount dat)
return $ ByteString.take amount dat