{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Test.Generator where

import qualified Test.Utility as Util
import Test.Utility (Match(Match,Mismatch))

import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix.Square as Square
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Hermitian (Hermitian)
import Numeric.LAPACK.Matrix (ZeroInt, zeroInt)
import Numeric.LAPACK.Scalar (RealOf, fromReal, one)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)
import Data.Array.Comfort.Shape ((:+:)((:+:)))

import qualified Control.Monad.Trans.RWS as MRWS
import qualified Control.Monad.Trans.Class as MT
import qualified Control.Functor.HT as FuncHT
import Control.Applicative (liftA2, (<$>))

import Data.Traversable (for)
import Data.Tuple.HT (mapFst, mapSnd, mapPair, swap)

import qualified Test.QuickCheck as QC



{- |
@Cons generator@ with @generator maxElem maxDim fixedDims@.
@generator@ constructs an array with possibly fixed height or width
and returns its actual dimensions.
Non-fixed dimensions will be choosen arbitrarily from the range @(0,maxDim)@.
Elements are choosen from the range @(-maxElem,maxElem)@.
-}
newtype T tag required actual array = Cons (required -> ExtGen (array,actual))

instance Functor (T tag required actual) where
   fmap f (Cons gen) = Cons $ \fixed -> mapFst f <$> gen fixed

type ExtGen = MRWS.RWST (Integer,Int,MatchMode) Match () QC.Gen

data MatchMode = DontForceMatch | ForceMatch
   deriving (Eq, Show)

class Required required where nothingRequired :: required
instance Required () where nothingRequired = ()
instance Required (Maybe a) where nothingRequired = Nothing
instance (Required a, Required b) => Required (a,b) where
   nothingRequired = (nothingRequired,nothingRequired)

run ::
   (Required required) =>
   T tag required actual array -> Integer -> Int ->
   Util.TaggedGen tag (array, Match)
run (Cons gen) maxElem maxDim =
   Util.Tagged $ do
      forceMatch <- QC.elements [DontForceMatch, ForceMatch]
      ((array, _actualDim), match) <-
         MRWS.evalRWST (gen nothingRequired) (maxElem, maxDim, forceMatch) ()
      return (array, match)

withExtra ::
   (T tag required actual (a,b) -> ((a,b) -> c) -> io) ->
   QC.Gen a -> T tag required actual b -> (a -> b -> c) -> io
withExtra checkForAll genA genB test =
   checkForAll (mapGen (\_ b -> flip (,) b <$> genA) genB) (uncurry test)


mapGen ::
   (Integer -> a -> QC.Gen b) ->
   T tag required actual a -> T tag required actual b
mapGen f (Cons gen) =
   Cons $ \fixed -> do
      (maxElem, _maxDim, _match) <- MRWS.ask
      MT.lift . FuncHT.mapFst (f maxElem) =<< gen fixed

mapGenDim ::
   (Integer -> Int -> a -> QC.Gen b) ->
   T tag required actual a -> T tag required actual b
mapGenDim f (Cons gen) =
   Cons $ \fixed -> do
      (maxElem, maxDim, _match) <- MRWS.ask
      MT.lift . FuncHT.mapFst (f maxElem maxDim) =<< gen fixed


chooseDimMin :: Int -> ExtGen Int
chooseDimMin k = do
   (_maxElem, maxDim, _match) <- MRWS.ask
   MT.lift $ QC.choose (k,maxDim)


class Dim dim where chooseDim :: ExtGen dim
instance Dim Int where chooseDim = chooseDimMin 0
instance (Dim dimA, Dim dimB) => Dim (dimA:+:dimB) where
   chooseDim = liftA2 (:+:) chooseDim chooseDim


matchDim :: (Dim i, Eq i) => i -> ExtGen i
matchDim size = do
   (_maxElem, _maxDim, match) <- MRWS.ask
   case match of
      ForceMatch -> return size
      DontForceMatch -> do
         newSize <- chooseDim
         MRWS.tell $ if newSize==size then Match else Mismatch
         return newSize


type Scalar tag = T tag () ()

scalar :: (Class.Floating a) => Scalar a a
scalar =
   Cons $ \ _fixed -> do
      (maxElem, _maxDim, _match) <- MRWS.ask
      MT.lift $ flip (,) () <$> Util.genElement maxElem

(<.*.>) ::
   Vector tag size (a -> b) ->
   Vector tag size a ->
   Scalar tag b
(<.*.>) (Cons genA) (Cons genB) =
   Cons $ \() -> do
      (f,size) <- genA Nothing
      (a,_) <- genB $ Just size
      return (f a, ())


type Vector tag size = T tag (Maybe size) size

vectorDim :: (Class.Floating a) => Vector a Int ZeroInt
vectorDim =
   Cons $ \ fixed -> do
      dims <- maybe chooseDim return fixed
      return (zeroInt dims, dims)

vector :: (Class.Floating a) => Vector a Int (Vector.Vector ZeroInt a)
vector = mapGen Util.genArray vectorDim

vectorReal ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Vector a Int (Vector.Vector ZeroInt ar)
vectorReal = mapGen Util.genArray vectorDim

(<.*|>) ::
   (Dim height, Eq height) =>
   Vector tag height (a -> b) ->
   Matrix tag height width a ->
   Vector tag width b
(<.*|>) (Cons genA) (Cons genB) =
   Cons $ \fixed -> do
      (a,(height,width)) <- genB $ Right <$> fixed
      (f,_) <- genA . Just =<< matchDim height
      return (f a, width)

(<|*.>) ::
   (Dim width, Eq width) =>
   Matrix tag height width (a -> b) ->
   Vector tag width a ->
   Vector tag height b
(<|*.>) (Cons genA) (Cons genB) =
   Cons $ \fixed -> do
      (f,(height,width)) <- genA $ Left <$> fixed
      (a,_) <- genB . Just =<< matchDim width
      return (f a, height)

(<.=.>) ::
   (Dim size, Eq size) =>
   Vector tag size (a -> b) ->
   Vector tag size a ->
   Vector tag size b
(<.=.>) (Cons genA) (Cons genB) =
   Cons $ \fixed -> do
      (f,size) <- genA fixed
      (a,_) <- genB . Just =<< matchDim size
      return (f a, size)


type Matrix tag height width =
      T tag (Maybe (Either height width)) (height,width)

matrixDims ::
   (Class.Floating a) => Matrix a Int Int (ZeroInt, ZeroInt)
matrixDims =
   Cons $ \ fixed -> do
      dims <-
         case fixed of
            Nothing -> liftA2 (,) chooseDim chooseDim
            Just (Left h) -> (,) h <$> chooseDim
            Just (Right w) -> flip (,) w <$> chooseDim
      return (mapPair (zeroInt,zeroInt) dims, dims)

matrix ::
   (Class.Floating a) => Matrix a Int Int (Matrix.General ZeroInt ZeroInt a)
matrix =
   flip mapGen matrixDims $ \maxElem dims -> do
      order <- Util.genOrder
      Util.genArray maxElem $ uncurry (MatrixShape.general order) dims


squareDim :: (Class.Floating a) => Matrix a Int Int ZeroInt
squareDim =
   Cons $ \ fixed -> do
      size <-
         case fixed of
            Nothing -> chooseDim
            Just (Left h) -> return h
            Just (Right w) -> return w
      return (zeroInt size, (size,size))

squareShaped ::
   (Shape.C sh, Class.Floating a) =>
   (MatrixShape.Order -> ZeroInt -> sh) -> Matrix a Int Int (Array sh a)
squareShaped shape =
   flip mapGen squareDim $ \maxElem size -> do
      order <- Util.genOrder
      Util.genArray maxElem $ shape order size

square :: (Class.Floating a) => Matrix a Int Int (Square.Square ZeroInt a)
square = squareShaped MatrixShape.square

squareCond ::
   (Class.Floating a) =>
   (Square.Square ZeroInt a -> Bool) ->
   Matrix a Int Int (Square.Square ZeroInt a)
squareCond cond =
   flip mapGen squareDim $ \maxElem size -> do
      order <- Util.genOrder
      Util.genArray maxElem (MatrixShape.square order size)
         `QC.suchThat`
         cond

invertible ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Matrix a Int Int (Square.Square ZeroInt a)
invertible = squareCond Util.invertible

diagonal ::
   (Class.Floating a) => Matrix a Int Int (Triangular.Diagonal ZeroInt a)
diagonal = squareShaped MatrixShape.diagonal

identity ::
   (MatrixShape.Content lo, MatrixShape.Content up, Class.Floating a) =>
   Matrix a Int Int (Triangular.Triangular lo MatrixShape.Unit up ZeroInt a)
identity =
   flip mapGen squareDim $ \ _maxElem size -> do
      order <- Util.genOrder
      return $ Triangular.identity order size

triangularCond ::
   (MatrixShape.Content up, MatrixShape.Content lo, MatrixShape.TriDiag diag,
    Class.Floating a) =>
   (Triangular.Triangular lo diag up ZeroInt a -> Bool) ->
   Matrix a Int Int (Triangular.Triangular lo diag up ZeroInt a)
triangularCond cond =
   flip mapGen squareDim $ \maxElem size -> do
      order <- Util.genOrder
      genTriangularArray maxElem
         (MatrixShape.Triangular
            MatrixShape.autoDiag MatrixShape.autoUplo order size)
         `QC.suchThat`
         cond

triangular ::
   (MatrixShape.Content up, MatrixShape.Content lo, MatrixShape.TriDiag diag,
    Class.Floating a) =>
   Matrix a Int Int (Triangular.Triangular lo diag up ZeroInt a)
triangular = triangularCond (const True)


newtype GenTriangularDiag lo up a diag =
   GenTriangularDiag {
      runGenTriangularDiag ::
         MatrixShape.Triangular lo diag up ZeroInt ->
         QC.Gen (Triangular.Triangular lo diag up ZeroInt a)
   }

genTriangularArray ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Class.Floating a) =>
   Integer ->
   MatrixShape.Triangular lo diag up ZeroInt ->
   QC.Gen (Triangular.Triangular lo diag up ZeroInt a)
genTriangularArray maxElem =
   runGenTriangularDiag $
   MatrixShape.switchTriDiag
      (GenTriangularDiag $ \shape ->
         Array.fromList shape <$>
            (for (Shape.indices shape) $ \(r,c) ->
               if r==c
                  then return one
                  else Util.genElement maxElem))
      (GenTriangularDiag $ Util.genArray maxElem)


tallDims :: (Class.Floating a) => Matrix a Int Int (ZeroInt, ZeroInt)
tallDims =
   Cons $ \ fixed -> do
      dims <-
         case fixed of
            Nothing -> do
               h <- chooseDim
               w <- MT.lift $ QC.choose (0,h)
               return (h,w)
            Just (Left h) -> do
               w <- MT.lift $ QC.choose (0,h)
               return (h,w)
            Just (Right w) -> do
               h <- chooseDimMin w
               return (h,w)
      return (mapPair (zeroInt,zeroInt) dims, dims)

tall ::
   (Class.Floating a) =>
   Matrix a Int Int (Matrix.Tall ZeroInt ZeroInt a)
tall =
   flip mapGen tallDims $ \maxElem dims -> do
      order <- Util.genOrder
      Util.genArray maxElem $ uncurry (MatrixShape.tall order) dims

fullRankTall ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Matrix a Int Int (Matrix.Tall ZeroInt ZeroInt a)
fullRankTall =
   flip mapGen tallDims $ \maxElem dims -> do
      order <- Util.genOrder
      Util.genArray maxElem (uncurry (MatrixShape.tall order) dims)
         `QC.suchThat` Util.fullRankTall


wideDims :: (Class.Floating a) => Matrix a Int Int (ZeroInt, ZeroInt)
wideDims =
   Cons $ \ fixed -> do
      dims <-
         case fixed of
            Nothing -> do
               w <- chooseDim
               h <- MT.lift $ QC.choose (0,w)
               return (h,w)
            Just (Left h) -> do
               w <- chooseDimMin h
               return (h,w)
            Just (Right w) -> do
               h <- MT.lift $ QC.choose (0,w)
               return (h,w)
      return (mapPair (zeroInt,zeroInt) dims, dims)

wide ::
   (Class.Floating a) =>
   Matrix a Int Int (Matrix.Wide ZeroInt ZeroInt a)
wide =
   flip mapGen wideDims $ \maxElem dims -> do
      order <- Util.genOrder
      Util.genArray maxElem $ uncurry (MatrixShape.wide order) dims

fullRankWide ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Matrix a Int Int (Matrix.Wide ZeroInt ZeroInt a)
fullRankWide =
   flip mapGen wideDims $ \maxElem dims -> do
      order <- Util.genOrder
      fmap Matrix.transpose $
         Util.genArray maxElem (uncurry (MatrixShape.tall order) (swap dims))
            `QC.suchThat` Util.fullRankTall


hermitian ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Matrix a Int Int (Hermitian ZeroInt a)
hermitian = hermitianCond (const True)

hermitianCond ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   (Hermitian ZeroInt a -> Bool) ->
   Matrix a Int Int (Hermitian ZeroInt a)
hermitianCond cond =
   flip mapGen squareDim $ \maxElem size -> do
      order <- Util.genOrder
      let shape = MatrixShape.hermitian order size
      (Array.fromList shape <$>
         (for (Shape.indices shape) $ \(r,c) ->
            if r==c
               then fromReal <$> Util.genReal maxElem
               else Util.genElement maxElem))
         `QC.suchThat` cond


{-
There cannot be a pure/point function.
-}
(<|*|>) ::
   (Dim fuse, Eq fuse) =>
   Matrix tag height fuse (a -> b) ->
   Matrix tag fuse width a ->
   Matrix tag height width b
(<|*|>) (Cons genA) (Cons genB) =
   Cons $ \fixed ->
      case fixed of
         Just (Right width) -> do
            (a,(fuse,_)) <- genB $ Just $ Right width
            (f,(height,_)) <- genA . Just . Right =<< matchDim fuse
            return (f a, (height,width))
         Just (Left height) -> do
            (f,(_,fuse)) <- genA $ Just $ Left height
            (a,(_,width)) <- genB . Just . Left =<< matchDim fuse
            return (f a, (height,width))
         Nothing -> do
            (f,(height,fuse)) <- genA Nothing
            (a,(_,width)) <- genB . Just . Left =<< matchDim fuse
            return (f a, (height,width))

transpose ::
   Matrix tag height width a ->
   Matrix tag width height a
transpose (Cons gen) =
   Cons $ fmap (mapSnd swap) . gen . fmap (either Right Left)

(<|\|>) ::
   (Dim height, Eq height) =>
   Matrix tag height width (a -> b) ->
   Matrix tag height nrhs a ->
   Matrix tag width nrhs b
(<|\|>) a b = transpose a <|*|> b

(<***>) ::
   Vector tag height (a -> b) ->
   Vector tag width a ->
   Matrix tag height width b
(<***>) (Cons genA) (Cons genB) =
   Cons $ \fixed -> do
      (f,height) <- genA $ either Just (const Nothing) =<< fixed
      (a,width) <- genB $ either (const Nothing) Just =<< fixed
      return (f a, (height,width))


{-
We need this type because the test stackRowsColumnsCommutative
requires to fix both height and width of the bottom right matrix.

Conversely, we cannot use the type e.g. for Square matrices,
because Square does not allow independent choice of height and width.
-}
type Matrix2 tag height width =
      T tag (Maybe height, Maybe width) (height,width)

matrix2Dims :: (Class.Floating a) => Matrix2 a Int Int (ZeroInt, ZeroInt)
matrix2Dims =
   Cons $ \ (fixedHeight,fixedWidth) -> do
      let maybeChooseDim = maybe chooseDim return
      dims <-
         liftA2 (,) (maybeChooseDim fixedHeight) (maybeChooseDim fixedWidth)
      return (mapPair (zeroInt,zeroInt) dims, dims)

matrix2 ::
   (Class.Floating a) => Matrix2 a Int Int (Matrix.General ZeroInt ZeroInt a)
matrix2 =
   flip mapGen matrix2Dims $ \maxElem dims -> do
      order <- Util.genOrder
      Util.genArray maxElem $ uncurry (MatrixShape.general order) dims

(<===>) ::
   (Dim width, Eq width) =>
   Matrix2 tag heightA width (a -> b) ->
   Matrix2 tag heightB width a ->
   Matrix2 tag (heightA:+:heightB) width b
(<===>) (Cons genA) (Cons genB) =
   Cons $ \(fixedHeight,fixedWidth) -> do
      (f,(heightA,width)) <-
         genA ((\(heightA:+:_) -> heightA) <$> fixedHeight, fixedWidth)
      matchingWidth <- matchDim width
      (a,(heightB,_)) <-
         genB ((\(_:+:heightB) -> heightB) <$> fixedHeight, Just matchingWidth)
      return (f a, (heightA:+:heightB, width))

(<|||>) ::
   (Dim height, Eq height) =>
   Matrix2 tag height widthA (a -> b) ->
   Matrix2 tag height widthB a ->
   Matrix2 tag height (widthA:+:widthB) b
(<|||>) (Cons genA) (Cons genB) =
   Cons $ \(fixedHeight,fixedWidth) -> do
      (f,(height,widthA)) <-
         genA (fixedHeight, (\(widthA:+:_) -> widthA) <$> fixedWidth)
      matchingHeight <- matchDim height
      (a,(_,widthB)) <-
         genB (Just matchingHeight, (\(_:+:widthB) -> widthB) <$> fixedWidth)
      return (f a, (height, widthA:+:widthB))


infixl 4 <.*.>, <.*|>, <|*.>, <|*|>, <|\|>, <***>, <.=.>, <===>, <|||>