-- This Source Code Form is subject to the terms of the Mozilla Public
-- License, v. 2.0. If a copy of the MPL was not distributed with this
-- file, You can obtain one at http://mozilla.org/MPL/2.0/.

{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE ViewPatterns        #-}

module Database.CQL.IO.Connection
    ( Connection
    , ConnId
    , resolve
    , ping
    , connect
    , close
    , request
    , startup
    , register
    , query
    , useKeyspace
    , address
    , protocol
    , eventSig
    ) where

import Control.Applicative
import Control.Concurrent (myThreadId, forkIOWithUnmask)
import Control.Concurrent.Async
import Control.Concurrent.MVar
import Control.Concurrent.STM
import Control.Exception (throwTo, AsyncException (ThreadKilled))
import Control.Lens ((^.), makeLenses, view)
import Control.Monad
import Control.Monad.Catch
import Control.Monad.IO.Class
import Data.ByteString.Lazy (ByteString)
import Data.Int
import Data.Maybe (fromMaybe)
import Data.Monoid
import Data.Text.Lazy (fromStrict)
import Data.Unique
import Data.Vector (Vector, (!))
import Database.CQL.Protocol
import Database.CQL.IO.Connection.Socket (Socket)
import Database.CQL.IO.Connection.Settings
import Database.CQL.IO.Hexdump
import Database.CQL.IO.Protocol
import Database.CQL.IO.Signal hiding (connect)
import Database.CQL.IO.Sync (Sync)
import Database.CQL.IO.Types
import Database.CQL.IO.Tickets (Pool, toInt, markAvailable)
import Database.CQL.IO.Timeouts (TimeoutManager, withTimeout)
import Network.Socket hiding (Socket, close, connect, send)
import System.IO (nativeNewline, Newline (..))
import System.Logger hiding (Settings, close, defSettings, settings)
import System.Timeout
import Prelude

import qualified Data.ByteString.Lazy              as L
import qualified Data.ByteString.Lazy.Char8        as Char8
import qualified Data.HashMap.Strict               as HashMap
import qualified Data.Vector                       as Vector
import qualified Database.CQL.IO.Connection.Socket as Socket
import qualified Database.CQL.IO.Sync              as Sync
import qualified Database.CQL.IO.Tickets           as Tickets
import qualified Network.Socket                    as S

type Streams = Vector (Sync (Header, ByteString))

data Connection = Connection
    { _settings :: !ConnectionSettings
    , _address  :: !InetAddr
    , _tmanager :: !TimeoutManager
    , _protocol :: !Version
    , _sock     :: !Socket
    , _status   :: !(TVar Bool)
    , _streams  :: !Streams
    , _wLock    :: !(MVar ())
    , _reader   :: !(Async ())
    , _tickets  :: !Pool
    , _logger   :: !Logger
    , _eventSig :: !(Signal Event)
    , _ident    :: !ConnId
    }

makeLenses ''Connection

instance Eq Connection where
    a == b = a^.ident == b^.ident

instance Show Connection where
    show = Char8.unpack . eval . bytes

instance ToBytes Connection where
    bytes c = bytes (c^.address) +++ val "#" +++ c^.sock

resolve :: String -> PortNumber -> IO [InetAddr]
resolve host port =
    map (InetAddr . addrAddress) <$> getAddrInfo (Just hints) (Just host) (Just (show port))
  where
    hints = defaultHints { addrFlags = [AI_ADDRCONFIG], addrSocketType = Stream }

connect :: MonadIO m => ConnectionSettings -> TimeoutManager -> Version -> Logger -> InetAddr -> m Connection
connect t m v g a = liftIO $ do
    c <- bracketOnError (Socket.open (t^.connectTimeout) a (t^.tlsContext)) Socket.close $ \s -> do
        tck <- Tickets.pool (t^.maxStreams)
        syn <- Vector.replicateM (t^.maxStreams) Sync.create
        lck <- newMVar ()
        sta <- newTVarIO True
        sig <- signal
        rdr <- async (readLoop v g t tck a s syn sig sta lck)
        Connection t a m v s sta syn lck rdr tck g sig . ConnId <$> newUnique
    validateSettings c `onException` close c
    return c

ping :: MonadIO m => InetAddr -> m Bool
ping a = liftIO $ bracket (Socket.mkSock a) S.close $ \s ->
    fromMaybe False <$> timeout 5000000
        ((S.connect s (sockAddr a) >> return True) `catchAll` const (return False))

readLoop :: Version
         -> Logger
         -> ConnectionSettings
         -> Pool
         -> InetAddr
         -> Socket
         -> Streams
         -> Signal Event
         -> TVar Bool
         -> MVar ()
         -> IO ()
readLoop v g set tck i sck syn s sref wlck =
    run `catch` logException `finally` cleanup
  where
    run = forever $ do
        x <- readSocket v g i sck (set^.maxRecvBuffer)
        case fromStreamId $ streamId (fst x) of
            -1 ->
                case parse (set^.compression) x :: Raw Response of
                    RsError _ _ e -> throwM e
                    RsEvent _ _ e -> emit s e
                    r             -> throwM (UnexpectedResponse' r)
            sid -> do
                ok <- Sync.put x (syn ! sid)
                unless ok $
                    markAvailable tck sid

    cleanup = uninterruptibleMask_ $ do
        isOpen <- atomically $ swapTVar sref False
        when isOpen $ do
            Tickets.close (ConnectionClosed i) tck
            Vector.mapM_ (Sync.close (ConnectionClosed i)) syn
            void $ forkIOWithUnmask $ \unmask -> unmask $ do
                Socket.shutdown sck ShutdownReceive
                withMVar wlck (const $ Socket.close sck)

    logException :: SomeException -> IO ()
    logException e = case fromException e of
        Just ThreadKilled -> return ()
        _                 -> warn g $ msg i ~~ msg (val "read-loop: " +++ show e)

close :: Connection -> IO ()
close = cancel . view reader

request :: Connection -> (Int -> ByteString) -> IO (Header, ByteString)
request c f = send >>= receive
  where
    send = withTimeout (c^.tmanager) (c^.settings.sendTimeout) (close c) $ do
        i <- toInt <$> Tickets.get (c^.tickets)
        let req = f i
        trace (c^.logger) $ msg c
            ~~ "stream" .= i
            ~~ "type"   .= val "request"
            ~~ msg' (hexdump (L.take 160 req))
        withMVar (c^.wLock) $ const $ do
            isOpen <- readTVarIO (c^.status)
            if isOpen then
                Socket.send (c^.sock) req
            else
                throwM $ ConnectionClosed (c^.address)
        return i

    receive i = do
        let e = TimeoutRead (show c ++ ":" ++ show i)
        tid <- myThreadId
        withTimeout (c^.tmanager) (c^.settings.responseTimeout) (throwTo tid e) $ do
            x <- Sync.get (view streams c ! i) `onException` Sync.kill e (view streams c ! i)
            markAvailable (c^.tickets) i
            return x

readSocket :: Version -> Logger -> InetAddr -> Socket -> Int -> IO (Header, ByteString)
readSocket v g i s n = do
    b <- Socket.recv n i s 9
    h <- case header v b of
            Left  e -> throwM $ InternalError ("response header reading: " ++ e)
            Right h -> return h
    case headerType h of
        RqHeader -> throwM $ InternalError "unexpected request header"
        RsHeader -> do
            let len = lengthRepr (bodyLength h)
            x <- Socket.recv n i s (fromIntegral len)
            trace g $ msg (i +++ val "#" +++ s)
                ~~ "stream" .= fromStreamId (streamId h)
                ~~ "type"   .= val "response"
                ~~ msg' (hexdump $ L.take 160 (b <> x))
            return (h, x)

-----------------------------------------------------------------------------
-- Operations

startup :: MonadIO m => Connection -> m ()
startup c = liftIO $ do
    let cmp = c^.settings.compression
    let req = RqStartup (Startup Cqlv300 (algorithm cmp))
    let enc = serialise (c^.protocol) cmp (req :: Raw Request)
    res <- request c enc
    case parse cmp res :: Raw Response of
        RsReady _ _ Ready       -> checkAuth c
        RsAuthenticate _ _ auth -> authenticate c auth
        RsError _ _ e           -> throwM e
        other                   -> throwM $ UnexpectedResponse' other

checkAuth :: Connection -> IO ()
checkAuth c = unless (null (c^.settings.authenticators)) $
    warn (_logger c) $ msg $ val
        "Authentication configured but none required by server."

authenticate :: (MonadIO m, MonadThrow m) => Connection -> Authenticate -> m ()
authenticate c (Authenticate (AuthMechanism -> m)) =
    case HashMap.lookup m (c^.settings.authenticators) of
        Nothing -> throwM $ AuthenticationRequired m
        Just Authenticator {
            authOnRequest   = onR
          , authOnChallenge = onC
          , authOnSuccess   = onS
        } -> liftIO $ do
            (rs, s) <- onR context
            case onC of
                Just  f -> loop f onS (rs, s)
                Nothing -> authResponse c rs >>= either
                    (throwM . UnexpectedAuthenticationChallenge m)
                    (onS s)
  where
    context = AuthContext (c^.ident) (c^.address)

    loop onC onS (rs, s) =
        authResponse c rs >>= either
            (onC s >=> loop onC onS)
            (onS s)

authResponse :: MonadIO m
             => Connection
             -> AuthResponse
             -> m (Either AuthChallenge AuthSuccess)
authResponse c resp = liftIO $ do
    let cmp = c^.settings.compression
    let req = RqAuthResp resp
    let enc = serialise (c^.protocol) cmp (req :: Raw Request)
    res <- request c enc
    case parse cmp res :: Raw Response of
        RsAuthSuccess _ _ success     -> return $ Right success
        RsAuthChallenge _ _ challenge -> return $ Left challenge
        RsError _ _ e                 -> throwM e
        other                         -> throwM $ UnexpectedResponse' other

register :: MonadIO m => Connection -> [EventType] -> EventHandler -> m ()
register c e f = liftIO $ do
    let req = RqRegister (Register e) :: Raw Request
    let enc = serialise (c^.protocol) (c^.settings.compression) req
    res <- request c enc
    case parse (c^.settings.compression) res :: Raw Response of
        RsReady _ _ Ready -> c^.eventSig |-> f
        other           -> throwM (UnexpectedResponse' other)

validateSettings :: MonadIO m => Connection -> m ()
validateSettings c = liftIO $ do
    Supported ca _ <- supportedOptions c
    let x = algorithm (c^.settings.compression)
    unless (x == None || x `elem` ca) $
        throwM $ UnsupportedCompression ca

supportedOptions :: MonadIO m => Connection -> m Supported
supportedOptions c = liftIO $ do
    let options = RqOptions Options :: Raw Request
    res <- request c (serialise (c^.protocol) noCompression options)
    case parse noCompression res :: Raw Response of
        RsSupported _ _ x -> return x
        other           -> throwM (UnexpectedResponse' other)

useKeyspace :: MonadIO m => Connection -> Keyspace -> m ()
useKeyspace c ks = liftIO $ do
    let cmp    = c^.settings.compression
        params = QueryParams One False () Nothing Nothing Nothing Nothing
        kspace = quoted (fromStrict $ unKeyspace ks)
        req    = RqQuery (Query (QueryString $ "use " <> kspace) params)
    res <- request c (serialise (c^.protocol) cmp req)
    case parse cmp res :: Raw Response of
        RsResult _ _ (SetKeyspaceResult _) -> return ()
        other                            -> throwM (UnexpectedResponse' other)

query :: forall k a b m. (Tuple a, Tuple b, Show b, MonadIO m)
      => Connection
      -> Consistency
      -> QueryString k a b
      -> a
      -> m [b]
query c cons q p = liftIO $ do
    let req = RqQuery (Query q params) :: Request k a b
    let enc = serialise (c^.protocol) (c^.settings.compression) req
    res <- request c enc
    case parse (c^.settings.compression) res :: Response k a b of
        RsResult _ _ (RowsResult _ b) -> return b
        other                       -> throwM (UnexpectedResponse' other)
  where
    params = QueryParams cons False p Nothing Nothing Nothing Nothing

-- logging helpers:

msg' :: ByteString -> Msg -> Msg
msg' x = msg $ case nativeNewline of
    LF   -> val "\n"   +++ x
    CRLF -> val "\r\n" +++ x