{-# LANGUAGE RecordWildCards #-}
module Data.Array.Accelerate.LLVM.PTX.Context (
Context(..),
new, raw, withContext,
) where
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.LLVM.PTX.Analysis.Device
import qualified Data.Array.Accelerate.LLVM.PTX.Debug as Debug
import qualified Foreign.CUDA.Analysis as CUDA
import qualified Foreign.CUDA.Driver as CUDA
import qualified Foreign.CUDA.Driver.Device as CUDA
import Control.Exception
import Control.Monad
import Text.PrettyPrint
data Context = Context {
deviceProperties :: {-# UNPACK #-} !CUDA.DeviceProperties
, deviceContext :: {-# UNPACK #-} !(Lifetime CUDA.Context)
}
instance Eq Context where
c1 == c2 = deviceContext c1 == deviceContext c2
new :: CUDA.Device
-> CUDA.DeviceProperties
-> [CUDA.ContextFlag]
-> IO Context
new dev prp flags = do
ctx <- raw dev prp =<< CUDA.create dev flags
_ <- CUDA.pop
return ctx
raw :: CUDA.Device
-> CUDA.DeviceProperties
-> CUDA.Context
-> IO Context
raw dev prp ctx = do
lft <- newLifetime ctx
addFinalizer lft $ do
message $ "finalise context " ++ showContext ctx
CUDA.destroy ctx
when (CUDA.computeCapability prp >= CUDA.Compute 2 0)
(CUDA.setCache CUDA.PreferL1)
Debug.traceIO Debug.verbose (deviceInfo dev prp)
return $! Context prp lft
{-# INLINE withContext #-}
withContext :: Context -> IO a -> IO a
withContext Context{..} action =
withLifetime deviceContext $ \ctx ->
bracket_ (push ctx) pop action
{-# INLINE push #-}
push :: CUDA.Context -> IO ()
push ctx = do
message $ "push context: " ++ showContext ctx
CUDA.push ctx
{-# INLINE pop #-}
pop :: IO ()
pop = do
ctx <- CUDA.pop
message $ "pop context: " ++ showContext ctx
deviceInfo :: CUDA.Device -> CUDA.DeviceProperties -> String
deviceInfo dev prp = render $ reset <>
devID <> colon <+> vcat [ name <+> parens compute
, processors <+> at <+> text clock <+> parens cores <> comma <+> memory
]
where
name = text (CUDA.deviceName prp)
compute = text "compute capatability" <+> text (show $ CUDA.computeCapability prp)
devID = text "Device" <+> int (fromIntegral $ CUDA.useDevice dev)
processors = int (CUDA.multiProcessorCount prp) <+> text "multiprocessors"
cores = int (CUDA.multiProcessorCount prp * coresPerMultiProcessor prp) <+> text "cores"
memory = text mem <+> text "global memory"
clock = Debug.showFFloatSIBase (Just 2) 1000 (fromIntegral $ CUDA.clockRate prp * 1000 :: Double) "Hz"
mem = Debug.showFFloatSIBase (Just 0) 1024 (fromIntegral $ CUDA.totalGlobalMem prp :: Double) "B"
at = char '@'
reset = zeroWidthText "\r"
{-# INLINE trace #-}
trace :: String -> IO a -> IO a
trace msg next = do
Debug.traceIO Debug.dump_gc ("gc: " ++ msg)
next
{-# INLINE message #-}
message :: String -> IO ()
message s = s `trace` return ()
{-# INLINE showContext #-}
showContext :: CUDA.Context -> String
showContext (CUDA.Context c) = show c