{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Singular (
   values,
   valuesTall,
   valuesWide,
   decompose,
   decomposeTall,
   decomposeWide,
   determinantAbsolute,
   leastSquaresMinimumNormRCond,
   pseudoInverseRCond,
   decomposePolar,
   RealOf,
   ) where

import qualified Numeric.LAPACK.Singular.Plain as Plain

import qualified Numeric.LAPACK.Matrix.Hermitian as Hermitian
import qualified Numeric.LAPACK.Matrix.Mosaic.Private as Mos
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix as Matrix
import Numeric.LAPACK.Matrix.Array.Banded (RectangularDiagonal)
import Numeric.LAPACK.Matrix.Array.Private (ArrayMatrix, Full, General, Square)
import Numeric.LAPACK.Matrix.Multiply ((##*#), (#*##))
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape

import Data.Tuple.HT (mapFst, mapSnd, mapPair, mapSnd3, mapTriple)


type RealVector sh a = Vector sh (RealOf a)

values ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Full meas vert horiz height width a ->
   RectangularDiagonal meas vert horiz height width (RealOf a)
values = ArrMatrix.lift1 Plain.values

valuesTall ::
   (Extent.Measure meas, Extent.C vert,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Full meas vert Extent.Small height width a -> RealVector width a
valuesTall = Plain.valuesTall . ArrMatrix.toVector

valuesWide ::
   (Extent.Measure meas, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Full meas Extent.Small horiz height width a -> RealVector height a
valuesWide = Plain.valuesWide . ArrMatrix.toVector


determinantAbsolute ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> RealOf a
determinantAbsolute = Plain.determinantAbsolute . ArrMatrix.toVector


decompose ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Full meas vert horiz height width a ->
   (Square height a,
    RectangularDiagonal meas vert horiz height width (RealOf a),
    Square width a)
decompose = mapSnd3 ArrMatrix.lift0 . liftDecompose Plain.decompose

{- |
> let (u,s,vt) = Singular.decomposeWide a
> in a  ==  u #*## Matrix.scaleRowsReal s vt
-}
decomposeWide ::
   (Extent.Measure meas, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Full meas Extent.Small horiz height width a ->
   (Square height a, RealVector height a,
      Full meas Extent.Small horiz height width a)
decomposeWide = liftDecompose Plain.decomposeWide

{- |
> let (u,s,vt) = Singular.decomposeTall a
> in a  ==  u ##*# Matrix.scaleRowsReal s vt
-}
decomposeTall ::
   (Extent.Measure meas, Extent.C vert,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Full meas vert Extent.Small height width a ->
   (Full meas vert Extent.Small height width a,
      RealVector width a, Square width a)
decomposeTall = liftDecompose Plain.decomposeTall


type FullArray meas vert horiz height width =
         ArrMatrix.PlainArray Layout.Unpacked Omni.Arbitrary
            Layout.Filled Layout.Filled
            meas vert horiz height width
type FullMatrix meas vert horiz height width =
         ArrayMatrix Layout.Unpacked Omni.Arbitrary
            Layout.Filled Layout.Filled
            meas vert horiz height width

liftDecompose ::
   (FullArray measA vertA horizA heightA widthA a ->
    (FullArray measB vertB horizB heightB widthB b, f,
     FullArray measC vertC horizC heightC widthC c)) ->
   FullMatrix measA vertA horizA heightA widthA a ->
    (FullMatrix measB vertB horizB heightB widthB b, f,
     FullMatrix measC vertC horizC heightC widthC c)
liftDecompose f =
   mapTriple (ArrMatrix.lift0, id, ArrMatrix.lift0) . f . ArrMatrix.toVector



leastSquaresMinimumNormRCond ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) =>
   RealOf a ->
   Full meas horiz vert height width a ->
   Full meas vert horiz height nrhs a ->
   (Int, Full meas vert horiz width nrhs a)
leastSquaresMinimumNormRCond rcond a b =
   mapSnd ArrMatrix.lift0 $
   Plain.leastSquaresMinimumNormRCond
      rcond (ArrMatrix.toVector a) (ArrMatrix.toVector b)

pseudoInverseRCond ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   RealOf a ->
   Full meas vert horiz height width a ->
   (Int, Full meas horiz vert width height a)
pseudoInverseRCond rcond =
   mapSnd (ArrMatrix.lift0 . Basic.recheck) .
   Plain.pseudoInverseRCond rcond .
   Basic.uncheck . ArrMatrix.toVector



{- |
In @decomposePolar a = (u,h)@,
@u@ is the orthogonal matrix closest to @a@
with respect to the 2- and the Frobenius norm.
(Higham: Functions of Matrices - Theory and Computation.)
-}
decomposePolar ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Full meas vert horiz height width a ->
   (Full meas vert horiz height width a, Matrix.Hermitian width a)
decomposePolar =
   mapPair
      (ArrMatrix.lift1 Basic.recheck,
       ArrMatrix.lift1 Mos.recheck)
   .
   getDecomposePolar
      (Extent.switchTagTriple
         (DecomposePolar decomposePolarWide)
         (DecomposePolar decomposePolarWide)
         (DecomposePolar decomposePolarWide)
         (DecomposePolar decomposePolarTall)
         (DecomposePolar $
            either
               (mapFst Matrix.fromFull . decomposePolarTall)
               (mapFst Matrix.fromFull . decomposePolarWide)
            .
            Matrix.caseTallWide))
   .
   ArrMatrix.lift1 Basic.uncheck

newtype DecomposePolar height width a meas vert horiz =
   DecomposePolar {
      getDecomposePolar ::
         Full meas vert horiz height width a ->
         (Full meas vert horiz height width a, Matrix.Hermitian width a)
   }

decomposePolarTall ::
   (Extent.Measure meas, Extent.C vert,
    Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Full meas vert Extent.Small height width a ->
   (Full meas vert Extent.Small height width a, Matrix.Hermitian width a)
decomposePolarTall a =
   let (u,s,vt) = decomposeTall a
   in (u ##*# vt, Hermitian.congruenceDiagonal s $ Matrix.fromFull vt)

decomposePolarWide ::
   (Extent.Measure meas, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   Full meas Extent.Small horiz height width a ->
   (Full meas Extent.Small horiz height width a, Matrix.Hermitian width a)
decomposePolarWide a =
   let (u,s,vt) = decomposeWide a
   in (u #*## vt, Hermitian.congruenceDiagonal s $ Matrix.fromFull vt)