{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Numeric.LAPACK.Matrix.Divide where

import qualified Numeric.LAPACK.Matrix.Multiply as Multiply
import qualified Numeric.LAPACK.Matrix.Array.Divide as ArrDivide
import qualified Numeric.LAPACK.Matrix.Array.Unpacked as Unpacked
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Permutation as PermMatrix
import qualified Numeric.LAPACK.Matrix.Type as Matrix
import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Array.Private (Full)
import Numeric.LAPACK.Matrix.Type (scaleWithCheck)
import Numeric.LAPACK.Matrix.Modifier
         (Transposition(NonTransposed,Transposed),
          Inversion(Inverted))
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable as Array
import qualified Data.Array.Comfort.Shape as Shape

import Data.Semigroup ((<>))


class (Matrix.Box typ) => Determinant typ xl xu where
   determinant ::
      (Omni.Strip lower, Omni.Strip upper) =>
      (Shape.C sh, Class.Floating a) =>
      Matrix.Quadratic typ xl xu lower upper sh a -> a

class (Matrix.Box typ) => Solve typ xl xu where
   {-# MINIMAL solve | solveLeft,solveRight #-}
   solve ::
      (Omni.Strip lower, Omni.Strip upper) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Eq height, Class.Floating a) =>
      Transposition ->
      Matrix.Quadratic typ xl xu lower upper height a ->
      Full meas vert horiz height width a ->
      Full meas vert horiz height width a
   solve NonTransposed a b = solveRight a b
   solve Transposed a b =
      Unpacked.transpose $ solveLeft (Unpacked.transpose b) a

   solveRight ::
      (Omni.Strip lower, Omni.Strip upper) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Eq height, Class.Floating a) =>
      Matrix.Quadratic typ xl xu lower upper height a ->
      Full meas vert horiz height width a ->
      Full meas vert horiz height width a
   solveRight = solve NonTransposed

   solveLeft ::
      (Omni.Strip lower, Omni.Strip upper) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
      Full meas vert horiz height width a ->
      Matrix.Quadratic typ xl xu lower upper width a ->
      Full meas vert horiz height width a
   solveLeft = Unpacked.swapMultiply $ solve Transposed

class (Solve typ xl xu) => Inverse typ xl xu where
   inverse ::
      (MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper,
       Extent.Measure meas, Shape.C height, Shape.C width, Class.Floating a) =>
      Matrix.QuadraticMeas typ xl xu lower upper meas height width a ->
      Matrix.QuadraticMeas typ xl xu lower upper meas width height a

infixl 7 ##/#
infixr 7 #\##

(#\##) ::
   (Solve typ xl xu, Matrix.ToQuadratic typ,
    Omni.Strip lower, Omni.Strip upper,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs,
    Extent.Measure measA, Extent.Measure measB, Extent.Measure measC,
    Extent.MultiplyMeasure measA measB ~ measC,
    Extent.C vert, Extent.C horiz, Class.Floating a) =>
   Matrix.QuadraticMeas typ xl xu lower upper measA height width a ->
   Full measB vert horiz height nrhs a -> Full measC vert horiz width nrhs a
a#\##b =
   case Multiply.factorIdentityRight a of
      (q, ident) ->
         Multiply.reshapeHeight (Matrix.transpose ident) (solveRight q b)

(##/#) ::
   (Solve typ xl xu, Matrix.ToQuadratic typ,
    Omni.Strip lower, Omni.Strip upper,
    Shape.C height, Shape.C width, Eq width, Shape.C nrhs,
    Extent.Measure measA, Extent.Measure measB, Extent.Measure measC,
    Extent.MultiplyMeasure measA measB ~ measC,
    Extent.C vert, Extent.C horiz, Class.Floating a) =>
   Full measB vert horiz nrhs width a ->
   Matrix.QuadraticMeas typ xl xu lower upper measA height width a ->
   Full measC vert horiz nrhs height a
b##/#a =
   case Multiply.factorIdentityLeft a of
      (ident, q) ->
         Multiply.reshapeWidth (solveLeft b q) (Matrix.transpose ident)


solveVector ::
   (Solve typ xl xu, Omni.Strip lower, Omni.Strip upper,
    Shape.C sh, Eq sh, Class.Floating a) =>
   Transposition ->
   Matrix.Quadratic typ xl xu lower upper sh a ->
   Vector sh a -> Vector sh a
solveVector trans =
   ArrMatrix.unliftColumn Layout.ColumnMajor . solve trans


infixl 7 -/#
infixr 7 #\|

(#\|) ::
   (Solve typ xl xu, Matrix.ToQuadratic typ,
    Omni.Strip lower, Omni.Strip upper, Extent.Measure meas,
    Shape.C height, Shape.C width, Eq height, Class.Floating a) =>
   Matrix.QuadraticMeas typ xl xu lower upper meas height width a ->
   Vector height a -> Vector width a
(#\|) a =
   case Multiply.factorIdentityRight a of
      (q, ident) ->
         reshapeVector (Matrix.transpose ident) . solveVector NonTransposed q

(-/#) ::
   (Solve typ xl xu, Matrix.ToQuadratic typ,
    Omni.Strip lower, Omni.Strip upper, Extent.Measure meas,
    Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Vector width a ->
   Matrix.QuadraticMeas typ xl xu lower upper meas height width a ->
   Vector height a
(-/#) = flip $ \a ->
   case Multiply.factorIdentityLeft a of
      (ident, q) -> reshapeVector ident . solveVector Transposed q


reshapeVector ::
   (Extent.Measure meas, Shape.C height, Shape.C width) =>
   Multiply.IdentityMaes meas height width a ->
   Vector width a -> Vector height a
reshapeVector (Matrix.Identity extent) = Array.reshape (Extent.height extent)


instance (xl ~ (), xu ~ ()) => Determinant Matrix.Scale xl xu where
   determinant (Matrix.Scale sh a) = a ^ Shape.size sh

instance (xl ~ (), xu ~ ()) => Solve Matrix.Scale xl xu where
   solve _trans =
      scaleWithCheck "Matrix.Scale.solve" Matrix.height $
         ArrMatrix.lift1 . Vector.scale . recip

instance (xl ~ (), xu ~ ()) => Inverse Matrix.Scale xl xu where
   inverse (Matrix.Scale shape a) = Matrix.Scale shape $ recip a


instance (xl ~ (), xu ~ ()) => Determinant Matrix.Permutation xl xu where
   determinant = PermMatrix.determinant

instance (xl ~ (), xu ~ ()) => Solve Matrix.Permutation xl xu where
   solve trans =
      PermMatrix.multiplyFull
         (Inverted <> PermMatrix.inversionFromTransposition trans)

instance (xl ~ (), xu ~ ()) => Inverse Matrix.Permutation xl xu where
   inverse a@(Matrix.Permutation _) =
      case Matrix.powerStrips a of
         (MatrixShape.Filled, MatrixShape.Filled) -> PermMatrix.transpose a
         _ -> a -- identity matrix

instance
   (Layout.Packing pack, Omni.Property property, xl ~ (), xu ~ ()) =>
      Determinant (ArrMatrix.Array pack property) xl xu where
   determinant = ArrDivide.determinant

instance
   (Layout.Packing pack, Omni.Property property, xl ~ (), xu ~ ()) =>
      Solve (ArrMatrix.Array pack property) xl xu where
   solveRight = ArrDivide.solve
   solveLeft = Matrix.swapMultiply $  ArrDivide.solve . Matrix.transpose

instance
   (Layout.Packing pack, Omni.Property property, xl ~ (), xu ~ ()) =>
      Inverse (ArrMatrix.Array pack property) xl xu where
   inverse = ArrDivide.inverse