{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix (
Full,
General, Tall, Wide,
ZeroInt, zeroInt,
transpose, adjoint,
Matrix.height, Matrix.width,
caseTallWide,
fromScalar, toScalar,
fromList,
mapExtent, fromFull,
generalizeTall, generalizeWide,
mapHeight, mapWidth,
identity,
diagonal,
fromRowsNonEmpty, fromRowArray, fromRows,
fromColumnsNonEmpty, fromColumnArray, fromColumns,
Basic.singleRow, Basic.singleColumn,
Basic.flattenRow, Basic.flattenColumn,
Basic.liftRow, Basic.liftColumn,
Basic.unliftRow, Basic.unliftColumn,
toRows, toColumns,
toRowArray, toColumnArray,
takeRow, takeColumn,
takeRows, takeColumns, takeEqually,
dropRows, dropColumns, dropEqually,
takeTopRows, takeBottomRows,
takeLeftColumns, takeRightColumns,
takeRowArray, takeColumnArray,
swapRows, swapColumns,
reverseRows, reverseColumns,
fromRowMajor, toRowMajor, flatten,
forceOrder, adaptOrder,
(|||),
(===),
tensorProduct,
outer,
sumRank1,
RealOf,
add, sub,
rowSums, columnSums,
scaleRows, scaleColumns,
scaleRowsComplex, scaleColumnsComplex,
scaleRowsReal, scaleColumnsReal,
multiply,
multiplyVector,
Multiply, (<#>),
MultiplyLeft, (<#),
MultiplyRight, (#>),
Indexed, (#!),
Solve, solve, solveVector,
Inverse, inverse,
) where
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Square.Basic as Square
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Private as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor))
import Numeric.LAPACK.Matrix.Multiply
(Multiply((<#>)), MultiplyLeft((<#)), MultiplyRight((#>)),
multiplyVector, multiply, multiplyVectorUnchecked)
import Numeric.LAPACK.Matrix.Indexed (Indexed((#!)))
import Numeric.LAPACK.Matrix.Divide
(Solve(solve), solveVector, Inverse(inverse))
import Numeric.LAPACK.Matrix.Basic
(transpose, forceOrder, forceRowMajor, scaleRows, scaleColumns)
import Numeric.LAPACK.Matrix.Private
(Full, Tall, Wide, General, argGeneral, ZeroInt, zeroInt,
mapExtent, fromFull, generalizeTall, generalizeWide)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, zero, one)
import Numeric.LAPACK.Private
(pointerSeq, fill, copyTransposed, copySubMatrix, copyBlock)
import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import qualified Data.Array.Comfort.Boxed as BoxedArray
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape ((:+:)((:+:)))
import Foreign.Marshal.Array (copyArray, advancePtr, pokeArray)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Storable (Storable, poke, peek)
import System.IO.Unsafe (unsafePerformIO)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when, mfilter)
import qualified Data.NonEmpty.Mixed as NonEmptyM
import qualified Data.NonEmpty as NonEmpty
import Data.Complex (Complex)
import Data.Foldable (forM_)
import Data.Maybe (listToMaybe)
import Data.Bool.HT (if')
adjoint ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
Class.Floating a) =>
Full vert horiz height width a -> Full horiz vert width height a
adjoint = transpose . Vector.conjugate
caseTallWide ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
Full vert horiz height width a ->
Either (Tall height width a) (Wide height width a)
caseTallWide (Array shape a) =
either (Left . flip Array a) (Right . flip Array a) $
MatrixShape.caseTallWide shape
fromScalar :: (Storable a) => a -> General () () a
fromScalar = Square.toGeneral . Square.fromScalar
toScalar :: (Storable a) => General () () a -> a
toScalar = argGeneral $ \_ () () a ->
unsafePerformIO $ withForeignPtr a peek
fromList ::
(Shape.C height, Shape.C width, Storable a) =>
height -> width -> [a] -> General height width a
fromList height width =
CheckedArray.fromList (MatrixShape.general RowMajor height width)
identity ::
(Shape.C sh, Class.Floating a) =>
sh -> General sh sh a
identity = Square.toGeneral . Square.identity
diagonal ::
(Shape.C sh, Class.Floating a) =>
Vector sh a -> General sh sh a
diagonal = Square.toGeneral . Square.diagonal
mapHeight ::
(Shape.C heightA, Shape.C heightB,
Extent.GeneralTallWide vert horiz,
Extent.GeneralTallWide horiz vert) =>
(heightA -> heightB) ->
Full vert horiz heightA width a ->
Full vert horiz heightB width a
mapHeight f = Basic.mapHeight $ withSizeCheck "mapHeight" f
mapWidth ::
(Shape.C widthA, Shape.C widthB,
Extent.GeneralTallWide vert horiz,
Extent.GeneralTallWide horiz vert) =>
(widthA -> widthB) ->
Full vert horiz height widthA a ->
Full vert horiz height widthB a
mapWidth f = Basic.mapWidth $ withSizeCheck "mapWidth" f
withSizeCheck ::
(Shape.C sha, Shape.C shb) =>
String -> (sha -> shb) -> sha -> shb
withSizeCheck name f sizeA =
let sizeB = f sizeA
in if Shape.size sizeA == Shape.size sizeB
then sizeB
else error $ name ++ ": sizes mismatch"
fromRowsNonEmpty ::
(Shape.C width, Eq width, Storable a) =>
NonEmpty.T [] (Vector width a) -> General ZeroInt width a
fromRowsNonEmpty (NonEmpty.Cons row rows) =
fromRows (Array.shape row) (row:rows)
fromRowArray ::
(Shape.C height, Shape.C width, Eq width, Storable a) =>
width -> BoxedArray.Array height (Vector width a) -> General height width a
fromRowArray width rows =
Basic.mapHeight (const $ BoxedArray.shape rows) $
fromRows width $ BoxedArray.toList rows
fromRows ::
(Shape.C width, Eq width, Storable a) =>
width -> [Vector width a] -> General ZeroInt width a
fromRows width rows =
Array.unsafeCreate
(MatrixShape.general RowMajor (zeroInt $ length rows) width)
(gather width rows)
fromColumnsNonEmpty ::
(Shape.C height, Eq height, Storable a) =>
NonEmpty.T [] (Vector height a) -> General height ZeroInt a
fromColumnsNonEmpty (NonEmpty.Cons column columns) =
fromColumns (Array.shape column) (column:columns)
fromColumnArray ::
(Shape.C height, Eq height, Shape.C width, Storable a) =>
height -> BoxedArray.Array width (Vector height a) -> General height width a
fromColumnArray height columns =
Basic.mapWidth (const $ BoxedArray.shape columns) $
fromColumns height $ BoxedArray.toList columns
fromColumns ::
(Shape.C height, Eq height, Storable a) =>
height -> [Vector height a] -> General height ZeroInt a
fromColumns height columns =
Array.unsafeCreate
(MatrixShape.general ColumnMajor height (zeroInt $ length columns))
(gather height columns)
gather ::
(Shape.C width, Eq width, Storable a) =>
width -> [Array width a] -> Ptr a -> IO ()
gather width rows dstPtr =
let widthSize = Shape.size width
in forM_ (zip (pointerSeq widthSize dstPtr) rows) $
\(dstRowPtr, Array.Array rowWidth srcFPtr) ->
withForeignPtr srcFPtr $ \srcPtr -> do
Call.assert
"Matrix.fromRows/fromColumnsNonEmpty: non-matching vector size"
(width == rowWidth)
copyArray dstRowPtr srcPtr widthSize
toRows ::
(Extent.C vert, Extent.C horiz,
Shape.Indexed height, Shape.C width, Class.Floating a) =>
Full vert horiz height width a -> [Vector width a]
toRows a = map (takeRow a) $ Shape.indices $ Matrix.height a
toColumns ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.Indexed width, Class.Floating a) =>
Full vert horiz height width a -> [Vector height a]
toColumns a = map (takeColumn a) $ Shape.indices $ Matrix.width a
toRowArray ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Class.Floating a) =>
Full vert horiz height width a -> BoxedArray.Array height (Vector width a)
toRowArray a =
let ad = Basic.mapHeight Shape.Deferred $ fromFull a
in BoxedArray.fromList (Matrix.height a) $
map (takeRow ad) $ Shape.indices (Matrix.height ad)
toColumnArray ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Class.Floating a) =>
Full vert horiz height width a -> BoxedArray.Array width (Vector height a)
toColumnArray a =
let ad = Basic.mapWidth Shape.Deferred $ fromFull a
in BoxedArray.fromList (Matrix.width a) $
map (takeColumn ad) $ Shape.indices (Matrix.width ad)
takeRow ::
(Extent.C vert, Extent.C horiz,
Shape.Indexed height, Shape.C width, Shape.Index height ~ ix,
Class.Floating a) =>
Full vert horiz height width a -> ix -> Vector width a
takeRow (Array (MatrixShape.Full order extent) x) ix =
let (height,width) = Extent.dimensions extent
in case order of
RowMajor -> pickConsecutive height width x ix
ColumnMajor -> pickScattered width height x ix
takeColumn ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.Indexed width, Shape.Index width ~ ix,
Class.Floating a) =>
Full vert horiz height width a -> ix -> Vector height a
takeColumn (Array (MatrixShape.Full order extent) x) ix =
let (height,width) = Extent.dimensions extent
in case order of
RowMajor -> pickScattered height width x ix
ColumnMajor -> pickConsecutive width height x ix
pickConsecutive ::
(Shape.Indexed height, Shape.C width, Shape.Index height ~ ix,
Class.Floating a) =>
height -> width -> ForeignPtr a -> ix -> Vector width a
pickConsecutive height width x ix =
Array.unsafeCreateWithSize width $ \n yPtr -> evalContT $ do
let offset = Shape.offset height ix
nPtr <- Call.cint n
xPtr <- ContT $ withForeignPtr x
incxPtr <- Call.cint 1
incyPtr <- Call.cint 1
liftIO $
BlasGen.copy nPtr (advancePtr xPtr (n*offset)) incxPtr yPtr incyPtr
pickScattered ::
(Shape.C height, Shape.Indexed width, Shape.Index width ~ ix,
Class.Floating a) =>
height -> width -> ForeignPtr a -> ix -> Vector height a
pickScattered height width x ix =
Array.unsafeCreateWithSize height $ \n yPtr -> evalContT $ do
let offset = Shape.offset width ix
nPtr <- Call.cint n
xPtr <- ContT $ withForeignPtr x
incxPtr <- Call.cint $ Shape.size width
incyPtr <- Call.cint 1
liftIO $
BlasGen.copy nPtr (advancePtr xPtr offset) incxPtr yPtr incyPtr
takeTopRows ::
(Extent.C vert, Shape.C height0, Shape.C height1, Shape.C width,
Class.Floating a) =>
Full vert Extent.Big (height0:+:height1) width a ->
Full vert Extent.Big height0 width a
takeTopRows (Array (MatrixShape.Full order extentA) a) =
let (heightA@(heightB:+:_), width) = Extent.dimensions extentA
extentB = Extent.reduceWideHeight heightB extentA
ma = Shape.size heightA
mb = Shape.size heightB
n = Shape.size width
in Array.unsafeCreateWithSize (MatrixShape.Full order extentB) $
\blockSize bPtr ->
withForeignPtr a $ \aPtr ->
case order of
RowMajor -> copyBlock blockSize aPtr bPtr
ColumnMajor -> copySubMatrix mb n ma aPtr mb bPtr
takeBottomRows ::
(Extent.C vert, Shape.C height0, Shape.C height1, Shape.C width,
Class.Floating a) =>
Full vert Extent.Big (height0:+:height1) width a ->
Full vert Extent.Big height1 width a
takeBottomRows (Array (MatrixShape.Full order extentA) a) =
let (heightA@(height0:+:heightB), width) = Extent.dimensions extentA
extentB = Extent.reduceWideHeight heightB extentA
k = Shape.size height0
ma = Shape.size heightA
mb = Shape.size heightB
n = Shape.size width
in Array.unsafeCreateWithSize (MatrixShape.Full order extentB) $
\blockSize bPtr ->
withForeignPtr a $ \aPtr ->
case order of
RowMajor -> copyBlock blockSize (advancePtr aPtr (k*n)) bPtr
ColumnMajor -> copySubMatrix mb n ma (advancePtr aPtr k) mb bPtr
takeLeftColumns ::
(Extent.C vert, Shape.C height, Shape.C width0, Shape.C width1,
Class.Floating a) =>
Full Extent.Big vert height (width0:+:width1) a ->
Full Extent.Big vert height width0 a
takeLeftColumns = transpose . takeTopRows . transpose
takeRightColumns ::
(Extent.C vert, Shape.C height, Shape.C width0, Shape.C width1,
Class.Floating a) =>
Full Extent.Big vert height (width0:+:width1) a ->
Full Extent.Big vert height width1 a
takeRightColumns = transpose . takeBottomRows . transpose
splitRows ::
(Extent.C vert, Shape.C width, Class.Floating a) =>
Int ->
Full vert Extent.Big ZeroInt width a ->
Full vert Extent.Big (ZeroInt:+:ZeroInt) width a
splitRows k =
Array.mapShape
(\(MatrixShape.Full order extentA) ->
let (Shape.ZeroBased heightA) = Extent.height extentA
heightB = min k heightA
in if' (k<0) (error "split: negative number") $
MatrixShape.Full order $
Extent.reduceWideHeight
(Shape.ZeroBased heightB :+: Shape.ZeroBased (heightA-heightB))
extentA)
takeRows, dropRows ::
(Extent.C vert, Shape.C width, Class.Floating a) =>
Int ->
Full vert Extent.Big ZeroInt width a ->
Full vert Extent.Big ZeroInt width a
takeRows k = takeTopRows . splitRows k
dropRows k = takeBottomRows . splitRows k
takeColumns, dropColumns ::
(Extent.C horiz, Shape.C height, Class.Floating a) =>
Int ->
Full Extent.Big horiz height ZeroInt a ->
Full Extent.Big horiz height ZeroInt a
takeColumns k = transpose . takeRows k . transpose
dropColumns k = transpose . dropRows k . transpose
takeEqually ::
(Extent.C vert, Extent.C horiz, Class.Floating a) =>
Int ->
Full vert horiz ZeroInt ZeroInt a ->
Full vert horiz ZeroInt ZeroInt a
takeEqually k (Array (MatrixShape.Full order extentA) a) =
let (Shape.ZeroBased heightA, Shape.ZeroBased widthA) =
Extent.dimensions extentA
heightB = min k heightA
widthB = min k widthA
extentB =
Extent.reduceConsistent
(Shape.ZeroBased heightB) (Shape.ZeroBased widthB) extentA
in if' (k<0) (error "take: negative number") $
Array.unsafeCreate (MatrixShape.Full order extentB) $ \bPtr ->
withForeignPtr a $ \aPtr ->
case order of
RowMajor -> copySubMatrix widthB heightB widthA aPtr widthB bPtr
ColumnMajor -> copySubMatrix heightB widthB heightA aPtr heightB bPtr
dropEqually ::
(Extent.C vert, Extent.C horiz, Class.Floating a) =>
Int ->
Full vert horiz ZeroInt ZeroInt a ->
Full vert horiz ZeroInt ZeroInt a
dropEqually k (Array (MatrixShape.Full order extentA) a) =
let (Shape.ZeroBased heightA, Shape.ZeroBased widthA) =
Extent.dimensions extentA
heightB = heightA - top; top = min k heightA
widthB = widthA - left; left = min k widthA
extentB =
Extent.reduceConsistent
(Shape.ZeroBased heightB) (Shape.ZeroBased widthB) extentA
in if' (k<0) (error "drop: negative number") $
Array.unsafeCreate (MatrixShape.Full order extentB) $ \bPtr ->
withForeignPtr a $ \aPtr ->
case order of
RowMajor ->
copySubMatrix widthB heightB
widthA (advancePtr aPtr (top*widthA+left)) widthB bPtr
ColumnMajor ->
copySubMatrix heightB widthB
heightA (advancePtr aPtr (left*heightA+top)) heightB bPtr
swapRows ::
(Extent.C vert, Extent.C horiz,
Shape.Indexed height, Shape.C width, Class.Floating a) =>
Shape.Index height -> Shape.Index height ->
Full vert horiz height width a -> Full vert horiz height width a
swapRows i j (Array shape@(MatrixShape.Full order extent) a) =
Array.unsafeCreateWithSize shape $ \blockSize bPtr -> evalContT $ do
let (height,width) = Extent.dimensions extent
let m = Shape.size height
let n = Shape.size width
nPtr <- Call.cint n
aPtr <- ContT $ withForeignPtr a
let offsetI = Shape.offset height i
let offsetJ = Shape.offset height j
let (incVert,incHoriz) =
case order of
RowMajor -> (n,1)
ColumnMajor -> (1,m)
incPtr <- Call.cint incHoriz
liftIO $ do
copyBlock blockSize aPtr bPtr
when (offsetI/=offsetJ) $
BlasGen.swap nPtr
(advancePtr bPtr (incVert*offsetI)) incPtr
(advancePtr bPtr (incVert*offsetJ)) incPtr
swapColumns ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.Indexed width, Class.Floating a) =>
Shape.Index width -> Shape.Index width ->
Full vert horiz height width a -> Full vert horiz height width a
swapColumns i j = transpose . swapRows i j . transpose
reverseRows ::
(Extent.C vert, Extent.C horiz, Shape.C width, Class.Floating a) =>
Full vert horiz ZeroInt width a -> Full vert horiz ZeroInt width a
reverseRows (Array shape@(MatrixShape.Full order extent) a) =
Array.unsafeCreateWithSize shape $ \blockSize bPtr -> evalContT $ do
let (height,width) = Extent.dimensions extent
let n = Shape.size height
let m = Shape.size width
fwdPtr <- Call.bool True
nPtr <- Call.cint n
mPtr <- Call.cint m
kPtr <- Call.allocaArray n
aPtr <- ContT $ withForeignPtr a
liftIO $ do
copyBlock blockSize aPtr bPtr
pokeArray kPtr $ take n $ iterate (subtract 1) $ fromIntegral n
case order of
RowMajor -> LapackGen.lapmt fwdPtr mPtr nPtr bPtr mPtr kPtr
ColumnMajor -> LapackGen.lapmr fwdPtr nPtr mPtr bPtr nPtr kPtr
reverseColumns ::
(Extent.C vert, Extent.C horiz, Shape.C height, Class.Floating a) =>
Full vert horiz height ZeroInt a -> Full vert horiz height ZeroInt a
reverseColumns = transpose . reverseRows . transpose
takeRowArray ::
(Shape.Indexed height, Shape.C width, Shape.C sh, Class.Floating a) =>
BoxedArray.Array sh (Shape.Index height) ->
General height width a -> General sh width a
takeRowArray ixs (Array (MatrixShape.Full order extent) a) =
let (heightA, width) = Extent.dimensions extent
heightB = BoxedArray.shape ixs
offsets = map (Shape.offset heightA) $ BoxedArray.toList ixs
startBlocks blocks = zip (scanl (+) 0 $ map fst blocks) blocks
ma = Shape.size heightA
mb = Shape.size heightB
n = Shape.size width
in Array.unsafeCreate (MatrixShape.general order heightB width) $ \bPtr ->
withForeignPtr a $ \aPtr ->
case order of
RowMajor -> do
forM_ (startBlocks $ chopRowBlocks offsets) $
\(dest, (numRows, (start,step))) ->
copySubMatrix n numRows
(step*n) (advancePtr aPtr (start*n))
n (advancePtr bPtr (dest*n))
ColumnMajor -> do
forM_ (startBlocks $ chopColumnBlocks offsets) $
\(dest, (numRows, start)) ->
copySubMatrix numRows n
ma (advancePtr aPtr start)
mb (advancePtr bPtr dest)
chopRowBlocks :: (Integral i) => [i] -> [(Int,(i,i))]
chopRowBlocks =
let go [] = []
go is@(i0:is0) =
case mfilter (i0<) $ listToMaybe is0 of
Nothing -> (1,(i0,0)) : go is0
Just i1 ->
let (consecutive,remainder) =
span (uncurry (==)) $ zip [i0,i1..] is
in (length consecutive, (i0,i1-i0)) : go (map snd remainder)
in go
chopColumnBlocks :: (Integral i) => [i] -> [(Int,i)]
chopColumnBlocks =
map (\is -> (length $ NonEmpty.flatten is, NonEmpty.head is)) .
NonEmptyM.groupBy (\i j -> i+1 == j)
takeColumnArray ::
(Shape.C height, Shape.Indexed width, Shape.C sh, Class.Floating a) =>
BoxedArray.Array sh (Shape.Index width) ->
General height width a -> General height sh a
takeColumnArray ixs = transpose . takeRowArray ixs . transpose
fromRowMajor ::
(Shape.C height, Shape.C width, Class.Floating a) =>
Array (height,width) a -> General height width a
fromRowMajor = Array.mapShape (uncurry $ MatrixShape.general RowMajor)
toRowMajor ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Class.Floating a) =>
Full vert horiz height width a -> Array (height,width) a
toRowMajor =
Array.mapShape
(\shape -> (MatrixShape.fullHeight shape, MatrixShape.fullWidth shape)) .
forceRowMajor
adaptOrder ::
(Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
Class.Floating a) =>
Full vert horiz height width a ->
Full vert horiz height width a ->
Full vert horiz height width a
adaptOrder x = forceOrder (MatrixShape.fullOrder $ Array.shape x)
flatten ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Class.Floating a) =>
Full vert horiz height width a -> Vector ZeroInt a
flatten = Array.mapShape (zeroInt . Shape.size) . toRowMajor
infixl 3 |||
infixl 2 ===
(|||) ::
(Extent.C vert, Shape.C height, Eq height, Shape.C widtha, Shape.C widthb,
Class.Floating a) =>
Full vert Extent.Big height widtha a ->
Full vert Extent.Big height widthb a ->
Full vert Extent.Big height (widtha:+:widthb) a
(|||)
(Array (MatrixShape.Full orderA extentA) a)
(Array (MatrixShape.Full orderB extentB) b) =
let (heightA,widthA) = Extent.dimensions extentA
(heightB,widthB) = Extent.dimensions extentB
extent = Extent.widen (widthA:+:widthB) extentA
shape order = MatrixShape.Full order extent
in
if heightA /= heightB
then error "(|||): mismatching heights"
else
case (orderA,orderB) of
(RowMajor,RowMajor) ->
Array.unsafeCreate (shape RowMajor) $
\cPtr -> evalContT $ do
let n = Shape.size heightA
let ma = Shape.size widthA
let mb = Shape.size widthB
let m = ma+mb
maPtr <- Call.cint ma
mbPtr <- Call.cint mb
aPtr <- ContT $ withForeignPtr a
bPtr <- ContT $ withForeignPtr b
incxPtr <- Call.cint 1
incyPtr <- Call.cint 1
liftIO $
sequence_ $ take n $
zipWith3
(\akPtr bkPtr ckPtr -> do
BlasGen.copy maPtr akPtr incxPtr ckPtr incyPtr
BlasGen.copy mbPtr bkPtr incxPtr
(ckPtr `advancePtr` ma) incyPtr)
(pointerSeq ma aPtr)
(pointerSeq mb bPtr)
(pointerSeq m cPtr)
(RowMajor,ColumnMajor) ->
Array.unsafeCreate (shape ColumnMajor) $
\cPtr -> evalContT $ do
let n = Shape.size heightA
let ma = Shape.size widthA
let mb = Shape.size widthB
aPtr <- ContT $ withForeignPtr a
bPtr <- ContT $ withForeignPtr b
liftIO $ do
copyTransposed n ma aPtr n cPtr
copyBlock (n*mb) bPtr (advancePtr cPtr (n*ma))
(ColumnMajor,RowMajor) ->
Array.unsafeCreate (shape ColumnMajor) $
\cPtr -> evalContT $ do
let n = Shape.size heightA
let ma = Shape.size widthA
let mb = Shape.size widthB
let volA = n*ma
aPtr <- ContT $ withForeignPtr a
bPtr <- ContT $ withForeignPtr b
liftIO $ do
copyBlock volA aPtr cPtr
copyTransposed n mb bPtr n (advancePtr cPtr volA)
(ColumnMajor,ColumnMajor) ->
Array.unsafeCreate (shape ColumnMajor) $
\cPtr -> evalContT $ do
let n = Shape.size heightA
let na = n * Shape.size widthA
let nb = n * Shape.size widthB
naPtr <- Call.cint na
nbPtr <- Call.cint nb
aPtr <- ContT $ withForeignPtr a
bPtr <- ContT $ withForeignPtr b
incxPtr <- Call.cint 1
incyPtr <- Call.cint 1
liftIO $ do
BlasGen.copy naPtr aPtr incxPtr cPtr incyPtr
BlasGen.copy nbPtr bPtr incxPtr
(cPtr `advancePtr` na) incyPtr
(===) ::
(Extent.C horiz, Shape.C width, Eq width, Shape.C heighta, Shape.C heightb,
Class.Floating a) =>
Full Extent.Big horiz heighta width a ->
Full Extent.Big horiz heightb width a ->
Full Extent.Big horiz (heighta:+:heightb) width a
(===) a b = transpose (transpose a ||| transpose b)
add, sub ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Eq height, Eq width,
Class.Floating a) =>
Full vert horiz height width a ->
Full vert horiz height width a ->
Full vert horiz height width a
add x y = Vector.add (adaptOrder y x) y
sub x y = Vector.sub (adaptOrder y x) y
rowSums ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Class.Floating a) =>
Full vert horiz height width a -> Vector height a
rowSums m =
let width = MatrixShape.fullWidth $ Array.shape m
in multiplyVectorUnchecked m (Vector.constant width one)
columnSums ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Class.Floating a) =>
Full vert horiz height width a -> Vector width a
columnSums m =
let height = MatrixShape.fullHeight $ Array.shape m
in multiplyVectorUnchecked (transpose m) (Vector.constant height one)
scaleRowsComplex ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Class.Real a) =>
Vector height a ->
Full vert horiz height width (Complex a) ->
Full vert horiz height width (Complex a)
scaleRowsComplex
(Array heightX x) (Array shape@(MatrixShape.Full order extent) a) =
Array.unsafeCreate shape $ \bComplexPtr -> do
let (height,width) = Extent.dimensions extent
Call.assert "scaleRowsComplex: sizes mismatch" (heightX == height)
let bPtr = castPtr bComplexPtr
case order of
RowMajor -> evalContT $ do
let m = Shape.size height
let n = Shape.size width * 2
alphaPtr <- Call.alloca
nPtr <- Call.cint n
xPtr <- ContT $ withForeignPtr x
aPtr <- fmap castPtr $ ContT $ withForeignPtr a
incaPtr <- Call.cint 1
incbPtr <- Call.cint 1
liftIO $ sequence_ $ take m $
zipWith3
(\xkPtr akPtr bkPtr -> do
poke alphaPtr =<< peek xkPtr
BlasGen.copy nPtr akPtr incaPtr bkPtr incbPtr
BlasGen.scal nPtr alphaPtr bkPtr incbPtr)
(pointerSeq 1 xPtr)
(pointerSeq n aPtr)
(pointerSeq n bPtr)
ColumnMajor -> evalContT $ do
let m = Shape.size width
let nr = Shape.size height
let n = 2*nr
transPtr <- Call.char 'N'
nrPtr <- Call.cint nr
nPtr <- Call.cint n
klPtr <- Call.cint 0
kuPtr <- Call.cint 0
alphaPtr <- Call.number one
xrPtr <- ContT $ withForeignPtr x
xPtr <- Call.allocaArray n
incxrPtr <- Call.cint 1
incxPtr <- Call.cint 2
ldxPtr <- Call.leadingDim 1
aPtr <- fmap castPtr $ ContT $ withForeignPtr a
incaPtr <- Call.cint 1
betaPtr <- Call.number zero
incbPtr <- Call.cint 1
liftIO $ do
BlasGen.copy nrPtr xrPtr incxrPtr xPtr incxPtr
BlasGen.copy nrPtr xrPtr incxrPtr (advancePtr xPtr 1) incxPtr
sequence_ $ take m $
zipWith
(\akPtr bkPtr ->
Private.gbmv transPtr
nPtr nPtr klPtr kuPtr alphaPtr xPtr ldxPtr
akPtr incaPtr betaPtr bkPtr incbPtr)
(pointerSeq n aPtr)
(pointerSeq n bPtr)
scaleColumnsComplex ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Eq width, Class.Real a) =>
Vector width a ->
Full vert horiz height width (Complex a) ->
Full vert horiz height width (Complex a)
scaleColumnsComplex x = transpose . scaleRowsComplex x . transpose
scaleRowsReal ::
(Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width,
Class.Floating a) =>
Vector height (RealOf a) ->
Full vert horiz height width a ->
Full vert horiz height width a
scaleRowsReal =
getScaleRowsReal $
Class.switchFloating
(ScaleRowsReal scaleRows)
(ScaleRowsReal scaleRows)
(ScaleRowsReal scaleRowsComplex)
(ScaleRowsReal scaleRowsComplex)
newtype ScaleRowsReal f g a =
ScaleRowsReal {getScaleRowsReal :: f (RealOf a) -> g a -> g a}
scaleColumnsReal ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
Vector width (RealOf a) ->
Full vert horiz height width a ->
Full vert horiz height width a
scaleColumnsReal x = transpose . scaleRowsReal x . transpose
tensorProduct ::
(Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
Order -> Vector height a -> Vector width a -> General height width a
tensorProduct order x y =
case order of
ColumnMajor -> tensorProd 'T' order x y
RowMajor -> transpose $ tensorProd 'T' order y x
outer ::
(Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
Order -> Vector height a -> Vector width a -> General height width a
outer order x y =
case order of
ColumnMajor -> tensorProd 'C' ColumnMajor x y
RowMajor -> transpose $ tensorProd 'C' RowMajor y x
{-# INLINE tensorProd #-}
tensorProd ::
(Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
Char -> Order ->
Vector height a -> Vector width a -> General height width a
tensorProd trans order (Array shX x) (Array shY y) =
Array.unsafeCreate (MatrixShape.general MatrixShape.ColumnMajor shX shY) $
\cPtr -> do
let m = Shape.size shX
let n = Shape.size shY
let ((transa,transb),(lda,ldb)) =
case order of
ColumnMajor -> (('N',trans),(m,n))
RowMajor -> ((trans,'N'),(1,1))
evalContT $ do
transaPtr <- Call.char transa
transbPtr <- Call.char transb
mPtr <- Call.cint m
nPtr <- Call.cint n
kPtr <- Call.cint 1
alphaPtr <- Call.number one
aPtr <- ContT $ withForeignPtr x
ldaPtr <- Call.leadingDim lda
bPtr <- ContT $ withForeignPtr y
ldbPtr <- Call.leadingDim ldb
betaPtr <- Call.number zero
ldcPtr <- Call.leadingDim m
liftIO $
BlasGen.gemm
transaPtr transbPtr mPtr nPtr kPtr alphaPtr
aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr
sumRank1 ::
(Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
(height,width) ->
[(a, (Vector height a, Vector width a))] -> General height width a
sumRank1 (height,width) xys =
Array.unsafeCreateWithSize (MatrixShape.general ColumnMajor height width) $
\size aPtr ->
evalContT $ do
let m = Shape.size height
let n = Shape.size width
mPtr <- Call.cint m
nPtr <- Call.cint n
alphaPtr <- Call.alloca
incxPtr <- Call.cint 1
incyPtr <- Call.cint 1
ldaPtr <- Call.leadingDim m
liftIO $ do
fill zero size aPtr
forM_ xys $ \(alpha, (Array shX x, Array shY y)) ->
withForeignPtr x $ \xPtr ->
withForeignPtr y $ \yPtr -> do
Call.assert "Matrix.sumRank1: non-matching height" (height==shX)
Call.assert "Matrix.sumRank1: non-matching width" (width==shY)
poke alphaPtr alpha
BlasGen.gerc mPtr nPtr
alphaPtr xPtr incxPtr yPtr incyPtr aPtr ldaPtr