module Bayes.FactorElimination.JTree(
IsCluster(..)
, Cluster(..)
, JTree(..)
, JunctionTree(..)
, setFactors
, distribute
, collect
, fromCluster
, changeEvidence
, nodeIsMemberOfTree
, singletonTree
, addNode
, addSeparator
, leaves
, nodeValue
, NodeValue(..)
, SeparatorValue(..)
, downMessage
, upMessage
, nodeParent
, nodeChildren
, traverseTree
, separatorChild
, treeNodes
, Action(..)
) where
import qualified Data.Map as Map
import qualified Data.Tree as Tree
import Data.Maybe(fromJust,mapMaybe)
import qualified Data.Set as Set
import Data.Monoid
import Data.List((\\), intersect,partition, foldl')
import Bayes.PrivateTypes
import Bayes.Factor
import Bayes
import Debug.Trace
debug s a = trace (s ++ " " ++ show a ++ "\n") a
type UpMessage a = a
type DownMessage a = Maybe a
data SeparatorValue a = SeparatorValue !(UpMessage a) !(DownMessage a)
| EmptySeparator
deriving(Eq)
instance Show a => Show (SeparatorValue a) where
show EmptySeparator = ""
show (SeparatorValue u Nothing) = "u(" ++ show u ++ ")"
show (SeparatorValue u (Just d)) = "u(" ++ show u ++ ") d(" ++ show d ++ ")"
type FactorValue a = a
type EvidenceValue a = a
data NodeValue a = NodeValue !(FactorValue a) !(EvidenceValue a) deriving(Eq)
instance Show a => Show (NodeValue a) where
show (NodeValue f e) = "f(" ++ show f ++ ") e(" ++ show e ++ ")"
data JTree c f = JTree { root :: !c
, leavesSet :: !(Set.Set c)
, childrenMap :: !(Map.Map c [c])
, parentMap :: !(Map.Map c c)
, separatorParentMap :: !(Map.Map c c)
, separatorChildMap :: !(Map.Map c c)
, nodeValueMap :: !(Map.Map c (NodeValue f))
, separatorValueMap :: !(Map.Map c (SeparatorValue f))
} deriving(Eq)
singletonTree r factorValue evidenceValue =
let t = JTree r Set.empty Map.empty Map.empty Map.empty Map.empty Map.empty Map.empty
in
addNode r factorValue evidenceValue t
leaves :: JTree c a -> [c]
leaves = Set.toList . leavesSet
treeNodes :: JTree c a -> [c]
treeNodes = Map.keys . nodeValueMap
nodeValue :: Ord c => JTree c a -> c -> NodeValue a
nodeValue t e = fromJust $ Map.lookup e (nodeValueMap t)
setNodeValue :: Ord c => c -> NodeValue a -> JTree c a -> JTree c a
setNodeValue c v t = t {nodeValueMap = Map.insert c v (nodeValueMap t)}
nodeParent :: Ord c => JTree c a -> c -> Maybe c
nodeParent t e = Map.lookup e (parentMap t)
separatorValue :: Ord c => JTree c a -> c -> SeparatorValue a
separatorValue t e = fromJust $ Map.lookup e (separatorValueMap t)
separatorParent :: Ord c => JTree c a -> c -> c
separatorParent t e = fromJust $ Map.lookup e (separatorParentMap t)
upMessage :: Ord c => JTree c a -> c -> a
upMessage t c = case separatorValue t c of
SeparatorValue up _ -> up
_ -> error "Trying to get an up message on an empty seperator ! Should never occur !"
downMessage :: Ord c => JTree c a -> c -> Maybe a
downMessage t c = case separatorValue t c of
SeparatorValue _ (Just down) -> Just down
SeparatorValue _ Nothing -> Nothing
_ -> error "Trying to get a down message on an empty separator ! Should never occur !"
nodeChildren :: Ord c => JTree c a -> c -> [c]
nodeChildren t e = maybe [] id $ Map.lookup e (childrenMap t)
separatorChild :: Ord c => JTree c a -> c -> c
separatorChild t e = fromJust $ Map.lookup e (separatorChildMap t)
nodeIsMemberOfTree :: Ord c => c -> JTree c a -> Bool
nodeIsMemberOfTree c t = Map.member c (nodeValueMap t)
addSeparator :: (Ord c)
=> c
-> c
-> c
-> JTree c a
-> JTree c a
addSeparator node sep dest t =
t { childrenMap = Map.insertWith' (++) node [sep] (childrenMap t)
, separatorChildMap = Map.insert sep dest (separatorChildMap t)
, separatorValueMap = Map.insert sep EmptySeparator (separatorValueMap t)
, leavesSet = Set.delete node (leavesSet t)
, parentMap = Map.insert dest sep (parentMap t)
, separatorParentMap = Map.insert sep node (separatorParentMap t)
}
addNode :: (Ord c)
=> c
-> a
-> a
-> JTree c a
-> JTree c a
addNode node factorValue evidenceValue t =
t { nodeValueMap = Map.insert node (NodeValue factorValue evidenceValue) (nodeValueMap t)
, leavesSet = Set.insert node (leavesSet t)
}
updateUpMessage :: Ord c
=> Maybe c
-> a
-> JTree c a
-> JTree c a
updateUpMessage Nothing _ t = t
updateUpMessage (Just sep) newval t =
let newSepValue = case separatorValue t sep of
EmptySeparator -> SeparatorValue newval Nothing
SeparatorValue up down -> SeparatorValue newval down
in
t {separatorValueMap = Map.insert sep newSepValue (separatorValueMap t)}
updateDownMessage :: Ord c
=> c
-> a
-> JTree c a
-> JTree c a
updateDownMessage sep newval t =
let newSepValue = case separatorValue t sep of
EmptySeparator -> error "Can't set a down message on an empty separator"
SeparatorValue up _ -> SeparatorValue up (Just newval)
in
t {separatorValueMap = Map.insert sep newSepValue (separatorValueMap t)}
class Message f c | f -> c where
newMessage :: [f] -> NodeValue f -> c -> f
separatorInitialized :: SeparatorValue a -> Bool
separatorInitialized EmptySeparator = False
separatorInitialized _ = True
allSeparatorsHaveReceivedAMessage :: Ord c
=> JTree c a
-> [c]
-> Bool
allSeparatorsHaveReceivedAMessage t seps =
all separatorInitialized . map (separatorValue t) $ seps
updateUpSeparator :: (Message a c, Ord c)
=> JTree c a
-> c
-> JTree c a
updateUpSeparator t h =
let seps = nodeChildren t h
in
case allSeparatorsHaveReceivedAMessage t seps of
False -> t
True -> let incomingMessages = map (upMessage t) seps
currentValue = nodeValue t h
destinationNode = nodeParent t h
in
case destinationNode of
Nothing -> t
Just p -> let generatedMessage = newMessage incomingMessages currentValue p
in
updateUpMessage destinationNode generatedMessage t
updateDownSeparator :: (Message a c, Ord c)
=> c
-> JTree c a
-> c
-> JTree c a
updateDownSeparator node t child =
let incomingMessagesFromBelow = map (upMessage t) (nodeChildren t node \\ [child])
messageFromAbove = downMessage t =<< (nodeParent t node)
incomingMessages = maybe incomingMessagesFromBelow (\x -> x:incomingMessagesFromBelow) messageFromAbove
currentValue = nodeValue t node
generatedMessage = newMessage incomingMessages currentValue child
in
updateDownMessage child generatedMessage t
unique :: Ord c => [c] -> [c]
unique = Set.toList . Set.fromList
data TraversalState = ACluster | ASeparator
collect :: (Ord c, Message a c)
=> JTree c a
-> JTree c a
collect t = _collect ACluster (leaves t) t
_collect :: (Ord c, Message a c)
=> TraversalState
-> [c]
-> JTree c a
-> JTree c a
_collect _ [] t = t
_collect ACluster l t =
let newTree = foldl' updateUpSeparator t l
in
_collect ASeparator (mapMaybe (nodeParent t) l) newTree
_collect ASeparator l t = _collect ACluster (unique . map (separatorParent t) $ l) t
distribute :: (Ord c, Message a c)
=> JTree c a
-> JTree c a
distribute t = _distribute ACluster t (root t)
_distribute :: (Ord c, Message a c)
=> TraversalState
-> JTree c a
-> c
-> JTree c a
_distribute ACluster t node =
let children = nodeChildren t node
newTree = foldl' (updateDownSeparator node) t $ children
in
foldl' (_distribute ASeparator) newTree children
_distribute ASeparator t node = _distribute ACluster t (separatorChild t node)
class IsCluster c where
overlappingEvidence :: c -> [DVI Int] -> [DVI Int]
clusterVariables :: c -> [DV]
mkSeparator :: c -> c -> c
instance IsCluster [DV] where
overlappingEvidence c e = filter (\x -> instantiationVariable x `elem` c) e
clusterVariables = id
mkSeparator = intersect
data Action s a = Skip !s
| ModifyAndStop !s !a
| Modify !s !a
| Stop !s
traverseTree :: Ord c
=> (s -> c -> NodeValue f -> Action s (NodeValue f))
-> s
-> JTree c f
-> (JTree c f,s)
traverseTree action state t = _traverseTree True action (t,state) (root t)
_traverseTree False action (t,state) current = _traverseTree True action (t,state) (separatorChild t current)
_traverseTree True action (t,state) current =
case action state current (nodeValue t current) of
Stop newState -> (t,newState)
ModifyAndStop _ newValue -> (setNodeValue current newValue t, state)
Skip newState -> foldl' (_traverseTree False action) (t,newState) (nodeChildren t current)
Modify newState newValue ->
let newTree = setNodeValue current newValue t
in
foldl' (_traverseTree False action) (newTree,newState) (nodeChildren newTree current)
mapWithCluster :: Ord c
=> (c -> NodeValue f -> NodeValue f)
-> JTree c f
-> JTree c f
mapWithCluster f t = t {nodeValueMap = Map.mapWithKey f (nodeValueMap t)}
setFactors :: (Graph g, Factor f, Show f, IsCluster c, Ord c)
=> BayesianNetwork g f
-> JTree c f
-> JTree c f
setFactors g t =
let factors = allVertexValues g
in
fst . traverseTree updateFactor factors $ t
updateFactor :: (Factor f, IsCluster c)
=> [f]
-> c
-> NodeValue f
-> Action [f] (NodeValue f)
updateFactor lf c (NodeValue _ evidence) | null lf = Stop lf
| otherwise =
let isFactorIncluded l = all (`elem` clusterVariables c) (factorVariables l)
(attributedFactors,remainingFactors) = partition isFactorIncluded lf
in
Modify remainingFactors (NodeValue (factorProduct attributedFactors) evidence)
changeEvidence :: (IsCluster c, Ord c, Factor f, Message f c)
=> [DVI Int]
-> JTree c f
-> JTree c f
changeEvidence e t = distribute .
collect . fst .
traverseTree changeNodeEvidence e $
t { separatorValueMap = Map.map (const EmptySeparator) (separatorValueMap t)}
changeNodeEvidence :: (IsCluster c, Factor f)
=> [DVI Int]
-> c
-> NodeValue f
-> Action [DVI Int] (NodeValue f)
changeNodeEvidence [] c v = Stop []
changeNodeEvidence e c (NodeValue f olde) =
let oe = overlappingEvidence c e
ns = e \\ oe
newEvidence = factorProduct $ map factorFromInstantiation oe
in
Modify ns (NodeValue f newEvidence)
newtype Cluster = Cluster (Set.Set DV) deriving(Eq,Ord)
instance Show Cluster where
show (Cluster s) = show . Set.toList $ s
fromCluster (Cluster s) = Set.toList s
instance Factor f => Message f Cluster where
newMessage input (NodeValue f e) dv = factorProjectTo (fromCluster dv) (factorProduct (f:e:input))
type JunctionTree f = JTree Cluster f
data NodeKind c = N !c | S !c
label True c a = c ++ "=" ++ show a
label False c _ = c
toTree :: (Ord c, Show c, Show a)
=> Bool
-> JTree c a
-> Tree.Tree String
toTree d t =
let r = root t
v = nodeValue t r
nodec = map S (nodeChildren t r)
in
Tree.Node (label d (show r) v) (_toTree d t nodec)
_toTree :: (Ord c, Show c, Show a)
=> Bool
-> JTree c a
-> [NodeKind c]
-> [Tree.Tree String]
_toTree _ _ [] = []
_toTree d t ((N h):l) =
let nodec = map S (nodeChildren t h)
v = nodeValue t h
in
Tree.Node (label d (show h) v) (_toTree d t nodec):_toTree d t l
_toTree d t ((S h):l) =
let separatorc = [N $ separatorChild t h]
v = separatorValue t h
in
Tree.Node (label d ("<" ++ show h ++ ">") v ) (_toTree d t separatorc):_toTree d t l
instance (Ord c, Show c, Show a) => Show (JTree c a) where
show = Tree.drawTree . toTree False
displayTree b = Tree.drawTree . toTree b