module Data.Eigen.Matrix (
Matrix(..),
fromList,
toList,
generate,
empty,
zero,
ones,
identity,
constant,
cols,
rows,
coeff,
minCoeff,
maxCoeff,
col,
row,
block,
topRows,
bottomRows,
leftCols,
rightCols,
norm,
squaredNorm,
determinant,
add,
sub,
mul,
inverse,
adjoint,
conjugate,
transpose,
normalize,
modify,
thaw,
freeze,
unsafeThaw,
unsafeFreeze
) where
import Data.List (intercalate)
import Data.Tuple
import Foreign.Ptr
import Foreign.C.Types
import Foreign.C.String
import Control.Monad
import Control.Monad.ST
import Control.Monad.Primitive
import Control.Applicative hiding (empty)
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import Data.Eigen.Internal
import Data.Eigen.Matrix.Mutable
data Matrix = Matrix {
m_rows :: Int,
m_cols :: Int,
m_vals :: VS.Vector CDouble
};
instance Show Matrix where
show m@Matrix{..} = concat [
"Matrix ", show m_rows, "x", show m_cols, "\n", intercalate "\n" $ map (intercalate "\t" . map show) $ toList m, "\n"]
instance Num Matrix where
(*) = mul
(+) = add
() = sub
fromInteger = undefined
signum = undefined
abs = undefined
empty :: Matrix
empty = Matrix 0 0 VS.empty
constant :: Int -> Int -> Double -> Matrix
constant rows cols val = Matrix rows cols $ VS.replicate (rows * cols) (cast val)
zero :: Int -> Int -> Matrix
zero rows cols = constant rows cols 0
ones :: Int -> Int -> Matrix
ones rows cols = constant rows cols 1
identity :: Int -> Matrix
identity size = Matrix size size $ VS.create $ do
vm <- VSM.replicate (size * size) 0
forM_ [0..pred size] $ \n ->
VSM.write vm (n * size + n) 1
return vm
rows :: Matrix -> Int
rows = m_rows
cols :: Matrix -> Int
cols = m_cols
coeff :: Int -> Int -> Matrix -> Double
coeff row col Matrix{..} = cast $ m_vals VS.! (col * m_rows + row)
col :: Int -> Matrix -> [Double]
col c m@Matrix{..} = [coeff r c m | r <- [0..pred m_rows]]
row :: Int -> Matrix -> [Double]
row r m@Matrix{..} = [coeff r c m | c <- [0..pred m_cols]]
block :: Int -> Int -> Int -> Int -> Matrix -> Matrix
block startRow startCol blockRows blockCols m =
generate blockRows blockCols $ \row col ->
coeff (startRow + row) (startCol + col) m
maxCoeff :: Matrix -> Double
maxCoeff Matrix{..} = cast $ VS.maximum m_vals
minCoeff :: Matrix -> Double
minCoeff Matrix{..} = cast $ VS.minimum m_vals
topRows :: Int -> Matrix -> Matrix
topRows rows m@Matrix{..} = block 0 0 rows m_cols m
bottomRows :: Int -> Matrix -> Matrix
bottomRows rows m@Matrix{..} = block (m_rows rows) 0 rows m_cols m
leftCols :: Int -> Matrix -> Matrix
leftCols cols m@Matrix{..} = block 0 0 m_rows cols m
rightCols :: Int -> Matrix -> Matrix
rightCols cols m@Matrix{..} = block 0 (m_cols cols) m_rows cols m
fromList :: [[Double]] -> Matrix
fromList list = Matrix rows cols vals where
rows = length list
cols = maximum $ map length list
vals = VS.create $ do
vm <- VSM.replicate (rows * cols) 0
forM_ (zip [0..] list) $ \(row, vals) ->
forM_ (zip [0..] vals) $ \(col, val) ->
VSM.write vm (col * rows + row) (cast val)
return vm
toList :: Matrix -> [[Double]]
toList Matrix{..} = [[cast $ m_vals VS.! (col * m_rows + row) | col <- [0..pred m_cols]] | row <- [0..pred m_rows]]
generate :: Int -> Int -> (Int -> Int -> Double) -> Matrix
generate rows cols f = Matrix rows cols $ VS.create $ do
vals <- VSM.new (rows * cols)
forM_ [0..pred rows] $ \row ->
forM_ [0..pred cols] $ \col ->
VSM.write vals (col * rows + row) (cast $ f row col)
return vals
norm :: Matrix -> Double
norm = _unop c_norm
squaredNorm :: Matrix -> Double
squaredNorm = _unop c_squaredNorm
determinant :: Matrix -> Double
determinant m@Matrix{..}
| m_cols == m_rows = _unop c_determinant m
| otherwise = error "you tried calling determinant on non-square matrix"
add :: Matrix -> Matrix -> Matrix
add = _binop c_add
sub :: Matrix -> Matrix -> Matrix
sub = _binop c_sub
mul :: Matrix -> Matrix -> Matrix
mul = _binop c_mul
inverse :: Matrix -> Matrix
inverse m@Matrix{..}
| m_rows == m_cols = _modify id c_inverse m
| otherwise = error "you tried calling inverse on non-square matrix"
adjoint :: Matrix -> Matrix
adjoint = _modify swap c_adjoint
transpose :: Matrix -> Matrix
transpose = _modify swap c_transpose
conjugate :: Matrix -> Matrix
conjugate = _modify id c_conjugate
normalize :: Matrix -> Matrix
normalize Matrix{..} = performIO $ do
vals <- VS.thaw m_vals
VSM.unsafeWith vals $ \p ->
call $ c_normalize p (cast m_rows) (cast m_cols)
Matrix m_rows m_cols <$> VS.unsafeFreeze vals
modify :: (forall s. MMatrix s -> ST s ()) -> Matrix -> Matrix
modify f m@Matrix{..} = m { m_vals = VS.modify f' m_vals } where
f' vals = f (MMatrix m_rows m_cols vals)
freeze :: PrimMonad m => MMatrix (PrimState m) -> m Matrix
freeze MMatrix{..} = VS.freeze mm_vals >>= \vals -> return $ Matrix mm_rows mm_cols vals
thaw :: PrimMonad m => Matrix -> m (MMatrix (PrimState m))
thaw Matrix{..} = VS.thaw m_vals >>= \vals -> return $ MMatrix m_rows m_cols vals
unsafeFreeze :: PrimMonad m => MMatrix (PrimState m) -> m Matrix
unsafeFreeze MMatrix{..} = VS.unsafeFreeze mm_vals >>= \vals -> return $ Matrix mm_rows mm_cols vals
unsafeThaw :: PrimMonad m => Matrix -> m (MMatrix (PrimState m))
unsafeThaw Matrix{..} = VS.unsafeThaw m_vals >>= \vals -> return $ MMatrix m_rows m_cols vals
_unop :: (Ptr CDouble -> CInt -> CInt -> IO CDouble) -> Matrix -> Double
_unop f Matrix{..} = performIO $ VS.unsafeWith m_vals $ \p ->
cast <$> f p (cast m_rows) (cast m_cols)
_binop :: (Ptr CDouble -> CInt -> CInt -> Ptr CDouble -> CInt -> CInt -> IO CString) -> Matrix -> Matrix -> Matrix
_binop f m1 m2 = performIO $ do
vals <- VS.thaw (m_vals m1)
VSM.unsafeWith vals $ \lhs ->
VS.unsafeWith (m_vals m2) $ \rhs ->
call $ f
lhs (cast $ m_rows m1) (cast $ m_cols m1)
rhs (cast $ m_rows m2) (cast $ m_cols m2)
Matrix (m_rows m1) (m_cols m1) <$> VS.unsafeFreeze vals
_modify :: ((Int,Int) -> (Int,Int)) -> (Ptr CDouble -> CInt -> CInt -> Ptr CDouble -> CInt -> CInt -> IO CString) -> Matrix -> Matrix
_modify f g Matrix{..} = performIO $ do
let (rows, cols) = f (m_rows, m_cols)
vals <- VSM.new (rows * cols)
VS.unsafeWith m_vals $ \src ->
VSM.unsafeWith vals $ \dst ->
call $ g
dst (cast rows) (cast cols)
src (cast m_rows) (cast m_cols)
Matrix rows cols <$> VS.unsafeFreeze vals