module Data.Packed.Internal.Matrix(
Matrix(..), rows, cols, cdat, fdat,
MatrixOrder(..), orderOf,
createMatrix, mat,
cmat, fmat,
toLists, flatten, reshape,
Element(..),
trans,
fromRows, toRows, fromColumns, toColumns,
matrixFromVector,
subMatrix,
liftMatrix, liftMatrix2,
(@@>), atM',
singleton,
emptyM,
size, shSize, conformVs, conformMs, conformVTo, conformMTo
) where
import Data.Packed.Internal.Common
import Data.Packed.Internal.Signatures
import Data.Packed.Internal.Vector
import Foreign.Marshal.Alloc(alloca, free)
import Foreign.Marshal.Array(newArray)
import Foreign.Ptr(Ptr, castPtr)
import Foreign.Storable(Storable, peekElemOff, pokeElemOff, poke, sizeOf)
import Data.Complex(Complex)
import Foreign.C.Types
import System.IO.Unsafe(unsafePerformIO)
import Control.DeepSeq
data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
transOrder RowMajor = ColumnMajor
transOrder ColumnMajor = RowMajor
data Matrix t = Matrix { irows :: !Int
, icols :: !Int
, xdat :: !(Vector t)
, order :: !MatrixOrder }
cdat = xdat
fdat = xdat
rows :: Matrix t -> Int
rows = irows
cols :: Matrix t -> Int
cols = icols
orderOf :: Matrix t -> MatrixOrder
orderOf = order
trans :: Matrix t -> Matrix t
trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o}
cmat :: (Element t) => Matrix t -> Matrix t
cmat m@Matrix{order = RowMajor} = m
cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor}
fmat :: (Element t) => Matrix t -> Matrix t
fmat m@Matrix{order = ColumnMajor} = m
fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor}
mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
mat a f =
unsafeWith (xdat a) $ \p -> do
let m g = do
g (fi (rows a)) (fi (cols a)) p
f m
flatten :: Element t => Matrix t -> Vector t
flatten = xdat . cmat
toLists :: (Element t) => Matrix t -> [[t]]
toLists m = splitEvery (cols m) . toList . flatten $ m
fromRows :: Element t => [Vector t] -> Matrix t
fromRows [] = emptyM 0 0
fromRows vs = case compatdim (map dim vs) of
Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
Just 0 -> emptyM r 0
Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs
where
r = length vs
adapt c v
| c == 0 = fromList[]
| dim v == c = v
| otherwise = constantD (v@>0) c
toRows :: Element t => Matrix t -> [Vector t]
toRows m
| c == 0 = replicate r (fromList[])
| otherwise = toRows' 0
where
v = flatten m
r = rows m
c = cols m
toRows' k | k == r*c = []
| otherwise = subVector k c v : toRows' (k+c)
fromColumns :: Element t => [Vector t] -> Matrix t
fromColumns m = trans . fromRows $ m
toColumns :: Element t => Matrix t -> [Vector t]
toColumns m = toRows . trans $ m
(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
infixl 9 @@>
m@Matrix {irows = r, icols = c} @@> (i,j)
| safe = if i<0 || i>=r || j<0 || j>=c
then error "matrix indexing out of range"
else atM' m i j
| otherwise = atM' m i j
atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j)
atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
matrixFromVector o r c v
| r * c == dim v = m
| otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
where
m = Matrix { irows = r, icols = c, xdat = v, order = o }
createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
createMatrix ord r c = do
p <- createVector (r*c)
return (matrixFromVector ord r c p)
reshape :: Storable t => Int -> Vector t -> Matrix t
reshape 0 v = matrixFromVector RowMajor 0 0 v
reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
singleton x = reshape 1 (fromList [x])
liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d)
liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
liftMatrix2 f m1 m2
| not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
| otherwise = case orderOf m1 of
RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2))
ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2))
compat :: Matrix a -> Matrix b -> Bool
compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
class (Storable a) => Element a where
subMatrixD :: (Int,Int)
-> (Int,Int)
-> Matrix a -> Matrix a
subMatrixD = subMatrix'
transdata :: Int -> Vector a -> Int -> Vector a
transdata = transdataP
constantD :: a -> Int -> Vector a
constantD = constantP
instance Element Float where
transdata = transdataAux ctransF
constantD = constantAux cconstantF
instance Element Double where
transdata = transdataAux ctransR
constantD = constantAux cconstantR
instance Element (Complex Float) where
transdata = transdataAux ctransQ
constantD = constantAux cconstantQ
instance Element (Complex Double) where
transdata = transdataAux ctransC
constantD = constantAux cconstantC
transdataAux fun c1 d c2 =
if noneed
then d
else unsafePerformIO $ do
v <- createVector (dim d)
unsafeWith d $ \pd ->
unsafeWith v $ \pv ->
fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux"
return v
where r1 = dim d `div` c1
r2 = dim d `div` c2
noneed = dim d == 0 || r1 == 1 || c1 == 1
transdataP :: Storable a => Int -> Vector a -> Int -> Vector a
transdataP c1 d c2 =
if noneed
then d
else unsafePerformIO $ do
v <- createVector (dim d)
unsafeWith d $ \pd ->
unsafeWith v $ \pv ->
ctransP (fi r1) (fi c1) (castPtr pd) (fi sz) (fi r2) (fi c2) (castPtr pv) (fi sz) // check "transdataP"
return v
where r1 = dim d `div` c1
r2 = dim d `div` c2
sz = sizeOf (d @> 0)
noneed = dim d == 0 || r1 == 1 || c1 == 1
foreign import ccall unsafe "transF" ctransF :: TFMFM
foreign import ccall unsafe "transR" ctransR :: TMM
foreign import ccall unsafe "transQ" ctransQ :: TQMQM
foreign import ccall unsafe "transC" ctransC :: TCMCM
foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt
constantAux fun x n = unsafePerformIO $ do
v <- createVector n
px <- newArray [x]
app1 (fun px) vec v "constantAux"
free px
return v
foreign import ccall unsafe "constantF" cconstantF :: Ptr Float -> TF
foreign import ccall unsafe "constantR" cconstantR :: Ptr Double -> TV
foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV
foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV
constantP :: Storable a => a -> Int -> Vector a
constantP a n = unsafePerformIO $ do
let sz = sizeOf a
v <- createVector n
unsafeWith v $ \p -> do
alloca $ \k -> do
poke k a
cconstantP (castPtr k) (fi n) (castPtr p) (fi sz) // check "constantP"
return v
foreign import ccall unsafe "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -> CInt -> IO CInt
subMatrix :: Element a
=> (Int,Int)
-> (Int,Int)
-> Matrix a
-> Matrix a
subMatrix (r0,c0) (rt,ct) m
| 0 <= r0 && 0 <= rt && r0+rt <= (rows m) &&
0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m
| otherwise = error $ "wrong subMatrix "++
show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do
w <- createVector (rt*ct)
unsafeWith v $ \p ->
unsafeWith w $ \q -> do
let go (1) _ = return ()
go !i (1) = go (i1) (ct1)
go !i !j = do x <- peekElemOff p ((i+r0)*c+j+c0)
pokeElemOff q (i*ct+j) x
go i (j1)
go (rt1) (ct1)
return w
subMatrix' (r0,c0) (rt,ct) (Matrix { icols = c, xdat = v, order = RowMajor}) = Matrix rt ct (subMatrix'' (r0,c0) (rt,ct) c v) RowMajor
subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m)
maxZ xs = if minimum xs == 0 then 0 else maximum xs
conformMs ms = map (conformMTo (r,c)) ms
where
r = maxZ (map rows ms)
c = maxZ (map cols ms)
conformVs vs = map (conformVTo n) vs
where
n = maxZ (map dim vs)
conformMTo (r,c) m
| size m == (r,c) = m
| size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
| size m == (r,1) = repCols c m
| size m == (1,c) = repRows r m
| otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
conformVTo n v
| dim v == n = v
| dim v == 1 = constantD (v@>0) n
| otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
repRows n x = fromRows (replicate n (flatten x))
repCols n x = fromColumns (replicate n (flatten x))
size m = (rows m, cols m)
shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
emptyM r c = matrixFromVector RowMajor r c (fromList[])
instance (Storable t, NFData t) => NFData (Matrix t)
where
rnf m | d > 0 = rnf (v @> 0)
| otherwise = ()
where
d = dim v
v = xdat m