module Statistics.Matrix.Mutable
(
MMatrix(..)
, MVector
, replicate
, thaw
, bounds
, unsafeNew
, unsafeFreeze
, unsafeRead
, unsafeWrite
, unsafeModify
, immutably
, unsafeBounds
) where
import Control.Applicative ((<$>))
import Control.DeepSeq (NFData(..))
import Control.Monad.ST (ST)
import Statistics.Matrix.Types (Matrix(..), MMatrix(..), MVector)
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as M
import Prelude hiding (replicate)
replicate :: Int -> Int -> Double -> ST s (MMatrix s)
replicate :: forall s. Int -> Int -> Double -> ST s (MMatrix s)
replicate Int
r Int
c Double
k = Int -> Int -> MVector s -> MMatrix s
forall s. Int -> Int -> MVector s -> MMatrix s
MMatrix Int
r Int
c (MVector s -> MMatrix s) -> ST s (MVector s) -> ST s (MMatrix s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Double -> ST s (MVector (PrimState (ST s)) Double)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
M.replicate (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
c) Double
k
thaw :: Matrix -> ST s (MMatrix s)
thaw :: forall s. Matrix -> ST s (MMatrix s)
thaw (Matrix Int
r Int
c Vector
v) = Int -> Int -> MVector s -> MMatrix s
forall s. Int -> Int -> MVector s -> MMatrix s
MMatrix Int
r Int
c (MVector s -> MMatrix s) -> ST s (MVector s) -> ST s (MMatrix s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector -> ST s (MVector (PrimState (ST s)) Double)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.thaw Vector
v
unsafeFreeze :: MMatrix s -> ST s Matrix
unsafeFreeze :: forall s. MMatrix s -> ST s Matrix
unsafeFreeze (MMatrix Int
r Int
c MVector s
mv) = Int -> Int -> Vector -> Matrix
Matrix Int
r Int
c (Vector -> Matrix) -> ST s Vector -> ST s Matrix
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) Double -> ST s Vector
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s
MVector (PrimState (ST s)) Double
mv
unsafeNew :: Int
-> Int
-> ST s (MMatrix s)
unsafeNew :: forall s. Int -> Int -> ST s (MMatrix s)
unsafeNew Int
r Int
c
| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = [Char] -> ST s (MMatrix s)
forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Matrix.Mutable.unsafeNew: negative number of rows"
| Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = [Char] -> ST s (MMatrix s)
forall a. HasCallStack => [Char] -> a
error [Char]
"Statistics.Matrix.Mutable.unsafeNew: negative number of columns"
| Bool
otherwise = do
MVector s
vec <- Int -> ST s (MVector (PrimState (ST s)) Double)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
M.new (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
c)
MMatrix s -> ST s (MMatrix s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (MMatrix s -> ST s (MMatrix s)) -> MMatrix s -> ST s (MMatrix s)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> MVector s -> MMatrix s
forall s. Int -> Int -> MVector s -> MMatrix s
MMatrix Int
r Int
c MVector s
vec
unsafeRead :: MMatrix s -> Int -> Int -> ST s Double
unsafeRead :: forall s. MMatrix s -> Int -> Int -> ST s Double
unsafeRead MMatrix s
mat Int
r Int
c = MMatrix s
-> Int -> Int -> (MVector s -> Int -> ST s Double) -> ST s Double
forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds MMatrix s
mat Int
r Int
c MVector s -> Int -> ST s Double
MVector (PrimState (ST s)) Double -> Int -> ST s Double
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
M.unsafeRead
{-# INLINE unsafeRead #-}
unsafeWrite :: MMatrix s -> Int -> Int -> Double -> ST s ()
unsafeWrite :: forall s. MMatrix s -> Int -> Int -> Double -> ST s ()
unsafeWrite MMatrix s
mat Int
row Int
col Double
k = MMatrix s -> Int -> Int -> (MVector s -> Int -> ST s ()) -> ST s ()
forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds MMatrix s
mat Int
row Int
col ((MVector s -> Int -> ST s ()) -> ST s ())
-> (MVector s -> Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \MVector s
v Int
i ->
MVector (PrimState (ST s)) Double -> Int -> Double -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector s
MVector (PrimState (ST s)) Double
v Int
i Double
k
{-# INLINE unsafeWrite #-}
unsafeModify :: MMatrix s -> Int -> Int -> (Double -> Double) -> ST s ()
unsafeModify :: forall s. MMatrix s -> Int -> Int -> (Double -> Double) -> ST s ()
unsafeModify MMatrix s
mat Int
row Int
col Double -> Double
f = MMatrix s -> Int -> Int -> (MVector s -> Int -> ST s ()) -> ST s ()
forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds MMatrix s
mat Int
row Int
col ((MVector s -> Int -> ST s ()) -> ST s ())
-> (MVector s -> Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \MVector s
v Int
i -> do
Double
k <- MVector (PrimState (ST s)) Double -> Int -> ST s Double
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
M.unsafeRead MVector s
MVector (PrimState (ST s)) Double
v Int
i
MVector (PrimState (ST s)) Double -> Int -> Double -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector s
MVector (PrimState (ST s)) Double
v Int
i (Double -> Double
f Double
k)
{-# INLINE unsafeModify #-}
bounds :: MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
bounds :: forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
bounds (MMatrix Int
rs Int
cs MVector s
mv) Int
r Int
c MVector s -> Int -> r
k
| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
rs = [Char] -> r
forall a. HasCallStack => [Char] -> a
error [Char]
"row out of bounds"
| Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
cs = [Char] -> r
forall a. HasCallStack => [Char] -> a
error [Char]
"column out of bounds"
| Bool
otherwise = MVector s -> Int -> r
k MVector s
mv (Int -> r) -> Int -> r
forall a b. (a -> b) -> a -> b
$! Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
cs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c
{-# INLINE bounds #-}
unsafeBounds :: MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds :: forall s r. MMatrix s -> Int -> Int -> (MVector s -> Int -> r) -> r
unsafeBounds (MMatrix Int
_ Int
cs MVector s
mv) Int
r Int
c MVector s -> Int -> r
k = MVector s -> Int -> r
k MVector s
mv (Int -> r) -> Int -> r
forall a b. (a -> b) -> a -> b
$! Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
cs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c
{-# INLINE unsafeBounds #-}
immutably :: NFData a => MMatrix s -> (Matrix -> a) -> ST s a
immutably :: forall a s. NFData a => MMatrix s -> (Matrix -> a) -> ST s a
immutably MMatrix s
mmat Matrix -> a
f = do
a
k <- Matrix -> a
f (Matrix -> a) -> ST s Matrix -> ST s a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MMatrix s -> ST s Matrix
forall s. MMatrix s -> ST s Matrix
unsafeFreeze MMatrix s
mmat
a -> ()
forall a. NFData a => a -> ()
rnf a
k () -> ST s a -> ST s a
forall a b. a -> b -> b
`seq` a -> ST s a
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
k
{-# INLINE immutably #-}