module Control.Monad.CatchIO ( MonadCatchIO(..),
                               E.Exception(..),
                               throw,
                               try, tryJust,
                               onException, bracket, bracket_,
                               finally, bracketOnError,
                               Handler(..), catches )

where

import Prelude hiding ( catch )

import qualified Control.Exception.Extensible as E

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Error
import Control.Monad.Writer
import Control.Monad.RWS

class MonadIO m => MonadCatchIO m where
    -- | Generalized version of 'E.catch'
    catch   :: E.Exception e => m a -> (e -> m a) -> m a
 
    -- | Generalized version of 'E.block'
    block   :: m a -> m a

    -- | Generalized version of 'E.unblock'
    unblock :: m a -> m a

-- | Generalized version of 'E.throwIO'
throw :: (MonadIO m, E.Exception e) => e -> m a

-- | Generalized version of 'E.try'
try :: (MonadCatchIO m, E.Exception e) => m a -> m (Either e a)

-- | Generalized version of 'E.tryJust'
tryJust :: (MonadCatchIO m, E.Exception e)
        => (e -> Maybe b) -> m a -> m (Either b a)

-- | Generalized version of 'E.Handler'
data Handler m a = forall e . E.Exception e => Handler (e -> m a)

-- | Generalized version of 'E.catches'
catches :: MonadCatchIO m => m a -> [Handler m a] -> m a
catches a handlers = a `catch` handler
    where handler e = foldr tryH (throw e) handlers
            where tryH (Handler h) res = case E.fromException e of
                                             Just e' -> h e'
                                             Nothing -> res


instance MonadCatchIO IO where
    catch   = E.catch
    block   = E.block
    unblock = E.unblock

instance MonadCatchIO m => MonadCatchIO (ReaderT r m) where
    m `catch` f = ReaderT $ \r -> (runReaderT m r)
                                    `catch` (\e -> runReaderT (f e) r)
    block       = mapReaderT block
    unblock     = mapReaderT unblock

instance MonadCatchIO m => MonadCatchIO (StateT s m) where
    m `catch` f = StateT $ \s -> (runStateT m s)
                                   `catch` (\e -> runStateT (f e) s)
    block       = mapStateT block
    unblock     = mapStateT unblock

instance (MonadCatchIO m, Error e) => MonadCatchIO (ErrorT e m) where
    m `catch` f = mapErrorT (\m' -> m' `catch` (\e -> runErrorT $ f e)) m
    block       = mapErrorT block
    unblock     = mapErrorT unblock

instance (Monoid w, MonadCatchIO m) => MonadCatchIO (WriterT w m) where
    m `catch` f = WriterT $ runWriterT m
                              `catch` \e -> runWriterT (f e)
    block       = mapWriterT block
    unblock     = mapWriterT unblock

instance (Monoid w, MonadCatchIO m) => MonadCatchIO (RWST r w s m) where
    m `catch` f = RWST $ \r s -> runRWST m r s
                           `catch` \e -> runRWST (f e) r s
    block       = mapRWST block
    unblock     = mapRWST unblock

throw = liftIO . E.throwIO

try a = catch (a >>= \ v -> return (Right v)) (\e -> return (Left e))

tryJust p a = do
  r <- try a
  case r of
        Right v -> return (Right v)
        Left  e -> case p e of
                        Nothing -> throw e `asTypeOf` (return $ Left undefined)
                        Just b  -> return (Left b)

-- | Generalized version of 'E.bracket'
bracket :: MonadCatchIO m => m a -> (a -> m b) -> (a -> m c) -> m c
bracket before after thing =
    block (do a <- before
              r <- unblock (thing a) `onException` after a
              _void $ after a
              return r)

-- | A variant of 'bracket' where the return value from the first computation
-- is not required.
bracket_ :: MonadCatchIO m
         => m a  -- ^ computation to run first (\"acquire resource\")
         -> m b  -- ^ computation to run last (\"release resource\")
         -> m c  -- ^ computation to run in-between
         -> m c  -- returns the value from the in-between computation
bracket_ before after thing =
   block $ do _void before
              r <- unblock thing `onException` after
              _void after
              return r

-- | A specialised variant of 'bracket' with just a computation to run
-- afterward.
finally :: MonadCatchIO m
        => m a -- ^ computation to run first
        -> m b -- ^ computation to run afterward (even if an exception was
               -- raised)
        -> m a -- returns the value from the first computation
thing `finally` after =
   block $ do r <- unblock thing `onException` after
              _void after
              return r

-- | Like 'bracket', but only performs the final action if there was an
-- exception raised by the in-between computation.
bracketOnError :: MonadCatchIO m
               => m a       -- ^ computation to run first (\"acquire resource\")
               -> (a -> m b)-- ^ computation to run last (\"release resource\")
               -> (a -> m c)-- ^ computation to run in-between
               -> m c       -- returns the value from the in-between
                            -- computation
bracketOnError before after thing =
   block $ do a <- before
              unblock (thing a) `onException` after a

-- | Generalized version of 'E.onException'
onException :: MonadCatchIO m => m a -> m b -> m a
onException a onEx = a `catch` (\e -> onEx >> throw (e::E.SomeException))

_void :: Monad m => m a -> m ()
_void a = a >> return ()