-- |
-- Module    : Statistics.Matrix.Mutable
-- Copyright : (c) 2014 Bryan O'Sullivan
-- License   : BSD3
--
-- Basic mutable matrix operations.

module Statistics.Matrix.Mutable
    (
      MMatrix(..)
    , MVector
    , replicate
    , thaw
    , bounds
    , unsafeNew
    , unsafeFreeze
    , unsafeRead
    , unsafeWrite
    , unsafeModify
    , immutably
    , unsafeBounds
    ) where

import Control.Applicative ((<$>))
import Control.DeepSeq (NFData(..))
import Control.Monad.ST (ST)
import Statistics.Matrix.Types (Matrix(..), MMatrix(..), MVector)
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as M
import Prelude hiding (replicate)

replicate :: Int -> Int -> Double -> ST s (MMatrix s)
replicate :: forall s. Int -> Int -> Double -> ST s (MMatrix s)
replicate Int
r Int
c Double
k = Int -> Int -> MVector s -> MMatrix s
forall s. Int -> Int -> MVector s -> MMatrix s
MMatrix Int
r Int
c (MVector s -> MMatrix s) -> ST s (MVector s) -> ST s (MMatrix s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Double -> ST s (MVector (PrimState (ST s)) Double)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
M.replicate (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
c) Double
k

thaw :: Matrix -> ST s (MMatrix s)
thaw :: forall s. Matrix -> ST s (MMatrix s)
thaw (Matrix Int
r Int
c Vector
v) = Int -> Int -> MVector s -> MMatrix s
forall s. Int -> Int -> MVector s -> MMatrix s
MMatrix Int
r Int
c (MVector s -> MMatrix s) -> ST s (MVector s) -> ST s (MMatrix s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector -> ST s (MVector (PrimState (ST s)) Double)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.thaw Vector
v

unsafeFreeze :: MMatrix s -> ST s Matrix
unsafeFreeze :: forall s. MMatrix s -> ST s Matrix
unsafeFreeze (MMatrix Int
r Int
c MVector s
mv) = Int -> Int -> Vector -> Matrix
Matrix Int
r Int
c (Vector -> Matrix) -> ST s Vector -> ST s Matrix
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) Double -> ST s Vector
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s
MVector (PrimState (ST s)) Double
mv

-- | Allocate new matrix. Matrix content is not initialized hence unsafe.
unsafeNew :: Int                -- ^ Number of row
          -> Int                -- ^ Number of columns
          -> ST s (MMatrix s)
unsafeNew :: forall s. Int -> Int -> ST s (MMatrix s)
unsafeNew Int
r Int
c
  | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0     = [Char] -> ST s (MMatrix s)
forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Matrix.Mutable.unsafeNew: negative number of rows"
  | Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0     = [Char] -> ST s (MMatrix s)
forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Matrix.Mutable.unsafeNew: negative number of columns"
  | Bool
otherwise = do
      MVector s
vec <- Int -> ST s (MVector (PrimState (ST s)) Double)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
M.new (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
c)
      MMatrix s -> ST s (MMatrix s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (MMatrix s -> ST s (MMatrix s)) -> MMatrix s -> ST s (MMatrix s)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> MVector s -> MMatrix s
forall s. Int -> Int -> MVector s -> MMatrix s
MMatrix Int
r Int
c MVector s
vec

unsafeRead :: MMatrix s -> Int -> Int -> ST s Double
unsafeRead :: forall s. MMatrix s -> Int -> Int -> ST s Double
unsafeRead MMatrix s
mat Int
r Int
c = MMatrix s
-> Int -> Int -> (MVector s -> Int -> ST s Double) -> ST s Double
forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds MMatrix s
mat Int
r Int
c MVector s -> Int -> ST s Double
MVector (PrimState (ST s)) Double -> Int -> ST s Double
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
M.unsafeRead
{-# INLINE unsafeRead #-}

unsafeWrite :: MMatrix s -> Int -> Int -> Double -> ST s ()
unsafeWrite :: forall s. MMatrix s -> Int -> Int -> Double -> ST s ()
unsafeWrite MMatrix s
mat Int
row Int
col Double
k = MMatrix s -> Int -> Int -> (MVector s -> Int -> ST s ()) -> ST s ()
forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds MMatrix s
mat Int
row Int
col ((MVector s -> Int -> ST s ()) -> ST s ())
-> (MVector s -> Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \MVector s
v Int
i ->
  MVector (PrimState (ST s)) Double -> Int -> Double -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector s
MVector (PrimState (ST s)) Double
v Int
i Double
k
{-# INLINE unsafeWrite #-}

unsafeModify :: MMatrix s -> Int -> Int -> (Double -> Double) -> ST s ()
unsafeModify :: forall s. MMatrix s -> Int -> Int -> (Double -> Double) -> ST s ()
unsafeModify MMatrix s
mat Int
row Int
col Double -> Double
f = MMatrix s -> Int -> Int -> (MVector s -> Int -> ST s ()) -> ST s ()
forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds MMatrix s
mat Int
row Int
col ((MVector s -> Int -> ST s ()) -> ST s ())
-> (MVector s -> Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \MVector s
v Int
i -> do
  Double
k <- MVector (PrimState (ST s)) Double -> Int -> ST s Double
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
M.unsafeRead MVector s
MVector (PrimState (ST s)) Double
v Int
i
  MVector (PrimState (ST s)) Double -> Int -> Double -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector s
MVector (PrimState (ST s)) Double
v Int
i (Double -> Double
f Double
k)
{-# INLINE unsafeModify #-}

-- | Given row and column numbers, calculate the offset into the flat
-- row-major vector.
bounds :: MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
bounds :: forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
bounds (MMatrix Int
rs Int
cs MVector s
mv) Int
r Int
c MVector s -> Int -> r
k
  | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
rs = [Char] -> r
forall a. HasCallStack => [Char] -> a
error [Char]
"row out of bounds"
  | Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
cs = [Char] -> r
forall a. HasCallStack => [Char] -> a
error [Char]
"column out of bounds"
  | Bool
otherwise        = MVector s -> Int -> r
k MVector s
mv (Int -> r) -> Int -> r
forall a b. (a -> b) -> a -> b
$! Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
cs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c
{-# INLINE bounds #-}

-- | Given row and column numbers, calculate the offset into the flat
-- row-major vector, without checking.
unsafeBounds :: MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds :: forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds (MMatrix Int
_ Int
cs MVector s
mv) Int
r Int
c MVector s -> Int -> r
k = MVector s -> Int -> r
k MVector s
mv (Int -> r) -> Int -> r
forall a b. (a -> b) -> a -> b
$! Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
cs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c
{-# INLINE unsafeBounds #-}

immutably :: NFData a => MMatrix s -> (Matrix -> a) -> ST s a
immutably :: forall a s. NFData a => MMatrix s -> (Matrix -> a) -> ST s a
immutably MMatrix s
mmat Matrix -> a
f = do
  a
k <- Matrix -> a
f (Matrix -> a) -> ST s Matrix -> ST s a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MMatrix s -> ST s Matrix
forall s. MMatrix s -> ST s Matrix
unsafeFreeze MMatrix s
mmat
  a -> ()
forall a. NFData a => a -> ()
rnf a
k () -> ST s a -> ST s a
forall a b. a -> b -> b
`seq` a -> ST s a
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
k
{-# INLINE immutably #-}