{-# LANGUAGE CPP #-}
{-# LANGUAGE InstanceSigs #-}
-- |
-- Module: Language.KURE.MonadCatch
-- Copyright: (c) 2012--2021 The University of Kansas
-- License: BSD3
--
-- Maintainer: Neil Sculthorpe <neil.sculthorpe@ntu.ac.uk>
-- Stability: beta
-- Portability: ghc
--
-- This module provides classes for catch-like operations on 'Monad's.

module Language.KURE.MonadCatch
           ( -- * Monads with a Catch
             MonadCatch(..)
             -- ** The KURE Monad
           , KureM
           , runKureM
           , fromKureM
           , liftKureM
             -- ** The IO Monad
           , liftAndCatchIO
             -- ** Combinators
           , (<+)
           , catchesM
           , tryM
           , mtryM
           , attemptM
           , testM
           , notM
           , modFailMsg
           , setFailMsg
           , prefixFailMsg
           , withPatFailMsg
) where

import Prelude hiding (foldr)

import Control.Exception (catch, SomeException)

#if !MIN_VERSION_base(4,13,0)
import Control.Monad.Fail (MonadFail)
import qualified Control.Monad.Fail
#endif

import Control.Monad (liftM, ap, join)
import Control.Monad.IO.Class

import Data.Foldable
import Data.List (isPrefixOf)

import Language.KURE.Combinators.Monad

infixl 3 <+

------------------------------------------------------------------------------------------

-- | 'Monad's with a catch for 'fail'.
--   The following laws are expected to hold:
--
-- > fail msg `catchM` f == f msg
-- > return a `catchM` f == return a

class MonadFail m => MonadCatch m where
  -- | Catch a failing monadic computation.
  catchM :: m a -> (String -> m a) -> m a

------------------------------------------------------------------------------------------

-- | 'KureM' is the minimal structure that can be an instance of 'MonadCatch'.
--   The KURE user is free to either use 'KureM' or provide their own monad.
--   'KureM' is essentially the same as 'Either' 'String', except that it supports a 'MonadCatch' instance which 'Either' 'String' does not (because its 'fail' method calls 'error')
--   A major advantage of this is that monadic pattern match failures are caught safely.
data KureM a = Failure String | Success a deriving (KureM a -> KureM a -> Bool
(KureM a -> KureM a -> Bool)
-> (KureM a -> KureM a -> Bool) -> Eq (KureM a)
forall a. Eq a => KureM a -> KureM a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KureM a -> KureM a -> Bool
$c/= :: forall a. Eq a => KureM a -> KureM a -> Bool
== :: KureM a -> KureM a -> Bool
$c== :: forall a. Eq a => KureM a -> KureM a -> Bool
Eq, Int -> KureM a -> ShowS
[KureM a] -> ShowS
KureM a -> String
(Int -> KureM a -> ShowS)
-> (KureM a -> String) -> ([KureM a] -> ShowS) -> Show (KureM a)
forall a. Show a => Int -> KureM a -> ShowS
forall a. Show a => [KureM a] -> ShowS
forall a. Show a => KureM a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KureM a] -> ShowS
$cshowList :: forall a. Show a => [KureM a] -> ShowS
show :: KureM a -> String
$cshow :: forall a. Show a => KureM a -> String
showsPrec :: Int -> KureM a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> KureM a -> ShowS
Show)

-- | Eliminator for 'KureM'.
runKureM :: (a -> b) -> (String -> b) -> KureM a -> b
runKureM :: (a -> b) -> (String -> b) -> KureM a -> b
runKureM a -> b
_ String -> b
f (Failure String
msg) = String -> b
f String
msg
runKureM a -> b
s String -> b
_ (Success a
a)   = a -> b
s a
a
{-# INLINE runKureM #-}

-- | Get the value from a 'KureM', providing a function to handle the error case.
fromKureM :: (String -> a) -> KureM a -> a
fromKureM :: (String -> a) -> KureM a -> a
fromKureM = (a -> a) -> (String -> a) -> KureM a -> a
forall a b. (a -> b) -> (String -> b) -> KureM a -> b
runKureM a -> a
forall a. a -> a
id
{-# INLINE fromKureM #-}

-- | Lift a 'KureM' computation to any other monad.
liftKureM :: MonadFail m => KureM a -> m a
liftKureM :: KureM a -> m a
liftKureM = (a -> m a) -> (String -> m a) -> KureM a -> m a
forall a b. (a -> b) -> (String -> b) -> KureM a -> b
runKureM a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail
{-# INLINE liftKureM #-}

instance Monad KureM where
   return :: a -> KureM a
   return :: a -> KureM a
return = a -> KureM a
forall a. a -> KureM a
Success
   {-# INLINE return #-}

   (>>=) :: KureM a -> (a -> KureM b) -> KureM b
   (Success a
a)   >>= :: KureM a -> (a -> KureM b) -> KureM b
>>= a -> KureM b
f = a -> KureM b
f a
a
   (Failure String
msg) >>= a -> KureM b
_ = String -> KureM b
forall a. String -> KureM a
Failure String
msg
   {-# INLINE (>>=) #-}

#if !MIN_VERSION_base(4,13,0)
   fail :: String -> KureM a
   fail = Failure
   {-# INLINE fail #-}
#endif

instance MonadFail KureM where
   fail :: String -> KureM a
   fail :: String -> KureM a
fail = String -> KureM a
forall a. String -> KureM a
Failure
   {-# INLINE fail #-}

instance MonadCatch KureM where
   catchM :: KureM a -> (String -> KureM a) -> KureM a
   (Success a
a)   catchM :: KureM a -> (String -> KureM a) -> KureM a
`catchM` String -> KureM a
_ = a -> KureM a
forall a. a -> KureM a
Success a
a
   (Failure String
msg) `catchM` String -> KureM a
f = String -> KureM a
f String
msg
   {-# INLINE catchM #-}

instance Functor KureM where
   fmap :: (a -> b) -> KureM a -> KureM b
   fmap :: (a -> b) -> KureM a -> KureM b
fmap = (a -> b) -> KureM a -> KureM b
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM
   {-# INLINE fmap #-}

instance Applicative KureM where
   pure :: a -> KureM a
   pure :: a -> KureM a
pure = a -> KureM a
forall (m :: * -> *) a. Monad m => a -> m a
return
   {-# INLINE pure #-}

   (<*>) :: KureM (a -> b) -> KureM a -> KureM b
   <*> :: KureM (a -> b) -> KureM a -> KureM b
(<*>) = KureM (a -> b) -> KureM a -> KureM b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
   {-# INLINE (<*>) #-}

-------------------------------------------------------------------------------

-- | A monadic catch that ignores the error message.
(<+) :: MonadCatch m => m a -> m a -> m a
m a
ma <+ :: m a -> m a -> m a
<+ m a
mb = m a
ma m a -> (String -> m a) -> m a
forall (m :: * -> *) a.
MonadCatch m =>
m a -> (String -> m a) -> m a
`catchM` (\String
_ -> m a
mb)
{-# INLINE (<+) #-}

-- | Select the first monadic computation that succeeds, discarding any thereafter.
catchesM :: (Foldable f, MonadCatch m) => f (m a) -> m a
catchesM :: f (m a) -> m a
catchesM = (m a -> m a -> m a) -> m a -> f (m a) -> m a
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr m a -> m a -> m a
forall (m :: * -> *) a. MonadCatch m => m a -> m a -> m a
(<+) (String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"catchesM failed")
{-# INLINE catchesM #-}

-- | Catch a failing monadic computation, making it succeed with a constant value.
tryM :: MonadCatch m => a -> m a -> m a
tryM :: a -> m a -> m a
tryM a
a m a
ma = m a
ma m a -> m a -> m a
forall (m :: * -> *) a. MonadCatch m => m a -> m a -> m a
<+ a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
{-# INLINE tryM #-}

-- | Catch a failing monadic computation, making it succeed with 'mempty'.
mtryM :: (MonadCatch m, Monoid a) => m a -> m a
mtryM :: m a -> m a
mtryM = a -> m a -> m a
forall (m :: * -> *) a. MonadCatch m => a -> m a -> m a
tryM a
forall a. Monoid a => a
mempty
{-# INLINE mtryM #-}

-- | Catch a failing monadic computation, making it succeed with an error message.
attemptM :: MonadCatch m => m a -> m (Either String a)
attemptM :: m a -> m (Either String a)
attemptM m a
ma = (a -> Either String a) -> m a -> m (Either String a)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM a -> Either String a
forall a b. b -> Either a b
Right m a
ma m (Either String a)
-> (String -> m (Either String a)) -> m (Either String a)
forall (m :: * -> *) a.
MonadCatch m =>
m a -> (String -> m a) -> m a
`catchM` (Either String a -> m (Either String a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String a -> m (Either String a))
-> (String -> Either String a) -> String -> m (Either String a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Either String a
forall a b. a -> Either a b
Left)
{-# INLINE attemptM #-}

-- | Determine if a monadic computation succeeds.
testM :: MonadCatch m => m a -> m Bool
testM :: m a -> m Bool
testM m a
ma = (a -> Bool) -> m a -> m Bool
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (Bool -> a -> Bool
forall a b. a -> b -> a
const Bool
True) m a
ma m Bool -> m Bool -> m Bool
forall (m :: * -> *) a. MonadCatch m => m a -> m a -> m a
<+ Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
{-# INLINE testM #-}

-- | Fail if the monadic computation succeeds; succeed with @()@ if it fails.
notM :: MonadCatch m => m a -> m ()
notM :: m a -> m ()
notM m a
ma = m Bool -> m () -> m () -> m ()
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM (m a -> m Bool
forall (m :: * -> *) a. MonadCatch m => m a -> m Bool
testM m a
ma) (String -> m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"notM of success") (() -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
{-# INLINE notM #-}

-- | Modify the error message of a failing monadic computation.
--   Successful computations are unaffected.
modFailMsg :: MonadCatch m => (String -> String) -> m a -> m a
modFailMsg :: ShowS -> m a -> m a
modFailMsg ShowS
f m a
ma = m a
ma m a -> (String -> m a) -> m a
forall (m :: * -> *) a.
MonadCatch m =>
m a -> (String -> m a) -> m a
`catchM` (String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m a) -> ShowS -> String -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
f)
{-# INLINE modFailMsg #-}

-- | Set the error message of a failing monadic computation.
--   Successful computations are unaffected.
setFailMsg :: MonadCatch m => String -> m a -> m a
setFailMsg :: String -> m a -> m a
setFailMsg String
msg = ShowS -> m a -> m a
forall (m :: * -> *) a. MonadCatch m => ShowS -> m a -> m a
modFailMsg (String -> ShowS
forall a b. a -> b -> a
const String
msg)
{-# INLINE setFailMsg #-}

-- | Add a prefix to the error message of a failing monadic computation.
--   Successful computations are unaffected.
prefixFailMsg :: MonadCatch m => String -> m a -> m a
prefixFailMsg :: String -> m a -> m a
prefixFailMsg String
msg = ShowS -> m a -> m a
forall (m :: * -> *) a. MonadCatch m => ShowS -> m a -> m a
modFailMsg (String
msg String -> ShowS
forall a. [a] -> [a] -> [a]
++)
{-# INLINE prefixFailMsg #-}

-- | Use the given error message whenever a monadic pattern match failure occurs.
withPatFailMsg :: MonadCatch m => String -> m a -> m a
withPatFailMsg :: String -> m a -> m a
withPatFailMsg String
msg = ShowS -> m a -> m a
forall (m :: * -> *) a. MonadCatch m => ShowS -> m a -> m a
modFailMsg (\ String
e -> if String
"Pattern match failure" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
e then String
msg else String
e)
{-# INLINE withPatFailMsg #-}

------------------------------------------------------------------------------------------

-- | The String is generated by 'show'ing the exception.
instance MonadCatch IO where
  catchM :: IO a -> (String -> IO a) -> IO a
  catchM :: IO a -> (String -> IO a) -> IO a
catchM IO a
io String -> IO a
f = IO a
io IO a -> (SomeException -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\ SomeException
e -> String -> IO a
f (String -> IO a) -> String -> IO a
forall a b. (a -> b) -> a -> b
$ SomeException -> String
forall a. Show a => a -> String
show (SomeException
e :: SomeException))
  {-# INLINE catchM #-}

-- | Lift a computation from the 'IO' monad, catching failures in the target monad.
liftAndCatchIO :: (MonadCatch m, MonadIO m) => IO a -> m a
liftAndCatchIO :: IO a -> m a
liftAndCatchIO IO a
io = m (m a) -> m a
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (m (m a) -> m a) -> m (m a) -> m a
forall a b. (a -> b) -> a -> b
$ IO (m a) -> m (m a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO ((a -> m a) -> IO a -> IO (m a)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return IO a
io IO (m a) -> (String -> IO (m a)) -> IO (m a)
forall (m :: * -> *) a.
MonadCatch m =>
m a -> (String -> m a) -> m a
`catchM` (m a -> IO (m a)
forall (m :: * -> *) a. Monad m => a -> m a
return (m a -> IO (m a)) -> (String -> m a) -> String -> IO (m a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail))
{-# INLINE liftAndCatchIO #-}

------------------------------------------------------------------------------------------