{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Numeric.LAPACK.Matrix.Triangular.Private where import qualified Numeric.LAPACK.Matrix.Private as Matrix import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import qualified Numeric.LAPACK.Matrix.Shape.Box as Box import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor,ColumnMajor), flipOrder, uploFromOrder, Empty, Filled, NonUnit) import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated)) import Numeric.LAPACK.Matrix.Private (Full) import Numeric.LAPACK.Scalar (zero) import Numeric.LAPACK.Private (pointerSeq, copyBlock, copyCondConjugateToTemp, pokeCInt, fill, withInfo, errorCodeMsg) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Unchecked (Array(Array)) import Data.Array.Comfort.Shape ((:+:)((:+:))) import Foreign.Marshal.Alloc (alloca) import Foreign.Marshal.Array (advancePtr) import Foreign.C.Types (CInt) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Data.Foldable (forM_) diagonalPointers :: (Storable a) => Order -> Int -> Ptr a -> [Ptr a] diagonalPointers order n aPtr = take n $ scanl advancePtr aPtr $ case order of RowMajor -> iterate pred n ColumnMajor -> iterate succ 2 diagonalPointerPairs :: (Storable a, Storable b) => Order -> Int -> Ptr a -> Ptr b -> [(Ptr a, Ptr b)] diagonalPointerPairs order n aPtr bPtr = zip (pointerSeq 1 aPtr) $ diagonalPointers order n bPtr columnMajorPointers :: (Storable a) => Int -> Ptr a -> Ptr a -> [(Int, ((Ptr a, Ptr a), Ptr a))] columnMajorPointers n fullPtr packedPtr = let ds = iterate succ 1 in take n $ zip ds $ zip (zip (pointerSeq 1 fullPtr) (pointerSeq n fullPtr)) (scanl advancePtr packedPtr ds) rowMajorPointers :: (Storable a) => Int -> Ptr a -> Ptr a -> [(Int, (Ptr a, Ptr a))] rowMajorPointers n fullPtr packedPtr = let ds = iterate pred n in take n $ zip ds $ zip (pointerSeq (n+1) fullPtr) (scanl advancePtr packedPtr ds) forPointers :: [(Int, a)] -> (Ptr CInt -> a -> IO ()) -> IO () forPointers xs act = alloca $ \nPtr -> forM_ xs $ \(d,ptrs) -> do pokeCInt nPtr d act nPtr ptrs copyTriangleToTemp :: Class.Floating a => Conjugation -> Order -> Int -> ForeignPtr a -> ContT r IO (Ptr a) copyTriangleToTemp conj order = copyCondConjugateToTemp $ case order of RowMajor -> conj ColumnMajor -> NonConjugated unpackToTemp :: Storable a => (Int -> Ptr a -> Ptr a -> IO ()) -> Int -> ForeignPtr a -> ContT r IO (Ptr a) unpackToTemp f n a = do apPtr <- ContT $ withForeignPtr a aPtr <- Call.allocaArray (n*n) liftIO $ f n apPtr aPtr return aPtr unpack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO () unpack order n packedPtr fullPtr = evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n ldaPtr <- Call.leadingDim n liftIO $ withInfo errorCodeMsg "tpttr" $ LapackGen.tpttr uploPtr nPtr packedPtr fullPtr ldaPtr pack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO () pack order n = packRect order n n packRect :: Class.Floating a => Order -> Int -> Int -> Ptr a -> Ptr a -> IO () packRect order n ld fullPtr packedPtr = evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n ldaPtr <- Call.leadingDim ld liftIO $ withInfo errorCodeMsg "trttp" $ LapackGen.trttp uploPtr nPtr fullPtr ldaPtr packedPtr unpackZero, _unpackZero :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO () _unpackZero order n packedPtr fullPtr = do fill zero (n*n) fullPtr unpack order n packedPtr fullPtr unpackZero order n packedPtr fullPtr = do fillTriangle zero (flipOrder order) n fullPtr unpack order n packedPtr fullPtr fillTriangle :: Class.Floating a => a -> Order -> Int -> Ptr a -> IO () fillTriangle z order n aPtr = evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n zPtr <- Call.number z liftIO $ LapackGen.laset uploPtr nPtr nPtr zPtr zPtr aPtr nPtr stack :: (Box.Box sh0, Box.HeightOf sh0 ~ height, Shape.C height, Eq height, Box.Box sh1, Box.WidthOf sh1 ~ width, Shape.C width, Eq width, Shape.C sh2, Class.Floating a) => String -> (height:+:width -> sh2) -> Array sh0 a -> Matrix.General height width a -> Array sh1 a -> Array sh2 a stack name consShape (Array sha a) (Array (MatrixShape.Full order extent) b) (Array shc c) = let (height,width) = Extent.dimensions extent in Array.unsafeCreate (consShape (height :+: width)) $ \xPtr -> do Call.assert (name++".stack: height shapes mismatch") $ height == Box.height sha Call.assert (name++".stack: width shapes mismatch") $ width == Box.width shc let m = Shape.size height let n = Shape.size width withForeignPtr a $ \aPtr -> copyTriangleA copyBlock order m n aPtr xPtr withForeignPtr b $ \bPtr -> copyRectangle copyBlock order m n bPtr xPtr withForeignPtr c $ \cPtr -> copyTriangleC copyBlock order m n cPtr xPtr takeTopRight :: (Shape.C sh, Shape.C height, Shape.C width, Class.Floating a) => (sh -> (MatrixShape.Order, height:+:width)) -> Array sh a -> Matrix.General height width a takeTopRight getShapes (Array sh x) = let (order, height:+:width) = getShapes sh in Array.unsafeCreate (MatrixShape.general order height width) $ \bPtr -> do let m = Shape.size height let n = Shape.size width withForeignPtr x $ copyRectangle (flip . copyBlock) order m n bPtr takeTopLeft :: (Shape.C sh, Shape.C sha, Shape.C height, Shape.C width, Class.Floating a) => (sh -> (sha, (MatrixShape.Order, height:+:width))) -> Array sh a -> Array sha a takeTopLeft getShapes (Array sh x) = let (sha, (order, height:+:width)) = getShapes sh in Array.unsafeCreate sha $ \aPtr -> do let m = Shape.size height let n = Shape.size width withForeignPtr x $ copyTriangleA (flip . copyBlock) order m n aPtr takeBottomRight :: (Shape.C sh, Shape.C shc, Shape.C height, Shape.C width, Class.Floating a) => (sh -> (shc, (MatrixShape.Order, height:+:width))) -> Array sh a -> Array shc a takeBottomRight getShapes (Array sh x) = let (shc, (order, height:+:width)) = getShapes sh in Array.unsafeCreate shc $ \cPtr -> do let m = Shape.size height let n = Shape.size width withForeignPtr x $ copyTriangleC (flip . copyBlock) order m n cPtr {-# INLINE copyTriangleA #-} copyTriangleA :: (Class.Floating a) => (Int -> Ptr a -> Ptr a -> IO ()) -> Order -> Int -> Int -> Ptr a -> Ptr a -> IO () copyTriangleA copy order m n aPtr xPtr = case order of ColumnMajor -> copy (Shape.triangleSize m) aPtr xPtr RowMajor -> forM_ (zip (iterate pred m) $ zip (diagonalPointers order m aPtr) (diagonalPointers order (m+n) xPtr)) $ \(k,(aiPtr,xiPtr)) -> copy k aiPtr xiPtr {-# INLINE copyTriangleC #-} copyTriangleC :: (Class.Floating a) => (Int -> Ptr a -> Ptr a -> IO ()) -> Order -> Int -> Int -> Ptr a -> Ptr a -> IO () copyTriangleC copy order m n cPtr xPtr = case order of RowMajor -> let triSize = Shape.triangleSize n in copy triSize cPtr (advancePtr xPtr $ Shape.triangleSize (m+n) - triSize) ColumnMajor -> forM_ (zip (iterate succ 0) $ zip (diagonalPointers order n cPtr) (drop m $ diagonalPointers order (m+n) xPtr)) $ \(k,(aiPtr,xiPtr)) -> copy (k+1) (advancePtr aiPtr (-k)) (advancePtr xiPtr (-k)) {-# INLINE copyRectangle #-} copyRectangle :: (Class.Floating a) => (Int -> Ptr a -> Ptr a -> IO ()) -> Order -> Int -> Int -> Ptr a -> Ptr a -> IO () copyRectangle copy order m n bPtr xPtr = case order of RowMajor -> forM_ (take m $ zip (iterate pred m) $ zip (pointerSeq n bPtr) (diagonalPointers order (m+n) xPtr)) $ \(k,(biPtr,xiPtr)) -> copy n biPtr (advancePtr xiPtr k) ColumnMajor -> forM_ (take n $ zip (iterate succ m) $ zip (pointerSeq m bPtr) (drop m $ diagonalPointers order (m+n) xPtr)) $ \(k,(biPtr,xiPtr)) -> copy m biPtr (advancePtr xiPtr (-k)) type Triangular lo diag up sh = Array (MatrixShape.Triangular lo diag up sh) type FlexDiagonal diag sh = Triangular MatrixShape.Empty diag MatrixShape.Empty sh newtype MultiplyRight diag sh a b lo up = MultiplyRight {getMultiplyRight :: Triangular lo diag up sh a -> b} newtype Map diag sh0 sh1 a lo up = Map {getMap :: Triangular lo diag up sh0 a -> Triangular lo diag up sh1 a} newtype Power diag sh a lo up = Power { getPower :: Triangular lo diag up sh a -> Triangular lo (PowerDiag lo up diag) up sh a } type family PowerDiag lo up diag type instance PowerDiag Empty up diag = diag type instance PowerDiag Filled Empty diag = diag type instance PowerDiag Filled Filled diag = NonUnit fromBanded :: (Class.Floating a) => Int -> Order -> Int -> ForeignPtr a -> Int -> Ptr a -> IO () fromBanded k order n a bSize bPtr = withForeignPtr a $ \aPtr -> do fill zero bSize bPtr let lda = k+1 let pointers = zip [0..] $ zip (pointerSeq lda aPtr) $ diagonalPointers order n bPtr case order of ColumnMajor -> forM_ pointers $ \(i,(xPtr,yPtr)) -> let j = min i k in copyBlock (j+1) (advancePtr xPtr (k-j)) (advancePtr yPtr (-j)) RowMajor -> forM_ pointers $ \(i,(xPtr,yPtr)) -> copyBlock (min lda (n-i)) xPtr yPtr type FlexLower diag sh = Array (MatrixShape.LowerTriangular diag sh) takeLower :: (MatrixShape.TriDiag diag, Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => (diag, Order -> Int -> Ptr a -> IO ()) -> Full Extent.Small horiz height width a -> FlexLower diag height a takeLower (diag, fillDiag) (Array (MatrixShape.Full order extent) a) = let (height,width) = Extent.dimensions extent m = Shape.size height n = Shape.size width k = case order of RowMajor -> n; ColumnMajor -> m in Array.unsafeCreate (MatrixShape.Triangular diag MatrixShape.lower order height) $ \lPtr -> withForeignPtr a $ \aPtr -> do let dstOrder = flipOrder order packRect dstOrder m k aPtr lPtr fillDiag dstOrder m lPtr