{- This file is part of monad-connect.
 -
 - Written in 2016 by fr33domlover <fr33domlover@riseup.net>.
 -
 - ♡ Copying is an act of love. Please copy, reuse and share.
 -
 - The author(s) have dedicated all copyright and related and neighboring
 - rights to this software to the public domain worldwide. This software is
 - distributed without any warranty.
 -
 - You should have received a copy of the CC0 Public Domain Dedication along
 - with this software. If not, see
 - <http://creativecommons.org/publicdomain/zero/1.0/>.
 -}

{-# LANGUAGE GeneralizedNewtypeDeriving #-}

-- | Simple monad transformer allowing your monad stack to manage a single TCP
-- connection, optionally over TLS and\/or SOCKS.
module Control.Monad.Trans.Connect
    ( -- * Connect and disconnect
      ConnectT ()
    , runConnectT
    , runConnectTWithCtx
    , connClose
      -- * Receive data
    , connGetSome
    , connGetChunk
    , connGetChunk'
    , connGetLine
      -- * Send data
    , connPut
      -- * TLS
    , connSetSecure
    , connSetSecureWithCtx
    , connIsSecure
    )
where

import Control.Monad.Catch
import Control.Monad.Fix (MonadFix)
import Control.Monad.IO.Class
import Control.Monad.Trans.Class (MonadTrans)
import Control.Monad.Trans.Reader
import Data.ByteString (ByteString)

import qualified Network.Connection as NC

--------------TODO TODO TODO careful exception handling, especially to allow
--------------continuing with the same monadic computation after recovery


-- | Monad transformer allowing your monad stack to manage a single TCP
-- connection, optionally over TLS and\/or SOCKS. A single 'ConnectT'
-- computation sequence opens a connection, possibly sends and receives data,
-- and finally closes it.
newtype ConnectT m a = ConnectT
    { unCT :: ReaderT (NC.ConnectionContext, NC.Connection) m a
    }
    deriving
        ( -- Basics
          Functor
        , Applicative
        , Monad
          -- Extra monads from base
        , MonadFix
          -- Network operations are IO
        , MonadIO
          -- This is a transformer after all
        , MonadTrans
          -- Exceptions
        , MonadCatch
        , MonadThrow
        , MonadMask
        )

-- | Execute a computation which has a network connection available to it. If
-- an exception is thrown (e.g. due to network error) and the computation
-- doesn't catch it, the connection is closed and the exception is rethrown.
runConnectT
    :: (MonadIO m, MonadMask m)
    => ConnectT m a
    -- ^ The computation
    -> NC.ConnectionParams
    -- ^ Network connection details
    -> Bool
    -- ^ Whether to close the connection if the computation finishes without
    -- errors (the connection is always closed in case of error unless you
    -- catch the error and handle it your way).
    -> m a
runConnectT act params close = do
    ctx <- liftIO NC.initConnectionContext
    runConnectTWithCtx act ctx params close

-- | A variant of 'runConnectT' which takes an existing 'NC.ConnectionContext'
-- instead of initializing a new one. This is useful if you want to start
-- several connections (in which case you can share a single context between
-- them).
runConnectTWithCtx
    :: (MonadIO m, MonadMask m)
    => ConnectT m a
    -> NC.ConnectionContext
    -> NC.ConnectionParams
    -> Bool
    -> m a
runConnectTWithCtx act ctx params close =
    (if close then bracket else bracketOnError)
        (liftIO $ NC.connectTo ctx params)
        (liftIO . NC.connectionClose)
        (\ conn -> runReaderT (unCT act) (ctx, conn))

askConn :: Monad m => ConnectT m NC.Connection
askConn = snd <$> ConnectT ask

askCtx :: Monad m => ConnectT m NC.ConnectionContext
askCtx = fst <$> ConnectT ask

-- | Close a connection. After you do that, don't use the connection. If the
-- connection is already closed, for a plain TCP connection this currently does
-- nothing. But when using TLS, it assumes the connection is still open. So to
-- be safe, don't use this more than once.
--
-- If you asked for the connection to be closed after a succesful computation
-- (i.e. passed 'True' to 'runConnectT'), then the connection will be closed
-- for you and calling this function is unnecessary (and possibly unsafe).
connClose :: MonadIO m => ConnectT m ()
connClose = askConn >>= liftIO . NC.connectionClose

-- | Wait until there is data to read from the connection, and return it. The
-- parameter specifies the maximal number of bytes to read. Less than that can
-- be returned, depending on how much data was available.
--
-- On end of input, an empty bytestring is returned, but subsequent calls will
-- throw an @isEOFError@ exception.
connGetSome :: MonadIO m => Int -> ConnectT m ByteString
connGetSome maxlen = askConn >>= liftIO . flip NC.connectionGet maxlen

-- | Get the next block of data from the connection.
connGetChunk :: MonadIO m => ConnectT m ByteString
connGetChunk = askConn >>= liftIO . NC.connectionGetChunk

-- | Like 'connGetChunk', but return the unused portion to the buffer, where it
-- will be the next chunk read.
connGetChunk' :: MonadIO m => (ByteString -> (a, ByteString)) -> ConnectT m a
connGetChunk' f = askConn >>= liftIO . flip NC.connectionGetChunk' f

-- | Get the next line, using ASCII LF (@'\n'@) as the line terminator.
--
-- This throws an @isEOFError@ exception on end of input, and 'NC.LineTooLong'
-- when the number of bytes gathered is over the limit without a line
-- terminator.
--
-- The actual line returned can be bigger than the limit specified, provided
-- that the last chunk returned by the underlaying backend contains a LF.
--
-- An end of file will be considered as a line terminator too, if the line is
-- not empty.
connGetLine :: MonadIO m => Int -> ConnectT m ByteString
connGetLine maxlen = askConn >>= liftIO . NC.connectionGetLine maxlen

-- | Send a block of data over the connection.
connPut :: MonadIO m => ByteString -> ConnectT m ()
connPut s = askConn >>= liftIO . flip NC.connectionPut s

-- | Activate secure layer using the parameters specified.
--
-- This is typically used to negociate a TLS channel on an already establish
-- channel, e.g. supporting a STARTTLS command. It also flushes the received
-- buffer to prevent application from confusing received data before and after
-- the setSecure call.
--
-- If the connection is already using TLS, nothing else happens.
connSetSecure :: MonadIO m => NC.TLSSettings -> ConnectT m ()
connSetSecure tlss = askCtx >>= flip connSetSecureWithCtx tlss

-- | Like 'connSetSecure', but uses the supplied 'NC.ConnectionContext' instead
-- of using the same one which was used when opening the connection.
connSetSecureWithCtx
    :: MonadIO m
    => NC.ConnectionContext
    -> NC.TLSSettings
    -> ConnectT m ()
connSetSecureWithCtx ctx tlss = do
    conn <- askConn
    liftIO $ NC.connectionSetSecure ctx conn tlss

-- | Check whether the connection is established securely (using TLS).
connIsSecure :: MonadIO m => ConnectT m Bool
connIsSecure = askConn >>= liftIO . NC.connectionIsSecure