{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      : Jikka.Common.Alpha
-- Description : provides a monad to run alpha-conversion. / alpha 変換用のモナドを提供します。
-- Copyright   : (c) Kimiyuki Onaka, 2020
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
--
-- `Jikka.Common.Alpha` provides a monad to run alpha-conversion. This monad has only a feature to make unique numbers.
module Jikka.Common.Alpha where

import Control.Arrow (first)
import Control.Monad.Except
import Control.Monad.Identity (Identity (..))
import Control.Monad.Reader
import Control.Monad.Signatures
import Control.Monad.State.Strict
import Control.Monad.Writer.Strict

class Monad m => MonadAlpha m where
  nextCounter :: m Int

newtype AlphaT m a = AlphaT {AlphaT m a -> Int -> m (a, Int)
runAlphaT :: Int -> m (a, Int)}

instance Monad m => MonadAlpha (AlphaT m) where
  nextCounter :: AlphaT m Int
nextCounter = (Int -> m (Int, Int)) -> AlphaT m Int
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT (\Int
i -> (Int, Int) -> m (Int, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
i, Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))

instance Functor m => Functor (AlphaT m) where
  fmap :: (a -> b) -> AlphaT m a -> AlphaT m b
fmap a -> b
f (AlphaT Int -> m (a, Int)
x) = (Int -> m (b, Int)) -> AlphaT m b
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT (\Int
i -> ((a, Int) -> (b, Int)) -> m (a, Int) -> m (b, Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> b) -> (a, Int) -> (b, Int)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first a -> b
f) (Int -> m (a, Int)
x Int
i))

instance Monad m => Applicative (AlphaT m) where
  pure :: a -> AlphaT m a
pure a
x = (Int -> m (a, Int)) -> AlphaT m a
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT (\Int
i -> (a, Int) -> m (a, Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, Int
i))
  AlphaT Int -> m (a -> b, Int)
f <*> :: AlphaT m (a -> b) -> AlphaT m a -> AlphaT m b
<*> AlphaT Int -> m (a, Int)
x = (Int -> m (b, Int)) -> AlphaT m b
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT ((Int -> m (b, Int)) -> AlphaT m b)
-> (Int -> m (b, Int)) -> AlphaT m b
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    (a -> b
f, Int
i) <- Int -> m (a -> b, Int)
f Int
i
    (a
x, Int
i) <- Int -> m (a, Int)
x Int
i
    (b, Int) -> m (b, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> b
f a
x, Int
i)

instance Monad m => Monad (AlphaT m) where
  AlphaT Int -> m (a, Int)
x >>= :: AlphaT m a -> (a -> AlphaT m b) -> AlphaT m b
>>= a -> AlphaT m b
f = (Int -> m (b, Int)) -> AlphaT m b
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT ((Int -> m (b, Int)) -> AlphaT m b)
-> (Int -> m (b, Int)) -> AlphaT m b
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    (a
x, Int
i) <- Int -> m (a, Int)
x Int
i
    AlphaT m b -> Int -> m (b, Int)
forall (m :: * -> *) a. AlphaT m a -> Int -> m (a, Int)
runAlphaT (a -> AlphaT m b
f a
x) Int
i

instance MonadFix m => MonadFix (AlphaT m) where
  mfix :: (a -> AlphaT m a) -> AlphaT m a
mfix a -> AlphaT m a
f = (Int -> m (a, Int)) -> AlphaT m a
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT (\Int
i -> ((a, Int) -> m (a, Int)) -> m (a, Int)
forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix (\(a, Int)
x -> AlphaT m a -> Int -> m (a, Int)
forall (m :: * -> *) a. AlphaT m a -> Int -> m (a, Int)
runAlphaT (a -> AlphaT m a
f ((a, Int) -> a
forall a b. (a, b) -> a
fst (a, Int)
x)) Int
i))

liftCatch :: Catch e m (a, Int) -> Catch e (AlphaT m) a
liftCatch :: Catch e m (a, Int) -> Catch e (AlphaT m) a
liftCatch Catch e m (a, Int)
catchE AlphaT m a
m e -> AlphaT m a
h = (Int -> m (a, Int)) -> AlphaT m a
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT (\Int
i -> AlphaT m a -> Int -> m (a, Int)
forall (m :: * -> *) a. AlphaT m a -> Int -> m (a, Int)
runAlphaT AlphaT m a
m Int
i Catch e m (a, Int)
`catchE` \e
e -> AlphaT m a -> Int -> m (a, Int)
forall (m :: * -> *) a. AlphaT m a -> Int -> m (a, Int)
runAlphaT (e -> AlphaT m a
h e
e) Int
i)

instance MonadTrans AlphaT where
  lift :: m a -> AlphaT m a
lift m a
m = (Int -> m (a, Int)) -> AlphaT m a
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT ((Int -> m (a, Int)) -> AlphaT m a)
-> (Int -> m (a, Int)) -> AlphaT m a
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    a
a <- m a
m
    (a, Int) -> m (a, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, Int
i)

instance MonadError e m => MonadError e (AlphaT m) where
  throwError :: e -> AlphaT m a
throwError = m a -> AlphaT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> AlphaT m a) -> (e -> m a) -> e -> AlphaT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
  catchError :: AlphaT m a -> (e -> AlphaT m a) -> AlphaT m a
catchError = Catch e m (a, Int) -> AlphaT m a -> (e -> AlphaT m a) -> AlphaT m a
forall e (m :: * -> *) a.
Catch e m (a, Int) -> Catch e (AlphaT m) a
liftCatch Catch e m (a, Int)
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError

instance MonadIO m => MonadIO (AlphaT m) where
  liftIO :: IO a -> AlphaT m a
liftIO = m a -> AlphaT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> AlphaT m a) -> (IO a -> m a) -> IO a -> AlphaT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

evalAlphaT :: Functor m => AlphaT m a -> Int -> m a
evalAlphaT :: AlphaT m a -> Int -> m a
evalAlphaT AlphaT m a
f Int
i = (a, Int) -> a
forall a b. (a, b) -> a
fst ((a, Int) -> a) -> m (a, Int) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AlphaT m a -> Int -> m (a, Int)
forall (m :: * -> *) a. AlphaT m a -> Int -> m (a, Int)
runAlphaT AlphaT m a
f Int
i

instance MonadAlpha m => MonadAlpha (ExceptT e m) where
  nextCounter :: ExceptT e m Int
nextCounter = m Int -> ExceptT e m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Int
forall (m :: * -> *). MonadAlpha m => m Int
nextCounter

instance MonadAlpha m => MonadAlpha (ReaderT r m) where
  nextCounter :: ReaderT r m Int
nextCounter = m Int -> ReaderT r m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Int
forall (m :: * -> *). MonadAlpha m => m Int
nextCounter

instance MonadAlpha m => MonadAlpha (StateT s m) where
  nextCounter :: StateT s m Int
nextCounter = m Int -> StateT s m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Int
forall (m :: * -> *). MonadAlpha m => m Int
nextCounter

instance (MonadAlpha m, Monoid w) => MonadAlpha (WriterT w m) where
  nextCounter :: WriterT w m Int
nextCounter = m Int -> WriterT w m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Int
forall (m :: * -> *). MonadAlpha m => m Int
nextCounter

evalAlpha :: AlphaT Identity a -> Int -> a
evalAlpha :: AlphaT Identity a -> Int -> a
evalAlpha AlphaT Identity a
f Int
i = Identity a -> a
forall a. Identity a -> a
runIdentity (AlphaT Identity a -> Int -> Identity a
forall (m :: * -> *) a. Functor m => AlphaT m a -> Int -> m a
evalAlphaT AlphaT Identity a
f Int
i)

resetAlphaT :: Monad m => Int -> AlphaT m ()
resetAlphaT :: Int -> AlphaT m ()
resetAlphaT Int
i = (Int -> m ((), Int)) -> AlphaT m ()
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT ((Int -> m ((), Int)) -> AlphaT m ())
-> (Int -> m ((), Int)) -> AlphaT m ()
forall a b. (a -> b) -> a -> b
$ \Int
_ -> ((), Int) -> m ((), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Int
i)