{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Hermitian.Basic (
Hermitian,
Transposition(..),
fromList,
autoFromList,
identity,
diagonal,
takeDiagonal,
forceOrder,
stack,
takeTopLeft,
takeTopRight,
takeBottomRight,
multiplyVector,
square,
multiplyFull,
outer,
sumRank1,
sumRank2,
toSquare,
gramian,
congruenceDiagonal,
congruence,
scaledAnticommutator,
addAdjoint,
) where
import qualified Numeric.LAPACK.Matrix.Symmetric.Private as Symmetric
import qualified Numeric.LAPACK.Matrix.Triangular.Private as Triangular
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Split as Split
import Numeric.LAPACK.Matrix.Hermitian.Private (Diagonal(..), TakeDiagonal(..))
import Numeric.LAPACK.Matrix.Triangular.Private
(forPointers, pack, packRect, unpack, unpackToTemp,
diagonalPointers, diagonalPointerPairs,
rowMajorPointers, columnMajorPointers)
import Numeric.LAPACK.Matrix.Shape.Private
(Order(RowMajor,ColumnMajor), flipOrder, sideSwapFromOrder,
uploFromOrder)
import Numeric.LAPACK.Matrix.Modifier
(Transposition(NonTransposed, Transposed), transposeOrder,
Conjugation(Conjugated), conjugatedOnRowMajor)
import Numeric.LAPACK.Matrix.Private
(Full, General, argGeneral, Square, argSquare, ZeroInt, zeroInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, zero, one)
import Numeric.LAPACK.Private
(fill, lacgv, realPtr,
copyConjugate, condConjugate, conjugateToTemp, condConjugateToTemp)
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.BLAS.FFI.Complex as BlasComplex
import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape ((:+:)((:+:)))
import Foreign.C.Types (CInt, CChar)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable, poke, peek)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)
import Data.Foldable (forM_)
type Hermitian sh = Array (MatrixShape.Hermitian sh)
fromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Hermitian sh a
fromList order sh =
CheckedArray.fromList (MatrixShape.Hermitian order sh)
autoFromList :: (Storable a) => Order -> [a] -> Hermitian ZeroInt a
autoFromList order xs =
fromList order
(zeroInt $ MatrixShape.triangleExtent "Hermitian.autoFromList" $
length xs)
xs
identity :: (Shape.C sh, Class.Floating a) => Order -> sh -> Hermitian sh a
identity order sh =
Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $
\triSize aPtr -> do
fill zero triSize aPtr
mapM_ (flip poke one) $ diagonalPointers order (Shape.size sh) aPtr
diagonal ::
(Shape.C sh, Class.Floating a) =>
Order -> Vector sh (RealOf a) -> Hermitian sh a
diagonal order =
runDiagonal $
Class.switchFloating
(Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order)
(Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order)
diagonalAux ::
(Shape.C sh, Class.Floating a, RealOf a ~ ar, Storable ar) =>
Order -> Vector sh ar -> Hermitian sh a
diagonalAux order (Array sh x) =
Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $
\triSize aPtr -> do
fill zero triSize aPtr
withForeignPtr x $ \xPtr ->
forM_ (diagonalPointerPairs order (Shape.size sh) xPtr aPtr) $
\(srcPtr,dstPtr) -> poke (realPtr dstPtr) =<< peek srcPtr
takeDiagonal ::
(Shape.C sh, Class.Floating a) =>
Hermitian sh a -> Vector sh (RealOf a)
takeDiagonal =
runTakeDiagonal $
Class.switchFloating
(TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux)
(TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux)
takeDiagonalAux ::
(Shape.C sh, Storable a, RealOf a ~ ar, Storable ar) =>
Hermitian sh a -> Vector sh ar
takeDiagonalAux (Array (MatrixShape.Hermitian order sh) a) =
Array.unsafeCreateWithSize sh $ \n xPtr ->
withForeignPtr a $ \aPtr ->
forM_ (diagonalPointerPairs order n xPtr aPtr) $
\(dstPtr,srcPtr) -> poke dstPtr =<< peek (realPtr srcPtr)
forceOrder ::
(Shape.C sh, Class.Floating a) =>
Order -> Hermitian sh a -> Hermitian sh a
forceOrder newOrder a =
if MatrixShape.hermitianOrder (Array.shape a) == newOrder
then a
else fromUpperPart $ Basic.forceOrder newOrder $ toSquare a
fromUpperPart ::
(Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) =>
Full vert Extent.Small height width a -> Hermitian width a
fromUpperPart (Array (MatrixShape.Full order extent) a) =
let (height,width) = Extent.dimensions extent
m = Shape.size height
n = Shape.size width
k = case order of RowMajor -> n; ColumnMajor -> m
in Array.unsafeCreate (MatrixShape.Hermitian order width) $ \bPtr ->
withForeignPtr a $ \aPtr -> packRect order n k aPtr bPtr
stack ::
(Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
Hermitian sh0 a -> General sh0 sh1 a -> Hermitian sh1 a ->
Hermitian (sh0:+:sh1) a
stack a b c =
let order = MatrixShape.fullOrder $ Array.shape b
in Triangular.stack "Hermitian" (MatrixShape.Hermitian order)
(forceOrder order a) b (forceOrder order c)
takeTopLeft ::
(Shape.C sh0, Shape.C sh1, Class.Floating a) =>
Hermitian (sh0:+:sh1) a -> Hermitian sh0 a
takeTopLeft =
Triangular.takeTopLeft
(\(MatrixShape.Hermitian order sh@(sh0:+:_sh1)) ->
(MatrixShape.Hermitian order sh0, (order,sh)))
takeTopRight ::
(Shape.C sh0, Shape.C sh1, Class.Floating a) =>
Hermitian (sh0:+:sh1) a -> General sh0 sh1 a
takeTopRight =
Triangular.takeTopRight (\(MatrixShape.Hermitian order sh) -> (order,sh))
takeBottomRight ::
(Shape.C sh0, Shape.C sh1, Class.Floating a) =>
Hermitian (sh0:+:sh1) a -> Hermitian sh1 a
takeBottomRight =
Triangular.takeBottomRight
(\(MatrixShape.Hermitian order sh@(_sh0:+:sh1)) ->
(MatrixShape.Hermitian order sh1, (order,sh)))
multiplyVector ::
(Shape.C sh, Eq sh, Class.Floating a) =>
Transposition -> Hermitian sh a -> Vector sh a -> Vector sh a
multiplyVector transposed
(Array (MatrixShape.Hermitian order shA) a) (Array shX x) =
Array.unsafeCreateWithSize shX $ \n yPtr -> do
Call.assert "Hermitian.multiplyVector: width shapes mismatch" (shA == shX)
evalContT $ do
let conj = conjugatedOnRowMajor $ transposeOrder transposed order
uploPtr <- Call.char $ uploFromOrder order
nPtr <- Call.cint n
alphaPtr <- Call.number one
aPtr <- ContT $ withForeignPtr a
xPtr <- condConjugateToTemp conj n x
incxPtr <- Call.cint 1
betaPtr <- Call.number zero
incyPtr <- Call.cint 1
liftIO $ do
BlasGen.hpmv
uploPtr nPtr alphaPtr aPtr xPtr incxPtr betaPtr yPtr incyPtr
condConjugate conj nPtr yPtr incyPtr
square ::
(Shape.C sh, Eq sh, Class.Floating a) =>
Hermitian sh a -> Hermitian sh a
square (Array shape@(MatrixShape.Hermitian order sh) a) =
Array.unsafeCreate shape $
Symmetric.square Conjugated order (Shape.size sh) a
multiplyFull ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width,
Class.Floating a) =>
Transposition -> Hermitian height a ->
Full vert horiz height width a ->
Full vert horiz height width a
multiplyFull transposed
(Array (MatrixShape.Hermitian orderA shA) a)
(Array shapeB@(MatrixShape.Full orderB extentB) b) =
Array.unsafeCreate shapeB $ \cPtr -> do
let (height,width) = Extent.dimensions extentB
Call.assert "Hermitian.multiplyFull: shapes mismatch" (shA == height)
let m0 = Shape.size height
let n0 = Shape.size width
let size = m0*m0
evalContT $ do
let (side,(m,n)) = sideSwapFromOrder orderB (m0,n0)
sidePtr <- Call.char side
uploPtr <- Call.char $ uploFromOrder orderA
mPtr <- Call.cint m
nPtr <- Call.cint n
alphaPtr <- Call.number one
aPtr <- unpackToTemp (unpack orderA) m0 a
ldaPtr <- Call.leadingDim m0
incaPtr <- Call.cint 1
sizePtr <- Call.cint size
bPtr <- ContT $ withForeignPtr b
ldbPtr <- Call.leadingDim m
betaPtr <- Call.number zero
ldcPtr <- Call.leadingDim m
liftIO $ do
when (transposeOrder transposed orderA /= orderB) $
lacgv sizePtr aPtr incaPtr
BlasGen.hemm sidePtr uploPtr
mPtr nPtr alphaPtr aPtr ldaPtr
bPtr ldbPtr betaPtr cPtr ldcPtr
withConjBuffer ::
(Shape.C sh, Class.Floating a) =>
Order -> sh -> Int -> Ptr a ->
(Ptr CChar -> Ptr CInt -> Ptr CInt -> IO ()) -> ContT r IO ()
withConjBuffer order sh triSize aPtr act = do
uploPtr <- Call.char $ uploFromOrder order
nPtr <- Call.cint $ Shape.size sh
incxPtr <- Call.cint 1
sizePtr <- Call.cint triSize
liftIO $ do
fill zero triSize aPtr
act uploPtr nPtr incxPtr
condConjugate (conjugatedOnRowMajor order) sizePtr aPtr incxPtr
outer ::
(Shape.C sh, Class.Floating a) => Order -> Vector sh a -> Hermitian sh a
outer order (Array sh x) =
Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $
\triSize aPtr ->
evalContT $ do
alphaPtr <- realOneArg aPtr
xPtr <- ContT $ withForeignPtr x
withConjBuffer order sh triSize aPtr $ \uploPtr nPtr incxPtr ->
hpr uploPtr nPtr alphaPtr xPtr incxPtr aPtr
sumRank1 ::
(Shape.C sh, Eq sh, Class.Floating a) =>
Order -> sh -> [(RealOf a, Vector sh a)] -> Hermitian sh a
sumRank1 =
getSumRank1 $
Class.switchFloating
(SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux)
(SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux)
type SumRank1_ sh ar a = Order -> sh -> [(ar, Vector sh a)] -> Hermitian sh a
newtype SumRank1 sh a = SumRank1 {getSumRank1 :: SumRank1_ sh (RealOf a) a}
sumRank1Aux ::
(Shape.C sh, Eq sh, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
SumRank1_ sh ar a
sumRank1Aux order sh xs =
Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $
\triSize aPtr ->
evalContT $ do
alphaPtr <- Call.alloca
withConjBuffer order sh triSize aPtr $ \uploPtr nPtr incxPtr ->
forM_ xs $ \(alpha, Array shX x) ->
withForeignPtr x $ \xPtr -> do
Call.assert
"Hermitian.sumRank1: non-matching vector size" (sh==shX)
poke alphaPtr alpha
hpr uploPtr nPtr alphaPtr xPtr incxPtr aPtr
type HPR_ a =
Ptr CChar -> Ptr CInt ->
Ptr (RealOf a) -> Ptr a -> Ptr CInt -> Ptr a -> IO ()
newtype HPR a = HPR {getHPR :: HPR_ a}
hpr :: Class.Floating a => HPR_ a
hpr =
getHPR $
Class.switchFloating
(HPR BlasReal.spr) (HPR BlasReal.spr)
(HPR BlasComplex.hpr) (HPR BlasComplex.hpr)
sumRank2 ::
(Shape.C sh, Eq sh, Class.Floating a) =>
Order -> sh -> [(a, (Vector sh a, Vector sh a))] -> Hermitian sh a
sumRank2 order sh xys =
Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $
\triSize aPtr ->
evalContT $ do
alphaPtr <- Call.alloca
withConjBuffer order sh triSize aPtr $ \uploPtr nPtr incPtr ->
forM_ xys $ \(alpha, (Array shX x, Array shY y)) ->
withForeignPtr x $ \xPtr ->
withForeignPtr y $ \yPtr -> do
Call.assert
"Hermitian.sumRank2: non-matching x vector size" (sh==shX)
Call.assert
"Hermitian.sumRank2: non-matching y vector size" (sh==shY)
poke alphaPtr alpha
BlasGen.hpr2 uploPtr nPtr alphaPtr xPtr incPtr yPtr incPtr aPtr
toSquare, _toSquare ::
(Shape.C sh, Class.Floating a) => Hermitian sh a -> Square sh a
_toSquare (Array (MatrixShape.Hermitian order sh) a) =
Array.unsafeCreate (MatrixShape.square order sh) $ \bPtr ->
evalContT $ do
let n = Shape.size sh
aPtr <- ContT $ withForeignPtr a
conjPtr <- conjugateToTemp (Shape.triangleSize n) a
liftIO $ do
unpack (flipOrder order) n conjPtr bPtr
unpack order n aPtr bPtr
toSquare (Array (MatrixShape.Hermitian order sh) a) =
Array.unsafeCreate (MatrixShape.square order sh) $ \bPtr ->
withForeignPtr a $ \aPtr ->
Symmetric.unpack Conjugated order (Shape.size sh) aPtr bPtr
gramian ::
(Shape.C height, Shape.C width, Class.Floating a) =>
General height width a -> Hermitian width a
gramian = argGeneral $ \order height width a ->
Array.unsafeCreate (MatrixShape.Hermitian order width) $ \bPtr -> do
let n = Shape.size width
let k = Shape.size height
evalContT $ do
let (uplo,trans,lda) =
case order of
ColumnMajor -> ('U', 'C', k)
RowMajor -> ('L', 'N', n)
uploPtr <- Call.char uplo
transPtr <- Call.char trans
nPtr <- Call.cint n
kPtr <- Call.cint k
alphaPtr <- realOneArg a
aPtr <- ContT $ withForeignPtr a
ldaPtr <- Call.leadingDim lda
betaPtr <- realZeroArg a
cPtr <- Call.allocaArray (n*n)
ldcPtr <- Call.leadingDim n
liftIO $ do
herk uploPtr transPtr
nPtr kPtr alphaPtr aPtr ldaPtr betaPtr cPtr ldcPtr
pack order n cPtr bPtr
type HERK_ a =
Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr (RealOf a) -> Ptr a ->
Ptr CInt -> Ptr (RealOf a) -> Ptr a -> Ptr CInt -> IO ()
newtype HERK a = HERK {getHERK :: HERK_ a}
herk :: Class.Floating a => HERK_ a
herk =
getHERK $
Class.switchFloating
(HERK BlasReal.syrk)
(HERK BlasReal.syrk)
(HERK BlasComplex.herk)
(HERK BlasComplex.herk)
congruenceDiagonal ::
(Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
Vector height (RealOf a) -> General height width a -> Hermitian width a
congruenceDiagonal d a = scaledAnticommutator 0.5 a $ Basic.scaleRowsReal d a
congruence ::
(Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
Hermitian height a -> General height width a -> Hermitian width a
congruence b a =
scaledAnticommutator one a $
Split.tallMultiplyR NonTransposed (takeHalf b) a
data Corrupt = Corrupt
deriving (Eq)
takeHalf ::
(Shape.C sh, Class.Floating a) =>
Hermitian sh a -> Split.Square Corrupt sh a
takeHalf (Array (MatrixShape.Hermitian order sh) a) =
Array.unsafeCreate (MatrixShape.Split Corrupt order (Extent.square sh)) $
\bPtr -> evalContT $ do
let n = Shape.size sh
aPtr <- ContT $ withForeignPtr a
nPtr <- Call.cint n
alphaPtr <- Call.number 0.5
incxPtr <- Call.cint (n+1)
liftIO $ do
unpack order n aPtr bPtr
BlasGen.scal nPtr alphaPtr bPtr incxPtr
scaledAnticommutator ::
(Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
a ->
Full vert horiz height width a ->
Full vert horiz height width a -> Hermitian width a
scaledAnticommutator alpha arr (Array (MatrixShape.Full order extentB) b) = do
let (Array (MatrixShape.Full _ extentA) a) = Basic.forceOrder order arr
let (heightA,widthA) = Extent.dimensions extentA
let (heightB,widthB) = Extent.dimensions extentB
let n = Shape.size widthB
let k = Shape.size heightB
Array.unsafeCreate (MatrixShape.Hermitian order widthB) $ \cpPtr -> do
Call.assert "Hermitian.anticommutator: heights mismatch"
(heightA==heightB)
Call.assert "Hermitian.anticommutator: widths mismatch"
(widthA==widthB)
evalContT $ do
let (uplo,trans,lda) =
case order of
ColumnMajor -> ('U', 'C', k)
RowMajor -> ('L', 'N', n)
uploPtr <- Call.char uplo
transPtr <- Call.char trans
nPtr <- Call.cint n
kPtr <- Call.cint k
alphaPtr <- Call.number alpha
aPtr <- ContT $ withForeignPtr a
ldaPtr <- Call.leadingDim lda
bPtr <- ContT $ withForeignPtr b
let ldbPtr = ldaPtr
betaPtr <- realZeroArg aPtr
cPtr <- Call.allocaArray (n*n)
ldcPtr <- Call.leadingDim n
liftIO $ do
her2k uploPtr transPtr nPtr kPtr alphaPtr
aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr
pack order n cPtr cpPtr
type HER2K_ a =
Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a ->
Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt ->
Ptr (RealOf a) -> Ptr a -> Ptr CInt -> IO ()
newtype HER2K a = HER2K {getHER2K :: HER2K_ a}
her2k :: Class.Floating a => HER2K_ a
her2k =
getHER2K $
Class.switchFloating
(HER2K BlasReal.syr2k)
(HER2K BlasReal.syr2k)
(HER2K BlasComplex.her2k)
(HER2K BlasComplex.her2k)
addAdjoint, _addAdjoint ::
(Shape.C sh, Class.Floating a) => Square sh a -> Hermitian sh a
_addAdjoint =
argSquare $ \order sh a ->
Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $ \bSize bPtr -> do
let n = Shape.size sh
evalContT $ do
alphaPtr <- Call.number one
incxPtr <- Call.cint 1
aPtr <- ContT $ withForeignPtr a
sizePtr <- Call.cint bSize
conjPtr <- Call.allocaArray bSize
liftIO $ do
pack order n aPtr bPtr
pack (flipOrder order) n aPtr conjPtr
lacgv sizePtr conjPtr incxPtr
BlasGen.axpy sizePtr alphaPtr conjPtr incxPtr bPtr incxPtr
addAdjoint =
argSquare $ \order sh a ->
Array.unsafeCreate (MatrixShape.Hermitian order sh) $ \bPtr -> do
let n = Shape.size sh
evalContT $ do
alphaPtr <- Call.number one
incxPtr <- Call.cint 1
incnPtr <- Call.cint n
aPtr <- ContT $ withForeignPtr a
liftIO $ case order of
RowMajor ->
forPointers (rowMajorPointers n aPtr bPtr) $
\nPtr (srcPtr,dstPtr) -> do
copyConjugate nPtr srcPtr incnPtr dstPtr incxPtr
BlasGen.axpy nPtr alphaPtr srcPtr incxPtr dstPtr incxPtr
ColumnMajor ->
forPointers (columnMajorPointers n aPtr bPtr) $
\nPtr ((srcRowPtr,srcColumnPtr),dstPtr) -> do
copyConjugate nPtr srcRowPtr incnPtr dstPtr incxPtr
BlasGen.axpy nPtr alphaPtr srcColumnPtr incxPtr dstPtr incxPtr
_pack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
_pack order n fullPtr packedPtr =
evalContT $ do
incxPtr <- Call.cint 1
liftIO $
case order of
ColumnMajor ->
forPointers (columnMajorPointers n fullPtr packedPtr) $
\nPtr ((_,srcPtr),dstPtr) ->
BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr
RowMajor ->
forPointers (rowMajorPointers n fullPtr packedPtr) $
\nPtr (srcPtr,dstPtr) ->
BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr
realZeroArg, realOneArg ::
(Class.Floating a) => f a -> ContT r IO (Ptr (RealOf a))
realZeroArg =
runRealArg $
Class.switchFloating
(RealArg $ const $ Call.number zero)
(RealArg $ const $ Call.number zero)
(RealArg $ const $ Call.number zero)
(RealArg $ const $ Call.number zero)
realOneArg =
runRealArg $
Class.switchFloating
(RealArg $ const $ Call.number one)
(RealArg $ const $ Call.number one)
(RealArg $ const $ Call.number one)
(RealArg $ const $ Call.number one)
newtype RealArg f g a = RealArg {runRealArg :: f a -> g (Ptr (RealOf a))}