{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Database.Redis.ConnectionContext ( ConnectionContext(..) , ConnectTimeout(..) , ConnectionLostException(..) , PortID(..) , connect , disconnect , send , recv , errConnClosed , enableTLS , flush , ioErrorToConnLost ) where import Control.Concurrent (threadDelay) import Control.Concurrent.Async (race) import Control.Monad(when) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as LB import qualified Data.IORef as IOR import Control.Concurrent.MVar(newMVar, readMVar, swapMVar) import Control.Exception(bracketOnError, Exception, throwIO, try) import Data.Typeable import Data.Functor(void) import qualified Network.Socket as NS import qualified Network.TLS as TLS import System.IO(Handle, hSetBinaryMode, hClose, IOMode(..), hFlush, hIsOpen) import System.IO.Error(catchIOError) data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context instance Show ConnectionContext where show :: ConnectionContext -> String show (NormalHandle Handle _) = String "NormalHandle" show (TLSContext Context _) = String "TLSContext" data Connection = Connection { Connection -> ConnectionContext ctx :: ConnectionContext , Connection -> IORef (Maybe ByteString) lastRecvRef :: IOR.IORef (Maybe B.ByteString) } instance Show Connection where show :: Connection -> String show Connection{IORef (Maybe ByteString) ConnectionContext lastRecvRef :: IORef (Maybe ByteString) ctx :: ConnectionContext lastRecvRef :: Connection -> IORef (Maybe ByteString) ctx :: Connection -> ConnectionContext ..} = String "Connection{ ctx = " String -> ShowS forall a. [a] -> [a] -> [a] ++ ConnectionContext -> String forall a. Show a => a -> String show ConnectionContext ctx String -> ShowS forall a. [a] -> [a] -> [a] ++ String ", lastRecvRef = IORef}" data ConnectPhase = PhaseUnknown | PhaseResolve | PhaseOpenSocket deriving (Int -> ConnectPhase -> ShowS [ConnectPhase] -> ShowS ConnectPhase -> String (Int -> ConnectPhase -> ShowS) -> (ConnectPhase -> String) -> ([ConnectPhase] -> ShowS) -> Show ConnectPhase forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [ConnectPhase] -> ShowS $cshowList :: [ConnectPhase] -> ShowS show :: ConnectPhase -> String $cshow :: ConnectPhase -> String showsPrec :: Int -> ConnectPhase -> ShowS $cshowsPrec :: Int -> ConnectPhase -> ShowS Show) newtype ConnectTimeout = ConnectTimeout ConnectPhase deriving (Int -> ConnectTimeout -> ShowS [ConnectTimeout] -> ShowS ConnectTimeout -> String (Int -> ConnectTimeout -> ShowS) -> (ConnectTimeout -> String) -> ([ConnectTimeout] -> ShowS) -> Show ConnectTimeout forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [ConnectTimeout] -> ShowS $cshowList :: [ConnectTimeout] -> ShowS show :: ConnectTimeout -> String $cshow :: ConnectTimeout -> String showsPrec :: Int -> ConnectTimeout -> ShowS $cshowsPrec :: Int -> ConnectTimeout -> ShowS Show, Typeable) instance Exception ConnectTimeout data ConnectionLostException = ConnectionLost deriving Int -> ConnectionLostException -> ShowS [ConnectionLostException] -> ShowS ConnectionLostException -> String (Int -> ConnectionLostException -> ShowS) -> (ConnectionLostException -> String) -> ([ConnectionLostException] -> ShowS) -> Show ConnectionLostException forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [ConnectionLostException] -> ShowS $cshowList :: [ConnectionLostException] -> ShowS show :: ConnectionLostException -> String $cshow :: ConnectionLostException -> String showsPrec :: Int -> ConnectionLostException -> ShowS $cshowsPrec :: Int -> ConnectionLostException -> ShowS Show instance Exception ConnectionLostException data PortID = PortNumber NS.PortNumber | UnixSocket String deriving (PortID -> PortID -> Bool (PortID -> PortID -> Bool) -> (PortID -> PortID -> Bool) -> Eq PortID forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a /= :: PortID -> PortID -> Bool $c/= :: PortID -> PortID -> Bool == :: PortID -> PortID -> Bool $c== :: PortID -> PortID -> Bool Eq, Int -> PortID -> ShowS [PortID] -> ShowS PortID -> String (Int -> PortID -> ShowS) -> (PortID -> String) -> ([PortID] -> ShowS) -> Show PortID forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [PortID] -> ShowS $cshowList :: [PortID] -> ShowS show :: PortID -> String $cshow :: PortID -> String showsPrec :: Int -> PortID -> ShowS $cshowsPrec :: Int -> PortID -> ShowS Show) connect :: NS.HostName -> PortID -> Maybe Int -> IO ConnectionContext connect :: String -> PortID -> Maybe Int -> IO ConnectionContext connect String hostName PortID portId Maybe Int timeoutOpt = IO Handle -> (Handle -> IO ()) -> (Handle -> IO ConnectionContext) -> IO ConnectionContext forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c bracketOnError IO Handle hConnect Handle -> IO () hClose ((Handle -> IO ConnectionContext) -> IO ConnectionContext) -> (Handle -> IO ConnectionContext) -> IO ConnectionContext forall a b. (a -> b) -> a -> b $ \Handle h -> do Handle -> Bool -> IO () hSetBinaryMode Handle h Bool True ConnectionContext -> IO ConnectionContext forall (m :: * -> *) a. Monad m => a -> m a return (ConnectionContext -> IO ConnectionContext) -> ConnectionContext -> IO ConnectionContext forall a b. (a -> b) -> a -> b $ Handle -> ConnectionContext NormalHandle Handle h where hConnect :: IO Handle hConnect = do MVar ConnectPhase phaseMVar <- ConnectPhase -> IO (MVar ConnectPhase) forall a. a -> IO (MVar a) newMVar ConnectPhase PhaseUnknown let doConnect :: IO Handle doConnect = MVar ConnectPhase -> IO Handle hConnect' MVar ConnectPhase phaseMVar case Maybe Int timeoutOpt of Maybe Int Nothing -> IO Handle doConnect Just Int micros -> do Either Handle () result <- IO Handle -> IO () -> IO (Either Handle ()) forall a b. IO a -> IO b -> IO (Either a b) race IO Handle doConnect (Int -> IO () threadDelay Int micros) case Either Handle () result of Left Handle h -> Handle -> IO Handle forall (m :: * -> *) a. Monad m => a -> m a return Handle h Right () -> do ConnectPhase phase <- MVar ConnectPhase -> IO ConnectPhase forall a. MVar a -> IO a readMVar MVar ConnectPhase phaseMVar ConnectPhase -> IO Handle forall a. ConnectPhase -> IO a errConnectTimeout ConnectPhase phase hConnect' :: MVar ConnectPhase -> IO Handle hConnect' MVar ConnectPhase mvar = IO Socket -> (Socket -> IO ()) -> (Socket -> IO Handle) -> IO Handle forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c bracketOnError IO Socket createSock Socket -> IO () NS.close ((Socket -> IO Handle) -> IO Handle) -> (Socket -> IO Handle) -> IO Handle forall a b. (a -> b) -> a -> b $ \Socket sock -> do Socket -> SocketOption -> Int -> IO () NS.setSocketOption Socket sock SocketOption NS.KeepAlive Int 1 IO ConnectPhase -> IO () forall (f :: * -> *) a. Functor f => f a -> f () void (IO ConnectPhase -> IO ()) -> IO ConnectPhase -> IO () forall a b. (a -> b) -> a -> b $ MVar ConnectPhase -> ConnectPhase -> IO ConnectPhase forall a. MVar a -> a -> IO a swapMVar MVar ConnectPhase mvar ConnectPhase PhaseResolve IO ConnectPhase -> IO () forall (f :: * -> *) a. Functor f => f a -> f () void (IO ConnectPhase -> IO ()) -> IO ConnectPhase -> IO () forall a b. (a -> b) -> a -> b $ MVar ConnectPhase -> ConnectPhase -> IO ConnectPhase forall a. MVar a -> a -> IO a swapMVar MVar ConnectPhase mvar ConnectPhase PhaseOpenSocket Socket -> IOMode -> IO Handle NS.socketToHandle Socket sock IOMode ReadWriteMode where createSock :: IO Socket createSock = case PortID portId of PortNumber PortNumber portNumber -> do [AddrInfo] addrInfo <- String -> PortNumber -> IO [AddrInfo] getHostAddrInfo String hostName PortNumber portNumber [AddrInfo] -> IO Socket connectSocket [AddrInfo] addrInfo UnixSocket String addr -> IO Socket -> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket NS.socket Family NS.AF_UNIX SocketType NS.Stream ProtocolNumber NS.defaultProtocol) Socket -> IO () NS.close (\Socket sock -> Socket -> SockAddr -> IO () NS.connect Socket sock (String -> SockAddr NS.SockAddrUnix String addr) IO () -> IO Socket -> IO Socket forall (m :: * -> *) a b. Monad m => m a -> m b -> m b >> Socket -> IO Socket forall (m :: * -> *) a. Monad m => a -> m a return Socket sock) getHostAddrInfo :: NS.HostName -> NS.PortNumber -> IO [NS.AddrInfo] getHostAddrInfo :: String -> PortNumber -> IO [AddrInfo] getHostAddrInfo String hostname PortNumber port = Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo] NS.getAddrInfo (AddrInfo -> Maybe AddrInfo forall a. a -> Maybe a Just AddrInfo hints) (String -> Maybe String forall a. a -> Maybe a Just String hostname) (String -> Maybe String forall a. a -> Maybe a Just (String -> Maybe String) -> String -> Maybe String forall a b. (a -> b) -> a -> b $ PortNumber -> String forall a. Show a => a -> String show PortNumber port) where hints :: AddrInfo hints = AddrInfo NS.defaultHints { addrSocketType :: SocketType NS.addrSocketType = SocketType NS.Stream } errConnectTimeout :: ConnectPhase -> IO a errConnectTimeout :: ConnectPhase -> IO a errConnectTimeout ConnectPhase phase = ConnectTimeout -> IO a forall e a. Exception e => e -> IO a throwIO (ConnectTimeout -> IO a) -> ConnectTimeout -> IO a forall a b. (a -> b) -> a -> b $ ConnectPhase -> ConnectTimeout ConnectTimeout ConnectPhase phase connectSocket :: [NS.AddrInfo] -> IO NS.Socket connectSocket :: [AddrInfo] -> IO Socket connectSocket [] = String -> IO Socket forall a. HasCallStack => String -> a error String "connectSocket: unexpected empty list" connectSocket (AddrInfo addr:[AddrInfo] rest) = IO (Either IOError Socket) tryConnect IO (Either IOError Socket) -> (Either IOError Socket -> IO Socket) -> IO Socket forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= \case Right Socket sock -> Socket -> IO Socket forall (m :: * -> *) a. Monad m => a -> m a return Socket sock Left IOError err -> if [AddrInfo] -> Bool forall (t :: * -> *) a. Foldable t => t a -> Bool null [AddrInfo] rest then IOError -> IO Socket forall e a. Exception e => e -> IO a throwIO IOError err else [AddrInfo] -> IO Socket connectSocket [AddrInfo] rest where tryConnect :: IO (Either IOError NS.Socket) tryConnect :: IO (Either IOError Socket) tryConnect = IO Socket -> (Socket -> IO ()) -> (Socket -> IO (Either IOError Socket)) -> IO (Either IOError Socket) forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c bracketOnError IO Socket createSock Socket -> IO () NS.close ((Socket -> IO (Either IOError Socket)) -> IO (Either IOError Socket)) -> (Socket -> IO (Either IOError Socket)) -> IO (Either IOError Socket) forall a b. (a -> b) -> a -> b $ \Socket sock -> IO () -> IO (Either IOError ()) forall e a. Exception e => IO a -> IO (Either e a) try (Socket -> SockAddr -> IO () NS.connect Socket sock (SockAddr -> IO ()) -> SockAddr -> IO () forall a b. (a -> b) -> a -> b $ AddrInfo -> SockAddr NS.addrAddress AddrInfo addr) IO (Either IOError ()) -> (Either IOError () -> IO (Either IOError Socket)) -> IO (Either IOError Socket) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= \case Right () -> Either IOError Socket -> IO (Either IOError Socket) forall (m :: * -> *) a. Monad m => a -> m a return (Socket -> Either IOError Socket forall a b. b -> Either a b Right Socket sock) Left IOError err -> Socket -> IO () NS.close Socket sock IO () -> IO (Either IOError Socket) -> IO (Either IOError Socket) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b >> Either IOError Socket -> IO (Either IOError Socket) forall (m :: * -> *) a. Monad m => a -> m a return (IOError -> Either IOError Socket forall a b. a -> Either a b Left IOError err) where createSock :: IO Socket createSock = Family -> SocketType -> ProtocolNumber -> IO Socket NS.socket (AddrInfo -> Family NS.addrFamily AddrInfo addr) (AddrInfo -> SocketType NS.addrSocketType AddrInfo addr) (AddrInfo -> ProtocolNumber NS.addrProtocol AddrInfo addr) send :: ConnectionContext -> B.ByteString -> IO () send :: ConnectionContext -> ByteString -> IO () send (NormalHandle Handle h) ByteString requestData = IO () -> IO () forall a. IO a -> IO a ioErrorToConnLost (Handle -> ByteString -> IO () B.hPut Handle h ByteString requestData) send (TLSContext Context ctx) ByteString requestData = IO () -> IO () forall a. IO a -> IO a ioErrorToConnLost (Context -> ByteString -> IO () forall (m :: * -> *). MonadIO m => Context -> ByteString -> m () TLS.sendData Context ctx (ByteString -> ByteString LB.fromStrict ByteString requestData)) recv :: ConnectionContext -> IO B.ByteString recv :: ConnectionContext -> IO ByteString recv (NormalHandle Handle h) = IO ByteString -> IO ByteString forall a. IO a -> IO a ioErrorToConnLost (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString forall a b. (a -> b) -> a -> b $ Handle -> Int -> IO ByteString B.hGetSome Handle h Int 4096 recv (TLSContext Context ctx) = Context -> IO ByteString forall (m :: * -> *). MonadIO m => Context -> m ByteString TLS.recvData Context ctx ioErrorToConnLost :: IO a -> IO a ioErrorToConnLost :: IO a -> IO a ioErrorToConnLost IO a a = IO a a IO a -> (IOError -> IO a) -> IO a forall a. IO a -> (IOError -> IO a) -> IO a `catchIOError` IO a -> IOError -> IO a forall a b. a -> b -> a const IO a forall a. IO a errConnClosed errConnClosed :: IO a errConnClosed :: IO a errConnClosed = ConnectionLostException -> IO a forall e a. Exception e => e -> IO a throwIO ConnectionLostException ConnectionLost enableTLS :: TLS.ClientParams -> ConnectionContext -> IO ConnectionContext enableTLS :: ClientParams -> ConnectionContext -> IO ConnectionContext enableTLS ClientParams tlsParams (NormalHandle Handle h) = do Context ctx <- Handle -> ClientParams -> IO Context forall (m :: * -> *) backend params. (MonadIO m, HasBackend backend, TLSParams params) => backend -> params -> m Context TLS.contextNew Handle h ClientParams tlsParams Context -> IO () forall (m :: * -> *). MonadIO m => Context -> m () TLS.handshake Context ctx ConnectionContext -> IO ConnectionContext forall (m :: * -> *) a. Monad m => a -> m a return (ConnectionContext -> IO ConnectionContext) -> ConnectionContext -> IO ConnectionContext forall a b. (a -> b) -> a -> b $ Context -> ConnectionContext TLSContext Context ctx enableTLS ClientParams _ c :: ConnectionContext c@(TLSContext Context _) = ConnectionContext -> IO ConnectionContext forall (m :: * -> *) a. Monad m => a -> m a return ConnectionContext c disconnect :: ConnectionContext -> IO () disconnect :: ConnectionContext -> IO () disconnect (NormalHandle Handle h) = do Bool open <- Handle -> IO Bool hIsOpen Handle h Bool -> IO () -> IO () forall (f :: * -> *). Applicative f => Bool -> f () -> f () when Bool open (IO () -> IO ()) -> IO () -> IO () forall a b. (a -> b) -> a -> b $ Handle -> IO () hClose Handle h disconnect (TLSContext Context ctx) = do Context -> IO () forall (m :: * -> *). MonadIO m => Context -> m () TLS.bye Context ctx Context -> IO () TLS.contextClose Context ctx flush :: ConnectionContext -> IO () flush :: ConnectionContext -> IO () flush (NormalHandle Handle h) = Handle -> IO () hFlush Handle h flush (TLSContext Context c) = Context -> IO () TLS.contextFlush Context c