module Bayes.Inference
( getFactors, inferNetwork, pruneNetwork, infer, simulate, Query
, inferEvidence
, trimFor
, toEvidence
, toNetwork
) where
import Data.List
import Data.Maybe
import Bayes.EliminationOrdering
import Bayes.Evidence
import Bayes.Factor hiding (size)
import qualified Bayes.Factor as F
import Bayes.Network
import Bayes.NodeTypes
import Bayes.Probability
import qualified Data.Map as M
import qualified Data.Set as S
type Query = S.Set String
posteriors :: Query -> [Factor] -> EliminationOrdering -> M.Map String [Probability]
posteriors qs list order =
case nextVariable order list [] of
Nothing -> M.empty
Just (v, rest)
| v `S.member` qs -> M.insert v ps results
| otherwise ->
results
where
ps = query v list rest
results = posteriors (S.delete v qs) (eliminate list v) rest
query :: String -> [Factor] -> EliminationOrdering -> [Probability]
query q list order =
case nextVariable order list [q] of
Nothing -> map fromDouble (values (normalize (mconcat list)))
Just (v, rest) -> query q (if q==v then list else eliminate list v) rest
pruneNetwork :: Query -> Network a -> Network a
pruneNetwork qs nw = filterNodes ((`S.member` keep) . nodeId) nw
where
keep = collect S.empty (S.toList qs)
collect acc [] = acc
collect acc (x:xs)
| x `S.member` acc = collect acc xs
| otherwise =
case findNode nw x of
Just n -> collect (S.insert x acc) (parentIds n ++ xs)
Nothing -> collect acc xs
getFactors :: Network a -> [Factor]
getFactors nw =
[ nodeToFactor nw n | n <- nodes nw ]
cptFactor :: [Node a] -> [Probability] -> Factor
cptFactor ns = makeFactor (map (\n -> (nodeId n, size n)) ns) . map toDouble
nodeToFactor :: Network a -> Node a -> Factor
nodeToFactor nw n =
case definition n of
CPT xs -> cptFactor (parents nw n ++ [n]) xs
NoisyMax str xs -> mkNoisyOr (nodeId n, size n) ps str xs
NoisyAdder dst ws xs -> mkNoisyAdder (nodeId n, size n) ps dst ws xs
where
ps = map (\x -> (nodeId x, size x)) (parents nw n)
infer :: Network () -> Evidence -> EliminationOrdering -> Query -> M.Map String [Probability]
infer nw0 ev vs qs0 = posteriors qs list vs `M.union` givens
where
nw = pruneNetwork (qs0 `S.union` S.fromList (map fst cs)) nw0
qs = qs0 S.\\ S.fromList (map fst cs)
list = map (conditions cs) (getFactors nw ++ virtualFactors ev)
givens = M.fromList $ concatMap f $ fromEvidenceTp ev
where
f (s, Index i) | s `elem` qs0 =
[(s, [ if a == i then 1 else 0 | a <- take n [0..] ])]
where
n = maybe 0 size (findNode nw s)
f _ = []
cs = indexMap ev
virtualFactors :: Evidence -> [Factor]
virtualFactors = concatMap f . fromEvidenceTp
where
f (s, Virtual ps) = [virtualFactor s (map snd ps)]
f (_, Index _) = []
toNetwork :: Network () -> M.Map String [Probability] -> Network Probability
toNetwork nw result = mapNodes f nw
where
f n = let xs = fromMaybe [] $ M.lookup (nodeId n) result
in n { states = zipWith (\(s, _) p -> (s, p)) (states n) xs}
toEvidence :: Network () -> M.Map String [Probability] -> Evidence
toEvidence nw = getVirtuals . toNetwork nw
indexMap :: Evidence -> [(String, Int)]
indexMap = map f . fromEvidenceTp
where
f (s, Index i) = (s, i)
f (s, Virtual _) = ('#':s, 0)
inferNetwork :: Network () -> Evidence -> EliminationOrdering -> Query -> Network Probability
inferNetwork nw ev vs q = toNetwork nw $ infer nw ev vs q
inferEvidence :: Network () -> Evidence -> EliminationOrdering -> Query -> Evidence
inferEvidence nw ev vs q = toEvidence nw $ infer nw ev vs q
simulate :: String -> EliminationOrdering -> [Factor] -> IO ()
simulate qv eo fs = do
printFactors
case nextVariable eo fs [qv] of
Nothing -> print (map normalize fs)
Just (x, eo') -> do
putStrLn $ " => " ++ x
simulate qv eo' (eliminate fs x)
where
ns = map F.size fs
ss = nub $ concatMap vars fs
printFactors = do
putStrLn $ "total size: " ++ show (sum ns)
putStrLn $ "#factors: " ++ show (length ns)
putStrLn $ "#vars: " ++ show (length ss)
print ns
trimFor :: Network a -> Evidence -> Evidence
trimFor nw = filterEvidence (`elem` map nodeId (nodes nw))
getVirtuals :: Network Probability -> Evidence
getVirtuals nw = mconcat $ map f (nodes nw)
where
f n = virtual n (map snd (states n))