{-# LANGUAGE BangPatterns, FlexibleContexts #-}
-- |
-- Module      : Statistics.Correlation.Kendall
--
-- Fast O(NlogN) implementation of
-- <http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient Kendall's tau>.
--
-- This module implements Kendall's tau form b which allows ties in the data.
-- This is the same formula used by other statistical packages, e.g., R, matlab.
--
-- > \tau = \frac{n_c - n_d}{\sqrt{(n_0 - n_1)(n_0 - n_2)}}
--
-- where n_0 = n(n-1)\/2, n_1 = number of pairs tied for the first quantify,
-- n_2 = number of pairs tied for the second quantify,
-- n_c = number of concordant pairs$, n_d = number of discordant pairs.

module Statistics.Correlation.Kendall
    ( kendall

    -- * References
    -- $references
    ) where

import Control.Monad.ST (ST, runST)
import Data.Bits (shiftR)
import Data.Function (on)
import Data.STRef
import qualified Data.Vector.Algorithms.Intro as I
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as GM

-- | /O(nlogn)/ Compute the Kendall's tau from a vector of paired data.
-- Return NaN when number of pairs <= 1.
kendall :: (Ord a, Ord b, G.Vector v (a, b)) => v (a, b) -> Double
kendall :: forall a b (v :: * -> *).
(Ord a, Ord b, Vector v (a, b)) =>
v (a, b) -> Double
kendall v (a, b)
xy'
  | v (a, b) -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v (a, b)
xy' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = Double
0Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0
  | Bool
otherwise  = (forall s. ST s Double) -> Double
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Double) -> Double)
-> (forall s. ST s Double) -> Double
forall a b. (a -> b) -> a -> b
$ do
    Mutable v s (a, b)
xy <- v (a, b) -> ST s (Mutable v (PrimState (ST s)) (a, b))
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
G.thaw v (a, b)
xy'
    let n :: Int
n = Mutable v s (a, b) -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length Mutable v s (a, b)
xy
    STRef s Integer
n_dRef <- Integer -> ST s (STRef s Integer)
forall a s. a -> ST s (STRef s a)
newSTRef Integer
0
    Mutable v (PrimState (ST s)) (a, b) -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e, Ord e) =>
v (PrimState m) e -> m ()
I.sort Mutable v s (a, b)
Mutable v (PrimState (ST s)) (a, b)
xy
    Integer
tieX <- ((a, b) -> (a, b) -> Bool) -> Mutable v s (a, b) -> ST s Integer
forall (v :: * -> * -> *) a s.
MVector v a =>
(a -> a -> Bool) -> v s a -> ST s Integer
numOfTiesBy (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) (a -> a -> Bool) -> ((a, b) -> a) -> (a, b) -> (a, b) -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (a, b) -> a
forall a b. (a, b) -> a
fst) Mutable v s (a, b)
xy
    Integer
tieXY <- ((a, b) -> (a, b) -> Bool) -> Mutable v s (a, b) -> ST s Integer
forall (v :: * -> * -> *) a s.
MVector v a =>
(a -> a -> Bool) -> v s a -> ST s Integer
numOfTiesBy (a, b) -> (a, b) -> Bool
forall a. Eq a => a -> a -> Bool
(==) Mutable v s (a, b)
xy
    Mutable v s (a, b)
tmp <- Int -> ST s (Mutable v (PrimState (ST s)) (a, b))
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
GM.new Int
n
    ((a, b) -> (a, b) -> Ordering)
-> Mutable v s (a, b)
-> Mutable v s (a, b)
-> STRef s Integer
-> ST s ()
forall (v :: * -> * -> *) e s.
MVector v e =>
(e -> e -> Ordering)
-> v s e -> v s e -> STRef s Integer -> ST s ()
mergeSort (b -> b -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (b -> b -> Ordering)
-> ((a, b) -> b) -> (a, b) -> (a, b) -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (a, b) -> b
forall a b. (a, b) -> b
snd) Mutable v s (a, b)
xy Mutable v s (a, b)
tmp STRef s Integer
n_dRef
    Integer
tieY <- ((a, b) -> (a, b) -> Bool) -> Mutable v s (a, b) -> ST s Integer
forall (v :: * -> * -> *) a s.
MVector v a =>
(a -> a -> Bool) -> v s a -> ST s Integer
numOfTiesBy (b -> b -> Bool
forall a. Eq a => a -> a -> Bool
(==) (b -> b -> Bool) -> ((a, b) -> b) -> (a, b) -> (a, b) -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (a, b) -> b
forall a b. (a, b) -> b
snd) Mutable v s (a, b)
xy
    Integer
n_d <- STRef s Integer -> ST s Integer
forall s a. STRef s a -> ST s a
readSTRef STRef s Integer
n_dRef
    let n_0 :: Integer
n_0 = (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
1)) Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` Int
1 :: Integer
        n_c :: Integer
n_c = Integer
n_0 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
n_d Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
tieX Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
tieY Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
tieXY
    Double -> ST s Double
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Double -> ST s Double) -> Double -> ST s Double
forall a b. (a -> b) -> a -> b
$ Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer
n_c Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
n_d) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/
             (Double -> Double
forall a. Floating a => a -> a
sqrt(Double -> Double) -> (Integer -> Double) -> Integer -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral) ((Integer
n_0 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
tieX) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
n_0 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
tieY))
{-# INLINE kendall #-}

-- calculate number of tied pairs in a sorted vector
numOfTiesBy :: GM.MVector v a
            => (a -> a -> Bool) -> v s a -> ST s Integer
numOfTiesBy :: forall (v :: * -> * -> *) a s.
MVector v a =>
(a -> a -> Bool) -> v s a -> ST s Integer
numOfTiesBy a -> a -> Bool
f v s a
xs = do STRef s Integer
count <- Integer -> ST s (STRef s Integer)
forall a s. a -> ST s (STRef s a)
newSTRef (Integer
0::Integer)
                      STRef s Integer -> Int -> Int -> ST s ()
forall {a} {a}.
(Integral a, Bits a, Num a) =>
STRef s a -> a -> Int -> ST s ()
loop STRef s Integer
count (Int
1::Int) (Int
0::Int)
                      STRef s Integer -> ST s Integer
forall s a. STRef s a -> ST s a
readSTRef STRef s Integer
count
  where
    n :: Int
n = v s a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s a
xs
    loop :: STRef s a -> a -> Int -> ST s ()
loop STRef s a
c !a
acc !Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 = STRef s a -> (a -> a) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s a
c (a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall {a} {b}. (Integral a, Bits a, Num b) => a -> b
g a
acc)
                   | Bool
otherwise = do
                       a
x1 <- v (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s a
v (PrimState (ST s)) a
xs Int
i
                       a
x2 <- v (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s a
v (PrimState (ST s)) a
xs (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
                       if a -> a -> Bool
f a
x1 a
x2
                          then STRef s a -> a -> Int -> ST s ()
loop STRef s a
c (a
acca -> a -> a
forall a. Num a => a -> a -> a
+a
1) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
                          else STRef s a -> (a -> a) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s a
c (a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall {a} {b}. (Integral a, Bits a, Num b) => a -> b
g a
acc) ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> STRef s a -> a -> Int -> ST s ()
loop STRef s a
c a
1 (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
    g :: a -> b
g a
x = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((a
x a -> a -> a
forall a. Num a => a -> a -> a
* (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
1)) a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
{-# INLINE numOfTiesBy #-}

-- Implementation of Knight's merge sort (adapted from vector-algorithm). This
-- function is used to count the number of discordant pairs.
mergeSort :: GM.MVector v e
          => (e -> e -> Ordering)
          -> v s e
          -> v s e
          -> STRef s Integer
          -> ST s ()
mergeSort :: forall (v :: * -> * -> *) e s.
MVector v e =>
(e -> e -> Ordering)
-> v s e -> v s e -> STRef s Integer -> ST s ()
mergeSort e -> e -> Ordering
cmp v s e
src v s e
buf STRef s Integer
count = Int -> Int -> ST s ()
loop Int
0 (v s e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
src Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
  where
    loop :: Int -> Int -> ST s ()
loop Int
l Int
u
      | Int
u Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
l = () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Int
u Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = do
          e
eL <- v (PrimState (ST s)) e -> Int -> ST s e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s e
v (PrimState (ST s)) e
src Int
l
          e
eU <- v (PrimState (ST s)) e -> Int -> ST s e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s e
v (PrimState (ST s)) e
src Int
u
          case e -> e -> Ordering
cmp e
eL e
eU of
              Ordering
GT -> do v (PrimState (ST s)) e -> Int -> e -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v s e
v (PrimState (ST s)) e
src Int
l e
eU
                       v (PrimState (ST s)) e -> Int -> e -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v s e
v (PrimState (ST s)) e
src Int
u e
eL
                       STRef s Integer -> (Integer -> Integer) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s Integer
count (Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1)
              Ordering
_ -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Bool
otherwise  = do
          let mid :: Int
mid = (Int
u Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
l) Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
1
          Int -> Int -> ST s ()
loop Int
l Int
mid
          Int -> Int -> ST s ()
loop Int
mid Int
u
          (e -> e -> Ordering)
-> v s e -> v s e -> Int -> STRef s Integer -> ST s ()
forall (v :: * -> * -> *) e s.
MVector v e =>
(e -> e -> Ordering)
-> v s e -> v s e -> Int -> STRef s Integer -> ST s ()
merge e -> e -> Ordering
cmp (Int -> Int -> v s e -> v s e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
l (Int
uInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) v s e
src) v s e
buf (Int
mid Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l) STRef s Integer
count
{-# INLINE mergeSort #-}

merge :: GM.MVector v e
      => (e -> e -> Ordering)
      -> v s e
      -> v s e
      -> Int
      -> STRef s Integer
      -> ST s ()
merge :: forall (v :: * -> * -> *) e s.
MVector v e =>
(e -> e -> Ordering)
-> v s e -> v s e -> Int -> STRef s Integer -> ST s ()
merge e -> e -> Ordering
cmp v s e
src v s e
buf Int
mid STRef s Integer
count = do v (PrimState (ST s)) e -> v (PrimState (ST s)) e -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
GM.unsafeCopy v s e
v (PrimState (ST s)) e
tmp v s e
v (PrimState (ST s)) e
lower
                                 e
eTmp <- v (PrimState (ST s)) e -> Int -> ST s e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s e
v (PrimState (ST s)) e
tmp Int
0
                                 e
eUpp <- v (PrimState (ST s)) e -> Int -> ST s e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s e
v (PrimState (ST s)) e
upper Int
0
                                 v s e -> Int -> e -> v s e -> Int -> e -> Int -> ST s ()
forall {v :: * -> * -> *}.
MVector v e =>
v s e -> Int -> e -> v s e -> Int -> e -> Int -> ST s ()
loop v s e
tmp Int
0 e
eTmp v s e
upper Int
0 e
eUpp Int
0
  where
    lower :: v s e
lower = Int -> Int -> v s e -> v s e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
0 Int
mid v s e
src
    upper :: v s e
upper = Int -> Int -> v s e -> v s e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
mid (v s e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
src Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
mid) v s e
src
    tmp :: v s e
tmp = Int -> Int -> v s e -> v s e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
0 Int
mid v s e
buf
    wroteHigh :: v s e -> Int -> e -> v s e -> Int -> Int -> ST s ()
wroteHigh v s e
low Int
iLow e
eLow v s e
high Int
iHigh Int
iIns
      | Int
iHigh Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= v s e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
high =
          v (PrimState (ST s)) e -> v (PrimState (ST s)) e -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
GM.unsafeCopy (Int -> Int -> v s e -> v s e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
iIns (v s e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
low Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iLow) v s e
src)
                        (Int -> Int -> v s e -> v s e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
iLow (v s e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
low Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iLow) v s e
low)
      | Bool
otherwise = do e
eHigh <- v (PrimState (ST s)) e -> Int -> ST s e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s e
v (PrimState (ST s)) e
high Int
iHigh
                       v s e -> Int -> e -> v s e -> Int -> e -> Int -> ST s ()
loop v s e
low Int
iLow e
eLow v s e
high Int
iHigh e
eHigh Int
iIns

    wroteLow :: v s e -> Int -> v s e -> Int -> e -> Int -> ST s ()
wroteLow v s e
low Int
iLow v s e
high Int
iHigh e
eHigh Int
iIns
      | Int
iLow  Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= v s e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
low  = () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Bool
otherwise = do e
eLow <- v (PrimState (ST s)) e -> Int -> ST s e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v s e
v (PrimState (ST s)) e
low Int
iLow
                       v s e -> Int -> e -> v s e -> Int -> e -> Int -> ST s ()
loop v s e
low Int
iLow e
eLow v s e
high Int
iHigh e
eHigh Int
iIns

    loop :: v s e -> Int -> e -> v s e -> Int -> e -> Int -> ST s ()
loop !v s e
low !Int
iLow !e
eLow !v s e
high !Int
iHigh !e
eHigh !Int
iIns = case e -> e -> Ordering
cmp e
eHigh e
eLow of
        Ordering
LT -> do v (PrimState (ST s)) e -> Int -> e -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v s e
v (PrimState (ST s)) e
src Int
iIns e
eHigh
                 STRef s Integer -> (Integer -> Integer) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s Integer
count (Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (v s e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v s e
low Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
iLow))
                 v s e -> Int -> e -> v s e -> Int -> Int -> ST s ()
wroteHigh v s e
low Int
iLow e
eLow v s e
high (Int
iHighInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
iInsInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
        Ordering
_  -> do v (PrimState (ST s)) e -> Int -> e -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v s e
v (PrimState (ST s)) e
src Int
iIns e
eLow
                 v s e -> Int -> v s e -> Int -> e -> Int -> ST s ()
wroteLow v s e
low (Int
iLowInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) v s e
high Int
iHigh e
eHigh (Int
iInsInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
{-# INLINE merge #-}

-- $references
--
-- * William R. Knight. (1966) A computer method for calculating Kendall's Tau
--   with ungrouped data. /Journal of the American Statistical Association/,
--   Vol. 61, No. 314, Part 1, pp. 436-439. <http://www.jstor.org/pss/2282833>
--