{-# LANGUAGE CPP #-}
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Trustworthy #-}
#endif
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
module Control.Monad.Trans.Either
( EitherT(..)
, eitherT
, bimapEitherT
, mapEitherT
, hoistEither
, bracketEitherT
, bracketEitherT_
, left
, right
, swapEitherT
) where
import Control.Applicative
import Control.Monad (liftM, MonadPlus(..))
import Control.Monad.Base (MonadBase(..), liftBaseDefault)
import Control.Monad.Cont.Class
import Control.Monad.Error.Class
import Control.Monad.Free.Class
import Control.Monad.Catch as MonadCatch
import Control.Monad.Fix
import Control.Monad.IO.Class
import Control.Monad.Reader.Class
import Control.Monad.State (MonadState,get,put)
import Control.Monad.Trans.Class
import Control.Monad.Trans.Control (MonadBaseControl(..), MonadTransControl(..), defaultLiftBaseWith, defaultRestoreM)
import Control.Monad.Writer.Class
import Control.Monad.Random (MonadRandom,getRandom,getRandoms,getRandomR,getRandomRs)
import Control.Monad.Morph (MFunctor, hoist)
import Data.Either.Combinators ( swapEither )
import Data.Foldable
import Data.Function (on)
import Data.Functor.Bind
import Data.Functor.Plus
import Data.Traversable
import Data.Semigroup
newtype EitherT e m a = EitherT { runEitherT :: m (Either e a) }
instance Show (m (Either e a)) => Show (EitherT e m a) where
showsPrec d (EitherT m) = showParen (d > 10) $
showString "EitherT " . showsPrec 11 m
{-# INLINE showsPrec #-}
instance Read (m (Either e a)) => Read (EitherT e m a) where
readsPrec d = readParen (d > 10)
(\r' -> [ (EitherT m, t)
| ("EitherT", s) <- lex r'
, (m, t) <- readsPrec 11 s])
{-# INLINE readsPrec #-}
instance Eq (m (Either e a)) => Eq (EitherT e m a) where
(==) = (==) `on` runEitherT
{-# INLINE (==) #-}
instance Ord (m (Either e a)) => Ord (EitherT e m a) where
compare = compare `on` runEitherT
{-# INLINE compare #-}
instance MFunctor (EitherT e) where
hoist f = EitherT . f . runEitherT
{-# INLINE hoist #-}
eitherT :: Monad m => (a -> m c) -> (b -> m c) -> EitherT a m b -> m c
eitherT f g (EitherT m) = m >>= \z -> case z of
Left a -> f a
Right b -> g b
{-# INLINE eitherT #-}
left :: Monad m => e -> EitherT e m a
left = EitherT . return . Left
{-# INLINE left #-}
right :: Monad m => a -> EitherT e m a
right = return
{-# INLINE right #-}
bimapEitherT :: Functor m => (e -> f) -> (a -> b) -> EitherT e m a -> EitherT f m b
bimapEitherT f g (EitherT m) = EitherT (fmap h m) where
h (Left e) = Left (f e)
h (Right a) = Right (g a)
{-# INLINE bimapEitherT #-}
mapEitherT :: (m (Either e a) -> n (Either e' b)) -> EitherT e m a -> EitherT e' n b
mapEitherT f m = EitherT $ f (runEitherT m)
{-# INLINE mapEitherT #-}
hoistEither :: Monad m => Either e a -> EitherT e m a
hoistEither = EitherT . return
{-# INLINE hoistEither #-}
bracketEitherT :: Monad m => EitherT e m a -> (a -> EitherT e m b) -> (a -> EitherT e m c) -> EitherT e m c
bracketEitherT before after thing = do
a <- before
r <- thing a `catchError` (\err -> after a >> left err)
_ <- after a
return r
bracketEitherT_ :: Monad m => EitherT e m a -> EitherT e m b -> EitherT e m c -> EitherT e m c
bracketEitherT_ before after thing = do
_ <- before
r <- thing `catchError` (\err -> after >> left err)
_ <- after
return r
swapEitherT :: (Functor m) => EitherT e m a -> EitherT a m e
swapEitherT = EitherT . fmap swapEither . runEitherT
{-# INLINE swapEitherT #-}
instance Monad m => Functor (EitherT e m) where
fmap f = EitherT . liftM (fmap f) . runEitherT
{-# INLINE fmap #-}
instance Monad m => Apply (EitherT e m) where
EitherT f <.> EitherT v = EitherT $ f >>= \mf -> case mf of
Left e -> return (Left e)
Right k -> v >>= \mv -> case mv of
Left e -> return (Left e)
Right x -> return (Right (k x))
{-# INLINE (<.>) #-}
instance Monad m => Applicative (EitherT e m) where
pure a = EitherT $ return (Right a)
{-# INLINE pure #-}
EitherT f <*> EitherT v = EitherT $ f >>= \mf -> case mf of
Left e -> return (Left e)
Right k -> v >>= \mv -> case mv of
Left e -> return (Left e)
Right x -> return (Right (k x))
{-# INLINE (<*>) #-}
instance (Monad m, Monoid e) => Alternative (EitherT e m) where
EitherT m <|> EitherT n = EitherT $ m >>= \a -> case a of
Left l -> liftM (\b -> case b of
Left l' -> Left (mappend l l')
Right r -> Right r) n
Right r -> return (Right r)
{-# INLINE (<|>) #-}
empty = EitherT $ return (Left mempty)
{-# INLINE empty #-}
instance (Monad m, Monoid e) => MonadPlus (EitherT e m) where
mplus = (<|>)
{-# INLINE mplus #-}
mzero = empty
{-# INLINE mzero #-}
instance Monad m => Semigroup (EitherT e m a) where
EitherT m <> EitherT n = EitherT $ m >>= \a -> case a of
Left _ -> n
Right r -> return (Right r)
{-# INLINE (<>) #-}
instance (Monad m, Semigroup e) => Alt (EitherT e m) where
EitherT m <!> EitherT n = EitherT $ m >>= \a -> case a of
Left l -> liftM (\b -> case b of
Left l' -> Left (l <> l')
Right r -> Right r) n
Right r -> return (Right r)
{-# INLINE (<!>) #-}
instance Monad m => Bind (EitherT e m) where
(>>-) = (>>=)
{-# INLINE (>>-) #-}
instance Monad m => Monad (EitherT e m) where
return a = EitherT $ return (Right a)
{-# INLINE return #-}
m >>= k = EitherT $ do
a <- runEitherT m
case a of
Left l -> return (Left l)
Right r -> runEitherT (k r)
{-# INLINE (>>=) #-}
fail = EitherT . fail
{-# INLINE fail #-}
instance Monad m => MonadError e (EitherT e m) where
throwError = EitherT . return . Left
{-# INLINE throwError #-}
EitherT m `catchError` h = EitherT $ m >>= \a -> case a of
Left l -> runEitherT (h l)
Right r -> return (Right r)
{-# INLINE catchError #-}
instance MonadThrow m => MonadThrow (EitherT e m) where
throwM = lift . throwM
{-# INLINE throwM #-}
instance MonadCatch m => MonadCatch (EitherT e m) where
catch (EitherT m) f = EitherT $ MonadCatch.catch m (runEitherT . f)
{-# INLINE catch #-}
instance MonadFix m => MonadFix (EitherT e m) where
mfix f = EitherT $ mfix $ \a -> runEitherT $ f $ case a of
Right r -> r
_ -> error "empty mfix argument"
{-# INLINE mfix #-}
instance MonadTrans (EitherT e) where
lift = EitherT . liftM Right
{-# INLINE lift #-}
instance MonadIO m => MonadIO (EitherT e m) where
liftIO = lift . liftIO
{-# INLINE liftIO #-}
instance MonadCont m => MonadCont (EitherT e m) where
callCC f = EitherT $
callCC $ \c ->
runEitherT (f (\a -> EitherT $ c (Right a)))
{-# INLINE callCC #-}
instance MonadReader r m => MonadReader r (EitherT e m) where
ask = lift ask
{-# INLINE ask #-}
local f (EitherT m) = EitherT (local f m)
{-# INLINE local #-}
instance MonadState s m => MonadState s (EitherT e m) where
get = lift get
{-# INLINE get #-}
put = lift . put
{-# INLINE put #-}
instance MonadWriter s m => MonadWriter s (EitherT e m) where
tell = lift . tell
{-# INLINE tell #-}
listen = mapEitherT $ \ m -> do
(a, w) <- listen m
return $! fmap (\ r -> (r, w)) a
{-# INLINE listen #-}
pass = mapEitherT $ \ m -> pass $ do
a <- m
return $! case a of
Left l -> (Left l, id)
Right (r, f) -> (Right r, f)
{-# INLINE pass #-}
instance MonadRandom m => MonadRandom (EitherT e m) where
getRandom = lift getRandom
{-# INLINE getRandom #-}
getRandoms = lift getRandoms
{-# INLINE getRandoms #-}
getRandomR = lift . getRandomR
{-# INLINE getRandomR #-}
getRandomRs = lift . getRandomRs
{-# INLINE getRandomRs #-}
instance Foldable m => Foldable (EitherT e m) where
foldMap f = foldMap (either mempty f) . runEitherT
{-# INLINE foldMap #-}
instance (Functor f, MonadFree f m) => MonadFree f (EitherT e m) where
wrap = EitherT . wrap . fmap runEitherT
instance (Monad f, Traversable f) => Traversable (EitherT e f) where
traverse f (EitherT a) =
EitherT <$> traverse (either (pure . Left) (fmap Right . f)) a
{-# INLINE traverse #-}
instance MonadBase b m => MonadBase b (EitherT e m) where
liftBase = liftBaseDefault
{-# INLINE liftBase #-}
#if MIN_VERSION_monad_control(1,0,0)
instance MonadTransControl (EitherT e) where
type StT (EitherT e) a = Either e a
liftWith f = EitherT $ liftM return $ f runEitherT
{-# INLINE liftWith #-}
restoreT = EitherT
{-# INLINE restoreT #-}
instance MonadBaseControl b m => MonadBaseControl b (EitherT e m) where
type StM (EitherT e m) a = StM m (StT (EitherT e) a)
liftBaseWith = defaultLiftBaseWith
{-# INLINE liftBaseWith #-}
restoreM = defaultRestoreM
{-# INLINE restoreM #-}
#else
instance MonadTransControl (EitherT e) where
newtype StT (EitherT e) a = StEitherT {unStEitherT :: Either e a}
liftWith f = EitherT $ liftM return $ f $ liftM StEitherT . runEitherT
{-# INLINE liftWith #-}
restoreT = EitherT . liftM unStEitherT
{-# INLINE restoreT #-}
instance MonadBaseControl b m => MonadBaseControl b (EitherT e m) where
newtype StM (EitherT e m) a = StMEitherT { unStMEitherT :: StM m (StT (EitherT e) a) }
liftBaseWith = defaultLiftBaseWith StMEitherT
{-# INLINE liftBaseWith #-}
restoreM = defaultRestoreM unStMEitherT
{-# INLINE restoreM #-}
#endif