-- This Source Code Form is subject to the terms of the Mozilla Public
-- License, v. 2.0. If a copy of the MPL was not distributed with this
-- file, You can obtain one at http://mozilla.org/MPL/2.0/.

{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Database.CQL.IO.Connection
    ( Connection
    , resolve
    , ping
    , connect
    , close
    , request
    , startup
    , register
    , query
    , useKeyspace
    , address
    , protocol
    , eventSig

    , ConnectionSettings
    , defSettings
    , connectTimeout
    , sendTimeout
    , responseTimeout
    , maxStreams
    , compression
    , defKeyspace
    , maxRecvBuffer
    ) where

import Control.Applicative
import Control.Concurrent (myThreadId)
import Control.Concurrent.Async
import Control.Concurrent.MVar
import Control.Exception (throwTo, AsyncException (ThreadKilled))
import Control.Lens ((^.), makeLenses, view)
import Control.Monad
import Control.Monad.Catch
import Control.Monad.IO.Class
import Data.ByteString.Builder
import Data.ByteString.Lazy (ByteString)
import Data.Int
import Data.Maybe (isJust, fromMaybe)
import Data.Monoid
import Data.Text.Lazy (fromStrict)
import Data.Unique
import Data.Vector (Vector, (!))
import Database.CQL.Protocol
import Database.CQL.IO.Hexdump
import Database.CQL.IO.Protocol
import Database.CQL.IO.Signal hiding (connect)
import Database.CQL.IO.Sync (Sync)
import Database.CQL.IO.Types
import Database.CQL.IO.Tickets (Pool, toInt, markAvailable)
import Database.CQL.IO.Timeouts (TimeoutManager, withTimeout)
import Foreign.C.Types (CInt (..))
import Network
import Network.Socket hiding (connect, close, recv, send)
import Network.Socket.ByteString.Lazy (sendAll)
import System.IO (nativeNewline, Newline (..))
import System.Logger hiding (Settings, close, defSettings, settings)
import System.Timeout

import qualified Data.ByteString            as B
import qualified Data.ByteString.Lazy       as L
import qualified Data.ByteString.Lazy.Char8 as Char8
import qualified Data.Vector                as Vector
import qualified Database.CQL.IO.Sync       as Sync
import qualified Database.CQL.IO.Tickets    as Tickets
import qualified Network.Socket             as S
import qualified Network.Socket.ByteString  as NB

data ConnectionSettings = ConnectionSettings
    { _connectTimeout  :: !Milliseconds
    , _sendTimeout     :: !Milliseconds
    , _responseTimeout :: !Milliseconds
    , _maxStreams      :: !Int
    , _compression     :: !Compression
    , _defKeyspace     :: !(Maybe Keyspace)
    , _maxRecvBuffer   :: !Int
    }

type Streams = Vector (Sync (Header, ByteString))

data Connection = Connection
    { _settings :: !ConnectionSettings
    , _address  :: !InetAddr
    , _tmanager :: !TimeoutManager
    , _protocol :: !Version
    , _sock     :: !Socket
    , _streams  :: !Streams
    , _wLock    :: !(MVar ())
    , _reader   :: !(Async ())
    , _tickets  :: !Pool
    , _logger   :: !Logger
    , _eventSig :: !(Signal Event)
    , _ident    :: !Unique
    }

makeLenses ''ConnectionSettings
makeLenses ''Connection

instance Eq Connection where
    a == b = a^.ident == b^.ident

instance Show Connection where
    show = Char8.unpack . eval . bytes

instance ToBytes Connection where
    bytes c = bytes (c^.address) +++ val "#" +++ fd (c^.sock)

defSettings :: ConnectionSettings
defSettings =
    ConnectionSettings 5000          -- connect timeout
                       3000          -- send timeout
                       10000         -- response timeout
                       128           -- max streams per connection
                       noCompression -- compression
                       Nothing       -- keyspace
                       16384         -- receive buffer size

resolve :: String -> PortNumber -> IO InetAddr
resolve host port =
    InetAddr . addrAddress . head <$> getAddrInfo (Just hints) (Just host) (Just (show port))
  where
    hints = defaultHints { addrFlags = [AI_ADDRCONFIG], addrSocketType = Stream }

connect :: MonadIO m => ConnectionSettings -> TimeoutManager -> Version -> Logger -> InetAddr -> m Connection
connect t m v g a = liftIO $
    bracketOnError (mkSock a) S.close $ \s -> do
        ok <- timeout (ms (t^.connectTimeout) * 1000) (S.connect s (sockAddr a))
        unless (isJust ok) $
            throwM (ConnectTimeout a)
        c <- open s
        validateSettings c
        return c
  where
    open s = do
        tck <- Tickets.pool (t^.maxStreams)
        syn <- Vector.replicateM (t^.maxStreams) Sync.create
        lck <- newMVar ()
        sig <- signal
        rdr <- async (readLoop v g t tck a s syn sig)
        Connection t a m v s syn lck rdr tck g sig <$> newUnique

mkSock :: InetAddr -> IO Socket
mkSock (InetAddr a) = socket (familyOf a) Stream defaultProtocol
  where
    familyOf (SockAddrInet  {..}) = AF_INET
    familyOf (SockAddrInet6 {..}) = AF_INET6
    familyOf (SockAddrUnix  {..}) = AF_UNIX

ping :: MonadIO m => InetAddr -> m Bool
ping a = liftIO $ bracket (mkSock a) S.close $ \s ->
    fromMaybe False <$> timeout 5000000
        ((S.connect s (sockAddr a) >> return True) `catchAll` const (return False))

readLoop :: Version -> Logger -> ConnectionSettings -> Pool -> InetAddr -> Socket -> Streams -> Signal Event -> IO ()
readLoop v g set tck i sck syn s = run `catch` logException `finally` cleanup
  where
    run = forever $ do
        x <- readSocket v g i sck (set^.maxRecvBuffer)
        case fromStreamId $ streamId (fst x) of
            -1 ->
                case parse (set^.compression) x :: Response () () () of
                    RsError _ e -> throwM e
                    RsEvent _ e -> emit s e
                    r           -> throwM (UnexpectedResponse' r)
            sid -> do
                ok <- Sync.put x (syn ! sid)
                unless ok $
                    markAvailable tck sid

    cleanup = do
        Tickets.close (ConnectionClosed i) tck
        Vector.mapM_ (Sync.close (ConnectionClosed i)) syn
        S.close sck

    logException :: SomeException -> IO ()
    logException e = case fromException e of
        Just ThreadKilled -> return ()
        _                 -> warn g $ msg i ~~ msg (val "read-loop: " +++ show e)

close :: Connection -> IO ()
close = cancel . view reader

request :: Connection -> (Int -> ByteString) -> IO (Header, ByteString)
request c f = send >>= receive
  where
    send = withTimeout (c^.tmanager) (c^.settings.sendTimeout) (close c) $ do
        i <- toInt <$> Tickets.get (c^.tickets)
        let req = f i
        trace (c^.logger) $ msg c
            ~~ "stream" .= i
            ~~ "type"   .= val "request"
            ~~ msg' (hexdump (L.take 160 req))
        withMVar (c^.wLock) $
            const $ sendAll (c^.sock) req
        return i

    receive i = do
        let e = TimeoutRead (show c ++ ":" ++ show i)
        tid <- myThreadId
        withTimeout (c^.tmanager) (c^.settings.responseTimeout) (throwTo tid e) $ do
            x <- Sync.get (view streams c ! i) `onException` (Sync.kill e) (view streams c ! i)
            markAvailable (c^.tickets) i
            return x

readSocket :: Version -> Logger -> InetAddr -> Socket -> Int -> IO (Header, ByteString)
readSocket v g i s n = do
    b <- recv n i s (if v == V3 then 9 else 8)
    h <- case header v b of
            Left  e -> throwM $ InternalError ("response header reading: " ++ e)
            Right h -> return h
    case headerType h of
        RqHeader -> throwM $ InternalError "unexpected request header"
        RsHeader -> do
            let len = lengthRepr (bodyLength h)
            x <- recv n i s (fromIntegral len)
            trace g $ msg (i +++ val "#" +++ fd s)
                ~~ "stream" .= fromStreamId (streamId h)
                ~~ "type"   .= val "response"
                ~~ msg' (hexdump $ L.take 160 (b <> x))
            return (h, x)

recv :: Int -> InetAddr -> Socket -> Int -> IO ByteString
recv _ _ _ 0 = return L.empty
recv x i s n = toLazyByteString <$> go n mempty
  where
    go !k !bb = do
        a <- NB.recv s (k `min` x)
        when (B.null a) $
            throwM (ConnectionClosed i)
        let b = bb <> byteString a
        let m = k - B.length a
        if m > 0 then go m b else return b

-----------------------------------------------------------------------------
-- Operations

startup :: MonadIO m => Connection -> m ()
startup c = liftIO $ do
    let cmp = c^.settings.compression
    let req = RqStartup (Startup Cqlv300 (algorithm cmp))
    let enc = serialise (c^.protocol) cmp (req :: Request () () ())
    res <- request c enc
    (parse cmp res :: Response () () ()) `seq` return ()

register :: MonadIO m => Connection -> [EventType] -> EventHandler -> m ()
register c e f = liftIO $ do
    let req = RqRegister (Register e) :: Request () () ()
    let enc = serialise (c^.protocol) (c^.settings.compression) req
    res <- request c enc
    case parse (c^.settings.compression) res :: Response () () () of
        RsReady _ Ready -> c^.eventSig |-> f
        other           -> throwM (UnexpectedResponse' other)

validateSettings :: MonadIO m => Connection -> m ()
validateSettings c = liftIO $ do
    Supported ca _ <- supportedOptions c
    let x = algorithm (c^.settings.compression)
    unless (x == None || x `elem` ca) $
        throwM $ UnsupportedCompression ca

supportedOptions :: MonadIO m => Connection -> m Supported
supportedOptions c = liftIO $ do
    let options = RqOptions Options :: Request () () ()
    res <- request c (serialise (c^.protocol) noCompression options)
    case parse noCompression res :: Response () () () of
        RsSupported _ x -> return x
        other           -> throwM (UnexpectedResponse' other)

useKeyspace :: MonadIO m => Connection -> Keyspace -> m ()
useKeyspace c ks = liftIO $ do
    let cmp    = c^.settings.compression
        params = QueryParams One False () Nothing Nothing Nothing
        kspace = quoted (fromStrict $ unKeyspace ks)
        req    = RqQuery (Query (QueryString $ "use " <> kspace) params)
    res <- request c (serialise (c^.protocol) cmp req)
    case parse cmp res :: Response () () () of
        RsResult _ (SetKeyspaceResult _) -> return ()
        other                            -> throwM (UnexpectedResponse' other)

query :: forall k a b m. (Tuple a, Tuple b, Show b, MonadIO m)
      => Connection
      -> Consistency
      -> QueryString k a b
      -> a
      -> m [b]
query c cons q p = liftIO $ do
    let req = RqQuery (Query q params) :: Request k a b
    let enc = serialise (c^.protocol) (c^.settings.compression) req
    res <- request c enc
    case parse (c^.settings.compression) res :: Response k a b of
        RsResult _ (RowsResult _ b) -> return b
        other                       -> throwM (UnexpectedResponse' other)
  where
    params = QueryParams cons False p Nothing Nothing Nothing

-- logging helpers:

fd :: Socket -> Int32
fd !s = let CInt !n = fdSocket s in n

msg' :: ByteString -> Msg -> Msg
msg' x = msg $ case nativeNewline of
    LF   -> val "\n"   +++ x
    CRLF -> val "\r\n" +++ x