{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ConstraintKinds #-}
module Test.Utility where

import qualified Numeric.LAPACK.Matrix.Hermitian as Herm
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Orthogonal.Householder as HH
import qualified Numeric.LAPACK.Orthogonal as Ortho
import Numeric.LAPACK.Matrix.Array (ArrayMatrix)
import Numeric.LAPACK.Matrix.Shape.Omni (Omni)
import Numeric.LAPACK.Matrix.Layout (Order(RowMajor,ColumnMajor))
import Numeric.LAPACK.Matrix (Matrix, ShapeInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, absolute)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)
import Data.Array.Comfort.Shape ((::+))

import qualified Control.Monad.Trans.State as MS
import Control.Monad (replicateM)
import Control.Applicative (Applicative, liftA2, pure, (<*>), (<$>))

import qualified Data.List.HT as ListHT
import qualified Data.Complex as Complex
import Data.Complex (Complex((:+)))
import Data.Traversable (traverse)
import Data.Monoid (Monoid(mempty,mappend))
import Data.Semigroup (Semigroup((<>)))
import Data.Tuple.HT (mapFst)
import Data.Eq.HT (equating)

import qualified Test.QuickCheck as QC
import Test.ChasingBottoms.IsBottom (isBottom)


equalListWith :: (a -> a -> Bool) -> [a] -> [a] -> Bool
equalListWith eq xs ys =
   and $ ListHT.takeWhileJust $
   zipWith
      (\mx my ->
         case (mx,my) of
            (Nothing,Nothing) -> Nothing
            (Just x, Just y) -> Just $ eq x y
            _ -> Just False)
      (map Just xs ++ repeat Nothing)
      (map Just ys ++ repeat Nothing)


equalVectorBody ::
   (Shape.C shape, Class.Floating a) =>
   Array shape a -> Array shape a -> Bool
equalVectorBody =
   getEqualArray $
   Class.switchFloating
      (EqualArray $ equating Array.toList)
      (EqualArray $ equating Array.toList)
      (EqualArray $ equating Array.toList)
      (EqualArray $ equating Array.toList)

newtype EqualArray f a = EqualArray {getEqualArray :: f a -> f a -> Bool}

equalVector ::
   (Shape.C shape, Eq shape, Class.Floating a) =>
   Array shape a -> Array shape a -> Bool
equalVector x y =
   if Array.shape x == Array.shape y
     then equalVectorBody x y
     else error "equalArray: shapes mismatch"

equalArray ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Eq height, Eq width, Class.Floating a) =>
   ArrayMatrix pack property lower upper meas vert horiz height width a ->
   ArrayMatrix pack property lower upper meas vert horiz height width a -> Bool
equalArray x y = equalVector (ArrMatrix.unwrap x) (ArrMatrix.unwrap y)

equalMatrix ::
   (MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Eq height, Eq width, Class.Floating a) =>
   ArrayMatrix pack property lower upper meas vert horiz height width a ->
   ArrayMatrix pack property lower upper meas vert horiz height width a -> Bool
equalMatrix x y = equalArray (Matrix.adaptOrder y x) y


approx ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) => ar -> a -> a -> Bool
approx tol x y = absolute (x-y) <= tol

approxReal :: (Class.Real a) => a -> a -> a -> Bool
approxReal tol x y = abs (x-y) <= tol


approxVectorTol ::
   (Shape.C shape, Eq shape, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   ar -> Array shape a -> Array shape a -> Bool
approxVectorTol tol x y =
   if Array.shape x == Array.shape y
     then and $ zipWith (approx tol) (Array.toList x) (Array.toList y)
     else error "approxArray: shapes mismatch"

approxVector ::
   (Shape.C shape, Eq shape, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Array shape a -> Array shape a -> Bool
approxVector = approxVectorTol 1e-5

approxRealVectorTol ::
   (Shape.C shape, Eq shape, Class.Real a) =>
   a -> Array shape a -> Array shape a -> Bool
approxRealVectorTol tol x y =
   if Array.shape x == Array.shape y
     then and $ zipWith (approxReal tol) (Array.toList x) (Array.toList y)
     else error "approxRealArray: shapes mismatch"


approxArrayTol ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Eq height, Eq width) =>
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   ar ->
   ArrayMatrix pack property lower upper meas vert horiz height width a ->
   ArrayMatrix pack property lower upper meas vert horiz height width a -> Bool
approxArrayTol tol x y =
   approxVectorTol tol (ArrMatrix.unwrap x) (ArrMatrix.unwrap y)

approxArray ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Eq height, Eq width) =>
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   ArrayMatrix pack property lower upper meas vert horiz height width a ->
   ArrayMatrix pack property lower upper meas vert horiz height width a -> Bool
approxArray x y = approxVector (ArrMatrix.unwrap x) (ArrMatrix.unwrap y)


approxMatrix ::
   (MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Eq height, Eq width) =>
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   ar ->
   ArrayMatrix pack property lower upper meas vert horiz height width a ->
   ArrayMatrix pack property lower upper meas vert horiz height width a -> Bool
approxMatrix tol x y =
   approxArrayTol tol x $ Matrix.adaptOrder x y


maybeConjugate ::
   (Matrix.Complex typ) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   HH.Conjugation ->
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   Matrix typ xl xu lower upper meas vert horiz height width a
maybeConjugate HH.NonConjugated = id
maybeConjugate HH.Conjugated = Matrix.conjugate


type NonEmptyInt = ()::+ShapeInt
type EInt = Either () Int


genReal :: (Class.Real a) => Integer -> QC.Gen a
genReal n = fromInteger <$> QC.choose (-n,n)

genComplex :: (Class.Real a) => Integer -> QC.Gen (Complex a)
genComplex n = liftA2 (Complex.:+) (genReal n) (genReal n)

genElement :: (Class.Floating a) => Integer -> QC.Gen a
genElement n =
   Class.switchFloating (genReal n) (genReal n) (genComplex n) (genComplex n)

genVector ::
   (Shape.C shape, Class.Floating a) =>
   Integer -> shape -> QC.Gen (Array shape a)
genVector maxElem shape =
   Array.fromList shape <$> replicateM (Shape.size shape) (genElement maxElem)

genArray ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Integer ->
   Omni pack property lower upper meas vert horiz height width ->
   QC.Gen (ArrayMatrix pack property lower upper meas vert horiz height width a)
genArray maxElem shape = fmap ArrMatrix.Array $ genVector maxElem shape

genArrayIndexed ::
   (Shape.Indexed shape, Class.Floating a) =>
   shape -> (Shape.Index shape -> QC.Gen a) -> QC.Gen (Array shape a)
genArrayIndexed shape f =
   Array.fromList shape <$> traverse f (Shape.indices shape)

genArrayExtraDiag_ ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width) =>
   (Shape.Indexed shape, Shape.Index shape ~ (i,i), Eq i, Class.Floating a) =>
   Integer ->
   (Omni pack property lower upper meas vert horiz height width -> shape) ->
   Omni pack property lower upper meas vert horiz height width ->
   (i -> QC.Gen a) ->
   QC.Gen (ArrayMatrix pack property lower upper meas vert horiz height width a)
genArrayExtraDiag_ maxElem toPlainShape shape diag =
   fmap (ArrMatrix.Array . Array.reshape shape) $
   genArrayIndexed (toPlainShape shape) $
      \(r,c) -> if r==c then diag r else genElement maxElem

genArrayExtraDiag ::
   (MatrixShape.Packing pack) =>
   (MatrixShape.DiagUpLo lo up, MatrixShape.TriDiag diag) =>
   (Shape.C sh, Shape.Indexed sh, Shape.Index sh ~ i, Eq i) =>
   (Class.Floating a) =>
   Integer ->
   MatrixShape.Quadratic pack diag lo up sh ->
   (i -> QC.Gen a) ->
   QC.Gen (ArrMatrix.Quadratic pack diag lo up sh a)
genArrayExtraDiag maxElem shape0 diag =
   flip runGenTriangularLoUp shape0 $
   MatrixShape.switchDiagUpLo
      (GenTriangularLoUp $
       \shape ->
          fmap (ArrMatrix.Array . Array.reshape shape) $
          genArrayIndexed (MatrixShape.squareSize shape) diag)
      (GenTriangularLoUp $
       \shape -> genArrayExtraDiag_ maxElem Omni.toPlain shape diag)
      (GenTriangularLoUp $
       \shape -> genArrayExtraDiag_ maxElem Omni.toPlain shape diag)

newtype GenTriangularLoUp pack diag sh a lo up =
   GenTriangularLoUp {
      runGenTriangularLoUp ::
         MatrixShape.Quadratic pack diag lo up sh ->
         QC.Gen (ArrMatrix.Quadratic pack diag lo up sh a)
   }


select :: [a] -> QC.Gen (a, [a])
select = QC.elements . ListHT.removeEach

genDistinct ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   [Integer] -> [Integer] -> ShapeInt -> QC.Gen (Vector ShapeInt a)
genDistinct elemsS elemsD size@(Shape.ZeroBased n) = do
   let range ks = map fromInteger ks
   fmap (Vector.fromList size) $
      MS.evalStateT (replicateM n $ MS.StateT select) $
      Class.switchFloating
         (range elemsS) (range elemsD)
         (liftA2 (:+) (range elemsS) (range elemsS))
         (liftA2 (:+) (range elemsD) (range elemsD))


genOrder :: QC.Gen Order
genOrder = QC.elements [RowMajor, ColumnMajor]



invertible ::
   (Matrix.Determinant typ,
    Matrix.DeterminantExtra typ xl, Matrix.DeterminantExtra typ xu,
    MatrixShape.Strip lower, MatrixShape.Strip upper,
    Shape.C sh, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Matrix.Quadratic typ xl xu lower upper sh a -> Bool
invertible a = absolute (Matrix.determinant a) > 0.1

fullRankTall ::
   (Shape.C height, Shape.C width,
    Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Matrix.Tall height width a -> Bool
fullRankTall a = Ortho.determinantAbsolute a > 0.1


isIdentity ::
   (Omni.Quadratic pack property lower upper,
    Omni.Quadratic pack property upper lower,
    Shape.C sh, Eq sh, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   ar ->
   ArrMatrix.Quadratic pack property lower upper sh a ->
   Bool
isIdentity tol eye =
   approxArrayTol tol eye (Matrix.identityFrom eye)

isUnitary ::
   (Extent.Measure meas, Extent.C vert,
    Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   ar -> Matrix.Full meas vert Extent.Small ShapeInt ShapeInt a -> Bool
isUnitary tol =
   isIdentity tol . ArrMatrix.asPacked . Herm.gramian . Matrix.fromFull


addMatrices ::
   (ArrMatrix.Homogeneous property) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   Omni pack property lower upper meas vert horiz height width ->
   [ArrayMatrix pack property lower upper meas vert horiz height width a] ->
   ArrayMatrix pack property lower upper meas vert horiz height width a
addMatrices sh = foldl (ArrMatrix.liftOmni2 Vector.add) (ArrMatrix.zero sh)



infixl 3 !|||
infixl 2 !===

(!|||) ::
   (Shape.C height, Eq height, Shape.C widthA, Shape.C widthB,
    Class.Floating a) =>
   Matrix.General height widthA a ->
   Matrix.General height widthB a ->
   Matrix.General height (widthA::+widthB) a
(!|||) = Matrix.beside Matrix.leftBias Extent.appendAny

(!===) ::
   (Shape.C width, Eq width, Shape.C heightA, Shape.C heightB,
    Class.Floating a) =>
   Matrix.General heightA width a ->
   Matrix.General heightB width a ->
   Matrix.General (heightA::+heightB) width a
(!===) = Matrix.above Matrix.leftBias Extent.appendAny



newtype Tagged tag a = Tagged a deriving (Show)
type TaggedGen tag a = Tagged tag (QC.Gen a)

instance Functor (Tagged tag) where
   fmap f (Tagged a) = Tagged (f a)

instance Applicative (Tagged tag) where
   pure = Tagged
   Tagged f <*> Tagged a = Tagged (f a)



checkForAllPlain ::
   (Show a, QC.Testable test) =>
   TaggedGen tag a -> (a -> test) -> Tagged tag QC.Property
checkForAllPlain (Tagged gen) test = Tagged $ QC.forAll gen test

checkForAll ::
   (Show a, QC.Testable test) =>
   TaggedGen tag (a, Match) -> (a -> test) -> Tagged tag QC.Property
checkForAll taggedGen test =
   checkForAllPlain taggedGen $ \(a,match) ->
      case match of
         Match -> QC.property $ test a
         Mismatch -> QC.property $ isBottom $ test a

{- |
In @DontForceMatch@ mode the test generators
may ignore generating matching dimensions.
If dimensions actually mismatch, a @Mismatch@ value is returned.
In this case the test driver asserts that
the test routine is aborted with an error.
However, a typical test type might be
\"generic implementation = specialized implementation\".
If the generic implementation correctly checks the sizes,
then the tester cannot detect a missing check in the specialized implementation.
So far the proposed way to avoid this problem
is to add a test that relies solely on the function to be tested.
If you have no better idea, compare an implementation with itself.
-}
data Match = Mismatch | Match
   deriving (Eq, Show)

instance Semigroup Match where
   Match <> Match = Match
   _ <> _ = Mismatch

instance Monoid Match where
   mempty = Match
   mappend = (<>)



prefix :: String -> [(String, test)] -> [(String, test)]
prefix msg = map $ mapFst (\str -> msg ++ "." ++ str)

suffix :: String -> [(String, test)] -> [(String, test)]
suffix msg = map $ mapFst (\str -> str ++ "." ++ msg)