{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Matrix.HermitianPositiveDefinite.Linear (
   solve,
   solveDecomposed,
   inverse,
   decompose,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Hermitian.Basic (Hermitian)
import Numeric.LAPACK.Matrix.Hermitian.Private (Determinant(..))
import Numeric.LAPACK.Matrix.Triangular.Basic (Upper, takeDiagonal)
import Numeric.LAPACK.Matrix.Triangular.Private (copyTriangleToTemp)
import Numeric.LAPACK.Matrix.Shape.Private (NonUnit(NonUnit), uploFromOrder)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(Conjugated))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Linear.Private (solver)
import Numeric.LAPACK.Scalar (RealOf, realPart)
import Numeric.LAPACK.Private (copyBlock, withInfo, rankMsg, definiteMsg)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape (triangleSize)

import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)


solve ::
   (Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Hermitian sh a ->
   Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solve :: Hermitian sh a
-> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solve (Array (MatrixShape.Hermitian Order
orderA sh
shA) ForeignPtr a
a) =
   String
-> sh
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall vert horiz height width a.
(C vert, C horiz, C height, C width, Eq height, Floating a) =>
String
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz height width a
-> Full vert horiz height width a
solver String
"Hermitian.solve" sh
shA ((Int
  -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
 -> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a)
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr -> do
      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
uploFromOrder Order
orderA
      Ptr a
apPtr <- Conjugation -> Order -> Int -> ForeignPtr a -> ContT () IO (Ptr a)
forall a r.
Floating a =>
Conjugation -> Order -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyTriangleToTemp Conjugation
Conjugated Order
orderA (Int -> Int
triangleSize Int
n) ForeignPtr a
a
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
definiteMsg String
"ppsv" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.ppsv Ptr CChar
uploPtr Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
apPtr Ptr a
xPtr Ptr CInt
ldxPtr

solveDecomposed ::
   (Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Upper sh a -> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solveDecomposed :: Upper sh a
-> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solveDecomposed (Array (MatrixShape.Triangular NonUnit
NonUnit (Empty, Filled)
_uplo Order
orderA sh
shA) ForeignPtr a
a) =
   String
-> sh
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall vert horiz height width a.
(C vert, C horiz, C height, C width, Eq height, Floating a) =>
String
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz height width a
-> Full vert horiz height width a
solver String
"Hermitian.solveDecomposed" sh
shA ((Int
  -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
 -> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a)
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr -> do
      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
uploFromOrder Order
orderA
      Ptr a
apPtr <- Conjugation -> Order -> Int -> ForeignPtr a -> ContT () IO (Ptr a)
forall a r.
Floating a =>
Conjugation -> Order -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyTriangleToTemp Conjugation
Conjugated Order
orderA (Int -> Int
triangleSize Int
n) ForeignPtr a
a
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
rankMsg String
"pptrs" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.pptrs Ptr CChar
uploPtr Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
apPtr Ptr a
xPtr Ptr CInt
ldxPtr


inverse :: (Shape.C sh, Class.Floating a) => Hermitian sh a -> Hermitian sh a
inverse :: Hermitian sh a -> Hermitian sh a
inverse
   (Array shape :: Hermitian sh
shape@(MatrixShape.Hermitian Order
order sh
sh) ForeignPtr a
a) =
      Hermitian sh -> (Int -> Ptr a -> IO ()) -> Hermitian sh a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize Hermitian sh
shape ((Int -> Ptr a -> IO ()) -> Hermitian sh a)
-> (Int -> Ptr a -> IO ()) -> Hermitian sh a
forall a b. (a -> b) -> a -> b
$ \Int
triSize Ptr a
bPtr -> do
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
uploFromOrder Order
order
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint (Int -> FortranIO () (Ptr CInt)) -> Int -> FortranIO () (Ptr CInt)
forall a b. (a -> b) -> a -> b
$ sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         Int -> Ptr a -> Ptr a -> IO ()
forall a. Floating a => Int -> Ptr a -> Ptr a -> IO ()
copyBlock Int
triSize Ptr a
aPtr Ptr a
bPtr
         String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
definiteMsg String
"pptrf" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
LapackGen.pptrf Ptr CChar
uploPtr Ptr CInt
nPtr Ptr a
bPtr
         String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
rankMsg String
"pptri" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
LapackGen.pptri Ptr CChar
uploPtr Ptr CInt
nPtr Ptr a
bPtr

decompose :: (Shape.C sh, Class.Floating a) => Hermitian sh a -> Upper sh a
decompose :: Hermitian sh a -> Upper sh a
decompose
   (Array (MatrixShape.Hermitian Order
order sh
sh) ForeignPtr a
a) =
      Triangular Empty NonUnit Filled sh
-> (Int -> Ptr a -> IO ()) -> Upper sh a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize
         (NonUnit
-> (Empty, Filled)
-> Order
-> sh
-> Triangular Empty NonUnit Filled sh
forall lo diag up size.
diag -> (lo, up) -> Order -> size -> Triangular lo diag up size
MatrixShape.Triangular NonUnit
NonUnit (Empty, Filled)
MatrixShape.upper Order
order sh
sh) ((Int -> Ptr a -> IO ()) -> Upper sh a)
-> (Int -> Ptr a -> IO ()) -> Upper sh a
forall a b. (a -> b) -> a -> b
$
            \Int
triSize Ptr a
bPtr -> do
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
uploFromOrder Order
order
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint (Int -> FortranIO () (Ptr CInt)) -> Int -> FortranIO () (Ptr CInt)
forall a b. (a -> b) -> a -> b
$ sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         Int -> Ptr a -> Ptr a -> IO ()
forall a. Floating a => Int -> Ptr a -> Ptr a -> IO ()
copyBlock Int
triSize Ptr a
aPtr Ptr a
bPtr
         String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
definiteMsg String
"pptrf" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
LapackGen.pptrf Ptr CChar
uploPtr Ptr CInt
nPtr Ptr a
bPtr


determinant :: (Shape.C sh, Class.Floating a) => Hermitian sh a -> RealOf a
determinant :: Hermitian sh a -> RealOf a
determinant =
   Determinant (Array (Hermitian sh)) a -> Hermitian sh a -> RealOf a
forall (f :: * -> *) a. Determinant f a -> f a -> RealOf a
getDeterminant (Determinant (Array (Hermitian sh)) a
 -> Hermitian sh a -> RealOf a)
-> Determinant (Array (Hermitian sh)) a
-> Hermitian sh a
-> RealOf a
forall a b. (a -> b) -> a -> b
$
   Determinant (Array (Hermitian sh)) Float
-> Determinant (Array (Hermitian sh)) Double
-> Determinant (Array (Hermitian sh)) (Complex Float)
-> Determinant (Array (Hermitian sh)) (Complex Double)
-> Determinant (Array (Hermitian sh)) a
forall a (f :: * -> *).
Floating a =>
f Float
-> f Double -> f (Complex Float) -> f (Complex Double) -> f a
Class.switchFloating
      ((Array (Hermitian sh) Float -> RealOf Float)
-> Determinant (Array (Hermitian sh)) Float
forall (f :: * -> *) a. (f a -> RealOf a) -> Determinant f a
Determinant Array (Hermitian sh) Float -> RealOf Float
forall sh a ar.
(C sh, Floating a, RealOf a ~ ar, Real ar) =>
Hermitian sh a -> ar
determinantAux) ((Array (Hermitian sh) Double -> RealOf Double)
-> Determinant (Array (Hermitian sh)) Double
forall (f :: * -> *) a. (f a -> RealOf a) -> Determinant f a
Determinant Array (Hermitian sh) Double -> RealOf Double
forall sh a ar.
(C sh, Floating a, RealOf a ~ ar, Real ar) =>
Hermitian sh a -> ar
determinantAux)
      ((Array (Hermitian sh) (Complex Float) -> RealOf (Complex Float))
-> Determinant (Array (Hermitian sh)) (Complex Float)
forall (f :: * -> *) a. (f a -> RealOf a) -> Determinant f a
Determinant Array (Hermitian sh) (Complex Float) -> RealOf (Complex Float)
forall sh a ar.
(C sh, Floating a, RealOf a ~ ar, Real ar) =>
Hermitian sh a -> ar
determinantAux) ((Array (Hermitian sh) (Complex Double) -> RealOf (Complex Double))
-> Determinant (Array (Hermitian sh)) (Complex Double)
forall (f :: * -> *) a. (f a -> RealOf a) -> Determinant f a
Determinant Array (Hermitian sh) (Complex Double) -> RealOf (Complex Double)
forall sh a ar.
(C sh, Floating a, RealOf a ~ ar, Real ar) =>
Hermitian sh a -> ar
determinantAux)

determinantAux ::
   (Shape.C sh, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Hermitian sh a -> ar
determinantAux :: Hermitian sh a -> ar
determinantAux =
   (ar -> Int -> ar
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int)) (ar -> ar) -> (Hermitian sh a -> ar) -> Hermitian sh a -> ar
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ar] -> ar
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([ar] -> ar) -> (Hermitian sh a -> [ar]) -> Hermitian sh a -> ar
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> ar) -> [a] -> [ar]
forall a b. (a -> b) -> [a] -> [b]
map a -> ar
forall a. Floating a => a -> RealOf a
realPart ([a] -> [ar]) -> (Hermitian sh a -> [a]) -> Hermitian sh a -> [ar]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
   Array sh a -> [a]
forall sh a. (C sh, Storable a) => Array sh a -> [a]
Array.toList (Array sh a -> [a])
-> (Hermitian sh a -> Array sh a) -> Hermitian sh a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Triangular Empty NonUnit Filled sh a -> Array sh a
forall lo up sh a diag.
(Content lo, Content up, C sh, Floating a) =>
Triangular lo diag up sh a -> Vector sh a
takeDiagonal (Triangular Empty NonUnit Filled sh a -> Array sh a)
-> (Hermitian sh a -> Triangular Empty NonUnit Filled sh a)
-> Hermitian sh a
-> Array sh a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Hermitian sh a -> Triangular Empty NonUnit Filled sh a
forall sh a. (C sh, Floating a) => Hermitian sh a -> Upper sh a
decompose