{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE EmptyDataDecls #-}
module Numeric.LAPACK.Matrix.Type where

import qualified Numeric.LAPACK.Matrix.Array.Format as ArrFormat
import qualified Numeric.LAPACK.Output as Output
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Scalar as Scalar
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import Numeric.LAPACK.Output (Output)
import Numeric.LAPACK.Scalar (RealOf, ComplexOf)

import qualified Numeric.Netlib.Class as Class

import qualified Hyper

import qualified Control.DeepSeq as DeepSeq

import qualified Data.Array.Comfort.Shape as Shape

import Data.Semigroup (Semigroup, (<>))



data family Matrix typ a


data Scale shape
data instance Matrix (Scale shape) a = Scale shape a

data Inverse typ
newtype instance Matrix (Inverse typ) a = Inverse (Matrix typ a)


newtype instance Matrix (Perm.Permutation sh) a =
   Permutation (Perm.Permutation sh)
      deriving (Show)



instance (NFData typ, DeepSeq.NFData a) => DeepSeq.NFData (Matrix typ a) where
   rnf = rnf

class NFData typ where
   rnf :: (DeepSeq.NFData a) => Matrix typ a -> ()



instance
   (FormatMatrix typ, Class.Floating a) =>
      Hyper.Display (Matrix typ a) where
   display = Output.hyper . formatMatrix ArrFormat.deflt


class FormatMatrix typ where
   {-
   We use constraint @(Class.Floating a)@ and not @(Format a)@
   because it allows us to align the components of complex numbers.
   -}
   formatMatrix ::
      (Class.Floating a, Output out) => String -> Matrix typ a -> out

instance (Shape.C sh) => FormatMatrix (Scale sh) where
   formatMatrix fmt (Scale shape a) =
      ArrFormat.formatDiagonal fmt MatrixShape.RowMajor shape $
      replicate (Shape.size shape) a

instance (Shape.C sh) => FormatMatrix (Perm.Permutation sh) where
   formatMatrix _fmt (Permutation perm) = Perm.format perm



instance (MultiplySame typ, Class.Floating a) => Semigroup (Matrix typ a) where
   (<>) = multiplySame

class MultiplySame typ where
   multiplySame ::
      (Class.Floating a) => Matrix typ a -> Matrix typ a -> Matrix typ a

instance (Eq shape) => MultiplySame (Scale shape) where
   multiplySame =
      scaleWithCheck "Scale.multiplySame" height
         (\a (Scale shape b) -> Scale shape $ a*b)

instance (MultiplySame typ) => MultiplySame (Inverse typ) where
   multiplySame (Inverse a) (Inverse b) = Inverse $ multiplySame b a

instance (Shape.C sh, Eq sh) => MultiplySame (Perm.Permutation sh) where
   multiplySame (Permutation a) (Permutation b) =
      Permutation $ Perm.multiply b a


scaleWithCheck :: (Eq shape) =>
   String -> (b -> shape) ->
   (a -> b -> c) -> Matrix (Scale shape) a -> b -> c
scaleWithCheck name getSize f (Scale shape a) b =
   if shape == getSize b
      then f a b
      else error $ name ++ ": dimensions mismatch"


class Box typ where
   type HeightOf typ
   type WidthOf typ
   height :: Matrix typ a -> HeightOf typ
   width :: Matrix typ a -> WidthOf typ

instance Box (Scale sh) where
   type HeightOf (Scale sh) = sh
   type WidthOf (Scale sh) = sh
   height (Scale shape _) = shape
   width (Scale shape _) = shape

instance (Box typ) => Box (Inverse typ) where
   type HeightOf (Inverse typ) = HeightOf typ
   type WidthOf (Inverse typ) = WidthOf typ
   height (Inverse m) = height m
   width (Inverse m) = width m

instance Box (Perm.Permutation sh) where
   type HeightOf (Perm.Permutation sh) = sh
   type WidthOf (Perm.Permutation sh) = sh
   height (Permutation perm) = Perm.size perm
   width (Permutation perm) = Perm.size perm

indices ::
   (Box typ,
    HeightOf typ ~ height, Shape.Indexed height,
    WidthOf typ ~ width, Shape.Indexed width) =>
   Matrix typ a -> [(Shape.Index height, Shape.Index width)]
indices sh = Shape.indices (height sh, width sh)


class Complex typ where
   conjugate :: (Class.Floating a) => Matrix typ a -> Matrix typ a
   fromReal :: (Class.Floating a) => Matrix typ (RealOf a) -> Matrix typ a
   toComplex :: (Class.Floating a) => Matrix typ a -> Matrix typ (ComplexOf a)

instance (Complex typ) => Complex (Inverse typ) where
   conjugate (Inverse m) = Inverse $ conjugate m
   fromReal (Inverse m) = Inverse $ fromReal m
   toComplex (Inverse m) = Inverse $ toComplex m

instance (Shape.C shape) => Complex (Scale shape) where
   conjugate (Scale sh m) = Scale sh $ Scalar.conjugate m
   fromReal (Scale sh m) = Scale sh $ Scalar.fromReal m
   toComplex (Scale sh m) = Scale sh $ Scalar.toComplex m

instance (Shape.C shape) => Complex (Perm.Permutation shape) where
   conjugate = id
   fromReal (Permutation p) = Permutation p
   toComplex (Permutation p) = Permutation p