module Network.Stomp (
Command (..),
ClientCommand (..),
ServerCommand (..),
Frame (..),
Connection,
StompUri,
Host,
Destination,
MessageId,
Transaction,
Subscription,
Version,
StompException (..),
connect,
stomp,
connect',
stomp',
disconnect,
send,
send',
subscribe,
unsubscribe,
ack,
nack,
begin,
commit,
abort,
startConsumer,
receiveFrame,
setExcpHandler,
startSendBeat,
startRecvBeat,
beat,
sendTimeout,
recvTimeout,
lastSend,
lastRecv,
versions,
session,
server
)
where
import qualified Data.ByteString.UTF8 as BU
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.List (intercalate)
import Data.Time (UTCTime, getCurrentTime, diffUTCTime)
import Data.Typeable
import Data.Maybe
import Data.Char (toLower)
import qualified Data.Binary.Builder as B
import Data.Binary.Put
import Network.BSD
import Network.URI
import qualified Network.Socket as N
import qualified Control.Exception as E
import Control.Concurrent
import Control.Monad
import System.IO
import qualified Network.Socket.ByteString.Lazy as L
data Frame = Frame {
command :: Command,
headers :: [Header],
body :: BL.ByteString
} deriving Show
data ClientCommand
= SEND
| SUBSCRIBE
| UNSUBSCRIBE
| BEGIN
| COMMIT
| ABORT
| ACK
| NACK
| DISCONNECT
| CONNECT
| STOMP
deriving (Show, Read, Eq)
data ServerCommand
= CONNECTED
| MESSAGE
| RECEIPT
| ERROR
deriving (Show, Read, Eq)
data Command
= CC ClientCommand
| SC ServerCommand
deriving (Show, Read, Eq)
type Header = (String, String)
data Connection = Connection {
handle :: Handle,
versions :: [Version],
session :: String,
server :: String,
listener :: MVar ThreadId,
sendBeat :: MVar ThreadId,
recvBeat :: MVar ThreadId,
sendTimeout :: Int,
recvTimeout :: Int,
lastSend :: MVar UTCTime,
lastRecv :: MVar UTCTime,
sockLock :: MVar (),
closed :: MVar (Maybe StompException),
disconReq :: MVar String,
disconResp :: MVar (),
excpHandle :: MVar (Maybe (StompException -> IO ()))
}
data StompException
= ConnectionError String
| InvalidUri String
| InvalidFrame String
| BrokerError String
| StompIOError E.IOException
deriving (Typeable, Show, Eq)
instance E.Exception StompException
type StompUri = String
type Host = String
type Destination = String
type MessageId = String
type Transaction = String
type Subscription = String
type Version = (Int,Int)
connect :: StompUri -> [Version] -> [Header] -> IO Connection
connect uri vs hs = do
(host, port, hs') <- processUri uri
connect' host port vs (hs ++ hs')
stomp :: StompUri -> [Header] -> IO Connection
stomp uri hs = do
(host, port, hs') <- processUri uri
stomp' host port (hs ++ hs')
connect' :: Host -> PortNumber -> [Version] -> [Header] -> IO Connection
connect' h p [] hs = mkConnection CONNECT h p (hs ++ [("accept-version","1.0")])
connect' h p vs hs = mkConnection CONNECT h p hs'
where hs' = hs ++ [("accept-version", vers)]
vers = intercalate "," $ map go vs
go (v,v') = show v ++ '.':show v'
stomp' :: Host -> PortNumber -> [Header] -> IO Connection
stomp' = mkConnection STOMP
disconnect :: Connection -> [Header] -> IO ()
disconnect con hs = do
let receipt = lookup "receipt" hs
case receipt of
Nothing -> do stop con
sendFrame con (Frame (CC DISCONNECT) hs BL.empty)
Just r -> do putMVar (disconReq con) r
sendFrame con (Frame (CC DISCONNECT) hs BL.empty)
readMVar $ disconResp con
stop con
hClose $ handle con
modifyMVar_ (closed con) $ \x ->
return (Just $ fromMaybe (ConnectionError "Connection closed") x)
send :: Connection -> Destination -> [Header] -> BL.ByteString -> IO ()
send con dest hs body = sendFrame con $ mkSendFrame (dest, hs, body)
send' :: Connection -> [(Destination, [Header], BL.ByteString)] -> IO ()
send' con xs = sendFrames con $ map mkSendFrame xs
mkSendFrame :: (Destination, [Header], BL.ByteString) -> Frame
mkSendFrame (dest, hs, body) = Frame (CC SEND) hs' body
where hs' = hs ++ [("destination", dest)] ++ clh
clh = maybe hdr (\_->[]) $ lookup "content-length" hs
hdr = [("content-length", show $ BL.length body)]
subscribe :: Connection -> Destination -> Subscription -> [Header] -> IO ()
subscribe con dest id hs = sendFrame con (Frame (CC SUBSCRIBE) hs' BL.empty)
where hs' = hs ++ [("destination", dest), ("id", id)]
unsubscribe :: Connection -> Subscription -> [Header] -> IO ()
unsubscribe con id hs = sendFrame con (Frame (CC UNSUBSCRIBE) hs' BL.empty)
where hs' = hs ++ [("id", id)]
ack :: Connection -> Subscription -> MessageId -> [Header] -> IO ()
ack con id msgid hs = sendFrame con (Frame (CC ACK) hs' BL.empty)
where hs' = hs ++ [("subscription", id), ("message-id", msgid)]
nack :: Connection -> Subscription -> MessageId -> [Header] -> IO ()
nack con id msgid hs = sendFrame con (Frame (CC NACK) hs' BL.empty)
where hs' = hs ++ [("subscription", id), ("message-id", msgid)]
begin :: Connection -> Transaction -> [Header] -> IO ()
begin con tid hs = sendFrame con (Frame (CC BEGIN) hs' BL.empty)
where hs' = hs ++ [("transaction", tid)]
commit :: Connection -> Transaction -> [Header] -> IO ()
commit con tid hs = sendFrame con (Frame (CC COMMIT) hs' BL.empty)
where hs' = hs ++ [("transaction", tid)]
abort :: Connection -> Transaction -> [Header] -> IO ()
abort con tid hs = sendFrame con (Frame (CC ABORT) hs' BL.empty)
where hs' = hs ++ [("transaction", tid)]
mkConnection :: ClientCommand -> Host -> PortNumber -> [Header] -> IO Connection
mkConnection cmd host port hs = do
con <- newConn host port hs
sendFrame con (Frame (CC cmd) hs BL.empty)
(Frame (SC cmd) hs' body) <- receiveFrame con
when (cmd /= CONNECTED) $
E.throwIO $ ConnectionError (BL.unpack body)
let (sendBeat, recvBeat) = getBeats hs hs'
return con { recvTimeout = recvBeat
, sendTimeout = sendBeat
, versions = maybe [(1,0)] parseVer $ lookup "version" hs'
, session = fromMaybe [] $ lookup "session" hs'
, server = fromMaybe [] $ lookup "server" hs'
}
parseVer :: String -> [Version]
parseVer vs
| null r = [mkV v]
| otherwise = mkV v : parseVer (tail r)
where (v,r) = break (==',') vs
mkV s = let (m, m') = break (=='.') s in
(read m, read (tail m'))
openSocket :: String -> PortNumber -> IO Handle
openSocket host port = do
proto <- getProtocolNumber "tcp"
sock <- N.socket N.AF_INET N.Stream proto
addr <- N.inet_addr host
N.connect sock (N.SockAddrInet port addr)
h <- N.socketToHandle sock ReadWriteMode
hSetBuffering h (BlockBuffering Nothing)
return h
newConn :: Host -> PortNumber -> [Header] -> IO Connection
newConn host port hs = do
h <- openSocket host port
lstnr <- newEmptyMVar
sBeat <- newEmptyMVar
rBeat <- newEmptyMVar
lSend <- newEmptyMVar
lRecv <- newEmptyMVar
sLock <- newMVar ()
close <- newMVar Nothing
dRcpt <- newEmptyMVar
dLock <- newEmptyMVar
eHndl <- newMVar Nothing
return $ Connection h [] [] [] lstnr sBeat rBeat 0 0 lSend
lRecv sLock close dRcpt dLock eHndl
stompAuth :: String -> Maybe URIAuth
stompAuth str = maybe Nothing auth (parseURI str)
where auth (URI s a _ _ _) =
if map toLower s /= "stomp:" then Nothing else a
fromAuth :: PortNumber -> URIAuth -> (Host, PortNumber, [Header])
fromAuth defPort ua =
(host, portNum, [("login", user), ("passcode", pwd')])
where (user, pwd) = break (==':') (uriUserInfo ua)
pwd' = if null pwd then [] else drop 1 $ init pwd
port = drop 1 (uriPort ua)
portNum = if null port then defPort
else toEnum (read port :: Int)
host = uriRegName ua
processUri :: String -> IO (Host, PortNumber, [Header])
processUri uri =
maybe (E.throwIO $ InvalidUri uri)
(return . fromAuth defaultPort)
(stompAuth uri)
stop :: Connection -> IO ()
stop c = mapM_ go [listener c, sendBeat c, recvBeat c]
where go x = do
t <- tryTakeMVar x
maybe (return ()) killThread t
defaultPort :: PortNumber
defaultPort = 61613
setExcpHandler :: Connection -> (StompException -> IO ()) -> IO ()
setExcpHandler con fun =
modifyMVar_ (excpHandle con) $ \_ -> return $ Just fun
onException :: Connection -> StompException -> IO ()
onException con e =
withMVar (excpHandle con) $ \h -> do
maybe (return ()) (\f -> f e) h
modifyMVar_ (closed con) $ \x ->
return $ mplus x (Just e)
startConsumer :: Connection -> (Frame -> IO ()) -> IO ()
startConsumer c fun = do
t <- tryTakeMVar (listener c)
when (isNothing t) $ do
tid <- forkIO $ E.finally
(consumeFrames c fun)
(tryPutMVar (disconResp c) ())
void $ tryPutMVar (listener c) tid
consumeFrames :: Connection -> (Frame -> IO ()) -> IO ()
consumeFrames con fun =
E.catch
(do
frame <- receiveFrame con
fun frame
rec <- tryTakeMVar (disconReq con)
unless (isJust rec && checkReceipt (fromJust rec) frame) $ do
maybe (return ()) (putMVar (disconReq con)) rec
consumeFrames con fun)
(\(e::StompException) -> onException con e)
checkReceipt :: String -> Frame -> Bool
checkReceipt rec (Frame (SC RECEIPT) hs _) =
maybe False (== rec) $ lookup "receipt-id" hs
checkReceipt _ _ = False
sendFrame :: Connection -> Frame -> IO ()
sendFrame con f = sendBuf con (putFrames [f])
sendFrames :: Connection -> [Frame] -> IO ()
sendFrames con fs = sendBuf con (putFrames fs)
putFrames :: [Frame] -> BL.ByteString
putFrames fs = runPut $ mapM_ put fs
where
put (Frame (CC cmd) hs body) = do
mapM_ (putByteString . BU.fromString)
[show cmd, "\n", hdrToStr cmd hs, "\n"]
unless (BL.null body) $ putLazyByteString body
putWord8 0x00
hdrToStr :: ClientCommand -> [Header] -> String
hdrToStr _ [] = []
hdrToStr CONNECT hs = hdrToStr' id hs
hdrToStr cmd hs = hdrToStr' esc hs
hdrToStr' f xs = unlines $ map go xs
where go (x,y) = f x ++ ":" ++ f y
sendBuf :: Connection -> BL.ByteString -> IO ()
sendBuf con bs =
withMVar (closed con) $ \c ->
if isJust c then E.throwIO (fromJust c)
else
E.catch
(withMVar (sockLock con) $ \_ -> do
BS.hPut (handle con) (BS.concat $ BL.toChunks bs)
hFlush (handle con)
beatTime (lastSend con))
(\(e :: E.IOException) -> E.throwIO $ StompIOError e)
receiveFrame :: Connection -> IO Frame
receiveFrame con = do
eof <- hIsEOF (handle con)
if eof then
E.throwIO $ ConnectionError "Connection closed by broker"
else do
frame <- getFrame (handle con)
beatTime (lastRecv con)
maybe (receiveFrame con) return frame
getFrame :: Handle -> IO (Maybe Frame)
getFrame h = do
line <- liftM (dropWhile (=='\x00')) (readLine h)
if null line then return Nothing
else do
let cmd = SC (read line :: ServerCommand)
headers <- readHeaders h cmd
body <- readBody h headers
return (Just $ Frame cmd headers body)
readHeaders :: Handle -> Command -> IO [Header]
readHeaders h cmd = do
l <- readLine h
if null l then return []
else do hs <- readHeaders h cmd
return (header (unesc l):hs)
case cmd of
(SC CONNECTED) -> return (header l:hs)
_ -> return (header (unesc l):hs)
where
header x = let (name, val) = break (==':') x
in (name, tail val)
readBody :: Handle -> [Header] -> IO BL.ByteString
readBody h hs = maybe (readTill h) (readBuf h) len
where
len = lookup "content-length" hs
readBuf h x = do
bs <- BL.hGet h (read x :: Int)
tr <- BL.hGet h 1
when (tr /= term) $
E.throwIO $ InvalidFrame "Missing frame terminator"
return bs
readTill h = liftM B.toLazyByteString (readTill' h)
readTill' h = do
ch <- BL.hGet h 1
if ch == term then return B.empty
else liftM (B.append $ B.fromLazyByteString ch) (readTill' h)
term = BL.singleton '\x00'
readLine :: Handle -> IO String
readLine h = fmap BU.toString (BS.hGetLine h)
esc :: String -> String
esc = concatMap (\c -> fromMaybe [c] (code c))
where code x = lookup x escMap
escMap = zip ":\n\\" ["\\c","\\n","\\\\"]
unesc :: String -> String
unesc [] = []
unesc [x] = [x]
unesc (x:x':xs)
| [x,x'] == "\\n" = '\n' : unesc xs
| [x,x'] == "\\c" = ':' : unesc xs
| [x,x'] == "\\\\" = '\\' : unesc xs
| otherwise = x:unesc(x':xs)
getBeats :: [Header] -> [Header] -> (Int, Int)
getBeats xs ys = (getBeat cs sr, 2 * getBeat cr ss)
where
getBeat x y
| x /= 0 && y /= 0 = max x y
| otherwise = 0
(cs,cr) = mkBeat xs
(ss,sr) = mkBeat ys
mkBeat :: [Header] -> (Int,Int)
mkBeat = maybe (0,0) parseBeat . lookup "heart-beat"
where
parseBeat = parse . break (==',')
parse (x,y) = (read x :: Int, read (tail y) :: Int)
beatTime :: MVar UTCTime -> IO ()
beatTime v = do
now <- getCurrentTime
b <- tryPutMVar v now
unless b $ modifyMVar_ v $ \_ -> return now
beat :: Connection -> IO ()
beat con = sendBuf con (BL.fromChunks [BU.fromString "\n\x00"])
clientBeat :: Connection -> IO ()
clientBeat con =
E.catch
(do
prev <- readMVar (lastSend con)
now <- getCurrentTime
let diff = floor $ diffUTCTime now prev * 1000
let delay = sendTimeout con
if diff >= delay then do
beat con
threadDelay (1000 * delay)
else threadDelay (1000 * (delay diff))
clientBeat con)
(\(e::StompException) -> onException con e)
brokerBeat :: Connection -> IO ()
brokerBeat con =
E.catch
(do
prev <- readMVar (lastRecv con)
now <- getCurrentTime
let diff = floor $ diffUTCTime now prev * 1000
let delay = recvTimeout con
if diff >= delay then
E.throwIO $ BrokerError "Broker timeout expired"
else threadDelay (1000 * (delay diff))
brokerBeat con)
(\(e::StompException) -> onException con e)
startSendBeat :: Connection -> IO ()
startSendBeat c = do
svar <- tryTakeMVar (sendBeat c)
when (sendTimeout c > 0 && isNothing svar) $ do
tid <- forkIO $ E.finally
(clientBeat c)
(tryPutMVar (disconResp c) ())
void $ tryPutMVar (sendBeat c) tid
startRecvBeat :: Connection -> IO ()
startRecvBeat c = do
rvar <- tryTakeMVar (recvBeat c)
when (recvTimeout c > 0 && isNothing rvar) $ do
tid <- forkIO $ E.finally
(brokerBeat c)
(tryPutMVar (disconResp c) ())
void $ tryPutMVar (recvBeat c) tid