{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Numeric.LAPACK.Matrix.Block where

import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Matrix.Class as MatrixClass
import qualified Numeric.LAPACK.Matrix.Divide as Divide
import qualified Numeric.LAPACK.Matrix.Multiply as Multiply
import qualified Numeric.LAPACK.Matrix.Square as Square
import qualified Numeric.LAPACK.Matrix.Array.Unpacked as Unpacked
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Layout as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.Netlib.Class as Class
import Numeric.LAPACK.Matrix.Divide (determinant)
import Numeric.LAPACK.Matrix.Type (Matrix, Quadratic, extent, squareSize)
import Numeric.LAPACK.Matrix.Layout.Private (Filled)
import Numeric.LAPACK.Matrix.Extent.Private (Size, Big)
import Numeric.LAPACK.Matrix ((#*#), (#-#), (===), (|||))
import Numeric.LAPACK.Vector ((|+|))
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked), deconsUnchecked)

import qualified Data.Array.Comfort.Storable as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((::+)((::+)))



data Square typ00 typ01 typ10 typ11
data instance
   Matrix (Square typ00 typ01 typ10 typ11) xl xu
      lower upper meas vert horiz height width a where
   Square ::
      (Extent.Measure measOff, Extent.C vertOff, Extent.C horizOff,
       Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1) =>
      Quadratic typ00 xl00 xu00 Filled Filled sh0 a ->
      Matrix typ01 xl01 xu01 Filled Filled measOff vertOff horizOff sh0 sh1 a ->
      Matrix typ10 xl10 xu10 Filled Filled measOff horizOff vertOff sh1 sh0 a ->
      Quadratic typ11 xl11 xu11 Filled Filled sh1 a ->
      Quadratic
         (Square typ00 typ01 typ10 typ11)
         (xl00,xl10,xl01,xl11) (xu00,xu01,xu10,xu11)
         Filled Filled (sh0::+sh1) a

instance
   (Matrix.Box typ00, Matrix.Box typ11) =>
      Matrix.Box (Square typ00 typ01 typ10 typ11) where
   extent (Square a _b _c d) = Extent.square (squareSize a ::+ squareSize d)

transposeSquare ::
   (Matrix.Transpose typ00, Matrix.Transpose typ01) =>
   (Matrix.Transpose typ10, Matrix.Transpose typ11) =>
   (Class.Floating a) =>
   Matrix (Square typ00 typ01 typ10 typ11) xl xu
      lower upper meas vert horiz height width a ->
   Matrix (Square typ00 typ10 typ01 typ11) xu xl
      upper lower meas horiz vert width height a
transposeSquare (Square a b c d) =
   Square
      (Matrix.transpose a) (Matrix.transpose c)
      (Matrix.transpose b) (Matrix.transpose d)

instance
   (xl ~ (xl00,xl10,xl01,xl11),
    xu ~ (xu00,xu01,xu10,xu11),
    Multiply.MultiplyVector typ00 xl00 xu00,
    Multiply.MultiplyVector typ01 xl01 xu01,
    Multiply.MultiplyVector typ10 xl10 xu10,
    Multiply.MultiplyVector typ11 xl11 xu11) =>
      Multiply.MultiplyVector (Square typ00 typ01 typ10 typ11) xl xu where
   matrixVector (Square a b c d) x =
      let (x0,x1) = Array.split x
      in Array.append
            (Multiply.matrixVector a x0 |+| Multiply.matrixVector b x1)
            (Multiply.matrixVector c x0 |+| Multiply.matrixVector d x1)
   vectorMatrix x (Square a b c d) =
      let (x0,x1) = Array.split x
      in Array.append
            (Multiply.vectorMatrix x0 a |+| Multiply.vectorMatrix x1 c)
            (Multiply.vectorMatrix x0 b |+| Multiply.vectorMatrix x1 d)


type TypeFull = ArrMatrix.Array Layout.Unpacked Omni.Arbitrary

schurComplement ::
   (Divide.Solve typ11 xl11 xu11,
    Omni.Strip lower, Omni.Strip upper, Class.Floating a) =>
   Quadratic
      (Square TypeFull TypeFull TypeFull typ11)
      ((),(),(),xl11) ((),(),(),xu11)
      lower upper (sh0::+sh1) a ->
   Square.Square sh0 a
schurComplement (Square a b c d) =
   Unpacked.fillBoth a
   #-#
   Square.fromFull
      (Matrix.fromFull b #*# Matrix.fromFull (Divide.solveRight d c))

{- |
Requires that the right bottom sub-matrix is invertible.
-}
instance
   (xl ~ ((),(),(),xl11),
    xu ~ ((),(),(),xu11),
    typ00 ~ TypeFull,
    typ01 ~ TypeFull,
    typ10 ~ TypeFull,
    Divide.Solve typ11 xl11 xu11,
    Divide.Determinant typ11 xl11 xu11) =>
      Divide.Determinant (Square typ00 typ01 typ10 typ11) xl xu where
   determinant sq@(Square _a _b _c d) =
      determinant d * determinant (schurComplement sq)


withoutHeightCheck ::
   (MatrixClass.MapSize typ0, Extent.C vert0, Extent.C horiz0) =>
   (MatrixClass.MapSize typ1, Extent.C vert1, Extent.C horiz1) =>
   (MatrixClass.MapSize typ2, Extent.C vert2, Extent.C horiz2) =>
   (Matrix typ0 xl0 xu0 lower0 upper0 Size vert0 horiz0 ~ matrix0) =>
   (Matrix typ1 xl1 xu1 lower1 upper1 Size vert1 horiz1 ~ matrix1) =>
   (Matrix typ2 xl2 xu2 lower2 upper2 Size vert2 horiz2 ~ matrix2) =>
   (Shape.C height0, Shape.C width0) =>
   (Shape.C height1, Shape.C width1) =>
   (Shape.C height2, Shape.C width2) =>
   (matrix0 (Unchecked height0) width0 a0 ->
    matrix1 (Unchecked height1) width1 a1 ->
    matrix2 (Unchecked height2) width2 a2) ->
   matrix0 height0 width0 a0 ->
   matrix1 height1 width1 a1 ->
   matrix2 height2 width2 a2
withoutHeightCheck op a b =
   Matrix.mapHeight deconsUnchecked $
   Matrix.mapHeight Unchecked a `op` Matrix.mapHeight Unchecked b

withoutWidthCheck ::
   (MatrixClass.MapSize typ0, Extent.C vert0, Extent.C horiz0) =>
   (MatrixClass.MapSize typ1, Extent.C vert1, Extent.C horiz1) =>
   (MatrixClass.MapSize typ2, Extent.C vert2, Extent.C horiz2) =>
   (Matrix typ0 xl0 xu0 lower0 upper0 Size vert0 horiz0 ~ matrix0) =>
   (Matrix typ1 xl1 xu1 lower1 upper1 Size vert1 horiz1 ~ matrix1) =>
   (Matrix typ2 xl2 xu2 lower2 upper2 Size vert2 horiz2 ~ matrix2) =>
   (Shape.C height0, Shape.C width0) =>
   (Shape.C height1, Shape.C width1) =>
   (Shape.C height2, Shape.C width2) =>
   (matrix0 height0 (Unchecked width0) a0 ->
    matrix1 height1 (Unchecked width1) a1 ->
    matrix2 height2 (Unchecked width2) a2) ->
   matrix0 height0 width0 a0 ->
   matrix1 height1 width1 a1 ->
   matrix2 height2 width2 a2
withoutWidthCheck op a b =
   Matrix.mapWidth deconsUnchecked $
   Matrix.mapWidth Unchecked a `op` Matrix.mapWidth Unchecked b

{- |
Requires that the right bottom sub-matrix is invertible.
-}
instance
   (xl ~ ((),(),(),xl11),
    xu ~ ((),(),(),xu11),
    typ00 ~ TypeFull,
    typ01 ~ TypeFull,
    typ10 ~ TypeFull,
    Divide.Solve typ11 xl11 xu11) =>
      Divide.Solve (Square typ00 typ01 typ10 typ11) xl xu where
   solveRight sq@(Square _a b c d) x =
      let x0 = Matrix.takeTop    $ Matrix.fromFull x
          x1 = Matrix.takeBottom $ Matrix.fromFull x
          xComplement =
            withoutWidthCheck (#-#) x0 $
               Matrix.fromFull b #*# Divide.solveRight d x1
          y = Divide.solveRight (schurComplement sq) xComplement
      in -- ToDo: does it always has correct order?
         ArrMatrix.reshape (ArrMatrix.shape x) $
         withoutWidthCheck (===) y
            (Divide.solveRight d $
             withoutWidthCheck (#-#) x1 $ Matrix.fromFull c #*# y)
   solveLeft x sq@(Square _a b c d) =
      let x0 = Matrix.takeLeft  $ Matrix.fromFull x
          x1 = Matrix.takeRight $ Matrix.fromFull x
          xComplement =
            withoutHeightCheck (#-#) x0 $
               Divide.solveLeft x1 d #*# Matrix.fromFull c
          y = Divide.solveLeft xComplement (schurComplement sq)
      in -- ToDo: does it always has correct order?
         ArrMatrix.reshape (ArrMatrix.shape x) $
         withoutHeightCheck (|||) y
            (Divide.solveLeft
               (withoutHeightCheck (#-#) x1 $ y #*# Matrix.fromFull b)
               d)

{-
instance
   (xl ~ ((),(),(),xl11),
    xu ~ ((),(),(),xu11),
    typ00 ~ TypeFull,
    typ01 ~ TypeFull,
    typ10 ~ TypeFull,
    Divide.Inverse typ11 xl11 xu11) =>
      Divide.Inverse (Square typ00 typ01 typ10 typ11) xl xu where
   inverse (Square a b c d) =
      let as = a #-# b #*# Divide.solveRight d c
          bdinv = Divide.solveLeft b d
          dinvc = Divide.solveRight d c
          br = Divide.solveRight as bdinv
          cr = Divide.solveLeft dinvc as
      in Square
            (inverse as)
            (Matrix.negate br)
            (Matrix.negate cr)
            (inverse d #+# br #*# bdinv)
-}


data Mosaic typ0 typOff typ1
data instance
   Matrix (Mosaic typ0 typOff typ1) xl xu
      lower upper meas vert horiz height width a where
   Mosaic ::
      Quadratic typ0 xl0 xu0 lower upper sh0 a ->
      Matrix typOff xlOff xuOff Filled Filled Size Big Big sh0 sh1 a ->
      Quadratic typ1 xl1 xu1 lower upper sh1 a ->
      Quadratic
         (Mosaic typ0 typOff typ1)
         (xl0,xlOff,xl1) (xu0,xuOff,xu1)
         lower upper (sh0::+sh1) a