{-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE FlexibleInstances #-} {- | Algorithms for factor elimination -} module Bayes.FactorElimination( -- * Moral graph moralGraph -- * Triangulation , nodeComparisonForTriangulation , numberOfAddedEdges , weight , weightedEdges , triangulate -- * Junction tree , createClusterGraph , Cluster , createJunctionTree , createUninitializedJunctionTree , JunctionTree -- * Shenoy-Shafer message passing , collect , distribute , posterior -- * Evidence , changeEvidence -- * Test , junctionTreeProperty_prop , VertexCluster -- * For debug , junctionTreeProperty , maximumSpanningTree , fromVertexCluster ) where import Bayes import qualified Data.Foldable as F import Data.Maybe(fromJust,mapMaybe,isJust) import Control.Monad(mapM,guard) import Bayes.Factor hiding (isEmpty) import Data.Function(on) import Data.List(minimumBy,maximumBy,inits,foldl') import qualified Data.Set as Set import qualified Data.Map as Map import qualified Data.Functor as Functor import qualified Data.Tree as T import Bayes.FactorElimination.JTree import Control.Applicative((<$>)) import Test.QuickCheck hiding ((.||.), collect) import Test.QuickCheck.Arbitrary --import Debug.Trace --debug s a = trace (s ++ " " ++ show a ++ "\n") a {- Comparison functions for graph triangulation -} -- | Number of edges added when connecting all neighbors numberOfAddedEdges :: UndirectedGraph g => g a b -> Vertex -> Int numberOfAddedEdges g v = let nodes = fromJust $ neighbors g v in length [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)] weightedEdges :: (UndirectedGraph g, Factor f) => g a f -> Vertex -> Int weightedEdges g v = let nodes = fromJust $ neighbors g v in sum [weight g x * weight g y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)] -- | Weight of a node weight :: (UndirectedGraph g, Factor f) => g a f -> Vertex -> Int weight g v = factorDimension . fromJust . vertexValue g $ v (.||.) :: (a -> a -> Ordering) -> (a -> a -> Ordering) -> (a -> a -> Ordering) f .||. g = \a b -> case f a b of EQ -> g a b r -> r -- | Node selection comparison function used for triangulating the graph nodeComparisonForTriangulation :: (UndirectedGraph g, Factor f) => g a f -> Vertex -> Vertex -> Ordering nodeComparisonForTriangulation g = (compare `on` (numberOfAddedEdges g)) .||. (compare `on` (weight g)) {- Graph triangulation -} -- | A cluster containing only the vertices and not yet the factors newtype VertexCluster = VertexCluster (Set.Set Vertex) deriving(Eq) fromVertexCluster (VertexCluster s) = s instance Show VertexCluster where show (VertexCluster s) = show . Set.toList $ s instance IsCluster Cluster where overlappingEvidence (Cluster c) e = filter (\x -> Set.member (instantiationVariable x) c) e clusterVariables (Cluster s) = Set.toList s mkSeparator (Cluster sa) (Cluster sb) = Cluster $ Set.intersection sa sb -- | Triangulate a graph using a cost function -- The result is the triangulated graph and the list of clusters -- which may not be maximal. triangulate :: Graph g => (Vertex -> Vertex -> Ordering) -- ^ Criterion function for triangulation -> g () b -> ([VertexCluster],g () b) -- ^ Returns the clusters and the triangulated graph triangulate cmp g = -- At start, gsrc and gdst are the same -- gsrc is modified. It is where vertex elimination is taking place. -- The edges are added to gdst let processAllNodes gsrc gdst l | hasNoVertices gsrc = (keepMaximalClusters (reverse l),gdst) | otherwise = let selectedNode = minimumBy cmp (allVertices gsrc) theNeighbors = selectedNode : (fromJust $ neighbors gsrc selectedNode) addEmptyEdge e g = addEdge e () g (gsrc',gdst') = connectAllNodesWith gsrc gdst addEmptyEdge addEmptyEdge theNeighbors gsrc'' = removeVertex selectedNode gsrc' in processAllNodes gsrc'' gdst' ((VertexCluster . Set.fromList $ theNeighbors) : l) in processAllNodes g g [] -- | Find for a containing cluster. findContainingCluster :: VertexCluster -- ^ Cluster processed -> [VertexCluster] -- ^ Cluster list where to look for a containing cluster -> (Maybe VertexCluster,[VertexCluster]) -- ^ Return the containing cluster and a new list without the containing cluster findContainingCluster cluster l = let clusterIsNotASubsetOf s = (Set.isSubsetOf (fromVertexCluster cluster) (fromVertexCluster s)) (prefix,suffix) = break clusterIsNotASubsetOf l in case suffix of [] -> (Nothing,l) _ -> (Just (head suffix),prefix ++ tail suffix) -- | Remove clusters already contained in a previous clusters keepMaximalClusters :: [VertexCluster] -> [VertexCluster] keepMaximalClusters [] = [] keepMaximalClusters l = checkIfMaximal [] (head l) (tail l) where checkIfMaximal reversedPrefix current [] = case findContainingCluster current (reverse reversedPrefix) of (Nothing,_) -> reverse (current:reversedPrefix) (Just r,l) -> reverse (r:reverse l) checkIfMaximal reversedPrefix current suffix = case findContainingCluster current (reverse reversedPrefix) of (Nothing,_) -> checkIfMaximal (current:reversedPrefix) (head suffix) (tail suffix) (Just r,l) -> checkIfMaximal (r:reverse l) (head suffix) (tail suffix) -- | Convert the clusters from vertex to 'DV' clusters vertexClusterToCluster :: (Factor f , Graph g) => g e f -> VertexCluster -> Cluster vertexClusterToCluster g c = let vertices = Set.toList . fromVertexCluster $ c variables = map factorMainVariable . mapMaybe (vertexValue g) $ vertices in Cluster . Set.fromList $ variables -- | Create the cluster graph createClusterGraph :: (UndirectedGraph g, Factor f, Graph g') => g' e f -> [VertexCluster] -> g Int Cluster createClusterGraph bn c = let numberedClusters = zip c (map Vertex [0..]) addCluster g (c,v) = addVertex v (vertexClusterToCluster bn c) g graphWithoutEdges = foldl' addCluster emptyGraph numberedClusters separatorSize ca cb = Set.size $ Set.intersection (fromVertexCluster ca) (fromVertexCluster cb) allEdges = [(cx,cy) | cx <- numberedClusters, cy <- numberedClusters, cx /= cy] addClusterEdge g ((ca,va),(cb,vb)) = addEdge (edge va vb) (separatorSize ca cb) g in foldl' addClusterEdge graphWithoutEdges allEdges {- Maximum spanning tree using Prim's algorithm -} -- | Get all possible edges between the leaves and the remaining nodes possibilities :: (Ord c , UndirectedGraph g) => g Int c -- ^ Original graph to get the edge value -> JTree c (Vertex,f) -- ^ Tree to get the vertex for a leaf -> [Vertex] -- ^ Vertices to add to the tree -> [c] -- ^ List of leaves -> [(Vertex,c,Int)] -- ^ Found edge to add possibilities g currentT remaining leavesClusters = do rv <- remaining lv <- leavesClusters let NodeValue (lvVertex,lvCluster) _ = nodeValue currentT lv guard (isLinkedWithAnEdge g rv lvVertex) let ev = fromJust $ edgeValue g (edge rv lvVertex) return $ (rv,lv,ev) -- | Find the max edge to add to the tree findMax :: (UndirectedGraph g, Ord c, Factor f) => g Int c -- ^ Graph -> [Vertex] -- ^ Nodes to add -> JTree c (Vertex,f) -> ([Vertex],(Vertex,c),c) findMax g remaining currentT = let leavesClusters = treeNodes currentT edgeValue (_,_,e) = e (rf,lf,ef) = maximumBy (compare `on` edgeValue) (possibilities g currentT remaining leavesClusters) remaining' = filter (/= rf) remaining foundCluster = fromJust $ vertexValue g rf in (remaining', (rf, foundCluster), lf) removeVertices :: JTree c (Vertex,f) -> JTree c f removeVertices t = t { nodeValueMap = Map.map removeVertexFromNode (nodeValueMap t) , separatorValueMap = Map.map removeVertexFromSeparator (separatorValueMap t) } where removeVertexFromNode (NodeValue (_,f) (_,e)) = NodeValue f e removeVertexFromSeparator (SeparatorValue (_,u) (Just (_,d))) = SeparatorValue u (Just d) removeVertexFromSeparator (SeparatorValue (_,u) Nothing) = SeparatorValue u Nothing removeVertexFromSeparator EmptySeparator = EmptySeparator -- | Implementing the Prim's algorithm for minimum spanning tree maximumSpanningTree :: (UndirectedGraph g, IsCluster c, Factor f, Ord c) => g Int c -> JTree c f maximumSpanningTree g = let rootNodeVertex = fromJust $ someVertex g rootNodeValue = fromJust $ vertexValue g rootNodeVertex unitFactor = factorFromScalar 1.0 startTree = singletonTree rootNodeValue (rootNodeVertex,unitFactor) (rootNodeVertex,unitFactor) remainingVertices = filter (/= rootNodeVertex) (allVertices g) in removeVertices $ buildTree g remainingVertices startTree buildTree :: (UndirectedGraph g , IsCluster c, Factor f, Ord c) => g Int c -> [Vertex] -> JTree c (Vertex,f) -> JTree c (Vertex,f) buildTree g [] currentT = currentT buildTree g l@(h:t) currentT = let unitFactor = factorFromScalar 1.0 (l',(foundElemVertex,foundElemValue),leaf) = findMax g l currentT sep = mkSeparator foundElemValue leaf newTree = addSeparator leaf sep foundElemValue . addNode foundElemValue (foundElemVertex,unitFactor) (foundElemVertex,unitFactor) $ currentT in buildTree g l' newTree {- Junction tree algorithm -} -- | Create a junction tree with only the clusters and no factors createUninitializedJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f) => (UndirectedSG () f -> Vertex -> Vertex -> Ordering) -- ^ Weight function on the moral graph -> g () f -- ^ Input directed graph -> JunctionTree f -- ^ Junction tree createUninitializedJunctionTree cmp g = let theMoralGraph = moralGraph g (clusters,_) = triangulate (cmp theMoralGraph) theMoralGraph g'' = createClusterGraph g clusters :: UndirectedSG Int Cluster in maximumSpanningTree g'' -- | Create a function tree createJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, Show f) => (UndirectedSG () f -> Vertex -> Vertex -> Ordering) -- ^ Weight function on the moral graph -> BayesianNetwork g f -- ^ Input directed graph -> JunctionTree f -- ^ Junction tree createJunctionTree cmp g = let cTree = createUninitializedJunctionTree cmp g -- A vertex is linked with a factor so vertex is used as the identifier newTree = setFactors g cTree in distribute . collect $ newTree -- | Compute the marginal posterior (if some evidence is set on the junction tree) -- otherwise compute just the marginal prior. posterior :: Factor f => JunctionTree f -> DV -> Maybe f posterior t v = case snd $ traverseTree (findClusterFor v) Nothing t of Nothing -> Nothing Just c -> let NodeValue f e = nodeValue t c d = maybe (factorFromScalar 1.0) id $ downMessage t =<< (nodeParent t c) u = map (upMessage t) (nodeChildren t c) unNormalized = factorProjectTo [v] (factorProduct (f:e:d:u)) in Just $ factorDivide unNormalized (factorNorm unNormalized) -- | Find a cluster containing the variable findClusterFor :: DV -> Maybe Cluster -> Cluster -- ^ Current cluster -> NodeValue f -- ^ Current value -> Action (Maybe Cluster) (NodeValue f) findClusterFor dv s c@(Cluster sc) v = case Set.member dv sc of False -> Skip s True -> Stop (Just c) junctionTreeProperty_prop :: DirectedSG () CPT -> Property junctionTreeProperty_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==> let cmp ug = (compare `on` (numberOfAddedEdges ug)) t = createUninitializedJunctionTree cmp g in junctionTreeProperty t [] (root t) junctionTreeProperty :: JTree Cluster CPT -> [Cluster] -> Cluster -> Bool junctionTreeProperty t path c = let cl = map (separatorChild t) . nodeChildren t $ c in checkPath c path && all (junctionTreeProperty t (c:path)) cl -- | Check that the intersection of C with any parent in included in all cluster between the parent and C. checkPath :: Cluster -> [Cluster] -> Bool checkPath _ [] = True checkPath (Cluster c) l = let clusterSet (Cluster s) = s -- x parentSets = map clusterSet l -- Example a b c d where a is the root allIntersectionsWithParents = map (Set.intersection c) parentSets -- a ^ x, b ^ x , c ^ x , d ^ x pathsToEachParent = tail . inits $ parentSets -- a, ab, abc, abcd isSubsetOfAllParents i path = all (Set.isSubsetOf i) path in and $ zipWith isSubsetOfAllParents allIntersectionsWithParents pathsToEachParent {- Moral graph -} -- | Get the parents of a vertex parents :: DirectedGraph g => g a b -> Vertex -> [Vertex] parents g v = fromJust $ ingoing g v >>= mapM (startVertex g) -- | Get the children of a vertex children :: DirectedGraph g => g a b -> Vertex -> [Vertex] children g v = fromJust $ outgoing g v >>= mapM (endVertex g) -- | Connect all the nodes which are not connected and apply the function f for each new connection -- The origin and dest graph must share the same vertex. connectAllNodesWith :: (Graph g, Graph g') => g a b -- ^ Graph containing the nodes -> g' a b -- ^ Graph to be modified -> (Edge -> g a b -> g a b) -- ^ Function used to modify the source graph -> (Edge -> g' a b -> g' a b) -- ^ Function used to modify a new graph -> [Vertex] -- ^ List of nodes to connect -> (g a b,g' a b) -- ^ Result graph connectAllNodesWith originGraph dstGraph g f nodes = let h e (x,y) = (g e x, f e y) (originGraph',dstGraph') = foldr h (originGraph,dstGraph) [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge originGraph x y)] in (originGraph',dstGraph') -- | Add the missing parent links addMissingLinks :: DirectedGraph g => g () b -> Vertex -> b -> g () b addMissingLinks g v _ = let (_,g') = connectAllNodesWith g g (\e m -> m) (\e m -> addEdge e () m) (parents g v) in g' -- | Convert the graph to an undirected form convertToUndirected :: (FoldableWithVertex g, Graph g, NamedGraph g, NamedGraph g',UndirectedGraph g') => g () b -> g' () b convertToUndirected m = let addVertexWithLabel g v dat = let theName = fromJust $ vertexLabel m v in addLabeledVertex theName v dat g newDiscreteGraph = foldlWithVertex' addVertexWithLabel emptyGraph m addEmptyEdge edge g = addEdge edge () g in foldr addEmptyEdge newDiscreteGraph . allEdges $ m -- | For the junction tree construction, only the vertices are needed during the intermediate steps. -- So, the moral graph is returned without any vertex data. moralGraph :: (NamedGraph g, FoldableWithVertex g, DirectedGraph g) => g () b -> UndirectedSG () b moralGraph g = convertToUndirected . foldlWithVertex' addMissingLinks g $ g