{-# LANGUAGE BangPatterns, FlexibleContexts #-}
-- |
-- Module      : Statistics.Correlation.Kendall
--
-- Fast O(NlogN) implementation of
-- <http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient Kendall's tau>.
--
-- This module implements Kendall's tau form b which allows ties in the data.
-- This is the same formula used by other statistical packages, e.g., R, matlab.
--
-- > \tau = \frac{n_c - n_d}{\sqrt{(n_0 - n_1)(n_0 - n_2)}}
--
-- where n_0 = n(n-1)\/2, n_1 = number of pairs tied for the first quantify,
-- n_2 = number of pairs tied for the second quantify,
-- n_c = number of concordant pairs$, n_d = number of discordant pairs.

module Statistics.Correlation.Kendall
    ( kendall

    -- * References
    -- $references
    ) where

import Control.Monad.ST (ST, runST)
import Data.Bits (shiftR)
import Data.Function (on)
import Data.STRef
import qualified Data.Vector.Algorithms.Intro as I
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as GM

-- | /O(nlogn)/ Compute the Kendall's tau from a vector of paired data.
-- Return NaN when number of pairs <= 1.
kendall :: (Ord a, Ord b, G.Vector v (a, b)) => v (a, b) -> Double
kendall :: forall a b (v :: * -> *).
(Ord a, Ord b, Vector v (a, b)) =>
v (a, b) -> Double
kendall v (a, b)
xy'
  | forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v (a, b)
xy' forall a. Ord a => a -> a -> Bool
<= Int
1 = Double
0forall a. Fractional a => a -> a -> a
/Double
0
  | Bool
otherwise  = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
    Mutable v s (a, b)
xy <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
G.thaw v (a, b)
xy'
    let n :: Int
n = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length Mutable v s (a, b)
xy
    STRef s Integer
n_dRef <- forall a s. a -> ST s (STRef s a)
newSTRef Integer
0
    forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e, Ord e) =>
v (PrimState m) e -> m ()
I.sort Mutable v s (a, b)
xy
    Integer
tieX <- forall (v :: * -> * -> *) a s.
MVector v a =>
(a -> a -> Bool) -> v s a -> ST s Integer
numOfTiesBy (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) Mutable v s (a, b)
xy
    Integer
tieXY <- forall (v :: * -> * -> *) a s.
MVector v a =>
(a -> a -> Bool) -> v s a -> ST s Integer
numOfTiesBy forall a. Eq a => a -> a -> Bool
(==) Mutable v s (a, b)
xy
    Mutable v s (a, b)
tmp <- forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
GM.new Int
n
    forall (v :: * -> * -> *) e s.
MVector v e =>
(e -> e -> Ordering)
-> v s e -> v s e -> STRef s Integer -> ST s ()
mergeSort (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> b
snd) Mutable v s (a, b)
xy Mutable v s (a, b)
tmp STRef s Integer
n_dRef
    Integer
tieY <- forall (v :: * -> * -> *) a s.
MVector v a =>
(a -> a -> Bool) -> v s a -> ST s Integer
numOfTiesBy (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> b
snd) Mutable v s (a, b)
xy
    Integer
n_d <- forall s a. STRef s a -> ST s a
readSTRef STRef s Integer
n_dRef
    let n_0 :: Integer
n_0 = (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n forall a. Num a => a -> a -> a
* (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nforall a. Num a => a -> a -> a
-Integer
1)) forall a. Bits a => a -> Int -> a
`shiftR` Int
1 :: Integer
        n_c :: Integer
n_c = Integer
n_0 forall a. Num a => a -> a -> a
- Integer
n_d forall a. Num a => a -> a -> a
- Integer
tieX forall a. Num a => a -> a -> a
- Integer
tieY forall a. Num a => a -> a -> a
+ Integer
tieXY
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer
n_c forall a. Num a => a -> a -> a
- Integer
n_d) forall a. Fractional a => a -> a -> a
/
             (forall a. Floating a => a -> a
sqrtforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (Integral a, Num b) => a -> b
fromIntegral) ((Integer
n_0 forall a. Num a => a -> a -> a
- Integer
tieX) forall a. Num a => a -> a -> a
* (Integer
n_0 forall a. Num a => a -> a -> a
- Integer
tieY))
{-# INLINE kendall #-}

-- calculate number of tied pairs in a sorted vector
numOfTiesBy :: GM.MVector v a
            => (a -> a -> Bool) -> v s a -> ST s Integer
numOfTiesBy :: forall (v :: * -> * -> *) a s.
MVector v a =>
(a -> a -> Bool) -> v s a -> ST s Integer
numOfTiesBy a -> a -> Bool
f v s a
xs = do STRef s Integer
count <- forall a s. a -> ST s (STRef s a)
newSTRef (Integer
0::Integer)
                      forall {a} {a}.
(Integral a, Bits a, Num a) =>
STRef s a -> a -> Int -> ST s ()
loop STRef s Integer
count (Int
1::Int) (Int
0::Int)
                      forall s a. STRef s a -> ST s a
readSTRef STRef s Integer
count
  where
    n :: Int
n = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s a
xs
    loop :: STRef s a -> a -> Int -> ST s ()
loop STRef s a
c !a
acc !Int
i | Int
i forall a. Ord a => a -> a -> Bool
>= Int
n forall a. Num a => a -> a -> a
- Int
1 = forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s a
c (forall a. Num a => a -> a -> a
+ forall {a} {b}. (Integral a, Bits a, Num b) => a -> b
g a
acc)
                   | Bool
otherwise = do
                       a
x1 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s a
xs Int
i
                       a
x2 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s a
xs (Int
iforall a. Num a => a -> a -> a
+Int
1)
                       if a -> a -> Bool
f a
x1 a
x2
                          then STRef s a -> a -> Int -> ST s ()
loop STRef s a
c (a
accforall a. Num a => a -> a -> a
+a
1) (Int
iforall a. Num a => a -> a -> a
+Int
1)
                          else forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s a
c (forall a. Num a => a -> a -> a
+ forall {a} {b}. (Integral a, Bits a, Num b) => a -> b
g a
acc) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> STRef s a -> a -> Int -> ST s ()
loop STRef s a
c a
1 (Int
iforall a. Num a => a -> a -> a
+Int
1)
    g :: a -> b
g a
x = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((a
x forall a. Num a => a -> a -> a
* (a
x forall a. Num a => a -> a -> a
- a
1)) forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
{-# INLINE numOfTiesBy #-}

-- Implementation of Knight's merge sort (adapted from vector-algorithm). This
-- function is used to count the number of discordant pairs.
mergeSort :: GM.MVector v e
          => (e -> e -> Ordering)
          -> v s e
          -> v s e
          -> STRef s Integer
          -> ST s ()
mergeSort :: forall (v :: * -> * -> *) e s.
MVector v e =>
(e -> e -> Ordering)
-> v s e -> v s e -> STRef s Integer -> ST s ()
mergeSort e -> e -> Ordering
cmp v s e
src v s e
buf STRef s Integer
count = Int -> Int -> ST s ()
loop Int
0 (forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
src forall a. Num a => a -> a -> a
- Int
1)
  where
    loop :: Int -> Int -> ST s ()
loop Int
l Int
u
      | Int
u forall a. Eq a => a -> a -> Bool
== Int
l = forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Int
u forall a. Num a => a -> a -> a
- Int
l forall a. Eq a => a -> a -> Bool
== Int
1 = do
          e
eL <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s e
src Int
l
          e
eU <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s e
src Int
u
          case e -> e -> Ordering
cmp e
eL e
eU of
              Ordering
GT -> do forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v s e
src Int
l e
eU
                       forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v s e
src Int
u e
eL
                       forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s Integer
count (forall a. Num a => a -> a -> a
+Integer
1)
              Ordering
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Bool
otherwise  = do
          let mid :: Int
mid = (Int
u forall a. Num a => a -> a -> a
+ Int
l) forall a. Bits a => a -> Int -> a
`shiftR` Int
1
          Int -> Int -> ST s ()
loop Int
l Int
mid
          Int -> Int -> ST s ()
loop Int
mid Int
u
          forall (v :: * -> * -> *) e s.
MVector v e =>
(e -> e -> Ordering)
-> v s e -> v s e -> Int -> STRef s Integer -> ST s ()
merge e -> e -> Ordering
cmp (forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
l (Int
uforall a. Num a => a -> a -> a
-Int
lforall a. Num a => a -> a -> a
+Int
1) v s e
src) v s e
buf (Int
mid forall a. Num a => a -> a -> a
- Int
l) STRef s Integer
count
{-# INLINE mergeSort #-}

merge :: GM.MVector v e
      => (e -> e -> Ordering)
      -> v s e
      -> v s e
      -> Int
      -> STRef s Integer
      -> ST s ()
merge :: forall (v :: * -> * -> *) e s.
MVector v e =>
(e -> e -> Ordering)
-> v s e -> v s e -> Int -> STRef s Integer -> ST s ()
merge e -> e -> Ordering
cmp v s e
src v s e
buf Int
mid STRef s Integer
count = do forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
GM.unsafeCopy v s e
tmp v s e
lower
                                 e
eTmp <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s e
tmp Int
0
                                 e
eUpp <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s e
upper Int
0
                                 forall {v :: * -> * -> *}.
MVector v e =>
v s e -> Int -> e -> v s e -> Int -> e -> Int -> ST s ()
loop v s e
tmp Int
0 e
eTmp v s e
upper Int
0 e
eUpp Int
0
  where
    lower :: v s e
lower = forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
0 Int
mid v s e
src
    upper :: v s e
upper = forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
mid (forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
src forall a. Num a => a -> a -> a
- Int
mid) v s e
src
    tmp :: v s e
tmp = forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
0 Int
mid v s e
buf
    wroteHigh :: v s e -> Int -> e -> v s e -> Int -> Int -> ST s ()
wroteHigh v s e
low Int
iLow e
eLow v s e
high Int
iHigh Int
iIns
      | Int
iHigh forall a. Ord a => a -> a -> Bool
>= forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
high =
          forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
GM.unsafeCopy (forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
iIns (forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
low forall a. Num a => a -> a -> a
- Int
iLow) v s e
src)
                        (forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
iLow (forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
low forall a. Num a => a -> a -> a
- Int
iLow) v s e
low)
      | Bool
otherwise = do e
eHigh <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s e
high Int
iHigh
                       v s e -> Int -> e -> v s e -> Int -> e -> Int -> ST s ()
loop v s e
low Int
iLow e
eLow v s e
high Int
iHigh e
eHigh Int
iIns

    wroteLow :: v s e -> Int -> v s e -> Int -> e -> Int -> ST s ()
wroteLow v s e
low Int
iLow v s e
high Int
iHigh e
eHigh Int
iIns
      | Int
iLow  forall a. Ord a => a -> a -> Bool
>= forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
low  = forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Bool
otherwise = do e
eLow <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s e
low Int
iLow
                       v s e -> Int -> e -> v s e -> Int -> e -> Int -> ST s ()
loop v s e
low Int
iLow e
eLow v s e
high Int
iHigh e
eHigh Int
iIns

    loop :: v s e -> Int -> e -> v s e -> Int -> e -> Int -> ST s ()
loop !v s e
low !Int
iLow !e
eLow !v s e
high !Int
iHigh !e
eHigh !Int
iIns = case e -> e -> Ordering
cmp e
eHigh e
eLow of
        Ordering
LT -> do forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v s e
src Int
iIns e
eHigh
                 forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s Integer
count (forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
low forall a. Num a => a -> a -> a
- Int
iLow))
                 v s e -> Int -> e -> v s e -> Int -> Int -> ST s ()
wroteHigh v s e
low Int
iLow e
eLow v s e
high (Int
iHighforall a. Num a => a -> a -> a
+Int
1) (Int
iInsforall a. Num a => a -> a -> a
+Int
1)
        Ordering
_  -> do forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v s e
src Int
iIns e
eLow
                 v s e -> Int -> v s e -> Int -> e -> Int -> ST s ()
wroteLow v s e
low (Int
iLowforall a. Num a => a -> a -> a
+Int
1) v s e
high Int
iHigh e
eHigh (Int
iInsforall a. Num a => a -> a -> a
+Int
1)
{-# INLINE merge #-}

-- $references
--
-- * William R. Knight. (1966) A computer method for calculating Kendall's Tau
--   with ungrouped data. /Journal of the American Statistical Association/,
--   Vol. 61, No. 314, Part 1, pp. 436-439. <http://www.jstor.org/pss/2282833>
--