{-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ForeignFunctionInterface #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE ScopedTypeVariables #-} module DBus.MainLoop where import Control.Concurrent import Control.Concurrent.Async import Control.Concurrent.STM import qualified Control.Exception as Ex import Control.Monad import Control.Monad.Catch (throwM) import Control.Monad.Fix (mfix) import Control.Monad.Trans import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy.Builder as BS import qualified Data.Conduit as C import qualified Data.Conduit.Binary as CB import Data.Map (Map) import qualified Data.Map as Map import Data.Text (Text) import qualified Data.Text as Text import Network.Socket (Socket, socketToHandle) import Network.Socket.ByteString (send) import System.Environment import System.IO import System.Log.Logger import DBus.Auth import DBus.Error import DBus.Message import DBus.MessageBus import DBus.Object import DBus.Transport import DBus.Types import DBus.Signal import DBus.Introspect handleMessage :: (MessageHeader -> [SomeDBusValue] -> IO ()) -> (MessageHeader -> [SomeDBusValue] -> IO a) -> TVar AnswerSlots -> TVar SignalSlots -> TVar PropertySlots -> (MessageHeader, [SomeDBusValue]) -> IO () handleMessage handleCall handleSignals answerSlots signalSlots propertySlots (header, body) = do case messageType header of MessageTypeMethodCall -> do let hfs = fields header logDebug $ "Dispatching method call " ++ show (hFPath hfs) ++ "; " ++ (maybe "" Text.unpack $ hFInterface hfs) ++ "; " ++ (maybe "" Text.unpack $ hFMember hfs) ++ ": " ++ show body handleCall header body MessageTypeMethodReturn -> handleReturn True MessageTypeError -> handleError MessageTypeSignal -> handleSignal' _ -> return () where handleReturn nonError = case hFReplySerial $ fields header of Nothing -> return () Just s -> atomically $ do slots <- readTVar answerSlots case Map.lookup s slots of Nothing -> return () Just putSlot -> do writeTVar answerSlots (Map.delete s slots) putSlot $ if nonError then Right body else Left body handleError = case hFReplySerial $ fields header of Nothing -> return () -- TODO: handle non-response errors Just _ -> handleReturn False handleSignal' = do _ <- handleSignals header body sSlots <- atomically $ readTVar signalSlots let fs = fields header case () of _ | Just iface <- hFInterface fs , Just member <- hFMember fs , Just path <- hFPath fs , Just sender <- hFSender fs -> case (iface, member) of ( "org.freedesktop.DBus.Properties" ,"PropertiesChanged") | [DBV pi', DBV uds, DBV invs] <- body , Just propIface <- fromRep =<< castDBV pi' :: Maybe Text , Just updates <- fromRep =<< castDBV uds :: Maybe (Map Text (DBusValue 'TypeVariant)) , Just ivs <- (fromRep =<< castDBV invs :: Maybe [Text]) -> handlePropertyUpdates path propIface updates ivs _ -> case filter (match4 ( Match iface , Match member , Match path , Match sender) . fst) sSlots of handlers@(_:_) -> case listToSomeArguments body of SDBA as -> let sig = SomeSignal $ Signal { signalPath = path , signalInterface = iface , signalMember = member , signalBody = as } in forM_ handlers $ \(_, handler) -> handler sig _ -> logDebug $ "Unhandled signal" ++ show path ++ "/ " ++ show iface ++ "." ++ show member ++ " from " ++ show sender ++ ": " ++ show body ++ "\n" | otherwise -> logDebug $ "Signal is missing header fields:" ++ show header ++ "; " ++ show body match4 (x1, x2, x3, x4) (y1, y2, y3, y4) = and $ [ x1 `checkMatch` y1 , x2 `checkMatch` y2 , x3 `checkMatch` y3 , x4 `checkMatch` y4 ] handlePropertyUpdates path iface updates ivs = do pSlots <- readTVarIO propertySlots let items = Map.toList (Just <$> updates) ++ ((\i -> (i, Nothing)) <$> ivs) forM_ items $ \(member, mbV) -> case Map.lookup (path, iface, member) pSlots of Nothing -> logDebug $ "unexpected property update for " ++ show path ++" / " ++ Text.unpack iface ++ "." ++ Text.unpack member Just hs -> do let v = variantToDBV <$> mbV logDebug $ "Recevied property updates " ++ show updates ++ " and invalidated propertied " ++ show ivs forM_ hs $ \h -> forkIO $ h v variantToDBV :: DBusValue 'TypeVariant -> SomeDBusValue variantToDBV (DBVVariant v) = DBV v -- | Create a message handler that dispatches matches to the methods in a root -- object objectRoot :: Objects -> Handler objectRoot o conn header args | fs <- fields header , Just path <- hFPath fs , Just iface <- hFInterface fs , Just member <- hFMember fs , ser <- serial header , Just sender <- hFSender fs = Ex.handle (\e -> hPutStrLn stderr (show ( e:: Ex.SomeException))) $ do let errToErrMessage s e = errorMessage s (Just ser) sender (errorName e) (errorText e) (errorBody e) mkReturnMethod s args' = methodReturn s ser sender args' (ret, sigs) <- case callAtPath o path iface member args of Left e -> return (Left e, []) Right f -> do ret <- withAsync (Ex.catch (runMethodHandlerT f) -- catches MsgError only: (\e -> return (Left e, [])) ) waitCatch case ret of Left e -> return $ (Left (MsgError "org.freedesktop.DBus.Error.Failed" (Just $ "Method threw exception: " `Text.append` Text.pack (show e)) []) , []) Right r -> return r serial' <- atomically $ dBusCreateSerial conn forM_ sigs $ flip emitSignal' conn logDebug $ "method call returned " ++ show ret case ret of Left err -> sendBS conn $ errToErrMessage serial' err Right r -> sendBS conn $ mkReturnMethod serial' r logDebug "done" objectRoot _ _ _ _ = return () -- | Check whether connection is alive checkAlive :: DBusConnection -> IO Bool checkAlive conn = atomically $ readTVar (dBusConnectionAliveRef conn) -- | Wait until connection is closed. The intended use is to keep alive servers waitFor :: DBusConnection -> IO () waitFor conn = atomically $ do alive <- readTVar (dBusConnectionAliveRef conn) when alive retry void $ readTVar (dBusGcRef conn) -- avoid closing the connection prematurely -- | Which Bus to connect to data ConnectionType = System -- ^ The well-known system bus. First -- the environmental variable -- DBUS_SYSTEM_BUS_ADDRESS is checked and if it -- doesn't exist the address -- /unix:path=\/var\/run\/dbus\/system_bus_socket/ -- is used | Session -- ^ The well-known session bus. Refers to the -- address stored in the environmental variable -- DBUS_SESSION_BUS_ADDRESS | Address String -- ^ The bus at the give addresss type MethodCallHandler = DBusConnection -- ^ Connection that the call was -- received from. Should be used to -- return the result or error -> MessageHeader -> [SomeDBusValue] -> IO () type SignalHandler = ( DBusConnection -> MessageHeader -> [SomeDBusValue] -> IO ()) -- | General way to connect to a message bus, see 'connectBusWithAuth'. -- -- Uses the @EXTERNAL@ authentication mechanism. connectBus :: ConnectionType -- ^ Bus to connect to -> MethodCallHandler -- ^ Handler for incoming method calls -> SignalHandler -- ^ Handler for incoming signals -> IO DBusConnection connectBus transport = connectBusWithAuth transport external -- | General way to connect to a message bus, with a custom authentication -- method. Takes two callback functions: -- -- * A 'MethodCallHandler' that is invoked when a method call is received. -- -- * A SignalHandler that is invoked when a Signak is received: connectBusWithAuth :: ConnectionType -- ^ Bus to connect to -> SASL BS.ByteString -- ^ The authentication mechanism -> MethodCallHandler -- ^ Handler for incoming method calls -> SignalHandler -- ^ Handler for incoming signals -> IO DBusConnection connectBusWithAuth transport auth handleCalls handleSignals = do addressString <- case transport of Session -> getEnv "DBUS_SESSION_BUS_ADDRESS" System -> do fromEnv <- lookupEnv "DBUS_SYSTEM_BUS_ADDRESS" case fromEnv of Nothing -> return "unix:path=/var/run/dbus/system_bus_socket" Just addr -> return addr Address addr -> return addr debugM "DBus" $ "connecting to " ++ addressString mbS <- connectString addressString s <- case mbS of Nothing -> throwM (CouldNotConnect "All addresses failed to connect") Just s -> return s _ <- sendCredentials s h <- socketToHandle s ReadWriteMode debugM "DBus" $ "Running SASL" _ <- runSasl (\bs -> do debugM "DBus.Sasl" $ "C: " ++ show (BS.toLazyByteString bs) BS.hPutBuilder h bs) (do bs <- BS.hGetLine h debugM "DBus.Sasl" $ "S: " ++ show bs return bs) auth serialCounter <- newTVarIO 1 let getSerial = do s' <- readTVar serialCounter writeTVar serialCounter (s'+1) return s' lock <- newTMVarIO $ BS.hPutBuilder h answerSlots <- newTVarIO (Map.empty :: AnswerSlots) signalSlots <- newTVarIO ([] :: SignalSlots) propertySlots <- newTVarIO (Map.empty :: PropertySlots) aliveRef <- newTVarIO True -- True and fake GC refs, see the explanation below. gcRef' <- newTVarIO () fakeGcRef <- newTVarIO () let kill = do atomically $ writeTVar aliveRef False hClose h atomically $ do sls <- readTVar answerSlots writeTVar answerSlots Map.empty writeTVar signalSlots [] writeTVar propertySlots Map.empty forM_ (Map.elems sls) $ \s' -> s' . Left $ [DBV $ DBVString "Connection Closed"] -- In order not to retain a reference to gcRef in the connection thread, -- the DBusConnection in the connection thread (forked below) needs to -- contain a "fake" TVar (), different from the one to which the -- finalizer is attached. -- -- Originally, I tried to overwrite the gcRef after that thread -- is forked and set it to fakeGcRef. However, that didn't work: -- -- · We can't force evaluation of the DBusConnection in the forked -- thread; it makes the mfix diverge due to the lack of laziness. -- · OTOH, if we don't force the DBusConnection, the update thunk -- continues to hold the reference to the original DBusConnection, and, -- therefore, to the "true" gcRef. -- -- Hence, we do it the other way around: initialize the connection with -- the fakeGcRef and later overwrite it with the true one. -- Note that we update gcRef *outside* of the mfix block. conn <- mfix $ \conn' -> do debugM "DBus" $ "Forking" handlerThread <- forkIO $ (CB.sourceHandle h C.$= parseMessages C.$$ (C.awaitForever $ liftIO . handleMessage (handleCalls conn') (handleSignals conn') answerSlots signalSlots propertySlots ) ) `Ex.finally` kill addTVarFinalizer gcRef' $ killThread handlerThread let conn = DBusConnection { dBusCreateSerial = getSerial , dBusAnswerSlots = answerSlots , dBusSignalSlots = signalSlots , dBusPropertySlots = propertySlots , dBusWriteLock = lock , dBusConnectionName = "" , dBusConnectionAliveRef = aliveRef , dBusGcRef = fakeGcRef , dBusKillConnection = killThread handlerThread } debugM "DBus" $ "hello" connName <- hello conn debugM "DBus" $ "Done" return conn{dBusConnectionName = connName} return conn{dBusGcRef = gcRef'} where addTVarFinalizer :: TVar a -> IO () -> IO () addTVarFinalizer tvar fin = void $ mkWeakTVar tvar fin -- | Create a simple server that exports @Objects@ and ignores all incoming signals. -- -- Use the default @EXTERNAL@ authentication mechanism (see 'makeServerWithAuth'). makeServer :: ConnectionType -> Objects -> IO DBusConnection makeServer transport = makeServerWithAuth transport external -- | Create a simple server with a custom bus authentication mechanism that -- exports @Objects@ and ignores all incoming signals. makeServerWithAuth :: ConnectionType -> SASL BS.ByteString -> Objects -> IO DBusConnection makeServerWithAuth transport auth objs = do connectBusWithAuth transport auth (objectRoot (addIntrospectable objs)) (\_ _ _ -> return ()) type Handler = DBusConnection -> MessageHeader -> [SomeDBusValue] -> IO () sendCredentials :: Socket -> IO Int #ifdef SEND_CREDENTIALS foreign import ccall "send_credentials_and_zero" sendCredentialsAndZero :: CInt -> IO CInt sendCredentials (MkSocket si _ _ _ _) = fromIntegral <$> sendCredentialsAndZero si #else sendCredentials s = send s "\0" #endif -- | Close the connection and finalize all handlers. -- -- This is automatically done when the connection is garbage collected, but close :: DBusConnection -> IO () close = dBusKillConnection