{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Eigen.Matrix (
) where
import qualified Prelude as P
import qualified Data.List as L
import Prelude hiding (null, sum, all, any, map, filter)
import Data.Tuple
import Data.Complex hiding (conjugate)
import Data.Binary hiding (encode, decode)
import qualified Data.Binary as B
import Foreign.Ptr
import Foreign.C.Types
import Foreign.C.String
import Foreign.Storable
import Foreign.Marshal.Alloc
import Text.Printf
import Control.Monad
import Control.Monad.ST
import Control.Monad.Primitive
#if __GLASGOW_HASKELL__ >= 710
import Control.Applicative hiding (empty)
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import qualified Data.Eigen.Internal as I
import qualified Data.Eigen.Matrix.Mutable as M
import qualified Data.ByteString.Lazy as BSL
data Matrix a b where
Matrix :: I.Elem a b => !Int -> !Int -> !(VS.Vector b) -> Matrix a b
type MatrixXf = Matrix Float CFloat
type MatrixXd = Matrix Double CDouble
type MatrixXcf = Matrix (Complex Float) (I.CComplex CFloat)
type MatrixXcd = Matrix (Complex Double) (I.CComplex CDouble)
instance (I.Elem a b, Show a) => Show (Matrix a b) where
show m@(Matrix rows cols _) = concat [
"Matrix ", show rows, "x", show cols,
"\n", L.intercalate "\n" $ P.map (L.intercalate "\t" . P.map show) $ toList m, "\n"]
instance I.Elem a b => Num (Matrix a b) where
(*) = mul
(+) = add
(-) = sub
fromInteger = constant 1 1 . fromInteger
signum = map signum
abs = map abs
negate = map negate
instance I.Elem a b => Binary (Matrix a b) where
put (Matrix rows cols vals) = do
put $ I.magicCode (undefined :: b)
put rows
put cols
put vals
get = do
get >>= (`when` fail "wrong matrix type") . (/= I.magicCode (undefined :: b))
Matrix <$> get <*> get <*> get
encode :: I.Elem a b => Matrix a b -> BSL.ByteString
encode = B.encode
decode :: I.Elem a b => BSL.ByteString -> Matrix a b
decode = B.decode
{-# INLINE empty #-}
empty :: I.Elem a b => Matrix a b
empty = Matrix 0 0 VS.empty
{-# INLINE null #-}
null :: I.Elem a b => Matrix a b -> Bool
null (Matrix rows cols _) = rows == 0 && cols == 0
{-# INLINE square #-}
square :: I.Elem a b => Matrix a b -> Bool
square (Matrix rows cols _) = rows == cols
{-# INLINE constant #-}
constant :: I.Elem a b => Int -> Int -> a -> Matrix a b
constant rows cols val = Matrix rows cols $ VS.replicate (rows * cols) (I.cast val)
{-# INLINE zero #-}
zero :: I.Elem a b => Int -> Int -> Matrix a b
zero rows cols = constant rows cols 0
{-# INLINE ones #-}
ones :: I.Elem a b => Int -> Int -> Matrix a b
ones rows cols = constant rows cols 1
identity :: I.Elem a b => Int -> Int -> Matrix a b
identity rows cols = I.performIO $ do
m <- M.new rows cols
I.call $ M.unsafeWith m I.identity
unsafeFreeze m
random :: I.Elem a b => Int -> Int -> IO (Matrix a b)
random rows cols = do
m <- M.new rows cols
I.call $ M.unsafeWith m I.random
unsafeFreeze m
{-# INLINE rows #-}
rows :: I.Elem a b => Matrix a b -> Int
rows (Matrix rows _ _) = rows
{-# INLINE cols #-}
cols :: I.Elem a b => Matrix a b -> Int
cols (Matrix _ cols _) = cols
{-# INLINE dims #-}
dims :: I.Elem a b => Matrix a b -> (Int, Int)
dims (Matrix rows cols _) = (rows, cols)
{-# INLINE (!) #-}
(!) :: I.Elem a b => Matrix a b -> (Int, Int) -> a
(!) m (row,col) = coeff row col m
{-# INLINE coeff #-}
coeff :: I.Elem a b => Int -> Int -> Matrix a b -> a
coeff row col m@(Matrix rows cols _)
| not (valid m) = error "matrix is not valid"
| row < 0 || row >= rows = error $ printf "Matrix.coeff: row %d is out of bounds [0..%d)" row rows
| col < 0 || col >= cols = error $ printf "Matrix.coeff: col %d is out of bounds [0..%d)" col cols
| otherwise = unsafeCoeff row col m
{-# INLINE unsafeCoeff #-}
unsafeCoeff :: I.Elem a b => Int -> Int -> Matrix a b -> a
unsafeCoeff row col (Matrix rows _ vals) = I.cast $ VS.unsafeIndex vals $ col * rows + row
{-# INLINE col #-}
col :: I.Elem a b => Int -> Matrix a b -> [a]
col c m@(Matrix rows _ _) = [coeff r c m | r <- [0..pred rows]]
{-# INLINE row #-}
row :: I.Elem a b => Int -> Matrix a b -> [a]
row r m@(Matrix _ cols _) = [coeff r c m | c <- [0..pred cols]]
block :: I.Elem a b => Int -> Int -> Int -> Int -> Matrix a b -> Matrix a b
block startRow startCol blockRows blockCols m =
generate blockRows blockCols $ \row col ->
coeff (startRow + row) (startCol + col) m
{-# INLINE valid #-}
valid :: I.Elem a b => Matrix a b -> Bool
valid (Matrix rows cols vals) = rows >= 0 && cols >= 0 && VS.length vals == rows * cols
{-# INLINE maxCoeff #-}
maxCoeff :: (I.Elem a b, Ord a) => Matrix a b -> a
maxCoeff = fold1' max
{-# INLINE minCoeff #-}
minCoeff :: (I.Elem a b, Ord a) => Matrix a b -> a
minCoeff = fold1' min
{-# INLINE topRows #-}
topRows :: I.Elem a b => Int -> Matrix a b -> Matrix a b
topRows n m@(Matrix _ cols _) = block 0 0 n cols m
{-# INLINE bottomRows #-}
bottomRows :: I.Elem a b => Int -> Matrix a b -> Matrix a b
bottomRows n m@(Matrix rows cols _) = block (rows - n) 0 n cols m
{-# INLINE leftCols #-}
leftCols :: I.Elem a b => Int -> Matrix a b -> Matrix a b
leftCols n m@(Matrix rows _ _) = block 0 0 rows n m
{-# INLINE rightCols #-}
rightCols :: I.Elem a b => Int -> Matrix a b -> Matrix a b
rightCols n m@(Matrix rows cols _) = block 0 (cols - n) rows n m
fromList :: I.Elem a b => [[a]] -> Matrix a b
fromList list = Matrix rows cols vals where
rows = length list
cols = L.foldl' max 0 $ P.map length list
vals = VS.create $ do
vm <- VSM.replicate (rows * cols) (I.cast (0 `asTypeOf` (head (head list))))
forM_ (zip [0..] list) $ \(row, vals) ->
forM_ (zip [0..] vals) $ \(col, val) ->
VSM.write vm (col * rows + row) (I.cast val)
return vm
toList :: I.Elem a b => Matrix a b -> [[a]]
toList m@(Matrix rows cols vals)
| not (valid m) = error "matrix is not valid"
| otherwise = [[I.cast $ vals `VS.unsafeIndex` (col * rows + row) | col <- [0..pred cols]] | row <- [0..pred rows]]
fromFlatList :: I.Elem a b => Int -> Int -> [a] -> Matrix a b
fromFlatList rows cols list
| not (rows * cols == (length list)) = error $ concat ["cannot construct ", show rows, "x", show cols, " matrix from ", show $ length list, " values"]
| otherwise = Matrix rows cols vals where
vals = VS.create $ do
vm <- VSM.replicate (rows * cols) (I.cast (0 `asTypeOf` (head list)))
forM_ (zip [(col * rows + row) | row <- [0..pred rows], col <- [0..pred cols]] list) $ \(idx, val) ->
VSM.write vm idx (I.cast val)
return vm
toFlatList :: I.Elem a b => Matrix a b -> [a]
toFlatList m@(Matrix rows cols vals)
| not (valid m) = error "matrix is not valid"
| otherwise = [I.cast $ vals `VS.unsafeIndex` (col * rows + row) | row <- [0..pred rows], col <- [0..pred cols]]
generate :: I.Elem a b => Int -> Int -> (Int -> Int -> a) -> Matrix a b
generate rows cols f = Matrix rows cols $ VS.create $ do
vals <- VSM.new (rows * cols)
forM_ [0..pred rows] $ \row ->
forM_ [0..pred cols] $ \col ->
VSM.write vals (col * rows + row) (I.cast $ f row col)
return vals
sum :: I.Elem a b => Matrix a b -> a
sum = _prop I.sum
prod :: I.Elem a b => Matrix a b -> a
prod = _prop I.prod
mean :: I.Elem a b => Matrix a b -> a
mean = _prop I.mean
trace :: I.Elem a b => Matrix a b -> a
trace = _prop I.trace
all :: I.Elem a b => (a -> Bool) -> Matrix a b -> Bool
all f = VS.all (f . I.cast) . _vals
any :: I.Elem a b => (a -> Bool) -> Matrix a b -> Bool
any f = VS.any (f . I.cast) . _vals
count :: I.Elem a b => (a -> Bool) -> Matrix a b -> Int
count f = VS.foldl' (\n x -> if f (I.cast x) then succ n else n) 0 . _vals
norm :: I.Elem a b => Matrix a b -> a
norm = _prop I.norm
squaredNorm :: I.Elem a b => Matrix a b -> a
squaredNorm = _prop I.squaredNorm
blueNorm :: I.Elem a b => Matrix a b -> a
blueNorm = _prop I.blueNorm
hypotNorm :: I.Elem a b => Matrix a b -> a
hypotNorm = _prop I.hypotNorm
determinant :: I.Elem a b => Matrix a b -> a
determinant m
| square m = _prop I.determinant m
| otherwise = error "Matrix.determinant: non-square matrix"
add :: I.Elem a b => Matrix a b -> Matrix a b -> Matrix a b
add m1 m2
| dims m1 == dims m2 = _binop const I.add m1 m2
| otherwise = error "Matrix.add: matrices should have the same size"
sub :: I.Elem a b => Matrix a b -> Matrix a b -> Matrix a b
sub m1 m2
| dims m1 == dims m2 = _binop const I.sub m1 m2
| otherwise = error "Matrix.add: matrices should have the same size"
mul :: I.Elem a b => Matrix a b -> Matrix a b -> Matrix a b
mul m1 m2
| cols m1 == rows m2 = _binop (\(rows, _) (_, cols) -> (rows, cols)) I.mul m1 m2
| otherwise = error "Matrix.mul: number of columns for lhs matrix should be the same as number of rows for rhs matrix"
map :: I.Elem a b => (a -> a) -> Matrix a b -> Matrix a b
map f (Matrix rows cols vals) = Matrix rows cols (VS.map (I.cast . f . I.cast) vals)
imap :: I.Elem a b => (Int -> Int -> a -> a) -> Matrix a b -> Matrix a b
imap f (Matrix rows cols vals) = Matrix rows cols (VS.imap (\n -> let (c, r) = divMod n rows in I.cast . f r c . I.cast) vals)
data TriangularMode
= Lower
| Upper
| StrictlyLower
| StrictlyUpper
| UnitLower
| UnitUpper deriving (Eq, Enum, Show, Read)
triangularView :: I.Elem a b => TriangularMode -> Matrix a b -> Matrix a b
triangularView Lower = imap $ \row col val -> case compare row col of { LT -> 0; _ -> val }
triangularView Upper = imap $ \row col val -> case compare row col of { GT -> 0; _ -> val }
triangularView StrictlyLower = imap $ \row col val -> case compare row col of { GT -> val; _ -> 0 }
triangularView StrictlyUpper = imap $ \row col val -> case compare row col of { LT -> val; _ -> 0 }
triangularView UnitLower = imap $ \row col val -> case compare row col of { GT -> val; LT -> 0; EQ -> 1 }
triangularView UnitUpper = imap $ \row col val -> case compare row col of { LT -> val; GT -> 0; EQ -> 1 }
lowerTriangle :: I.Elem a b => Matrix a b -> Matrix a b
lowerTriangle = triangularView Lower
upperTriangle :: I.Elem a b => Matrix a b -> Matrix a b
upperTriangle = triangularView Upper
filter :: I.Elem a b => (a -> Bool) -> Matrix a b -> Matrix a b
filter f = map (\x -> if f x then x else 0)
ifilter :: I.Elem a b => (Int -> Int -> a -> Bool) -> Matrix a b -> Matrix a b
ifilter f = imap (\r c x -> if f r c x then x else 0)
fold :: I.Elem a b => (c -> a -> c) -> c -> Matrix a b -> c
fold f a (Matrix _ _ vals) = VS.foldl (\a x -> f a (I.cast x)) a vals
fold' :: I.Elem a b => (c -> a -> c) -> c -> Matrix a b -> c
fold' f a (Matrix _ _ vals) = VS.foldl' (\a x -> f a (I.cast x)) a vals
ifold :: I.Elem a b => (Int -> Int -> c -> a -> c) -> c -> Matrix a b -> c
ifold f a (Matrix rows _ vals) = VS.ifoldl (\a n x -> let (c,r) = divMod n rows in f r c a (I.cast x)) a vals
ifold' :: I.Elem a b => (Int -> Int -> c -> a -> c) -> c -> Matrix a b -> c
ifold' f a (Matrix rows _ vals) = VS.ifoldl' (\a n x -> let (c,r) = divMod n rows in f r c a (I.cast x)) a vals
fold1 :: I.Elem a b => (a -> a -> a) -> Matrix a b -> a
fold1 f = foldl1 f . P.map I.cast . VS.toList . _vals
fold1' :: I.Elem a b => (a -> a -> a) -> Matrix a b -> a
fold1' f = L.foldl1' f . P.map I.cast . VS.toList . _vals
diagonal :: I.Elem a b => Matrix a b -> Matrix a b
diagonal = _unop (\(rows, cols) -> (min rows cols, 1)) I.diagonal
inverse :: I.Elem a b => Matrix a b -> Matrix a b
inverse m
| square m = _unop id I.inverse m
| otherwise = error "Matrix.inverse: non-square matrix"
adjoint :: I.Elem a b => Matrix a b -> Matrix a b
adjoint = _unop swap I.adjoint
transpose :: I.Elem a b => Matrix a b -> Matrix a b
transpose = _unop swap I.transpose
conjugate :: I.Elem a b => Matrix a b -> Matrix a b
conjugate = _unop id I.conjugate
normalize :: I.Elem a b => Matrix a b -> Matrix a b
normalize (Matrix rows cols vals) = I.performIO $ do
vals <- VS.thaw vals
VSM.unsafeWith vals $ \p ->
I.call $ I.normalize p (I.cast rows) (I.cast cols)
Matrix rows cols <$> VS.unsafeFreeze vals
modify :: I.Elem a b => (forall s. M.MMatrix a b s -> ST s ()) -> Matrix a b -> Matrix a b
modify f (Matrix rows cols vals) = Matrix rows cols (VS.modify (f . M.MMatrix rows cols) vals)
convert :: (I.Elem a b, I.Elem c d) => (a -> c) -> Matrix a b -> Matrix c d
convert f (Matrix rows cols vals) = Matrix rows cols $ VS.map (I.cast . f . I.cast) vals
freeze :: I.Elem a b => PrimMonad m => M.MMatrix a b (PrimState m) -> m (Matrix a b)
freeze (M.MMatrix mrows mcols mvals) = VS.freeze mvals >>= return . Matrix mrows mcols
thaw :: I.Elem a b => PrimMonad m => Matrix a b -> m (M.MMatrix a b (PrimState m))
thaw (Matrix rows cols vals) = VS.thaw vals >>= return . M.MMatrix rows cols
unsafeFreeze :: I.Elem a b => PrimMonad m => M.MMatrix a b (PrimState m) -> m (Matrix a b)
unsafeFreeze (M.MMatrix mrows mcols mvals) = VS.unsafeFreeze mvals >>= return . Matrix mrows mcols
unsafeThaw :: I.Elem a b => PrimMonad m => Matrix a b -> m (M.MMatrix a b (PrimState m))
unsafeThaw (Matrix rows cols vals) = VS.unsafeThaw vals >>= return . M.MMatrix rows cols
unsafeWith :: I.Elem a b => Matrix a b -> (Ptr b -> CInt -> CInt -> IO c) -> IO c
unsafeWith m@(Matrix rows cols vals) f
| not (valid m) = fail "Matrix.unsafeWith: matrix layout is invalid"
| otherwise = VS.unsafeWith vals $ \p -> f p (I.cast rows) (I.cast cols)
{-# INLINE _prop #-}
_prop :: I.Elem a b => (Ptr b -> Ptr b -> CInt -> CInt -> IO CString) -> Matrix a b -> a
_prop f m = I.cast $ I.performIO $ alloca $ \p -> do
I.call $ unsafeWith m (f p)
peek p
{-# INLINE _binop #-}
_binop :: I.Elem a b => ((Int, Int) -> (Int, Int) -> (Int, Int)) -> (Ptr b -> CInt -> CInt -> Ptr b -> CInt -> CInt -> Ptr b -> CInt -> CInt -> IO CString) -> Matrix a b -> Matrix a b -> Matrix a b
_binop f g m1 m2 = I.performIO $ do
m0 <- uncurry M.new $ f (dims m1) (dims m2)
M.unsafeWith m0 $ \vals0 rows0 cols0 ->
unsafeWith m1 $ \vals1 rows1 cols1 ->
unsafeWith m2 $ \vals2 rows2 cols2 ->
I.call $ g
vals0 rows0 cols0
vals1 rows1 cols1
vals2 rows2 cols2
unsafeFreeze m0
{-# INLINE _unop #-}
_unop :: I.Elem a b => ((Int,Int) -> (Int,Int)) -> (Ptr b -> CInt -> CInt -> Ptr b -> CInt -> CInt -> IO CString) -> Matrix a b -> Matrix a b
_unop f g m1 = I.performIO $ do
m0 <- uncurry M.new $ f (dims m1)
M.unsafeWith m0 $ \vals0 rows0 cols0 ->
unsafeWith m1 $ \vals1 rows1 cols1 ->
I.call $ g
vals0 rows0 cols0
vals1 rows1 cols1
unsafeFreeze m0
{-# INLINE _vals #-}
_vals :: I.Elem a b => Matrix a b -> VS.Vector b
_vals (Matrix _ _ vals) = vals