module Network.Hadoop.Rpc
( Connection(..)
, Protocol(..)
, User
, Method
, RawRequest
, RawResponse
, initConnectionV9
, invokeAsync
, invoke
) where
import Control.Applicative ((<$>))
import Control.Concurrent (ThreadId, forkIO, newEmptyMVar, putMVar, takeMVar)
import Control.Concurrent.STM
import Control.Exception (SomeException(..), throwIO, handle)
import Control.Monad (forever, when)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy as L
import qualified Data.HashMap.Strict as H
import Data.Hashable (Hashable)
import Data.Maybe (fromMaybe, isNothing)
import Data.Monoid ((<>))
import Data.Monoid (mempty)
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.UUID as UUID
import System.Random (randomIO)
import Data.ProtocolBuffers
import Data.ProtocolBuffers.Orphans ()
import Data.Serialize.Get
import Data.Serialize.Put
import qualified Data.Hadoop.Protobuf.Headers as P
import Data.Hadoop.Types
import qualified Network.Hadoop.Stream as S
import Network.Socket (Socket)
data Connection = Connection
{ cnVersion :: !Int
, cnConfig :: !HadoopConfig
, cnProtocol :: !Protocol
, invokeRaw :: !(Method -> RawRequest -> (RawResponse -> IO ()) -> IO ())
}
data Protocol = Protocol
{ prName :: !Text
, prVersion :: !Int
} deriving (Eq, Ord, Show)
type Method = Text
type RawRequest = ByteString
type RawResponse = Either SomeException ByteString
type CallId = Int
newtype ClientId = ClientId { unClientId :: ByteString }
data ConnectionState = ConnectionState
{ csStream :: !S.Stream
, csClientId :: !ClientId
, csCallId :: !(TVar CallId)
, csRecvCallbacks :: !(TVar (H.HashMap CallId (RawResponse -> IO ())))
, csSendQueue :: !(TQueue (Method, RawRequest, RawResponse -> IO ()))
, csFatalError :: !(TVar (Maybe SomeException))
}
initConnectionV9 :: HadoopConfig -> Protocol -> Socket -> IO Connection
initConnectionV9 config@HadoopConfig{..} protocol sock = do
csStream <- S.mkSocketStream sock
csClientId <- mkClientId
S.runPut csStream $ do
putByteString "hrpc"
putWord8 9
putWord8 0
putWord8 0
putMessage $ delimitedBytesL (rpcRequestHeaderProto csClientId (3))
<> delimitedBytesL (contextProto protocol hcUser)
csCallId <- newTVarIO 0
csRecvCallbacks <- newTVarIO H.empty
csSendQueue <- newTQueueIO
csFatalError <- newTVarIO Nothing
let cs = ConnectionState{..}
_ <- forkSend cs
_ <- forkRecv cs
return (Connection 7 config protocol (enqueue cs))
where
enqueue :: ConnectionState
-> Method
-> RawRequest
-> (RawResponse -> IO ())
-> IO ()
enqueue ConnectionState{..} method bs k = do
merr <- atomically $ do
merr <- readTVar csFatalError
when (isNothing merr) $ writeTQueue csSendQueue (method, bs, k)
return merr
case merr of
Just err -> throwIO err
Nothing -> return ()
forkSend :: ConnectionState -> IO ThreadId
forkSend cs@ConnectionState{..} = forkIO $ handle (onSocketError cs) $ forever $ do
bs <- atomically $ do
(method, requestBytes, k) <- readTQueue csSendQueue
callId <- readTVar csCallId
modifyTVar' csCallId succ
modifyTVar' csRecvCallbacks (H.insert callId k)
return $ delimitedBytesL (rpcRequestHeaderProto csClientId callId)
<> delimitedBytesL (requestHeaderProto protocol method)
<> L.fromStrict requestBytes
S.runPut csStream (putMessage bs)
forkRecv :: ConnectionState -> IO ThreadId
forkRecv cs@ConnectionState{..} = forkIO $ handle (onSocketError cs) $ forever $ do
mget <- S.maybeGet csStream $ do
n <- fromIntegral <$> getWord32be
isolate n $ do
hdr <- decodeLengthPrefixedMessage
msg <- case getField (P.rspStatus hdr) of
P.Success -> Right <$> getRemaining
_ -> return . Left . SomeException $ rspError hdr
return (hdr, msg)
case mget of
Nothing -> throwIO ConnectionClosed
Just (hdr, msg) -> do
onResponse <- fromMaybe (return $ return ())
<$> lookupDelete csRecvCallbacks (fromIntegral $ getField $ P.rspCallId hdr)
onResponse msg
onSocketError :: ConnectionState -> SomeException -> IO ()
onSocketError ConnectionState{..} ex = do
ks <- atomically $ do
writeTVar csFatalError (Just ex)
sks <- map (\(_,_,k) -> k) <$> unfoldM (tryReadTQueue csSendQueue)
rks <- H.elems <$> readTVar csRecvCallbacks
return (sks ++ rks)
mapM_ (\k -> handle ignore $ k $ Left ex) ks
ignore :: SomeException -> IO ()
ignore _ = return ()
mkClientId :: IO ClientId
mkClientId = ClientId . L.toStrict . UUID.toByteString <$> randomIO
unfoldM :: Monad m => m (Maybe a) -> m [a]
unfoldM f = go []
where
go xs = do
m <- f
case m of
Nothing -> return xs
Just x -> go (xs ++ [x])
lookupDelete :: (Eq k, Hashable k) => TVar (H.HashMap k v) -> k -> IO (Maybe v)
lookupDelete var k = atomically $ do
hm <- readTVar var
writeTVar var (H.delete k hm)
return (H.lookup k hm)
contextProto :: Protocol -> User -> P.IpcConnectionContext
contextProto protocol user = P.IpcConnectionContext
{ P.ctxProtocol = putField (Just (prName protocol))
, P.ctxUserInfo = putField (Just P.UserInformation
{ P.effectiveUser = putField (Just user)
, P.realUser = mempty
})
}
rpcRequestHeaderProto :: ClientId -> CallId -> P.RpcRequestHeader
rpcRequestHeaderProto clientId callId = P.RpcRequestHeader
{ P.reqKind = putField (Just P.ProtocolBuffer)
, P.reqOp = putField (Just P.FinalPacket)
, P.reqCallId = putField (fromIntegral callId)
, P.reqClientId = putField (unClientId clientId)
, P.reqRetryCount = putField (Just (1))
}
requestHeaderProto :: Protocol -> Method -> P.RequestHeader
requestHeaderProto protocol method = P.RequestHeader
{ P.reqMethodName = putField method
, P.reqProtocolName = putField (prName protocol)
, P.reqProtocolVersion = putField (fromIntegral (prVersion protocol))
}
rspError :: P.RpcResponseHeader -> RemoteError
rspError rsp = RemoteError (fromMaybe "unknown error" $ getField $ P.rspExceptionClassName rsp)
(fromMaybe "unknown error" $ getField $ P.rspErrorMsg rsp)
putMessage :: L.ByteString -> Put
putMessage body = do
putWord32be (fromIntegral (L.length body))
putLazyByteString body
getRemaining :: Get ByteString
getRemaining = do
n <- remaining
getByteString n
invoke :: (Decode b, Encode a) => Connection -> Text -> a -> IO b
invoke connection method arg = do
mv <- newEmptyMVar
invokeAsync connection method arg (putMVar mv)
e <- takeMVar mv
case e of
Left ex -> throwIO ex
Right x -> return x
invokeAsync :: (Decode b, Encode a) => Connection -> Text -> a -> (Either SomeException b -> IO ()) -> IO ()
invokeAsync Connection{..} method arg k = invokeRaw method (delimitedBytes arg) k'
where
k' (Left err) = k (Left err)
k' (Right bs) = k (fromDelimitedBytes bs)
delimitedBytes :: Encode a => a -> ByteString
delimitedBytes = runPut . encodeLengthPrefixedMessage
delimitedBytesL :: Encode a => a -> L.ByteString
delimitedBytesL = L.fromStrict . delimitedBytes
fromDelimitedBytes :: Decode a => ByteString -> Either SomeException a
fromDelimitedBytes bs = case runGetState decodeLengthPrefixedMessage bs 0 of
Left err -> decodeError (T.pack err)
Right (x, "") -> Right x
Right (_, _) -> decodeError "decoded response but did not consume enough bytes"
where
decodeError = Left . SomeException . DecodeError