{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Numeric.LAPACK.Matrix.Mosaic.Private where

import qualified Numeric.LAPACK.Matrix.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Layout.Private
         (Order(RowMajor,ColumnMajor), flipOrder, uploFromOrder)
import Numeric.BLAS.Matrix.Modifier (Conjugation(NonConjugated))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Scalar (zero)
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked))
import Numeric.LAPACK.Private
         (pointerSeq, copyBlock, copyCondConjugateToTemp,
          pokeCInt, fill, withAutoWorkspaceInfo, withInfo, errorCodeMsg)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape ((::+)((::+)))

import Foreign.Marshal.Alloc (alloca)
import Foreign.Marshal.Array (advancePtr)
import Foreign.C.Types (CInt)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Applicative (pure, (<*>))

import Data.Foldable (forM_)

type Mosaic pack mirror uplo sh = Array (Layout.Mosaic pack mirror uplo sh)
type MosaicPacked mirror uplo sh = Mosaic Layout.Packed mirror uplo sh
type MosaicUnpacked mirror uplo sh = Mosaic Layout.Unpacked mirror uplo sh

type MosaicLower mirror sh = MosaicPacked mirror Shape.Lower sh
type MosaicUpper mirror sh = MosaicPacked mirror Shape.Upper sh

diagonalPointers :: (Storable a) => Order -> Int -> Ptr a -> [Ptr a]
diagonalPointers order n aPtr =
   take n $ scanl advancePtr aPtr $
   case order of
      RowMajor -> iterate pred n
      ColumnMajor -> iterate succ 2

diagonalPointerPairs ::
   (Storable a, Storable b) =>
   Order -> Int -> Ptr a -> Ptr b -> [(Ptr a, Ptr b)]
diagonalPointerPairs order n aPtr bPtr =
   zip (pointerSeq 1 aPtr) $ diagonalPointers order n bPtr

columnMajorPointers ::
   (Storable a) => Int -> Ptr a -> Ptr a -> [(Int, ((Ptr a, Ptr a), Ptr a))]
columnMajorPointers n fullPtr packedPtr =
   let ds = iterate succ 1
   in  take n $ zip ds $
         (zip (pointerSeq 1 fullPtr) (pointerSeq n fullPtr))
         (scanl advancePtr packedPtr ds)

rowMajorPointers ::
   (Storable a) => Int -> Ptr a -> Ptr a -> [(Int, (Ptr a, Ptr a))]
rowMajorPointers n fullPtr packedPtr =
   let ds = iterate pred n
   in  take n $ zip ds $
       zip (pointerSeq (n+1) fullPtr) (scanl advancePtr packedPtr ds)

forPointers :: [(Int, a)] -> (Ptr CInt -> a -> IO ()) -> IO ()
forPointers xs act =
   alloca $ \nPtr ->
   forM_ xs $ \(d,ptrs) -> do
      pokeCInt nPtr d
      act nPtr ptrs

copyTriangleToTemp ::
   Class.Floating a =>
   Conjugation -> Order -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyTriangleToTemp conj order =
   copyCondConjugateToTemp $
   case order of
      RowMajor -> conj
      ColumnMajor -> NonConjugated

unpackToTemp ::
   Storable a =>
   (Int -> Ptr a -> Ptr a -> IO ()) ->
   Int -> ForeignPtr a -> ContT r IO (Ptr a)
unpackToTemp f n a = do
   apPtr <- ContT $ withForeignPtr a
   aPtr <- Call.allocaArray (n*n)
   liftIO $ f n apPtr aPtr
   return aPtr

unpack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
unpack order n packedPtr fullPtr =
   evalContT $ do
      uploPtr <- Call.char $ uploFromOrder order
      nPtr <- Call.cint n
      ldaPtr <- Call.leadingDim n
      liftIO $ withInfo errorCodeMsg "tpttr" $
         LapackGen.tpttr uploPtr nPtr packedPtr fullPtr ldaPtr

pack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
pack order n = packRect order n n

packRect :: Class.Floating a => Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
packRect order n ld fullPtr packedPtr =
   evalContT $ do
      uploPtr <- Call.char $ uploFromOrder order
      nPtr <- Call.cint n
      ldaPtr <- Call.leadingDim ld
      liftIO $ withInfo errorCodeMsg "trttp" $
         LapackGen.trttp uploPtr nPtr fullPtr ldaPtr packedPtr

unpackZero, _unpackZero ::
   Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
_unpackZero order n packedPtr fullPtr = do
   fill zero (n*n) fullPtr
   unpack order n packedPtr fullPtr

unpackZero order n packedPtr fullPtr = do
   fillTriangle zero (flipOrder order) n fullPtr
   unpack order n packedPtr fullPtr

fillTriangle :: Class.Floating a => a -> Order -> Int -> Ptr a -> IO ()
fillTriangle z order n aPtr = evalContT $ do
   uploPtr <- Call.char $ uploFromOrder order
   nPtr <- Call.cint n
   zPtr <- Call.number z
   liftIO $ LapackGen.laset uploPtr nPtr nPtr zPtr zPtr aPtr nPtr

uncheck ::
   Mosaic pack mirror uplo sh a -> Mosaic pack mirror uplo (Unchecked sh) a
uncheck =
   Array.mapShape $
      \(Layout.Mosaic packing mirror uplo order sh) ->
         Layout.Mosaic packing mirror uplo order (Unchecked sh)

recheck ::
   Mosaic pack mirror uplo (Unchecked sh) a -> Mosaic pack mirror uplo sh a
recheck =
   Array.mapShape $
      \(Layout.Mosaic packing mirror uplo order (Unchecked sh)) ->
         Layout.Mosaic packing mirror uplo order sh

stack ::
   (Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   MosaicUpper mirror height a ->
   Matrix.General height width a ->
   MosaicUpper mirror width a ->
   MosaicUpper mirror (height::+width) a
stack (Array sha a) (Array (Layout.Full order extent) b) (Array shc c) =
   let name = show $ Layout.mosaicMirror sha
       (height,width) = Extent.dimensions extent
   in Array.unsafeCreate
         (Layout.Mosaic Layout.Packed
            (Layout.mosaicMirror sha)
            Layout.Upper order (height ::+ width)) $ \xPtr -> do
      Call.assert (name++".stack: height shapes mismatch") $
         height == Layout.mosaicSize sha
      Call.assert (name++".stack: width shapes mismatch") $
         width == Layout.mosaicSize shc
      let m = Shape.size height
      let n = Shape.size width
      withForeignPtr a $ \aPtr -> copyTriangleA copyBlock order m n aPtr xPtr
      withForeignPtr b $ \bPtr -> copyRectangle copyBlock order m n bPtr xPtr
      withForeignPtr c $ \cPtr -> copyTriangleC copyBlock order m n cPtr xPtr

takeTopRight ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   MosaicUpper mirror (height::+width) a -> Matrix.General height width a
      (Layout.Mosaic _packed _mirror _upper order (height::+width)) x) =
   Array.unsafeCreate (Layout.general order height width) $ \bPtr -> do
      let m = Shape.size height
      let n = Shape.size width
      withForeignPtr x $ copyRectangle (flip . copyBlock) order m n bPtr

takeTopLeft ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   MosaicUpper mirror (height::+width) a ->
   MosaicUpper mirror height a
   (Array (Layout.Mosaic packing mirror upper order (height::+width)) x) =
   Array.unsafeCreate (Layout.Mosaic packing mirror upper order height) $
         \aPtr -> do
      let m = Shape.size height
      let n = Shape.size width
      withForeignPtr x $ copyTriangleA (flip . copyBlock) order m n aPtr

takeBottomRight ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   MosaicUpper mirror (height::+width) a ->
   MosaicUpper mirror width a
   (Array (Layout.Mosaic packing mirror upper order (height::+width)) x) =
   Array.unsafeCreate (Layout.Mosaic packing mirror upper order width) $
         \cPtr -> do
      let m = Shape.size height
      let n = Shape.size width
      withForeignPtr x $ copyTriangleC (flip . copyBlock) order m n cPtr

{-# INLINE copyTriangleA #-}
copyTriangleA ::
   (Class.Floating a) =>
   (Int -> Ptr a -> Ptr a -> IO ()) ->
   Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyTriangleA copy order m n aPtr xPtr =
   case order of
      ColumnMajor -> copy (Shape.triangleSize m) aPtr xPtr
      RowMajor ->
         forM_ (zip (iterate pred m) $
                zip (diagonalPointers order m aPtr)
                    (diagonalPointers order (m+n) xPtr)) $
            \(k,(aiPtr,xiPtr)) -> copy k aiPtr xiPtr

{-# INLINE copyTriangleC #-}
copyTriangleC ::
   (Class.Floating a) =>
   (Int -> Ptr a -> Ptr a -> IO ()) ->
   Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyTriangleC copy order m n cPtr xPtr =
   case order of
      RowMajor ->
         let triSize = Shape.triangleSize n
         in copy triSize cPtr
               (advancePtr xPtr $ Shape.triangleSize (m+n) - triSize)
      ColumnMajor ->
         forM_ (zip (iterate succ 0) $
                zip (diagonalPointers order n cPtr)
                    (drop m $ diagonalPointers order (m+n) xPtr)) $
            \(k,(aiPtr,xiPtr)) ->
               copy (k+1) (advancePtr aiPtr (-k)) (advancePtr xiPtr (-k))

{-# INLINE copyRectangle #-}
copyRectangle ::
   (Class.Floating a) =>
   (Int -> Ptr a -> Ptr a -> IO ()) ->
   Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyRectangle copy order m n bPtr xPtr =
   case order of
      RowMajor ->
         forM_ (take m $ zip (iterate pred m) $
                zip (pointerSeq n bPtr) (diagonalPointers order (m+n) xPtr)) $
            \(k,(biPtr,xiPtr)) -> copy n biPtr (advancePtr xiPtr k)
      ColumnMajor ->
         forM_ (take n $ zip (iterate succ m) $
                zip (pointerSeq m bPtr)
                    (drop m $ diagonalPointers order (m+n) xPtr)) $
            \(k,(biPtr,xiPtr)) -> copy m biPtr (advancePtr xiPtr (-k))

type Triangular uplo sh = Array (Layout.Triangular uplo sh)
type Lower sh = Triangular Shape.Lower sh
type Upper sh = Triangular Shape.Upper sh

newtype MultiplyRight sh a b uplo =
   MultiplyRight {getMultiplyRight :: Triangular uplo sh a -> b}

newtype Map pack mirror sh0 sh1 a uplo =
   Map {
      getMap :: Mosaic pack mirror uplo sh0 a -> Mosaic pack mirror uplo sh1 a

fromBanded ::
   (Class.Floating a) =>
   Int -> Order -> Int -> ForeignPtr a -> Int -> Ptr a -> IO ()
fromBanded k order n a bSize bPtr =
   withForeignPtr a $ \aPtr -> do
      fill zero bSize bPtr
      let lda = k+1
      let pointers =
            zip [0..] $ zip (pointerSeq lda aPtr) $
            diagonalPointers order n bPtr
      case order of
         ColumnMajor ->
            forM_ pointers $ \(i,(xPtr,yPtr)) ->
               let j = min i k
               in copyBlock (j+1) (advancePtr xPtr (k-j)) (advancePtr yPtr (-j))
         RowMajor ->
            forM_ pointers $ \(i,(xPtr,yPtr)) ->
               copyBlock (min lda (n-i)) xPtr yPtr

Naming is inconsistent to Triangular.takeUpper,
because here Hermitian is the input
and in Triangular.takeUpper, Triangular is the output.
takeUpper :: MosaicUpper mirror sh a -> Upper sh a
takeUpper =
      (\(Layout.Mosaic packing _mirror upper order sh) ->
         Layout.Mosaic packing Layout.NoMirror upper order sh)

fromUpper ::
   (Layout.Mirror mirror) => Upper sh a -> MosaicUpper mirror sh a
fromUpper =
      (\(Layout.Mosaic packing Layout.NoMirror upper order sh) ->
         Layout.Mosaic packing Layout.autoMirror upper order sh)

fromLowerPart ::
   (Extent.Measure meas, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   (Order -> Int -> Ptr a -> IO ()) ->
   Layout.MirrorSingleton mirror ->
   Full meas Extent.Small horiz height width a -> MosaicLower mirror height a
fromLowerPart fillDiag mirror (Array (Layout.Full order extent) a) =
   let (height,width) = Extent.dimensions extent
       m = Shape.size height
       n = Shape.size width
       k = case order of RowMajor -> n; ColumnMajor -> m
   in Array.unsafeCreate
         (Layout.Mosaic Layout.Packed
            mirror Layout.Lower order height) $ \lPtr ->
      withForeignPtr a $ \aPtr -> do
         let dstOrder = flipOrder order
         packRect dstOrder m k aPtr lPtr
         fillDiag dstOrder m lPtr

leaveDiagonal :: Order -> Int -> Ptr a -> IO ()
leaveDiagonal _order _m _ptr = return ()

data Labelled r label a = Labelled label (ContT r IO a)

label :: label -> a -> Labelled r label a
label lab a = Labelled lab (pure a)

noLabel :: a -> Labelled r () a
noLabel a = Labelled () (pure a)

instance Functor (Labelled r label) where
   fmap f (Labelled lab a) = Labelled lab $ fmap f a

runUnlabelled :: Labelled r () (IO ()) -> ContT r IO ()
runUnlabelled (Labelled () m)  =  liftIO =<< m

runLabelledLinear ::
   String -> Labelled r String (Ptr CInt -> IO ()) -> ContT r IO ()
runLabelledLinear msg (Labelled name m)  =  liftIO . withInfo msg name =<< m

runLabelledWorkspace ::
   (Class.Floating a) =>
   String ->
   Labelled r String (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) ->
   ContT r IO ()
runLabelledWorkspace msg (Labelled name m) =
   liftIO . withAutoWorkspaceInfo msg name =<< m

data Labelled2 r label a b = Labelled2 (Labelled r label a) (Labelled r label b)

instance Functor (Labelled2 r label a) where
   fmap f (Labelled2 a b) = Labelled2 a (fmap f b)

infixl 9 $*, $**

($*) :: Labelled2 r label (a -> f) (a -> g) -> a -> Labelled2 r label f g
Labelled2 f g $* a = Labelled2 (fmap ($ a) f) (fmap ($ a) g)

($**) ::
   Labelled2 r label (a -> f) (a -> Ptr CInt -> g) ->
   (a,Int) -> Labelled2 r label f g
Labelled2 f (Labelled lab g) $** (a,n) =
   Labelled2 (fmap ($ a) f) (Labelled lab $ fmap ($ a) g <*> Call.leadingDim n)

runPacking ::
   Layout.PackingSingleton pack ->
   Labelled2 r label func func -> Labelled r label func
runPacking pck (Labelled2 lp lu) =
   case pck of
      Layout.Packed -> lp
      Layout.Unpacked -> lu

withPacking ::
   Layout.PackingSingleton pack ->
   Labelled2 r () (IO ()) (IO ()) -> ContT r IO ()
withPacking pck = runUnlabelled . runPacking pck

withPackingLinear ::
   (func ~ (Ptr CInt -> IO ())) =>
   String -> Layout.PackingSingleton pack ->
   Labelled2 r String func func -> ContT r IO ()
withPackingLinear msg pck = runLabelledLinear msg . runPacking pck

data TriArg a = TriArg (Ptr a) Int

triArg :: Ptr a -> Int -> TriArg a
triArg = TriArg

applyFuncPair ::
   (m ~ Labelled (FuncCont f) (FuncLabel f), FunctionPair f) =>
   m (FuncPacked f) -> m (FuncUnpacked f) -> f
applyFuncPair f g = apply (Labelled2 f g)

class FunctionPair f where
   type FuncCont f
   type FuncLabel f
   type FuncPacked f
   type FuncUnpacked f
   apply ::
      Labelled2 (FuncCont f) (FuncLabel f) (FuncPacked f) (FuncUnpacked f) -> f

type family LabelResult a
type instance LabelResult (Labelled r label a) = a

instance FunctionPair (Labelled2 r label a b) where
   type FuncCont (Labelled2 r label a b) = r
   type FuncLabel (Labelled2 r label a b) = label
   type FuncPacked (Labelled2 r label a b) = a
   type FuncUnpacked (Labelled2 r label a b) = b
   apply = id

instance (FunctionArg a, FunctionPair f) => FunctionPair (a -> f) where
   type FuncCont (a -> f) = FuncCont f
   type FuncLabel (a -> f) = FuncLabel f
   type FuncPacked (a -> f) = FuncArgPacked a f
   type FuncUnpacked (a -> f) = FuncArgUnpacked a f
   apply = applyArg

class FunctionArg a where
   type FuncArgPacked a f
   type FuncArgUnpacked a f
   applyArg ::
      (FunctionPair f) =>
      Labelled2 (FuncCont f)
         (FuncLabel f) (FuncArgPacked a f) (FuncArgUnpacked a f) ->
      a -> f

instance FunctionArg (Ptr a) where
   type FuncArgPacked (Ptr a) f = Ptr a -> FuncPacked f
   type FuncArgUnpacked (Ptr a) f = Ptr a -> FuncUnpacked f
   applyArg fg a = apply (fg$*a)

instance FunctionArg (TriArg a) where
   type FuncArgPacked (TriArg a) f = Ptr a -> FuncPacked f
   type FuncArgUnpacked (TriArg a) f = Ptr a -> Ptr CInt -> FuncUnpacked f
   applyArg fg (TriArg a n) = apply (fg$**(a,n))