{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE UndecidableInstances #-}
module Numeric.LAPACK.Matrix.Mosaic.Private where
import qualified Numeric.LAPACK.Matrix.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Layout.Private
(Order(RowMajor,ColumnMajor), flipOrder, uploFromOrder)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Scalar (zero)
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked))
import Numeric.LAPACK.Private
(pointerSeq, copyBlock, copyCondConjugateToTemp,
pokeCInt, fill, withAutoWorkspaceInfo, withInfo, errorCodeMsg)
import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape ((::+)((::+)))
import Foreign.Marshal.Alloc (alloca)
import Foreign.Marshal.Array (advancePtr)
import Foreign.C.Types (CInt)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Applicative (pure, (<*>))
import Data.Foldable (forM_)
type Mosaic pack mirror uplo sh = Array (Layout.Mosaic pack mirror uplo sh)
type MosaicPacked mirror uplo sh = Mosaic Layout.Packed mirror uplo sh
type MosaicUnpacked mirror uplo sh = Mosaic Layout.Unpacked mirror uplo sh
type MosaicLower mirror sh = MosaicPacked mirror Shape.Lower sh
type MosaicUpper mirror sh = MosaicPacked mirror Shape.Upper sh
diagonalPointers :: (Storable a) => Order -> Int -> Ptr a -> [Ptr a]
diagonalPointers order n aPtr =
take n $ scanl advancePtr aPtr $
case order of
RowMajor -> iterate pred n
ColumnMajor -> iterate succ 2
diagonalPointerPairs ::
(Storable a, Storable b) =>
Order -> Int -> Ptr a -> Ptr b -> [(Ptr a, Ptr b)]
diagonalPointerPairs order n aPtr bPtr =
zip (pointerSeq 1 aPtr) $ diagonalPointers order n bPtr
columnMajorPointers ::
(Storable a) => Int -> Ptr a -> Ptr a -> [(Int, ((Ptr a, Ptr a), Ptr a))]
columnMajorPointers n fullPtr packedPtr =
let ds = iterate succ 1
in take n $ zip ds $
zip
(zip (pointerSeq 1 fullPtr) (pointerSeq n fullPtr))
(scanl advancePtr packedPtr ds)
rowMajorPointers ::
(Storable a) => Int -> Ptr a -> Ptr a -> [(Int, (Ptr a, Ptr a))]
rowMajorPointers n fullPtr packedPtr =
let ds = iterate pred n
in take n $ zip ds $
zip (pointerSeq (n+1) fullPtr) (scanl advancePtr packedPtr ds)
forPointers :: [(Int, a)] -> (Ptr CInt -> a -> IO ()) -> IO ()
forPointers xs act =
alloca $ \nPtr ->
forM_ xs $ \(d,ptrs) -> do
pokeCInt nPtr d
act nPtr ptrs
copyTriangleToTemp ::
Class.Floating a =>
Conjugation -> Order -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyTriangleToTemp conj order =
copyCondConjugateToTemp $
case order of
RowMajor -> conj
ColumnMajor -> NonConjugated
unpackToTemp ::
Storable a =>
(Int -> Ptr a -> Ptr a -> IO ()) ->
Int -> ForeignPtr a -> ContT r IO (Ptr a)
unpackToTemp f n a = do
apPtr <- ContT $ withForeignPtr a
aPtr <- Call.allocaArray (n*n)
liftIO $ f n apPtr aPtr
return aPtr
unpack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
unpack order n packedPtr fullPtr =
evalContT $ do
uploPtr <- Call.char $ uploFromOrder order
nPtr <- Call.cint n
ldaPtr <- Call.leadingDim n
liftIO $ withInfo errorCodeMsg "tpttr" $
LapackGen.tpttr uploPtr nPtr packedPtr fullPtr ldaPtr
pack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
pack order n = packRect order n n
packRect :: Class.Floating a => Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
packRect order n ld fullPtr packedPtr =
evalContT $ do
uploPtr <- Call.char $ uploFromOrder order
nPtr <- Call.cint n
ldaPtr <- Call.leadingDim ld
liftIO $ withInfo errorCodeMsg "trttp" $
LapackGen.trttp uploPtr nPtr fullPtr ldaPtr packedPtr
unpackZero, _unpackZero ::
Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
_unpackZero order n packedPtr fullPtr = do
fill zero (n*n) fullPtr
unpack order n packedPtr fullPtr
unpackZero order n packedPtr fullPtr = do
fillTriangle zero (flipOrder order) n fullPtr
unpack order n packedPtr fullPtr
fillTriangle :: Class.Floating a => a -> Order -> Int -> Ptr a -> IO ()
fillTriangle z order n aPtr = evalContT $ do
uploPtr <- Call.char $ uploFromOrder order
nPtr <- Call.cint n
zPtr <- Call.number z
liftIO $ LapackGen.laset uploPtr nPtr nPtr zPtr zPtr aPtr nPtr
uncheck ::
Mosaic pack mirror uplo sh a -> Mosaic pack mirror uplo (Unchecked sh) a
uncheck =
Array.mapShape $
\(Layout.Mosaic packing mirror uplo order sh) ->
Layout.Mosaic packing mirror uplo order (Unchecked sh)
recheck ::
Mosaic pack mirror uplo (Unchecked sh) a -> Mosaic pack mirror uplo sh a
recheck =
Array.mapShape $
\(Layout.Mosaic packing mirror uplo order (Unchecked sh)) ->
Layout.Mosaic packing mirror uplo order sh
stack ::
(Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
MosaicUpper mirror height a ->
Matrix.General height width a ->
MosaicUpper mirror width a ->
MosaicUpper mirror (height::+width) a
stack (Array sha a) (Array (Layout.Full order extent) b) (Array shc c) =
let name = show $ Layout.mosaicMirror sha
(height,width) = Extent.dimensions extent
in Array.unsafeCreate
(Layout.Mosaic Layout.Packed
(Layout.mosaicMirror sha)
Layout.Upper order (height ::+ width)) $ \xPtr -> do
Call.assert (name++".stack: height shapes mismatch") $
height == Layout.mosaicSize sha
Call.assert (name++".stack: width shapes mismatch") $
width == Layout.mosaicSize shc
let m = Shape.size height
let n = Shape.size width
withForeignPtr a $ \aPtr -> copyTriangleA copyBlock order m n aPtr xPtr
withForeignPtr b $ \bPtr -> copyRectangle copyBlock order m n bPtr xPtr
withForeignPtr c $ \cPtr -> copyTriangleC copyBlock order m n cPtr xPtr
takeTopRight ::
(Shape.C height, Shape.C width, Class.Floating a) =>
MosaicUpper mirror (height::+width) a -> Matrix.General height width a
takeTopRight
(Array
(Layout.Mosaic _packed _mirror _upper order (height::+width)) x) =
Array.unsafeCreate (Layout.general order height width) $ \bPtr -> do
let m = Shape.size height
let n = Shape.size width
withForeignPtr x $ copyRectangle (flip . copyBlock) order m n bPtr
takeTopLeft ::
(Shape.C height, Shape.C width, Class.Floating a) =>
MosaicUpper mirror (height::+width) a ->
MosaicUpper mirror height a
takeTopLeft
(Array (Layout.Mosaic packing mirror upper order (height::+width)) x) =
Array.unsafeCreate (Layout.Mosaic packing mirror upper order height) $
\aPtr -> do
let m = Shape.size height
let n = Shape.size width
withForeignPtr x $ copyTriangleA (flip . copyBlock) order m n aPtr
takeBottomRight ::
(Shape.C height, Shape.C width, Class.Floating a) =>
MosaicUpper mirror (height::+width) a ->
MosaicUpper mirror width a
takeBottomRight
(Array (Layout.Mosaic packing mirror upper order (height::+width)) x) =
Array.unsafeCreate (Layout.Mosaic packing mirror upper order width) $
\cPtr -> do
let m = Shape.size height
let n = Shape.size width
withForeignPtr x $ copyTriangleC (flip . copyBlock) order m n cPtr
{-# INLINE copyTriangleA #-}
copyTriangleA ::
(Class.Floating a) =>
(Int -> Ptr a -> Ptr a -> IO ()) ->
Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyTriangleA copy order m n aPtr xPtr =
case order of
ColumnMajor -> copy (Shape.triangleSize m) aPtr xPtr
RowMajor ->
forM_ (zip (iterate pred m) $
zip (diagonalPointers order m aPtr)
(diagonalPointers order (m+n) xPtr)) $
\(k,(aiPtr,xiPtr)) -> copy k aiPtr xiPtr
{-# INLINE copyTriangleC #-}
copyTriangleC ::
(Class.Floating a) =>
(Int -> Ptr a -> Ptr a -> IO ()) ->
Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyTriangleC copy order m n cPtr xPtr =
case order of
RowMajor ->
let triSize = Shape.triangleSize n
in copy triSize cPtr
(advancePtr xPtr $ Shape.triangleSize (m+n) - triSize)
ColumnMajor ->
forM_ (zip (iterate succ 0) $
zip (diagonalPointers order n cPtr)
(drop m $ diagonalPointers order (m+n) xPtr)) $
\(k,(aiPtr,xiPtr)) ->
copy (k+1) (advancePtr aiPtr (-k)) (advancePtr xiPtr (-k))
{-# INLINE copyRectangle #-}
copyRectangle ::
(Class.Floating a) =>
(Int -> Ptr a -> Ptr a -> IO ()) ->
Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyRectangle copy order m n bPtr xPtr =
case order of
RowMajor ->
forM_ (take m $ zip (iterate pred m) $
zip (pointerSeq n bPtr) (diagonalPointers order (m+n) xPtr)) $
\(k,(biPtr,xiPtr)) -> copy n biPtr (advancePtr xiPtr k)
ColumnMajor ->
forM_ (take n $ zip (iterate succ m) $
zip (pointerSeq m bPtr)
(drop m $ diagonalPointers order (m+n) xPtr)) $
\(k,(biPtr,xiPtr)) -> copy m biPtr (advancePtr xiPtr (-k))
type Triangular uplo sh = Array (Layout.Triangular uplo sh)
type Lower sh = Triangular Shape.Lower sh
type Upper sh = Triangular Shape.Upper sh
newtype MultiplyRight sh a b uplo =
MultiplyRight {getMultiplyRight :: Triangular uplo sh a -> b}
newtype Map pack mirror sh0 sh1 a uplo =
Map {
getMap :: Mosaic pack mirror uplo sh0 a -> Mosaic pack mirror uplo sh1 a
}
fromBanded ::
(Class.Floating a) =>
Int -> Order -> Int -> ForeignPtr a -> Int -> Ptr a -> IO ()
fromBanded k order n a bSize bPtr =
withForeignPtr a $ \aPtr -> do
fill zero bSize bPtr
let lda = k+1
let pointers =
zip [0..] $ zip (pointerSeq lda aPtr) $
diagonalPointers order n bPtr
case order of
ColumnMajor ->
forM_ pointers $ \(i,(xPtr,yPtr)) ->
let j = min i k
in copyBlock (j+1) (advancePtr xPtr (k-j)) (advancePtr yPtr (-j))
RowMajor ->
forM_ pointers $ \(i,(xPtr,yPtr)) ->
copyBlock (min lda (n-i)) xPtr yPtr
takeUpper :: MosaicUpper mirror sh a -> Upper sh a
takeUpper =
Array.mapShape
(\(Layout.Mosaic packing _mirror upper order sh) ->
Layout.Mosaic packing Layout.NoMirror upper order sh)
fromUpper ::
(Layout.Mirror mirror) => Upper sh a -> MosaicUpper mirror sh a
fromUpper =
Array.mapShape
(\(Layout.Mosaic packing Layout.NoMirror upper order sh) ->
Layout.Mosaic packing Layout.autoMirror upper order sh)
fromLowerPart ::
(Extent.Measure meas, Extent.C horiz,
Shape.C height, Shape.C width, Class.Floating a) =>
(Order -> Int -> Ptr a -> IO ()) ->
Layout.MirrorSingleton mirror ->
Full meas Extent.Small horiz height width a -> MosaicLower mirror height a
fromLowerPart fillDiag mirror (Array (Layout.Full order extent) a) =
let (height,width) = Extent.dimensions extent
m = Shape.size height
n = Shape.size width
k = case order of RowMajor -> n; ColumnMajor -> m
in Array.unsafeCreate
(Layout.Mosaic Layout.Packed
mirror Layout.Lower order height) $ \lPtr ->
withForeignPtr a $ \aPtr -> do
let dstOrder = flipOrder order
packRect dstOrder m k aPtr lPtr
fillDiag dstOrder m lPtr
leaveDiagonal :: Order -> Int -> Ptr a -> IO ()
leaveDiagonal _order _m _ptr = return ()
data Labelled r label a = Labelled label (ContT r IO a)
label :: label -> a -> Labelled r label a
label lab a = Labelled lab (pure a)
noLabel :: a -> Labelled r () a
noLabel a = Labelled () (pure a)
instance Functor (Labelled r label) where
fmap f (Labelled lab a) = Labelled lab $ fmap f a
runUnlabelled :: Labelled r () (IO ()) -> ContT r IO ()
runUnlabelled (Labelled () m) = liftIO =<< m
runLabelledLinear ::
String -> Labelled r String (Ptr CInt -> IO ()) -> ContT r IO ()
runLabelledLinear msg (Labelled name m) = liftIO . withInfo msg name =<< m
runLabelledWorkspace ::
(Class.Floating a) =>
String ->
Labelled r String (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) ->
ContT r IO ()
runLabelledWorkspace msg (Labelled name m) =
liftIO . withAutoWorkspaceInfo msg name =<< m
data Labelled2 r label a b = Labelled2 (Labelled r label a) (Labelled r label b)
instance Functor (Labelled2 r label a) where
fmap f (Labelled2 a b) = Labelled2 a (fmap f b)
infixl 9 $*, $**
($*) :: Labelled2 r label (a -> f) (a -> g) -> a -> Labelled2 r label f g
Labelled2 f g $* a = Labelled2 (fmap ($a) f) (fmap ($a) g)
($**) ::
Labelled2 r label (a -> f) (a -> Ptr CInt -> g) ->
(a,Int) -> Labelled2 r label f g
Labelled2 f (Labelled lab g) $** (a,n) =
Labelled2 (fmap ($a) f) (Labelled lab $ fmap ($a) g <*> Call.leadingDim n)
runPacking ::
Layout.PackingSingleton pack ->
Labelled2 r label func func -> Labelled r label func
runPacking pck (Labelled2 lp lu) =
case pck of
Layout.Packed -> lp
Layout.Unpacked -> lu
withPacking ::
Layout.PackingSingleton pack ->
Labelled2 r () (IO ()) (IO ()) -> ContT r IO ()
withPacking pck = runUnlabelled . runPacking pck
withPackingLinear ::
(func ~ (Ptr CInt -> IO ())) =>
String -> Layout.PackingSingleton pack ->
Labelled2 r String func func -> ContT r IO ()
withPackingLinear msg pck = runLabelledLinear msg . runPacking pck
data TriArg a = TriArg (Ptr a) Int
triArg :: Ptr a -> Int -> TriArg a
triArg = TriArg
applyFuncPair ::
(m ~ Labelled (FuncCont f) (FuncLabel f), FunctionPair f) =>
m (FuncPacked f) -> m (FuncUnpacked f) -> f
applyFuncPair f g = apply (Labelled2 f g)
class FunctionPair f where
type FuncCont f
type FuncLabel f
type FuncPacked f
type FuncUnpacked f
apply ::
Labelled2 (FuncCont f) (FuncLabel f) (FuncPacked f) (FuncUnpacked f) -> f
type family LabelResult a
type instance LabelResult (Labelled r label a) = a
instance FunctionPair (Labelled2 r label a b) where
type FuncCont (Labelled2 r label a b) = r
type FuncLabel (Labelled2 r label a b) = label
type FuncPacked (Labelled2 r label a b) = a
type FuncUnpacked (Labelled2 r label a b) = b
apply = id
instance (FunctionArg a, FunctionPair f) => FunctionPair (a -> f) where
type FuncCont (a -> f) = FuncCont f
type FuncLabel (a -> f) = FuncLabel f
type FuncPacked (a -> f) = FuncArgPacked a f
type FuncUnpacked (a -> f) = FuncArgUnpacked a f
apply = applyArg
class FunctionArg a where
type FuncArgPacked a f
type FuncArgUnpacked a f
applyArg ::
(FunctionPair f) =>
Labelled2 (FuncCont f)
(FuncLabel f) (FuncArgPacked a f) (FuncArgUnpacked a f) ->
a -> f
instance FunctionArg (Ptr a) where
type FuncArgPacked (Ptr a) f = Ptr a -> FuncPacked f
type FuncArgUnpacked (Ptr a) f = Ptr a -> FuncUnpacked f
applyArg fg a = apply (fg$*a)
instance FunctionArg (TriArg a) where
type FuncArgPacked (TriArg a) f = Ptr a -> FuncPacked f
type FuncArgUnpacked (TriArg a) f = Ptr a -> Ptr CInt -> FuncUnpacked f
applyArg fg (TriArg a n) = apply (fg$**(a,n))