{- | Junction Trees The Tree data structures are not working very well with message passing algorithms. So, junction trees are using a different representation -} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FunctionalDependencies #-} module Bayes.FactorElimination.JTree( IsCluster(..) , Cluster(..) , JTree(..) , JunctionTree(..) , setFactors , distribute , collect , fromCluster , changeEvidence , nodeIsMemberOfTree , singletonTree , addNode , addSeparator , leaves , nodeValue , NodeValue(..) , SeparatorValue(..) , downMessage , upMessage , nodeParent , nodeChildren , traverseTree , separatorChild , treeNodes , Action(..) ) where import qualified Data.Map as Map import qualified Data.Tree as Tree import Data.Maybe(fromJust,mapMaybe) import qualified Data.Set as Set import Data.Monoid import Data.List((\\), intersect,partition, foldl') import Bayes.PrivateTypes import Bayes.Factor import Bayes import Debug.Trace debug s a = trace (s ++ " " ++ show a ++ "\n") a type UpMessage a = a type DownMessage a = Maybe a -- | Separator value data SeparatorValue a = SeparatorValue !(UpMessage a) !(DownMessage a) | EmptySeparator -- ^ Use to track the progress in the collect phase deriving(Eq) instance Show a => Show (SeparatorValue a) where show EmptySeparator = "" show (SeparatorValue u Nothing) = "u(" ++ show u ++ ")" show (SeparatorValue u (Just d)) = "u(" ++ show u ++ ") d(" ++ show d ++ ")" type FactorValue a = a type EvidenceValue a = a -- | Node value data NodeValue a = NodeValue !(FactorValue a) !(EvidenceValue a) deriving(Eq) instance Show a => Show (NodeValue a) where show (NodeValue f e) = "f(" ++ show f ++ ") e(" ++ show e ++ ")" -- | Junction tree. -- 'c' is the node / separator identifier (for instance a set of 'DV') -- a are the values for a node or separator data JTree c f = JTree { root :: !c -- | Leaves of the tree , leavesSet :: !(Set.Set c) -- | The children of a node are separators , childrenMap :: !(Map.Map c [c]) -- | Parent of a node , parentMap :: !(Map.Map c c) -- | Parent of a separator , separatorParentMap :: !(Map.Map c c) -- | The child of a seperator is a node , separatorChildMap :: !(Map.Map c c) -- | Values for nodes and seperators , nodeValueMap :: !(Map.Map c (NodeValue f)) , separatorValueMap :: !(Map.Map c (SeparatorValue f)) } deriving(Eq) -- | Create a singleton tree with just one root node singletonTree r factorValue evidenceValue = let t = JTree r Set.empty Map.empty Map.empty Map.empty Map.empty Map.empty Map.empty in addNode r factorValue evidenceValue t -- | Leaves of the tree leaves :: JTree c a -> [c] leaves = Set.toList . leavesSet -- | All nodes of the tree treeNodes :: JTree c a -> [c] treeNodes = Map.keys . nodeValueMap -- | Value of a node nodeValue :: Ord c => JTree c a -> c -> NodeValue a nodeValue t e = fromJust $ Map.lookup e (nodeValueMap t) -- | Change the value of a node setNodeValue :: Ord c => c -> NodeValue a -> JTree c a -> JTree c a setNodeValue c v t = t {nodeValueMap = Map.insert c v (nodeValueMap t)} -- | Parent of a node nodeParent :: Ord c => JTree c a -> c -> Maybe c nodeParent t e = Map.lookup e (parentMap t) -- | Value of a node separatorValue :: Ord c => JTree c a -> c -> SeparatorValue a separatorValue t e = fromJust $ Map.lookup e (separatorValueMap t) -- | Parent of a separator separatorParent :: Ord c => JTree c a -> c -> c separatorParent t e = fromJust $ Map.lookup e (separatorParentMap t) -- | UpMessage for a separator node upMessage :: Ord c => JTree c a -> c -> a upMessage t c = case separatorValue t c of SeparatorValue up _ -> up _ -> error "Trying to get an up message on an empty seperator ! Should never occur !" -- | DownMessage for a separator node downMessage :: Ord c => JTree c a -> c -> Maybe a downMessage t c = case separatorValue t c of SeparatorValue _ (Just down) -> Just down SeparatorValue _ Nothing -> Nothing _ -> error "Trying to get a down message on an empty separator ! Should never occur !" -- | Return the separator childrens of a node nodeChildren :: Ord c => JTree c a -> c -> [c] nodeChildren t e = maybe [] id $ Map.lookup e (childrenMap t) -- | Return the child of a separator separatorChild :: Ord c => JTree c a -> c -> c separatorChild t e = fromJust $ Map.lookup e (separatorChildMap t) -- | Check if a node is member of the tree nodeIsMemberOfTree :: Ord c => c -> JTree c a -> Bool nodeIsMemberOfTree c t = Map.member c (nodeValueMap t) -- | Add a separator between two nodes. -- The nodes MUST already be in the tree addSeparator :: (Ord c) => c -- ^ Origin node -> c -- ^ Separator -> c -- ^ Destination node -> JTree c a -- ^ Current tree -> JTree c a -- ^ Modified tree addSeparator node sep dest t = t { childrenMap = Map.insertWith' (++) node [sep] (childrenMap t) , separatorChildMap = Map.insert sep dest (separatorChildMap t) , separatorValueMap = Map.insert sep EmptySeparator (separatorValueMap t) , leavesSet = Set.delete node (leavesSet t) , parentMap = Map.insert dest sep (parentMap t) , separatorParentMap = Map.insert sep node (separatorParentMap t) } -- | Add a new node addNode :: (Ord c) => c -- ^ Node -> a -- ^ Factor value -> a -- ^ Evidence value -> JTree c a -> JTree c a addNode node factorValue evidenceValue t = t { nodeValueMap = Map.insert node (NodeValue factorValue evidenceValue) (nodeValueMap t) , leavesSet = Set.insert node (leavesSet t) } -- | Update the up message of a separator updateUpMessage :: Ord c => Maybe c -- ^ Separator node to update (if any : none for root node) -> a -- ^ New value -> JTree c a -- ^ Old tree -> JTree c a updateUpMessage Nothing _ t = t updateUpMessage (Just sep) newval t = let newSepValue = case separatorValue t sep of EmptySeparator -> SeparatorValue newval Nothing SeparatorValue up down -> SeparatorValue newval down in t {separatorValueMap = Map.insert sep newSepValue (separatorValueMap t)} -- | Update the down message of a separator updateDownMessage :: Ord c => c -- ^ Separator node to update -> a -- ^ New value -> JTree c a -- ^ Old tree -> JTree c a updateDownMessage sep newval t = let newSepValue = case separatorValue t sep of EmptySeparator -> error "Can't set a down message on an empty separator" SeparatorValue up _ -> SeparatorValue up (Just newval) in t {separatorValueMap = Map.insert sep newSepValue (separatorValueMap t)} {- Message passing algorithms -} -- | Functions used to generate new messages class Message f c | f -> c where -- | Generate a new message from the received ones newMessage :: [f] -> NodeValue f -> c -> f -- | Check that a separator is initialized separatorInitialized :: SeparatorValue a -> Bool separatorInitialized EmptySeparator = False separatorInitialized _ = True allSeparatorsHaveReceivedAMessage :: Ord c => JTree c a -- ^ Tree -> [c] -- ^ Separators -> Bool allSeparatorsHaveReceivedAMessage t seps = all separatorInitialized . map (separatorValue t) $ seps -- | Update the up separator by sending a message -- But only if all the down separators have received a message updateUpSeparator :: (Message a c, Ord c) => JTree c a -> c -- ^ Node generating the new upMessage -> JTree c a updateUpSeparator t h = let seps = nodeChildren t h in case allSeparatorsHaveReceivedAMessage t seps of False -> t True -> let incomingMessages = map (upMessage t) seps currentValue = nodeValue t h destinationNode = nodeParent t h in case destinationNode of Nothing -> t -- When root Just p -> let generatedMessage = newMessage incomingMessages currentValue p in updateUpMessage destinationNode generatedMessage t -- | Update the down separator by sending a message updateDownSeparator :: (Message a c, Ord c) => c -- ^ Node generating the message -> JTree c a -> c -- ^ Child receiving the message -> JTree c a updateDownSeparator node t child = let incomingMessagesFromBelow = map (upMessage t) (nodeChildren t node \\ [child]) messageFromAbove = downMessage t =<< (nodeParent t node) incomingMessages = maybe incomingMessagesFromBelow (\x -> x:incomingMessagesFromBelow) messageFromAbove currentValue = nodeValue t node generatedMessage = newMessage incomingMessages currentValue child in updateDownMessage child generatedMessage t unique :: Ord c => [c] -> [c] unique = Set.toList . Set.fromList data TraversalState = ACluster | ASeparator -- | Collect message taking into account that the tree depth may be different for different leaves. collect :: (Ord c, Message a c) => JTree c a -> JTree c a collect t = _collect ACluster (leaves t) t _collect :: (Ord c, Message a c) => TraversalState -- ^ Node processing phase or separator processing phase -> [c] -> JTree c a -- ^ Tree -> JTree c a -- ^ Modified tree _collect _ [] t = t _collect ACluster l t = let newTree = foldl' updateUpSeparator t l in _collect ASeparator (mapMaybe (nodeParent t) l) newTree _collect ASeparator l t = _collect ACluster (unique . map (separatorParent t) $ l) t distribute :: (Ord c, Message a c) => JTree c a -> JTree c a distribute t = _distribute ACluster t (root t) _distribute :: (Ord c, Message a c) => TraversalState -- ^ True if node -> JTree c a -> c -- ^ Destination of the distribute -> JTree c a _distribute ACluster t node = let children = nodeChildren t node newTree = foldl' (updateDownSeparator node) t $ children in foldl' (_distribute ASeparator) newTree children _distribute ASeparator t node = _distribute ACluster t (separatorChild t node) {- Factors and evidence modifications -} -- | This class is used to check if evidence or a factor is relevant -- for a cluster class IsCluster c where overlappingEvidence :: c -> [DVI Int] -> [DVI Int] clusterVariables :: c -> [DV] mkSeparator :: c -> c -> c instance IsCluster [DV] where overlappingEvidence c e = filter (\x -> instantiationVariable x `elem` c) e clusterVariables = id mkSeparator = intersect data Action s a = Skip !s | ModifyAndStop !s !a | Modify !s !a | Stop !s -- | Traverse a tree and modify it traverseTree :: Ord c => (s -> c -> NodeValue f -> Action s (NodeValue f)) -- ^ Modification function -> s -- ^ Current state -> JTree c f -- ^ Input tree -> (JTree c f,s) traverseTree action state t = _traverseTree True action (t,state) (root t) _traverseTree False action (t,state) current = _traverseTree True action (t,state) (separatorChild t current) _traverseTree True action (t,state) current = case action state current (nodeValue t current) of Stop newState -> (t,newState) ModifyAndStop _ newValue -> (setNodeValue current newValue t, state) Skip newState -> foldl' (_traverseTree False action) (t,newState) (nodeChildren t current) Modify newState newValue -> let newTree = setNodeValue current newValue t in foldl' (_traverseTree False action) (newTree,newState) (nodeChildren newTree current) mapWithCluster :: Ord c => (c -> NodeValue f -> NodeValue f) -> JTree c f -> JTree c f mapWithCluster f t = t {nodeValueMap = Map.mapWithKey f (nodeValueMap t)} -- | Set the factors in the tree setFactors :: (Graph g, Factor f, Show f, IsCluster c, Ord c) => BayesianNetwork g f -> JTree c f -> JTree c f setFactors g t = let factors = allVertexValues g in fst . traverseTree updateFactor factors $ t -- | Update factors in a cluster updateFactor :: (Factor f, IsCluster c) => [f] -- ^ Remaining list of factors to attribute -> c -- ^ Current cluster -> NodeValue f -- ^ Current value -> Action [f] (NodeValue f) updateFactor lf c (NodeValue _ evidence) | null lf = Stop lf | otherwise = let isFactorIncluded l = all (`elem` clusterVariables c) (factorVariables l) (attributedFactors,remainingFactors) = partition isFactorIncluded lf in Modify remainingFactors (NodeValue (factorProduct attributedFactors) evidence) -- | Change evidence in the network changeEvidence :: (IsCluster c, Ord c, Factor f, Message f c) => [DVI Int] -- ^ Evidence -> JTree c f -> JTree c f changeEvidence e t = distribute . collect . fst . traverseTree changeNodeEvidence e $ t { separatorValueMap = Map.map (const EmptySeparator) (separatorValueMap t)} changeNodeEvidence :: (IsCluster c, Factor f) => [DVI Int] -- ^ Evidence -> c -- ^ Current cluster -> NodeValue f -- ^ Current value -> Action [DVI Int] (NodeValue f) changeNodeEvidence [] c v = Stop [] changeNodeEvidence e c (NodeValue f olde) = let oe = overlappingEvidence c e ns = e \\ oe newEvidence = factorProduct $ map factorFromInstantiation oe in Modify ns (NodeValue f newEvidence) -- | Cluster of discrete variables. -- Discrete variables instead of vertices are needed because the -- factor are using 'DV' and we need to find -- which factors must be contained in a given cluster. newtype Cluster = Cluster (Set.Set DV) deriving(Eq,Ord) instance Show Cluster where show (Cluster s) = show . Set.toList $ s fromCluster (Cluster s) = Set.toList s instance Factor f => Message f Cluster where newMessage input (NodeValue f e) dv = factorProjectTo (fromCluster dv) (factorProduct (f:e:input)) type JunctionTree f = JTree Cluster f {- Implement the show function to see the structure of the tree (without the values) -} data NodeKind c = N !c | S !c label True c a = c ++ "=" ++ show a label False c _ = c -- | Convert the JTree into a tree of string -- using the cluster. toTree :: (Ord c, Show c, Show a) => Bool -- ^ True if the data must be displayed -> JTree c a -> Tree.Tree String toTree d t = let r = root t v = nodeValue t r nodec = map S (nodeChildren t r) in Tree.Node (label d (show r) v) (_toTree d t nodec) _toTree :: (Ord c, Show c, Show a) => Bool -> JTree c a -> [NodeKind c] -> [Tree.Tree String] _toTree _ _ [] = [] _toTree d t ((N h):l) = let nodec = map S (nodeChildren t h) -- Node children are separators v = nodeValue t h in Tree.Node (label d (show h) v) (_toTree d t nodec):_toTree d t l _toTree d t ((S h):l) = let separatorc = [N $ separatorChild t h] -- separator child is a node v = separatorValue t h in Tree.Node (label d ("<" ++ show h ++ ">") v ) (_toTree d t separatorc):_toTree d t l instance (Ord c, Show c, Show a) => Show (JTree c a) where show = Tree.drawTree . toTree False displayTree b = Tree.drawTree . toTree b {- Debug functions for tests -} --instance Message (Sum Int) String where -- newMessage l (NodeValue a b) _ = mconcat (a:b:l) -- --testTree :: JTree String (Sum Int) --testTree = let s a= Sum a -- in -- addSeparator "ROOT" "RB" "B" . -- addNode "B" (s 3) (s 3) . -- addSeparator "ROOT" "RA" "A" . -- addNode "A" (s 2) (s 2) $ -- singletonTree "ROOT" (s 4) (s 5) --