{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Matrix.Plain.Class (
   Admissible(check),
   Homogeneous(zero, negate, scaleReal),
   ShapeOrder(forceOrder, shapeOrder), adaptOrder,
   Additive(add, sub),
   Complex(conjugate, fromReal, toComplex),
   SquareShape(toSquare, identityOrder, takeDiagonal),
   ) where

import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Box as Box
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Matrix.BandedHermitian.Basic as BandedHermitian
import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded
import qualified Numeric.LAPACK.Matrix.Triangular.Basic as Triangular
import qualified Numeric.LAPACK.Matrix.Hermitian.Basic as Hermitian
import qualified Numeric.LAPACK.Matrix.Square.Basic as Square
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Scalar as Scalar
import Numeric.LAPACK.Matrix.Private (Square)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, ComplexOf)
import Numeric.LAPACK.Wrapper (Flip(Flip, getFlip))

import qualified Numeric.Netlib.Class as Class

import qualified Type.Data.Num.Unary as Unary

import Control.Applicative ((<|>))

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array, (!))

import qualified Data.Complex as Complex
import Data.Functor.Compose (Compose(Compose, getCompose))
import Data.Maybe.HT (toMaybe)

import Prelude hiding (negate)



class (Shape.C shape) => Admissible shape where
   check :: (Class.Floating a) => Array shape a -> Maybe String

assert :: msg -> Bool -> Maybe msg
assert msg = flip toMaybe msg . not

instance
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
      Admissible (MatrixShape.Full vert horiz height width) where
   check _ = Nothing

instance (Shape.C size) => Admissible (MatrixShape.Hermitian size) where
   check =
      assert "Hermitian with non-real diagonal" .
      isReal . Triangular.takeDiagonal . Hermitian.takeUpper

instance
   (MatrixShape.Content lo, MatrixShape.TriDiag diag, MatrixShape.Content up,
    Shape.C size) =>
      Admissible (MatrixShape.Triangular lo diag up size) where
   check =
      getCheckDiag $
      MatrixShape.switchTriDiag
         (CheckDiag $ const Nothing)
         (CheckDiag checkUnitDiagonal)

newtype CheckDiag lo up sh a b diag =
   CheckDiag {getCheckDiag :: Triangular.Triangular lo diag up sh a -> b}

checkUnitDiagonal ::
   (Shape.C size, Class.Floating a,
   MatrixShape.Content lo, MatrixShape.TriDiag diag, MatrixShape.Content up) =>
   Triangular.Triangular lo diag up size a -> Maybe String
checkUnitDiagonal =
   assert "Triangular.Unit with non-unit diagonal elements" .
   all (Scalar.equal Scalar.one) . Vector.toList . Triangular.takeDiagonal

instance
   (Unary.Natural sub, Unary.Natural super, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
      Admissible (MatrixShape.Banded sub super vert horiz height width) where
   check arr0 =
      let arr =
            Banded.mapHeight Shape.Deferred $
            Banded.mapWidth Shape.Deferred $
            Banded.mapExtent Extent.toGeneral arr0
      in assert "Banded with non-zero unused elements" $
         all (Scalar.isZero . (arr!)) $
         filter
            (\ix -> case ix of MatrixShape.InsideBox _ _ -> False; _ -> True) $
         Shape.indices $ Array.shape arr

instance
   (Unary.Natural off, Shape.C size) =>
      Admissible (MatrixShape.BandedHermitian off size) where
   check arr =
      let u = BandedHermitian.takeUpper arr
      in check u <|>
         (assert "BandedHermitian with non-real diagonal" .
          isReal . Banded.takeDiagonal $ u)

isReal :: (Shape.C sh, Class.Floating a) => Vector sh a -> Bool
isReal =
   getFlip $ getCompose $
   Class.switchFloating
      (Compose $ Flip $ const True)
      (Compose $ Flip $ const True)
      (Compose $ Flip isComplexReal)
      (Compose $ Flip isComplexReal)

isComplexReal ::
   (Shape.C sh, Class.Real a) => Vector sh (Complex.Complex a) -> Bool
isComplexReal = all Scalar.isZero . Vector.toList . Vector.imaginaryPart


class (Shape.C shape) => Complex shape where
   conjugate ::
      (Class.Floating a) => Array shape a -> Array shape a
   conjugate = Vector.conjugate
   fromReal ::
      (Class.Floating a) =>
      Array shape (RealOf a) -> Array shape a
   fromReal = Vector.fromReal
   toComplex ::
      (Class.Floating a) =>
      Array shape a -> Array shape (ComplexOf a)
   toComplex = Vector.toComplex

instance
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
      Complex (MatrixShape.Full vert horiz height width) where

instance (Shape.C size) => Complex (MatrixShape.Hermitian size) where

instance
   (MatrixShape.Content lo, MatrixShape.TriDiag diag, MatrixShape.Content up,
    Shape.C size) =>
      Complex (MatrixShape.Triangular lo diag up size) where

instance
   (Unary.Natural sub, Unary.Natural super, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
      Complex (MatrixShape.Banded sub super vert horiz height width) where

instance
   (Unary.Natural off, Shape.C size) =>
      Complex (MatrixShape.BandedHermitian off size) where


class (Shape.C shape) => Homogeneous shape where
   zero :: (Class.Floating a) => shape -> Array shape a
   zero = Vector.zero
   negate :: (Class.Floating a) => Array shape a -> Array shape a
   negate = Vector.negate
   scaleReal :: (Class.Floating a) =>
      RealOf a -> Array shape a -> Array shape a
   scaleReal = Vector.scaleReal


instance
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
      Homogeneous (MatrixShape.Full vert horiz height width) where

instance (Shape.C size) => Homogeneous (MatrixShape.Hermitian size) where

instance
   (MatrixShape.Content lo, MatrixShape.NonUnit ~ diag, MatrixShape.Content up,
    Shape.C size) =>
      Homogeneous (MatrixShape.Triangular lo diag up size) where

instance
   (Unary.Natural sub, Unary.Natural super, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
      Homogeneous (MatrixShape.Banded sub super vert horiz height width) where

instance
   (Unary.Natural off, Shape.C size) =>
      Homogeneous (MatrixShape.BandedHermitian off size) where


class (Shape.C shape) => ShapeOrder shape where
   forceOrder ::
      (Class.Floating a) =>
      MatrixShape.Order -> Array shape a -> Array shape a
   shapeOrder :: shape -> MatrixShape.Order

{- |
@adaptOrder x y@ contains the data of @y@ with the layout of @x@.
-}
adaptOrder ::
   (ShapeOrder shape, Class.Floating a) =>
   Array shape a -> Array shape a -> Array shape a
adaptOrder = forceOrder . shapeOrder . Array.shape

instance
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
      ShapeOrder (MatrixShape.Full vert horiz height width) where
   forceOrder = Basic.forceOrder
   shapeOrder = MatrixShape.fullOrder

instance (Shape.C size) => ShapeOrder (MatrixShape.Hermitian size) where
   forceOrder = Hermitian.forceOrder
   shapeOrder = MatrixShape.hermitianOrder

instance
   (MatrixShape.Content lo,
    MatrixShape.TriDiag diag,
    MatrixShape.Content up, Shape.C size) =>
      ShapeOrder (MatrixShape.Triangular lo diag up size) where
   forceOrder = Triangular.forceOrder
   shapeOrder = MatrixShape.triangularOrder


class (Homogeneous shape, Eq shape) => Additive shape where
   add, sub ::
      (Class.Floating a) =>
      Array shape a -> Array shape a -> Array shape a
   sub a = add a . negate

instance
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width) =>
      Additive (MatrixShape.Full vert horiz height width) where
   add = addGen
   sub = subGen

instance (Shape.C size, Eq size) => Additive (MatrixShape.Hermitian size) where
   add = addGen
   sub = subGen

instance
   (MatrixShape.Content lo, Eq lo,
    MatrixShape.NonUnit ~ diag,
    MatrixShape.Content up, Eq up,
    Shape.C size, Eq size) =>
      Additive (MatrixShape.Triangular lo diag up size) where
   add = addGen
   sub = subGen

addGen, subGen ::
   (ShapeOrder shape, Eq shape, Class.Floating a) =>
   Array shape a -> Array shape a -> Array shape a
addGen a b = Vector.add (adaptOrder b a) b
subGen a b = Vector.sub (adaptOrder b a) b


class
   (Box.Box shape, Box.HeightOf shape ~ Box.WidthOf shape) =>
      SquareShape shape where
   toSquare ::
      (Box.HeightOf shape ~ sh, Class.Floating a) =>
      Array shape a -> Square sh a
   identityOrder ::
      (Box.HeightOf shape ~ sh, Class.Floating a) =>
      MatrixShape.Order -> sh -> Array shape a
   takeDiagonal ::
      (Box.HeightOf shape ~ sh, Class.Floating a) =>
      Array shape a -> Vector sh a

instance
   (Extent.Small ~ vert, Extent.Small ~ horiz,
    Shape.C height, height ~ width) =>
      SquareShape (MatrixShape.Full vert horiz height width) where
   toSquare = id
   identityOrder = Square.identityOrder
   takeDiagonal = Square.takeDiagonal

instance (Shape.C size) => SquareShape (MatrixShape.Hermitian size) where
   toSquare = Hermitian.toSquare
   identityOrder = Hermitian.identity
   takeDiagonal = Vector.fromReal . Hermitian.takeDiagonal

instance
   (MatrixShape.Content lo, MatrixShape.TriDiag diag, MatrixShape.Content up,
    Shape.C size) =>
      SquareShape (MatrixShape.Triangular lo diag up size) where
   toSquare = Triangular.toSquare
   identityOrder order =
      Triangular.relaxUnitDiagonal . Triangular.identity order
   takeDiagonal = Triangular.takeDiagonal

instance
   (Unary.Natural sub, Unary.Natural super,
    Extent.Small ~ vert, Extent.Small ~ horiz,
    Shape.C height, height ~ width) =>
      SquareShape
         (MatrixShape.Banded sub super vert horiz height width) where
   toSquare = Banded.toFull
   identityOrder = Banded.identityFatOrder
   takeDiagonal = Banded.takeDiagonal

instance
   (Unary.Natural offDiag, Shape.C size) =>
      SquareShape (MatrixShape.BandedHermitian offDiag size) where
   toSquare = Banded.toFull . BandedHermitian.toBanded
   identityOrder = BandedHermitian.identityFatOrder
   takeDiagonal = Vector.fromReal . BandedHermitian.takeDiagonal