{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
module Numeric.LAPACK.Matrix.Array.Divide where

import qualified Numeric.LAPACK.Matrix.Square.Linear
                                           as Square
import qualified Numeric.LAPACK.Matrix.Square.Basic
                                           as Square
import qualified Numeric.LAPACK.Matrix.Triangular.Linear
                                           as Triangular
import qualified Numeric.LAPACK.Matrix.Triangular.Basic
                                           as Triangular
import qualified Numeric.LAPACK.Matrix.Hermitian.Linear
                                           as Hermitian
import qualified Numeric.LAPACK.Matrix.Banded.Linear
                                           as Banded
import qualified Numeric.LAPACK.Matrix.Banded.Basic
                                           as Banded
import qualified Numeric.LAPACK.Matrix.BandedHermitianPositiveDefinite.Linear
                                           as BandedHermitianPositiveDefinite

import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Box as Box
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Extent.Private (Small)
import Numeric.LAPACK.Matrix.Basic (swapMultiply)
import Numeric.LAPACK.Matrix.Modifier (Transposition(Transposed, NonTransposed))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Type.Data.Num.Unary as Unary

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)


class
   (Box.Box shape, Box.HeightOf shape ~ Box.WidthOf shape) =>
      Solve shape where
   {-# MINIMAL solve | solveLeft,solveRight #-}
   solve ::
      (Class.Floating a, Box.HeightOf shape ~ height, Eq height,
       Extent.C horiz, Extent.C vert, Shape.C width) =>
      Transposition -> Array shape a ->
      Full vert horiz height width a -> Full vert horiz height width a
   solve NonTransposed a b = solveRight a b
   solve Transposed a b = Basic.transpose $ solveLeft (Basic.transpose b) a

   solveRight ::
      (Class.Floating a, Box.HeightOf shape ~ height, Eq height,
       Extent.C horiz, Extent.C vert, Shape.C width) =>
      Array shape a ->
      Full vert horiz height width a -> Full vert horiz height width a
   solveRight = solve NonTransposed

   solveLeft ::
      (Class.Floating a, Box.HeightOf shape ~ width, Eq width,
       Extent.C horiz, Extent.C vert, Shape.C height) =>
      Full vert horiz height width a ->
      Array shape a ->
      Full vert horiz height width a
   solveLeft = swapMultiply $ solve Transposed

class (Solve shape) => Inverse shape where
   inverse :: (Class.Floating a) => Array shape a -> Array shape a

solveVector ::
   (Solve shape, Box.HeightOf shape ~ height, Eq height, Class.Floating a) =>
   Transposition -> Array shape a -> Vector height a -> Vector height a
solveVector trans = Basic.unliftColumn MatrixShape.ColumnMajor . solve trans


instance
   (vert ~ Small, horiz ~ Small, Shape.C height, height ~ width) =>
      Solve (MatrixShape.Full vert horiz height width) where
   solveRight = Square.solve
   solveLeft = swapMultiply $ Square.solve . Square.transpose

instance
   (vert ~ Small, horiz ~ Small, Shape.C height, height ~ width) =>
      Inverse (MatrixShape.Full vert horiz height width) where
   inverse = Square.inverse


instance (Shape.C shape) => Solve (MatrixShape.Hermitian shape) where
   solveRight = Hermitian.solve
   solveLeft = swapMultiply $ Hermitian.solve . Vector.conjugate

instance (Shape.C shape) => Inverse (MatrixShape.Hermitian shape) where
   inverse = Hermitian.inverse


instance
   (MatrixShape.Content lo, MatrixShape.Content up,
    MatrixShape.TriDiag diag, Shape.C shape) =>
      Solve (MatrixShape.Triangular lo diag up shape) where
   solveRight = Triangular.solve
   solveLeft = swapMultiply $ Triangular.solve . Triangular.transpose

instance
   (MatrixShape.DiagUpLo lo up,
    MatrixShape.TriDiag diag, Shape.C shape) =>
      Inverse (MatrixShape.Triangular lo diag up shape) where
   inverse = Triangular.inverse


instance
   (Unary.Natural sub, Unary.Natural super, vert ~ Small, horiz ~ Small,
    Shape.C width, Shape.C height, width ~ height) =>
      Solve (MatrixShape.Banded sub super vert horiz height width) where
   solveRight = Banded.solve
   solveLeft = swapMultiply $ Banded.solve . Banded.transpose


{- |
There is no solver for indefinite matrices.
Thus the instance will fail for indefinite but solvable systems.
-}
instance
   (Unary.Natural offDiag, Shape.C size) =>
      Solve (MatrixShape.BandedHermitian offDiag size) where
   solveRight = BandedHermitianPositiveDefinite.solve
   solveLeft =
      swapMultiply $ BandedHermitianPositiveDefinite.solve . Vector.conjugate