{-# LANGUAGE DeriveGeneric             #-}
{-# LANGUAGE OverloadedStrings         #-}
{-# LANGUAGE PatternGuards             #-}

module Language.Fixpoint.Solver.Rewrite
  ( getRewrite
  , subExprs
  , unify
  , RewriteArgs(..)
  , RWTerminationOpts(..)
  , SubExpr
  , TermOrigin(..)
  ) where

import           Control.Monad.State
import           Control.Monad.Trans.Maybe
import           GHC.Generics
import           Data.Hashable
import qualified Data.HashMap.Strict  as M
import qualified Data.HashSet         as S
import qualified Data.List            as L
import qualified Data.Maybe           as Mb
import           Language.Fixpoint.Types hiding (simplify)
import qualified Data.Text as TX

type Op = Symbol
type OpOrdering = [Symbol]
data Term = Term Symbol [Term] deriving (Term -> Term -> Bool
(Term -> Term -> Bool) -> (Term -> Term -> Bool) -> Eq Term
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Term -> Term -> Bool
$c/= :: Term -> Term -> Bool
== :: Term -> Term -> Bool
$c== :: Term -> Term -> Bool
Eq, (forall x. Term -> Rep Term x)
-> (forall x. Rep Term x -> Term) -> Generic Term
forall x. Rep Term x -> Term
forall x. Term -> Rep Term x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Term x -> Term
$cfrom :: forall x. Term -> Rep Term x
Generic)
instance Hashable Term

termSym :: Term -> Symbol
termSym :: Term -> Symbol
termSym (Term Symbol
s [Term]
_) = Symbol
s

instance Show Term where
  show :: Term -> String
show (Term Symbol
op [])   = Text -> String
TX.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Symbol -> Text
symbolText Symbol
op
  show (Term Symbol
op [Term]
args) =
    Text -> String
TX.unpack (Symbol -> Text
symbolText Symbol
op) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
L.intercalate String
", " ((Term -> String) -> [Term] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Term -> String
forall a. Show a => a -> String
show [Term]
args) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

data SCDir =
    SCUp
  | SCEq 
  | SCDown
  deriving (SCDir -> SCDir -> Bool
(SCDir -> SCDir -> Bool) -> (SCDir -> SCDir -> Bool) -> Eq SCDir
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SCDir -> SCDir -> Bool
$c/= :: SCDir -> SCDir -> Bool
== :: SCDir -> SCDir -> Bool
$c== :: SCDir -> SCDir -> Bool
Eq, Eq SCDir
Eq SCDir
-> (SCDir -> SCDir -> Ordering)
-> (SCDir -> SCDir -> Bool)
-> (SCDir -> SCDir -> Bool)
-> (SCDir -> SCDir -> Bool)
-> (SCDir -> SCDir -> Bool)
-> (SCDir -> SCDir -> SCDir)
-> (SCDir -> SCDir -> SCDir)
-> Ord SCDir
SCDir -> SCDir -> Bool
SCDir -> SCDir -> Ordering
SCDir -> SCDir -> SCDir
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SCDir -> SCDir -> SCDir
$cmin :: SCDir -> SCDir -> SCDir
max :: SCDir -> SCDir -> SCDir
$cmax :: SCDir -> SCDir -> SCDir
>= :: SCDir -> SCDir -> Bool
$c>= :: SCDir -> SCDir -> Bool
> :: SCDir -> SCDir -> Bool
$c> :: SCDir -> SCDir -> Bool
<= :: SCDir -> SCDir -> Bool
$c<= :: SCDir -> SCDir -> Bool
< :: SCDir -> SCDir -> Bool
$c< :: SCDir -> SCDir -> Bool
compare :: SCDir -> SCDir -> Ordering
$ccompare :: SCDir -> SCDir -> Ordering
$cp1Ord :: Eq SCDir
Ord, Int -> SCDir -> ShowS
[SCDir] -> ShowS
SCDir -> String
(Int -> SCDir -> ShowS)
-> (SCDir -> String) -> ([SCDir] -> ShowS) -> Show SCDir
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SCDir] -> ShowS
$cshowList :: [SCDir] -> ShowS
show :: SCDir -> String
$cshow :: SCDir -> String
showsPrec :: Int -> SCDir -> ShowS
$cshowsPrec :: Int -> SCDir -> ShowS
Show, (forall x. SCDir -> Rep SCDir x)
-> (forall x. Rep SCDir x -> SCDir) -> Generic SCDir
forall x. Rep SCDir x -> SCDir
forall x. SCDir -> Rep SCDir x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep SCDir x -> SCDir
$cfrom :: forall x. SCDir -> Rep SCDir x
Generic)

instance Hashable SCDir

type SCPath = ((Op, Int), (Op, Int), [SCDir])
type SubExpr = (Expr, Expr -> Expr)

data SCEntry = SCEntry {
    SCEntry -> (Symbol, Int)
from :: (Op, Int)
  , SCEntry -> (Symbol, Int)
to   :: (Op, Int)
  , SCEntry -> SCDir
dir  :: SCDir
} deriving (SCEntry -> SCEntry -> Bool
(SCEntry -> SCEntry -> Bool)
-> (SCEntry -> SCEntry -> Bool) -> Eq SCEntry
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SCEntry -> SCEntry -> Bool
$c/= :: SCEntry -> SCEntry -> Bool
== :: SCEntry -> SCEntry -> Bool
$c== :: SCEntry -> SCEntry -> Bool
Eq, Eq SCEntry
Eq SCEntry
-> (SCEntry -> SCEntry -> Ordering)
-> (SCEntry -> SCEntry -> Bool)
-> (SCEntry -> SCEntry -> Bool)
-> (SCEntry -> SCEntry -> Bool)
-> (SCEntry -> SCEntry -> Bool)
-> (SCEntry -> SCEntry -> SCEntry)
-> (SCEntry -> SCEntry -> SCEntry)
-> Ord SCEntry
SCEntry -> SCEntry -> Bool
SCEntry -> SCEntry -> Ordering
SCEntry -> SCEntry -> SCEntry
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SCEntry -> SCEntry -> SCEntry
$cmin :: SCEntry -> SCEntry -> SCEntry
max :: SCEntry -> SCEntry -> SCEntry
$cmax :: SCEntry -> SCEntry -> SCEntry
>= :: SCEntry -> SCEntry -> Bool
$c>= :: SCEntry -> SCEntry -> Bool
> :: SCEntry -> SCEntry -> Bool
$c> :: SCEntry -> SCEntry -> Bool
<= :: SCEntry -> SCEntry -> Bool
$c<= :: SCEntry -> SCEntry -> Bool
< :: SCEntry -> SCEntry -> Bool
$c< :: SCEntry -> SCEntry -> Bool
compare :: SCEntry -> SCEntry -> Ordering
$ccompare :: SCEntry -> SCEntry -> Ordering
$cp1Ord :: Eq SCEntry
Ord, Int -> SCEntry -> ShowS
[SCEntry] -> ShowS
SCEntry -> String
(Int -> SCEntry -> ShowS)
-> (SCEntry -> String) -> ([SCEntry] -> ShowS) -> Show SCEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SCEntry] -> ShowS
$cshowList :: [SCEntry] -> ShowS
show :: SCEntry -> String
$cshow :: SCEntry -> String
showsPrec :: Int -> SCEntry -> ShowS
$cshowsPrec :: Int -> SCEntry -> ShowS
Show, (forall x. SCEntry -> Rep SCEntry x)
-> (forall x. Rep SCEntry x -> SCEntry) -> Generic SCEntry
forall x. Rep SCEntry x -> SCEntry
forall x. SCEntry -> Rep SCEntry x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep SCEntry x -> SCEntry
$cfrom :: forall x. SCEntry -> Rep SCEntry x
Generic)

instance Hashable SCEntry

getDir :: OpOrdering -> Term -> Term -> SCDir
getDir :: OpOrdering -> Term -> Term -> SCDir
getDir OpOrdering
o Term
from Term
to =
  case (OpOrdering -> Term -> Term -> Bool
synGTE OpOrdering
o Term
from Term
to, OpOrdering -> Term -> Term -> Bool
synGTE OpOrdering
o Term
to Term
from) of
      (Bool
True, Bool
True)  -> SCDir
SCEq
      (Bool
True, Bool
False) -> SCDir
SCDown
      (Bool
False, Bool
_)    -> SCDir
SCUp

getSC :: OpOrdering -> Term -> Term -> S.HashSet SCEntry
getSC :: OpOrdering -> Term -> Term -> HashSet SCEntry
getSC OpOrdering
o (Term Symbol
op [Term]
ts) (Term Symbol
op' [Term]
us) = 
  [SCEntry] -> HashSet SCEntry
forall a. (Eq a, Hashable a) => [a] -> HashSet a
S.fromList ([SCEntry] -> HashSet SCEntry) -> [SCEntry] -> HashSet SCEntry
forall a b. (a -> b) -> a -> b
$ do
    (Int
i, Term
from) <- [Int] -> [Term] -> [(Int, Term)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [Term]
ts
    (Int
j, Term
to)   <- [Int] -> [Term] -> [(Int, Term)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [Term]
us
    SCEntry -> [SCEntry]
forall (m :: * -> *) a. Monad m => a -> m a
return (SCEntry -> [SCEntry]) -> SCEntry -> [SCEntry]
forall a b. (a -> b) -> a -> b
$ (Symbol, Int) -> (Symbol, Int) -> SCDir -> SCEntry
SCEntry (Symbol
op, Int
i) (Symbol
op', Int
j) (OpOrdering -> Term -> Term -> SCDir
getDir OpOrdering
o Term
from Term
to)

scp :: OpOrdering -> [Term] -> S.HashSet SCPath
scp :: OpOrdering -> [Term] -> HashSet SCPath
scp OpOrdering
_ []       = HashSet SCPath
forall a. HashSet a
S.empty
scp OpOrdering
_ [Term
_]      = HashSet SCPath
forall a. HashSet a
S.empty
scp OpOrdering
o [Term
t1, Term
t2] = [SCPath] -> HashSet SCPath
forall a. (Eq a, Hashable a) => [a] -> HashSet a
S.fromList ([SCPath] -> HashSet SCPath) -> [SCPath] -> HashSet SCPath
forall a b. (a -> b) -> a -> b
$ do
  (SCEntry (Symbol, Int)
a (Symbol, Int)
b SCDir
d) <- HashSet SCEntry -> [SCEntry]
forall a. HashSet a -> [a]
S.toList (HashSet SCEntry -> [SCEntry]) -> HashSet SCEntry -> [SCEntry]
forall a b. (a -> b) -> a -> b
$ OpOrdering -> Term -> Term -> HashSet SCEntry
getSC OpOrdering
o Term
t1 Term
t2
  SCPath -> [SCPath]
forall (m :: * -> *) a. Monad m => a -> m a
return ((Symbol, Int)
a, (Symbol, Int)
b, [SCDir
d])
scp OpOrdering
o (Term
t1:Term
t2:[Term]
trms) = [SCPath] -> HashSet SCPath
forall a. (Eq a, Hashable a) => [a] -> HashSet a
S.fromList ([SCPath] -> HashSet SCPath) -> [SCPath] -> HashSet SCPath
forall a b. (a -> b) -> a -> b
$ do
  (SCEntry (Symbol, Int)
a (Symbol, Int)
b' SCDir
d) <- HashSet SCEntry -> [SCEntry]
forall a. HashSet a -> [a]
S.toList (HashSet SCEntry -> [SCEntry]) -> HashSet SCEntry -> [SCEntry]
forall a b. (a -> b) -> a -> b
$ OpOrdering -> Term -> Term -> HashSet SCEntry
getSC OpOrdering
o Term
t1 Term
t2
  ((Symbol, Int)
a', (Symbol, Int)
b, [SCDir]
ds)      <- HashSet SCPath -> [SCPath]
forall a. HashSet a -> [a]
S.toList (HashSet SCPath -> [SCPath]) -> HashSet SCPath -> [SCPath]
forall a b. (a -> b) -> a -> b
$ OpOrdering -> [Term] -> HashSet SCPath
scp OpOrdering
o (Term
t2Term -> [Term] -> [Term]
forall a. a -> [a] -> [a]
:[Term]
trms)
  Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> [()]) -> Bool -> [()]
forall a b. (a -> b) -> a -> b
$ (Symbol, Int)
b' (Symbol, Int) -> (Symbol, Int) -> Bool
forall a. Eq a => a -> a -> Bool
== (Symbol, Int)
a'
  SCPath -> [SCPath]
forall (m :: * -> *) a. Monad m => a -> m a
return ((Symbol, Int)
a, (Symbol, Int)
b, SCDir
dSCDir -> [SCDir] -> [SCDir]
forall a. a -> [a] -> [a]
:[SCDir]
ds)

synEQ :: OpOrdering -> Term -> Term -> Bool
synEQ :: OpOrdering -> Term -> Term -> Bool
synEQ OpOrdering
o Term
l Term
r = OpOrdering -> Term -> Term -> Bool
synGTE OpOrdering
o Term
l Term
r Bool -> Bool -> Bool
&& OpOrdering -> Term -> Term -> Bool
synGTE OpOrdering
o Term
r Term
l

opGT :: OpOrdering -> Op -> Op -> Bool
opGT :: OpOrdering -> Symbol -> Symbol -> Bool
opGT OpOrdering
ordering Symbol
op1 Symbol
op2 = case (Symbol -> OpOrdering -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
L.elemIndex Symbol
op1 OpOrdering
ordering, Symbol -> OpOrdering -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
L.elemIndex Symbol
op2 OpOrdering
ordering) of
  (Just Int
index1, Just Int
index2) -> Int
index1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
index2
  (Just Int
_, Maybe Int
Nothing)          -> Bool
True
  (Maybe Int, Maybe Int)
_                          -> Bool
False

removeSynEQs :: OpOrdering -> [Term] -> [Term] -> ([Term], [Term])
removeSynEQs :: OpOrdering -> [Term] -> [Term] -> ([Term], [Term])
removeSynEQs OpOrdering
_ [] [Term]
ys      = ([], [Term]
ys)
removeSynEQs OpOrdering
ordering (Term
x:[Term]
xs) [Term]
ys
  | Just Int
yIndex <- (Term -> Bool) -> [Term] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
L.findIndex (OpOrdering -> Term -> Term -> Bool
synEQ OpOrdering
ordering Term
x) [Term]
ys
  = OpOrdering -> [Term] -> [Term] -> ([Term], [Term])
removeSynEQs OpOrdering
ordering [Term]
xs ([Term] -> ([Term], [Term])) -> [Term] -> ([Term], [Term])
forall a b. (a -> b) -> a -> b
$ Int -> [Term] -> [Term]
forall a. Int -> [a] -> [a]
take Int
yIndex [Term]
ys [Term] -> [Term] -> [Term]
forall a. [a] -> [a] -> [a]
++ Int -> [Term] -> [Term]
forall a. Int -> [a] -> [a]
drop (Int
yIndex Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Term]
ys
  | Bool
otherwise =
    let
      ([Term]
xs', [Term]
ys') = OpOrdering -> [Term] -> [Term] -> ([Term], [Term])
removeSynEQs OpOrdering
ordering [Term]
xs [Term]
ys
    in
      (Term
xTerm -> [Term] -> [Term]
forall a. a -> [a] -> [a]
:[Term]
xs', [Term]
ys')

synGTEM :: OpOrdering -> [Term] -> [Term] -> Bool
synGTEM :: OpOrdering -> [Term] -> [Term] -> Bool
synGTEM OpOrdering
ordering [Term]
xs [Term]
ys =     
  case OpOrdering -> [Term] -> [Term] -> ([Term], [Term])
removeSynEQs OpOrdering
ordering [Term]
xs [Term]
ys of
    ([Term]
_   , []) -> Bool
True
    ([Term]
xs', [Term]
ys') -> (Term -> Bool) -> [Term] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\Term
x -> (Term -> Bool) -> [Term] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (OpOrdering -> Term -> Term -> Bool
synGT OpOrdering
ordering Term
x) [Term]
ys') [Term]
xs'
    
synGT :: OpOrdering -> Term -> Term -> Bool
synGT :: OpOrdering -> Term -> Term -> Bool
synGT OpOrdering
o Term
t1 Term
t2 = OpOrdering -> Term -> Term -> Bool
synGTE OpOrdering
o Term
t1 Term
t2 Bool -> Bool -> Bool
&& Bool -> Bool
not (OpOrdering -> Term -> Term -> Bool
synGTE OpOrdering
o Term
t2 Term
t1)

synGTM :: OpOrdering -> [Term] -> [Term] -> Bool
synGTM :: OpOrdering -> [Term] -> [Term] -> Bool
synGTM OpOrdering
o [Term]
t1 [Term]
t2 = OpOrdering -> [Term] -> [Term] -> Bool
synGTEM OpOrdering
o [Term]
t1 [Term]
t2 Bool -> Bool -> Bool
&& Bool -> Bool
not (OpOrdering -> [Term] -> [Term] -> Bool
synGTEM OpOrdering
o [Term]
t2 [Term]
t1)

synGTE :: OpOrdering -> Term -> Term -> Bool
synGTE :: OpOrdering -> Term -> Term -> Bool
synGTE OpOrdering
ordering t1 :: Term
t1@(Term Symbol
x [Term]
tms) t2 :: Term
t2@(Term Symbol
y [Term]
tms') =
  if OpOrdering -> Symbol -> Symbol -> Bool
opGT OpOrdering
ordering Symbol
x Symbol
y then
    OpOrdering -> [Term] -> [Term] -> Bool
synGTM OpOrdering
ordering [Term
t1] [Term]
tms'
  else if OpOrdering -> Symbol -> Symbol -> Bool
opGT OpOrdering
ordering Symbol
y Symbol
x then
    OpOrdering -> [Term] -> [Term] -> Bool
synGTEM OpOrdering
ordering [Term]
tms [Term
t2]
  else
    OpOrdering -> [Term] -> [Term] -> Bool
synGTEM OpOrdering
ordering [Term]
tms [Term]
tms'

subsequencesOfSize :: Int -> [a] -> [[a]]
subsequencesOfSize :: Int -> [a] -> [[a]]
subsequencesOfSize Int
n [a]
xs = let l :: Int
l = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs
                          in if Int
nInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
l then [] else [a] -> [[[a]]]
forall a. [a] -> [[[a]]]
subsequencesBySize [a]
xs [[[a]]] -> Int -> [[a]]
forall a. [a] -> Int -> a
!! (Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
n)
 where
   subsequencesBySize :: [a] -> [[[a]]]
subsequencesBySize [] = [[[]]]
   subsequencesBySize (a
x:[a]
xs) = let next :: [[[a]]]
next = [a] -> [[[a]]]
subsequencesBySize [a]
xs
                             in ([[a]] -> [[a]] -> [[a]]) -> [[[a]]] -> [[[a]]] -> [[[a]]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [[a]] -> [[a]] -> [[a]]
forall a. [a] -> [a] -> [a]
(++) ([][[a]] -> [[[a]]] -> [[[a]]]
forall a. a -> [a] -> [a]
:[[[a]]]
next) (([[a]] -> [[a]]) -> [[[a]]] -> [[[a]]]
forall a b. (a -> b) -> [a] -> [b]
map (([a] -> [a]) -> [[a]] -> [[a]]
forall a b. (a -> b) -> [a] -> [b]
map (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:)) [[[a]]]
next [[[a]]] -> [[[a]]] -> [[[a]]]
forall a. [a] -> [a] -> [a]
++ [[]])

data TermOrigin = PLE | RW OpOrdering deriving (Int -> TermOrigin -> ShowS
[TermOrigin] -> ShowS
TermOrigin -> String
(Int -> TermOrigin -> ShowS)
-> (TermOrigin -> String)
-> ([TermOrigin] -> ShowS)
-> Show TermOrigin
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TermOrigin] -> ShowS
$cshowList :: [TermOrigin] -> ShowS
show :: TermOrigin -> String
$cshow :: TermOrigin -> String
showsPrec :: Int -> TermOrigin -> ShowS
$cshowsPrec :: Int -> TermOrigin -> ShowS
Show, TermOrigin -> TermOrigin -> Bool
(TermOrigin -> TermOrigin -> Bool)
-> (TermOrigin -> TermOrigin -> Bool) -> Eq TermOrigin
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TermOrigin -> TermOrigin -> Bool
$c/= :: TermOrigin -> TermOrigin -> Bool
== :: TermOrigin -> TermOrigin -> Bool
$c== :: TermOrigin -> TermOrigin -> Bool
Eq)

data DivergeResult = Diverging | NotDiverging OpOrdering

fromRW :: TermOrigin -> Bool
fromRW :: TermOrigin -> Bool
fromRW (RW OpOrdering
_) = Bool
True
fromRW TermOrigin
PLE    = Bool
False

getOrdering :: TermOrigin -> Maybe OpOrdering
getOrdering :: TermOrigin -> Maybe OpOrdering
getOrdering (RW OpOrdering
o) = OpOrdering -> Maybe OpOrdering
forall a. a -> Maybe a
Just OpOrdering
o
getOrdering TermOrigin
PLE    = Maybe OpOrdering
forall a. Maybe a
Nothing

diverges :: Maybe Int -> [(Term, TermOrigin)] -> Term -> DivergeResult
diverges :: Maybe Int -> [(Term, TermOrigin)] -> Term -> DivergeResult
diverges Maybe Int
maxOrderingConstraints [(Term, TermOrigin)]
path Term
term = Int -> DivergeResult
go Int
0
  where
   path' :: [Term]
path' = ((Term, TermOrigin) -> Term) -> [(Term, TermOrigin)] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (Term, TermOrigin) -> Term
forall a b. (a, b) -> a
fst [(Term, TermOrigin)]
path [Term] -> [Term] -> [Term]
forall a. [a] -> [a] -> [a]
++ [Term
term]
   go :: Int -> DivergeResult
go Int
n |    Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> OpOrdering -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length OpOrdering
syms'
          Bool -> Bool -> Bool
|| Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
Mb.fromMaybe (OpOrdering -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length OpOrdering
syms') Maybe Int
maxOrderingConstraints = DivergeResult
Diverging
   go Int
n = case (OpOrdering -> Bool) -> [OpOrdering] -> Maybe OpOrdering
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
L.find (Bool -> Bool
not (Bool -> Bool) -> (OpOrdering -> Bool) -> OpOrdering -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpOrdering -> Bool
diverges') (Int -> [OpOrdering]
orderings' Int
n) of
     Just OpOrdering
ordering -> OpOrdering -> DivergeResult
NotDiverging OpOrdering
ordering
     Maybe OpOrdering
Nothing       -> Int -> DivergeResult
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
   ops :: Term -> OpOrdering
ops (Term Symbol
o [Term]
xs) = Symbol
oSymbol -> OpOrdering -> OpOrdering
forall a. a -> [a] -> [a]
:(Term -> OpOrdering) -> [Term] -> OpOrdering
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Term -> OpOrdering
ops [Term]
xs
   syms' :: OpOrdering
syms'           = OpOrdering -> OpOrdering
forall a. Eq a => [a] -> [a]
L.nub (OpOrdering -> OpOrdering) -> OpOrdering -> OpOrdering
forall a b. (a -> b) -> a -> b
$ (Term -> OpOrdering) -> [Term] -> OpOrdering
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Term -> OpOrdering
ops [Term]
path'
   suggestedOrderings :: [OpOrdering]
   suggestedOrderings :: [OpOrdering]
suggestedOrderings =
     [OpOrdering] -> [OpOrdering]
forall a. [a] -> [a]
reverse ([OpOrdering] -> [OpOrdering]) -> [OpOrdering] -> [OpOrdering]
forall a b. (a -> b) -> a -> b
$ [Maybe OpOrdering] -> [OpOrdering]
forall a. [Maybe a] -> [a]
Mb.catMaybes ([Maybe OpOrdering] -> [OpOrdering])
-> [Maybe OpOrdering] -> [OpOrdering]
forall a b. (a -> b) -> a -> b
$ ((Term, TermOrigin) -> Maybe OpOrdering)
-> [(Term, TermOrigin)] -> [Maybe OpOrdering]
forall a b. (a -> b) -> [a] -> [b]
map (TermOrigin -> Maybe OpOrdering
getOrdering (TermOrigin -> Maybe OpOrdering)
-> ((Term, TermOrigin) -> TermOrigin)
-> (Term, TermOrigin)
-> Maybe OpOrdering
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Term, TermOrigin) -> TermOrigin
forall a b. (a, b) -> b
snd) [(Term, TermOrigin)]
path
   orderings' :: Int -> [OpOrdering]
orderings' Int
n    =
     [OpOrdering]
suggestedOrderings [OpOrdering] -> [OpOrdering] -> [OpOrdering]
forall a. [a] -> [a] -> [a]
++ (OpOrdering -> [OpOrdering]) -> [OpOrdering] -> [OpOrdering]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap OpOrdering -> [OpOrdering]
forall a. [a] -> [[a]]
L.permutations ((Int -> OpOrdering -> [OpOrdering]
forall a. Int -> [a] -> [[a]]
subsequencesOfSize Int
n) OpOrdering
syms')
   diverges' :: OpOrdering -> Bool
diverges' OpOrdering
o     = OpOrdering -> [(Term, TermOrigin)] -> Term -> Bool
divergesFor OpOrdering
o [(Term, TermOrigin)]
path Term
term

divergesFor :: OpOrdering -> [(Term, TermOrigin)] -> Term -> Bool
divergesFor :: OpOrdering -> [(Term, TermOrigin)] -> Term -> Bool
divergesFor OpOrdering
o [(Term, TermOrigin)]
path Term
term = ([Term] -> Bool) -> [[Term]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any [Term] -> Bool
diverges' [[Term]]
terms'
  where
    terms :: [Term]
terms = ((Term, TermOrigin) -> Term) -> [(Term, TermOrigin)] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (Term, TermOrigin) -> Term
forall a b. (a, b) -> a
fst [(Term, TermOrigin)]
path [Term] -> [Term] -> [Term]
forall a. [a] -> [a] -> [a]
++ [Term
term]
    lastRWIndex :: Int
lastRWIndex =
      Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
Mb.fromMaybe Int
0 (((Int, (Term, TermOrigin)) -> Int)
-> Maybe (Int, (Term, TermOrigin)) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int, (Term, TermOrigin)) -> Int
forall a b. (a, b) -> a
fst (Maybe (Int, (Term, TermOrigin)) -> Maybe Int)
-> Maybe (Int, (Term, TermOrigin)) -> Maybe Int
forall a b. (a -> b) -> a -> b
$ ((Int, (Term, TermOrigin)) -> Bool)
-> [(Int, (Term, TermOrigin))] -> Maybe (Int, (Term, TermOrigin))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
L.find (TermOrigin -> Bool
fromRW (TermOrigin -> Bool)
-> ((Int, (Term, TermOrigin)) -> TermOrigin)
-> (Int, (Term, TermOrigin))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Term, TermOrigin) -> TermOrigin
forall a b. (a, b) -> b
snd ((Term, TermOrigin) -> TermOrigin)
-> ((Int, (Term, TermOrigin)) -> (Term, TermOrigin))
-> (Int, (Term, TermOrigin))
-> TermOrigin
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, (Term, TermOrigin)) -> (Term, TermOrigin)
forall a b. (a, b) -> b
snd) ([(Int, (Term, TermOrigin))] -> Maybe (Int, (Term, TermOrigin)))
-> [(Int, (Term, TermOrigin))] -> Maybe (Int, (Term, TermOrigin))
forall a b. (a -> b) -> a -> b
$ [(Int, (Term, TermOrigin))] -> [(Int, (Term, TermOrigin))]
forall a. [a] -> [a]
reverse ([(Int, (Term, TermOrigin))] -> [(Int, (Term, TermOrigin))])
-> [(Int, (Term, TermOrigin))] -> [(Int, (Term, TermOrigin))]
forall a b. (a -> b) -> a -> b
$ [Int] -> [(Term, TermOrigin)] -> [(Int, (Term, TermOrigin))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1..] [(Term, TermOrigin)]
path) 
    okTerms :: [Term]
okTerms    = Int -> [Term] -> [Term]
forall a. Int -> [a] -> [a]
take Int
lastRWIndex [Term]
terms
    checkTerms :: [Term]
checkTerms = Int -> [Term] -> [Term]
forall a. Int -> [a] -> [a]
drop Int
lastRWIndex [Term]
terms
    terms' :: [[Term]]
terms' = [Term] -> [[Term]]
forall a. [a] -> [[a]]
L.subsequences [Term]
checkTerms [[Term]] -> [[Term]] -> [[Term]]
forall a. [a] -> [a] -> [a]
++ do
      [Term]
firstpart  <- [Term] -> [[Term]]
forall a. [a] -> [[a]]
L.tails [Term]
okTerms
      [Term]
secondpart <- [Term] -> [[Term]]
forall a. [a] -> [[a]]
L.inits [Term]
checkTerms
      [Term] -> [[Term]]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Term] -> [[Term]]) -> [Term] -> [[Term]]
forall a b. (a -> b) -> a -> b
$ [Term]
firstpart [Term] -> [Term] -> [Term]
forall a. [a] -> [a] -> [a]
++ [Term]
secondpart
    diverges' :: [Term] -> Bool
    diverges' :: [Term] -> Bool
diverges' [Term]
trms' =
      if [Term] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Term]
trms' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 Bool -> Bool -> Bool
|| Term -> Symbol
termSym ([Term] -> Term
forall a. [a] -> a
head [Term]
trms') Symbol -> Symbol -> Bool
forall a. Eq a => a -> a -> Bool
/= Term -> Symbol
termSym ([Term] -> Term
forall a. [a] -> a
last [Term]
trms') then
        Bool
False
      else
        (SCPath -> Bool) -> HashSet SCPath -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any SCPath -> Bool
ascending (OpOrdering -> [Term] -> HashSet SCPath
scp OpOrdering
o [Term]
trms') Bool -> Bool -> Bool
&& (SCPath -> Bool) -> HashSet SCPath -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Bool -> Bool
not (Bool -> Bool) -> (SCPath -> Bool) -> SCPath -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SCPath -> Bool
descending) (OpOrdering -> [Term] -> HashSet SCPath
scp OpOrdering
o [Term]
trms')
      
descending :: SCPath -> Bool
descending :: SCPath -> Bool
descending ((Symbol, Int)
a, (Symbol, Int)
b, [SCDir]
ds) = (Symbol, Int)
a (Symbol, Int) -> (Symbol, Int) -> Bool
forall a. Eq a => a -> a -> Bool
== (Symbol, Int)
b Bool -> Bool -> Bool
&& SCDir -> [SCDir] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
L.elem SCDir
SCDown [SCDir]
ds Bool -> Bool -> Bool
&& SCDir -> [SCDir] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
L.notElem SCDir
SCUp [SCDir]
ds

ascending :: SCPath -> Bool
ascending :: SCPath -> Bool
ascending  ((Symbol, Int)
a, (Symbol, Int)
b, [SCDir]
ds) = (Symbol, Int)
a (Symbol, Int) -> (Symbol, Int) -> Bool
forall a. Eq a => a -> a -> Bool
== (Symbol, Int)
b Bool -> Bool -> Bool
&& SCDir -> [SCDir] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
L.elem SCDir
SCUp [SCDir]
ds

data RWTerminationOpts =
    RWTerminationCheckEnabled (Maybe Int) -- # Of constraints to consider
  | RWTerminationCheckDisabled

data RewriteArgs = RWArgs
 { RewriteArgs -> Expr -> IO Bool
isRWValid          :: Expr -> IO Bool
 , RewriteArgs -> RWTerminationOpts
rwTerminationOpts  :: RWTerminationOpts
 }

getRewrite :: RewriteArgs -> [(Expr, TermOrigin)] -> SubExpr -> AutoRewrite -> MaybeT IO (Expr, TermOrigin)
getRewrite :: RewriteArgs
-> [(Expr, TermOrigin)]
-> SubExpr
-> AutoRewrite
-> MaybeT IO (Expr, TermOrigin)
getRewrite RewriteArgs
rwArgs [(Expr, TermOrigin)]
path (Expr
subE, Expr -> Expr
toE) (AutoRewrite [SortedReft]
args Expr
lhs Expr
rhs) =
  do
    Subst
su <- IO (Maybe Subst) -> MaybeT IO Subst
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (IO (Maybe Subst) -> MaybeT IO Subst)
-> IO (Maybe Subst) -> MaybeT IO Subst
forall a b. (a -> b) -> a -> b
$ Maybe Subst -> IO (Maybe Subst)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Subst -> IO (Maybe Subst))
-> Maybe Subst -> IO (Maybe Subst)
forall a b. (a -> b) -> a -> b
$ OpOrdering -> Expr -> Expr -> Maybe Subst
unify OpOrdering
freeVars Expr
lhs Expr
subE
    let subE' :: Expr
subE' = Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst Subst
su Expr
rhs
    let expr' :: Expr
expr' = Expr -> Expr
toE Expr
subE'
    Bool -> MaybeT IO ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT IO ()) -> Bool -> MaybeT IO ()
forall a b. (a -> b) -> a -> b
$ ((Expr, TermOrigin) -> Bool) -> [(Expr, TermOrigin)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ( (Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
/= Expr
expr') (Expr -> Bool)
-> ((Expr, TermOrigin) -> Expr) -> (Expr, TermOrigin) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Expr, TermOrigin) -> Expr
forall a b. (a, b) -> a
fst) [(Expr, TermOrigin)]
path
    (Expr -> MaybeT IO ()) -> [Expr] -> MaybeT IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Expr -> MaybeT IO ()
check (Expr -> MaybeT IO ()) -> (Expr -> Expr) -> Expr -> MaybeT IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst Subst
su) [Expr]
exprs
    let termPath :: [(Term, TermOrigin)]
termPath = ((Expr, TermOrigin) -> (Term, TermOrigin))
-> [(Expr, TermOrigin)] -> [(Term, TermOrigin)]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
t, TermOrigin
o) -> (Expr -> Term
convert Expr
t, TermOrigin
o)) [(Expr, TermOrigin)]
path
    case RewriteArgs -> RWTerminationOpts
rwTerminationOpts RewriteArgs
rwArgs of
      RWTerminationCheckEnabled Maybe Int
maxConstraints ->
        case Maybe Int -> [(Term, TermOrigin)] -> Term -> DivergeResult
diverges Maybe Int
maxConstraints [(Term, TermOrigin)]
termPath (Expr -> Term
convert Expr
expr') of
          NotDiverging OpOrdering
opOrdering  ->
            (Expr, TermOrigin) -> MaybeT IO (Expr, TermOrigin)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
expr', OpOrdering -> TermOrigin
RW OpOrdering
opOrdering)
          DivergeResult
Diverging ->
            MaybeT IO (Expr, TermOrigin)
forall (m :: * -> *) a. MonadPlus m => m a
mzero
      RWTerminationOpts
RWTerminationCheckDisabled -> (Expr, TermOrigin) -> MaybeT IO (Expr, TermOrigin)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
expr', OpOrdering -> TermOrigin
RW [])
  where
    
    convert :: Expr -> Term
convert (EIte Expr
i Expr
t Expr
e) = Symbol -> [Term] -> Term
Term Symbol
"$ite" ([Term] -> Term) -> [Term] -> Term
forall a b. (a -> b) -> a -> b
$ (Expr -> Term) -> [Expr] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> Term
convert [Expr
i,Expr
t,Expr
e]
    convert (EApp (EVar Symbol
s) (EVar Symbol
var))
      | Symbol
dcPrefix Symbol -> Symbol -> Bool
`isPrefixOfSym` Symbol
s
      = Symbol -> [Term] -> Term
Term (Text -> Symbol
forall a. Symbolic a => a -> Symbol
symbol (Text -> Symbol) -> Text -> Symbol
forall a b. (a -> b) -> a -> b
$ [Text] -> Text
TX.concat [Symbol -> Text
symbolText Symbol
s, Text
"$", Symbol -> Text
symbolText Symbol
var]) []
     
    convert e :: Expr
e@(EApp{})    | (EVar Symbol
fName, [Expr]
terms) <- Expr -> (Expr, [Expr])
splitEApp Expr
e
                          = Symbol -> [Term] -> Term
Term Symbol
fName ([Term] -> Term) -> [Term] -> Term
forall a b. (a -> b) -> a -> b
$ (Expr -> Term) -> [Expr] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> Term
convert [Expr]
terms
    convert (EVar Symbol
s)      = Symbol -> [Term] -> Term
Term Symbol
s []                  
    convert (PAnd [Expr]
es)     = Symbol -> [Term] -> Term
Term Symbol
"$and" ([Term] -> Term) -> [Term] -> Term
forall a b. (a -> b) -> a -> b
$ (Expr -> Term) -> [Expr] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> Term
convert [Expr]
es
    convert (POr [Expr]
es)      = Symbol -> [Term] -> Term
Term Symbol
"$or" ([Term] -> Term) -> [Term] -> Term
forall a b. (a -> b) -> a -> b
$ (Expr -> Term) -> [Expr] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> Term
convert [Expr]
es
    convert (PAtom Brel
s Expr
l Expr
r) = Symbol -> [Term] -> Term
Term (String -> Symbol
forall a. Symbolic a => a -> Symbol
symbol (String -> Symbol) -> String -> Symbol
forall a b. (a -> b) -> a -> b
$ String
"$atom" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Brel -> String
forall a. Show a => a -> String
show Brel
s) [Expr -> Term
convert Expr
l, Expr -> Term
convert Expr
r]
    convert (EBin Bop
o Expr
l Expr
r)  = Symbol -> [Term] -> Term
Term (String -> Symbol
forall a. Symbolic a => a -> Symbol
symbol (String -> Symbol) -> String -> Symbol
forall a b. (a -> b) -> a -> b
$ String
"$ebin" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Bop -> String
forall a. Show a => a -> String
show Bop
o) [Expr -> Term
convert Expr
l, Expr -> Term
convert Expr
r]
    convert (ECon Constant
c)      = Symbol -> [Term] -> Term
Term (String -> Symbol
forall a. Symbolic a => a -> Symbol
symbol (String -> Symbol) -> String -> Symbol
forall a b. (a -> b) -> a -> b
$ String
"$econ" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Constant -> String
forall a. Show a => a -> String
show Constant
c) []
    convert Expr
e             = String -> Term
forall a. HasCallStack => String -> a
error (Expr -> String
forall a. Show a => a -> String
show Expr
e)
    
    
    check :: Expr -> MaybeT IO ()
    check :: Expr -> MaybeT IO ()
check Expr
e = do
      Bool
valid <- IO (Maybe Bool) -> MaybeT IO Bool
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (IO (Maybe Bool) -> MaybeT IO Bool)
-> IO (Maybe Bool) -> MaybeT IO Bool
forall a b. (a -> b) -> a -> b
$ Bool -> Maybe Bool
forall a. a -> Maybe a
Just (Bool -> Maybe Bool) -> IO Bool -> IO (Maybe Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RewriteArgs -> Expr -> IO Bool
isRWValid RewriteArgs
rwArgs Expr
e
      Bool -> MaybeT IO ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
valid
      
    dcPrefix :: Symbol
dcPrefix = Symbol
"lqdc"

    freeVars :: OpOrdering
freeVars = [Symbol
s | RR Sort
_ (Reft (Symbol
s, Expr
_)) <- [SortedReft]
args ]
    exprs :: [Expr]
exprs    = [Expr
e | RR Sort
_ (Reft (Symbol
_, Expr
e)) <- [SortedReft]
args ]

subExprs :: Expr -> [SubExpr]
subExprs :: Expr -> [SubExpr]
subExprs Expr
e = (Expr
e,Expr -> Expr
forall a. a -> a
id)SubExpr -> [SubExpr] -> [SubExpr]
forall a. a -> [a] -> [a]
:Expr -> [SubExpr]
subExprs' Expr
e

subExprs' :: Expr -> [SubExpr]
subExprs' :: Expr -> [SubExpr]
subExprs' (EIte Expr
c Expr
lhs Expr
rhs)  = [SubExpr]
c'' [SubExpr] -> [SubExpr] -> [SubExpr]
forall a. [a] -> [a] -> [a]
++ [SubExpr]
l'' [SubExpr] -> [SubExpr] -> [SubExpr]
forall a. [a] -> [a] -> [a]
++ [SubExpr]
r''
  where
    c' :: [SubExpr]
c' = Expr -> [SubExpr]
subExprs Expr
c
    l' :: [SubExpr]
l' = Expr -> [SubExpr]
subExprs Expr
lhs
    r' :: [SubExpr]
r' = Expr -> [SubExpr]
subExprs Expr
rhs
    c'' :: [SubExpr]
c'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr -> Expr
EIte (Expr -> Expr
f Expr
e') Expr
lhs Expr
rhs)) [SubExpr]
c'
    l'' :: [SubExpr]
l'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr -> Expr
EIte Expr
c (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
l'
    r'' :: [SubExpr]
r'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr -> Expr
EIte Expr
c Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
r'
    
subExprs' (EBin Bop
op Expr
lhs Expr
rhs) = [SubExpr]
lhs'' [SubExpr] -> [SubExpr] -> [SubExpr]
forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Bop -> Expr -> Expr -> Expr
EBin Bop
op (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Bop -> Expr -> Expr -> Expr
EBin Bop
op Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'
    
subExprs' (PImp Expr
lhs Expr
rhs) = [SubExpr]
lhs'' [SubExpr] -> [SubExpr] -> [SubExpr]
forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PImp (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PImp Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'
    
subExprs' (PAtom Brel
op Expr
lhs Expr
rhs) = [SubExpr]
lhs'' [SubExpr] -> [SubExpr] -> [SubExpr]
forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Brel -> Expr -> Expr -> Expr
PAtom Brel
op (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Brel -> Expr -> Expr -> Expr
PAtom Brel
op Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'

-- subExprs' e@(EApp{}) = concatMap replace indexedArgs
--   where
--     (f, es)          = splitEApp e
--     indexedArgs      = zip [0..] es
--     replace (i, arg) = do
--       (subArg, toArg) <- subExprs arg
--       return (subArg, \subArg' -> eApps f $ (take i es) ++ (toArg subArg'):(drop (i+1) es))
      
subExprs' Expr
_ = []

unifyAll :: [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll :: OpOrdering -> [Expr] -> [Expr] -> Maybe Subst
unifyAll OpOrdering
_ []     []               = Subst -> Maybe Subst
forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su HashMap Symbol Expr
forall k v. HashMap k v
M.empty)
unifyAll OpOrdering
freeVars (Expr
template:[Expr]
xs) (Expr
seen:[Expr]
ys) =
  do
    rs :: Subst
rs@(Su HashMap Symbol Expr
s1) <- OpOrdering -> Expr -> Expr -> Maybe Subst
unify OpOrdering
freeVars Expr
template Expr
seen
    let xs' :: [Expr]
xs' = (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst Subst
rs) [Expr]
xs
    let ys' :: [Expr]
ys' = (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst Subst
rs) [Expr]
ys
    (Su HashMap Symbol Expr
s2) <- OpOrdering -> [Expr] -> [Expr] -> Maybe Subst
unifyAll (OpOrdering
freeVars OpOrdering -> OpOrdering -> OpOrdering
forall a. Eq a => [a] -> [a] -> [a]
L.\\ HashMap Symbol Expr -> OpOrdering
forall k v. HashMap k v -> [k]
M.keys HashMap Symbol Expr
s1) [Expr]
xs' [Expr]
ys'
    Subst -> Maybe Subst
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Maybe Subst) -> Subst -> Maybe Subst
forall a b. (a -> b) -> a -> b
$ HashMap Symbol Expr -> Subst
Su (HashMap Symbol Expr -> HashMap Symbol Expr -> HashMap Symbol Expr
forall k v.
(Eq k, Hashable k) =>
HashMap k v -> HashMap k v -> HashMap k v
M.union HashMap Symbol Expr
s1 HashMap Symbol Expr
s2)
unifyAll OpOrdering
_ [Expr]
_ [Expr]
_ = Maybe Subst
forall a. HasCallStack => a
undefined

unify :: [Symbol] -> Expr -> Expr -> Maybe Subst
unify :: OpOrdering -> Expr -> Expr -> Maybe Subst
unify OpOrdering
_ Expr
template Expr
seenExpr | Expr
template Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
seenExpr = Subst -> Maybe Subst
forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su HashMap Symbol Expr
forall k v. HashMap k v
M.empty)
unify OpOrdering
freeVars Expr
template Expr
seenExpr = case (Expr
template, Expr
seenExpr) of
  (EVar Symbol
rwVar, Expr
_) | Symbol
rwVar Symbol -> OpOrdering -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` OpOrdering
freeVars ->
    Subst -> Maybe Subst
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Maybe Subst) -> Subst -> Maybe Subst
forall a b. (a -> b) -> a -> b
$ HashMap Symbol Expr -> Subst
Su (Symbol -> Expr -> HashMap Symbol Expr
forall k v. Hashable k => k -> v -> HashMap k v
M.singleton Symbol
rwVar Expr
seenExpr)
  (EApp Expr
templateF Expr
templateBody, EApp Expr
seenF Expr
seenBody) ->
    OpOrdering -> [Expr] -> [Expr] -> Maybe Subst
unifyAll OpOrdering
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
  (ENeg Expr
rw, ENeg Expr
seen) ->
    OpOrdering -> Expr -> Expr -> Maybe Subst
unify OpOrdering
freeVars Expr
rw Expr
seen
  (EBin Bop
op Expr
rwLeft Expr
rwRight, EBin Bop
op' Expr
seenLeft Expr
seenRight) | Bop
op Bop -> Bop -> Bool
forall a. Eq a => a -> a -> Bool
== Bop
op' ->
    OpOrdering -> [Expr] -> [Expr] -> Maybe Subst
unifyAll OpOrdering
freeVars [Expr
rwLeft, Expr
rwRight] [Expr
seenLeft, Expr
seenRight]
  (EIte Expr
cond Expr
rwLeft Expr
rwRight, EIte Expr
seenCond Expr
seenLeft Expr
seenRight) ->
    OpOrdering -> [Expr] -> [Expr] -> Maybe Subst
unifyAll OpOrdering
freeVars [Expr
cond, Expr
rwLeft, Expr
rwRight] [Expr
seenCond, Expr
seenLeft, Expr
seenRight]
  (ECst Expr
rw Sort
_, ECst Expr
seen Sort
_) ->
    OpOrdering -> Expr -> Expr -> Maybe Subst
unify OpOrdering
freeVars Expr
rw Expr
seen
  (ETApp Expr
rw Sort
_, ETApp Expr
seen Sort
_) ->
    OpOrdering -> Expr -> Expr -> Maybe Subst
unify OpOrdering
freeVars Expr
rw Expr
seen
  (ETAbs Expr
rw Symbol
_, ETAbs Expr
seen Symbol
_) ->
    OpOrdering -> Expr -> Expr -> Maybe Subst
unify OpOrdering
freeVars Expr
rw Expr
seen
  (PAnd [Expr]
rw, PAnd [Expr]
seen ) ->
    OpOrdering -> [Expr] -> [Expr] -> Maybe Subst
unifyAll OpOrdering
freeVars [Expr]
rw [Expr]
seen
  (POr [Expr]
rw, POr [Expr]
seen ) ->
    OpOrdering -> [Expr] -> [Expr] -> Maybe Subst
unifyAll OpOrdering
freeVars [Expr]
rw [Expr]
seen
  (PNot Expr
rw, PNot Expr
seen) ->
    OpOrdering -> Expr -> Expr -> Maybe Subst
unify OpOrdering
freeVars Expr
rw Expr
seen
  (PImp Expr
templateF Expr
templateBody, PImp Expr
seenF Expr
seenBody) ->
    OpOrdering -> [Expr] -> [Expr] -> Maybe Subst
unifyAll OpOrdering
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
  (PIff Expr
templateF Expr
templateBody, PIff Expr
seenF Expr
seenBody) ->
    OpOrdering -> [Expr] -> [Expr] -> Maybe Subst
unifyAll OpOrdering
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
  (PAtom Brel
rel Expr
templateF Expr
templateBody, PAtom Brel
rel' Expr
seenF Expr
seenBody) | Brel
rel Brel -> Brel -> Bool
forall a. Eq a => a -> a -> Bool
== Brel
rel' ->
    OpOrdering -> [Expr] -> [Expr] -> Maybe Subst
unifyAll OpOrdering
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
  (PAll [(Symbol, Sort)]
_ Expr
rw, PAll [(Symbol, Sort)]
_ Expr
seen) ->
    OpOrdering -> Expr -> Expr -> Maybe Subst
unify OpOrdering
freeVars Expr
rw Expr
seen
  (PExist [(Symbol, Sort)]
_ Expr
rw, PExist [(Symbol, Sort)]
_ Expr
seen) ->
    OpOrdering -> Expr -> Expr -> Maybe Subst
unify OpOrdering
freeVars Expr
rw Expr
seen
  (PGrad KVar
_ Subst
_ GradInfo
_ Expr
rw, PGrad KVar
_ Subst
_ GradInfo
_ Expr
seen) ->
    OpOrdering -> Expr -> Expr -> Maybe Subst
unify OpOrdering
freeVars Expr
rw Expr
seen
  (ECoerc Sort
_ Sort
_ Expr
rw, ECoerc Sort
_ Sort
_ Expr
seen) ->
    OpOrdering -> Expr -> Expr -> Maybe Subst
unify OpOrdering
freeVars Expr
rw Expr
seen
  (Expr, Expr)
_ -> Maybe Subst
forall a. Maybe a
Nothing