{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
module Internal.Sparse(
GMatrix(..), CSR(..), mkCSR, fromCSR, impureCSR,
mkSparse, mkDiagR, mkDense,
AssocMatrix,
toDense,
gmXv, (!#>)
)where
import Internal.Vector
import Internal.Matrix
import Internal.Numeric
import qualified Data.Vector.Storable as V
import qualified Data.Vector.Storable.Mutable as M
import Control.Arrow((***))
import Control.Monad(when, foldM)
import Control.Monad.ST (runST)
import Control.Monad.Primitive (PrimMonad)
import Data.List(sort)
import Foreign.C.Types(CInt(..))
import Internal.Devel
import System.IO.Unsafe(unsafePerformIO)
import Foreign(Ptr)
import Text.Printf(printf)
type AssocMatrix = [(IndexOf Matrix, Double)]
data CSR = CSR
{ CSR -> Vector Double
csrVals :: Vector Double
, CSR -> Vector CInt
csrCols :: Vector CInt
, CSR -> Vector CInt
csrRows :: Vector CInt
, CSR -> Int
csrNRows :: Int
, CSR -> Int
csrNCols :: Int
} deriving Int -> CSR -> ShowS
[CSR] -> ShowS
CSR -> String
(Int -> CSR -> ShowS)
-> (CSR -> String) -> ([CSR] -> ShowS) -> Show CSR
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CSR] -> ShowS
$cshowList :: [CSR] -> ShowS
show :: CSR -> String
$cshow :: CSR -> String
showsPrec :: Int -> CSR -> ShowS
$cshowsPrec :: Int -> CSR -> ShowS
Show
data CSC = CSC
{ CSC -> Vector Double
cscVals :: Vector Double
, CSC -> Vector CInt
cscRows :: Vector CInt
, CSC -> Vector CInt
cscCols :: Vector CInt
, CSC -> Int
cscNRows :: Int
, CSC -> Int
cscNCols :: Int
} deriving Int -> CSC -> ShowS
[CSC] -> ShowS
CSC -> String
(Int -> CSC -> ShowS)
-> (CSC -> String) -> ([CSC] -> ShowS) -> Show CSC
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CSC] -> ShowS
$cshowList :: [CSC] -> ShowS
show :: CSC -> String
$cshow :: CSC -> String
showsPrec :: Int -> CSC -> ShowS
$cshowsPrec :: Int -> CSC -> ShowS
Show
mkCSR :: AssocMatrix -> CSR
mkCSR :: AssocMatrix -> CSR
mkCSR AssocMatrix
ms =
(forall s. ST s CSR) -> CSR
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s CSR) -> CSR) -> (forall s. ST s CSR) -> CSR
forall a b. (a -> b) -> a -> b
$ (forall x.
(x -> (IndexOf Matrix, Double) -> ST s x)
-> ST s x -> (x -> ST s CSR) -> [((Int, Int), Double)] -> ST s CSR)
-> [((Int, Int), Double)] -> ST s CSR
forall (m :: * -> *) r.
PrimMonad m =>
(forall x.
(x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r)
-> r
impureCSR forall x.
(x -> (IndexOf Matrix, Double) -> ST s x)
-> ST s x -> (x -> ST s CSR) -> [((Int, Int), Double)] -> ST s CSR
forall (m :: * -> *) (t :: * -> *) t a b.
(Monad m, Foldable t) =>
(t -> a -> m t) -> m t -> (t -> m b) -> t a -> m b
runFold ([((Int, Int), Double)] -> ST s CSR)
-> [((Int, Int), Double)] -> ST s CSR
forall a b. (a -> b) -> a -> b
$ [((Int, Int), Double)] -> [((Int, Int), Double)]
forall a. Ord a => [a] -> [a]
sort [((Int, Int), Double)]
AssocMatrix
ms
where
runFold :: (t -> a -> m t) -> m t -> (t -> m b) -> t a -> m b
runFold t -> a -> m t
next m t
initialise t -> m b
xtract t a
as0 = do
t
i0 <- m t
initialise
t
acc <- (t -> a -> m t) -> t -> t a -> m t
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM t -> a -> m t
next t
i0 t a
as0
t -> m b
xtract t
acc
impureCSR
:: PrimMonad m
=> (forall x . (x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r)
-> r
impureCSR :: (forall x.
(x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r)
-> r
impureCSR forall x.
(x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r
f = ((MVector (PrimState m) Double, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, Int)
-> (IndexOf Matrix, Double)
-> m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, Int))
-> m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, Int)
-> ((MVector (PrimState m) Double, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m CSR)
-> r
forall x.
(x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r
f (MVector (PrimState m) Double, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, Int)
-> (IndexOf Matrix, Double)
-> m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, Int)
forall (m :: * -> *) g a.
(Ord g, PrintfArg g, PrimMonad m, Num g, Enum g, Storable a) =>
(MVector (PrimState m) a, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, g)
-> ((g, Int), a)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, g)
next m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, Int)
begin (MVector (PrimState m) Double, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m CSR
forall (m :: * -> *).
PrimMonad m =>
(MVector (PrimState m) Double, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m CSR
done
where
sfi :: Int -> CInt
sfi = CInt -> CInt
forall a. Enum a => a -> a
succ (CInt -> CInt) -> (Int -> CInt) -> Int -> CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> CInt
fi
begin :: m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, Int)
begin = do
MVector (PrimState m) Double
mv <- Int -> m (MVector (PrimState m) Double)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew Int
64
MVector (PrimState m) CInt
mr <- Int -> m (MVector (PrimState m) CInt)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew Int
64
MVector (PrimState m) CInt
mc <- Int -> m (MVector (PrimState m) CInt)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew Int
64
(MVector (PrimState m) Double, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (MVector (PrimState m) Double
mv, MVector (PrimState m) CInt
mr, MVector (PrimState m) CInt
mc, Int
0, Int
0, Int
0, -Int
1)
next :: (MVector (PrimState m) a, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, g)
-> ((g, Int), a)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, g)
next (!MVector (PrimState m) a
mv, !MVector (PrimState m) CInt
mr, !MVector (PrimState m) CInt
mc, !Int
idxVC, !Int
idxR, !Int
maxC, !g
curRow) ((g
r,Int
c),a
d) = do
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (g
r g -> g -> Bool
forall a. Ord a => a -> a -> Bool
< g
curRow) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
String -> m ()
forall a. HasCallStack => String -> a
error (String -> g -> g -> String
forall r. PrintfType r => String -> r
printf String
"impureCSR: row %i specified after %i" g
r g
curRow)
let lenVC :: Int
lenVC = MVector (PrimState m) a -> Int
forall a s. Storable a => MVector s a -> Int
M.length MVector (PrimState m) a
mv
lenR :: Int
lenR = MVector (PrimState m) CInt -> Int
forall a s. Storable a => MVector s a -> Int
M.length MVector (PrimState m) CInt
mr
maxC' :: Int
maxC' = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
maxC Int
c
(MVector (PrimState m) a
mv', MVector (PrimState m) CInt
mc') <-
if Int
idxVC Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
lenVC then do
MVector (PrimState m) a
mv' <- MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
M.unsafeGrow MVector (PrimState m) a
mv Int
lenVC
MVector (PrimState m) CInt
mc' <- MVector (PrimState m) CInt -> Int -> m (MVector (PrimState m) CInt)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
M.unsafeGrow MVector (PrimState m) CInt
mc Int
lenVC
(MVector (PrimState m) a, MVector (PrimState m) CInt)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt)
forall (m :: * -> *) a. Monad m => a -> m a
return (MVector (PrimState m) a
mv', MVector (PrimState m) CInt
mc')
else
(MVector (PrimState m) a, MVector (PrimState m) CInt)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt)
forall (m :: * -> *) a. Monad m => a -> m a
return (MVector (PrimState m) a
mv, MVector (PrimState m) CInt
mc)
MVector (PrimState m) CInt
mr' <-
if Int
idxR Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
lenR Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 then
MVector (PrimState m) CInt -> Int -> m (MVector (PrimState m) CInt)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
M.unsafeGrow MVector (PrimState m) CInt
mr Int
lenR
else
MVector (PrimState m) CInt -> m (MVector (PrimState m) CInt)
forall (m :: * -> *) a. Monad m => a -> m a
return MVector (PrimState m) CInt
mr
MVector (PrimState m) CInt -> Int -> CInt -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) CInt
mc' Int
idxVC (Int -> CInt
sfi Int
c)
MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) a
mv' Int
idxVC a
d
Int
idxR' <-
(Int -> g -> m Int) -> Int -> [g] -> m Int
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
(\Int
idxR' g
_ -> Int
idxR' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> m () -> m Int
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ MVector (PrimState m) CInt -> Int -> CInt -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) CInt
mr' Int
idxR' (Int -> CInt
sfi Int
idxVC))
Int
idxR [g
1 .. (g
rg -> g -> g
forall a. Num a => a -> a -> a
-g
curRow)]
(MVector (PrimState m) a, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, g)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, g)
forall (m :: * -> *) a. Monad m => a -> m a
return (MVector (PrimState m) a
mv', MVector (PrimState m) CInt
mr', MVector (PrimState m) CInt
mc', Int
idxVC Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
idxR', Int
maxC', g
r)
done :: (MVector (PrimState m) Double, MVector (PrimState m) CInt,
MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m CSR
done (!MVector (PrimState m) Double
mv, !MVector (PrimState m) CInt
mr, !MVector (PrimState m) CInt
mc, !Int
idxVC, !Int
idxR, !Int
maxC, !Int
curR) = do
MVector (PrimState m) CInt -> Int -> CInt -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) CInt
mr Int
idxR (Int -> CInt
sfi Int
idxVC)
Vector Double
vv <- MVector (PrimState m) Double -> m (Vector Double)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (Int -> MVector (PrimState m) Double -> MVector (PrimState m) Double
forall a s. Storable a => Int -> MVector s a -> MVector s a
M.unsafeTake Int
idxVC MVector (PrimState m) Double
mv)
Vector CInt
vc <- MVector (PrimState m) CInt -> m (Vector CInt)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (Int -> MVector (PrimState m) CInt -> MVector (PrimState m) CInt
forall a s. Storable a => Int -> MVector s a -> MVector s a
M.unsafeTake Int
idxVC MVector (PrimState m) CInt
mc)
Vector CInt
vr <- MVector (PrimState m) CInt -> m (Vector CInt)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (Int -> MVector (PrimState m) CInt -> MVector (PrimState m) CInt
forall a s. Storable a => Int -> MVector s a -> MVector s a
M.unsafeTake (Int
idxR Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) MVector (PrimState m) CInt
mr)
CSR -> m CSR
forall (m :: * -> *) a. Monad m => a -> m a
return (CSR -> m CSR) -> CSR -> m CSR
forall a b. (a -> b) -> a -> b
$ Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSR
CSR Vector Double
vv Vector CInt
vc Vector CInt
vr (Int -> Int
forall a. Enum a => a -> a
succ Int
curR) (Int -> Int
forall a. Enum a => a -> a
succ Int
maxC)
data GMatrix
= SparseR
{ GMatrix -> CSR
gmCSR :: CSR
, GMatrix -> Int
nRows :: Int
, GMatrix -> Int
nCols :: Int
}
| SparseC
{ GMatrix -> CSC
gmCSC :: CSC
, nRows :: Int
, nCols :: Int
}
| Diag
{ GMatrix -> Vector Double
diagVals :: Vector Double
, nRows :: Int
, nCols :: Int
}
| Dense
{ GMatrix -> Matrix Double
gmDense :: Matrix Double
, nRows :: Int
, nCols :: Int
}
deriving Int -> GMatrix -> ShowS
[GMatrix] -> ShowS
GMatrix -> String
(Int -> GMatrix -> ShowS)
-> (GMatrix -> String) -> ([GMatrix] -> ShowS) -> Show GMatrix
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GMatrix] -> ShowS
$cshowList :: [GMatrix] -> ShowS
show :: GMatrix -> String
$cshow :: GMatrix -> String
showsPrec :: Int -> GMatrix -> ShowS
$cshowsPrec :: Int -> GMatrix -> ShowS
Show
mkDense :: Matrix Double -> GMatrix
mkDense :: Matrix Double -> GMatrix
mkDense Matrix Double
m = Dense :: Matrix Double -> Int -> Int -> GMatrix
Dense{Int
Matrix Double
nCols :: Int
nRows :: Int
gmDense :: Matrix Double
gmDense :: Matrix Double
nCols :: Int
nRows :: Int
..}
where
gmDense :: Matrix Double
gmDense = Matrix Double
m
nRows :: Int
nRows = Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
m
nCols :: Int
nCols = Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
m
mkSparse :: AssocMatrix -> GMatrix
mkSparse :: AssocMatrix -> GMatrix
mkSparse = CSR -> GMatrix
fromCSR (CSR -> GMatrix)
-> ([((Int, Int), Double)] -> CSR)
-> [((Int, Int), Double)]
-> GMatrix
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [((Int, Int), Double)] -> CSR
AssocMatrix -> CSR
mkCSR
fromCSR :: CSR -> GMatrix
fromCSR :: CSR -> GMatrix
fromCSR CSR
csr = SparseR :: CSR -> Int -> Int -> GMatrix
SparseR {Int
CSR
nCols :: Int
nRows :: Int
gmCSR :: CSR
nCols :: Int
nRows :: Int
gmCSR :: CSR
..}
where
gmCSR :: CSR
gmCSR @ CSR {Int
Vector Double
Vector CInt
csrNCols :: Int
csrNRows :: Int
csrRows :: Vector CInt
csrCols :: Vector CInt
csrVals :: Vector Double
csrNCols :: CSR -> Int
csrNRows :: CSR -> Int
csrRows :: CSR -> Vector CInt
csrCols :: CSR -> Vector CInt
csrVals :: CSR -> Vector Double
..} = CSR
csr
nRows :: Int
nRows = Int
csrNRows
nCols :: Int
nCols = Int
csrNCols
mkDiagR :: Int -> Int -> Vector Double -> GMatrix
mkDiagR :: Int -> Int -> Vector Double -> GMatrix
mkDiagR Int
r Int
c Vector Double
v
| Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
r Int
c = Diag :: Vector Double -> Int -> Int -> GMatrix
Diag{Int
Vector Double
diagVals :: Vector Double
nCols :: Int
nRows :: Int
diagVals :: Vector Double
nCols :: Int
nRows :: Int
..}
| Bool
otherwise = String -> GMatrix
forall a. HasCallStack => String -> a
error (String -> GMatrix) -> String -> GMatrix
forall a b. (a -> b) -> a -> b
$ String -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"mkDiagR: incorrect sizes (%d,%d) [%d]" Int
r Int
c (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)
where
nRows :: Int
nRows = Int
r
nCols :: Int
nCols = Int
c
diagVals :: Vector Double
diagVals = Vector Double
v
type IV t = CInt -> Ptr CInt -> t
type V t = CInt -> Ptr Double -> t
type SMxV = V (IV (IV (V (V (IO CInt)))))
gmXv :: GMatrix -> Vector Double -> Vector Double
gmXv :: GMatrix -> Vector Double -> Vector Double
gmXv SparseR { gmCSR :: GMatrix -> CSR
gmCSR = CSR{Int
Vector Double
Vector CInt
csrNCols :: Int
csrNRows :: Int
csrRows :: Vector CInt
csrCols :: Vector CInt
csrVals :: Vector Double
csrNCols :: CSR -> Int
csrNRows :: CSR -> Int
csrRows :: CSR -> Vector CInt
csrCols :: CSR -> Vector CInt
csrVals :: CSR -> Vector Double
..}, Int
nCols :: Int
nRows :: Int
nCols :: GMatrix -> Int
nRows :: GMatrix -> Int
.. } Vector Double
v = IO (Vector Double) -> Vector Double
forall a. IO a -> a
unsafePerformIO (IO (Vector Double) -> Vector Double)
-> IO (Vector Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nCols) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
String -> IO ()
forall a. HasCallStack => String -> a
error (String -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"gmXv (CSR): incorrect sizes: (%d,%d) x %d" Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v))
Vector Double
r <- Int -> IO (Vector Double)
forall a. Storable a => Int -> IO (Vector a)
createVector Int
nRows
(Vector Double
csrVals Vector Double
-> ((CInt
-> Ptr CInt
-> CInt
-> Ptr CInt
-> CInt
-> Ptr Double
-> CInt
-> Ptr Double
-> IO CInt)
-> IO CInt)
-> Trans
(Vector Double)
(CInt
-> Ptr CInt
-> CInt
-> Ptr CInt
-> CInt
-> Ptr Double
-> CInt
-> Ptr Double
-> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
csrCols Vector CInt
-> ((CInt
-> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt)
-> Trans
(Vector CInt)
(CInt
-> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
csrRows Vector CInt
-> ((CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt)
-> Trans
(Vector CInt) (CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector Double
v Vector Double
-> Vector Double
-> Trans (Vector Double) (Trans (Vector Double) (IO CInt))
-> IO CInt
forall c c1 r.
(TransArray c, TransArray c1) =>
c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
#! Vector Double
r) Trans
(Vector Double)
(CInt
-> Ptr CInt
-> CInt
-> Ptr CInt
-> CInt
-> Ptr Double
-> CInt
-> Ptr Double
-> IO CInt)
SMxV
c_smXv IO CInt -> String -> IO ()
#|String
"CSRXv"
Vector Double -> IO (Vector Double)
forall (m :: * -> *) a. Monad m => a -> m a
return Vector Double
r
gmXv SparseC { gmCSC :: GMatrix -> CSC
gmCSC = CSC{Int
Vector Double
Vector CInt
cscNCols :: Int
cscNRows :: Int
cscCols :: Vector CInt
cscRows :: Vector CInt
cscVals :: Vector Double
cscNCols :: CSC -> Int
cscNRows :: CSC -> Int
cscCols :: CSC -> Vector CInt
cscRows :: CSC -> Vector CInt
cscVals :: CSC -> Vector Double
..}, Int
nCols :: Int
nRows :: Int
nCols :: GMatrix -> Int
nRows :: GMatrix -> Int
.. } Vector Double
v = IO (Vector Double) -> Vector Double
forall a. IO a -> a
unsafePerformIO (IO (Vector Double) -> Vector Double)
-> IO (Vector Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nCols) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
String -> IO ()
forall a. HasCallStack => String -> a
error (String -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"gmXv (CSC): incorrect sizes: (%d,%d) x %d" Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v))
Vector Double
r <- Int -> IO (Vector Double)
forall a. Storable a => Int -> IO (Vector a)
createVector Int
nRows
(Vector Double
cscVals Vector Double
-> ((CInt
-> Ptr CInt
-> CInt
-> Ptr CInt
-> CInt
-> Ptr Double
-> CInt
-> Ptr Double
-> IO CInt)
-> IO CInt)
-> Trans
(Vector Double)
(CInt
-> Ptr CInt
-> CInt
-> Ptr CInt
-> CInt
-> Ptr Double
-> CInt
-> Ptr Double
-> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
cscRows Vector CInt
-> ((CInt
-> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt)
-> Trans
(Vector CInt)
(CInt
-> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
cscCols Vector CInt
-> ((CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt)
-> Trans
(Vector CInt) (CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector Double
v Vector Double
-> Vector Double
-> Trans (Vector Double) (Trans (Vector Double) (IO CInt))
-> IO CInt
forall c c1 r.
(TransArray c, TransArray c1) =>
c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
#! Vector Double
r) Trans
(Vector Double)
(CInt
-> Ptr CInt
-> CInt
-> Ptr CInt
-> CInt
-> Ptr Double
-> CInt
-> Ptr Double
-> IO CInt)
SMxV
c_smTXv IO CInt -> String -> IO ()
#|String
"CSCXv"
Vector Double -> IO (Vector Double)
forall (m :: * -> *) a. Monad m => a -> m a
return Vector Double
r
gmXv Diag{Int
Vector Double
nCols :: Int
nRows :: Int
diagVals :: Vector Double
diagVals :: GMatrix -> Vector Double
nCols :: GMatrix -> Int
nRows :: GMatrix -> Int
..} Vector Double
v
| Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nCols
= [Vector Double] -> Vector Double
forall t. Storable t => [Vector t] -> Vector t
vjoin [ Int -> Int -> Vector Double -> Vector Double
forall t. Storable t => Int -> Int -> Vector t -> Vector t
subVector Int
0 (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) Vector Double
v Vector Double -> Vector Double -> Vector Double
forall (c :: * -> *) e. Container c e => c e -> c e -> c e
`mul` Vector Double
diagVals
, Double -> Int -> Vector Double
forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst Double
0 (Int
nRows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) ]
| Bool
otherwise = String -> Vector Double
forall a. HasCallStack => String -> a
error (String -> Vector Double) -> String -> Vector Double
forall a b. (a -> b) -> a -> b
$ String -> Int -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)
gmXv Dense{Int
Matrix Double
nCols :: Int
nRows :: Int
gmDense :: Matrix Double
gmDense :: GMatrix -> Matrix Double
nCols :: GMatrix -> Int
nRows :: GMatrix -> Int
..} Vector Double
v
| Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nCols
= Matrix Double -> Vector Double -> Vector Double
forall t. Product t => Matrix t -> Vector t -> Vector t
mXv Matrix Double
gmDense Vector Double
v
| Bool
otherwise = String -> Vector Double
forall a. HasCallStack => String -> a
error (String -> Vector Double) -> String -> Vector Double
forall a b. (a -> b) -> a -> b
$ String -> Int -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"gmXv (Dense): incorrect sizes: (%d,%d) x %d"
Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)
infixr 8 !#>
(!#>) :: GMatrix -> Vector Double -> Vector Double
!#> :: GMatrix -> Vector Double -> Vector Double
(!#>) = GMatrix -> Vector Double -> Vector Double
gmXv
foreign import ccall unsafe "smXv"
c_smXv :: SMxV
foreign import ccall unsafe "smTXv"
c_smTXv :: SMxV
toDense :: AssocMatrix -> Matrix Double
toDense :: AssocMatrix -> Matrix Double
toDense AssocMatrix
asm = IndexOf Matrix -> Double -> AssocMatrix -> Matrix Double
forall (c :: * -> *) e.
Container c e =>
IndexOf c -> e -> [(IndexOf c, e)] -> c e
assoc (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1,Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Double
0 AssocMatrix
asm
where
(Int
r,Int
c) = ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Int] -> Int) -> ([Int] -> Int) -> ([Int], [Int]) -> (Int, Int)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum) (([Int], [Int]) -> (Int, Int))
-> ([((Int, Int), Double)] -> ([Int], [Int]))
-> [((Int, Int), Double)]
-> (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Int, Int)] -> ([Int], [Int])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Int, Int)] -> ([Int], [Int]))
-> ([((Int, Int), Double)] -> [(Int, Int)])
-> [((Int, Int), Double)]
-> ([Int], [Int])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Int, Int), Double) -> (Int, Int))
-> [((Int, Int), Double)] -> [(Int, Int)]
forall a b. (a -> b) -> [a] -> [b]
map ((Int, Int), Double) -> (Int, Int)
forall a b. (a, b) -> a
fst ([((Int, Int), Double)] -> (Int, Int))
-> [((Int, Int), Double)] -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ [((Int, Int), Double)]
AssocMatrix
asm
instance Transposable CSR CSC
where
tr :: CSR -> CSC
tr (CSR Vector Double
vs Vector CInt
cs Vector CInt
rs Int
n Int
m) = Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSC
CSC Vector Double
vs Vector CInt
cs Vector CInt
rs Int
m Int
n
tr' :: CSR -> CSC
tr' = CSR -> CSC
forall m mt. Transposable m mt => m -> mt
tr
instance Transposable CSC CSR
where
tr :: CSC -> CSR
tr (CSC Vector Double
vs Vector CInt
rs Vector CInt
cs Int
n Int
m) = Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSR
CSR Vector Double
vs Vector CInt
rs Vector CInt
cs Int
m Int
n
tr' :: CSC -> CSR
tr' = CSC -> CSR
forall m mt. Transposable m mt => m -> mt
tr
instance Transposable GMatrix GMatrix
where
tr :: GMatrix -> GMatrix
tr (SparseR CSR
s Int
n Int
m) = CSC -> Int -> Int -> GMatrix
SparseC (CSR -> CSC
forall m mt. Transposable m mt => m -> mt
tr CSR
s) Int
m Int
n
tr (SparseC CSC
s Int
n Int
m) = CSR -> Int -> Int -> GMatrix
SparseR (CSC -> CSR
forall m mt. Transposable m mt => m -> mt
tr CSC
s) Int
m Int
n
tr (Diag Vector Double
v Int
n Int
m) = Vector Double -> Int -> Int -> GMatrix
Diag Vector Double
v Int
m Int
n
tr (Dense Matrix Double
a Int
n Int
m) = Matrix Double -> Int -> Int -> GMatrix
Dense (Matrix Double -> Matrix Double
forall m mt. Transposable m mt => m -> mt
tr Matrix Double
a) Int
m Int
n
tr' :: GMatrix -> GMatrix
tr' = GMatrix -> GMatrix
forall m mt. Transposable m mt => m -> mt
tr