module Database.Redis.ProtocolPipelining (
Connection,
connect, disconnect, request, send, recv, flush,
ConnectionLostException(..),
HostName, PortID(..)
) where
import Prelude
import Control.Concurrent (threadDelay)
import Control.Concurrent.Async (race)
import Control.Concurrent.MVar
import Control.Exception
import Control.Monad
import qualified Scanner
import qualified Data.ByteString as S
import Data.IORef
import Data.Typeable
import Network
import qualified Network.BSD as BSD
import qualified Network.Socket as NS
import System.IO
import System.IO.Error
import System.IO.Unsafe
import Database.Redis.Protocol
data Connection = Conn
{ connHandle :: Handle
, connReplies :: IORef [Reply]
, connPending :: IORef [Reply]
, connPendingCnt :: IORef Int
}
data ConnectionLostException = ConnectionLost
deriving (Show, Typeable)
instance Exception ConnectionLostException
data ConnectPhase
= PhaseUnknown
| PhaseResolve
| PhaseOpenSocket
deriving (Show)
data ConnectTimeout = ConnectTimeout ConnectPhase
deriving (Show, Typeable)
instance Exception ConnectTimeout
connect :: HostName -> PortID -> Maybe Int -> IO Connection
connect hostName portID timeoutOpt =
bracketOnError hConnect hClose $ \connHandle -> do
hSetBinaryMode connHandle True
connReplies <- newIORef []
connPending <- newIORef []
connPendingCnt <- newIORef 0
let conn = Conn{..}
rs <- connGetReplies conn
writeIORef connReplies rs
writeIORef connPending rs
return conn
where
hConnect = do
phaseMVar <- newMVar PhaseUnknown
let doConnect = hConnect' portID phaseMVar
case timeoutOpt of
Nothing -> doConnect
Just micros -> do
result <- race doConnect (threadDelay micros)
case result of
Left h -> return h
Right () -> do
phase <- readMVar phaseMVar
errConnectTimeout phase
hConnect' (PortNumber port) mvar =
bracketOnError mkSocket NS.close $ \sock -> do
NS.setSocketOption sock NS.KeepAlive 1
void $ swapMVar mvar PhaseResolve
host <- BSD.getHostByName hostName
void $ swapMVar mvar PhaseOpenSocket
NS.connect sock $ NS.SockAddrInet port (BSD.hostAddress host)
NS.socketToHandle sock ReadWriteMode
hConnect' _ _ = connectTo hostName portID
mkSocket = NS.socket NS.AF_INET NS.Stream 0
disconnect :: Connection -> IO ()
disconnect Conn{..} = do
open <- hIsOpen connHandle
when open (hClose connHandle)
send :: Connection -> S.ByteString -> IO ()
send Conn{..} s = do
ioErrorToConnLost (S.hPut connHandle s)
n <- atomicModifyIORef' connPendingCnt $ \n -> let n' = n+1 in (n', n')
when (n >= 1000) $ do
r:_ <- readIORef connPending
r `seq` return ()
recv :: Connection -> IO Reply
recv Conn{..} = do
(r:rs) <- readIORef connReplies
writeIORef connReplies rs
return r
flush :: Connection -> IO ()
flush Conn{..} = hFlush connHandle
request :: Connection -> S.ByteString -> IO Reply
request conn req = send conn req >> recv conn
connGetReplies :: Connection -> IO [Reply]
connGetReplies Conn{..} = go S.empty (SingleLine "previous of first")
where
go rest previous = do
~(r, rest') <- unsafeInterleaveIO $ do
previous `seq` return ()
scanResult <- Scanner.scanWith readMore reply rest
case scanResult of
Scanner.Fail{} -> errConnClosed
Scanner.More{} -> error "Hedis: parseWith returned Partial"
Scanner.Done rest' r -> do
atomicModifyIORef' connPending $ \(_:rs) -> (rs, ())
atomicModifyIORef' connPendingCnt $ \n -> (max 0 (n1), ())
return (r, rest')
rs <- unsafeInterleaveIO (go rest' r)
return (r:rs)
readMore = ioErrorToConnLost $ do
hFlush connHandle
S.hGetSome connHandle 4096
ioErrorToConnLost :: IO a -> IO a
ioErrorToConnLost a = a `catchIOError` const errConnClosed
errConnClosed :: IO a
errConnClosed = throwIO ConnectionLost
errConnectTimeout :: ConnectPhase -> IO a
errConnectTimeout phase = throwIO $ ConnectTimeout phase