{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Multiply where

import qualified Numeric.LAPACK.Matrix.Array.Multiply as Multiply
import qualified Numeric.LAPACK.Matrix.Array.Unpacked as Unpacked
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Permutation as PermMatrix
import qualified Numeric.LAPACK.Matrix.Class as MatrixClass
import qualified Numeric.LAPACK.Matrix.Type as Matrix
import qualified Numeric.LAPACK.Matrix.Basic as FullBasic
import qualified Numeric.LAPACK.Matrix.Modifier as Mod
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as ExtentPriv
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Array.Private (Full)
import Numeric.LAPACK.Matrix.Type (Matrix, scaleWithCheck)
import Numeric.LAPACK.Matrix.Modifier (Transposition(NonTransposed,Transposed))
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable as Array
import qualified Data.Array.Comfort.Shape as Shape

import qualified Data.Stream as Stream
import Data.Stream (Stream)
import Data.Maybe (fromMaybe)



infixl 7 -*#
infixr 7 #*|

(#*|) ::
   (MultiplyVector typ xl xu, Omni.Strip lower, Omni.Strip upper) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   Vector width a -> Vector height a
(#*|) = matrixVector

(-*#) ::
   (MultiplyVector typ xl xu, Omni.Strip lower, Omni.Strip upper) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq height, Class.Floating a) =>
   Vector height a ->
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   Vector width a
(-*#) = vectorMatrix


class (Matrix.Box typ) => MultiplyVector typ xl xu where
   matrixVector ::
      (Omni.Strip lower, Omni.Strip upper) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz,
       Shape.C height, Shape.C width) =>
      (Eq width, Class.Floating a) =>
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Vector width a -> Vector height a
   vectorMatrix ::
      (Omni.Strip lower, Omni.Strip upper) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz,
       Shape.C height, Shape.C width) =>
      (Eq height, Class.Floating a) =>
      Vector height a ->
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Vector width a

instance (xl ~ (), xu ~ ()) => MultiplyVector Matrix.Scale xl xu where
   matrixVector a@(Matrix.Scale _ _) =
      scaleWithCheck "Matrix.Multiply.matrixVector Scale"
         Array.shape Vector.scale a
   vectorMatrix v a@(Matrix.Scale _ _) =
      scaleWithCheck "Matrix.Multiply.vectorMatrix Scale"
         Array.shape Vector.scale a v

instance (xl ~ (), xu ~ ()) => MultiplyVector Matrix.Permutation xl xu where
   matrixVector a@(Matrix.Permutation _) =
      PermMatrix.multiplyVector Mod.NonInverted a
   vectorMatrix v a@(Matrix.Permutation _) =
      PermMatrix.multiplyVector Mod.Inverted a v

instance
   (Layout.Packing pack, Omni.Property property, xl ~ (), xu ~ ()) =>
      MultiplyVector (ArrMatrix.Array pack property) xl xu where
   matrixVector = Multiply.matrixVector
   vectorMatrix = Multiply.vectorMatrix



class (Matrix.Box typ) => MultiplySquare typ xl xu where
   {-# MINIMAL transposableSquare | fullSquare,squareFull #-}
   transposableSquare ::
      (Omni.Strip lower, Omni.Strip upper) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
      Transposition ->
      Matrix.Quadratic typ xl xu lower upper height a ->
      Full meas vert horiz height width a ->
      Full meas vert horiz height width a
   transposableSquare NonTransposed a b = squareFull a b
   transposableSquare Transposed a b =
      Unpacked.transpose $ fullSquare (Unpacked.transpose b) a

   squareFull ::
      (Omni.Strip lower, Omni.Strip upper) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
      Matrix.Quadratic typ xl xu lower upper height a ->
      Full meas vert horiz height width a ->
      Full meas vert horiz height width a
   squareFull = transposableSquare NonTransposed

   fullSquare ::
      (Omni.Strip lower, Omni.Strip upper) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
      Full meas vert horiz height width a ->
      Matrix.Quadratic typ xl xu lower upper width a ->
      Full meas vert horiz height width a
   fullSquare b a =
      Unpacked.transpose $
      transposableSquare Transposed a $ Unpacked.transpose b


type Unpacked lower upper meas vert horiz height width =
         Unpacked.Unpacked MatrixShape.Arbitrary
            lower upper meas vert horiz height width

infixl 7 ##*#, #*#
infixr 7 #*##

(#*##) ::
   (MultiplySquare typ xl xu, Matrix.ToQuadratic typ) =>
   (Omni.Strip lowerA, Omni.Strip upperA) =>
   (Omni.Strip lowerB, Omni.Strip upperB) =>
   (Omni.Strip lowerC, Omni.Strip upperC) =>
   (Omni.MultipliedBands lowerA lowerB ~ lowerC) =>
   (Omni.MultipliedBands lowerB lowerA ~ lowerC) =>
   (Omni.MultipliedBands upperA upperB ~ upperC) =>
   (Omni.MultipliedBands upperB upperA ~ upperC) =>
   (Extent.Measure measA, Extent.Measure measB, Extent.Measure measC,
    Extent.MultiplyMeasure measA measB ~ measC) =>
   (Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C fuse, Eq fuse, Shape.C width, Class.Floating a) =>
   Matrix.QuadraticMeas typ xl xu lowerA upperA measA height fuse a ->
   Unpacked lowerB upperB measB vert horiz fuse width a ->
   Unpacked lowerC upperC measC vert horiz height width a
a#*##b =
   case factorIdentityLeft a of
      (ident, q) ->
         ArrMatrix.liftUnpacked1 id $
         reshapeHeight ident (squareFull q (ArrMatrix.liftUnpacked1 id b))

(##*#) ::
   (MultiplySquare typ xl xu, Matrix.ToQuadratic typ) =>
   (Omni.Strip lowerA, Omni.Strip upperA) =>
   (Omni.Strip lowerB, Omni.Strip upperB) =>
   (Omni.Strip lowerC, Omni.Strip upperC) =>
   (Omni.MultipliedBands lowerA lowerB ~ lowerC) =>
   (Omni.MultipliedBands lowerB lowerA ~ lowerC) =>
   (Omni.MultipliedBands upperA upperB ~ upperC) =>
   (Omni.MultipliedBands upperB upperA ~ upperC) =>
   (Extent.Measure measA, Extent.Measure measB, Extent.Measure measC,
    Extent.MultiplyMeasure measA measB ~ measC) =>
   (Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C fuse, Eq fuse, Shape.C width, Class.Floating a) =>
   Unpacked lowerB upperB measB vert horiz height fuse a ->
   Matrix.QuadraticMeas typ xl xu lowerA upperA measA fuse width a ->
   Unpacked lowerC upperC measC vert horiz height width a
b##*#a =
   case factorIdentityRight a of
      (q, ident) ->
         ArrMatrix.liftUnpacked1 id $
         reshapeWidth (fullSquare (ArrMatrix.liftUnpacked1 id b) q) ident


type IdentityMaes meas =
         Matrix.QuadraticMeas Matrix.Identity () ()
            MatrixShape.Empty MatrixShape.Empty meas

factorIdentityLeft ::
   (Matrix.ToQuadratic typ, Extent.Measure meas) =>
   Matrix.QuadraticMeas typ xl xu lower upper meas height width a ->
   (IdentityMaes meas height width a,
    Matrix.Quadratic typ xl xu lower upper width a)
factorIdentityLeft a =
   (Matrix.Identity $ Matrix.extent a, Matrix.widthToQuadratic a)

factorIdentityRight ::
   (Matrix.ToQuadratic typ, Extent.Measure meas) =>
   Matrix.QuadraticMeas typ xl xu lower upper meas height width a ->
   (Matrix.Quadratic typ xl xu lower upper height a,
    IdentityMaes meas height width a)
factorIdentityRight a =
   (Matrix.heightToQuadratic a, Matrix.Identity $ Matrix.extent a)

reshapeHeight ::
   (Extent.Measure measA, Extent.Measure measB, Extent.Measure measC,
    Extent.MultiplyMeasure measA measB ~ measC,
    Extent.C vert, Extent.C horiz, Eq fuse) =>
   IdentityMaes measA height fuse a ->
   Full measB vert horiz fuse width a ->
   Full measC vert horiz height width a
reshapeHeight (Matrix.Identity extentA) =
   ArrMatrix.lift1 $ FullBasic.mapExtent $ \extentB ->
      fromMaybe (error "Multiply.reshapeHeight: shapes mismatch") $
      Extent.fuse
         (ExtentPriv.unifyLeft extentA extentB)
         (ExtentPriv.relaxMeasureWith extentA extentB)

reshapeWidth ::
   (Extent.Measure measA, Extent.Measure measB, Extent.Measure measC,
    Extent.MultiplyMeasure measA measB ~ measC,
    Extent.C vert, Extent.C horiz, Eq fuse) =>
   Full measB vert horiz height fuse a ->
   IdentityMaes measA fuse width a ->
   Full measC vert horiz height width a
reshapeWidth = flip $ \(Matrix.Identity extentA) ->
   ArrMatrix.lift1 $ FullBasic.mapExtent $ \extentB ->
      fromMaybe (error "Multiply.reshapeWidth: shapes mismatch") $
      Extent.fuse
         (ExtentPriv.relaxMeasureWith extentA extentB)
         (Extent.transpose $
          ExtentPriv.unifyLeft
            (Extent.transpose extentA) (Extent.transpose extentB))


instance (xl ~ (), xu ~ ()) => MultiplySquare Matrix.Scale xl xu where
   transposableSquare _trans =
      scaleWithCheck "Matrix.Multiply.transposableSquare" Matrix.height $
         ArrMatrix.lift1 . Vector.scale

instance (xl ~ (), xu ~ ()) => MultiplySquare Matrix.Permutation xl xu where
   transposableSquare =
      PermMatrix.multiplyFull . Perm.inversionFromTransposition

instance
   (Layout.Packing pack, Omni.Property property, xl ~ (), xu ~ ()) =>
      MultiplySquare (ArrMatrix.Array pack property) xl xu where
   transposableSquare = Multiply.transposableSquare
   fullSquare = Multiply.fullSquare
   squareFull = Multiply.squareFull


class (Matrix.Box typ) => Power typ xl xu where
   square ::
      (MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper) =>
      (Shape.C sh, Class.Floating a) =>
      Matrix.Quadratic typ xl xu lower upper sh a ->
      Matrix.Quadratic typ xl xu lower upper sh a
   power ::
      (MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper) =>
      (Shape.C sh, Class.Floating a) =>
      Integer ->
      Matrix.Quadratic typ xl xu lower upper sh a ->
      Matrix.Quadratic typ xl xu lower upper sh a
   powers1 ::
      (MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper) =>
      (Shape.C sh, Class.Floating a) =>
      Matrix.Quadratic typ xl xu lower upper sh a ->
      Stream (Matrix.Quadratic typ xl xu lower upper sh a)

instance (xl ~ (), xu ~ ()) => Power Matrix.Scale xl xu where
   square (Matrix.Scale sh a) = Matrix.Scale sh (a*a)
   power n (Matrix.Scale sh a) = Matrix.Scale sh (a^n)
   powers1 (Matrix.Scale sh a) =
      fmap (Matrix.Scale sh) $ Stream.iterate (*a) a

instance (xl ~ (), xu ~ ()) => Power Matrix.Permutation xl xu where
   square (Matrix.Permutation p) = Matrix.Permutation $ Perm.square p
   power n (Matrix.Permutation p) = Matrix.Permutation $ Perm.power n p
   powers1 (Matrix.Permutation p) =
      fmap Matrix.Permutation $ Stream.iterate (flip Perm.multiplyUnchecked p) p

instance
   (Layout.Packing pack, Omni.Property property, xl ~ (), xu ~ ()) =>
      Power (ArrMatrix.Array pack property) xl xu where
   square = Multiply.square
   power = Multiply.power
   powers1 = Multiply.powers1

powers ::
   (Power typ xl xu, MatrixClass.SquareShape typ) =>
   (MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper) =>
   (Shape.C sh, Class.Floating a) =>
   Matrix.Quadratic typ xl xu lower upper sh a ->
   Stream (Matrix.Quadratic typ xl xu lower upper sh a)
powers a = Stream.Cons (MatrixClass.identityFrom a) (powers1 a)


(#*#) ::
   (Matrix.Box typA, Omni.Strip lowerA, Omni.Strip upperA) =>
   (Matrix.Box typB, Omni.Strip lowerB, Omni.Strip upperB) =>
   (Matrix.Box typC, Omni.Strip lowerC, Omni.Strip upperC) =>
   (Multiply typA xlA xuA typB xlB xuB lowerC upperC measC) =>
   (Multiplied typA xlA xuA typB xlB xuB lowerC upperC measC ~ typC) =>
   (MultipliedExtra typA xlA xuA typB xlB xuB ~ xlC) =>
   (MultipliedExtra typA xuA xlA typB xuB xlB ~ xuC) =>
   (Omni.MultipliedStrip lowerA lowerB ~ lowerC) =>
   (Omni.MultipliedStrip lowerB lowerA ~ lowerC) =>
   (Omni.MultipliedStrip upperA upperB ~ upperC) =>
   (Omni.MultipliedStrip upperB upperA ~ upperC) =>
   (Omni.MultipliedBands lowerA lowerB ~ lowerC) =>
   (Omni.MultipliedBands lowerB lowerA ~ lowerC) =>
   (Omni.MultipliedBands upperA upperB ~ upperC) =>
   (Omni.MultipliedBands upperB upperA ~ upperC) =>
   (Extent.Measure measA, Extent.C vertA, Extent.C horizA) =>
   (Extent.Measure measB, Extent.C vertB, Extent.C horizB) =>
   (ExtentPriv.MultiplyMeasure measA measB ~ measC) =>
   (ExtentPriv.MultiplyMeasure measB measA ~ measC) =>
   (ExtentPriv.Multiply vertA  vertB  ~ vertC)  =>
   (ExtentPriv.Multiply vertB  vertA  ~ vertC)  =>
   (ExtentPriv.Multiply horizA horizB ~ horizC) =>
   (ExtentPriv.Multiply horizB horizA ~ horizC) =>
   (Shape.C height, Shape.C fuse, Eq fuse, Shape.C width) =>
   (Class.Floating a) =>
   Matrix typA xlA xuA lowerA upperA measA vertA horizA height fuse a ->
   Matrix typB xlB xuB lowerB upperB measB vertB horizB fuse width a ->
   Matrix typC xlC xuC lowerC upperC measC vertC horizC height width a
(#*#) = matrixMatrix

class
   (Matrix.Box typA, Matrix.Box typB) =>
      Multiply typA xlA xuA typB xlB xuB lowerC upperC measC where
   type Multiplied typA xlA xuA typB xlB xuB lowerC upperC measC
   type MultipliedExtra typA xlA xuA typB xlB xuB
   matrixMatrix ::
      (Matrix.Box typA, Omni.Strip lowerA, Omni.Strip upperA) =>
      (Matrix.Box typB, Omni.Strip lowerB, Omni.Strip upperB) =>
      (Matrix.Box typC, Omni.Strip lowerC, Omni.Strip upperC) =>
      (Multiplied typA xlA xuA typB xlB xuB lowerC upperC measC ~ typC) =>
      (MultipliedExtra typA xlA xuA typB xlB xuB ~ xlC) =>
      (MultipliedExtra typA xuA xlA typB xuB xlB ~ xuC) =>
      (Omni.MultipliedStrip lowerA lowerB ~ lowerC) =>
      (Omni.MultipliedStrip lowerB lowerA ~ lowerC) =>
      (Omni.MultipliedStrip upperA upperB ~ upperC) =>
      (Omni.MultipliedStrip upperB upperA ~ upperC) =>
      (Omni.MultipliedBands lowerA lowerB ~ lowerC) =>
      (Omni.MultipliedBands lowerB lowerA ~ lowerC) =>
      (Omni.MultipliedBands upperA upperB ~ upperC) =>
      (Omni.MultipliedBands upperB upperA ~ upperC) =>
      (Extent.Measure measA, Extent.C vertA, Extent.C horizA) =>
      (Extent.Measure measB, Extent.C vertB, Extent.C horizB) =>
      (ExtentPriv.MultiplyMeasure measA measB ~ measC) =>
      (ExtentPriv.MultiplyMeasure measB measA ~ measC) =>
      (ExtentPriv.Multiply vertA  vertB  ~ vertC)  =>
      (ExtentPriv.Multiply vertB  vertA  ~ vertC)  =>
      (ExtentPriv.Multiply horizA horizB ~ horizC) =>
      (ExtentPriv.Multiply horizB horizA ~ horizC) =>
      (Shape.C height, Shape.C fuse, Eq fuse, Shape.C width) =>
      (Class.Floating a) =>
      Matrix typA xlA xuA lowerA upperA measA vertA horizA height fuse a ->
      Matrix typB xlB xuB lowerB upperB measB vertB horizB fuse width a ->
      Matrix typC xlC xuC lowerC upperC measC vertC horizC height width a

instance
   (Layout.Packing packA, Omni.Property propertyA, xlA ~ (), xuA ~ (),
    Layout.Packing packB, Omni.Property propertyB, xlB ~ (), xuB ~ (),
    Layout.Packing packC, Omni.Property propertyC,
    Multiply.MultipliedPacking packA packB ~ pack,
    Multiply.MultipliedPacking packB packA ~ pack,
    Omni.MultipliedProperty propertyA propertyB ~ propertyAB,
    Omni.MultipliedProperty propertyB propertyA ~ propertyAB,
    Omni.UnitIfTriangular lowerC upperC ~ diag,
    Omni.UnitIfTriangular upperC lowerC ~ diag,
    Multiply.PackingByStrip lowerC upperC measC pack ~ packC,
    Multiply.PackingByStrip upperC lowerC measC pack ~ packC,
    Omni.MergeUnit propertyAB diag ~ propertyC,
    Omni.MergeUnit diag propertyAB ~ propertyC) =>
      Multiply
         (ArrMatrix.Array packA propertyA) xlA xuA
         (ArrMatrix.Array packB propertyB) xlB xuB
         lowerC upperC measC where
   type Multiplied
            (ArrMatrix.Array packA propertyA) xlA xuA
            (ArrMatrix.Array packB propertyB) xlB xuB
            lowerC upperC measC =
         ArrMatrix.Array
            (Multiply.PackingByStrip lowerC upperC measC
               (Multiply.MultipliedPacking packA packB))
            (Omni.MergeUnit
               (Omni.MultipliedProperty propertyA propertyB)
               (Omni.UnitIfTriangular lowerC upperC))
   type MultipliedExtra
            (ArrMatrix.Array packA propertyA) xlA xuA
            (ArrMatrix.Array packB propertyB) xlB xuB = ()
   matrixMatrix = Multiply.matrixMatrix


instance
   (xlA ~ (), xuA ~ (), xlB ~ (), xuB ~ (),
    lowerC ~ MatrixShape.Empty, upperC ~ MatrixShape.Empty) =>
      Multiply Matrix.Scale xlA xuA Matrix.Scale xlB xuB lowerC upperC measC
         where
   type Multiplied
            Matrix.Scale xlA xuA Matrix.Scale xlB xuB lowerC upperC measC =
               Matrix.Scale
   type MultipliedExtra Matrix.Scale xlA xuA Matrix.Scale xlB xuB = ()
   matrixMatrix a@(Matrix.Scale _ _) b@(Matrix.Scale _ _) =
      Matrix.multiplySame a b

instance
   (ArrMatrix.Scale property, xlA ~ (), xuA ~ (), xlB ~ (), xuB ~ ()) =>
      Multiply Matrix.Scale xlA xuA (ArrMatrix.Array pack property) xlB xuB
            lowerC upperC measC
         where
   type Multiplied Matrix.Scale xlA xuA (ArrMatrix.Array pack property) xlB xuB
            lowerC upperC measC
         = ArrMatrix.Array pack property
   type MultipliedExtra
            Matrix.Scale xlA xuA (ArrMatrix.Array pack property) xlB xuB = ()
   matrixMatrix a@(Matrix.Scale _ _) =
      scaleWithCheck "Matrix.Multiply.multiply Scale" Matrix.height
         ArrMatrix.scale a

instance
   (ArrMatrix.Scale property, xlA ~ (), xuA ~ (), xlB ~ (), xuB ~ ()) =>
      Multiply (ArrMatrix.Array pack property) xlA xuA Matrix.Scale xlB xuB
            lowerC upperC measC
         where
   type Multiplied (ArrMatrix.Array pack property) xlA xuA Matrix.Scale xlB xuB
            lowerC upperC measC
         = ArrMatrix.Array pack property
   type MultipliedExtra
            (ArrMatrix.Array pack property) xlA xuA Matrix.Scale xlB xuB = ()
   matrixMatrix = flip $ \b@(Matrix.Scale _ _) a@(ArrMatrix.Array _) ->
      scaleWithCheck "Matrix.Multiply.multiply Scale" Matrix.width
         ArrMatrix.scale b a


instance
   (xlA ~ (), xuA ~ (), xlB ~ (), xuB ~ ()) =>
      Multiply Matrix.Permutation xlA xuA Matrix.Permutation xlB xuB
         lowerC upperC measC where
   type Multiplied Matrix.Permutation xlA xuA Matrix.Permutation xlB xuB
            lowerC upperC measC  =  Matrix.Permutation
   type MultipliedExtra
            Matrix.Permutation xlA xuA Matrix.Permutation xlB xuB = ()
   matrixMatrix (Matrix.Permutation a) (Matrix.Permutation b) =
      Matrix.Permutation $ Perm.multiply b a