{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Symmetric.Basic (
   Symmetric,
   SymmetricP,
   sumRank1,
   congruenceDiagonal, congruenceDiagonalTransposed,
   ) where

import qualified Numeric.LAPACK.Matrix.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import Numeric.LAPACK.Matrix.Symmetric.Unified
         (skipCheckCongruence, spr, syr, complement,
          scaledAnticommutator, scaledAnticommutatorTransposed)
import Numeric.LAPACK.Matrix.Mosaic.Private
         (withPacking, noLabel, applyFuncPair, triArg)
import Numeric.LAPACK.Matrix.Layout.Private
         (MirrorSingleton(SimpleMirror), Order, uploFromOrder)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated))
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (zero)
import Numeric.LAPACK.Private (fill)

import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Storable (poke)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)

import Data.Foldable (forM_)


type Symmetric sh = SymmetricP Layout.Unpacked sh
type SymmetricP pack sh = Array (Layout.SymmetricP pack sh)


sumRank1 ::
   (Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> sh -> [(a, Vector sh a)] -> SymmetricP pack sh a
sumRank1 :: Order -> sh -> [(a, Vector sh a)] -> SymmetricP pack sh a
sumRank1 Order
order sh
sh [(a, Vector sh a)]
xs =
   let pack :: PackingSingleton pack
pack = PackingSingleton pack
forall pack. Packing pack => PackingSingleton pack
Layout.autoPacking
   in SymmetricP pack sh
-> (Int -> Ptr a -> IO ()) -> SymmetricP pack sh a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize (PackingSingleton pack -> Order -> sh -> SymmetricP pack sh
forall pack size.
PackingSingleton pack -> Order -> size -> SymmetricP pack size
Layout.symmetricP PackingSingleton pack
pack Order
order sh
sh) ((Int -> Ptr a -> IO ()) -> SymmetricP pack sh a)
-> (Int -> Ptr a -> IO ()) -> SymmetricP pack sh a
forall a b. (a -> b) -> a -> b
$
      \Int
triSize Ptr a
aPtr -> do

   let n :: Int
n = sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
uploFromOrder Order
order
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
alphaPtr <- FortranIO () (Ptr a)
forall a r. Storable a => FortranIO r (Ptr a)
Call.alloca
      Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         a -> Int -> Ptr a -> IO ()
forall a. Floating a => a -> Int -> Ptr a -> IO ()
fill a
forall a. Floating a => a
zero Int
triSize Ptr a
aPtr
         [(a, Vector sh a)] -> ((a, Vector sh a) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(a, Vector sh a)]
xs (((a, Vector sh a) -> IO ()) -> IO ())
-> ((a, Vector sh a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(a
alpha, Array sh
shX ForeignPtr a
x) -> do
            String -> Bool -> IO ()
Call.assert String
"Symmetric.sumRank1: non-matching vector size" (sh
shsh -> sh -> Bool
forall a. Eq a => a -> a -> Bool
==sh
shX)
            Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr a
alphaPtr a
alpha
            ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
               Ptr a
xPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
               PackingSingleton pack
-> Labelled2 () () (IO ()) (IO ()) -> ContT () IO ()
forall pack r.
PackingSingleton pack
-> Labelled2 r () (IO ()) (IO ()) -> ContT r IO ()
withPacking PackingSingleton pack
pack (Labelled2 () () (IO ()) (IO ()) -> ContT () IO ())
-> Labelled2 () () (IO ()) (IO ()) -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
                  Labelled
  ()
  ()
  (FuncPacked
     (Ptr CChar
      -> Ptr CInt
      -> Ptr a
      -> Ptr a
      -> Ptr CInt
      -> TriArg a
      -> Labelled2 () () (IO ()) (IO ())))
-> Labelled
     ()
     ()
     (FuncUnpacked
        (Ptr CChar
         -> Ptr CInt
         -> Ptr a
         -> Ptr a
         -> Ptr CInt
         -> TriArg a
         -> Labelled2 () () (IO ()) (IO ())))
-> Ptr CChar
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> TriArg a
-> Labelled2 () () (IO ()) (IO ())
forall (m :: * -> *) f.
(m ~ Labelled (FuncCont f) (FuncLabel f), FunctionPair f) =>
m (FuncPacked f) -> m (FuncUnpacked f) -> f
applyFuncPair (SYR a a (IO ()) -> Labelled () () (SYR a a (IO ()))
forall a r. a -> Labelled r () a
noLabel SYR a a (IO ())
forall a. Floating a => SYR a a (IO ())
spr) (SYR a a (Ptr CInt -> IO ())
-> Labelled () () (SYR a a (Ptr CInt -> IO ()))
forall a r. a -> Labelled r () a
noLabel SYR a a (Ptr CInt -> IO ())
forall a. Floating a => SYR a a (Ptr CInt -> IO ())
syr)
                     Ptr CChar
uploPtr Ptr CInt
nPtr Ptr a
alphaPtr Ptr a
xPtr Ptr CInt
incxPtr (Ptr a -> Int -> TriArg a
forall a. Ptr a -> Int -> TriArg a
triArg Ptr a
aPtr Int
n)
   PackingSingleton pack
-> Conjugation -> Order -> Int -> Ptr a -> IO ()
forall a pack.
Floating a =>
PackingSingleton pack
-> Conjugation -> Order -> Int -> Ptr a -> IO ()
complement PackingSingleton pack
pack Conjugation
NonConjugated Order
order Int
n Ptr a
aPtr


congruenceDiagonal ::
   (Layout.Packing pack,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a -> Matrix.General height width a -> SymmetricP pack width a
congruenceDiagonal :: Vector height a
-> General height width a -> SymmetricP pack width a
congruenceDiagonal Vector height a
d =
   ((width -> Unchecked width)
 -> General height width a
 -> Full Size Big Big height (Unchecked width) a)
-> (Full Size Big Big height (Unchecked width) a
    -> Mosaic pack SimpleMirror Upper (Unchecked width) a)
-> General height width a
-> SymmetricP pack width a
forall sh matrix0 matrix1 pack mirror uplo a.
((sh -> Unchecked sh) -> matrix0 -> matrix1)
-> (matrix1 -> Mosaic pack mirror uplo (Unchecked sh) a)
-> matrix0
-> Mosaic pack mirror uplo sh a
skipCheckCongruence (width -> Unchecked width)
-> General height width a
-> Full Size Big Big height (Unchecked width) a
forall vert horiz widthA widthB height a.
(C vert, C horiz) =>
(widthA -> widthB)
-> Full Size vert horiz height widthA a
-> Full Size vert horiz height widthB a
Basic.mapWidth ((Full Size Big Big height (Unchecked width) a
  -> Mosaic pack SimpleMirror Upper (Unchecked width) a)
 -> General height width a -> SymmetricP pack width a)
-> (Full Size Big Big height (Unchecked width) a
    -> Mosaic pack SimpleMirror Upper (Unchecked width) a)
-> General height width a
-> SymmetricP pack width a
forall a b. (a -> b) -> a -> b
$ \Full Size Big Big height (Unchecked width) a
a ->
      MirrorSingleton SimpleMirror
-> a
-> Full Size Big Big height (Unchecked width) a
-> Full Size Big Big height (Unchecked width) a
-> Mosaic pack SimpleMirror Upper (Unchecked width) a
forall pack mirror meas vert horiz height width a.
(Packing pack, Mirror mirror, Measure meas, C vert, C horiz,
 C height, Eq height, C width, Eq width, Floating a) =>
MirrorSingleton mirror
-> a
-> Full meas vert horiz height width a
-> Full meas vert horiz height width a
-> Mosaic pack mirror Upper width a
scaledAnticommutator MirrorSingleton SimpleMirror
SimpleMirror a
0.5 Full Size Big Big height (Unchecked width) a
a (Full Size Big Big height (Unchecked width) a
 -> Mosaic pack SimpleMirror Upper (Unchecked width) a)
-> Full Size Big Big height (Unchecked width) a
-> Mosaic pack SimpleMirror Upper (Unchecked width) a
forall a b. (a -> b) -> a -> b
$
         Vector height a
-> Full Size Big Big height (Unchecked width) a
-> Full Size Big Big height (Unchecked width) a
forall meas vert horiz height width a.
(Measure meas, C vert, C horiz, C height, Eq height, C width,
 Floating a) =>
Vector height a
-> Full meas vert horiz height width a
-> Full meas vert horiz height width a
Basic.scaleRows Vector height a
d Full Size Big Big height (Unchecked width) a
a

congruenceDiagonalTransposed ::
   (Layout.Packing pack,
    Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Matrix.General height width a -> Vector width a -> SymmetricP pack height a
congruenceDiagonalTransposed :: General height width a
-> Vector width a -> SymmetricP pack height a
congruenceDiagonalTransposed =
   (Vector width a
 -> General height width a -> SymmetricP pack height a)
-> General height width a
-> Vector width a
-> SymmetricP pack height a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Vector width a
  -> General height width a -> SymmetricP pack height a)
 -> General height width a
 -> Vector width a
 -> SymmetricP pack height a)
-> (Vector width a
    -> General height width a -> SymmetricP pack height a)
-> General height width a
-> Vector width a
-> SymmetricP pack height a
forall a b. (a -> b) -> a -> b
$ \Vector width a
d -> ((height -> Unchecked height)
 -> General height width a
 -> Full Size Big Big (Unchecked height) width a)
-> (Full Size Big Big (Unchecked height) width a
    -> Mosaic pack SimpleMirror Upper (Unchecked height) a)
-> General height width a
-> SymmetricP pack height a
forall sh matrix0 matrix1 pack mirror uplo a.
((sh -> Unchecked sh) -> matrix0 -> matrix1)
-> (matrix1 -> Mosaic pack mirror uplo (Unchecked sh) a)
-> matrix0
-> Mosaic pack mirror uplo sh a
skipCheckCongruence (height -> Unchecked height)
-> General height width a
-> Full Size Big Big (Unchecked height) width a
forall vert horiz heightA heightB width a.
(C vert, C horiz) =>
(heightA -> heightB)
-> Full Size vert horiz heightA width a
-> Full Size vert horiz heightB width a
Basic.mapHeight ((Full Size Big Big (Unchecked height) width a
  -> Mosaic pack SimpleMirror Upper (Unchecked height) a)
 -> General height width a -> SymmetricP pack height a)
-> (Full Size Big Big (Unchecked height) width a
    -> Mosaic pack SimpleMirror Upper (Unchecked height) a)
-> General height width a
-> SymmetricP pack height a
forall a b. (a -> b) -> a -> b
$ \Full Size Big Big (Unchecked height) width a
a ->
      MirrorSingleton SimpleMirror
-> a
-> Full Size Big Big (Unchecked height) width a
-> Full Size Big Big (Unchecked height) width a
-> Mosaic pack SimpleMirror Upper (Unchecked height) a
forall pack mirror meas vert horiz height width a.
(Packing pack, Mirror mirror, Measure meas, C vert, C horiz,
 C height, Eq height, C width, Eq width, Floating a) =>
MirrorSingleton mirror
-> a
-> Full meas vert horiz height width a
-> Full meas vert horiz height width a
-> Mosaic pack mirror Upper height a
scaledAnticommutatorTransposed MirrorSingleton SimpleMirror
SimpleMirror a
0.5 Full Size Big Big (Unchecked height) width a
a (Full Size Big Big (Unchecked height) width a
 -> Mosaic pack SimpleMirror Upper (Unchecked height) a)
-> Full Size Big Big (Unchecked height) width a
-> Mosaic pack SimpleMirror Upper (Unchecked height) a
forall a b. (a -> b) -> a -> b
$
         Vector width a
-> Full Size Big Big (Unchecked height) width a
-> Full Size Big Big (Unchecked height) width a
forall meas vert horiz height width a.
(Measure meas, C vert, C horiz, C height, C width, Eq width,
 Floating a) =>
Vector width a
-> Full meas vert horiz height width a
-> Full meas vert horiz height width a
Basic.scaleColumns Vector width a
d Full Size Big Big (Unchecked height) width a
a