module Database.Redis.Internal where
import Prelude hiding (putStrLn, putStr, catch)
import Control.Concurrent (ThreadId, myThreadId)
import qualified Control.Concurrent.RLock as RLock
import Data.IORef
import qualified System.IO as IO
import System.IO.UTF8 (putStrLn, putStr)
import qualified Data.ByteString as B
import Data.ByteString (ByteString)
import Data.ByteString.Char8 (readInt)
import qualified Data.ByteString.UTF8 as U
import Data.Maybe (fromJust, isNothing, isJust)
import Data.List (intersperse)
import qualified Data.Map as Map
import Data.Map (Map(..))
import Control.Monad (when)
import Control.Exception (bracket, bracketOnError, catch, SomeException)
import Database.Redis.ByteStringClass
#if __GLASGOW_HASKELL__ < 700
import Control.Exception (block)
#else
import Control.Exception.Base (mask)
block f = mask $ \ _ -> f
#endif
tracebs = putStrLn . U.toString
tracebs' = putStr . U.toString
data RedisState = RedisState { server :: (String, String),
database :: Int,
handle :: IO.Handle,
isSubscribed :: Int,
renamedCommands :: Map ByteString ByteString
}
data Redis = Redis {r_lock :: RLock.RLock,
r_st :: IORef RedisState}
deriving Eq
newRedis :: (String, String) -> IO.Handle -> IO Redis
newRedis server h = do l <- RLock.new
st <- newIORef $ RedisState server 0 h 0 Map.empty
return $ Redis l st
data Command = CInline ByteString
| CMInline [ByteString]
| CBulk [ByteString] ByteString
| CMBulk [ByteString]
data Reply s = RTimeout
| RParseError String
| ROk
| RPong
| RQueued
| RError String
| RInline s
| RInt Int
| RBulk (Maybe s)
| RMulti (Maybe [Reply s])
deriving Eq
showbs :: BS s => s -> String
showbs = U.toString . toBS
instance BS s => Show (Reply s) where
show RTimeout = "RTimeout"
show (RParseError msg) = "RParseError: " ++ msg
show ROk = "ROk"
show RPong = "RPong"
show RQueued = "RQueued"
show (RError msg) = "RError: " ++ msg
show (RInline s) = "RInline (" ++ (showbs s) ++ ")"
show (RInt a) = "RInt " ++ show a
show (RBulk (Just s)) = "RBulk " ++ showbs s
show (RBulk Nothing) = "RBulk Nil"
show (RMulti (Just rs)) = "RMulti [" ++ join rs ++ "]"
where join = concat . intersperse ", " . map show
show (RMulti Nothing) = "RMulti Nil"
data Message s = MSubscribe s Int
| MUnsubscribe s Int
| MPSubscribe s Int
| MPUnsubscribe s Int
| MMessage s s
| MPMessage s s s
deriving (Eq, Show)
urn = U.fromString "\r\n"
uspace = U.fromString " "
uminus = U.fromString "-"
uplus = U.fromString "+"
ucolon = U.fromString ":"
ubucks = U.fromString "$"
uasterisk = U.fromString "*"
hPutRn h = B.hPut h urn
takeState :: Redis -> IO RedisState
takeState r = block $ do RLock.acquire $ r_lock r
readIORef $ r_st r
putState :: Redis -> RedisState -> IO ()
putState r s = block $ do lstate <- RLock.state $ r_lock r
mytid <- myThreadId
case lstate of
Just (mytid, _) -> do writeIORef (r_st r) s
RLock.release $ r_lock r
otherwise -> error "putState: trying put state that was not took"
putStateUnmodified :: Redis -> IO ()
putStateUnmodified r = RLock.release $ r_lock r
inState :: Redis -> (RedisState -> IO (RedisState, a)) -> IO a
inState r action = bracketOnError (takeState r) (\_ -> putStateUnmodified r)
$ \s -> do (s', a) <- action s
putState r s'
return a
inState_ :: Redis -> (RedisState -> IO RedisState) -> IO ()
inState_ r action = bracketOnError (takeState r) (\_ -> putStateUnmodified r) (\s -> action s >>= putState r)
withState :: Redis -> (RedisState -> IO a) -> IO a
withState r action = bracket (takeState r) (\_ -> putStateUnmodified r) action
withState' = flip withState
send :: IO.Handle -> [ByteString] -> IO ()
send h [] = return ()
send h (bs:ls) = B.hPut h bs >> B.hPut h uspace >> send h ls
lookupRenamed :: RedisState -> ByteString -> ByteString
lookupRenamed r c = let c' = Map.findWithDefault c c (renamedCommands r)
in if B.null c'
then error $ "Command " ++ (fromBS c :: String) ++ " is disabled"
else c'
sendCommand :: RedisState -> Command -> IO ()
sendCommand r (CInline bs) = let h = handle r
cmd = lookupRenamed r bs
in B.hPut h cmd >> hPutRn h >> IO.hFlush h
sendCommand r (CMInline (l:ls)) = let h = handle r
cmd = lookupRenamed r l
in send h (cmd:ls) >> hPutRn h >> IO.hFlush h
sendCommand r (CBulk (l:ls) bs) = let h = handle r
size = U.fromString $ show $ B.length bs
cmd = lookupRenamed r l
in do send h (cmd:ls)
B.hPut h uspace
B.hPut h size
hPutRn h
B.hPut h bs
hPutRn h
IO.hFlush h
sendCommand r (CMBulk s@(c:cs)) = let h = handle r
sendls [] = return ()
sendls (bs:ls) = let size = U.fromString . show . B.length
in do B.hPut h ubucks
B.hPut h $ size bs
hPutRn h
B.hPut h bs
hPutRn h
sendls ls
c' = lookupRenamed r c
in do B.hPut h uasterisk
B.hPut h $ U.fromString $ show $ length s
hPutRn h
sendls (c':cs)
IO.hFlush h
sendCommand' = flip sendCommand
recv :: BS s => RedisState -> IO (Reply s)
recv r = do first <- trim `fmap` B.hGetLine h
case U.uncons first of
Just ('-', rest) -> recv_err rest
Just ('+', rest) -> recv_inline rest
Just (':', rest) -> recv_int rest
Just ('$', rest) -> recv_bulk rest
Just ('*', rest) -> recv_multi rest
where
h = handle r
trim = B.takeWhile (\c -> c /= 13 && c /= 10)
safeFromBS constructor bs = (return $! constructor $! fromBS bs)
`catch`
(\e -> let msg = show (e :: SomeException)
in return $ RParseError msg)
recv_err rest = return $ RError $ U.toString rest
recv_inline rest = case rest of
"OK" -> return ROk
"PONG" -> return RPong
"QUEUED" -> return RQueued
_ -> safeFromBS RInline rest
recv_int rest = let reply = fst $ fromJust $ readInt rest
in return $ RInt reply
recv_bulk rest = let size = fst $ fromJust $ readInt rest
in do body <- recv_bulk_body size
maybe (return $ RBulk Nothing) (safeFromBS (RBulk . Just)) body
recv_bulk_body (1) = return Nothing
recv_bulk_body size = do body <- B.hGet h (size + 2)
let reply = B.take size body
return $ Just reply
recv_multi rest = let cnt = fst $ fromJust $ readInt rest
in do bulks <- recv_multi_n cnt
return $ RMulti bulks
recv_multi_n (1) = return Nothing
recv_multi_n 0 = return $ Just []
recv_multi_n n = do this <- recv r
tail <- fromJust `fmap` recv_multi_n (n1)
return $ Just (this : tail)
wait :: RedisState -> Int -> IO Bool
wait rs = IO.hWaitForInput (handle rs)