{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE StandaloneDeriving #-}
module Numeric.LAPACK.Matrix.Type where

import qualified Numeric.LAPACK.Matrix.Plain.Format as ArrFormat
import qualified Numeric.LAPACK.Output as Output
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Extent.Strict as ExtentStrict
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Layout.Private (Empty, Filled)
import Numeric.LAPACK.Matrix.Extent.Private (Extent, Shape, Small)
import Numeric.LAPACK.Output (Output)

import qualified Numeric.Netlib.Class as Class

import qualified Hyper

import qualified Control.DeepSeq as DeepSeq

import qualified Data.Array.Comfort.Shape as Shape

import Data.Function.HT (Id)
import Data.Monoid (Monoid, mempty, mappend)
import Data.Semigroup (Semigroup, (<>))



data family
   Matrix typ extraLower extraUpper lower upper meas vert horiz height width a

type Quadratic typ extraLower extraUpper lower upper sh =
      QuadraticMeas typ extraLower extraUpper lower upper Shape sh sh
type QuadraticMeas typ extraLower extraUpper lower upper meas =
      Matrix typ extraLower extraUpper lower upper meas Small Small


asQuadratic ::
   Id (QuadraticMeas typ extraLower extraUpper lower upper meas height width a)
asQuadratic = id


data Product fuse
data instance
   Matrix (Product fuse) xl xu lower upper meas vert horiz height width a where
      Product ::
         (Omni.MultipliedBands lowerA lowerB ~ lowerC,
          Omni.MultipliedBands lowerB lowerA ~ lowerC,
          Omni.MultipliedBands upperA upperB ~ upperC,
          Omni.MultipliedBands upperB upperA ~ upperC) =>
         Matrix typA xlA xuA lowerA upperA meas vert horiz height fuse a ->
         Matrix typB xlB xuB lowerB upperB meas vert horiz fuse width a ->
         Matrix (Product fuse)
            (typA,lowerA,lowerB,xlA,xlB) (typB,upperB,upperA,xuB,xuA)
            lowerC upperC meas vert horiz height width a


data Scale
data instance
   Matrix Scale xl xu lower upper meas vert horiz height width a where
      Scale :: sh -> a -> Quadratic Scale () () Empty Empty sh a

deriving instance
   (Shape.C height, Show height, Show a) =>
   Show (Matrix Scale xl xu lower upper meas vert horiz height width a)


data Identity
data instance
   Matrix Identity xl xu lower upper meas vert horiz height width a where
      Identity ::
         (Extent.Measure meas) =>
         Extent meas Small Small height width ->
         QuadraticMeas Identity () () Empty Empty meas height width a


data Permutation
data instance
   Matrix Permutation xl xu lower upper meas vert horiz height width a where
   Permutation ::
      Perm.Permutation sh -> Quadratic Permutation () () lower upper sh a

deriving instance
   (Shape.C height, Show height) =>
   Show (Matrix Permutation xl xu lower upper meas vert horiz height width a)

deriving instance
   (Shape.C height, Eq height) =>
   Eq (Matrix Permutation xl xu lower upper meas vert horiz height width a)


instance
   (NFData typ,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    DeepSeq.NFData height, DeepSeq.NFData width, DeepSeq.NFData a) =>
   DeepSeq.NFData (Matrix typ xl xu lower upper meas vert horiz height width a)
      where
   rnf = rnf

class NFData typ where
   rnf ::
      (Extent.Measure meas, Extent.C vert, Extent.C horiz,
       DeepSeq.NFData height, DeepSeq.NFData width, DeepSeq.NFData a) =>
      Matrix typ xl xu lower upper meas vert horiz height width a -> ()



instance
   (FormatMatrix typ, Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C width, Shape.C height, Class.Floating a) =>
      Hyper.Display
         (Matrix typ xl xu lower upper meas vert horiz height width a) where
   display = Output.hyper . formatMatrix ArrFormat.deflt


class FormatMatrix typ where
   {-
   We use constraint @(Class.Floating a)@ and not @(Format a)@
   because it allows us to align the components of complex numbers.
   -}
   formatMatrix ::
      (Extent.Measure meas, Extent.C vert, Extent.C horiz,
       Shape.C width, Shape.C height, Class.Floating a, Output out) =>
      String ->
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      out

instance FormatMatrix Scale where
   formatMatrix fmt (Scale shape a) =
      ArrFormat.formatDiagonal fmt Layout.RowMajor shape $
      replicate (Shape.size shape) a

instance FormatMatrix Permutation where
   formatMatrix _fmt (Permutation perm) = Perm.format perm



instance
   (MultiplySame typ xl xu,
    MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, height ~ width, Class.Floating a) =>
      Semigroup (Matrix typ xl xu lower upper meas vert horiz height width a)
         where
   (<>) = multiplySame

class (Box typ) => MultiplySame typ xl xu where
   multiplySame ::
      (matrix ~ Matrix typ xl xu lower upper meas vert horiz sh sh a,
       MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper,
       Extent.Measure meas, Extent.C vert, Extent.C horiz,
       Shape.C sh, Eq sh, Class.Floating a) =>
      matrix -> matrix -> matrix

instance (xl ~ (), xu ~ ()) => MultiplySame Scale xl xu where
   multiplySame =
      scaleWithCheck "Scale.multiplySame" height
         (\a (Scale shape b) -> Scale shape $ a*b)

instance (xl ~ (), xu ~ ()) => MultiplySame Permutation xl xu where
   multiplySame (Permutation a) (Permutation b) =
      Permutation $ Perm.multiply b a


instance
   (MultiplySame typ xl xu, StaticIdentity typ xl xu lower upper,
    MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper,
    meas ~ Shape, vert ~ Small, horiz ~ Small,
    Shape.Static height, Eq height, height ~ width, Class.Floating a) =>
      Monoid (Matrix typ xl xu lower upper meas vert horiz height width a) where
   mappend = (<>)
   mempty = staticIdentity

class StaticIdentity typ xl xu lower upper where
   staticIdentity ::
      (Shape.Static sh, Class.Floating a) =>
      Quadratic typ xl xu lower upper sh a

instance
   (xl ~ (), xu ~ (), lower ~ Empty, upper ~ Empty) =>
      StaticIdentity Scale xl xu lower upper where
   staticIdentity = Scale Shape.static 1

instance
   (xl ~ (), xu ~ (), lower ~ Filled, upper ~ Filled) =>
      StaticIdentity Permutation xl xu lower upper where
   staticIdentity = Permutation $ Perm.identity Shape.static


scaleWithCheck :: (Eq shape) =>
   String -> (b -> shape) ->
   (a -> b -> c) ->
   Matrix Scale xl xu lower upper meas vert horiz shape shape a ->
   b -> c
scaleWithCheck name getSize f (Scale shape a) b =
   if shape == getSize b
      then f a b
      else error $ name ++ ": dimensions mismatch"


class Box typ where
   extent ::
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Extent.Extent meas vert horiz height width
   height ::
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      Matrix typ xl xu lower upper meas vert horiz height width a -> height
   height = Extent.height . extent
   width ::
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      Matrix typ xl xu lower upper meas vert horiz height width a -> width
   width = Extent.width . extent

instance Box Scale where
   extent (Scale shape _) = Extent.square shape
   height (Scale shape _) = shape
   width (Scale shape _) = shape

instance Box Identity where
   extent (Identity extent_) = extent_

instance Box Permutation where
   extent (Permutation perm) = Extent.square $ Perm.size perm
   height (Permutation perm) = Perm.size perm
   width (Permutation perm) = Perm.size perm

{- ToDo: requires parameters xl and xu for Box class

instance (Eq fuse) => Box (Product fuse) where
   extent (Product a b) =
      fromMaybe (error "Matrix.Product: shapes mismatch") $
      Extent.fuse (extent a) (extent b)
-}

squareSize :: (Box typ) => Quadratic typ xl xu lower upper sh a -> sh
squareSize = height

indices ::
   (Box typ, Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.Indexed height, Shape.Indexed width) =>
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   [(Shape.Index height, Shape.Index width)]
indices sh = Shape.indices (height sh, width sh)


class (Box typ) => ToQuadratic typ where
   heightToQuadratic ::
      (Extent.Measure meas) =>
      QuadraticMeas typ xl xu lower upper meas height width a ->
      Quadratic typ xl xu lower upper height a
   widthToQuadratic ::
      (Extent.Measure meas) =>
      QuadraticMeas typ xl xu lower upper meas height width a ->
      Quadratic typ xl xu lower upper width a

instance ToQuadratic Scale where
   heightToQuadratic (Scale shape a) = Scale shape a
   widthToQuadratic (Scale shape a) = Scale shape a

instance ToQuadratic Identity where
   heightToQuadratic (Identity extent_) =
      Identity $ Extent.square $ Extent.height extent_
   widthToQuadratic (Identity extent_) =
      Identity $ Extent.square $ Extent.width extent_

instance ToQuadratic Permutation where
   heightToQuadratic (Permutation perm) = Permutation perm
   widthToQuadratic (Permutation perm) = Permutation perm


class (Box typ) => MapExtent typ xl xu lower upper where
   mapExtent ::
      (Extent.Measure measA, Extent.C vertA, Extent.C horizA) =>
      (Extent.Measure measB, Extent.C vertB, Extent.C horizB) =>
      ExtentStrict.Map measA vertA horizA measB vertB horizB height width ->
      Matrix typ xl xu lower upper measA vertA horizA height width a ->
      Matrix typ xl xu lower upper measB vertB horizB height width a


class (Box typ) => Transpose typ where
   transpose ::
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C width, Shape.C height, Class.Floating a) =>
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Matrix typ xu xl upper lower meas horiz vert width height a

instance Transpose Scale where
   transpose (Scale shape a) = Scale shape a

instance Transpose Identity where
   transpose (Identity extent_) = Identity $ Extent.transpose extent_

instance Transpose Permutation where
   transpose (Permutation perm) = Permutation $ Perm.transpose perm

{-
instance (Shape.C fuse, Eq fuse) => Transpose (Product fuse) where
   transpose (Product a b) = Product (transpose b) (transpose a)
-}


swapMultiply ::
   (Transpose typA, Transpose typB) =>
   (Extent.Measure measA, Extent.C vertA, Extent.C horizA) =>
   (Extent.Measure measB, Extent.C vertB, Extent.C horizB) =>
   (Shape.C heightA, Shape.C widthA) =>
   (Shape.C heightB, Shape.C widthB) =>
   (Class.Floating a) =>
   (matrix ->
    Matrix typA xuA xlA upperA lowerA measA horizA vertA widthA heightA a ->
    Matrix typB xuB xlB upperB lowerB measB horizB vertB widthB heightB a) ->
   Matrix typA xlA xuA lowerA upperA measA vertA horizA heightA widthA a ->
   matrix ->
   Matrix typB xlB xuB lowerB upperB measB vertB horizB heightB widthB a
swapMultiply multiplyTrans a b =
   transpose $ multiplyTrans b $ transpose a

powerStrips ::
   (MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper) =>
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   (MatrixShape.PowerStripSingleton lower,
    MatrixShape.PowerStripSingleton upper)
powerStrips _ =
   (MatrixShape.powerStripSingleton, MatrixShape.powerStripSingleton)

strips ::
   (MatrixShape.Strip lower, MatrixShape.Strip upper) =>
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   (MatrixShape.StripSingleton lower, MatrixShape.StripSingleton upper)
strips _ = (MatrixShape.stripSingleton, MatrixShape.stripSingleton)