{-# LANGUAGE Rank2Types    #-}
{-# LANGUAGE BangPatterns  #-}
{-# LANGUAGE ViewPatterns #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Internal.ST
-- Copyright   :  (c) Alberto Ruiz 2008
-- License     :  BSD3
-- Maintainer  :  Alberto Ruiz
-- Stability   :  provisional
--
-- In-place manipulation inside the ST monad.
-- See @examples/inplace.hs@ in the repository.
--
-----------------------------------------------------------------------------

module Internal.ST (
    ST, runST,
    -- * Mutable Vectors
    STVector, newVector, thawVector, freezeVector, runSTVector,
    readVector, writeVector, modifyVector, liftSTVector,
    -- * Mutable Matrices
    STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
    readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
    mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..),
    -- * Unsafe functions
    newUndefinedVector,
    unsafeReadVector, unsafeWriteVector,
    unsafeThawVector, unsafeFreezeVector,
    newUndefinedMatrix,
    unsafeReadMatrix, unsafeWriteMatrix,
    unsafeThawMatrix, unsafeFreezeMatrix
) where

import Internal.Vector
import Internal.Matrix
import Internal.Vectorized
import Control.Monad.ST(ST, runST)
import Foreign.Storable(Storable, peekElemOff, pokeElemOff)
import Control.Monad.ST.Unsafe(unsafeIOToST)

{-# INLINE ioReadV #-}
ioReadV :: Storable t => Vector t -> Int -> IO t
ioReadV :: Vector t -> Int -> IO t
ioReadV Vector t
v Int
k = Vector t -> (Ptr t -> IO t) -> IO t
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
unsafeWith Vector t
v ((Ptr t -> IO t) -> IO t) -> (Ptr t -> IO t) -> IO t
forall a b. (a -> b) -> a -> b
$ \Ptr t
s -> Ptr t -> Int -> IO t
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr t
s Int
k

{-# INLINE ioWriteV #-}
ioWriteV :: Storable t => Vector t -> Int -> t -> IO ()
ioWriteV :: Vector t -> Int -> t -> IO ()
ioWriteV Vector t
v Int
k t
x = Vector t -> (Ptr t -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
unsafeWith Vector t
v ((Ptr t -> IO ()) -> IO ()) -> (Ptr t -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr t
s -> Ptr t -> Int -> t -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr t
s Int
k t
x

newtype STVector s t = STVector (Vector t)

thawVector :: Storable t => Vector t -> ST s (STVector s t)
thawVector :: Vector t -> ST s (STVector s t)
thawVector = IO (STVector s t) -> ST s (STVector s t)
forall a s. IO a -> ST s a
unsafeIOToST (IO (STVector s t) -> ST s (STVector s t))
-> (Vector t -> IO (STVector s t))
-> Vector t
-> ST s (STVector s t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector t -> STVector s t) -> IO (Vector t) -> IO (STVector s t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Vector t -> STVector s t
forall s t. Vector t -> STVector s t
STVector (IO (Vector t) -> IO (STVector s t))
-> (Vector t -> IO (Vector t)) -> Vector t -> IO (STVector s t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector t -> IO (Vector t)
forall t. Storable t => Vector t -> IO (Vector t)
cloneVector

unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t)
unsafeThawVector :: Vector t -> ST s (STVector s t)
unsafeThawVector = IO (STVector s t) -> ST s (STVector s t)
forall a s. IO a -> ST s a
unsafeIOToST (IO (STVector s t) -> ST s (STVector s t))
-> (Vector t -> IO (STVector s t))
-> Vector t
-> ST s (STVector s t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STVector s t -> IO (STVector s t)
forall (m :: * -> *) a. Monad m => a -> m a
return (STVector s t -> IO (STVector s t))
-> (Vector t -> STVector s t) -> Vector t -> IO (STVector s t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector t -> STVector s t
forall s t. Vector t -> STVector s t
STVector

runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t
runSTVector :: (forall s. ST s (STVector s t)) -> Vector t
runSTVector forall s. ST s (STVector s t)
st = (forall s. ST s (Vector t)) -> Vector t
forall a. (forall s. ST s a) -> a
runST (ST s (STVector s t)
forall s. ST s (STVector s t)
st ST s (STVector s t)
-> (STVector s t -> ST s (Vector t)) -> ST s (Vector t)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STVector s t -> ST s (Vector t)
forall t s. Storable t => STVector s t -> ST s (Vector t)
unsafeFreezeVector)

{-# INLINE unsafeReadVector #-}
unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t
unsafeReadVector :: STVector s t -> Int -> ST s t
unsafeReadVector   (STVector Vector t
x) = IO t -> ST s t
forall a s. IO a -> ST s a
unsafeIOToST (IO t -> ST s t) -> (Int -> IO t) -> Int -> ST s t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector t -> Int -> IO t
forall t. Storable t => Vector t -> Int -> IO t
ioReadV Vector t
x

{-# INLINE unsafeWriteVector #-}
unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s ()
unsafeWriteVector :: STVector s t -> Int -> t -> ST s ()
unsafeWriteVector  (STVector Vector t
x) Int
k = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> (t -> IO ()) -> t -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector t -> Int -> t -> IO ()
forall t. Storable t => Vector t -> Int -> t -> IO ()
ioWriteV Vector t
x Int
k

{-# INLINE modifyVector #-}
modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
modifyVector :: STVector s t -> Int -> (t -> t) -> ST s ()
modifyVector STVector s t
x Int
k t -> t
f = STVector s t -> Int -> ST s t
forall t s. Storable t => STVector s t -> Int -> ST s t
readVector STVector s t
x Int
k ST s t -> (t -> ST s t) -> ST s t
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> ST s t
forall (m :: * -> *) a. Monad m => a -> m a
return (t -> ST s t) -> (t -> t) -> t -> ST s t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> t
f ST s t -> (t -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STVector s t -> Int -> t -> ST s ()
forall t s. Storable t => STVector s t -> Int -> t -> ST s ()
unsafeWriteVector STVector s t
x Int
k

liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s t -> ST s a
liftSTVector :: (Vector t -> a) -> STVector s t -> ST s a
liftSTVector Vector t -> a
f (STVector Vector t
x) = IO a -> ST s a
forall a s. IO a -> ST s a
unsafeIOToST (IO a -> ST s a) -> (Vector t -> IO a) -> Vector t -> ST s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector t -> a) -> IO (Vector t) -> IO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Vector t -> a
f (IO (Vector t) -> IO a)
-> (Vector t -> IO (Vector t)) -> Vector t -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector t -> IO (Vector t)
forall t. Storable t => Vector t -> IO (Vector t)
cloneVector (Vector t -> ST s a) -> Vector t -> ST s a
forall a b. (a -> b) -> a -> b
$ Vector t
x

freezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
freezeVector :: STVector s t -> ST s (Vector t)
freezeVector STVector s t
v = (Vector t -> Vector t) -> STVector s t -> ST s (Vector t)
forall t a s.
Storable t =>
(Vector t -> a) -> STVector s t -> ST s a
liftSTVector Vector t -> Vector t
forall a. a -> a
id STVector s t
v

unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
unsafeFreezeVector :: STVector s t -> ST s (Vector t)
unsafeFreezeVector (STVector Vector t
x) = IO (Vector t) -> ST s (Vector t)
forall a s. IO a -> ST s a
unsafeIOToST (IO (Vector t) -> ST s (Vector t))
-> (Vector t -> IO (Vector t)) -> Vector t -> ST s (Vector t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector t -> IO (Vector t)
forall (m :: * -> *) a. Monad m => a -> m a
return (Vector t -> ST s (Vector t)) -> Vector t -> ST s (Vector t)
forall a b. (a -> b) -> a -> b
$ Vector t
x

{-# INLINE safeIndexV #-}
safeIndexV :: Storable t2
           => (STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t
safeIndexV :: (STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t
safeIndexV STVector s t2 -> Int -> t
f (STVector Vector t2
v) Int
k
    | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
kInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Vector t2 -> Int
forall t. Storable t => Vector t -> Int
dim Vector t2
v = [Char] -> t
forall a. HasCallStack => [Char] -> a
error ([Char] -> t) -> [Char] -> t
forall a b. (a -> b) -> a -> b
$ [Char]
"out of range error in vector (dim="
                                   [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++Int -> [Char]
forall a. Show a => a -> [Char]
show (Vector t2 -> Int
forall t. Storable t => Vector t -> Int
dim Vector t2
v)[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
", pos="[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++Int -> [Char]
forall a. Show a => a -> [Char]
show Int
k[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
")"
    | Bool
otherwise = STVector s t2 -> Int -> t
f (Vector t2 -> STVector s t2
forall s t. Vector t -> STVector s t
STVector Vector t2
v) Int
k

{-# INLINE readVector #-}
readVector :: Storable t => STVector s t -> Int -> ST s t
readVector :: STVector s t -> Int -> ST s t
readVector = (STVector s t -> Int -> ST s t) -> STVector s t -> Int -> ST s t
forall t2 s t t1.
Storable t2 =>
(STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t
safeIndexV STVector s t -> Int -> ST s t
forall t s. Storable t => STVector s t -> Int -> ST s t
unsafeReadVector

{-# INLINE writeVector #-}
writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
writeVector :: STVector s t -> Int -> t -> ST s ()
writeVector = (STVector s t -> Int -> t -> ST s ())
-> STVector s t -> Int -> t -> ST s ()
forall t2 s t t1.
Storable t2 =>
(STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t
safeIndexV STVector s t -> Int -> t -> ST s ()
forall t s. Storable t => STVector s t -> Int -> t -> ST s ()
unsafeWriteVector

newUndefinedVector :: Storable t => Int -> ST s (STVector s t)
newUndefinedVector :: Int -> ST s (STVector s t)
newUndefinedVector = IO (STVector s t) -> ST s (STVector s t)
forall a s. IO a -> ST s a
unsafeIOToST (IO (STVector s t) -> ST s (STVector s t))
-> (Int -> IO (STVector s t)) -> Int -> ST s (STVector s t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector t -> STVector s t) -> IO (Vector t) -> IO (STVector s t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Vector t -> STVector s t
forall s t. Vector t -> STVector s t
STVector (IO (Vector t) -> IO (STVector s t))
-> (Int -> IO (Vector t)) -> Int -> IO (STVector s t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IO (Vector t)
forall a. Storable a => Int -> IO (Vector a)
createVector

{-# INLINE newVector #-}
newVector :: Storable t => t -> Int -> ST s (STVector s t)
newVector :: t -> Int -> ST s (STVector s t)
newVector t
x Int
n = do
    STVector s t
v <- Int -> ST s (STVector s t)
forall t s. Storable t => Int -> ST s (STVector s t)
newUndefinedVector Int
n
    let go :: Int -> ST s (STVector s t)
go (-1) = STVector s t -> ST s (STVector s t)
forall (m :: * -> *) a. Monad m => a -> m a
return STVector s t
v
        go !Int
k = STVector s t -> Int -> t -> ST s ()
forall t s. Storable t => STVector s t -> Int -> t -> ST s ()
unsafeWriteVector STVector s t
v Int
k t
x ST s () -> ST s (STVector s t) -> ST s (STVector s t)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ST s (STVector s t)
go (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 :: Int)
    Int -> ST s (STVector s t)
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)

-------------------------------------------------------------------------

{-# INLINE ioReadM #-}
ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
ioReadM :: Matrix t -> Int -> Int -> IO t
ioReadM Matrix t
m Int
r Int
c = Vector t -> Int -> IO t
forall t. Storable t => Vector t -> Int -> IO t
ioReadV (Matrix t -> Vector t
forall t. Matrix t -> Vector t
xdat Matrix t
m) (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
* Matrix t -> Int
forall t. Matrix t -> Int
xRow Matrix t
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
* Matrix t -> Int
forall t. Matrix t -> Int
xCol Matrix t
m)


{-# INLINE ioWriteM #-}
ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
ioWriteM :: Matrix t -> Int -> Int -> t -> IO ()
ioWriteM Matrix t
m Int
r Int
c t
val = Vector t -> Int -> t -> IO ()
forall t. Storable t => Vector t -> Int -> t -> IO ()
ioWriteV (Matrix t -> Vector t
forall t. Matrix t -> Vector t
xdat Matrix t
m)  (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
* Matrix t -> Int
forall t. Matrix t -> Int
xRow Matrix t
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
* Matrix t -> Int
forall t. Matrix t -> Int
xCol Matrix t
m) t
val


newtype STMatrix s t = STMatrix (Matrix t)

thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t)
thawMatrix :: Matrix t -> ST s (STMatrix s t)
thawMatrix = IO (STMatrix s t) -> ST s (STMatrix s t)
forall a s. IO a -> ST s a
unsafeIOToST (IO (STMatrix s t) -> ST s (STMatrix s t))
-> (Matrix t -> IO (STMatrix s t))
-> Matrix t
-> ST s (STMatrix s t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Matrix t -> STMatrix s t) -> IO (Matrix t) -> IO (STMatrix s t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Matrix t -> STMatrix s t
forall s t. Matrix t -> STMatrix s t
STMatrix (IO (Matrix t) -> IO (STMatrix s t))
-> (Matrix t -> IO (Matrix t)) -> Matrix t -> IO (STMatrix s t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix t -> IO (Matrix t)
forall t. Element t => Matrix t -> IO (Matrix t)
cloneMatrix

unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
unsafeThawMatrix :: Matrix t -> ST s (STMatrix s t)
unsafeThawMatrix = IO (STMatrix s t) -> ST s (STMatrix s t)
forall a s. IO a -> ST s a
unsafeIOToST (IO (STMatrix s t) -> ST s (STMatrix s t))
-> (Matrix t -> IO (STMatrix s t))
-> Matrix t
-> ST s (STMatrix s t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STMatrix s t -> IO (STMatrix s t)
forall (m :: * -> *) a. Monad m => a -> m a
return (STMatrix s t -> IO (STMatrix s t))
-> (Matrix t -> STMatrix s t) -> Matrix t -> IO (STMatrix s t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix t -> STMatrix s t
forall s t. Matrix t -> STMatrix s t
STMatrix

runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t
runSTMatrix :: (forall s. ST s (STMatrix s t)) -> Matrix t
runSTMatrix forall s. ST s (STMatrix s t)
st = (forall s. ST s (Matrix t)) -> Matrix t
forall a. (forall s. ST s a) -> a
runST (ST s (STMatrix s t)
forall s. ST s (STMatrix s t)
st ST s (STMatrix s t)
-> (STMatrix s t -> ST s (Matrix t)) -> ST s (Matrix t)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STMatrix s t -> ST s (Matrix t)
forall t s. Storable t => STMatrix s t -> ST s (Matrix t)
unsafeFreezeMatrix)

{-# INLINE unsafeReadMatrix #-}
unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
unsafeReadMatrix :: STMatrix s t -> Int -> Int -> ST s t
unsafeReadMatrix   (STMatrix Matrix t
x) Int
r = IO t -> ST s t
forall a s. IO a -> ST s a
unsafeIOToST (IO t -> ST s t) -> (Int -> IO t) -> Int -> ST s t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix t -> Int -> Int -> IO t
forall t. Storable t => Matrix t -> Int -> Int -> IO t
ioReadM Matrix t
x Int
r

{-# INLINE unsafeWriteMatrix #-}
unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
unsafeWriteMatrix :: STMatrix s t -> Int -> Int -> t -> ST s ()
unsafeWriteMatrix  (STMatrix Matrix t
x) Int
r Int
c = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> (t -> IO ()) -> t -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix t -> Int -> Int -> t -> IO ()
forall t. Storable t => Matrix t -> Int -> Int -> t -> IO ()
ioWriteM Matrix t
x Int
r Int
c

{-# INLINE modifyMatrix #-}
modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
modifyMatrix :: STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
modifyMatrix STMatrix s t
x Int
r Int
c t -> t
f = STMatrix s t -> Int -> Int -> ST s t
forall t s. Storable t => STMatrix s t -> Int -> Int -> ST s t
readMatrix STMatrix s t
x Int
r Int
c ST s t -> (t -> ST s t) -> ST s t
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> ST s t
forall (m :: * -> *) a. Monad m => a -> m a
return (t -> ST s t) -> (t -> t) -> t -> ST s t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> t
f ST s t -> (t -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STMatrix s t -> Int -> Int -> t -> ST s ()
forall t s.
Storable t =>
STMatrix s t -> Int -> Int -> t -> ST s ()
unsafeWriteMatrix STMatrix s t
x Int
r Int
c

liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a
liftSTMatrix :: (Matrix t -> a) -> STMatrix s t -> ST s a
liftSTMatrix Matrix t -> a
f (STMatrix Matrix t
x) = IO a -> ST s a
forall a s. IO a -> ST s a
unsafeIOToST (IO a -> ST s a) -> (Matrix t -> IO a) -> Matrix t -> ST s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Matrix t -> a) -> IO (Matrix t) -> IO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Matrix t -> a
f (IO (Matrix t) -> IO a)
-> (Matrix t -> IO (Matrix t)) -> Matrix t -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix t -> IO (Matrix t)
forall t. Element t => Matrix t -> IO (Matrix t)
cloneMatrix (Matrix t -> ST s a) -> Matrix t -> ST s a
forall a b. (a -> b) -> a -> b
$ Matrix t
x

unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
unsafeFreezeMatrix :: STMatrix s t -> ST s (Matrix t)
unsafeFreezeMatrix (STMatrix Matrix t
x) = IO (Matrix t) -> ST s (Matrix t)
forall a s. IO a -> ST s a
unsafeIOToST (IO (Matrix t) -> ST s (Matrix t))
-> (Matrix t -> IO (Matrix t)) -> Matrix t -> ST s (Matrix t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix t -> IO (Matrix t)
forall (m :: * -> *) a. Monad m => a -> m a
return (Matrix t -> ST s (Matrix t)) -> Matrix t -> ST s (Matrix t)
forall a b. (a -> b) -> a -> b
$ Matrix t
x


freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t)
freezeMatrix :: STMatrix s t -> ST s (Matrix t)
freezeMatrix STMatrix s t
m = (Matrix t -> Matrix t) -> STMatrix s t -> ST s (Matrix t)
forall t a s.
Element t =>
(Matrix t -> a) -> STMatrix s t -> ST s a
liftSTMatrix Matrix t -> Matrix t
forall a. a -> a
id STMatrix s t
m

cloneMatrix :: Element t => Matrix t -> IO (Matrix t)
cloneMatrix :: Matrix t -> IO (Matrix t)
cloneMatrix Matrix t
m = MatrixOrder -> Matrix t -> IO (Matrix t)
forall t. Element t => MatrixOrder -> Matrix t -> IO (Matrix t)
copy (Matrix t -> MatrixOrder
forall t. Matrix t -> MatrixOrder
orderOf Matrix t
m) Matrix t
m

{-# INLINE safeIndexM #-}
safeIndexM :: (STMatrix s t2 -> Int -> Int -> t)
           -> STMatrix t1 t2 -> Int -> Int -> t
safeIndexM :: (STMatrix s t2 -> Int -> Int -> t)
-> STMatrix t1 t2 -> Int -> Int -> t
safeIndexM STMatrix s t2 -> Int -> Int -> t
f (STMatrix Matrix t2
m) Int
r Int
c
    | Int
rInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<Int
0 Bool -> Bool -> Bool
|| Int
rInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>=Matrix t2 -> Int
forall t. Matrix t -> Int
rows Matrix t2
m Bool -> Bool -> Bool
||
      Int
cInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<Int
0 Bool -> Bool -> Bool
|| Int
cInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>=Matrix t2 -> Int
forall t. Matrix t -> Int
cols Matrix t2
m = [Char] -> t
forall a. HasCallStack => [Char] -> a
error ([Char] -> t) -> [Char] -> t
forall a b. (a -> b) -> a -> b
$ [Char]
"out of range error in matrix (size="
                                 [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++(Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show (Matrix t2 -> Int
forall t. Matrix t -> Int
rows Matrix t2
m,Matrix t2 -> Int
forall t. Matrix t -> Int
cols Matrix t2
m)[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
", pos="[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++(Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show (Int
r,Int
c)[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
")"
    | Bool
otherwise = STMatrix s t2 -> Int -> Int -> t
f (Matrix t2 -> STMatrix s t2
forall s t. Matrix t -> STMatrix s t
STMatrix Matrix t2
m) Int
r Int
c

{-# INLINE readMatrix #-}
readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
readMatrix :: STMatrix s t -> Int -> Int -> ST s t
readMatrix = (STMatrix s t -> Int -> Int -> ST s t)
-> STMatrix s t -> Int -> Int -> ST s t
forall s t2 t t1.
(STMatrix s t2 -> Int -> Int -> t)
-> STMatrix t1 t2 -> Int -> Int -> t
safeIndexM STMatrix s t -> Int -> Int -> ST s t
forall t s. Storable t => STMatrix s t -> Int -> Int -> ST s t
unsafeReadMatrix

{-# INLINE writeMatrix #-}
writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
writeMatrix :: STMatrix s t -> Int -> Int -> t -> ST s ()
writeMatrix = (STMatrix s t -> Int -> Int -> t -> ST s ())
-> STMatrix s t -> Int -> Int -> t -> ST s ()
forall s t2 t t1.
(STMatrix s t2 -> Int -> Int -> t)
-> STMatrix t1 t2 -> Int -> Int -> t
safeIndexM STMatrix s t -> Int -> Int -> t -> ST s ()
forall t s.
Storable t =>
STMatrix s t -> Int -> Int -> t -> ST s ()
unsafeWriteMatrix

setMatrix :: Element t => STMatrix s t -> Int -> Int -> Matrix t -> ST s ()
setMatrix :: STMatrix s t -> Int -> Int -> Matrix t -> ST s ()
setMatrix (STMatrix Matrix t
x) Int
i Int
j Matrix t
m = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> IO () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Matrix t -> Matrix t -> IO ()
forall a. Element a => Int -> Int -> Matrix a -> Matrix a -> IO ()
setRect Int
i Int
j Matrix t
m Matrix t
x

newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
newUndefinedMatrix :: MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
newUndefinedMatrix MatrixOrder
ord Int
r Int
c = IO (STMatrix s t) -> ST s (STMatrix s t)
forall a s. IO a -> ST s a
unsafeIOToST (IO (STMatrix s t) -> ST s (STMatrix s t))
-> IO (STMatrix s t) -> ST s (STMatrix s t)
forall a b. (a -> b) -> a -> b
$ (Matrix t -> STMatrix s t) -> IO (Matrix t) -> IO (STMatrix s t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Matrix t -> STMatrix s t
forall s t. Matrix t -> STMatrix s t
STMatrix (IO (Matrix t) -> IO (STMatrix s t))
-> IO (Matrix t) -> IO (STMatrix s t)
forall a b. (a -> b) -> a -> b
$ MatrixOrder -> Int -> Int -> IO (Matrix t)
forall a. Storable a => MatrixOrder -> Int -> Int -> IO (Matrix a)
createMatrix MatrixOrder
ord Int
r Int
c

{-# NOINLINE newMatrix #-}
newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
newMatrix :: t -> Int -> Int -> ST s (STMatrix s t)
newMatrix t
v Int
r Int
c = Matrix t -> ST s (STMatrix s t)
forall t s. Storable t => Matrix t -> ST s (STMatrix s t)
unsafeThawMatrix (Matrix t -> ST s (STMatrix s t))
-> Matrix t -> ST s (STMatrix s t)
forall a b. (a -> b) -> a -> b
$ Int -> Vector t -> Matrix t
forall t. Storable t => Int -> Vector t -> Matrix t
reshape Int
c (Vector t -> Matrix t) -> Vector t -> Matrix t
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (STVector s t)) -> Vector t
forall t. Storable t => (forall s. ST s (STVector s t)) -> Vector t
runSTVector ((forall s. ST s (STVector s t)) -> Vector t)
-> (forall s. ST s (STVector s t)) -> Vector t
forall a b. (a -> b) -> a -> b
$ t -> Int -> ST s (STVector s t)
forall t s. Storable t => t -> Int -> ST s (STVector s t)
newVector t
v (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
c)

--------------------------------------------------------------------------------

data ColRange = AllCols
              | ColRange Int Int
              | Col Int
              | FromCol Int

getColRange :: Int -> ColRange -> (Int, Int)
getColRange :: Int -> ColRange -> (Int, Int)
getColRange Int
c ColRange
AllCols = (Int
0,Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
getColRange Int
c (ColRange Int
a Int
b) = (Int
a Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
c, Int
b Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
c)
getColRange Int
c (Col Int
a) = (Int
a Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
c, Int
a Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
c)
getColRange Int
c (FromCol Int
a) = (Int
a Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
c, Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)

data RowRange = AllRows
              | RowRange Int Int
              | Row Int
              | FromRow Int

getRowRange :: Int -> RowRange -> (Int, Int)
getRowRange :: Int -> RowRange -> (Int, Int)
getRowRange Int
r RowRange
AllRows = (Int
0,Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
getRowRange Int
r (RowRange Int
a Int
b) = (Int
a Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
r, Int
b Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
r)
getRowRange Int
r (Row Int
a) = (Int
a Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
r, Int
a Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
r)
getRowRange Int
r (FromRow Int
a) = (Int
a Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
r, Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)

data RowOper t = AXPY t Int Int  ColRange
               | SCAL t RowRange ColRange
               | SWAP Int Int    ColRange

rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s ()

rowOper :: RowOper t -> STMatrix s t -> ST s ()
rowOper (AXPY t
x Int
i1 Int
i2 ColRange
r) (STMatrix Matrix t
m) = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> IO () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int -> t -> Int -> Int -> Int -> Int -> Matrix t -> IO ()
forall a.
Element a =>
Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
rowOp Int
0 t
x Int
i1' Int
i2' Int
j1 Int
j2 Matrix t
m
  where
    (Int
j1,Int
j2) = Int -> ColRange -> (Int, Int)
getColRange (Matrix t -> Int
forall t. Matrix t -> Int
cols Matrix t
m) ColRange
r
    i1' :: Int
i1' = Int
i1 Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` (Matrix t -> Int
forall t. Matrix t -> Int
rows Matrix t
m)
    i2' :: Int
i2' = Int
i2 Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` (Matrix t -> Int
forall t. Matrix t -> Int
rows Matrix t
m)

rowOper (SCAL t
x RowRange
rr ColRange
rc) (STMatrix Matrix t
m) = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> IO () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int -> t -> Int -> Int -> Int -> Int -> Matrix t -> IO ()
forall a.
Element a =>
Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
rowOp Int
1 t
x Int
i1 Int
i2 Int
j1 Int
j2 Matrix t
m
  where
    (Int
i1,Int
i2) = Int -> RowRange -> (Int, Int)
getRowRange (Matrix t -> Int
forall t. Matrix t -> Int
rows Matrix t
m) RowRange
rr
    (Int
j1,Int
j2) = Int -> ColRange -> (Int, Int)
getColRange (Matrix t -> Int
forall t. Matrix t -> Int
cols Matrix t
m) ColRange
rc

rowOper (SWAP Int
i1 Int
i2 ColRange
r) (STMatrix Matrix t
m) = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> IO () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int -> t -> Int -> Int -> Int -> Int -> Matrix t -> IO ()
forall a.
Element a =>
Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
rowOp Int
2 t
0 Int
i1' Int
i2' Int
j1 Int
j2 Matrix t
m
  where
    (Int
j1,Int
j2) = Int -> ColRange -> (Int, Int)
getColRange (Matrix t -> Int
forall t. Matrix t -> Int
cols Matrix t
m) ColRange
r
    i1' :: Int
i1' = Int
i1 Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` (Matrix t -> Int
forall t. Matrix t -> Int
rows Matrix t
m)
    i2' :: Int
i2' = Int
i2 Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` (Matrix t -> Int
forall t. Matrix t -> Int
rows Matrix t
m)


extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a)
extractMatrix :: STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a)
extractMatrix (STMatrix Matrix a
m) RowRange
rr ColRange
rc = IO (Matrix a) -> ST s (Matrix a)
forall a s. IO a -> ST s a
unsafeIOToST (MatrixOrder
-> Matrix a
-> CInt
-> Vector CInt
-> CInt
-> Vector CInt
-> IO (Matrix a)
forall a.
Element a =>
MatrixOrder
-> Matrix a
-> CInt
-> Vector CInt
-> CInt
-> Vector CInt
-> IO (Matrix a)
extractR (Matrix a -> MatrixOrder
forall t. Matrix t -> MatrixOrder
orderOf Matrix a
m) Matrix a
m CInt
0 ([Int] -> Vector CInt
idxs[Int
i1,Int
i2]) CInt
0 ([Int] -> Vector CInt
idxs[Int
j1,Int
j2]))
  where
    (Int
i1,Int
i2) = Int -> RowRange -> (Int, Int)
getRowRange (Matrix a -> Int
forall t. Matrix t -> Int
rows Matrix a
m) RowRange
rr
    (Int
j1,Int
j2) = Int -> ColRange -> (Int, Int)
getColRange (Matrix a -> Int
forall t. Matrix t -> Int
cols Matrix a
m) ColRange
rc

-- | r0 c0 height width
data Slice s t = Slice (STMatrix s t) Int Int Int Int

slice :: Element a => Slice t a -> Matrix a
slice :: Slice t a -> Matrix a
slice (Slice (STMatrix Matrix a
m) Int
r0 Int
c0 Int
nr Int
nc) = (Int, Int) -> (Int, Int) -> Matrix a -> Matrix a
forall a.
Element a =>
(Int, Int) -> (Int, Int) -> Matrix a -> Matrix a
subMatrix (Int
r0,Int
c0) (Int
nr,Int
nc) Matrix a
m

gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
gemmm :: t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
gemmm t
beta (Slice s t -> Matrix t
forall a t. Element a => Slice t a -> Matrix a
slice->Matrix t
r) t
alpha (Slice s t -> Matrix t
forall a t. Element a => Slice t a -> Matrix a
slice->Matrix t
a) (Slice s t -> Matrix t
forall a t. Element a => Slice t a -> Matrix a
slice->Matrix t
b) = ST s ()
forall s. ST s ()
res
  where
    res :: ST s ()
res = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (Vector t -> Matrix t -> Matrix t -> Matrix t -> IO ()
forall a.
Element a =>
Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
gemm Vector t
v Matrix t
a Matrix t
b Matrix t
r)
    v :: Vector t
v = [t] -> Vector t
forall a. Storable a => [a] -> Vector a
fromList [t
alpha,t
beta]


mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
mutable :: (forall s. (Int, Int) -> STMatrix s t -> ST s u)
-> Matrix t -> (Matrix t, u)
mutable forall s. (Int, Int) -> STMatrix s t -> ST s u
f Matrix t
a = (forall s. ST s (Matrix t, u)) -> (Matrix t, u)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Matrix t, u)) -> (Matrix t, u))
-> (forall s. ST s (Matrix t, u)) -> (Matrix t, u)
forall a b. (a -> b) -> a -> b
$ do
   STMatrix s t
x <- Matrix t -> ST s (STMatrix s t)
forall t s. Element t => Matrix t -> ST s (STMatrix s t)
thawMatrix Matrix t
a
   u
info <- (Int, Int) -> STMatrix s t -> ST s u
forall s. (Int, Int) -> STMatrix s t -> ST s u
f (Matrix t -> Int
forall t. Matrix t -> Int
rows Matrix t
a, Matrix t -> Int
forall t. Matrix t -> Int
cols Matrix t
a) STMatrix s t
x
   Matrix t
r <- STMatrix s t -> ST s (Matrix t)
forall t s. Storable t => STMatrix s t -> ST s (Matrix t)
unsafeFreezeMatrix STMatrix s t
x
   (Matrix t, u) -> ST s (Matrix t, u)
forall (m :: * -> *) a. Monad m => a -> m a
return (Matrix t
r,u
info)