{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Hermitian.Basic (
   Hermitian,
   HermitianP,
   Transposition(..),
   diagonal,
   takeDiagonal,

   sumRank1,
   sumRank2,
   ) where

import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Scalar as Scalar
import Numeric.LAPACK.Matrix.Hermitian.Private (Diagonal(..), TakeDiagonal(..))
import Numeric.LAPACK.Matrix.Symmetric.Unified (complement)
import Numeric.LAPACK.Matrix.Mosaic.Private
         (forPointers, diagonalPointerPairs,
          rowMajorPointers, columnMajorPointers,
          withPacking, noLabel, applyFuncPair, triArg)
import Numeric.LAPACK.Matrix.Layout.Private
         (Order(RowMajor,ColumnMajor), uploFromOrder)
import Numeric.BLAS.Matrix.Modifier
         (Transposition(NonTransposed, Transposed),
          Conjugation(Conjugated), conjugatedOnRowMajor)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, zero)
import Numeric.LAPACK.Private (fill, realPtr, condConjugate)

import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.BLAS.FFI.Complex as BlasComplex
import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Foreign.C.Types (CInt, CChar)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable, poke, peek)

import Control.Monad.Trans.Cont (ContT, evalContT)
import Control.Monad.IO.Class (liftIO)

import Data.Foldable (forM_)


type Hermitian sh = Array (Layout.Hermitian sh)
type HermitianP pack sh = Array (Layout.HermitianP pack sh)


diagonal ::
   (Shape.C sh, Class.Floating a) =>
   Order -> Vector sh (RealOf a) -> Hermitian sh a
diagonal order =
   runDiagonal $
   Class.switchFloating
      (Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order)
      (Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order)

diagonalAux ::
   (Shape.C sh, Class.Floating a, RealOf a ~ ar, Storable ar) =>
   Order -> Vector sh ar -> Hermitian sh a
diagonalAux order (Array sh x) =
   Array.unsafeCreateWithSize (Layout.hermitian order sh) $
      \triSize aPtr -> do
   fill zero triSize aPtr
   withForeignPtr x $ \xPtr ->
      forM_ (diagonalPointerPairs order (Shape.size sh) xPtr aPtr) $
         \(srcPtr,dstPtr) -> poke (realPtr dstPtr) =<< peek srcPtr


takeDiagonal ::
   (Shape.C sh, Class.Floating a) =>
   Hermitian sh a -> Vector sh (RealOf a)
takeDiagonal =
   runTakeDiagonal $
   Class.switchFloating
      (TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux)
      (TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux)

takeDiagonalAux ::
   (Shape.C sh, Storable a, RealOf a ~ ar, Storable ar) =>
   Hermitian sh a -> Vector sh ar
takeDiagonalAux (Array (Layout.Mosaic _pack _mirror _upper order sh) a) =
   Array.unsafeCreateWithSize sh $ \n xPtr ->
   withForeignPtr a $ \aPtr ->
      forM_ (diagonalPointerPairs order n xPtr aPtr) $
         \(dstPtr,srcPtr) -> poke dstPtr =<< peek (realPtr srcPtr)


withConjBuffer ::
   (Shape.C sh, Class.Floating a) =>
   Layout.PackingSingleton pack -> Order -> sh -> Int -> Ptr a ->
   (Ptr CChar -> Int -> Ptr CInt -> Ptr CInt -> IO ()) -> ContT r IO ()
withConjBuffer pack order sh triSize aPtr act = do
   uploPtr <- Call.char $ uploFromOrder order
   let n = Shape.size sh
   nPtr <- Call.cint n
   incxPtr <- Call.cint 1
   sizePtr <- Call.cint triSize
   liftIO $ do
      fill zero triSize aPtr
      act uploPtr n nPtr incxPtr
      condConjugate (conjugatedOnRowMajor order) sizePtr aPtr incxPtr
      complement pack Conjugated order n aPtr


{-
Not easy to generalize to Symmetric
because LapackComplex.spr and LapackComplex.syr
expect complex parameter 'alpha'.
-}
sumRank1 ::
   (Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> sh -> [(RealOf a, Vector sh a)] -> HermitianP pack sh a
sumRank1 =
   getSumRank1 $
   Class.switchFloating
      (SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux)
      (SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux)

type SumRank1_ pack sh ar a =
      Order -> sh -> [(ar, Vector sh a)] -> HermitianP pack sh a

newtype SumRank1 pack sh a =
      SumRank1 {getSumRank1 :: SumRank1_ pack sh (RealOf a) a}

sumRank1Aux ::
   (Layout.Packing pack, Shape.C sh, Eq sh,
    Class.Floating a, RealOf a ~ ar, Storable ar) =>
   SumRank1_ pack sh ar a
sumRank1Aux order sh xs =
   let pack = Layout.autoPacking
   in Array.unsafeCreateWithSize (Layout.hermitianP pack order sh) $
      \triSize aPtr ->
   evalContT $ do
      alphaPtr <- Call.alloca
      withConjBuffer pack order sh triSize aPtr $ \uploPtr n nPtr incxPtr -> do
         forM_ xs $ \(alpha, Array shX x) ->
            withForeignPtr x $ \xPtr -> do
               Call.assert
                  "Hermitian.sumRank1: non-matching vector size" (sh==shX)
               poke alphaPtr alpha
               evalContT $ withPacking pack $
                  case Scalar.complexSingletonOfFunctor aPtr of
                     Scalar.Real ->
                        applyFuncPair
                           (noLabel BlasReal.spr) (noLabel BlasReal.syr)
                           uploPtr nPtr alphaPtr xPtr incxPtr (triArg aPtr n)
                     Scalar.Complex ->
                        applyFuncPair
                           (noLabel BlasComplex.hpr) (noLabel BlasComplex.her)
                           uploPtr nPtr alphaPtr xPtr incxPtr (triArg aPtr n)


{-
Not easy to generalize to Symmetric
because there are no Complex.spr2 and Complex.syr2.
However, there is BlasComplex.syr2k.
-}
sumRank2 ::
   (Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> sh -> [(a, (Vector sh a, Vector sh a))] -> HermitianP pack sh a
sumRank2 order sh xys =
   let pack = Layout.autoPacking
   in Array.unsafeCreateWithSize (Layout.hermitianP pack order sh) $
      \triSize aPtr ->
   evalContT $ do
      alphaPtr <- Call.alloca
      withConjBuffer pack order sh triSize aPtr $ \uploPtr n nPtr incPtr -> do
         forM_ xys $ \(alpha, (Array shX x, Array shY y)) ->
            withForeignPtr x $ \xPtr ->
            withForeignPtr y $ \yPtr -> do
               Call.assert
                  "Hermitian.sumRank2: non-matching x vector size" (sh==shX)
               Call.assert
                  "Hermitian.sumRank2: non-matching y vector size" (sh==shY)
               poke alphaPtr alpha
               evalContT $ withPacking pack $
                  applyFuncPair (noLabel BlasGen.hpr2) (noLabel BlasGen.her2)
                     uploPtr nPtr
                     alphaPtr xPtr incPtr yPtr incPtr (triArg aPtr n)


_pack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
_pack order n fullPtr packedPtr =
   evalContT $ do
      incxPtr <- Call.cint 1
      liftIO $
         case order of
            ColumnMajor ->
               forPointers (columnMajorPointers n fullPtr packedPtr) $
                  \nPtr ((_,srcPtr),dstPtr) ->
                     BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr
            RowMajor ->
               forPointers (rowMajorPointers n fullPtr packedPtr) $
                  \nPtr (srcPtr,dstPtr) ->
                     BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr