{-# LANGUAGE AllowAmbiguousTypes    #-}
{-# LANGUAGE BlockArguments         #-}
{-# LANGUAGE ConstraintKinds        #-}
{-# LANGUAGE DataKinds              #-}
{-# LANGUAGE EmptyCase              #-}
{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs                  #-}
{-# LANGUAGE InstanceSigs           #-}
{-# LANGUAGE LambdaCase             #-}
{-# LANGUAGE PolyKinds              #-}
{-# LANGUAGE RankNTypes             #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE StandaloneDeriving     #-}
{-# LANGUAGE TypeApplications       #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators          #-}
{-# LANGUAGE UndecidableInstances   #-}

module Control.Monad.Oops
  ( -- * MTL/transformer utilities
    catchFM,
    catchM,

    throwFM,
    throwM,

    snatchFM,
    snatchM,

    runOops,
    suspendM,

    catchAndExitFailureM,

    throwLeftM,
    throwNothingM,
    throwNothingAsM,

    recoverM,
    recoverOrVoidM,

    DV.CouldBeF (..),
    DV.CouldBe  (..),
    DV.CouldBeAnyOfF,
    DV.CouldBeAnyOf,
    DV.Variant,
    DV.VariantF(..),

  ) where

import Control.Monad.Error.Class (MonadError (..))
import Control.Monad.Except (ExceptT(ExceptT))
import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Monad.Trans.Except (mapExceptT, runExceptT)
import Data.Function ((&))
import Data.Functor.Identity (Identity (..))
import Data.Variant (Catch, CatchF(..), CouldBe, CouldBeF(..), Variant, VariantF, preposterous)
import Data.Void (Void, absurd)

import qualified Data.Variant as DV
import qualified System.Exit  as IO

-- | When working in some monadic context, using 'catch' becomes trickier. The
-- intuitive behaviour is that each 'catch' shrinks the variant in the left
-- side of my 'MonadError', but this is therefore type-changing: as we can only
-- 'throwError' and 'catchError' with a 'MonadError' type, this is impossible!
--
-- To get round this problem, we have to specialise to 'ExceptT', which allows
-- us to map over the error type and change it as we go. If the error we catch
-- is the one in the variant that we want to handle, we pluck it out and deal
-- with it. Otherwise, we "re-throw" the variant minus the one we've handled.
catchFM :: forall x e e' f m a. ()
  => Monad m
  => CatchF x e e'
  => (f x -> ExceptT (VariantF f e') m a)
  -> ExceptT (VariantF f e ) m a
  -> ExceptT (VariantF f e') m a
catchFM :: forall {k} (x :: k) (e :: [k]) (e' :: [k]) (f :: k -> *)
       (m :: * -> *) a.
(Monad m, CatchF x e e') =>
(f x -> ExceptT (VariantF f e') m a)
-> ExceptT (VariantF f e) m a -> ExceptT (VariantF f e') m a
catchFM f x -> ExceptT (VariantF f e') m a
recover ExceptT (VariantF f e) m a
xs = forall (m :: * -> *) e a (n :: * -> *) e' b.
(m (Either e a) -> n (Either e' b))
-> ExceptT e m a -> ExceptT e' n b
mapExceptT (forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Either (VariantF f e) a -> m (Either (VariantF f e') a)
go) ExceptT (VariantF f e) m a
xs
  where
    go :: Either (VariantF f e) a -> m (Either (VariantF f e') a)
go = \case
      Right a
success -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right a
success)
      Left  VariantF f e
failure -> case forall {k} (x :: k) (xs :: [k]) (ys :: [k]) (f :: k -> *).
CatchF x xs ys =>
VariantF f xs -> Either (VariantF f ys) (f x)
catchF @x VariantF f e
failure of
        Right f x
hit  -> forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (f x -> ExceptT (VariantF f e') m a
recover f x
hit)
        Left  VariantF f e'
miss -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left VariantF f e'
miss)

-- | Just the same as 'catchFM', but specialised for our plain 'Variant' and
-- sounding much less like a radio station.
catchM :: forall x e e' m a. ()
  => Monad m
  => Catch x e e'
  => (x -> ExceptT (Variant e') m a)
  -> ExceptT (Variant e ) m a
  -> ExceptT (Variant e') m a
catchM :: forall x (e :: [*]) (e' :: [*]) (m :: * -> *) a.
(Monad m, Catch x e e') =>
(x -> ExceptT (Variant e') m a)
-> ExceptT (Variant e) m a -> ExceptT (Variant e') m a
catchM x -> ExceptT (Variant e') m a
recover ExceptT (Variant e) m a
xs
  = forall {k} (x :: k) (e :: [k]) (e' :: [k]) (f :: k -> *)
       (m :: * -> *) a.
(Monad m, CatchF x e e') =>
(f x -> ExceptT (VariantF f e') m a)
-> ExceptT (VariantF f e) m a -> ExceptT (VariantF f e') m a
catchFM (x -> ExceptT (Variant e') m a
recover forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Identity a -> a
runIdentity) ExceptT (Variant e) m a
xs

-- | Same as 'catchFM' except the error is not removed from the type.
-- This is useful for writing recursive computations or computations that
-- rethrow the same error type.
snatchFM
  :: forall x e f m a. ()
  => Monad m
  => e `CouldBe` x
  => (f x -> ExceptT (VariantF f e) m a)
  -> ExceptT (VariantF f e) m a
  -> ExceptT (VariantF f e) m a
snatchFM :: forall x (e :: [*]) (f :: * -> *) (m :: * -> *) a.
(Monad m, CouldBe e x) =>
(f x -> ExceptT (VariantF f e) m a)
-> ExceptT (VariantF f e) m a -> ExceptT (VariantF f e) m a
snatchFM f x -> ExceptT (VariantF f e) m a
recover ExceptT (VariantF f e) m a
xs = forall (m :: * -> *) e a (n :: * -> *) e' b.
(m (Either e a) -> n (Either e' b))
-> ExceptT e m a -> ExceptT e' n b
mapExceptT (forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Either (VariantF f e) a -> m (Either (VariantF f e) a)
go) ExceptT (VariantF f e) m a
xs
  where
    go :: Either (VariantF f e) a -> m (Either (VariantF f e) a)
go = \case
      Right a
success -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right a
success)
      Left  VariantF f e
failure -> case forall k (xs :: [k]) (x :: k) (f :: k -> *).
CouldBeF xs x =>
VariantF f xs -> Either (VariantF f xs) (f x)
snatchF @_ @_ @x VariantF f e
failure of
        Right f x
hit  -> forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (f x -> ExceptT (VariantF f e) m a
recover f x
hit)
        Left  VariantF f e
miss -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left VariantF f e
miss)


-- | Same as 'catchM' except the error is not removed from the type.
-- This is useful for writing recursive computations or computations that
-- rethrow the same error type.
snatchM :: forall x e m a. ()
  => Monad m
  => e `CouldBe` x
  => (x -> ExceptT (Variant e) m a)
  -> ExceptT (Variant e) m a
  -> ExceptT (Variant e) m a
snatchM :: forall x (e :: [*]) (m :: * -> *) a.
(Monad m, CouldBe e x) =>
(x -> ExceptT (Variant e) m a)
-> ExceptT (Variant e) m a -> ExceptT (Variant e) m a
snatchM x -> ExceptT (Variant e) m a
recover ExceptT (Variant e) m a
xs = forall x (e :: [*]) (f :: * -> *) (m :: * -> *) a.
(Monad m, CouldBe e x) =>
(f x -> ExceptT (VariantF f e) m a)
-> ExceptT (VariantF f e) m a -> ExceptT (VariantF f e) m a
snatchFM (x -> ExceptT (Variant e) m a
recover forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Identity a -> a
runIdentity) ExceptT (Variant e) m a
xs

-- | Throw an error into a variant 'MonadError' context. Note that this /isn't/
-- type-changing, so this can work for any 'MonadError', rather than just
-- 'ExceptT'.
throwFM :: forall x e f m a. ()
  => MonadError (VariantF f e) m
  => e `CouldBe` x
  => f x
  -> m a
throwFM :: forall x (e :: [*]) (f :: * -> *) (m :: * -> *) a.
(MonadError (VariantF f e) m, CouldBe e x) =>
f x -> m a
throwFM = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k (xs :: [k]) (x :: k) (f :: k -> *).
CouldBeF xs x =>
f x -> VariantF f xs
throwF

-- | Same as 'throwFM', but without the @f@ context. Given a value of some type
-- within a 'Variant' within a 'MonadError' context, "throw" the error.
throwM :: forall x e m a. ()
  => MonadError (Variant e) m
  => e `CouldBe` x
  => x
  -> m a
throwM :: forall x (e :: [*]) (m :: * -> *) a.
(MonadError (Variant e) m, CouldBe e x) =>
x -> m a
throwM = forall x (e :: [*]) (f :: * -> *) (m :: * -> *) a.
(MonadError (VariantF f e) m, CouldBe e x) =>
f x -> m a
throwFM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Identity a
Identity

-- | Add 'ExceptT (Variant '[])' to the monad transformer stack.
runOops :: ()
  => Monad m
  => ExceptT (Variant '[]) m a
  -> m a
runOops :: forall (m :: * -> *) a. Monad m => ExceptT (Variant '[]) m a -> m a
runOops ExceptT (Variant '[]) m a
f = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a. Void -> a
absurd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (f :: k -> *). VariantF f '[] -> Void
preposterous) forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ExceptT (Variant '[]) m a
f

-- | Suspend the 'ExceptT` monad transformer from the top of the stack so that the
-- stack can be manipulated without the 'ExceptT` layer.
suspendM :: forall x m a n b. ()
  => (m (Either x a) -> n (Either x b))
  -> ExceptT x m a
  -> ExceptT x n b
suspendM :: forall x (m :: * -> *) a (n :: * -> *) b.
(m (Either x a) -> n (Either x b))
-> ExceptT x m a -> ExceptT x n b
suspendM m (Either x a) -> n (Either x b)
f = forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (Either x a) -> n (Either x b)
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT

-- | Catch the specified exception.  If that exception is caught, exit the program.
catchAndExitFailureM :: forall x e m a. ()
  => MonadIO m
  => ExceptT (Variant (x : e)) m a
  -> ExceptT (Variant e) m a
catchAndExitFailureM :: forall x (e :: [*]) (m :: * -> *) a.
MonadIO m =>
ExceptT (Variant (x : e)) m a -> ExceptT (Variant e) m a
catchAndExitFailureM = forall x (e :: [*]) (e' :: [*]) (m :: * -> *) a.
(Monad m, Catch x e e') =>
(x -> ExceptT (Variant e') m a)
-> ExceptT (Variant e) m a -> ExceptT (Variant e') m a
catchM @x (forall a b. a -> b -> a
const (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a. IO a
IO.exitFailure))

-- | When the expression of type 'Either x a' evaluates to 'Left x', throw the 'x',
-- otherwise return 'a'.
throwLeftM :: forall x e m a. ()
  => MonadError (Variant e) m
  => CouldBeF e x
  => Monad m
  => Either x a
  -> m a
throwLeftM :: forall x (e :: [*]) (m :: * -> *) a.
(MonadError (Variant e) m, CouldBeF e x, Monad m) =>
Either x a -> m a
throwLeftM = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall x (e :: [*]) (m :: * -> *) a.
(MonadError (Variant e) m, CouldBe e x) =>
x -> m a
throwM forall (f :: * -> *) a. Applicative f => a -> f a
pure

-- | When the expression of type 'Maybe a' evaluates to 'Nothing', throw '()',
-- otherwise return 'a'.
throwNothingM :: ()
  => MonadError (Variant e) m
  => CouldBeF e ()
  => Monad m
  => Maybe a
  -> m a
throwNothingM :: forall (e :: [*]) (m :: * -> *) a.
(MonadError (Variant e) m, CouldBeF e (), Monad m) =>
Maybe a -> m a
throwNothingM = forall e (es :: [*]) (m :: * -> *) a.
(MonadError (Variant es) m, CouldBe es e) =>
e -> Maybe a -> m a
throwNothingAsM ()

-- | When the expression of type 'Maybe a' evaluates to 'Nothing', throw the specified value,
-- otherwise return 'a'.
throwNothingAsM :: forall e es m a. ()
  => MonadError (Variant es) m
  => CouldBe es e
  => e
  -> Maybe a
  -> m a
throwNothingAsM :: forall e (es :: [*]) (m :: * -> *) a.
(MonadError (Variant es) m, CouldBe es e) =>
e -> Maybe a -> m a
throwNothingAsM e
e = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall x (e :: [*]) (m :: * -> *) a.
(MonadError (Variant e) m, CouldBe e x) =>
x -> m a
throwM e
e) forall (f :: * -> *) a. Applicative f => a -> f a
pure

-- | Catch the specified exception and return it instead.
-- The evaluated computation must return the same type that is being caught.
recoverM :: forall x e m a. ()
  => Monad m
  => (x -> a)
  -> ExceptT (Variant (x : e)) m a
  -> ExceptT (Variant e) m a
recoverM :: forall x (e :: [*]) (m :: * -> *) a.
Monad m =>
(x -> a)
-> ExceptT (Variant (x : e)) m a -> ExceptT (Variant e) m a
recoverM x -> a
g ExceptT (Variant (x : e)) m a
f = ExceptT (Variant (x : e)) m a
f forall a b. a -> (a -> b) -> b
& forall x (e :: [*]) (e' :: [*]) (m :: * -> *) a.
(Monad m, Catch x e e') =>
(x -> ExceptT (Variant e') m a)
-> ExceptT (Variant e) m a -> ExceptT (Variant e') m a
catchM (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> a
g)

-- | Catch the specified exception and return it instead.  The evaluated computation
-- must return `Void` (ie. it never returns)
recoverOrVoidM :: forall x e m. ()
  => Monad m
  => ExceptT (Variant (x : e)) m Void
  -> ExceptT (Variant e) m x
recoverOrVoidM :: forall x (e :: [*]) (m :: * -> *).
Monad m =>
ExceptT (Variant (x : e)) m Void -> ExceptT (Variant e) m x
recoverOrVoidM ExceptT (Variant (x : e)) m Void
f = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Void -> a
absurd forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. b -> Either a b
Right ExceptT (Variant (x : e)) m Void
f forall a b. a -> (a -> b) -> b
& forall x (e :: [*]) (e' :: [*]) (m :: * -> *) a.
(Monad m, Catch x e e') =>
(x -> ExceptT (Variant e') m a)
-> ExceptT (Variant e) m a -> ExceptT (Variant e') m a
catchM @x (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left))