{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Private where
import Numeric.LAPACK.Matrix.Shape.Private
(Order(RowMajor, ColumnMajor), transposeFromOrder)
import Numeric.LAPACK.Wrapper (Flip(Flip, getFlip))
import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.LAPACK.FFI.Complex as LapackComplex
import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated, Conjugated))
import Numeric.LAPACK.Scalar (RealOf, zero, one, isZero)
import qualified Foreign.Marshal.Array.Guarded as ForeignArray
import qualified Foreign.Marshal.Utils as Marshal
import qualified Foreign.C.String as CStr
import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.Marshal.Alloc (alloca)
import Foreign.C.Types (CChar, CInt)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Storable (Storable, poke, peek, pokeElemOff, peekElemOff)
import Text.Printf (printf)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT, runContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)
import Control.Applicative (Const(Const,getConst), liftA2, (<$>))
import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)
import qualified Data.Complex as Complex
import Data.Complex (Complex)
import Data.Tuple.HT (swap)
import Prelude hiding (sum)
realPtr :: Ptr a -> Ptr (RealOf a)
realPtr = castPtr
fill :: (Class.Floating a) => a -> Int -> Ptr a -> IO ()
fill a n dstPtr = evalContT $ do
nPtr <- Call.cint n
srcPtr <- Call.number a
incxPtr <- Call.cint 0
incyPtr <- Call.cint 1
liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr
copyBlock :: (Class.Floating a) => Int -> Ptr a -> Ptr a -> IO ()
copyBlock n srcPtr dstPtr = evalContT $ do
nPtr <- Call.cint n
incxPtr <- Call.cint 1
incyPtr <- Call.cint 1
liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr
copyToTemp :: (Storable a) => Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToTemp n fptr = do
ptr <- ContT $ withForeignPtr fptr
tmpPtr <- Call.allocaArray n
liftIO $ copyArray tmpPtr ptr n
return tmpPtr
conjugateToTemp ::
(Class.Floating a) => Int -> ForeignPtr a -> ContT r IO (Ptr a)
conjugateToTemp n =
runCopyToTemp $
Class.switchFloating
(CopyToTemp $ ContT . withForeignPtr)
(CopyToTemp $ ContT . withForeignPtr)
(CopyToTemp $ complexConjugateToTemp n)
(CopyToTemp $ complexConjugateToTemp n)
newtype CopyToTemp r a =
CopyToTemp {runCopyToTemp :: ForeignPtr a -> ContT r IO (Ptr a)}
complexConjugateToTemp ::
Class.Real a =>
Int -> ForeignPtr (Complex a) -> ContT r IO (Ptr (Complex a))
complexConjugateToTemp n x = do
nPtr <- Call.cint n
xPtr <- copyToTemp n x
incxPtr <- Call.cint 1
liftIO $ LapackComplex.lacgv nPtr xPtr incxPtr
return xPtr
condConjugate ::
(Class.Floating a) => Conjugation -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
condConjugate conj nPtr yPtr incyPtr =
when (conj==Conjugated) $ lacgv nPtr yPtr incyPtr
copyConjugate ::
(Class.Floating a) =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyConjugate nPtr xPtr incxPtr yPtr incyPtr = do
BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr
lacgv nPtr yPtr incyPtr
copyCondConjugate ::
(Class.Floating a) =>
Conjugation -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyCondConjugate conj nPtr xPtr incxPtr yPtr incyPtr = do
BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr
condConjugate conj nPtr yPtr incyPtr
condConjugateToTemp ::
(Class.Floating a) =>
Conjugation -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
condConjugateToTemp conj n x =
case conj of
NonConjugated -> ContT $ withForeignPtr x
Conjugated -> conjugateToTemp n x
copyCondConjugateToTemp ::
(Class.Floating a) =>
Conjugation -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyCondConjugateToTemp conj n a = do
bPtr <- Call.allocaArray n
liftIO $ evalContT $ do
aPtr <- ContT $ withForeignPtr a
sizePtr <- Call.cint n
incPtr <- Call.cint 1
liftIO $ copyCondConjugate conj sizePtr aPtr incPtr bPtr incPtr
return bPtr
copySubMatrix ::
(Class.Floating a) =>
Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubMatrix = copySubTrapezoid 'A'
copySubTrapezoid ::
(Class.Floating a) =>
Char -> Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubTrapezoid side m n lda aPtr ldb bPtr = evalContT $ do
uploPtr <- Call.char side
mPtr <- Call.cint m
nPtr <- Call.cint n
ldaPtr <- Call.leadingDim lda
ldbPtr <- Call.leadingDim ldb
liftIO $ LapackGen.lacpy uploPtr mPtr nPtr aPtr ldaPtr bPtr ldbPtr
copyTransposed ::
(Class.Floating a) =>
Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copyTransposed n m aPtr ldb bPtr = evalContT $ do
nPtr <- Call.cint n
incaPtr <- Call.cint m
incbPtr <- Call.cint 1
liftIO $ sequence_ $ take m $
zipWith
(\akPtr bkPtr -> BlasGen.copy nPtr akPtr incaPtr bkPtr incbPtr)
(pointerSeq 1 aPtr)
(pointerSeq ldb bPtr)
copyToColumnMajor ::
(Class.Floating a) =>
Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyToColumnMajor order m n aPtr bPtr =
case order of
RowMajor -> copyTransposed m n aPtr m bPtr
ColumnMajor -> copyBlock (m*n) aPtr bPtr
copyToSubColumnMajor ::
(Class.Floating a) =>
Order -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copyToSubColumnMajor order m n aPtr ldb bPtr =
case order of
RowMajor -> copyTransposed m n aPtr ldb bPtr
ColumnMajor ->
if m==ldb
then copyBlock (m*n) aPtr bPtr
else copySubMatrix m n m aPtr ldb bPtr
pointerSeq :: (Storable a) => Int -> Ptr a -> [Ptr a]
pointerSeq k ptr = iterate (flip advancePtr k) ptr
createHigherArray ::
(Shape.C sh, Class.Floating a) =>
sh -> Int -> Int -> Int ->
((Ptr a, Int) -> IO rank) -> IO (rank, Array sh a)
createHigherArray shapeX m n nrhs act =
fmap swap $ ArrayIO.unsafeCreateWithSizeAndResult shapeX $ \ _ xPtr ->
if m>n
then
runContT (Call.allocaArray (m*nrhs)) $ \tmpPtr -> do
r <- act (tmpPtr,m)
copySubMatrix n nrhs m tmpPtr n xPtr
return r
else act (xPtr,n)
newtype Sum a = Sum {runSum :: Int -> Ptr a -> Int -> IO a}
sum :: Class.Floating a => Int -> Ptr a -> Int -> IO a
sum =
runSum $
Class.switchFloating
(Sum sumReal)
(Sum sumReal)
(Sum sumComplex)
(Sum sumComplex)
sumReal :: Class.Real a => Int -> Ptr a -> Int -> IO a
sumReal n xPtr incx =
evalContT $ do
nPtr <- Call.cint n
incxPtr <- Call.cint incx
yPtr <- Call.real one
incyPtr <- Call.cint 0
liftIO $ BlasReal.dot nPtr xPtr incxPtr yPtr incyPtr
sumComplex, sumComplexAlt ::
Class.Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplex n xPtr incx =
evalContT $ do
nPtr <- Call.cint n
let sxPtr = realPtr xPtr
incxPtr <- Call.cint (2*incx)
yPtr <- Call.real one
incyPtr <- Call.cint 0
liftIO $
liftA2 (Complex.:+)
(BlasReal.dot nPtr sxPtr incxPtr yPtr incyPtr)
(BlasReal.dot nPtr (advancePtr sxPtr 1) incxPtr yPtr incyPtr)
sumComplexAlt n aPtr inca =
evalContT $ do
transPtr <- Call.char 'N'
mPtr <- Call.cint 2
nPtr <- Call.cint n
onePtr <- Call.number one
inc0Ptr <- Call.cint 0
let saPtr = realPtr aPtr
ldaPtr <- Call.leadingDim (2*inca)
sxPtr <- Call.allocaArray n
incxPtr <- Call.cint 1
betaPtr <- Call.number zero
yPtr <- Call.alloca
let syPtr = realPtr yPtr
incyPtr <- Call.cint 1
liftIO $ do
BlasGen.copy nPtr onePtr inc0Ptr sxPtr incxPtr
gemv
transPtr mPtr nPtr onePtr saPtr ldaPtr
sxPtr incxPtr betaPtr syPtr incyPtr
peek yPtr
mulReal ::
(Class.Floating a) =>
Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulReal n aPtr inca xPtr incx yPtr incy = evalContT $ do
uploPtr <- Call.char 'U'
nPtr <- Call.cint n
kPtr <- Call.cint 0
alphaPtr <- Call.number one
ldaPtr <- Call.leadingDim inca
incxPtr <- Call.cint incx
betaPtr <- Call.number zero
incyPtr <- Call.cint incy
liftIO $
BlasGen.hbmv uploPtr
nPtr kPtr alphaPtr aPtr ldaPtr
xPtr incxPtr betaPtr yPtr incyPtr
mul ::
(Class.Floating a) =>
Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mul n aPtr inca xPtr incx yPtr incy = evalContT $ do
transPtr <- Call.char 'N'
nPtr <- Call.cint n
klPtr <- Call.cint 0
kuPtr <- Call.cint 0
alphaPtr <- Call.number one
ldaPtr <- Call.leadingDim inca
incxPtr <- Call.cint incx
betaPtr <- Call.number zero
incyPtr <- Call.cint incy
liftIO $
BlasGen.gbmv transPtr
nPtr nPtr klPtr kuPtr alphaPtr aPtr ldaPtr
xPtr incxPtr betaPtr yPtr incyPtr
product :: (Class.Floating a) => Int -> Ptr a -> Int -> IO a
product n aPtr inca =
case compare n 1 of
LT -> return one
EQ -> peek aPtr
GT -> let n2 = div n 2; new = n-n2
in ForeignArray.alloca (2*new-1) $ \xPtr -> do
mulPairs n2 aPtr inca xPtr 1
when (odd n) $ pokeElemOff xPtr n2 =<< peekElemOff aPtr ((n-1)*inca)
productLoop new xPtr
productLoop :: (Class.Floating a) => Int -> Ptr a -> IO a
productLoop n xPtr =
if n==1
then peek xPtr
else do
let n2 = div n 2
mulPairs n2 xPtr 1 (advancePtr xPtr n) 1
productLoop (n-n2) (advancePtr xPtr (2*n2))
mulPairs ::
(Class.Floating a) =>
Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulPairs n aPtr inca xPtr incx =
let inca2 = 2*inca
in mul n aPtr inca2 (advancePtr aPtr inca) inca2 xPtr incx
newtype LACGV a = LACGV {getLACGV :: Ptr CInt -> Ptr a -> Ptr CInt -> IO ()}
lacgv :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
lacgv =
getLACGV $
Class.switchFloating
(LACGV $ const $ const $ const $ return ())
(LACGV $ const $ const $ const $ return ())
(LACGV LapackComplex.lacgv)
(LACGV LapackComplex.lacgv)
{-# INLINE gemv #-}
gemv ::
(Class.Floating a) =>
Ptr CChar -> Ptr CInt -> Ptr CInt ->
Ptr a -> Ptr a -> Ptr CInt ->
Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
gemv transPtr mPtr nPtr
alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr = do
initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr
BlasGen.gemv transPtr mPtr nPtr
alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr
{-# INLINE gbmv #-}
gbmv ::
(Class.Floating a) =>
Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr CInt -> Ptr CInt ->
Ptr a -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt ->
Ptr a -> Ptr a -> Ptr CInt -> IO ()
gbmv transPtr mPtr nPtr klPtr kuPtr
alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr = do
initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr
BlasGen.gbmv transPtr mPtr nPtr klPtr kuPtr
alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr
initializeMV ::
Class.Floating a =>
Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr = do
trans <- peek transPtr
let (mtPtr,ntPtr) =
if trans == CStr.castCharToCChar 'N'
then (mPtr,nPtr) else (nPtr,mPtr)
n <- peek ntPtr
beta <- peek betaPtr
when (n == 0 && isZero beta) $
Marshal.with 0 $ \incbPtr ->
BlasGen.copy mtPtr betaPtr incbPtr yPtr incyPtr
multiplyMatrix ::
(Class.Floating a) =>
Order -> Order -> Int -> Int -> Int ->
ForeignPtr a -> ForeignPtr a -> Ptr a -> IO ()
multiplyMatrix orderA orderB m k n a b cPtr = do
let lda = case orderA of RowMajor -> k; ColumnMajor -> m
let ldb = case orderB of RowMajor -> n; ColumnMajor -> k
let ldc = m
evalContT $ do
transaPtr <- Call.char $ transposeFromOrder orderA
transbPtr <- Call.char $ transposeFromOrder orderB
mPtr <- Call.cint m
nPtr <- Call.cint n
kPtr <- Call.cint k
alphaPtr <- Call.number one
aPtr <- ContT $ withForeignPtr a
ldaPtr <- Call.leadingDim lda
bPtr <- ContT $ withForeignPtr b
ldbPtr <- Call.leadingDim ldb
betaPtr <- Call.number zero
ldcPtr <- Call.leadingDim ldc
liftIO $
BlasGen.gemm
transaPtr transbPtr mPtr nPtr kPtr alphaPtr aPtr ldaPtr
bPtr ldbPtr betaPtr cPtr ldcPtr
withAutoWorkspaceInfo ::
(Class.Floating a) =>
String -> String -> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO ()
withAutoWorkspaceInfo msg name computation =
withInfo msg name $ \infoPtr ->
withAutoWorkspace $ \workPtr lworkPtr ->
computation workPtr lworkPtr infoPtr
withAutoWorkspace ::
(Class.Floating a) =>
(Ptr a -> Ptr CInt -> IO ()) -> IO ()
withAutoWorkspace computation = evalContT $ do
lworkPtr <- Call.cint (-1)
lwork <- liftIO $ alloca $ \workPtr -> do
computation workPtr lworkPtr
max 1 . ceilingSize <$> peek workPtr
workPtr <- Call.allocaArray lwork
liftIO $ pokeCInt lworkPtr lwork
liftIO $ computation workPtr lworkPtr
withInfo :: String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo msg name computation = alloca $ \infoPtr -> do
computation infoPtr
info <- peekCInt infoPtr
case compare info (0::Int) of
EQ -> return ()
LT -> error $ printf argMsg name (-info)
GT -> error $ name ++ ": " ++ printf msg info
argMsg :: String
argMsg = "%s: illegal value in %d-th argument"
errorCodeMsg :: String
errorCodeMsg = "unknown error code %d"
rankMsg :: String
rankMsg = "deficient rank %d"
definiteMsg :: String
definiteMsg = "minor of order %d not positive definite"
eigenMsg :: String
eigenMsg = "%d off-diagonal elements not converging"
pokeCInt :: Ptr CInt -> Int -> IO ()
pokeCInt ptr = poke ptr . fromIntegral
peekCInt :: Ptr CInt -> IO Int
peekCInt ptr = fromIntegral <$> peek ptr
ceilingSize :: (Class.Floating a) => a -> Int
ceilingSize =
getFlip $
Class.switchFloating
(Flip ceiling)
(Flip ceiling)
(Flip $ ceiling . Complex.realPart)
(Flip $ ceiling . Complex.realPart)
caseRealComplexFunc :: (Class.Floating a) => f a -> b -> b -> b
caseRealComplexFunc f r c =
getConstFunc f $
Class.switchFloating (Const r) (Const r) (Const c) (Const c)
getConstFunc :: f c -> Const a c -> a
getConstFunc _ = getConst
data ComplexPart = RealPart | ImaginaryPart
deriving (Eq, Ord, Show, Enum, Bounded)