module Data.Sparse.SpMatrix where
import Control.Exception.Common
import Data.Sparse.SpVector
import Data.Sparse.Utils
import Data.Sparse.Types
import Numeric.Eps
import Numeric.LinearAlgebra.Class
import Data.Sparse.Internal.IntM (IntM (..))
import qualified Data.Sparse.Internal.IntM as I
import Data.Sparse.Internal.IntMap2
import GHC.Exts
import qualified Data.IntMap.Strict as IM
import Data.Complex
import Data.Foldable (foldl')
import Data.Maybe
import Data.VectorSpace hiding (magnitude)
data SpMatrix a = SM {smDim :: !(Rows, Cols),
smData :: !(IntM (IntM a))}
deriving (Eq, Functor, Foldable)
sizeStr :: (FDSize f ~ (a1, a2), Sparse f a, Show a2, Show a1) => f a -> String
sizeStr sm =
unwords ["(",show nr,"rows,",show nc,"columns ) ,",show nz,"NZ ( sparsity",show sy,")"] where
(nr, nc) = dim sm
nz = nnz sm
sy = spy sm :: Double
instance Show a => Show (SpMatrix a) where
show sm@(SM _ x) = unwords ["SM",sizeStr sm,show (toList $ toList <$> x)]
instance Set SpMatrix where
liftU2 f2 (SM n1 x1) (SM n2 x2) = SM (maxTup n1 n2) ((liftU2.liftU2) f2 x1 x2)
liftI2 f2 (SM n1 x1) (SM n2 x2) = SM (minTup n1 n2) ((liftI2.liftI2) f2 x1 x2)
instance Num a => AdditiveGroup (SpMatrix a) where
zeroV = SM (0,0) I.empty
(^+^) = liftU2 (+)
negateV = fmap negate
(^-^) = liftU2 ()
instance FiniteDim SpMatrix where
type FDSize SpMatrix = (Rows, Cols)
dim = smDim
instance HasData SpMatrix a where
type HDData SpMatrix a = IntM (IntM a)
nnz = nzSM
dat = smData
instance Sparse SpMatrix a where
spy = spySM
instance Num a => SpContainer SpMatrix a where
type ScIx SpMatrix = (Rows, Cols)
scInsert (i,j) = insertSpMatrix i j
scLookup m (i, j) = lookupSM m i j
m @@ d | isValidIxSM m d = m @@! d
| otherwise = error $ "@@ : incompatible indices : matrix size is " ++ show (dim m) ++ ", but user looked up " ++ show d
zeroSM :: Rows -> Cols -> SpMatrix a
zeroSM m n = SM (m,n) I.empty
mkDiagonal :: Int -> [a] -> SpMatrix a
mkDiagonal n = mkSubDiagonal n 0
eye :: Num a => Int -> SpMatrix a
eye n = mkDiagonal n (replicate n 1)
permutationSM :: Num a => Int -> [IxRow] -> SpMatrix a
permutationSM n iis = permutPairsSM n (zip [0 .. n1] iis)
permutPairsSM :: Num a => Int -> [(IxRow, IxRow)] -> SpMatrix a
permutPairsSM n iix = go iix (eye n) where
go ((i1, i2):iis) m = go iis (swapRows i1 i2 m)
go [] m = m
mkSubDiagonal :: Int -> Int -> [a] -> SpMatrix a
mkSubDiagonal n o xx | abs o < n = if o >= 0
then fz ii jj xx
else fz jj ii xx
| otherwise = error "mkSubDiagonal : offset > dimension" where
ii = [0 .. n1]
jj = [abs o .. n 1]
fz a b x = fromListSM (n,n) (zip3 a b x)
insertSpMatrix :: IxRow -> IxCol -> a -> SpMatrix a -> SpMatrix a
insertSpMatrix i j x s
| inBounds02 d (i,j) = SM d $ insertIM2 i j x smd
| otherwise = error "insertSpMatrix : index out of bounds" where
smd = immSM s
d = dim s
fromListSM' :: Foldable t => t (IxRow, IxCol, a) -> SpMatrix a -> SpMatrix a
fromListSM' iix sm = foldl' ins sm iix where
ins t (i,j,x) = insertSpMatrix i j x t
fromListSM :: Foldable t => (Int, Int) -> t (IxRow, IxCol, a) -> SpMatrix a
fromListSM (m,n) iix = fromListSM' iix (zeroSM m n)
mkSpMR :: Foldable t =>
(Int, Int) -> t (IxRow, IxCol, Double) -> SpMatrix Double
mkSpMR d ixv = fromListSM d ixv :: SpMatrix Double
mkSpMC :: Foldable t =>
(Int, Int) -> t (IxRow, IxCol, Complex Double) -> SpMatrix (Complex Double)
mkSpMC d ixv = fromListSM d ixv :: SpMatrix (Complex Double)
fromListDenseSM :: Int -> [a] -> SpMatrix a
fromListDenseSM m ll = fromListSM (m, n) $ indexed2 m ll where
n = length ll `div` m
toListSM :: SpMatrix t -> [(IxRow, IxCol, t)]
toListSM = ifoldlSM buildf [] where
buildf i j x y = (i, j, x) : y
toDenseListSM :: Num t => SpMatrix t -> [(IxRow, IxCol, t)]
toDenseListSM m =
[(i, j, m @@ (i, j)) | i <- [0 .. nrows m 1], j <- [0 .. ncols m 1]]
lookupSM :: SpMatrix a -> IxRow -> IxCol -> Maybe a
lookupSM (SM _ im) i j = I.lookup i im >>= I.lookup j
lookupWD_SM, (@@!):: Num a => SpMatrix a -> (IxRow, IxCol) -> a
lookupWD_SM sm (i,j) =
fromMaybe 0 (lookupSM sm i j)
(@@!) = lookupWD_SM
filterSM :: (IM.Key -> IM.Key -> a -> Bool) -> SpMatrix a -> SpMatrix a
filterSM f sm = SM (dim sm) $ ifilterIM2 f (dat sm)
extractDiag, extractSuperDiag, extractSubDiag :: SpMatrix a -> SpMatrix a
extractSubDiag = filterSM (\i j _ -> i > j)
extractSuperDiag = filterSM (\i j _ -> i < j)
extractDiag = filterSM (\i j _ -> i == j)
extractSubmatrixSM ::
(IM.Key -> IM.Key) ->
(IM.Key -> IM.Key) ->
SpMatrix a ->
(IxRow, IxRow) -> (IxCol, IxCol) ->
SpMatrix a
extractSubmatrixSM fi gi (SM (r, c) im) (i1, i2) (j1, j2)
| q = SM (m', n') imm'
| otherwise = error $ "extractSubmatrixSM : invalid index " ++ show (i1, i2) ++ ", " ++ show (j1, j2) where
imm' = mapKeysIM2 fi gi $
I.filterI (not . null) $
ifilterIM2 ff im
ff i j _ = i1 <= i &&
i <= i2 &&
j1 <= j &&
j <= j2
(m', n') = (i2i1 + 1, j2j1 + 1)
q = inBounds0 r i1 &&
inBounds0 r i2 &&
inBounds0 c j1 &&
inBounds0 c j2 &&
i2 >= i1
extractSubmatrixRebalanceKeys ::
SpMatrix a -> (IxRow, IxRow) -> (IxCol, IxCol) -> SpMatrix a
extractSubmatrixRebalanceKeys mm (i1,i2) (j1,j2) =
extractSubmatrixSM (\i -> i i1) (\j -> j j1) mm (i1,i2) (j1,j2)
extractSubmatrix :: SpMatrix a -> (IxRow, IxRow) -> (IxCol, IxCol) -> SpMatrix a
extractSubmatrix = extractSubmatrixSM id id
takeRows :: IxRow -> SpMatrix a -> SpMatrix a
takeRows n mm = extractSubmatrix mm (0, n1) (0, ncols mm 1)
takeCols :: IxCol -> SpMatrix a -> SpMatrix a
takeCols n mm = extractSubmatrix mm (0, nrows mm 1) (0, n 1)
extractColSM :: SpMatrix a -> IxCol -> SpMatrix a
extractColSM sm j = extractSubmatrix sm (0, nrows sm 1) (j, j)
extractSubColSM :: SpMatrix a -> IxCol -> (IxRow, IxRow) -> SpMatrix a
extractSubColSM sm j (i1, i2) = extractSubmatrix sm (i1, i2) (j, j)
extractSubColSM_RK :: SpMatrix a -> IxCol -> (IxRow, IxRow) -> SpMatrix a
extractSubColSM_RK sm j (i1, i2) =
extractSubmatrixRebalanceKeys sm (i1, i2) (j, j)
isValidIxSM :: SpMatrix a -> (Int, Int) -> Bool
isValidIxSM mm = inBounds02 (dim mm)
isSquareSM :: SpMatrix a -> Bool
isSquareSM m = nrows m == ncols m
isDiagonalSM :: SpMatrix a -> Bool
isDiagonalSM m = I.size d == nrows m where
d = I.filterWithKey ff (immSM m)
ff irow row = I.size row == 1 &&
I.size (I.filterWithKey (\j _ -> j == irow) row) == 1
isLowerTriSM, isUpperTriSM :: Eq a => SpMatrix a -> Bool
isLowerTriSM m = m == lm where
lm = ifilterSM (\i j _ -> i >= j) m
isUpperTriSM m = m == lm where
lm = ifilterSM (\i j _ -> i <= j) m
isOrthogonalSM sm@(SM (_,n) _) = rsm == eye n where
rsm = roundZeroOneSM $ transpose sm ## sm
immSM (SM _ imm) = imm
dimSM :: SpMatrix t -> (Rows, Cols)
dimSM (SM d _) = d
nelSM :: SpMatrix t -> Int
nelSM (SM (nr,nc) _) = nr*nc
nrows :: SpMatrix a -> Rows
nrows = fst . dim
ncols :: SpMatrix a -> Cols
ncols = snd . dim
data SMInfo = SMInfo { smNz :: Int,
smSpy :: Double} deriving (Eq, Show)
infoSM :: SpMatrix a -> SMInfo
infoSM s = SMInfo (nzSM s) (spySM s)
nzSM :: SpMatrix a -> Int
nzSM s = sum $ fmap I.size (immSM s)
spySM :: Fractional b => SpMatrix a -> b
spySM s = fromIntegral (nzSM s) / fromIntegral (nelSM s)
nzRow :: SpMatrix a -> IM.Key -> Int
nzRow s i | inBounds0 (nrows s) i = nzRowU s i
| otherwise = error "nzRow : index out of bounds" where
nzRowU :: SpMatrix a -> IM.Key -> Int
nzRowU s i = maybe 0 I.size (I.lookup i $ immSM s)
bwMinSM :: SpMatrix a -> Int
bwMinSM = fst . bwBoundsSM
bwMaxSM :: SpMatrix a -> Int
bwMaxSM = snd . bwBoundsSM
bwBoundsSM :: SpMatrix a -> (Int, Int)
bwBoundsSM s =
(snd $ I.findMin b,
snd $ I.findMax b)
where
ss = immSM s
fmi = fst . I.findMin
fma = fst . I.findMax
b = fmap (\x -> fma x fmi x + 1:: Int) ss
vertStackSM, (-=-) :: SpMatrix a -> SpMatrix a -> SpMatrix a
vertStackSM mm1 mm2 = SM (m, n) $ I.union u1 u2 where
nro1 = nrows mm1
m = nro1 + nrows mm2
n = max (ncols mm1) (ncols mm2)
u1 = immSM mm1
u2 = I.mapKeys (+ nro1) (immSM mm2)
(-=-) = vertStackSM
horizStackSM, (-||-) :: SpMatrix a -> SpMatrix a -> SpMatrix a
horizStackSM mm1 mm2 = t (t mm1 -=- t mm2) where
t = transposeSM
(-||-) = horizStackSM
fromBlocksDiag :: [SpMatrix a] -> SpMatrix a
fromBlocksDiag mml = fromListSM (n, n) lstot where
dims = map nrows mml
n = sum dims
shifts = init $ scanl (+) 0 dims
lstot = concat $ zipWith shiftDims shifts $ map toListSM mml --lsts
shiftDims s = map (\(i,j,x) -> (i + s, j + s, x))
ifilterSM :: (IM.Key -> IM.Key -> a -> Bool) -> SpMatrix a -> SpMatrix a
ifilterSM f (SM d im) = SM d $ ifilterIM2 f im
foldlSM :: (a -> b -> b) -> b -> SpMatrix a -> b
foldlSM f n (SM _ m)= foldlIM2 f n m
ifoldlSM :: (IM.Key -> IM.Key -> a -> b -> b) -> b -> SpMatrix a -> b
ifoldlSM f n (SM _ m) = ifoldlIM2' f n m
countSubdiagonalNZSM :: SpMatrix a -> Int
countSubdiagonalNZSM (SM _ im) = countSubdiagonalNZ im
subdiagIndicesSM :: SpMatrix a -> [(IxRow, IxCol)]
subdiagIndicesSM (SM _ im) = subdiagIndices im
sparsifyIM2 ::
Epsilon a => I.IntM (I.IntM a) -> I.IntM (I.IntM a)
sparsifyIM2 = ifilterIM2 (\_ _ x -> isNz x)
sparsifySM :: Epsilon a => SpMatrix a -> SpMatrix a
sparsifySM (SM d im) = SM d $ sparsifyIM2 im
roundZeroOneSM :: Epsilon a => SpMatrix a -> SpMatrix a
roundZeroOneSM (SM d im) = sparsifySM $ SM d $ mapIM2 roundZeroOne im
modifyKeysSM' :: (IxRow -> a) -> (IxCol -> b) -> SpMatrix c -> [(a, b, c)]
modifyKeysSM' fi fj mm = zip3 (fi <$> ii) (fj <$> jj) xx where
(ii, jj, xx) = unzip3 $ toListSM mm
modifyKeysSM :: (IxRow -> IxRow) -> (IxCol -> IxCol) -> SpMatrix a -> SpMatrix a
modifyKeysSM fi fj mm = fromListSM (dim mm) $ zip3 (fi <$> ii) (fj <$> jj) xx where
(ii, jj, xx) = unzip3 $ toListSM mm
swapRows :: IxRow -> IxRow -> SpMatrix a -> SpMatrix a
swapRows i1 i2 (SM d im) = SM d $ I.insert i1 ro2 im' where
ro1 = im I.! i1
ro2 = im I.! i2
im' = I.insert i2 ro1 im
swapRowsSafe :: IxRow -> IxRow -> SpMatrix a -> SpMatrix a
swapRowsSafe i1 i2 m
| inBounds02 (nro, nro) (i1, i2) = swapRows i1 i2 m
| otherwise =
error $ "swapRowsSafe : index out of bounds " ++ show (i1, i2)
where nro = nrows m
transposeSM :: SpMatrix a -> SpMatrix a
transposeSM (SM (m, n) im) = SM (n, m) (transposeIM2 im)
hermitianConj :: Num a => SpMatrix (Complex a) -> SpMatrix (Complex a)
hermitianConj m = conjugate <$> transposeSM m
matScale :: Num a => a -> SpMatrix a -> SpMatrix a
matScale a = fmap (* a)
trace :: Num b => SpMatrix b -> b
trace m = foldlSM (+) 0 $ extractDiag m
normFrobeniusSM :: (MatrixRing (SpMatrix a), Floating a) => SpMatrix a -> a
normFrobeniusSM m = sqrt $ trace (m ##^ m)
normFrobeniusSMC ::
(MatrixRing (SpMatrix (Complex a)), RealFloat a) => SpMatrix (Complex a) -> a
normFrobeniusSMC m = sqrt $ magnitude $ trace (m ##^ m)
instance MatrixRing (SpMatrix Double) where
type MatrixNorm (SpMatrix Double) = Double
(##) = matMat_ AB
(##^) = matMat_ ABt
transpose = transposeSM
normFrobenius = normFrobeniusSM
instance MatrixRing (SpMatrix (Complex Double)) where
type MatrixNorm (SpMatrix (Complex Double)) = Double
(##) = matMat_ AB
(##^) = matMat_ ABt
transpose = hermitianConj
normFrobenius = normFrobeniusSMC
data MatProd_ = AB | ABt deriving (Eq, Show)
matMat_ pt mm1 mm2 =
case pt of AB -> matMatCheck (matMatUnsafeWith transposeIM2) mm1 mm2
ABt -> matMatCheck (matMatUnsafeWith id) mm1 (trDim mm2)
where
trDim (SM (a, b) x) = SM (b, a) x
matMatCheck mmf m1 m2
| c1 == r2 = mmf m1 m2
| otherwise = error $ "matMat : incompatible matrix sizes" ++ show (d1, d2)
where
d1@(_, c1) = dim m1
d2@(r2, _) = dim m2
matMatUnsafeWith ff2 m1 m2 = SM (nrows m1, ncols m2) (overRows2 <$> immSM m1) where
overRows2 vm1 = (`dott` vm1) <$> ff2 (immSM m2)
dott x y = sum $ liftI2 (*) x y
matMatSparsified, (#~#) :: (MatrixRing (SpMatrix a), Epsilon a) =>
SpMatrix a -> SpMatrix a -> SpMatrix a
matMatSparsified m1 m2 = sparsifySM $ m1 ## m2
(#~#) = matMatSparsified
(#~#^) :: (MatrixRing (SpMatrix a), Epsilon a) =>
SpMatrix a -> SpMatrix a -> SpMatrix a
a #~^# b = transpose a #~# b
(#~^#) :: (MatrixRing (SpMatrix a), Epsilon a) =>
SpMatrix a -> SpMatrix a -> SpMatrix a
a #~#^ b = a #~# transpose b
contractSub :: Elt a => SpMatrix a -> SpMatrix a -> IxRow -> IxCol -> Int -> a
contractSub a b i j n
| ncols a == nrows b &&
isValidIxSM a (i,j) &&
n <= ncols a = sum $ map (\i' -> (a@@!(i,i'))*b@@!(i',j)) [0 .. n]
| otherwise = error "contractSub : n must be <= i"