{-# LANGUAGE TypeFamilies #-}
module Numeric.BLAS.Private where

import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.BLAS.FFI.Generic as Blas
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import Numeric.BLAS.Matrix.Modifier (Conjugation(NonConjugated, Conjugated))
import Numeric.BLAS.Scalar (RealOf, zero, one, minusOne, isZero)

import qualified Foreign.Marshal.Array.Guarded as ForeignArray
import qualified Foreign.Marshal.Utils as Marshal
import qualified Foreign.C.String as CStr
import Foreign.Marshal.Array (advancePtr)
import Foreign.C.Types (CChar, CInt)
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Storable (Storable, peek, pokeElemOff, peekElemOff)

import Control.Monad.Trans.Cont (evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)
import Control.Applicative (liftA2)

import qualified Data.Array.Comfort.Shape as Shape

import qualified Data.Complex as Complex
import Data.Complex (Complex((:+)))

import Prelude hiding (sum)


type ShapeInt = Shape.ZeroBased Int

shapeInt :: Int -> ShapeInt
shapeInt = Shape.ZeroBased


realPtr :: Ptr a -> Ptr (RealOf a)
realPtr = castPtr


pointerSeq :: (Storable a) => Int -> Ptr a -> [Ptr a]
pointerSeq k ptr = iterate (flip advancePtr k) ptr


fill :: (Class.Floating a) => a -> Int -> Ptr a -> IO ()
fill a n dstPtr = evalContT $ do
   nPtr <- Call.cint n
   srcPtr <- Call.number a
   incxPtr <- Call.cint 0
   incyPtr <- Call.cint 1
   liftIO $ Blas.copy nPtr srcPtr incxPtr dstPtr incyPtr


copyConjugate ::
   (Class.Floating a) =>
   Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyConjugate nPtr xPtr incxPtr yPtr incyPtr = do
   Blas.copy nPtr xPtr incxPtr yPtr incyPtr
   lacgv nPtr yPtr incyPtr



newtype Sum a = Sum {runSum :: Int -> Ptr a -> Int -> IO a}

sum :: Class.Floating a => Int -> Ptr a -> Int -> IO a
sum =
   runSum $
   Class.switchFloating
      (Sum sumReal)
      (Sum sumReal)
      (Sum sumComplex)
      (Sum sumComplex)

sumReal :: Class.Real a => Int -> Ptr a -> Int -> IO a
sumReal n xPtr incx =
   evalContT $ do
      nPtr <- Call.cint n
      incxPtr <- Call.cint incx
      yPtr <- Call.real one
      incyPtr <- Call.cint 0
      liftIO $ BlasReal.dot nPtr xPtr incxPtr yPtr incyPtr

sumComplex, sumComplexAlt ::
   Class.Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplex n xPtr incx =
   evalContT $ do
      nPtr <- Call.cint n
      let sxPtr = realPtr xPtr
      incxPtr <- Call.cint (2*incx)
      yPtr <- Call.real one
      incyPtr <- Call.cint 0
      liftIO $
         liftA2 (Complex.:+)
            (BlasReal.dot nPtr sxPtr incxPtr yPtr incyPtr)
            (BlasReal.dot nPtr (advancePtr sxPtr 1) incxPtr yPtr incyPtr)

sumComplexAlt n aPtr inca =
   evalContT $ do
      transPtr <- Call.char 'N'
      mPtr <- Call.cint 2
      nPtr <- Call.cint n
      onePtr <- Call.number one
      inc0Ptr <- Call.cint 0
      let saPtr = realPtr aPtr
      ldaPtr <- Call.leadingDim (2*inca)
      sxPtr <- Call.allocaArray n
      incxPtr <- Call.cint 1
      betaPtr <- Call.number zero
      yPtr <- Call.alloca
      let syPtr = realPtr yPtr
      incyPtr <- Call.cint 1
      liftIO $ do
         Blas.copy nPtr onePtr inc0Ptr sxPtr incxPtr
         gemv
            transPtr mPtr nPtr onePtr saPtr ldaPtr
            sxPtr incxPtr betaPtr syPtr incyPtr
         peek yPtr


mul ::
   (Class.Floating a) =>
   Conjugation -> Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mul conj n aPtr inca xPtr incx yPtr incy =
   mulAdd conj n aPtr inca xPtr incx zero yPtr incy

mulAdd ::
   (Class.Floating a) =>
   Conjugation ->
   Int -> Ptr a -> Int -> Ptr a -> Int -> a -> Ptr a -> Int -> IO ()
mulAdd conj n aPtr inca xPtr incx beta yPtr incy = evalContT $ do
   transPtr <- Call.char $ case conj of NonConjugated -> 'N'; Conjugated -> 'C'
   nPtr <- Call.cint n
   klPtr <- Call.cint 0
   kuPtr <- Call.cint 0
   alphaPtr <- Call.number one
   ldaPtr <- Call.leadingDim inca
   incxPtr <- Call.cint incx
   betaPtr <- Call.number beta
   incyPtr <- Call.cint incy
   liftIO $
      Blas.gbmv transPtr
         nPtr nPtr klPtr kuPtr alphaPtr aPtr ldaPtr
         xPtr incxPtr betaPtr yPtr incyPtr

{- |
Use the foldBalanced trick.
-}
product :: (Class.Floating a) => Int -> Ptr a -> Int -> IO a
product n aPtr inca =
   case compare n 1 of
      LT -> return one
      EQ -> peek aPtr
      GT -> let n2 = div n 2; new = n-n2
            in ForeignArray.alloca (2*new-1) $ \xPtr -> do
         mulPairs n2 aPtr inca xPtr 1
         when (odd n) $ pokeElemOff xPtr n2 =<< peekElemOff aPtr ((n-1)*inca)
         productLoop new xPtr

{- |
If 'mul' would be based on a scalar loop
we would not need to cut the vector into chunks.

The invariance is:
When calling @productLoop n xPtr@,
starting from xPtr there is storage allocated for 2*n-1 elements.
-}
productLoop :: (Class.Floating a) => Int -> Ptr a -> IO a
productLoop n xPtr =
   if n==1
      then peek xPtr
      else do
         let n2 = div n 2
         mulPairs n2 xPtr 1 (advancePtr xPtr n) 1
         productLoop (n-n2) (advancePtr xPtr (2*n2))

mulPairs ::
   (Class.Floating a) =>
   Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulPairs n aPtr inca xPtr incx =
   let inca2 = 2*inca
   in mul NonConjugated n aPtr inca2 (advancePtr aPtr inca) inca2 xPtr incx


newtype LACGV a = LACGV {getLACGV :: Ptr CInt -> Ptr a -> Ptr CInt -> IO ()}

lacgv :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
lacgv =
   getLACGV $
   Class.switchFloating
      (LACGV $ const $ const $ const $ return ())
      (LACGV $ const $ const $ const $ return ())
      (LACGV clacgv)
      (LACGV clacgv)

clacgv :: Class.Real a => Ptr CInt -> Ptr (Complex a) -> Ptr CInt -> IO ()
clacgv nPtr xPtr incxPtr =
   Marshal.with minusOne $ \saPtr -> do
      incx <- peek incxPtr
      Marshal.with (2*incx) $ \incyPtr ->
         BlasReal.scal nPtr saPtr (advancePtr (realPtr xPtr) 1) incyPtr


{-
Work around an inconsistency of BLAS.
In case of a zero-column matrix
BLAS's gemv and gbmv do not initialize the target vector.
In contrast, these work-arounds do.
-}
{-# INLINE gemv #-}
gemv ::
   (Class.Floating a) =>
   Ptr CChar -> Ptr CInt -> Ptr CInt ->
   Ptr a -> Ptr a -> Ptr CInt ->
   Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
gemv transPtr mPtr nPtr
      alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr = do
   initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr
   Blas.gemv transPtr mPtr nPtr
      alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr

{-# INLINE gbmv #-}
gbmv ::
   (Class.Floating a) =>
   Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr CInt -> Ptr CInt ->
   Ptr a -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt ->
   Ptr a -> Ptr a -> Ptr CInt -> IO ()
gbmv transPtr mPtr nPtr klPtr kuPtr
      alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr = do
   initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr
   Blas.gbmv transPtr mPtr nPtr klPtr kuPtr
      alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr

initializeMV ::
   Class.Floating a =>
   Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr = do
   trans <- peek transPtr
   let (mtPtr,ntPtr) =
         if trans == CStr.castCharToCChar 'N'
            then (mPtr,nPtr) else (nPtr,mPtr)
   n <- peek ntPtr
   beta <- peek betaPtr
   when (n == 0 && isZero beta) $
      Marshal.with 0 $ \incbPtr ->
      Blas.copy mtPtr betaPtr incbPtr yPtr incyPtr


{-
ToDo:

type ComplexShape =
         Shape.NestedTuple Shape.TupleAccessor (Complex Shape.Element)

This would allow the use of Complex.realPart as accessor,
but it requires GHC>7.6.3 or so, where realPart has no RealFloat constraint.
-}
type ComplexShape = Shape.NestedTuple Shape.TupleIndex (Complex Shape.Element)

ixReal, ixImaginary :: Shape.ElementIndex (Complex Shape.Element)
ixReal :+ ixImaginary =
   Shape.indexTupleFromShape (Shape.static :: ComplexShape)