module Bayes.FactorElimination.JTree(
IsCluster(..)
, Cluster(..)
, JTree(..)
, JunctionTree(..)
, Sep
, setFactors
, distribute
, collect
, fromCluster
, changeEvidence
, nodeIsMemberOfTree
, singletonTree
, addNode
, addSeparator
, leaves
, nodeValue
, NodeValue(..)
, SeparatorValue(..)
, downMessage
, upMessage
, nodeParent
, nodeChildren
, traverseTree
, separatorChild
, treeNodes
, treeValues
, displayTreeValues
, Action(..)
) where
import qualified Data.Map as Map
import qualified Data.Tree as Tree
import Data.Maybe(fromJust,mapMaybe)
import qualified Data.Set as Set
import Data.Monoid
import Data.List((\\), intersect,partition, foldl',minimumBy,nub)
import Bayes.PrivateTypes
import Bayes.Factor
import Bayes
import Data.Function(on)
import Bayes.VariableElimination(marginal)
import Debug.Trace
debug s a = trace (s ++ " " ++ show a ++ "\n") a
type UpMessage a = a
type DownMessage a = Maybe a
data SeparatorValue a = SeparatorValue !(UpMessage a) !(DownMessage a)
| EmptySeparator
deriving(Eq)
instance Show a => Show (SeparatorValue a) where
show EmptySeparator = ""
show (SeparatorValue u Nothing) = "u(" ++ show u ++ ")"
show (SeparatorValue u (Just d)) = "u(" ++ show u ++ ") d(" ++ show d ++ ")"
type FactorValues a = [a]
type EvidenceValues a = [a]
data NodeValue a = NodeValue !Vertex !(FactorValues a) !(EvidenceValues a) deriving(Eq)
instance Show a => Show (NodeValue a) where
show (NodeValue v f e) = "f(" ++ show f ++ ") e(" ++ show e ++ ")"
newtype Sep = Sep Int deriving(Eq,Ord,Show,Num)
data JTree c f = JTree { root :: !c
, leavesSet :: !(Set.Set c)
, childrenMap :: !(Map.Map c [Sep])
, parentMap :: !(Map.Map c Sep)
, separatorParentMap :: !(Map.Map Sep c)
, separatorChildMap :: !(Map.Map Sep c)
, nodeValueMap :: !(Map.Map c (NodeValue f))
, separatorValueMap :: !(Map.Map Sep (SeparatorValue f))
, separatorCurrentKey :: !Sep
, separatorClusterMap :: !(Map.Map Sep c)
} deriving(Eq)
singletonTree r rootVertex factorValue evidenceValue =
let t = JTree r Set.empty Map.empty Map.empty Map.empty Map.empty Map.empty Map.empty (Sep 0) Map.empty
in
addNode r rootVertex factorValue evidenceValue t
resetEvidences :: Factor f => JTree c f -> JTree c f
resetEvidences t = t {nodeValueMap = Map.map resetNodeEvidence (nodeValueMap t)}
where
resetNodeEvidence (NodeValue v f _) = NodeValue v f []
separatorCluster :: JTree c a -> Sep -> c
separatorCluster t s = fromJust $ Map.lookup s (separatorClusterMap t)
leaves :: JTree c a -> [c]
leaves = Set.toList . leavesSet
treeNodes :: JTree c a -> [c]
treeNodes = Map.keys . nodeValueMap
treeValues :: JTree c f -> [(c,NodeValue f)]
treeValues = Map.toList . nodeValueMap
nodeValue :: Ord c => JTree c a -> c -> NodeValue a
nodeValue t e = fromJust $ Map.lookup e (nodeValueMap t)
setNodeValue :: Ord c => c -> NodeValue a -> JTree c a -> JTree c a
setNodeValue c v t = t {nodeValueMap = Map.insert c v (nodeValueMap t)}
nodeParent :: Ord c => JTree c a -> c -> Maybe Sep
nodeParent t e = Map.lookup e (parentMap t)
separatorValue :: Ord c => JTree c a -> Sep -> SeparatorValue a
separatorValue t e = fromJust $ Map.lookup e (separatorValueMap t)
separatorParent :: Ord c => JTree c a -> Sep -> c
separatorParent t e = fromJust $ Map.lookup e (separatorParentMap t)
upMessage :: Ord c => JTree c a -> Sep -> a
upMessage t c = case separatorValue t c of
SeparatorValue up _ -> up
_ -> error "Trying to get an up message on an empty seperator ! Should never occur !"
downMessage :: Ord c => JTree c a -> Sep -> Maybe a
downMessage t c = case separatorValue t c of
SeparatorValue _ (Just down) -> Just down
SeparatorValue _ Nothing -> Nothing
_ -> error "Trying to get a down message on an empty separator ! Should never occur !"
nodeChildren :: Ord c => JTree c a -> c -> [Sep]
nodeChildren t e = maybe [] id $ Map.lookup e (childrenMap t)
separatorChild :: Ord c => JTree c a -> Sep -> c
separatorChild t e = fromJust $ Map.lookup e (separatorChildMap t)
nodeIsMemberOfTree :: Ord c => c -> JTree c a -> Bool
nodeIsMemberOfTree c t = Map.member c (nodeValueMap t)
addSeparator :: (Ord c)
=> c
-> c
-> c
-> JTree c a
-> JTree c a
addSeparator node sepCluster dest t =
let newSep = (separatorCurrentKey t) + 1
in
t { childrenMap = Map.insertWith' (++) node [newSep] (childrenMap t)
, separatorChildMap = Map.insert newSep dest (separatorChildMap t)
, separatorValueMap = Map.insert newSep EmptySeparator (separatorValueMap t)
, separatorClusterMap = Map.insert newSep sepCluster (separatorClusterMap t)
, leavesSet = Set.delete node (leavesSet t)
, parentMap = Map.insert dest newSep (parentMap t)
, separatorParentMap = Map.insert newSep node (separatorParentMap t)
, separatorCurrentKey = newSep
}
addNode :: (Ord c)
=> c
-> Vertex
-> [a]
-> [a]
-> JTree c a
-> JTree c a
addNode node vertex factorValue evidenceValue t =
t { nodeValueMap = Map.insert node (NodeValue vertex factorValue evidenceValue) (nodeValueMap t)
, leavesSet = Set.insert node (leavesSet t)
}
updateUpMessage :: Ord c
=> Maybe Sep
-> a
-> JTree c a
-> JTree c a
updateUpMessage Nothing _ t = t
updateUpMessage (Just sep) newval t =
let newSepValue = case separatorValue t sep of
EmptySeparator -> SeparatorValue newval Nothing
SeparatorValue up down -> SeparatorValue newval down
in
t {separatorValueMap = Map.insert sep newSepValue (separatorValueMap t)}
updateDownMessage :: Ord c
=> Sep
-> a
-> JTree c a
-> JTree c a
updateDownMessage sep newval t =
let newSepValue = case separatorValue t sep of
EmptySeparator -> error "Can't set a down message on an empty separator"
SeparatorValue up _ -> SeparatorValue up (Just newval)
in
t {separatorValueMap = Map.insert sep newSepValue (separatorValueMap t)}
class Message f c | f -> c where
newMessage :: [f] -> NodeValue f -> c -> f
separatorInitialized :: SeparatorValue a -> Bool
separatorInitialized EmptySeparator = False
separatorInitialized _ = True
allSeparatorsHaveReceivedAMessage :: Ord c
=> JTree c a
-> [Sep]
-> Bool
allSeparatorsHaveReceivedAMessage t seps =
all separatorInitialized . map (separatorValue t) $ seps
updateUpSeparator :: (Message a c, Ord c)
=> JTree c a
-> c
-> JTree c a
updateUpSeparator t h =
let seps = nodeChildren t h
in
case allSeparatorsHaveReceivedAMessage t seps of
False -> t
True -> let incomingMessages = map (upMessage t) seps
currentValue = nodeValue t h
destinationNode = nodeParent t h
in
case destinationNode of
Nothing -> t
Just p -> let sepC = separatorCluster t p
generatedMessage = newMessage incomingMessages currentValue sepC
in
updateUpMessage destinationNode generatedMessage t
updateDownSeparator :: (Message a c, Ord c)
=> c
-> JTree c a
-> Sep
-> JTree c a
updateDownSeparator node t child =
let incomingMessagesFromBelow = map (upMessage t) (nodeChildren t node \\ [child])
messageFromAbove = downMessage t =<< (nodeParent t node)
incomingMessages = maybe incomingMessagesFromBelow (\x -> x:incomingMessagesFromBelow) messageFromAbove
currentValue = nodeValue t node
childC = separatorCluster t child
generatedMessage = newMessage incomingMessages currentValue childC
in
updateDownMessage child generatedMessage t
unique :: Ord c => [c] -> [c]
unique = Set.toList . Set.fromList
collect :: (Ord c, Message a c)
=> JTree c a
-> JTree c a
collect t = _collectNodes (leaves t) t
_collectSeparators :: (Ord c, Message a c)
=> [Sep]
-> JTree c a
-> JTree c a
_collectSeparators l t = _collectNodes (unique . map (separatorParent t) $ l) t
_collectNodes :: (Ord c, Message a c)
=> [c]
-> JTree c a
-> JTree c a
_collectNodes [] t = t
_collectNodes l t =
let newTree = foldl' updateUpSeparator t l
in
_collectSeparators (mapMaybe (nodeParent t) l) newTree
distribute :: (Ord c, Message a c)
=> JTree c a
-> JTree c a
distribute t = _distributeNodes t (root t)
_distributeSeparators :: (Ord c, Message a c)
=> JTree c a
-> Sep
-> JTree c a
_distributeSeparators t node = _distributeNodes t (separatorChild t node)
_distributeNodes :: (Ord c, Message a c)
=> JTree c a
-> c
-> JTree c a
_distributeNodes t node =
let children = nodeChildren t node
newTree = foldl' (updateDownSeparator node) t $ children
in
foldl' _distributeSeparators newTree children
class IsCluster c where
overlappingEvidence :: c -> [DVI Int] -> [DVI Int]
clusterVariables :: c -> [DV]
mkSeparator :: c -> c -> c
instance IsCluster [DV] where
overlappingEvidence c e = filter (\x -> instantiationVariable x `elem` c) e
clusterVariables = id
mkSeparator = intersect
data Action s a = Skip !s
| ModifyAndStop !s !a
| Modify !s !a
| Stop !s
traverseTree :: Ord c
=> (s -> c -> NodeValue f -> Action s (NodeValue f))
-> s
-> JTree c f
-> (JTree c f,s)
traverseTree action state t = _traverseTreeNodes action (t,state) (root t)
_traverseTreeSeparators action (t,state) current = _traverseTreeNodes action (t,state) (separatorChild t current)
_traverseTreeNodes action (t,state) current =
case action state current (nodeValue t current) of
Stop newState -> (t,newState)
ModifyAndStop _ newValue -> (setNodeValue current newValue t, state)
Skip newState -> foldl' (_traverseTreeSeparators action) (t,newState) (nodeChildren t current)
Modify newState newValue ->
let newTree = setNodeValue current newValue t
in
foldl' (_traverseTreeSeparators action) (newTree,newState) (nodeChildren newTree current)
mapWithCluster :: Ord c
=> (c -> NodeValue f -> NodeValue f)
-> JTree c f
-> JTree c f
mapWithCluster f t = t {nodeValueMap = Map.mapWithKey f (nodeValueMap t)}
updateTreeValues :: (Factor f, IsCluster c, Ord c, Show c, Show f)
=> (f -> NodeValue f -> NodeValue f)
-> [f]
-> JTree c f
-> JTree c f
updateTreeValues change factors t =
let allNodes = treeNodes t
factorIncludedInCluster f c = all (`elem` clusterVariables c) (factorVariables f)
coveringClusters f = filter (f `factorIncludedInCluster`) allNodes
clusterSize a = product . map (fromIntegral . dimension) . clusterVariables $ a :: Integer
addFactor t newFactor =
let minimumCluster = minimumBy (compare `on` clusterSize) (coveringClusters newFactor)
clusterValue = nodeValue t minimumCluster
in
setNodeValue minimumCluster (change newFactor clusterValue) t
in
foldl' addFactor t factors
setFactors :: (Graph g, Factor f, IsCluster c, Ord c, Show c, Show f)
=> BayesianNetwork g f
-> JTree c f
-> JTree c f
setFactors g t =
let factors = allVertexValues g
changeFactor f (NodeValue v oldf e) = NodeValue v (f:oldf) e
in
updateTreeValues changeFactor factors t
changeEvidence :: (IsCluster c, Ord c, Factor f, Message f c, Show c, Show f)
=> [DVI Int]
-> JTree c f
-> JTree c f
changeEvidence e t =
let evidences = map factorFromInstantiation e
changeEvidence newe (NodeValue v f olde) = NodeValue v f (newe:olde)
in
distribute .
collect .
updateTreeValues changeEvidence evidences .
resetEvidences $
t { separatorValueMap = Map.map (const EmptySeparator) (separatorValueMap t)}
newtype Cluster = Cluster (Set.Set DV) deriving(Eq,Ord)
instance IsCluster Cluster where
overlappingEvidence c = overlappingEvidence (fromCluster c)
clusterVariables c = clusterVariables (fromCluster c)
mkSeparator (Cluster a) (Cluster b) = Cluster (Set.intersection a b)
instance Show Cluster where
show (Cluster s) = show . Set.toList $ s
fromCluster (Cluster s) = Set.toList s
instance Factor f => Message f Cluster where
newMessage input (NodeValue _ f e) dv =
let allFactors = f ++ e ++ input
variablesToKeep = fromCluster dv
variablesToRemove = (nub (concatMap factorVariables allFactors)) \\ variablesToKeep
in
marginal allFactors variablesToRemove variablesToKeep []
type JunctionTree f = JTree Cluster f
data NodeKind c = N !c | S !c
label True c a = c ++ "=" ++ show a
label False c _ = c
toTree :: (Ord c, Show c, Show a)
=> Bool
-> JTree c a
-> Tree.Tree String
toTree d t =
let r = root t
v = nodeValue t r
nodec = nodeChildren t r
in
Tree.Node (label d (show r) v) (_toTreeSeparators d t nodec)
_toTreeNodes :: (Ord c, Show c, Show a)
=> Bool
-> JTree c a
-> [c]
-> [Tree.Tree String]
_toTreeNodes _ _ [] = []
_toTreeNodes d t (h:l) =
let nodec = nodeChildren t h
v = nodeValue t h
in
Tree.Node (label d (show h) v) (_toTreeSeparators d t nodec):_toTreeNodes d t l
_toTreeSeparators :: (Ord c, Show c, Show a)
=> Bool
-> JTree c a
-> [Sep]
-> [Tree.Tree String]
_toTreeSeparators _ _ [] = []
_toTreeSeparators d t (h:l) =
let separatorc = [separatorChild t h]
v = separatorValue t h
in
Tree.Node (label d ("<" ++ show (separatorCluster t h) ++ ">") v ) (_toTreeNodes d t separatorc):_toTreeSeparators d t l
instance (Ord c, Show c, Show a) => Show (JTree c a) where
show = Tree.drawTree . toTree False
displayTree b = Tree.drawTree . toTree b
displayTreeValues :: (Show f, Show c) => JTree c f -> IO ()
displayTreeValues t =
let allValues = treeValues t
printAValue (c,NodeValue _ f e) = do
print c
putStrLn "FACTOR"
print f
putStrLn "EVIDENCE"
print e
putStrLn "------"
in
mapM_ printAValue allValues