module Bayes.FactorElimination(
moralGraph
, nodeComparisonForTriangulation
, numberOfAddedEdges
, weight
, weightedEdges
, triangulate
, createClusterGraph
, Cluster
, createJunctionTree
, createUninitializedJunctionTree
, JunctionTree
, displayTreeValues
, collect
, distribute
, posterior
, changeEvidence
, junctionTreeProperty_prop
, junctionTreeAllClusters_prop
, VertexCluster
, junctionTreeProperty
, maximumSpanningTree
, fromVertexCluster
, triangulatedebug
) where
import Bayes
import qualified Data.Foldable as F
import Data.Maybe(fromJust,mapMaybe,isJust)
import Control.Monad(mapM,guard)
import Bayes.Factor hiding (isEmpty)
import Data.Function(on)
import Data.List(minimumBy,maximumBy,inits,foldl',nub,(\\))
import qualified Data.Set as Set
import qualified Data.Map as Map
import qualified Data.Functor as Functor
import qualified Data.Tree as T
import Bayes.FactorElimination.JTree
import Control.Applicative((<$>))
import Bayes.VariableElimination(marginal)
import Test.QuickCheck hiding ((.||.), collect)
import Test.QuickCheck.Arbitrary
numberOfAddedEdges :: UndirectedGraph g
=> g a b
-> Vertex
-> Integer
numberOfAddedEdges g v =
let nodes = fromJust $ neighbors g v
in
fromIntegral $ length [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]
weightedEdges :: (UndirectedGraph g, Factor f)
=> g a f
-> Vertex
-> Integer
weightedEdges g v =
let nodes = fromJust $ neighbors g v
in
sum [weight g x * weight g y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge g x y)]
weight :: (UndirectedGraph g, Factor f)
=> g a f
-> Vertex
-> Integer
weight g v =
fromIntegral $ factorDimension . fromJust . vertexValue g $ v
(.||.) :: (a -> a -> Ordering)
-> (a -> a -> Ordering)
-> (a -> a -> Ordering)
f .||. g =
\a b -> case f a b of
EQ -> g a b
r -> r
nodeComparisonForTriangulation :: (UndirectedGraph g, Factor f)
=> g a f
-> Vertex
-> Vertex
-> Ordering
nodeComparisonForTriangulation g = (compare `on` (numberOfAddedEdges g)) .||. (compare `on` (weight g))
newtype VertexCluster = VertexCluster (Set.Set Vertex) deriving(Eq,Ord)
fromVertexCluster (VertexCluster s) = s
instance Show VertexCluster where
show (VertexCluster s) = show . Set.toList $ s
triangulate :: Graph g
=> (Vertex -> Vertex -> Ordering)
-> g () b
-> [VertexCluster]
triangulate cmp gr = removeNodes cmp gr []
where
removeNodes cmp g l | hasNoVertices g = keepMaximalClusters (reverse l)
| otherwise =
let selectedNode = minimumBy cmp (allVertices g)
theNeighbors = fromJust $ neighbors g selectedNode
g' = removeVertex selectedNode . connectAllNonAdjacentNodes theNeighbors $ g
newCluster = VertexCluster . Set.fromList $ (selectedNode:theNeighbors)
in
removeNodes cmp g' (newCluster:l)
triangulatedebug :: Graph g
=> (Vertex -> Vertex -> Ordering)
-> g () b
-> ([VertexCluster],[g () b])
triangulatedebug cmp gr = removeNodes cmp gr [] []
where
removeNodes cmp g l gl | hasNoVertices g = (reverse l,reverse gl)
| otherwise =
let selectedNode = minimumBy cmp (allVertices g)
theNeighbors = fromJust $ neighbors g selectedNode
g' = removeVertex selectedNode . connectAllNonAdjacentNodes theNeighbors $ g
newCluster = VertexCluster . Set.fromList $ (selectedNode:theNeighbors)
in
removeNodes cmp g' (newCluster:l) (g:gl)
findContainingCluster :: VertexCluster
-> [VertexCluster]
-> (Maybe VertexCluster,[VertexCluster])
findContainingCluster cluster l =
let clusterIsNotASubsetOf s = (Set.isSubsetOf (fromVertexCluster cluster) (fromVertexCluster s))
(prefix,suffix) = break clusterIsNotASubsetOf l
in
case suffix of
[] -> (Nothing,l)
_ -> (Just (head suffix),prefix ++ tail suffix)
keepMaximalClusters :: [VertexCluster] -> [VertexCluster]
keepMaximalClusters [] = []
keepMaximalClusters l = checkIfMaximal [] (head l) (tail l)
where
checkIfMaximal reversedPrefix current [] =
case findContainingCluster current (reverse reversedPrefix) of
(Nothing,_) -> reverse (current:reversedPrefix)
(Just r,l) -> reverse (r:reverse l)
checkIfMaximal reversedPrefix current suffix =
case findContainingCluster current (reverse reversedPrefix) of
(Nothing,_) -> checkIfMaximal (current:reversedPrefix) (head suffix) (tail suffix)
(Just r,l) -> checkIfMaximal (r:reverse l) (head suffix) (tail suffix)
vertexClusterToCluster :: (Factor f , Graph g)
=> g e f
-> VertexCluster
-> Cluster
vertexClusterToCluster g c =
let vertices = Set.toList . fromVertexCluster $ c
variables = map factorMainVariable . mapMaybe (vertexValue g) $ vertices
in
Cluster . Set.fromList $ variables
createClusterGraph :: (UndirectedGraph g, Factor f, Graph g')
=> g' e f
-> [VertexCluster]
-> g Int Cluster
createClusterGraph bn c =
let numberedClusters = zip c (map Vertex [0..])
addCluster g (c,v) = addVertex v (vertexClusterToCluster bn c) g
graphWithoutEdges = foldl' addCluster emptyGraph numberedClusters
separatorSize ca cb = Set.size $ Set.intersection (fromVertexCluster ca) (fromVertexCluster cb)
allEdges = [(cx,cy) | cx <- numberedClusters, cy <- numberedClusters, cx /= cy]
addClusterEdge g ((ca,va),(cb,vb)) = addEdge (edge va vb) (separatorSize ca cb) g
in
foldl' addClusterEdge graphWithoutEdges allEdges
possibilities :: (Ord c , UndirectedGraph g)
=> g Int c
-> JTree c f
-> [Vertex]
-> [c]
-> [(Vertex,c,Int)]
possibilities g currentT remaining leavesClusters = do
rv <- remaining
lv <- leavesClusters
let NodeValue lvVertex lvCluster _ = nodeValue currentT lv
guard (isLinkedWithAnEdge g rv lvVertex)
let ev = fromJust $ edgeValue g (edge rv lvVertex)
return $ (rv,lv,ev)
findMax :: (UndirectedGraph g, Ord c, Factor f,Show c)
=> g Int c
-> [Vertex]
-> JTree c f
-> ([Vertex],(Vertex,c),c)
findMax g remaining currentT =
let leavesClusters = treeNodes currentT
edgeValue (_,_,e) = e
(rf,lf,ef) = maximumBy (compare `on` edgeValue) (possibilities g currentT remaining leavesClusters)
remaining' = filter (/= rf) remaining
foundCluster = fromJust $ vertexValue g rf
in
(remaining', (rf, foundCluster), lf)
maximumSpanningTree :: (UndirectedGraph g, IsCluster c, Factor f, Ord c, Show c, Show f)
=> g Int c
-> JTree c f
maximumSpanningTree g =
let rootNodeVertex = fromJust $ someVertex g
rootNodeValue = fromJust $ vertexValue g rootNodeVertex
startTree = singletonTree rootNodeValue rootNodeVertex [] []
remainingVertices = filter (/= rootNodeVertex) (allVertices g)
in
buildTree g remainingVertices startTree
buildTree :: (UndirectedGraph g , IsCluster c, Factor f, Ord c, Show c, Show f)
=> g Int c
-> [Vertex]
-> JTree c f
-> JTree c f
buildTree g [] currentT = currentT
buildTree g l currentT =
let (l',(foundElemVertex,foundElemValue),leaf) = findMax g l currentT
sep = mkSeparator foundElemValue leaf
newTree = addSeparator leaf sep foundElemValue .
addNode foundElemValue foundElemVertex [] [] $ currentT
in
buildTree g l' newTree
createUninitializedJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, Show f)
=> (UndirectedSG () f -> Vertex -> Vertex -> Ordering)
-> g () f
-> JunctionTree f
createUninitializedJunctionTree cmp g =
let theMoralGraph = moralGraph g
clusters = triangulate (cmp theMoralGraph) theMoralGraph
g'' = createClusterGraph g clusters :: UndirectedSG Int Cluster
in
maximumSpanningTree g''
createJunctionTree :: (DirectedGraph g, FoldableWithVertex g, NamedGraph g, Factor f, Show f)
=> (UndirectedSG () f -> Vertex -> Vertex -> Ordering)
-> BayesianNetwork g f
-> JunctionTree f
createJunctionTree cmp g =
let cTree = createUninitializedJunctionTree cmp g
newTree = setFactors g cTree
in
distribute . collect $ newTree
posterior :: Factor f => JunctionTree f -> DV -> Maybe f
posterior t v =
case snd $ traverseTree (findClusterFor v) Nothing t of
Nothing -> Nothing
Just c -> let NodeValue ver f e = nodeValue t c
d = maybe (factorFromScalar 1.0) id $ downMessage t =<< (nodeParent t c)
u = map (upMessage t) (nodeChildren t c)
allFactors = d:u ++ f ++ e
variablesToRemove = (nub (concatMap factorVariables allFactors)) \\ [v]
unNormalized = marginal allFactors variablesToRemove [v] []
in
Just $ factorDivide unNormalized (factorNorm unNormalized)
findClusterFor :: DV
-> Maybe Cluster
-> Cluster
-> NodeValue f
-> Action (Maybe Cluster) (NodeValue f)
findClusterFor dv s c@(Cluster sc) v =
case Set.member dv sc of
False -> Skip s
True -> Stop (Just c)
junctionTreeProperty_prop :: DirectedSG () CPT -> Property
junctionTreeProperty_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==>
let cmp ug = (compare `on` (numberOfAddedEdges ug))
t = createUninitializedJunctionTree cmp g
in
junctionTreeProperty t [] (root t)
junctionTreeAllClusters_prop :: DirectedSG () CPT -> Property
junctionTreeAllClusters_prop g = (not . isEmpty) g && (not . hasNoEdges) g && connectedGraph g ==>
let theMoralGraph = moralGraph g
cmp ug = (compare `on` (numberOfAddedEdges ug))
clusters = triangulate (cmp theMoralGraph) theMoralGraph
g'' = createClusterGraph g clusters :: UndirectedSG Int Cluster
jt = maximumSpanningTree g'' :: JunctionTree CPT
treeClusters = treeNodes jt
sa = Set.fromList (map (vertexClusterToCluster g) clusters)
sb = Set.fromList treeClusters
in
Set.isSubsetOf sa sb && Set.isSubsetOf sb sa
junctionTreeProperty :: JTree Cluster CPT -> [Cluster] -> Cluster -> Bool
junctionTreeProperty t path c =
let cl = map (separatorChild t) . nodeChildren t $ c
in
checkPath c path && all (junctionTreeProperty t (c:path)) cl
checkPath :: Cluster -> [Cluster] -> Bool
checkPath _ [] = True
checkPath (Cluster c) l =
let clusterSet (Cluster s) = s
parentSets = map clusterSet l
allIntersectionsWithParents = map (Set.intersection c) parentSets
pathsToEachParent = tail . inits $ parentSets
isSubsetOfAllParents i path = all (Set.isSubsetOf i) path
in
and $ zipWith isSubsetOfAllParents allIntersectionsWithParents pathsToEachParent
parents :: DirectedGraph g => g a b -> Vertex -> [Vertex]
parents g v = fromJust $ ingoing g v >>= mapM (startVertex g)
children :: DirectedGraph g => g a b -> Vertex -> [Vertex]
children g v = fromJust $ outgoing g v >>= mapM (endVertex g)
connectAllNonAdjacentNodes :: (Graph g)
=> [Vertex]
-> g () b
-> g () b
connectAllNonAdjacentNodes nodes originGraph =
let addEmptyEdge g e = addEdge e () g
in
foldl' addEmptyEdge originGraph [edge x y | x <- nodes, y <- nodes, x /= y, not (isLinkedWithAnEdge originGraph x y)]
addMissingLinks :: DirectedGraph g => g () b -> Vertex -> b -> g () b
addMissingLinks g v _ = connectAllNonAdjacentNodes (parents g v) g
convertToUndirected :: (FoldableWithVertex g, Graph g, NamedGraph g, NamedGraph g',UndirectedGraph g')
=> g () b
-> g' () b
convertToUndirected m =
let addVertexWithLabel g v dat =
let theName = fromJust $ vertexLabel m v
in
addLabeledVertex theName v dat g
newDiscreteGraph = foldlWithVertex' addVertexWithLabel emptyGraph m
addEmptyEdge edge g = addEdge edge () g
in
foldr addEmptyEdge newDiscreteGraph . allEdges $ m
moralGraph :: (NamedGraph g, FoldableWithVertex g, DirectedGraph g)
=> g () b -> UndirectedSG () b
moralGraph g =
convertToUndirected . foldlWithVertex' addMissingLinks g $ g