{-|
Module      : Control.Monad.Bayes.Enumerator
Description : Exhaustive enumeration of discrete random variables
Copyright   : (c) Adam Scibior, 2015-2020
License     : MIT
Maintainer  : leonhard.markert@tweag.io
Stability   : experimental
Portability : GHC

-}

module Control.Monad.Bayes.Enumerator (
    Enumerator,
    logExplicit,
    explicit,
    evidence,
    mass,
    compact,
    enumerate,
    expectation,
    normalForm
            ) where

import Data.AEq (AEq, (===), (~==))
import Control.Applicative (Applicative, Alternative)
import Control.Monad (MonadPlus)
import Control.Arrow (second)
import qualified Data.Map as Map
import qualified Data.Vector.Generic as V
import Numeric.Log as Log
import Control.Monad.Trans.Writer
import Data.Monoid
import Data.Maybe

import Control.Monad.Bayes.Class


-- | An exact inference transformer that integrates
-- discrete random variables by enumerating all execution paths.
newtype Enumerator a = Enumerator (WriterT (Product (Log Double)) [] a)
  deriving(Functor, Applicative, Monad, Alternative, MonadPlus)

instance MonadSample Enumerator where
  random = error "Infinitely supported random variables not supported in Enumerator"
  bernoulli p = fromList [(True, (Exp . log) p), (False, (Exp . log) (1-p))]
  categorical v = fromList $ zip [0..] $ map (Exp . log) (V.toList v)

instance MonadCond Enumerator where
  score w = fromList [((), w)]

instance MonadInfer Enumerator

-- | Construct Enumerator from a list of values and associated weights.
fromList :: [(a, Log Double)] -> Enumerator a
fromList = Enumerator . WriterT . map (second Product)

-- | Returns the posterior as a list of weight-value pairs without any post-processing,
-- such as normalization or aggregation
logExplicit :: Enumerator a -> [(a, Log Double)]
logExplicit (Enumerator m) = map (second getProduct) $ runWriterT m

-- | Same as `toList`, only weights are converted from log-domain.
explicit :: Enumerator a -> [(a,Double)]
explicit = map (second (exp . ln)) . logExplicit

-- | Returns the model evidence, that is sum of all weights.
evidence :: Enumerator a -> Log Double
evidence = Log.sum . map snd . logExplicit

-- | Normalized probability mass of a specific value.
mass :: Ord a => Enumerator a -> a -> Double
mass d = f where
  f a = fromMaybe 0 $ lookup a m
  m = enumerate d

-- | Aggregate weights of equal values.
-- The resulting list is sorted ascendingly according to values.
compact :: (Num r, Ord a) => [(a,r)] -> [(a,r)]
compact = Map.toAscList . Map.fromListWith (+)

-- | Aggregate and normalize of weights.
-- The resulting list is sorted ascendingly according to values.
--
-- > enumerate = compact . explicit
enumerate :: Ord a => Enumerator a -> [(a, Double)]
enumerate d = compact (zip xs ws) where
  (xs, ws) = second (map (exp . ln) . normalize) $ unzip (logExplicit d)

-- | Expectation of a given function computed using normalized weights.
expectation :: (a -> Double) -> Enumerator a -> Double
expectation f = Prelude.sum . map (\(x, w) -> f x * (exp . ln) w) . normalizeWeights . logExplicit

normalize :: [Log Double] -> [Log Double]
normalize xs = map (/ z) xs where
  z = Log.sum xs

-- | Divide all weights by their sum.
normalizeWeights :: [(a, Log Double)] -> [(a, Log Double)]
normalizeWeights ls = zip xs ps where
  (xs, ws) = unzip ls
  ps = normalize ws

-- | 'compact' followed by removing values with zero weight.
normalForm :: Ord a => Enumerator a -> [(a, Double)]
normalForm = filter ((/= 0) . snd) . compact . explicit

instance Ord a => Eq (Enumerator a) where
  p == q = normalForm p == normalForm q

instance Ord a => AEq (Enumerator a) where
  p === q = xs == ys && ps === qs where
    (xs,ps) = unzip (normalForm p)
    (ys,qs) = unzip (normalForm q)
  p ~== q = xs == ys && ps ~== qs where
    (xs,ps) = unzip $ filter (not . (~== 0) . snd) $ normalForm p
    (ys,qs) = unzip $ filter (not . (~== 0) . snd) $ normalForm q