module Network.Mail.Postie.Connection
  ( Connection,
    connIsSecure,
    connSetSecure,
    connRecv,
    connSend,
    connClose,
    mkSocketConnection,
    toProducer,
  )
where

import Control.Exception (finally)
import Control.Monad (unless)
import Control.Monad.IO.Class
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import Data.ByteString.Lazy.Internal (defaultChunkSize)
import Data.IORef
import Network.Socket
import Network.Socket.ByteString hiding (sendAll)
import Network.Socket.ByteString.Lazy (sendAll)
import Network.TLS
import qualified Pipes as P

data ConnectionBackend
  = ConnPlain Socket
  | ConnSecure Context

newtype Connection = Connection (IORef ConnectionBackend)

connSetSecure :: Connection -> ServerParams -> IO ()
connSetSecure :: Connection -> ServerParams -> IO ()
connSetSecure (Connection cbe :: IORef ConnectionBackend
cbe) params :: ServerParams
params = do
  ConnectionBackend
backend <- IORef ConnectionBackend -> IO ConnectionBackend
forall a. IORef a -> IO a
readIORef IORef ConnectionBackend
cbe
  ConnectionBackend
securedBackend <- ConnectionBackend -> IO ConnectionBackend
forall (m :: * -> *).
MonadIO m =>
ConnectionBackend -> m ConnectionBackend
upgrade ConnectionBackend
backend
  IORef ConnectionBackend -> ConnectionBackend -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ConnectionBackend
cbe ConnectionBackend
securedBackend
  where
    upgrade :: ConnectionBackend -> m ConnectionBackend
upgrade (ConnPlain be :: Socket
be) = do
      Context
context <- Socket -> ServerParams -> m Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Socket
be ServerParams
params
      Context -> m ()
forall (m :: * -> *). MonadIO m => Context -> m ()
handshake Context
context
      ConnectionBackend -> m ConnectionBackend
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> ConnectionBackend
ConnSecure Context
context)
    upgrade (ConnSecure _) = [Char] -> m ConnectionBackend
forall a. HasCallStack => [Char] -> a
error "already on secure connection"

connIsSecure :: Connection -> IO Bool
connIsSecure :: Connection -> IO Bool
connIsSecure (Connection cbe :: IORef ConnectionBackend
cbe) = do
  ConnectionBackend
backend <- IORef ConnectionBackend -> IO ConnectionBackend
forall a. IORef a -> IO a
readIORef IORef ConnectionBackend
cbe
  Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ case ConnectionBackend
backend of
    (ConnSecure _) -> Bool
True
    _ -> Bool
False

mkSocketConnection :: Socket -> IO Connection
mkSocketConnection :: Socket -> IO Connection
mkSocketConnection s :: Socket
s = do
  IORef ConnectionBackend
conn <- ConnectionBackend -> IO (IORef ConnectionBackend)
forall a. a -> IO (IORef a)
newIORef (Socket -> ConnectionBackend
ConnPlain Socket
s)
  Connection -> IO Connection
forall (m :: * -> *) a. Monad m => a -> m a
return (IORef ConnectionBackend -> Connection
Connection IORef ConnectionBackend
conn)

connBackendRecv :: ConnectionBackend -> IO BS.ByteString
connBackendRecv :: ConnectionBackend -> IO ByteString
connBackendRecv (ConnPlain s :: Socket
s) = Socket -> Int -> IO ByteString
recv Socket
s Int
defaultChunkSize
connBackendRecv (ConnSecure ctx :: Context
ctx) = Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
recvData Context
ctx

connBackendSend :: ConnectionBackend -> LBS.ByteString -> IO ()
connBackendSend :: ConnectionBackend -> ByteString -> IO ()
connBackendSend (ConnPlain s :: Socket
s) = Socket -> ByteString -> IO ()
sendAll Socket
s
connBackendSend (ConnSecure ctx :: Context
ctx) = Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
sendData Context
ctx

connRecv :: Connection -> IO BS.ByteString
connRecv :: Connection -> IO ByteString
connRecv (Connection cbe :: IORef ConnectionBackend
cbe) = IORef ConnectionBackend -> IO ConnectionBackend
forall a. IORef a -> IO a
readIORef IORef ConnectionBackend
cbe IO ConnectionBackend
-> (ConnectionBackend -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ConnectionBackend -> IO ByteString
connBackendRecv

connSend :: Connection -> LBS.ByteString -> IO ()
connSend :: Connection -> ByteString -> IO ()
connSend (Connection cbe :: IORef ConnectionBackend
cbe) lbs :: ByteString
lbs = do
  ConnectionBackend
backend <- IORef ConnectionBackend -> IO ConnectionBackend
forall a. IORef a -> IO a
readIORef IORef ConnectionBackend
cbe
  ConnectionBackend -> ByteString -> IO ()
connBackendSend ConnectionBackend
backend ByteString
lbs

connClose :: Connection -> IO ()
connClose :: Connection -> IO ()
connClose (Connection cbe :: IORef ConnectionBackend
cbe) = ConnectionBackend -> IO ()
closeBackend (ConnectionBackend -> IO ()) -> IO ConnectionBackend -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IORef ConnectionBackend -> IO ConnectionBackend
forall a. IORef a -> IO a
readIORef IORef ConnectionBackend
cbe
  where
    closeBackend :: ConnectionBackend -> IO ()
closeBackend (ConnPlain s :: Socket
s) = Socket -> IO ()
close Socket
s
    closeBackend (ConnSecure context :: Context
context) = Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
bye Context
context IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` Context -> IO ()
contextClose Context
context

toProducer :: (MonadIO m) => Connection -> P.Producer' BS.ByteString m ()
toProducer :: Connection -> Producer' ByteString m ()
toProducer conn :: Connection
conn = Proxy x' x () ByteString m ()
Producer' ByteString m ()
go
  where
    go :: Proxy x' x () ByteString m ()
go = do
      ByteString
bs <- IO ByteString -> Proxy x' x () ByteString m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> Proxy x' x () ByteString m ByteString)
-> IO ByteString -> Proxy x' x () ByteString m ByteString
forall a b. (a -> b) -> a -> b
$ Connection -> IO ByteString
connRecv Connection
conn
      Bool
-> Proxy x' x () ByteString m () -> Proxy x' x () ByteString m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
bs) (Proxy x' x () ByteString m () -> Proxy x' x () ByteString m ())
-> Proxy x' x () ByteString m () -> Proxy x' x () ByteString m ()
forall a b. (a -> b) -> a -> b
$
        ByteString -> Producer' ByteString m ()
forall (m :: * -> *) a. Functor m => a -> Producer' a m ()
P.yield ByteString
bs Proxy x' x () ByteString m ()
-> Proxy x' x () ByteString m () -> Proxy x' x () ByteString m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Proxy x' x () ByteString m ()
go