module Statistics.Correlation.Kendall
( kendall
) where
import Control.Monad.ST (ST, runST)
import Data.Bits (shiftR)
import Data.Function (on)
import Data.STRef
import qualified Data.Vector.Algorithms.Intro as I
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as GM
kendall :: (Ord a, Ord b, G.Vector v (a, b)) => v (a, b) -> Double
kendall xy'
| G.length xy' <= 1 = 0/0
| otherwise = runST $ do
xy <- G.thaw xy'
let n = GM.length xy
n_dRef <- newSTRef 0
I.sort xy
tieX <- numOfTiesBy ((==) `on` fst) xy
tieXY <- numOfTiesBy (==) xy
tmp <- GM.new n
mergeSort (compare `on` snd) xy tmp n_dRef
tieY <- numOfTiesBy ((==) `on` snd) xy
n_d <- readSTRef n_dRef
let n_0 = (fromIntegral n * (fromIntegral n1)) `shiftR` 1 :: Integer
n_c = n_0 n_d tieX tieY + tieXY
return $ fromIntegral (n_c n_d) /
(sqrt.fromIntegral) ((n_0 tieX) * (n_0 tieY))
numOfTiesBy :: GM.MVector v a
=> (a -> a -> Bool) -> v s a -> ST s Integer
numOfTiesBy f xs = do count <- newSTRef (0::Integer)
loop count (1::Int) (0::Int)
readSTRef count
where
n = GM.length xs
loop c !acc !i | i >= n 1 = modifySTRef' c (+ g acc)
| otherwise = do
x1 <- GM.unsafeRead xs i
x2 <- GM.unsafeRead xs (i+1)
if f x1 x2
then loop c (acc+1) (i+1)
else modifySTRef' c (+ g acc) >> loop c 1 (i+1)
g x = fromIntegral ((x * (x 1)) `shiftR` 1)
mergeSort :: GM.MVector v e
=> (e -> e -> Ordering)
-> v s e
-> v s e
-> STRef s Integer
-> ST s ()
mergeSort cmp src buf count = loop 0 (GM.length src 1)
where
loop l u
| u == l = return ()
| u l == 1 = do
eL <- GM.unsafeRead src l
eU <- GM.unsafeRead src u
case cmp eL eU of
GT -> do GM.unsafeWrite src l eU
GM.unsafeWrite src u eL
modifySTRef' count (+1)
_ -> return ()
| otherwise = do
let mid = (u + l) `shiftR` 1
loop l mid
loop mid u
merge cmp (GM.unsafeSlice l (ul+1) src) buf (mid l) count
merge :: GM.MVector v e
=> (e -> e -> Ordering)
-> v s e
-> v s e
-> Int
-> STRef s Integer
-> ST s ()
merge cmp src buf mid count = do GM.unsafeCopy tmp lower
eTmp <- GM.unsafeRead tmp 0
eUpp <- GM.unsafeRead upper 0
loop tmp 0 eTmp upper 0 eUpp 0
where
lower = GM.unsafeSlice 0 mid src
upper = GM.unsafeSlice mid (GM.length src mid) src
tmp = GM.unsafeSlice 0 mid buf
wroteHigh low iLow eLow high iHigh iIns
| iHigh >= GM.length high =
GM.unsafeCopy (GM.unsafeSlice iIns (GM.length low iLow) src)
(GM.unsafeSlice iLow (GM.length low iLow) low)
| otherwise = do eHigh <- GM.unsafeRead high iHigh
loop low iLow eLow high iHigh eHigh iIns
wroteLow low iLow high iHigh eHigh iIns
| iLow >= GM.length low = return ()
| otherwise = do eLow <- GM.unsafeRead low iLow
loop low iLow eLow high iHigh eHigh iIns
loop !low !iLow !eLow !high !iHigh !eHigh !iIns = case cmp eHigh eLow of
LT -> do GM.unsafeWrite src iIns eHigh
modifySTRef' count (+ fromIntegral (GM.length low iLow))
wroteHigh low iLow eLow high (iHigh+1) (iIns+1)
_ -> do GM.unsafeWrite src iIns eLow
wroteLow low (iLow+1) high iHigh eHigh (iIns+1)
#if !MIN_VERSION_base(4,6,0)
modifySTRef' :: STRef s a -> (a -> a) -> ST s ()
modifySTRef' = modifySTRef
#endif