{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingVia #-}
module Oath
  ( Oath(..)
  , hoistOath
  , evalOath
  , tryOath
  , oath
  , delay
  , timeout) where

import Control.Applicative
import Control.Concurrent
import Control.Concurrent.STM
import Control.Concurrent.STM.Delay
import Control.Exception
import Data.Monoid

-- 'Oath' is an 'Applicative' structure that collects results of one or more computations.
newtype Oath a = Oath { forall a. Oath a -> forall r. (STM a -> IO r) -> IO r
runOath :: forall r. (STM a -> IO r) -> IO r }
  deriving (forall a b. (a -> b) -> Oath a -> Oath b)
-> (forall a b. a -> Oath b -> Oath a) -> Functor Oath
forall a b. a -> Oath b -> Oath a
forall a b. (a -> b) -> Oath a -> Oath b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> Oath a -> Oath b
fmap :: forall a b. (a -> b) -> Oath a -> Oath b
$c<$ :: forall a b. a -> Oath b -> Oath a
<$ :: forall a b. a -> Oath b -> Oath a
Functor
  deriving (NonEmpty (Oath a) -> Oath a
Oath a -> Oath a -> Oath a
(Oath a -> Oath a -> Oath a)
-> (NonEmpty (Oath a) -> Oath a)
-> (forall b. Integral b => b -> Oath a -> Oath a)
-> Semigroup (Oath a)
forall b. Integral b => b -> Oath a -> Oath a
forall a. Semigroup a => NonEmpty (Oath a) -> Oath a
forall a. Semigroup a => Oath a -> Oath a -> Oath a
forall a b. (Semigroup a, Integral b) => b -> Oath a -> Oath a
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
$c<> :: forall a. Semigroup a => Oath a -> Oath a -> Oath a
<> :: Oath a -> Oath a -> Oath a
$csconcat :: forall a. Semigroup a => NonEmpty (Oath a) -> Oath a
sconcat :: NonEmpty (Oath a) -> Oath a
$cstimes :: forall a b. (Semigroup a, Integral b) => b -> Oath a -> Oath a
stimes :: forall b. Integral b => b -> Oath a -> Oath a
Semigroup, Semigroup (Oath a)
Oath a
Semigroup (Oath a)
-> Oath a
-> (Oath a -> Oath a -> Oath a)
-> ([Oath a] -> Oath a)
-> Monoid (Oath a)
[Oath a] -> Oath a
Oath a -> Oath a -> Oath a
forall a.
Semigroup a -> a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
forall {a}. Monoid a => Semigroup (Oath a)
forall a. Monoid a => Oath a
forall a. Monoid a => [Oath a] -> Oath a
forall a. Monoid a => Oath a -> Oath a -> Oath a
$cmempty :: forall a. Monoid a => Oath a
mempty :: Oath a
$cmappend :: forall a. Monoid a => Oath a -> Oath a -> Oath a
mappend :: Oath a -> Oath a -> Oath a
$cmconcat :: forall a. Monoid a => [Oath a] -> Oath a
mconcat :: [Oath a] -> Oath a
Monoid) via Ap Oath a

-- | Apply a function to the inner computation that waits for the result.
hoistOath :: (STM a -> STM b) -> Oath a -> Oath b
hoistOath :: forall a b. (STM a -> STM b) -> Oath a -> Oath b
hoistOath STM a -> STM b
t (Oath forall r. (STM a -> IO r) -> IO r
m) = (forall r. (STM b -> IO r) -> IO r) -> Oath b
forall a. (forall r. (STM a -> IO r) -> IO r) -> Oath a
Oath ((forall r. (STM b -> IO r) -> IO r) -> Oath b)
-> (forall r. (STM b -> IO r) -> IO r) -> Oath b
forall a b. (a -> b) -> a -> b
$ \STM b -> IO r
cont -> (STM a -> IO r) -> IO r
forall r. (STM a -> IO r) -> IO r
m ((STM a -> IO r) -> IO r) -> (STM a -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ STM b -> IO r
cont (STM b -> IO r) -> (STM a -> STM b) -> STM a -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM a -> STM b
t

-- | Run an 'Oath' and wait for the result.
evalOath :: Oath a -> IO a
evalOath :: forall a. Oath a -> IO a
evalOath Oath a
m = Oath a -> forall r. (STM a -> IO r) -> IO r
forall a. Oath a -> forall r. (STM a -> IO r) -> IO r
runOath Oath a
m STM a -> IO a
forall a. STM a -> IO a
atomically

-- | Catch an exception thrown in the inner computation.
tryOath :: Exception e => Oath a -> Oath (Either e a)
tryOath :: forall e a. Exception e => Oath a -> Oath (Either e a)
tryOath = (STM a -> STM (Either e a)) -> Oath a -> Oath (Either e a)
forall a b. (STM a -> STM b) -> Oath a -> Oath b
hoistOath ((STM a -> STM (Either e a)) -> Oath a -> Oath (Either e a))
-> (STM a -> STM (Either e a)) -> Oath a -> Oath (Either e a)
forall a b. (a -> b) -> a -> b
$ \STM a
t -> (a -> Either e a) -> STM a -> STM (Either e a)
forall a b. (a -> b) -> STM a -> STM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Either e a
forall a b. b -> Either a b
Right STM a
t STM (Either e a) -> (e -> STM (Either e a)) -> STM (Either e a)
forall e a. Exception e => STM a -> (e -> STM a) -> STM a
`catchSTM` (Either e a -> STM (Either e a)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either e a -> STM (Either e a))
-> (e -> Either e a) -> e -> STM (Either e a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> Either e a
forall a b. a -> Either a b
Left)

-- | ('<*>') initiates both computations, then combines the results once they are done
instance Applicative Oath where
  pure :: forall a. a -> Oath a
pure a
a = (forall r. (STM a -> IO r) -> IO r) -> Oath a
forall a. (forall r. (STM a -> IO r) -> IO r) -> Oath a
Oath ((forall r. (STM a -> IO r) -> IO r) -> Oath a)
-> (forall r. (STM a -> IO r) -> IO r) -> Oath a
forall a b. (a -> b) -> a -> b
$ \STM a -> IO r
cont -> STM a -> IO r
cont (a -> STM a
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a)
  Oath forall r. (STM (a -> b) -> IO r) -> IO r
m <*> :: forall a b. Oath (a -> b) -> Oath a -> Oath b
<*> Oath forall r. (STM a -> IO r) -> IO r
n = (forall r. (STM b -> IO r) -> IO r) -> Oath b
forall a. (forall r. (STM a -> IO r) -> IO r) -> Oath a
Oath ((forall r. (STM b -> IO r) -> IO r) -> Oath b)
-> (forall r. (STM b -> IO r) -> IO r) -> Oath b
forall a b. (a -> b) -> a -> b
$ \STM b -> IO r
cont -> (STM (a -> b) -> IO r) -> IO r
forall r. (STM (a -> b) -> IO r) -> IO r
m ((STM (a -> b) -> IO r) -> IO r) -> (STM (a -> b) -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \STM (a -> b)
f -> (STM a -> IO r) -> IO r
forall r. (STM a -> IO r) -> IO r
n ((STM a -> IO r) -> IO r) -> (STM a -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \STM a
x -> STM b -> IO r
cont (STM (a -> b)
f STM (a -> b) -> STM a -> STM b
forall a b. STM (a -> b) -> STM a -> STM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> STM a
x)

-- | ('<|>') waits for the first result, then cancel the loser
instance Alternative Oath where
  empty :: forall a. Oath a
empty = (forall r. (STM a -> IO r) -> IO r) -> Oath a
forall a. (forall r. (STM a -> IO r) -> IO r) -> Oath a
Oath ((forall r. (STM a -> IO r) -> IO r) -> Oath a)
-> (forall r. (STM a -> IO r) -> IO r) -> Oath a
forall a b. (a -> b) -> a -> b
$ \STM a -> IO r
cont -> STM a -> IO r
cont STM a
forall a. STM a
forall (f :: * -> *) a. Alternative f => f a
empty
  Oath forall r. (STM a -> IO r) -> IO r
m <|> :: forall a. Oath a -> Oath a -> Oath a
<|> Oath forall r. (STM a -> IO r) -> IO r
n = (forall r. (STM a -> IO r) -> IO r) -> Oath a
forall a. (forall r. (STM a -> IO r) -> IO r) -> Oath a
Oath ((forall r. (STM a -> IO r) -> IO r) -> Oath a)
-> (forall r. (STM a -> IO r) -> IO r) -> Oath a
forall a b. (a -> b) -> a -> b
$ \STM a -> IO r
cont -> (STM a -> IO r) -> IO r
forall r. (STM a -> IO r) -> IO r
m ((STM a -> IO r) -> IO r) -> (STM a -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \STM a
a -> (STM a -> IO r) -> IO r
forall r. (STM a -> IO r) -> IO r
n ((STM a -> IO r) -> IO r) -> (STM a -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \STM a
b -> STM a -> IO r
cont (STM a
a STM a -> STM a -> STM a
forall a. STM a -> STM a -> STM a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> STM a
b)

-- | Lift an IO action into an 'Oath', forking a thread.
-- When the continuation terminates, it kills the thread.
-- Exception thrown in the thread will be propagated to the result.
oath :: IO a -> Oath a
oath :: forall a. IO a -> Oath a
oath IO a
act = (forall r. (STM a -> IO r) -> IO r) -> Oath a
forall a. (forall r. (STM a -> IO r) -> IO r) -> Oath a
Oath ((forall r. (STM a -> IO r) -> IO r) -> Oath a)
-> (forall r. (STM a -> IO r) -> IO r) -> Oath a
forall a b. (a -> b) -> a -> b
$ \STM a -> IO r
cont -> do
  TMVar (Either SomeException a)
v <- IO (TMVar (Either SomeException a))
forall a. IO (TMVar a)
newEmptyTMVarIO
  ThreadId
tid <- IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally IO a
act (STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ())
-> (Either SomeException a -> STM ())
-> Either SomeException a
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TMVar (Either SomeException a) -> Either SomeException a -> STM ()
forall a. TMVar a -> a -> STM ()
putTMVar TMVar (Either SomeException a)
v)
  let await :: STM a
await = TMVar (Either SomeException a) -> STM (Either SomeException a)
forall a. TMVar a -> STM a
readTMVar TMVar (Either SomeException a)
v STM (Either SomeException a)
-> (Either SomeException a -> STM a) -> STM a
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (SomeException -> STM a)
-> (a -> STM a) -> Either SomeException a -> STM a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> STM a
forall e a. Exception e => e -> STM a
throwSTM a -> STM a
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  STM a -> IO r
cont STM a
await IO r -> IO () -> IO r
forall a b. IO a -> IO b -> IO a
`finally` ThreadId -> IO ()
killThread ThreadId
tid

-- | An 'Oath' that finishes once the given number of microseconds elapses
delay :: Int -> Oath ()
delay :: Int -> Oath ()
delay Int
dur = (forall r. (STM () -> IO r) -> IO r) -> Oath ()
forall a. (forall r. (STM a -> IO r) -> IO r) -> Oath a
Oath ((forall r. (STM () -> IO r) -> IO r) -> Oath ())
-> (forall r. (STM () -> IO r) -> IO r) -> Oath ()
forall a b. (a -> b) -> a -> b
$ \STM () -> IO r
cont -> IO Delay -> (Delay -> IO ()) -> (Delay -> IO r) -> IO r
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Int -> IO Delay
newDelay Int
dur) Delay -> IO ()
cancelDelay (STM () -> IO r
cont (STM () -> IO r) -> (Delay -> STM ()) -> Delay -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Delay -> STM ()
waitDelay)

-- | Returns nothing if the 'Oath' does not finish within the given number of microseconds.
timeout :: Int -> Oath a -> Oath (Maybe a)
timeout :: forall a. Int -> Oath a -> Oath (Maybe a)
timeout Int
dur Oath a
m = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> Oath a -> Oath (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Oath a
m Oath (Maybe a) -> Oath (Maybe a) -> Oath (Maybe a)
forall a. Oath a -> Oath a -> Oath a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Maybe a
forall a. Maybe a
Nothing Maybe a -> Oath () -> Oath (Maybe a)
forall a b. a -> Oath b -> Oath a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Int -> Oath ()
delay Int
dur