{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Private where
import Numeric.LAPACK.Matrix.Shape.Private
(Order(RowMajor, ColumnMajor), transposeFromOrder)
import Numeric.LAPACK.Wrapper (Flip(Flip, getFlip))
import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.LAPACK.FFI.Complex as LapackComplex
import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import Numeric.LAPACK.Scalar (zero, one, isZero)
import qualified Foreign.Marshal.Utils as Marshal
import qualified Foreign.C.String as CStr
import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.Marshal.Alloc (alloca)
import Foreign.C.Types (CChar, CInt)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable, poke, peek)
import Text.Printf (printf)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT, runContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when, foldM)
import Control.Applicative ((<$>))
import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)
import qualified Data.Complex as Complex
import Data.Complex (Complex)
import Data.Tuple.HT (swap)
import Prelude hiding (sum)
fill :: (Class.Floating a) => a -> Int -> Ptr a -> IO ()
fill a n dstPtr = evalContT $ do
nPtr <- Call.cint n
srcPtr <- Call.number a
incxPtr <- Call.cint 0
incyPtr <- Call.cint 1
liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr
copyBlock :: (Class.Floating a) => Int -> Ptr a -> Ptr a -> IO ()
copyBlock n srcPtr dstPtr = evalContT $ do
nPtr <- Call.cint n
incxPtr <- Call.cint 1
incyPtr <- Call.cint 1
liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr
copyToTemp :: (Storable a) => Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToTemp n fptr = do
ptr <- ContT $ withForeignPtr fptr
tmpPtr <- Call.allocaArray n
liftIO $ copyArray tmpPtr ptr n
return tmpPtr
conjugateToTemp ::
(Class.Floating a) => Int -> ForeignPtr a -> ContT r IO (Ptr a)
conjugateToTemp n =
runCopyToTemp $
Class.switchFloating
(CopyToTemp $ ContT . withForeignPtr)
(CopyToTemp $ ContT . withForeignPtr)
(CopyToTemp $ complexConjugateToTemp n)
(CopyToTemp $ complexConjugateToTemp n)
newtype CopyToTemp r a =
CopyToTemp {runCopyToTemp :: ForeignPtr a -> ContT r IO (Ptr a)}
complexConjugateToTemp ::
Class.Real a =>
Int -> ForeignPtr (Complex a) -> ContT r IO (Ptr (Complex a))
complexConjugateToTemp n x = do
nPtr <- Call.cint n
xPtr <- copyToTemp n x
incxPtr <- Call.cint 1
liftIO $ LapackComplex.lacgv nPtr xPtr incxPtr
return xPtr
copyConjugate ::
(Class.Floating a) =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyConjugate nPtr xPtr incxPtr yPtr incyPtr = do
BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr
lacgv nPtr yPtr incyPtr
copyCondConjugate ::
(Class.Floating a) =>
Bool -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyCondConjugate conj nPtr xPtr incxPtr yPtr incyPtr = do
BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr
when conj $ lacgv nPtr yPtr incyPtr
condConjugateToTemp ::
(Class.Floating a) =>
Bool -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
condConjugateToTemp conj n x =
if conj then conjugateToTemp n x else ContT $ withForeignPtr x
copyCondConjugateToTemp ::
(Class.Floating a) =>
Bool -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyCondConjugateToTemp conj n a = do
bPtr <- Call.allocaArray n
liftIO $ evalContT $ do
aPtr <- ContT $ withForeignPtr a
sizePtr <- Call.cint n
incPtr <- Call.cint 1
liftIO $ copyCondConjugate conj sizePtr aPtr incPtr bPtr incPtr
return bPtr
copySubMatrix ::
(Class.Floating a) =>
Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubMatrix = copySubTrapezoid 'A'
copySubTrapezoid ::
(Class.Floating a) =>
Char -> Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubTrapezoid side m n lda aPtr ldb bPtr = evalContT $ do
uploPtr <- Call.char side
mPtr <- Call.cint m
nPtr <- Call.cint n
ldaPtr <- Call.leadingDim lda
ldbPtr <- Call.leadingDim ldb
liftIO $ LapackGen.lacpy uploPtr mPtr nPtr aPtr ldaPtr bPtr ldbPtr
copyTransposed ::
(Class.Floating a) =>
Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copyTransposed n m aPtr ldb bPtr = evalContT $ do
nPtr <- Call.cint n
incaPtr <- Call.cint m
incbPtr <- Call.cint 1
liftIO $ sequence_ $ take m $
zipWith
(\akPtr bkPtr -> BlasGen.copy nPtr akPtr incaPtr bkPtr incbPtr)
(pointerSeq 1 aPtr)
(pointerSeq ldb bPtr)
copyToColumnMajor ::
(Class.Floating a) =>
Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyToColumnMajor order m n aPtr bPtr =
case order of
RowMajor -> copyTransposed m n aPtr m bPtr
ColumnMajor -> copyBlock (m*n) aPtr bPtr
copyToSubColumnMajor ::
(Class.Floating a) =>
Order -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copyToSubColumnMajor order m n aPtr ldb bPtr =
case order of
RowMajor -> copyTransposed m n aPtr ldb bPtr
ColumnMajor ->
if m==ldb
then copyBlock (m*n) aPtr bPtr
else copySubMatrix m n m aPtr ldb bPtr
pointerSeq :: (Storable a) => Int -> Ptr a -> [Ptr a]
pointerSeq k ptr = iterate (flip advancePtr k) ptr
createHigherArray ::
(Shape.C sh, Class.Floating a) =>
sh -> Int -> Int -> Int ->
((Ptr a, Int) -> IO rank) -> IO (rank, Array sh a)
createHigherArray shapeX m n nrhs act =
fmap swap $ ArrayIO.unsafeCreateWithSizeAndResult shapeX $ \ _ xPtr ->
if m>n
then
runContT (Call.allocaArray (m*nrhs)) $ \tmpPtr -> do
r <- act (tmpPtr,m)
copySubMatrix n nrhs m tmpPtr n xPtr
return r
else act (xPtr,n)
newtype Sum a = Sum {runSum :: Int -> Ptr a -> Int -> IO a}
sum :: Class.Floating a => Int -> Ptr a -> Int -> IO a
sum =
runSum $
Class.switchFloating
(Sum sumReal)
(Sum sumReal)
(Sum sumComplex)
(Sum sumComplex)
sumReal :: Class.Real a => Int -> Ptr a -> Int -> IO a
sumReal n xPtr incx =
evalContT $ do
nPtr <- Call.cint n
incxPtr <- Call.cint incx
yPtr <- Call.real one
incyPtr <- Call.cint 0
liftIO $ BlasReal.dot nPtr xPtr incxPtr yPtr incyPtr
sumComplex :: Class.Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplex n xPtr incx =
evalContT $ do
transPtr <- Call.char 'N'
mPtr <- Call.cint 1
nPtr <- Call.cint n
alphaPtr <- Call.number one
onePtr <- Call.number one
zeroincPtr <- Call.cint 0
aPtr <- Call.allocaArray n
ldaPtr <- Call.leadingDim 1
incxPtr <- Call.cint incx
betaPtr <- Call.number zero
yPtr <- Call.alloca
incyPtr <- Call.cint 1
liftIO $ do
BlasGen.copy nPtr onePtr zeroincPtr aPtr incyPtr
gemv
transPtr mPtr nPtr alphaPtr aPtr ldaPtr
xPtr incxPtr betaPtr yPtr incyPtr
peek yPtr
product :: Class.Floating a => Int -> Ptr a -> Int -> IO a
product n xPtr incx =
foldM (\x ptr -> do y <- peek ptr; return $! x*y) one $
take n $ pointerSeq incx xPtr
newtype LACGV a = LACGV {getLACGV :: Ptr CInt -> Ptr a -> Ptr CInt -> IO ()}
lacgv :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
lacgv =
getLACGV $
Class.switchFloating
(LACGV $ const $ const $ const $ return ())
(LACGV $ const $ const $ const $ return ())
(LACGV LapackComplex.lacgv)
(LACGV LapackComplex.lacgv)
{-# INLINE gemv #-}
gemv ::
(Class.Floating a) =>
Ptr CChar -> Ptr CInt -> Ptr CInt ->
Ptr a -> Ptr a -> Ptr CInt ->
Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
gemv transPtr mPtr nPtr
alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr = do
initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr
BlasGen.gemv transPtr mPtr nPtr
alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr
{-# INLINE gbmv #-}
gbmv ::
(Class.Floating a) =>
Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr CInt -> Ptr CInt ->
Ptr a -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt ->
Ptr a -> Ptr a -> Ptr CInt -> IO ()
gbmv transPtr mPtr nPtr klPtr kuPtr
alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr = do
initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr
BlasGen.gbmv transPtr mPtr nPtr klPtr kuPtr
alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr
initializeMV ::
Class.Floating a =>
Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr = do
trans <- peek transPtr
let (mtPtr,ntPtr) =
if trans == CStr.castCharToCChar 'N'
then (mPtr,nPtr) else (nPtr,mPtr)
n <- peek ntPtr
beta <- peek betaPtr
when (n == 0 && isZero beta) $
Marshal.with 0 $ \incbPtr ->
BlasGen.copy mtPtr betaPtr incbPtr yPtr incyPtr
multiplyMatrix ::
(Class.Floating a) =>
Order -> Order -> Int -> Int -> Int ->
ForeignPtr a -> ForeignPtr a -> Ptr a -> IO ()
multiplyMatrix orderA orderB m k n a b cPtr = do
let lda = case orderA of RowMajor -> k; ColumnMajor -> m
let ldb = case orderB of RowMajor -> n; ColumnMajor -> k
let ldc = m
evalContT $ do
transaPtr <- Call.char $ transposeFromOrder orderA
transbPtr <- Call.char $ transposeFromOrder orderB
mPtr <- Call.cint m
nPtr <- Call.cint n
kPtr <- Call.cint k
alphaPtr <- Call.number one
aPtr <- ContT $ withForeignPtr a
ldaPtr <- Call.leadingDim lda
bPtr <- ContT $ withForeignPtr b
ldbPtr <- Call.leadingDim ldb
betaPtr <- Call.number zero
ldcPtr <- Call.leadingDim ldc
liftIO $
BlasGen.gemm
transaPtr transbPtr mPtr nPtr kPtr alphaPtr aPtr ldaPtr
bPtr ldbPtr betaPtr cPtr ldcPtr
withAutoWorkspaceInfo ::
(Class.Floating a) =>
String -> String -> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO ()
withAutoWorkspaceInfo msg name computation =
withInfo msg name $ \infoPtr ->
withAutoWorkspace $ \workPtr lworkPtr ->
computation workPtr lworkPtr infoPtr
withAutoWorkspace ::
(Class.Floating a) =>
(Ptr a -> Ptr CInt -> IO ()) -> IO ()
withAutoWorkspace computation = evalContT $ do
lworkPtr <- Call.cint (-1)
lwork <- liftIO $ alloca $ \workPtr -> do
computation workPtr lworkPtr
max 1 . ceilingSize <$> peek workPtr
workPtr <- Call.allocaArray lwork
liftIO $ pokeCInt lworkPtr lwork
liftIO $ computation workPtr lworkPtr
withInfo :: String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo msg name computation = alloca $ \infoPtr -> do
computation infoPtr
info <- peekCInt infoPtr
case compare info (0::Int) of
EQ -> return ()
LT -> error $ printf argMsg name (-info)
GT -> error $ name ++ ": " ++ printf msg info
argMsg :: String
argMsg = "%s: illegal value in %d-th argument"
errorCodeMsg :: String
errorCodeMsg = "unknown error code %d"
rankMsg :: String
rankMsg = "deficient rank %d"
definiteMsg :: String
definiteMsg = "minor of order %d not positive definite"
eigenMsg :: String
eigenMsg = "%d off-diagonal elements not converging"
pokeCInt :: Ptr CInt -> Int -> IO ()
pokeCInt ptr = poke ptr . fromIntegral
peekCInt :: Ptr CInt -> IO Int
peekCInt ptr = fromIntegral <$> peek ptr
ceilingSize :: (Class.Floating a) => a -> Int
ceilingSize =
getFlip $
Class.switchFloating
(Flip ceiling)
(Flip ceiling)
(Flip $ ceiling . Complex.realPart)
(Flip $ ceiling . Complex.realPart)