module Data.Matrix.Dense.Internal (
module BLAS.Matrix.Base,
module BLAS.Tensor,
) where
import Control.Monad ( forM_, zipWithM_ )
import Data.Ix ( inRange, range )
import Foreign
import System.IO.Unsafe
import Unsafe.Coerce
import Data.AEq
import Data.Vector.Dense.Internal hiding ( toForeignPtr, fromForeignPtr,
unsafeFreeze, unsafeThaw, fptr, offset, unsafeWithElemPtr )
import qualified Data.Vector.Dense.Internal as V
import qualified Data.Vector.Dense.Operations as V
import BLAS.Access
import BLAS.Internal ( inlinePerformIO, checkedRow, checkedCol, checkedDiag,
checkedSubmatrix, diagStart, diagLen )
import BLAS.Elem ( Elem, BLAS1 )
import qualified BLAS.Elem as E
import BLAS.Matrix.Base hiding ( Matrix )
import qualified BLAS.Matrix.Base as C
import BLAS.Tensor
import BLAS.Types
data DMatrix t mn e =
DM { fptr :: !(ForeignPtr e)
, offset :: !Int
, size1 :: !Int
, size2 :: !Int
, lda :: !Int
| H !(DMatrix t mn e)
type Matrix = DMatrix Imm
type IOMatrix = DMatrix Mut
unsafeFreeze :: DMatrix t mn e -> Matrix mn e
unsafeFreeze = unsafeCoerce
unsafeThaw :: DMatrix t mn e -> IOMatrix mn e
unsafeThaw = unsafeCoerce
coerceMatrix :: DMatrix t mn e -> DMatrix t kl e
coerceMatrix = unsafeCoerce
fromForeignPtr :: ForeignPtr e -> Int -> (Int,Int) -> Int -> DMatrix t (m,n) e
fromForeignPtr f o (m,n) l = DM f o m n l
toForeignPtr :: DMatrix t (m,n) e -> (ForeignPtr e, Int, (Int,Int), Int)
toForeignPtr (H a) = toForeignPtr a
toForeignPtr a@(DM _ _ _ _ _) = (fptr a, offset a, (size1 a, size2 a), lda a)
ldaOf :: DMatrix t (m,n) e -> Int
ldaOf (H a) = ldaOf a
ldaOf a@(DM _ _ _ _ _) = lda a
indexOf :: DMatrix t (m,n) e -> (Int,Int) -> Int
indexOf (H a) (i,j) = indexOf a (j,i)
indexOf a@(DM _ _ _ _ _) (i,j) =
let o = offset a
l = lda a
in o + i + j*l
orderOf :: DMatrix t (m,n) e -> Order
orderOf (H a) = flipOrder (orderOf a)
orderOf (DM _ _ _ _ _) = ColMajor
isHerm :: DMatrix t (m,n) e -> Bool
isHerm (H a) = not (isHerm a)
isHerm (DM _ _ _ _ _) = False
newMatrix :: (BLAS1 e) => (Int,Int) -> [((Int,Int), e)] -> IO (DMatrix t (m,n) e)
newMatrix = newMatrixHelp writeElem
unsafeNewMatrix :: (BLAS1 e) => (Int,Int) -> [((Int,Int), e)] -> IO (DMatrix t (m,n) e)
unsafeNewMatrix = newMatrixHelp unsafeWriteElem
newMatrixHelp :: (BLAS1 e) =>
(IOMatrix (m,n) e -> (Int,Int) -> e -> IO ())
-> (Int,Int) -> [((Int,Int),e)] -> IO (DMatrix t (m,n) e)
newMatrixHelp set mn ijes = do
x <- newZero mn
io <- unsafeInterleaveIO $ mapM_ (uncurry $ set $ unsafeThaw x) ijes
return $ io `seq` x
newMatrix_ :: (Elem e) => (Int,Int) -> IO (DMatrix t (m,n) e)
newMatrix_ (m,n)
| m < 0 || n < 0 =
ioError $ userError $
"Tried to create a matrix with shape `" ++ show (m,n) ++ "'"
| otherwise = do
f <- mallocForeignPtrArray (m*n)
return $ fromForeignPtr f 0 (m,n) (max 1 m)
newListMatrix :: (Elem e) => (Int,Int) -> [e] -> IO (DMatrix t (m,n) e)
newListMatrix (m,n) es = do
a <- newMatrix_ (m,n)
unsafeWithElemPtr a (0,0) $ flip pokeArray (take (m*n) es)
return a
listMatrix :: (Elem e) => (Int,Int) -> [e] -> Matrix (m,n) e
listMatrix mn es = unsafePerformIO $ newListMatrix mn es
newIdentity :: (BLAS1 e) => (Int,Int) -> IO (DMatrix t (m,n) e)
newIdentity mn = do
a <- newMatrix_ mn
setIdentity (unsafeThaw a)
return a
setIdentity :: (BLAS1 e) => IOMatrix (m,n) e -> IO ()
setIdentity a = do
s <- getSize a
case s of
0 -> return ()
_ -> setZero a >>
setConstant 1 (diag a 0)
newColsMatrix :: (BLAS1 e) => (Int,Int) -> [DVector t m e] -> IO (DMatrix r (m,n) e)
newColsMatrix (m,n) cs = do
a <- newZero (m,n)
forM_ (zip [0..(n1)] cs) $ \(j,c) ->
V.copyVector (unsafeCol (unsafeThaw a) j) c
return a
newRowsMatrix :: (BLAS1 e) => (Int,Int) -> [DVector t n e] -> IO (DMatrix r (m,n) e)
newRowsMatrix (m,n) rs = do
a <- newZero (m,n)
forM_ (zip [0..(m1)] rs) $ \(i,r) ->
V.copyVector (unsafeRow (unsafeThaw a) i) r
return a
unsafeWithElemPtr :: (Elem e) => DMatrix t (m,n) e -> (Int,Int) -> (Ptr e -> IO a) -> IO a
unsafeWithElemPtr (H a) (i,j) f = unsafeWithElemPtr a (j,i) f
unsafeWithElemPtr a@(DM _ _ _ _ _) ij f =
withForeignPtr (fptr a) $ \ptr ->
let ptr' = ptr `advancePtr` (indexOf a ij)
in f ptr'
row :: (Elem e) => DMatrix t (m,n) e -> Int -> DVector t n e
row a = checkedRow (shape a) (unsafeRow a)
rows :: (Elem e) => DMatrix t (m,n) e -> [DVector t n e]
rows a = [ unsafeRow a i | i <- [0..numRows a 1] ]
cols :: (Elem e) => DMatrix t (m,n) e -> [DVector t m e]
cols a = [ unsafeCol a j | j <- [0..numCols a 1] ]
col :: (Elem e) => DMatrix t (m,n) e -> Int -> DVector t m e
col a = checkedCol (shape a) (unsafeCol a)
unsafeRow :: (Elem e) => DMatrix t (m,n) e -> Int -> DVector t n e
unsafeRow a@(H _) i = conj $ unsafeCol (herm a) i
unsafeRow a@(DM _ _ _ _ _) i =
let f = fptr a
o = indexOf a (i,0)
n = size2 a
s = lda a
in V.fromForeignPtr f o n s
unsafeCol :: (Elem e) => DMatrix t (m,n) e -> Int -> DVector t m e
unsafeCol a@(H _) j = conj $ unsafeRow (herm a) j
unsafeCol a@(DM _ _ _ _ _) j =
let f = fptr a
o = indexOf a (0,j)
m = size1 a
s = 1
in V.fromForeignPtr f o m s
diag :: (Elem e) => DMatrix t (m,n) e -> Int -> DVector t k e
diag a = checkedDiag (shape a) (unsafeDiag a)
unsafeDiag :: (Elem e) => DMatrix t (m,n) e -> Int -> DVector t k e
unsafeDiag (H a) i = conj $ unsafeDiag a (negate i)
unsafeDiag a@(DM _ _ _ _ _) i =
let f = fptr a
o = indexOf a (diagStart i)
n = diagLen (shape a) i
s = lda a + 1
in V.fromForeignPtr f o n s
submatrix :: (Elem e) => DMatrix t (m,n) e -> (Int,Int) -> (Int,Int) -> DMatrix t (k,l) e
submatrix a = checkedSubmatrix (shape a) (unsafeSubmatrix a)
unsafeSubmatrix :: (Elem e) => DMatrix t (m,n) e -> (Int,Int) -> (Int,Int) -> DMatrix t (k,l) e
unsafeSubmatrix a@(H _) (i,j) (m',n') = herm $ unsafeSubmatrix (herm a) (j,i) (n',m')
unsafeSubmatrix a@(DM _ _ _ _ _) (i,j) mn' =
let f = fptr a
o = indexOf a (i,j)
l = lda a
in fromForeignPtr f o mn' l
maybeFromRow :: (Elem e) => DVector t m e -> Maybe (DMatrix t (one,m) e)
maybeFromRow (V.C (V.C x)) = maybeFromRow x
maybeFromRow (V.C x@(V.DV _ _ _ _))
| V.stride x == 1 =
let f = V.fptr x
o = V.offset x
n = V.dim x
l = max 1 n
in Just $ herm $ fromForeignPtr f o (n,1) l
| otherwise =
maybeFromRow x@(V.DV _ _ _ _) =
let f = V.fptr x
o = V.offset x
n = V.dim x
s = V.stride x
l = max 1 s
in Just $ fromForeignPtr f o (1,n) l
maybeFromCol :: (Elem e) => DVector t n e -> Maybe (DMatrix t (n,one) e)
maybeFromCol (V.C x) = maybeFromRow x >>= return . herm
maybeFromCol x@(V.DV _ _ _ _)
| V.stride x == 1 =
let f = V.fptr x
o = V.offset x
m = dim x
l = max 1 m
in Just $ fromForeignPtr f o (m,1) l
| otherwise =
maybeToVector :: (Elem e) => DMatrix t (m,n) e -> Maybe (Order, DVector t k e)
maybeToVector (H a) = maybeToVector a >>= (\(o,x) -> return (flipOrder o, conj x))
maybeToVector (DM f o m n ld)
| ld == m =
Just $ (ColMajor, V.fromForeignPtr f o (m*n) 1)
| m == 1 =
Just $ (ColMajor, V.fromForeignPtr f o n ld)
| otherwise =
liftV :: (Elem e) => (DVector t k e -> IO ()) -> DMatrix t (m,n) e -> IO ()
liftV f a =
case maybeToVector a of
Just (_,x) -> f x
_ ->
let xs = case orderOf a of
RowMajor -> rows (coerceMatrix a)
ColMajor -> cols (coerceMatrix a)
in mapM_ f xs
liftV2 :: (Elem e) => (DVector s k e -> DVector t k e -> IO ())
-> DMatrix s (m,n) e -> DMatrix t (m,n) e -> IO ()
liftV2 f a b =
case (maybeToVector a, maybeToVector b) of
(Just (RowMajor,x), Just (RowMajor,y)) -> f x y
(Just (ColMajor,x), Just (ColMajor,y)) -> f x y
_ ->
let (xs,ys) = case orderOf a of
RowMajor -> (rows (coerceMatrix a), rows (coerceMatrix b))
ColMajor -> (cols (coerceMatrix a), cols (coerceMatrix b))
in zipWithM_ f xs ys
instance C.Matrix (DMatrix t) where
numRows = fst . shape
numCols = snd . shape
herm a = case a of
(H a') -> coerceMatrix a'
_ -> H (coerceMatrix a)
instance Tensor (DMatrix t (m,n)) (Int,Int) e where
shape a = case a of
(H a') -> case shape a' of (m,n) -> (n,m)
_ -> (size1 a, size2 a)
bounds a = let (m,n) = shape a in ((0,0), (m1,n1))
instance (BLAS1 e) => ITensor (DMatrix Imm (m,n)) (Int,Int) e where
size a = (numRows a * numCols a)
unsafeAt a = inlinePerformIO . unsafeReadElem a
indices a = [ (i,j) | j <- range (0,n1), i <- range (0,m1) ]
where (m,n) = shape a
elems = inlinePerformIO . getElems
assocs = inlinePerformIO . getAssocs
(//) = replaceHelp writeElem
unsafeReplace = replaceHelp unsafeWriteElem
amap f a = listMatrix (shape a) (map f $ elems a)
azipWith f a b
| shape b /= mn =
error ("azipWith: matrix shapes differ; first has shape `"
++ show mn ++ "' and second has shape `"
++ show (shape b) ++ "'")
| otherwise =
listMatrix mn (zipWith f (elems a) (elems b))
mn = shape a
replaceHelp :: (BLAS1 e) =>
(IOMatrix (m,n) e -> (Int,Int) -> e -> IO ())
-> Matrix (m,n) e -> [((Int,Int), e)] -> Matrix (m,n) e
replaceHelp set x ies =
unsafeFreeze $ unsafePerformIO $ do
y <- newCopy (unsafeThaw x)
mapM_ (uncurry $ set y) ies
return y
instance (BLAS1 e) => IDTensor (DMatrix Imm (m,n)) (Int,Int) e where
zero = unsafePerformIO . newZero
constant mn = unsafePerformIO . newConstant mn
instance (BLAS1 e) => RTensor (DMatrix t (m,n)) (Int,Int) e IO where
getSize a = return (numRows a * numCols a)
newCopy a = case a of
(H a') -> newCopy a' >>= return . H
_ -> do
a' <- newMatrix_ (shape a)
liftV2 V.copyVector (unsafeThaw a') a
return a'
unsafeReadElem a (i,j) = case a of
(H a') -> unsafeReadElem a' (j,i) >>= return . E.conj
_ -> withForeignPtr (fptr a) $ \ptr ->
peekElemOff ptr (indexOf a (i,j))
getIndices = return . indices . unsafeFreeze
getElems a = return $ go (cols a)
where go cs | cs `seq` False = undefined
go [] = []
go (c:cs) =
let e = inlinePerformIO $ getElems c
es = go cs
in e ++ es
getAssocs a = return $ go (cols a) 0
where go cs j | cs `seq` j `seq` False = undefined
go [] _ = []
go (c:cs) j =
let ie = inlinePerformIO $ getAssocs c
ije = map (\(i,e) -> ((i,j),e)) ie
ijes = go cs (j+1)
in ije ++ ijes
instance (BLAS1 e) => RDTensor (DMatrix t (m,n)) (Int,Int) e IO where
newZero mn = do
a <- newMatrix_ mn
setZero (unsafeThaw a)
return a
newConstant mn e = do
a <- newMatrix_ mn
setConstant e (unsafeThaw a)
return a
instance (BLAS1 e) => MTensor (DMatrix Mut (m,n)) (Int,Int) e IO where
setZero = liftV setZero
setConstant e = liftV (setConstant e)
canModifyElem a ij =
return $ inRange (bounds a) ij
unsafeWriteElem a (i,j) e = case a of
(H a') -> unsafeWriteElem a' (j,i) $ E.conj e
_ -> withForeignPtr (fptr a) $ \ptr ->
pokeElemOff ptr (indexOf a (i,j)) e
modifyWith f = liftV (modifyWith f)
instance (BLAS1 e, Show e) => Show (DMatrix Imm (m,n) e) where
show a = case a of
(H a') -> "herm (" ++ show a' ++ ")"
_ -> "listMatrix " ++ show (shape a) ++ " " ++ show (elems a)
compareHelp :: (BLAS1 e) =>
(e -> e -> Bool) -> Matrix (m,n) e -> Matrix (m,n) e -> Bool
compareHelp cmp x y
| isHerm x && isHerm y =
compareHelp cmp (herm x) (herm y)
compareHelp cmp x y =
(shape x == shape y) && (and $ zipWith cmp (elems x) (elems y))
instance (BLAS1 e, Eq e) => Eq (DMatrix Imm (m,n) e) where
(==) = compareHelp (==)
instance (BLAS1 e, AEq e) => AEq (DMatrix Imm (m,n) e) where
(===) = compareHelp (===)
(~==) = compareHelp (~==)