{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy #-}

-- |
-- Module      :   Grisette.Lib.Control.Monad
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Lib.Control.Monad
  ( -- * mrg* variants for operations in "Control.Monad"
    mrgReturnWithStrategy,
    mrgBindWithStrategy,
    mrgReturn,
    (.>>=),
    (.>>),
    mrgFoldM,
    mrgMzero,
    mrgMplus,
    mrgFmap,
  )
where

import Control.Monad (MonadPlus (mplus, mzero))
import Grisette.Core.Control.Monad.Union (MonadUnion)
import Grisette.Core.Data.Class.Mergeable
  ( Mergeable,
    MergingStrategy,
  )
import Grisette.Core.Data.Class.SimpleMergeable
  ( UnionLike (mergeWithStrategy),
    merge,
  )
import Grisette.Lib.Data.Foldable (mrgFoldlM)

-- | 'return' with 'MergingStrategy' knowledge propagation.
mrgReturnWithStrategy :: (MonadUnion u) => MergingStrategy a -> a -> u a
mrgReturnWithStrategy :: forall (u :: * -> *) a.
MonadUnion u =>
MergingStrategy a -> a -> u a
mrgReturnWithStrategy MergingStrategy a
s = MergingStrategy a -> u a -> u a
forall a. MergingStrategy a -> u a -> u a
forall (u :: * -> *) a.
UnionLike u =>
MergingStrategy a -> u a -> u a
mergeWithStrategy MergingStrategy a
s (u a -> u a) -> (a -> u a) -> a -> u a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> u a
forall a. a -> u a
forall (m :: * -> *) a. Monad m => a -> m a
return
{-# INLINE mrgReturnWithStrategy #-}

-- | '>>=' with 'MergingStrategy' knowledge propagation.
mrgBindWithStrategy :: (MonadUnion u) => MergingStrategy b -> u a -> (a -> u b) -> u b
mrgBindWithStrategy :: forall (u :: * -> *) b a.
MonadUnion u =>
MergingStrategy b -> u a -> (a -> u b) -> u b
mrgBindWithStrategy MergingStrategy b
s u a
a a -> u b
f = MergingStrategy b -> u b -> u b
forall a. MergingStrategy a -> u a -> u a
forall (u :: * -> *) a.
UnionLike u =>
MergingStrategy a -> u a -> u a
mergeWithStrategy MergingStrategy b
s (u b -> u b) -> u b -> u b
forall a b. (a -> b) -> a -> b
$ u a
a u a -> (a -> u b) -> u b
forall a b. u a -> (a -> u b) -> u b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> u b
f
{-# INLINE mrgBindWithStrategy #-}

-- | 'return' with 'MergingStrategy' knowledge propagation.
mrgReturn :: (MonadUnion u, Mergeable a) => a -> u a
mrgReturn :: forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn = u a -> u a
forall (u :: * -> *) a. (UnionLike u, Mergeable a) => u a -> u a
merge (u a -> u a) -> (a -> u a) -> a -> u a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> u a
forall a. a -> u a
forall (m :: * -> *) a. Monad m => a -> m a
return
{-# INLINE mrgReturn #-}

-- | '>>=' with 'MergingStrategy' knowledge propagation.
(.>>=) :: (MonadUnion u, Mergeable b) => u a -> (a -> u b) -> u b
u a
a .>>= :: forall (u :: * -> *) b a.
(MonadUnion u, Mergeable b) =>
u a -> (a -> u b) -> u b
.>>= a -> u b
f = u b -> u b
forall (u :: * -> *) a. (UnionLike u, Mergeable a) => u a -> u a
merge (u b -> u b) -> u b -> u b
forall a b. (a -> b) -> a -> b
$ u a
a u a -> (a -> u b) -> u b
forall a b. u a -> (a -> u b) -> u b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> u b
f
{-# INLINE (.>>=) #-}

-- | 'foldM' with 'MergingStrategy' knowledge propagation.
mrgFoldM :: (MonadUnion m, Mergeable b, Foldable t) => (b -> a -> m b) -> b -> t a -> m b
mrgFoldM :: forall (m :: * -> *) b (t :: * -> *) a.
(MonadUnion m, Mergeable b, Foldable t) =>
(b -> a -> m b) -> b -> t a -> m b
mrgFoldM = (b -> a -> m b) -> b -> t a -> m b
forall (m :: * -> *) b (t :: * -> *) a.
(MonadUnion m, Mergeable b, Foldable t) =>
(b -> a -> m b) -> b -> t a -> m b
mrgFoldlM
{-# INLINE mrgFoldM #-}

-- | '>>' with 'MergingStrategy' knowledge propagation.
--
-- This is usually more efficient than calling the original '>>' and merge the results.
(.>>) :: forall m a b. (MonadUnion m, Mergeable b) => m a -> m b -> m b
m a
a .>> :: forall (m :: * -> *) a b.
(MonadUnion m, Mergeable b) =>
m a -> m b -> m b
.>> m b
f = m b -> m b
forall (u :: * -> *) a. (UnionLike u, Mergeable a) => u a -> u a
merge (m b -> m b) -> m b -> m b
forall a b. (a -> b) -> a -> b
$ (a -> ()) -> m a -> m ()
forall (f :: * -> *) b a.
(MonadUnion f, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (() -> a -> ()
forall a b. a -> b -> a
const ()) m a
a m () -> m b -> m b
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m b
f
{-# INLINE (.>>) #-}

-- | 'mzero' with 'MergingStrategy' knowledge propagation.
mrgMzero :: forall m a. (MonadUnion m, Mergeable a, MonadPlus m) => m a
mrgMzero :: forall (m :: * -> *) a.
(MonadUnion m, Mergeable a, MonadPlus m) =>
m a
mrgMzero = m a -> m a
forall (u :: * -> *) a. (UnionLike u, Mergeable a) => u a -> u a
merge m a
forall a. m a
forall (m :: * -> *) a. MonadPlus m => m a
mzero
{-# INLINE mrgMzero #-}

-- | 'mplus' with 'MergingStrategy' knowledge propagation.
mrgMplus :: forall m a. (MonadUnion m, Mergeable a, MonadPlus m) => m a -> m a -> m a
mrgMplus :: forall (m :: * -> *) a.
(MonadUnion m, Mergeable a, MonadPlus m) =>
m a -> m a -> m a
mrgMplus m a
a m a
b = m a -> m a
forall (u :: * -> *) a. (UnionLike u, Mergeable a) => u a -> u a
merge (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ m a -> m a -> m a
forall a. m a -> m a -> m a
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
mplus m a
a m a
b
{-# INLINE mrgMplus #-}

-- | 'fmap' with 'MergingStrategy' knowledge propagation.
mrgFmap :: (MonadUnion f, Mergeable b, Functor f) => (a -> b) -> f a -> f b
mrgFmap :: forall (f :: * -> *) b a.
(MonadUnion f, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap a -> b
f f a
a = f b -> f b
forall (u :: * -> *) a. (UnionLike u, Mergeable a) => u a -> u a
merge (f b -> f b) -> f b -> f b
forall a b. (a -> b) -> a -> b
$ (a -> b) -> f a -> f b
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f f a
a
{-# INLINE mrgFmap #-}