{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ConstraintKinds #-}
module Numeric.LAPACK.Matrix.Triangular.Basic (
Triangular, MatrixShape.UpLo,
Diagonal,
Upper, FlexUpper, UnitUpper,
Lower, FlexLower, UnitLower,
Symmetric, FlexSymmetric,
fromList, autoFromList,
relaxUnitDiagonal, strictNonUnitDiagonal,
identity,
diagonal,
takeDiagonal,
transpose,
adjoint,
stackDiagonal,
stackLower,
stackUpper,
stackSymmetric,
takeTopLeft,
takeTopRight,
takeBottomLeft,
takeBottomRight,
toSquare,
takeLower,
takeUpper,
fromLowerRowMajor, toLowerRowMajor,
fromUpperRowMajor, toUpperRowMajor,
forceOrder, adaptOrder,
add, sub,
Tri.PowerDiag,
Tri.PowerContentDiag,
multiplyVector,
square, power,
multiply,
multiplyFull,
) where
import qualified Numeric.LAPACK.Matrix.Symmetric.Private as Symmetric
import qualified Numeric.LAPACK.Matrix.Triangular.Private as Tri
import qualified Numeric.LAPACK.Matrix.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Vector.Private as VectorPriv
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Triangular.Private
(Triangular, FlexDiagonal, diagonalPointers, diagonalPointerPairs,
pack, packRect, unpack, unpackZero, unpackToTemp, uncheck, recheck)
import Numeric.LAPACK.Matrix.Shape.Private
(Order(RowMajor,ColumnMajor),
flipOrder, transposeFromOrder, uploFromOrder, uploOrder,
Unit(Unit), NonUnit(NonUnit), charFromTriDiag)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated))
import Numeric.LAPACK.Matrix.Private (Full, Square, ShapeInt, shapeInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (zero, one)
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked))
import Numeric.LAPACK.Private (fill, copyBlock)
import qualified Numeric.LAPACK.FFI.Complex as LapackComplex
import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape ((:+:)((:+:)))
import Foreign.C.Types (CChar, CInt)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable, poke, peek)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Data.Function.HT (powerAssociative)
import Data.Foldable (forM_)
import Data.Tuple.HT (double)
type Lower sh = FlexLower NonUnit sh
type Upper sh = FlexUpper NonUnit sh
type Symmetric sh = FlexSymmetric NonUnit sh
type Diagonal sh = FlexDiagonal NonUnit sh
type UnitLower sh = FlexLower Unit sh
type UnitUpper sh = FlexUpper Unit sh
type FlexLower diag sh = Array (MatrixShape.LowerTriangular diag sh)
type FlexUpper diag sh = Array (MatrixShape.UpperTriangular diag sh)
type FlexSymmetric diag sh = Array (MatrixShape.FlexSymmetric diag sh)
transpose ::
(MatrixShape.Content lo, MatrixShape.Content up,
MatrixShape.TriDiag diag) =>
Triangular lo diag up sh a -> Triangular up diag lo sh a
transpose (Array sh a) =
Array (MatrixShape.triangularTranspose sh) a
adjoint ::
(MatrixShape.Content lo, MatrixShape.Content up,
MatrixShape.TriDiag diag, Shape.C sh, Class.Floating a) =>
Triangular lo diag up sh a -> Triangular up diag lo sh a
adjoint = Vector.conjugate . transpose
fromList ::
(MatrixShape.Content lo, MatrixShape.Content up, Shape.C sh, Storable a) =>
Order -> sh -> [a] -> Triangular lo NonUnit up sh a
fromList order sh =
CheckedArray.fromList
(MatrixShape.Triangular NonUnit MatrixShape.autoUplo order sh)
autoFromList ::
(MatrixShape.Content lo, MatrixShape.Content up, Storable a) =>
Order -> [a] -> Triangular lo NonUnit up ShapeInt a
autoFromList order xs =
let n = length xs
triSize = MatrixShape.triangleExtent "Triangular.autoFromList" n
uplo = MatrixShape.autoUplo
size = MatrixShape.caseDiagUpLoSym uplo n triSize triSize triSize
in Array.fromList
(MatrixShape.Triangular
MatrixShape.autoDiag uplo order (shapeInt size))
xs
toSquare ::
(MatrixShape.Content lo, MatrixShape.Content up,
Shape.C sh, Class.Floating a) =>
Triangular lo diag up sh a -> Square sh a
toSquare (Array (MatrixShape.Triangular _diag uplo order sh) a) =
Array.unsafeCreateWithSize (MatrixShape.square order sh) $ \size bPtr ->
let n = Shape.size sh
in withForeignPtr a $ \aPtr ->
MatrixShape.caseDiagUpLoSym uplo
(do
fill zero size bPtr
evalContT $ do
nPtr <- Call.cint n
incxPtr <- Call.cint 1
incyPtr <- Call.cint (n+1)
liftIO $ BlasGen.copy nPtr aPtr incxPtr bPtr incyPtr)
(unpackZero order n aPtr bPtr)
(unpackZero (flipOrder order) n aPtr bPtr)
(Symmetric.unpack NonConjugated order n aPtr bPtr)
takeLower ::
(Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) =>
Full Extent.Small horiz height width a -> Lower height a
takeLower =
Tri.takeLower (MatrixShape.NonUnit, const $ const $ const $ return ())
takeUpper ::
(Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) =>
Full vert Extent.Small height width a -> Upper width a
takeUpper (Array (MatrixShape.Full order extent) a) =
let (height,width) = Extent.dimensions extent
m = Shape.size height
n = Shape.size width
k = case order of RowMajor -> n; ColumnMajor -> m
in Array.unsafeCreate
(MatrixShape.Triangular MatrixShape.NonUnit
MatrixShape.upper order width) $ \bPtr ->
withForeignPtr a $ \aPtr -> packRect order n k aPtr bPtr
fromLowerRowMajor ::
(Shape.C sh, Class.Floating a) =>
Array (Shape.Triangular Shape.Lower sh) a -> Lower sh a
fromLowerRowMajor =
Array.mapShape
(MatrixShape.Triangular MatrixShape.NonUnit MatrixShape.lower RowMajor .
Shape.triangularSize)
fromUpperRowMajor ::
(Shape.C sh, Class.Floating a) =>
Array (Shape.Triangular Shape.Upper sh) a -> Upper sh a
fromUpperRowMajor =
Array.mapShape
(MatrixShape.Triangular MatrixShape.NonUnit MatrixShape.upper RowMajor .
Shape.triangularSize)
toLowerRowMajor ::
(Shape.C sh, Class.Floating a) =>
Lower sh a -> Array (Shape.Triangular Shape.Lower sh) a
toLowerRowMajor =
Array.mapShape (Shape.Triangular Shape.Lower . MatrixShape.triangularSize)
.
forceOrder MatrixShape.RowMajor
toUpperRowMajor ::
(Shape.C sh, Class.Floating a) =>
Upper sh a -> Array (Shape.Triangular Shape.Upper sh) a
toUpperRowMajor =
Array.mapShape (Shape.Triangular Shape.Upper . MatrixShape.triangularSize)
.
forceOrder MatrixShape.RowMajor
forceOrder ::
(MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
Shape.C sh, Class.Floating a) =>
Order -> Triangular lo diag up sh a -> Triangular lo diag up sh a
forceOrder newOrder =
Tri.getMap $
MatrixShape.switchDiagUpLoSym
(Tri.Map $
Array.mapShape (\sh -> sh{MatrixShape.triangularOrder = newOrder}))
(forceOrderMap newOrder takeUpper)
(forceOrderMap newOrder takeLower)
(forceOrderMap newOrder $
Array.mapShape
(\sh -> sh{MatrixShape.triangularUplo = MatrixShape.autoUplo})
.
takeUpper)
forceOrderMap ::
(MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
Shape.C sh, Class.Floating a) =>
Order ->
(Square sh a -> Triangular lo NonUnit up sh a) ->
Tri.Map diag sh sh a lo up
forceOrderMap newOrder f = Tri.Map $ \a ->
if MatrixShape.triangularOrder (Array.shape a) == newOrder
then a
else uncheckedRelaxNonUnitDiagonal $
f $ Basic.forceOrder newOrder $ toSquare a
uncheckedRelaxNonUnitDiagonal ::
(MatrixShape.TriDiag diag) =>
Triangular lo NonUnit up sh a -> Triangular lo diag up sh a
uncheckedRelaxNonUnitDiagonal =
Array.mapShape (\sh -> sh{MatrixShape.triangularDiag = MatrixShape.autoDiag})
adaptOrder ::
(MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
Shape.C sh, Class.Floating a) =>
Triangular lo diag up sh a ->
Triangular lo diag up sh a ->
Triangular lo diag up sh a
adaptOrder x = forceOrder (MatrixShape.triangularOrder $ Array.shape x)
add, sub ::
(MatrixShape.Content lo, MatrixShape.Content up,
Eq lo, Eq up, Eq sh, Shape.C sh, Class.Floating a) =>
Triangular lo NonUnit up sh a ->
Triangular lo NonUnit up sh a ->
Triangular lo NonUnit up sh a
add x y = Vector.add (adaptOrder y x) y
sub x y = Vector.sub (adaptOrder y x) y
identity ::
(MatrixShape.Content lo, MatrixShape.Content up,
Shape.C sh, Class.Floating a) =>
Order -> sh -> Triangular lo Unit up sh a
identity order sh =
let (realOrder, uplo) = autoUploOrder order
in Array.unsafeCreateWithSize (MatrixShape.Triangular Unit uplo order sh) $
\size aPtr -> do
let n = Shape.size sh
let fillTriangle = do
fill zero size aPtr
mapM_ (flip poke one) (diagonalPointers realOrder n aPtr)
MatrixShape.caseDiagUpLoSym uplo
(fill one n aPtr)
fillTriangle
fillTriangle
fillTriangle
diagonal, diagonalAux ::
(MatrixShape.Content lo, MatrixShape.Content up,
Shape.C sh, Class.Floating a) =>
Order -> Vector sh a -> Triangular lo NonUnit up sh a
diagonal order x@(Array sh xPtr) =
let uplo = MatrixShape.autoUplo
in MatrixShape.caseDiagUpLoSym uplo
(Array (MatrixShape.Triangular NonUnit uplo order sh) xPtr)
(diagonalAux order x)
(diagonalAux order x)
(diagonalAux order x)
diagonalAux order (Array sh x) =
let (realOrder, uplo) = autoUploOrder order
in Array.unsafeCreateWithSize
(MatrixShape.Triangular NonUnit uplo order sh) $
\size aPtr -> do
let n = Shape.size sh
fill zero size aPtr
withForeignPtr x $ \xPtr ->
forM_ (diagonalPointerPairs realOrder n xPtr aPtr) $
\(srcPtr,dstPtr) -> poke dstPtr =<< peek srcPtr
takeDiagonal, takeDiagonalAux ::
(MatrixShape.Content lo, MatrixShape.Content up,
Shape.C sh, Class.Floating a) =>
Triangular lo diag up sh a -> Vector sh a
takeDiagonal a@(Array (MatrixShape.Triangular _diag uplo _order sh) aPtr) =
MatrixShape.caseDiagUpLoSym uplo
(Array sh aPtr)
(takeDiagonalAux a)
(takeDiagonalAux a)
(takeDiagonalAux a)
takeDiagonalAux (Array (MatrixShape.Triangular _diag uplo order sh) a) =
Array.unsafeCreate sh $ \xPtr ->
withForeignPtr a $ \aPtr ->
mapM_
(\(dstPtr,srcPtr) -> poke dstPtr =<< peek srcPtr)
(diagonalPointerPairs (uploOrder uplo order) (Shape.size sh) xPtr aPtr)
relaxUnitDiagonal ::
(MatrixShape.TriDiag diag) =>
Triangular lo Unit up sh a -> Triangular lo diag up sh a
relaxUnitDiagonal = Array.mapShape MatrixShape.relaxUnitDiagonal
strictNonUnitDiagonal ::
(MatrixShape.TriDiag diag) =>
Triangular lo diag up sh a -> Triangular lo NonUnit up sh a
strictNonUnitDiagonal = Array.mapShape MatrixShape.strictNonUnitDiagonal
liftDiagonal ::
(Vector sh0 a -> Vector sh1 a) ->
FlexDiagonal diag sh0 a -> FlexDiagonal diag sh1 a
liftDiagonal f (Array (MatrixShape.Triangular diag uplo order sh0) a) =
Array.mapShape (MatrixShape.Triangular diag uplo order) $ f $ Array sh0 a
stackDiagonal ::
(MatrixShape.TriDiag diag, Shape.C sh0, Shape.C sh1, Class.Floating a) =>
FlexDiagonal diag sh0 a ->
FlexDiagonal diag sh1 a ->
FlexDiagonal diag (sh0:+:sh1) a
stackDiagonal a = liftDiagonal (Vector.append $ takeDiagonal a)
stackLower ::
(MatrixShape.TriDiag diag,
Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
FlexLower diag sh0 a ->
Matrix.General sh1 sh0 a ->
FlexLower diag sh1 a ->
FlexLower diag (sh0:+:sh1) a
stackLower a b c =
transpose $
stackAux "LowerTriangular" (transpose a) (Basic.transpose b) (transpose c)
stackUpper ::
(MatrixShape.TriDiag diag,
Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
FlexUpper diag sh0 a ->
Matrix.General sh0 sh1 a ->
FlexUpper diag sh1 a ->
FlexUpper diag (sh0:+:sh1) a
stackUpper = stackAux "UpperTriangular"
stackSymmetric ::
(MatrixShape.TriDiag diag,
Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
FlexSymmetric diag sh0 a ->
Matrix.General sh0 sh1 a ->
FlexSymmetric diag sh1 a ->
FlexSymmetric diag (sh0:+:sh1) a
stackSymmetric = stackAux "Symmetric"
stackAux ::
(MatrixShape.Content lo, MatrixShape.TriDiag diag,
Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
String ->
Triangular lo diag MatrixShape.Filled sh0 a ->
Matrix.General sh0 sh1 a ->
Triangular lo diag MatrixShape.Filled sh1 a ->
Triangular lo diag MatrixShape.Filled (sh0:+:sh1) a
stackAux name a b c =
let order = MatrixShape.fullOrder $ Array.shape b
in Tri.stack name
(\sh ->
(Array.shape a) {
MatrixShape.triangularOrder = order,
MatrixShape.triangularSize = sh})
(forceOrder order a) b (forceOrder order c)
takeTopLeft ::
(MatrixShape.Content lo, MatrixShape.TriDiag diag, MatrixShape.Content up,
Shape.C sh0, Shape.C sh1, Class.Floating a) =>
Triangular lo diag up (sh0:+:sh1) a ->
Triangular lo diag up sh0 a
takeTopLeft =
Tri.getMap $
MatrixShape.switchDiagUpLoSym
(Tri.Map $ liftDiagonal Vector.takeLeft)
(Tri.Map $ takeTopLeftAux)
(Tri.Map $ transpose . takeTopLeftAux . transpose)
(Tri.Map $ takeTopLeftAux)
takeTopLeftAux ::
(MatrixShape.Content lo, MatrixShape.TriDiag diag,
Shape.C sh0, Shape.C sh1, Class.Floating a) =>
Triangular lo diag MatrixShape.Filled (sh0:+:sh1) a ->
Triangular lo diag MatrixShape.Filled sh0 a
takeTopLeftAux =
Tri.takeTopLeft
(\(MatrixShape.Triangular diag uplo order sh@(sh0:+:_sh1)) ->
(MatrixShape.Triangular diag uplo order sh0, (order,sh)))
takeBottomLeft ::
(MatrixShape.TriDiag diag, MatrixShape.Content up,
Shape.C sh0, Shape.C sh1, Class.Floating a) =>
Triangular MatrixShape.Filled diag up (sh0:+:sh1) a ->
Matrix.General sh1 sh0 a
takeBottomLeft = Basic.transpose . takeTopRight . transpose
takeTopRight ::
(MatrixShape.Content lo, MatrixShape.TriDiag diag,
Shape.C sh0, Shape.C sh1, Class.Floating a) =>
Triangular lo diag MatrixShape.Filled (sh0:+:sh1) a ->
Matrix.General sh0 sh1 a
takeTopRight =
Tri.takeTopRight
(\(MatrixShape.Triangular _diag _uplo order sh) -> (order,sh))
takeBottomRight ::
(MatrixShape.Content lo, MatrixShape.TriDiag diag, MatrixShape.Content up,
Shape.C sh0, Shape.C sh1, Class.Floating a) =>
Triangular lo diag up (sh0:+:sh1) a ->
Triangular lo diag up sh1 a
takeBottomRight =
Tri.getMap $
MatrixShape.switchDiagUpLoSym
(Tri.Map $ liftDiagonal Vector.takeRight)
(Tri.Map $ takeBottomRightAux)
(Tri.Map $ transpose . takeBottomRightAux . transpose)
(Tri.Map $ takeBottomRightAux)
takeBottomRightAux ::
(MatrixShape.Content lo, MatrixShape.TriDiag diag,
Shape.C sh0, Shape.C sh1, Class.Floating a) =>
Triangular lo diag MatrixShape.Filled (sh0:+:sh1) a ->
Triangular lo diag MatrixShape.Filled sh1 a
takeBottomRightAux =
Tri.takeBottomRight
(\(MatrixShape.Triangular diag uplo order sh@(_sh0:+:sh1)) ->
(MatrixShape.Triangular diag uplo order sh1, (order,sh)))
multiplyVector ::
(MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
Shape.C sh, Eq sh, Class.Floating a) =>
Triangular lo diag up sh a -> Vector sh a -> Vector sh a
multiplyVector =
Tri.getMultiplyRight $
MatrixShape.switchDiagUpLoSym
(Tri.MultiplyRight $ Vector.mul . takeDiagonal)
(Tri.MultiplyRight multiplyVectorTriangular)
(Tri.MultiplyRight multiplyVectorTriangular)
(Tri.MultiplyRight multiplyVectorTriangular)
multiplyVectorTriangular ::
(MatrixShape.UpLoSym lo up, MatrixShape.TriDiag diag,
Shape.C sh, Eq sh, Class.Floating a) =>
Triangular lo diag up sh a -> Vector sh a -> Vector sh a
multiplyVectorTriangular
(Array (MatrixShape.Triangular diag uplo order shA) a) (Array shX x) =
Array.unsafeCreate shX $ \yPtr -> do
Call.assert "Triangular.multiplyVector: width shapes mismatch" (shA == shX)
let n = Shape.size shA
evalContT $ do
uploPtr <- Call.char $ uploFromOrder $ uploOrder uplo order
transPtr <- Call.char $ transposeFromOrder order
diagPtr <- Call.char $ charFromTriDiag diag
nPtr <- Call.cint n
aPtr <- ContT $ withForeignPtr a
xPtr <- ContT $ withForeignPtr x
alphaPtr <- Call.number one
betaPtr <- Call.number zero
incxPtr <- Call.cint 1
incyPtr <- Call.cint 1
let runTPMV = do
copyBlock n xPtr yPtr
BlasGen.tpmv uploPtr transPtr diagPtr nPtr aPtr yPtr incyPtr
liftIO $
MatrixShape.caseUpLoSym uplo
runTPMV
runTPMV
(spmv uploPtr nPtr alphaPtr aPtr xPtr incxPtr betaPtr yPtr incyPtr)
newtype SPMV a =
SPMV {
getSPMV ::
Ptr CChar -> Ptr CInt -> Ptr a -> Ptr a ->
Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
}
spmv :: Class.Floating a =>
Ptr CChar -> Ptr CInt -> Ptr a -> Ptr a ->
Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
spmv =
getSPMV $
Class.switchFloating
(SPMV BlasReal.spmv) (SPMV BlasReal.spmv)
(SPMV LapackComplex.spmv) (SPMV LapackComplex.spmv)
square ::
(MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
Shape.C sh, Class.Floating a) =>
Triangular lo diag up sh a ->
Triangular lo (Tri.PowerDiag lo up diag) up sh a
square =
Tri.getPower $
MatrixShape.switchDiagUpLoSym
(Tri.Power squareDiagonal)
(Tri.Power squareTriangular)
(Tri.Power squareTriangular)
(Tri.Power $ squareSymmetric . strictNonUnitDiagonal)
squareDiagonal ::
(MatrixShape.TriDiag diag, Shape.C sh, Class.Floating a) =>
FlexDiagonal diag sh a -> FlexDiagonal diag sh a
squareDiagonal =
getMapDiag $
MatrixShape.switchTriDiag
(MapDiag id)
(MapDiag $
VectorPriv.recheck . uncurry Vector.mul . double . VectorPriv.uncheck)
newtype MapDiag lo up sh a diag =
MapDiag {
getMapDiag ::
Triangular lo diag up sh a ->
Triangular lo diag up sh a
}
squareTriangular ::
(MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
Shape.C sh, Class.Floating a) =>
Triangular lo diag up sh a -> Triangular lo diag up sh a
squareTriangular
(Array shape@(MatrixShape.Triangular diag uplo order sh) a) =
Array.unsafeCreate shape $ \bpPtr -> do
let n = Shape.size sh
evalContT $ do
sidePtr <- Call.char 'L'
let realOrder = uploOrder uplo order
uploPtr <- Call.char $ uploFromOrder realOrder
transPtr <- Call.char 'N'
diagPtr <- Call.char $ charFromTriDiag diag
nPtr <- Call.cint n
ldPtr <- Call.leadingDim n
aPtr <- unpackToTemp (unpack realOrder) n a
bPtr <- unpackToTemp (unpackZero realOrder) n a
alphaPtr <- Call.number one
liftIO $ do
BlasGen.trmm sidePtr uploPtr transPtr diagPtr
nPtr nPtr alphaPtr aPtr ldPtr bPtr ldPtr
pack realOrder n bPtr bpPtr
squareSymmetric ::
(Shape.C sh, Class.Floating a) => Symmetric sh a -> Symmetric sh a
squareSymmetric (Array shape@(MatrixShape.Triangular _diag _uplo order sh) a) =
Array.unsafeCreate shape $
Symmetric.square NonConjugated order (Shape.size sh) a
power ::
(MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
Shape.C sh, Class.Floating a) =>
Int ->
Triangular lo diag up sh a ->
Triangular lo (Tri.PowerDiag lo up diag) up sh a
power n =
Tri.getPower $
MatrixShape.switchDiagUpLoSym
(Tri.Power $ Array.map (^n))
(Tri.Power $ powerTriangular (fromIntegral n))
(Tri.Power $ powerTriangular (fromIntegral n))
(Tri.Power $ powerSymmetric (fromIntegral n) . strictNonUnitDiagonal)
powerTriangular ::
(MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
Shape.C sh, Class.Floating a) =>
Integer -> Triangular lo diag up sh a -> Triangular lo diag up sh a
powerTriangular n a@(Array (MatrixShape.Triangular _diag _uplo order sh) _) =
recheck $
powerAssociative multiplyTriangular
(relaxUnitDiagonal $ identity order $ Unchecked sh)
(uncheck a)
n
powerSymmetric ::
(Shape.C sh, Class.Floating a) => Integer -> Symmetric sh a -> Symmetric sh a
powerSymmetric n a0@(Array (MatrixShape.Triangular _diag _uplo order sh) _) =
recheck $
powerAssociative
(\a b ->
Tri.fromUpperPart
(MatrixShape.Triangular NonUnit MatrixShape.autoUplo) $
multiplyFullTriangular a $ toSquare b)
(relaxUnitDiagonal $ identity order $ Unchecked sh)
(uncheck a0)
n
multiply ::
(MatrixShape.DiagUpLo lo up, MatrixShape.TriDiag diag,
Shape.C sh, Eq sh, Class.Floating a) =>
Triangular lo diag up sh a -> Triangular lo diag up sh a ->
Triangular lo diag up sh a
multiply =
getMultiply $
MatrixShape.switchDiagUpLo
(Multiply $ liftDiagonal . Vector.mul . takeDiagonal)
(Multiply multiplyTriangular)
(Multiply multiplyTriangular)
newtype Multiply diag sh a lo up =
Multiply {
getMultiply ::
Triangular lo diag up sh a ->
Triangular lo diag up sh a -> Triangular lo diag up sh a
}
multiplyTriangular ::
(MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
Shape.C sh, Eq sh, Class.Floating a) =>
Triangular lo diag up sh a ->
Triangular lo diag up sh a -> Triangular lo diag up sh a
multiplyTriangular
(Array (MatrixShape.Triangular diag uploA orderA shA) a)
(Array shapeB@(MatrixShape.Triangular _diag uploB orderB shB) b) =
Array.unsafeCreate shapeB $ \cpPtr -> do
Call.assert "Triangular.multiply: width shapes mismatch" (shA == shB)
let n = Shape.size shA
evalContT $ do
let (side,trans) =
case orderB of
ColumnMajor -> ('L', orderA)
RowMajor -> ('R', flipOrder orderA)
sidePtr <- Call.char side
let realOrderA = uploOrder uploA orderA
let realOrderB = uploOrder uploB orderB
uploPtr <- Call.char $ uploFromOrder realOrderA
transPtr <- Call.char $ transposeFromOrder trans
diagPtr <- Call.char $ charFromTriDiag diag
nPtr <- Call.cint n
ldPtr <- Call.leadingDim n
aPtr <- unpackToTemp (unpack realOrderA) n a
bPtr <- unpackToTemp (unpackZero realOrderB) n b
alphaPtr <- Call.number one
liftIO $ do
BlasGen.trmm sidePtr uploPtr transPtr diagPtr
nPtr nPtr alphaPtr aPtr ldPtr bPtr ldPtr
pack realOrderB n bPtr cpPtr
multiplyFull ::
(MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width,
Class.Floating a) =>
Triangular lo diag up height a ->
Full vert horiz height width a ->
Full vert horiz height width a
multiplyFull =
Tri.getMultiplyRight $
MatrixShape.switchDiagUpLoSym
(Tri.MultiplyRight $ Basic.scaleRows . takeDiagonal)
(Tri.MultiplyRight multiplyFullTriangular)
(Tri.MultiplyRight multiplyFullTriangular)
(Tri.MultiplyRight multiplyFullTriangular)
multiplyFullTriangular ::
(MatrixShape.UpLoSym lo up, MatrixShape.TriDiag diag,
Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width,
Class.Floating a) =>
Triangular lo diag up height a ->
Full vert horiz height width a ->
Full vert horiz height width a
multiplyFullTriangular
(Array (MatrixShape.Triangular diag uploA orderA shA) a)
(Array shapeB@(MatrixShape.Full orderB extentB) b) =
Array.unsafeCreateWithSize shapeB $ \size cPtr -> do
let (height,width) = Extent.dimensions extentB
Call.assert "Triangular.multiplyFull: shapes mismatch" (shA == height)
let m0 = Shape.size height
let n0 = Shape.size width
evalContT $ do
let (side,trans,(m,n)) =
case orderB of
ColumnMajor -> ('L', orderA, (m0,n0))
RowMajor -> ('R', flipOrder orderA, (n0,m0))
sidePtr <- Call.char side
let realOrderA = uploOrder uploA orderA
uploPtr <- Call.char $ uploFromOrder realOrderA
transPtr <- Call.char $ transposeFromOrder trans
diagPtr <- Call.char $ charFromTriDiag diag
mPtr <- Call.cint m
nPtr <- Call.cint n
alphaPtr <- Call.number one
aPtr <- unpackToTemp (unpack realOrderA) m0 a
ldaPtr <- Call.leadingDim m0
betaPtr <- Call.number zero
bPtr <- ContT $ withForeignPtr b
ldbPtr <- Call.leadingDim m
let runTRMM = do
copyBlock size bPtr cPtr
BlasGen.trmm sidePtr uploPtr transPtr diagPtr
mPtr nPtr alphaPtr aPtr ldaPtr cPtr ldbPtr
liftIO $
MatrixShape.caseUpLoSym uploA
runTRMM
runTRMM
(BlasGen.symm sidePtr uploPtr
mPtr nPtr alphaPtr aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldbPtr)
autoUploOrder ::
(MatrixShape.Content lo, MatrixShape.Content up) => Order -> (Order, (lo,up))
autoUploOrder order =
case MatrixShape.autoUplo of
uplo -> (uploOrder uplo order, uplo)