{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Matrix.Square.Linear (
   solve,
   inverse,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Linear.Private
         (solver, withDeterminantInfo, withInfo, diagonalMsg)
import Numeric.LAPACK.Matrix.Layout.Private (transposeFromOrder)
import Numeric.LAPACK.Matrix.Private (Full, Square, SquareMeas, argSquare)
import Numeric.LAPACK.Private
         (withAutoWorkspaceInfo, copyBlock, copyToTemp, copyToColumnMajorTemp)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import System.IO.Unsafe (unsafePerformIO)

import Foreign.Marshal.Array (peekArray)
import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)


solve, _solve ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Square sh a ->
   Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a
solve =
   argSquare $ \orderA shA a ->
   solver "Square.solve" shA $ \n nPtr nrhsPtr xPtr ldxPtr -> do
      transPtr <- Call.char $ transposeFromOrder orderA
      aPtr <- copyToTemp (n*n) a
      ldaPtr <- Call.leadingDim n
      ipivPtr <- Call.allocaArray n
      liftIO $ do
         withInfo "getrf" $
            LapackGen.getrf nPtr nPtr aPtr ldaPtr ipivPtr
         withInfo "getrs" $
            LapackGen.getrs transPtr nPtr nrhsPtr
               aPtr ldaPtr ipivPtr xPtr ldxPtr

_solve =
   argSquare $ \orderA shA a ->
   solver "Square.solve" shA $ \n nPtr nrhsPtr xPtr ldxPtr -> do
      aPtr <- copyToColumnMajorTemp orderA n n a
      ldaPtr <- Call.leadingDim n
      ipivPtr <- Call.allocaArray n
      liftIO $ do
         withInfo "gesv" $
            LapackGen.gesv nPtr nrhsPtr aPtr ldaPtr ipivPtr xPtr ldxPtr


inverse ::
   (Extent.Measure meas, Shape.C height, Shape.C width, Class.Floating a) =>
   SquareMeas meas height width a -> SquareMeas meas width height a
inverse (Array shape@(Layout.Full _order extent) a) =
   Array.unsafeCreateWithSize (Layout.inverse shape) $
      \blockSize bPtr -> do
   let n = Shape.size $ Extent.height extent
   evalContT $ do
      nPtr <- Call.cint n
      aPtr <- ContT $ withForeignPtr a
      ldbPtr <- Call.leadingDim n
      ipivPtr <- Call.allocaArray n
      liftIO $ when (n>0) $ do
         copyBlock blockSize aPtr bPtr
         withInfo "getrf" $ LapackGen.getrf nPtr nPtr bPtr ldbPtr ipivPtr
         withAutoWorkspaceInfo diagonalMsg "getri" $
            LapackGen.getri nPtr bPtr ldbPtr ipivPtr


determinant :: (Shape.C sh, Class.Floating a) => Square sh a -> a
determinant = argSquare $ \_order sh a -> unsafePerformIO $ do
   let n = Shape.size sh
   evalContT $ do
      nPtr <- Call.cint n
      aPtr <- copyToTemp (n*n) a
      ldaPtr <- Call.leadingDim n
      ipivPtr <- Call.allocaArray n
      liftIO $ withDeterminantInfo "getrf"
         (LapackGen.getrf nPtr nPtr aPtr ldaPtr ipivPtr)
         (do
            det <- Private.product n aPtr (n+1)
            ipiv <- peekArray n ipivPtr
            return $ Perm.condNegate ipiv det)