{-# LANGUAGE BangPatterns , RankNTypes, GADTs #-}


module Numerical.HBLAS.BLAS where
--(
--    dgemm
--    ,sgemm
--    ,cgemm
--    ,zgemm) 


import Numerical.HBLAS.UtilsFFI    
import Numerical.HBLAS.BLAS.FFI 
import Numerical.HBLAS.MatrixTypes
import Control.Monad.Primitive
import Data.Complex 
import qualified Data.Vector.Storable.Mutable as SM

--
flopsThreshold = 10000
gemmComplexity a b c = a * b * c  -- this will be wrong by some constant factor, albeit a small one


-- this covers the ~6 cases for checking the dimensions for GEMM quite nicely
isBadGemm tra trb  ax ay bx by cx cy = isBadGemmHelper (cds tra (ax,ay)) (cds trb (bx,by) )  (cx,cy)
    where 
    cds = coordSwapper 
    isBadGemmHelper !(ax,ay) !(bx,by) !(cx,cy) =  (minimum [ax, ay, bx, by, cx ,cy] <= 0)  
        || not (  cy ==  ay && cx == bx && ax == by)

coordSwapper :: Transpose -> (a,a)-> (a,a)
coordSwapper NoTranspose (a,b) = (a,b)
coordSwapper ConjNoTranspose (a,b) = (a,b) 
coordSwapper Transpose (a,b) = (b,a)
coordSwapper ConjTranspose (a,b) = (b,a)


encodeNiceOrder :: SOrientation x  -> CBLAS_ORDERT
encodeNiceOrder SRow= encodeOrder  BLASRowMajor
encodeNiceOrder SColumn= encodeOrder BLASColMajor


encodeFFITranpose :: Transpose -> CBLAS_TRANSPOSET
encodeFFITranpose  x=  encodeTranpose $ encodeNiceTranpose x 

encodeNiceTranpose :: Transpose -> BLAS_Transpose
encodeNiceTranpose x = case x of 
        NoTranspose -> BlasNoTranspose
        Transpose -> BlasTranspose
        ConjTranspose -> BlasConjTranspose
        ConjNoTranspose -> BlasConjNoTranspose

--data BLAS_Tranpose = BlasNoTranspose | BlasTranpose | BlasConjTranspose | BlasConjNoTranpose 
--data Tranpose = NoTranpose | Tranpose | ConjTranpose | ConjNoTranpose


type GemmFun el orient s m = Transpose ->Transpose ->  el -> el  -> MutDenseMatrix s orient el
  ->   MutDenseMatrix s orient el  ->  MutDenseMatrix s orient el -> m ()


{-
A key design goal of this ffi is to provide *safe* throughput guarantees 
for a concurrent application built on top of these apis, while evading
any overheads for providing such safety. Accordingly, on inputs sizes

-}


---- |  Matrix mult for general dense matrices
--type GemmFunFFI scale el = CBLAS_ORDERT ->   CBLAS_TRANSPOSET -> CBLAS_TRANSPOSET->
        --CInt -> CInt -> CInt -> {- scal A * B -} scale  -> {- Matrix A-} Ptr el  -> CInt -> {- B -}  Ptr el -> CInt-> 
            --scale -> {- C -}  Ptr el -> CInt -> IO ()
--type GemmFun = MutDenseMatrix or el ->  MutDenseMatrix or el ->   MutDenseMatrix or el -> m ()

{-# NOINLINE gemmAbstraction #-}
gemmAbstraction:: (SM.Storable el, PrimMonad m) =>  String -> 
    GemmFunFFI scale el -> GemmFunFFI scale el -> (el -> (scale -> m ())->m ()) -> forall orient . GemmFun el orient (PrimState m) m 
gemmAbstraction gemmName gemmSafeFFI gemmUnsafeFFI constHandler = go 
  where 
    shouldCallFast :: Int -> Int -> Int -> Bool                         
    shouldCallFast cy cx ax = flopsThreshold >= gemmComplexity cy cx ax

    go  tra trb  alpha beta 
        (MutableDenseMatrix ornta ax ay astride abuff) 
        (MutableDenseMatrix _ bx by bstride bbuff) 
        (MutableDenseMatrix _ cx cy cstride cbuff) 
            |  isBadGemm tra trb  ax ay bx by cx cy = error $! "bad dimension args to GEMM: ax ay bx by cx cy: " ++ show [ax, ay, bx, by, cx ,cy]
            | SM.overlaps abuff cbuff || SM.overlaps bbuff cbuff = 
                    error $ "the read and write inputs for: " ++ gemmName ++ " overlap. This is a programmer error. Please fix." 
            | otherwise  = 
                {-  FIXME : Add Sharing check that also errors out for now-}
                unsafeWithPrim abuff $ \ap -> 
                unsafeWithPrim bbuff $ \bp ->  
                unsafeWithPrim cbuff $ \cp  -> 
                constHandler alpha $  \alphaPtr ->   
                constHandler beta $ \betaPtr ->  
                    do  (ax,ay) <- return $ coordSwapper tra (ax,ay)
                        --- dont need to swap b, info is in a and c
                        --- c doesn't get implicitly transposed
                        blasOrder <- return $ encodeNiceOrder ornta -- all three are the same orientation
                        rawTra <- return $  encodeFFITranpose tra 
                        rawTrb <- return $   encodeFFITranpose trb
                                 -- example of why i want to switch to singletones
                        unsafePrimToPrim $!  (if shouldCallFast cy cx ax then gemmUnsafeFFI  else gemmSafeFFI ) 
                            blasOrder rawTra rawTrb (fromIntegral cy) (fromIntegral cx) (fromIntegral ax) 
                                alphaPtr ap  (fromIntegral astride) bp (fromIntegral bstride) betaPtr  cp (fromIntegral cstride)


{-pureGemm :: PrimMonad m=>
(Transpose ->Transpose ->  el -> el  -> MutDenseMatrix (PrimState m) orient el
      ->   MutDenseMatrix (PrimState m) orient el  ->  
                                    MutDenseMatrix (PrimState m) orient el -> m ())->
  Transpose ->Transpose ->  el -> el  -> DenseMatrix  orient el
  ->   DenseMatrix  orient el  ->  DenseMatrix  orient el  -}

sgemm :: PrimMonad m=> 
     Transpose ->Transpose ->  Float -> Float  -> MutDenseMatrix (PrimState m) orient Float
  ->   MutDenseMatrix (PrimState m) orient Float  ->  MutDenseMatrix (PrimState m) orient Float -> m ()
sgemm =  gemmAbstraction "sgemm" cblas_sgemm_unsafe cblas_sgemm_safe (\x f -> f x )                                 
                        

dgemm :: PrimMonad m=> 
     Transpose ->Transpose ->  Double -> Double -> MutDenseMatrix (PrimState m) orient Double
  ->   MutDenseMatrix (PrimState m) orient Double   ->  MutDenseMatrix (PrimState m) orient Double -> m ()
dgemm = gemmAbstraction "dgemm" cblas_dgemm_unsafe cblas_dgemm_safe (\x f -> f x )    
 

cgemm :: PrimMonad m=>  Transpose ->Transpose ->  (Complex Float) -> (Complex Float)  -> 
        MutDenseMatrix (PrimState m) orient (Complex Float)  ->   
        MutDenseMatrix (PrimState m) orient (Complex Float)  ->  
        MutDenseMatrix (PrimState m) orient (Complex Float) -> m ()
cgemm = gemmAbstraction "cgemm" cblas_cgemm_unsafe cblas_cgemm_safe withRStorable_                                

zgemm :: PrimMonad m=>  Transpose ->Transpose ->  (Complex Double) -> (Complex Double )  -> 
        MutDenseMatrix (PrimState m) orient (Complex Double )  ->   
        MutDenseMatrix (PrimState m) orient (Complex Double)  ->  
        MutDenseMatrix (PrimState m) orient (Complex Double) -> m ()
zgemm = gemmAbstraction "zgemm" cblas_zgemm_unsafe cblas_zgemm_safe withRStorable_