{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Equality.Matching
( ematch
, eGraphToDatabase
, Match(..)
, compileToQuery
, module Data.Equality.Matching.Pattern
)
where
import Data.Maybe (mapMaybe)
import Data.Foldable (toList)
import Data.Containers.ListUtils
import Control.Monad
import Control.Monad.Trans.State.Strict
import qualified Data.Map.Strict as M
import qualified Data.IntMap.Strict as IM
import qualified Data.IntSet as IS
import Data.Equality.Graph
import Data.Equality.Graph.Lens
import Data.Equality.Matching.Database
import Data.Equality.Matching.Pattern
data Match = Match
{ Match -> Subst
matchSubst :: !Subst
, Match -> ClassId
matchClassId :: {-# UNPACK #-} !ClassId
}
ematch :: Language l
=> Database l
-> Pattern l
-> [Match]
ematch :: forall (l :: * -> *).
Language l =>
Database l -> Pattern l -> [Match]
ematch Database l
db Pattern l
patr =
let
(Query l
q, ClassId
root) = Pattern l -> (Query l, ClassId)
forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> (Query lang, ClassId)
compileToQuery Pattern l
patr
f :: Subst -> Maybe Match
f :: Subst -> Maybe Match
f Subst
s = if Subst -> Bool
forall a. IntMap a -> Bool
IM.null Subst
s then Maybe Match
forall a. Maybe a
Nothing
else case ClassId -> Subst -> Maybe ClassId
forall a. ClassId -> IntMap a -> Maybe a
IM.lookup ClassId
root Subst
s of
Maybe ClassId
Nothing -> [Char] -> Maybe Match
forall a. HasCallStack => [Char] -> a
error [Char]
"how is root not in map?"
Just ClassId
found -> Match -> Maybe Match
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Subst -> ClassId -> Match
Match Subst
s ClassId
found)
in (Subst -> Maybe Match) -> [Subst] -> [Match]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Subst -> Maybe Match
f (Database l -> Query l -> [Subst]
forall (l :: * -> *).
Language l =>
Database l -> Query l -> [Subst]
genericJoin Database l
db Query l
q)
eGraphToDatabase :: Language l => EGraph a l -> Database l
eGraphToDatabase :: forall (l :: * -> *) a. Language l => EGraph a l -> Database l
eGraphToDatabase EGraph a l
egr = (ENode l -> ClassId -> Database l -> Database l)
-> Database l -> NodeMap l ClassId -> Database l
forall (l :: * -> *) a b.
Ord1 l =>
(ENode l -> a -> b -> b) -> b -> NodeMap l a -> b
foldrWithKeyNM' ENode l -> ClassId -> Database l -> Database l
forall (l :: * -> *).
Language l =>
ENode l -> ClassId -> Database l -> Database l
addENodeToDB (Map (Operator l) IntTrie -> Database l
forall (lang :: * -> *).
Map (Operator lang) IntTrie -> Database lang
DB Map (Operator l) IntTrie
forall a. Monoid a => a
mempty) (EGraph a l
egrEGraph a l
-> Lens' (EGraph a l) (NodeMap l ClassId) -> NodeMap l ClassId
forall s a. s -> Lens' s a -> a
^.(NodeMap l ClassId -> f (NodeMap l ClassId))
-> EGraph a l -> f (EGraph a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l ClassId -> f (NodeMap l ClassId))
-> EGraph a l -> f (EGraph a l)
Lens' (EGraph a l) (NodeMap l ClassId)
_memo)
where
addENodeToDB :: Language l => ENode l -> ClassId -> Database l -> Database l
addENodeToDB :: forall (l :: * -> *).
Language l =>
ENode l -> ClassId -> Database l -> Database l
addENodeToDB ENode l
enode ClassId
classid (DB Map (Operator l) IntTrie
m) =
Map (Operator l) IntTrie -> Database l
forall (lang :: * -> *).
Map (Operator lang) IntTrie -> Database lang
DB (Map (Operator l) IntTrie -> Database l)
-> Map (Operator l) IntTrie -> Database l
forall a b. (a -> b) -> a -> b
$ (Maybe IntTrie -> Maybe IntTrie)
-> Operator l
-> Map (Operator l) IntTrie
-> Map (Operator l) IntTrie
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
M.alter (IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie)
-> (Maybe IntTrie -> IntTrie) -> Maybe IntTrie -> Maybe IntTrie
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ClassId] -> Maybe IntTrie -> IntTrie
populate (ClassId
classidClassId -> [ClassId] -> [ClassId]
forall a. a -> [a] -> [a]
:ENode l -> [ClassId]
forall (l :: * -> *). Traversable l => ENode l -> [ClassId]
children ENode l
enode)) (ENode l -> Operator l
forall (l :: * -> *). Traversable l => ENode l -> Operator l
operator ENode l
enode) Map (Operator l) IntTrie
m
populate :: [ClassId] -> Maybe IntTrie -> IntTrie
populate :: [ClassId] -> Maybe IntTrie -> IntTrie
populate [] Maybe IntTrie
Nothing = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie IntSet
forall a. Monoid a => a
mempty IntMap IntTrie
forall a. Monoid a => a
mempty
populate (ClassId
x:[ClassId]
xs) Maybe IntTrie
Nothing = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie (ClassId -> IntSet
IS.singleton ClassId
x) (IntMap IntTrie -> IntTrie) -> IntMap IntTrie -> IntTrie
forall a b. (a -> b) -> a -> b
$ ClassId -> IntTrie -> IntMap IntTrie
forall a. ClassId -> a -> IntMap a
IM.singleton ClassId
x ([ClassId] -> Maybe IntTrie -> IntTrie
populate [ClassId]
xs Maybe IntTrie
forall a. Maybe a
Nothing)
populate [] (Just IntTrie
it) = IntTrie
it
populate (ClassId
x:[ClassId]
xs) (Just (MkIntTrie IntSet
k IntMap IntTrie
m)) = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie (ClassId
x ClassId -> IntSet -> IntSet
`IS.insert` IntSet
k) (IntMap IntTrie -> IntTrie) -> IntMap IntTrie -> IntTrie
forall a b. (a -> b) -> a -> b
$ (Maybe IntTrie -> Maybe IntTrie)
-> ClassId -> IntMap IntTrie -> IntMap IntTrie
forall a. (Maybe a -> Maybe a) -> ClassId -> IntMap a -> IntMap a
IM.alter (IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie)
-> (Maybe IntTrie -> IntTrie) -> Maybe IntTrie -> Maybe IntTrie
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ClassId] -> Maybe IntTrie -> IntTrie
populate [ClassId]
xs) ClassId
x IntMap IntTrie
m
{-# INLINABLE eGraphToDatabase #-}
data AuxResult lang = {-# UNPACK #-} !Var :~ [Atom lang]
compileToQuery :: (Traversable lang) => Pattern lang -> (Query lang, Var)
compileToQuery :: forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> (Query lang, ClassId)
compileToQuery (VariablePattern ClassId
x) = (ClassId -> Query lang
forall (lang :: * -> *). ClassId -> Query lang
SelectAllQuery ClassId
x, ClassId
x)
compileToQuery pa :: Pattern lang
pa@(NonVariablePattern lang (Pattern lang)
_) =
let ClassId
root :~ [Atom lang]
atoms = State ClassId (AuxResult lang) -> ClassId -> AuxResult lang
forall s a. State s a -> s -> a
evalState (Pattern lang -> State ClassId (AuxResult lang)
forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State ClassId (AuxResult lang)
aux Pattern lang
pa) ClassId
0
in ([ClassId] -> [Atom lang] -> Query lang
forall (lang :: * -> *). [ClassId] -> [Atom lang] -> Query lang
Query ([ClassId] -> [ClassId]
nubInt ([ClassId] -> [ClassId]) -> [ClassId] -> [ClassId]
forall a b. (a -> b) -> a -> b
$ ClassId
rootClassId -> [ClassId] -> [ClassId]
forall a. a -> [a] -> [a]
:Pattern lang -> [ClassId]
forall (lang :: * -> *). Foldable lang => Pattern lang -> [ClassId]
vars Pattern lang
pa) [Atom lang]
atoms, ClassId
root)
where
aux :: (Traversable lang) => Pattern lang -> State Int (AuxResult lang)
aux :: forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State ClassId (AuxResult lang)
aux (VariablePattern ClassId
x) = AuxResult lang -> StateT ClassId Identity (AuxResult lang)
forall a. a -> StateT ClassId Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (ClassId
x ClassId -> [Atom lang] -> AuxResult lang
forall (lang :: * -> *). ClassId -> [Atom lang] -> AuxResult lang
:~ [])
aux (NonVariablePattern lang (Pattern lang)
p) = do
ClassId
v <- StateT ClassId Identity ClassId
forall (m :: * -> *) s. Monad m => StateT s m s
get
(ClassId -> ClassId) -> StateT ClassId Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (ClassId -> ClassId -> ClassId
forall a. Num a => a -> a -> a
+ClassId
1)
(lang (AuxResult lang) -> [AuxResult lang]
forall a. lang a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList -> [AuxResult lang]
auxs) <- (Pattern lang -> StateT ClassId Identity (AuxResult lang))
-> lang (Pattern lang)
-> StateT ClassId Identity (lang (AuxResult lang))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> lang a -> f (lang b)
traverse Pattern lang -> StateT ClassId Identity (AuxResult lang)
forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State ClassId (AuxResult lang)
aux lang (Pattern lang)
p
let boundVars :: [ClassId]
boundVars = (AuxResult lang -> ClassId) -> [AuxResult lang] -> [ClassId]
forall a b. (a -> b) -> [a] -> [b]
map (\(ClassId
b :~ [Atom lang]
_) -> ClassId
b) [AuxResult lang]
auxs
atoms :: [Atom lang]
atoms = [[Atom lang]] -> [Atom lang]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join ([[Atom lang]] -> [Atom lang]) -> [[Atom lang]] -> [Atom lang]
forall a b. (a -> b) -> a -> b
$ (AuxResult lang -> [Atom lang])
-> [AuxResult lang] -> [[Atom lang]]
forall a b. (a -> b) -> [a] -> [b]
map (\(ClassId
_ :~ [Atom lang]
a) -> [Atom lang]
a) [AuxResult lang]
auxs
p' :: lang ClassId
p' = State ClassId (lang ClassId) -> ClassId -> lang ClassId
forall s a. State s a -> s -> a
evalState (lang (Pattern lang) -> [ClassId] -> State ClassId (lang ClassId)
forall (lang :: * -> *).
Traversable lang =>
lang (Pattern lang) -> [ClassId] -> State ClassId (lang ClassId)
subPatsToVars lang (Pattern lang)
p [ClassId]
boundVars) ClassId
0
AuxResult lang -> StateT ClassId Identity (AuxResult lang)
forall a. a -> StateT ClassId Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (ClassId
v ClassId -> [Atom lang] -> AuxResult lang
forall (lang :: * -> *). ClassId -> [Atom lang] -> AuxResult lang
:~ (ClassIdOrVar -> lang ClassIdOrVar -> Atom lang
forall (lang :: * -> *).
ClassIdOrVar -> lang ClassIdOrVar -> Atom lang
Atom (ClassId -> ClassIdOrVar
CVar ClassId
v) ((ClassId -> ClassIdOrVar) -> lang ClassId -> lang ClassIdOrVar
forall a b. (a -> b) -> lang a -> lang b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ClassId -> ClassIdOrVar
CVar lang ClassId
p')Atom lang -> [Atom lang] -> [Atom lang]
forall a. a -> [a] -> [a]
:[Atom lang]
atoms))
where
subPatsToVars :: Traversable lang => lang (Pattern lang) -> [Var] -> State Int (lang Var)
subPatsToVars :: forall (lang :: * -> *).
Traversable lang =>
lang (Pattern lang) -> [ClassId] -> State ClassId (lang ClassId)
subPatsToVars lang (Pattern lang)
p' [ClassId]
boundVars = (Pattern lang -> StateT ClassId Identity ClassId)
-> lang (Pattern lang) -> StateT ClassId Identity (lang ClassId)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> lang a -> f (lang b)
traverse (StateT ClassId Identity ClassId
-> Pattern lang -> StateT ClassId Identity ClassId
forall a b. a -> b -> a
const (StateT ClassId Identity ClassId
-> Pattern lang -> StateT ClassId Identity ClassId)
-> StateT ClassId Identity ClassId
-> Pattern lang
-> StateT ClassId Identity ClassId
forall a b. (a -> b) -> a -> b
$ ([ClassId]
boundVars [ClassId] -> ClassId -> ClassId
forall a. HasCallStack => [a] -> ClassId -> a
!!) (ClassId -> ClassId)
-> StateT ClassId Identity ClassId
-> StateT ClassId Identity ClassId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (StateT ClassId Identity ClassId
forall (m :: * -> *) s. Monad m => StateT s m s
get StateT ClassId Identity ClassId
-> (ClassId -> StateT ClassId Identity ClassId)
-> StateT ClassId Identity ClassId
forall a b.
StateT ClassId Identity a
-> (a -> StateT ClassId Identity b) -> StateT ClassId Identity b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ClassId
i -> (ClassId -> ClassId) -> StateT ClassId Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (ClassId -> ClassId -> ClassId
forall a. Num a => a -> a -> a
+ClassId
1) StateT ClassId Identity ()
-> StateT ClassId Identity ClassId
-> StateT ClassId Identity ClassId
forall a b.
StateT ClassId Identity a
-> StateT ClassId Identity b -> StateT ClassId Identity b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ClassId -> StateT ClassId Identity ClassId
forall a. a -> StateT ClassId Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return ClassId
i)) lang (Pattern lang)
p'
vars :: Foldable lang => Pattern lang -> [Var]
vars :: forall (lang :: * -> *). Foldable lang => Pattern lang -> [ClassId]
vars (VariablePattern ClassId
x) = [ClassId
x]
vars (NonVariablePattern lang (Pattern lang)
p) = [ClassId] -> [ClassId]
nubInt ([ClassId] -> [ClassId]) -> [ClassId] -> [ClassId]
forall a b. (a -> b) -> a -> b
$ [[ClassId]] -> [ClassId]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join ([[ClassId]] -> [ClassId]) -> [[ClassId]] -> [ClassId]
forall a b. (a -> b) -> a -> b
$ (Pattern lang -> [ClassId]) -> [Pattern lang] -> [[ClassId]]
forall a b. (a -> b) -> [a] -> [b]
map Pattern lang -> [ClassId]
forall (lang :: * -> *). Foldable lang => Pattern lang -> [ClassId]
vars ([Pattern lang] -> [[ClassId]]) -> [Pattern lang] -> [[ClassId]]
forall a b. (a -> b) -> a -> b
$ lang (Pattern lang) -> [Pattern lang]
forall a. lang a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList lang (Pattern lang)
p
{-# INLINABLE compileToQuery #-}