module Data.Packed.Internal.Matrix(
Matrix(..), rows, cols,
MatrixOrder(..), orderOf,
createMatrix, mat,
cmat, fmat,
toLists, flatten, reshape,
Element(..),
trans,
fromRows, toRows, fromColumns, toColumns,
matrixFromVector,
subMatrix,
liftMatrix, liftMatrix2,
(@@>),
saveMatrix,
singleton
) where
import Data.Packed.Internal.Common
import Data.Packed.Internal.Signatures
import Data.Packed.Internal.Vector
import Foreign hiding (xor)
import Data.Complex
import Foreign.C.Types
import Foreign.C.String
data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
data Matrix t = MC { irows :: !Int
, icols :: !Int
, cdat :: !(Vector t) }
| MF { irows :: !Int
, icols :: !Int
, fdat :: !(Vector t) }
rows :: Matrix t -> Int
rows = irows
cols :: Matrix t -> Int
cols = icols
xdat MC {cdat = d } = d
xdat MF {fdat = d } = d
orderOf :: Matrix t -> MatrixOrder
orderOf MF{} = ColumnMajor
orderOf MC{} = RowMajor
trans :: Matrix t -> Matrix t
trans MC {irows = r, icols = c, cdat = d } = MF {irows = c, icols = r, fdat = d }
trans MF {irows = r, icols = c, fdat = d } = MC {irows = c, icols = r, cdat = d }
cmat :: (Element t) => Matrix t -> Matrix t
cmat m@MC{} = m
cmat MF {irows = r, icols = c, fdat = d } = MC {irows = r, icols = c, cdat = transdata r d c}
fmat :: (Element t) => Matrix t -> Matrix t
fmat m@MF{} = m
fmat MC {irows = r, icols = c, cdat = d } = MF {irows = r, icols = c, fdat = transdata c d r}
mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
mat a f =
unsafeWith (xdat a) $ \p -> do
let m g = do
g (fi (rows a)) (fi (cols a)) p
f m
flatten :: Element t => Matrix t -> Vector t
flatten = cdat . cmat
type Mt t s = Int -> Int -> Ptr t -> s
toLists :: (Element t) => Matrix t -> [[t]]
toLists m = splitEvery (cols m) . toList . flatten $ m
fromRows :: Element t => [Vector t] -> Matrix t
fromRows vs = case compatdim (map dim vs) of
Nothing -> error "fromRows applied to [] or to vectors with different sizes"
Just c -> reshape c . join . map (adapt c) $ vs
where
adapt c v | dim v == c = v
| otherwise = constantD (v@>0) c
toRows :: Element t => Matrix t -> [Vector t]
toRows m = toRows' 0 where
v = flatten $ m
r = rows m
c = cols m
toRows' k | k == r*c = []
| otherwise = subVector k c v : toRows' (k+c)
fromColumns :: Element t => [Vector t] -> Matrix t
fromColumns m = trans . fromRows $ m
toColumns :: Element t => Matrix t -> [Vector t]
toColumns m = toRows . trans $ m
(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
infixl 9 @@>
MC {irows = r, icols = c, cdat = v} @@> (i,j)
| safe = if i<0 || i>=r || j<0 || j>=c
then error "matrix indexing out of range"
else v `at` (i*c+j)
| otherwise = v `at` (i*c+j)
MF {irows = r, icols = c, fdat = v} @@> (i,j)
| safe = if i<0 || i>=r || j<0 || j>=c
then error "matrix indexing out of range"
else v `at` (j*r+i)
| otherwise = v `at` (j*r+i)
atM' MC {icols = c, cdat = v} i j = v `at'` (i*c+j)
atM' MF {irows = r, fdat = v} i j = v `at'` (j*r+i)
matrixFromVector RowMajor c v = MC { irows = r, icols = c, cdat = v }
where (d,m) = dim v `divMod` c
r | m==0 = d
| otherwise = error "matrixFromVector"
matrixFromVector ColumnMajor c v = MF { irows = r, icols = c, fdat = v }
where (d,m) = dim v `divMod` c
r | m==0 = d
| otherwise = error "matrixFromVector"
createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
createMatrix order r c = do
p <- createVector (r*c)
return (matrixFromVector order c p)
reshape :: Storable t => Int -> Vector t -> Matrix t
reshape c v = matrixFromVector RowMajor c v
singleton x = reshape 1 (fromList [x])
liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
liftMatrix f MC { icols = c, cdat = d } = matrixFromVector RowMajor c (f d)
liftMatrix f MF { icols = c, fdat = d } = matrixFromVector ColumnMajor c (f d)
liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
liftMatrix2 f m1 m2
| not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
| otherwise = case m1 of
MC {} -> matrixFromVector RowMajor (cols m1) (f (cdat m1) (flatten m2))
MF {} -> matrixFromVector ColumnMajor (cols m1) (f (fdat m1) ((fdat.fmat) m2))
compat :: Matrix a -> Matrix b -> Bool
compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
class (Storable a) => Element a where
subMatrixD :: (Int,Int)
-> (Int,Int)
-> Matrix a -> Matrix a
subMatrixD = subMatrix'
transdata :: Int -> Vector a -> Int -> Vector a
transdata = transdataP
constantD :: a -> Int -> Vector a
constantD = constantP
instance Element Float where
transdata = transdataAux ctransF
constantD = constantAux cconstantF
instance Element Double where
transdata = transdataAux ctransR
constantD = constantAux cconstantR
instance Element (Complex Float) where
transdata = transdataAux ctransQ
constantD = constantAux cconstantQ
instance Element (Complex Double) where
transdata = transdataAux ctransC
constantD = constantAux cconstantC
transdata' :: Storable a => Int -> Vector a -> Int -> Vector a
transdata' c1 v c2 =
if noneed
then v
else unsafePerformIO $ do
w <- createVector (r2*c2)
unsafeWith v $ \p ->
unsafeWith w $ \q -> do
let go (1) _ = return ()
go !i (1) = go (i1) (c11)
go !i !j = do x <- peekElemOff p (i*c1+j)
pokeElemOff q (j*c2+i) x
go i (j1)
go (r11) (c11)
return w
where r1 = dim v `div` c1
r2 = dim v `div` c2
noneed = r1 == 1 || c1 == 1
transdataAux fun c1 d c2 =
if noneed
then d
else unsafePerformIO $ do
v <- createVector (dim d)
unsafeWith d $ \pd ->
unsafeWith v $ \pv ->
fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux"
return v
where r1 = dim d `div` c1
r2 = dim d `div` c2
noneed = r1 == 1 || c1 == 1
transdataP :: Storable a => Int -> Vector a -> Int -> Vector a
transdataP c1 d c2 =
if noneed
then d
else unsafePerformIO $ do
v <- createVector (dim d)
unsafeWith d $ \pd ->
unsafeWith v $ \pv ->
ctransP (fi r1) (fi c1) (castPtr pd) (fi sz) (fi r2) (fi c2) (castPtr pv) (fi sz) // check "transdataP"
return v
where r1 = dim d `div` c1
r2 = dim d `div` c2
sz = sizeOf (d @> 0)
noneed = r1 == 1 || c1 == 1
foreign import ccall "transF" ctransF :: TFMFM
foreign import ccall "transR" ctransR :: TMM
foreign import ccall "transQ" ctransQ :: TQMQM
foreign import ccall "transC" ctransC :: TCMCM
foreign import ccall "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt
constant' v n = unsafePerformIO $ do
w <- createVector n
unsafeWith w $ \p -> do
let go (1) = return ()
go !k = pokeElemOff p k v >> go (k1)
go (n1)
return w
constantAux fun x n = unsafePerformIO $ do
v <- createVector n
px <- newArray [x]
app1 (fun px) vec v "constantAux"
free px
return v
constantF :: Float -> Int -> Vector Float
constantF = constantAux cconstantF
foreign import ccall "constantF" cconstantF :: Ptr Float -> TF
constantR :: Double -> Int -> Vector Double
constantR = constantAux cconstantR
foreign import ccall "constantR" cconstantR :: Ptr Double -> TV
constantQ :: Complex Float -> Int -> Vector (Complex Float)
constantQ = constantAux cconstantQ
foreign import ccall "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV
constantC :: Complex Double -> Int -> Vector (Complex Double)
constantC = constantAux cconstantC
foreign import ccall "constantC" cconstantC :: Ptr (Complex Double) -> TCV
constantP :: Storable a => a -> Int -> Vector a
constantP a n = unsafePerformIO $ do
let sz = sizeOf a
v <- createVector n
unsafeWith v $ \p -> do
alloca $ \k -> do
poke k a
cconstantP (castPtr k) (fi n) (castPtr p) (fi sz) // check "constantP"
return v
foreign import ccall "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -> CInt -> IO CInt
subMatrix :: Element a
=> (Int,Int)
-> (Int,Int)
-> Matrix a
-> Matrix a
subMatrix (r0,c0) (rt,ct) m
| 0 <= r0 && 0 < rt && r0+rt <= (rows m) &&
0 <= c0 && 0 < ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m
| otherwise = error $ "wrong subMatrix "++
show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do
w <- createVector (rt*ct)
unsafeWith v $ \p ->
unsafeWith w $ \q -> do
let go (1) _ = return ()
go !i (1) = go (i1) (ct1)
go !i !j = do x <- peekElemOff p ((i+r0)*c+j+c0)
pokeElemOff q (i*ct+j) x
go i (j1)
go (rt1) (ct1)
return w
subMatrix' (r0,c0) (rt,ct) (MC _r c v) = MC rt ct $ subMatrix'' (r0,c0) (rt,ct) c v
subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m)
saveMatrix :: FilePath
-> String
-> Matrix Double
-> IO ()
saveMatrix filename fmt m = do
charname <- newCString filename
charfmt <- newCString fmt
let o = if orderOf m == RowMajor then 1 else 0
app1 (matrix_fprintf charname charfmt o) mat m "matrix_fprintf"
free charname
free charfmt
foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM