-- | A Backend represents a unified way to do IO on different
-- types without burdening our calling API with multiple
-- ways to initialize a new context.
--
-- Typically, a backend provides:
-- * a way to read data
-- * a way to write data
-- * a way to close the stream
-- * a way to flush the stream
module Network.TLS.Backend (
    HasBackend (..),
    Backend (..),
) where

import qualified Data.ByteString as B
import qualified Network.Socket as Network
import qualified Network.Socket.ByteString as Network
import Network.TLS.Imports
import System.IO (BufferMode (..), Handle, hClose, hFlush, hSetBuffering)

-- | Connection IO backend
data Backend = Backend
    { Backend -> IO ()
backendFlush :: IO ()
    -- ^ Flush the connection sending buffer, if any.
    , Backend -> IO ()
backendClose :: IO ()
    -- ^ Close the connection.
    , Backend -> ByteString -> IO ()
backendSend :: ByteString -> IO ()
    -- ^ Send a bytestring through the connection.
    , Backend -> Int -> IO ByteString
backendRecv :: Int -> IO ByteString
    -- ^ Receive specified number of bytes from the connection.
    }

class HasBackend a where
    initializeBackend :: a -> IO ()
    getBackend :: a -> Backend

instance HasBackend Backend where
    initializeBackend :: Backend -> IO ()
initializeBackend Backend
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    getBackend :: Backend -> Backend
getBackend = Backend -> Backend
forall a. a -> a
id

safeRecv :: Network.Socket -> Int -> IO ByteString
safeRecv :: Socket -> Int -> IO ByteString
safeRecv = Socket -> Int -> IO ByteString
Network.recv

instance HasBackend Network.Socket where
    initializeBackend :: Socket -> IO ()
initializeBackend Socket
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    getBackend :: Socket -> Backend
getBackend Socket
sock = IO ()
-> IO ()
-> (ByteString -> IO ())
-> (Int -> IO ByteString)
-> Backend
Backend (() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (Socket -> IO ()
Network.close Socket
sock) (Socket -> ByteString -> IO ()
Network.sendAll Socket
sock) Int -> IO ByteString
recvAll
      where
        recvAll :: Int -> IO ByteString
recvAll Int
n = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> IO [ByteString] -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO [ByteString]
loop Int
n
          where
            loop :: Int -> IO [ByteString]
loop Int
0 = [ByteString] -> IO [ByteString]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
            loop Int
left = do
                ByteString
r <- Socket -> Int -> IO ByteString
safeRecv Socket
sock Int
left
                if ByteString -> Bool
B.null ByteString
r
                    then [ByteString] -> IO [ByteString]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
                    else (ByteString
r ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:) ([ByteString] -> [ByteString])
-> IO [ByteString] -> IO [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO [ByteString]
loop (Int
left Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
B.length ByteString
r)

instance HasBackend Handle where
    initializeBackend :: Handle -> IO ()
initializeBackend Handle
handle = Handle -> BufferMode -> IO ()
hSetBuffering Handle
handle BufferMode
NoBuffering
    getBackend :: Handle -> Backend
getBackend Handle
handle = IO ()
-> IO ()
-> (ByteString -> IO ())
-> (Int -> IO ByteString)
-> Backend
Backend (Handle -> IO ()
hFlush Handle
handle) (Handle -> IO ()
hClose Handle
handle) (Handle -> ByteString -> IO ()
B.hPut Handle
handle) (Handle -> Int -> IO ByteString
B.hGet Handle
handle)