{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Mosaic.Unpacked where

import qualified Numeric.LAPACK.Matrix.Square.Basic as Square
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Mosaic.Private (uncheck, recheck)
import Numeric.LAPACK.Matrix.Private (Square, Full)
import Numeric.LAPACK.Matrix.Shape.Omni
         (TriDiag, DiagSingleton, charFromTriDiag)
import Numeric.LAPACK.Matrix.Layout.Private
         (PackingSingleton(Unpacked), MirrorSingleton,
          Order(RowMajor,ColumnMajor),
          flipOrder, transposeFromOrder,
          sideSwapFromOrder, uploFromOrder, uploOrder)
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked))
import Numeric.LAPACK.Scalar (one, zero)
import Numeric.LAPACK.Private (copyBlock, conjugateToTemp)

import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)

import qualified Data.Stream as Stream
import Data.Stream (Stream)
import Data.Function.HT (powerAssociative)
import Data.Tuple.HT (double)



type Mosaic mirror uplo sh =
         Array (Layout.Mosaic Layout.Unpacked mirror uplo sh)
type Triangular uplo sh =
         Array (Layout.TriangularP Layout.Unpacked uplo sh)


fromSquare ::
   (Layout.UpLo uplo) =>
   MirrorSingleton mirror -> Square sh a -> Mosaic mirror uplo sh a
fromSquare mirror =
   Array.mapShape
      (\(Layout.Full order extent) ->
         Layout.Mosaic Unpacked mirror Layout.autoUplo order $
         Extent.squareSize extent)

toSquare :: Mosaic mirror uplo sh a -> Square sh a
toSquare =
   Array.mapShape
      (\(Layout.Mosaic Unpacked _mirror _uplo order size) ->
         Layout.square order size)

forceOrder ::
   (Layout.UpLo uplo, Shape.C sh, Class.Floating a) =>
   Order -> Mosaic mirror uplo sh a -> Mosaic mirror uplo sh a
forceOrder newOrder a =
   fromSquare (Layout.mosaicMirror $ Array.shape a) .
   Basic.forceOrder newOrder . toSquare $ a



square ::
   (TriDiag diag, Layout.UpLo uplo, Shape.C sh, Class.Floating a) =>
   DiagSingleton diag -> Mosaic mirror uplo sh a -> Mosaic mirror uplo sh a
square diag  =  recheck . uncurry (multiplyCompatible diag) . double . uncheck

power ::
   (Layout.UpLo uplo, TriDiag diag, Shape.C sh, Class.Floating a) =>
   DiagSingleton diag ->
   Integer -> Mosaic mirror uplo sh a -> Mosaic mirror uplo sh a
power diag n
   a@(Array (Layout.Mosaic Layout.Unpacked mirror _uplo order sh) _) =

   recheck $
   powerAssociative (multiplyCompatible diag)
      (fromSquare mirror $ Square.identityOrder order $ Unchecked sh)
      (uncheck a)
      n

powers1 ::
   (Layout.UpLo uplo, TriDiag diag, Shape.C sh, Class.Floating a) =>
   DiagSingleton diag ->
   Mosaic mirror uplo sh a -> Stream (Mosaic mirror uplo sh a)
powers1 diag a =
   fmap recheck $
   let au = uncheck a
   in Stream.iterate (flip (multiplyCompatible diag) au) au


multiplyCompatible ::
   (Layout.UpLo uplo, TriDiag diag, Shape.C sh, Eq sh, Class.Floating a) =>
   DiagSingleton diag ->
   Mosaic mirror uplo sh a ->
   Mosaic mirror uplo sh a -> Mosaic mirror uplo sh a
multiplyCompatible diag a b =
   fromSquare (Layout.mosaicMirror $ Array.shape b) $
   multiplyFull diag a $ toSquare b

multiplyFull ::
   (Layout.UpLo uplo, TriDiag diag,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   DiagSingleton diag ->
   Mosaic mirror uplo height a ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
multiplyFull diag
   (Array (Layout.Mosaic Layout.Unpacked mirror uploA orderA shA) a)
   (Array shapeB@(Layout.Full orderB extentB) b) =
      Array.unsafeCreateWithSize shapeB $ \size cPtr -> do
   let (height,width) = Extent.dimensions extentB
   Call.assert (show mirror ++ ".multiplyFull: shapes mismatch") (shA == height)
   let m0 = Shape.size height
   let n0 = Shape.size width
   let realOrderA = uploOrder uploA orderA
   evalContT $ do
      let (side,(m,n)) = sideSwapFromOrder orderB (m0,n0)
      sidePtr <- Call.char side
      uploPtr <- Call.char $ uploFromOrder realOrderA
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      alphaPtr <- Call.number one
      ldaPtr <- Call.leadingDim m0
      bPtr <- ContT (withForeignPtr b) `asTypeOf` return alphaPtr
      ldbPtr <- Call.leadingDim m
      case mirror of
         Layout.NoMirror -> do
            transPtr <-
               Call.char $ transposeFromOrder $
               case orderB of
                  ColumnMajor -> orderA
                  RowMajor -> flipOrder orderA
            diagPtr <- Call.char $ charFromTriDiag diag
            aPtr <- ContT $ withForeignPtr a
            liftIO $ do
               copyBlock size bPtr cPtr
               BlasGen.trmm sidePtr uploPtr transPtr diagPtr
                  mPtr nPtr alphaPtr aPtr ldaPtr cPtr ldbPtr
         Layout.SimpleMirror -> do
            betaPtr <- Call.number zero
            aPtr <- ContT $ withForeignPtr a
            liftIO $
               BlasGen.symm sidePtr uploPtr mPtr nPtr
                  alphaPtr aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldbPtr
         Layout.ConjugateMirror -> do
            aPtr <-
               if orderA == orderB
                  then ContT $ withForeignPtr a
                  else conjugateToTemp (m0*m0) a
            betaPtr <- Call.number zero
            liftIO $
               BlasGen.hemm sidePtr uploPtr mPtr nPtr
                  alphaPtr aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldbPtr