{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Orthogonal.Basic (
   leastSquares,
   minimumNorm,
   leastSquaresMinimumNormRCond,
   pseudoInverseRCond,

   leastSquaresConstraint,
   gaussMarkovLinearModel,

   determinantAbsolute,
   complement,
   extractComplement,
   ) where

import qualified Numeric.LAPACK.Orthogonal.Plain as HH

import qualified Numeric.LAPACK.Matrix.Square.Basic as Square
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Layout.Private (Order(RowMajor,ColumnMajor))
import Numeric.LAPACK.Matrix.Private
         (Full, General, Tall, Wide, ShapeInt, shapeInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, zero, absolute)
import Numeric.LAPACK.Private
         (lacgv, peekCInt,
          copySubMatrix, copyToTemp,
          copyToColumnMajorTemp, copyToSubColumnMajor,
          withAutoWorkspaceInfo, rankMsg, errorCodeMsg, createHigherArray)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.LAPACK.FFI.Complex as LapackComplex
import qualified Numeric.LAPACK.FFI.Real as LapackReal
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import System.IO.Unsafe (unsafePerformIO)

import Foreign.Marshal.Array (pokeArray)
import Foreign.C.Types (CInt, CChar)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)

import Data.Complex (Complex)
import Data.Tuple.HT (mapSnd)


leastSquares ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) =>
   Full meas horiz Extent.Small height width a ->
   Full meas vert horiz height nrhs a ->
   Full meas vert horiz width nrhs a
leastSquares
   (Array shapeA@(Layout.Full orderA extentA) a)
   (Array        (Layout.Full orderB extentB) b) =

 case Extent.fuse (Extent.weakenWide $ Extent.transpose extentA) extentB of
  Nothing -> error "leastSquares: height shapes mismatch"
  Just extent ->
      Array.unsafeCreate (Layout.Full ColumnMajor extent) $ \xPtr -> do

   let widthA = Extent.width extentA
   let (height,widthB) = Extent.dimensions extentB
   let (m,n) = Layout.dimensions shapeA
   let lda = m
   let nrhs = Shape.size widthB
   let ldb = Shape.size height
   let ldx = Shape.size widthA
   evalContT $ do
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      nrhsPtr <- Call.cint nrhs
      (transPtr,aPtr) <- adjointA orderA (m*n) a
      ldaPtr <- Call.leadingDim lda
      ldbPtr <- Call.leadingDim ldb
      bPtr <- copyToColumnMajorTemp orderB ldb nrhs b
      liftIO $ withAutoWorkspaceInfo rankMsg "gels" $
         LapackGen.gels transPtr
            mPtr nPtr nrhsPtr aPtr ldaPtr bPtr ldbPtr
      liftIO $ copySubMatrix ldx nrhs ldb bPtr ldx xPtr

minimumNorm ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) =>
   Full meas Extent.Small vert height width a ->
   Full meas vert horiz height nrhs a ->
   Full meas vert horiz width nrhs a
minimumNorm
   (Array shapeA@(Layout.Full orderA extentA) a)
   (Array        (Layout.Full orderB extentB) b) =

 case Extent.fuse (Extent.weakenTall $ Extent.transpose extentA) extentB of
  Nothing -> error "minimumNorm: height shapes mismatch"
  Just extent ->
      Array.unsafeCreate (Layout.Full ColumnMajor extent) $ \xPtr -> do

   let widthA = Extent.width extentA
   let (height,widthB) = Extent.dimensions extentB
   let (m,n) = Layout.dimensions shapeA
   let lda = m
   let nrhs = Shape.size widthB
   let ldb = Shape.size height
   let ldx = Shape.size widthA
   evalContT $ do
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      nrhsPtr <- Call.cint nrhs
      (transPtr,aPtr) <- adjointA orderA (m*n) a
      ldaPtr <- Call.leadingDim lda
      bPtr <- ContT $ withForeignPtr b
      ldxPtr <- Call.leadingDim ldx
      liftIO $ copyToSubColumnMajor orderB ldb nrhs bPtr ldx xPtr
      liftIO $ withAutoWorkspaceInfo rankMsg "gels" $
         LapackGen.gels transPtr
            mPtr nPtr nrhsPtr aPtr ldaPtr xPtr ldxPtr


adjointA ::
   Class.Floating a =>
   Order -> Int -> ForeignPtr a -> ContT r IO (Ptr CChar, Ptr a)
adjointA order size a = do
   aPtr <- copyToTemp size a
   trans <-
      case order of
         RowMajor -> do
            sizePtr <- Call.cint size
            incPtr <- Call.cint 1
            liftIO $ lacgv sizePtr aPtr incPtr
            return $ HH.invChar a
         ColumnMajor -> return 'N'
   transPtr <- Call.char trans
   return (transPtr, aPtr)


leastSquaresMinimumNormRCond ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) =>
   RealOf a ->
   Full meas horiz vert height width a ->
   Full meas vert horiz height nrhs a ->
   (Int, Full meas vert horiz width nrhs a)
leastSquaresMinimumNormRCond rcond
      (Array (Layout.Full orderA extentA) a)
      (Array (Layout.Full orderB extentB) b) =
   case Extent.fuse (Extent.transpose extentA) extentB of
      Nothing -> error "leastSquaresMinimumNormRCond: height shapes mismatch"
      Just extent ->
         let widthA = Extent.width extentA
             (height,widthB) = Extent.dimensions extentB
             shapeX = Layout.Full ColumnMajor extent
             m = Shape.size height
             n = Shape.size widthA
             nrhs = Shape.size widthB
         in  if m == 0
                then (0, Vector.zero shapeX)
                else
                  if nrhs == 0
                     then
                        (fst $ unsafePerformIO $
                         case Vector.zero height of
                           Array _ b1 ->
                              leastSquaresMinimumNormIO rcond
                                 (Layout.general ColumnMajor widthA ())
                                 orderA a orderB b1 m n 1,
                         Vector.zero shapeX)
                     else
                        unsafePerformIO $
                        leastSquaresMinimumNormIO rcond shapeX
                           orderA a orderB b m n nrhs

leastSquaresMinimumNormIO ::
   (Shape.C sh, Class.Floating a) =>
   RealOf a -> sh ->
   Order -> ForeignPtr a ->
   Order -> ForeignPtr a ->
   Int -> Int -> Int -> IO (Int, Array sh a)
leastSquaresMinimumNormIO rcond shapeX orderA a orderB b m n nrhs =
   createHigherArray shapeX m n nrhs $ \(tmpPtr,ldtmp) -> do

   let lda = m
   evalContT $ do
      aPtr <- copyToColumnMajorTemp orderA m n a
      ldaPtr <- Call.leadingDim lda
      ldtmpPtr <- Call.leadingDim ldtmp
      bPtr <- ContT $ withForeignPtr b
      liftIO $ copyToSubColumnMajor orderB m nrhs bPtr ldtmp tmpPtr
      jpvtPtr <- Call.allocaArray n
      liftIO $ pokeArray jpvtPtr (replicate n 0)
      rankPtr <- Call.alloca
      gelsy m n nrhs aPtr ldaPtr tmpPtr ldtmpPtr jpvtPtr rcond rankPtr
      liftIO $ peekCInt rankPtr


type GELSY_ r ar a =
   Int -> Int -> Int -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt ->
   Ptr CInt -> ar -> Ptr CInt -> ContT r IO ()

newtype GELSY r a = GELSY {getGELSY :: GELSY_ r (RealOf a) a}

gelsy :: (Class.Floating a) => GELSY_ r (RealOf a) a
gelsy =
   getGELSY $
   Class.switchFloating
      (GELSY gelsyReal)
      (GELSY gelsyReal)
      (GELSY gelsyComplex)
      (GELSY gelsyComplex)

gelsyReal :: (Class.Real a) => GELSY_ r a a
gelsyReal m n nrhs aPtr ldaPtr bPtr ldbPtr jpvtPtr rcond rankPtr = do
   mPtr <- Call.cint m
   nPtr <- Call.cint n
   nrhsPtr <- Call.cint nrhs
   rcondPtr <- Call.real rcond
   liftIO $ withAutoWorkspaceInfo errorCodeMsg "gelsy" $
      LapackReal.gelsy mPtr nPtr nrhsPtr
         aPtr ldaPtr bPtr ldbPtr jpvtPtr rcondPtr rankPtr

gelsyComplex :: (Class.Real a) => GELSY_ r a (Complex a)
gelsyComplex m n nrhs aPtr ldaPtr bPtr ldbPtr jpvtPtr rcond rankPtr = do
   mPtr <- Call.cint m
   nPtr <- Call.cint n
   nrhsPtr <- Call.cint nrhs
   rcondPtr <- Call.real rcond
   rworkPtr <- Call.allocaArray (2*n)
   liftIO $
      withAutoWorkspaceInfo errorCodeMsg "gelsy" $ \workPtr lworkPtr infoPtr ->
      LapackComplex.gelsy mPtr nPtr nrhsPtr
         aPtr ldaPtr bPtr ldbPtr jpvtPtr rcondPtr rankPtr
         workPtr lworkPtr rworkPtr infoPtr


pseudoInverseRCond ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   RealOf a ->
   Full meas vert horiz height width a ->
   (Int, Full meas horiz vert width height a)
pseudoInverseRCond rcond a =
   case Basic.caseTallWide a of
      Left _ ->
         mapSnd Basic.transpose $
         leastSquaresMinimumNormRCond rcond (Basic.transpose a) $
         Square.toFull $ Square.identity $
         Layout.fullWidth $ Array.shape a
      Right _ ->
         leastSquaresMinimumNormRCond rcond a $
         Square.toFull $ Square.identity $
         Layout.fullHeight $ Array.shape a


leastSquaresConstraint ::
   (Shape.C height, Eq height,
    Shape.C width, Eq width,
    Shape.C constraints, Eq constraints, Class.Floating a) =>
   General height width a -> Vector height a ->
   Wide constraints width a -> Vector constraints a ->
   Vector width a
leastSquaresConstraint
   (Array (Layout.Full orderA extentA) a) c
   (Array (Layout.Full orderB extentB) b) d =

 let sameShape name shape0 shape1 =
      if shape0 == shape1
         then shape0
         else error $ "leastSquaresConstraint: " ++ name ++ " shapes mismatch"
     width = sameShape "width" (Extent.width extentA) (Extent.width extentB)
 in
   Array.unsafeCreate width $ \xPtr -> do

   let height = sameShape "height" (Extent.height extentA) (Array.shape c)
   let constraints =
         sameShape "constraints" (Extent.height extentB) (Array.shape d)
   let m = Shape.size height
   let n = Shape.size width
   let p = Shape.size constraints
   evalContT $ do
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      pPtr <- Call.cint p
      aPtr <- copyToColumnMajorTemp orderA m n a
      ldaPtr <- Call.leadingDim m
      bPtr <- copyToColumnMajorTemp orderB p n b
      ldbPtr <- Call.leadingDim p
      cPtr <- copyToTemp m (Array.buffer c)
      dPtr <- copyToTemp p (Array.buffer d)
      liftIO $ withAutoWorkspaceInfo rankMsg "gglse" $
         LapackGen.gglse
            mPtr nPtr pPtr aPtr ldaPtr bPtr ldbPtr cPtr dPtr xPtr

gaussMarkovLinearModel ::
   (Shape.C height, Eq height,
    Shape.C width, Eq width,
    Shape.C opt, Eq opt, Class.Floating a) =>
   Tall height width a -> General height opt a -> Vector height a ->
   (Vector width a, Vector opt a)
gaussMarkovLinearModel
   (Array (Layout.Full orderA extentA) a)
   (Array (Layout.Full orderB extentB) b) d =

   let width = Extent.width extentA in
   let opt = Extent.width extentB in
   Array.unsafeCreateWithSizeAndResult width $ \m xPtr -> do
   ArrayIO.unsafeCreateWithSize opt $ \p yPtr -> do

   let sameHeight shape0 shape1 =
         if shape0 == shape1
            then shape0
            else error $ "gaussMarkovLinearModel: height shapes mismatch"
       height =
         sameHeight (Array.shape d) $
         sameHeight (Extent.height extentA) (Extent.height extentB)
   let n = Shape.size height
   evalContT $ do
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      pPtr <- Call.cint p
      aPtr <- copyToColumnMajorTemp orderA n m a
      ldaPtr <- Call.leadingDim n
      bPtr <- copyToColumnMajorTemp orderB n p b
      ldbPtr <- Call.leadingDim n
      dPtr <- copyToTemp n (Array.buffer d)
      liftIO $ withAutoWorkspaceInfo rankMsg "ggglm" $
         LapackGen.ggglm
            nPtr mPtr pPtr aPtr ldaPtr bPtr ldbPtr dPtr xPtr yPtr


determinantAbsolute ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Full meas vert horiz height width a -> RealOf a
determinantAbsolute =
   absolute .
   either (HH.determinantR . HH.fromMatrix) (const zero) .
   Basic.caseTallWide


complement ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Tall height width a -> Tall height ShapeInt a
complement = extractComplement . HH.fromMatrix

extractComplement ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   HH.Tall height width a -> Tall height ShapeInt a
extractComplement qr =
   Basic.dropColumns
      (Shape.size $ Layout.splitWidth $ Array.shape $ HH.split_ qr) $
   Basic.mapWidth (shapeInt . Shape.size) $ Square.toFull $
   HH.extractQ qr