-- |
-- Module:     Network.Smtp.Protocol
-- Copyright:  (c) 2010 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
-- Stability:  experimental
--
-- This module implements the low level SMTP protocol implementation.

{-# LANGUAGE DeriveDataTypeable, OverloadedStrings #-}

module Network.Smtp.Protocol
    ( -- * Types
      Extension,
      Mail,
      MailConfig(..),

      -- * Sessions
      -- ** Running sessions
      runMail,
      sendMail,
      sendMailDirect,
      -- ** Chatting with the server
      waitForWelcome,
      sendHello,
      sendMailFrom,
      sendRcptTo,
      sendData,
      sendReset,
      sendQuit,

      -- * Utilities
      -- ** Parsing
      codeParser,

      -- ** Input/output
      mailPut,
      mailPutList
    )
    where

import qualified Data.ByteString.Char8 as B
import qualified Data.Set as S
import Blaze.ByteString.Builder
import Control.Applicative
import Control.Concurrent
import Control.Concurrent.STM
import Control.Exception
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.State
import Data.Attoparsec as P (skipWhile, takeTill)
import Data.Attoparsec.Char8 as P hiding (skipWhile, takeTill)
import Data.Attoparsec.Enumerator
import Data.ByteString.Char8 (ByteString)
import Data.Enumerator as E hiding (map)
import Data.Enumerator.IO
import Data.List as L
import Data.Maybe
import Data.Monoid
import Data.Set (Set)
import Network.DnsCache
import Network.Fancy
import System.IO


-- ===== --
-- Types --
-- ===== --


-- | The 'Mail' monad transformer encapsulates an SMTP session.

type Mail a = StateT MailConfig (Iteratee ByteString IO) a


-- | Mail session configuration.

data MailConfig =
    MailConfig {
      mailExtensions     :: Set Extension,
      mailHandle         :: Handle
    }

defMailConfig :: Handle -> MailConfig
defMailConfig h =
    MailConfig { mailExtensions = S.empty,
                 mailHandle = h }


-- | SMTP service extension.

data Extension = Extension  -- ^ We don't know any extensions yet.
                 deriving (Eq, Ord)


-- ========= --
-- Iteratees --
-- ========= --

-- | Wait for 220 greeting.

waitForWelcome :: Mail ()
waitForWelcome = do
    accepted <- lift $ iterParser welcomeParser
    if accepted
      then return ()
      else lift $ throwError (userError "SMTP session rejected")


-- | Try *EHLO* with fallback to *HELO*.

sendHello :: ByteString -> Mail ()
sendHello domain = do
    mailPutList [ "EHLO ", domain, "\r\n" ]
    response <- lift $ iterParser ehloResponseParser
    case response of
      HelloOk exts -> modify (\cfg -> cfg { mailExtensions = exts })
      HelloTryHelo -> do
          mailPutList $ [ "HELO ", domain, "\r\n" ]
          lift . iterParser $ codeParser "250"
          return ()
      HelloInvalidArg -> lift $ throwError (userError "Invalid argument to EHLO")
      HelloUnavailable -> lift $ throwError (userError "Service unavailable")


-- | Send *MAIL FROM* command.

sendMailFrom :: ByteString -> Mail ()
sendMailFrom from = do
    mailPutList [ "MAIL FROM:<", from, ">\r\n" ]
    response <- lift $ iterParser mailFromResponseParser
    case response of
      MailFromOk -> return ()
      MailFromParseError ->
          lift $ throwError (userError "Parse error")
      MailFromAlreadySpecified ->
          lift $ throwError (userError "Sender already specified")


-- | Send *RCPT TO* command.  By specification this command can be
-- issued multiple times.

sendRcptTo :: ByteString -> Mail ()
sendRcptTo to = do
    mailPutList [ "RCPT TO:<", to, ">\r\n" ]
    response <- lift $ iterParser rcptToResponseParser
    case response of
      RcptToOk -> return ()
      RcptToUnknown ->
          lift $ throwError (userError "Recipient unknown")
      RcptToError ->
          lift $ throwError (userError "RCPT TO rejected, unknown error")


-- | Send *DATA* command followed by the actual mail content.

sendData :: Builder -> Mail ()
sendData content = do
    mailPutList ["DATA\r\n"]
    response1 <- lift $ iterParser dataResponseParser
    case response1 of
      DataOk ->
          lift . throwError $
          userError "Protocol error: Got 250 after sending DATA command."
      DataIntermediate -> do
          mailPut (mappend content (fromByteString ".\r\n"))
          response2 <- lift $ iterParser dataResponseParser
          case response2 of
            DataOk           -> return ()
            DataIntermediate ->
                lift . throwError $
                userError "Protocol error: Got 354 after sending mail."


-- | Send *RSET* command to abort the current SMTP transaction.

sendReset :: Mail ()
sendReset = do
    mailPutList ["RSET\r\n"]
    () <$ lift (iterParser (codeParser "250"))


-- | Send *QUIT* command to finish the SMTP session.

sendQuit :: Mail ()
sendQuit = do
    mailPutList ["QUIT\r\n"]
    lift $ do
        iterParser (codeParser "221")
        E.dropWhile (B.all $ inClass "\r\n")
        eof <- E.isEOF
        unless eof $ throwError (userError "Session still open after QUIT")


-- ======= --
-- Parsers --
-- ======= --


-- | Welcome notice.

welcomeParser :: Parser Bool
welcomeParser =
    P.try (True <$ codeParser "220") <|>
    (False      <$ codeParser "554")


-- | Responses for EHLO and HELO commands.

data HelloResponse
    = HelloOk (Set Extension) -- ^ Code 250 with set of extensions.
    | HelloTryHelo            -- ^ Codes 500, 502, and 554.
    | HelloInvalidArg         -- ^ Code 501.
    | HelloUnavailable        -- ^ Code 421.


-- | Parse EHLO reponse.

ehloResponseParser :: Parser HelloResponse
ehloResponseParser =
    choice [ HelloOk <$> P.try ok,
             HelloTryHelo <$ P.try tryHelo,
             HelloInvalidArg <$ P.try invArg,
             HelloUnavailable <$ unavail ]

    where
    ok = S.fromList . catMaybes . map stringToExtension . tail
         <$> codeParser "250"
    tryHelo =
        P.try (codeParser "500") <|>
        P.try (codeParser "502") <|>
        codeParser "554"
    invArg  = codeParser "501"
    unavail = codeParser "421"


-- | Responses for MAIL FROM command.

data MailFromResponse
    = MailFromOk                -- ^ Code 250.
    | MailFromParseError        -- ^ Code 501.
    | MailFromAlreadySpecified  -- ^ Code 503.


-- | Parse MAIL FROM response.

mailFromResponseParser :: Parser MailFromResponse
mailFromResponseParser =
    P.try (MailFromOk         <$ codeParser "250") <|>
    P.try (MailFromParseError <$ codeParser "501") <|>
    (MailFromAlreadySpecified <$ codeParser "503")


-- | Responses for RCPT TO command.

data RcptToResponse
    = RcptToOk       -- ^ Code 250.
    | RcptToUnknown  -- ^ Code 550.
    | RcptToError    -- ^ Code 554.


-- | Parse RCPT TO response.

rcptToResponseParser :: Parser RcptToResponse
rcptToResponseParser =
    P.try (RcptToOk      <$ codeParser "250") <|>
    P.try (RcptToUnknown <$ codeParser "550") <|>
    (RcptToError         <$ codeParser "554")


-- | Responses to DATA command.

data DataResponse
    = DataOk            -- ^ Code 250
    | DataIntermediate  -- ^ Code 354


-- | Parse DATA response.

dataResponseParser :: Parser DataResponse
dataResponseParser =
    P.try (DataOk     <$ codeParser "250") <|>
    (DataIntermediate <$ codeParser "354")


-- | Read SMTP code.

codeParser :: ByteString -> Parser [ByteString]
codeParser code =
    choice
    [ P.try (codeMoreParser code),
      codeFinalParser code ]


-- | Read SMTP continued code.

codeMoreParser :: ByteString -> Parser [ByteString]
codeMoreParser code = do
    skipWhile isEndOfLine
    string (code `B.snoc` '-')
    (:) <$> takeTill isEndOfLine
        <*> codeParser code


-- | Read SMTP final code.

codeFinalParser :: ByteString -> Parser [ByteString]
codeFinalParser code = do
    skipWhile isEndOfLine
    string (code `B.snoc` ' ')
    pure <$> takeTill isEndOfLine


-- ============= --
-- Sending mails --
-- ============= --

-- | Run a 'Mail' computation with the given session timeout in
-- microseconds.

runMail :: Int -> MailConfig -> Mail a -> IO a
runMail timeout cfg comp = do
    let h = mailHandle cfg
    timeoutVar <- registerDelay timeout
    resultVar <- newEmptyTMVarIO

    mailerThread <-
        forkIO $ do
            hSetBuffering h NoBuffering
            run (enumHandle 1 h $$ evalStateT comp cfg)
                >>= atomically . putTMVar resultVar

    result <-
        let timeout = do
                readTVar timeoutVar >>= check
                return . Left . toException $ userError "Timed out"
        in atomically $ timeout `orElse` readTMVar resultVar

    killThread mailerThread
    either throwIO return result


-- | Send mail via MX.

sendMail :: DnsMonad m => Int -> Domain -> Mail a -> m a
sendMail timeout domain c = do
    mx <- resolveMX domain
    let hostname = L.head $ concat (maybeToList mx) ++ [domain]
    liftIO . withStream (IP hostname 25) $ \h ->
        runMail timeout (defMailConfig h) c


-- | Send mail directly to a host.

sendMailDirect :: Int -> Address -> Mail a -> IO a
sendMailDirect timeout addr c =
    withStream addr $ \h ->
        runMail timeout (defMailConfig h) c


-- ================= --
-- Utility functions --
-- ================= --

-- | Convert extension string to 'Extension' value, if the corresponding
-- extension is known.

stringToExtension :: ByteString -> Maybe Extension
stringToExtension _ = Nothing


-- | Send a command to the SMTP peer.

mailPut :: Builder -> Mail ()
mailPut str = do
    h <- gets mailHandle
    liftIO $ toByteStringIO (B.hPutStr h) str


-- | Send a list of strings to the SMTP peer.

mailPutList :: [ByteString] -> Mail ()
mailPutList = mailPut . mconcat . map fromByteString