-- This Source Code Form is subject to the terms of the Mozilla Public -- License, v. 2.0. If a copy of the MPL was not distributed with this -- file, You can obtain one at http://mozilla.org/MPL/2.0/. {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE ViewPatterns #-} module Database.CQL.IO.Connection ( Connection , ConnId , resolve , ping , connect , close , request , startup , register , query , useKeyspace , address , protocol , eventSig ) where import Control.Applicative import Control.Concurrent (myThreadId, forkIOWithUnmask) import Control.Concurrent.Async import Control.Concurrent.MVar import Control.Concurrent.STM import Control.Exception (throwTo, AsyncException (ThreadKilled)) import Control.Lens ((^.), makeLenses, view) import Control.Monad import Control.Monad.Catch import Control.Monad.IO.Class import Data.ByteString.Lazy (ByteString) import Data.Int import Data.Maybe (fromMaybe) import Data.Monoid import Data.Text.Lazy (fromStrict) import Data.Unique import Data.Vector (Vector, (!)) import Database.CQL.Protocol import Database.CQL.IO.Connection.Socket (Socket) import Database.CQL.IO.Connection.Settings import Database.CQL.IO.Hexdump import Database.CQL.IO.Protocol import Database.CQL.IO.Signal hiding (connect) import Database.CQL.IO.Sync (Sync) import Database.CQL.IO.Types import Database.CQL.IO.Tickets (Pool, toInt, markAvailable) import Database.CQL.IO.Timeouts (TimeoutManager, withTimeout) import Network.Socket hiding (Socket, close, connect, send) import System.IO (nativeNewline, Newline (..)) import System.Logger hiding (Settings, close, defSettings, settings) import System.Timeout import Prelude import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy.Char8 as Char8 import qualified Data.HashMap.Strict as HashMap import qualified Data.Vector as Vector import qualified Database.CQL.IO.Connection.Socket as Socket import qualified Database.CQL.IO.Sync as Sync import qualified Database.CQL.IO.Tickets as Tickets import qualified Network.Socket as S type Streams = Vector (Sync (Header, ByteString)) data Connection = Connection { _settings :: !ConnectionSettings , _address :: !InetAddr , _tmanager :: !TimeoutManager , _protocol :: !Version , _sock :: !Socket , _status :: !(TVar Bool) , _streams :: !Streams , _wLock :: !(MVar ()) , _reader :: !(Async ()) , _tickets :: !Pool , _logger :: !Logger , _eventSig :: !(Signal Event) , _ident :: !ConnId } makeLenses ''Connection instance Eq Connection where a == b = a^.ident == b^.ident instance Show Connection where show = Char8.unpack . eval . bytes instance ToBytes Connection where bytes c = bytes (c^.address) +++ val "#" +++ c^.sock resolve :: String -> PortNumber -> IO [InetAddr] resolve host port = map (InetAddr . addrAddress) <$> getAddrInfo (Just hints) (Just host) (Just (show port)) where hints = defaultHints { addrFlags = [AI_ADDRCONFIG], addrSocketType = Stream } connect :: MonadIO m => ConnectionSettings -> TimeoutManager -> Version -> Logger -> InetAddr -> m Connection connect t m v g a = liftIO $ do c <- bracketOnError (Socket.open (t^.connectTimeout) a (t^.tlsContext)) Socket.close $ \s -> do tck <- Tickets.pool (t^.maxStreams) syn <- Vector.replicateM (t^.maxStreams) Sync.create lck <- newMVar () sta <- newTVarIO True sig <- signal rdr <- async (readLoop v g t tck a s syn sig sta lck) Connection t a m v s sta syn lck rdr tck g sig . ConnId <$> newUnique validateSettings c `onException` close c return c ping :: MonadIO m => InetAddr -> m Bool ping a = liftIO $ bracket (Socket.mkSock a) S.close $ \s -> fromMaybe False <$> timeout 5000000 ((S.connect s (sockAddr a) >> return True) `catchAll` const (return False)) readLoop :: Version -> Logger -> ConnectionSettings -> Pool -> InetAddr -> Socket -> Streams -> Signal Event -> TVar Bool -> MVar () -> IO () readLoop v g set tck i sck syn s sref wlck = run `catch` logException `finally` cleanup where run = forever $ do x <- readSocket v g i sck (set^.maxRecvBuffer) case fromStreamId $ streamId (fst x) of -1 -> case parse (set^.compression) x :: Raw Response of RsError _ _ e -> throwM e RsEvent _ _ e -> emit s e r -> throwM (UnexpectedResponse' r) sid -> do ok <- Sync.put x (syn ! sid) unless ok $ markAvailable tck sid cleanup = uninterruptibleMask_ $ do isOpen <- atomically $ swapTVar sref False when isOpen $ do Tickets.close (ConnectionClosed i) tck Vector.mapM_ (Sync.close (ConnectionClosed i)) syn void $ forkIOWithUnmask $ \unmask -> unmask $ do Socket.shutdown sck ShutdownReceive withMVar wlck (const $ Socket.close sck) logException :: SomeException -> IO () logException e = case fromException e of Just ThreadKilled -> return () _ -> warn g $ msg i ~~ msg (val "read-loop: " +++ show e) close :: Connection -> IO () close = cancel . view reader request :: Connection -> (Int -> ByteString) -> IO (Header, ByteString) request c f = send >>= receive where send = withTimeout (c^.tmanager) (c^.settings.sendTimeout) (close c) $ do i <- toInt <$> Tickets.get (c^.tickets) let req = f i trace (c^.logger) $ msg c ~~ "stream" .= i ~~ "type" .= val "request" ~~ msg' (hexdump (L.take 160 req)) withMVar (c^.wLock) $ const $ do isOpen <- readTVarIO (c^.status) if isOpen then Socket.send (c^.sock) req else throwM $ ConnectionClosed (c^.address) return i receive i = do let e = TimeoutRead (show c ++ ":" ++ show i) tid <- myThreadId withTimeout (c^.tmanager) (c^.settings.responseTimeout) (throwTo tid e) $ do x <- Sync.get (view streams c ! i) `onException` Sync.kill e (view streams c ! i) markAvailable (c^.tickets) i return x readSocket :: Version -> Logger -> InetAddr -> Socket -> Int -> IO (Header, ByteString) readSocket v g i s n = do b <- Socket.recv n i s 9 h <- case header v b of Left e -> throwM $ InternalError ("response header reading: " ++ e) Right h -> return h case headerType h of RqHeader -> throwM $ InternalError "unexpected request header" RsHeader -> do let len = lengthRepr (bodyLength h) x <- Socket.recv n i s (fromIntegral len) trace g $ msg (i +++ val "#" +++ s) ~~ "stream" .= fromStreamId (streamId h) ~~ "type" .= val "response" ~~ msg' (hexdump $ L.take 160 (b <> x)) return (h, x) ----------------------------------------------------------------------------- -- Operations startup :: MonadIO m => Connection -> m () startup c = liftIO $ do let cmp = c^.settings.compression let req = RqStartup (Startup Cqlv300 (algorithm cmp)) let enc = serialise (c^.protocol) cmp (req :: Raw Request) res <- request c enc case parse cmp res :: Raw Response of RsReady _ _ Ready -> checkAuth c RsAuthenticate _ _ auth -> authenticate c auth RsError _ _ e -> throwM e other -> throwM $ UnexpectedResponse' other checkAuth :: Connection -> IO () checkAuth c = unless (null (c^.settings.authenticators)) $ warn (_logger c) $ msg $ val "Authentication configured but none required by server." authenticate :: (MonadIO m, MonadThrow m) => Connection -> Authenticate -> m () authenticate c (Authenticate (AuthMechanism -> m)) = case HashMap.lookup m (c^.settings.authenticators) of Nothing -> throwM $ AuthenticationRequired m Just Authenticator { authOnRequest = onR , authOnChallenge = onC , authOnSuccess = onS } -> liftIO $ do (rs, s) <- onR context case onC of Just f -> loop f onS (rs, s) Nothing -> authResponse c rs >>= either (throwM . UnexpectedAuthenticationChallenge m) (onS s) where context = AuthContext (c^.ident) (c^.address) loop onC onS (rs, s) = authResponse c rs >>= either (onC s >=> loop onC onS) (onS s) authResponse :: MonadIO m => Connection -> AuthResponse -> m (Either AuthChallenge AuthSuccess) authResponse c resp = liftIO $ do let cmp = c^.settings.compression let req = RqAuthResp resp let enc = serialise (c^.protocol) cmp (req :: Raw Request) res <- request c enc case parse cmp res :: Raw Response of RsAuthSuccess _ _ success -> return $ Right success RsAuthChallenge _ _ challenge -> return $ Left challenge RsError _ _ e -> throwM e other -> throwM $ UnexpectedResponse' other register :: MonadIO m => Connection -> [EventType] -> EventHandler -> m () register c e f = liftIO $ do let req = RqRegister (Register e) :: Raw Request let enc = serialise (c^.protocol) (c^.settings.compression) req res <- request c enc case parse (c^.settings.compression) res :: Raw Response of RsReady _ _ Ready -> c^.eventSig |-> f other -> throwM (UnexpectedResponse' other) validateSettings :: MonadIO m => Connection -> m () validateSettings c = liftIO $ do Supported ca _ <- supportedOptions c let x = algorithm (c^.settings.compression) unless (x == None || x `elem` ca) $ throwM $ UnsupportedCompression ca supportedOptions :: MonadIO m => Connection -> m Supported supportedOptions c = liftIO $ do let options = RqOptions Options :: Raw Request res <- request c (serialise (c^.protocol) noCompression options) case parse noCompression res :: Raw Response of RsSupported _ _ x -> return x other -> throwM (UnexpectedResponse' other) useKeyspace :: MonadIO m => Connection -> Keyspace -> m () useKeyspace c ks = liftIO $ do let cmp = c^.settings.compression params = QueryParams One False () Nothing Nothing Nothing Nothing kspace = quoted (fromStrict $ unKeyspace ks) req = RqQuery (Query (QueryString $ "use " <> kspace) params) res <- request c (serialise (c^.protocol) cmp req) case parse cmp res :: Raw Response of RsResult _ _ (SetKeyspaceResult _) -> return () other -> throwM (UnexpectedResponse' other) query :: forall k a b m. (Tuple a, Tuple b, Show b, MonadIO m) => Connection -> Consistency -> QueryString k a b -> a -> m [b] query c cons q p = liftIO $ do let req = RqQuery (Query q params) :: Request k a b let enc = serialise (c^.protocol) (c^.settings.compression) req res <- request c enc case parse (c^.settings.compression) res :: Response k a b of RsResult _ _ (RowsResult _ b) -> return b other -> throwM (UnexpectedResponse' other) where params = QueryParams cons False p Nothing Nothing Nothing Nothing -- logging helpers: msg' :: ByteString -> Msg -> Msg msg' x = msg $ case nativeNewline of LF -> val "\n" +++ x CRLF -> val "\r\n" +++ x