{-# LANGUAGE DeriveGeneric #-}
-- | Matrix datatype and operations.
--
--   Every provided example has been tested.
--   Run @cabal test@ for further tests.
module Data.Matrix (
    -- * Matrix type
    Matrix , prettyMatrix
  , nrows , ncols
  , forceMatrix
    -- * Builders
  , matrix
  , rowVector
  , colVector
    -- ** Special matrices
  , zero
  , identity
  , diagonalList
  , diagonal
  , permMatrix
    -- * List conversions
  , fromList , fromLists
  , toList   , toLists
    -- * Accessing
  , getElem , (!) , unsafeGet , safeGet, safeSet
  , getRow  , safeGetRow , getCol , safeGetCol
  , getDiag
  , getMatrixAsVector
    -- * Manipulating matrices
  , setElem
  , unsafeSet
  , transpose , setSize , extendTo
  , inverse, rref
  , mapRow , mapCol, mapPos
    -- * Submatrices
    -- ** Splitting blocks
  , submatrix
  , minorMatrix
  , splitBlocks
   -- ** Joining blocks
  , (<|>) , (<->)
  , joinBlocks
    -- * Matrix operations
  , elementwise, elementwiseUnsafe
    -- * Matrix multiplication
    -- ** About matrix multiplication
    -- $mult

    -- ** Functions
  , multStd
  , multStd2
  , multStrassen
  , multStrassenMixed
    -- * Linear transformations
  , scaleMatrix
  , scaleRow
  , combineRows
  , switchRows
  , switchCols
    -- * Decompositions
  , luDecomp , luDecompUnsafe
  , luDecomp', luDecompUnsafe'
  , cholDecomp
    -- * Properties
  , trace , diagProd
    -- ** Determinants
  , detLaplace
  , detLU
  , flatten
  ) where

import Prelude hiding (foldl1)
-- Classes
import Control.DeepSeq
import Control.Monad (forM_)
import Control.Loop (numLoop,numLoopFold)
import Data.Foldable (Foldable, foldMap, foldl1)
import Data.Maybe
import Data.Monoid
import qualified Data.Semigroup as S
import Data.Traversable
import Control.Applicative(Applicative, (<$>), (<*>), pure)
import GHC.Generics (Generic)
-- Data
import           Control.Monad.Primitive (PrimMonad, PrimState)
import           Data.List               (maximumBy,foldl1',find)
import           Data.Ord                (comparing)
import qualified Data.Vector             as V
import qualified Data.Vector.Mutable     as MV

-------------------------------------------------------
-------------------------------------------------------
---- MATRIX TYPE

encode :: Int -> (Int,Int) -> Int
{-# INLINE encode #-}
encode :: Int -> (Int, Int) -> Int
encode Int
m (Int
i,Int
j) = (Int
iforall a. Num a => a -> a -> a
-Int
1)forall a. Num a => a -> a -> a
*Int
m forall a. Num a => a -> a -> a
+ Int
j forall a. Num a => a -> a -> a
- Int
1

decode :: Int -> Int -> (Int,Int)
{-# INLINE decode #-}
decode :: Int -> Int -> (Int, Int)
decode Int
m Int
k = (Int
qforall a. Num a => a -> a -> a
+Int
1,Int
rforall a. Num a => a -> a -> a
+Int
1)
 where
  (Int
q,Int
r) = forall a. Integral a => a -> a -> (a, a)
quotRem Int
k Int
m

-- | Type of matrices.
--
--   Elements can be of any type. Rows and columns
--   are indexed starting by 1. This means that, if @m :: Matrix a@ and
--   @i,j :: Int@, then @m ! (i,j)@ is the element in the @i@-th row and
--   @j@-th column of @m@.
data Matrix a = M {
   forall a. Matrix a -> Int
nrows     :: {-# UNPACK #-} !Int -- ^ Number of rows.
 , forall a. Matrix a -> Int
ncols     :: {-# UNPACK #-} !Int -- ^ Number of columns.
 , forall a. Matrix a -> Int
rowOffset :: {-# UNPACK #-} !Int
 , forall a. Matrix a -> Int
colOffset :: {-# UNPACK #-} !Int
 , forall a. Matrix a -> Int
vcols     :: {-# UNPACK #-} !Int -- ^ Number of columns of the matrix without offset
 , forall a. Matrix a -> Vector a
mvect     :: V.Vector a          -- ^ Content of the matrix as a plain vector.
   } deriving (forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (Matrix a) x -> Matrix a
forall a x. Matrix a -> Rep (Matrix a) x
$cto :: forall a x. Rep (Matrix a) x -> Matrix a
$cfrom :: forall a x. Matrix a -> Rep (Matrix a) x
Generic)

instance Eq a => Eq (Matrix a) where
  Matrix a
m1 == :: Matrix a -> Matrix a -> Bool
== Matrix a
m2 =
    let r :: Int
r = forall a. Matrix a -> Int
nrows Matrix a
m1
        c :: Int
c = forall a. Matrix a -> Int
ncols Matrix a
m1
    in  forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall a b. (a -> b) -> a -> b
$ (Int
r forall a. Eq a => a -> a -> Bool
== forall a. Matrix a -> Int
nrows Matrix a
m2) forall a. a -> [a] -> [a]
: (Int
c forall a. Eq a => a -> a -> Bool
== forall a. Matrix a -> Int
ncols Matrix a
m2)
            forall a. a -> [a] -> [a]
: [ Matrix a
m1 forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j) forall a. Eq a => a -> a -> Bool
== Matrix a
m2 forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j) | Int
i <- [Int
1 .. Int
r] , Int
j <- [Int
1 .. Int
c] ]

-- | Just a cool way to output the size of a matrix.
sizeStr :: Int -> Int -> String
sizeStr :: Int -> Int -> String
sizeStr Int
n Int
m = forall a. Show a => a -> String
show Int
n forall a. [a] -> [a] -> [a]
++ String
"x" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
m

-- | Display a matrix as a 'String' using the 'Show' instance of its elements.
prettyMatrix :: Show a => Matrix a -> String
prettyMatrix :: forall a. Show a => Matrix a -> String
prettyMatrix Matrix a
m = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
   [ String
"┌ ", [String] -> String
unwords (forall a. Int -> a -> [a]
replicate (forall a. Matrix a -> Int
ncols Matrix a
m) String
blank), String
" ┐\n"
   , [String] -> String
unlines
   [ String
"│ " forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
j -> String -> String
fill forall a b. (a -> b) -> a -> b
$ Matrix String
strings forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j)) [Int
1..forall a. Matrix a -> Int
ncols Matrix a
m]) forall a. [a] -> [a] -> [a]
++ String
" │" | Int
i <- [Int
1..forall a. Matrix a -> Int
nrows Matrix a
m] ]
   , String
"└ ", [String] -> String
unwords (forall a. Int -> a -> [a]
replicate (forall a. Matrix a -> Int
ncols Matrix a
m) String
blank), String
" ┘"
   ]
 where
   strings :: Matrix String
strings@(M Int
_ Int
_ Int
_ Int
_ Int
_ Vector String
v)  = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Show a => a -> String
show Matrix a
m
   widest :: Int
widest = forall a. Ord a => Vector a -> a
V.maximum forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t a -> Int
length Vector String
v
   fill :: String -> String
fill String
str = forall a. Int -> a -> [a]
replicate (Int
widest forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str) Char
' ' forall a. [a] -> [a] -> [a]
++ String
str
   blank :: String
blank = String -> String
fill String
""


instance Show a => Show (Matrix a) where
 show :: Matrix a -> String
show = forall a. Show a => Matrix a -> String
prettyMatrix

instance NFData a => NFData (Matrix a) where
 rnf :: Matrix a -> ()
rnf = forall a. NFData a => a -> ()
rnf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Matrix a -> Vector a
mvect

-- | /O(rows*cols)/. Similar to 'V.force'. It copies the matrix content
--   dropping any extra memory.
--
--   Useful when using 'submatrix' from a big matrix.
--
forceMatrix :: Matrix a -> Matrix a
forceMatrix :: forall a. Matrix a -> Matrix a
forceMatrix Matrix a
m = forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (forall a. Matrix a -> Int
nrows Matrix a
m) (forall a. Matrix a -> Int
ncols Matrix a
m) forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m

-------------------------------------------------------
-------------------------------------------------------
---- FUNCTOR INSTANCE

instance Functor Matrix where
 {-# INLINE fmap #-}
 fmap :: forall a b. (a -> b) -> Matrix a -> Matrix b
fmap a -> b
f (M Int
n Int
m Int
ro Int
co Int
w Vector a
v) = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
ro Int
co Int
w forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> Vector a -> Vector b
V.map a -> b
f Vector a
v

-------------------------------------------------------
-------------------------------------------------------

-------------------------------------------------------
-------------------------------------------------------
---- MONOID INSTANCE

instance Monoid a => S.Semigroup (Matrix a) where
  <> :: Matrix a -> Matrix a -> Matrix a
(<>) = forall a. Monoid a => a -> a -> a
mappend

instance Monoid a => Monoid (Matrix a) where
  mempty :: Matrix a
mempty = forall a. Int -> Int -> [a] -> Matrix a
fromList Int
1 Int
1 [forall a. Monoid a => a
mempty]
  mappend :: Matrix a -> Matrix a -> Matrix a
mappend Matrix a
m Matrix a
m' = forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (forall a. Ord a => a -> a -> a
max (forall a. Matrix a -> Int
nrows Matrix a
m) (forall a. Matrix a -> Int
nrows Matrix a
m')) (forall a. Ord a => a -> a -> a
max (forall a. Matrix a -> Int
ncols Matrix a
m) (forall a. Matrix a -> Int
ncols Matrix a
m')) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> Int -> a
zipTogether
    where zipTogether :: Int -> Int -> a
zipTogether Int
row Int
column = forall a. a -> Maybe a -> a
fromMaybe forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall a. Int -> Int -> Matrix a -> Maybe a
safeGet Int
row Int
column Matrix a
m forall a. Semigroup a => a -> a -> a
<> forall a. Int -> Int -> Matrix a -> Maybe a
safeGet Int
row Int
column Matrix a
m'


-------------------------------------------------------
-------------------------------------------------------
-------------------------------------------------------
-------------------------------------------------------

-------------------------------------------------------
-------------------------------------------------------
---- APPLICATIVE INSTANCE
---- Works like tensor product but applies a function

instance Applicative Matrix where
  pure :: forall a. a -> Matrix a
pure a
x = forall a. Int -> Int -> [a] -> Matrix a
fromList Int
1 Int
1 [a
x]
  Matrix (a -> b)
m <*> :: forall a b. Matrix (a -> b) -> Matrix a -> Matrix b
<*> Matrix a
m' = forall a. Matrix (Matrix a) -> Matrix a
flatten forall a b. (a -> b) -> a -> b
$ (\a -> b
f -> a -> b
f forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Matrix a
m') forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Matrix (a -> b)
m


-------------------------------------------------------
-------------------------------------------------------



-- | Flatten a matrix of matrices. All sub matrices must have same dimensions
--   This criteria is not checked.
flatten:: Matrix (Matrix a) -> Matrix a
flatten :: forall a. Matrix (Matrix a) -> Matrix a
flatten Matrix (Matrix a)
m = forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 forall a. Matrix a -> Matrix a -> Matrix a
(<->) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 forall a. Matrix a -> Matrix a -> Matrix a
(<|>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\Int
i -> forall a. Int -> Matrix a -> Vector a
getRow Int
i Matrix (Matrix a)
m)) [Int
1..(forall a. Matrix a -> Int
nrows Matrix (Matrix a)
m)]

-- | /O(rows*cols)/. Map a function over a row.
--   Example:
--
-- >                          ( 1 2 3 )   ( 1 2 3 )
-- >                          ( 4 5 6 )   ( 5 6 7 )
-- > mapRow (\_ x -> x + 1) 2 ( 7 8 9 ) = ( 7 8 9 )
--
mapRow :: (Int -> a -> a) -- ^ Function takes the current column as additional argument.
        -> Int            -- ^ Row to map.
        -> Matrix a -> Matrix a
mapRow :: forall a. (Int -> a -> a) -> Int -> Matrix a -> Matrix a
mapRow Int -> a -> a
f Int
r Matrix a
m =
  forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (forall a. Matrix a -> Int
nrows Matrix a
m) (forall a. Matrix a -> Int
ncols Matrix a
m) forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) ->
    let a :: a
a = forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m
    in  if Int
i forall a. Eq a => a -> a -> Bool
== Int
r
           then Int -> a -> a
f Int
j a
a
           else a
a

-- | /O(rows*cols)/. Map a function over a column.
--   Example:
--
-- >                          ( 1 2 3 )   ( 1 3 3 )
-- >                          ( 4 5 6 )   ( 4 6 6 )
-- > mapCol (\_ x -> x + 1) 2 ( 7 8 9 ) = ( 7 9 9 )
--
mapCol :: (Int -> a -> a) -- ^ Function takes the current row as additional argument.
        -> Int            -- ^ Column to map.
        -> Matrix a -> Matrix a
mapCol :: forall a. (Int -> a -> a) -> Int -> Matrix a -> Matrix a
mapCol Int -> a -> a
f Int
c Matrix a
m =
  forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (forall a. Matrix a -> Int
nrows Matrix a
m) (forall a. Matrix a -> Int
ncols Matrix a
m) forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) ->
    let a :: a
a = forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m
    in  if Int
j forall a. Eq a => a -> a -> Bool
== Int
c
           then Int -> a -> a
f Int
i a
a
           else a
a


-- | /O(rows*cols)/. Map a function over elements.
--   Example:
--
-- >                            ( 1 2 3 )   ( 0 -1 -2 )
-- >                            ( 4 5 6 )   ( 1  0 -1 )
-- > mapPos (\(r,c) a -> r - c) ( 7 8 9 ) = ( 2  1  0 )
--
mapPos :: ((Int, Int) -> a -> b) -- ^ Function takes the current Position as additional argument.
        -> Matrix a
        -> Matrix b
mapPos :: forall a b. ((Int, Int) -> a -> b) -> Matrix a -> Matrix b
mapPos (Int, Int) -> a -> b
f m :: Matrix a
m@(M {ncols :: forall a. Matrix a -> Int
ncols = Int
cols, mvect :: forall a. Matrix a -> Vector a
mvect = Vector a
vect})=
  Matrix a
m { mvect :: Vector b
mvect = forall a b. (Int -> a -> b) -> Vector a -> Vector b
V.imap (\Int
i a
e -> (Int, Int) -> a -> b
f (Int -> Int -> (Int, Int)
decode Int
cols Int
i) a
e) Vector a
vect}

-------------------------------------------------------
-------------------------------------------------------
---- FOLDABLE AND TRAVERSABLE INSTANCES

instance Foldable Matrix where
 foldMap :: forall m a. Monoid m => (a -> m) -> Matrix a -> m
foldMap a -> m
f = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap a -> m
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Matrix a -> Vector a
mvect forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Matrix a -> Matrix a
forceMatrix

instance Traversable Matrix where
 sequenceA :: forall (f :: * -> *) a.
Applicative f =>
Matrix (f a) -> f (Matrix a)
sequenceA Matrix (f a)
m = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M (forall a. Matrix a -> Int
nrows Matrix (f a)
m) (forall a. Matrix a -> Int
ncols Matrix (f a)
m) Int
0 Int
0 (forall a. Matrix a -> Int
ncols Matrix (f a)
m)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Matrix a -> Vector a
mvect forall a b. (a -> b) -> a -> b
$ forall a. Matrix a -> Matrix a
forceMatrix Matrix (f a)
m

-------------------------------------------------------
-------------------------------------------------------
---- BUILDERS

-- | /O(rows*cols)/. The zero matrix of the given size.
--
-- > zero n m =
-- >                 m
-- >   1 ( 0 0 ... 0 0 )
-- >   2 ( 0 0 ... 0 0 )
-- >     (     ...     )
-- >     ( 0 0 ... 0 0 )
-- >   n ( 0 0 ... 0 0 )
zero :: Num a =>
     Int -- ^ Rows
  -> Int -- ^ Columns
  -> Matrix a
{-# INLINE zero #-}
zero :: forall a. Num a => Int -> Int -> Matrix a
zero Int
n Int
m = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
0 Int
0 Int
m forall a b. (a -> b) -> a -> b
$ forall a. Int -> a -> Vector a
V.replicate (Int
nforall a. Num a => a -> a -> a
*Int
m) a
0

-- | /O(rows*cols)/. Generate a matrix from a generator function.
--   Example of usage:
--
-- >                                  (  1  0 -1 -2 )
-- >                                  (  3  2  1  0 )
-- >                                  (  5  4  3  2 )
-- > matrix 4 4 $ \(i,j) -> 2*i - j = (  7  6  5  4 )
matrix :: Int -- ^ Rows
       -> Int -- ^ Columns
       -> ((Int,Int) -> a) -- ^ Generator function
       -> Matrix a
{-# INLINE matrix #-}
matrix :: forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
m (Int, Int) -> a
f = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
0 Int
0 Int
m forall a b. (a -> b) -> a -> b
$ forall a. (forall s. ST s (MVector s a)) -> Vector a
V.create forall a b. (a -> b) -> a -> b
$ do
  MVector s a
v <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MV.new forall a b. (a -> b) -> a -> b
$ Int
n forall a. Num a => a -> a -> a
* Int
m
  let en :: (Int, Int) -> Int
en = Int -> (Int, Int) -> Int
encode Int
m
  forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
n forall a b. (a -> b) -> a -> b
$
    \Int
i -> forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
m forall a b. (a -> b) -> a -> b
$
    \Int
j -> forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.unsafeWrite MVector s a
v ((Int, Int) -> Int
en (Int
i,Int
j)) ((Int, Int) -> a
f (Int
i,Int
j))
  forall (m :: * -> *) a. Monad m => a -> m a
return MVector s a
v

-- | /O(rows*cols)/. Identity matrix of the given order.
--
-- > identity n =
-- >                 n
-- >   1 ( 1 0 ... 0 0 )
-- >   2 ( 0 1 ... 0 0 )
-- >     (     ...     )
-- >     ( 0 0 ... 1 0 )
-- >   n ( 0 0 ... 0 1 )
--
identity :: Num a => Int -> Matrix a
identity :: forall a. Num a => Int -> Matrix a
identity Int
n = forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
n forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> if Int
i forall a. Eq a => a -> a -> Bool
== Int
j then a
1 else a
0

-- | Similar to 'diagonalList', but using 'V.Vector', which
--   should be more efficient.
diagonal :: a -- ^ Default element
         -> V.Vector a  -- ^ Diagonal vector
         -> Matrix a
diagonal :: forall a. a -> Vector a -> Matrix a
diagonal a
e Vector a
v = forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
n forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> if Int
i forall a. Eq a => a -> a -> Bool
== Int
j then forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v (Int
i forall a. Num a => a -> a -> a
- Int
1) else a
e
  where
    n :: Int
n = forall a. Vector a -> Int
V.length Vector a
v

-- | Create a matrix from a non-empty list given the desired size.
--   The list must have at least /rows*cols/ elements.
--   An example:
--
-- >                       ( 1 2 3 )
-- >                       ( 4 5 6 )
-- > fromList 3 3 [1..] =  ( 7 8 9 )
--
fromList :: Int -- ^ Rows
         -> Int -- ^ Columns
         -> [a] -- ^ List of elements
         -> Matrix a
{-# INLINE fromList #-}
fromList :: forall a. Int -> Int -> [a] -> Matrix a
fromList Int
n Int
m [a]
xs
    | Int
nforall a. Num a => a -> a -> a
*Int
m forall a. Ord a => a -> a -> Bool
> forall a. Vector a -> Int
V.length Vector a
v =
        (forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
            String
"List size "
            forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall a. Vector a -> Int
V.length Vector a
v)
            forall a. [a] -> [a] -> [a]
++ String
" is inconsistent with matrix size "
            forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n Int
m
            forall a. [a] -> [a] -> [a]
++ String
" in fromList")
    | Bool
otherwise       = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
0 Int
0 Int
m Vector a
v
    where v :: Vector a
v = forall a. Int -> [a] -> Vector a
V.fromListN (Int
nforall a. Num a => a -> a -> a
*Int
m) [a]
xs

-- | Get the elements of a matrix stored in a list.
--
-- >        ( 1 2 3 )
-- >        ( 4 5 6 )
-- > toList ( 7 8 9 ) = [1,2,3,4,5,6,7,8,9]
--
toList :: Matrix a -> [a]
toList :: forall a. Matrix a -> [a]
toList Matrix a
m = [ forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m | Int
i <- [Int
1 .. forall a. Matrix a -> Int
nrows Matrix a
m] , Int
j <- [Int
1 .. forall a. Matrix a -> Int
ncols Matrix a
m] ]

-- | Get the elements of a matrix stored in a list of lists,
--   where each list contains the elements of a single row.
--
-- >         ( 1 2 3 )   [ [1,2,3]
-- >         ( 4 5 6 )   , [4,5,6]
-- > toLists ( 7 8 9 ) = , [7,8,9] ]
--
toLists :: Matrix a -> [[a]]
toLists :: forall a. Matrix a -> [[a]]
toLists Matrix a
m = [ [ forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m | Int
j <- [Int
1 .. forall a. Matrix a -> Int
ncols Matrix a
m] ] | Int
i <- [Int
1 .. forall a. Matrix a -> Int
nrows Matrix a
m] ]

-- | Diagonal matrix from a non-empty list given the desired size.
--   Non-diagonal elements will be filled with the given default element.
--   The list must have at least /order/ elements.
--
-- > diagonalList n 0 [1..] =
-- >                   n
-- >   1 ( 1 0 ... 0   0 )
-- >   2 ( 0 2 ... 0   0 )
-- >     (     ...       )
-- >     ( 0 0 ... n-1 0 )
-- >   n ( 0 0 ... 0   n )
--
diagonalList :: Int -> a -> [a] -> Matrix a
diagonalList :: forall a. Int -> a -> [a] -> Matrix a
diagonalList Int
n a
e [a]
xs = forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
n forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> if Int
i forall a. Eq a => a -> a -> Bool
== Int
j then [a]
xs forall a. [a] -> Int -> a
!! (Int
i forall a. Num a => a -> a -> a
- Int
1) else a
e

-- | Create a matrix from a non-empty list of non-empty lists.
--   /Each list must have at least as many elements as the first list/.
--   Examples:
--
-- > fromLists [ [1,2,3]      ( 1 2 3 )
-- >           , [4,5,6]      ( 4 5 6 )
-- >           , [7,8,9] ] =  ( 7 8 9 )
--
-- > fromLists [ [1,2,3  ]     ( 1 2 3 )
-- >           , [4,5,6,7]     ( 4 5 6 )
-- >           , [8,9,0  ] ] = ( 8 9 0 )
--
fromLists :: [[a]] -> Matrix a
{-# INLINE fromLists #-}
fromLists :: forall a. [[a]] -> Matrix a
fromLists [] = forall a. HasCallStack => String -> a
error String
"fromLists: empty list."
fromLists ([a]
xs:[[a]]
xss) = forall a. Int -> Int -> [a] -> Matrix a
fromList Int
n Int
m forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ [a]
xs forall a. a -> [a] -> [a]
: forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> [a] -> [a]
take Int
m) [[a]]
xss
  where
    n :: Int
n = Int
1 forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
length [[a]]
xss
    m :: Int
m = forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs

-- | /O(1)/. Represent a vector as a one row matrix.
rowVector :: V.Vector a -> Matrix a
rowVector :: forall a. Vector a -> Matrix a
rowVector Vector a
v = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
1 Int
m Int
0 Int
0 Int
m Vector a
v
  where
    m :: Int
m = forall a. Vector a -> Int
V.length Vector a
v

-- | /O(1)/. Represent a vector as a one column matrix.
colVector :: V.Vector a -> Matrix a
colVector :: forall a. Vector a -> Matrix a
colVector Vector a
v = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M (forall a. Vector a -> Int
V.length Vector a
v) Int
1 Int
0 Int
0 Int
1 Vector a
v

-- | /O(rows*cols)/. Permutation matrix.
--
-- > permMatrix n i j =
-- >               i     j       n
-- >   1 ( 1 0 ... 0 ... 0 ... 0 0 )
-- >   2 ( 0 1 ... 0 ... 0 ... 0 0 )
-- >     (     ...   ...   ...     )
-- >   i ( 0 0 ... 0 ... 1 ... 0 0 )
-- >     (     ...   ...   ...     )
-- >   j ( 0 0 ... 1 ... 0 ... 0 0 )
-- >     (     ...   ...   ...     )
-- >     ( 0 0 ... 0 ... 0 ... 1 0 )
-- >   n ( 0 0 ... 0 ... 0 ... 0 1 )
--
-- When @i == j@ it reduces to 'identity' @n@.
--
permMatrix :: Num a
           => Int -- ^ Size of the matrix.
           -> Int -- ^ Permuted row 1.
           -> Int -- ^ Permuted row 2.
           -> Matrix a -- ^ Permutation matrix.
permMatrix :: forall a. Num a => Int -> Int -> Int -> Matrix a
permMatrix Int
n Int
r1 Int
r2 | Int
r1 forall a. Eq a => a -> a -> Bool
== Int
r2 = forall a. Num a => Int -> Matrix a
identity Int
n
permMatrix Int
n Int
r1 Int
r2 = forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
n forall {a}. Num a => (Int, Int) -> a
f
 where
  f :: (Int, Int) -> a
f (Int
i,Int
j)
   | Int
i forall a. Eq a => a -> a -> Bool
== Int
r1 = if Int
j forall a. Eq a => a -> a -> Bool
== Int
r2 then a
1 else a
0
   | Int
i forall a. Eq a => a -> a -> Bool
== Int
r2 = if Int
j forall a. Eq a => a -> a -> Bool
== Int
r1 then a
1 else a
0
   | Int
i forall a. Eq a => a -> a -> Bool
== Int
j = a
1
   | Bool
otherwise = a
0

-------------------------------------------------------
-------------------------------------------------------
---- ACCESSING

-- | /O(1)/. Get an element of a matrix. Indices range from /(1,1)/ to /(n,m)/.
--   It returns an 'error' if the requested element is outside of range.
getElem :: Int      -- ^ Row
        -> Int      -- ^ Column
        -> Matrix a -- ^ Matrix
        -> a
{-# INLINE getElem #-}
getElem :: forall a. Int -> Int -> Matrix a -> a
getElem Int
i Int
j Matrix a
m =
  forall a. a -> Maybe a -> a
fromMaybe
    (forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
       String
"getElem: Trying to get the "
        forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (Int
i, Int
j)
        forall a. [a] -> [a] -> [a]
++ String
" element from a "
        forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr (forall a. Matrix a -> Int
nrows Matrix a
m) (forall a. Matrix a -> Int
ncols Matrix a
m)
        forall a. [a] -> [a] -> [a]
++ String
" matrix."
    )
    (forall a. Int -> Int -> Matrix a -> Maybe a
safeGet Int
i Int
j Matrix a
m)

-- | /O(1)/. Unsafe variant of 'getElem', without bounds checking.
unsafeGet :: Int      -- ^ Row
          -> Int      -- ^ Column
          -> Matrix a -- ^ Matrix
          -> a
{-# INLINE unsafeGet #-}
unsafeGet :: forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j (M Int
_ Int
_ Int
ro Int
co Int
w Vector a
v) = forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v forall a b. (a -> b) -> a -> b
$ Int -> (Int, Int) -> Int
encode Int
w (Int
iforall a. Num a => a -> a -> a
+Int
ro,Int
jforall a. Num a => a -> a -> a
+Int
co)

-- | Short alias for 'getElem'.
(!) :: Matrix a -> (Int,Int) -> a
{-# INLINE (!) #-}
Matrix a
m ! :: forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j) = forall a. Int -> Int -> Matrix a -> a
getElem Int
i Int
j Matrix a
m

-- | Internal alias for 'unsafeGet'.
(!.) :: Matrix a -> (Int,Int) -> a
{-# INLINE (!.) #-}
Matrix a
m !. :: forall a. Matrix a -> (Int, Int) -> a
!. (Int
i,Int
j) = forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m

-- | Variant of 'getElem' that returns Maybe instead of an error.
safeGet :: Int -> Int -> Matrix a -> Maybe a
safeGet :: forall a. Int -> Int -> Matrix a -> Maybe a
safeGet Int
i Int
j a :: Matrix a
a@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_)
 | Int
i forall a. Ord a => a -> a -> Bool
> Int
n Bool -> Bool -> Bool
|| Int
j forall a. Ord a => a -> a -> Bool
> Int
m Bool -> Bool -> Bool
|| Int
i forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
j forall a. Ord a => a -> a -> Bool
< Int
1 = forall a. Maybe a
Nothing
 | Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
a

-- | Variant of 'setElem' that returns Maybe instead of an error.
safeSet:: a -> (Int, Int) -> Matrix a -> Maybe (Matrix a)
safeSet :: forall a. a -> (Int, Int) -> Matrix a -> Maybe (Matrix a)
safeSet a
x p :: (Int, Int)
p@(Int
i,Int
j) a :: Matrix a
a@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_)
  | Int
i forall a. Ord a => a -> a -> Bool
> Int
n Bool -> Bool -> Bool
|| Int
j forall a. Ord a => a -> a -> Bool
> Int
m Bool -> Bool -> Bool
|| Int
i forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
j forall a. Ord a => a -> a -> Bool
< Int
1 = forall a. Maybe a
Nothing
  | Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. a -> (Int, Int) -> Matrix a -> Matrix a
unsafeSet a
x (Int, Int)
p Matrix a
a

-- | /O(1)/. Get a row of a matrix as a vector.
getRow :: Int -> Matrix a -> V.Vector a
{-# INLINE getRow #-}
getRow :: forall a. Int -> Matrix a -> Vector a
getRow Int
i (M Int
_ Int
m Int
ro Int
co Int
w Vector a
v) = forall a. Int -> Int -> Vector a -> Vector a
V.slice (Int
wforall a. Num a => a -> a -> a
*(Int
iforall a. Num a => a -> a -> a
-Int
1forall a. Num a => a -> a -> a
+Int
ro) forall a. Num a => a -> a -> a
+ Int
co) Int
m Vector a
v

-- | Varian of 'getRow' that returns a maybe instead of an error
safeGetRow :: Int -> Matrix a -> Maybe (V.Vector a)
safeGetRow :: forall a. Int -> Matrix a -> Maybe (Vector a)
safeGetRow Int
r Matrix a
m
    | Int
r forall a. Ord a => a -> a -> Bool
> forall a. Matrix a -> Int
nrows Matrix a
m Bool -> Bool -> Bool
|| Int
r forall a. Ord a => a -> a -> Bool
< Int
1 = forall a. Maybe a
Nothing
    | Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. Int -> Matrix a -> Vector a
getRow Int
r Matrix a
m

-- | /O(rows)/. Get a column of a matrix as a vector.
getCol :: Int -> Matrix a -> V.Vector a
{-# INLINE getCol #-}
getCol :: forall a. Int -> Matrix a -> Vector a
getCol Int
j (M Int
n Int
_ Int
ro Int
co Int
w Vector a
v) = forall a. Int -> (Int -> a) -> Vector a
V.generate Int
n forall a b. (a -> b) -> a -> b
$ \Int
i -> Vector a
v forall a. Vector a -> Int -> a
V.! Int -> (Int, Int) -> Int
encode Int
w (Int
iforall a. Num a => a -> a -> a
+Int
1forall a. Num a => a -> a -> a
+Int
ro,Int
jforall a. Num a => a -> a -> a
+Int
co)

-- | Varian of 'getColumn' that returns a maybe instead of an error
safeGetCol :: Int -> Matrix a -> Maybe (V.Vector a)
safeGetCol :: forall a. Int -> Matrix a -> Maybe (Vector a)
safeGetCol Int
c Matrix a
m
    | Int
c forall a. Ord a => a -> a -> Bool
> forall a. Matrix a -> Int
ncols Matrix a
m Bool -> Bool -> Bool
|| Int
c forall a. Ord a => a -> a -> Bool
< Int
1 = forall a. Maybe a
Nothing
    | Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. Int -> Matrix a -> Vector a
getCol Int
c Matrix a
m

-- | /O(min rows cols)/. Diagonal of a /not necessarily square/ matrix.
getDiag :: Matrix a -> V.Vector a
getDiag :: forall a. Matrix a -> Vector a
getDiag Matrix a
m = forall a. Int -> (Int -> a) -> Vector a
V.generate Int
k forall a b. (a -> b) -> a -> b
$ \Int
i -> Matrix a
m forall a. Matrix a -> (Int, Int) -> a
! (Int
iforall a. Num a => a -> a -> a
+Int
1,Int
iforall a. Num a => a -> a -> a
+Int
1)
 where
  k :: Int
k = forall a. Ord a => a -> a -> a
min (forall a. Matrix a -> Int
nrows Matrix a
m) (forall a. Matrix a -> Int
ncols Matrix a
m)

-- | /O(rows*cols)/. Transform a 'Matrix' to a 'V.Vector' of size /rows*cols/.
--  This is equivalent to get all the rows of the matrix using 'getRow'
--  and then append them, but far more efficient.
getMatrixAsVector :: Matrix a -> V.Vector a
getMatrixAsVector :: forall a. Matrix a -> Vector a
getMatrixAsVector = forall a. Matrix a -> Vector a
mvect forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Matrix a -> Matrix a
forceMatrix

-------------------------------------------------------
-------------------------------------------------------
---- MANIPULATING MATRICES

msetElem :: PrimMonad m
         => a -- ^ New element
         -> Int -- ^ Number of columns of the matrix
         -> Int -- ^ Row offset
         -> Int -- ^ Column offset
         -> (Int,Int) -- ^ Position to set the new element
         -> MV.MVector (PrimState m) a -- ^ Mutable vector
         -> m ()
{-# INLINE msetElem #-}
msetElem :: forall (m :: * -> *) a.
PrimMonad m =>
a
-> Int
-> Int
-> Int
-> (Int, Int)
-> MVector (PrimState m) a
-> m ()
msetElem a
x Int
w Int
ro Int
co (Int
i,Int
j) MVector (PrimState m) a
v = forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector (PrimState m) a
v (Int -> (Int, Int) -> Int
encode Int
w (Int
iforall a. Num a => a -> a -> a
+Int
ro,Int
jforall a. Num a => a -> a -> a
+Int
co)) a
x

unsafeMset :: PrimMonad m
         => a -- ^ New element
         -> Int -- ^ Number of columns of the matrix
         -> Int -- ^ Row offset
         -> Int -- ^ Column offset
         -> (Int,Int) -- ^ Position to set the new element
         -> MV.MVector (PrimState m) a -- ^ Mutable vector
         -> m ()
{-# INLINE unsafeMset #-}
unsafeMset :: forall (m :: * -> *) a.
PrimMonad m =>
a
-> Int
-> Int
-> Int
-> (Int, Int)
-> MVector (PrimState m) a
-> m ()
unsafeMset a
x Int
w Int
ro Int
co (Int
i,Int
j) MVector (PrimState m) a
v = forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.unsafeWrite MVector (PrimState m) a
v (Int -> (Int, Int) -> Int
encode Int
w (Int
iforall a. Num a => a -> a -> a
+Int
ro,Int
jforall a. Num a => a -> a -> a
+Int
co)) a
x

-- | Replace the value of a cell in a matrix.
setElem :: a -- ^ New value.
        -> (Int,Int) -- ^ Position to replace.
        -> Matrix a -- ^ Original matrix.
        -> Matrix a -- ^ Matrix with the given position replaced with the given value.
{-# INLINE setElem #-}
setElem :: forall a. a -> (Int, Int) -> Matrix a -> Matrix a
setElem a
x (Int, Int)
p (M Int
n Int
m Int
ro Int
co Int
w Vector a
v) = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
ro Int
co Int
w forall a b. (a -> b) -> a -> b
$ forall a.
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
V.modify (forall (m :: * -> *) a.
PrimMonad m =>
a
-> Int
-> Int
-> Int
-> (Int, Int)
-> MVector (PrimState m) a
-> m ()
msetElem a
x Int
w Int
ro Int
co (Int, Int)
p) Vector a
v

-- | Unsafe variant of 'setElem', without bounds checking.
unsafeSet :: a -- ^ New value.
        -> (Int,Int) -- ^ Position to replace.
        -> Matrix a -- ^ Original matrix.
        -> Matrix a -- ^ Matrix with the given position replaced with the given value.
{-# INLINE unsafeSet #-}
unsafeSet :: forall a. a -> (Int, Int) -> Matrix a -> Matrix a
unsafeSet a
x (Int, Int)
p (M Int
n Int
m Int
ro Int
co Int
w Vector a
v) = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
ro Int
co Int
w forall a b. (a -> b) -> a -> b
$ forall a.
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
V.modify (forall (m :: * -> *) a.
PrimMonad m =>
a
-> Int
-> Int
-> Int
-> (Int, Int)
-> MVector (PrimState m) a
-> m ()
unsafeMset a
x Int
w Int
ro Int
co (Int, Int)
p) Vector a
v

-- | /O(rows*cols)/. The transpose of a matrix.
--   Example:
--
-- >           ( 1 2 3 )   ( 1 4 7 )
-- >           ( 4 5 6 )   ( 2 5 8 )
-- > transpose ( 7 8 9 ) = ( 3 6 9 )
transpose :: Matrix a -> Matrix a
transpose :: forall a. Matrix a -> Matrix a
transpose Matrix a
m = forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (forall a. Matrix a -> Int
ncols Matrix a
m) (forall a. Matrix a -> Int
nrows Matrix a
m) forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> Matrix a
m forall a. Matrix a -> (Int, Int) -> a
! (Int
j,Int
i)

-- | /O(rows*rows*rows*rows) = O(cols*cols*cols*cols)/. The inverse of a square matrix.
--   Uses naive Gaussian elimination formula.
inverse :: (Fractional a, Eq a) => Matrix a -> Either String (Matrix a)
inverse :: forall a.
(Fractional a, Eq a) =>
Matrix a -> Either String (Matrix a)
inverse Matrix a
m
    | forall a. Matrix a -> Int
ncols Matrix a
m forall a. Eq a => a -> a -> Bool
/= forall a. Matrix a -> Int
nrows Matrix a
m
        = forall a b. a -> Either a b
Left
            forall a b. (a -> b) -> a -> b
$ String
"Inverting non-square matrix with dimensions "
                forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (Int -> Int -> String
sizeStr (forall a. Matrix a -> Int
ncols Matrix a
m) (forall a. Matrix a -> Int
nrows Matrix a
m))
    | Bool
otherwise =
        let
            adjoinedWId :: Matrix a
adjoinedWId = Matrix a
m forall a. Matrix a -> Matrix a -> Matrix a
<|> forall a. Num a => Int -> Matrix a
identity (forall a. Matrix a -> Int
nrows Matrix a
m)
            rref'd :: Either String (Matrix a)
rref'd = forall a.
(Fractional a, Eq a) =>
Matrix a -> Either String (Matrix a)
rref Matrix a
adjoinedWId
            checkInvertible :: Matrix a -> Either String (Matrix a)
checkInvertible Matrix a
a = if forall a. Int -> Int -> Matrix a -> a
unsafeGet (forall a. Matrix a -> Int
ncols Matrix a
m) (forall a. Matrix a -> Int
nrows Matrix a
m) Matrix a
a forall a. Eq a => a -> a -> Bool
== a
1
                then forall a b. b -> Either a b
Right Matrix a
a
                else forall a b. a -> Either a b
Left String
"Attempt to invert a non-invertible matrix"
        in Either String (Matrix a)
rref'd forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {a}. (Eq a, Num a) => Matrix a -> Either String (Matrix a)
checkInvertible forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix Int
1 (forall a. Matrix a -> Int
nrows Matrix a
m) (forall a. Matrix a -> Int
ncols Matrix a
m forall a. Num a => a -> a -> a
+ Int
1) (forall a. Matrix a -> Int
ncols Matrix a
m forall a. Num a => a -> a -> a
* Int
2)


-- | Converts a matrix to reduced row echelon form, thus
--   solving a linear system of equations. This requires that (cols > rows)
--   if cols < rows, then there are fewer variables than equations and the
--   problem cannot be solved consistently. If rows = cols, then it is
--   basically a homogenous system of equations, so it will be reduced to
--   identity or an error depending on whether the marix is invertible
--   (this case is allowed for robustness).
--   This implementation is taken from rosettacode
--   https://rosettacode.org/wiki/Reduced_row_echelon_form#Haskell
rref :: (Fractional a, Eq a) => Matrix a -> Either String (Matrix a)
rref :: forall a.
(Fractional a, Eq a) =>
Matrix a -> Either String (Matrix a)
rref Matrix a
m
        | forall a. Matrix a -> Int
ncols Matrix a
m forall a. Ord a => a -> a -> Bool
< forall a. Matrix a -> Int
nrows Matrix a
m
            = forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$
                String
"Invalid dimensions "
                    forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (Int -> Int -> String
sizeStr (forall a. Matrix a -> Int
ncols Matrix a
m) (forall a. Matrix a -> Int
nrows Matrix a
m))
        | Bool
otherwise = forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [[a]] -> Matrix a
fromLists forall a b. (a -> b) -> a -> b
$ forall {a}. (Eq a, Fractional a) => [[a]] -> Int -> [Int] -> [[a]]
f [[a]]
matM Int
0 [Int
0 .. Int
rows forall a. Num a => a -> a -> a
- Int
1]
  where
    matM :: [[a]]
matM = forall a. Matrix a -> [[a]]
toLists Matrix a
m
    rows :: Int
rows = forall a. Matrix a -> Int
nrows Matrix a
m
    cols :: Int
cols = forall a. Matrix a -> Int
ncols Matrix a
m

    f :: [[a]] -> Int -> [Int] -> [[a]]
f [[a]]
a Int
_    []           = [[a]]
a
    f [[a]]
a Int
lead (Int
r : [Int]
rs)
      | forall a. Maybe a -> Bool
isNothing Maybe (Int, Int)
indices = [[a]]
a
      | Bool
otherwise         = [[a]] -> Int -> [Int] -> [[a]]
f [[a]]
a' (Int
lead' forall a. Num a => a -> a -> a
+ Int
1) [Int]
rs
      where
        indices :: Maybe (Int, Int)
indices = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Int, Int) -> Bool
p [(Int, Int)]
l
        p :: (Int, Int) -> Bool
p (Int
col, Int
row) = [[a]]
a forall a. [a] -> Int -> a
!! Int
row forall a. [a] -> Int -> a
!! Int
col forall a. Eq a => a -> a -> Bool
/= a
0
        l :: [(Int, Int)]
l = [(Int
col, Int
row) |
            Int
col <- [Int
lead .. Int
cols forall a. Num a => a -> a -> a
- Int
1],
            Int
row <- [Int
r .. Int
rows forall a. Num a => a -> a -> a
- Int
1]]

        Just (Int
lead', Int
i) = Maybe (Int, Int)
indices
        newRow :: [a]
newRow = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Fractional a => a -> a -> a
/ [[a]]
a forall a. [a] -> Int -> a
!! Int
i forall a. [a] -> Int -> a
!! Int
lead') forall a b. (a -> b) -> a -> b
$ [[a]]
a forall a. [a] -> Int -> a
!! Int
i

        a' :: [[a]]
a' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> [a] -> [a]
g [Int
0..] forall a b. (a -> b) -> a -> b
$
            forall b. Int -> b -> [b] -> [b]
replace Int
r [a]
newRow forall a b. (a -> b) -> a -> b
$
            forall b. Int -> b -> [b] -> [b]
replace Int
i ([[a]]
a forall a. [a] -> Int -> a
!! Int
r) [[a]]
a
        g :: Int -> [a] -> [a]
g Int
n [a]
row
            | Int
n forall a. Eq a => a -> a -> Bool
== Int
r    = [a]
row
            | Bool
otherwise = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith a -> a -> a
h [a]
newRow [a]
row
              where h :: a -> a -> a
h = forall a. Num a => a -> a -> a
subtract forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Num a => a -> a -> a
* [a]
row forall a. [a] -> Int -> a
!! Int
lead')

        replace :: Int -> b -> [b] -> [b]
        {- Replaces the element at the given index. -}
        replace :: forall b. Int -> b -> [b] -> [b]
replace Int
n b
e [b]
t = [b]
a forall a. [a] -> [a] -> [a]
++ b
e forall a. a -> [a] -> [a]
: [b]
b
          where ([b]
a, b
_ : [b]
b) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [b]
t


-- | Extend a matrix to a given size adding a default element.
--   If the matrix already has the required size, nothing happens.
--   The matrix is /never/ reduced in size.
--   Example:
--
-- >                            ( 1 2 3 0 0 )
-- >                ( 1 2 3 )   ( 4 5 6 0 0 )
-- >                ( 4 5 6 )   ( 7 8 9 0 0 )
-- > extendTo 0 4 5 ( 7 8 9 ) = ( 0 0 0 0 0 )
--
-- The definition of 'extendTo' is based on 'setSize':
--
-- > extendTo e n m a = setSize e (max n $ nrows a) (max m $ ncols a) a
--
extendTo :: a   -- ^ Element to add when extending.
         -> Int -- ^ Minimal number of rows.
         -> Int -- ^ Minimal number of columns.
         -> Matrix a -> Matrix a
extendTo :: forall a. a -> Int -> Int -> Matrix a -> Matrix a
extendTo a
e Int
n Int
m Matrix a
a = forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
e (forall a. Ord a => a -> a -> a
max Int
n forall a b. (a -> b) -> a -> b
$ forall a. Matrix a -> Int
nrows Matrix a
a) (forall a. Ord a => a -> a -> a
max Int
m forall a b. (a -> b) -> a -> b
$ forall a. Matrix a -> Int
ncols Matrix a
a) Matrix a
a

-- | Set the size of a matrix to given parameters. Use a default element
--   for undefined entries if the matrix has been extended.
setSize :: a   -- ^ Default element.
        -> Int -- ^ Number of rows.
        -> Int -- ^ Number of columns.
        -> Matrix a
        -> Matrix a
{-# INLINE setSize #-}
setSize :: forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
e Int
n Int
m a :: Matrix a
a@(M Int
n0 Int
m0 Int
_ Int
_ Int
_ Vector a
_) = forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
m forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) ->
  if Int
i forall a. Ord a => a -> a -> Bool
<= Int
n0 Bool -> Bool -> Bool
&& Int
j forall a. Ord a => a -> a -> Bool
<= Int
m0
     then forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
a
     else a
e

-------------------------------------------------------
-------------------------------------------------------
---- WORKING WITH BLOCKS

-- | /O(1)/. Extract a submatrix given row and column limits.
--   Example:
--
-- >                   ( 1 2 3 )
-- >                   ( 4 5 6 )   ( 2 3 )
-- > submatrix 1 2 2 3 ( 7 8 9 ) = ( 5 6 )
submatrix :: Int    -- ^ Starting row
          -> Int -- ^ Ending row
          -> Int    -- ^ Starting column
          -> Int -- ^ Ending column
          -> Matrix a
          -> Matrix a
{-# INLINE submatrix #-}
submatrix :: forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix Int
r1 Int
r2 Int
c1 Int
c2 (M Int
n Int
m Int
ro Int
co Int
w Vector a
v)
  | Int
r1 forall a. Ord a => a -> a -> Bool
< Int
1  Bool -> Bool -> Bool
|| Int
r1 forall a. Ord a => a -> a -> Bool
> Int
n = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"submatrix: starting row (" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
r1 forall a. [a] -> [a] -> [a]
++ String
") is out of range. Matrix has " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
n forall a. [a] -> [a] -> [a]
++ String
" rows."
  | Int
c1 forall a. Ord a => a -> a -> Bool
< Int
1  Bool -> Bool -> Bool
|| Int
c1 forall a. Ord a => a -> a -> Bool
> Int
m = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"submatrix: starting column (" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
c1 forall a. [a] -> [a] -> [a]
++ String
") is out of range. Matrix has " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
m forall a. [a] -> [a] -> [a]
++ String
" columns."
  | Int
r2 forall a. Ord a => a -> a -> Bool
< Int
r1 Bool -> Bool -> Bool
|| Int
r2 forall a. Ord a => a -> a -> Bool
> Int
n = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"submatrix: ending row (" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
r2 forall a. [a] -> [a] -> [a]
++ String
") is out of range. Matrix has " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
n forall a. [a] -> [a] -> [a]
++ String
" rows, and starting row is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
r1 forall a. [a] -> [a] -> [a]
++ String
"."
  | Int
c2 forall a. Ord a => a -> a -> Bool
< Int
c1 Bool -> Bool -> Bool
|| Int
c2 forall a. Ord a => a -> a -> Bool
> Int
m = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"submatrix: ending column (" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
c2 forall a. [a] -> [a] -> [a]
++ String
") is out of range. Matrix has " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
m forall a. [a] -> [a] -> [a]
++ String
" columns, and starting column is " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
c1 forall a. [a] -> [a] -> [a]
++ String
"."
  | Bool
otherwise = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M (Int
r2forall a. Num a => a -> a -> a
-Int
r1forall a. Num a => a -> a -> a
+Int
1) (Int
c2forall a. Num a => a -> a -> a
-Int
c1forall a. Num a => a -> a -> a
+Int
1) (Int
roforall a. Num a => a -> a -> a
+Int
r1forall a. Num a => a -> a -> a
-Int
1) (Int
coforall a. Num a => a -> a -> a
+Int
c1forall a. Num a => a -> a -> a
-Int
1) Int
w Vector a
v

-- | /O(rows*cols)/. Remove a row and a column from a matrix.
--   Example:
--
-- >                 ( 1 2 3 )
-- >                 ( 4 5 6 )   ( 1 3 )
-- > minorMatrix 2 2 ( 7 8 9 ) = ( 7 9 )
minorMatrix :: Int -- ^ Row @r@ to remove.
            -> Int -- ^ Column @c@ to remove.
            -> Matrix a -- ^ Original matrix.
            -> Matrix a -- ^ Matrix with row @r@ and column @c@ removed.
minorMatrix :: forall a. Int -> Int -> Matrix a -> Matrix a
minorMatrix Int
r0 Int
c0 (M Int
n Int
m Int
ro Int
co Int
w Vector a
v) =
  let r :: Int
r = Int
r0 forall a. Num a => a -> a -> a
+ Int
ro
      c :: Int
c = Int
c0 forall a. Num a => a -> a -> a
+ Int
co
  in  forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M (Int
nforall a. Num a => a -> a -> a
-Int
1) (Int
mforall a. Num a => a -> a -> a
-Int
1) Int
ro Int
co (Int
wforall a. Num a => a -> a -> a
-Int
1) forall a b. (a -> b) -> a -> b
$ forall a. (Int -> a -> Bool) -> Vector a -> Vector a
V.ifilter (\Int
k a
_ -> let (Int
i,Int
j) = Int -> Int -> (Int, Int)
decode Int
w Int
k in Int
i forall a. Eq a => a -> a -> Bool
/= Int
r Bool -> Bool -> Bool
&& Int
j forall a. Eq a => a -> a -> Bool
/= Int
c) Vector a
v

-- | /O(1)/. Make a block-partition of a matrix using a given element as reference.
--   The element will stay in the bottom-right corner of the top-left corner matrix.
--
-- >                 (             )   (      |      )
-- >                 (             )   ( ...  | ...  )
-- >                 (    x        )   (    x |      )
-- > splitBlocks i j (             ) = (-------------) , where x = a_{i,j}
-- >                 (             )   (      |      )
-- >                 (             )   ( ...  | ...  )
-- >                 (             )   (      |      )
--
--   Note that some blocks can end up empty. We use the following notation for these blocks:
--
-- > ( TL | TR )
-- > (---------)
-- > ( BL | BR )
--
--   Where T = Top, B = Bottom, L = Left, R = Right.
--
splitBlocks :: Int      -- ^ Row of the splitting element.
            -> Int      -- ^ Column of the splitting element.
            -> Matrix a -- ^ Matrix to split.
            -> (Matrix a,Matrix a
               ,Matrix a,Matrix a) -- ^ (TL,TR,BL,BR)
{-# INLINE[1] splitBlocks #-}
splitBlocks :: forall a.
Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
splitBlocks Int
i Int
j a :: Matrix a
a@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_) =
    ( forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix    Int
1  Int
i Int
1 Int
j Matrix a
a , forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix    Int
1  Int
i (Int
jforall a. Num a => a -> a -> a
+Int
1) Int
m Matrix a
a
    , forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix (Int
iforall a. Num a => a -> a -> a
+Int
1) Int
n Int
1 Int
j Matrix a
a , forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix (Int
iforall a. Num a => a -> a -> a
+Int
1) Int
n (Int
jforall a. Num a => a -> a -> a
+Int
1) Int
m Matrix a
a )

-- | Join blocks of the form detailed in 'splitBlocks'. Precisely:
--
-- > joinBlocks (tl,tr,bl,br) =
-- >   (tl <|> tr)
-- >       <->
-- >   (bl <|> br)
joinBlocks :: (Matrix a,Matrix a,Matrix a,Matrix a) -> Matrix a
{-# INLINE[1] joinBlocks #-}
joinBlocks :: forall a. (Matrix a, Matrix a, Matrix a, Matrix a) -> Matrix a
joinBlocks (Matrix a
tl,Matrix a
tr,Matrix a
bl,Matrix a
br) =
  let n :: Int
n  = forall a. Matrix a -> Int
nrows Matrix a
tl
      nb :: Int
nb = forall a. Matrix a -> Int
nrows Matrix a
bl
      n' :: Int
n' = Int
n forall a. Num a => a -> a -> a
+ Int
nb
      m :: Int
m  = forall a. Matrix a -> Int
ncols Matrix a
tl
      mr :: Int
mr = forall a. Matrix a -> Int
ncols Matrix a
tr
      m' :: Int
m' = Int
m forall a. Num a => a -> a -> a
+ Int
mr
      en :: (Int, Int) -> Int
en = Int -> (Int, Int) -> Int
encode Int
m'
  in  forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n' Int
m' Int
0 Int
0 Int
m' forall a b. (a -> b) -> a -> b
$ forall a. (forall s. ST s (MVector s a)) -> Vector a
V.create forall a b. (a -> b) -> a -> b
$ do
        MVector (PrimState (ST s)) a
v <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MV.new (Int
n'forall a. Num a => a -> a -> a
*Int
m')
        let wr :: Int -> a -> ST s ()
wr = forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector (PrimState (ST s)) a
v
        forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
n  forall a b. (a -> b) -> a -> b
$ \Int
i -> do
          forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
m  forall a b. (a -> b) -> a -> b
$ \Int
j -> Int -> a -> ST s ()
wr ((Int, Int) -> Int
en (Int
i ,Int
j  )) forall a b. (a -> b) -> a -> b
$ Matrix a
tl forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j)
          forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
mr forall a b. (a -> b) -> a -> b
$ \Int
j -> Int -> a -> ST s ()
wr ((Int, Int) -> Int
en (Int
i ,Int
jforall a. Num a => a -> a -> a
+Int
m)) forall a b. (a -> b) -> a -> b
$ Matrix a
tr forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j)
        forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
nb forall a b. (a -> b) -> a -> b
$ \Int
i -> do
          let i' :: Int
i' = Int
iforall a. Num a => a -> a -> a
+Int
n
          forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
m  forall a b. (a -> b) -> a -> b
$ \Int
j -> Int -> a -> ST s ()
wr ((Int, Int) -> Int
en (Int
i',Int
j  )) forall a b. (a -> b) -> a -> b
$ Matrix a
bl forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j)
          forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
mr forall a b. (a -> b) -> a -> b
$ \Int
j -> Int -> a -> ST s ()
wr ((Int, Int) -> Int
en (Int
i',Int
jforall a. Num a => a -> a -> a
+Int
m)) forall a b. (a -> b) -> a -> b
$ Matrix a
br forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j)
        forall (m :: * -> *) a. Monad m => a -> m a
return MVector (PrimState (ST s)) a
v

{-# RULES
"matrix/splitAndJoin"
   forall i j m. joinBlocks (splitBlocks i j m) = m
  #-}

-- | Horizontally join two matrices. Visually:
--
-- > ( A ) <|> ( B ) = ( A | B )
--
-- Where both matrices /A/ and /B/ have the same number of rows.
-- /This condition is not checked/.
(<|>) :: Matrix a -> Matrix a -> Matrix a
{-# INLINE (<|>) #-}
Matrix a
m <|> :: forall a. Matrix a -> Matrix a -> Matrix a
<|> Matrix a
m' =
  let c :: Int
c = forall a. Matrix a -> Int
ncols Matrix a
m
  in  forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (forall a. Matrix a -> Int
nrows Matrix a
m) (Int
c forall a. Num a => a -> a -> a
+ forall a. Matrix a -> Int
ncols Matrix a
m') forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) ->
        if Int
j forall a. Ord a => a -> a -> Bool
<= Int
c then Matrix a
m forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j) else Matrix a
m' forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
jforall a. Num a => a -> a -> a
-Int
c)

-- | Vertically join two matrices. Visually:
--
-- >                   ( A )
-- > ( A ) <-> ( B ) = ( - )
-- >                   ( B )
--
-- Where both matrices /A/ and /B/ have the same number of columns.
-- /This condition is not checked/.
(<->) :: Matrix a -> Matrix a -> Matrix a
{-# INLINE (<->) #-}
Matrix a
m <-> :: forall a. Matrix a -> Matrix a -> Matrix a
<-> Matrix a
m' =
  let r :: Int
r = forall a. Matrix a -> Int
nrows Matrix a
m
  in  forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (Int
r forall a. Num a => a -> a -> a
+ forall a. Matrix a -> Int
nrows Matrix a
m') (forall a. Matrix a -> Int
ncols Matrix a
m) forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) ->
        if Int
i forall a. Ord a => a -> a -> Bool
<= Int
r then Matrix a
m forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
j) else Matrix a
m' forall a. Matrix a -> (Int, Int) -> a
! (Int
iforall a. Num a => a -> a -> a
-Int
r,Int
j)

-------------------------------------------------------
-------------------------------------------------------
---- MATRIX OPERATIONS

-- | Perform an operation element-wise.
--   The second matrix must have at least as many rows
--   and columns as the first matrix. If it's bigger,
--   the leftover items will be ignored.
--   If it's smaller, it will cause a run-time error.
--   You may want to use 'elementwiseUnsafe' if you
--   are definitely sure that a run-time error won't
--   arise.
elementwise :: (a -> b -> c) -> (Matrix a -> Matrix b -> Matrix c)
elementwise :: forall a b c. (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c
elementwise a -> b -> c
f Matrix a
m Matrix b
m' = forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (forall a. Matrix a -> Int
nrows Matrix a
m) (forall a. Matrix a -> Int
ncols Matrix a
m) forall a b. (a -> b) -> a -> b
$
  \(Int, Int)
k -> a -> b -> c
f (Matrix a
m forall a. Matrix a -> (Int, Int) -> a
! (Int, Int)
k) (Matrix b
m' forall a. Matrix a -> (Int, Int) -> a
! (Int, Int)
k)

-- | Unsafe version of 'elementwise', but faster.
elementwiseUnsafe :: (a -> b -> c) -> (Matrix a -> Matrix b -> Matrix c)
{-# INLINE elementwiseUnsafe #-}
elementwiseUnsafe :: forall a b c. (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c
elementwiseUnsafe a -> b -> c
f Matrix a
m Matrix b
m' = forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix (forall a. Matrix a -> Int
nrows Matrix a
m) (forall a. Matrix a -> Int
ncols Matrix a
m) forall a b. (a -> b) -> a -> b
$
  \(Int
i,Int
j) -> a -> b -> c
f (forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
m) (forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix b
m')

infixl 6 +., -.

-- | Internal unsafe addition.
(+.) :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE (+.) #-}
+. :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
(+.) = forall a b c. (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c
elementwiseUnsafe forall a. Num a => a -> a -> a
(+)

-- | Internal unsafe substraction.
(-.) :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE (-.) #-}
-. :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
(-.) = forall a b c. (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c
elementwiseUnsafe (-)

-------------------------------------------------------
-------------------------------------------------------
---- MATRIX MULTIPLICATION

{- $mult

Four methods are provided for matrix multiplication.

* 'multStd':
     Matrix multiplication following directly the definition.
     This is the best choice when you know for sure that your
     matrices are small.

* 'multStd2':
     Matrix multiplication following directly the definition.
     However, using a different definition from 'multStd'.
     According to our benchmarks with this version, 'multStd2' is
     around 3 times faster than 'multStd'.

* 'multStrassen':
     Matrix multiplication following the Strassen's algorithm.
     Complexity grows slower but also some work is added
     partitioning the matrix. Also, it only works on square
     matrices of order @2^n@, so if this condition is not
     met, it is zero-padded until this is accomplished.
     Therefore, its use is not recommended.

* 'multStrassenMixed':
     This function mixes the previous methods.
     It provides a better performance in general. Method @(@'*'@)@
     of the 'Num' class uses this function because it gives the best
     average performance. However, if you know for sure that your matrices are
     small (size less than 500x500), you should use 'multStd' or 'multStd2' instead,
     since 'multStrassenMixed' is going to switch to those functions anyway.

We keep researching how to get better performance for matrix multiplication.
If you want to be on the safe side, use ('*').

-}

-- | Standard matrix multiplication by definition.
multStd :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE multStd #-}
multStd :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd a1 :: Matrix a
a1@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_) a2 :: Matrix a
a2@(M Int
n' Int
m' Int
_ Int
_ Int
_ Vector a
_)
   -- Checking that sizes match...
   | Int
m forall a. Eq a => a -> a -> Bool
/= Int
n' = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Multiplication of " forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n Int
m forall a. [a] -> [a] -> [a]
++ String
" and "
                    forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n' Int
m' forall a. [a] -> [a] -> [a]
++ String
" matrices."
   | Bool
otherwise = forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd_ Matrix a
a1 Matrix a
a2

-- | Standard matrix multiplication by definition.
multStd2 :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE multStd2 #-}
multStd2 :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd2 a1 :: Matrix a
a1@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_) a2 :: Matrix a
a2@(M Int
n' Int
m' Int
_ Int
_ Int
_ Vector a
_)
   -- Checking that sizes match...
   | Int
m forall a. Eq a => a -> a -> Bool
/= Int
n' = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Multiplication of " forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n Int
m forall a. [a] -> [a] -> [a]
++ String
" and "
                    forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n' Int
m' forall a. [a] -> [a] -> [a]
++ String
" matrices."
   | Bool
otherwise = forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd__ Matrix a
a1 Matrix a
a2

-- | Standard matrix multiplication by definition, without checking if sizes match.
multStd_ :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE multStd_ #-}
multStd_ :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd_ a :: Matrix a
a@(M Int
1 Int
1 Int
_ Int
_ Int
_ Vector a
_) b :: Matrix a
b@(M Int
1 Int
1 Int
_ Int
_ Int
_ Vector a
_) = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
1 Int
1 Int
0 Int
0 Int
1 forall a b. (a -> b) -> a -> b
$ forall a. a -> Vector a
V.singleton forall a b. (a -> b) -> a -> b
$ (Matrix a
a forall a. Matrix a -> (Int, Int) -> a
! (Int
1,Int
1)) forall a. Num a => a -> a -> a
* (Matrix a
b forall a. Matrix a -> (Int, Int) -> a
! (Int
1,Int
1))
multStd_ a :: Matrix a
a@(M Int
2 Int
2 Int
_ Int
_ Int
_ Vector a
_) b :: Matrix a
b@(M Int
2 Int
2 Int
_ Int
_ Int
_ Vector a
_) =
  forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
2 Int
2 Int
0 Int
0 Int
2 forall a b. (a -> b) -> a -> b
$
    let -- A
        a11 :: a
a11 = Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
1) ; a12 :: a
a12 = Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
2)
        a21 :: a
a21 = Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
1) ; a22 :: a
a22 = Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
2)
        -- B
        b11 :: a
b11 = Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
1) ; b12 :: a
b12 = Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
2)
        b21 :: a
b21 = Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
1) ; b22 :: a
b22 = Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
2)
    in forall a. [a] -> Vector a
V.fromList
         [ a
a11forall a. Num a => a -> a -> a
*a
b11 forall a. Num a => a -> a -> a
+ a
a12forall a. Num a => a -> a -> a
*a
b21 , a
a11forall a. Num a => a -> a -> a
*a
b12 forall a. Num a => a -> a -> a
+ a
a12forall a. Num a => a -> a -> a
*a
b22
         , a
a21forall a. Num a => a -> a -> a
*a
b11 forall a. Num a => a -> a -> a
+ a
a22forall a. Num a => a -> a -> a
*a
b21 , a
a21forall a. Num a => a -> a -> a
*a
b12 forall a. Num a => a -> a -> a
+ a
a22forall a. Num a => a -> a -> a
*a
b22
           ]
multStd_ a :: Matrix a
a@(M Int
3 Int
3 Int
_ Int
_ Int
_ Vector a
_) b :: Matrix a
b@(M Int
3 Int
3 Int
_ Int
_ Int
_ Vector a
_) =
  forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
3 Int
3 Int
0 Int
0 Int
3 forall a b. (a -> b) -> a -> b
$
    let -- A
        a11 :: a
a11 = Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
1) ; a12 :: a
a12 = Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
2) ; a13 :: a
a13 = Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
3)
        a21 :: a
a21 = Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
1) ; a22 :: a
a22 = Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
2) ; a23 :: a
a23 = Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
3)
        a31 :: a
a31 = Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
3,Int
1) ; a32 :: a
a32 = Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
3,Int
2) ; a33 :: a
a33 = Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
3,Int
3)
        -- B
        b11 :: a
b11 = Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
1) ; b12 :: a
b12 = Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
2) ; b13 :: a
b13 = Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
1,Int
3)
        b21 :: a
b21 = Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
1) ; b22 :: a
b22 = Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
2) ; b23 :: a
b23 = Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
2,Int
3)
        b31 :: a
b31 = Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
3,Int
1) ; b32 :: a
b32 = Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
3,Int
2) ; b33 :: a
b33 = Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
3,Int
3)
    in forall a. [a] -> Vector a
V.fromList
         [ a
a11forall a. Num a => a -> a -> a
*a
b11 forall a. Num a => a -> a -> a
+ a
a12forall a. Num a => a -> a -> a
*a
b21 forall a. Num a => a -> a -> a
+ a
a13forall a. Num a => a -> a -> a
*a
b31 , a
a11forall a. Num a => a -> a -> a
*a
b12 forall a. Num a => a -> a -> a
+ a
a12forall a. Num a => a -> a -> a
*a
b22 forall a. Num a => a -> a -> a
+ a
a13forall a. Num a => a -> a -> a
*a
b32 , a
a11forall a. Num a => a -> a -> a
*a
b13 forall a. Num a => a -> a -> a
+ a
a12forall a. Num a => a -> a -> a
*a
b23 forall a. Num a => a -> a -> a
+ a
a13forall a. Num a => a -> a -> a
*a
b33
         , a
a21forall a. Num a => a -> a -> a
*a
b11 forall a. Num a => a -> a -> a
+ a
a22forall a. Num a => a -> a -> a
*a
b21 forall a. Num a => a -> a -> a
+ a
a23forall a. Num a => a -> a -> a
*a
b31 , a
a21forall a. Num a => a -> a -> a
*a
b12 forall a. Num a => a -> a -> a
+ a
a22forall a. Num a => a -> a -> a
*a
b22 forall a. Num a => a -> a -> a
+ a
a23forall a. Num a => a -> a -> a
*a
b32 , a
a21forall a. Num a => a -> a -> a
*a
b13 forall a. Num a => a -> a -> a
+ a
a22forall a. Num a => a -> a -> a
*a
b23 forall a. Num a => a -> a -> a
+ a
a23forall a. Num a => a -> a -> a
*a
b33
         , a
a31forall a. Num a => a -> a -> a
*a
b11 forall a. Num a => a -> a -> a
+ a
a32forall a. Num a => a -> a -> a
*a
b21 forall a. Num a => a -> a -> a
+ a
a33forall a. Num a => a -> a -> a
*a
b31 , a
a31forall a. Num a => a -> a -> a
*a
b12 forall a. Num a => a -> a -> a
+ a
a32forall a. Num a => a -> a -> a
*a
b22 forall a. Num a => a -> a -> a
+ a
a33forall a. Num a => a -> a -> a
*a
b32 , a
a31forall a. Num a => a -> a -> a
*a
b13 forall a. Num a => a -> a -> a
+ a
a32forall a. Num a => a -> a -> a
*a
b23 forall a. Num a => a -> a -> a
+ a
a33forall a. Num a => a -> a -> a
*a
b33
           ]
multStd_ a :: Matrix a
a@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_) b :: Matrix a
b@(M Int
_ Int
m' Int
_ Int
_ Int
_ Vector a
_) = forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
n Int
m' forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [ Matrix a
a forall a. Matrix a -> (Int, Int) -> a
!. (Int
i,Int
k) forall a. Num a => a -> a -> a
* Matrix a
b forall a. Matrix a -> (Int, Int) -> a
!. (Int
k,Int
j) | Int
k <- [Int
1 .. Int
m] ]

multStd__ :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE multStd__ #-}
multStd__ :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd__ Matrix a
a Matrix a
b = forall a. Int -> Int -> ((Int, Int) -> a) -> Matrix a
matrix Int
r Int
c forall a b. (a -> b) -> a -> b
$ \(Int
i,Int
j) -> forall a. Num a => Vector a -> Vector a -> a
dotProduct (forall a. Vector a -> Int -> a
V.unsafeIndex Vector (Vector a)
avs forall a b. (a -> b) -> a -> b
$ Int
i forall a. Num a => a -> a -> a
- Int
1) (forall a. Vector a -> Int -> a
V.unsafeIndex Vector (Vector a)
bvs forall a b. (a -> b) -> a -> b
$ Int
j forall a. Num a => a -> a -> a
- Int
1)
  where
    r :: Int
r = forall a. Matrix a -> Int
nrows Matrix a
a
    avs :: Vector (Vector a)
avs = forall a. Int -> (Int -> a) -> Vector a
V.generate Int
r forall a b. (a -> b) -> a -> b
$ \Int
i -> forall a. Int -> Matrix a -> Vector a
getRow (Int
iforall a. Num a => a -> a -> a
+Int
1) Matrix a
a
    c :: Int
c = forall a. Matrix a -> Int
ncols Matrix a
b
    bvs :: Vector (Vector a)
bvs = forall a. Int -> (Int -> a) -> Vector a
V.generate Int
c forall a b. (a -> b) -> a -> b
$ \Int
i -> forall a. Int -> Matrix a -> Vector a
getCol (Int
iforall a. Num a => a -> a -> a
+Int
1) Matrix a
b

dotProduct :: Num a => V.Vector a -> V.Vector a -> a
{-# INLINE dotProduct #-}
dotProduct :: forall a. Num a => Vector a -> Vector a -> a
dotProduct Vector a
v1 Vector a
v2 = forall a acc.
(Num a, Eq a) =>
a -> a -> acc -> (acc -> a -> acc) -> acc
numLoopFold Int
0 (forall a. Vector a -> Int
V.length Vector a
v1 forall a. Num a => a -> a -> a
- Int
1) a
0 forall a b. (a -> b) -> a -> b
$
  \a
r Int
i -> forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v1 Int
i forall a. Num a => a -> a -> a
* forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v2 Int
i forall a. Num a => a -> a -> a
+ a
r

{-
dotProduct v1 v2 = go (V.length v1 - 1) 0
  where
    go (-1) a = a
    go i a = go (i-1) $ (V.unsafeIndex v1 i) * (V.unsafeIndex v2 i) + a
-}

first :: (a -> Bool) -> [a] -> a
first :: forall a. (a -> Bool) -> [a] -> a
first a -> Bool
f = [a] -> a
go
 where
  go :: [a] -> a
go (a
x:[a]
xs) = if a -> Bool
f a
x then a
x else [a] -> a
go [a]
xs
  go [a]
_ = forall a. HasCallStack => String -> a
error String
"first: no element match the condition."

-- | Strassen's algorithm over square matrices of order @2^n@.
strassen :: Num a => Matrix a -> Matrix a -> Matrix a
-- Trivial 1x1 multiplication.
strassen :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen a :: Matrix a
a@(M Int
1 Int
1 Int
_ Int
_ Int
_ Vector a
_) b :: Matrix a
b@(M Int
1 Int
1 Int
_ Int
_ Int
_ Vector a
_) = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
1 Int
1 Int
0 Int
0 Int
1 forall a b. (a -> b) -> a -> b
$ forall a. a -> Vector a
V.singleton forall a b. (a -> b) -> a -> b
$ (Matrix a
a forall a. Matrix a -> (Int, Int) -> a
! (Int
1,Int
1)) forall a. Num a => a -> a -> a
* (Matrix a
b forall a. Matrix a -> (Int, Int) -> a
! (Int
1,Int
1))
-- General case guesses that the input matrices are square matrices
-- whose order is a power of two.
strassen Matrix a
a Matrix a
b = forall a. (Matrix a, Matrix a, Matrix a, Matrix a) -> Matrix a
joinBlocks (Matrix a
c11,Matrix a
c12,Matrix a
c21,Matrix a
c22)
 where
  -- Size of the subproblem is halved.
  n :: Int
n = forall a. Integral a => a -> a -> a
div (forall a. Matrix a -> Int
nrows Matrix a
a) Int
2
  -- Split of the original problem into smaller subproblems.
  (Matrix a
a11,Matrix a
a12,Matrix a
a21,Matrix a
a22) = forall a.
Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
splitBlocks Int
n Int
n Matrix a
a
  (Matrix a
b11,Matrix a
b12,Matrix a
b21,Matrix a
b22) = forall a.
Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
splitBlocks Int
n Int
n Matrix a
b
  -- The seven Strassen's products.
  p1 :: Matrix a
p1 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen (Matrix a
a11 forall a. Num a => a -> a -> a
+ Matrix a
a22) (Matrix a
b11 forall a. Num a => a -> a -> a
+ Matrix a
b22)
  p2 :: Matrix a
p2 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen (Matrix a
a21 forall a. Num a => a -> a -> a
+ Matrix a
a22)  Matrix a
b11
  p3 :: Matrix a
p3 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen  Matrix a
a11        (Matrix a
b12 forall a. Num a => a -> a -> a
- Matrix a
b22)
  p4 :: Matrix a
p4 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen        Matrix a
a22  (Matrix a
b21 forall a. Num a => a -> a -> a
- Matrix a
b11)
  p5 :: Matrix a
p5 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen (Matrix a
a11 forall a. Num a => a -> a -> a
+ Matrix a
a12)        Matrix a
b22
  p6 :: Matrix a
p6 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen (Matrix a
a21 forall a. Num a => a -> a -> a
- Matrix a
a11) (Matrix a
b11 forall a. Num a => a -> a -> a
+ Matrix a
b12)
  p7 :: Matrix a
p7 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen (Matrix a
a12 forall a. Num a => a -> a -> a
- Matrix a
a22) (Matrix a
b21 forall a. Num a => a -> a -> a
+ Matrix a
b22)
  -- Merging blocks
  c11 :: Matrix a
c11 = Matrix a
p1 forall a. Num a => a -> a -> a
+ Matrix a
p4 forall a. Num a => a -> a -> a
- Matrix a
p5 forall a. Num a => a -> a -> a
+ Matrix a
p7
  c12 :: Matrix a
c12 = Matrix a
p3 forall a. Num a => a -> a -> a
+ Matrix a
p5
  c21 :: Matrix a
c21 = Matrix a
p2 forall a. Num a => a -> a -> a
+ Matrix a
p4
  c22 :: Matrix a
c22 = Matrix a
p1 forall a. Num a => a -> a -> a
- Matrix a
p2 forall a. Num a => a -> a -> a
+ Matrix a
p3 forall a. Num a => a -> a -> a
+ Matrix a
p6

-- | Strassen's matrix multiplication.
multStrassen :: Num a => Matrix a -> Matrix a -> Matrix a
multStrassen :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStrassen a1 :: Matrix a
a1@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_) a2 :: Matrix a
a2@(M Int
n' Int
m' Int
_ Int
_ Int
_ Vector a
_)
   | Int
m forall a. Eq a => a -> a -> Bool
/= Int
n' = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Multiplication of " forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n Int
m forall a. [a] -> [a] -> [a]
++ String
" and "
                    forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n' Int
m' forall a. [a] -> [a] -> [a]
++ String
" matrices."
   | Bool
otherwise =
       let mx :: Int
mx = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [Int
n,Int
m,Int
n',Int
m']
           n2 :: Int
n2  = forall a. (a -> Bool) -> [a] -> a
first (forall a. Ord a => a -> a -> Bool
>= Int
mx) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int
2forall a b. (Num a, Integral b) => a -> b -> a
^) [(Int
0 :: Int)..]
           b1 :: Matrix a
b1 = forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
0 Int
n2 Int
n2 Matrix a
a1
           b2 :: Matrix a
b2 = forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
0 Int
n2 Int
n2 Matrix a
a2
       in  forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix Int
1 Int
n Int
1 Int
m' forall a b. (a -> b) -> a -> b
$ forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassen Matrix a
b1 Matrix a
b2

strmixFactor :: Int
strmixFactor :: Int
strmixFactor = Int
300

-- | Strassen's mixed algorithm.
strassenMixed :: Num a => Matrix a -> Matrix a -> Matrix a
{-# SPECIALIZE strassenMixed :: Matrix Double -> Matrix Double -> Matrix Double #-}
{-# SPECIALIZE strassenMixed :: Matrix Int -> Matrix Int -> Matrix Int #-}
{-# SPECIALIZE strassenMixed :: Matrix Rational -> Matrix Rational -> Matrix Rational #-}
strassenMixed :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed Matrix a
a Matrix a
b
 | Int
r forall a. Ord a => a -> a -> Bool
< Int
strmixFactor = forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd__ Matrix a
a Matrix a
b
 | forall a. Integral a => a -> Bool
odd Int
r = let r' :: Int
r' = Int
r forall a. Num a => a -> a -> a
+ Int
1
               a' :: Matrix a
a' = forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
0 Int
r' Int
r' Matrix a
a
               b' :: Matrix a
b' = forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
0 Int
r' Int
r' Matrix a
b
           in  forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix Int
1 Int
r Int
1 Int
r forall a b. (a -> b) -> a -> b
$ forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed Matrix a
a' Matrix a
b'
 | Bool
otherwise =
      forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
r Int
r Int
0 Int
0 Int
r forall a b. (a -> b) -> a -> b
$ forall a. (forall s. ST s (MVector s a)) -> Vector a
V.create forall a b. (a -> b) -> a -> b
$ do
         MVector s a
v <- forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
MV.unsafeNew (Int
rforall a. Num a => a -> a -> a
*Int
r)
         let en :: (Int, Int) -> Int
en = Int -> (Int, Int) -> Int
encode Int
r
             n' :: Int
n' = Int
n forall a. Num a => a -> a -> a
+ Int
1
         -- c11 = p1 + p4 - p5 + p7
         forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s a
v Int
k forall a b. (a -> b) -> a -> b
$
                         forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
p1
                       forall a. Num a => a -> a -> a
+ forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
p4
                       forall a. Num a => a -> a -> a
- forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
p5
                       forall a. Num a => a -> a -> a
+ forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j Matrix a
p7
                   | Int
i <- [Int
1..Int
n]
                   , Int
j <- [Int
1..Int
n]
                   , let k :: Int
k = (Int, Int) -> Int
en (Int
i,Int
j)
                     ]
         -- c12 = p3 + p5
         forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s a
v Int
k forall a b. (a -> b) -> a -> b
$
                         forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j' Matrix a
p3
                       forall a. Num a => a -> a -> a
+ forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i Int
j' Matrix a
p5
                   | Int
i <- [Int
1..Int
n]
                   , Int
j <- [Int
n'..Int
r]
                   , let k :: Int
k = (Int, Int) -> Int
en (Int
i,Int
j)
                   , let j' :: Int
j' = Int
j forall a. Num a => a -> a -> a
- Int
n
                     ]
         -- c21 = p2 + p4
         forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s a
v Int
k forall a b. (a -> b) -> a -> b
$
                         forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i' Int
j Matrix a
p2
                       forall a. Num a => a -> a -> a
+ forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i' Int
j Matrix a
p4
                   | Int
i <- [Int
n'..Int
r]
                   , Int
j <- [Int
1..Int
n]
                   , let k :: Int
k = (Int, Int) -> Int
en (Int
i,Int
j)
                   , let i' :: Int
i' = Int
i forall a. Num a => a -> a -> a
- Int
n
                     ]
         -- c22 = p1 - p2 + p3 + p6
         forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write MVector s a
v Int
k forall a b. (a -> b) -> a -> b
$
                         forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i' Int
j' Matrix a
p1
                       forall a. Num a => a -> a -> a
- forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i' Int
j' Matrix a
p2
                       forall a. Num a => a -> a -> a
+ forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i' Int
j' Matrix a
p3
                       forall a. Num a => a -> a -> a
+ forall a. Int -> Int -> Matrix a -> a
unsafeGet Int
i' Int
j' Matrix a
p6
                   | Int
i <- [Int
n'..Int
r]
                   , Int
j <- [Int
n'..Int
r]
                   , let k :: Int
k = (Int, Int) -> Int
en (Int
i,Int
j)
                   , let i' :: Int
i' = Int
i forall a. Num a => a -> a -> a
- Int
n
                   , let j' :: Int
j' = Int
j forall a. Num a => a -> a -> a
- Int
n
                     ]
         forall (m :: * -> *) a. Monad m => a -> m a
return MVector s a
v
 where
  r :: Int
r = forall a. Matrix a -> Int
nrows Matrix a
a
  -- Size of the subproblem is halved.
  n :: Int
n = forall a. Integral a => a -> a -> a
quot Int
r Int
2
  -- Split of the original problem into smaller subproblems.
  (Matrix a
a11,Matrix a
a12,Matrix a
a21,Matrix a
a22) = forall a.
Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
splitBlocks Int
n Int
n Matrix a
a
  (Matrix a
b11,Matrix a
b12,Matrix a
b21,Matrix a
b22) = forall a.
Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
splitBlocks Int
n Int
n Matrix a
b
  -- The seven Strassen's products.
  p1 :: Matrix a
p1 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed (Matrix a
a11 forall a. Num a => Matrix a -> Matrix a -> Matrix a
+. Matrix a
a22) (Matrix a
b11 forall a. Num a => Matrix a -> Matrix a -> Matrix a
+. Matrix a
b22)
  p2 :: Matrix a
p2 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed (Matrix a
a21 forall a. Num a => Matrix a -> Matrix a -> Matrix a
+. Matrix a
a22)  Matrix a
b11
  p3 :: Matrix a
p3 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed  Matrix a
a11         (Matrix a
b12 forall a. Num a => Matrix a -> Matrix a -> Matrix a
-. Matrix a
b22)
  p4 :: Matrix a
p4 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed         Matrix a
a22  (Matrix a
b21 forall a. Num a => Matrix a -> Matrix a -> Matrix a
-. Matrix a
b11)
  p5 :: Matrix a
p5 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed (Matrix a
a11 forall a. Num a => Matrix a -> Matrix a -> Matrix a
+. Matrix a
a12)         Matrix a
b22
  p6 :: Matrix a
p6 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed (Matrix a
a21 forall a. Num a => Matrix a -> Matrix a -> Matrix a
-. Matrix a
a11) (Matrix a
b11 forall a. Num a => Matrix a -> Matrix a -> Matrix a
+. Matrix a
b12)
  p7 :: Matrix a
p7 = forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed (Matrix a
a12 forall a. Num a => Matrix a -> Matrix a -> Matrix a
-. Matrix a
a22) (Matrix a
b21 forall a. Num a => Matrix a -> Matrix a -> Matrix a
+. Matrix a
b22)

-- | Mixed Strassen's matrix multiplication.
multStrassenMixed :: Num a => Matrix a -> Matrix a -> Matrix a
{-# INLINE multStrassenMixed #-}
multStrassenMixed :: forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStrassenMixed a1 :: Matrix a
a1@(M Int
n Int
m Int
_ Int
_ Int
_ Vector a
_) a2 :: Matrix a
a2@(M Int
n' Int
m' Int
_ Int
_ Int
_ Vector a
_)
   | Int
m forall a. Eq a => a -> a -> Bool
/= Int
n' = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Multiplication of " forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n Int
m forall a. [a] -> [a] -> [a]
++ String
" and "
                    forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
sizeStr Int
n' Int
m' forall a. [a] -> [a] -> [a]
++ String
" matrices."
   | Int
n forall a. Ord a => a -> a -> Bool
< Int
strmixFactor = forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd__ Matrix a
a1 Matrix a
a2
   | Bool
otherwise =
       let mx :: Int
mx = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [Int
n,Int
m,Int
n',Int
m']
           n2 :: Int
n2 = if forall a. Integral a => a -> Bool
even Int
mx then Int
mx else Int
mxforall a. Num a => a -> a -> a
+Int
1
           b1 :: Matrix a
b1 = forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
0 Int
n2 Int
n2 Matrix a
a1
           b2 :: Matrix a
b2 = forall a. a -> Int -> Int -> Matrix a -> Matrix a
setSize a
0 Int
n2 Int
n2 Matrix a
a2
       in  forall a. Int -> Int -> Int -> Int -> Matrix a -> Matrix a
submatrix Int
1 Int
n Int
1 Int
m' forall a b. (a -> b) -> a -> b
$ forall a. Num a => Matrix a -> Matrix a -> Matrix a
strassenMixed Matrix a
b1 Matrix a
b2

-------------------------------------------------------
-------------------------------------------------------
---- NUMERICAL INSTANCE

instance Num a => Num (Matrix a) where
 fromInteger :: Integer -> Matrix a
fromInteger = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
1 Int
1 Int
0 Int
0 Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Vector a
V.singleton forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Integer -> a
fromInteger
 negate :: Matrix a -> Matrix a
negate = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Num a => a -> a
negate
 abs :: Matrix a -> Matrix a
abs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Num a => a -> a
abs
 signum :: Matrix a -> Matrix a
signum = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Num a => a -> a
signum

 -- Addition of matrices.
 {-# SPECIALIZE (+) :: Matrix Double -> Matrix Double -> Matrix Double #-}
 {-# SPECIALIZE (+) :: Matrix Int -> Matrix Int -> Matrix Int #-}
 {-# SPECIALIZE (+) :: Matrix Rational -> Matrix Rational -> Matrix Rational #-}
 + :: Matrix a -> Matrix a -> Matrix a
(+) = forall a b c. (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c
elementwise forall a. Num a => a -> a -> a
(+)

 -- Substraction of matrices.
 {-# SPECIALIZE (-) :: Matrix Double -> Matrix Double -> Matrix Double #-}
 {-# SPECIALIZE (-) :: Matrix Int -> Matrix Int -> Matrix Int #-}
 {-# SPECIALIZE (-) :: Matrix Rational -> Matrix Rational -> Matrix Rational #-}
 (-) = forall a b c. (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c
elementwise (-)

 -- Multiplication of matrices.
 {-# INLINE (*) #-}
 * :: Matrix a -> Matrix a -> Matrix a
(*) = forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStrassenMixed

-------------------------------------------------------
-------------------------------------------------------
---- TRANSFORMATIONS

-- | Scale a matrix by a given factor.
--   Example:
--
-- >               ( 1 2 3 )   (  2  4  6 )
-- >               ( 4 5 6 )   (  8 10 12 )
-- > scaleMatrix 2 ( 7 8 9 ) = ( 14 16 18 )
scaleMatrix :: Num a => a -> Matrix a -> Matrix a
scaleMatrix :: forall a. Num a => a -> Matrix a -> Matrix a
scaleMatrix = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a -> a
(*)

-- | Scale a row by a given factor.
--   Example:
--
-- >              ( 1 2 3 )   (  1  2  3 )
-- >              ( 4 5 6 )   (  8 10 12 )
-- > scaleRow 2 2 ( 7 8 9 ) = (  7  8  9 )
scaleRow :: Num a => a -> Int -> Matrix a -> Matrix a
scaleRow :: forall a. Num a => a -> Int -> Matrix a -> Matrix a
scaleRow = forall a. (Int -> a -> a) -> Int -> Matrix a -> Matrix a
mapRow forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a -> a
(*)

-- | Add to one row a scalar multiple of another row.
--   Example:
--
-- >                   ( 1 2 3 )   (  1  2  3 )
-- >                   ( 4 5 6 )   (  6  9 12 )
-- > combineRows 2 2 1 ( 7 8 9 ) = (  7  8  9 )
combineRows :: Num a => Int -> a -> Int -> Matrix a -> Matrix a
combineRows :: forall a. Num a => Int -> a -> Int -> Matrix a -> Matrix a
combineRows Int
r1 a
l Int
r2 Matrix a
m = forall a. (Int -> a -> a) -> Int -> Matrix a -> Matrix a
mapRow (\Int
j a
x -> a
x forall a. Num a => a -> a -> a
+ a
l forall a. Num a => a -> a -> a
* forall a. Int -> Int -> Matrix a -> a
getElem Int
r2 Int
j Matrix a
m) Int
r1 Matrix a
m

-- | Switch two rows of a matrix.
--   Example:
--
-- >                ( 1 2 3 )   ( 4 5 6 )
-- >                ( 4 5 6 )   ( 1 2 3 )
-- > switchRows 1 2 ( 7 8 9 ) = ( 7 8 9 )
switchRows :: Int -- ^ Row 1.
           -> Int -- ^ Row 2.
           -> Matrix a -- ^ Original matrix.
           -> Matrix a -- ^ Matrix with rows 1 and 2 switched.
switchRows :: forall a. Int -> Int -> Matrix a -> Matrix a
switchRows Int
r1 Int
r2 (M Int
n Int
m Int
ro Int
co Int
w Vector a
vs) = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
ro Int
co Int
w forall a b. (a -> b) -> a -> b
$ forall a.
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
V.modify (\MVector s a
mv -> do
  forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
m forall a b. (a -> b) -> a -> b
$ \Int
j ->
    forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
MV.swap MVector s a
mv (Int -> (Int, Int) -> Int
encode Int
w (Int
r1forall a. Num a => a -> a -> a
+Int
ro,Int
jforall a. Num a => a -> a -> a
+Int
co)) (Int -> (Int, Int) -> Int
encode Int
w (Int
r2forall a. Num a => a -> a -> a
+Int
ro,Int
jforall a. Num a => a -> a -> a
+Int
co))) Vector a
vs

-- | Switch two coumns of a matrix.
--   Example:
--
-- >                ( 1 2 3 )   ( 2 1 3 )
-- >                ( 4 5 6 )   ( 5 4 6 )
-- > switchCols 1 2 ( 7 8 9 ) = ( 8 7 9 )
switchCols :: Int -- ^ Col 1.
           -> Int -- ^ Col 2.
           -> Matrix a -- ^ Original matrix.
           -> Matrix a -- ^ Matrix with cols 1 and 2 switched.
switchCols :: forall a. Int -> Int -> Matrix a -> Matrix a
switchCols Int
c1 Int
c2 (M Int
n Int
m Int
ro Int
co Int
w Vector a
vs) = forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M Int
n Int
m Int
ro Int
co Int
w forall a b. (a -> b) -> a -> b
$ forall a.
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
V.modify (\MVector s a
mv -> do
  forall a (m :: * -> *).
(Num a, Ord a, Monad m) =>
a -> a -> (a -> m ()) -> m ()
numLoop Int
1 Int
n forall a b. (a -> b) -> a -> b
$ \Int
j ->
    forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
MV.swap MVector s a
mv (Int -> (Int, Int) -> Int
encode Int
m (Int
jforall a. Num a => a -> a -> a
+Int
ro,Int
c1forall a. Num a => a -> a -> a
+Int
co)) (Int -> (Int, Int) -> Int
encode Int
m (Int
jforall a. Num a => a -> a -> a
+Int
ro,Int
c2forall a. Num a => a -> a -> a
+Int
co))) Vector a
vs

-------------------------------------------------------
-------------------------------------------------------
---- DECOMPOSITIONS

-- LU DECOMPOSITION

-- | Matrix LU decomposition with /partial pivoting/.
--   The result for a matrix /M/ is given in the format /(U,L,P,d)/ where:
--
--   * /U/ is an upper triangular matrix.
--
--   * /L/ is an /unit/ lower triangular matrix.
--
--   * /P/ is a permutation matrix.
--
--   * /d/ is the determinant of /P/.
--
--   * /PM = LU/.
--
--   These properties are only guaranteed when the input matrix is invertible.
--   An additional property matches thanks to the strategy followed for pivoting:
--
--   * /L_(i,j)/ <= 1, for all /i,j/.
--
--   This follows from the maximal property of the selected pivots, which also
--   leads to a better numerical stability of the algorithm.
--
--   Example:
--
-- >          ( 1 2 0 )     ( 2 0  2 )   (   1 0 0 )   ( 0 0 1 )
-- >          ( 0 2 1 )     ( 0 2 -1 )   ( 1/2 1 0 )   ( 1 0 0 )
-- > luDecomp ( 2 0 2 ) = ( ( 0 0  2 ) , (   0 1 1 ) , ( 0 1 0 ) , 1 )
--
--   'Nothing' is returned if no LU decomposition exists.
luDecomp :: (Ord a, Fractional a) => Matrix a -> Maybe (Matrix a,Matrix a,Matrix a,a)
luDecomp :: forall a.
(Ord a, Fractional a) =>
Matrix a -> Maybe (Matrix a, Matrix a, Matrix a, a)
luDecomp Matrix a
a = forall a.
(Ord a, Fractional a) =>
Matrix a
-> Matrix a
-> Matrix a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, a)
recLUDecomp Matrix a
a Matrix a
i Matrix a
i a
1 Int
1 Int
n
 where
  i :: Matrix a
i = forall a. Num a => Int -> Matrix a
identity forall a b. (a -> b) -> a -> b
$ forall a. Matrix a -> Int
nrows Matrix a
a
  n :: Int
n = forall a. Ord a => a -> a -> a
min (forall a. Matrix a -> Int
nrows Matrix a
a) (forall a. Matrix a -> Int
ncols Matrix a
a)

recLUDecomp ::  (Ord a, Fractional a)
            =>  Matrix a -- ^ U
            ->  Matrix a -- ^ L
            ->  Matrix a -- ^ P
            ->  a        -- ^ d
            ->  Int      -- ^ Current row
            ->  Int      -- ^ Total rows
            -> Maybe (Matrix a,Matrix a,Matrix a,a)
recLUDecomp :: forall a.
(Ord a, Fractional a) =>
Matrix a
-> Matrix a
-> Matrix a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, a)
recLUDecomp Matrix a
u Matrix a
l Matrix a
p a
d Int
k Int
n =
    if Int
k forall a. Ord a => a -> a -> Bool
> Int
n then forall a. a -> Maybe a
Just (Matrix a
u,Matrix a
l,Matrix a
p,a
d)
    else if a
ukk forall a. Eq a => a -> a -> Bool
== a
0 then forall a. Maybe a
Nothing
                     else forall a.
(Ord a, Fractional a) =>
Matrix a
-> Matrix a
-> Matrix a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, a)
recLUDecomp Matrix a
u'' Matrix a
l'' Matrix a
p' a
d' (Int
kforall a. Num a => a -> a -> a
+Int
1) Int
n
 where
  -- Pivot strategy: maximum value in absolute value below the current row.
  i :: Int
i  = forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (\Int
x Int
y -> forall a. Ord a => a -> a -> Ordering
compare (forall a. Num a => a -> a
abs forall a b. (a -> b) -> a -> b
$ Matrix a
u forall a. Matrix a -> (Int, Int) -> a
! (Int
x,Int
k)) (forall a. Num a => a -> a
abs forall a b. (a -> b) -> a -> b
$ Matrix a
u forall a. Matrix a -> (Int, Int) -> a
! (Int
y,Int
k))) [ Int
k .. Int
n ]
  -- Switching to place pivot in current row.
  u' :: Matrix a
u' = forall a. Int -> Int -> Matrix a -> Matrix a
switchRows Int
k Int
i Matrix a
u
  l' :: Matrix a
l' = let lw :: Int
lw = forall a. Matrix a -> Int
vcols Matrix a
l
           en :: (Int, Int) -> Int
en = Int -> (Int, Int) -> Int
encode Int
lw
           lro :: Int
lro = forall a. Matrix a -> Int
rowOffset Matrix a
l
           lco :: Int
lco = forall a. Matrix a -> Int
colOffset Matrix a
l
       in  if Int
i forall a. Eq a => a -> a -> Bool
== Int
k
              then Matrix a
l
              else forall a. Int -> Int -> Int -> Int -> Int -> Vector a -> Matrix a
M (forall a. Matrix a -> Int
nrows Matrix a
l) (forall a. Matrix a -> Int
ncols Matrix a
l) Int
lro Int
lco Int
lw forall a b. (a -> b) -> a -> b
$
                     forall a.
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
V.modify (\MVector s a
mv -> forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
1 .. Int
kforall a. Num a => a -> a -> a
-Int
1] forall a b. (a -> b) -> a -> b
$
                                 \Int
j -> forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
MV.swap MVector s a
mv ((Int, Int) -> Int
en (Int
iforall a. Num a => a -> a -> a
+Int
lro,Int
jforall a. Num a => a -> a -> a
+Int
lco))
                                                  ((Int, Int) -> Int
en (Int
kforall a. Num a => a -> a -> a
+Int
lro,Int
jforall a. Num a => a -> a -> a
+Int
lco))
                                ) forall a b. (a -> b) -> a -> b
$ forall a. Matrix a -> Vector a
mvect Matrix a
l
  p' :: Matrix a
p' = forall a. Int -> Int -> Matrix a -> Matrix a
switchRows Int
k Int
i Matrix a
p
  -- Permutation determinant
  d' :: a
d' = if Int
i forall a. Eq a => a -> a -> Bool
== Int
k then a
d else forall a. Num a => a -> a
negate a
d
  -- Cancel elements below the pivot.
  (Matrix a
u'',Matrix a
l'') = Matrix a -> Matrix a -> Int -> (Matrix a, Matrix a)
go Matrix a
u' Matrix a
l' (Int
kforall a. Num a => a -> a -> a
+Int
1)
  ukk :: a
ukk = Matrix a
u' forall a. Matrix a -> (Int, Int) -> a
! (Int
k,Int
k)
  go :: Matrix a -> Matrix a -> Int -> (Matrix a, Matrix a)
go Matrix a
u_ Matrix a
l_ Int
j =
    if Int
j forall a. Ord a => a -> a -> Bool
> forall a. Matrix a -> Int
nrows Matrix a
u_
    then (Matrix a
u_,Matrix a
l_)
    else let x :: a
x = (Matrix a
u_ forall a. Matrix a -> (Int, Int) -> a
! (Int
j,Int
k)) forall a. Fractional a => a -> a -> a
/ a
ukk
         in  Matrix a -> Matrix a -> Int -> (Matrix a, Matrix a)
go (forall a. Num a => Int -> a -> Int -> Matrix a -> Matrix a
combineRows Int
j (-a
x) Int
k Matrix a
u_) (forall a. a -> (Int, Int) -> Matrix a -> Matrix a
setElem a
x (Int
j,Int
k) Matrix a
l_) (Int
jforall a. Num a => a -> a -> a
+Int
1)

-- | Unsafe version of 'luDecomp'. It fails when the input matrix is singular.
luDecompUnsafe :: (Ord a, Fractional a) => Matrix a -> (Matrix a, Matrix a, Matrix a, a)
luDecompUnsafe :: forall a.
(Ord a, Fractional a) =>
Matrix a -> (Matrix a, Matrix a, Matrix a, a)
luDecompUnsafe Matrix a
m = case forall a.
(Ord a, Fractional a) =>
Matrix a -> Maybe (Matrix a, Matrix a, Matrix a, a)
luDecomp Matrix a
m of
  Just (Matrix a, Matrix a, Matrix a, a)
x -> (Matrix a, Matrix a, Matrix a, a)
x
  Maybe (Matrix a, Matrix a, Matrix a, a)
_ -> forall a. HasCallStack => String -> a
error String
"luDecompUnsafe of singular matrix."

-- | Matrix LU decomposition with /complete pivoting/.
--   The result for a matrix /M/ is given in the format /(U,L,P,Q,d,e)/ where:
--
--   * /U/ is an upper triangular matrix.
--
--   * /L/ is an /unit/ lower triangular matrix.
--
--   * /P,Q/ are permutation matrices.
--
--   * /d,e/ are the determinants of /P/ and /Q/ respectively.
--
--   * /PMQ = LU/.
--
--   These properties are only guaranteed when the input matrix is invertible.
--   An additional property matches thanks to the strategy followed for pivoting:
--
--   * /L_(i,j)/ <= 1, for all /i,j/.
--
--   This follows from the maximal property of the selected pivots, which also
--   leads to a better numerical stability of the algorithm.
--
--   Example:
--
-- >           ( 1 0 )     ( 2 1 )   (   1    0 0 )   ( 0 0 1 )
-- >           ( 0 2 )     ( 0 2 )   (   0    1 0 )   ( 0 1 0 )   ( 1 0 )
-- > luDecomp' ( 2 1 ) = ( ( 0 0 ) , ( 1/2 -1/4 1 ) , ( 1 0 0 ) , ( 0 1 ) , -1 , 1 )
--
--   'Nothing' is returned if no LU decomposition exists.
luDecomp' :: (Ord a, Fractional a) => Matrix a -> Maybe (Matrix a,Matrix a,Matrix a,Matrix a,a,a)
luDecomp' :: forall a.
(Ord a, Fractional a) =>
Matrix a -> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
luDecomp' Matrix a
a = forall a.
(Ord a, Fractional a) =>
Matrix a
-> Matrix a
-> Matrix a
-> Matrix a
-> a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
recLUDecomp' Matrix a
a Matrix a
i Matrix a
i (forall a. Num a => Int -> Matrix a
identity forall a b. (a -> b) -> a -> b
$ forall a. Matrix a -> Int
ncols Matrix a
a) a
1 a
1 Int
1 Int
n
 where
  i :: Matrix a
i = forall a. Num a => Int -> Matrix a
identity forall a b. (a -> b) -> a -> b
$ forall a. Matrix a -> Int
nrows Matrix a
a
  n :: Int
n = forall a. Ord a => a -> a -> a
min (forall a. Matrix a -> Int
nrows Matrix a
a) (forall a. Matrix a -> Int
ncols Matrix a
a)

-- | Unsafe version of 'luDecomp''. It fails when the input matrix is singular.
luDecompUnsafe' :: (Ord a, Fractional a) => Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
luDecompUnsafe' :: forall a.
(Ord a, Fractional a) =>
Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
luDecompUnsafe' Matrix a
m = case forall a.
(Ord a, Fractional a) =>
Matrix a -> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
luDecomp' Matrix a
m of
  Just (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
x -> (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
x
  Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
_ -> forall a. HasCallStack => String -> a
error String
"luDecompUnsafe' of singular matrix."

recLUDecomp' ::  (Ord a, Fractional a)
            =>  Matrix a -- ^ U
            ->  Matrix a -- ^ L
            ->  Matrix a -- ^ P
            ->  Matrix a -- ^ Q
            ->  a        -- ^ d
            ->  a        -- ^ e
            ->  Int      -- ^ Current row
            ->  Int      -- ^ Total rows
            ->  Maybe (Matrix a,Matrix a,Matrix a,Matrix a,a,a)
recLUDecomp' :: forall a.
(Ord a, Fractional a) =>
Matrix a
-> Matrix a
-> Matrix a
-> Matrix a
-> a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
recLUDecomp' Matrix a
u Matrix a
l Matrix a
p Matrix a
q a
d a
e Int
k Int
n =
    if Int
k forall a. Ord a => a -> a -> Bool
> Int
n Bool -> Bool -> Bool
|| Matrix a
u'' forall a. Matrix a -> (Int, Int) -> a
! (Int
k, Int
k) forall a. Eq a => a -> a -> Bool
== a
0
    then forall a. a -> Maybe a
Just (Matrix a
u,Matrix a
l,Matrix a
p,Matrix a
q,a
d,a
e)
    else if a
ukk forall a. Eq a => a -> a -> Bool
== a
0
            then forall a. Maybe a
Nothing
            else forall a.
(Ord a, Fractional a) =>
Matrix a
-> Matrix a
-> Matrix a
-> Matrix a
-> a
-> a
-> Int
-> Int
-> Maybe (Matrix a, Matrix a, Matrix a, Matrix a, a, a)
recLUDecomp' Matrix a
u'' Matrix a
l'' Matrix a
p' Matrix a
q' a
d' a
e' (Int
kforall a. Num a => a -> a -> a
+Int
1) Int
n
 where
  -- Pivot strategy: maximum value in absolute value below the current row & col.
  (Int
i, Int
j) = forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (\(Int
i0, Int
j0) -> forall a. Num a => a -> a
abs forall a b. (a -> b) -> a -> b
$ Matrix a
u forall a. Matrix a -> (Int, Int) -> a
! (Int
i0,Int
j0)))
           [ (Int
i0, Int
j0) | Int
i0 <- [Int
k .. forall a. Matrix a -> Int
nrows Matrix a
u], Int
j0 <- [Int
k .. forall a. Matrix a -> Int
ncols Matrix a
u] ]
  -- Switching to place pivot in current row.
  u' :: Matrix a
u' = forall a. Int -> Int -> Matrix a -> Matrix a
switchCols Int
k Int
j forall a b. (a -> b) -> a -> b
$ forall a. Int -> Int -> Matrix a -> Matrix a
switchRows Int
k Int
i Matrix a
u
  l'0 :: Matrix a
l'0 = forall a. Int -> Int -> Matrix a -> Matrix a
switchRows Int
k Int
i Matrix a
l
  l' :: Matrix a
l' = forall a. Int -> Int -> Matrix a -> Matrix a
switchCols Int
k Int
i Matrix a
l'0
  p' :: Matrix a
p' = forall a. Int -> Int -> Matrix a -> Matrix a
switchRows Int
k Int
i Matrix a
p
  q' :: Matrix a
q' = forall a. Int -> Int -> Matrix a -> Matrix a
switchCols Int
k Int
j Matrix a
q
  -- Permutation determinant
  d' :: a
d' = if Int
i forall a. Eq a => a -> a -> Bool
== Int
k then a
d else forall a. Num a => a -> a
negate a
d
  e' :: a
e' = if Int
j forall a. Eq a => a -> a -> Bool
== Int
k then a
e else forall a. Num a => a -> a
negate a
e
  -- Cancel elements below the pivot.
  (Matrix a
u'',Matrix a
l'') = Matrix a -> Matrix a -> Int -> (Matrix a, Matrix a)
go Matrix a
u' Matrix a
l' (Int
kforall a. Num a => a -> a -> a
+Int
1)
  ukk :: a
ukk = Matrix a
u' forall a. Matrix a -> (Int, Int) -> a
! (Int
k,Int
k)
  go :: Matrix a -> Matrix a -> Int -> (Matrix a, Matrix a)
go Matrix a
u_ Matrix a
l_ Int
h =
    if Int
h forall a. Ord a => a -> a -> Bool
> forall a. Matrix a -> Int
nrows Matrix a
u_
    then (Matrix a
u_,Matrix a
l_)
    else let x :: a
x = (Matrix a
u_ forall a. Matrix a -> (Int, Int) -> a
! (Int
h,Int
k)) forall a. Fractional a => a -> a -> a
/ a
ukk
         in  Matrix a -> Matrix a -> Int -> (Matrix a, Matrix a)
go (forall a. Num a => Int -> a -> Int -> Matrix a -> Matrix a
combineRows Int
h (-a
x) Int
k Matrix a
u_) (forall a. a -> (Int, Int) -> Matrix a -> Matrix a
setElem a
x (Int
h,Int
k) Matrix a
l_) (Int
hforall a. Num a => a -> a -> a
+Int
1)

-- CHOLESKY DECOMPOSITION

-- | Simple Cholesky decomposition of a symmetric, positive definite matrix.
--   The result for a matrix /M/ is a lower triangular matrix /L/ such that:
--
--   * /M = LL^T/.
--
--   Example:
--
-- >            (  2 -1  0 )   (  1.41  0     0    )
-- >            ( -1  2 -1 )   ( -0.70  1.22  0    )
-- > cholDecomp (  0 -1  2 ) = (  0.00 -0.81  1.15 )
cholDecomp :: (Floating a) => Matrix a -> Matrix a
cholDecomp :: forall a. Floating a => Matrix a -> Matrix a
cholDecomp Matrix a
a
        | (forall a. Matrix a -> Int
nrows Matrix a
a forall a. Eq a => a -> a -> Bool
== Int
1) Bool -> Bool -> Bool
&& (forall a. Matrix a -> Int
ncols Matrix a
a forall a. Eq a => a -> a -> Bool
== Int
1) = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
sqrt Matrix a
a
        | Bool
otherwise = forall a. (Matrix a, Matrix a, Matrix a, Matrix a) -> Matrix a
joinBlocks (Matrix a
l11, Matrix a
l12, Matrix a
l21, Matrix a
l22) where
    (Matrix a
a11, Matrix a
a12, Matrix a
a21, Matrix a
a22) = forall a.
Int -> Int -> Matrix a -> (Matrix a, Matrix a, Matrix a, Matrix a)
splitBlocks Int
1 Int
1 Matrix a
a
    l11' :: a
l11' = forall a. Floating a => a -> a
sqrt (Matrix a
a11 forall a. Matrix a -> (Int, Int) -> a
! (Int
1,Int
1))
    l11 :: Matrix a
l11 = forall a. Int -> Int -> [a] -> Matrix a
fromList Int
1 Int
1 [a
l11']
    l12 :: Matrix a
l12 = forall a. Num a => Int -> Int -> Matrix a
zero (forall a. Matrix a -> Int
nrows Matrix a
a12) (forall a. Matrix a -> Int
ncols Matrix a
a12)
    l21 :: Matrix a
l21 = forall a. Num a => a -> Matrix a -> Matrix a
scaleMatrix (a
1forall a. Fractional a => a -> a -> a
/a
l11') Matrix a
a21
    a22' :: Matrix a
a22' = Matrix a
a22 forall a. Num a => a -> a -> a
- forall a. Num a => Matrix a -> Matrix a -> Matrix a
multStd Matrix a
l21 (forall a. Matrix a -> Matrix a
transpose Matrix a
l21)
    l22 :: Matrix a
l22 = forall a. Floating a => Matrix a -> Matrix a
cholDecomp Matrix a
a22'

-------------------------------------------------------
-------------------------------------------------------
---- PROPERTIES

{-# RULES
"matrix/traceOfSum"
    forall a b. trace (a + b) = trace a + trace b

"matrix/traceOfScale"
    forall k a. trace (scaleMatrix k a) = k * trace a
  #-}

-- | Sum of the elements in the diagonal. See also 'getDiag'.
--   Example:
--
-- >       ( 1 2 3 )
-- >       ( 4 5 6 )
-- > trace ( 7 8 9 ) = 15
trace :: Num a => Matrix a -> a
trace :: forall a. Num a => Matrix a -> a
trace = forall a. Num a => Vector a -> a
V.sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Matrix a -> Vector a
getDiag

-- | Product of the elements in the diagonal. See also 'getDiag'.
--   Example:
--
-- >          ( 1 2 3 )
-- >          ( 4 5 6 )
-- > diagProd ( 7 8 9 ) = 45
diagProd :: Num a => Matrix a -> a
diagProd :: forall a. Num a => Matrix a -> a
diagProd = forall a. Num a => Vector a -> a
V.product forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Matrix a -> Vector a
getDiag

-- DETERMINANT

{-# RULES
"matrix/detLaplaceProduct"
    forall a b. detLaplace (a*b) = detLaplace a * detLaplace b

"matrix/detLUProduct"
    forall a b. detLU (a*b) = detLU a * detLU b
  #-}

-- | Matrix determinant using Laplace expansion.
--   If the elements of the 'Matrix' are instance of 'Ord' and 'Fractional'
--   consider to use 'detLU' in order to obtain better performance.
--   Function 'detLaplace' is /extremely/ slow.
detLaplace :: Num a => Matrix a -> a
detLaplace :: forall a. Num a => Matrix a -> a
detLaplace m :: Matrix a
m@(M Int
1 Int
1 Int
_ Int
_ Int
_ Vector a
_) = Matrix a
m forall a. Matrix a -> (Int, Int) -> a
! (Int
1,Int
1)
detLaplace Matrix a
m = [a] -> a
sum1 [ (-a
1)forall a b. (Num a, Integral b) => a -> b -> a
^(Int
iforall a. Num a => a -> a -> a
-Int
1) forall a. Num a => a -> a -> a
* Matrix a
m forall a. Matrix a -> (Int, Int) -> a
! (Int
i,Int
1) forall a. Num a => a -> a -> a
* forall a. Num a => Matrix a -> a
detLaplace (forall a. Int -> Int -> Matrix a -> Matrix a
minorMatrix Int
i Int
1 Matrix a
m) | Int
i <- [Int
1 .. forall a. Matrix a -> Int
nrows Matrix a
m] ]
  where
    sum1 :: [a] -> a
sum1 = forall a. (a -> a -> a) -> [a] -> a
foldl1' forall a. Num a => a -> a -> a
(+)

-- | Matrix determinant using LU decomposition.
--   It works even when the input matrix is singular.
detLU :: (Ord a, Fractional a) => Matrix a -> a
detLU :: forall a. (Ord a, Fractional a) => Matrix a -> a
detLU Matrix a
m = case forall a.
(Ord a, Fractional a) =>
Matrix a -> Maybe (Matrix a, Matrix a, Matrix a, a)
luDecomp Matrix a
m of
  Just (Matrix a
u,Matrix a
_,Matrix a
_,a
d) -> a
d forall a. Num a => a -> a -> a
* forall a. Num a => Matrix a -> a
diagProd Matrix a
u
  Maybe (Matrix a, Matrix a, Matrix a, a)
Nothing -> a
0