{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
module Database.CQL.IO.Client
( Client
, MonadClient (..)
, ClientState
, DebugInfo (..)
, ControlState (..)
, runClient
, init
, shutdown
, request
, requestN
, request1
, execute
, executeWithPrepare
, prepare
, retry
, once
, debugInfo
, preparedQueries
, withPrepareStrategy
, getResult
, unexpected
, C.defQueryParams
) where
import Control.Applicative
import Control.Concurrent (threadDelay, forkIO)
import Control.Concurrent.Async (async, wait)
import Control.Concurrent.STM (STM, atomically)
import Control.Concurrent.STM.TVar
import Control.Exception (IOException, SomeAsyncException (..))
import Control.Lens (makeLenses, (^.), set, over, view)
import Control.Monad (when, unless)
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad.IO.Unlift
import Control.Monad.Reader (ReaderT (..), runReaderT, MonadReader, ask)
import Control.Monad.Trans.Class
import Control.Monad.Trans.Except
import Control.Retry (capDelay, exponentialBackoff, rsIterNumber)
import Control.Retry (recovering)
import Data.Foldable (for_, foldrM)
import Data.List (find)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe, listToMaybe)
import Data.Semigroup
import Data.Text.Encoding (encodeUtf8)
import Data.Word
import Database.CQL.IO.Cluster.Host
import Database.CQL.IO.Cluster.Policies
import Database.CQL.IO.Connection (Connection, host, Raw)
import Database.CQL.IO.Connection.Settings
import Database.CQL.IO.Exception
import Database.CQL.IO.Jobs
import Database.CQL.IO.Log
import Database.CQL.IO.Pool (Pool)
import Database.CQL.IO.PrepQuery (PrepQuery, PreparedQueries)
import Database.CQL.IO.Settings
import Database.CQL.IO.Signal
import Database.CQL.IO.Timeouts (TimeoutManager)
import Database.CQL.Protocol hiding (Map)
import OpenSSL.Session (SomeSSLException)
import Prelude hiding (init)
import qualified Control.Monad.Reader as Reader
import qualified Control.Monad.State.Strict as S
import qualified Control.Monad.State.Lazy as LS
import qualified Data.List.NonEmpty as NE
import qualified Data.Map.Strict as Map
import qualified Database.CQL.IO.Cluster.Discovery as Disco
import qualified Database.CQL.IO.Connection as C
import qualified Database.CQL.IO.Pool as Pool
import qualified Database.CQL.IO.PrepQuery as PQ
import qualified Database.CQL.IO.Timeouts as TM
import qualified Database.CQL.Protocol as Cql
data ControlState
= Connected
| Reconnecting
| Disconnected
deriving (Eq, Ord, Show)
data Control = Control
{ _state :: !ControlState
, _connection :: !Connection
}
data Context = Context
{ _settings :: !Settings
, _timeouts :: !TimeoutManager
, _sigMonit :: !(Signal HostEvent)
}
data ClientState = ClientState
{ _context :: !Context
, _policy :: !Policy
, _prepQueries :: !PreparedQueries
, _control :: !(TVar Control)
, _hostmap :: !(TVar (Map Host Pool))
, _jobs :: !(Jobs InetAddr)
}
makeLenses ''Control
makeLenses ''Context
makeLenses ''ClientState
newtype Client a = Client
{ client :: ReaderT ClientState IO a
} deriving ( Functor
, Applicative
, Monad
, MonadIO
, MonadUnliftIO
, MonadThrow
, MonadCatch
, MonadMask
, MonadReader ClientState
)
class (MonadIO m, MonadThrow m) => MonadClient m
where
liftClient :: Client a -> m a
localState :: (ClientState -> ClientState) -> m a -> m a
instance MonadClient Client where
liftClient = id
localState = Reader.local
instance MonadClient m => MonadClient (ReaderT r m) where
liftClient = lift . liftClient
localState f m = ReaderT (localState f . runReaderT m)
instance MonadClient m => MonadClient (S.StateT s m) where
liftClient = lift . liftClient
localState f m = S.StateT (localState f . S.runStateT m)
instance MonadClient m => MonadClient (LS.StateT s m) where
liftClient = lift . liftClient
localState f m = LS.StateT (localState f . LS.runStateT m)
instance MonadClient m => MonadClient (ExceptT e m) where
liftClient = lift . liftClient
localState f m = ExceptT $ localState f (runExceptT m)
runClient :: MonadIO m => ClientState -> Client a -> m a
runClient p a = liftIO $ runReaderT (client a) p
retry :: MonadClient m => RetrySettings -> m a -> m a
retry r = localState (set (context.settings.retrySettings) r)
once :: MonadClient m => m a -> m a
once = retry noRetry
withPrepareStrategy :: MonadClient m => PrepareStrategy -> m a -> m a
withPrepareStrategy s = localState (set (context.settings.prepStrategy) s)
request :: (MonadClient m, Tuple a, Tuple b) => Request k a b -> m (HostResponse k a b)
request a = liftClient $ do
n <- liftIO . hostCount =<< view policy
withRetries (requestN n) a
requestN :: (Tuple b, Tuple a)
=> Word
-> Request k a b
-> ClientState
-> Client (HostResponse k a b)
requestN !n a s = liftIO (select (s^.policy)) >>= \case
Nothing -> replaceControl >> throwM NoHostAvailable
Just h -> tryRequest1 h a s >>= \case
Just hr -> return hr
Nothing -> if n > 1
then requestN (n - 1) a s
else throwM HostsBusy
request1 :: (Tuple a, Tuple b)
=> Host
-> Request k a b
-> ClientState
-> Client (HostResponse k a b)
request1 h r s = do
rs <- tryRequest1 h r s
maybe (throwM HostsBusy) return rs
tryRequest1 :: (Tuple a, Tuple b)
=> Host
-> Request k a b
-> ClientState
-> Client (Maybe (HostResponse k a b))
tryRequest1 h a s = do
pool <- Map.lookup h <$> readTVarIO' (s^.hostmap)
case pool of
Just p -> do
result <- Pool.with p exec `catches` handlers
for_ result $ \(HostResponse _ r) ->
for_ (Cql.warnings r) $ \w ->
logWarn' $ "Server warning: " <> byteString (encodeUtf8 w)
return result
Nothing -> do
logError' $ "No pool for host: " <> string8 (show h)
p' <- mkPool (s^.context) h
atomically' $ modifyTVar' (s^.hostmap) (Map.alter (maybe (Just p') Just) h)
tryRequest1 h a s
where
exec c = do
r <- C.request c a
return $ HostResponse h r
handlers =
[ Handler $ \(e :: ConnectionError) -> onConnectionError e
, Handler $ \(e :: IOException) -> onConnectionError e
, Handler $ \(e :: SomeSSLException) -> onConnectionError e
]
onConnectionError exc = do
e <- ask
logWarn' (string8 (show exc))
liftIO $ ignore $ onEvent (e^.policy) (HostDown (h^.hostAddr))
runJob_ (e^.jobs) (h^.hostAddr) $
runClient e $ monitor (Ms 0) (Ms 30000) h
ch <- fmap (view (connection.host)) . readTVarIO' =<< view control
when (h == ch) $ do
ok <- checkControl
unless ok replaceControl
throwM exc
executeWithPrepare :: (Tuple b, Tuple a)
=> Maybe Host
-> Request k a b
-> Client (HostResponse k a b)
executeWithPrepare mh rq
| Just h <- mh = exec (request1 h)
| otherwise = do
p <- view policy
n <- liftIO $ hostCount p
exec (requestN n)
where
exec action = do
r <- withRetries action rq
case hrResponse r of
RsError _ _ (Unprepared _ i) -> do
pq <- preparedQueries
qs <- atomically' (PQ.lookupQueryString (QueryId i) pq)
case qs of
Nothing -> throwM $ UnexpectedQueryId (QueryId i)
Just s -> do
(h, _) <- prepare (Just LazyPrepare) (s :: Raw QueryString)
executeWithPrepare (Just h) rq
_ -> return r
prepare :: (Tuple b, Tuple a) => Maybe PrepareStrategy -> QueryString k a b -> Client (Host, QueryId k a b)
prepare (Just LazyPrepare) qs = do
s <- ask
n <- liftIO $ hostCount (s^.policy)
r <- withRetries (requestN n) (RqPrepare (Prepare qs))
getPreparedQueryId r
prepare (Just EagerPrepare) qs = view policy
>>= liftIO . current
>>= mapM (action (RqPrepare (Prepare qs)))
>>= first
where
action rq h = withRetries (request1 h) rq >>= getPreparedQueryId
first (x:_) = return x
first [] = replaceControl >> throwM NoHostAvailable
prepare Nothing qs = do
ps <- view (context.settings.prepStrategy)
prepare (Just ps) qs
execute :: (Tuple b, Tuple a) => PrepQuery k a b -> QueryParams a -> Client (HostResponse k a b)
execute q p = do
pq <- view prepQueries
maybe (new pq) (exec Nothing) =<< atomically' (PQ.lookupQueryId q pq)
where
exec h i = executeWithPrepare h (RqExecute (Execute i p))
new pq = do
(h, i) <- prepare (Just LazyPrepare) (PQ.queryString q)
atomically' (PQ.insert q i pq)
exec (Just h) i
prepareAllQueries :: Host -> Client ()
prepareAllQueries h = do
pq <- view prepQueries
qs <- atomically' $ PQ.queryStrings pq
for_ qs $ \q ->
let qry = QueryString q :: Raw QueryString in
withRetries (request1 h) (RqPrepare (Prepare qry))
data DebugInfo = DebugInfo
{ policyInfo :: String
, jobInfo :: [InetAddr]
, hostInfo :: [Host]
, controlInfo :: (Host, ControlState)
}
instance Show DebugInfo where
show dbg = showString "running jobs: "
. shows (jobInfo dbg)
. showString "\nknown hosts: "
. shows (hostInfo dbg)
. showString "\npolicy info: "
. shows (policyInfo dbg)
. showString "\ncontrol host: "
. shows (controlInfo dbg)
$ ""
debugInfo :: MonadClient m => m DebugInfo
debugInfo = liftClient $ do
hosts <- Map.keys <$> (readTVarIO' =<< view hostmap)
pols <- liftIO . display =<< view policy
jbs <- listJobKeys =<< view jobs
ctrl <- (\(Control s c) -> (c^.host, s)) <$> (readTVarIO' =<< view control)
return $ DebugInfo pols jbs hosts ctrl
preparedQueries :: Client PreparedQueries
preparedQueries = view prepQueries
init :: MonadIO m => Settings -> m ClientState
init s = liftIO $ do
tom <- TM.create (Ms 250)
ctx <- Context s tom <$> signal
bracketOnError (mkContact ctx) C.close $ \con -> do
pol <- s^.policyMaker
cst <- ClientState ctx
<$> pure pol
<*> PQ.new
<*> newTVarIO (Control Connected con)
<*> newTVarIO Map.empty
<*> newJobs
ctx^.sigMonit |-> onEvent pol
runClient cst (setupControl con)
return cst
mkContact :: Context -> IO Connection
mkContact (Context s t _) = tryAll (s^.contacts) mkConnection
where
mkConnection h = do
as <- C.resolve h (s^.portnumber)
NE.fromList as `tryAll` doConnect
doConnect a = do
logDebug (s^.logger) $ "Connecting to " <> string8 (show a)
c <- C.connect (s^.connSettings) t (s^.protoVersion) (s^.logger) (Host a "" "")
return c
discoverPeers :: MonadIO m => Context -> Connection -> m [Host]
discoverPeers ctx c = liftIO $ do
let p = ctx^.settings.portnumber
map (peer2Host p . asRecord) <$> C.query c One Disco.peers ()
mkPool :: MonadIO m => Context -> Host -> m Pool
mkPool ctx h = liftIO $ do
let s = ctx^.settings
let m = s^.connSettings.maxStreams
Pool.create (connOpen s) connClose (ctx^.settings.logger) (s^.poolSettings) m
where
lgr = ctx^.settings.logger
connOpen s = do
c <- C.connect (s^.connSettings) (ctx^.timeouts) (s^.protoVersion) lgr h
logDebug lgr $ "Connection established: " <> string8 (show c)
return c
connClose c = do
C.close c
logDebug lgr $ "Connection closed: " <> string8 (show c)
shutdown :: MonadIO m => ClientState -> m ()
shutdown s = liftIO $ asyncShutdown >>= wait
where
asyncShutdown = async $ do
TM.destroy (s^.context.timeouts) True
cancelJobs (s^.jobs)
ignore $ C.close . view connection =<< readTVarIO (s^.control)
mapM_ Pool.destroy . Map.elems =<< readTVarIO (s^.hostmap)
monitor :: Milliseconds -> Milliseconds -> Host -> Client ()
monitor initial maxDelay h = do
liftIO $ threadDelay (toMicros initial)
logInfo' $ "Monitoring: " <> string8 (show h)
hostCheck 0
where
hostCheck :: Int -> Client ()
hostCheck !n = do
hosts <- liftIO . readTVarIO =<< view hostmap
when (Map.member h hosts) $ do
isUp <- C.canConnect h
if isUp then do
sig <- view (context.sigMonit)
liftIO $ sig $$ (HostUp (h^.hostAddr))
logInfo' $ "Reachable: " <> string8 (show h)
else do
logInfo' $ "Unreachable: " <> string8 (show h)
liftIO $ threadDelay (2^n * minDelay)
hostCheck (min (n + 1) maxExp)
toMicros :: Milliseconds -> Int
toMicros (Ms s) = min (s * 1000) (5 * 60 * 1000000)
minDelay :: Int
minDelay = 50000
maxExp :: Int
maxExp = let steps = fromIntegral (toMicros maxDelay `div` minDelay) :: Double
in floor (logBase 2 steps)
withRetries
:: (Tuple a, Tuple b)
=> (Request k a b -> ClientState -> Client (HostResponse k a b))
-> Request k a b
-> Client (HostResponse k a b)
withRetries fn a = do
s <- ask
let how = s^.context.settings.retrySettings.retryPolicy
let what = s^.context.settings.retrySettings.retryHandlers
r <- try $ recovering how what $ \i -> do
hr <- if rsIterNumber i == 0
then fn a s
else fn (newRequest s) (adjust s)
maybe (return hr) throwM (toResponseError hr)
return $ either fromResponseError id r
where
adjust s =
let Ms x = s^.context.settings.retrySettings.sendTimeoutChange
Ms y = s^.context.settings.retrySettings.recvTimeoutChange
in over (context.settings.connSettings.sendTimeout) (Ms . (+ x) . ms)
. over (context.settings.connSettings.responseTimeout) (Ms . (+ y) . ms)
$ s
newRequest s =
case s^.context.settings.retrySettings.reducedConsistency of
Nothing -> a
Just c ->
case a of
RqQuery (Query q p) -> RqQuery (Query q p { consistency = c })
RqExecute (Execute q p) -> RqExecute (Execute q p { consistency = c })
RqBatch b -> RqBatch b { batchConsistency = c }
_ -> a
setupControl :: Connection -> Client ()
setupControl c = do
env <- ask
pol <- view policy
ctx <- view context
l <- updateHost (c^.host) . listToMaybe <$> C.query c One Disco.local ()
r <- discoverPeers ctx c
(up, down) <- mkHostMap ctx pol (l:r)
m <- view hostmap
let h = Map.union up down
atomically' $ writeTVar m h
liftIO $ setup pol (Map.keys up) (Map.keys down)
C.register c C.allEventTypes (runClient env . onCqlEvent)
logInfo' $ "Known hosts: " <> string8 (show (Map.keys h))
j <- view jobs
for_ (Map.keys down) $ \d ->
runJob j (d^.hostAddr) $
runClient env $ monitor (Ms 1000) (Ms 60000) d
ctl <- view control
let c' = set C.host l c
atomically' $ writeTVar ctl (Control Connected c')
logInfo' $ "New control connection: " <> string8 (show c')
mkHostMap :: Context -> Policy -> [Host] -> Client (Map Host Pool, Map Host Pool)
mkHostMap c p = liftIO . foldrM checkHost (Map.empty, Map.empty)
where
checkHost h (up, down) = do
okay <- acceptable p h
if okay then do
isUp <- C.canConnect h
if isUp then do
up' <- Map.insert h <$> mkPool c h <*> pure up
return (up', down)
else do
down' <- Map.insert h <$> mkPool c h <*> pure down
return (up, down')
else
return (up, down)
checkControl :: Client Bool
checkControl = do
cc <- view connection <$> (readTVarIO' =<< view control)
rs <- liftIO $ C.requestRaw cc (RqOptions Options)
return $ case rs of
RsSupported {} -> True
_ -> False
`recover`
False
replaceControl :: Client ()
replaceControl = do
e <- ask
let l = e^.context.settings.logger
liftIO $ mask $ \restore -> do
cc <- setReconnecting e
for_ cc $ \c -> forkIO $
restore $ do
ignore (C.close c)
reconnect e l
`catchAll` \ex -> do
logError l $ "Control connection reconnect aborted: " <> string8 (show ex)
atomically $ modifyTVar' (e^.control) (set state Disconnected)
where
setReconnecting e = atomically $ do
ctrl <- readTVar (e^.control)
if ctrl^.state /= Reconnecting
then do
writeTVar (e^.control) (set state Reconnecting ctrl)
return $ Just (ctrl^.connection)
else
return Nothing
reconnect e l = recovering adInf (onExc l) $ \_ -> do
hosts <- NE.nonEmpty . Map.keys <$> readTVarIO (e^.hostmap)
case hosts of
Just hs -> hs `tryAll` (runClient e . renewControl)
`catch` \x -> case fromException x of
Just (SomeAsyncException _) -> throwM x
Nothing -> do
logError l "All known hosts unreachable."
runClient e rebootControl
Nothing -> do
logError l "No known hosts."
runClient e rebootControl
adInf = capDelay 5000000 (exponentialBackoff 5000)
onExc l =
[ const $ Handler $ \(_ :: SomeAsyncException) -> return False
, const $ Handler $ \(e :: SomeException) -> do
logError l $ "Replacement of control connection failed with: "
<> string8 (show e)
<> ". Retrying ..."
return True
]
renewControl :: Host -> Client ()
renewControl h = do
ctx <- view context
logInfo' "Renewing control connection with known host ..."
let s = ctx^.settings
bracketOnError
(C.connect (s^.connSettings) (ctx^.timeouts) (s^.protoVersion) (s^.logger) h)
(liftIO . C.close)
setupControl
rebootControl :: Client ()
rebootControl = do
e <- ask
logInfo' "Renewing control connection with initial contacts ..."
bracketOnError
(liftIO (mkContact (e^.context)))
(liftIO . C.close)
setupControl
onCqlEvent :: Event -> Client ()
onCqlEvent x = do
logInfo' $ "Event: " <> string8 (show x)
pol <- view policy
prt <- view (context.settings.portnumber)
case x of
StatusEvent Down (sock2inet prt -> a) ->
liftIO $ onEvent pol (HostDown a)
TopologyEvent RemovedNode (sock2inet prt -> a) -> do
hmap <- view hostmap
atomically' $
modifyTVar' hmap (Map.filterWithKey (\h _ -> h^.hostAddr /= a))
liftIO $ onEvent pol (HostGone a)
StatusEvent Up (sock2inet prt -> a) -> do
s <- ask
startMonitor s a
TopologyEvent NewNode (sock2inet prt -> a) -> do
s <- ask
let ctx = s^.context
let hmap = s^.hostmap
ctrl <- readTVarIO' (s^.control)
let c = ctrl^.connection
peers <- liftIO $ discoverPeers ctx c `recover` []
let h = fromMaybe (Host a "" "") $ find ((a == ) . view hostAddr) peers
okay <- liftIO $ acceptable pol h
when okay $ do
p <- mkPool ctx h
atomically' $ modifyTVar' hmap (Map.alter (maybe (Just p) Just) h)
liftIO $ onEvent pol (HostNew h)
tryRunJob_ (s^.jobs) a $ runClient s (prepareAllQueries h)
SchemaEvent _ -> return ()
where
startMonitor s a = do
hmp <- readTVarIO' (s^.hostmap)
case find ((a ==) . view hostAddr) (Map.keys hmp) of
Just h -> tryRunJob_ (s^.jobs) a $ runClient s $ do
monitor (Ms 3000) (Ms 60000) h
prepareAllQueries h
Nothing -> return ()
getResult :: MonadThrow m => HostResponse k a b -> m (Result k a b)
getResult (HostResponse _ (RsResult _ _ r)) = return r
getResult (HostResponse h (RsError t w e)) = throwM (ResponseError h t w e)
getResult hr = unexpected hr
{-# INLINE getResult #-}
getPreparedQueryId :: MonadThrow m => HostResponse k a b -> m (Host, QueryId k a b)
getPreparedQueryId hr = getResult hr >>= \case
PreparedResult i _ _ -> return (hrHost hr, i)
_ -> unexpected hr
{-# INLINE getPreparedQueryId #-}
unexpected :: MonadThrow m => HostResponse k a b -> m c
unexpected (HostResponse h r) = throwM $ UnexpectedResponse h r
atomically' :: STM a -> Client a
atomically' = liftIO . atomically
readTVarIO' :: TVar a -> Client a
readTVarIO' = liftIO . readTVarIO
logInfo' :: Builder -> Client ()
logInfo' m = do
l <- view (context.settings.logger)
liftIO $ logInfo l m
{-# INLINE logInfo' #-}
logWarn' :: Builder -> Client ()
logWarn' m = do
l <- view (context.settings.logger)
liftIO $ logWarn l m
{-# INLINE logWarn' #-}
logError' :: Builder -> Client ()
logError' m = do
l <- view (context.settings.logger)
liftIO $ logError l m
{-# INLINE logError' #-}