module Network.Hadoop.Rpc
( Connection(..)
, Protocol(..)
, User
, Method
, RawRequest
, RawResponse
, initConnectionV7
, invokeAsync
, invoke
) where
import Control.Applicative ((<$>), (<*>))
import Control.Concurrent (ThreadId, forkIO, newEmptyMVar, putMVar, takeMVar)
import Control.Concurrent.STM
import Control.Exception (SomeException(..), throwIO, handle)
import Control.Monad (forever, when)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.HashMap.Strict as H
import Data.Hashable (Hashable)
import Data.Maybe (fromMaybe, isNothing)
import Data.Monoid (mempty)
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Data.ProtocolBuffers
import Data.ProtocolBuffers.Orphans ()
import Data.Serialize.Get
import Data.Serialize.Put
import Data.Hadoop.Protobuf.Headers
import Data.Hadoop.Types
import qualified Network.Hadoop.Stream as S
import Network.Socket (Socket)
data Connection = Connection
{ cnVersion :: !Int
, cnConfig :: !HadoopConfig
, cnProtocol :: !Protocol
, invokeRaw :: !(Method -> RawRequest -> (RawResponse -> IO ()) -> IO ())
}
data Protocol = Protocol
{ prName :: !Text
, prVersion :: !Int
} deriving (Eq, Ord, Show)
type Method = Text
type RawRequest = ByteString
type RawResponse = Either SomeException ByteString
type CallId = Int
data ConnectionState = ConnectionState
{ csStream :: !S.Stream
, csCallId :: !(TVar CallId)
, csRecvCallbacks :: !(TVar (H.HashMap CallId (RawResponse -> IO ())))
, csSendQueue :: !(TQueue (Method, RawRequest, RawResponse -> IO ()))
, csFatalError :: !(TVar (Maybe SomeException))
}
initConnectionV7 :: HadoopConfig -> Protocol -> Socket -> IO Connection
initConnectionV7 config@HadoopConfig{..} protocol sock = do
csStream <- S.mkSocketStream sock
S.runPut csStream $ do
putByteString "hrpc"
putWord8 7
putWord8 80
putWord8 0
let bs = runPut (encodeMessage context)
putWord32be (fromIntegral (B.length bs))
putByteString bs
csCallId <- newTVarIO 0
csRecvCallbacks <- newTVarIO H.empty
csSendQueue <- newTQueueIO
csFatalError <- newTVarIO Nothing
let cs = ConnectionState{..}
_ <- forkSend cs
_ <- forkRecv cs
return (Connection 7 config protocol (enqueue cs))
where
enqueue :: ConnectionState
-> Method
-> RawRequest
-> (RawResponse -> IO ())
-> IO ()
enqueue ConnectionState{..} method bs k = do
merr <- atomically $ do
merr <- readTVar csFatalError
when (isNothing merr) $ writeTQueue csSendQueue (method, bs, k)
return merr
case merr of
Just err -> throwIO err
Nothing -> return ()
forkSend :: ConnectionState -> IO ThreadId
forkSend cs@ConnectionState{..} = forkIO $ handle (onSocketError cs) $ forever $ do
bs <- atomically $ do
(method, requestBytes, k) <- readTQueue csSendQueue
callId <- readTVar csCallId
modifyTVar' csCallId succ
modifyTVar' csRecvCallbacks (H.insert callId k)
return $ runPut $ encodeLengthPrefixedMessage (requestHeaderProto callId)
>> encodeLengthPrefixedMessage (requestProto method requestBytes)
S.runPut csStream $ do
putWord32be (fromIntegral (B.length bs))
putByteString bs
forkRecv :: ConnectionState -> IO ThreadId
forkRecv cs@ConnectionState{..} = forkIO $ handle (onSocketError cs) $ forever $ do
hdr <- S.maybeGet csStream decodeLengthPrefixedMessage
case hdr of
Nothing -> throwIO ConnectionClosed
Just rspHdr -> do
onResponse <- fromMaybe (return $ return ())
<$> lookupDelete csRecvCallbacks (fromIntegral $ getField $ rspCallId rspHdr)
case getField (rspStatus rspHdr) of
Success -> S.runGet csStream getResponse >>= onResponse . Right
_ -> S.runGet csStream getError >>= onResponse . Left . SomeException
onSocketError :: ConnectionState -> SomeException -> IO ()
onSocketError ConnectionState{..} ex = do
ks <- atomically $ do
writeTVar csFatalError (Just ex)
sks <- map (\(_,_,k) -> k) <$> unfoldM (tryReadTQueue csSendQueue)
rks <- H.elems <$> readTVar csRecvCallbacks
return (sks ++ rks)
mapM_ (\k -> handle ignore $ k $ Left ex) ks
ignore :: SomeException -> IO ()
ignore _ = return ()
context = IpcConnectionContext
{ ctxProtocol = putField (Just (prName protocol))
, ctxUserInfo = putField (Just UserInformation
{ effectiveUser = putField (Just hcUser)
, realUser = mempty
})
}
requestHeaderProto callId = RpcRequestHeader
{ reqKind = putField (Just ProtocolBuffer)
, reqOp = putField (Just FinalPacket)
, reqCallId = putField (fromIntegral callId)
}
requestProto method bytes = RpcRequest
{ reqMethodName = putField method
, reqBytes = putField (Just bytes)
, reqProtocolName = putField (prName protocol)
, reqProtocolVersion = putField (fromIntegral (prVersion protocol))
}
unfoldM :: Monad m => m (Maybe a) -> m [a]
unfoldM f = go []
where
go xs = do
m <- f
case m of
Nothing -> return xs
Just x -> go (xs ++ [x])
lookupDelete :: (Eq k, Hashable k) => TVar (H.HashMap k v) -> k -> IO (Maybe v)
lookupDelete var k = atomically $ do
hm <- readTVar var
writeTVar var (H.delete k hm)
return (H.lookup k hm)
getResponse :: Get ByteString
getResponse = do
n <- fromIntegral <$> getWord32be
getByteString n
getError :: Get RemoteError
getError = RemoteError <$> getText <*> getText
where
getText = do
n <- fromIntegral <$> getWord32be
T.decodeUtf8 <$> getByteString n
invoke :: (Decode b, Encode a) => Connection -> Text -> a -> IO b
invoke connection method arg = do
mv <- newEmptyMVar
invokeAsync connection method arg (putMVar mv)
e <- takeMVar mv
case e of
Left ex -> throwIO ex
Right x -> return x
invokeAsync :: (Decode b, Encode a) => Connection -> Text -> a -> (Either SomeException b -> IO ()) -> IO ()
invokeAsync Connection{..} method arg k = invokeRaw method (encodeBytes arg) k'
where
k' (Left err) = k (Left err)
k' (Right bs) = k (decodeBytes bs)
encodeBytes :: Encode a => a -> ByteString
encodeBytes = runPut . encodeMessage
decodeBytes :: Decode a => ByteString -> Either SomeException a
decodeBytes bs = case runGetState decodeMessage bs 0 of
Left err -> decodeError (T.pack err)
Right (x, "") -> Right x
Right (_, _) -> decodeError "decoded response but did not consume enough bytes"
where
decodeError = Left . SomeException . DecodeError