{-# LANGUAGE BangPatterns, CPP, FlexibleContexts, Rank2Types #-}
{-# OPTIONS_GHC -fsimpl-tick-factor=200 #-}
-- |
-- Module    : Statistics.Function
-- Copyright : (c) 2009, 2010, 2011 Bryan O'Sullivan
-- License   : BSD3
--
-- Maintainer  : bos@serpentine.com
-- Stability   : experimental
-- Portability : portable
--
-- Useful functions.

module Statistics.Function
    (
    -- * Scanning
      minMax
    -- * Sorting
    , sort
    , gsort
    , sortBy
    , partialSort
    -- * Indexing
    , indexed
    , indices
    -- * Bit twiddling
    , nextHighestPowerOfTwo
    -- * Comparison
    , within
    -- * Arithmetic
    , square
    -- * Vectors
    , unsafeModify
    -- * Combinators
    , for
    , rfor
    ) where

#include "MachDeps.h"

import Control.Monad.ST (ST)
import Data.Bits ((.|.), shiftR)
import qualified Data.Vector.Algorithms.Intro as I
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as M
import Numeric.MathFunctions.Comparison (within)

-- | Sort a vector.
sort :: U.Vector Double -> U.Vector Double
sort :: Vector Double -> Vector Double
sort = (forall s. Mutable Vector s Double -> ST s ())
-> Vector Double -> Vector Double
forall (v :: * -> *) a.
Vector v a =>
(forall s. Mutable v s a -> ST s ()) -> v a -> v a
G.modify MVector (PrimState (ST s)) Double -> ST s ()
Mutable Vector s Double -> ST s ()
forall s. Mutable Vector s Double -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e, Ord e) =>
v (PrimState m) e -> m ()
I.sort
{-# NOINLINE sort #-}

-- | Sort a vector.
gsort :: (Ord e, G.Vector v e) => v e -> v e
gsort :: forall e (v :: * -> *). (Ord e, Vector v e) => v e -> v e
gsort = (forall s. Mutable v s e -> ST s ()) -> v e -> v e
forall (v :: * -> *) a.
Vector v a =>
(forall s. Mutable v s a -> ST s ()) -> v a -> v a
G.modify Mutable v s e -> ST s ()
Mutable v (PrimState (ST s)) e -> ST s ()
forall s. Mutable v s e -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e, Ord e) =>
v (PrimState m) e -> m ()
I.sort
{-# INLINE gsort #-}

-- | Sort a vector using a custom ordering.
sortBy :: (G.Vector v e) => I.Comparison e -> v e -> v e
sortBy :: forall (v :: * -> *) e. Vector v e => Comparison e -> v e -> v e
sortBy Comparison e
f = (forall s. Mutable v s e -> ST s ()) -> v e -> v e
forall (v :: * -> *) a.
Vector v a =>
(forall s. Mutable v s a -> ST s ()) -> v a -> v a
G.modify ((forall s. Mutable v s e -> ST s ()) -> v e -> v e)
-> (forall s. Mutable v s e -> ST s ()) -> v e -> v e
forall a b. (a -> b) -> a -> b
$ Comparison e -> Mutable v (PrimState (ST s)) e -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> m ()
I.sortBy Comparison e
f
{-# INLINE sortBy #-}

-- | Partially sort a vector, such that the least /k/ elements will be
-- at the front.
partialSort :: (G.Vector v e, Ord e) =>
               Int -- ^ The number /k/ of least elements.
            -> v e
            -> v e
partialSort :: forall (v :: * -> *) e. (Vector v e, Ord e) => Int -> v e -> v e
partialSort Int
k = (forall s. Mutable v s e -> ST s ()) -> v e -> v e
forall (v :: * -> *) a.
Vector v a =>
(forall s. Mutable v s a -> ST s ()) -> v a -> v a
G.modify (Mutable v (PrimState (ST s)) e -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e, Ord e) =>
v (PrimState m) e -> Int -> m ()
`I.partialSort` Int
k)
{-# SPECIALIZE partialSort :: Int -> U.Vector Double -> U.Vector Double #-}

-- | Return the indices of a vector.
indices :: (G.Vector v a, G.Vector v Int) => v a -> v Int
indices :: forall (v :: * -> *) a. (Vector v a, Vector v Int) => v a -> v Int
indices v a
a = Int -> Int -> v Int
forall (v :: * -> *) a. (Vector v a, Enum a) => a -> a -> v a
G.enumFromTo Int
0 (v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
a Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
{-# INLINE indices #-}

-- | Zip a vector with its indices.
indexed :: (G.Vector v e, G.Vector v Int, G.Vector v (Int,e)) => v e -> v (Int,e)
indexed :: forall (v :: * -> *) e.
(Vector v e, Vector v Int, Vector v (Int, e)) =>
v e -> v (Int, e)
indexed v e
a = v Int -> v e -> v (Int, e)
forall (v :: * -> *) a b.
(Vector v a, Vector v b, Vector v (a, b)) =>
v a -> v b -> v (a, b)
G.zip (v e -> v Int
forall (v :: * -> *) a. (Vector v a, Vector v Int) => v a -> v Int
indices v e
a) v e
a
{-# INLINE indexed #-}

data MM = MM {-# UNPACK #-} !Double {-# UNPACK #-} !Double

-- | Compute the minimum and maximum of a vector in one pass.
minMax :: (G.Vector v Double) => v Double -> (Double, Double)
minMax :: forall (v :: * -> *).
Vector v Double =>
v Double -> (Double, Double)
minMax = MM -> (Double, Double)
fini (MM -> (Double, Double))
-> (v Double -> MM) -> v Double -> (Double, Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MM -> Double -> MM) -> MM -> v Double -> MM
forall (v :: * -> *) b a.
Vector v b =>
(a -> b -> a) -> a -> v b -> a
G.foldl' MM -> Double -> MM
go (Double -> Double -> MM
MM (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0) (-Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0))
  where
    go :: MM -> Double -> MM
go (MM Double
lo Double
hi) Double
k = Double -> Double -> MM
MM (Double -> Double -> Double
forall a. Ord a => a -> a -> a
min Double
lo Double
k) (Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
hi Double
k)
    fini :: MM -> (Double, Double)
fini (MM Double
lo Double
hi) = (Double
lo, Double
hi)
{-# INLINE minMax #-}

-- | Efficiently compute the next highest power of two for a
-- non-negative integer.  If the given value is already a power of
-- two, it is returned unchanged.  If negative, zero is returned.
nextHighestPowerOfTwo :: Int -> Int
nextHighestPowerOfTwo :: Int -> Int
nextHighestPowerOfTwo Int
n
#if WORD_SIZE_IN_BITS == 64
  = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
_i32
#else
  = 1 + i16
#endif
  where
    i0 :: Int
i0   = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    i1 :: Int
i1   = Int
i0  Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
i0  Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
1
    i2 :: Int
i2   = Int
i1  Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
i1  Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
2
    i4 :: Int
i4   = Int
i2  Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
i2  Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
4
    i8 :: Int
i8   = Int
i4  Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
i4  Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
8
    i16 :: Int
i16  = Int
i8  Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
i8  Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
16
    _i32 :: Int
_i32 = Int
i16 Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
i16 Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
32
-- It could be implemented as
--
-- > nextHighestPowerOfTwo n = 1 + foldl' go (n-1) [1, 2, 4, 8, 16, 32]
--     where go m i = m .|. m `shiftR` i
--
-- But GHC do not inline foldl (probably because it's recursive) and
-- as result function walks list of boxed ints. Hand rolled version
-- uses unboxed arithmetic.

-- | Multiply a number by itself.
square :: Double -> Double
square :: Double -> Double
square Double
x = Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
x

-- | Simple for loop.  Counts from /start/ to /end/-1.
for :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
for :: forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
for Int
n0 !Int
n Int -> m ()
f = Int -> m ()
loop Int
n0
  where
    loop :: Int -> m ()
loop Int
i | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n    = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
           | Bool
otherwise = Int -> m ()
f Int
i m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> m ()
loop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
{-# INLINE for #-}

-- | Simple reverse-for loop.  Counts from /start/-1 to /end/ (which
-- must be less than /start/).
rfor :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
rfor :: forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
rfor Int
n0 !Int
n Int -> m ()
f = Int -> m ()
loop Int
n0
  where
    loop :: Int -> m ()
loop Int
i | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n    = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
           | Bool
otherwise = let i' :: Int
i' = Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 in Int -> m ()
f Int
i' m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> m ()
loop Int
i'
{-# INLINE rfor #-}

unsafeModify :: M.MVector s Double -> Int -> (Double -> Double) -> ST s ()
unsafeModify :: forall s. MVector s Double -> Int -> (Double -> Double) -> ST s ()
unsafeModify MVector s Double
v Int
i Double -> Double
f = do
  Double
k <- MVector (PrimState (ST s)) Double -> Int -> ST s Double
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
M.unsafeRead MVector s Double
MVector (PrimState (ST s)) Double
v Int
i
  MVector (PrimState (ST s)) Double -> Int -> Double -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector s Double
MVector (PrimState (ST s)) Double
v Int
i (Double -> Double
f Double
k)
{-# INLINE unsafeModify #-}