{-# LANGUAGE DeriveGeneric, LambdaCase, OverloadedStrings, ViewPatterns #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE DeriveFunctor #-} module Database.Franz.Network ( startServer , defaultPort , Connection , withConnection , connect , disconnect , Query(..) , ItemRef(..) , RequestType(..) , defQuery , Response , awaitResponse , SomeIndexMap , Contents , fetch , fetchTraverse , fetchSimple , atomicallyWithin , FranzException(..)) where import Control.Concurrent import Control.Exception import Control.Monad import Control.Monad.Trans.Cont (ContT(..)) import Control.Concurrent.STM import Control.Concurrent.STM.Delay import Database.Franz.Reader import qualified Data.IntMap.Strict as IM import Data.IORef import Data.Int (Int64) import Data.Serialize import qualified Data.ByteString.Char8 as B import qualified Data.HashMap.Strict as HM import qualified Data.Vector as V import GHC.Generics (Generic) import qualified Network.Socket.SendFile.Handle as SF import qualified Network.Socket.ByteString as SB import qualified Network.Socket as S import System.Directory import System.FilePath import System.IO import System.Process (callProcess) defaultPort :: S.PortNumber defaultPort = 1886 data RawRequest = RawRequest !ResponseId !Query | RawClean !ResponseId deriving Generic instance Serialize RawRequest type ResponseId = Int data ResponseHeader = ResponseInstant !ResponseId -- ^ response ID, number of streams; there are items satisfying the query | ResponseWait !ResponseId -- ^ response ID; requested elements are not available right now | ResponseDelayed !ResponseId -- ^ response ID, number of streams; items are available | ResponseError !ResponseId !FranzException -- ^ something went wrong deriving (Show, Generic) instance Serialize ResponseHeader -- | Initial seqno, final seqno, base offset, index names data PayloadHeader = PayloadHeader !Int !Int !Int ![B.ByteString] instance Serialize PayloadHeader where put (PayloadHeader s t u xs) = f s *> f t *> f u *> put xs where f = putInt64le . fromIntegral get = PayloadHeader <$> f <*> f <*> f <*> get where f = fromIntegral <$> getInt64le respond :: FranzReader -> IORef (IM.IntMap ThreadId) -> B.ByteString -> IORef B.ByteString -> MVar S.Socket -> IO () respond env refThreads (B.unpack -> path) buf vConn = do recvConn <- readMVar vConn runGetRecv buf recvConn get >>= \case Right (RawRequest reqId req) -> do (stream, query) <- handleQuery env path req join $ atomically $ do (ready, offsets) <- query return $ if ready then removeActivity stream >> send (ResponseInstant reqId) stream offsets else do m <- readIORef refThreads if IM.member reqId m then sendHeader $ ResponseError reqId $ MalformedRequest "duplicate request ID" else do sendHeader $ ResponseWait reqId -- Fork a thread to send a delayed response tid <- flip forkFinally (const $ removeActivity stream) $ join $ atomically $ do (ready', offsets') <- query check ready' return $ send (ResponseDelayed reqId) stream offsets' writeIORef refThreads $! IM.insert reqId tid m `catchSTM` \e -> return $ do removeActivity stream sendHeader $ ResponseError reqId e `catch` \e -> sendHeader $ ResponseError reqId e Right (RawClean reqId) -> do m <- readIORef refThreads mapM_ killThread $ IM.lookup reqId m writeIORef refThreads $! IM.delete reqId m Left err -> throwIO $ MalformedRequest err where sendHeader x = withMVar vConn $ \conn -> SB.sendAll conn $ encode x send header Stream{..} ((s0, p0), (s1, p1)) = withMVar vConn $ \conn -> do SB.sendAll conn $ encode (header, PayloadHeader s0 s1 p0 indexNames) let siz = 8 * (length indexNames + 1) SF.sendFile' conn indexHandle (fromIntegral $ siz * succ s0) (fromIntegral $ siz * (s1 - s0)) SF.sendFile' conn payloadHandle (fromIntegral p0) (fromIntegral $ p1 - p0) startServer :: Double -- reaping interval -> Double -- stream life (seconds) -> S.PortNumber -> FilePath -- live prefix -> Maybe FilePath -- archive prefix -> IO () startServer interval life port lprefix aprefix = withFranzReader lprefix $ \env -> do hSetBuffering stderr LineBuffering _ <- forkIO $ reaper interval life env vMountCount <- newTVarIO (HM.empty :: HM.HashMap B.ByteString Int) let hints = S.defaultHints { S.addrFlags = [S.AI_NUMERICHOST, S.AI_NUMERICSERV], S.addrSocketType = S.Stream } addr:_ <- S.getAddrInfo (Just hints) (Just "0.0.0.0") (Just $ show port) bracket (S.socket (S.addrFamily addr) S.Stream (S.addrProtocol addr)) S.close $ \sock -> do S.setSocketOption sock S.ReuseAddr 1 S.setSocketOption sock S.NoDelay 1 S.bind sock $ S.SockAddrInet (fromIntegral port) (S.tupleToHostAddress (0,0,0,0)) S.listen sock S.maxListenQueue forever $ do (conn, connAddr) <- S.accept sock let respondLoop path = do SB.sendAll conn apiVersion hPutStrLn stderr $ unwords ["[server]", show connAddr, show path] ref <- newIORef IM.empty buf <- newIORef B.empty vConn <- newMVar conn forever (respond env ref path buf vConn) `finally` do readIORef ref >>= mapM_ killThread forkFinally (do decode <$> SB.recv conn 4096 >>= \case Left _ -> throwIO $ MalformedRequest "Expecting a path" Right path | Just apath <- aprefix -> do let src = apath B.unpack path let dest = lprefix B.unpack path join $ atomically $ do m <- readTVar vMountCount case HM.lookup path m of Nothing -> return $ do b <- doesFileExist src when b $ do createDirectoryIfMissing True dest callProcess "squashfuse" [src, dest] atomically $ writeTVar vMountCount $! HM.insert path 1 m Just n -> fmap pure $ writeTVar vMountCount $ HM.insert path (n + 1) m respondLoop path `finally` do join $ atomically $ do m <- readTVar vMountCount case HM.lookup path m of Just 1 -> return $ do callProcess "fusermount" ["-u", dest] atomically $ writeTVar vMountCount $ HM.delete path m Just n -> do writeTVar vMountCount $! HM.insert path (n - 1) m pure (pure ()) Nothing -> pure (pure ()) Right path -> respondLoop path ) $ \result -> do case result of Left ex -> case fromException ex of Just e -> SB.sendAll conn $ encode $ ResponseError (-1) e Nothing -> logServer [show ex] Right _ -> return () S.close conn where logServer = hPutStrLn stderr . unwords . (:) "[server]" -- The protocol -- -- Client Server --- | ---- Archive prefix ---> | Mounts P if possible --- | <--- apiVersion -------- | --- | ---- RawRequest i p ---> | --- | ---- RawRequest j q ---> | --- | ---- RawRequest k r ---> | --- | <--- ResponseInstant i - | --- | <--- result for p ----- | --- | <--- ResponseWait j ---- | --- | <--- ResponseWait k ---- | --- | <--- ResponseDelayed j - | --- | <--- result for q ----- | -- | ---- RawClean i ----> | -- | ---- RawClean j ----> | -- | ---- RawClean k ----> | data Connection = Connection { connSocket :: MVar S.Socket , connReqId :: TVar Int , connStates :: TVar (IM.IntMap (ResponseStatus Contents)) , connThread :: !ThreadId } data ResponseStatus a = WaitingInstant | WaitingDelayed | Errored !FranzException | Available !a deriving (Show, Functor) withConnection :: String -> S.PortNumber -> B.ByteString -> (Connection -> IO r) -> IO r withConnection host port dir = bracket (connect host port dir) disconnect apiVersion :: B.ByteString apiVersion = "0" connect :: String -> S.PortNumber -> B.ByteString -> IO Connection connect host port dir = do let hints = S.defaultHints { S.addrFlags = [S.AI_NUMERICSERV], S.addrSocketType = S.Stream } addr:_ <- S.getAddrInfo (Just hints) (Just host) (Just $ show port) sock <- S.socket (S.addrFamily addr) S.Stream (S.addrProtocol addr) S.connect sock $ S.addrAddress addr SB.sendAll sock $ encode dir readyMsg <- SB.recv sock 4096 unless (readyMsg == apiVersion) $ case decode readyMsg of Right (ResponseError _ e) -> throwIO e e -> throwIO $ ClientError $ "Database.Franz.Network.connect: Unexpected response: " ++ show e connSocket <- newMVar sock connReqId <- newTVarIO 0 connStates <- newTVarIO IM.empty buf <- newIORef B.empty connThread <- flip forkFinally (either throwIO pure) $ forever $ (>>=either (throwIO . ClientError) atomically) $ runGetRecv buf sock $ get >>= \case ResponseInstant i -> do resp <- getResponse return $ do m <- readTVar connStates case IM.lookup i m of Nothing -> pure () Just WaitingInstant -> writeTVar connStates $! IM.insert i (Available resp) m e -> throwSTM $ ClientError $ "Unexpected state on ResponseInstant " ++ show i ++ ": " ++ show e ResponseWait i -> return $ do m <- readTVar connStates case IM.lookup i m of Nothing -> pure () Just WaitingInstant -> writeTVar connStates $! IM.insert i WaitingDelayed m e -> throwSTM $ ClientError $ "Unexpected state on ResponseWait " ++ show i ++ ": " ++ show e ResponseDelayed i -> do resp <- getResponse return $ do m <- readTVar connStates case IM.lookup i m of Nothing -> pure () Just WaitingDelayed -> writeTVar connStates $! IM.insert i (Available resp) m e -> throwSTM $ ClientError $ "Unexpected state on ResponseDelayed " ++ show i ++ ": " ++ show e ResponseError i e -> return $ do m <- readTVar connStates case IM.lookup i m of Nothing -> throwSTM e Just _ -> writeTVar connStates $! IM.insert i (Errored e) m return Connection{..} disconnect :: Connection -> IO () disconnect Connection{..} = do killThread connThread withMVar connSocket S.close runGetRecv :: IORef B.ByteString -> S.Socket -> Get a -> IO (Either String a) runGetRecv refBuf sock m = do lo <- readIORef refBuf let go (Done a lo') = do writeIORef refBuf lo' return $ Right a go (Partial cont) = SB.recv sock 4096 >>= go . cont go (Fail str lo') = do writeIORef refBuf lo' return $ Left $ show sock ++ str bs <- if B.null lo then SB.recv sock 4096 else pure lo go $ runGetPartial m bs defQuery :: B.ByteString -> Query defQuery name = Query { reqStream = name , reqFrom = BySeqNum 0 , reqTo = BySeqNum 0 , reqType = AllItems } type SomeIndexMap = HM.HashMap B.ByteString Int64 -- | (seqno, indices, payloads) type Contents = [(Int, SomeIndexMap, B.ByteString)] -- | When it is 'Right', it might block until the content arrives. type Response = Either Contents (STM Contents) awaitResponse :: STM (Either a (STM a)) -> STM a awaitResponse = (>>=either pure id) getResponse :: Get Contents getResponse = do PayloadHeader s0 s1 p0 names <- get ixs <- V.replicateM (s1 - s0) $ (,) <$> fmap fromIntegral getInt64le <*> traverse (const getInt64le) names let ofss = V.cons p0 $ V.map fst ixs payload <- getByteString $ fromIntegral $ V.last ofss - p0 return $ do i <- [0..s1-s0-1] let ofs0 = maybe (error "ofs0") id $ ofss V.!? i let ofs1 = maybe (error "ofs1") fst $ ixs V.!? i let indices = maybe (error "indices") snd $ ixs V.!? i pure (s0 + i + 1, HM.fromList $ zip names indices, B.take (ofs1 - ofs0) $ B.drop (ofs0 - p0) payload) -- | Fetch requested data from the server. -- Termination of the continuation cancels the request, allowing flexible -- control of its lifetime. fetch :: Connection -> Query -> (STM Response -> IO r) -- ^ running the STM action blocks until the response arrives -> IO r fetch Connection{..} req cont = do reqId <- atomically $ do i <- readTVar connReqId writeTVar connReqId $! i + 1 modifyTVar' connStates $ IM.insert i WaitingInstant return i withMVar connSocket $ \sock -> SB.sendAll sock $ encode $ RawRequest reqId req let go = do m <- readTVar connStates case IM.lookup reqId m of Nothing -> return $ Left [] -- fetch ended; nothing to return Just WaitingInstant -> retry -- wait for an instant response Just (Available xs) -> do writeTVar connStates $! IM.delete reqId m return $ Left xs Just WaitingDelayed -> return $ Right $ do m' <- readTVar connStates case IM.lookup reqId m' of Nothing -> return [] -- fetch ended; nothing to return Just WaitingDelayed -> retry Just (Available xs) -> do writeTVar connStates $! IM.delete reqId m' return xs Just (Errored e) -> throwSTM e Just WaitingInstant -> throwSTM $ ClientError $ "fetch/WaitingDelayed: unexpected state WaitingInstant" Just (Errored e) -> throwSTM e cont go `finally` do withMVar connSocket $ \sock -> do atomically $ modifyTVar' connStates $ IM.delete reqId SB.sendAll sock $ encode $ RawClean reqId -- | Queries in traversable @t@ form an atomic request. The response will become -- available once all the elements are available. -- -- Generalisation to Traversable guarantees that the response preserves the -- shape of the request. fetchTraverse :: Traversable t => Connection -> t Query -> (STM (Either (t Contents) (STM (t Contents))) -> IO r) -> IO r fetchTraverse conn reqs = runContT $ do tresps <- traverse (ContT . fetch conn) reqs return $ do resps <- sequence tresps case traverse (either Just (const Nothing)) resps of Just instant -> return $ Left instant Nothing -> return $ Right $ traverse (either pure id) resps -- | Send a single query and wait for the result. If it timeouts, it returns an empty list. fetchSimple :: Connection -> Int -- ^ timeout in microseconds -> Query -> IO Contents fetchSimple conn timeout req = fetch conn req (fmap (maybe [] id) . atomicallyWithin timeout . awaitResponse) atomicallyWithin :: Int -- ^ timeout in microseconds -> STM a -> IO (Maybe a) atomicallyWithin timeout m = do d <- newDelay timeout atomically $ fmap Just m `orElse` (Nothing <$ waitDelay d)