{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Plain.Class (
   check,
   -- for testing
   fromList,
   ) where

import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as ExtentPriv
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Matrix.BandedHermitian.Basic as BandedHermitian
import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded
import qualified Numeric.LAPACK.Matrix.Triangular.Basic as Triangular
import qualified Numeric.LAPACK.Matrix.Mosaic.Basic as Mosaic
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Private as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Scalar as Scalar
import Numeric.LAPACK.Matrix.Shape.Omni (Omni)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Wrapper (Flip(Flip, getFlip))

import qualified Numeric.Netlib.Class as Class

import qualified Type.Data.Num.Unary as Unary
import Type.Base.Proxy (Proxy(Proxy))

import Control.Applicative ((<|>))

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array, (!))

import qualified Data.Complex as Complex
import Data.Functor.Compose (Compose(Compose, getCompose))
import Data.Tuple.HT (mapPair)
import Data.Maybe.HT (toMaybe)
import Data.Bool.HT (implies)



check ::
   (Omni.Plain pack property lower upper meas vert horiz height width ~ shape,
    Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Omni pack property lower upper meas vert horiz height width ->
   Array shape a -> Maybe String
check omni =
   case omni of
      Omni.Full shape -> \arr0 ->
         let arr =
               Basic.mapHeight Shape.Deferred $
               Basic.mapWidth Shape.Deferred $
               Matrix.fromFull $ Array.reshape shape arr0
         in checkStrips omni arr <|> checkProperty omni arr
      Omni.LowerTriangular _ -> checkTriangular (diagTag omni)
      Omni.UpperTriangular _ -> checkTriangular (diagTag omni)
      Omni.Hermitian _ -> assertReal "Hermitian" . Mosaic.takeDiagonal
      Omni.Banded _ -> checkBanded
      Omni.UnitBandedTriangular _ -> \arr ->
         checkBanded arr <|>
         assertOnes "BandedTriangular" (Banded.takeDiagonal arr)
      Omni.BandedHermitian _ -> \arr ->
         let u = BandedHermitian.takeUpper arr
         in checkBanded u <|>
            assertReal "BandedHermitian" (Banded.takeDiagonal u)
      _ -> const Nothing

fromList ::
   (Omni.ToPlain pack property lower upper meas vert horiz height width,
    Omni pack property lower upper meas vert horiz height width ~ shape,
    Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   shape -> [a] -> Array shape a
fromList omni =
   (\arr -> maybe arr error $ check omni $ Array.mapShape Omni.toPlain arr) .
   CheckedArray.fromList omni


assert :: msg -> Bool -> Maybe msg
assert msg = flip toMaybe msg . not

proxyFromBands :: f n -> Proxy n
proxyFromBands _ = Proxy

outsideBands :: Omni.StripSingleton offDiag -> Int -> Bool
outsideBands strip =
   case strip of
      Omni.StripFilled -> const False
      Omni.StripBands k -> (> Unary.integralFromProxy (proxyFromBands k))

checkStrips ::
   (Omni.Strip lower, Omni.Strip upper) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Omni pack property lower upper meas vert horiz height width ->
   Matrix.General (Shape.Deferred height) (Shape.Deferred width) a ->
   Maybe String
checkStrips omni arr =
   case mapPair (outsideBands, outsideBands) $ Omni.strips omni of
      (outsideLower,outsideUpper) ->
         assert "non-zero elements outside of declared bands" $
         all (Scalar.isZero . (arr!)) $
         filter
            (\(Shape.DeferredIndex r, Shape.DeferredIndex c) ->
               outsideLower (r-c) || outsideUpper (c-r)) $
         Shape.indices $ Array.shape arr

checkProperty ::
   (Omni.Property property) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Omni pack property lower upper meas vert horiz height width ->
   Matrix.General (Shape.Deferred height) (Shape.Deferred width) a ->
   Maybe String
checkProperty omni arr =
   let shape = Array.shape arr
   in case Omni.property omni of
         Omni.PropArbitrary -> Nothing
         Omni.PropUnit ->
            assert "non-unit diagonal in unpacked matrix" $
            all (Scalar.equal Scalar.one . (arr!)) $
            zip
               (Shape.indices $ Matrix.height arr)
               (Shape.indices $ Matrix.width arr)
         Omni.PropSymmetric ->
            assert "symmetry violated in unpacked matrix" $
            all
               (\ix@(Shape.DeferredIndex r, Shape.DeferredIndex c) ->
                  let tix = (Shape.DeferredIndex c, Shape.DeferredIndex r)
                  in Shape.inBounds shape tix `implies`
                        Scalar.equal (arr!ix) (arr!tix)) $
            Shape.indices shape
         Omni.PropHermitian ->
            assert "conjugated symmetry violated in unpacked matrix" $
            all
               (\ix@(Shape.DeferredIndex r, Shape.DeferredIndex c) ->
                  let tix = (Shape.DeferredIndex c, Shape.DeferredIndex r)
                  in Shape.inBounds shape tix `implies`
                        Scalar.equal (arr!ix) (Scalar.conjugate $ arr!tix)) $
            Shape.indices shape


diagTag ::
   (Omni.TriDiag diag) =>
   MatrixShape.Triangular lo diag up sh -> Omni.DiagSingleton diag
diagTag _ = Omni.autoDiag

checkTriangular ::
   (Layout.UpLo uplo, Omni.TriDiag diag,
    Shape.C size, Class.Floating a) =>
   Omni.DiagSingleton diag ->
   Triangular.Triangular uplo size a -> Maybe String
checkTriangular diag =
   case diag of
      Omni.Arbitrary -> const Nothing
      Omni.Unit -> assertOnes "Triangular" . Mosaic.takeDiagonal

checkBanded ::
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width,  Class.Floating a) =>
   Banded.Banded sub super meas vert horiz height width a -> Maybe String
checkBanded arr0 =
   let arr =
         Banded.mapHeight Shape.Deferred $
         Banded.mapWidth Shape.Deferred $
         Banded.mapExtent ExtentPriv.toGeneral arr0
   in assert "Banded with non-zero unused elements" $
      all (Scalar.isZero . (arr!)) $
      filter
         (\ix -> case ix of Layout.InsideBox _ _ -> False; _ -> True) $
      Shape.indices $ Array.shape arr

assertReal ::
   (Shape.C sh, Class.Floating a) => String -> Vector sh a -> Maybe String
assertReal name = assert (name ++ " with non-real diagonal") . isReal

isReal :: (Shape.C sh, Class.Floating a) => Vector sh a -> Bool
isReal =
   getFlip $ getCompose $
   Class.switchFloating
      (Compose $ Flip $ const True)
      (Compose $ Flip $ const True)
      (Compose $ Flip isComplexReal)
      (Compose $ Flip isComplexReal)

isComplexReal ::
   (Shape.C sh, Class.Real a) => Vector sh (Complex.Complex a) -> Bool
isComplexReal = all Scalar.isZero . Vector.toList . Vector.imaginaryPart

assertOnes ::
   (Shape.C sh, Class.Floating a) => String -> Vector sh a -> Maybe String
assertOnes name =
   assert (name ++ ".Unit with non-unit diagonal elements") .
   all (Scalar.equal Scalar.one) . Vector.toList