module Numeric.LinearAlgebra.Sparse
(
linSolve0, LinSolveMethod(..), (<\>),
pinv,
jacobiPre, ilu0, mSsor,
luSolve,
triLowerSolve,
triUpperSolve,
eigsQR,
eigRayleigh,
qr,
lu,
chol,
arnoldi,
diagPartitions,
givens,
conditionNumberSM,
hhMat, hhRefl,
fromListSV, toListSV,
fromListSM, toListSM,
untilConvergedG0, untilConvergedG, untilConvergedGM,
modifyInspectGuarded, modifyInspectGuardedM, IterationConfig (..),
modifyUntil, modifyUntilM
)
where
import Control.Exception.Common
import Control.Iterative
import Data.Sparse.Common
import Control.Monad.Catch
import Data.Typeable
import Control.Monad.State.Strict
import qualified Control.Monad.Trans.State as MTS
import Data.Complex
import Data.VectorSpace hiding (magnitude)
import qualified Data.Sparse.Internal.IntM as I
import Data.Maybe
import qualified Data.Vector as V
type Num' x = (Epsilon x, Elt x, Show x, Ord x)
conditionNumberSM :: (MonadThrow m, MatrixRing (SpMatrix a), Num' a, Typeable a) =>
SpMatrix a -> m a
conditionNumberSM m = do
(_, r) <- qr m
let
u = extractDiagDense r
lmax = abs (maximum u)
lmin = abs (minimum u)
kappa = lmax / lmin
if nearZero lmin
then throwM (HugeConditionNumber "conditionNumberSM" kappa)
else return kappa
hhMat :: Num a => a -> SpVector a -> SpMatrix a
hhMat beta x = eye n ^-^ beta `scale` (x >< x) where
n = dim x
hhRefl :: Num a => SpVector a -> SpMatrix a
hhRefl = hhMat (fromInteger 2)
givens :: (Elt a, MonadThrow m) => SpMatrix a -> Int -> Int -> m (SpMatrix a)
givens aa i j
| isValidIxSM aa (i,j) && nrows aa >= ncols aa = do
i' <- candidateRows' (immSM aa) i j
return $ givensMat aa i i' j
| otherwise = throwM (OOBIxsError "Givens" [i, j])
where
givensMat mm i i' j =
fromListSM'
[(i,i, c), (j,j, conj c), (j,i, conj s), (i,j, s)]
(eye (nrows mm))
where
(c, s, _) = givensCoef a b
a = mm @@ (i', j)
b = mm @@ (i, j)
candidateRows' mm i j | null u = throwM (OOBNoCompatRows "Givens" (i,j))
| otherwise = return $ head (I.keys u) where
u = I.filterWithKey (\irow row -> irow /= i &&
firstNZColumn row j) mm
firstNZColumn m k = isJust (I.lookup k m) &&
isNothing (I.lookupLT k m)
rotMat :: Elt e => e -> e -> IxRow -> IxRow -> Int -> SpMatrix e
rotMat a b i j n =
fromListSM' [(i,i, c), (j,j, conj c), (j,i, conj s), (i,j, s)] (eye n)
where
(c, s, _) = givensCoef a b
givensCoef :: Elt t => t -> t -> (t, t, t)
givensCoef a b = (c0/r, s0/r, r) where
c0 = conj a
s0 = conj b
r = hypot c0 s0
hypot x y = abs x * sqrt (1 + (y/x)**2)
qr :: (Elt a, MatrixRing (SpMatrix a), Epsilon a, MonadThrow m) =>
SpMatrix a -> m (SpMatrix a, SpMatrix a)
qr mm = do
(qt, r, _) <- MTS.execStateT (modifyUntilM haltf qrstepf) gminit
return (transpose qt, r)
where
gminit = (eye (nrows mm), mm, subdiagIndicesSM mm)
haltf (_, _, iis) = null iis
qrstepf (qmatt, m, iis) = do
let (i, j) = head iis
g <- givens m i j
let
qmatt' = g #~# qmatt
m' = g #~# m
return (qmatt', m', tail iis)
eigsQR :: (MonadThrow m, MonadIO m, Elt a, Normed (SpVector a), MatrixRing (SpMatrix a), Epsilon a, Typeable (Magnitude (SpVector a)), Typeable a, Show a) =>
Int
-> Bool
-> SpMatrix a
-> m (SpVector a)
eigsQR nitermax debq m = pf <$> untilConvergedGM "eigsQR" c (const True) stepf m
where
pf = extractDiagDense
c = IterConf nitermax debq pf prd
stepf mm = do
(q, _) <- qr mm
return $ q #~^# (m ## q)
eigRayleigh nitermax debq prntf m = untilConvergedGM "eigRayleigh" config (const True) (rayStep m)
where
ii = eye (nrows m)
config = IterConf nitermax debq fst prntf
rayStep aa (b, mu) = do
nom <- (m ^-^ (mu `matScale` ii)) <\> b
let b' = normalize2' nom
mu' = (b' <.> (aa #> b')) / (b' <.> b')
return (b', mu')
eigArnoldi :: (Scalar (SpVector t) ~ t, MatrixType (SpVector t) ~ SpMatrix t,
Elt t, Normed (SpVector t), MatrixRing (SpMatrix t),
LinearVectorSpace (SpVector t), Epsilon t, MonadThrow m) =>
Int
-> SpMatrix t
-> SpVector t
-> m (SpMatrix t, SpMatrix t, SpVector t)
eigArnoldi nitermax aa b = do
(q, h) <- arnoldi aa b nitermax
(o, r) <- qr h
return (q, o, extractDiagDense r)
hhV :: (Scalar (SpVector t) ~ t, Elt t, InnerSpace (SpVector t), Epsilon t) =>
SpVector t -> (SpVector t, t)
hhV x = (v, beta) where
tx = tailSV x
sigma = tx <.> tx
vtemp = singletonSV 1 `concatSV` tx
(v, beta) | nearZero sigma = (vtemp, 0)
| otherwise = let mu = sqrt (headSV x**2 + sigma)
xh = headSV x
vh | mag xh <= 1 = xh mu
| otherwise = sigma / (xh + mu)
vnew = (1 / vh) `scale` insertSpVector 0 vh vtemp
in (vnew, 2 * xh**2 / (sigma + vh**2))
chol :: (Elt a, Epsilon a, MonadThrow m) =>
SpMatrix a
-> m (SpMatrix a)
chol aa = do
let n = nrows aa
q (i, _) = i == n
l0 <- cholUpd aa (0, zeroSM n n)
(_, lfin) <- MTS.execStateT (modifyUntilM q (cholUpd aa)) l0
return lfin
where
oops i = throwM (NeedsPivoting "chol" (unwords ["L", show (i,i)]) :: MatrixException Double)
cholUpd aa (i, ll) = do
sd <- cholSDRowUpd aa ll i
ll' <- cholDiagUpd aa sd i
return (i + 1, ll')
cholSDRowUpd aa ll i = do
lrs <- fromListSV (i + 1) <$> onRangeSparseM cholSubDiag [0 .. i1]
return $ insertRow ll lrs i where
cholSubDiag j | isNz ljj = return $ 1/ljj*(aij inn)
| otherwise = oops j
where
ljj = ll @@! (j, j)
aij = aa @@! (i, j)
inn = contractSub ll ll i j (j 1)
cholDiagUpd aa ll i = do
cd <- cholDiag
return $ insertSpMatrix i i cd ll where
cholDiag | i == 0 = sqrt <$> aai
| otherwise = do
a <- aai
let l = sum (fmap (**2) lrow)
return $ sqrt (a l)
where
lrow = ifilterSV (\j _ -> j < i) (extractRow ll i)
aai | isNz aaii = return aaii
| otherwise = oops i
where
aaii = aa @@! (i,i)
lu :: (Scalar (SpVector t) ~ t, Elt t, VectorSpace (SpVector t), Epsilon t,
MonadThrow m) =>
SpMatrix t
-> m (SpMatrix t, SpMatrix t)
lu aa = do
let oops j = throwM (NeedsPivoting "solveForLij" ("U" ++ show (j, j)) :: MatrixException Double)
n = nrows aa
q (i, _, _) = i == n 1
luInit | isNz u00 = return (1, l0, u0)
| otherwise = oops (0 :: Int)
where
l0 = insertCol (eye n) ((extractSubCol aa 0 (1, n1)) ./ u00 ) 0
u0 = insertRow (zeroSM n n) (extractRow aa 0) 0
u00 = u0 @@! (0,0)
luUpd (i, l, u) = do
u' <- uUpd aa n (i, l, u)
l' <- lUpd (i, l, u')
return (i + 1, l', u')
uUpd aa n (ix, lmat, umat) = do
let us = onRangeSparse (solveForUij ix) [ix .. n 1]
solveForUij i j = a p where
a = aa @@! (i, j)
p = contractSub lmat umat i j (i 1)
return $ insertRow umat (fromListSV n us) ix
lUpd (ix, lmat, umat) = do
ls <- lsm
return $ insertCol lmat (fromListSV n ls) ix
where
lsm = onRangeSparseM (`solveForLij` ix) [ix + 1 .. n 1]
solveForLij i j
| isNz ujj = return $ (a p)/ujj
| otherwise = oops j
where
a = aa @@! (i, j)
ujj = umat @@! (j , j)
p = contractSub lmat umat i j (i 1)
s0 <- luInit
(ixf, lf, uf) <- MTS.execStateT (modifyUntilM q luUpd) s0
ufin <- uUpd aa n (ixf, lf, uf)
return (lf, ufin)
arnoldi :: (MatrixType (SpVector a) ~ SpMatrix a, V (SpVector a) ,
Scalar (SpVector a) ~ a, Epsilon a, MonadThrow m) =>
SpMatrix a
-> SpVector a
-> Int
-> m (SpMatrix a, SpMatrix a)
arnoldi aa b kn | n == nb = return (fromCols qvfin, fromListSM (nmax + 1, nmax) hhfin)
| otherwise = throwM (MatVecSizeMismatchException "arnoldi" (m,n) nb)
where
(qvfin, hhfin, nmax, _) = execState (modifyUntil tf arnoldiStep) arnInit
tf (_, _, ii, fbreak) = ii == kn || fbreak
(m, n) = (nrows aa, ncols aa)
nb = dim b
arnInit = (qv1, hh1, 1, False) where
q0 = normalize2 b
aq0 = aa #> q0
h11 = q0 `dot` aq0
q1nn = aq0 ^-^ (h11 .* q0)
hh1 = V.fromList [(0, 0, h11), (1, 0, h21)] where
h21 = norm2' q1nn
q1 = normalize2 q1nn
qv1 = V.fromList [q0, q1]
arnoldiStep (qv, hh, i, _) = (qv', hh', i + 1, breakf) where
qi = V.last qv
aqi = aa #> qi
hhcoli = fmap (`dot` aqi) qv
zv = zeroSV m
qipnn =
aqi ^-^ V.foldl' (^+^) zv (V.zipWith (.*) hhcoli qv)
qipnorm = norm2' qipnn
qip = normalize2 qipnn
hh' = (V.++) hh (indexed2 $ V.snoc hhcoli qipnorm) where
indexed2 v = V.zip3 ii jj v
ii = V.fromList [0 .. n]
jj = V.replicate (n + 1) i
qv' = V.snoc qv qip
breakf | nearZero qipnorm = True
| otherwise = False
diagPartitions :: SpMatrix a
-> (SpMatrix a, SpMatrix a, SpMatrix a)
diagPartitions aa = (e,d,f) where
e = extractSubDiag aa
d = extractDiag aa
f = extractSuperDiag aa
jacobiPre :: Fractional a => SpMatrix a -> SpMatrix a
jacobiPre x = recip <$> extractDiag x
ilu0 :: (Scalar (SpVector t) ~ t, Elt t, VectorSpace (SpVector t),
Epsilon t, MonadThrow m) =>
SpMatrix t
-> m (SpMatrix t, SpMatrix t)
ilu0 aa = do
(l, u) <- lu aa
let lh = sparsifyLU l aa
uh = sparsifyLU u aa
sparsifyLU m m2 = ifilterSM f m where
f i j _ = isJust (lookupSM m2 i j)
return (lh, uh)
mSsor :: (MatrixRing (SpMatrix b), Fractional b) =>
SpMatrix b
-> b
-> (SpMatrix b, SpMatrix b)
mSsor aa omega = (l, r) where
(e, d, f) = diagPartitions aa
n = nrows e
l = (eye n ^-^ scale omega e) ## reciprocal d
r = d ^-^ scale omega f
luSolve :: (Scalar (SpVector t) ~ t, MonadThrow m, Elt t, InnerSpace (SpVector t), Epsilon t) =>
SpMatrix t
-> SpMatrix t
-> SpVector t
-> m (SpVector t)
luSolve ll uu b
| isLowerTriSM ll && isUpperTriSM uu = do
w <- triLowerSolve ll b
triUpperSolve uu w
| otherwise = throwM (NonTriangularException "luSolve")
triLowerSolve :: (Scalar (SpVector t) ~ t, Elt t, InnerSpace (SpVector t),
Epsilon t, MonadThrow m) =>
SpMatrix t
-> SpVector t
-> m (SpVector t)
triLowerSolve ll b = do
let q (_, i) = i == nb
nb = svDim b
oops i = throwM (NeedsPivoting "triLowerSolve" (unwords ["L", show (i, i)]) :: MatrixException Double)
lStep (ww, i) = do
let
lii = ll @@ (i, i)
bi = b @@ i
wi = (bi r)/lii where
r = extractSubRow ll i (0, i1) `dot` takeSV i ww
if isNz lii
then return (insertSpVector i wi ww, i + 1)
else oops i
lInit = do
let
l00 = ll @@ (0, 0)
b0 = b @@ 0
w0 = b0 / l00
if isNz l00
then return (insertSpVector 0 w0 $ zeroSV (dim b), 1)
else oops (0 :: Int)
l0 <- lInit
(v, _) <- MTS.execStateT (modifyUntilM q lStep) l0
return $ sparsifySV v
triUpperSolve :: (Scalar (SpVector t) ~ t, Elt t, InnerSpace (SpVector t),
Epsilon t, MonadThrow m) =>
SpMatrix t
-> SpVector t
-> m (SpVector t)
triUpperSolve uu w = do
let q (_, i) = i == ( 1)
nw = svDim w
oops i = throwM (NeedsPivoting "triUpperSolve" (unwords ["U", show (i, i)]) :: MatrixException Double)
uStep (xx, i) = do
let uii = uu @@ (i, i)
wi = w @@ i
r = extractSubRow_RK uu i (i + 1, nw 1) `dot` dropSV (i + 1) xx
xi = (wi r) / uii
if isNz uii
then return (insertSpVector i xi xx, i 1)
else oops i
uInit = do
let i = nw 1
u00 = uu @@! (i, i)
w0 = w @@ i
x0 = w0 / u00
if isNz u00
then return (insertSpVector i x0 (zeroSV nw), i 1)
else oops (0 :: Int)
u0 <- uInit
(x, _) <- MTS.execStateT (modifyUntilM q uStep) u0
return $ sparsifySV x
gmres :: (Scalar (SpVector t) ~ t, MatrixType (SpVector t) ~ SpMatrix t,
Elt t, Normed (SpVector t), LinearVectorSpace (SpVector t), Epsilon t,
MonadThrow m) =>
SpMatrix t -> SpVector t -> m (SpVector t)
gmres aa b = do
let m = ncols aa
(qa, ha) <- arnoldi aa b m
let b' = norm2' b .* ei mp1 1
where mp1 = nrows ha
(qh, rh) <- qr ha
let rhs' = takeSV (dim b' 1) (transpose qh #> b')
rh' = takeRows (nrows rh 1) rh
yhat <- triUpperSolve rh' rhs'
let qa' = takeCols (ncols qa 1) qa
return $ qa' #> yhat
data CGNE a =
CGNE {_xCgne , _rCgne, _pCgne :: SpVector a} deriving Eq
instance Show a => Show (CGNE a) where
show (CGNE x r p) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n"
cgneInit :: (MatrixType (SpVector a) ~ SpMatrix a,
LinearVectorSpace (SpVector a)) =>
SpMatrix a -> SpVector a -> SpVector a -> CGNE a
cgneInit aa b x0 = CGNE x0 r0 p0 where
r0 = b ^-^ (aa #> x0)
p0 = transposeSM aa #> r0
cgneStep :: (MatrixType (SpVector a) ~ SpMatrix a,
LinearVectorSpace (SpVector a), InnerSpace (SpVector a),
Fractional (Scalar (SpVector a))) =>
SpMatrix a -> CGNE a -> CGNE a
cgneStep aa (CGNE x r p) = CGNE x1 r1 p1 where
alphai = (r `dot` r) / (p `dot` p)
x1 = x ^+^ (alphai .* p)
r1 = r ^-^ (alphai .* (aa #> p))
beta = (r1 `dot` r1) / (r `dot` r)
p1 = transpose aa #> r ^+^ (beta .* p)
data BCG a =
BCG { _xBcg, _rBcg, _rHatBcg, _pBcg, _pHatBcg :: SpVector a } deriving Eq
bcgInit :: LinearVectorSpace (SpVector a) =>
MatrixType (SpVector a) -> SpVector a -> SpVector a -> BCG a
bcgInit aa b x0 = BCG x0 r0 r0hat p0 p0hat where
r0 = b ^-^ (aa #> x0)
r0hat = r0
p0 = r0
p0hat = r0
bcgStep :: (MatrixType (SpVector a) ~ SpMatrix a,
LinearVectorSpace (SpVector a), InnerSpace (SpVector a),
Fractional (Scalar (SpVector a))) =>
SpMatrix a -> BCG a -> BCG a
bcgStep aa (BCG x r rhat p phat) = BCG x1 r1 rhat1 p1 phat1 where
aap = aa #> p
alpha = (r `dot` rhat) / (aap `dot` phat)
x1 = x ^+^ (alpha .* p)
r1 = r ^-^ (alpha .* aap)
rhat1 = rhat ^-^ (alpha .* (transpose aa #> phat))
beta = (r1 `dot` rhat1) / (r `dot` rhat)
p1 = r1 ^+^ (beta .* p)
phat1 = rhat1 ^+^ (beta .* phat)
instance Show a => Show (BCG a) where
show (BCG x r rhat p phat) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"r_hat = " ++ show rhat ++ "\n" ++
"p = " ++ show p ++ "\n" ++
"p_hat = " ++ show phat ++ "\n"
data CGS a = CGS { _x, _r, _p, _u :: SpVector a} deriving Eq
cgsInit :: LinearVectorSpace (SpVector a) =>
MatrixType (SpVector a) -> SpVector a -> SpVector a -> CGS a
cgsInit aa b x0 = CGS x0 r0 r0 r0 where
r0 = b ^-^ (aa #> x0)
cgsStep :: (V (SpVector a), Fractional (Scalar (SpVector a))) =>
MatrixType (SpVector a) -> SpVector a -> CGS a -> CGS a
cgsStep aa rhat (CGS x r p u) = CGS xj1 rj1 pj1 uj1
where
aap = aa #> p
alphaj = (r `dot` rhat) / (aap `dot` rhat)
q = u ^-^ (alphaj .* aap)
xj1 = x ^+^ (alphaj .* (u ^+^ q))
rj1 = r ^-^ (alphaj .* (aa #> (u ^+^ q)))
betaj = (rj1 `dot` rhat) / (r `dot` rhat)
uj1 = rj1 ^+^ (betaj .* q)
pj1 = uj1 ^+^ (betaj .* (q ^+^ (betaj .* p)))
instance (Show a) => Show (CGS a) where
show (CGS x r p u) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n" ++
"u = " ++ show u ++ "\n"
data BICGSTAB a =
BICGSTAB { _xBicgstab, _rBicgstab, _pBicgstab :: SpVector a} deriving Eq
bicgsInit :: LinearVectorSpace (SpVector a) =>
MatrixType (SpVector a) -> SpVector a -> SpVector a -> BICGSTAB a
bicgsInit aa b x0 = BICGSTAB x0 r0 r0 where
r0 = b ^-^ (aa #> x0)
bicgstabStep :: (V (SpVector a), Fractional (Scalar (SpVector a))) =>
MatrixType (SpVector a) -> SpVector a -> BICGSTAB a -> BICGSTAB a
bicgstabStep aa r0hat (BICGSTAB x r p) = BICGSTAB xj1 rj1 pj1 where
aap = aa #> p
alphaj = (r <.> r0hat) / (aap <.> r0hat)
sj = r ^-^ (alphaj .* aap)
aasj = aa #> sj
omegaj = (aasj <.> sj) / (aasj <.> aasj)
xj1 = x ^+^ (alphaj .* p) ^+^ (omegaj .* sj)
rj1 = sj ^-^ (omegaj .* aasj)
betaj = (rj1 <.> r0hat)/(r <.> r0hat) * alphaj / omegaj
pj1 = rj1 ^+^ (betaj .* (p ^-^ (omegaj .* aap)))
instance Show a => Show (BICGSTAB a) where
show (BICGSTAB x r p) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n"
pinv :: (MatrixType v ~ SpMatrix a, LinearSystem v, Epsilon a,
MonadThrow m, MonadIO m) =>
SpMatrix a -> v -> m v
pinv aa b = aa #~^# aa <\> atb where
atb = transpose aa #> b
linSolve0 method aa b x0
| m /= nb = throwM (MatVecSizeMismatchException "linSolve0" dm nb)
| otherwise = solve aa b where
solve aa' b' | isDiagonalSM aa' = return $ reciprocal aa' #> b'
| otherwise = xHat
xHat = case method of
BICGSTAB_ -> solver "BICGSTAB" nits _xBicgstab (bicgstabStep aa r0hat) (bicgsInit aa b x0)
BCG_ -> solver "BCG" nits _xBcg (bcgStep aa) (bcgInit aa b x0)
CGS_ -> solver "CGS" nits _x (cgsStep aa r0hat) (cgsInit aa b x0)
GMRES_ -> gmres aa b
CGNE_ -> solver "CGNE" nits _xCgne (cgneStep aa) (cgneInit aa b x0)
r0hat = b ^-^ (aa #> x0)
nits = 200
dm@(m,n) = dim aa
nb = dim b
solver fname nitermax fproj stepf initf = do
xf <- untilConvergedG fname config (const True) stepf initf
return $ fproj xf
where
config = IterConf nitermax True fproj prd
data LinSolveMethod = GMRES_ | CGNE_ | BCG_ | CGS_ | BICGSTAB_ deriving (Eq, Show)
instance LinearSystem (SpVector Double) where
aa <\> b = linSolve0 GMRES_ aa b (mkSpVR n $ replicate n 0.1)
where n = ncols aa
instance LinearSystem (SpVector (Complex Double)) where
aa <\> b = linSolve0 GMRES_ aa b (mkSpVC n $ replicate n 0.1)
where n = ncols aa
aa4 :: SpMatrix Double
aa4 = fromListDenseSM 3 [3,2,2,2,2,1,6,5,4]
aa4c :: SpMatrix (Complex Double)
aa4c = toC <$> aa4