{-# Language BlockArguments, ScopedTypeVariables, LambdaCase #-}
{-|
Module      : Hookup.Concurrent
Description : Concurrently run actions until one succeeds or all fail
Copyright   : (c) Eric Mertens, 2020
License     : ISC
Maintainer  : emertens@gmail.com

-}
module Hookup.Concurrent (concurrentAttempts) where

import Control.Concurrent (forkIO, throwTo)
import Control.Concurrent.Async (Async, AsyncCancelled(..), async, asyncThreadId, cancel, waitCatch, waitCatchSTM)
import Control.Concurrent.STM (STM, atomically, check, orElse, readTVar, registerDelay, retry)
import Control.Exception (SomeException, finally, mask_, onException)
import Control.Monad (join, void)
import Data.Foldable (for_)

concurrentAttempts ::
  Int {- ^ microsecond delay between attempts -} ->
  (a -> IO ()) {- ^ release unneeded success -} ->
  [IO a] {- ^ ordered list of attempts -} ->
  IO (Either [SomeException] a)
concurrentAttempts :: forall a.
Int -> (a -> IO ()) -> [IO a] -> IO (Either [SomeException] a)
concurrentAttempts Int
delay a -> IO ()
release [IO a]
actions =
  forall a. IO a -> IO a
mask_ (forall a. St a -> Answer a
loop St{
    threads :: [Async a]
threads = [],
    errors :: [SomeException]
errors = [],
    work :: [IO a]
work = [IO a]
actions,
    delay :: Int
delay = Int
delay,
    clean :: a -> IO ()
clean = a -> IO ()
release,
    readySTM :: STM ()
readySTM = forall a. STM a
retry })

data St a = St
  { forall a. St a -> [Async a]
threads :: [Async a]
  , forall a. St a -> [SomeException]
errors  :: [SomeException]
  , forall a. St a -> [IO a]
work    :: [IO a]
  , forall a. St a -> Int
delay   :: !Int
  , forall a. St a -> a -> IO ()
clean   :: a -> IO ()
  , forall a. St a -> STM ()
readySTM :: STM ()
  }

type Answer a = IO (Either [SomeException] a)

-- | Main event loop for concurrent attempt system
loop :: forall a. St a -> Answer a
loop :: forall a. St a -> Answer a
loop St a
st
  | forall (t :: * -> *) a. Foldable t => t a -> Bool
null (forall a. St a -> [Async a]
threads St a
st) = forall a. St a -> Answer a
nothingRunning St a
st
  | Bool
otherwise         = forall a. St a -> Answer a
waitForEvent St a
st

-- | No threads are active, either start a new thread or return the complete error list
nothingRunning :: St a -> Answer a
nothingRunning :: forall a. St a -> Answer a
nothingRunning St a
st =
  case forall a. St a -> [IO a]
work St a
st of
    []   -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left (forall a. St a -> [SomeException]
errors St a
st))
    IO a
x:[IO a]
xs -> forall a. IO a -> St a -> Answer a
start IO a
x St a
st{work :: [IO a]
work = [IO a]
xs}

-- | Start a new thread for the given attempt
start :: IO a -> St a -> Answer a
start :: forall a. IO a -> St a -> Answer a
start IO a
io St a
st =
  do Async a
thread <- forall a. IO a -> IO (Async a)
async IO a
io
     STM ()
ready <- if forall (t :: * -> *) a. Foldable t => t a -> Bool
null (forall a. St a -> [IO a]
work St a
st) then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. STM a
retry else Int -> IO (STM ())
startTimer (forall a. St a -> Int
delay St a
st)
     forall a. St a -> Answer a
loop St a
st{ threads :: [Async a]
threads = Async a
thread forall a. a -> [a] -> [a]
: forall a. St a -> [Async a]
threads St a
st, readySTM :: STM ()
readySTM = STM ()
ready }

-- Nothing to do but wait for a thread to finish or the timer to fire
waitForEvent :: St a -> Answer a
waitForEvent :: forall a. St a -> Answer a
waitForEvent St a
st =
  forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (forall a. STM a -> IO a
atomically (forall a. St a -> [Async a] -> [Async a] -> STM (Answer a)
finish St a
st [] (forall a. St a -> [Async a]
threads St a
st))
  forall a b. IO a -> IO b -> IO a
`onException` forall a. (a -> IO ()) -> [Async a] -> IO ()
cleanup (forall a. St a -> a -> IO ()
clean St a
st) (forall a. St a -> [Async a]
threads St a
st))

-- Search for an event out of the active threads and timer
finish :: St a -> [Async a] -> [Async a] -> STM (Answer a)
finish :: forall a. St a -> [Async a] -> [Async a] -> STM (Answer a)
finish St a
st [Async a]
threads' = \case
  []   -> forall a. St a -> STM (Answer a)
fresh St a
st
  Async a
t:[Async a]
ts -> forall a. St a -> [Async a] -> Async a -> STM (Answer a)
finish1 St a
st ([Async a]
threads' forall a. [a] -> [a] -> [a]
++ [Async a]
ts) Async a
t forall a. STM a -> STM a -> STM a
`orElse` forall a. St a -> [Async a] -> [Async a] -> STM (Answer a)
finish St a
st (Async a
tforall a. a -> [a] -> [a]
:[Async a]
threads') [Async a]
ts

-- Handle a thread completion event
finish1 :: St a -> [Async a] -> Async a -> STM (Answer a)
finish1 :: forall a. St a -> [Async a] -> Async a -> STM (Answer a)
finish1 St a
st [Async a]
threads' Async a
t =
 do Either SomeException a
res <- forall a. Async a -> STM (Either SomeException a)
waitCatchSTM Async a
t
    forall (f :: * -> *) a. Applicative f => a -> f a
pure case Either SomeException a
res of
      Right a
s -> forall a b. b -> Either a b
Right a
s forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall a. (a -> IO ()) -> [Async a] -> IO ()
cleanup (forall a. St a -> a -> IO ()
clean St a
st) [Async a]
threads'
      Left  SomeException
e -> forall a. St a -> Answer a
loop St a
st{ errors :: [SomeException]
errors = SomeException
e forall a. a -> [a] -> [a]
: forall a. St a -> [SomeException]
errors St a
st, threads :: [Async a]
threads = [Async a]
threads' }

-- Handle a new thread timer event
fresh :: St a -> STM (Answer a)
fresh :: forall a. St a -> STM (Answer a)
fresh St a
st =
  case forall a. St a -> [IO a]
work St a
st of
    []   -> forall a. STM a
retry
    IO a
x:[IO a]
xs -> forall a. IO a -> St a -> Answer a
start IO a
x St a
st{work :: [IO a]
work = [IO a]
xs} forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall a. St a -> STM ()
readySTM St a
st

-- | Create an STM action that only succeeds after at least 'n' microseconds have passed.
startTimer :: Int -> IO (STM ())
startTimer :: Int -> IO (STM ())
startTimer Int
n =
  do TVar Bool
v <- Int -> IO (TVar Bool)
registerDelay Int
n
     forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> STM ()
check forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. TVar a -> STM a
readTVar TVar Bool
v)

-- non-blocking cancelation of the remaining threads
cleanup :: (a -> IO ()) -> [Async a] -> IO ()
cleanup :: forall a. (a -> IO ()) -> [Async a] -> IO ()
cleanup a -> IO ()
release [Async a]
xs =
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO
   do forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Async a]
xs \Async a
x -> forall e. Exception e => ThreadId -> e -> IO ()
throwTo (forall a. Async a -> ThreadId
asyncThreadId Async a
x) AsyncCancelled
AsyncCancelled
      forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Async a]
xs \Async a
x ->
       do Either SomeException a
res <- forall a. Async a -> IO (Either SomeException a)
waitCatch Async a
x
          forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ Either SomeException a
res a -> IO ()
release