{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- Module      : Control.Monad.STM.Class
-- Copyright   : (c) 2016--2021 Michael Walker
-- License     : MIT
-- Maintainer  : Michael Walker <mike@barrucadu.co.uk>
-- Stability   : experimental
-- Portability : CPP, RankNTypes, StandaloneDeriving, TemplateHaskell, TypeFamilies
--
-- This module provides an abstraction over 'STM', which can be used
-- with 'MonadConc'.
--
-- This module only defines the 'STM' class; you probably want to
-- import "Control.Concurrent.Classy.STM" (which exports
-- "Control.Monad.STM.Class").
--
-- __Deriving instances:__ If you have a newtype wrapper around a type
-- with an existing @MonadSTM@ instance, you should be able to derive
-- an instance for your type automatically, in simple cases.
--
-- For example:
--
-- > {-# LANGUAGE GeneralizedNewtypeDeriving #-}
-- > {-# LANGUAGE StandaloneDeriving #-}
-- > {-# LANGUAGE UndecidableInstances #-}
-- >
-- > data Env = Env
-- >
-- > newtype MyMonad m a = MyMonad { runMyMonad :: ReaderT Env m a }
-- >   deriving (Functor, Applicative, Monad, Alternative, MonadPlus)
-- >
-- > deriving instance MonadThrow m => MonadThrow (MyMonad m)
-- > deriving instance MonadCatch m => MonadCatch (MyMonad m)
-- >
-- > deriving instance MonadSTM m => MonadSTM (MyMonad m)
--
-- Do not be put off by the use of @UndecidableInstances@, it is safe
-- here.
--
-- __Deviations:__ An instance of @MonadSTM@ is not required to be a
-- @MonadFix@, unlike @STM@.
module Control.Monad.STM.Class
  ( MonadSTM(..)
  , retry
  , check
  , orElse
  , throwSTM
  , catchSTM

    -- * Utilities for type shenanigans
  , IsSTM
  , toIsSTM
  , fromIsSTM
) where

import           Control.Applicative          (Alternative(..))
import           Control.Exception            (Exception)
import           Control.Monad                (MonadPlus(..), unless)
import           Control.Monad.Fail           (MonadFail(..))
import           Control.Monad.Reader         (ReaderT)
import           Control.Monad.Trans          (lift)
import           Control.Monad.Trans.Identity (IdentityT)
import           Data.Kind                    (Type)

import qualified Control.Concurrent.STM       as STM
import qualified Control.Monad.Catch          as Ca
import qualified Control.Monad.RWS.Lazy       as RL
import qualified Control.Monad.RWS.Strict     as RS
import qualified Control.Monad.State.Lazy     as SL
import qualified Control.Monad.State.Strict   as SS
import qualified Control.Monad.Writer.Lazy    as WL
import qualified Control.Monad.Writer.Strict  as WS

-- | @MonadSTM@ is an abstraction over 'STM'.
--
-- This class does not provide any way to run transactions, rather
-- each 'MonadConc' has an associated @MonadSTM@ from which it can
-- atomically run a transaction.
--
-- @since 1.2.0.0
class (Ca.MonadCatch stm, MonadPlus stm) => MonadSTM stm where
  {-# MINIMAL
        (newTVar | newTVarN)
      , readTVar
      , writeTVar
    #-}

  -- | The mutable reference type. These behave like 'TVar's, in that
  -- they always contain a value and updates are non-blocking and
  -- synchronised.
  --
  -- @since 1.0.0.0
  type TVar stm :: Type -> Type

  -- | Create a new @TVar@ containing the given value.
  --
  -- > newTVar = newTVarN ""
  --
  -- @since 1.0.0.0
  newTVar :: a -> stm (TVar stm a)
  newTVar = forall (stm :: * -> *) a.
MonadSTM stm =>
String -> a -> stm (TVar stm a)
newTVarN String
""

  -- | Create a new @TVar@ containing the given value, but it is
  -- given a name which may be used to present more useful debugging
  -- information.
  --
  -- If an empty name is given, a counter starting from 0 is used. If
  -- names conflict, successive @TVar@s with the same name are given
  -- a numeric suffix, counting up from 1.
  --
  -- > newTVarN _ = newTVar
  --
  -- @since 1.0.0.0
  newTVarN :: String -> a -> stm (TVar stm a)
  newTVarN String
_ = forall (stm :: * -> *) a. MonadSTM stm => a -> stm (TVar stm a)
newTVar

  -- | Return the current value stored in a @TVar@.
  --
  -- @since 1.0.0.0
  readTVar :: TVar stm a -> stm a

  -- | Write the supplied value into the @TVar@.
  --
  -- @since 1.0.0.0
  writeTVar :: TVar stm a -> a -> stm ()

-- | Retry execution of this transaction because it has seen values in
-- @TVar@s that it shouldn't have. This will result in the thread
-- running the transaction being blocked until any @TVar@s referenced
-- in it have been mutated.
--
-- This is just 'mzero'.
--
-- @since 1.2.0.0
retry :: MonadSTM stm => stm a
retry :: forall (stm :: * -> *) a. MonadSTM stm => stm a
retry = forall (m :: * -> *) a. MonadPlus m => m a
mzero

-- | Check whether a condition is true and, if not, call @retry@.
--
-- @since 1.0.0.0
check :: MonadSTM stm => Bool -> stm ()
check :: forall (stm :: * -> *). MonadSTM stm => Bool -> stm ()
check Bool
b = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
b forall (stm :: * -> *) a. MonadSTM stm => stm a
retry

-- | Run the first transaction and, if it @retry@s, run the second
-- instead.
--
-- This is just 'mplus'.
--
-- @since 1.2.0.0
orElse :: MonadSTM stm => stm a -> stm a -> stm a
orElse :: forall (stm :: * -> *) a. MonadSTM stm => stm a -> stm a -> stm a
orElse = forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
mplus

-- | Throw an exception. This aborts the transaction and propagates
-- the exception.
--
-- @since 1.0.0.0
throwSTM :: (MonadSTM stm, Exception e) => e -> stm a
throwSTM :: forall (stm :: * -> *) e a.
(MonadSTM stm, Exception e) =>
e -> stm a
throwSTM = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
Ca.throwM

-- | Handling exceptions from 'throwSTM'.
--
-- @since 1.0.0.0
catchSTM :: (MonadSTM stm, Exception e) => stm a -> (e -> stm a) -> stm a
catchSTM :: forall (stm :: * -> *) e a.
(MonadSTM stm, Exception e) =>
stm a -> (e -> stm a) -> stm a
catchSTM = forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
Ca.catch

-- | @since 1.0.0.0
instance MonadSTM STM.STM where
  type TVar STM.STM = STM.TVar

  newTVar :: forall a. a -> STM (TVar STM a)
newTVar   = forall a. a -> STM (TVar a)
STM.newTVar
  readTVar :: forall a. TVar STM a -> STM a
readTVar  = forall a. TVar a -> STM a
STM.readTVar
  writeTVar :: forall a. TVar STM a -> a -> STM ()
writeTVar = forall a. TVar a -> a -> STM ()
STM.writeTVar

-------------------------------------------------------------------------------
-- Type shenanigans

-- | A value of type @IsSTM m a@ can only be constructed if @m@ has a
-- @MonadSTM@ instance.
--
-- @since 1.2.2.0
newtype IsSTM m a = IsSTM { forall (m :: * -> *) a. IsSTM m a -> m a
unIsSTM :: m a }
  deriving (forall a b. a -> IsSTM m b -> IsSTM m a
forall a b. (a -> b) -> IsSTM m a -> IsSTM m b
forall (m :: * -> *) a b. Functor m => a -> IsSTM m b -> IsSTM m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> IsSTM m a -> IsSTM m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> IsSTM m b -> IsSTM m a
$c<$ :: forall (m :: * -> *) a b. Functor m => a -> IsSTM m b -> IsSTM m a
fmap :: forall a b. (a -> b) -> IsSTM m a -> IsSTM m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> IsSTM m a -> IsSTM m b
Functor, forall a. a -> IsSTM m a
forall a b. IsSTM m a -> IsSTM m b -> IsSTM m a
forall a b. IsSTM m a -> IsSTM m b -> IsSTM m b
forall a b. IsSTM m (a -> b) -> IsSTM m a -> IsSTM m b
forall a b c. (a -> b -> c) -> IsSTM m a -> IsSTM m b -> IsSTM m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall {m :: * -> *}. Applicative m => Functor (IsSTM m)
forall (m :: * -> *) a. Applicative m => a -> IsSTM m a
forall (m :: * -> *) a b.
Applicative m =>
IsSTM m a -> IsSTM m b -> IsSTM m a
forall (m :: * -> *) a b.
Applicative m =>
IsSTM m a -> IsSTM m b -> IsSTM m b
forall (m :: * -> *) a b.
Applicative m =>
IsSTM m (a -> b) -> IsSTM m a -> IsSTM m b
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> IsSTM m a -> IsSTM m b -> IsSTM m c
<* :: forall a b. IsSTM m a -> IsSTM m b -> IsSTM m a
$c<* :: forall (m :: * -> *) a b.
Applicative m =>
IsSTM m a -> IsSTM m b -> IsSTM m a
*> :: forall a b. IsSTM m a -> IsSTM m b -> IsSTM m b
$c*> :: forall (m :: * -> *) a b.
Applicative m =>
IsSTM m a -> IsSTM m b -> IsSTM m b
liftA2 :: forall a b c. (a -> b -> c) -> IsSTM m a -> IsSTM m b -> IsSTM m c
$cliftA2 :: forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c) -> IsSTM m a -> IsSTM m b -> IsSTM m c
<*> :: forall a b. IsSTM m (a -> b) -> IsSTM m a -> IsSTM m b
$c<*> :: forall (m :: * -> *) a b.
Applicative m =>
IsSTM m (a -> b) -> IsSTM m a -> IsSTM m b
pure :: forall a. a -> IsSTM m a
$cpure :: forall (m :: * -> *) a. Applicative m => a -> IsSTM m a
Applicative, forall a. IsSTM m a
forall a. IsSTM m a -> IsSTM m [a]
forall a. IsSTM m a -> IsSTM m a -> IsSTM m a
forall (f :: * -> *).
Applicative f
-> (forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
forall {m :: * -> *}. Alternative m => Applicative (IsSTM m)
forall (m :: * -> *) a. Alternative m => IsSTM m a
forall (m :: * -> *) a. Alternative m => IsSTM m a -> IsSTM m [a]
forall (m :: * -> *) a.
Alternative m =>
IsSTM m a -> IsSTM m a -> IsSTM m a
many :: forall a. IsSTM m a -> IsSTM m [a]
$cmany :: forall (m :: * -> *) a. Alternative m => IsSTM m a -> IsSTM m [a]
some :: forall a. IsSTM m a -> IsSTM m [a]
$csome :: forall (m :: * -> *) a. Alternative m => IsSTM m a -> IsSTM m [a]
<|> :: forall a. IsSTM m a -> IsSTM m a -> IsSTM m a
$c<|> :: forall (m :: * -> *) a.
Alternative m =>
IsSTM m a -> IsSTM m a -> IsSTM m a
empty :: forall a. IsSTM m a
$cempty :: forall (m :: * -> *) a. Alternative m => IsSTM m a
Alternative, forall a. a -> IsSTM m a
forall a b. IsSTM m a -> IsSTM m b -> IsSTM m b
forall a b. IsSTM m a -> (a -> IsSTM m b) -> IsSTM m b
forall {m :: * -> *}. Monad m => Applicative (IsSTM m)
forall (m :: * -> *) a. Monad m => a -> IsSTM m a
forall (m :: * -> *) a b.
Monad m =>
IsSTM m a -> IsSTM m b -> IsSTM m b
forall (m :: * -> *) a b.
Monad m =>
IsSTM m a -> (a -> IsSTM m b) -> IsSTM m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> IsSTM m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> IsSTM m a
>> :: forall a b. IsSTM m a -> IsSTM m b -> IsSTM m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
IsSTM m a -> IsSTM m b -> IsSTM m b
>>= :: forall a b. IsSTM m a -> (a -> IsSTM m b) -> IsSTM m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
IsSTM m a -> (a -> IsSTM m b) -> IsSTM m b
Monad, forall a. IsSTM m a
forall a. IsSTM m a -> IsSTM m a -> IsSTM m a
forall {m :: * -> *}. MonadPlus m => Monad (IsSTM m)
forall {m :: * -> *}. MonadPlus m => Alternative (IsSTM m)
forall (m :: * -> *) a. MonadPlus m => IsSTM m a
forall (m :: * -> *) a.
MonadPlus m =>
IsSTM m a -> IsSTM m a -> IsSTM m a
forall (m :: * -> *).
Alternative m
-> Monad m
-> (forall a. m a)
-> (forall a. m a -> m a -> m a)
-> MonadPlus m
mplus :: forall a. IsSTM m a -> IsSTM m a -> IsSTM m a
$cmplus :: forall (m :: * -> *) a.
MonadPlus m =>
IsSTM m a -> IsSTM m a -> IsSTM m a
mzero :: forall a. IsSTM m a
$cmzero :: forall (m :: * -> *) a. MonadPlus m => IsSTM m a
MonadPlus, forall e a. Exception e => e -> IsSTM m a
forall (m :: * -> *).
Monad m -> (forall e a. Exception e => e -> m a) -> MonadThrow m
forall {m :: * -> *}. MonadThrow m => Monad (IsSTM m)
forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> IsSTM m a
throwM :: forall e a. Exception e => e -> IsSTM m a
$cthrowM :: forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> IsSTM m a
Ca.MonadThrow, forall e a.
Exception e =>
IsSTM m a -> (e -> IsSTM m a) -> IsSTM m a
forall (m :: * -> *).
MonadThrow m
-> (forall e a. Exception e => m a -> (e -> m a) -> m a)
-> MonadCatch m
forall {m :: * -> *}. MonadCatch m => MonadThrow (IsSTM m)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
IsSTM m a -> (e -> IsSTM m a) -> IsSTM m a
catch :: forall e a.
Exception e =>
IsSTM m a -> (e -> IsSTM m a) -> IsSTM m a
$ccatch :: forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
IsSTM m a -> (e -> IsSTM m a) -> IsSTM m a
Ca.MonadCatch)

-- | @since 1.8.0.0
deriving instance MonadFail m => MonadFail (IsSTM m)

-- | Wrap an @m a@ value inside an @IsSTM@ if @m@ has a @MonadSTM@
-- instance.
--
-- @since 1.2.2.0
toIsSTM :: MonadSTM m => m a -> IsSTM m a
toIsSTM :: forall (m :: * -> *) a. MonadSTM m => m a -> IsSTM m a
toIsSTM = forall (m :: * -> *) a. m a -> IsSTM m a
IsSTM

-- | Unwrap an @IsSTM@ value.
--
-- @since 1.2.2.0
fromIsSTM :: MonadSTM m => IsSTM m a -> m a
fromIsSTM :: forall (m :: * -> *) a. MonadSTM m => IsSTM m a -> m a
fromIsSTM = forall (m :: * -> *) a. IsSTM m a -> m a
unIsSTM

instance MonadSTM m => MonadSTM (IsSTM m) where
  type TVar (IsSTM m) = TVar m

  newTVar :: forall a. a -> IsSTM m (TVar (IsSTM m) a)
newTVar     = forall (m :: * -> *) a. MonadSTM m => m a -> IsSTM m a
toIsSTM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (stm :: * -> *) a. MonadSTM stm => a -> stm (TVar stm a)
newTVar
  newTVarN :: forall a. String -> a -> IsSTM m (TVar (IsSTM m) a)
newTVarN String
n  = forall (m :: * -> *) a. MonadSTM m => m a -> IsSTM m a
toIsSTM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (stm :: * -> *) a.
MonadSTM stm =>
String -> a -> stm (TVar stm a)
newTVarN String
n
  readTVar :: forall a. TVar (IsSTM m) a -> IsSTM m a
readTVar    = forall (m :: * -> *) a. MonadSTM m => m a -> IsSTM m a
toIsSTM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (stm :: * -> *) a. MonadSTM stm => TVar stm a -> stm a
readTVar
  writeTVar :: forall a. TVar (IsSTM m) a -> a -> IsSTM m ()
writeTVar TVar (IsSTM m) a
v = forall (m :: * -> *) a. MonadSTM m => m a -> IsSTM m a
toIsSTM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (stm :: * -> *) a. MonadSTM stm => TVar stm a -> a -> stm ()
writeTVar TVar (IsSTM m) a
v

-------------------------------------------------------------------------------
-- Transformer instances

#define INSTANCE(T,C,F)                                  \
instance C => MonadSTM (T stm) where { \
  type TVar (T stm) = TVar stm      ; \
                                      \
  newTVar     = lift . newTVar      ; \
  newTVarN n  = lift . newTVarN n   ; \
  readTVar    = lift . readTVar     ; \
  writeTVar v = lift . writeTVar v  }

-- | @since 1.0.0.0
INSTANCE(ReaderT r, MonadSTM stm, id)

-- | @since 1.0.0.0
INSTANCE(IdentityT, MonadSTM stm, id)

-- | @since 1.0.0.0
INSTANCE(WL.WriterT w, (MonadSTM stm, Monoid w), fst)

-- | @since 1.0.0.0
INSTANCE(WS.WriterT w, (MonadSTM stm, Monoid w), fst)

-- | @since 1.0.0.0
INSTANCE(SL.StateT s, MonadSTM stm, fst)

-- | @since 1.0.0.0
INSTANCE(SS.StateT s, MonadSTM stm, fst)

-- | @since 1.0.0.0
INSTANCE(RL.RWST r w s, (MonadSTM stm, Monoid w), (\(a,_,_) -> a))

-- | @since 1.0.0.0
INSTANCE(RS.RWST r w s, (MonadSTM stm, Monoid w), (\(a,_,_) -> a))

#undef INSTANCE