{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP          #-}
{-# LANGUAGE RankNTypes   #-}

module Snap.Internal.Http.Server.Thread
  ( SnapThread
  , fork
  , forkOn
  , cancel
  , wait
  , cancelAndWait
  , isFinished
  ) where

#if !MIN_VERSION_base(4,8,0)
import           Control.Applicative         ((<$>))
#endif
import           Control.Concurrent          (MVar, ThreadId, killThread, newEmptyMVar, putMVar, readMVar)
#if MIN_VERSION_base(4,7,0)
import           Control.Concurrent          (tryReadMVar)
#else
import           Control.Concurrent          (tryTakeMVar)
import           Control.Monad               (when)
import           Data.Maybe                  (fromJust, isJust)
#endif
import           Control.Concurrent.Extended (forkIOLabeledWithUnmaskBs, forkOnLabeledWithUnmaskBs)
import qualified Control.Exception           as E
import           Control.Monad               (void)
import qualified Data.ByteString.Char8       as B
import           GHC.Exts                    (inline)

#if !MIN_VERSION_base(4,7,0)
tryReadMVar :: MVar a -> IO (Maybe a)
tryReadMVar mv = do
    m <- tryTakeMVar mv
    when (isJust m) $ putMVar mv (fromJust m)
    return m
#endif

------------------------------------------------------------------------------
data SnapThread = SnapThread {
      SnapThread -> ThreadId
_snapThreadId :: {-# UNPACK #-} !ThreadId
    , SnapThread -> MVar ()
_snapThreadFinished :: {-# UNPACK #-} !(MVar ())
    }

instance Show SnapThread where
  show :: SnapThread -> String
show = ThreadId -> String
forall a. Show a => a -> String
show (ThreadId -> String)
-> (SnapThread -> ThreadId) -> SnapThread -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SnapThread -> ThreadId
_snapThreadId


------------------------------------------------------------------------------
forkOn :: B.ByteString                          -- ^ thread label
       -> Int                                   -- ^ capability
       -> ((forall a . IO a -> IO a) -> IO ())  -- ^ user thread action, taking
                                                --   a restore function
       -> IO SnapThread
forkOn :: ByteString
-> Int -> ((forall a. IO a -> IO a) -> IO ()) -> IO SnapThread
forkOn ByteString
label Int
cap (forall a. IO a -> IO a) -> IO ()
action = do
    MVar ()
mv <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
    IO SnapThread -> IO SnapThread
forall a. IO a -> IO a
E.uninterruptibleMask_ (IO SnapThread -> IO SnapThread) -> IO SnapThread -> IO SnapThread
forall a b. (a -> b) -> a -> b
$ do
        ThreadId
tid <- ByteString
-> Int -> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkOnLabeledWithUnmaskBs ByteString
label Int
cap (MVar ()
-> ((forall a. IO a -> IO a) -> IO ())
-> (forall a. IO a -> IO a)
-> IO ()
wrapAction MVar ()
mv (forall a. IO a -> IO a) -> IO ()
action)
        SnapThread -> IO SnapThread
forall (m :: * -> *) a. Monad m => a -> m a
return (SnapThread -> IO SnapThread) -> SnapThread -> IO SnapThread
forall a b. (a -> b) -> a -> b
$! ThreadId -> MVar () -> SnapThread
SnapThread ThreadId
tid MVar ()
mv


------------------------------------------------------------------------------
fork :: B.ByteString                          -- ^ thread label
     -> ((forall a . IO a -> IO a) -> IO ())  -- ^ user thread action, taking
                                              --   a restore function
     -> IO SnapThread
fork :: ByteString -> ((forall a. IO a -> IO a) -> IO ()) -> IO SnapThread
fork ByteString
label (forall a. IO a -> IO a) -> IO ()
action = do
    MVar ()
mv <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
    IO SnapThread -> IO SnapThread
forall a. IO a -> IO a
E.uninterruptibleMask_ (IO SnapThread -> IO SnapThread) -> IO SnapThread -> IO SnapThread
forall a b. (a -> b) -> a -> b
$ do
        ThreadId
tid <- ByteString -> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOLabeledWithUnmaskBs ByteString
label (MVar ()
-> ((forall a. IO a -> IO a) -> IO ())
-> (forall a. IO a -> IO a)
-> IO ()
wrapAction MVar ()
mv (forall a. IO a -> IO a) -> IO ()
action)
        SnapThread -> IO SnapThread
forall (m :: * -> *) a. Monad m => a -> m a
return (SnapThread -> IO SnapThread) -> SnapThread -> IO SnapThread
forall a b. (a -> b) -> a -> b
$! ThreadId -> MVar () -> SnapThread
SnapThread ThreadId
tid MVar ()
mv


------------------------------------------------------------------------------
cancel :: SnapThread -> IO ()
cancel :: SnapThread -> IO ()
cancel = ThreadId -> IO ()
killThread (ThreadId -> IO ())
-> (SnapThread -> ThreadId) -> SnapThread -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SnapThread -> ThreadId
_snapThreadId


------------------------------------------------------------------------------
wait :: SnapThread -> IO ()
wait :: SnapThread -> IO ()
wait = IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> (SnapThread -> IO ()) -> SnapThread -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVar () -> IO ()
forall a. MVar a -> IO a
readMVar (MVar () -> IO ())
-> (SnapThread -> MVar ()) -> SnapThread -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SnapThread -> MVar ()
_snapThreadFinished


------------------------------------------------------------------------------
cancelAndWait :: SnapThread -> IO ()
cancelAndWait :: SnapThread -> IO ()
cancelAndWait SnapThread
t = SnapThread -> IO ()
cancel SnapThread
t IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SnapThread -> IO ()
wait SnapThread
t


------------------------------------------------------------------------------
isFinished :: SnapThread -> IO Bool
isFinished :: SnapThread -> IO Bool
isFinished SnapThread
t =
    Bool -> (() -> Bool) -> Maybe () -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Bool -> () -> Bool
forall a b. a -> b -> a
const Bool
True) (Maybe () -> Bool) -> IO (Maybe ()) -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVar () -> IO (Maybe ())
forall a. MVar a -> IO (Maybe a)
tryReadMVar (SnapThread -> MVar ()
_snapThreadFinished SnapThread
t)


------------------------------------------------------------------------------
-- Internal functions follow
------------------------------------------------------------------------------
wrapAction :: MVar ()
           -> ((forall a . IO a -> IO a) -> IO ())
           -> ((forall a . IO a -> IO a) -> IO ())
wrapAction :: MVar ()
-> ((forall a. IO a -> IO a) -> IO ())
-> (forall a. IO a -> IO a)
-> IO ()
wrapAction MVar ()
mv (forall a. IO a -> IO a) -> IO ()
action forall a. IO a -> IO a
restore = ((forall a. IO a -> IO a) -> IO ()
action forall a. IO a -> IO a
restore IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO () -> IO ()
forall a. a -> a
inline IO ()
exit) IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` SomeException -> IO ()
onEx
  where
    onEx :: E.SomeException -> IO ()
    onEx :: SomeException -> IO ()
onEx !SomeException
_ = IO () -> IO ()
forall a. a -> a
inline IO ()
exit

    exit :: IO ()
exit = IO () -> IO ()
forall a. IO a -> IO a
E.uninterruptibleMask_ (MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
mv (() -> IO ()) -> () -> IO ()
forall a b. (a -> b) -> a -> b
$! ())