{-# LANGUAGE LambdaCase #-} module Network.Serverless.Execute.Utils where -------------------------------------------------------------------------------- import Control.Monad (forM, unless) import Control.Concurrent.STM (newTVar, atomically, retry, readTVar, modifyTVar, writeTVar) import Control.Monad.Trans.State (evalStateT, get, put) import Control.Monad.IO.Class (liftIO) import Control.Concurrent.Async (async, wait, waitBoth) import Control.Exception (throwIO) import Data.Function (fix) import Control.Concurrent (threadDelay) import qualified System.Console.Terminal.Size as TS -------------------------------------------------------------------------------- import Network.Serverless.Execute -------------------------------------------------------------------------------- -- | -- Runs given closures concurrently using the 'Backend' with a progress bar. -- -- Throws 'ExecutorFailedException' if something fails. mapWithProgress :: Backend -> Closure (Dict (Serializable a)) -> [Closure (IO a)] -> IO [a] mapWithProgress backend dict xs = do st <- atomically $ newTVar (True, 0::Int, 0::Int, 0::Int, 0::Int) tvars <- mapM (executeAsync backend dict) xs asyncs <- forM tvars $ \tv -> async . flip evalStateT (0, 0, 0, 0) . fix $ \recurse -> do oldState <- get (newState, result) <- liftIO $ atomically $ do (n, r) <- readTVar tv >>= return . \case ExecutorPending (ExecutorWaiting _) -> ((1, 0, 0, 0), Nothing) ExecutorPending (ExecutorSubmitted _) -> ((0, 1, 0, 0), Nothing) ExecutorPending (ExecutorStarted _) -> ((0, 0, 1, 0), Nothing) ExecutorFinished fr -> ((0, 0, 0, 1), Just fr) unless (oldState /= n) retry return (n, r) put newState liftIO . atomically $ modifyTVar st $ \case (_, s1, s2, s3, s4) -> case (oldState, newState) of ((o1, o2, o3, o4), (n1, n2, n3, n4)) -> (True, s1 - o1 + n1, s2 - o2 + n2, s3 - o3 + n3, s4 - o4 + n4) case result of Nothing -> recurse Just (ExecutorFailed err) -> liftIO . throwIO $ ExecutorFailedException err Just (ExecutorSucceeded x) -> return x result <- async $ mapM wait asyncs termWidth <- maybe 40 TS.width <$> TS.size :: IO Int let total = length xs ratio = fromIntegral total / fromIntegral (termWidth - 2) :: Double pbar <- async . fix $ \recurse -> do (waiting, submitted, started, finished) <- atomically $ do readTVar st >>= \case (False, _, _, _, _) -> retry (True, a, b, c, d) -> do writeTVar st (False, a, b, c, d) return (a, b, c, d) let p c n = replicate (truncate (fromIntegral n / ratio)) c putStr . concat $ [ "\r" , "[" , p '#' finished , p ':' started , p '.' submitted , p ' ' waiting , "]" ] if (finished < total) then threadDelay 10000 >> recurse else putStrLn "" fst <$> waitBoth result pbar