{-# LANGUAGE MagicHash #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Context (
withBLAS
) where
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.LLVM.PTX
import Data.Array.Accelerate.LLVM.PTX.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base
import Control.Monad.State
import Control.Concurrent.MVar
import Data.IntMap.Strict ( IntMap )
import System.IO.Unsafe
import qualified Data.IntMap.Strict as IM
import qualified Foreign.CUDA.Driver.Context as CUDA
import qualified Foreign.CUDA.BLAS as BLAS
import GHC.Ptr
import GHC.Base
import Prelude hiding ( lookup )
withBLAS :: (BLAS.Handle -> LLVM PTX b) -> LLVM PTX b
withBLAS :: (Handle -> LLVM PTX b) -> LLVM PTX b
withBLAS Handle -> LLVM PTX b
k = do
Lifetime Context
lc <- (PTX -> Lifetime Context) -> LLVM PTX (Lifetime Context)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Context -> Lifetime Context
deviceContext (Context -> Lifetime Context)
-> (PTX -> Context) -> PTX -> Lifetime Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PTX -> Context
ptxContext)
Lifetime Handle
h <- IO (Lifetime Handle) -> LLVM PTX (Lifetime Handle)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Lifetime Handle) -> LLVM PTX (Lifetime Handle))
-> IO (Lifetime Handle) -> LLVM PTX (Lifetime Handle)
forall a b. (a -> b) -> a -> b
$
Lifetime Context
-> (Context -> IO (Lifetime Handle)) -> IO (Lifetime Handle)
forall a b. Lifetime a -> (a -> IO b) -> IO b
withLifetime Lifetime Context
lc ((Context -> IO (Lifetime Handle)) -> IO (Lifetime Handle))
-> (Context -> IO (Lifetime Handle)) -> IO (Lifetime Handle)
forall a b. (a -> b) -> a -> b
$ \Context
ctx ->
MVar (IntMap (Lifetime Handle))
-> (IntMap (Lifetime Handle)
-> IO (IntMap (Lifetime Handle), Lifetime Handle))
-> IO (Lifetime Handle)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar (IntMap (Lifetime Handle))
handles ((IntMap (Lifetime Handle)
-> IO (IntMap (Lifetime Handle), Lifetime Handle))
-> IO (Lifetime Handle))
-> (IntMap (Lifetime Handle)
-> IO (IntMap (Lifetime Handle), Lifetime Handle))
-> IO (Lifetime Handle)
forall a b. (a -> b) -> a -> b
$ \IntMap (Lifetime Handle)
im ->
let key :: Key
key = Context -> Key
toKey Context
ctx in
case Key -> IntMap (Lifetime Handle) -> Maybe (Lifetime Handle)
forall a. Key -> IntMap a -> Maybe a
IM.lookup Key
key IntMap (Lifetime Handle)
im of
Maybe (Lifetime Handle)
Nothing -> do
Handle
h <- IO Handle
BLAS.create
Lifetime Handle
l <- Handle -> IO (Lifetime Handle)
forall a. a -> IO (Lifetime a)
newLifetime Handle
h
Handle -> AtomicsMode -> IO ()
BLAS.setAtomicsMode Handle
h AtomicsMode
BLAS.Allowed
Lifetime Context -> IO () -> IO ()
forall a. Lifetime a -> IO () -> IO ()
addFinalizer Lifetime Context
lc (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar (IntMap (Lifetime Handle))
-> (IntMap (Lifetime Handle) -> IO (IntMap (Lifetime Handle), ()))
-> IO ()
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar (IntMap (Lifetime Handle))
handles (\IntMap (Lifetime Handle)
im' -> (IntMap (Lifetime Handle), ()) -> IO (IntMap (Lifetime Handle), ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Key -> IntMap (Lifetime Handle) -> IntMap (Lifetime Handle)
forall a. Key -> IntMap a -> IntMap a
IM.delete Key
key IntMap (Lifetime Handle)
im', ()))
Lifetime Handle -> IO () -> IO ()
forall a. Lifetime a -> IO () -> IO ()
addFinalizer Lifetime Handle
l (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> IO ()
BLAS.destroy Handle
h
(IntMap (Lifetime Handle), Lifetime Handle)
-> IO (IntMap (Lifetime Handle), Lifetime Handle)
forall (m :: * -> *) a. Monad m => a -> m a
return ( Key
-> Lifetime Handle
-> IntMap (Lifetime Handle)
-> IntMap (Lifetime Handle)
forall a. Key -> a -> IntMap a -> IntMap a
IM.insert Key
key Lifetime Handle
l IntMap (Lifetime Handle)
im, Lifetime Handle
l )
Just Lifetime Handle
h -> (IntMap (Lifetime Handle), Lifetime Handle)
-> IO (IntMap (Lifetime Handle), Lifetime Handle)
forall (m :: * -> *) a. Monad m => a -> m a
return (IntMap (Lifetime Handle)
im, Lifetime Handle
h)
Lifetime Handle -> (Handle -> LLVM PTX b) -> LLVM PTX b
forall a b. Lifetime a -> (a -> LLVM PTX b) -> LLVM PTX b
withLifetime' Lifetime Handle
h Handle -> LLVM PTX b
k
toKey :: CUDA.Context -> IM.Key
toKey :: Context -> Key
toKey (CUDA.Context (Ptr Addr#
addr#)) = Int# -> Key
I# (Addr# -> Int#
addr2Int# Addr#
addr#)
{-# NOINLINE handles #-}
handles :: MVar (IntMap (Lifetime BLAS.Handle))
handles :: MVar (IntMap (Lifetime Handle))
handles = IO (MVar (IntMap (Lifetime Handle)))
-> MVar (IntMap (Lifetime Handle))
forall a. IO a -> a
unsafePerformIO (IO (MVar (IntMap (Lifetime Handle)))
-> MVar (IntMap (Lifetime Handle)))
-> IO (MVar (IntMap (Lifetime Handle)))
-> MVar (IntMap (Lifetime Handle))
forall a b. (a -> b) -> a -> b
$ IntMap (Lifetime Handle) -> IO (MVar (IntMap (Lifetime Handle)))
forall a. a -> IO (MVar a)
newMVar IntMap (Lifetime Handle)
forall a. IntMap a
IM.empty