{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Banded (
   Banded,
   General,
   Square,
   Upper,
   Lower,
   Diagonal,
   Hermitian,
   height, width,
   fromList,
   squareFromList,
   lowerFromList,
   upperFromList,
   mapExtent,
   diagonal,
   takeDiagonal,
   toFull,
   toLowerTriangular,
   toUpperTriangular,
   transpose,
   adjoint,
   multiplyVector,
   multiply,
   multiplyFull,

   solve,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Banded.Linear as Linear
import qualified Numeric.LAPACK.Matrix.Banded.Basic as Basic

import qualified Numeric.LAPACK.Matrix.Array.Triangular as Tri
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Array.Banded
         (Banded, General, Square, Lower, Upper, Diagonal, Hermitian)
import Numeric.LAPACK.Matrix.Array (Full)
import Numeric.LAPACK.Matrix.Shape.Private (Order, UnaryProxy)
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Type.Data.Num.Unary as Unary
import Type.Data.Num.Unary ((:+:))

import qualified Data.Array.Comfort.Shape as Shape

import Foreign.Storable (Storable)


height ::
   (Extent.C vert, Extent.C horiz) =>
   Banded sub super vert horiz height width a -> height
height = MatrixShape.bandedHeight . ArrMatrix.shape

width ::
   (Extent.C vert, Extent.C horiz) =>
   Banded sub super vert horiz height width a -> width
width = MatrixShape.bandedWidth . ArrMatrix.shape



fromList ::
   (Unary.Natural sub, Unary.Natural super,
    Shape.C height, Shape.C width, Storable a) =>
   (UnaryProxy sub, UnaryProxy super) -> Order -> height -> width -> [a] ->
   General sub super height width a
fromList offDiag order height_ width_ =
   ArrMatrix.lift0 . Basic.fromList offDiag order height_ width_

squareFromList ::
   (Unary.Natural sub, Unary.Natural super, Shape.C size, Storable a) =>
   (UnaryProxy sub, UnaryProxy super) -> Order -> size -> [a] ->
   Square sub super size a
squareFromList offDiag order size =
   ArrMatrix.lift0 . Basic.squareFromList offDiag order size

lowerFromList ::
   (Unary.Natural sub, Shape.C size, Storable a) =>
   UnaryProxy sub -> Order -> size -> [a] -> Lower sub size a
lowerFromList numOff order size =
   ArrMatrix.lift0 . Basic.lowerFromList numOff order size

upperFromList ::
   (Unary.Natural super, Shape.C size, Storable a) =>
   UnaryProxy super -> Order -> size -> [a] -> Upper super size a
upperFromList numOff order size =
   ArrMatrix.lift0 . Basic.upperFromList numOff order size

mapExtent ::
   (Extent.C vertA, Extent.C horizA) =>
   (Extent.C vertB, Extent.C horizB) =>
   Extent.Map vertA horizA vertB horizB height width ->
   Banded super sub vertA horizA height width a ->
   Banded super sub vertB horizB height width a
mapExtent = ArrMatrix.lift1 . Basic.mapExtent

transpose ::
   (Extent.C vert, Extent.C horiz) =>
   Banded sub super vert horiz height width a ->
   Banded super sub horiz vert width height a
transpose = ArrMatrix.lift1 Basic.transpose

adjoint ::
   (Unary.Natural super, Unary.Natural sub, Extent.C vert, Extent.C horiz,
    Shape.C width, Shape.C height, Class.Floating a) =>
   Banded sub super vert horiz height width a ->
   Banded super sub horiz vert width height a
adjoint = ArrMatrix.lift1 Basic.adjoint

diagonal ::
   (Shape.C sh, Class.Floating a) => Order -> Vector sh a -> Diagonal sh a
diagonal order = ArrMatrix.lift0 . Basic.diagonal order

takeDiagonal ::
   (Unary.Natural sub, Unary.Natural super, Shape.C sh, Class.Floating a) =>
   Square sub super sh a -> Vector sh a
takeDiagonal = Basic.takeDiagonal . ArrMatrix.toVector

multiplyVector ::
   (Unary.Natural sub, Unary.Natural super,
    Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Eq width,
    Class.Floating a) =>
   Banded sub super vert horiz height width a ->
   Vector width a -> Vector height a
multiplyVector = Basic.multiplyVector . ArrMatrix.toVector

multiply ::
   (Unary.Natural subA, Unary.Natural superA,
    Unary.Natural subB, Unary.Natural superB,
    (subA :+: subB) ~ subC,
    (superA :+: superB) ~ superC,
    Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Shape.C fuse, Eq fuse,
    Class.Floating a) =>
   Banded subA superA vert horiz height fuse a ->
   Banded subB superB vert horiz fuse width a ->
   Banded subC superC vert horiz height width a
multiply = ArrMatrix.lift2 Basic.multiply

multiplyFull ::
   (Unary.Natural sub, Unary.Natural super,
    Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Shape.C fuse, Eq fuse,
    Class.Floating a) =>
   Banded sub super vert horiz height fuse a ->
   Full vert horiz fuse width a -> Full vert horiz height width a
multiplyFull = ArrMatrix.lift2 Basic.multiplyFull

toLowerTriangular ::
   (Unary.Natural sub, Shape.C sh, Class.Floating a) =>
   Lower sub sh a -> Tri.Lower sh a
toLowerTriangular = ArrMatrix.lift1 Basic.toLowerTriangular

toUpperTriangular ::
   (Unary.Natural super, Shape.C sh, Class.Floating a) =>
   Upper super sh a -> Tri.Upper sh a
toUpperTriangular = ArrMatrix.lift1 Basic.toUpperTriangular

toFull ::
   (Unary.Natural sub, Unary.Natural super,
    Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Banded sub super vert horiz height width a ->
   Full vert horiz height width a
toFull = ArrMatrix.lift1 Basic.toFull



solve ::
   (Unary.Natural sub, Unary.Natural super, Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Square sub super sh a ->
   Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solve = ArrMatrix.lift2 Linear.solve

determinant ::
   (Unary.Natural sub, Unary.Natural super, Shape.C sh, Class.Floating a) =>
   Square sub super sh a -> a
determinant = Linear.determinant . ArrMatrix.toVector