{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Database.Redis.IO.Client where
import Control.Applicative
import Control.Exception (throw, throwIO)
import Control.Monad.Base (MonadBase (..))
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad.Operational
import Control.Monad.Reader (ReaderT (..), runReaderT, MonadReader, ask, asks)
import Control.Monad.Trans.Class
import Control.Monad.Trans.Control (MonadBaseControl (..))
import Control.Monad.Trans.Except
import Data.ByteString.Lazy (ByteString)
import Data.Int
import Data.IORef
import Data.Redis
import Data.Pool hiding (Pool)
import Database.Redis.IO.Connection (Connection)
import Database.Redis.IO.Settings
import Database.Redis.IO.Timeouts (TimeoutManager)
import Database.Redis.IO.Types (ConnectionError (..))
import Prelude hiding (readList)
import System.Logger.Class hiding (Settings, settings, eval)
import System.IO.Unsafe (unsafeInterleaveIO)
import qualified Control.Monad.State.Strict as S
import qualified Control.Monad.State.Lazy as L
import qualified Data.List.NonEmpty as NE
import qualified Data.Pool as P
import qualified Data.Sequence as Seq
import qualified Database.Redis.IO.Connection as C
import qualified System.Logger as Logger
import qualified Database.Redis.IO.Timeouts as TM
data Pool = Pool
{ settings :: !Settings
, connPool :: !(P.Pool Connection)
, logger :: !Logger.Logger
, timeouts :: !TimeoutManager
}
newtype Client a = Client
{ client :: ReaderT Pool IO a
} deriving ( Functor
, Applicative
, Monad
, MonadIO
, MonadThrow
, MonadMask
, MonadCatch
, MonadReader Pool
, MonadBase IO
)
instance MonadLogger Client where
log l m = asks logger >>= \g -> Logger.log g l m
instance MonadBaseControl IO Client where
type StM Client a = StM (ReaderT Pool IO) a
liftBaseWith f = Client $ liftBaseWith $ \run -> f (run . client)
restoreM = Client . restoreM
class (Functor m, Applicative m, Monad m, MonadIO m, MonadCatch m) => MonadClient m
where
liftClient :: Client a -> m a
instance MonadClient Client where
liftClient = id
instance MonadClient m => MonadClient (ReaderT r m) where
liftClient = lift . liftClient
instance MonadClient m => MonadClient (S.StateT s m) where
liftClient = lift . liftClient
instance MonadClient m => MonadClient (L.StateT s m) where
liftClient = lift . liftClient
instance MonadClient m => MonadClient (ExceptT e m) where
liftClient = lift . liftClient
mkPool :: MonadIO m => Logger -> Settings -> m Pool
mkPool g s = liftIO $ do
t <- TM.create 250
a <- C.resolve (sHost s) (sPort s)
Pool s <$> createPool (connOpen t a)
connClose
(sPoolStripes s)
(sIdleTimeout s)
(sMaxConnections s)
<*> pure g
<*> pure t
where
connOpen t aa = tryAll (NE.fromList aa) $ \a -> do
c <- C.connect s g t a
Logger.debug g $ "client.connect" .= sHost s ~~ msg (show c)
return c
connClose c = do
Logger.debug g $ "client.close" .= sHost s ~~ msg (show c)
C.close c
shutdown :: MonadIO m => Pool -> m ()
shutdown p = liftIO $ P.destroyAllResources (connPool p)
runRedis :: MonadIO m => Pool -> Client a -> m a
runRedis p a = liftIO $ runReaderT (client a) p
commands :: MonadClient m => Redis IO a -> m a
commands a = liftClient $ withConnection (flip (eval getLazy) a)
sync :: Redis IO a -> Redis IO a
sync cmd = cmd >>= \x -> x `seq` return x
pubSub :: MonadClient m
=> (Maybe ByteString -> ByteString -> ByteString -> PubSub IO ())
-> PubSub IO ()
-> m ()
pubSub f a = liftClient $ withConnection (loop a)
where
loop :: PubSub IO () -> Connection -> IO ((), [IO ()])
loop p h = do
cmds h p
r <- responses h
case r of
Nothing -> return ((), [])
Just k -> loop k h
cmds :: Connection -> PubSub IO () -> IO ()
cmds h c = do
r <- viewT c
case r of
Return x -> return x
Subscribe x :>>= k -> C.send h (Seq.singleton x) >>= cmds h . k
Unsubscribe x :>>= k -> C.send h (Seq.singleton x) >>= cmds h . k
PSubscribe x :>>= k -> C.send h (Seq.singleton x) >>= cmds h . k
PUnsubscribe x :>>= k -> C.send h (Seq.singleton x) >>= cmds h . k
responses :: Connection -> IO (Maybe (PubSub IO ()))
responses h = do
m <- readPushMessage <$> C.receive h
case m of
Right (Message ch ms) -> return (Just $ f Nothing ch ms)
Right (PMessage pat ch ms) -> return (Just $ f (Just pat) ch ms)
Right (UnsubscribeMessage _ 0) -> return Nothing
Right UnsubscribeMessage {} -> responses h
Right SubscribeMessage {} -> responses h
Left e -> throwIO e
eval :: (forall a. Connection -> Resp -> (Resp -> Result a) -> IO (a, IO ()))
-> Connection
-> Redis IO b
-> IO (b, [IO ()])
eval f conn red = run conn [] red
where
run :: Connection -> [IO ()] -> Redis IO a -> IO (a, [IO ()])
run h ii c = do
r <- viewT c
case r of
Return a -> return (a, ii)
Ping x :>>= k -> f h x (matchStr "PING" "PONG") >>= \(a, i) -> run h (i:ii) $ k a
Echo x :>>= k -> f h x (readBulk "ECHO") >>= \(a, i) -> run h (i:ii) $ k a
Auth x :>>= k -> f h x (matchStr "AUTH" "OK") >>= \(a, i) -> run h (i:ii) $ k a
Quit x :>>= k -> f h x (matchStr "QUIT" "OK") >>= \(a, i) -> run h (i:ii) $ k a
Select x :>>= k -> f h x (matchStr "SELECT" "OK") >>= \(a, i) -> run h (i:ii) $ k a
BgRewriteAOF x :>>= k -> f h x (anyStr "BGREWRITEAOF") >>= \(a, i) -> run h (i:ii) $ k a
BgSave x :>>= k -> f h x (anyStr "BGSAVE") >>= \(a, i) -> run h (i:ii) $ k a
Save x :>>= k -> f h x (matchStr "SAVE" "OK") >>= \(a, i) -> run h (i:ii) $ k a
FlushAll x :>>= k -> f h x (matchStr "FLUSHALL" "OK") >>= \(a, i) -> run h (i:ii) $ k a
FlushDb x :>>= k -> f h x (matchStr "FLUSHDB" "OK") >>= \(a, i) -> run h (i:ii) $ k a
DbSize x :>>= k -> f h x (readInt "DBSIZE") >>= \(a, i) -> run h (i:ii) $ k a
LastSave x :>>= k -> f h x (readInt "LASTSAVE") >>= \(a, i) -> run h (i:ii) $ k a
Multi x :>>= k -> f h x (matchStr "MULTI" "OK") >>= \(a, i) -> run h (i:ii) $ k a
Watch x :>>= k -> f h x (matchStr "WATCH" "OK") >>= \(a, i) -> run h (i:ii) $ k a
Unwatch x :>>= k -> f h x (matchStr "UNWATCH" "OK") >>= \(a, i) -> run h (i:ii) $ k a
Discard x :>>= k -> f h x (matchStr "DISCARD" "OK") >>= \(a, i) -> run h (i:ii) $ k a
Exec x :>>= k -> f h x (const $ Right ()) >>= \(a, i) -> run h (i:ii) $ k a
Del x :>>= k -> f h x (readInt "DEL") >>= \(a, i) -> run h (i:ii) $ k a
Dump x :>>= k -> f h x (readBulk'Null "DUMP") >>= \(a, i) -> run h (i:ii) $ k a
Exists x :>>= k -> f h x (readBool "EXISTS") >>= \(a, i) -> run h (i:ii) $ k a
Expire x :>>= k -> f h x (readBool "EXPIRE") >>= \(a, i) -> run h (i:ii) $ k a
ExpireAt x :>>= k -> f h x (readBool "EXPIREAT") >>= \(a, i) -> run h (i:ii) $ k a
Persist x :>>= k -> f h x (readBool "PERSIST") >>= \(a, i) -> run h (i:ii) $ k a
Keys x :>>= k -> f h x (readList "KEYS") >>= \(a, i) -> run h (i:ii) $ k a
RandomKey x :>>= k -> f h x (readBulk'Null "RANDOMKEY") >>= \(a, i) -> run h (i:ii) $ k a
Rename x :>>= k -> f h x (matchStr "RENAME" "OK") >>= \(a, i) -> run h (i:ii) $ k a
RenameNx x :>>= k -> f h x (readBool "RENAMENX") >>= \(a, i) -> run h (i:ii) $ k a
Ttl x :>>= k -> f h x (readTTL "TTL") >>= \(a, i) -> run h (i:ii) $ k a
Type x :>>= k -> f h x (readType "TYPE") >>= \(a, i) -> run h (i:ii) $ k a
Scan x :>>= k -> f h x (readScan "SCAN") >>= \(a, i) -> run h (i:ii) $ k a
Append x :>>= k -> f h x (readInt "APPEND") >>= \(a, i) -> run h (i:ii) $ k a
Get x :>>= k -> f h x (readBulk'Null "GET") >>= \(a, i) -> run h (i:ii) $ k a
GetRange x :>>= k -> f h x (readBulk "GETRANGE") >>= \(a, i) -> run h (i:ii) $ k a
GetSet x :>>= k -> f h x (readBulk'Null "GETSET") >>= \(a, i) -> run h (i:ii) $ k a
MGet x :>>= k -> f h x (readListOfMaybes "MGET") >>= \(a, i) -> run h (i:ii) $ k a
MSet x :>>= k -> f h x (matchStr "MSET" "OK") >>= \(a, i) -> run h (i:ii) $ k a
MSetNx x :>>= k -> f h x (readBool "MSETNX") >>= \(a, i) -> run h (i:ii) $ k a
Set x :>>= k -> f h x fromSet >>= \(a, i) -> run h (i:ii) $ k a
SetRange x :>>= k -> f h x (readInt "SETRANGE") >>= \(a, i) -> run h (i:ii) $ k a
StrLen x :>>= k -> f h x (readInt "STRLEN") >>= \(a, i) -> run h (i:ii) $ k a
BitAnd x :>>= k -> f h x (readInt "BITOP") >>= \(a, i) -> run h (i:ii) $ k a
BitCount x :>>= k -> f h x (readInt "BITCOUNT") >>= \(a, i) -> run h (i:ii) $ k a
BitNot x :>>= k -> f h x (readInt "BITOP") >>= \(a, i) -> run h (i:ii) $ k a
BitOr x :>>= k -> f h x (readInt "BITOP") >>= \(a, i) -> run h (i:ii) $ k a
BitPos x :>>= k -> f h x (readInt "BITPOS") >>= \(a, i) -> run h (i:ii) $ k a
BitXOr x :>>= k -> f h x (readInt "BITOP") >>= \(a, i) -> run h (i:ii) $ k a
GetBit x :>>= k -> f h x (readInt "GETBIT") >>= \(a, i) -> run h (i:ii) $ k a
SetBit x :>>= k -> f h x (readInt "SETBIT") >>= \(a, i) -> run h (i:ii) $ k a
Decr x :>>= k -> f h x (readInt "DECR") >>= \(a, i) -> run h (i:ii) $ k a
DecrBy x :>>= k -> f h x (readInt "DECRBY") >>= \(a, i) -> run h (i:ii) $ k a
Incr x :>>= k -> f h x (readInt "INCR") >>= \(a, i) -> run h (i:ii) $ k a
IncrBy x :>>= k -> f h x (readInt "INCRBY") >>= \(a, i) -> run h (i:ii) $ k a
IncrByFloat x :>>= k -> f h x (readBulk "INCRBYFLOAT") >>= \(a, i) -> run h (i:ii) $ k a
HDel x :>>= k -> f h x (readInt "HDEL") >>= \(a, i) -> run h (i:ii) $ k a
HExists x :>>= k -> f h x (readBool "HEXISTS") >>= \(a, i) -> run h (i:ii) $ k a
HGet x :>>= k -> f h x (readBulk'Null "HGET") >>= \(a, i) -> run h (i:ii) $ k a
HGetAll x :>>= k -> f h x (readFields "HGETALL") >>= \(a, i) -> run h (i:ii) $ k a
HIncrBy x :>>= k -> f h x (readInt "HINCRBY") >>= \(a, i) -> run h (i:ii) $ k a
HIncrByFloat x :>>= k -> f h x (readBulk "HINCRBYFLOAT") >>= \(a, i) -> run h (i:ii) $ k a
HKeys x :>>= k -> f h x (readList "HKEYS") >>= \(a, i) -> run h (i:ii) $ k a
HLen x :>>= k -> f h x (readInt "HLEN") >>= \(a, i) -> run h (i:ii) $ k a
HMGet x :>>= k -> f h x (readListOfMaybes "HMGET") >>= \(a, i) -> run h (i:ii) $ k a
HMSet x :>>= k -> f h x (matchStr "HMSET" "OK") >>= \(a, i) -> run h (i:ii) $ k a
HSet x :>>= k -> f h x (readBool "HSET") >>= \(a, i) -> run h (i:ii) $ k a
HSetNx x :>>= k -> f h x (readBool "HSETNX") >>= \(a, i) -> run h (i:ii) $ k a
HVals x :>>= k -> f h x (readList "HVALS") >>= \(a, i) -> run h (i:ii) $ k a
HScan x :>>= k -> f h x (readScan "HSCAN") >>= \(a, i) -> run h (i:ii) $ k a
BLPop t x :>>= k -> getNow (withTimeout t h) x (readKeyValue "BLPOP") >>= \(a, i) -> run h (i:ii) $ k a
BRPop t x :>>= k -> getNow (withTimeout t h) x (readKeyValue "BRPOP") >>= \(a, i) -> run h (i:ii) $ k a
BRPopLPush t x :>>= k -> getNow (withTimeout t h) x (readBulk'Null "BRPOPLPUSH") >>= \(a, i) -> run h (i:ii) $ k a
LIndex x :>>= k -> f h x (readBulk'Null "LINDEX") >>= \(a, i) -> run h (i:ii) $ k a
LInsert x :>>= k -> f h x (readInt "LINSERT") >>= \(a, i) -> run h (i:ii) $ k a
LLen x :>>= k -> f h x (readInt "LLEN") >>= \(a, i) -> run h (i:ii) $ k a
LPop x :>>= k -> f h x (readBulk'Null "LPOP") >>= \(a, i) -> run h (i:ii) $ k a
LPush x :>>= k -> f h x (readInt "LPUSH") >>= \(a, i) -> run h (i:ii) $ k a
LPushX x :>>= k -> f h x (readInt "LPUSHX") >>= \(a, i) -> run h (i:ii) $ k a
LRange x :>>= k -> f h x (readList "LRANGE") >>= \(a, i) -> run h (i:ii) $ k a
LRem x :>>= k -> f h x (readInt "LREM") >>= \(a, i) -> run h (i:ii) $ k a
LSet x :>>= k -> f h x (matchStr "LSET" "OK") >>= \(a, i) -> run h (i:ii) $ k a
LTrim x :>>= k -> f h x (matchStr "LTRIM" "OK") >>= \(a, i) -> run h (i:ii) $ k a
RPop x :>>= k -> f h x (readBulk'Null "RPOP") >>= \(a, i) -> run h (i:ii) $ k a
RPopLPush x :>>= k -> f h x (readBulk'Null "RPOPLPUSH") >>= \(a, i) -> run h (i:ii) $ k a
RPush x :>>= k -> f h x (readInt "RPUSH") >>= \(a, i) -> run h (i:ii) $ k a
RPushX x :>>= k -> f h x (readInt "RPUSHX") >>= \(a, i) -> run h (i:ii) $ k a
SAdd x :>>= k -> f h x (readInt "SADD") >>= \(a, i) -> run h (i:ii) $ k a
SCard x :>>= k -> f h x (readInt "SCARD") >>= \(a, i) -> run h (i:ii) $ k a
SDiff x :>>= k -> f h x (readList "SDIFF") >>= \(a, i) -> run h (i:ii) $ k a
SDiffStore x :>>= k -> f h x (readInt "SDIFFSTORE") >>= \(a, i) -> run h (i:ii) $ k a
SInter x :>>= k -> f h x (readList "SINTER") >>= \(a, i) -> run h (i:ii) $ k a
SInterStore x :>>= k -> f h x (readInt "SINTERSTORE") >>= \(a, i) -> run h (i:ii) $ k a
SIsMember x :>>= k -> f h x (readBool "SISMEMBER") >>= \(a, i) -> run h (i:ii) $ k a
SMembers x :>>= k -> f h x (readList "SMEMBERS") >>= \(a, i) -> run h (i:ii) $ k a
SMove x :>>= k -> f h x (readBool "SMOVE") >>= \(a, i) -> run h (i:ii) $ k a
SPop x :>>= k -> f h x (readBulk'Null "SPOP") >>= \(a, i) -> run h (i:ii) $ k a
SRandMember y x :>>= k -> f h x (readBulk'Array "SRANDMEMBER" y) >>= \(a, i) -> run h (i:ii) $ k a
SRem x :>>= k -> f h x (readInt "SREM") >>= \(a, i) -> run h (i:ii) $ k a
SScan x :>>= k -> f h x (readScan "SSCAN") >>= \(a, i) -> run h (i:ii) $ k a
SUnion x :>>= k -> f h x (readList "SUNION") >>= \(a, i) -> run h (i:ii) $ k a
SUnionStore x :>>= k -> f h x (readInt "SUNIONSTORE") >>= \(a, i) -> run h (i:ii) $ k a
ZAdd x :>>= k -> f h x (readInt "ZADD") >>= \(a, i) -> run h (i:ii) $ k a
ZCard x :>>= k -> f h x (readInt "ZCARD") >>= \(a, i) -> run h (i:ii) $ k a
ZCount x :>>= k -> f h x (readInt "ZCOUNT") >>= \(a, i) -> run h (i:ii) $ k a
ZIncrBy x :>>= k -> f h x (readBulk "ZINCRBY") >>= \(a, i) -> run h (i:ii) $ k a
ZInterStore x :>>= k -> f h x (readInt "ZINTERSTORE") >>= \(a, i) -> run h (i:ii) $ k a
ZLexCount x :>>= k -> f h x (readInt "ZLEXCOUNT") >>= \(a, i) -> run h (i:ii) $ k a
ZRange y x :>>= k -> f h x (readScoreList "ZRANGE" y) >>= \(a, i) -> run h (i:ii) $ k a
ZRangeByLex x :>>= k -> f h x (readList "ZRANGEBYLEX") >>= \(a, i) -> run h (i:ii) $ k a
ZRangeByScore y x :>>= k -> f h x (readScoreList "ZRANGEBYSCORE" y) >>= \(a, i) -> run h (i:ii) $ k a
ZRank x :>>= k -> f h x (readInt'Null "ZRANK") >>= \(a, i) -> run h (i:ii) $ k a
ZRem x :>>= k -> f h x (readInt "ZREM") >>= \(a, i) -> run h (i:ii) $ k a
ZRemRangeByLex x :>>= k -> f h x (readInt "ZREMRANGEBYLEX") >>= \(a, i) -> run h (i:ii) $ k a
ZRemRangeByRank x :>>= k -> f h x (readInt "ZREMRANGEBYRANK") >>= \(a, i) -> run h (i:ii) $ k a
ZRemRangeByScore x :>>= k -> f h x (readInt "ZREMRANGEBYSCORE") >>= \(a, i) -> run h (i:ii) $ k a
ZRevRange y x :>>= k -> f h x (readScoreList "ZREVRANGE" y) >>= \(a, i) -> run h (i:ii) $ k a
ZRevRangeByScore y x :>>= k -> f h x (readScoreList "ZREVRANGEBYSCORE" y) >>= \(a, i) -> run h (i:ii) $ k a
ZRevRank x :>>= k -> f h x (readInt'Null "ZREVRANK") >>= \(a, i) -> run h (i:ii) $ k a
ZScan x :>>= k -> f h x (readScan "ZSCAN") >>= \(a, i) -> run h (i:ii) $ k a
ZScore x :>>= k -> f h x (readBulk'Null "ZSCORE") >>= \(a, i) -> run h (i:ii) $ k a
ZUnionStore x :>>= k -> f h x (readInt "ZUNIONSTORE") >>= \(a, i) -> run h (i:ii) $ k a
Sort x :>>= k -> f h x (readList "SORT") >>= \(a, i) -> run h (i:ii) $ k a
PfAdd x :>>= k -> f h x (readBool "PFADD") >>= \(a, i) -> run h (i:ii) $ k a
PfCount x :>>= k -> f h x (readInt "PFCOUNT") >>= \(a, i) -> run h (i:ii) $ k a
PfMerge x :>>= k -> f h x (matchStr "PFMERGE" "OK") >>= \(a, i) -> run h (i:ii) $ k a
Publish x :>>= k -> getNow h x (readInt "PUBLISH") >>= \(a, i) -> run h (i:ii) $ k a
withConnection :: (Connection -> IO (a, [IO ()])) -> Client a
withConnection f = do
p <- ask
let c = connPool p
x <- liftIO $ tryWithResource c $ \h -> f h >>= \(a, i) -> sequence_ i >> return a
maybe (throwM ConnectionsBusy) return x
getLazy :: Connection -> Resp -> (Resp -> Result a) -> IO (a, IO ())
getLazy h x g = do
r <- newIORef (throw $ RedisError "missing response")
C.request x r h
a <- unsafeInterleaveIO $ do
C.sync h
either throwIO return =<< g <$> readIORef r
return (a, a `seq` return ())
{-# INLINE getLazy #-}
getNow :: Connection -> Resp -> (Resp -> Result a) -> IO (a, IO ())
getNow h x g = do
r <- newIORef (throw $ RedisError "missing response")
C.request x r h
C.sync h
a <- either throwIO return =<< g <$> readIORef r
return (a, return ())
{-# INLINE getNow #-}
withTimeout :: Int64 -> Connection -> Connection
withTimeout 0 c = c { C.settings = setSendRecvTimeout 0 (C.settings c) }
withTimeout t c = c { C.settings = setSendRecvTimeout (10 + fromIntegral t) (C.settings c) }
{-# INLINE withTimeout #-}
tryAll :: NonEmpty a -> (a -> IO b) -> IO b
tryAll (a :| []) f = f a
tryAll (a :| aa) f = f a `catchAll` (const $ tryAll (NE.fromList aa) f)
{-# INLINE tryAll #-}