{-# LANGUAGE CPP, OverloadedStrings, RecordWildCards, ScopedTypeVariables, FlexibleContexts, MultiWayIf #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Network.Riak.Connection.Internal
(
Network.Riak.Connection.Internal.connect
, disconnect
, setClientID
, defaultClient
, makeClientID
, exchange
, exchangeMaybe
, exchange_
, pipeline
, pipelineMaybe
, pipeline_
, sendRequest
, recvResponse
, recvMaybeResponse
, recvResponse_
) where
import Control.Concurrent.Async (async, waitBoth)
import Control.Exception (Exception, IOException, throwIO, bracketOnError)
import Control.Monad (forM_, replicateM)
import Data.Binary.Put (Put, putWord32be, runPut)
import Data.IORef (newIORef, readIORef, writeIORef)
import Data.Int (Int64)
import Network.Riak.Connection.NoPush (setNoPush)
import Network.Riak.Debug as Debug
import Network.Riak.Protocol.ErrorResponse
import Network.Riak.Protocol.SetClientIDRequest
import Network.Riak.Tag (getTag, putTag)
import Network.Riak.Types.Internal hiding (MessageTag(..))
import Network.Socket as Socket
import Numeric (showHex)
import System.Random (randomIO)
import Text.ProtocolBuffers (messageGetM, messagePutM, messageSize)
import Text.ProtocolBuffers.Get (Get, Result(..), getWord32be, runGet)
import qualified Control.Exception as E
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy.Char8 as L
import qualified Network.Riak.Types.Internal as T
import qualified Network.Socket.ByteString as B
import qualified Network.Socket.ByteString.Lazy as L
defaultClient :: Client
defaultClient = Client {
host = "127.0.0.1"
, port = "8087"
, clientID = L.empty
}
setClientID :: Connection -> ClientID -> IO ()
setClientID conn i = do
sendRequest conn $ SetClientIDRequest i
recvResponse_ conn T.SetClientIDResponse
makeClientID :: IO ClientID
makeClientID = do
r <- randomIO :: IO Int
return . L.append "hs_" . L.pack . showHex (abs r) $ ""
addClientID :: Client -> IO Client
addClientID client
| L.null (clientID client) = do
i <- makeClientID
return client { clientID = i }
| otherwise = return client
connect :: Client -> IO Connection
connect cli0 = do
client@Client{..} <- addClientID cli0
let hints = defaultHints {
addrFlags = [AI_ADDRCONFIG]
, addrSocketType = Stream
}
debug "connect" $ "server " ++ host ++ ":" ++ port ++ ", client ID " ++
L.unpack clientID
ais <- getAddrInfo (Just hints) (Just host) (Just port)
let ai = case ais of
(a:_) -> a
_ -> moduleError "connect" $
"could not look up server " ++ host ++ ":" ++ port
onIOException "connect" $
bracketOnError
(socket (addrFamily ai) (addrSocketType ai) (addrProtocol ai))
close $
\sock -> do
Socket.connect sock (addrAddress ai)
buf <- newIORef L.empty
let conn = Connection sock client buf
setClientID conn clientID
return conn
disconnect :: Connection -> IO ()
disconnect Connection{..} = onIOException "disconnect" $ do
debug "disconnect" $ "server " ++ host connClient ++ ":" ++ port connClient ++
", client ID " ++ L.unpack (clientID connClient)
close connSock
writeIORef connBuffer L.empty
recvBufferSize :: Integral a => a
recvBufferSize = 16384
{-# INLINE recvBufferSize #-}
recvExactly :: Connection -> Int64 -> IO L.ByteString
recvExactly Connection{..} n0
| n0 <= 0 = return L.empty
| otherwise = do
bs <- readIORef connBuffer
let (h,t) = L.splitAt n0 bs
len = L.length h
if len == n0
then writeIORef connBuffer t >> return h
else go (reverse (L.toChunks h)) (n0-len)
where
maxInt = fromIntegral (maxBound :: Int)
go (s:acc) n
| n < 0 = do
let (h,t) = B.splitAt (B.length s + fromIntegral n) s
writeIORef connBuffer $! L.fromChunks [t]
return $ L.fromChunks (reverse (h:acc))
go acc n
| n == 0 = do
writeIORef connBuffer L.empty
return $ L.fromChunks (reverse acc)
| otherwise = do
let n' = max recvBufferSize $ min n maxInt
bs <- B.recv connSock (fromIntegral n')
let len = B.length bs
if len == 0
then moduleError "recvExactly" "short read from network"
else go (bs:acc) (n - fromIntegral len)
recvGet :: Connection -> Get a -> IO a
recvGet Connection{..} get = do
let refill = do
bs <- L.recv connSock recvBufferSize
if L.null bs
then shutdown connSock ShutdownReceive >> return Nothing
else return (Just bs)
step (Failed _ err) = moduleError "recvGet" err
step (Finished bs _ r) = writeIORef connBuffer bs >> return r
step (Partial k) = (step . k) =<< refill
mbs <- do
buf <- readIORef connBuffer
if L.null buf
then refill
else return (Just buf)
case mbs of
Just bs -> step $ runGet get bs
Nothing -> moduleError "recvGet" "socket closed"
recvGetN :: Connection -> Int64 -> Get a -> IO a
recvGetN conn n get = do
bs <- recvExactly conn n
case runGet get bs of
Finished _ _ r -> return r
Partial k -> case k Nothing of
Finished _ _ r -> return r
Failed _ err -> moduleError "recvGetN" err
Partial _ -> moduleError "recvGetN"
"parser wants more input!?"
Failed _ err -> moduleError "recvGetN" err
putRequest :: (Request req) => req -> Put
putRequest req = do
putWord32be (fromIntegral (1 + messageSize req))
putTag (messageTag req)
messagePutM req
instance Exception ErrorResponse
throwError :: ErrorResponse -> IO a
throwError = throwIO
getResponse :: Response a => Connection -> Int64 -> a -> T.MessageTag -> IO a
getResponse conn len _ expected = do
tag <- recvGet conn getTag
if | tag == expected -> recvGetN conn (len-1) messageGetM
| tag == T.ErrorResponse -> throwError =<< recvGetN conn (len-1) messageGetM
| otherwise ->
moduleError "getResponse" $ "received unexpected response: expected " ++
show expected ++ ", received " ++ show tag
exchange :: Exchange req resp => Connection -> req -> IO resp
exchange conn@Connection{..} req = do
debug "exchange" $ ">>> " ++ showM req
onIOException ("exchange " ++ show (messageTag req)) $ do
sendRequest conn req
recvResponse conn
exchangeMaybe :: Exchange req resp => Connection -> req -> IO (Maybe resp)
exchangeMaybe conn@Connection{..} req = do
debug "exchangeMaybe" $ ">>> " ++ showM req
onIOException ("exchangeMaybe " ++ show (messageTag req)) $ do
sendRequest conn req
recvMaybeResponse conn
exchange_ :: Request req => Connection -> req -> IO ()
exchange_ conn req = do
debug "exchange_" $ ">>> " ++ showM req
onIOException ("exchange_ " ++ show (messageTag req)) $ do
sendRequest conn req
recvResponse_ conn (expectedResponse req)
sendAll :: Socket -> L.ByteString -> IO ()
sendAll sock bs = do
setNoPush sock True
L.sendAll sock bs
setNoPush sock False
sendRequest :: (Request req) => Connection -> req -> IO ()
sendRequest Connection{..} = sendAll connSock . runPut . putRequest
recvResponse :: (Response a) => Connection -> IO a
recvResponse conn = debugRecv showM $ go undefined where
go :: Response b => b -> IO b
go dummy = do
len <- fromIntegral `fmap` recvGet conn getWord32be
getResponse conn len dummy (messageTag dummy)
recvResponse_ :: Connection -> T.MessageTag -> IO ()
recvResponse_ conn expected = debugRecv show $ do
len <- fromIntegral `fmap` recvGet conn getWord32be
recvCorrectTag "recvResponse_" conn expected (len-1) ()
recvMaybeResponse :: (Response a) => Connection -> IO (Maybe a)
recvMaybeResponse conn = debugRecv (maybe "Nothing" (("Just " ++) . showM)) $
go undefined where
go :: Response b => b -> IO (Maybe b)
go dummy = do
len <- fromIntegral `fmap` recvGet conn getWord32be
let tag = messageTag dummy
if len == 1
then recvCorrectTag "recvMaybeResponse" conn tag 1 Nothing
else Just `fmap` getResponse conn len dummy tag
recvCorrectTag :: String -> Connection -> T.MessageTag -> Int64 -> a -> IO a
recvCorrectTag func conn expected len v = do
tag <- recvGet conn getTag
if | tag == expected -> recvExactly conn (len-1) >> return v
| tag == T.ErrorResponse -> throwError =<< recvGetN conn len messageGetM
| otherwise -> moduleError func $
"received unexpected response: expected " ++
show expected ++ ", received " ++ show tag
debugRecv :: (a -> String) -> IO a -> IO a
#ifdef DEBUG
debugRecv f act = do
r <- act
debug "recv" $ "<<< " ++ f r
return r
#else
debugRecv _ act = act
{-# INLINE debugRecv #-}
#endif
pipe :: (Request req) =>
(Connection -> IO resp) -> Connection -> [req] -> IO [resp]
pipe _ _ [] = return []
pipe receive conn@Connection{..} reqs = do
let numReqs = length reqs
let tag = show (messageTag (head reqs))
if Debug.level > 1
then forM_ reqs $ \req -> debug "pipe" $ ">>> " ++ showM req
else debug "pipe" $ ">>> " ++ show numReqs ++ "x " ++ tag
receiveResps <- async . replicateM numReqs $ receive conn
sendReqs <- async . sendAll connSock . runPut . mapM_ putRequest $ reqs
(_, resps) <- onIOException ("pipe " ++ tag) $
waitBoth sendReqs receiveResps
return resps
pipeline :: (Exchange req resp) => Connection -> [req] -> IO [resp]
pipeline = pipe recvResponse
pipelineMaybe :: (Exchange req resp) => Connection -> [req] -> IO [Maybe resp]
pipelineMaybe = pipe recvMaybeResponse
pipeline_ :: (Request req) => Connection -> [req] -> IO ()
pipeline_ _ [] = return ()
pipeline_ conn@Connection{..} reqs = do
receiveResps <- async $
forM_ reqs (recvResponse_ conn . expectedResponse)
if Debug.level > 1
then forM_ reqs $ \req -> debug "pipe" $ ">>> " ++ showM req
else debug "pipe" $ ">>> " ++ show (length reqs) ++ "x " ++
show (messageTag (head reqs))
sendReqs <- async . sendAll connSock . runPut . mapM_ putRequest $ reqs
_ <- onIOException "pipeline_" $ waitBoth sendReqs receiveResps
return ()
onIOException :: String -> IO a -> IO a
onIOException func act =
act `E.catch` \(e::IOException) -> do
let s = show e
debug func $ "caught IO exception: " ++ s
moduleError func s
moduleError :: String -> String -> a
moduleError = netError "Network.Riak.Connection.Internal"