module Network.Nats (
Nats
, NatsSID
, connect
, connectSettings
, NatsHost(..)
, NatsSettings(..)
, defaultSettings
, NatsException
, MsgCallback
, subscribe
, unsubscribe
, publish
, request
, requestMany
, disconnect
) where
import System.IO
import Control.Concurrent.MVar
import Control.Concurrent
import qualified Network.Socket as S
import Network.Socket (SocketOption(KeepAlive, NoDelay), setSocketOption, getAddrInfo, SockAddr(..))
import Control.Monad (forever, replicateM)
import Data.Dequeue as D
import Control.Applicative ((<$>))
import Data.Typeable
import qualified Data.Foldable as FOLD
import Control.Exception (bracket, bracketOnError, throwIO, catch, IOException, AsyncException, Exception, SomeException,
catches, Handler(..))
import System.Random (randomRIO)
import Data.IORef
import System.Timeout
import Control.Concurrent.Async (concurrently)
import qualified Data.Map.Strict as Map
import qualified Data.ByteString.Lazy.Char8 as BL
import qualified Data.ByteString.Char8 as BS
import Data.Char (toLower, isUpper)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8)
import qualified Data.Aeson as AE
import Data.Aeson.TH (deriveJSON, defaultOptions, fieldLabelModifier)
import qualified Network.URI as URI
pingInterval :: Int
pingInterval = 3000000
timeoutInterval :: Int
timeoutInterval = 1000000
data NatsException = NatsException String
deriving (Show, Typeable)
instance Exception NatsException
data NatsConnectionOptions = NatsConnectionOptions {
natsConnUser :: String
, natsConnPass :: String
, natsConnVerbose :: Bool
, natsConnPedantic :: Bool
, natsConnSslRequired :: Bool
} deriving (Show)
defaultConnectionOptions :: NatsConnectionOptions
defaultConnectionOptions = NatsConnectionOptions{natsConnUser="nats",natsConnPass="nats", natsConnVerbose=True,
natsConnPedantic=True, natsConnSslRequired=False}
$(deriveJSON defaultOptions{fieldLabelModifier =(
let insertUnderscore acc chr
| isUpper chr = chr : '_' : acc
| otherwise = chr : acc
in
map toLower . drop 1 . reverse . foldl insertUnderscore [] . drop 8
)} ''NatsConnectionOptions)
data NatsServerInfo = NatsServerInfo {
natsSvrAuthRequired :: Bool
} deriving (Show)
$(deriveJSON defaultOptions{fieldLabelModifier =(
let insertUnderscore acc chr
| isUpper chr = chr : '_' : acc
| otherwise = chr : acc
in
map toLower . drop 1 . reverse . foldl insertUnderscore [] . drop 7
)} ''NatsServerInfo)
newtype NatsSID = NatsSID Int deriving (Num, Ord, Eq)
instance Show NatsSID where
show (NatsSID num) = show num
instance Read NatsSID where
readsPrec x1 x2 = map (\(a,rest) -> (NatsSID a, rest)) $ readsPrec x1 x2
type MsgCallback = NatsSID
-> String
-> BL.ByteString
-> Maybe String
-> IO ()
data NatsSubscription = NatsSubscription {
subSubject :: Subject
, subQueue :: Maybe Subject
, subCallback :: MsgCallback
, subSid :: NatsSID
}
type FifoQueue = D.BankersDequeue (Maybe T.Text -> IO ())
data Nats = Nats {
natsSettings :: NatsSettings
, natsRuntime :: MVar (Handle,
FifoQueue,
Bool,
MVar ()
)
, natsThreadId :: MVar ThreadId
, natsNextSid :: IORef NatsSID
, natsSubMap :: IORef (Map.Map NatsSID NatsSubscription)
}
data NatsHost = NatsHost {
natsHHost :: String
, natsHPort :: Int
, natsHUser :: String
, natsHPass :: String
}
data NatsSettings = NatsSettings {
natsHosts :: [NatsHost]
, natsOnConnect :: Nats -> (String, Int) -> IO ()
, natsOnDisconnect :: Nats -> String -> IO ()
}
defaultSettings :: NatsSettings
defaultSettings = (NatsSettings {
natsHosts = [(NatsHost "localhost" 4222 "nats" "nats")]
, natsOnConnect = \_ _ -> (return ())
, natsOnDisconnect = \_ _ -> (return ())
})
data NatsSvrMessage =
NatsSvrMsg { msgSubject::String, msgSid::NatsSID, msgText::BS.ByteString, msgReply::Maybe String}
| NatsSvrOK
| NatsSvrError T.Text
| NatsSvrPing
| NatsSvrPong
| NatsSvrInfo NatsServerInfo
deriving (Show)
newtype Subject = Subject String deriving (Show)
subjectToStr :: Subject -> String
subjectToStr (Subject str) = str
makeSubject :: String -> Subject
makeSubject "" = error "Empty subject"
makeSubject str
| any (<=' ') str = error $ "Subject contains incorrect characters: " ++ str
| otherwise = Subject str
data NatsClntMessage =
NatsClntPing
| NatsClntPong
| NatsClntSubscribe Subject NatsSID (Maybe Subject)
| NatsClntUnsubscribe NatsSID
| NatsClntPublish Subject (Maybe Subject) BL.ByteString
| NatsClntConnect NatsConnectionOptions
makeClntMsg :: NatsClntMessage -> BL.ByteString
makeClntMsg = BL.fromChunks . _makeClntMsg
where
_makeClntMsg :: NatsClntMessage -> [BS.ByteString]
_makeClntMsg NatsClntPing = ["PING"]
_makeClntMsg NatsClntPong = ["PONG"]
_makeClntMsg (NatsClntSubscribe subject sid (Just queue)) = [BS.pack $ "SUB " ++ (subjectToStr subject) ++ " " ++ (subjectToStr queue) ++ " " ++ (show sid)]
_makeClntMsg (NatsClntSubscribe subject sid Nothing) = [BS.pack $ "SUB " ++ (subjectToStr subject) ++ " " ++ (show sid)]
_makeClntMsg (NatsClntUnsubscribe sid) = [ BS.pack $ "UNSUB " ++ (show sid) ]
_makeClntMsg (NatsClntPublish subj Nothing msg) =
(BS.pack $ "PUB " ++ (subjectToStr subj) ++ " " ++ (show $ BL.length msg) ++ "\r\n") : BL.toChunks msg
_makeClntMsg (NatsClntPublish subj (Just reply) msg) =
(BS.pack $ "PUB " ++ (subjectToStr subj) ++ " " ++ (subjectToStr reply) ++ " " ++ (show $ BL.length msg) ++ "\r\n") : BL.toChunks msg
_makeClntMsg (NatsClntConnect info) = "CONNECT " : (BL.toChunks $ AE.encode info)
decodeMessage :: BS.ByteString -> Maybe (NatsSvrMessage, Maybe Int)
decodeMessage line = decodeMessage_ mid (BS.drop 1 mrest)
where
(mid, mrest) = BS.span (\x -> x/=' ' && x/='\r') line
decodeMessage_ :: BS.ByteString -> BS.ByteString -> Maybe (NatsSvrMessage, Maybe Int)
decodeMessage_ "PING" _ = Just (NatsSvrPing, Nothing)
decodeMessage_ "PONG" _ = Just (NatsSvrPong, Nothing)
decodeMessage_ "+OK" _ = Just (NatsSvrOK, Nothing)
decodeMessage_ "-ERR" msg = Just (NatsSvrError (decodeUtf8 msg), Nothing)
decodeMessage_ "INFO" msg = do
info <- AE.decode $ BL.fromChunks [msg]
return $ (NatsSvrInfo info, Nothing)
decodeMessage_ "MSG" msg = do
case (map BS.unpack (BS.words msg)) of
[subj, sid, len] -> return (NatsSvrMsg subj (read sid) undefined Nothing, Just $ read len)
[subj, sid, reply, len] -> return (NatsSvrMsg subj (read sid) undefined (Just $ reply), Just $ read len)
_ -> fail ""
decodeMessage_ _ _ = Nothing
newNatsSid :: Nats -> IO NatsSID
newNatsSid nats = atomicModifyIORef' (natsNextSid nats) $ \sid -> (sid + 1, sid)
newInbox :: IO String
newInbox = do
rnd <- replicateM 13 (randomRIO ('a', 'z'))
return $ "_INBOX." ++ rnd
connectToServer :: String -> Int -> IO Handle
connectToServer hostname port = do
addrinfos <- getAddrInfo Nothing (Just hostname) Nothing
let serveraddr = (head addrinfos)
bracketOnError
(S.socket (S.addrFamily serveraddr) S.Stream S.defaultProtocol)
(S.sClose)
(\sock -> do
setSocketOption sock KeepAlive 1
setSocketOption sock NoDelay 1
let connaddr = case (S.addrAddress serveraddr) of
SockAddrInet _ haddr -> SockAddrInet (fromInteger $ toInteger port) haddr
SockAddrInet6 _ finfo haddr scopeid -> SockAddrInet6 (fromInteger $ toInteger port) finfo haddr scopeid
other -> other
S.connect sock connaddr
h <- S.socketToHandle sock ReadWriteMode
hSetBuffering h NoBuffering
return h
)
ensureConnection :: Nats
-> Bool
-> ((Handle, FifoQueue) -> IO FifoQueue)
-> IO ()
ensureConnection nats True f = do
bracketOnError
(takeMVar $ natsRuntime nats)
(putMVar $ natsRuntime nats)
(\r@(handle, _, x1, x2) -> do
result <- runAction r
case result of
Just nqueue -> putMVar (natsRuntime nats) (handle, nqueue, x1, x2)
Nothing -> return ()
)
where
runAction (handle, queue, True, _) = do
nqueue <- f (handle, queue)
return $ Just nqueue
runAction r@(_, _, False, csig) = do
putMVar (natsRuntime nats) r
readMVar csig
ensureConnection nats True f
return Nothing
ensureConnection nats False f = modifyMVarMasked_ (natsRuntime nats) runAction
where
runAction (handle, queue, True, csig) = do
nqueue <- f (handle, queue)
return (handle, nqueue, True, csig)
runAction (handle, queue, False, csig) =
return (handle, queue, False, csig)
sendMessage :: Nats -> Bool -> NatsClntMessage -> Maybe (Maybe T.Text -> IO ()) -> IO ()
sendMessage nats blockIfDisconnected msg mcb
| Just cb <- mcb, supportsCallback msg =
ensureConnection nats blockIfDisconnected $ \(handle, queue) -> do
_sendMessage handle msg
return $ D.pushBack queue cb
| supportsCallback msg = sendMessage nats blockIfDisconnected msg (Just $ \_ -> return ())
| Just _ <- mcb, not (supportsCallback msg) = error "Callback not supported"
| True = ensureConnection nats blockIfDisconnected $ \(handle, queue) -> do
_sendMessage handle msg
return queue
where
supportsCallback (NatsClntConnect {}) = True
supportsCallback (NatsClntPublish {}) = True
supportsCallback (NatsClntSubscribe {}) = True
supportsCallback (NatsClntUnsubscribe {}) = True
supportsCallback _ = False
timeoutThrow :: Int -> IO a -> IO a
timeoutThrow t f = do
res <- timeout t f
case res of
Just x -> return x
Nothing -> throwIO $ NatsException "Reached timeout"
_sendMessage :: Handle -> NatsClntMessage -> IO ()
_sendMessage handle cmsg = timeoutThrow timeoutInterval $ do
let msg = makeClntMsg cmsg
case () of
_| BL.length msg < 1024 ->
BS.hPut handle $ BS.concat $ (BL.toChunks msg) ++ ["\r\n"]
| True -> do
BL.hPut handle msg
BS.hPut handle "\r\n"
authenticate :: Handle -> String -> String -> IO ()
authenticate handle user password = do
info <- BS.hGetLine handle
case (decodeMessage info) of
Just (NatsSvrInfo (NatsServerInfo {natsSvrAuthRequired=True}), Nothing) -> do
let coptions = defaultConnectionOptions{natsConnUser=user, natsConnPass=password}
BL.hPut handle $ makeClntMsg (NatsClntConnect coptions)
BS.hPut handle "\r\n"
response <- BS.hGetLine handle
case (decodeMessage response) of
Just (NatsSvrOK, Nothing) -> return ()
Just (NatsSvrError err, Nothing)-> throwIO $ NatsException $ "Authentication error: " ++ (show err)
_ -> throwIO $ NatsException $ "Incorrect server response"
Just (NatsSvrInfo _, Nothing) -> return ()
_ -> throwIO $ NatsException "Incorrect input from server"
prepareConnection :: Nats -> NatsHost -> IO ()
prepareConnection nats nhost = timeoutThrow timeoutInterval $
bracketOnError
(connectToServer (natsHHost nhost) (natsHPort nhost))
(hClose)
(\handle -> do
authenticate handle (natsHUser nhost) (natsHPass nhost)
csig <- modifyMVar (natsRuntime nats) $ \(_,_,_, csig) ->
return $ ((handle, D.empty, True, undefined), csig)
putMVar csig ()
)
connectionThread :: Nats
-> [NatsHost]
-> IO ()
connectionThread _ [] = error "Empty list of connections"
connectionThread nats (thisconn:nextconn) = do
mnewconnlist <-
(connectionHandler nats thisconn >> return Nothing)
`catches` [Handler (\(e :: IOException) -> Just <$> errorHandler e),
Handler (\(e :: NatsException) -> Just <$> errorHandler e),
Handler (\e -> finalHandler e >> return Nothing)]
case mnewconnlist of
Nothing -> return ()
Just newconnlist -> connectionThread nats newconnlist
where
finalize :: (Show e) => e -> IO ()
finalize e = do
(handle, queue, _, _) <- takeMVar (natsRuntime nats)
finsignal <- newEmptyMVar
putMVar (natsRuntime nats) (undefined, undefined, False, finsignal)
hClose handle
FOLD.mapM_ (\f -> f $ Just (T.pack $ show e)) queue
(natsOnDisconnect $ natsSettings nats) nats (show e)
errorHandler :: (Show e) => e -> IO [NatsHost]
errorHandler e = do
finalize e
tryToConnect nextconn
where
tryToConnect connlist@(conn:rest) = do
res <- ((prepareConnection nats conn) >> (return $ Just connlist))
`catches` [ Handler (\(_ :: IOException) -> return Nothing),
Handler (\(_ :: NatsException) -> return Nothing) ]
case res of
Just restlist -> return restlist
Nothing -> threadDelay timeoutInterval >> tryToConnect rest
tryToConnect [] = error "Empty list of connections"
finalHandler :: AsyncException -> IO ()
finalHandler e = do
finalize e
pingerThread :: Nats -> IORef (Int, Int) -> IO ()
pingerThread nats pingStatus = forever $ do
threadDelay pingInterval
ok <- atomicModifyIORef' pingStatus $ \(pings, pongs) -> ((pings+1, pongs), pings pongs < 2)
if ok == False
then throwIO (NatsException "Ping timeouted")
else return ()
sendMessage nats True NatsClntPing Nothing
connectionHandler :: Nats -> NatsHost -> IO ()
connectionHandler nats (NatsHost host port _ _) = do
(handle, _, _, _) <- readMVar (natsRuntime nats)
subscriptions <- readIORef (natsSubMap nats)
FOLD.forM_ subscriptions $ \(NatsSubscription subject queue _ sid) ->
sendMessage nats True (NatsClntSubscribe subject sid queue) Nothing
(natsOnConnect $ natsSettings nats) nats (host, port)
pingStatus <- newIORef (0, 0)
_ <- concurrently
(pingerThread nats pingStatus)
(connectionHandler' handle nats pingStatus)
return ()
connectionHandler' :: Handle -> Nats -> IORef (Int, Int) -> IO ()
connectionHandler' handle nats pingStatus = forever $ do
line <- BS.hGetLine handle
case decodeMessage line of
Just (msg, Nothing) ->
handleMessage msg
Just (msg@(NatsSvrMsg {}), Just paylen) -> do
payload <- BS.hGet handle (paylen + 2)
handleMessage msg{msgText=BS.take paylen payload}
_ ->
putStrLn $ "Incorrect message: " ++ (show line)
where
popCb (h, queue, x1, x2) = return ((h, newq, x1, x2), item)
where
(item, newq) = D.popFront queue
handleMessage NatsSvrPing = sendMessage nats True NatsClntPong Nothing
handleMessage NatsSvrPong =
atomicModifyIORef' pingStatus $
\(pings, pongs) -> ((pings, pongs + 1), ())
handleMessage NatsSvrOK = do
cb <- modifyMVar (natsRuntime nats) $ popCb
case cb of
Just f -> f Nothing
Nothing -> return ()
handleMessage (NatsSvrError txt) = do
cb <- modifyMVar (natsRuntime nats) $ popCb
case cb of
Just f -> f $ Just txt
Nothing -> putStrLn $ show txt
handleMessage (NatsSvrInfo _) = return ()
handleMessage (NatsSvrMsg {..}) = do
msubscription <- Map.lookup msgSid <$> readIORef (natsSubMap nats)
case msubscription of
Just subscription ->
(subCallback subscription) msgSid msgSubject (BL.fromChunks [msgText]) msgReply
`catch`
(\(e :: SomeException) -> putStrLn $ (show e))
Nothing -> sendMessage nats True (NatsClntUnsubscribe msgSid) Nothing
connect :: String
-> IO Nats
connect uri = do
let parsedUri = case (URI.parseURI uri) of
Just x -> x
Nothing -> error ("Error parsing NATS url: " ++ uri)
if URI.uriScheme parsedUri /= "nats:"
then error "Incorrect URL scheme"
else return ()
let (host, port, user, password) = case (URI.uriAuthority parsedUri) of
Just (URI.URIAuth {..}) -> (uriRegName,
read $ drop 1 uriPort,
takeWhile (\x -> x /= ':') uriUserInfo,
takeWhile (\x -> x /= '@') $ drop 1 $ dropWhile (\x -> x /= ':') uriUserInfo
)
Nothing -> error "Missing hostname section"
connectSettings defaultSettings{
natsHosts=[(NatsHost host port user password)]
}
connectSettings :: NatsSettings -> IO Nats
connectSettings settings = do
csig <- newEmptyMVar
mruntime <- newMVar (undefined, undefined, False, csig)
mthreadid <- newEmptyMVar
nextsid <- newIORef 1
submap <- newIORef Map.empty
let nats = Nats{
natsSettings=settings
, natsRuntime=mruntime
, natsThreadId=mthreadid
, natsNextSid=nextsid
, natsSubMap=submap
}
hosts = (natsHosts settings)
connhost <- tryUntilSuccess hosts $ prepareConnection nats
threadid <- forkIO $ connectionThread nats (connhost:(cycle hosts))
putMVar mthreadid threadid
return nats
where
tryUntilSuccess [a] f = f a >> return a
tryUntilSuccess (a:rest) f = (f a >> return a) `catch` (\(_ :: SomeException) -> tryUntilSuccess rest f)
tryUntilSuccess [] _ = error "Empty list"
subscribe :: Nats
-> String
-> (Maybe String)
-> MsgCallback
-> IO NatsSID
subscribe nats subject queue cb =
let
ssubject = makeSubject subject
squeue = makeSubject `fmap` queue
addToSubTable sid = atomicModifyIORef' (natsSubMap nats) $ \submap ->
(Map.insert sid (NatsSubscription{subSubject=ssubject, subQueue=squeue, subCallback=cb, subSid=sid}) submap, ())
in do
mvar <- newEmptyMVar :: IO (MVar (Maybe T.Text))
sid <- newNatsSid nats
sendMessage nats True (NatsClntSubscribe ssubject sid squeue) $ Just $ \err -> do
case err of
Nothing -> addToSubTable sid
Just _ -> return ()
putMVar mvar err
merr <- takeMVar mvar
case merr of
Just err -> throwIO $ NatsException $ T.unpack err
Nothing -> return $ sid
unsubscribe :: Nats
-> NatsSID
-> IO ()
unsubscribe nats sid = do
atomicModifyIORef' (natsSubMap nats) $ \ioref -> (Map.delete sid ioref, ())
sendMessage nats False (NatsClntUnsubscribe sid) Nothing
`catches` [ Handler (\(_ :: IOException) -> return ()),
Handler (\(_ :: NatsException) -> return ()) ]
request :: Nats
-> String
-> BL.ByteString
-> IO BL.ByteString
request nats subject body = do
mvar <- newEmptyMVar :: IO (MVar (Either String BL.ByteString))
inbox <- newInbox
bracket
(subscribe nats inbox Nothing $ \_ _ response _ -> do
_ <- tryPutMVar mvar (Right response)
return ()
)
(unsubscribe nats)
(\_ -> do
sendMessage nats True (NatsClntPublish (makeSubject subject) (Just $ makeSubject inbox) body) $ Just $ \merr -> do
case merr of
Nothing -> return ()
Just err -> tryPutMVar mvar (Left $ T.unpack err) >> return ()
result <- takeMVar mvar
case result of
Left err -> throwIO $ NatsException err
Right res -> return $ res
)
requestMany :: Nats
-> String
-> BL.ByteString
-> Int
-> IO [BL.ByteString]
requestMany nats subject body time = do
result <- newIORef []
inbox <- newInbox
bracket
(subscribe nats inbox Nothing $ \_ _ response _ ->
atomicModifyIORef result $ \old -> (response:old, ())
)
(unsubscribe nats)
(\_ -> do
publish' nats subject (Just inbox) body
threadDelay time
)
reverse <$> readIORef result
publish :: Nats
-> String
-> BL.ByteString
-> IO ()
publish nats subject body = publish' nats subject Nothing body
publish' :: Nats
-> String
-> Maybe String
-> BL.ByteString
-> IO ()
publish' nats subject inbox body = do
sendMessage nats False (NatsClntPublish (makeSubject subject) (makeSubject <$> inbox) body) Nothing
`catches` [ Handler (\(_ :: IOException) -> return ()),
Handler (\(_ :: NatsException) -> return ()) ]
disconnect :: Nats -> IO ()
disconnect nats = do
threadid <- readMVar (natsThreadId nats)
killThread threadid