{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{- | Conditional probability table

Conditional Probability Tables and Probability tables

module Bayes.Factor(
 -- * Factor
 , isomorphicFactor
 , normedFactor
 -- * Set of variables 
 , Set(..)
 , BayesianDiscreteVariable(..)
 -- * Implementation
 , Vertex(..)
 -- ** Discrete variables and instantiations
 , DV
 --, DVSet(..)
 , DVI
 , setDVValue
 , instantiationValue
 , instantiationVariable
 , variableVertex
 , (=:)
 , forAllInstantiations
 , factorFromInstantiation
 , changeVariableOrder
 -- ** Factor
 , CPT
 -- * Tests
 , testProductProject_prop
 , testAssocProduct_prop
 , testScale_prop
 , testProjectCommut_prop
 , testScalarProduct_prop
 , testProjectionToScalar_prop
 ) where

import qualified Data.Vector.Unboxed as V
import Data.Vector.Unboxed((!))
import Data.Maybe(fromJust,mapMaybe,isJust)
import qualified Data.List as L
import Text.PrettyPrint.Boxes hiding((//))
import Test.QuickCheck
import Test.QuickCheck.Arbitrary
import qualified Data.IntMap as IM
import Control.Monad
import System.Random(Random)
import Data.List(partition)
import Bayes.PrivateTypes

-- | A vertex associated to another value (variable dimension, variable value ...)
class LabeledVertex l where
    variableVertex :: l -> Vertex

-- | Convert a variable instantation to a factor
-- Useful to create evidence factors
factorFromInstantiation :: Factor f => DVI Int -> f
factorFromInstantiation (DVI dv a) = 
    let setValue i = if i == a then 1.0 else 0.0 
    fromJust . factorWithVariables [dv] . map (setValue) $ [0..dimension dv-1]

instance LabeledVertex (DVI a) where
    variableVertex (DVI v _) = variableVertex v

instance LabeledVertex DV where
    variableVertex (DV v _) = v

-- | Norm the factor
normedFactor :: Factor f => f -> f 
normedFactor f = factorDivide f (factorNorm f)

-- | A factor as used in graphical model
-- It may or not be a probability distribution. So it has no reason to be
-- normalized to 1
class Factor f where
    -- | When all variables of a factor have been summed out, we have a scalar
    isScalarFactor :: f -> Bool 
    -- | An empty factor with no variable and no values
    emptyFactor :: f
    -- | Check if a given discrete variable is contained in a factor
    containsVariable :: f -> DV  -> Bool
    -- | Give the set of discrete variables used by the factor
    factorVariables :: f -> [DV]    
    -- | Return A in P(A | C D ...). It is making sense only if the factor is a conditional propbability
    -- table. It must always be in the vertex corresponding to A in the bayesian graph
    factorMainVariable :: f -> DV
    factorMainVariable f = let vars = factorVariables f 
      case vars of 
        [] -> error "Can't get the main variable of a scalar factor"
        (h:_) -> h 
    -- | Create a new factors with given set of variables and a list of value
    -- for initialization. The creation may fail if the number of values is not
    -- coherent with the variables and their levels.
    -- For boolean variables ABC, the value must be given in order
    -- FFF, FFT, FTF, FTT ...
    factorWithVariables :: [DV] -> [Double] -> Maybe f
    -- | Value of factor for a given set of variable instantitation.
    -- The variable instantion is like a multi-dimensional index.
    factorValue :: f -> [DVI Int] -> Double
    -- | Position of a discrete variable in te factor (p(AB) is differennt from p(BA) since values
    -- are not organized in same order in memory)
    variablePosition :: f -> DV -> Maybe Int
    -- | Dimension of the factor (number of floating point values)
    factorDimension :: f -> Int
    -- | Norm of the factor = sum of its values
    factorNorm :: f -> Double 

    -- | Scale the factor values by a given scaling factor
    factorScale :: Double -> f -> f

    -- | Create a scalar factor with no variables
    factorFromScalar :: Double -> f

    -- | Create an evidence factor from an instantiation.
    -- If the instantiation is empty then we get nothing
    evidenceFrom :: [DVI Int] -> Maybe f

    -- | Divide all the factor values
    factorDivide :: f -> Double -> f
    factorDivide f d = (1.0 / d) `factorScale` f 

    factorToList :: f -> [Double]

    -- | Multiply factors. 
    factorProduct :: [f] -> f

    -- | Project out a factor. The variable in the DVSet are summed out
    factorProjectOut :: [DV] -> f -> f

    -- | Project to. The variable are kept and other variables are removed
    factorProjectTo :: [DV] -> f -> f 
    factorProjectTo s f = 
        let alls = factorVariables f 
            s' = alls `difference` s 
        factorProjectOut s' f

-- | Change the layout of values in the
-- factor to correspond to a new variable order
-- Used to import external files
changeVariableOrder :: DVSet s -- ^ Old order
                    -> DVSet s' -- ^ New order 
                    -> [Double] -- ^ Old values
                    -> [Double] -- ^ New values
changeVariableOrder (DVSet oldOrder) newOrder oldValues =
    let oldFactor = fromJust $ factorWithVariables oldOrder oldValues :: CPT
    [factorValue oldFactor i | i <- forAllInstantiations newOrder]

-- | Mainly used for conditional probability table like p(A B | C D E) but the normalization to 1
-- is not imposed. And the conditionned variables are not different from the conditionning ones.
-- The dimensions for each variables are listed.
-- The variables on the left or right of the condition bar are not tracked. What's matter is that
-- it is encoding a function of several variables.
-- Marginalization of variables will be computed from the bayesian graph where
-- the knowledge of the dependencies is.
-- So, this same structure is used for a probability too (conditioned on nothing)
data CPT = CPT {
           dimensions :: ![DV] -- ^ Dimensions for all variables
         , mapping :: !(IM.IntMap Int) -- ^ Mapping from vertex number to position in dimensions
         , values :: !(V.Vector Double) -- ^ Table of values
         | Scalar !Double

debugCPT (Scalar d) = do 
   putStrLn "SCALAR CPT"
   print d
   putStrLn ""

debugCPT (CPT d m v) = do 
    putStrLn "CPT"
    print d 
    putStrLn ""
    print m 
    putStrLn ""
    print v
    putStrLn ""

CPT can't have same same vertex values but with different sizes.
But, arbitrary CPT generation will general several vertex with same vertex id
and different vertex size.

So, we introduce a function mapping a vertex ID to a vertex size. So, vertex size are hard coded


quickCheckVertexSize :: Int -> Int
quickCheckVertexSize 0 = 2
quickCheckVertexSize 1 = 2
quickCheckVertexSize 2 = 2
quickCheckVertexSize _ = 2

-- | Generate a random value until this value is not already present in the list
whileIn :: (Arbitrary a, Eq a) => [a] -> Gen a -> Gen a
whileIn l m = do 
    newVal <- m 
    if newVal `elem` l 
            whileIn l m 
            return newVal

-- | Generate a random vector of n elements without replacement (no duplicate)
-- May loop if the range is too small !
generateWithoutReplacement :: (Random a, Arbitrary a, Eq a)  
                           => Int -- ^ Vector size
                           -> (a,a) -- ^ Bounds
                           -> Gen [a]
generateWithoutReplacement n b | n == 1 = generateSingle b 
                               | n > 1 = generateMultiple n b 
                               | otherwise = return []
   generateSingle b = do 
       r <- choose b
       return [r]
   generateMultiple n b = do 
       l <- generateWithoutReplacement (n-1) b
       newelem <- whileIn l $ choose b
       return (newelem:l)

instance Arbitrary CPT where
    arbitrary = do 
        nbVertex <- choose (1,4) :: Gen Int
        vertexNumbers <- generateWithoutReplacement nbVertex (0,50)
        let dimensions = map (\i -> (DV (Vertex i)  (quickCheckVertexSize i))) vertexNumbers
        let valuelen = product (map dimension dimensions)
        rndValues <- vectorOf valuelen (choose (0.0,1.0) :: Gen Double)
        return . fromJust . factorWithVariables dimensions $ rndValues

-- | Test product followed by a projection when the factors have no
-- common variables

-- | Floating point number comparisons which should take into account
-- all the subtleties of that kind of comparison
nearlyEqual :: Double -> Double -> Bool
nearlyEqual a b = 
    let absA = abs a 
        absB = abs b 
        diff = abs (a-b)
        epsilon = 2e-5
    case (a,b) of 
        (x,y) | x == y -> True -- handle infinities
              | x*y == 0 -> diff < (epsilon * epsilon)
              | otherwise -> diff / (absA + absB) < epsilon

testScale_prop :: Double -> CPT -> Bool
testScale_prop s f = (factorNorm (s `factorScale` f)) `nearlyEqual` (s * (factorNorm f))

testProductProject_prop :: CPT -> CPT -> Property
testProductProject_prop fa fb = isEmpty ((factorVariables fa) `intersection` (factorVariables fb))  ==> 
    let r = factorProjectOut (factorVariables fb) (factorProduct [fa,fb])
        fa' = r `factorDivide` (factorNorm fb)
    fa' `isomorphicFactor` fa

testScalarProduct_prop :: Double -> CPT -> Bool 
testScalarProduct_prop v f = (factorProduct [(Scalar v),f]) `isomorphicFactor` (v `factorScale` f)

testAssocProduct_prop :: CPT -> CPT -> CPT -> Bool
testAssocProduct_prop a b c = (factorProduct [factorProduct [a,b],c] `isomorphicFactor` factorProduct [a,factorProduct [b,c]]) &&
  (factorProduct [a,b,c] `isomorphicFactor` (factorProduct [factorProduct [a,b],c]) )

testProjectionToScalar_prop :: CPT -> Bool 
testProjectionToScalar_prop f = 
    let allVars = factorVariables f 
    (factorProjectOut allVars f) `isomorphicFactor` (factorFromScalar (factorNorm f))

testProjectCommut_prop:: CPT -> Property 
testProjectCommut_prop f = length (factorVariables f) >= 3 ==>
    let a = take 1 (factorVariables f)
        b = take 1 . drop 1 $ factorVariables f 
        commuta = factorProjectOut a (factorProjectOut b f)
        commutb = factorProjectOut b (factorProjectOut a f)
    commuta `isomorphicFactor` commutb

-- | Test equality of two factors taking into account the fact
-- that the variables may be in a different order.
-- In case there is a distinction between conditionned variable and
-- conditionning variables (imposed from the exterior) then this
-- comparison may not make sense. It is a comparison of
-- function of several variables which no special interpretation of the
-- meaning of the variables according to their position.
isomorphicFactor :: Factor f => f -> f -> Bool
isomorphicFactor fa fb = maybe False (const True) $ do 
    let sa = factorVariables fa 
        sb = factorVariables fb 
        va = DVSet sa 
        vb = DVSet sb
    guard (sa `equal` sb)
    guard (factorDimension fa == factorDimension fb)
    guard $ and [factorValue fa ia `nearlyEqual` factorValue fb ia | ia <- forAllInstantiations va]
    return ()


Following functions are used to typeset the factor when displaying it

-- | Display a variable name and its size
vname :: Int -> DVI Int -> Box
vname vc i = text $ "v" ++ show vc ++ "=" ++ show (instantiationValue i)

dispFactor :: Factor f => f -> DV -> [DVI Int] -> [DV] -> Box
dispFactor cpt h c [] = 
    let dstIndexes = allInstantiationsForOneVariable h
        dependentIndexes =  reverse c
        factorValueAtPosition p = 
            let v = factorValue cpt p
            text . show  $ v
    vsep 0 center1 . map (factorValueAtPosition . (:dependentIndexes)) $ dstIndexes

dispFactor cpt dst c (h@(DV (Vertex vc) i):l) = 
    let allInst = allInstantiationsForOneVariable h
    hsep 1 top . map (\i -> vcat center1 [vname vc i,dispFactor cpt dst (i:c) l])  $ allInst

instance Show CPT where
    show (Scalar v) = "\nScalar Factor:\n" ++ show v
    show c@(CPT [] _ v) = "\nEmpty CPT:\n"

    show c@(CPT d _ v) = 
        let h@(DV (Vertex vc) _) = head d
            table = dispFactor c h [] (tail d)
            dstIndexes = map head (forAllInstantiations . DVSet $ [h])
            -- In P(A | B ...), the dst column is containing the possible values for the
            -- variables A with a header made of space to be aligned with the other part of the table.
            -- In the other part of the table, this header is containing the variable values for the other varibles
            dstColumn = vcat center1 $ replicate (length d - 1) (text "") ++ map (vname vc) dstIndexes
        "\n" ++ show d ++ "\n" ++ render (hsep 1 top [dstColumn,table])

instance Factor CPT where
    factorToList (Scalar v) = [v]
    factorToList (CPT _ _ v) = V.toList v
    emptyFactor = emptyCPT
    isScalarFactor (Scalar _) = True
    isScalarFactor _ = False
    factorFromScalar v = Scalar v
    factorDimension f@(CPT _ _ _) = product . map dimension . factorVariables$ f
    factorDimension _ = 1
    containsVariable (CPT _ m _) (DV (Vertex i) _)   = IM.member i m
    containsVariable (Scalar _) _ = False
    factorWithVariables = createCPTWithDims
    factorVariables (CPT v _ _) = v
    factorVariables (Scalar _) = []
    factorNorm f@(CPT d _ vals) = 
        let vars = DVSet d
            strides = indexStrides vars
        sum [ vals!(indexPosition strides x) | x <- indicesForDomain vars]
    factorNorm (Scalar v) = v
    variablePosition (CPT _ m _) (DV (Vertex i) _) = IM.lookup i m
    variablePosition (Scalar _) _ = Nothing
    factorScale s (Scalar v) = Scalar (s*v)
    factorScale s f@(CPT d _ vals) = 
        let vars = DVSet d
            strides = indexStrides vars
            newValues = map (s *) [ vals!(indexPosition strides x) | x <- indicesForDomain vars]
        fromJust $ factorWithVariables (factorVariables f) newValues
    factorValue (Scalar v) _ = v 
    factorValue f@(CPT d _ v) i = 
        let vars = DVSet d
            (dv,pos) = instantiationDetails i
            strides = indexStridesFor vars dv
        v!(indexPosition strides pos)
    evidenceFrom [] = Nothing 
    evidenceFrom l = 
        let (variables,index) = instantiationDetails l
            DVSet nakedVars = variables
            setValueForIndex i = if i == index then 1.0 else 0.0 
        factorWithVariables nakedVars . map setValueForIndex $ indicesForDomain variables
    factorProduct [] = factorFromScalar 1.0
    factorProduct l = 
        let allVars = DVSet $ L.foldl1' union . map factorVariables $ l
            DVSet nakedVars = allVars
            (scalars,cpts) = partition isScalarFactor l
            stridesFromCPT (CPT d _ _) = indexStridesFor (DVSet d) allVars
            ps = product . map (flip factorValue undefined) $ scalars
            cptsStrides = map stridesFromCPT cpts
        if L.null cpts 
                factorFromScalar ps
                let getFactorValueAtIndex i (strides,factor@(CPT _ _ vals)) = vals!(indexPosition strides i)
                    instantiationProduct instantiation = product . map (getFactorValueAtIndex instantiation) $ (zip cptsStrides cpts)
                    values = [ps * instantiationProduct x | x <- indicesForDomain allVars]
                values `seq` fromJust $ factorWithVariables nakedVars values
    factorProjectOut _ f@(Scalar v) = f
    factorProjectOut s f@(CPT d _ v) = 
        let variablesToSum = s
            variablesToKeep = d `difference` s 
            keepSet = DVSet variablesToKeep
            sumSet = DVSet variablesToSum 
            strides = indexStridesFor (DVSet d) (DVSet $ variablesToKeep ++ variablesToSum)

            values = do 
                  keepIndex <- indicesForDomain keepSet 
                  let l = do
                        sumIndex <- indicesForDomain sumSet 
                        return $ v!(indexPosition strides $ combineIndex strides keepIndex sumIndex)
                  return (sum l)
        values `seq` fromJust $ factorWithVariables variablesToKeep values
-- | Used to combined the keep and sum indices in the factor projection
combineIndex :: Strides s'' -> [Index s] -> [Index s'] -> [Index s''] 
combineIndex _ la lb = map (Index . fromIndex) la ++ map (Index .fromIndex) lb

-- | An empty CPT
emptyCPT :: CPT
emptyCPT = CPT [] IM.empty V.empty

newtype Strides s = Strides [Int] deriving(Eq,Show)

-- | Generate strides to read the first CPT using an index having meaning in the second CPT
indexStridesFor :: DVSet s -- ^ DVSet to be read
                -> DVSet s' -- ^ DVSet to interpret the index
                -> Strides s'
indexStridesFor dr@(DVSet drvars) di@(DVSet divars) =
    let Strides originStrides = indexStrides dr
        reference = zip drvars originStrides 
        getNewStrides dv = maybe 0 id (lookup dv reference)
    Strides $ map getNewStrides divars

-- | Generate the strides to read a given factor using a multiindex
-- using the same order as the factor variables
indexStrides :: DVSet s -> Strides s
indexStrides d@(DVSet dvars)  = 
    let dim = map dimension dvars
        pos' = scanr (*) (1::Int) (tail dim)
    Strides pos'
-- | Convertion of a multiindex to its
-- position inside of the data vector of a 'CPT'
indexPosition :: Strides s -> [Index s] -> Int
{-# INLINE indexPosition #-}
indexPositions _ []  = 0
indexPosition (Strides pos') pos = sum . map (\(x,y) -> x * fromIndex y) $ (zip pos' pos)

-- | Create a CPT given some dimensions and a list of Doubles.
-- Returns nothing is the length are not coherents.
createCPTWithDims :: [DV] -> [Double] -> Maybe CPT
createCPTWithDims dims values = 
    let createDVIndex i (DV (Vertex v) _)  = (v,i)
        m = IM.fromList . zipWith createDVIndex ([0,1..]::[Int]) $ dims
        p = product (map dimension dims) 
    if length values == p
            Just $ CPT dims m (V.fromList values)