{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Symmetric.Basic (
   Symmetric,

   gramian,              gramianTransposed,
   congruenceDiagonal,   congruenceDiagonalTransposed,
   congruence,           congruenceTransposed,
   scaledAnticommutator, scaledAnticommutatorTransposed,
   ) where

import qualified Numeric.LAPACK.Matrix.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Split as Split
import Numeric.LAPACK.Matrix.Triangular.Private (pack, recheck)
import Numeric.LAPACK.Matrix.Shape.Private
         (Order(RowMajor,ColumnMajor), NonUnit(NonUnit))
import Numeric.LAPACK.Matrix.Modifier
         (Transposition(NonTransposed, Transposed))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (zero, one)
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked))

import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)


type Symmetric sh = Array (MatrixShape.FlexSymmetric NonUnit sh)


-- cf. Hermitian.Basic
gramian ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Matrix.General height width a -> Symmetric width a
gramian (Array (MatrixShape.Full order extent) a) =
   Array.unsafeCreate (symmetricShape order $ Extent.width extent) $
   \bPtr -> gramianIO order a bPtr $ gramianParameters order extent

gramianParameters ::
   (Extent.C horiz, Extent.C vert, Shape.C height, Shape.C width) =>
   Order ->
   Extent.Extent vert horiz height width ->
   ((Int, Int), (Char, Char, Int))
gramianParameters order extent =
   let (height, width) = Extent.dimensions extent
       n = Shape.size width
       k = Shape.size height
    in ((n,k),
         case order of
            ColumnMajor -> ('U', 'T', k)
            RowMajor -> ('L', 'N', n))


gramianTransposed ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Matrix.General height width a -> Symmetric height a
gramianTransposed (Array (MatrixShape.Full order extent) a) =
   Array.unsafeCreate (symmetricShape order $ Extent.height extent) $
   \bPtr -> gramianIO order a bPtr $ gramianTransposedParameters order extent

gramianTransposedParameters ::
   (Extent.C horiz, Extent.C vert, Shape.C height, Shape.C width) =>
   Order ->
   Extent.Extent vert horiz height width ->
   ((Int, Int), (Char, Char, Int))
gramianTransposedParameters order extent =
   let (height, width) = Extent.dimensions extent
       n = Shape.size height
       k = Shape.size width
   in ((n,k),
         case order of
            ColumnMajor -> ('U', 'N', n)
            RowMajor -> ('L', 'T', k))

gramianIO ::
   (Class.Floating a) =>
   Order ->
   ForeignPtr a -> Ptr a ->
   ((Int, Int), (Char, Char, Int)) -> IO ()
gramianIO order a bPtr ((n,k), (uplo,trans,lda)) =
   evalContT $ do
      uploPtr <- Call.char uplo
      transPtr <- Call.char trans
      nPtr <- Call.cint n
      kPtr <- Call.cint k
      alphaPtr <- Call.number one
      aPtr <- ContT $ withForeignPtr a
      ldaPtr <- Call.leadingDim lda
      betaPtr <- Call.number zero
      cPtr <- Call.allocaArray (n*n)
      ldcPtr <- Call.leadingDim n
      liftIO $ do
         BlasGen.syrk uploPtr transPtr
            nPtr kPtr alphaPtr aPtr ldaPtr betaPtr cPtr ldcPtr
         pack order n cPtr bPtr

skipCheckCongruence ::
   ((sh -> Unchecked sh) -> matrix0 -> matrix1) ->
   (matrix1 -> Symmetric (Unchecked sh) a) -> matrix0 -> Symmetric sh a
skipCheckCongruence mapSize f a =
   recheck $ f $ mapSize Unchecked a


congruenceDiagonal ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a -> Matrix.General height width a -> Symmetric width a
congruenceDiagonal d =
   skipCheckCongruence Basic.mapWidth $ \a ->
      scaledAnticommutator 0.5 a $ Basic.scaleRows d a

congruenceDiagonalTransposed ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Matrix.General height width a -> Vector width a -> Symmetric height a
congruenceDiagonalTransposed =
   flip $ \d -> skipCheckCongruence Basic.mapHeight $ \a ->
      scaledAnticommutatorTransposed 0.5 a $ Basic.scaleColumns d a


congruence ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Symmetric height a -> Matrix.General height width a -> Symmetric width a
congruence b =
   skipCheckCongruence Basic.mapWidth $ \a ->
      scaledAnticommutator one a $
      Split.tallMultiplyR NonTransposed
         (Split.takeHalf MatrixShape.triangularOrder b) a

congruenceTransposed ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Matrix.General height width a -> Symmetric width a -> Symmetric height a
congruenceTransposed =
   flip $ \b -> skipCheckCongruence Basic.mapHeight $ \a ->
      scaledAnticommutatorTransposed one a $
      Basic.swapMultiply (Split.tallMultiplyR Transposed)
         a (Split.takeHalf MatrixShape.triangularOrder b)


scaledAnticommutator ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   a ->
   Full vert horiz height width a ->
   Full vert horiz height width a -> Symmetric width a
scaledAnticommutator alpha arr (Array (MatrixShape.Full order extentB) b) = do
   let (Array (MatrixShape.Full _ extentA) a) = Basic.forceOrder order arr
   Array.unsafeCreate (symmetricShape order $ Extent.width extentB) $
         \cpPtr -> do
      Call.assert "Symmetric.anticommutator: extents mismatch"
         (extentA==extentB)
      scaledAnticommutatorIO alpha order a b cpPtr $
         gramianParameters order extentB

scaledAnticommutatorTransposed ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   a ->
   Full vert horiz height width a ->
   Full vert horiz height width a -> Symmetric height a
scaledAnticommutatorTransposed
      alpha arr (Array (MatrixShape.Full order extentB) b) = do
   let (Array (MatrixShape.Full _ extentA) a) = Basic.forceOrder order arr
   Array.unsafeCreate (symmetricShape order $ Extent.height extentB) $
         \cpPtr -> do
      Call.assert "Symmetric.anticommutatorTransposed: extents mismatch"
         (extentA==extentB)
      scaledAnticommutatorIO alpha order a b cpPtr $
         gramianTransposedParameters order extentB

scaledAnticommutatorIO ::
   (Class.Floating a) =>
   a ->
   Order -> ForeignPtr a -> ForeignPtr a -> Ptr a ->
   ((Int, Int), (Char, Char, Int)) -> IO ()
scaledAnticommutatorIO alpha order a b cpPtr ((n,k), (uplo,trans,lda)) =
   evalContT $ do
      uploPtr <- Call.char uplo
      transPtr <- Call.char trans
      nPtr <- Call.cint n
      kPtr <- Call.cint k
      alphaPtr <- Call.number alpha
      aPtr <- ContT $ withForeignPtr a
      ldaPtr <- Call.leadingDim lda
      bPtr <- ContT $ withForeignPtr b
      let ldbPtr = ldaPtr
      betaPtr <- Call.number zero
      cPtr <- Call.allocaArray (n*n)
      ldcPtr <- Call.leadingDim n
      liftIO $ do
         BlasGen.syr2k uploPtr transPtr nPtr kPtr alphaPtr
            aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr
         pack order n cPtr cpPtr


symmetricShape :: Order -> size -> MatrixShape.Symmetric size
symmetricShape = MatrixShape.Triangular NonUnit MatrixShape.autoUplo