{-# LANGUAGE 
    BangPatterns 
  , NoMonomorphismRestriction
 #-}
-- | Scoring functions commonly used for evaluation of NLP
-- systems. Most functions in this module work on sequences which are
-- instances of 'Data.Foldable', but some take a precomputed table of
-- 'Counts'. This will give a speedup if you want to compute multiple
-- scores on the same data. For example to compute the Mutual
-- Information, Variation of Information and the Adjusted Rand Index
-- on the same pair of clusterings:
--
-- >>> let cs = counts "abcabc" "abaaba"
-- >>> mapM_ (print . ($ cs)) [mi, ari, vi]
-- >>> 0.9182958340544894
-- >>> 0.4444444444444445
-- >>> 0.6666666666666663

module NLP.Scores 
    ( 
    -- * Scores for classification and ranking
      errorRate
    , accuracy
    , recipRank
    , avgPrecision
    -- * Scores for clustering
    , ari
    , mi
    , vi
    -- * Strength of association
    , logLikelihoodRatio
    -- * Comparing probability distributions
    , kullbackLeibler
    , jensenShannon
    -- * Auxiliary types and functions
    , Count
    , Counts
    , counts
    , sum
    , mean
    , jaccard
    , entropy
    , histogram
    -- * Extracting joint and marginal counts from 'Counts'
    , countJoint
    , countFst
    , countSnd
    , countTotal
      -- * Extracting lists of values from 'Counts'
    , fstElems
    , sndElems
    )
where
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import Data.Monoid
import Data.List hiding (sum)
import qualified Data.Set as Set
import qualified Data.Map as Map
import Prelude hiding (sum)
import Data.Strict.Tuple (Pair((:!:)))
import NLP.Scores.Internals


-- | Error rate: the proportion of elements in the first sequence NOT
-- equal to elements at corresponding positions in second
-- sequence. Sequences should be of equal lengths.
errorRate :: (Eq a, Fractional c, T.Traversable t, F.Foldable s) =>  t a -> s a -> c
errorRate xs ys = 1 - accuracy xs ys
{-# SPECIALIZE errorRate :: [Double] -> [Double] -> Double #-}

-- | Accuracy: the proportion of elements in the first sequence equal
-- to elements at corresponding positions in second
-- sequence. Sequences should be of equal lengths.
accuracy :: (Eq a, Fractional c, T.Traversable t, F.Foldable s) =>  t a -> s a -> c
accuracy xs = mean . fmap fromEnum . zipWithTF (==) xs . F.toList
{-# SPECIALIZE accuracy :: [Double] -> [Double] -> Double #-}

-- | Reciprocal rank: the reciprocal of the rank at which the first arguments
-- occurs in the sequence given as the second argument.
recipRank :: (Eq a, Fractional b, F.Foldable t) => a -> t a -> b
recipRank y ys = 
    case [ r | (r,y') <- zip [1::Int ..] . F.toList $ ys , y' == y ] of
      []  -> 0
      r:_ -> 1/fromIntegral r
{-# SPECIALIZE recipRank :: Double -> [Double] -> Double #-}

-- | Average precision. 
-- <http://en.wikipedia.org/wiki/Information_retrieval#Average_precision>
avgPrecision :: (Fractional n, Ord a, F.Foldable t) => Set.Set a -> t a -> n
avgPrecision gold _ | Set.size gold == 0 = 0
avgPrecision gold xs =
      (/fromIntegral (Set.size gold))
    . sum 
    . map (\(r,rel,cum) -> if rel == 0 
                          then 0 
                          else fromIntegral cum / fromIntegral r)
    . takeWhile (\(_,_,cum) -> cum <= Set.size gold) 
    . snd 
    . mapAccumL (\z (r,rel) -> (z+rel,(r,rel,z+rel))) 0
    $ [ (r,fromEnum $ x `Set.member` gold) 
      | (x,r) <- zip (F.toList xs) [1::Int ..]]
{-# SPECIALIZE avgPrecision :: (Ord a) => Set.Set a -> [a] -> Double #-}

-- | Mutual information: MI(X,Y) = H(X) - H(X|Y) = H(Y) - H(Y|X). Also
-- known as information gain.
mi :: (Ord a, Ord b) => Counts a b -> Double
mi (Counts cxy cx cy) =
  let n = Map.foldl' (+) 0 cxy
      cell (x :!: y) nxy = 
        let nx = cx Map.! x
            ny = cy Map.! y
        in  nxy / n * logBase 2 (nxy * n / nx / ny)
  in sum [ cell (x :!: y) nxy | (x :!: y, nxy) <- Map.toList cxy ]

-- | Variation of information: VI(X,Y) = H(X) + H(Y) - 2 MI(X,Y)
vi :: (Ord a, Ord b) => Counts a b -> Double
vi cs@(Counts _ cx cy) = entropy (elems cx) + entropy (elems cy) - 2 * mi cs
  where elems = Map.elems


-- | Log-likelihood ratio for two binomial distributions.
-- H_0: P(x|y) = p = P(x|~y)
-- H_1: P(x|y) = p1 =/= p2 = P(x|~y)
logLikelihoodRatio :: (Ord a, Ord b) => Counts a b -> a -> b -> Double
logLikelihoodRatio cs x y =
  let p   = nx / n                     -- relative count of x
      p1  = nxy / ny                   -- relative count of xy among _y
      p2  = (nx - nxy) / (n - ny)      -- relative count of xnoty among noty
      n   = countTotal cs
      nx  = countFst x cs
      ny  = countSnd y cs
      nxy = countJoint x y cs
      b k n p = p**k * (1-p)**(n-k)
      {-# INLINE b #-}
  in   log (b nxy nx p)  + log (b (nx - nxy) (n - ny) p)
     - log (b nxy nx p1) - log (b (nx - nxy) (n - ny) p2)


-- | Kullback-Leibler divergence: KL(X,Y) = SUM_i P(X=i) log_2(P(X=i)\/P(Y=i)). 
-- The distributions can be unnormalized.
        
kullbackLeibler :: (Eq a, Floating a, F.Foldable f, T.Traversable t) => t a -> f a -> a
kullbackLeibler xs ys = sum . zipWithTF f xs $ ys
  where f !x !y = let px = x / sx in px `mult` logBase 2 (px/(y/sy))
        sx = sum xs
        sy = sum ys
        mult 0 _ = 0
        mult w p = w * p
        {-# INLINE mult #-}  

-- | Jensen-Shannon divergence: JS(X,Y) = 1\/2 KL(X,(X+Y)\/2) + 1\/2 KL(Y,(X+Y)\/2).
-- The distributions can be unnormalized.
jensenShannon :: (Eq a, Floating a, T.Traversable t, T.Traversable u) => t a -> u a -> a
jensenShannon xs ys = 0.5 * kullbackLeibler xs' zs + 0.5 * kullbackLeibler ys' zs
  where zs = zipWithTF (+) xs' ys' 
        xs' = normalize xs
        ys' = normalize ys
          
-- | Adjusted Rand Index: <http://en.wikipedia.org/wiki/Rand_index>
ari :: (Ord a, Ord b) => Counts a b -> Double
ari (Counts cxy cx cy) =  (sum1 - sum2*sum3/choicen2) 
                        / (1/2 * (sum2+sum3) - (sum2*sum3) / choicen2)
  where choicen2 = choice (sum . Map.elems $ cx) 2
        sum1 = sum [ choice nij 2 | nij <- Map.elems cxy ]
        sum2 = sum [ choice ni 2 | ni <- Map.elems cx ]
        sum3 = sum [ choice nj 2 | nj <- Map.elems cy ]

-- | The sum of a sequence of numbers
sum :: (F.Foldable t, Num a) => t a -> a
sum = F.foldl' (+) 0
{-# INLINE sum #-}

-- | The mean of a sequence of numbers.
mean :: (F.Foldable t, Fractional n, Real a) => t a -> n
mean xs = 
    let (tot :!: len) = F.foldl' (\(s :!: l) x -> ((s+x) :!: (l+1))) (0 :!: 0) xs
    in realToFrac tot/len
{-# SPECIALIZE mean :: [Double] -> Double #-}

-- | The binomial coefficient: C^n_k = PROD^k_i=1 (n-k-i)\/i
choice :: (Enum b, Fractional b) => b -> b -> b
choice n k = foldl' (*) 1 [n-k+1 .. n] / foldl' (*) 1 [1 .. k]
{-# SPECIALIZE choice :: Double -> Double -> Double #-}

-- | Jaccard coefficient
-- J(A,B) = |AB| / |A union B|
jaccard :: (Fractional n, Ord a) => Set.Set a -> Set.Set a -> n
jaccard a b = 
  fromIntegral (Set.size (Set.intersection a b))
  / 
  fromIntegral (Set.size (Set.union a b))
{-# SPECIALIZE jaccard :: (Ord a) => Set.Set a -> Set.Set a -> Double #-}  

-- | Entropy: H(X) = -SUM_i P(X=i) log_2(P(X=i)). @entropy xs@ is the
-- entropy of the random variable represented by the sequence @xs@,
-- where each element of @xs@ is the count of the one particular 
-- value the random variable can take. If you need to compute the 
-- entropy from a sequence of outcomes, the following will work:
--
-- > entropy . elems . histogram
--
entropy :: (Floating c, F.Foldable t) => t c -> c
entropy cx = negate . getSum . F.foldMap  (Sum . f)  $ cx
    where n    = sum cx
          logn = logBase 2 n
          f nx = nx / n * (logBase 2 nx - logn)

-- | @histogram xs@ is returns the map of the frequency counts of the
-- elements in sequence @xs@
histogram :: (Num a, Ord k, F.Foldable t) => t k -> Map.Map k a
histogram = F.foldl' (\ z k -> Map.insertWith' (+) k 1 z) Map.empty

-- | Creates count table 'Counts'
counts :: (Ord a, Ord b, T.Traversable t, F.Foldable s) => t a -> s b -> Counts a b
counts xs = F.foldl' f empty . zipWithTF (:!:) xs . F.toList
    where f cs@(Counts cxy cx cy) p@(x :!: y) = 
            cs { joint       = Map.insertWith' (+) p 1 cxy
               , marginalFst = Map.insertWith' (+) x 1 cx
               , marginalSnd = Map.insertWith' (+) y 1 cy }

-- | Joint count
countJoint :: (Ord a, Ord b) => a -> b -> Counts a b -> Count          
countJoint x y = Map.findWithDefault 0 (x :!: y) . joint
-- | Count of first element
countFst :: Ord k => k -> Counts k b -> Count
countFst x = Map.findWithDefault 0 x . marginalFst
-- | Count of second element
countSnd :: Ord k => k -> Counts a k -> Count
countSnd y = Map.findWithDefault 0 y . marginalSnd
-- | Total element count
countTotal :: Counts a k -> Count
countTotal = F.sum . joint

-- | List of values of first element
fstElems :: Counts k b -> [k]
fstElems = Map.keys . marginalFst
-- | List of values of second element
sndElems :: Counts a k -> [k]
sndElems = Map.keys . marginalSnd

-- | @zipWithTF h t f@ zips the values from the traversable @t@ with
-- the values from the foldable @f@ using the function @h@.
zipWithTF :: (T.Traversable t, F.Foldable f) =>
             (a -> b -> c) -> t a -> f b -> t c
zipWithTF h t f = snd . T.mapAccumL map_one (F.toList f) $ t
  where map_one (x:xs) y = (xs, h y x)
        
-- | @normalize xs@ divides each element of xs by the sum of xs.
normalize :: (Fractional b, Functor f, F.Foldable f) => f b -> f b        
normalize xs = let s = sum xs in fmap (/s) xs