{-# LANGUAGE LambdaCase #-}

-- |
-- Module      :   Grisette.Core.Control.Monad.Class.MonadParallelUnion
-- Copyright   :   (c) Sirui Lu 2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Core.Control.Monad.Class.MonadParallelUnion
  ( MonadParallelUnion (..),
  )
where

import Control.DeepSeq (NFData)
import Control.Monad.Except (ExceptT (ExceptT), runExceptT)
import Control.Monad.Identity (IdentityT (IdentityT, runIdentityT))
import qualified Control.Monad.RWS.Lazy as RWSLazy
import qualified Control.Monad.RWS.Strict as RWSStrict
import Control.Monad.Reader (ReaderT (ReaderT, runReaderT))
import qualified Control.Monad.State.Lazy as StateLazy
import qualified Control.Monad.State.Strict as StateStrict
import Control.Monad.Trans.Maybe (MaybeT (MaybeT, runMaybeT))
import qualified Control.Monad.Writer.Lazy as WriterLazy
import qualified Control.Monad.Writer.Strict as WriterStrict
import Grisette.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Core.Data.Class.SimpleMergeable
  ( UnionLike,
    merge,
  )

-- | Parallel union monad.
--
-- With the @QualifiedDo@ extension and the "Grisette.Qualified.ParallelUnionDo"
-- module, one can execute the paths in parallel and merge the results with:
--
-- > :set -XQualifiedDo -XOverloadedStrings
-- > import Grisette
-- > import qualified Grisette.Qualified.ParallelUnionDo as P
-- > P.do
-- >   x <- mrgIf "a" (return 1) (return 2) :: UnionM Int
-- >   return $ x + 1
-- >
-- > -- {If a 2 3}
class (UnionLike m, Monad m) => MonadParallelUnion m where
  parBindUnion :: (Mergeable b, NFData b) => m a -> (a -> m b) -> m b

instance (MonadParallelUnion m) => MonadParallelUnion (MaybeT m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
MaybeT m a -> (a -> MaybeT m b) -> MaybeT m b
parBindUnion (MaybeT m (Maybe a)
x) a -> MaybeT m b
f =
    forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT forall a b. (a -> b) -> a -> b
$
      m (Maybe a)
x forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
        Maybe a
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
        Just a
x'' -> forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT forall a b. (a -> b) -> a -> b
$ a -> MaybeT m b
f a
x''
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m, Mergeable e, NFData e) => MonadParallelUnion (ExceptT e m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
ExceptT e m a -> (a -> ExceptT e m b) -> ExceptT e m b
parBindUnion (ExceptT m (Either e a)
x) a -> ExceptT e m b
f =
    forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall a b. (a -> b) -> a -> b
$
      m (Either e a)
x forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
        Left e
e -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left e
e
        Right a
x'' -> forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ a -> ExceptT e m b
f a
x''
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m, Mergeable s, NFData s) => MonadParallelUnion (StateLazy.StateT s m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
StateT s m a -> (a -> StateT s m b) -> StateT s m b
parBindUnion (StateLazy.StateT s -> m (a, s)
x) a -> StateT s m b
f = forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateLazy.StateT forall a b. (a -> b) -> a -> b
$ \s
s ->
    s -> m (a, s)
x s
s forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
      ~(a
a, s
s') -> forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
StateLazy.runStateT (a -> StateT s m b
f a
a) s
s'
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m, Mergeable s, NFData s) => MonadParallelUnion (StateStrict.StateT s m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
StateT s m a -> (a -> StateT s m b) -> StateT s m b
parBindUnion (StateStrict.StateT s -> m (a, s)
x) a -> StateT s m b
f = forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateStrict.StateT forall a b. (a -> b) -> a -> b
$ \s
s ->
    s -> m (a, s)
x s
s forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
      (a
a, s
s') -> forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
StateStrict.runStateT (a -> StateT s m b
f a
a) s
s'
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m, Mergeable s, Monoid s, NFData s) => MonadParallelUnion (WriterLazy.WriterT s m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
WriterT s m a -> (a -> WriterT s m b) -> WriterT s m b
parBindUnion (WriterLazy.WriterT m (a, s)
x) a -> WriterT s m b
f =
    forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterLazy.WriterT forall a b. (a -> b) -> a -> b
$
      m (a, s)
x forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
        ~(a
a, s
w) ->
          forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
WriterLazy.runWriterT (a -> WriterT s m b
f a
a) forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
            ~(b
b, s
w') -> forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
w forall a. Semigroup a => a -> a -> a
<> s
w')
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m, Mergeable s, Monoid s, NFData s) => MonadParallelUnion (WriterStrict.WriterT s m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
WriterT s m a -> (a -> WriterT s m b) -> WriterT s m b
parBindUnion (WriterStrict.WriterT m (a, s)
x) a -> WriterT s m b
f =
    forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterStrict.WriterT forall a b. (a -> b) -> a -> b
$
      m (a, s)
x forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
        (a
a, s
w) ->
          forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
WriterStrict.runWriterT (a -> WriterT s m b
f a
a) forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
            (b
b, s
w') -> forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
w forall a. Semigroup a => a -> a -> a
<> s
w')
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m, Mergeable a, NFData a) => MonadParallelUnion (ReaderT a m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
ReaderT a m a -> (a -> ReaderT a m b) -> ReaderT a m b
parBindUnion (ReaderT a -> m a
x) a -> ReaderT a m b
f = forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ \a
a ->
    a -> m a
x a
a forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \a
a' -> forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (a -> ReaderT a m b
f a
a') a
a
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m) => MonadParallelUnion (IdentityT m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
IdentityT m a -> (a -> IdentityT m b) -> IdentityT m b
parBindUnion (IdentityT m a
x) a -> IdentityT m b
f = forall {k} (f :: k -> *) (a :: k). f a -> IdentityT f a
IdentityT forall a b. (a -> b) -> a -> b
$ m a
x forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` (forall (u :: * -> *) a. (UnionLike u, Mergeable a) => u a -> u a
merge forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (f :: k -> *) (a :: k). IdentityT f a -> f a
runIdentityT forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> IdentityT m b
f)
  {-# INLINE parBindUnion #-}

instance
  (MonadParallelUnion m, Mergeable s, Mergeable r, Mergeable w, Monoid w, NFData r, NFData w, NFData s) =>
  MonadParallelUnion (RWSStrict.RWST r w s m)
  where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
RWST r w s m a -> (a -> RWST r w s m b) -> RWST r w s m b
parBindUnion RWST r w s m a
m a -> RWST r w s m b
k = forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
RWSStrict.RWST forall a b. (a -> b) -> a -> b
$ \r
r s
s ->
    forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSStrict.runRWST RWST r w s m a
m r
r s
s forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
      (a
a, s
s', w
w) ->
        forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSStrict.runRWST (a -> RWST r w s m b
k a
a) r
r s
s' forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
          (b
b, s
s'', w
w') -> forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
s'', w
w forall a. Semigroup a => a -> a -> a
<> w
w')
  {-# INLINE parBindUnion #-}

instance
  (MonadParallelUnion m, Mergeable s, Mergeable r, Mergeable w, Monoid w, NFData r, NFData w, NFData s) =>
  MonadParallelUnion (RWSLazy.RWST r w s m)
  where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
RWST r w s m a -> (a -> RWST r w s m b) -> RWST r w s m b
parBindUnion RWST r w s m a
m a -> RWST r w s m b
k = forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
RWSLazy.RWST forall a b. (a -> b) -> a -> b
$ \r
r s
s ->
    forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSLazy.runRWST RWST r w s m a
m r
r s
s forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
      ~(a
a, s
s', w
w) ->
        forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSLazy.runRWST (a -> RWST r w s m b
k a
a) r
r s
s' forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
          ~(b
b, s
s'', w
w') -> forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
s'', w
w forall a. Semigroup a => a -> a -> a
<> w
w')
  {-# INLINE parBindUnion #-}