{-# LANGUAGE FlexibleContexts #-}
module Statistics.Test.Internal (
    rank
  , rankUnsorted  
  , splitByTags  
  ) where

import Data.Ord
import           Data.Vector.Generic           ((!))
import qualified Data.Vector.Generic         as G
import qualified Data.Vector.Generic.Mutable as M
import Statistics.Function


-- Private data type for unfolding
data Rank v a = Rank {
      forall (v :: * -> *) a. Rank v a -> Int
rankCnt :: {-# UNPACK #-} !Int        -- Number of ranks to return
    , forall (v :: * -> *) a. Rank v a -> Double
rankVal :: {-# UNPACK #-} !Double     -- Rank to return
    , forall (v :: * -> *) a. Rank v a -> Double
rankNum :: {-# UNPACK #-} !Double     -- Current rank
    , forall (v :: * -> *) a. Rank v a -> v a
rankVec :: v a                        -- Remaining vector
    }

-- | Calculate rank of every element of sample. In case of ties ranks
--   are averaged. Sample should be already sorted in ascending order.
--
--   Rank is index of element in the sample, numeration starts from 1.
--   In case of ties average of ranks of equal elements is assigned
--   to each
--
-- >>> rank (==) (fromList [10,20,30::Int])
-- > fromList [1.0,2.0,3.0]
--
-- >>> rank (==) (fromList [10,10,10,30::Int])
-- > fromList [2.0,2.0,2.0,4.0]
rank :: (G.Vector v a, G.Vector v Double)
     => (a -> a -> Bool)        -- ^ Equivalence relation
     -> v a                     -- ^ Vector to rank
     -> v Double
rank :: forall (v :: * -> *) a.
(Vector v a, Vector v Double) =>
(a -> a -> Bool) -> v a -> v Double
rank a -> a -> Bool
eq v a
vec = (Rank v a -> Maybe (Double, Rank v a)) -> Rank v a -> v Double
forall (v :: * -> *) a b.
Vector v a =>
(b -> Maybe (a, b)) -> b -> v a
G.unfoldr Rank v a -> Maybe (Double, Rank v a)
forall {v :: * -> *}.
Vector v a =>
Rank v a -> Maybe (Double, Rank v a)
go (Int -> Double -> Double -> v a -> Rank v a
forall (v :: * -> *) a. Int -> Double -> Double -> v a -> Rank v a
Rank Int
0 (-Double
1) Double
1 v a
vec)
  where
    go :: Rank v a -> Maybe (Double, Rank v a)
go (Rank Int
0 Double
_ Double
r v a
v)
      | v a -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
G.null v a
v  = Maybe (Double, Rank v a)
forall a. Maybe a
Nothing
      | Bool
otherwise =
          case v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
h of
            Int
1 -> (Double, Rank v a) -> Maybe (Double, Rank v a)
forall a. a -> Maybe a
Just (Double
r, Int -> Double -> Double -> v a -> Rank v a
forall (v :: * -> *) a. Int -> Double -> Double -> v a -> Rank v a
Rank Int
0 Double
0 (Double
rDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
1) v a
rest)
            Int
n -> Rank v a -> Maybe (Double, Rank v a)
go Rank { rankCnt :: Int
rankCnt = Int
n
                         , rankVal :: Double
rankVal = Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
rDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))
                         , rankNum :: Double
rankNum = Double
r Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
                         , rankVec :: v a
rankVec = v a
rest
                         }
          where
            (v a
h,v a
rest) = (a -> Bool) -> v a -> (v a, v a)
forall (v :: * -> *) a.
Vector v a =>
(a -> Bool) -> v a -> (v a, v a)
G.span (a -> a -> Bool
eq (a -> a -> Bool) -> a -> a -> Bool
forall a b. (a -> b) -> a -> b
$ v a -> a
forall (v :: * -> *) a. Vector v a => v a -> a
G.head v a
v) v a
v
    go (Rank Int
n Double
val Double
r v a
v) = (Double, Rank v a) -> Maybe (Double, Rank v a)
forall a. a -> Maybe a
Just (Double
val, Int -> Double -> Double -> v a -> Rank v a
forall (v :: * -> *) a. Int -> Double -> Double -> v a -> Rank v a
Rank (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Double
val Double
r v a
v)
{-# INLINE rank #-}

-- | Compute rank of every element of vector. Unlike rank it doesn't
--   require sample to be sorted.
rankUnsorted :: ( Ord a
                , G.Vector v a
                , G.Vector v Int
                , G.Vector v Double
                , G.Vector v (Int, a)
                )
             => v a
             -> v Double
rankUnsorted :: forall a (v :: * -> *).
(Ord a, Vector v a, Vector v Int, Vector v Double,
 Vector v (Int, a)) =>
v a -> v Double
rankUnsorted v a
xs = (forall s. ST s (Mutable v s Double)) -> v Double
forall (v :: * -> *) a.
Vector v a =>
(forall s. ST s (Mutable v s a)) -> v a
G.create ((forall s. ST s (Mutable v s Double)) -> v Double)
-> (forall s. ST s (Mutable v s Double)) -> v Double
forall a b. (a -> b) -> a -> b
$ do
    -- Put ranks into their original positions
    -- NOTE: backpermute will do wrong thing
    Mutable v s Double
vec <- Int -> ST s (Mutable v (PrimState (ST s)) Double)
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
M.new Int
n
    Int -> Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
for Int
0 Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
      Mutable v (PrimState (ST s)) Double -> Int -> Double -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite Mutable v s Double
Mutable v (PrimState (ST s)) Double
vec (v Int
index v Int -> Int -> Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! Int
i) (v Double
ranks v Double -> Int -> Double
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! Int
i)
    Mutable v s Double -> ST s (Mutable v s Double)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Mutable v s Double
vec
  where
    n :: Int
n = v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
xs
    -- Calculate ranks for sorted array
    ranks :: v Double
ranks = (a -> a -> Bool) -> v a -> v Double
forall (v :: * -> *) a.
(Vector v a, Vector v Double) =>
(a -> a -> Bool) -> v a -> v Double
rank a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) v a
sorted
    -- Sort vector and retain original indices of elements
    (v Int
index, v a
sorted)
      = v (Int, a) -> (v Int, v a)
forall (v :: * -> *) a b.
(Vector v a, Vector v b, Vector v (a, b)) =>
v (a, b) -> (v a, v b)
G.unzip
      (v (Int, a) -> (v Int, v a)) -> v (Int, a) -> (v Int, v a)
forall a b. (a -> b) -> a -> b
$ Comparison (Int, a) -> v (Int, a) -> v (Int, a)
forall (v :: * -> *) e. Vector v e => Comparison e -> v e -> v e
sortBy (((Int, a) -> a) -> Comparison (Int, a)
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (Int, a) -> a
forall a b. (a, b) -> b
snd)
      (v (Int, a) -> v (Int, a)) -> v (Int, a) -> v (Int, a)
forall a b. (a -> b) -> a -> b
$ v a -> v (Int, a)
forall (v :: * -> *) e.
(Vector v e, Vector v Int, Vector v (Int, e)) =>
v e -> v (Int, e)
indexed v a
xs
{-# INLINE rankUnsorted #-}


-- | Split tagged vector
splitByTags :: (G.Vector v a, G.Vector v (Bool,a)) => v (Bool,a) -> (v a, v a)
splitByTags :: forall (v :: * -> *) a.
(Vector v a, Vector v (Bool, a)) =>
v (Bool, a) -> (v a, v a)
splitByTags v (Bool, a)
vs = (((Bool, a) -> a) -> v (Bool, a) -> v a
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (Bool, a) -> a
forall a b. (a, b) -> b
snd v (Bool, a)
a, ((Bool, a) -> a) -> v (Bool, a) -> v a
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (Bool, a) -> a
forall a b. (a, b) -> b
snd v (Bool, a)
b)
  where
    (v (Bool, a)
a,v (Bool, a)
b) = ((Bool, a) -> Bool) -> v (Bool, a) -> (v (Bool, a), v (Bool, a))
forall (v :: * -> *) a.
Vector v a =>
(a -> Bool) -> v a -> (v a, v a)
G.unstablePartition (Bool, a) -> Bool
forall a b. (a, b) -> a
fst v (Bool, a)
vs
{-# INLINE splitByTags #-}