{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

{-|
Module       : ATP.Internal.Enumeration
Description  : The helper Enumeration monad used to describe computations that
               carry on a renaming of values to consecutive numbers.
Copyright    : (c) Evgenii Kotelnikov, 2019-2021
License      : GPL-3
Maintainer   : evgeny.kotelnikov@gmail.com
Stability    : experimental
-}

module ATP.Internal.Enumeration (
  EnumerationT(..),
  evalEnumerationT,
  Enumeration,
  evalEnumeration,
  next,
  enumerate,
  alias
) where

import Control.Monad.State (MonadTrans, MonadState, StateT, evalStateT, gets, modify)
import Data.Functor.Identity (Identity(..))
import Data.Map (Map)
import qualified Data.Map as M (empty, lookup, insert)


newtype EnumerationT a m s = EnumerationT {
  EnumerationT a m s -> StateT (Integer, Map a Integer) m s
runEnumerationT :: StateT (Integer, Map a Integer) m s
} deriving (a -> EnumerationT a m b -> EnumerationT a m a
(a -> b) -> EnumerationT a m a -> EnumerationT a m b
(forall a b. (a -> b) -> EnumerationT a m a -> EnumerationT a m b)
-> (forall a b. a -> EnumerationT a m b -> EnumerationT a m a)
-> Functor (EnumerationT a m)
forall a b. a -> EnumerationT a m b -> EnumerationT a m a
forall a b. (a -> b) -> EnumerationT a m a -> EnumerationT a m b
forall a (m :: * -> *) a b.
Functor m =>
a -> EnumerationT a m b -> EnumerationT a m a
forall a (m :: * -> *) a b.
Functor m =>
(a -> b) -> EnumerationT a m a -> EnumerationT a m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> EnumerationT a m b -> EnumerationT a m a
$c<$ :: forall a (m :: * -> *) a b.
Functor m =>
a -> EnumerationT a m b -> EnumerationT a m a
fmap :: (a -> b) -> EnumerationT a m a -> EnumerationT a m b
$cfmap :: forall a (m :: * -> *) a b.
Functor m =>
(a -> b) -> EnumerationT a m a -> EnumerationT a m b
Functor, Functor (EnumerationT a m)
a -> EnumerationT a m a
Functor (EnumerationT a m)
-> (forall a. a -> EnumerationT a m a)
-> (forall a b.
    EnumerationT a m (a -> b)
    -> EnumerationT a m a -> EnumerationT a m b)
-> (forall a b c.
    (a -> b -> c)
    -> EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m c)
-> (forall a b.
    EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m b)
-> (forall a b.
    EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m a)
-> Applicative (EnumerationT a m)
EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m b
EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m a
EnumerationT a m (a -> b)
-> EnumerationT a m a -> EnumerationT a m b
(a -> b -> c)
-> EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m c
forall a. a -> EnumerationT a m a
forall a b.
EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m a
forall a b.
EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m b
forall a b.
EnumerationT a m (a -> b)
-> EnumerationT a m a -> EnumerationT a m b
forall a b c.
(a -> b -> c)
-> EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m c
forall a (m :: * -> *). Monad m => Functor (EnumerationT a m)
forall a (m :: * -> *) a. Monad m => a -> EnumerationT a m a
forall a (m :: * -> *) a b.
Monad m =>
EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m a
forall a (m :: * -> *) a b.
Monad m =>
EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m b
forall a (m :: * -> *) a b.
Monad m =>
EnumerationT a m (a -> b)
-> EnumerationT a m a -> EnumerationT a m b
forall a (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m a
$c<* :: forall a (m :: * -> *) a b.
Monad m =>
EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m a
*> :: EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m b
$c*> :: forall a (m :: * -> *) a b.
Monad m =>
EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m b
liftA2 :: (a -> b -> c)
-> EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m c
$cliftA2 :: forall a (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m c
<*> :: EnumerationT a m (a -> b)
-> EnumerationT a m a -> EnumerationT a m b
$c<*> :: forall a (m :: * -> *) a b.
Monad m =>
EnumerationT a m (a -> b)
-> EnumerationT a m a -> EnumerationT a m b
pure :: a -> EnumerationT a m a
$cpure :: forall a (m :: * -> *) a. Monad m => a -> EnumerationT a m a
$cp1Applicative :: forall a (m :: * -> *). Monad m => Functor (EnumerationT a m)
Applicative, Applicative (EnumerationT a m)
a -> EnumerationT a m a
Applicative (EnumerationT a m)
-> (forall a b.
    EnumerationT a m a
    -> (a -> EnumerationT a m b) -> EnumerationT a m b)
-> (forall a b.
    EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m b)
-> (forall a. a -> EnumerationT a m a)
-> Monad (EnumerationT a m)
EnumerationT a m a
-> (a -> EnumerationT a m b) -> EnumerationT a m b
EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m b
forall a. a -> EnumerationT a m a
forall a b.
EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m b
forall a b.
EnumerationT a m a
-> (a -> EnumerationT a m b) -> EnumerationT a m b
forall a (m :: * -> *). Monad m => Applicative (EnumerationT a m)
forall a (m :: * -> *) a. Monad m => a -> EnumerationT a m a
forall a (m :: * -> *) a b.
Monad m =>
EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m b
forall a (m :: * -> *) a b.
Monad m =>
EnumerationT a m a
-> (a -> EnumerationT a m b) -> EnumerationT a m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> EnumerationT a m a
$creturn :: forall a (m :: * -> *) a. Monad m => a -> EnumerationT a m a
>> :: EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m b
$c>> :: forall a (m :: * -> *) a b.
Monad m =>
EnumerationT a m a -> EnumerationT a m b -> EnumerationT a m b
>>= :: EnumerationT a m a
-> (a -> EnumerationT a m b) -> EnumerationT a m b
$c>>= :: forall a (m :: * -> *) a b.
Monad m =>
EnumerationT a m a
-> (a -> EnumerationT a m b) -> EnumerationT a m b
$cp1Monad :: forall a (m :: * -> *). Monad m => Applicative (EnumerationT a m)
Monad, m a -> EnumerationT a m a
(forall (m :: * -> *) a. Monad m => m a -> EnumerationT a m a)
-> MonadTrans (EnumerationT a)
forall a (m :: * -> *) a. Monad m => m a -> EnumerationT a m a
forall (m :: * -> *) a. Monad m => m a -> EnumerationT a m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: m a -> EnumerationT a m a
$clift :: forall a (m :: * -> *) a. Monad m => m a -> EnumerationT a m a
MonadTrans, MonadState (Integer, Map a Integer))

evalEnumerationT :: Monad m => EnumerationT a m e -> m e
evalEnumerationT :: EnumerationT a m e -> m e
evalEnumerationT EnumerationT a m e
e = StateT (Integer, Map a Integer) m e
-> (Integer, Map a Integer) -> m e
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (EnumerationT a m e -> StateT (Integer, Map a Integer) m e
forall a (m :: * -> *) s.
EnumerationT a m s -> StateT (Integer, Map a Integer) m s
runEnumerationT EnumerationT a m e
e) (Integer
1, Map a Integer
forall k a. Map k a
M.empty)

type Enumeration a = EnumerationT a Identity

evalEnumeration :: Enumeration a e -> e
evalEnumeration :: Enumeration a e -> e
evalEnumeration = Identity e -> e
forall a. Identity a -> a
runIdentity (Identity e -> e)
-> (Enumeration a e -> Identity e) -> Enumeration a e -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Enumeration a e -> Identity e
forall (m :: * -> *) a e. Monad m => EnumerationT a m e -> m e
evalEnumerationT

next :: Monad m => EnumerationT a m Integer
next :: EnumerationT a m Integer
next = do
  Integer
i <- ((Integer, Map a Integer) -> Integer) -> EnumerationT a m Integer
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Integer, Map a Integer) -> Integer
forall a b. (a, b) -> a
fst
  ((Integer, Map a Integer) -> (Integer, Map a Integer))
-> EnumerationT a m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((Integer, Map a Integer) -> (Integer, Map a Integer))
 -> EnumerationT a m ())
-> ((Integer, Map a Integer) -> (Integer, Map a Integer))
-> EnumerationT a m ()
forall a b. (a -> b) -> a -> b
$ \(Integer
j, Map a Integer
m) -> (Integer
j Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1, Map a Integer
m)
  Integer -> EnumerationT a m Integer
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
i

enumerate :: (Ord a, Monad m) => a -> EnumerationT a m Integer
enumerate :: a -> EnumerationT a m Integer
enumerate a
v = ((Integer, Map a Integer) -> Maybe Integer)
-> EnumerationT a m (Maybe Integer)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (a -> Map a Integer -> Maybe Integer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup a
v (Map a Integer -> Maybe Integer)
-> ((Integer, Map a Integer) -> Map a Integer)
-> (Integer, Map a Integer)
-> Maybe Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer, Map a Integer) -> Map a Integer
forall a b. (a, b) -> b
snd) EnumerationT a m (Maybe Integer)
-> (Maybe Integer -> EnumerationT a m Integer)
-> EnumerationT a m Integer
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Just Integer
w -> Integer -> EnumerationT a m Integer
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
w
  Maybe Integer
Nothing -> do
    Integer
i <- EnumerationT a m Integer
forall (m :: * -> *) a. Monad m => EnumerationT a m Integer
next
    ((Integer, Map a Integer) -> (Integer, Map a Integer))
-> EnumerationT a m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((Integer, Map a Integer) -> (Integer, Map a Integer))
 -> EnumerationT a m ())
-> ((Integer, Map a Integer) -> (Integer, Map a Integer))
-> EnumerationT a m ()
forall a b. (a -> b) -> a -> b
$ (Map a Integer -> Map a Integer)
-> (Integer, Map a Integer) -> (Integer, Map a Integer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> Integer -> Map a Integer -> Map a Integer
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert a
v Integer
i)
    Integer -> EnumerationT a m Integer
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
i

alias :: (Ord a, Monad m) => a -> a -> EnumerationT a m ()
alias :: a -> a -> EnumerationT a m ()
alias a
a a
b = ((Integer, Map a Integer) -> (Maybe Integer, Maybe Integer))
-> EnumerationT a m (Maybe Integer, Maybe Integer)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (\(Integer
_, Map a Integer
m) -> (a -> Map a Integer -> Maybe Integer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup a
a Map a Integer
m, a -> Map a Integer -> Maybe Integer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup a
b Map a Integer
m)) EnumerationT a m (Maybe Integer, Maybe Integer)
-> ((Maybe Integer, Maybe Integer) -> EnumerationT a m ())
-> EnumerationT a m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  (Just Integer
i,  Maybe Integer
Nothing) -> ((Integer, Map a Integer) -> (Integer, Map a Integer))
-> EnumerationT a m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((Integer, Map a Integer) -> (Integer, Map a Integer))
 -> EnumerationT a m ())
-> ((Integer, Map a Integer) -> (Integer, Map a Integer))
-> EnumerationT a m ()
forall a b. (a -> b) -> a -> b
$ (Map a Integer -> Map a Integer)
-> (Integer, Map a Integer) -> (Integer, Map a Integer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> Integer -> Map a Integer -> Map a Integer
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert a
b Integer
i)
  (Maybe Integer
Nothing, Just Integer
i)  -> ((Integer, Map a Integer) -> (Integer, Map a Integer))
-> EnumerationT a m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((Integer, Map a Integer) -> (Integer, Map a Integer))
 -> EnumerationT a m ())
-> ((Integer, Map a Integer) -> (Integer, Map a Integer))
-> EnumerationT a m ()
forall a b. (a -> b) -> a -> b
$ (Map a Integer -> Map a Integer)
-> (Integer, Map a Integer) -> (Integer, Map a Integer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> Integer -> Map a Integer -> Map a Integer
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert a
a Integer
i)
  (Maybe Integer
_, Maybe Integer
_) -> do
    Integer
i <- EnumerationT a m Integer
forall (m :: * -> *) a. Monad m => EnumerationT a m Integer
next
    ((Integer, Map a Integer) -> (Integer, Map a Integer))
-> EnumerationT a m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((Integer, Map a Integer) -> (Integer, Map a Integer))
 -> EnumerationT a m ())
-> ((Integer, Map a Integer) -> (Integer, Map a Integer))
-> EnumerationT a m ()
forall a b. (a -> b) -> a -> b
$ (Map a Integer -> Map a Integer)
-> (Integer, Map a Integer) -> (Integer, Map a Integer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> Integer -> Map a Integer -> Map a Integer
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert a
a Integer
i)
    ((Integer, Map a Integer) -> (Integer, Map a Integer))
-> EnumerationT a m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((Integer, Map a Integer) -> (Integer, Map a Integer))
 -> EnumerationT a m ())
-> ((Integer, Map a Integer) -> (Integer, Map a Integer))
-> EnumerationT a m ()
forall a b. (a -> b) -> a -> b
$ (Map a Integer -> Map a Integer)
-> (Integer, Map a Integer) -> (Integer, Map a Integer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> Integer -> Map a Integer -> Map a Integer
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert a
b Integer
i)