{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Matrix.Square.Linear (
   solve,
   inverse,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Linear.Private
         (solver, withDeterminantInfo, withInfo, diagonalMsg)
import Numeric.LAPACK.Matrix.Shape.Private (transposeFromOrder)
import Numeric.LAPACK.Matrix.Private (Full, Square, argSquare)
import Numeric.LAPACK.Private
         (withAutoWorkspaceInfo, copyBlock, copyToTemp, copyToColumnMajorTemp)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import System.IO.Unsafe (unsafePerformIO)

import Foreign.Marshal.Array (peekArray)
import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)


solve, _solve ::
   (Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Square sh a -> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solve :: Square sh a
-> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solve =
   (Order
 -> sh
 -> ForeignPtr a
 -> Full vert horiz sh nrhs a
 -> Full vert horiz sh nrhs a)
-> Square sh a
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall sh a b.
(Order -> sh -> ForeignPtr a -> b) -> Square sh a -> b
argSquare ((Order
  -> sh
  -> ForeignPtr a
  -> Full vert horiz sh nrhs a
  -> Full vert horiz sh nrhs a)
 -> Square sh a
 -> Full vert horiz sh nrhs a
 -> Full vert horiz sh nrhs a)
-> (Order
    -> sh
    -> ForeignPtr a
    -> Full vert horiz sh nrhs a
    -> Full vert horiz sh nrhs a)
-> Square sh a
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall a b. (a -> b) -> a -> b
$ \Order
orderA sh
shA ForeignPtr a
a ->
   String
-> sh
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall vert horiz height width a.
(C vert, C horiz, C height, C width, Eq height, Floating a) =>
String
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz height width a
-> Full vert horiz height width a
solver String
"Square.solve" sh
shA ((Int
  -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
 -> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a)
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr -> do
      Ptr CChar
transPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
transposeFromOrder Order
orderA
      Ptr a
aPtr <- Int -> ForeignPtr a -> ContT () IO (Ptr a)
forall a r. Storable a => Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToTemp (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) ForeignPtr a
a
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
n
      Ptr CInt
ipivPtr <- Int -> FortranIO () (Ptr CInt)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
"getrf" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CInt
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> Ptr CInt -> IO ()
LapackGen.getrf Ptr CInt
nPtr Ptr CInt
nPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr CInt
ipivPtr
         String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
"getrs" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.getrs Ptr CChar
transPtr Ptr CInt
nPtr Ptr CInt
nrhsPtr
               Ptr a
aPtr Ptr CInt
ldaPtr Ptr CInt
ipivPtr Ptr a
xPtr Ptr CInt
ldxPtr

_solve :: Square sh a
-> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
_solve =
   (Order
 -> sh
 -> ForeignPtr a
 -> Full vert horiz sh nrhs a
 -> Full vert horiz sh nrhs a)
-> Square sh a
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall sh a b.
(Order -> sh -> ForeignPtr a -> b) -> Square sh a -> b
argSquare ((Order
  -> sh
  -> ForeignPtr a
  -> Full vert horiz sh nrhs a
  -> Full vert horiz sh nrhs a)
 -> Square sh a
 -> Full vert horiz sh nrhs a
 -> Full vert horiz sh nrhs a)
-> (Order
    -> sh
    -> ForeignPtr a
    -> Full vert horiz sh nrhs a
    -> Full vert horiz sh nrhs a)
-> Square sh a
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall a b. (a -> b) -> a -> b
$ \Order
orderA sh
shA ForeignPtr a
a ->
   String
-> sh
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall vert horiz height width a.
(C vert, C horiz, C height, C width, Eq height, Floating a) =>
String
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz height width a
-> Full vert horiz height width a
solver String
"Square.solve" sh
shA ((Int
  -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
 -> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a)
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr -> do
      Ptr a
aPtr <- Order -> Int -> Int -> ForeignPtr a -> ContT () IO (Ptr a)
forall a r.
Floating a =>
Order -> Int -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToColumnMajorTemp Order
orderA Int
n Int
n ForeignPtr a
a
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
n
      Ptr CInt
ipivPtr <- Int -> FortranIO () (Ptr CInt)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
"gesv" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.gesv Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr CInt
ipivPtr Ptr a
xPtr Ptr CInt
ldxPtr


inverse :: (Shape.C sh, Class.Floating a) => Square sh a -> Square sh a
inverse :: Square sh a -> Square sh a
inverse (Array shape :: Square sh
shape@(MatrixShape.Full Order
_order Extent Small Small sh sh
extent) ForeignPtr a
a) =
      Square sh -> (Int -> Ptr a -> IO ()) -> Square sh a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize Square sh
shape ((Int -> Ptr a -> IO ()) -> Square sh a)
-> (Int -> Ptr a -> IO ()) -> Square sh a
forall a b. (a -> b) -> a -> b
$ \Int
blockSize Ptr a
bPtr -> do
   let n :: Int
n = sh -> Int
forall sh. C sh => sh -> Int
Shape.size (sh -> Int) -> sh -> Int
forall a b. (a -> b) -> a -> b
$ Extent Small Small sh sh -> sh
forall height width. Extent Small Small height width -> height
Extent.squareSize Extent Small Small sh sh
extent
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
n
      Ptr CInt
ipivPtr <- Int -> FortranIO () (Ptr CInt)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
nInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
         Int -> Ptr a -> Ptr a -> IO ()
forall a. Floating a => Int -> Ptr a -> Ptr a -> IO ()
copyBlock Int
blockSize Ptr a
aPtr Ptr a
bPtr
         String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
"getrf" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CInt
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> Ptr CInt -> IO ()
LapackGen.getrf Ptr CInt
nPtr Ptr CInt
nPtr Ptr a
bPtr Ptr CInt
ldbPtr Ptr CInt
ipivPtr
         String
-> String -> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO ()
forall a.
Floating a =>
String
-> String -> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO ()
withAutoWorkspaceInfo String
diagonalMsg String
"getri" ((Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO ())
-> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.getri Ptr CInt
nPtr Ptr a
bPtr Ptr CInt
ldbPtr Ptr CInt
ipivPtr


determinant :: (Shape.C sh, Class.Floating a) => Square sh a -> a
determinant :: Square sh a -> a
determinant = (Order -> sh -> ForeignPtr a -> a) -> Square sh a -> a
forall sh a b.
(Order -> sh -> ForeignPtr a -> b) -> Square sh a -> b
argSquare ((Order -> sh -> ForeignPtr a -> a) -> Square sh a -> a)
-> (Order -> sh -> ForeignPtr a -> a) -> Square sh a -> a
forall a b. (a -> b) -> a -> b
$ \Order
_order sh
sh ForeignPtr a
a -> IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ do
   let n :: Int
n = sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh
   ContT a IO a -> IO a
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT a IO a -> IO a) -> ContT a IO a -> IO a
forall a b. (a -> b) -> a -> b
$ do
      Ptr CInt
nPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
aPtr <- Int -> ForeignPtr a -> ContT a IO (Ptr a)
forall a r. Storable a => Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToTemp (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) ForeignPtr a
a
      Ptr CInt
ldaPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
n
      Ptr CInt
ipivPtr <- Int -> FortranIO a (Ptr CInt)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
      IO a -> ContT a IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> ContT a IO a) -> IO a -> ContT a IO a
forall a b. (a -> b) -> a -> b
$ String -> (Ptr CInt -> IO ()) -> IO a -> IO a
forall a.
Floating a =>
String -> (Ptr CInt -> IO ()) -> IO a -> IO a
withDeterminantInfo String
"getrf"
         (Ptr CInt
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> Ptr CInt -> IO ()
LapackGen.getrf Ptr CInt
nPtr Ptr CInt
nPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr CInt
ipivPtr)
         (do
            a
det <- Int -> Ptr a -> Int -> IO a
forall a. Floating a => Int -> Ptr a -> Int -> IO a
Private.product Int
n Ptr a
aPtr (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
            [CInt]
ipiv <- Int -> Ptr CInt -> IO [CInt]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray Int
n Ptr CInt
ipivPtr
            a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> IO a) -> a -> IO a
forall a b. (a -> b) -> a -> b
$ [CInt] -> a -> a
forall a. Floating a => [CInt] -> a -> a
Perm.condNegate [CInt]
ipiv a
det)