-- |
-- Module    : Statistics.Matrix.Algorithms
-- Copyright : 2014 Bryan O'Sullivan
-- License   : BSD3
--
-- Useful matrix functions.

module Statistics.Matrix.Algorithms
    (
      qr
    ) where

import Control.Applicative ((<$>), (<*>))
import Control.Monad.ST (ST, runST)
import Prelude hiding (replicate)
import Numeric.Sum       (sumVector,kbn)
import Statistics.Matrix (Matrix, column, dimension, for, norm)
import qualified Statistics.Matrix.Mutable as M
import qualified Data.Vector.Unboxed as U

-- | /O(r*c)/ Compute the QR decomposition of a matrix.
-- The result returned is the matrices (/q/,/r/).
qr :: Matrix -> (Matrix, Matrix)
qr :: Matrix -> (Matrix, Matrix)
qr Matrix
mat = (forall s. ST s (Matrix, Matrix)) -> (Matrix, Matrix)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Matrix, Matrix)) -> (Matrix, Matrix))
-> (forall s. ST s (Matrix, Matrix)) -> (Matrix, Matrix)
forall a b. (a -> b) -> a -> b
$ do
  let (Int
m,Int
n) = Matrix -> (Int, Int)
dimension Matrix
mat
  MMatrix s
r <- Int -> Int -> Double -> ST s (MMatrix s)
forall s. Int -> Int -> Double -> ST s (MMatrix s)
M.replicate Int
n Int
n Double
0
  MMatrix s
a <- Matrix -> ST s (MMatrix s)
forall s. Matrix -> ST s (MMatrix s)
M.thaw Matrix
mat
  Int -> Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
for Int
0 Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
    Double
cn <- MMatrix s -> (Matrix -> Double) -> ST s Double
forall a s. NFData a => MMatrix s -> (Matrix -> a) -> ST s a
M.immutably MMatrix s
a ((Matrix -> Double) -> ST s Double)
-> (Matrix -> Double) -> ST s Double
forall a b. (a -> b) -> a -> b
$ \Matrix
aa -> Vector -> Double
norm (Matrix -> Int -> Vector
column Matrix
aa Int
j)
    MMatrix s -> Int -> Int -> Double -> ST s ()
forall s. MMatrix s -> Int -> Int -> Double -> ST s ()
M.unsafeWrite MMatrix s
r Int
j Int
j Double
cn
    Int -> Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
for Int
0 Int
m ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> MMatrix s -> Int -> Int -> (Double -> Double) -> ST s ()
forall s. MMatrix s -> Int -> Int -> (Double -> Double) -> ST s ()
M.unsafeModify MMatrix s
a Int
i Int
j (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
cn)
    Int -> Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
for (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
jj -> do
      Double
p <- MMatrix s -> Int -> Int -> ST s Double
forall s. MMatrix s -> Int -> Int -> ST s Double
innerProduct MMatrix s
a Int
j Int
jj
      MMatrix s -> Int -> Int -> Double -> ST s ()
forall s. MMatrix s -> Int -> Int -> Double -> ST s ()
M.unsafeWrite MMatrix s
r Int
j Int
jj Double
p
      Int -> Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
for Int
0 Int
m ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
        Double
aij <- MMatrix s -> Int -> Int -> ST s Double
forall s. MMatrix s -> Int -> Int -> ST s Double
M.unsafeRead MMatrix s
a Int
i Int
j
        MMatrix s -> Int -> Int -> (Double -> Double) -> ST s ()
forall s. MMatrix s -> Int -> Int -> (Double -> Double) -> ST s ()
M.unsafeModify MMatrix s
a Int
i Int
jj ((Double -> Double) -> ST s ()) -> (Double -> Double) -> ST s ()
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Double
forall a. Num a => a -> a -> a
subtract (Double
p Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
aij)
  (,) (Matrix -> Matrix -> (Matrix, Matrix))
-> ST s Matrix -> ST s (Matrix -> (Matrix, Matrix))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MMatrix s -> ST s Matrix
forall s. MMatrix s -> ST s Matrix
M.unsafeFreeze MMatrix s
a ST s (Matrix -> (Matrix, Matrix))
-> ST s Matrix -> ST s (Matrix, Matrix)
forall a b. ST s (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MMatrix s -> ST s Matrix
forall s. MMatrix s -> ST s Matrix
M.unsafeFreeze MMatrix s
r

innerProduct :: M.MMatrix s -> Int -> Int -> ST s Double
innerProduct :: forall s. MMatrix s -> Int -> Int -> ST s Double
innerProduct MMatrix s
mmat Int
j Int
k = MMatrix s -> (Matrix -> Double) -> ST s Double
forall a s. NFData a => MMatrix s -> (Matrix -> a) -> ST s a
M.immutably MMatrix s
mmat ((Matrix -> Double) -> ST s Double)
-> (Matrix -> Double) -> ST s Double
forall a b. (a -> b) -> a -> b
$ \Matrix
mat ->
  (KBNSum -> Double) -> Vector -> Double
forall (v :: * -> *) s.
(Vector v Double, Summation s) =>
(s -> Double) -> v Double -> Double
sumVector KBNSum -> Double
kbn (Vector -> Double) -> Vector -> Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double) -> Vector -> Vector -> Vector
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
U.zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(*) (Matrix -> Int -> Vector
column Matrix
mat Int
j) (Matrix -> Int -> Vector
column Matrix
mat Int
k)