{-# LANGUAGE CPP, LambdaCase, TupleSections, ViewPatterns #-}
{-# OPTIONS -Wno-name-shadowing #-}
module TypeLevel.Rewrite.Internal.ApplyRules where

import Control.Applicative
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.State
import Data.Foldable (asum, for_)
import Data.Map (Map)
import Data.Maybe (listToMaybe, maybeToList)
import Data.Traversable
import qualified Data.Map as Map

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Plugins (TyVar)
#else
import Type (TyVar)
#endif

-- term-rewriting API
import Data.Rewriting.Rule (Rule(..))
import Data.Rewriting.Substitution (gApply)
import Data.Rewriting.Term (Term(..))
import qualified Data.Rewriting.Substitution.Type as Substitution

import TypeLevel.Rewrite.Internal.TypeEq
import TypeLevel.Rewrite.Internal.TypeNode
import TypeLevel.Rewrite.Internal.TypeRule
import TypeLevel.Rewrite.Internal.TypeSubst
import TypeLevel.Rewrite.Internal.TypeTerm


type Subst = Map TyVar (Term TypeNode TypeEq)

applyRules
  :: Traversable t
  => TypeSubst
  -> [TypeRule]
  -> t TypeTerm
  -> Maybe (TypeRule,t TypeTerm)
applyRules :: forall (t :: * -> *).
Traversable t =>
TypeSubst
-> [TypeRule] -> t TypeTerm -> Maybe (TypeRule, t TypeTerm)
applyRules TypeSubst
typeSubst [TypeRule]
rules t TypeTerm
inputs
  = (TypeTerm -> Maybe (TypeRule, TypeTerm))
-> t TypeTerm -> Maybe (TypeRule, t TypeTerm)
forall (t :: * -> *) a annotation.
Traversable t =>
(a -> Maybe (annotation, a)) -> t a -> Maybe (annotation, t a)
annotatedTraverseFirst (TypeSubst -> [TypeRule] -> TypeTerm -> Maybe (TypeRule, TypeTerm)
multiRewrite TypeSubst
typeSubst [TypeRule]
rules) t TypeTerm
inputs

multiRewrite
  :: TypeSubst
  -> [TypeRule]
  -> TypeTerm
  -> Maybe (TypeRule, TypeTerm)
multiRewrite :: TypeSubst -> [TypeRule] -> TypeTerm -> Maybe (TypeRule, TypeTerm)
multiRewrite TypeSubst
typeSubst [TypeRule]
rules TypeTerm
input
  = [Maybe (TypeRule, TypeTerm)] -> Maybe (TypeRule, TypeTerm)
forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum
    [ (TypeRule
rule,) (TypeTerm -> (TypeRule, TypeTerm))
-> Maybe TypeTerm -> Maybe (TypeRule, TypeTerm)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeSubst -> TypeRule -> TypeTerm -> Maybe TypeTerm
singleRewrite TypeSubst
typeSubst TypeRule
rule TypeTerm
input
    | TypeRule
rule <- [TypeRule]
rules
    ]

-- >>> singleRewrite (F x (F x y) ~ F x y) [F a (F a b)]
-- Just [F a b]
singleRewrite
  :: TypeSubst
  -> TypeRule
  -> TypeTerm
  -> Maybe TypeTerm
singleRewrite :: TypeSubst -> TypeRule -> TypeTerm -> Maybe TypeTerm
singleRewrite TypeSubst
typeSubst TypeRule
rule input :: TypeTerm
input@(Fun TypeNode
inputF [TypeTerm]
inputXS)
    = TypeSubst -> TypeRule -> TypeTerm -> Maybe TypeTerm
topLevelRewrite TypeSubst
typeSubst TypeRule
rule TypeTerm
input
  Maybe TypeTerm -> Maybe TypeTerm -> Maybe TypeTerm
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (TypeNode -> [TypeTerm] -> TypeTerm
forall f v. f -> [Term f v] -> Term f v
Fun TypeNode
inputF ([TypeTerm] -> TypeTerm) -> Maybe [TypeTerm] -> Maybe TypeTerm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TypeTerm -> Maybe TypeTerm) -> [TypeTerm] -> Maybe [TypeTerm]
forall (t :: * -> *) a.
Traversable t =>
(a -> Maybe a) -> t a -> Maybe (t a)
traverseFirst (TypeSubst -> TypeRule -> TypeTerm -> Maybe TypeTerm
singleRewrite TypeSubst
typeSubst TypeRule
rule) [TypeTerm]
inputXS)
singleRewrite TypeSubst
typeSubst TypeRule
rule TypeTerm
input
  = TypeSubst -> TypeRule -> TypeTerm -> Maybe TypeTerm
topLevelRewrite TypeSubst
typeSubst TypeRule
rule TypeTerm
input


-- >>> topLevelRewrite (F x (F x y) ~ F x y) (F a (F a b))
-- Just (F a b)
topLevelRewrite
  :: TypeSubst
  -> TypeRule
  -> TypeTerm
  -> Maybe TypeTerm
topLevelRewrite :: TypeSubst -> TypeRule -> TypeTerm -> Maybe TypeTerm
topLevelRewrite TypeSubst
typeSubst (Rule Term TypeNode TyVar
pattern0 Term TypeNode TyVar
pattern') TypeTerm
input0 = do
  Subst
subst <- StateT Subst Maybe () -> Subst -> Maybe Subst
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT (Term TypeNode TyVar -> TypeTerm -> StateT Subst Maybe ()
go Term TypeNode TyVar
pattern0 TypeTerm
input0) Subst
forall k a. Map k a
Map.empty
  GSubst TyVar TypeNode TypeEq
-> Term TypeNode TyVar -> Maybe TypeTerm
forall v f v'.
Ord v =>
GSubst v f v' -> Term f v -> Maybe (Term f v')
gApply (Subst -> GSubst TyVar TypeNode TypeEq
forall v f v'. Map v (Term f v') -> GSubst v f v'
Substitution.fromMap Subst
subst) Term TypeNode TyVar
pattern'
  where
    go
      :: Term TypeNode TyVar
      -> TypeTerm
      -> StateT Subst Maybe ()
    go :: Term TypeNode TyVar -> TypeTerm -> StateT Subst Maybe ()
go (Var TyVar
var) TypeTerm
input = do
      Subst
subst <- StateT Subst Maybe Subst
forall (m :: * -> *) s. Monad m => StateT s m s
get
      case TyVar -> Subst -> Maybe TypeTerm
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup TyVar
var Subst
subst of
        Maybe TypeTerm
Nothing -> do
          (Subst -> Subst) -> StateT Subst Maybe ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (TyVar -> TypeTerm -> Subst -> Subst
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert TyVar
var TypeTerm
input)
        Just TypeTerm
term -> do
          Bool -> StateT Subst Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (TypeTerm
input TypeTerm -> TypeTerm -> Bool
forall a. Eq a => a -> a -> Bool
== TypeTerm
term)
    go (Fun TypeNode
patternF [Term TypeNode TyVar]
patternXS)
       (Fun TypeNode
inputF [TypeTerm]
inputXS)
       = do
      Bool -> StateT Subst Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (TypeNode
patternF TypeNode -> TypeNode -> Bool
forall a. Eq a => a -> a -> Bool
== TypeNode
inputF)
      Bool -> StateT Subst Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ([Term TypeNode TyVar] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Term TypeNode TyVar]
patternXS Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TypeTerm] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeTerm]
inputXS)
      [(Term TypeNode TyVar, TypeTerm)]
-> ((Term TypeNode TyVar, TypeTerm) -> StateT Subst Maybe ())
-> StateT Subst Maybe ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ ([Term TypeNode TyVar]
-> [TypeTerm] -> [(Term TypeNode TyVar, TypeTerm)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Term TypeNode TyVar]
patternXS [TypeTerm]
inputXS) (((Term TypeNode TyVar, TypeTerm) -> StateT Subst Maybe ())
 -> StateT Subst Maybe ())
-> ((Term TypeNode TyVar, TypeTerm) -> StateT Subst Maybe ())
-> StateT Subst Maybe ()
forall a b. (a -> b) -> a -> b
$ \(Term TypeNode TyVar
pattern, TypeTerm
input) -> do
        Term TypeNode TyVar -> TypeTerm -> StateT Subst Maybe ()
go Term TypeNode TyVar
pattern TypeTerm
input
    go Term TypeNode TyVar
pattern (Var TypeEq
var) = do
      let possibleReplacements :: [TypeTerm]
possibleReplacements = ((TypeEq, TypeTerm) -> TypeTerm) -> TypeSubst -> [TypeTerm]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TypeEq, TypeTerm) -> TypeTerm
forall a b. (a, b) -> b
snd
                               (TypeSubst -> [TypeTerm])
-> (TypeSubst -> TypeSubst) -> TypeSubst -> [TypeTerm]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((TypeEq, TypeTerm) -> Bool) -> TypeSubst -> TypeSubst
forall a. (a -> Bool) -> [a] -> [a]
filter ((TypeEq -> TypeEq -> Bool
forall a. Eq a => a -> a -> Bool
== TypeEq
var) (TypeEq -> Bool)
-> ((TypeEq, TypeTerm) -> TypeEq) -> (TypeEq, TypeTerm) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TypeEq, TypeTerm) -> TypeEq
forall a b. (a, b) -> a
fst)
                               (TypeSubst -> [TypeTerm]) -> TypeSubst -> [TypeTerm]
forall a b. (a -> b) -> a -> b
$ TypeSubst
typeSubst
      [StateT Subst Maybe ()] -> StateT Subst Maybe ()
forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum ([StateT Subst Maybe ()] -> StateT Subst Maybe ())
-> [StateT Subst Maybe ()] -> StateT Subst Maybe ()
forall a b. (a -> b) -> a -> b
$ (TypeTerm -> StateT Subst Maybe ())
-> [TypeTerm] -> [StateT Subst Maybe ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Term TypeNode TyVar -> TypeTerm -> StateT Subst Maybe ()
go Term TypeNode TyVar
pattern) [TypeTerm]
possibleReplacements

-- >>> traverseFirst (\x -> if even x then Just (10 + x) else Nothing) [1,3,5]
-- Nothing
-- >>> traverseFirst (\x -> if even x then Just (10 + x) else Nothing) [1,2,4]
-- Just [1,12,4]
traverseFirst
  :: Traversable t
  => (a -> Maybe a)
  -> t a
  -> Maybe (t a)
traverseFirst :: forall (t :: * -> *) a.
Traversable t =>
(a -> Maybe a) -> t a -> Maybe (t a)
traverseFirst a -> Maybe a
f = [t a] -> Maybe (t a)
forall a. [a] -> Maybe a
listToMaybe ([t a] -> Maybe (t a)) -> (t a -> [t a]) -> t a -> Maybe (t a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Maybe a) -> t a -> [t a]
forall (t :: * -> *) a.
Traversable t =>
(a -> Maybe a) -> t a -> [t a]
traverseAll a -> Maybe a
f

annotatedTraverseFirst
  :: Traversable t
  => (a -> Maybe (annotation, a))
  -> t a
  -> Maybe (annotation, t a)
annotatedTraverseFirst :: forall (t :: * -> *) a annotation.
Traversable t =>
(a -> Maybe (annotation, a)) -> t a -> Maybe (annotation, t a)
annotatedTraverseFirst a -> Maybe (annotation, a)
f = [(annotation, t a)] -> Maybe (annotation, t a)
forall a. [a] -> Maybe a
listToMaybe ([(annotation, t a)] -> Maybe (annotation, t a))
-> (t a -> [(annotation, t a)]) -> t a -> Maybe (annotation, t a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Maybe (annotation, a)) -> t a -> [(annotation, t a)]
forall (t :: * -> *) a annotation.
Traversable t =>
(a -> Maybe (annotation, a)) -> t a -> [(annotation, t a)]
annotatedTraverseAll a -> Maybe (annotation, a)
f

-- >>> traverseAll (\x -> if even x then Just (10 + x) else Nothing) [1,3,5]
-- []
-- >>> traverseAll (\x -> if even x then Just (10 + x) else Nothing) [1,2,4]
-- [[1,12,4], [1,2,14]]
traverseAll
  :: Traversable t
  => (a -> Maybe a)
  -> t a
  -> [t a]
traverseAll :: forall (t :: * -> *) a.
Traversable t =>
(a -> Maybe a) -> t a -> [t a]
traverseAll a -> Maybe a
f
  = (((), t a) -> t a) -> [((), t a)] -> [t a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), t a) -> t a
forall a b. (a, b) -> b
snd
  ([((), t a)] -> [t a]) -> (t a -> [((), t a)]) -> t a -> [t a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Maybe ((), a)) -> t a -> [((), t a)]
forall (t :: * -> *) a annotation.
Traversable t =>
(a -> Maybe (annotation, a)) -> t a -> [(annotation, t a)]
annotatedTraverseAll ((a -> ((), a)) -> Maybe a -> Maybe ((), a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((),) (Maybe a -> Maybe ((), a)) -> (a -> Maybe a) -> a -> Maybe ((), a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
f)

annotatedTraverseAll
  :: Traversable t
  => (a -> Maybe (annotation, a))
  -> t a
  -> [(annotation, t a)]
annotatedTraverseAll :: forall (t :: * -> *) a annotation.
Traversable t =>
(a -> Maybe (annotation, a)) -> t a -> [(annotation, t a)]
annotatedTraverseAll a -> Maybe (annotation, a)
f t a
ta = (StateT (Maybe annotation) [] (annotation, t a)
 -> Maybe annotation -> [(annotation, t a)])
-> Maybe annotation
-> StateT (Maybe annotation) [] (annotation, t a)
-> [(annotation, t a)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT (Maybe annotation) [] (annotation, t a)
-> Maybe annotation -> [(annotation, t a)]
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT Maybe annotation
forall a. Maybe a
Nothing (StateT (Maybe annotation) [] (annotation, t a)
 -> [(annotation, t a)])
-> StateT (Maybe annotation) [] (annotation, t a)
-> [(annotation, t a)]
forall a b. (a -> b) -> a -> b
$ do
  t a
ta' <- t a
-> (a -> StateT (Maybe annotation) [] a)
-> StateT (Maybe annotation) [] (t a)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for t a
ta ((a -> StateT (Maybe annotation) [] a)
 -> StateT (Maybe annotation) [] (t a))
-> (a -> StateT (Maybe annotation) [] a)
-> StateT (Maybe annotation) [] (t a)
forall a b. (a -> b) -> a -> b
$ \a
a -> do
    StateT (Maybe annotation) [] (Maybe annotation)
forall (m :: * -> *) s. Monad m => StateT s m s
get StateT (Maybe annotation) [] (Maybe annotation)
-> (Maybe annotation -> StateT (Maybe annotation) [] a)
-> StateT (Maybe annotation) [] a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just annotation
_ -> do
        -- already picked one
        a -> StateT (Maybe annotation) [] a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
      Maybe annotation
Nothing -> do
        Bool
pickIt <- [Bool] -> StateT (Maybe annotation) [] Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift [Bool
True,Bool
False]
        if Bool
pickIt
          then do
            (annotation
annotation, a
a) <- [(annotation, a)] -> StateT (Maybe annotation) [] (annotation, a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ([(annotation, a)] -> StateT (Maybe annotation) [] (annotation, a))
-> [(annotation, a)]
-> StateT (Maybe annotation) [] (annotation, a)
forall a b. (a -> b) -> a -> b
$ Maybe (annotation, a) -> [(annotation, a)]
forall a. Maybe a -> [a]
maybeToList (Maybe (annotation, a) -> [(annotation, a)])
-> Maybe (annotation, a) -> [(annotation, a)]
forall a b. (a -> b) -> a -> b
$ a -> Maybe (annotation, a)
f a
a
            Maybe annotation -> StateT (Maybe annotation) [] ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (annotation -> Maybe annotation
forall a. a -> Maybe a
Just annotation
annotation)
            a -> StateT (Maybe annotation) [] a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
          else do
            a -> StateT (Maybe annotation) [] a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
  Maybe annotation
maybeAnnotation <- StateT (Maybe annotation) [] (Maybe annotation)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  annotation
annotation <- [annotation] -> StateT (Maybe annotation) [] annotation
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ([annotation] -> StateT (Maybe annotation) [] annotation)
-> [annotation] -> StateT (Maybe annotation) [] annotation
forall a b. (a -> b) -> a -> b
$ Maybe annotation -> [annotation]
forall a. Maybe a -> [a]
maybeToList Maybe annotation
maybeAnnotation
  (annotation, t a) -> StateT (Maybe annotation) [] (annotation, t a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (annotation
annotation, t a
ta')