{-# LANGUAGE CPP, NoImplicitPrelude #-}
#if __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Trustworthy #-}
#endif
#if !(MIN_VERSION_base(4,18,0))
{-# LANGUAGE ScopedTypeVariables #-}
#endif
module Data.Traversable.Compat (
  module Base
, mapAccumM
, forAccumM
) where

import Data.Traversable as Base

#if !(MIN_VERSION_base(4,18,0))
import Prelude.Compat

import Control.Monad.Compat (liftM)

# if MIN_VERSION_base(4,8,0)
import Data.Coerce (Coercible, coerce)
# else
import Unsafe.Coerce (unsafeCoerce)
# endif
#endif

#if !(MIN_VERSION_base(4,18,0))
-- | A state transformer monad parameterized by the state and inner monad.
-- The implementation is copied from the transformers package with the
-- return tuple swapped.
--
-- /Since: 4.18.0.0/
newtype StateT s m a = StateT { forall s (m :: * -> *) a. StateT s m a -> s -> m (s, a)
runStateT :: s -> m (s, a) }

-- | /Since: 4.18.0.0/
instance Monad m => Functor (StateT s m) where
    fmap :: forall a b. (a -> b) -> StateT s m a -> StateT s m b
fmap = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM
    {-# INLINE fmap #-}

-- | /Since: 4.18.0.0/
instance Monad m => Applicative (StateT s m) where
    pure :: forall a. a -> StateT s m a
pure a
a = forall s (m :: * -> *) a. (s -> m (s, a)) -> StateT s m a
StateT forall a b. (a -> b) -> a -> b
$ \ s
s -> forall (m :: * -> *) a. Monad m => a -> m a
return (s
s, a
a)
    {-# INLINE pure #-}
    StateT s -> m (s, a -> b)
mf <*> :: forall a b. StateT s m (a -> b) -> StateT s m a -> StateT s m b
<*> StateT s -> m (s, a)
mx = forall s (m :: * -> *) a. (s -> m (s, a)) -> StateT s m a
StateT forall a b. (a -> b) -> a -> b
$ \ s
s -> do
        (s
s', a -> b
f) <- s -> m (s, a -> b)
mf s
s
        (s
s'', a
x) <- s -> m (s, a)
mx s
s'
        forall (m :: * -> *) a. Monad m => a -> m a
return (s
s'', a -> b
f a
x)
    {-# INLINE (<*>) #-}
    StateT s m a
m *> :: forall a b. StateT s m a -> StateT s m b -> StateT s m b
*> StateT s m b
k = StateT s m a
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
_ -> StateT s m b
k
    {-# INLINE (*>) #-}

# if MIN_VERSION_base(4,8,0)
(#.) :: Coercible b c => (b -> c) -> (a -> b) -> (a -> c)
#. :: forall b c a. Coercible b c => (b -> c) -> (a -> b) -> a -> c
(#.) b -> c
_f = coerce :: forall a b. Coercible a b => a -> b
coerce
# else
(#.) ::                  (b -> c) -> (a -> b) -> (a -> c)
(#.) _f = unsafeCoerce
# endif

-- | /Since: 4.18.0.0/
instance (Monad m) => Monad (StateT s m) where
    StateT s m a
m >>= :: forall a b. StateT s m a -> (a -> StateT s m b) -> StateT s m b
>>= a -> StateT s m b
k  = forall s (m :: * -> *) a. (s -> m (s, a)) -> StateT s m a
StateT forall a b. (a -> b) -> a -> b
$ \ s
s -> do
        (s
s', a
a) <- forall s (m :: * -> *) a. StateT s m a -> s -> m (s, a)
runStateT StateT s m a
m s
s
        forall s (m :: * -> *) a. StateT s m a -> s -> m (s, a)
runStateT (a -> StateT s m b
k a
a) s
s'
    {-# INLINE (>>=) #-}
# if !(MIN_VERSION_base(4,11,0))
    return = pure
# endif

-- | The `mapAccumM` function behaves like a combination of `mapM` and
-- `mapAccumL` that traverses the structure while evaluating the actions
-- and passing an accumulating parameter from left to right.
-- It returns a final value of this accumulator together with the new structure.
-- The accummulator is often used for caching the intermediate results of a computation.
--
--  @since 4.18.0.0
--
-- ==== __Examples__
--
-- Basic usage:
--
-- >>> let expensiveDouble a = putStrLn ("Doubling " <> show a) >> pure (2 * a)
-- >>> :{
-- mapAccumM (\cache a -> case lookup a cache of
--     Nothing -> expensiveDouble a >>= \double -> pure ((a, double):cache, double)
--     Just double -> pure (cache, double)
--     ) [] [1, 2, 3, 1, 2, 3]
-- :}
-- Doubling 1
-- Doubling 2
-- Doubling 3
-- ([(3,6),(2,4),(1,2)],[2,4,6,2,4,6])
--
mapAccumM
  :: forall m t s a b. (Monad m, Traversable t)
  => (s -> a -> m (s, b))
  -> s -> t a -> m (s, t b)
mapAccumM :: forall (m :: * -> *) (t :: * -> *) s a b.
(Monad m, Traversable t) =>
(s -> a -> m (s, b)) -> s -> t a -> m (s, t b)
mapAccumM s -> a -> m (s, b)
f s
s t a
t = forall s (m :: * -> *) a. StateT s m a -> s -> m (s, a)
runStateT (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall s (m :: * -> *) a. (s -> m (s, a)) -> StateT s m a
StateT forall b c a. Coercible b c => (b -> c) -> (a -> b) -> a -> c
#. forall a b c. (a -> b -> c) -> b -> a -> c
flip s -> a -> m (s, b)
f) t a
t) s
s

-- | 'forAccumM' is 'mapAccumM' with the arguments rearranged.
--
-- @since 4.18.0.0
forAccumM
  :: (Monad m, Traversable t)
  => s -> t a -> (s -> a -> m (s, b)) -> m (s, t b)
{-# INLINE forAccumM #-}
forAccumM :: forall (m :: * -> *) (t :: * -> *) s a b.
(Monad m, Traversable t) =>
s -> t a -> (s -> a -> m (s, b)) -> m (s, t b)
forAccumM s
s t a
t s -> a -> m (s, b)
f = forall (m :: * -> *) (t :: * -> *) s a b.
(Monad m, Traversable t) =>
(s -> a -> m (s, b)) -> s -> t a -> m (s, t b)
mapAccumM s -> a -> m (s, b)
f s
s t a
t
#endif