{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
module Database.Redis.Sentinel
(
SentinelConnectInfo(..)
, SentinelConnection
, connect
, runRedis
, RedisSentinelException(..)
, module Database.Redis
) where
import Control.Concurrent
import Control.Exception (Exception, IOException, evaluate, throwIO)
import Control.Monad
import Control.Monad.Catch (Handler (..), MonadCatch, catches, throwM)
import Control.Monad.Except
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BS8
import Data.Foldable (toList)
import Data.List (delete)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Typeable (Typeable)
import Data.Unique
import Network.Socket (HostName)
import Database.Redis hiding (Connection, connect, runRedis)
import qualified Database.Redis as Redis
runRedis :: SentinelConnection
-> Redis (Either Reply a)
-> IO (Either Reply a)
runRedis (SentinelConnection connMVar) action = do
(baseConn, preToken) <- modifyMVar connMVar $ \oldConnection@SentinelConnection'
{ rcCheckFailover
, rcToken = oldToken
, rcSentinelConnectInfo = oldConnectInfo
, rcMasterConnectInfo = oldMasterConnectInfo
, rcBaseConnection = oldBaseConnection } ->
if rcCheckFailover
then do
(newConnectInfo, newMasterConnectInfo) <- updateMaster oldConnectInfo
newToken <- newUnique
(connInfo, conn) <-
if sameHost newMasterConnectInfo oldMasterConnectInfo
then return (oldMasterConnectInfo, oldBaseConnection)
else do
newConn <- Redis.connect newMasterConnectInfo
return (newMasterConnectInfo, newConn)
return
( SentinelConnection'
{ rcCheckFailover = False
, rcToken = newToken
, rcSentinelConnectInfo = newConnectInfo
, rcMasterConnectInfo = connInfo
, rcBaseConnection = conn
}
, (conn, newToken)
)
else return (oldConnection, (oldBaseConnection, oldToken))
reply <- (Redis.runRedis baseConn action >>= evaluate)
`catchRedisRethrow` (\_ -> setCheckSentinel preToken)
case reply of
Left (Error e) | "READONLY " `BS.isPrefixOf` e ->
setCheckSentinel preToken
_ -> return ()
return reply
where
sameHost :: Redis.ConnectInfo -> Redis.ConnectInfo -> Bool
sameHost l r = connectHost l == connectHost r && connectPort l == connectPort r
setCheckSentinel preToken = modifyMVar_ connMVar $ \conn@SentinelConnection'{rcToken} ->
if preToken == rcToken
then do
newToken <- newUnique
return (conn{rcToken = newToken, rcCheckFailover = True})
else return conn
connect :: SentinelConnectInfo -> IO SentinelConnection
connect origConnectInfo = do
(connectInfo, masterConnectInfo) <- updateMaster origConnectInfo
conn <- Redis.connect masterConnectInfo
token <- newUnique
SentinelConnection <$> newMVar SentinelConnection'
{ rcCheckFailover = False
, rcToken = token
, rcSentinelConnectInfo = connectInfo
, rcMasterConnectInfo = masterConnectInfo
, rcBaseConnection = conn
}
updateMaster :: SentinelConnectInfo
-> IO (SentinelConnectInfo, Redis.ConnectInfo)
updateMaster sci@SentinelConnectInfo{..} = do
resultEither <- runExceptT $ forM_ connectSentinels $ \(host, port) -> do
trySentinel host port `catchRedis` (\_ -> return ())
case resultEither of
Left (conn, sentinelPair) -> return
( sci
{ connectSentinels = sentinelPair :| delete sentinelPair (toList connectSentinels)
}
, conn
)
Right () -> throwIO $ NoSentinels connectSentinels
where
trySentinel :: HostName -> PortID -> ExceptT (Redis.ConnectInfo, (HostName, PortID)) IO ()
trySentinel sentinelHost sentinelPort = do
!replyE <- liftIO $ do
!sentinelConn <- Redis.connect $ Redis.defaultConnectInfo
{ connectHost = sentinelHost
, connectPort = sentinelPort
, connectMaxConnections = 1
}
Redis.runRedis sentinelConn $ sendRequest
["SENTINEL", "get-master-addr-by-name", connectMasterName]
case replyE of
Right [host, port] ->
throwError
( connectBaseInfo
{ connectHost = BS8.unpack host
, connectPort =
maybe
(PortNumber 26379)
(PortNumber . fromIntegral . fst)
$ BS8.readInt port
}
, (sentinelHost, sentinelPort)
)
_ -> return ()
catchRedisRethrow :: MonadCatch m => m a -> (String -> m ()) -> m a
catchRedisRethrow action handler =
action `catches`
[ Handler $ \ex -> handler (show @IOException ex) >> throwM ex
, Handler $ \ex -> handler (show @ConnectionLostException ex) >> throwM ex
]
catchRedis :: MonadCatch m => m a -> (String -> m a) -> m a
catchRedis action handler =
action `catches`
[ Handler $ \ex -> handler (show @IOException ex)
, Handler $ \ex -> handler (show @ConnectionLostException ex)
]
newtype SentinelConnection = SentinelConnection (MVar SentinelConnection')
data SentinelConnection'
= SentinelConnection'
{ rcCheckFailover :: Bool
, rcToken :: Unique
, rcSentinelConnectInfo :: SentinelConnectInfo
, rcMasterConnectInfo :: Redis.ConnectInfo
, rcBaseConnection :: Redis.Connection
}
data SentinelConnectInfo
= SentinelConnectInfo
{ connectSentinels :: NonEmpty (HostName, PortID)
, connectMasterName :: ByteString
, connectBaseInfo :: Redis.ConnectInfo
}
deriving (Show)
data RedisSentinelException
= NoSentinels (NonEmpty (HostName, PortID))
deriving (Show, Typeable, Exception)