{-# LANGUAGE MagicHash       #-}
{-# LANGUAGE RecordWildCards #-}
-- |
-- Module      : Data.Array.Accelerate.Math.FFT.LLVM.PTX.Plans
-- Copyright   : [2017..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Math.FFT.LLVM.PTX.Plans (

  Plans,
  createPlan,
  withPlan,

) where

import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.LLVM.PTX
import Data.Array.Accelerate.LLVM.PTX.Foreign

import Data.Array.Accelerate.Math.FFT.LLVM.PTX.Base

import Control.Concurrent.MVar
import Control.Monad.State
import Data.HashMap.Strict
import qualified Data.HashMap.Strict                                as Map

import qualified Foreign.CUDA.Driver.Context                        as CUDA
import qualified Foreign.CUDA.FFT                                   as FFT

import GHC.Ptr
import GHC.Base
import Prelude                                                      hiding ( lookup )


data Plans a = Plans
  { Plans a -> MVar (HashMap (Int, Int) (Lifetime Handle))
plans   :: {-# UNPACK #-} !(MVar ( HashMap (Int, Int) (Lifetime FFT.Handle)))
  , Plans a -> a -> IO Handle
create  :: a -> IO FFT.Handle
  , Plans a -> a -> Int
hash    :: a -> Int
  }


-- Create a new plan cache
--
{-# INLINE createPlan #-}
createPlan :: (a -> IO FFT.Handle) -> (a -> Int) -> IO (Plans a)
createPlan :: (a -> IO Handle) -> (a -> Int) -> IO (Plans a)
createPlan a -> IO Handle
via a -> Int
mix =
  MVar (HashMap (Int, Int) (Lifetime Handle))
-> (a -> IO Handle) -> (a -> Int) -> Plans a
forall a.
MVar (HashMap (Int, Int) (Lifetime Handle))
-> (a -> IO Handle) -> (a -> Int) -> Plans a
Plans (MVar (HashMap (Int, Int) (Lifetime Handle))
 -> (a -> IO Handle) -> (a -> Int) -> Plans a)
-> IO (MVar (HashMap (Int, Int) (Lifetime Handle)))
-> IO ((a -> IO Handle) -> (a -> Int) -> Plans a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HashMap (Int, Int) (Lifetime Handle)
-> IO (MVar (HashMap (Int, Int) (Lifetime Handle)))
forall a. a -> IO (MVar a)
newMVar HashMap (Int, Int) (Lifetime Handle)
forall k v. HashMap k v
Map.empty IO ((a -> IO Handle) -> (a -> Int) -> Plans a)
-> IO (a -> IO Handle) -> IO ((a -> Int) -> Plans a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> IO Handle) -> IO (a -> IO Handle)
forall (f :: * -> *) a. Applicative f => a -> f a
pure a -> IO Handle
via IO ((a -> Int) -> Plans a) -> IO (a -> Int) -> IO (Plans a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> Int) -> IO (a -> Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure a -> Int
mix


-- Execute an operation with a cuFFT handle appropriate for the current
-- execution context.
--
-- Initial creation of the context is an atomic operation, but subsequently
-- multiple threads may use the context concurrently.
--
-- TLM: check that plans can be used concurrently
--
-- <http://docs.nvidia.com/cuda/cufft/index.html#thread-safety>
--
{-# INLINE withPlan #-}
withPlan :: Plans a -> a -> (FFT.Handle -> LLVM PTX b) -> LLVM PTX b
withPlan :: Plans a -> a -> (Handle -> LLVM PTX b) -> LLVM PTX b
withPlan Plans{MVar (HashMap (Int, Int) (Lifetime Handle))
a -> Int
a -> IO Handle
hash :: a -> Int
create :: a -> IO Handle
plans :: MVar (HashMap (Int, Int) (Lifetime Handle))
hash :: forall a. Plans a -> a -> Int
create :: forall a. Plans a -> a -> IO Handle
plans :: forall a. Plans a -> MVar (HashMap (Int, Int) (Lifetime Handle))
..} a
a Handle -> LLVM PTX b
k = do
  Lifetime Context
lc <- (PTX -> Lifetime Context) -> LLVM PTX (Lifetime Context)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Context -> Lifetime Context
deviceContext (Context -> Lifetime Context)
-> (PTX -> Context) -> PTX -> Lifetime Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PTX -> Context
ptxContext)
  Lifetime Handle
h  <- IO (Lifetime Handle) -> LLVM PTX (Lifetime Handle)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Lifetime Handle) -> LLVM PTX (Lifetime Handle))
-> IO (Lifetime Handle) -> LLVM PTX (Lifetime Handle)
forall a b. (a -> b) -> a -> b
$
          Lifetime Context
-> (Context -> IO (Lifetime Handle)) -> IO (Lifetime Handle)
forall a b. Lifetime a -> (a -> IO b) -> IO b
withLifetime Lifetime Context
lc  ((Context -> IO (Lifetime Handle)) -> IO (Lifetime Handle))
-> (Context -> IO (Lifetime Handle)) -> IO (Lifetime Handle)
forall a b. (a -> b) -> a -> b
$ \Context
ctx ->
          MVar (HashMap (Int, Int) (Lifetime Handle))
-> (HashMap (Int, Int) (Lifetime Handle)
    -> IO (HashMap (Int, Int) (Lifetime Handle), Lifetime Handle))
-> IO (Lifetime Handle)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar (HashMap (Int, Int) (Lifetime Handle))
plans ((HashMap (Int, Int) (Lifetime Handle)
  -> IO (HashMap (Int, Int) (Lifetime Handle), Lifetime Handle))
 -> IO (Lifetime Handle))
-> (HashMap (Int, Int) (Lifetime Handle)
    -> IO (HashMap (Int, Int) (Lifetime Handle), Lifetime Handle))
-> IO (Lifetime Handle)
forall a b. (a -> b) -> a -> b
$ \HashMap (Int, Int) (Lifetime Handle)
pm  ->
            let key :: (Int, Int)
key = (Context -> Int
toKey Context
ctx, a -> Int
hash a
a) in
            case (Int, Int)
-> HashMap (Int, Int) (Lifetime Handle) -> Maybe (Lifetime Handle)
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
Map.lookup (Int, Int)
key HashMap (Int, Int) (Lifetime Handle)
pm of
              -- handle does not exist yet; create it and add to the global
              -- state for reuse
              Maybe (Lifetime Handle)
Nothing -> do
                Handle
h <- a -> IO Handle
create a
a
                Lifetime Handle
l <- Handle -> IO (Lifetime Handle)
forall a. a -> IO (Lifetime a)
newLifetime Handle
h
                Lifetime Context -> IO () -> IO ()
forall a. Lifetime a -> IO () -> IO ()
addFinalizer Lifetime Context
lc (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar (HashMap (Int, Int) (Lifetime Handle))
-> (HashMap (Int, Int) (Lifetime Handle)
    -> IO (HashMap (Int, Int) (Lifetime Handle), ()))
-> IO ()
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar (HashMap (Int, Int) (Lifetime Handle))
plans (\HashMap (Int, Int) (Lifetime Handle)
pm' -> (HashMap (Int, Int) (Lifetime Handle), ())
-> IO (HashMap (Int, Int) (Lifetime Handle), ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((Int, Int)
-> HashMap (Int, Int) (Lifetime Handle)
-> HashMap (Int, Int) (Lifetime Handle)
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> HashMap k v
Map.delete (Int, Int)
key HashMap (Int, Int) (Lifetime Handle)
pm', ()))
                Lifetime Handle -> IO () -> IO ()
forall a. Lifetime a -> IO () -> IO ()
addFinalizer Lifetime Handle
l  (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> IO ()
FFT.destroy Handle
h
                (HashMap (Int, Int) (Lifetime Handle), Lifetime Handle)
-> IO (HashMap (Int, Int) (Lifetime Handle), Lifetime Handle)
forall (m :: * -> *) a. Monad m => a -> m a
return ( (Int, Int)
-> Lifetime Handle
-> HashMap (Int, Int) (Lifetime Handle)
-> HashMap (Int, Int) (Lifetime Handle)
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
Map.insert (Int, Int)
key Lifetime Handle
l HashMap (Int, Int) (Lifetime Handle)
pm, Lifetime Handle
l )

              -- return existing handle
              Just Lifetime Handle
h  -> (HashMap (Int, Int) (Lifetime Handle), Lifetime Handle)
-> IO (HashMap (Int, Int) (Lifetime Handle), Lifetime Handle)
forall (m :: * -> *) a. Monad m => a -> m a
return (HashMap (Int, Int) (Lifetime Handle)
pm, Lifetime Handle
h)
  --
  Lifetime Handle -> (Handle -> LLVM PTX b) -> LLVM PTX b
forall a b. Lifetime a -> (a -> LLVM PTX b) -> LLVM PTX b
withLifetime' Lifetime Handle
h Handle -> LLVM PTX b
k

{-# INLINE toKey #-}
toKey :: CUDA.Context -> Int
toKey :: Context -> Int
toKey (CUDA.Context (Ptr Addr#
addr#)) = Int# -> Int
I# (Addr# -> Int#
addr2Int# Addr#
addr#)