{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE EmptyCase #-}
module Data.Eigen.Matrix
(
Matrix(..)
, MatrixXf
, MatrixXd
, MatrixXcf
, MatrixXcd
, I.Elem
, I.CComplex
, valid
, fromList
, toList
, fromFlatList
, toFlatList
, generate
, empty
, null
, square
, zero
, ones
, identity
, constant
, random
, cols
, rows
, dims
, (!)
, coeff
, unsafeCoeff
, col
, row
, block
, topRows
, bottomRows
, leftCols
, rightCols
, sum
, prod
, mean
, minCoeff
, maxCoeff
, trace
, norm
, squaredNorm
, blueNorm
, hypotNorm
, determinant
, fold
, fold'
, ifold
, ifold'
, fold1
, fold1'
, all
, any
, count
, add
, sub
, mul
, map
, imap
, filter
, ifilter
, diagonal
, transpose
, inverse
, adjoint
, conjugate
, normalize
, modify
, convert
, TriangularMode(..)
, triangularView
, lowerTriangle
, upperTriangle
, encode
, decode
, thaw
, freeze
, unsafeThaw
, unsafeFreeze
, unsafeWith
) where
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Binary hiding (encode, decode)
import Data.Complex hiding (conjugate)
import Data.Tuple
import Foreign.C.String
import Foreign.C.Types
import Foreign.Marshal.Alloc
import Foreign.Ptr
import Foreign.Storable
import Prelude hiding (null, sum, all, any, map, filter)
import Text.Printf
import qualified Data.Binary as B
import qualified Data.ByteString.Lazy as BSL
import qualified Data.Eigen.Internal as I
import qualified Data.Eigen.Matrix.Mutable as M
import qualified Data.List as L
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import qualified Prelude as P
data Matrix a b where
Matrix :: I.Elem a b => !Int -> !Int -> !(VS.Vector b) -> Matrix a b
type MatrixXf = Matrix Float CFloat
type MatrixXd = Matrix Double CDouble
type MatrixXcf = Matrix (Complex Float) (I.CComplex CFloat)
type MatrixXcd = Matrix (Complex Double) (I.CComplex CDouble)
instance (I.Elem a b, Show a) => Show (Matrix a b) where
show m@(Matrix rows cols _) = concat [
"Matrix ", show rows, "x", show cols,
"\n", L.intercalate "\n" $ P.map (L.intercalate "\t" . P.map show) $ toList m, "\n"]
instance I.Elem a b => Num (Matrix a b) where
(*) = mul
(+) = add
(-) = sub
fromInteger = constant 1 1 . fromInteger
signum = map signum
abs = map abs
negate = map negate
instance I.Elem a b => Binary (Matrix a b) where
put (Matrix rows cols vals) = do
put $ I.magicCode (undefined :: b)
put rows
put cols
put vals
get = do
get >>= (`when` fail "wrong matrix type") . (/= I.magicCode (undefined :: b))
Matrix <$> get <*> get <*> get
encode :: I.Elem a b => Matrix a b -> BSL.ByteString
encode = B.encode
decode :: I.Elem a b => BSL.ByteString -> Matrix a b
decode = B.decode
{-# INLINE empty #-}
empty :: I.Elem a b => Matrix a b
empty = Matrix 0 0 VS.empty
{-# INLINE null #-}
null :: I.Elem a b => Matrix a b -> Bool
null (Matrix rows cols _) = rows == 0 && cols == 0
{-# INLINE square #-}
square :: I.Elem a b => Matrix a b -> Bool
square (Matrix rows cols _) = rows == cols
{-# INLINE constant #-}
constant :: I.Elem a b => Int -> Int -> a -> Matrix a b
constant rows cols val = Matrix rows cols $ VS.replicate (rows * cols) (I.cast val)
{-# INLINE zero #-}
zero :: I.Elem a b => Int -> Int -> Matrix a b
zero rows cols = constant rows cols 0
{-# INLINE ones #-}
ones :: I.Elem a b => Int -> Int -> Matrix a b
ones rows cols = constant rows cols 1
identity :: I.Elem a b => Int -> Int -> Matrix a b
identity rows cols = I.performIO $ do
m <- M.new rows cols
I.call $ M.unsafeWith m I.identity
unsafeFreeze m
random :: I.Elem a b => Int -> Int -> IO (Matrix a b)
random rows cols = do
m <- M.new rows cols
I.call $ M.unsafeWith m I.random
unsafeFreeze m
{-# INLINE rows #-}
rows :: I.Elem a b => Matrix a b -> Int
rows (Matrix rows _ _) = rows
{-# INLINE cols #-}
cols :: I.Elem a b => Matrix a b -> Int
cols (Matrix _ cols _) = cols
{-# INLINE dims #-}
dims :: I.Elem a b => Matrix a b -> (Int, Int)
dims (Matrix rows cols _) = (rows, cols)
{-# INLINE (!) #-}
(!) :: forall a b. (I.Elem a b) => Matrix a b -> (Int, Int) -> a
(!) m (row,col) = coeff row col m
{-# INLINE coeff #-}
coeff :: I.Elem a b => Int -> Int -> Matrix a b -> a
coeff row col m@(Matrix rows cols _)
| not (valid m) = error "matrix is not valid"
| row < 0 || row >= rows = error $ printf "Matrix.coeff: row %d is out of bounds [0..%d)" row rows
| col < 0 || col >= cols = error $ printf "Matrix.coeff: col %d is out of bounds [0..%d)" col cols
| otherwise = unsafeCoeff row col m
{-# INLINE unsafeCoeff #-}
unsafeCoeff :: I.Elem a b => Int -> Int -> Matrix a b -> a
unsafeCoeff row col (Matrix rows _ vals) = I.cast $ VS.unsafeIndex vals $ col * rows + row
{-# INLINE col #-}
col :: I.Elem a b => Int -> Matrix a b -> [a]
col c m@(Matrix rows _ _) = [coeff r c m | r <- [0..pred rows]]
{-# INLINE row #-}
row :: I.Elem a b => Int -> Matrix a b -> [a]
row r m@(Matrix _ cols _) = [coeff r c m | c <- [0..pred cols]]
block :: I.Elem a b => Int -> Int -> Int -> Int -> Matrix a b -> Matrix a b
block startRow startCol blockRows blockCols m =
generate blockRows blockCols $ \row col ->
coeff (startRow + row) (startCol + col) m
{-# INLINE valid #-}
valid :: I.Elem a b => Matrix a b -> Bool
valid (Matrix rows cols vals) =
rows >= 0
&& cols >= 0
&& VS.length vals == rows * cols
{-# INLINE maxCoeff #-}
maxCoeff :: (I.Elem a b, Ord a) => Matrix a b -> a
maxCoeff = fold1' max
{-# INLINE minCoeff #-}
minCoeff :: (I.Elem a b, Ord a) => Matrix a b -> a
minCoeff = fold1' min
{-# INLINE topRows #-}
topRows :: I.Elem a b => Int -> Matrix a b -> Matrix a b
topRows n m@(Matrix _ cols _) = block 0 0 n cols m
{-# INLINE bottomRows #-}
bottomRows :: I.Elem a b => Int -> Matrix a b -> Matrix a b
bottomRows n m@(Matrix rows cols _) = block (rows - n) 0 n cols m
{-# INLINE leftCols #-}
leftCols :: I.Elem a b => Int -> Matrix a b -> Matrix a b
leftCols n m@(Matrix rows _ _) = block 0 0 rows n m
{-# INLINE rightCols #-}
rightCols :: I.Elem a b => Int -> Matrix a b -> Matrix a b
rightCols n m@(Matrix rows cols _) = block 0 (cols - n) rows n m
fromList :: I.Elem a b => [[a]] -> Matrix a b
fromList list = Matrix rows cols vals where
rows = length list
cols = L.foldl' max 0 $ P.map length list
vals = VS.create $ do
vm <- VSM.replicate (rows * cols) (I.cast (0 `asTypeOf` (head (head list))))
forM_ (zip [0..] list) $ \(row, vals) ->
forM_ (zip [0..] vals) $ \(col, val) ->
VSM.write vm (col * rows + row) (I.cast val)
return vm
toList :: I.Elem a b => Matrix a b -> [[a]]
toList m@(Matrix rows cols vals)
| not (valid m) = error "matrix is not valid"
| otherwise = [[I.cast $ vals `VS.unsafeIndex` (col * rows + row) | col <- [0..pred cols]] | row <- [0..pred rows]]
fromFlatList :: I.Elem a b => Int -> Int -> [a] -> Matrix a b
fromFlatList rows cols list
| not (rows * cols == (length list)) = error $ concat ["cannot construct ", show rows, "x", show cols, " matrix from ", show $ length list, " values"]
| otherwise = Matrix rows cols vals where
vals = VS.create $ do
vm <- VSM.replicate (rows * cols) (I.cast (0 `asTypeOf` (head list)))
forM_ (zip [(col * rows + row) | row <- [0..pred rows], col <- [0..pred cols]] list) $ \(idx, val) ->
VSM.write vm idx (I.cast val)
return vm
toFlatList :: I.Elem a b => Matrix a b -> [a]
toFlatList m@(Matrix rows cols vals)
| not (valid m) = error "matrix is not valid"
| otherwise = [I.cast $ vals `VS.unsafeIndex` (col * rows + row) | row <- [0..pred rows], col <- [0..pred cols]]
generate :: I.Elem a b => Int -> Int -> (Int -> Int -> a) -> Matrix a b
generate rows cols f = Matrix rows cols $ VS.create $ do
vals <- VSM.new (rows * cols)
forM_ [0..pred rows] $ \row ->
forM_ [0..pred cols] $ \col ->
VSM.write vals (col * rows + row) (I.cast $ f row col)
return vals
sum :: I.Elem a b => Matrix a b -> a
sum = _prop I.sum
prod :: I.Elem a b => Matrix a b -> a
prod = _prop I.prod
mean :: I.Elem a b => Matrix a b -> a
mean = _prop I.mean
trace :: I.Elem a b => Matrix a b -> a
trace = _prop I.trace
all :: I.Elem a b => (a -> Bool) -> Matrix a b -> Bool
all f = VS.all (f . I.cast) . _vals
any :: I.Elem a b => (a -> Bool) -> Matrix a b -> Bool
any f = VS.any (f . I.cast) . _vals
count :: I.Elem a b => (a -> Bool) -> Matrix a b -> Int
count f = VS.foldl' (\n x -> if f (I.cast x) then succ n else n) 0 . _vals
norm :: I.Elem a b => Matrix a b -> a
norm = _prop I.norm
squaredNorm :: I.Elem a b => Matrix a b -> a
squaredNorm = _prop I.squaredNorm
blueNorm :: I.Elem a b => Matrix a b -> a
blueNorm = _prop I.blueNorm
hypotNorm :: I.Elem a b => Matrix a b -> a
hypotNorm = _prop I.hypotNorm
determinant :: I.Elem a b => Matrix a b -> a
determinant m
| square m = _prop I.determinant m
| otherwise = error "Matrix.determinant: non-square matrix"
add :: I.Elem a b => Matrix a b -> Matrix a b -> Matrix a b
add m1 m2
| dims m1 == dims m2 = _binop const I.add m1 m2
| otherwise = error "Matrix.add: matrices should have the same size"
sub :: I.Elem a b => Matrix a b -> Matrix a b -> Matrix a b
sub m1 m2
| dims m1 == dims m2 = _binop const I.sub m1 m2
| otherwise = error "Matrix.add: matrices should have the same size"
mul :: I.Elem a b => Matrix a b -> Matrix a b -> Matrix a b
mul m1 m2
| cols m1 == rows m2 = _binop (\(rows, _) (_, cols) -> (rows, cols)) I.mul m1 m2
| otherwise = error "Matrix.mul: number of columns for lhs matrix should be the same as number of rows for rhs matrix"
map :: I.Elem a b => (a -> a) -> Matrix a b -> Matrix a b
map f (Matrix rows cols vals) = Matrix rows cols (VS.map (I.cast . f . I.cast) vals)
imap :: I.Elem a b => (Int -> Int -> a -> a) -> Matrix a b -> Matrix a b
imap f (Matrix rows cols vals) = Matrix rows cols (VS.imap (\n -> let (c, r) = divMod n rows in I.cast . f r c . I.cast) vals)
data TriangularMode
= Lower
| Upper
| StrictlyLower
| StrictlyUpper
| UnitLower
| UnitUpper deriving (Eq, Enum, Show, Read)
triangularView :: I.Elem a b => TriangularMode -> Matrix a b -> Matrix a b
triangularView Lower = imap $ \row col val -> case compare row col of { LT -> 0; _ -> val }
triangularView Upper = imap $ \row col val -> case compare row col of { GT -> 0; _ -> val }
triangularView StrictlyLower = imap $ \row col val -> case compare row col of { GT -> val; _ -> 0 }
triangularView StrictlyUpper = imap $ \row col val -> case compare row col of { LT -> val; _ -> 0 }
triangularView UnitLower = imap $ \row col val -> case compare row col of { GT -> val; LT -> 0; EQ -> 1 }
triangularView UnitUpper = imap $ \row col val -> case compare row col of { LT -> val; GT -> 0; EQ -> 1 }
lowerTriangle :: I.Elem a b => Matrix a b -> Matrix a b
lowerTriangle = triangularView Lower
upperTriangle :: I.Elem a b => Matrix a b -> Matrix a b
upperTriangle = triangularView Upper
filter :: I.Elem a b => (a -> Bool) -> Matrix a b -> Matrix a b
filter f = map (\x -> if f x then x else 0)
ifilter :: I.Elem a b => (Int -> Int -> a -> Bool) -> Matrix a b -> Matrix a b
ifilter f = imap (\r c x -> if f r c x then x else 0)
fold :: I.Elem a b => (c -> a -> c) -> c -> Matrix a b -> c
fold f a (Matrix _ _ vals) = VS.foldl (\a x -> f a (I.cast x)) a vals
fold' :: I.Elem a b => (c -> a -> c) -> c -> Matrix a b -> c
fold' f a (Matrix _ _ vals) = VS.foldl' (\a x -> f a (I.cast x)) a vals
ifold :: I.Elem a b => (Int -> Int -> c -> a -> c) -> c -> Matrix a b -> c
ifold f a (Matrix rows _ vals) = VS.ifoldl (\a n x -> let (c,r) = divMod n rows in f r c a (I.cast x)) a vals
ifold' :: I.Elem a b => (Int -> Int -> c -> a -> c) -> c -> Matrix a b -> c
ifold' f a (Matrix rows _ vals) = VS.ifoldl' (\a n x -> let (c,r) = divMod n rows in f r c a (I.cast x)) a vals
fold1 :: I.Elem a b => (a -> a -> a) -> Matrix a b -> a
fold1 f = foldl1 f . P.map I.cast . VS.toList . _vals
fold1' :: I.Elem a b => (a -> a -> a) -> Matrix a b -> a
fold1' f = L.foldl1' f . P.map I.cast . VS.toList . _vals
diagonal :: I.Elem a b => Matrix a b -> Matrix a b
diagonal = _unop (\(rows, cols) -> (min rows cols, 1)) I.diagonal
inverse :: I.Elem a b => Matrix a b -> Matrix a b
inverse m
| square m = _unop id I.inverse m
| otherwise = error "Matrix.inverse: non-square matrix"
adjoint :: I.Elem a b => Matrix a b -> Matrix a b
adjoint = _unop swap I.adjoint
transpose :: I.Elem a b => Matrix a b -> Matrix a b
transpose = _unop swap I.transpose
conjugate :: I.Elem a b => Matrix a b -> Matrix a b
conjugate = _unop id I.conjugate
normalize :: I.Elem a b => Matrix a b -> Matrix a b
normalize (Matrix rows cols vals) = I.performIO $ do
vals <- VS.thaw vals
VSM.unsafeWith vals $ \p ->
I.call $ I.normalize p (I.cast rows) (I.cast cols)
Matrix rows cols <$> VS.unsafeFreeze vals
modify :: I.Elem a b => (forall s. M.MMatrix a b s -> ST s ()) -> Matrix a b -> Matrix a b
modify f (Matrix rows cols vals) = Matrix rows cols (VS.modify (f . M.MMatrix rows cols) vals)
convert :: (I.Elem a b, I.Elem c d) => (a -> c) -> Matrix a b -> Matrix c d
convert f (Matrix rows cols vals) = Matrix rows cols $ VS.map (I.cast . f . I.cast) vals
freeze :: I.Elem a b => PrimMonad m => M.MMatrix a b (PrimState m) -> m (Matrix a b)
freeze (M.MMatrix mrows mcols mvals) = VS.freeze mvals >>= return . Matrix mrows mcols
thaw :: I.Elem a b => PrimMonad m => Matrix a b -> m (M.MMatrix a b (PrimState m))
thaw (Matrix rows cols vals) = VS.thaw vals >>= return . M.MMatrix rows cols
unsafeFreeze :: I.Elem a b => PrimMonad m => M.MMatrix a b (PrimState m) -> m (Matrix a b)
unsafeFreeze (M.MMatrix mrows mcols mvals) = VS.unsafeFreeze mvals >>= return . Matrix mrows mcols
unsafeThaw :: I.Elem a b => PrimMonad m => Matrix a b -> m (M.MMatrix a b (PrimState m))
unsafeThaw (Matrix rows cols vals) = VS.unsafeThaw vals >>= return . M.MMatrix rows cols
unsafeWith :: I.Elem a b => Matrix a b -> (Ptr b -> CInt -> CInt -> IO c) -> IO c
unsafeWith m@(Matrix rows cols vals) f
| not (valid m) = fail "Matrix.unsafeWith: matrix layout is invalid"
| otherwise = VS.unsafeWith vals $ \p -> f p (I.cast rows) (I.cast cols)
{-# INLINE _prop #-}
_prop :: I.Elem a b => (Ptr b -> Ptr b -> CInt -> CInt -> IO CString) -> Matrix a b -> a
_prop f m = I.cast $ I.performIO $ alloca $ \p -> do
I.call $ unsafeWith m (f p)
peek p
{-# INLINE _binop #-}
_binop :: I.Elem a b => ((Int, Int) -> (Int, Int) -> (Int, Int)) -> (Ptr b -> CInt -> CInt -> Ptr b -> CInt -> CInt -> Ptr b -> CInt -> CInt -> IO CString) -> Matrix a b -> Matrix a b -> Matrix a b
_binop f g m1 m2 = I.performIO $ do
m0 <- uncurry M.new $ f (dims m1) (dims m2)
M.unsafeWith m0 $ \vals0 rows0 cols0 ->
unsafeWith m1 $ \vals1 rows1 cols1 ->
unsafeWith m2 $ \vals2 rows2 cols2 ->
I.call $ g
vals0 rows0 cols0
vals1 rows1 cols1
vals2 rows2 cols2
unsafeFreeze m0
{-# INLINE _unop #-}
_unop :: I.Elem a b => ((Int,Int) -> (Int,Int)) -> (Ptr b -> CInt -> CInt -> Ptr b -> CInt -> CInt -> IO CString) -> Matrix a b -> Matrix a b
_unop f g m1 = I.performIO $ do
m0 <- uncurry M.new $ f (dims m1)
M.unsafeWith m0 $ \vals0 rows0 cols0 ->
unsafeWith m1 $ \vals1 rows1 cols1 ->
I.call $ g
vals0 rows0 cols0
vals1 rows1 cols1
unsafeFreeze m0
{-# INLINE _vals #-}
_vals :: I.Elem a b => Matrix a b -> VS.Vector b
_vals (Matrix _ _ vals) = vals