{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Hermitian.Basic (
Hermitian,
HermitianP,
Transposition(..),
diagonal,
takeDiagonal,
sumRank1,
sumRank2,
) where
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Scalar as Scalar
import Numeric.LAPACK.Matrix.Hermitian.Private (Diagonal(..), TakeDiagonal(..))
import Numeric.LAPACK.Matrix.Symmetric.Unified (complement)
import Numeric.LAPACK.Matrix.Mosaic.Private
(forPointers, diagonalPointerPairs,
rowMajorPointers, columnMajorPointers,
withPacking, noLabel, applyFuncPair, triArg)
import Numeric.LAPACK.Matrix.Layout.Private
(Order(RowMajor,ColumnMajor), uploFromOrder)
import Numeric.BLAS.Matrix.Modifier
(Transposition(NonTransposed, Transposed),
Conjugation(Conjugated), conjugatedOnRowMajor)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, zero)
import Numeric.LAPACK.Private (fill, realPtr, condConjugate)
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.BLAS.FFI.Complex as BlasComplex
import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Foreign.C.Types (CInt, CChar)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable, poke, peek)
import Control.Monad.Trans.Cont (ContT, evalContT)
import Control.Monad.IO.Class (liftIO)
import Data.Foldable (forM_)
type Hermitian sh = Array (Layout.Hermitian sh)
type HermitianP pack sh = Array (Layout.HermitianP pack sh)
diagonal ::
(Shape.C sh, Class.Floating a) =>
Order -> Vector sh (RealOf a) -> Hermitian sh a
diagonal order =
runDiagonal $
Class.switchFloating
(Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order)
(Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order)
diagonalAux ::
(Shape.C sh, Class.Floating a, RealOf a ~ ar, Storable ar) =>
Order -> Vector sh ar -> Hermitian sh a
diagonalAux order (Array sh x) =
Array.unsafeCreateWithSize (Layout.hermitian order sh) $
\triSize aPtr -> do
fill zero triSize aPtr
withForeignPtr x $ \xPtr ->
forM_ (diagonalPointerPairs order (Shape.size sh) xPtr aPtr) $
\(srcPtr,dstPtr) -> poke (realPtr dstPtr) =<< peek srcPtr
takeDiagonal ::
(Shape.C sh, Class.Floating a) =>
Hermitian sh a -> Vector sh (RealOf a)
takeDiagonal =
runTakeDiagonal $
Class.switchFloating
(TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux)
(TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux)
takeDiagonalAux ::
(Shape.C sh, Storable a, RealOf a ~ ar, Storable ar) =>
Hermitian sh a -> Vector sh ar
takeDiagonalAux (Array (Layout.Mosaic _pack _mirror _upper order sh) a) =
Array.unsafeCreateWithSize sh $ \n xPtr ->
withForeignPtr a $ \aPtr ->
forM_ (diagonalPointerPairs order n xPtr aPtr) $
\(dstPtr,srcPtr) -> poke dstPtr =<< peek (realPtr srcPtr)
withConjBuffer ::
(Shape.C sh, Class.Floating a) =>
Layout.PackingSingleton pack -> Order -> sh -> Int -> Ptr a ->
(Ptr CChar -> Int -> Ptr CInt -> Ptr CInt -> IO ()) -> ContT r IO ()
withConjBuffer pack order sh triSize aPtr act = do
uploPtr <- Call.char $ uploFromOrder order
let n = Shape.size sh
nPtr <- Call.cint n
incxPtr <- Call.cint 1
sizePtr <- Call.cint triSize
liftIO $ do
fill zero triSize aPtr
act uploPtr n nPtr incxPtr
condConjugate (conjugatedOnRowMajor order) sizePtr aPtr incxPtr
complement pack Conjugated order n aPtr
sumRank1 ::
(Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) =>
Order -> sh -> [(RealOf a, Vector sh a)] -> HermitianP pack sh a
sumRank1 =
getSumRank1 $
Class.switchFloating
(SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux)
(SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux)
type SumRank1_ pack sh ar a =
Order -> sh -> [(ar, Vector sh a)] -> HermitianP pack sh a
newtype SumRank1 pack sh a =
SumRank1 {getSumRank1 :: SumRank1_ pack sh (RealOf a) a}
sumRank1Aux ::
(Layout.Packing pack, Shape.C sh, Eq sh,
Class.Floating a, RealOf a ~ ar, Storable ar) =>
SumRank1_ pack sh ar a
sumRank1Aux order sh xs =
let pack = Layout.autoPacking
in Array.unsafeCreateWithSize (Layout.hermitianP pack order sh) $
\triSize aPtr ->
evalContT $ do
alphaPtr <- Call.alloca
withConjBuffer pack order sh triSize aPtr $ \uploPtr n nPtr incxPtr -> do
forM_ xs $ \(alpha, Array shX x) ->
withForeignPtr x $ \xPtr -> do
Call.assert
"Hermitian.sumRank1: non-matching vector size" (sh==shX)
poke alphaPtr alpha
evalContT $ withPacking pack $
case Scalar.complexSingletonOfFunctor aPtr of
Scalar.Real ->
applyFuncPair
(noLabel BlasReal.spr) (noLabel BlasReal.syr)
uploPtr nPtr alphaPtr xPtr incxPtr (triArg aPtr n)
Scalar.Complex ->
applyFuncPair
(noLabel BlasComplex.hpr) (noLabel BlasComplex.her)
uploPtr nPtr alphaPtr xPtr incxPtr (triArg aPtr n)
sumRank2 ::
(Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) =>
Order -> sh -> [(a, (Vector sh a, Vector sh a))] -> HermitianP pack sh a
sumRank2 order sh xys =
let pack = Layout.autoPacking
in Array.unsafeCreateWithSize (Layout.hermitianP pack order sh) $
\triSize aPtr ->
evalContT $ do
alphaPtr <- Call.alloca
withConjBuffer pack order sh triSize aPtr $ \uploPtr n nPtr incPtr -> do
forM_ xys $ \(alpha, (Array shX x, Array shY y)) ->
withForeignPtr x $ \xPtr ->
withForeignPtr y $ \yPtr -> do
Call.assert
"Hermitian.sumRank2: non-matching x vector size" (sh==shX)
Call.assert
"Hermitian.sumRank2: non-matching y vector size" (sh==shY)
poke alphaPtr alpha
evalContT $ withPacking pack $
applyFuncPair (noLabel BlasGen.hpr2) (noLabel BlasGen.her2)
uploPtr nPtr
alphaPtr xPtr incPtr yPtr incPtr (triArg aPtr n)
_pack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
_pack order n fullPtr packedPtr =
evalContT $ do
incxPtr <- Call.cint 1
liftIO $
case order of
ColumnMajor ->
forPointers (columnMajorPointers n fullPtr packedPtr) $
\nPtr ((_,srcPtr),dstPtr) ->
BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr
RowMajor ->
forPointers (rowMajorPointers n fullPtr packedPtr) $
\nPtr (srcPtr,dstPtr) ->
BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr