-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-----------------------------------------------------------------------------

module Bayes.Inference
   ( getFactors, inferNetwork, pruneNetwork, infer, simulate, Query
   , inferEvidence
   , trimFor
   , toEvidence
   , toNetwork
   ) where

import Data.List
import Data.Maybe
import Bayes.EliminationOrdering
import Bayes.Evidence
import Bayes.Factor hiding (size)
import qualified Bayes.Factor as F
import Bayes.Network
import Bayes.NodeTypes
import Bayes.Probability
import qualified Data.Map as M
import qualified Data.Set as S

-- set of query variables
type Query = S.Set String

-- list of factors, set query variables qs, and elimination variables vs
posteriors :: Query -> [Factor] -> EliminationOrdering -> M.Map String [Probability]
posteriors qs list order =
   case nextVariable order list [] of
      Nothing -> M.empty
      Just (v, rest)
         | v `S.member` qs -> M.insert v ps results
         | otherwise ->
              results
       where
         ps      = query v list rest
         results = posteriors (S.delete v qs) (eliminate list v) rest

query :: String -> [Factor] -> EliminationOrdering -> [Probability]
query q list order =
   case nextVariable order list [q] of
      Nothing -> map fromDouble (values (normalize (mconcat list)))
      Just (v, rest) -> query q (if q==v then list else eliminate list v) rest

pruneNetwork :: Query -> Network a -> Network a
pruneNetwork qs nw = filterNodes ((`S.member` keep) . nodeId) nw
 where
   keep = collect S.empty (S.toList qs)

   collect acc [] = acc
   collect acc (x:xs)
      | x `S.member` acc = collect acc xs
      | otherwise =
           case findNode nw x of
              Just n  -> collect (S.insert x acc) (parentIds n ++ xs)
              Nothing -> collect acc xs

getFactors :: Network a -> [Factor]
getFactors nw =
   [ nodeToFactor nw n | n <- nodes nw ]

cptFactor :: [Node a] -> [Probability] -> Factor
cptFactor ns = makeFactor (map (\n -> (nodeId n, size n)) ns) . map toDouble

nodeToFactor :: Network a -> Node a -> Factor
nodeToFactor nw n =
   case definition n of
      CPT xs -> cptFactor (parents nw n ++ [n]) xs
      NoisyMax str xs -> mkNoisyOr (nodeId n, size n) ps str xs
      NoisyAdder dst ws xs -> mkNoisyAdder (nodeId n, size n) ps dst ws xs
 where
   ps = map (\x -> (nodeId x, size x)) (parents nw n)

infer :: Network () -> Evidence -> EliminationOrdering -> Query -> M.Map String [Probability]
infer nw0 ev vs qs0 = posteriors qs list vs `M.union` givens
 where
   nw   = pruneNetwork (qs0 `S.union` S.fromList (map fst cs)) nw0
   qs   = qs0 S.\\ S.fromList (map fst cs)
   list = map (conditions cs) (getFactors nw ++ virtualFactors ev)

   givens = M.fromList $ concatMap f $ fromEvidenceTp ev
    where
      f (s, Index i) | s `elem` qs0 =
         [(s, [ if a == i then 1 else 0 | a <- take n [0..] ])]
       where
         n = maybe 0 size (findNode nw s)
      f _ = []

   -- 'cs' is the index-map from the evidence
   cs = indexMap ev

   virtualFactors :: Evidence -> [Factor]
   virtualFactors = concatMap f . fromEvidenceTp
    where
      f (s, Virtual ps) = [virtualFactor s (map snd ps)]
      f (_, Index _)    = []


-- | Fill in the probabilities into a network.
toNetwork :: Network () -> M.Map String [Probability] -> Network Probability
toNetwork nw result = mapNodes f nw
   where
   f n = let xs = fromMaybe [] $ M.lookup (nodeId n) result
         in n { states = zipWith (\(s, _) p -> (s, p)) (states n) xs}

-- | Fill in the probabilities into Evidence.
toEvidence :: Network () -> M.Map String [Probability] -> Evidence
toEvidence nw = getVirtuals . toNetwork nw

indexMap :: Evidence -> [(String, Int)]
indexMap = map f . fromEvidenceTp
 where
   f (s, Index i)   = (s, i)
   f (s, Virtual _) = ('#':s, 0)

inferNetwork :: Network () -> Evidence -> EliminationOrdering -> Query -> Network Probability
inferNetwork nw ev vs q = toNetwork nw $ infer nw ev vs q

inferEvidence :: Network () -> Evidence -> EliminationOrdering -> Query -> Evidence
inferEvidence nw ev vs q = toEvidence nw $ infer nw ev vs q

simulate :: String -> EliminationOrdering -> [Factor] -> IO ()
simulate qv eo fs = do
   printFactors
   case nextVariable eo fs [qv] of
      Nothing -> print (map normalize fs)
      Just (x, eo') -> do
         putStrLn $ "   => " ++ x
         simulate qv eo' (eliminate fs x)
 where
   ns = map F.size fs
   ss = nub $ concatMap vars fs
   printFactors = do
      putStrLn $ "total size: " ++ show (sum ns)
      putStrLn $ "#factors: " ++ show (length ns)
      putStrLn $ "#vars: " ++ show (length ss)
      print ns


-- | Remove all 'Evidence' that cannot be fed to the given 'Network'.
trimFor :: Network a -> Evidence -> Evidence
trimFor nw = filterEvidence (`elem` map nodeId (nodes nw))

getVirtuals :: Network Probability -> Evidence
getVirtuals nw = mconcat $ map f (nodes nw)
 where
   f n = virtual n (map snd (states n))