{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.HTTP2.TLS.IO where

import Control.Monad (void, when)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import Network.Socket
import Network.Socket.BufferPool
import qualified Network.Socket.ByteString as NSB
import Network.TLS hiding (HostName)
import System.IO.Error (isEOFError)
import qualified System.TimeManager as T
import qualified UnliftIO.Exception as E

import Network.HTTP2.TLS.Settings

----------------------------------------------------------------

-- HTTP2: confReadN == recvTLS
-- TLS:   recvData  == contextRecv == backendRecv

----------------------------------------------------------------

mkRecvTCP :: Settings -> Socket -> IO (IO ByteString)
mkRecvTCP :: Settings -> Socket -> IO (IO ByteString)
mkRecvTCP Settings{Int
settingReadBufferLowerLimit :: Settings -> Int
settingReadBufferSize :: Settings -> Int
settingsSlowlorisSize :: Settings -> Int
settingsSendBufferSize :: Settings -> Int
settingsTimeout :: Settings -> Int
settingReadBufferLowerLimit :: Int
settingReadBufferSize :: Int
settingsSlowlorisSize :: Int
settingsSendBufferSize :: Int
settingsTimeout :: Int
..} Socket
sock = do
    BufferPool
pool <- Int -> Int -> IO BufferPool
newBufferPool Int
settingReadBufferLowerLimit Int
settingReadBufferSize
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Socket -> BufferPool -> IO ByteString
receive Socket
sock BufferPool
pool

sendTCP :: Socket -> ByteString -> IO ()
sendTCP :: Socket -> ByteString -> IO ()
sendTCP Socket
sock = Socket -> ByteString -> IO ()
NSB.sendAll Socket
sock

----------------------------------------------------------------

-- | Sending and receiving functions.
--   Tiemout is reset when they return.
--   One exception is the slowloris attach prevention.
--   See 'settingsSlowlorisSize'.
data IOBackend = IOBackend
    { IOBackend -> ByteString -> IO ()
send :: ByteString -> IO ()
    -- ^ Sending.
    , IOBackend -> [ByteString] -> IO ()
sendMany :: [ByteString] -> IO ()
    -- ^ Sending many.
    , IOBackend -> IO ByteString
recv :: IO ByteString
    -- ^ Receiving.
    }

timeoutIOBackend :: T.Handle -> Settings -> IOBackend -> IOBackend
timeoutIOBackend :: Handle -> Settings -> IOBackend -> IOBackend
timeoutIOBackend Handle
th Settings{Int
settingReadBufferLowerLimit :: Int
settingReadBufferSize :: Int
settingsSlowlorisSize :: Int
settingsSendBufferSize :: Int
settingsTimeout :: Int
settingReadBufferLowerLimit :: Settings -> Int
settingReadBufferSize :: Settings -> Int
settingsSlowlorisSize :: Settings -> Int
settingsSendBufferSize :: Settings -> Int
settingsTimeout :: Settings -> Int
..} IOBackend{IO ByteString
[ByteString] -> IO ()
ByteString -> IO ()
recv :: IO ByteString
sendMany :: [ByteString] -> IO ()
send :: ByteString -> IO ()
recv :: IOBackend -> IO ByteString
sendMany :: IOBackend -> [ByteString] -> IO ()
send :: IOBackend -> ByteString -> IO ()
..} =
    (ByteString -> IO ())
-> ([ByteString] -> IO ()) -> IO ByteString -> IOBackend
IOBackend ByteString -> IO ()
send' [ByteString] -> IO ()
sendMany' IO ByteString
recv'
  where
    send' :: ByteString -> IO ()
send' ByteString
bs = ByteString -> IO ()
send ByteString
bs forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
T.tickle Handle
th
    sendMany' :: [ByteString] -> IO ()
sendMany' [ByteString]
bss = [ByteString] -> IO ()
sendMany [ByteString]
bss forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
T.tickle Handle
th
    recv' :: IO ByteString
recv' = do
        ByteString
bs <- IO ByteString
recv
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Int
BS.length ByteString
bs forall a. Ord a => a -> a -> Bool
> Int
settingsSlowlorisSize) forall a b. (a -> b) -> a -> b
$ Handle -> IO ()
T.tickle Handle
th
        forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

tlsIOBackend :: Context -> IOBackend
tlsIOBackend :: Context -> IOBackend
tlsIOBackend Context
ctx =
    IOBackend
        { send :: ByteString -> IO ()
send = Context -> ByteString -> IO ()
sendTLS Context
ctx
        , sendMany :: [ByteString] -> IO ()
sendMany = Context -> [ByteString] -> IO ()
sendManyTLS Context
ctx
        , recv :: IO ByteString
recv = Context -> IO ByteString
recvTLS Context
ctx
        }

tcpIOBackend :: Settings -> Socket -> IO IOBackend
tcpIOBackend :: Settings -> Socket -> IO IOBackend
tcpIOBackend Settings
settings Socket
sock = do
    IO ByteString
recv' <- Settings -> Socket -> IO (IO ByteString)
mkRecvTCP Settings
settings Socket
sock
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
        IOBackend
            { send :: ByteString -> IO ()
send = forall (f :: * -> *) a. Functor f => f a -> f ()
void forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> ByteString -> IO Int
NSB.send Socket
sock
            , sendMany :: [ByteString] -> IO ()
sendMany = \[ByteString]
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
            , recv :: IO ByteString
recv = IO ByteString
recv'
            }

----------------------------------------------------------------

sendTLS :: Context -> ByteString -> IO ()
sendTLS :: Context -> ByteString -> IO ()
sendTLS Context
ctx = forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
sendData Context
ctx forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
LBS.fromStrict

sendManyTLS :: Context -> [ByteString] -> IO ()
sendManyTLS :: Context -> [ByteString] -> IO ()
sendManyTLS Context
ctx = forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
sendData Context
ctx forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
LBS.fromChunks

-- TLS version of recv (decrypting) without a cache.
recvTLS :: Context -> IO ByteString
recvTLS :: Context -> IO ByteString
recvTLS Context
ctx = forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
E.handle forall {m :: * -> *} {a}.
(IsString a, MonadIO m) =>
SomeException -> m a
onEOF forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadIO m => Context -> m ByteString
recvData Context
ctx
  where
    onEOF :: SomeException -> m a
onEOF SomeException
e
        | Just TLSError
Error_EOF <- forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
e = forall (m :: * -> *) a. Monad m => a -> m a
return a
""
        | Just IOError
ioe <- forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
e, IOError -> Bool
isEOFError IOError
ioe = forall (m :: * -> *) a. Monad m => a -> m a
return a
""
        | Bool
otherwise = forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO SomeException
e

----------------------------------------------------------------

mkBackend :: Settings -> Socket -> IO Backend
mkBackend :: Settings -> Socket -> IO Backend
mkBackend Settings
settings Socket
sock = do
    let send' :: ByteString -> IO ()
send' = Socket -> ByteString -> IO ()
sendTCP Socket
sock
    IO ByteString
recv' <- Settings -> Socket -> IO (IO ByteString)
mkRecvTCP Settings
settings Socket
sock
    RecvN
recvN <- ByteString -> IO ByteString -> IO RecvN
makeRecvN ByteString
"" IO ByteString
recv'
    forall (m :: * -> *) a. Monad m => a -> m a
return
        Backend
            { backendFlush :: IO ()
backendFlush = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            , backendClose :: IO ()
backendClose =
                Socket -> Int -> IO ()
gracefulClose Socket
sock Int
5000 forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`E.catch` \(E.SomeException e
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
            , backendSend :: ByteString -> IO ()
backendSend = ByteString -> IO ()
send'
            , backendRecv :: RecvN
backendRecv = RecvN
recvN
            }