{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Matrix.Banded.Linear (
   solve,
   solveColumnMajor,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Linear.Private (solver, withDeterminantInfo, withInfo)
import Numeric.LAPACK.Matrix.Shape.Private
         (Order(RowMajor,ColumnMajor), transposeFromOrder)
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Private (copySubMatrix)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Type.Data.Num.Unary as Unary
import Type.Data.Num (integralFromProxy)

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import System.IO.Unsafe (unsafePerformIO)

import Foreign.Marshal.Array (peekArray, advancePtr)
import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)


solve ::
   (Unary.Natural sub, Unary.Natural super, Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Banded.Square sub super sh a ->
   Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solve :: Square sub super sh a
-> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solve (Array (MatrixShape.Banded (UnaryProxy sub, UnaryProxy super)
numOff Order
order Extent Small Small sh sh
extent) ForeignPtr a
a) =
   String
-> sh
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall vert horiz height width a.
(C vert, C horiz, C height, C width, Eq height, Floating a) =>
String
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz height width a
-> Full vert horiz height width a
solver String
"Banded.solve" (Extent Small Small sh sh -> sh
forall height width. Extent Small Small height width -> height
Extent.squareSize Extent Small Small sh sh
extent) ((Int
  -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
 -> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a)
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall a b. (a -> b) -> a -> b
$
         \Int
n Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr -> do
      let (Int
kl,Int
ku) = Order -> (UnaryProxy sub, UnaryProxy super) -> (Int, Int)
forall sub super.
(Natural sub, Natural super) =>
Order -> (UnaryProxy sub, UnaryProxy super) -> (Int, Int)
MatrixShape.numOffDiagonals Order
order (UnaryProxy sub, UnaryProxy super)
numOff
      let k :: Int
k = Int
klInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
ku
      let ldab :: Int
ldab = Int
klInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
k
      Ptr CChar
transPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
transposeFromOrder Order
order
      Ptr CInt
klPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
kl
      Ptr CInt
kuPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
ku
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr a
abPtr <- Int -> ContT () IO (Ptr a)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
ldab)
      Ptr CInt
ldabPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldab
      Ptr CInt
ipivPtr <- Int -> FortranIO () (Ptr CInt)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
forall a.
Floating a =>
Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubMatrix Int
k Int
n Int
k Ptr a
aPtr Int
ldab (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
abPtr Int
kl)
         String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
"gbtrf" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.gbtrf Ptr CInt
nPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr Ptr a
abPtr Ptr CInt
ldabPtr Ptr CInt
ipivPtr
         String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
"gbtrs" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.gbtrs Ptr CChar
transPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr Ptr CInt
nrhsPtr
               Ptr a
abPtr Ptr CInt
ldabPtr Ptr CInt
ipivPtr Ptr a
xPtr Ptr CInt
ldxPtr

solveColumnMajor ::
   (Unary.Natural sub, Unary.Natural super, Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Banded.Square sub super sh a ->
   Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solveColumnMajor :: Square sub super sh a
-> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solveColumnMajor
      (Array (MatrixShape.Banded (UnaryProxy sub
sub,UnaryProxy super
super) Order
ColumnMajor Extent Small Small sh sh
extent) ForeignPtr a
a) =
   String
-> sh
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall vert horiz height width a.
(C vert, C horiz, C height, C width, Eq height, Floating a) =>
String
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz height width a
-> Full vert horiz height width a
solver String
"Banded.solve" (Extent Small Small sh sh -> sh
forall height width. Extent Small Small height width -> height
Extent.squareSize Extent Small Small sh sh
extent) ((Int
  -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
 -> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a)
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz sh nrhs a
-> Full vert horiz sh nrhs a
forall a b. (a -> b) -> a -> b
$
         \Int
n Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr -> do
      let kl :: Int
kl = UnaryProxy sub -> Int
forall x y. (Integer x, Num y) => Proxy x -> y
integralFromProxy UnaryProxy sub
sub
      let ku :: Int
ku = UnaryProxy super -> Int
forall x y. (Integer x, Num y) => Proxy x -> y
integralFromProxy UnaryProxy super
super
      let k :: Int
k = Int
klInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
ku
      let ldab :: Int
ldab = Int
klInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
k
      Ptr CInt
klPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
kl
      Ptr CInt
kuPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
ku
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr a
abPtr <- Int -> ContT () IO (Ptr a)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
ldab)
      Ptr CInt
ldabPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldab
      Ptr CInt
ipivPtr <- Int -> FortranIO () (Ptr CInt)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
forall a.
Floating a =>
Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubMatrix Int
k Int
n Int
k Ptr a
aPtr Int
ldab (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
abPtr Int
kl)
         String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
"gbsv" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.gbsv Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr Ptr CInt
nrhsPtr
               Ptr a
abPtr Ptr CInt
ldabPtr Ptr CInt
ipivPtr Ptr a
xPtr Ptr CInt
ldxPtr
solveColumnMajor (Array (MatrixShape.Banded (UnaryProxy sub, UnaryProxy super)
_ Order
RowMajor Extent Small Small sh sh
_) ForeignPtr a
_) =
   String -> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
forall a. HasCallStack => String -> a
error String
"Linear.Banded.solveColumnMajor: RowMajor intentionally unimplemented"

determinant ::
   (Unary.Natural sub, Unary.Natural super, Shape.C sh, Class.Floating a) =>
   Banded.Square sub super sh a -> a
determinant :: Square sub super sh a -> a
determinant (Array (MatrixShape.Banded (UnaryProxy sub, UnaryProxy super)
numOff Order
order Extent Small Small sh sh
extent) ForeignPtr a
a) =
      IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ do
   let n :: Int
n = sh -> Int
forall sh. C sh => sh -> Int
Shape.size (sh -> Int) -> sh -> Int
forall a b. (a -> b) -> a -> b
$ Extent Small Small sh sh -> sh
forall height width. Extent Small Small height width -> height
Extent.squareSize Extent Small Small sh sh
extent
   ContT a IO a -> IO a
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT a IO a -> IO a) -> ContT a IO a -> IO a
forall a b. (a -> b) -> a -> b
$ do
      let (Int
kl,Int
ku) = Order -> (UnaryProxy sub, UnaryProxy super) -> (Int, Int)
forall sub super.
(Natural sub, Natural super) =>
Order -> (UnaryProxy sub, UnaryProxy super) -> (Int, Int)
MatrixShape.numOffDiagonals Order
order (UnaryProxy sub, UnaryProxy super)
numOff
      let k :: Int
k = Int
klInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
ku
      let ldab :: Int
ldab = Int
klInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
k
      Ptr CInt
nPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
klPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
kl
      Ptr CInt
kuPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
ku
      Ptr a
aPtr <- ((Ptr a -> IO a) -> IO a) -> ContT a IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO a) -> IO a) -> ContT a IO (Ptr a))
-> ((Ptr a -> IO a) -> IO a) -> ContT a IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO a) -> IO a
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr a
abPtr <- Int -> ContT a IO (Ptr a)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
ldab)
      Ptr CInt
ldabPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldab
      Ptr CInt
ipivPtr <- Int -> FortranIO a (Ptr CInt)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
      IO a -> ContT a IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> ContT a IO a) -> IO a -> ContT a IO a
forall a b. (a -> b) -> a -> b
$ do
         Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
forall a.
Floating a =>
Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubMatrix Int
k Int
n Int
k Ptr a
aPtr Int
ldab (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
abPtr Int
kl)
         String -> (Ptr CInt -> IO ()) -> IO a -> IO a
forall a.
Floating a =>
String -> (Ptr CInt -> IO ()) -> IO a -> IO a
withDeterminantInfo String
"gbtrf"
            (Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.gbtrf Ptr CInt
nPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr Ptr a
abPtr Ptr CInt
ldabPtr Ptr CInt
ipivPtr)
            (do
               a
det <- Int -> Ptr a -> Int -> IO a
forall a. Floating a => Int -> Ptr a -> Int -> IO a
Private.product Int
n (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
abPtr (Int
klInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
ku)) Int
ldab
               [CInt]
ipiv <- Int -> Ptr CInt -> IO [CInt]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray Int
n Ptr CInt
ipivPtr
               a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> IO a) -> a -> IO a
forall a b. (a -> b) -> a -> b
$ [CInt] -> a -> a
forall a. Floating a => [CInt] -> a -> a
Perm.condNegate [CInt]
ipiv a
det)