{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Basic where
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.RowMajor as RowMajor
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Matrix.Shape.Private
(Order(RowMajor, ColumnMajor), transposeFromOrder, flipOrder)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated))
import Numeric.LAPACK.Matrix.Private
(Full, Tall, Wide, General, ShapeInt, revealOrder)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, zero, one)
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked))
import Numeric.LAPACK.Private
(pointerSeq, copyTransposed, copySubMatrix, copyBlock)
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape ((:+:)((:+:)))
import Foreign.Marshal.Array (advancePtr)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Data.Complex (Complex)
caseTallWide ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
Full vert horiz height width a ->
Either (Tall height width a) (Wide height width a)
caseTallWide (Array shape a) =
either (Left . flip Array a) (Right . flip Array a) $
MatrixShape.caseTallWide shape
transpose ::
(Extent.C vert, Extent.C horiz) =>
Full vert horiz height width a -> Full horiz vert width height a
transpose = Array.mapShape MatrixShape.transpose
adjoint ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
Class.Floating a) =>
Full vert horiz height width a -> Full horiz vert width height a
adjoint = transpose . Vector.conjugate
swapMultiply ::
(Extent.C vertA, Extent.C vertB, Extent.C horizA, Extent.C horizB) =>
(matrix ->
Full horizA vertA widthA heightA a ->
Full horizB vertB widthB heightB a) ->
Full vertA horizA heightA widthA a ->
matrix ->
Full vertB horizB heightB widthB a
swapMultiply multiplyTrans a b = transpose $ multiplyTrans b $ transpose a
mapHeight ::
(Extent.GeneralTallWide vert horiz,
Extent.GeneralTallWide horiz vert) =>
(heightA -> heightB) ->
Full vert horiz heightA width a ->
Full vert horiz heightB width a
mapHeight f =
Array.mapShape
(\(MatrixShape.Full order extent) ->
MatrixShape.Full order $ Extent.mapHeight f extent)
mapWidth ::
(Extent.GeneralTallWide vert horiz,
Extent.GeneralTallWide horiz vert) =>
(widthA -> widthB) ->
Full vert horiz height widthA a ->
Full vert horiz height widthB a
mapWidth f =
Array.mapShape
(\(MatrixShape.Full order extent) ->
MatrixShape.Full order $ Extent.mapWidth f extent)
uncheck ::
(Extent.C vert, Extent.C horiz) =>
Full vert horiz height width a ->
Full vert horiz (Unchecked height) (Unchecked width) a
uncheck =
Array.mapShape
(\(MatrixShape.Full order extent) ->
MatrixShape.Full order $ Extent.mapWrap Unchecked Unchecked extent)
recheck ::
(Extent.C vert, Extent.C horiz) =>
Full vert horiz (Unchecked height) (Unchecked width) a ->
Full vert horiz height width a
recheck =
Array.mapShape
(\(MatrixShape.Full order extent) ->
MatrixShape.Full order $ Extent.recheck extent)
singleRow :: Order -> Vector width a -> General () width a
singleRow order = Array.mapShape (MatrixShape.general order ())
singleColumn :: Order -> Vector height a -> General height () a
singleColumn order = Array.mapShape (flip (MatrixShape.general order) ())
flattenRow :: General () width a -> Vector width a
flattenRow = Array.mapShape MatrixShape.fullWidth
flattenColumn :: General height () a -> Vector height a
flattenColumn = Array.mapShape MatrixShape.fullHeight
liftRow ::
Order ->
(Vector height0 a -> Vector height1 b) ->
General () height0 a -> General () height1 b
liftRow order f = singleRow order . f . flattenRow
liftColumn ::
Order ->
(Vector height0 a -> Vector height1 b) ->
General height0 () a -> General height1 () b
liftColumn order f = singleColumn order . f . flattenColumn
unliftRow ::
Order ->
(General () height0 a -> General () height1 b) ->
Vector height0 a -> Vector height1 b
unliftRow order f = flattenRow . f . singleRow order
unliftColumn ::
Order ->
(General height0 () a -> General height1 () b) ->
Vector height0 a -> Vector height1 b
unliftColumn order f = flattenColumn . f . singleColumn order
forceRowMajor ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
Class.Floating a) =>
Full vert horiz height width a ->
Full vert horiz height width a
forceRowMajor (Array shape@(MatrixShape.Full order extent) x) =
case order of
RowMajor -> Array shape x
ColumnMajor ->
Array.unsafeCreate (MatrixShape.Full RowMajor extent) $ \yPtr ->
withForeignPtr x $ \xPtr -> do
let (height, width) = Extent.dimensions extent
let n = Shape.size width
let m = Shape.size height
Private.copyTransposed n m xPtr n yPtr
forceOrder ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
Class.Floating a) =>
Order ->
Full vert horiz height width a ->
Full vert horiz height width a
forceOrder order =
case order of
RowMajor -> forceRowMajor
ColumnMajor -> transpose . forceRowMajor . transpose
takeSub ::
(Extent.C vert, Extent.C horiz,
Shape.C heightA, Shape.C height, Shape.C width, Class.Floating a) =>
heightA -> Int -> ForeignPtr a ->
MatrixShape.Full vert horiz height width ->
Full vert horiz height width a
takeSub heightA k a shape@(MatrixShape.Full order extentB) =
Array.unsafeCreateWithSize shape $ \blockSize bPtr ->
withForeignPtr a $ \aPtr ->
let ma = Shape.size heightA
mb = Shape.size $ Extent.height extentB
n = Shape.size $ Extent.width extentB
in case order of
RowMajor -> copyBlock blockSize (advancePtr aPtr (k*n)) bPtr
ColumnMajor -> copySubMatrix mb n ma (advancePtr aPtr k) mb bPtr
takeTop ::
(Extent.C vert, Shape.C height0, Shape.C height1, Shape.C width,
Class.Floating a) =>
Full vert Extent.Big (height0:+:height1) width a ->
Full vert Extent.Big height0 width a
takeTop (Array (MatrixShape.Full order extentA) a) =
let heightA@(heightB:+:_) = Extent.height extentA
extentB = Extent.reduceWideHeight heightB extentA
in takeSub heightA 0 a $ MatrixShape.Full order extentB
takeBottom ::
(Extent.C vert, Shape.C height0, Shape.C height1, Shape.C width,
Class.Floating a) =>
Full vert Extent.Big (height0:+:height1) width a ->
Full vert Extent.Big height1 width a
takeBottom (Array (MatrixShape.Full order extentA) a) =
let heightA@(height0:+:heightB) = Extent.height extentA
extentB = Extent.reduceWideHeight heightB extentA
in takeSub heightA (Shape.size height0) a $ MatrixShape.Full order extentB
takeLeft ::
(Extent.C vert, Shape.C height, Shape.C width0, Shape.C width1,
Class.Floating a) =>
Full Extent.Big vert height (width0:+:width1) a ->
Full Extent.Big vert height width0 a
takeLeft = transpose . takeTop . transpose
takeRight ::
(Extent.C vert, Shape.C height, Shape.C width0, Shape.C width1,
Class.Floating a) =>
Full Extent.Big vert height (width0:+:width1) a ->
Full Extent.Big vert height width1 a
takeRight = transpose . takeBottom . transpose
splitRows ::
(Extent.C vert, Shape.C width, Class.Floating a) =>
Int ->
Full vert Extent.Big ShapeInt width a ->
Full vert Extent.Big (ShapeInt:+:ShapeInt) width a
splitRows k =
Array.mapShape
(\(MatrixShape.Full order extent) ->
MatrixShape.Full order $
Extent.reduceWideHeight
(Shape.zeroBasedSplit k $ Extent.height extent)
extent)
takeRows, dropRows ::
(Extent.C vert, Shape.C width, Class.Floating a) =>
Int ->
Full vert Extent.Big ShapeInt width a ->
Full vert Extent.Big ShapeInt width a
takeRows k = takeTop . splitRows k
dropRows k = takeBottom . splitRows k
takeColumns, dropColumns ::
(Extent.C horiz, Shape.C height, Class.Floating a) =>
Int ->
Full Extent.Big horiz height ShapeInt a ->
Full Extent.Big horiz height ShapeInt a
takeColumns k = transpose . takeRows k . transpose
dropColumns k = transpose . dropRows k . transpose
data OrderBias = LeftBias | RightBias | ContiguousBias
deriving (Eq, Ord, Enum, Show)
beside ::
(Extent.C vertA, Extent.C vertB, Extent.C vertC,
Shape.C height, Eq height, Shape.C widthA, Shape.C widthB,
Class.Floating a) =>
OrderBias ->
Extent.AppendMode vertA vertB vertC height widthA widthB ->
Full vertA Extent.Big height widthA a ->
Full vertB Extent.Big height widthB a ->
Full vertC Extent.Big height (widthA:+:widthB) a
beside orderBias (Extent.AppendMode appendMode)
(Array (MatrixShape.Full orderA extentA) a)
(Array (MatrixShape.Full orderB extentB) b) =
let (heightA,widthA) = Extent.dimensions extentA
(heightB,widthB) = Extent.dimensions extentB
n = Shape.size heightA
ma = Shape.size widthA; volA = n*ma
mb = Shape.size widthB; volB = n*mb
m = ma+mb
create order act =
Array.unsafeCreate
(MatrixShape.Full order $ appendMode extentA extentB) $ \cPtr ->
withForeignPtr a $ \aPtr ->
withForeignPtr b $ \bPtr ->
act aPtr bPtr cPtr $ advancePtr cPtr $
case order of
RowMajor -> ma
ColumnMajor -> volA
in
if heightA /= heightB
then error "beside: mismatching heights"
else
case (orderA,orderB) of
(RowMajor,RowMajor) ->
create RowMajor $ \aPtr bPtr cPtr _ -> evalContT $ do
maPtr <- Call.cint ma
mbPtr <- Call.cint mb
incxPtr <- Call.cint 1
incyPtr <- Call.cint 1
liftIO $
sequence_ $ take n $
zipWith3
(\akPtr bkPtr ckPtr -> do
BlasGen.copy maPtr akPtr incxPtr ckPtr incyPtr
BlasGen.copy mbPtr bkPtr incxPtr
(ckPtr `advancePtr` ma) incyPtr)
(pointerSeq ma aPtr)
(pointerSeq mb bPtr)
(pointerSeq m cPtr)
(RowMajor,ColumnMajor) ->
case orderBias of
LeftBias ->
create RowMajor $ \aPtr bPtr clPtr crPtr -> do
copySubMatrix ma n ma aPtr m clPtr
copyTransposed mb n bPtr m crPtr
_ ->
create ColumnMajor $ \aPtr bPtr clPtr crPtr -> do
copyTransposed n ma aPtr n clPtr
copyBlock volB bPtr crPtr
(ColumnMajor,RowMajor) ->
case orderBias of
RightBias ->
create RowMajor $ \aPtr bPtr clPtr crPtr -> do
copyTransposed ma n aPtr m clPtr
copySubMatrix mb n mb bPtr m crPtr
_ ->
create ColumnMajor $ \aPtr bPtr clPtr crPtr -> do
copyBlock volA aPtr clPtr
copyTransposed n mb bPtr n crPtr
(ColumnMajor,ColumnMajor) ->
create ColumnMajor $ \aPtr bPtr clPtr crPtr -> evalContT $ do
naPtr <- Call.cint volA
nbPtr <- Call.cint volB
incxPtr <- Call.cint 1
incyPtr <- Call.cint 1
liftIO $ do
BlasGen.copy naPtr aPtr incxPtr clPtr incyPtr
BlasGen.copy nbPtr bPtr incxPtr crPtr incyPtr
above ::
(Extent.C horizA, Extent.C horizB, Extent.C horizC,
Shape.C width, Eq width, Shape.C heightA, Shape.C heightB,
Class.Floating a) =>
OrderBias ->
Extent.AppendMode horizA horizB horizC width heightA heightB ->
Full Extent.Big horizA heightA width a ->
Full Extent.Big horizB heightB width a ->
Full Extent.Big horizC (heightA:+:heightB) width a
above orderBias appendMode a b =
transpose $ beside orderBias appendMode (transpose a) (transpose b)
liftRowMajor ::
(Extent.C vert, Extent.C horiz) =>
(Array (height, width) a -> Array (height, width) b) ->
(Array (width, height) a -> Array (width, height) b) ->
Full vert horiz height width a ->
Full vert horiz height width b
liftRowMajor fr fc a =
either
(Array.reshape (Array.shape a) . fr)
(Array.reshape (Array.shape a) . fc) $
revealOrder a
scaleRows ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
Vector height a ->
Full vert horiz height width a ->
Full vert horiz height width a
scaleRows x = liftRowMajor (RowMajor.scaleRows x) (RowMajor.scaleColumns x)
scaleColumns ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
Vector width a ->
Full vert horiz height width a ->
Full vert horiz height width a
scaleColumns x = transpose . scaleRows x . transpose
scaleRowsComplex ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Class.Real a) =>
Vector height a ->
Full vert horiz height width (Complex a) ->
Full vert horiz height width (Complex a)
scaleRowsComplex x =
liftRowMajor
(RowMajor.recomplex . RowMajor.scaleRows x . RowMajor.decomplex)
(RowMajor.recomplex .
RowMajor.scaleColumns
(RowMajor.tensorProduct (Left NonConjugated) x
(Vector.one Shape.Enumeration)) .
RowMajor.decomplex)
scaleColumnsComplex ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Eq width, Class.Real a) =>
Vector width a ->
Full vert horiz height width (Complex a) ->
Full vert horiz height width (Complex a)
scaleColumnsComplex x = transpose . scaleRowsComplex x . transpose
scaleRowsReal ::
(Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width,
Class.Floating a) =>
Vector height (RealOf a) ->
Full vert horiz height width a ->
Full vert horiz height width a
scaleRowsReal =
getScaleRowsReal $
Class.switchFloating
(ScaleRowsReal scaleRows)
(ScaleRowsReal scaleRows)
(ScaleRowsReal scaleRowsComplex)
(ScaleRowsReal scaleRowsComplex)
newtype ScaleRowsReal f g a =
ScaleRowsReal {getScaleRowsReal :: f (RealOf a) -> g a -> g a}
scaleColumnsReal ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
Vector width (RealOf a) ->
Full vert horiz height width a ->
Full vert horiz height width a
scaleColumnsReal x = transpose . scaleRowsReal x . transpose
multiplyVector ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Eq width,
Class.Floating a) =>
Full vert horiz height width a -> Vector width a -> Vector height a
multiplyVector a x =
let width = MatrixShape.fullWidth $ Array.shape a
in if width == Array.shape x
then multiplyVectorUnchecked a x
else error "multiplyVector: width shapes mismatch"
multiplyVectorUnchecked ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
Class.Floating a) =>
Full vert horiz height width a -> Vector width a -> Vector height a
multiplyVectorUnchecked
(Array shape@(MatrixShape.Full order extent) a) (Array _ x) =
Array.unsafeCreate (Extent.height extent) $ \yPtr -> do
let (m,n) = MatrixShape.dimensions shape
let lda = m
evalContT $ do
transPtr <- Call.char $ transposeFromOrder order
mPtr <- Call.cint m
nPtr <- Call.cint n
alphaPtr <- Call.number one
aPtr <- ContT $ withForeignPtr a
ldaPtr <- Call.leadingDim lda
xPtr <- ContT $ withForeignPtr x
incxPtr <- Call.cint 1
betaPtr <- Call.number zero
incyPtr <- Call.cint 1
liftIO $
Private.gemv
transPtr mPtr nPtr alphaPtr aPtr ldaPtr
xPtr incxPtr betaPtr yPtr incyPtr
multiply, multiplyColumnMajor ::
(Extent.C vert, Extent.C horiz,
Shape.C height,
Shape.C fuse, Eq fuse,
Shape.C width,
Class.Floating a) =>
Full vert horiz height fuse a ->
Full vert horiz fuse width a ->
Full vert horiz height width a
multiply
(Array (MatrixShape.Full orderA extentA) a)
(Array (MatrixShape.Full orderB extentB) b) =
case Extent.fuse extentA extentB of
Nothing -> error "multiply: fuse shapes mismatch"
Just extent ->
Array.unsafeCreate (MatrixShape.Full orderB extent) $ \cPtr -> do
let (height,fuse) = Extent.dimensions extentA
let width = Extent.width extentB
let m = Shape.size height
let n = Shape.size width
let k = Shape.size fuse
case orderB of
RowMajor ->
Private.multiplyMatrix (flipOrder orderB) (flipOrder orderA)
n k m b a cPtr
ColumnMajor -> Private.multiplyMatrix orderA orderB m k n a b cPtr
multiplyColumnMajor
(Array (MatrixShape.Full orderA extentA) a)
(Array (MatrixShape.Full orderB extentB) b) =
case Extent.fuse extentA extentB of
Nothing -> error "multiply: fuse shapes mismatch"
Just extent ->
Array.unsafeCreate (MatrixShape.Full ColumnMajor extent) $ \cPtr -> do
let (height,fuse) = Extent.dimensions extentA
let width = Extent.width extentB
let m = Shape.size height
let n = Shape.size width
let k = Shape.size fuse
Private.multiplyMatrix orderA orderB m k n a b cPtr