{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Database.Redis.ProtocolPipelining (
Connection,
connect, enableTLS, beginReceiving, disconnect, request, send, recv, flush,
ConnectionLostException(..),
PortID(..)
) where
import Prelude
import Control.Concurrent (threadDelay)
import Control.Concurrent.Async (race)
import Control.Concurrent.MVar
import Control.Exception
import Control.Monad
import qualified Scanner
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.IORef
import Data.Typeable
import qualified Network.Socket as NS
import qualified Network.TLS as TLS
import System.IO
import System.IO.Error
import System.IO.Unsafe
import Database.Redis.Protocol
data PortID = PortNumber NS.PortNumber
| UnixSocket String
deriving (Eq, Show)
data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context
data Connection = Conn
{ connCtx :: ConnectionContext
, connReplies :: IORef [Reply]
, connPending :: IORef [Reply]
, connPendingCnt :: IORef Int
}
data ConnectionLostException = ConnectionLost
deriving (Show, Typeable)
instance Exception ConnectionLostException
data ConnectPhase
= PhaseUnknown
| PhaseResolve
| PhaseOpenSocket
deriving (Show)
data ConnectTimeout = ConnectTimeout ConnectPhase
deriving (Show, Typeable)
instance Exception ConnectTimeout
getHostAddrInfo :: NS.HostName -> NS.PortNumber -> IO [NS.AddrInfo]
getHostAddrInfo hostname port = do
NS.getAddrInfo (Just hints) (Just hostname) (Just $ show port)
where
hints = NS.defaultHints
{ NS.addrSocketType = NS.Stream }
connectSocket :: [NS.AddrInfo] -> IO NS.Socket
connectSocket [] = error "connectSocket: unexpected empty list"
connectSocket (addr:rest) = tryConnect >>= \case
Right sock -> return sock
Left err -> if null rest
then throwIO err
else connectSocket rest
where
tryConnect :: IO (Either IOError NS.Socket)
tryConnect = bracketOnError createSock NS.close $ \sock -> do
try (NS.connect sock $ NS.addrAddress addr) >>= \case
Right () -> return (Right sock)
Left err -> return (Left err)
where
createSock = NS.socket (NS.addrFamily addr)
(NS.addrSocketType addr)
(NS.addrProtocol addr)
connect :: NS.HostName -> PortID -> Maybe Int -> IO Connection
connect hostName portId timeoutOpt =
bracketOnError hConnect hClose $ \h -> do
hSetBinaryMode h True
connReplies <- newIORef []
connPending <- newIORef []
connPendingCnt <- newIORef 0
let connCtx = NormalHandle h
return Conn{..}
where
hConnect = do
phaseMVar <- newMVar PhaseUnknown
let doConnect = hConnect' phaseMVar
case timeoutOpt of
Nothing -> doConnect
Just micros -> do
result <- race doConnect (threadDelay micros)
case result of
Left h -> return h
Right () -> do
phase <- readMVar phaseMVar
errConnectTimeout phase
hConnect' mvar = bracketOnError createSock NS.close $ \sock -> do
NS.setSocketOption sock NS.KeepAlive 1
void $ swapMVar mvar PhaseResolve
void $ swapMVar mvar PhaseOpenSocket
NS.socketToHandle sock ReadWriteMode
where
createSock = case portId of
PortNumber portNumber -> do
addrInfo <- getHostAddrInfo hostName portNumber
connectSocket addrInfo
UnixSocket addr -> bracketOnError
(NS.socket NS.AF_UNIX NS.Stream NS.defaultProtocol)
NS.close
(\sock -> NS.connect sock (NS.SockAddrUnix addr) >> return sock)
enableTLS :: TLS.ClientParams -> Connection -> IO Connection
enableTLS tlsParams conn@Conn{..} = do
case connCtx of
NormalHandle h -> do
ctx <- TLS.contextNew h tlsParams
TLS.handshake ctx
return $ conn { connCtx = TLSContext ctx }
TLSContext _ -> return conn
beginReceiving :: Connection -> IO ()
beginReceiving conn = do
rs <- connGetReplies conn
writeIORef (connReplies conn) rs
writeIORef (connPending conn) rs
disconnect :: Connection -> IO ()
disconnect Conn{..} = do
case connCtx of
NormalHandle h -> do
open <- hIsOpen h
when open $ hClose h
TLSContext ctx -> do
TLS.bye ctx
TLS.contextClose ctx
send :: Connection -> S.ByteString -> IO ()
send Conn{..} s = do
case connCtx of
NormalHandle h ->
ioErrorToConnLost $ S.hPut h s
TLSContext ctx ->
ioErrorToConnLost $ TLS.sendData ctx (L.fromStrict s)
n <- atomicModifyIORef' connPendingCnt $ \n -> let n' = n+1 in (n', n')
when (n >= 1000) $ do
r:_ <- readIORef connPending
r `seq` return ()
recv :: Connection -> IO Reply
recv Conn{..} = do
(r:rs) <- readIORef connReplies
writeIORef connReplies rs
return r
flush :: Connection -> IO ()
flush Conn{..} =
case connCtx of
NormalHandle h -> hFlush h
TLSContext ctx -> TLS.contextFlush ctx
request :: Connection -> S.ByteString -> IO Reply
request conn req = send conn req >> recv conn
connGetReplies :: Connection -> IO [Reply]
connGetReplies conn@Conn{..} = go S.empty (SingleLine "previous of first")
where
go rest previous = do
~(r, rest') <- unsafeInterleaveIO $ do
previous `seq` return ()
scanResult <- Scanner.scanWith readMore reply rest
case scanResult of
Scanner.Fail{} -> errConnClosed
Scanner.More{} -> error "Hedis: parseWith returned Partial"
Scanner.Done rest' r -> do
atomicModifyIORef' connPending $ \(_:rs) -> (rs, ())
atomicModifyIORef' connPendingCnt $ \n -> (max 0 (n-1), ())
return (r, rest')
rs <- unsafeInterleaveIO (go rest' r)
return (r:rs)
readMore = ioErrorToConnLost $ do
flush conn
case connCtx of
NormalHandle h -> S.hGetSome h 4096
TLSContext ctx -> TLS.recvData ctx
ioErrorToConnLost :: IO a -> IO a
ioErrorToConnLost a = a `catchIOError` const errConnClosed
errConnClosed :: IO a
errConnClosed = throwIO ConnectionLost
errConnectTimeout :: ConnectPhase -> IO a
errConnectTimeout phase = throwIO $ ConnectTimeout phase