{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.RowMajor (
   Matrix,
   Matrix.takeRow,
   Matrix.takeColumn,
   Matrix.fromRows,
   Matrix.tensorProduct,
   Matrix.decomplex,
   Matrix.recomplex,
   Matrix.scaleRows,
   Matrix.scaleColumns,
   kronecker,
   ) where

import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Private (Full)

import qualified Numeric.BLAS.Matrix.RowMajor as Matrix
import Numeric.BLAS.Matrix.RowMajor (Matrix)
import Numeric.BLAS.Matrix.Layout (Order(RowMajor, ColumnMajor))
import Numeric.BLAS.Scalar (zero, one)

import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import Foreign.Marshal.Array (advancePtr)
import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Applicative (liftA2)

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Data.Foldable (forM_)


kronecker ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB,
    Class.Floating a) =>
   Full meas vert horiz heightA widthA a ->
   Matrix heightB widthB a ->
   Matrix (heightA,heightB) (widthA,widthB) a
kronecker
      (Array (Layout.Full orderA extentA) a) (Array (heightB,widthB) b) =
   let (heightA,widthA) = Extent.dimensions extentA
   in Array.unsafeCreate ((heightA,heightB), (widthA,widthB)) $ \cPtr ->
      evalContT $ do
   let (ma,na) = (Shape.size heightA, Shape.size widthA)
   let (mb,nb) = (Shape.size heightB, Shape.size widthB)
   let (lda,istep) =
         case orderA of
            RowMajor -> (1,na)
            ColumnMajor -> (ma,1)
   transaPtr <- Call.char 'N'
   transbPtr <- Call.char 'T'
   mPtr <- Call.cint na
   nPtr <- Call.cint nb
   kPtr <- Call.cint 1
   alphaPtr <- Call.number one
   aPtr <- ContT $ withForeignPtr a
   ldaPtr <- Call.leadingDim lda
   bPtr <- ContT $ withForeignPtr b
   ldbPtr <- Call.leadingDim 1
   betaPtr <- Call.number zero
   ldcPtr <- Call.leadingDim nb
   liftIO $
      forM_ (liftA2 (,) (take ma [0..]) (take mb [0..])) $ \(i,j) -> do
         let aiPtr = advancePtr aPtr (istep*i)
         let bjPtr = advancePtr bPtr (nb*j)
         let cijPtr = advancePtr cPtr (na*nb*(j+mb*i))
         BlasGen.gemm
            transbPtr transaPtr nPtr mPtr kPtr alphaPtr
            bjPtr ldbPtr aiPtr ldaPtr betaPtr cijPtr ldcPtr