{-# LANGUAGE MagicHash #-}
-- |
-- Module      : Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Context
-- Copyright   : [2017..2020] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Context (

  withBLAS

) where

import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.LLVM.PTX
import Data.Array.Accelerate.LLVM.PTX.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base

import Control.Monad.State
import Control.Concurrent.MVar
import Data.IntMap.Strict                                           ( IntMap )
import System.IO.Unsafe
import qualified Data.IntMap.Strict                                 as IM

import qualified Foreign.CUDA.Driver.Context                        as CUDA
import qualified Foreign.CUDA.BLAS                                  as BLAS

import GHC.Ptr
import GHC.Base
import Prelude                                                      hiding ( lookup )


-- Execute an operation with a cuBLAS handle appropriate for the current
-- execution context.
--
-- Initial creation of the context is an atomic operation, but subsequently
-- multiple threads may use the context concurrently.
--
-- <http://docs.nvidia.com/cuda/cublas/index.html#thread-safety2>
--
withBLAS :: (BLAS.Handle -> LLVM PTX b) -> LLVM PTX b
withBLAS :: (Handle -> LLVM PTX b) -> LLVM PTX b
withBLAS Handle -> LLVM PTX b
k = do
  Lifetime Context
lc <- (PTX -> Lifetime Context) -> LLVM PTX (Lifetime Context)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Context -> Lifetime Context
deviceContext (Context -> Lifetime Context)
-> (PTX -> Context) -> PTX -> Lifetime Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PTX -> Context
ptxContext)
  Lifetime Handle
h  <- IO (Lifetime Handle) -> LLVM PTX (Lifetime Handle)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Lifetime Handle) -> LLVM PTX (Lifetime Handle))
-> IO (Lifetime Handle) -> LLVM PTX (Lifetime Handle)
forall a b. (a -> b) -> a -> b
$
          Lifetime Context
-> (Context -> IO (Lifetime Handle)) -> IO (Lifetime Handle)
forall a b. Lifetime a -> (a -> IO b) -> IO b
withLifetime Lifetime Context
lc    ((Context -> IO (Lifetime Handle)) -> IO (Lifetime Handle))
-> (Context -> IO (Lifetime Handle)) -> IO (Lifetime Handle)
forall a b. (a -> b) -> a -> b
$ \Context
ctx ->
          MVar (IntMap (Lifetime Handle))
-> (IntMap (Lifetime Handle)
    -> IO (IntMap (Lifetime Handle), Lifetime Handle))
-> IO (Lifetime Handle)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar (IntMap (Lifetime Handle))
handles ((IntMap (Lifetime Handle)
  -> IO (IntMap (Lifetime Handle), Lifetime Handle))
 -> IO (Lifetime Handle))
-> (IntMap (Lifetime Handle)
    -> IO (IntMap (Lifetime Handle), Lifetime Handle))
-> IO (Lifetime Handle)
forall a b. (a -> b) -> a -> b
$ \IntMap (Lifetime Handle)
im  ->
            let key :: Key
key = Context -> Key
toKey Context
ctx in
            case Key -> IntMap (Lifetime Handle) -> Maybe (Lifetime Handle)
forall a. Key -> IntMap a -> Maybe a
IM.lookup Key
key IntMap (Lifetime Handle)
im of
              -- handle does not exist yet; create it and add to the global
              -- state for reuse
              Maybe (Lifetime Handle)
Nothing -> do
                Handle
h <- IO Handle
BLAS.create
                Lifetime Handle
l <- Handle -> IO (Lifetime Handle)
forall a. a -> IO (Lifetime a)
newLifetime Handle
h
                -- BLAS.setPointerMode h BLAS.Device
                Handle -> AtomicsMode -> IO ()
BLAS.setAtomicsMode Handle
h AtomicsMode
BLAS.Allowed
                Lifetime Context -> IO () -> IO ()
forall a. Lifetime a -> IO () -> IO ()
addFinalizer Lifetime Context
lc (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar (IntMap (Lifetime Handle))
-> (IntMap (Lifetime Handle) -> IO (IntMap (Lifetime Handle), ()))
-> IO ()
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar (IntMap (Lifetime Handle))
handles (\IntMap (Lifetime Handle)
im' -> (IntMap (Lifetime Handle), ()) -> IO (IntMap (Lifetime Handle), ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Key -> IntMap (Lifetime Handle) -> IntMap (Lifetime Handle)
forall a. Key -> IntMap a -> IntMap a
IM.delete Key
key IntMap (Lifetime Handle)
im', ()))
                Lifetime Handle -> IO () -> IO ()
forall a. Lifetime a -> IO () -> IO ()
addFinalizer Lifetime Handle
l  (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> IO ()
BLAS.destroy Handle
h
                (IntMap (Lifetime Handle), Lifetime Handle)
-> IO (IntMap (Lifetime Handle), Lifetime Handle)
forall (m :: * -> *) a. Monad m => a -> m a
return ( Key
-> Lifetime Handle
-> IntMap (Lifetime Handle)
-> IntMap (Lifetime Handle)
forall a. Key -> a -> IntMap a -> IntMap a
IM.insert Key
key Lifetime Handle
l IntMap (Lifetime Handle)
im, Lifetime Handle
l )

              -- return existing handle
              Just Lifetime Handle
h  -> (IntMap (Lifetime Handle), Lifetime Handle)
-> IO (IntMap (Lifetime Handle), Lifetime Handle)
forall (m :: * -> *) a. Monad m => a -> m a
return (IntMap (Lifetime Handle)
im, Lifetime Handle
h)
  --
  Lifetime Handle -> (Handle -> LLVM PTX b) -> LLVM PTX b
forall a b. Lifetime a -> (a -> LLVM PTX b) -> LLVM PTX b
withLifetime' Lifetime Handle
h Handle -> LLVM PTX b
k


toKey :: CUDA.Context -> IM.Key
toKey :: Context -> Key
toKey (CUDA.Context (Ptr Addr#
addr#)) = Int# -> Key
I# (Addr# -> Int#
addr2Int# Addr#
addr#)

{-# NOINLINE handles #-}
handles :: MVar (IntMap (Lifetime BLAS.Handle))
handles :: MVar (IntMap (Lifetime Handle))
handles = IO (MVar (IntMap (Lifetime Handle)))
-> MVar (IntMap (Lifetime Handle))
forall a. IO a -> a
unsafePerformIO (IO (MVar (IntMap (Lifetime Handle)))
 -> MVar (IntMap (Lifetime Handle)))
-> IO (MVar (IntMap (Lifetime Handle)))
-> MVar (IntMap (Lifetime Handle))
forall a b. (a -> b) -> a -> b
$ IntMap (Lifetime Handle) -> IO (MVar (IntMap (Lifetime Handle)))
forall a. a -> IO (MVar a)
newMVar IntMap (Lifetime Handle)
forall a. IntMap a
IM.empty