-- | Auxiliary functions for the IO monad.

module Agda.Utils.IO where

import Control.Exception
import Control.Monad.State
import Control.Monad.Writer

-- | Catch 'IOException's.
--
class CatchIO m where
  catchIO :: m a -> (IOException -> m a) -> m a

-- | Alias of 'catch' for the IO monad.
--
instance CatchIO IO where
  catchIO :: forall a. IO a -> (IOException -> IO a) -> IO a
catchIO = IO a -> (IOException -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch

-- | Upon exception, the written output is lost.
--
instance CatchIO m => CatchIO (WriterT w m) where
  catchIO :: forall a.
WriterT w m a -> (IOException -> WriterT w m a) -> WriterT w m a
catchIO WriterT w m a
m IOException -> WriterT w m a
h = m (a, w) -> WriterT w m a
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterT (m (a, w) -> WriterT w m a) -> m (a, w) -> WriterT w m a
forall a b. (a -> b) -> a -> b
$ WriterT w m a -> m (a, w)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT WriterT w m a
m m (a, w) -> (IOException -> m (a, w)) -> m (a, w)
forall (m :: * -> *) a.
CatchIO m =>
m a -> (IOException -> m a) -> m a
`catchIO` \ IOException
e -> WriterT w m a -> m (a, w)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (IOException -> WriterT w m a
h IOException
e)

-- | Upon exception, the state is reset.
--
instance CatchIO m => CatchIO (StateT s m) where
  catchIO :: forall a.
StateT s m a -> (IOException -> StateT s m a) -> StateT s m a
catchIO StateT s m a
m IOException -> StateT s m a
h = (s -> m (a, s)) -> StateT s m a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (a, s)) -> StateT s m a)
-> (s -> m (a, s)) -> StateT s m a
forall a b. (a -> b) -> a -> b
$ \s
s -> StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
m s
s m (a, s) -> (IOException -> m (a, s)) -> m (a, s)
forall (m :: * -> *) a.
CatchIO m =>
m a -> (IOException -> m a) -> m a
`catchIO` \ IOException
e -> StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (IOException -> StateT s m a
h IOException
e) s
s