-- This Source Code Form is subject to the terms of the Mozilla Public -- License, v. 2.0. If a copy of the MPL was not distributed with this -- file, You can obtain one at http://mozilla.org/MPL/2.0/. {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE ViewPatterns #-} module Database.CQL.IO.Connection ( Connection , ConnId , ident , host -- * Lifecycle , connect , canConnect , close -- * Requests , request , Raw , requestRaw -- ** Queries , query , defQueryParams -- ** Events , EventHandler , allEventTypes , register -- * Re-exports , Socket.resolve ) where import Control.Concurrent (myThreadId, forkIOWithUnmask) import Control.Concurrent.Async import Control.Concurrent.MVar import Control.Concurrent.STM import Control.Exception (throwTo) import Control.Lens ((^.), makeLenses, view, set) import Control.Monad import Control.Monad.Catch import Control.Monad.IO.Class import Data.ByteString.Builder import Data.Foldable (for_) import Data.Semigroup ((<>)) import Data.Text.Lazy (fromStrict) import Data.Unique import Data.Vector (Vector, (!)) import Database.CQL.Protocol import Database.CQL.IO.Cluster.Host import Database.CQL.IO.Connection.Socket (Socket) import Database.CQL.IO.Connection.Settings import Database.CQL.IO.Exception import Database.CQL.IO.Log import Database.CQL.IO.Protocol import Database.CQL.IO.Signal (Signal, signal, (|->), emit) import Database.CQL.IO.Sync (Sync) import Database.CQL.IO.Timeouts (TimeoutManager, withTimeout) import qualified Data.HashMap.Strict as HashMap import qualified Data.Vector as Vector import qualified Database.CQL.IO.Connection.Socket as Socket import qualified Database.CQL.IO.Sync as Sync import qualified Database.CQL.IO.Tickets as Tickets -- | The streams of a connection are a vector of slots, each -- containing the last received CQL protocol frame on that stream. type Streams = Vector (Sync Frame) -- | A connection to a 'Host' in a Cassandra cluster. data Connection = Connection { _settings :: !ConnectionSettings , _host :: !Host , _tmanager :: !TimeoutManager , _protocol :: !Version , _sock :: !Socket , _status :: !(TVar Bool) , _streams :: !Streams , _wLock :: !(MVar ()) , _reader :: !(Async ()) , _tickets :: !Tickets.Pool , _logger :: !Logger , _eventSig :: !(Signal Event) , _ident :: !ConnId } makeLenses ''Connection instance Eq Connection where a == b = a^.ident == b^.ident instance Show Connection where show c = shows (c^.host) . showString "#" . shows (c^.sock) $ "" ------------------------------------------------------------------------------ -- Lifecycle -- | Establish and initialise a new connection to a Cassandra host. connect :: MonadIO m => ConnectionSettings -> TimeoutManager -> Version -> Logger -> Host -> m Connection connect t m v g h = liftIO $ do c <- bracketOnError sockOpen Socket.close $ \s -> do tck <- Tickets.pool (t^.maxStreams) syn <- Vector.replicateM (t^.maxStreams) Sync.create lck <- newMVar () sta <- newTVarIO True sig <- signal rdr <- async (readLoop v g t tck h s syn sig sta lck) Connection t h m v s sta syn lck rdr tck g sig . ConnId <$> newUnique initialise c return c where sockOpen = Socket.open (t^.connectTimeout) (h^.hostAddr) (t^.tlsContext) initialise c = do validateSettings c startup c for_ (t^.defKeyspace) $ useKeyspace c `onException` close c validateSettings c = do Supported ca _ <- supportedOptions c let x = algorithm (c^.settings.compression) unless (x == None || x `elem` ca) $ throwM $ UnsupportedCompression x ca supportedOptions c = do let req = RqOptions Options let c' = set (settings.compression) noCompression c requestRaw c' req >>= \case RsSupported _ _ x -> return x rs -> unhandled c rs -- | Check the connectivity of a Cassandra host on a new connection. canConnect :: MonadIO m => Host -> m Bool canConnect h = liftIO $ reachable `recover` False where reachable = bracket (Socket.open (Ms 5000) (h^.hostAddr) Nothing) Socket.close (const (return True)) -- Note: The socket is closed when the 'readLoop' exits. close :: Connection -> IO () close = cancel . view reader ------------------------------------------------------------------------------ -- Low-level operations type Raw a = a () () () request :: (Tuple a, Tuple b) => Connection -> Request k a b -> IO (Response k a b) request c rq = send >>= receive where send = withTimeout (c^.tmanager) (c^.settings.sendTimeout) (close c) $ do i <- Tickets.toInt <$> Tickets.get (c^.tickets) req <- serialise (c^.protocol) (c^.settings.compression) rq i logRequest (c^.logger) req withMVar (c^.wLock) $ const $ do isOpen <- readTVarIO (c^.status) if isOpen then Socket.send (c^.sock) req else throwM $ ConnectionClosed (c^.host.hostAddr) return i receive i = do let rt = ResponseTimeout (c^.host.hostAddr) tid <- myThreadId r <- withTimeout (c^.tmanager) (c^.settings.responseTimeout) (throwTo tid rt) $ do r <- Sync.get (view streams c ! i) `onException` Sync.kill rt (view streams c ! i) Tickets.markAvailable (c^.tickets) i return r parse (c^.settings.compression) r requestRaw :: Connection -> Raw Request -> IO (Raw Response) requestRaw = request ----------------------------------------------------------------------------- -- High-level operations startup :: MonadIO m => Connection -> m () startup c = liftIO $ do let cmp = c^.settings.compression let req = RqStartup (Startup Cqlv300 (algorithm cmp)) requestRaw c req >>= \case RsReady _ _ Ready -> checkAuth c RsAuthenticate _ _ auth -> authenticate c auth rs -> unhandled c rs checkAuth :: Connection -> IO () checkAuth c = unless (null (c^.settings.authenticators)) $ logWarn' (c^.logger) (c^.host) $ "Authentication configured but none required by the server." authenticate :: Connection -> Authenticate -> IO () authenticate c (Authenticate (AuthMechanism -> m)) = case HashMap.lookup m (c^.settings.authenticators) of Nothing -> throwM $ AuthenticationRequired m Just Authenticator { authOnRequest = onR , authOnChallenge = onC , authOnSuccess = onS } -> do (rs, s) <- onR context case onC of Just f -> loop f onS (rs, s) Nothing -> authResponse c rs >>= either (throwM . UnexpectedAuthenticationChallenge m) (onS s) where context = AuthContext (c^.ident) (c^.host.hostAddr) loop onC onS (rs, s) = authResponse c rs >>= either (onC s >=> loop onC onS) (onS s) authResponse :: Connection -> AuthResponse -> IO (Either AuthChallenge AuthSuccess) authResponse c resp = liftIO $ do let req = RqAuthResp resp requestRaw c req >>= \case RsAuthSuccess _ _ success -> return $ Right success RsAuthChallenge _ _ chall -> return $ Left chall rs -> unhandled c rs useKeyspace :: MonadIO m => Connection -> Keyspace -> m () useKeyspace c ks = liftIO $ do let params = defQueryParams One () kspace = quoted (fromStrict $ unKeyspace ks) req = RqQuery (Query (QueryString $ "use " <> kspace) params) requestRaw c req >>= \case RsResult _ _ (SetKeyspaceResult _) -> return () rs -> unhandled c rs ------------------------------------------------------------------------------ -- Queries query :: (Tuple a, Tuple b, MonadIO m) => Connection -> Consistency -> QueryString k a b -> a -> m [b] query c cons q p = liftIO $ do let req = RqQuery (Query q (defQueryParams cons p)) request c req >>= \case RsResult _ _ (RowsResult _ b) -> return b rs -> unhandled c rs -- | Construct default 'QueryParams' for the given consistency -- and bound values. In particular, no page size, paging state -- or serial consistency will be set. defQueryParams :: Consistency -> a -> QueryParams a defQueryParams c a = QueryParams { consistency = c , values = a , skipMetaData = False , pageSize = Nothing , queryPagingState = Nothing , serialConsistency = Nothing , enableTracing = Nothing } ------------------------------------------------------------------------------ -- Events type EventHandler = Event -> IO () allEventTypes :: [EventType] allEventTypes = [TopologyChangeEvent, StatusChangeEvent, SchemaChangeEvent] register :: MonadIO m => Connection -> [EventType] -> EventHandler -> m () register c ev f = liftIO $ do let req = RqRegister (Register ev) requestRaw c req >>= \case RsReady _ _ Ready -> c^.eventSig |-> f rs -> unhandled c rs ------------------------------------------------------------------------------ -- Read loop -- Note: The read loop owns the socket given and is responsible -- for closing it, when it gets interrupted. readLoop :: Version -> Logger -> ConnectionSettings -> Tickets.Pool -> Host -> Socket -> Streams -> Signal Event -> TVar Bool -> MVar () -> IO () readLoop v g cset tck h sck syn sig sref wlck = run `catch` logException `finally` cleanup where run = forever $ do f@(Frame hd _) <- readFrame v g h sck (cset^.maxRecvBuffer) case fromStreamId (streamId hd) of -1 -> do r <- parse (cset^.compression) f :: IO (Raw Response) case r of RsEvent _ _ e -> emit sig e _ -> throwM (UnexpectedResponse h r) sid -> do ok <- Sync.put f (syn ! sid) unless ok $ Tickets.markAvailable tck sid cleanup = uninterruptibleMask_ $ do isOpen <- atomically $ swapTVar sref False when isOpen $ do let ex = ConnectionClosed (h^.hostAddr) Tickets.close ex tck Vector.mapM_ (Sync.close ex) syn -- Try to shut down the socket gracefully, now allowing -- interruptions (i.e. all exceptions) but make sure -- the socket gets closed eventually. void $ forkIOWithUnmask $ \unmask -> unmask (do Socket.shutdown sck Socket.ShutdownReceive withMVar wlck (const $ Socket.close sck) ) `onException` Socket.close sck logException e = case fromException e of Just AsyncCancelled -> return () _ -> logWarn' g h ("read-loop: " <> string8 (show e)) readFrame :: Version -> Logger -> Host -> Socket -> Int -> IO Frame readFrame v g h s n = do b <- Socket.recv n (h^.hostAddr) s 9 case header v b of Left e -> throwM $ ParseError ("response header reading: " ++ e) Right hdr -> case headerType hdr of RqHeader -> throwM $ ParseError "unexpected header" RsHeader -> do let len = lengthRepr (bodyLength hdr) dat <- Socket.recv n (h^.hostAddr) s (fromIntegral len) logResponse g (b <> dat) return $ Frame hdr dat unhandled :: Connection -> Response k a b -> IO c unhandled c r = case r of RsError t w e -> throwM (ResponseError (c^.host) t w e) rs -> unexpected c rs unexpected :: Connection -> Response k a b -> IO c unexpected c r = throwM $ UnexpectedResponse (c^.host) r logWarn' :: Logger -> Host -> Builder -> IO () logWarn' l h m = logWarn l $ string8 (show h) <> string8 ": " <> m