module Numeric.Matrix (
Matrix,
MatrixElement (..),
(<|>),
(<->),
scale,
isUnit,
isZero,
isDiagonal,
isEmpty,
isSquare
) where
import Control.Applicative ((<$>))
import Control.DeepSeq
import Control.Monad
import Control.Monad.ST
import Data.Function (on)
import Data.Ratio
import Data.Complex
import Data.Maybe
import qualified Data.List as L
import Data.Array.IArray
import Data.Array.MArray
import Data.Array.Unboxed
import Data.Array.ST
import qualified Data.Array.Unsafe as U
import Data.STRef
import Data.Typeable
import Prelude hiding (any, all, read, map)
import qualified Prelude as P
data family Matrix e
data instance Matrix Int
= IntMatrix !Int !Int (Array Int (UArray Int Int))
data instance Matrix Float
= FloatMatrix !Int !Int (Array Int (UArray Int Float))
data instance Matrix Double
= DoubleMatrix !Int !Int (Array Int (UArray Int Double))
data instance Matrix Integer
= IntegerMatrix !Int !Int (Array Int (Array Int Integer))
data instance Matrix (Ratio a)
= RatioMatrix !Int !Int (Array Int (Array Int (Ratio a)))
data instance Matrix (Complex a)
= ComplexMatrix !Int !Int (Array Int (Array Int (Complex a)))
instance Typeable a => Typeable (Matrix a) where
typeOf x = mkTyConApp (mkTyCon3 "bed-and-breakfast"
"Numeric.Matrix"
"Matrix") [typeOf (unT x)]
where
unT :: Matrix a -> a
unT = undefined
instance (MatrixElement e, Show e) => Show (Matrix e) where
show = unlines . P.map showRow . toList
where
showRow = unwords . P.map ((' ':) . show)
instance (Read e, MatrixElement e) => Read (Matrix e) where
readsPrec _ = (\x -> [(x, "")]) . fromList . P.map (P.map P.read . words) . lines
instance (MatrixElement e) => Num (Matrix e) where
(+) = plus
() = minus
(*) = times
abs = map abs
signum = matrix (1,1) . const . signum . det
fromInteger = matrix (1,1) . const . fromInteger
instance (MatrixElement e, Fractional e) => Fractional (Matrix e) where
recip = fromJust . inv
fromRational = matrix (1,1) . const . fromRational
instance (MatrixElement e) => Eq (Matrix e) where
m == n
| dimensions m == dimensions n
= allWithIndex (\ix e -> m `at` ix == e) n
| otherwise = False
instance (MatrixElement e) => NFData (Matrix e) where
rnf matrix = matrix `deepseq` ()
(<|>) :: MatrixElement e => Matrix e -> Matrix e -> Matrix e
m1 <|> m2 = let m = numCols m1
n1 = numRows m1
n2 = numRows m2
in matrix (max n1 n2, m + numCols m2)
$ \(i,j) -> if j > m
then (if i > n2 then 0 else m2 `at` (i,jm))
else (if i > n1 then 0 else m1 `at` (i,j))
(<->) :: MatrixElement e => Matrix e -> Matrix e -> Matrix e
m1 <-> m2 = let m = numRows m1
n1 = numCols m1
n2 = numCols m2
in matrix (m + numRows m2, max n1 n2)
$ \(i,j) -> if i > m
then (if j > n2 then 0 else m2 `at` (im,j))
else (if j > n1 then 0 else m1 `at` (i,j))
scale :: MatrixElement e => Matrix e -> e -> Matrix e
scale m s = map (*s) m
isUnit, isDiagonal, isZero, isEmpty, isSquare :: MatrixElement e => Matrix e -> Bool
isZero = all (== 0)
isUnit m = isSquare m && allWithIndex (uncurry check) m
where check = \i j e -> if i == j then e == 1 else e == 0
isEmpty m = numRows m == 0 || numCols m == 0
isDiagonal m = isSquare m && allWithIndex (uncurry check) m
where check = \i j e -> if i /= j then e == 0 else True
isSquare m = let (a, b) = dimensions m in a == b
class Division e where
divide :: e -> e -> e
instance Division Int where divide = quot
instance Division Integer where divide = quot
instance Division Float where divide = (/)
instance Division Double where divide = (/)
instance Integral a => Division (Ratio a) where divide = (/)
instance RealFloat a => Division (Complex a) where divide = (/)
class (Eq e, Num e) => MatrixElement e where
matrix :: (Int, Int) -> ((Int, Int) -> e) -> Matrix e
select :: ((Int, Int) -> Bool) -> Matrix e -> [e]
at :: Matrix e -> (Int, Int) -> e
row :: Int -> Matrix e -> [e]
col :: Int -> Matrix e -> [e]
dimensions :: Matrix e -> (Int, Int)
numRows :: Matrix e -> Int
numCols :: Matrix e -> Int
fromList :: [[e]] -> Matrix e
toList :: Matrix e -> [[e]]
unit :: Int -> Matrix e
zero :: Int -> Matrix e
diag :: [e] -> Matrix e
empty :: Matrix e
minus :: Matrix e -> Matrix e -> Matrix e
plus :: Matrix e -> Matrix e -> Matrix e
times :: Matrix e -> Matrix e -> Matrix e
inv :: Matrix e -> Maybe (Matrix e)
det :: Matrix e -> e
transpose :: Matrix e -> Matrix e
rank :: Matrix e -> e
trace :: Matrix e -> [e]
minor :: MatrixElement e => Matrix e -> (Int, Int) -> e
cofactors :: MatrixElement e => Matrix e -> Matrix e
adjugate :: MatrixElement e => Matrix e -> Matrix e
minorMatrix :: MatrixElement e => Matrix e -> (Int, Int) -> Matrix e
map :: MatrixElement f => (e -> f) -> Matrix e -> Matrix f
all :: (e -> Bool) -> Matrix e -> Bool
any :: (e -> Bool) -> Matrix e -> Bool
mapWithIndex :: MatrixElement f => ((Int, Int) -> e -> f) -> Matrix e -> Matrix f
allWithIndex :: ((Int, Int) -> e -> Bool) -> Matrix e -> Bool
anyWithIndex :: ((Int, Int) -> e -> Bool) -> Matrix e -> Bool
unit n = fromList [[ if i == j then 1 else 0 | j <- [1..n]] | i <- [1..n] ]
zero n = matrix (n,n) (const 0)
empty = fromList []
diag xs = matrix (n,n) (\(i,j) -> if i == j then xs !! (i1) else 0)
where n = length xs
select p m = [ at m (i,j) | i <- [1..numRows m]
, j <- [1..numCols m]
, p (i,j) ]
at m (i, j) = ((!! j) . (!! i) . toList) m
row i m = ((!! (i1)) . toList) m
col i m = (row i . transpose) m
numRows = fst . dimensions
numCols = snd . dimensions
dimensions m = case toList m of [] -> (0, 0)
(x:xs) -> (length xs + 1, length x)
adjugate = transpose . cofactors
transpose m = matrix (dimensions m) (\(i,j) -> m `at` (j,i))
trace = select (uncurry (==))
inv _ = Nothing
minorMatrix m (i,j) = matrix (numRows m 1, numCols m 1) $
\(i',j') -> m `at` (if i' >= i then i' + 1 else i',
if j' >= j then j' + 1 else j')
minor m = det . minorMatrix m
cofactors m = matrix (dimensions m) $
\(i,j) -> fromIntegral ((1 :: Int)^(i+j)) * minor m (i,j)
map f = mapWithIndex (const f)
all f = allWithIndex (const f)
any f = anyWithIndex (const f)
mapWithIndex f m = matrix (dimensions m) (\x -> f x (m `at` x))
allWithIndex f m = P.all id [ f (i, j) (m `at` (i,j))
| i <- [1..numRows m], j <- [1..numCols m]]
anyWithIndex f m = P.any id [ f (i, j) (m `at` (i,j))
| i <- [1..numRows m], j <- [1..numCols m]]
a `plus` b
| dimensions a /= dimensions b = error "Matrix.plus: dimensions don't match."
| otherwise = matrix (dimensions a) (\x -> a `at` x + b `at` x)
a `minus` b
| dimensions a /= dimensions b = error "Matrix.minus: dimensions don't match."
| otherwise = matrix (dimensions a) (\x -> a `at` x b `at` x)
a `times` b
| numCols a /= numRows b = error "Matrix.times: `numRows a' and `numCols b' don't match."
| otherwise = _mult a b
instance MatrixElement Int where
matrix d g = runST (_matrix IntMatrix arrayST arraySTU d g)
fromList = _fromList IntMatrix
at (IntMatrix _ _ arr) = _at arr
dimensions (IntMatrix m n _) = (m, n)
row i (IntMatrix _ _ arr) = _row i arr
col j (IntMatrix _ _ arr) = _col j arr
toList (IntMatrix _ _ arr) = _toList arr
det (IntMatrix m n arr) = if m /= n then 0 else runST (_det thawsUnboxed arr)
rank (IntMatrix _ _ arr) = runST (_rank thawsBoxed arr)
instance MatrixElement Integer where
matrix d g = runST (_matrix IntegerMatrix arrayST arrayST d g)
fromList = _fromList IntegerMatrix
at (IntegerMatrix _ _ arr) = _at arr
dimensions (IntegerMatrix m n _) = (m, n)
row i (IntegerMatrix _ _ arr) = _row i arr
col j (IntegerMatrix _ _ arr) = _col j arr
toList (IntegerMatrix _ _ arr) = _toList arr
det (IntegerMatrix m n arr) = if m /= n then 0 else runST (_det thawsBoxed arr)
rank (IntegerMatrix _ _ arr) = runST (_rank thawsBoxed arr)
instance MatrixElement Float where
matrix d g = runST (_matrix FloatMatrix arrayST arraySTU d g)
fromList = _fromList FloatMatrix
at (FloatMatrix _ _ arr) = _at arr
dimensions (FloatMatrix m n _ ) = (m, n)
row i (FloatMatrix _ _ arr) = _row i arr
col j (FloatMatrix _ _ arr) = _col j arr
toList (FloatMatrix _ _ arr) = _toList arr
det (FloatMatrix m n arr) = if m /= n then 0 else runST (_det thawsUnboxed arr)
rank (FloatMatrix _ _ arr) = runST (_rank thawsBoxed arr)
inv (FloatMatrix m n arr) = if m /= n then Nothing else
let x = runST (_inv unboxedST arr)
in maybe Nothing (Just . FloatMatrix m n) x
instance MatrixElement Double where
matrix d g = runST (_matrix DoubleMatrix arrayST arraySTU d g)
fromList = _fromList DoubleMatrix
at (DoubleMatrix _ _ arr) = _at arr
dimensions (DoubleMatrix m n _ ) = (m, n)
row i (DoubleMatrix _ _ arr) = _row i arr
col j (DoubleMatrix _ _ arr) = _col j arr
toList (DoubleMatrix _ _ arr) = _toList arr
inv (DoubleMatrix m n arr) = if m /= n then Nothing else
let x = runST (_inv unboxedST arr)
in maybe Nothing (Just . DoubleMatrix m n) x
det (DoubleMatrix m n arr) = if m /= n then 0 else runST (_det thawsUnboxed arr)
rank (DoubleMatrix _ _ arr) = runST (_rank thawsBoxed arr)
instance (Show a, Integral a) => MatrixElement (Ratio a) where
matrix d g = runST (_matrix RatioMatrix arrayST arrayST d g)
fromList = _fromList RatioMatrix
at (RatioMatrix _ _ arr) = _at arr
dimensions (RatioMatrix m n _ ) = (m, n)
row i (RatioMatrix _ _ arr) = _row i arr
col j (RatioMatrix _ _ arr) = _col j arr
toList (RatioMatrix _ _ arr) = _toList arr
inv (RatioMatrix m n arr) = if m /= n then Nothing else
let x = runST (_inv boxedST arr)
in maybe Nothing (Just . RatioMatrix m n) x
det (RatioMatrix m n arr) = if m /= n then 0 else runST (_det thawsBoxed arr)
rank (RatioMatrix _ _ arr) = runST (_rank thawsBoxed arr)
instance (Show a, RealFloat a) => MatrixElement (Complex a) where
matrix d g = runST (_matrix ComplexMatrix arrayST arrayST d g)
fromList = _fromList ComplexMatrix
at (ComplexMatrix _ _ arr) = _at arr
dimensions (ComplexMatrix m n _ ) = (m, n)
row i (ComplexMatrix _ _ arr) = _row i arr
col j (ComplexMatrix _ _ arr) = _col j arr
toList (ComplexMatrix _ _ arr) = _toList arr
det (ComplexMatrix m n arr) = if m /= n then 0 else runST (_det thawsBoxed arr)
rank (ComplexMatrix _ _ arr) = runST (_rank thawsBoxed arr)
_at :: (IArray a (u Int e), IArray u e)
=> a Int (u Int e) -> (Int, Int) -> e
_at arr (i,j) = arr ! i ! j
_row, _col :: (IArray a (u Int e), IArray u e) => Int -> a Int (u Int e) -> [e]
_row i arr = let row = arr ! i in [ row ! j | j <- [1..(snd (bounds arr))] ]
_col j arr = [ arr ! i ! j | i <- [1..(snd (bounds arr))] ]
_toList :: (IArray a e) => Array Int (a Int e) -> [[e]]
_toList = P.map elems . elems
_fromList :: (IArray a (u Int e), IArray u e)
=> (Int -> Int -> a Int (u Int e) -> matrix e) -> [[e]] -> matrix e
_fromList c xs =
let lengths = P.map length xs
numCols = foldl1 min lengths
numRows = length lengths
in c numRows numCols
$ array (1, numRows)
$ zip [1..numRows]
$ P.map (array (1, numCols) . zip [1..numCols]) xs
thawsBoxed :: (IArray a e, MArray (STArray s) e (ST s))
=> Array Int (a Int e)
-> ST s [STArray s Int e]
thawsBoxed = mapM thaw . elems
thawsUnboxed :: (IArray a e, MArray (STUArray s) e (ST s))
=> Array Int (a Int e)
-> ST s [STUArray s Int e]
thawsUnboxed = mapM thaw . elems
arrays :: [(u s) Int e]
-> ST s ((STArray s) Int ((u s) Int e))
arrays list = newListArray (1, length list) list
augment :: (IArray a e, MArray (u s) e (ST s), Num e)
=> ((Int, Int) -> [e] -> ST s ((u s) Int e))
-> Array Int (a Int e)
-> ST s (STArray s Int (u s Int e))
augment _ arr = do
let (_, n) = bounds arr
row (a,i) = newListArray (1, 2*n)
[ if j > n then (if j == i + n then 1 else 0)
else a ! j
| j <- [1..2*n] ]
mapM row (zip (elems arr) [1..]) >>= newListArray (1, n)
boxedST :: MArray (STArray s) e (ST s)
=> (Int, Int) -> [e] -> ST s ((STArray s) Int e)
boxedST = newListArray
unboxedST :: MArray (STUArray s) e (ST s)
=> (Int, Int) -> [e] -> ST s ((STUArray s) Int e)
unboxedST = newListArray
arrayST :: MArray (STArray s) e (ST s)
=> (Int, Int) -> e -> ST s ((STArray s) Int e)
arrayST = newArray
arraySTU :: MArray (STUArray s) e (ST s)
=> (Int, Int) -> e -> ST s ((STUArray s) Int e)
arraySTU = newArray
tee :: Monad m => (b -> m a) -> b -> m b
tee f x = f x >> return x
read :: (MArray a1 b m, MArray a (a1 Int b) m) =>
a Int (a1 Int b) -> Int -> Int -> m b
read a i j = readArray a i >>= flip readArray j
_inv :: (IArray a e, MArray (u s) e (ST s), Fractional e, Ord e, Show e)
=> ((Int, Int) -> [e] -> ST s ((u s) Int e))
-> Array Int (a Int e)
-> ST s (Maybe (Array Int (a Int e)))
_inv mkArrayST mat = do
let m = snd $ bounds mat
n = 2*m
swap a i j = do
tmp <- readArray a i
readArray a j >>= writeArray a i
writeArray a j tmp
okay <- newSTRef True
a <- augment mkArrayST mat
flip mapM_ [1..m] $ \k -> do
iPivot <- zip [k..m] <$> mapM (\i -> abs <$> read a i k) [k..m]
>>= return . fst . L.maximumBy (compare `on` snd)
p <- read a iPivot k
if p == 0 then writeSTRef okay False else do
swap a iPivot k
flip mapM_ [k+1..m] $ \i -> do
a_i <- readArray a i
a_k <- readArray a k
flip mapM_ [k+1..n] $ \j -> do
a_ij <- readArray a_i j
a_kj <- readArray a_k j
a_ik <- readArray a_i k
writeArray a_i j (a_ij a_kj * (a_ik / p))
writeArray a_i k 0
invertible <- readSTRef okay
if invertible then
do
flip mapM_ [ m v | v <- [0..m1] ] $ \i -> do
a_i <- readArray a i
p <- readArray a_i i
writeArray a_i i 1
flip mapM_ [i+1..n] $ \j -> do
readArray a_i j >>= writeArray a_i j . (/ p)
unless (i == m) $ do
flip mapM_ [i+1..m] $ \k -> do
a_k <- readArray a k
p <- readArray a_i k
flip mapM_ [k..n] $ \j -> do
a_ij <- readArray a_i j
a_kj <- readArray a_k j
writeArray a_i j (a_ij p * a_kj)
mapM (\i -> readArray a i >>= getElems
>>= return . listArray (1, m) . drop m) [1..m]
>>= return . Just . listArray (1, m)
else return Nothing
_rank :: (IArray a e, MArray (u s) e (ST s), Num e, Division e, Eq e)
=> (Array Int (a Int e) -> ST s [(u s) Int e])
-> Array Int (a Int e)
-> ST s e
_rank thaws mat = do
let m = snd $ bounds mat
n = snd $ bounds (mat ! 1)
swap a i j = do
tmp <- readArray a i
readArray a j >>= writeArray a i
writeArray a j tmp
a <- thaws mat >>= arrays
ixPivot <- newSTRef 1
prevR <- newSTRef 1
flip mapM_ [1..n] $ \k -> do
pivotRow <- readSTRef ixPivot
switchRow <- mapM (\i -> read a i k) [pivotRow .. m]
>>= return . L.findIndex (/= 0)
when (isJust switchRow) $ do
let ix = fromJust switchRow + pivotRow
when (pivotRow /= ix) (swap a pivotRow ix)
a_p <- readArray a k
pivot <- readArray a_p k
prev <- readSTRef prevR
flip mapM_ [pivotRow+1..m] $ \i -> do
a_i <- readArray a i
flip mapM_ [k+1..n] $ \j -> do
a_ij <- readArray a_i j
a_ik <- readArray a_i k
a_pj <- readArray a_p j
writeArray a_i j ((pivot * a_ij a_ik * a_pj)
`divide` prev)
writeSTRef ixPivot (pivotRow + 1)
writeSTRef prevR pivot
readSTRef ixPivot >>= return . (+ negate 1) . fromIntegral
_det :: (IArray a e, MArray (u s) e (ST s),
Num e, Eq e, Division e)
=> (Array Int (a Int e) -> ST s [(u s) Int e])
-> Array Int (a Int e) -> ST s e
_det thaws mat = do
let size = snd $ bounds mat
a <- thaws mat >>= arrays
signR <- newSTRef 1
pivotR <- newSTRef 1
flip mapM_ [1..size] $ \k -> do
sign <- readSTRef signR
unless (sign == 0) $ do
prev <- readSTRef pivotR
pivot <- read a k k >>= tee (writeSTRef pivotR)
when (pivot == 0) $ do
s <- flip mapM [(k+1)..size] $ \r -> do
a_rk <- read a r k
if a_rk == 0 then return 0 else return r
let sf = filter (>0) s
when (not $ null sf) $ do
let sw = head sf
row <- readArray a sw
readArray a k >>= writeArray a sw
writeArray a k row
read a k k >>= writeSTRef pivotR
readSTRef signR >>= writeSTRef signR . negate
when (null sf) (writeSTRef signR 0)
sign' <- readSTRef signR
unless (sign' == 0) $ do
pivot' <- readSTRef pivotR
flip mapM_ [(k+1)..size] $ \i -> do
a_i <- readArray a i
flip mapM [(k+1)..size] $ \j -> do
a_ij <- readArray a_i j
a_ik <- readArray a_i k
a_kj <- read a k j
writeArray a_i j ((pivot' * a_ij a_ik * a_kj) `divide` prev)
liftM2 (*) (readSTRef pivotR) (readSTRef signR)
_mult :: MatrixElement e => Matrix e -> Matrix e -> Matrix e
_mult a b = let rowsA = numRows a
rowsB = numRows b
colsB = numCols b
in matrix (rowsA, colsB) (\(i,j) -> L.foldl' (+) 0 [a `at` (i, k) * b `at` (k, j) | k <- [1..rowsB]])
_matrix :: (IArray a1 (u Int e), IArray u e,
MArray a2 (u Int e) (ST s), MArray a3 e (ST s),
Num e)
=> (Int -> Int -> a1 Int (u Int e) -> matrix)
-> ((Int, Int) -> a -> ST s (a2 Int (u Int e)))
-> ((Int, Int) -> e -> ST s (a3 Int e))
-> (Int, Int)
-> ((Int, Int) -> e)
-> ST s matrix
_matrix c newArray newArrayU (m, n) g = do
rows <- newArray (1, m) undefined
flip mapM_ [1..m] $ \i -> do
cols <- newArrayU (1, n) 0
flip mapM_ [1..n] $ \j -> do
writeArray cols j (g (i,j))
U.unsafeFreeze cols >>= writeArray rows i
U.unsafeFreeze rows >>= return . c m n