module Network.Nats (
Nats
, NatsSID
, connect
, connectSettings
, NatsHost(..)
, NatsSettings(..)
, defaultSettings
, NatsException
, MsgCallback
, subscribe
, unsubscribe
, publish
, request
, requestMany
, disconnect
) where
import Control.Applicative ((<$>), (<*>))
import Control.Concurrent
import Control.Concurrent.Async (concurrently)
import Control.Exception (AsyncException, Exception,
Handler (..), IOException,
SomeException, bracket,
bracketOnError, catch, catches,
throwIO)
import Control.Monad (forever, replicateM, unless, void,
when, mzero)
import Data.Dequeue as D
import qualified Data.Foldable as FOLD
import Data.IORef
import Data.Maybe (fromMaybe)
import Data.Typeable
import Network.Socket (SockAddr (..),
SocketOption (KeepAlive, NoDelay),
getAddrInfo, setSocketOption)
import qualified Network.Socket as S
import System.IO
import System.Random (randomRIO)
import System.Timeout
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.Char (isUpper, toLower)
import qualified Data.Map.Strict as Map
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8)
import qualified Data.Aeson as AE
import Data.Aeson ((.:), (.!=))
import Data.Aeson.TH (defaultOptions, deriveJSON,
fieldLabelModifier)
import qualified Network.URI as URI
pingInterval :: Int
pingInterval = 3000000
timeoutInterval :: Int
timeoutInterval = 1000000
data NatsException = NatsException String
deriving (Show, Typeable)
instance Exception NatsException
data NatsConnectionOptions = NatsConnectionOptions {
natsConnUser :: String
, natsConnPass :: String
, natsConnVerbose :: Bool
, natsConnPedantic :: Bool
, natsConnSslRequired :: Bool
} deriving (Show)
defaultConnectionOptions :: NatsConnectionOptions
defaultConnectionOptions = NatsConnectionOptions{natsConnUser="nats",natsConnPass="nats", natsConnVerbose=True,
natsConnPedantic=True, natsConnSslRequired=False}
$(deriveJSON defaultOptions{fieldLabelModifier =
let insertUnderscore acc chr
| isUpper chr = chr : '_' : acc
| otherwise = chr : acc
in
map toLower . drop 1 . reverse . foldl insertUnderscore [] . drop 8
} ''NatsConnectionOptions)
data NatsServerInfo = NatsServerInfo {
natsSvrAuthRequired :: Bool
} deriving (Show)
$(deriveJSON defaultOptions{fieldLabelModifier =
let insertUnderscore acc chr
| isUpper chr = chr : '_' : acc
| otherwise = chr : acc
in
map toLower . drop 1 . reverse . foldl insertUnderscore [] . drop 7
} ''NatsServerInfo)
newtype NatsSID = NatsSID Int deriving (Num, Ord, Eq)
instance Show NatsSID where
show (NatsSID num) = show num
instance Read NatsSID where
readsPrec x1 x2 = map (\(a,rest) -> (NatsSID a, rest)) $ readsPrec x1 x2
type MsgCallback = NatsSID
-> String
-> BL.ByteString
-> Maybe String
-> IO ()
data NatsSubscription = NatsSubscription {
subSubject :: Subject
, subQueue :: Maybe Subject
, subCallback :: MsgCallback
, subSid :: NatsSID
}
type FifoQueue = D.BankersDequeue (Maybe T.Text -> IO ())
data Nats = Nats {
natsSettings :: NatsSettings
, natsRuntime :: MVar (Handle,
FifoQueue,
Bool,
MVar ()
)
, natsThreadId :: MVar ThreadId
, natsNextSid :: IORef NatsSID
, natsSubMap :: IORef (Map.Map NatsSID NatsSubscription)
}
data NatsHost = NatsHost {
natsHHost :: String
, natsHPort :: Int
, natsHUser :: String
, natsHPass :: String
} deriving (Show)
instance AE.FromJSON NatsHost where
parseJSON (AE.Object v) =
NatsHost <$>
v .: "host" <*>
v .: "port" .!= 4222 <*>
v .: "user" .!= "nats" <*>
v .: "pass" .!= "pass"
parseJSON _ = mzero
data NatsSettings = NatsSettings {
natsHosts :: [NatsHost]
, natsOnReconnect :: Nats -> (String, Int) -> IO ()
, natsOnDisconnect :: Nats -> String -> IO ()
}
defaultSettings :: NatsSettings
defaultSettings = NatsSettings {
natsHosts = [NatsHost "localhost" 4222 "nats" "nats"]
, natsOnReconnect = \_ _ -> (return ())
, natsOnDisconnect = \_ _ -> (return ())
}
data NatsSvrMessage =
NatsSvrMsg { msgSubject::String, msgSid::NatsSID, msgText::BS.ByteString, msgReply::Maybe String}
| NatsSvrOK
| NatsSvrError T.Text
| NatsSvrPing
| NatsSvrPong
| NatsSvrInfo NatsServerInfo
deriving (Show)
newtype Subject = Subject String deriving (Show)
subjectToStr :: Subject -> String
subjectToStr (Subject str) = str
makeSubject :: String -> Subject
makeSubject "" = error "Empty subject"
makeSubject str
| any (<=' ') str = error $ "Subject contains incorrect characters: " ++ str
| otherwise = Subject str
data NatsClntMessage =
NatsClntPing
| NatsClntPong
| NatsClntSubscribe Subject NatsSID (Maybe Subject)
| NatsClntUnsubscribe NatsSID
| NatsClntPublish Subject (Maybe Subject) BL.ByteString
| NatsClntConnect NatsConnectionOptions
makeClntMsg :: NatsClntMessage -> BL.ByteString
makeClntMsg = BL.fromChunks . _makeClntMsg
where
_makeClntMsg :: NatsClntMessage -> [BS.ByteString]
_makeClntMsg NatsClntPing = ["PING"]
_makeClntMsg NatsClntPong = ["PONG"]
_makeClntMsg (NatsClntSubscribe subject sid (Just queue)) = [BS.pack $ "SUB " ++ subjectToStr subject ++ " " ++ subjectToStr queue ++ " " ++ show sid]
_makeClntMsg (NatsClntSubscribe subject sid Nothing) = [BS.pack $ "SUB " ++ subjectToStr subject ++ " " ++ show sid]
_makeClntMsg (NatsClntUnsubscribe sid) = [ BS.pack $ "UNSUB " ++ show sid ]
_makeClntMsg (NatsClntPublish subj Nothing msg) =
BS.pack ("PUB " ++ subjectToStr subj ++ " " ++ show (BL.length msg) ++ "\r\n") : BL.toChunks msg
_makeClntMsg (NatsClntPublish subj (Just reply) msg) =
BS.pack ("PUB " ++ subjectToStr subj ++ " " ++ subjectToStr reply ++ " " ++ show (BL.length msg) ++ "\r\n") : BL.toChunks msg
_makeClntMsg (NatsClntConnect info) = "CONNECT " : BL.toChunks (AE.encode info)
decodeMessage :: BS.ByteString -> Maybe (NatsSvrMessage, Maybe Int)
decodeMessage line = decodeMessage_ mid (BS.drop 1 mrest)
where
(mid, mrest) = BS.span (\x -> x/=' ' && x/='\r') line
decodeMessage_ :: BS.ByteString -> BS.ByteString -> Maybe (NatsSvrMessage, Maybe Int)
decodeMessage_ "PING" _ = Just (NatsSvrPing, Nothing)
decodeMessage_ "PONG" _ = Just (NatsSvrPong, Nothing)
decodeMessage_ "+OK" _ = Just (NatsSvrOK, Nothing)
decodeMessage_ "-ERR" msg = Just (NatsSvrError (decodeUtf8 msg), Nothing)
decodeMessage_ "INFO" msg = do
info <- AE.decode $ BL.fromChunks [msg]
return (NatsSvrInfo info, Nothing)
decodeMessage_ "MSG" msg =
case map BS.unpack (BS.words msg) of
[subj, sid, len] -> return (NatsSvrMsg subj (read sid) undefined Nothing, Just $ read len)
[subj, sid, reply, len] -> return (NatsSvrMsg subj (read sid) undefined (Just reply), Just $ read len)
_ -> fail ""
decodeMessage_ _ _ = Nothing
newNatsSid :: Nats -> IO NatsSID
newNatsSid nats = atomicModifyIORef' (natsNextSid nats) $ \sid -> (sid + 1, sid)
newInbox :: IO String
newInbox = do
rnd <- replicateM 13 (randomRIO ('a', 'z'))
return $ "_INBOX." ++ rnd
connectToServer :: String -> Int -> IO Handle
connectToServer hostname port = do
addrinfos <- getAddrInfo Nothing (Just hostname) Nothing
let serveraddr = head addrinfos
bracketOnError
(S.socket (S.addrFamily serveraddr) S.Stream S.defaultProtocol)
S.sClose
(\sock -> do
setSocketOption sock KeepAlive 1
setSocketOption sock NoDelay 1
let connaddr = case S.addrAddress serveraddr of
SockAddrInet _ haddr -> SockAddrInet (fromInteger $ toInteger port) haddr
SockAddrInet6 _ finfo haddr scopeid -> SockAddrInet6 (fromInteger $ toInteger port) finfo haddr scopeid
other -> other
S.connect sock connaddr
h <- S.socketToHandle sock ReadWriteMode
hSetBuffering h NoBuffering
return h
)
ensureConnection :: Nats
-> Bool
-> ((Handle, FifoQueue) -> IO FifoQueue)
-> IO ()
ensureConnection nats True f =
bracketOnError
(takeMVar $ natsRuntime nats)
(putMVar $ natsRuntime nats)
(\r@(handle, _, x1, x2) -> do
result <- runAction r
case result of
Just nqueue -> putMVar (natsRuntime nats) (handle, nqueue, x1, x2)
Nothing -> return ()
)
where
runAction (handle, queue, True, _) = do
nqueue <- f (handle, queue)
return $ Just nqueue
runAction r@(_, _, False, csig) = do
putMVar (natsRuntime nats) r
readMVar csig
ensureConnection nats True f
return Nothing
ensureConnection nats False f = modifyMVarMasked_ (natsRuntime nats) runAction
where
runAction (handle, queue, True, csig) = do
nqueue <- f (handle, queue)
return (handle, nqueue, True, csig)
runAction (handle, queue, False, csig) =
return (handle, queue, False, csig)
sendMessage :: Nats -> Bool -> NatsClntMessage -> Maybe (Maybe T.Text -> IO ()) -> IO ()
sendMessage nats blockIfDisconnected msg mcb
| Just cb <- mcb, supportsCallback msg =
ensureConnection nats blockIfDisconnected $ \(handle, queue) -> do
_sendMessage handle msg
return $ D.pushBack queue cb
| supportsCallback msg = sendMessage nats blockIfDisconnected msg (Just $ \_ -> return ())
| Just _ <- mcb, not (supportsCallback msg) = error "Callback not supported"
| otherwise = ensureConnection nats blockIfDisconnected $ \(handle, queue) -> do
_sendMessage handle msg
return queue
where
supportsCallback (NatsClntConnect {}) = True
supportsCallback (NatsClntPublish {}) = True
supportsCallback (NatsClntSubscribe {}) = True
supportsCallback (NatsClntUnsubscribe {}) = True
supportsCallback _ = False
timeoutThrow :: Int -> IO a -> IO a
timeoutThrow t f = do
res <- timeout t f
case res of
Just x -> return x
Nothing -> throwIO $ NatsException "Reached timeout"
_sendMessage :: Handle -> NatsClntMessage -> IO ()
_sendMessage handle cmsg = timeoutThrow timeoutInterval $ do
let msg = makeClntMsg cmsg
case () of
_| BL.length msg < 1024 ->
BS.hPut handle $ BS.concat $ BL.toChunks msg ++ ["\r\n"]
| otherwise -> do
BL.hPut handle msg
BS.hPut handle "\r\n"
authenticate :: Handle -> String -> String -> IO ()
authenticate handle user password = do
info <- BS.hGetLine handle
case decodeMessage info of
Just (NatsSvrInfo (NatsServerInfo {natsSvrAuthRequired=True}), Nothing) -> do
let coptions = defaultConnectionOptions{natsConnUser=user, natsConnPass=password}
BL.hPut handle $ makeClntMsg (NatsClntConnect coptions)
BS.hPut handle "\r\n"
response <- BS.hGetLine handle
case decodeMessage response of
Just (NatsSvrOK, Nothing) -> return ()
Just (NatsSvrError err, Nothing)-> throwIO $ NatsException $ "Authentication error: " ++ show err
_ -> throwIO $ NatsException "Incorrect server response"
Just (NatsSvrInfo _, Nothing) -> return ()
_ -> throwIO $ NatsException "Incorrect input from server"
prepareConnection :: Nats -> NatsHost -> IO ()
prepareConnection nats nhost = timeoutThrow timeoutInterval $
bracketOnError
(connectToServer (natsHHost nhost) (natsHPort nhost))
hClose
(\handle -> do
authenticate handle (natsHUser nhost) (natsHPass nhost)
csig <- modifyMVar (natsRuntime nats) $ \(_,_,_, csig) ->
return ((handle, D.empty, True, undefined), csig)
putMVar csig ()
)
connectionThread :: Bool
-> Nats
-> [NatsHost]
-> IO ()
connectionThread _ _ [] = error "Empty list of connections"
connectionThread firstTime nats (thisconn:nextconn) = do
mnewconnlist <-
(connectionHandler firstTime nats thisconn >> return Nothing)
`catches` [Handler (\(e :: IOException) -> Just <$> errorHandler e),
Handler (\(e :: NatsException) -> Just <$> errorHandler e),
Handler (\e -> finalHandler e >> return Nothing)]
case mnewconnlist of
Nothing -> return ()
Just newconnlist -> connectionThread False nats newconnlist
where
finalize :: (Show e) => e -> IO ()
finalize e = do
(handle, queue, _, _) <- takeMVar (natsRuntime nats)
finsignal <- newEmptyMVar
putMVar (natsRuntime nats) (undefined, undefined, False, finsignal)
hClose handle
FOLD.mapM_ (\f -> f $ Just (T.pack $ show e)) queue
(natsOnDisconnect $ natsSettings nats) nats (show e)
errorHandler :: (Show e) => e -> IO [NatsHost]
errorHandler e = do
finalize e
tryToConnect nextconn
where
tryToConnect connlist@(conn:rest) = do
res <- (prepareConnection nats conn >> return (Just connlist))
`catches` [ Handler (\(_ :: IOException) -> return Nothing),
Handler (\(_ :: NatsException) -> return Nothing) ]
case res of
Just restlist -> return restlist
Nothing -> threadDelay timeoutInterval >> tryToConnect rest
tryToConnect [] = error "Empty list of connections"
finalHandler :: AsyncException -> IO ()
finalHandler = finalize
pingerThread :: Nats -> IORef (Int, Int) -> IO ()
pingerThread nats pingStatus = forever $ do
threadDelay pingInterval
ok <- atomicModifyIORef' pingStatus $ \(pings, pongs) -> ((pings+1, pongs), pings pongs < 2)
unless ok $ throwIO (NatsException "Ping timeouted")
sendMessage nats True NatsClntPing Nothing
connectionHandler :: Bool -> Nats -> NatsHost -> IO ()
connectionHandler firstTime nats (NatsHost host port _ _) = do
(handle, _, _, _) <- readMVar (natsRuntime nats)
subscriptions <- readIORef (natsSubMap nats)
FOLD.forM_ subscriptions $ \(NatsSubscription subject queue _ sid) ->
sendMessage nats True (NatsClntSubscribe subject sid queue) Nothing
unless firstTime $ (natsOnReconnect $ natsSettings nats) nats (host, port)
pingStatus <- newIORef (0, 0)
void $ concurrently
(pingerThread nats pingStatus)
(connectionHandler' handle nats pingStatus)
connectionHandler' :: Handle -> Nats -> IORef (Int, Int) -> IO ()
connectionHandler' handle nats pingStatus = forever $ do
line <- BS.hGetLine handle
case decodeMessage line of
Just (msg, Nothing) ->
handleMessage msg
Just (msg@(NatsSvrMsg {}), Just paylen) -> do
payload <- BS.hGet handle (paylen + 2)
handleMessage msg{msgText=BS.take paylen payload}
_ ->
putStrLn $ "Incorrect message: " ++ show line
where
popCb (h, queue, x1, x2) = return ((h, newq, x1, x2), item)
where
(item, newq) = case D.popFront queue of
Just inq -> inq
Nothing -> (maybe (return ()) print, D.empty)
handleMessage NatsSvrPing = sendMessage nats True NatsClntPong Nothing
handleMessage NatsSvrPong =
atomicModifyIORef' pingStatus $
\(pings, pongs) -> ((pings, pongs + 1), ())
handleMessage NatsSvrOK = do
cb <- modifyMVar (natsRuntime nats) popCb
cb Nothing
handleMessage (NatsSvrError txt) = do
cb <- modifyMVar (natsRuntime nats) popCb
cb $ Just txt
handleMessage (NatsSvrInfo _) = return ()
handleMessage (NatsSvrMsg {..}) = do
msubscription <- Map.lookup msgSid <$> readIORef (natsSubMap nats)
case msubscription of
Just subscription ->
subCallback subscription msgSid msgSubject (BL.fromChunks [msgText]) msgReply
`catch`
(\(e :: SomeException) -> print e)
Nothing -> sendMessage nats True (NatsClntUnsubscribe msgSid) Nothing
connect :: String
-> IO Nats
connect uri = do
let parsedUri = fromMaybe (error ("Error parsing NATS url: " ++ uri))
(URI.parseURI uri)
when (URI.uriScheme parsedUri /= "nats:") $ error "Incorrect URL scheme"
let (host, port, user, password) = case URI.uriAuthority parsedUri of
Just (URI.URIAuth {..}) -> (uriRegName,
read $ drop 1 uriPort,
takeWhile (/= ':') uriUserInfo,
takeWhile (/= '@') $ drop 1 $ dropWhile (/= ':') uriUserInfo
)
Nothing -> error "Missing hostname section"
connectSettings defaultSettings{
natsHosts=[NatsHost host port user password]
}
connectSettings :: NatsSettings -> IO Nats
connectSettings settings = do
csig <- newEmptyMVar
mruntime <- newMVar (undefined, undefined, False, csig)
mthreadid <- newEmptyMVar
nextsid <- newIORef 1
submap <- newIORef Map.empty
let nats = Nats{
natsSettings=settings
, natsRuntime=mruntime
, natsThreadId=mthreadid
, natsNextSid=nextsid
, natsSubMap=submap
}
hosts = natsHosts settings
connhost <- tryUntilSuccess hosts $ prepareConnection nats
threadid <- forkIO $ connectionThread True nats (connhost : cycle hosts)
putMVar mthreadid threadid
return nats
where
tryUntilSuccess [a] f = f a >> return a
tryUntilSuccess (a:rest) f = (f a >> return a) `catch` (\(_ :: SomeException) -> tryUntilSuccess rest f)
tryUntilSuccess [] _ = error "Empty list"
subscribe :: Nats
-> String
-> Maybe String
-> MsgCallback
-> IO NatsSID
subscribe nats subject queue cb =
let
ssubject = makeSubject subject
squeue = makeSubject `fmap` queue
addToSubTable sid = atomicModifyIORef' (natsSubMap nats) $ \submap ->
(Map.insert sid NatsSubscription{subSubject=ssubject, subQueue=squeue, subCallback=cb, subSid=sid} submap, ())
in do
mvar <- newEmptyMVar :: IO (MVar (Maybe T.Text))
sid <- newNatsSid nats
sendMessage nats True (NatsClntSubscribe ssubject sid squeue) $ Just $ \err -> do
case err of
Nothing -> addToSubTable sid
Just _ -> return ()
putMVar mvar err
merr <- takeMVar mvar
case merr of
Just err -> throwIO $ NatsException $ T.unpack err
Nothing -> return sid
unsubscribe :: Nats
-> NatsSID
-> IO ()
unsubscribe nats sid = do
atomicModifyIORef' (natsSubMap nats) $ \ioref -> (Map.delete sid ioref, ())
sendMessage nats False (NatsClntUnsubscribe sid) Nothing
`catches` [ Handler (\(_ :: IOException) -> return ()),
Handler (\(_ :: NatsException) -> return ()) ]
request :: Nats
-> String
-> BL.ByteString
-> IO BL.ByteString
request nats subject body = do
mvar <- newEmptyMVar :: IO (MVar (Either String BL.ByteString))
inbox <- newInbox
bracket
(subscribe nats inbox Nothing $ \_ _ response _ ->
void $ tryPutMVar mvar (Right response)
)
(unsubscribe nats)
(\_ -> do
sendMessage nats True (NatsClntPublish (makeSubject subject) (Just $ makeSubject inbox) body) $ Just $ \merr ->
case merr of
Nothing -> return ()
Just err -> void $ tryPutMVar mvar (Left $ T.unpack err)
result <- takeMVar mvar
case result of
Left err -> throwIO $ NatsException err
Right res -> return res
)
requestMany :: Nats
-> String
-> BL.ByteString
-> Int
-> IO [BL.ByteString]
requestMany nats subject body time = do
result <- newIORef []
inbox <- newInbox
bracket
(subscribe nats inbox Nothing $ \_ _ response _ ->
atomicModifyIORef result $ \old -> (response:old, ())
)
(unsubscribe nats)
(\_ -> do
publish' nats subject (Just inbox) body
threadDelay time
)
reverse <$> readIORef result
publish :: Nats
-> String
-> BL.ByteString
-> IO ()
publish nats subject = publish' nats subject Nothing
publish' :: Nats
-> String
-> Maybe String
-> BL.ByteString
-> IO ()
publish' nats subject inbox body =
sendMessage nats False (NatsClntPublish (makeSubject subject) (makeSubject <$> inbox) body) Nothing
`catches` [ Handler (\(_ :: IOException) -> return ()),
Handler (\(_ :: NatsException) -> return ()) ]
disconnect :: Nats -> IO ()
disconnect nats = do
threadid <- readMVar (natsThreadId nats)
killThread threadid