{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
module Data.Array.Accelerate.LLVM.PTX.Execute.Stream (
Reservoir, new,
Stream, create, destroy, streaming,
) where
import Data.Array.Accelerate.Lifetime
import qualified Data.Array.Accelerate.Array.Remote.LRU as Remote
import Data.Array.Accelerate.LLVM.PTX.Array.Remote ( )
import Data.Array.Accelerate.LLVM.PTX.Execute.Event ( Event )
import Data.Array.Accelerate.LLVM.PTX.Target ( PTX(..) )
import Data.Array.Accelerate.LLVM.State
import qualified Data.Array.Accelerate.LLVM.PTX.Debug as Debug
import qualified Data.Array.Accelerate.LLVM.PTX.Execute.Event as Event
import Data.Array.Accelerate.LLVM.PTX.Execute.Stream.Reservoir as RSV
import Foreign.CUDA.Driver.Error
import qualified Foreign.CUDA.Driver.Stream as Stream
import Control.Exception
import Control.Monad.State
type Stream = Lifetime Stream.Stream
{-# INLINEABLE streaming #-}
streaming
:: (Stream -> LLVM PTX a)
-> (Event -> a -> LLVM PTX b)
-> LLVM PTX b
streaming !action !after = do
PTX{..} <- gets llvmTarget
stream <- create
first <- action stream
end <- Event.waypoint stream
final <- after end first
liftIO $ do
destroy stream
Event.destroy end
return final
{-# INLINEABLE create #-}
create :: LLVM PTX Stream
create = do
PTX{..} <- gets llvmTarget
s <- create'
stream <- liftIO $ newLifetime s
liftIO $ addFinalizer stream (RSV.insert ptxStreamReservoir s)
return stream
create' :: LLVM PTX Stream.Stream
create' = do
PTX{..} <- gets llvmTarget
ms <- attempt "create/reservoir" (liftIO $ RSV.malloc ptxStreamReservoir)
`orElse`
attempt "create/new" (liftIO . catchOOM $ Stream.create [])
`orElse` do
Remote.reclaim ptxMemoryTable
liftIO $ do
message "create/new: failed (purging)"
catchOOM $ Stream.create []
case ms of
Just s -> return s
Nothing -> liftIO $ do
message "create/new: failed (non-recoverable)"
throwIO (ExitCode OutOfMemory)
where
catchOOM :: IO a -> IO (Maybe a)
catchOOM it =
liftM Just it `catch` \e -> case e of
ExitCode OutOfMemory -> return Nothing
_ -> throwIO e
attempt :: MonadIO m => String -> m (Maybe a) -> m (Maybe a)
attempt msg ea = do
ma <- ea
case ma of
Nothing -> return Nothing
Just a -> do liftIO (message msg)
return (Just a)
orElse :: MonadIO m => m (Maybe a) -> m (Maybe a) -> m (Maybe a)
orElse ea eb = do
ma <- ea
case ma of
Just a -> return (Just a)
Nothing -> eb
{-# INLINEABLE destroy #-}
destroy :: Stream -> IO ()
destroy = finalize
{-# INLINE trace #-}
trace :: String -> IO a -> IO a
trace msg next = do
Debug.traceIO Debug.dump_sched ("stream: " ++ msg)
next
{-# INLINE message #-}
message :: String -> IO ()
message s = s `trace` return ()