{- | Junction Trees 

The Tree data structures are not working very well with message passing algorithms. So, junction trees are using
a different representation

{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
module Bayes.FactorElimination.JTree(
    , Cluster(..)
    , JTree(..)
    , JunctionTree(..)
    , setFactors
    , distribute 
    , collect
    , fromCluster
    , changeEvidence
    , nodeIsMemberOfTree
    , singletonTree
    , addNode 
    , addSeparator
    , leaves
    , nodeValue
    , NodeValue(..)
    , SeparatorValue(..)
    , downMessage
    , upMessage 
    , nodeParent 
    , nodeChildren
    , traverseTree
    , separatorChild
    , treeNodes
    , Action(..)
    ) where 

import qualified Data.Map as Map
import qualified Data.Tree as Tree
import Data.Maybe(fromJust,mapMaybe)
import qualified Data.Set as Set
import Data.Monoid
import Data.List((\\), intersect,partition, foldl')
import Bayes.PrivateTypes 
import Bayes.Factor
import Bayes

import Debug.Trace 
debug s a = trace (s ++ " " ++ show a ++ "\n") a

type UpMessage a = a 
type DownMessage a = Maybe a

-- | Separator value
data SeparatorValue a = SeparatorValue !(UpMessage a) !(DownMessage a)
                      | EmptySeparator -- ^ Use to track the progress in the collect phase

instance Show a => Show (SeparatorValue a) where 
    show EmptySeparator = ""
    show (SeparatorValue u Nothing) = "u(" ++ show u ++ ")"
    show (SeparatorValue u (Just d)) = "u(" ++ show u ++ ") d(" ++ show d ++ ")"

type FactorValue a = a 
type EvidenceValue a = a

-- | Node value
data NodeValue a = NodeValue !(FactorValue a) !(EvidenceValue a) deriving(Eq)

instance Show a => Show (NodeValue a) where 
    show (NodeValue f e) = "f(" ++ show f ++ ") e(" ++ show e ++ ")"

-- | Junction tree.
-- 'c' is the node / separator identifier (for instance a set of 'DV')
-- a are the values for a node or separator
data JTree  c f = JTree {  root :: !c
                        -- | Leaves of the tree
                        ,  leavesSet :: !(Set.Set c)
                        -- | The children of a node are separators
                        ,  childrenMap :: !(Map.Map c [c])
                        -- | Parent of a node
                        ,  parentMap :: !(Map.Map c c)
                        -- | Parent of a separator
                        ,  separatorParentMap :: !(Map.Map c c)
                        -- | The child of a seperator is a node
                        ,  separatorChildMap :: !(Map.Map c c)
                        -- | Values for nodes and seperators
                        ,  nodeValueMap :: !(Map.Map c (NodeValue f))
                        ,  separatorValueMap :: !(Map.Map c (SeparatorValue f))
                        } deriving(Eq)

-- | Create a singleton tree with just one root node
singletonTree r factorValue evidenceValue = 
    let t = JTree r Set.empty Map.empty Map.empty Map.empty Map.empty Map.empty Map.empty
    addNode r factorValue evidenceValue t

-- | Leaves of the tree
leaves :: JTree c a -> [c]
leaves = Set.toList . leavesSet

-- | All nodes of the tree
treeNodes :: JTree c a -> [c]
treeNodes = Map.keys . nodeValueMap

-- | Value of a node
nodeValue :: Ord c => JTree c a -> c -> NodeValue a 
nodeValue t e = fromJust $ Map.lookup e (nodeValueMap t)

-- | Change the value of a node
setNodeValue :: Ord c => c -> NodeValue a -> JTree c a -> JTree c a
setNodeValue c v t = t {nodeValueMap = Map.insert c v (nodeValueMap t)} 

-- | Parent of a node
nodeParent :: Ord c => JTree c a -> c -> Maybe c 
nodeParent t e = Map.lookup e (parentMap t)

-- | Value of a node
separatorValue :: Ord c => JTree c a -> c -> SeparatorValue a 
separatorValue t e = fromJust $ Map.lookup e (separatorValueMap t)

-- | Parent of a separator
separatorParent :: Ord c => JTree c a -> c -> c 
separatorParent t e = fromJust $ Map.lookup e (separatorParentMap t)

-- | UpMessage for a separator node
upMessage :: Ord c => JTree c a -> c -> a
upMessage t c = case separatorValue t c of 
                  SeparatorValue up _ -> up 
                  _ -> error "Trying to get an up message on an empty seperator ! Should never occur !"

-- | DownMessage for a separator node
downMessage :: Ord c => JTree c a -> c -> Maybe a 
downMessage t c = case separatorValue t c of 
     SeparatorValue _ (Just down) -> Just down 
     SeparatorValue _ Nothing -> Nothing
     _ -> error "Trying to get a down message on an empty separator ! Should never occur !"

-- | Return the separator childrens of a node
nodeChildren :: Ord c => JTree c a -> c -> [c]
nodeChildren t e = maybe [] id $ Map.lookup e (childrenMap t)

-- | Return the child of a separator
separatorChild :: Ord c => JTree c a -> c -> c 
separatorChild t e = fromJust $ Map.lookup e (separatorChildMap t)

-- | Check if a node is member of the tree
nodeIsMemberOfTree :: Ord c => c -> JTree c a -> Bool 
nodeIsMemberOfTree c t = Map.member c (nodeValueMap t)

-- | Add a separator between two nodes.
-- The nodes MUST already be in the tree
addSeparator :: (Ord c) 
             => c -- ^ Origin node 
             -> c -- ^ Separator
             -> c -- ^ Destination node 
             -> JTree c a -- ^ Current tree 
             -> JTree c a -- ^ Modified tree 
addSeparator node sep dest t = 
    t { childrenMap = Map.insertWith' (++) node [sep] (childrenMap t)
      , separatorChildMap = Map.insert sep dest (separatorChildMap t)
      , separatorValueMap = Map.insert sep EmptySeparator (separatorValueMap t)
      , leavesSet = Set.delete node (leavesSet t) 
      , parentMap = Map.insert dest sep (parentMap t)
      , separatorParentMap = Map.insert sep node (separatorParentMap t)

-- | Add a new node
addNode :: (Ord c) 
        => c -- ^ Node
        -> a -- ^ Factor value 
        -> a -- ^ Evidence value
        -> JTree c a 
        -> JTree c a 
addNode node factorValue evidenceValue t = 
   t { nodeValueMap = Map.insert node (NodeValue factorValue evidenceValue) (nodeValueMap t)
     , leavesSet = Set.insert node (leavesSet t) 

-- | Update the up message of a separator
updateUpMessage :: Ord c 
                => Maybe c -- ^ Separator node to update (if any : none for root node)
                -> a -- ^ New value
                -> JTree c a -- ^ Old tree
                -> JTree c a
updateUpMessage Nothing _ t = t
updateUpMessage (Just sep) newval t = 
    let newSepValue =  case separatorValue t sep of 
                         EmptySeparator -> SeparatorValue newval Nothing
                         SeparatorValue up down -> SeparatorValue newval down 
    t {separatorValueMap = Map.insert sep newSepValue (separatorValueMap t)}

-- | Update the down message of a separator
updateDownMessage :: Ord c 
                  => c -- ^ Separator node to update
                  -> a -- ^ New value
                  -> JTree c a -- ^ Old tree
                  -> JTree c a
updateDownMessage sep newval t = 
    let newSepValue = case separatorValue t sep of 
                        EmptySeparator -> error "Can't set a down message on an empty separator"
                        SeparatorValue up _ -> SeparatorValue up (Just newval)
    t {separatorValueMap = Map.insert sep newSepValue (separatorValueMap t)}


Message passing algorithms

-- | Functions used to generate new messages
class Message f c | f -> c where
    -- | Generate a new message from the received ones
    newMessage :: [f] -> NodeValue f -> c -> f 

-- | Check that a separator is initialized
separatorInitialized :: SeparatorValue a -> Bool
separatorInitialized EmptySeparator = False 
separatorInitialized _ = True

allSeparatorsHaveReceivedAMessage :: Ord c
                                  => JTree c a -- ^ Tree
                                  -> [c] -- ^ Separators
                                  -> Bool 
allSeparatorsHaveReceivedAMessage t seps = 
  all separatorInitialized . map (separatorValue t) $ seps

-- | Update the up separator by sending a message
-- But only if all the down separators have received a message
updateUpSeparator :: (Message a c, Ord c) 
                  => JTree c a 
                  -> c -- ^ Node generating the new upMessage
                  -> JTree c a 
updateUpSeparator t h  = 
    let seps = nodeChildren t h
    case allSeparatorsHaveReceivedAMessage t seps of 
      False -> t 
      True -> let incomingMessages = map (upMessage t) seps
                  currentValue = nodeValue t h
                  destinationNode = nodeParent t h
              case destinationNode of 
                Nothing -> t -- When root
                Just p -> let generatedMessage = newMessage incomingMessages currentValue p
                          updateUpMessage destinationNode generatedMessage t

-- | Update the down separator by sending a message
updateDownSeparator :: (Message a c, Ord c) 
                    => c -- ^ Node generating the message 
                    -> JTree c a 
                    -> c -- ^ Child receiving the message
                    -> JTree c a 
updateDownSeparator node t child  = 
    let incomingMessagesFromBelow = map (upMessage t) (nodeChildren t node \\ [child])
        messageFromAbove = downMessage t =<< (nodeParent t node)
        incomingMessages = maybe incomingMessagesFromBelow (\x -> x:incomingMessagesFromBelow) messageFromAbove
        currentValue = nodeValue t node
        generatedMessage = newMessage incomingMessages currentValue child
    updateDownMessage child generatedMessage t

unique :: Ord c => [c] -> [c]
unique = Set.toList . Set.fromList

data TraversalState = ACluster | ASeparator

-- | Collect message taking into account that the tree depth may be different for different leaves.
collect :: (Ord c, Message a c) 
        => JTree c a 
        -> JTree c a
collect t = _collect ACluster (leaves t) t

_collect :: (Ord c, Message a c) 
         => TraversalState -- ^ Node processing phase or separator processing phase
         -> [c]
         -> JTree c a -- ^ Tree
         -> JTree c a -- ^ Modified tree 
_collect _ [] t = t
_collect ACluster l t = 
    let newTree = foldl' updateUpSeparator t l
    _collect ASeparator (mapMaybe (nodeParent t) l) newTree
_collect ASeparator l t = _collect ACluster (unique . map (separatorParent t) $ l) t
distribute :: (Ord c, Message a c)
           => JTree c a 
           -> JTree c a
distribute t = _distribute ACluster t (root t) 

_distribute :: (Ord c, Message a c)
            => TraversalState -- ^ True if node
            -> JTree c a 
            -> c -- ^ Destination of the distribute
            -> JTree c a 
_distribute ACluster t node = 
    let children = nodeChildren t node
        newTree = foldl' (updateDownSeparator node) t $ children
    foldl' (_distribute ASeparator) newTree children
_distribute ASeparator t node = _distribute ACluster t (separatorChild t node)


Factors and evidence modifications


-- | This class is used to check if evidence or a factor is relevant
-- for a cluster
class IsCluster c where 
  overlappingEvidence :: c -> [DVI Int] -> [DVI Int]
  clusterVariables :: c -> [DV]
  mkSeparator :: c -> c -> c

instance IsCluster [DV] where 
  overlappingEvidence c e = filter (\x -> instantiationVariable x `elem` c) e
  clusterVariables = id
  mkSeparator = intersect

data Action s a = Skip !s 
                | ModifyAndStop !s !a
                | Modify !s !a
                | Stop !s

-- | Traverse a tree and modify it
traverseTree :: Ord c 
             => (s -> c -> NodeValue f -> Action s (NodeValue f)) -- ^ Modification function
             -> s -- ^ Current state
             -> JTree c f -- ^ Input tree
             -> (JTree c f,s)
traverseTree action state t = _traverseTree True action (t,state) (root t)

_traverseTree False action (t,state) current  = _traverseTree True action (t,state)  (separatorChild t current) 
_traverseTree True action (t,state) current = 
  case action state current (nodeValue t current) of 
     Stop newState -> (t,newState)
     ModifyAndStop _ newValue -> (setNodeValue current newValue t, state) 
     Skip newState -> foldl' (_traverseTree False action) (t,newState) (nodeChildren t current)
     Modify newState newValue -> 
         let newTree = setNodeValue current newValue t 
         foldl' (_traverseTree False action) (newTree,newState) (nodeChildren newTree current)

mapWithCluster :: Ord c 
               => (c -> NodeValue f -> NodeValue f)
               -> JTree c f 
               -> JTree c f        
mapWithCluster f t = t {nodeValueMap = Map.mapWithKey f (nodeValueMap t)}

-- | Set the factors in the tree 
setFactors :: (Graph g, Factor f, Show f, IsCluster c, Ord c)
           => BayesianNetwork g f 
           -> JTree c f 
           -> JTree c f
setFactors g t = 
  let factors = allVertexValues g 
  fst . traverseTree updateFactor factors $ t 

-- | Update factors in a cluster
updateFactor :: (Factor f, IsCluster c) 
             => [f] -- ^ Remaining list of factors to attribute
             -> c -- ^ Current cluster
             -> NodeValue f -- ^ Current value
             -> Action [f] (NodeValue f)
updateFactor lf c (NodeValue _ evidence) | null lf = Stop lf
                                         | otherwise =
  let isFactorIncluded l = all (`elem` clusterVariables c) (factorVariables l)
      (attributedFactors,remainingFactors) = partition isFactorIncluded lf 
  Modify remainingFactors  (NodeValue (factorProduct attributedFactors) evidence)

-- | Change evidence in the network
changeEvidence :: (IsCluster c, Ord c, Factor f, Message f c)
               => [DVI Int] -- ^ Evidence
               -> JTree c f 
               -> JTree c f 
changeEvidence e t =  distribute . 
                      collect . fst .
                      traverseTree changeNodeEvidence e $ 
                      t { separatorValueMap = Map.map (const EmptySeparator) (separatorValueMap t)}

changeNodeEvidence :: (IsCluster c, Factor f) 
                   => [DVI Int] -- ^ Evidence
                   -> c  -- ^ Current cluster
                   -> NodeValue f -- ^ Current value
                   -> Action [DVI Int] (NodeValue f)
changeNodeEvidence [] c v = Stop []
changeNodeEvidence e c (NodeValue f olde) = 
  let oe = overlappingEvidence c e
      ns = e \\ oe
      newEvidence = factorProduct $ map factorFromInstantiation oe
  Modify ns (NodeValue f newEvidence)

-- | Cluster of discrete variables.
-- Discrete variables instead of vertices are needed because the
-- factor are using 'DV' and we need to find
-- which factors must be contained in a given cluster.
newtype Cluster = Cluster (Set.Set DV) deriving(Eq,Ord)

instance Show Cluster where 
  show (Cluster s) = show . Set.toList $ s

fromCluster (Cluster s) = Set.toList s 

instance Factor f => Message f Cluster where 
  newMessage input (NodeValue f e) dv = factorProjectTo (fromCluster dv) (factorProduct (f:e:input))

type JunctionTree f = JTree Cluster f


Implement the show function to see the structure of the tree
(without the values)

data NodeKind c = N !c | S !c

label True c a = c ++ "=" ++ show a 
label False c _ = c

-- | Convert the JTree into a tree of string
-- using the cluster.
toTree :: (Ord c, Show c, Show a) 
       => Bool -- ^ True if the data must be displayed
       -> JTree c a 
       -> Tree.Tree String
toTree d t = 
    let r = root t
        v = nodeValue t r
        nodec = map S (nodeChildren t r)
    Tree.Node (label d (show r) v) (_toTree d t nodec)

_toTree :: (Ord c, Show c, Show a) 
        => Bool
        -> JTree c a 
        -> [NodeKind c] 
        -> [Tree.Tree String]
_toTree _ _ [] = []
_toTree d t ((N h):l) = 
    let nodec = map S (nodeChildren t h) -- Node children are separators
        v = nodeValue t h
    Tree.Node (label d (show h) v) (_toTree d t nodec):_toTree d t l
_toTree d t ((S h):l) = 
    let separatorc = [N $ separatorChild t h] -- separator child is a node
        v = separatorValue t h
    Tree.Node (label d ("<" ++ show h ++ ">") v ) (_toTree d t separatorc):_toTree d t l

instance (Ord c, Show c, Show a) => Show (JTree c a) where 
    show = Tree.drawTree . toTree False

displayTree b = Tree.drawTree . toTree b


Debug functions for tests


