{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Class (
   SquareShape(toSquare, takeDiagonal, mapSquareSize, identityFrom),
   MapSize(mapHeight, mapWidth),
   trace,
   Complex(conjugate, fromReal, toComplex),
   adjoint,
   Unpack(unpack), toFull,
   ) where

import qualified Numeric.LAPACK.Matrix.Array.Basic as OmniMatrix
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Type as Matrix
import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.Permutation as Permutation
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Permutation as PermPub
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Scalar as Scalar
import Numeric.LAPACK.Matrix.Type (Matrix)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, ComplexOf)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape


class Complex typ where
   conjugate ::
      (Matrix typ xl xu lower upper meas vert horiz height width ~ matrix,
       Extent.Measure meas, Extent.C vert, Extent.C horiz,
       Shape.C height, Shape.C width, Class.Floating a) =>
      matrix a -> matrix a
   fromReal ::
      (Matrix typ xl xu lower upper meas vert horiz height width ~ matrix,
       Extent.Measure meas, Extent.C vert, Extent.C horiz,
       Shape.C height, Shape.C width, Class.Floating a) =>
      matrix (RealOf a) -> matrix a
   toComplex ::
      (Matrix typ xl xu lower upper meas vert horiz height width ~ matrix,
       Extent.Measure meas, Extent.C vert, Extent.C horiz,
       Shape.C height, Shape.C width, Class.Floating a) =>
      matrix a -> matrix (ComplexOf a)

instance Complex (ArrMatrix.Array pack property) where
   conjugate (ArrMatrix.Array a) = ArrMatrix.Array $ Vector.conjugate a
   fromReal  (ArrMatrix.Array a) = ArrMatrix.Array $ Vector.fromReal  a
   toComplex (ArrMatrix.Array a) = ArrMatrix.Array $ Vector.toComplex a

instance Complex Matrix.Scale where
   conjugate (Matrix.Scale sh m) = Matrix.Scale sh $ Scalar.conjugate m
   fromReal (Matrix.Scale sh m) = Matrix.Scale sh $ Scalar.fromReal m
   toComplex (Matrix.Scale sh m) = Matrix.Scale sh $ Scalar.toComplex m

instance Complex Matrix.Permutation where
   conjugate = id
   fromReal (Matrix.Permutation p) = Matrix.Permutation p
   toComplex (Matrix.Permutation p) = Matrix.Permutation p

adjoint ::
   (Matrix.Transpose typ, Complex typ) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   Matrix typ xu xl upper lower meas horiz vert width height a
adjoint = conjugate . Matrix.transpose


class (Matrix.Box typ) => SquareShape typ where
   toSquare ::
      (Shape.C sh, Class.Floating a) =>
      Matrix.Quadratic typ xl xu lower upper sh a -> ArrMatrix.Square sh a
   takeDiagonal ::
      (Shape.C sh, Class.Floating a) =>
      Matrix.Quadratic typ xl xu lower upper sh a -> Vector sh a
   {- |
   The number of rows and columns
   must be maintained by the shape mapping function.
   -}
   mapSquareSize ::
      (Shape.C shA, Shape.C shB) =>
      (shA -> shB) ->
      Matrix.Quadratic typ xl xu lower upper shA a ->
      Matrix.Quadratic typ xl xu lower upper shB a
   identityFrom ::
      (Shape.C sh, Class.Floating a) =>
      Matrix.Quadratic typ xl xu lower upper sh a ->
      Matrix.Quadratic typ xl xu lower upper sh a

instance SquareShape (ArrMatrix.Array pack property) where
   toSquare a@(ArrMatrix.Array _) = OmniMatrix.toFull a
   takeDiagonal a@(ArrMatrix.Array _) = OmniMatrix.takeDiagonal a
   mapSquareSize f a@(ArrMatrix.Array _) = OmniMatrix.mapSquareSize f a
   identityFrom a@(ArrMatrix.Array _) = OmniMatrix.identityFrom a

instance SquareShape Matrix.Scale where
   toSquare (Matrix.Scale sh a) =
      ArrMatrix.lift0 $ Banded.toFull $
      Banded.diagonal Layout.RowMajor $ Vector.constant sh a
   takeDiagonal (Matrix.Scale sh a) = Vector.constant sh a
   mapSquareSize f (Matrix.Scale sh a) =
      Matrix.Scale (Layout.mapChecked "Scale.mapSquareSize" f sh) a
   identityFrom (Matrix.Scale sh _a) = Matrix.Scale sh Scalar.one

instance SquareShape Matrix.Permutation where
   toSquare (Matrix.Permutation perm) = PermPub.toMatrix perm
   takeDiagonal a@(Matrix.Permutation _) =
      Perm.takeDiagonal . Permutation.toPermutation $ a
   mapSquareSize f (Matrix.Permutation perm) =
      Matrix.Permutation $ Perm.mapSize f perm
   identityFrom (Matrix.Permutation perm) =
      Matrix.Permutation $ Perm.identity $ Perm.size perm


trace ::
   (SquareShape typ, Shape.C sh, Class.Floating a) =>
   Matrix.Quadratic typ xl xu lower upper sh a -> a
trace = Vector.sum . takeDiagonal



class (Matrix.Box typ) => MapSize typ where
   {- |
   The number of rows and columns
   must be maintained by the shape mapping function.
   -}
   mapHeight ::
      (Extent.C vert, Extent.C horiz,
       Shape.C heightA, Shape.C heightB, Shape.C width) =>
      (heightA -> heightB) ->
      Matrix typ extraLower extraUpper lower upper
         Extent.Size vert horiz heightA width a ->
      Matrix typ extraLower extraUpper lower upper
         Extent.Size vert horiz heightB width a
   mapWidth ::
      (Extent.C vert, Extent.C horiz,
       Shape.C height, Shape.C widthA, Shape.C widthB) =>
      (widthA -> widthB) ->
      Matrix typ extraLower extraUpper lower upper
         Extent.Size vert horiz height widthA a ->
      Matrix typ extraLower extraUpper lower upper
         Extent.Size vert horiz height widthB a

instance MapSize (ArrMatrix.Array pack property) where
   mapHeight f a@(ArrMatrix.Array _) = OmniMatrix.mapHeight f a
   mapWidth f a@(ArrMatrix.Array _) = OmniMatrix.mapWidth f a


class Unpack typ where
   -- In contrast to OmniMatrix.unpack it cannot maintain the matrix property.
   unpack ::
      (Omni.Strip lower, Omni.Strip upper) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Class.Floating a) =>
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      ArrMatrix.ArrayMatrix Layout.Unpacked Omni.Arbitrary
         lower upper meas vert horiz height width a

instance (Omni.Property prop) => Unpack (ArrMatrix.Array pack prop) where
   unpack a@(ArrMatrix.Array _) =
      ArrMatrix.liftUnpacked1 id $ OmniMatrix.unpack a

instance Unpack Matrix.Scale where
   unpack (Matrix.Scale sh a) =
      ArrMatrix.liftUnpacked0 $ Banded.toFull $
      Banded.diagonal Layout.RowMajor $ Vector.constant sh a

instance Unpack Matrix.Permutation where
   unpack (Matrix.Permutation perm) =
      ArrMatrix.liftUnpacked1 id $ PermPub.toMatrix perm

toFull ::
   (Unpack typ) =>
   (Omni.Strip lower, Omni.Strip upper) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   ArrMatrix.Full meas vert horiz height width a
toFull = OmniMatrix.toFull . unpack