{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.BLAS.Matrix.RowMajor (
Matrix,
Vector,
takeRow,
takeColumn,
fromRows,
tensorProduct,
decomplex,
recomplex,
scaleRows,
scaleColumns,
) where
import qualified Numeric.BLAS.Private as Private
import Numeric.BLAS.Matrix.Modifier (Conjugation(NonConjugated,Conjugated))
import Numeric.BLAS.Scalar (zero, one)
import Numeric.BLAS.Private (ShapeInt, shapeInt, ComplexShape, pointerSeq)
import qualified Numeric.BLAS.FFI.Generic as Blas
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.ForeignPtr (withForeignPtr, castForeignPtr)
import Foreign.Storable (Storable)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Complex (Complex)
import Data.Foldable (forM_)
type Matrix height width = Array (height,width)
type Vector = Array
takeRow ::
(Shape.Indexed height, Shape.C width, Shape.Index height ~ ix,
Storable a) =>
ix -> Matrix height width a -> Vector width a
takeRow ix (Array (height,width) x) =
Array.unsafeCreateWithSize width $ \n yPtr ->
withForeignPtr x $ \xPtr ->
copyArray yPtr (advancePtr xPtr (n * Shape.offset height ix)) n
takeColumn ::
(Shape.C height, Shape.Indexed width, Shape.Index width ~ ix,
Class.Floating a) =>
ix -> Matrix height width a -> Vector height a
takeColumn ix (Array (height,width) x) =
Array.unsafeCreateWithSize height $ \n yPtr -> evalContT $ do
let offset = Shape.offset width ix
nPtr <- Call.cint n
xPtr <- ContT $ withForeignPtr x
incxPtr <- Call.cint $ Shape.size width
incyPtr <- Call.cint 1
liftIO $ Blas.copy nPtr (advancePtr xPtr offset) incxPtr yPtr incyPtr
fromRows ::
(Shape.C width, Eq width, Storable a) =>
width -> [Vector width a] -> Matrix ShapeInt width a
fromRows width rows =
Array.unsafeCreate (shapeInt $ length rows, width) $ \dstPtr ->
let widthSize = Shape.size width
in forM_ (zip (pointerSeq widthSize dstPtr) rows) $
\(dstRowPtr, Array.Array rowWidth srcFPtr) ->
withForeignPtr srcFPtr $ \srcPtr -> do
Call.assert
"Matrix.fromRows: non-matching vector size"
(width == rowWidth)
copyArray dstRowPtr srcPtr widthSize
tensorProduct ::
(Shape.C height, Shape.C width, Class.Floating a) =>
Either Conjugation Conjugation ->
Vector height a -> Vector width a -> Matrix height width a
tensorProduct side (Array height x) (Array width y) =
Array.unsafeCreate (height,width) $ \cPtr -> do
let m = Shape.size width
let n = Shape.size height
let trans conjugated =
case conjugated of NonConjugated -> 'T'; Conjugated -> 'C'
let ((transa,transb),(lda,ldb)) =
case side of
Left c -> ((trans c, 'N'),(1,1))
Right c -> (('N', trans c),(m,n))
evalContT $ do
transaPtr <- Call.char transa
transbPtr <- Call.char transb
mPtr <- Call.cint m
nPtr <- Call.cint n
kPtr <- Call.cint 1
alphaPtr <- Call.number one
aPtr <- ContT $ withForeignPtr y
ldaPtr <- Call.leadingDim lda
bPtr <- ContT $ withForeignPtr x
ldbPtr <- Call.leadingDim ldb
betaPtr <- Call.number zero
ldcPtr <- Call.leadingDim m
liftIO $
Blas.gemm
transaPtr transbPtr mPtr nPtr kPtr alphaPtr
aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr
decomplex ::
(Class.Real a) =>
Matrix height width (Complex a) ->
Matrix height (width, ComplexShape) a
decomplex (Array (height,width) a) =
Array (height, (width, Shape.static)) (castForeignPtr a)
recomplex ::
(Class.Real a) =>
Matrix height (width, ComplexShape) a ->
Matrix height width (Complex a)
recomplex (Array (height, (width, Shape.NestedTuple _)) a) =
Array (height,width) (castForeignPtr a)
scaleRows ::
(Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
Vector height a -> Matrix height width a -> Matrix height width a
scaleRows (Array heightX x) (Array shape@(height,width) a) =
Array.unsafeCreate shape $ \bPtr -> do
Call.assert "scaleRows: sizes mismatch" (heightX == height)
evalContT $ do
let m = Shape.size height
let n = Shape.size width
nPtr <- Call.cint n
xPtr <- ContT $ withForeignPtr x
aPtr <- ContT $ withForeignPtr a
incaPtr <- Call.cint 1
incbPtr <- Call.cint 1
liftIO $ sequence_ $ take m $
zipWith3
(\xkPtr akPtr bkPtr -> do
Blas.copy nPtr akPtr incaPtr bkPtr incbPtr
Blas.scal nPtr xkPtr bkPtr incbPtr)
(pointerSeq 1 xPtr)
(pointerSeq n aPtr)
(pointerSeq n bPtr)
scaleColumns ::
(Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
Vector width a -> Matrix height width a -> Matrix height width a
scaleColumns (Array widthX x) (Array shape@(height,width) a) =
Array.unsafeCreate shape $ \bPtr -> do
Call.assert "scaleColumns: sizes mismatch" (widthX == width)
evalContT $ do
let m = Shape.size height
let n = Shape.size width
transPtr <- Call.char 'N'
nPtr <- Call.cint n
klPtr <- Call.cint 0
kuPtr <- Call.cint 0
alphaPtr <- Call.number one
xPtr <- ContT $ withForeignPtr x
ldxPtr <- Call.leadingDim 1
aPtr <- ContT $ withForeignPtr a
incaPtr <- Call.cint 1
betaPtr <- Call.number zero
incbPtr <- Call.cint 1
liftIO $ sequence_ $ take m $
zipWith
(\akPtr bkPtr ->
Private.gbmv transPtr
nPtr nPtr klPtr kuPtr alphaPtr xPtr ldxPtr
akPtr incaPtr betaPtr bkPtr incbPtr)
(pointerSeq n aPtr)
(pointerSeq n bPtr)