{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}
module Language.Fixpoint.Solver.Rewrite
( getRewrite
, subExprs
, unify
, ordConstraints
, convert
, passesTerminationCheck
, RewriteArgs(..)
, RWTerminationOpts(..)
, SubExpr
, TermOrigin(..)
, OCType
, RESTOrdering(..)
) where
import Control.Monad.State (guard)
import Control.Monad.Trans.Maybe
import Data.Hashable
import qualified Data.HashMap.Strict as M
import qualified Data.List as L
import qualified Data.Text as TX
import GHC.IO.Handle.Types (Handle)
import GHC.Generics
import Text.PrettyPrint (text)
import Language.Fixpoint.Types.Config (RESTOrdering(..))
import Language.Fixpoint.Types hiding (simplify)
import Language.REST
import Language.REST.KBO (kbo)
import Language.REST.LPO (lpo)
import Language.REST.OCAlgebra as OC
import Language.REST.OCToAbstract (lift)
import Language.REST.Op
import Language.REST.SMT (SMTExpr)
import Language.REST.WQOConstraints.ADT (ConstraintsADT, adtOC)
import qualified Language.REST.RuntimeTerm as RT
type SubExpr = (Expr, Expr -> Expr)
data TermOrigin = PLE | RW deriving (Int -> TermOrigin -> ShowS
[TermOrigin] -> ShowS
TermOrigin -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [TermOrigin] -> ShowS
$cshowList :: [TermOrigin] -> ShowS
show :: TermOrigin -> [Char]
$cshow :: TermOrigin -> [Char]
showsPrec :: Int -> TermOrigin -> ShowS
$cshowsPrec :: Int -> TermOrigin -> ShowS
Show, TermOrigin -> TermOrigin -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TermOrigin -> TermOrigin -> Bool
$c/= :: TermOrigin -> TermOrigin -> Bool
== :: TermOrigin -> TermOrigin -> Bool
$c== :: TermOrigin -> TermOrigin -> Bool
Eq)
instance PPrint TermOrigin where
pprintTidy :: Tidy -> TermOrigin -> Doc
pprintTidy Tidy
_ = [Char] -> Doc
text forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> [Char]
show
data RWTerminationOpts =
RWTerminationCheckEnabled
| RWTerminationCheckDisabled
data RewriteArgs = RWArgs
{ RewriteArgs -> Expr -> IO Bool
isRWValid :: Expr -> IO Bool
, RewriteArgs -> RWTerminationOpts
rwTerminationOpts :: RWTerminationOpts
}
data OCType =
RPO (ConstraintsADT Op)
| LPO (ConstraintsADT Op)
| KBO (SMTExpr Bool)
| Fuel Int
deriving (OCType -> OCType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: OCType -> OCType -> Bool
$c/= :: OCType -> OCType -> Bool
== :: OCType -> OCType -> Bool
$c== :: OCType -> OCType -> Bool
Eq, Int -> OCType -> ShowS
[OCType] -> ShowS
OCType -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [OCType] -> ShowS
$cshowList :: [OCType] -> ShowS
show :: OCType -> [Char]
$cshow :: OCType -> [Char]
showsPrec :: Int -> OCType -> ShowS
$cshowsPrec :: Int -> OCType -> ShowS
Show, forall x. Rep OCType x -> OCType
forall x. OCType -> Rep OCType x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep OCType x -> OCType
$cfrom :: forall x. OCType -> Rep OCType x
Generic, Eq OCType
Int -> OCType -> Int
OCType -> Int
forall a. Eq a -> (Int -> a -> Int) -> (a -> Int) -> Hashable a
hash :: OCType -> Int
$chash :: OCType -> Int
hashWithSalt :: Int -> OCType -> Int
$chashWithSalt :: Int -> OCType -> Int
Hashable)
ordConstraints :: RESTOrdering -> (Handle, Handle) -> OCAlgebra OCType RT.RuntimeTerm IO
ordConstraints :: RESTOrdering -> (Handle, Handle) -> OCAlgebra OCType RuntimeTerm IO
ordConstraints RESTOrdering
RESTRPO (Handle, Handle)
solver = forall c d a (m :: * -> *).
(c -> d) -> (d -> c) -> OCAlgebra c a m -> OCAlgebra d a m
bimapConstraints ConstraintsADT Op -> OCType
RPO OCType -> ConstraintsADT Op
asRPO ((Handle, Handle) -> OCAlgebra (ConstraintsADT Op) RuntimeTerm IO
adtRPO (Handle, Handle)
solver)
where
asRPO :: OCType -> ConstraintsADT Op
asRPO (RPO ConstraintsADT Op
t) = ConstraintsADT Op
t
asRPO OCType
_ = forall a. HasCallStack => a
undefined
ordConstraints RESTOrdering
RESTKBO (Handle, Handle)
solver = forall c d a (m :: * -> *).
(c -> d) -> (d -> c) -> OCAlgebra c a m -> OCAlgebra d a m
bimapConstraints SMTExpr Bool -> OCType
KBO OCType -> SMTExpr Bool
asKBO ((Handle, Handle) -> OCAlgebra (SMTExpr Bool) RuntimeTerm IO
kbo (Handle, Handle)
solver)
where
asKBO :: OCType -> SMTExpr Bool
asKBO (KBO SMTExpr Bool
t) = SMTExpr Bool
t
asKBO OCType
_ = forall a. HasCallStack => a
undefined
ordConstraints RESTOrdering
RESTLPO (Handle, Handle)
solver = forall c d a (m :: * -> *).
(c -> d) -> (d -> c) -> OCAlgebra c a m -> OCAlgebra d a m
bimapConstraints ConstraintsADT Op -> OCType
LPO OCType -> ConstraintsADT Op
asLPO (forall (impl :: * -> *) base lifted (m :: * -> *).
(ToSMTVar base Int, Ord base, Eq base, Hashable base, Show lifted,
Show base, Show (impl base)) =>
WQOConstraints impl m
-> ConstraintGen impl base lifted Identity
-> OCAlgebra (impl base) lifted m
lift ((Handle, Handle) -> WQOConstraints ConstraintsADT IO
adtOC (Handle, Handle)
solver) forall (oc :: * -> *).
(Show (oc Op), Eq (oc Op), Hashable (oc Op)) =>
ConstraintGen oc Op RuntimeTerm Identity
lpo)
where
asLPO :: OCType -> ConstraintsADT Op
asLPO (LPO ConstraintsADT Op
t) = ConstraintsADT Op
t
asLPO OCType
_ = forall a. HasCallStack => a
undefined
ordConstraints (RESTFuel Int
n) (Handle, Handle)
_ = forall c d a (m :: * -> *).
(c -> d) -> (d -> c) -> OCAlgebra c a m -> OCAlgebra d a m
bimapConstraints Int -> OCType
Fuel OCType -> Int
asFuel forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => Int -> OCAlgebra Int a m
fuelOC Int
n
where
asFuel :: OCType -> Int
asFuel (Fuel Int
n) = Int
n
asFuel OCType
_ = forall a. HasCallStack => a
undefined
convert :: Expr -> RT.RuntimeTerm
convert :: Expr -> RuntimeTerm
convert (EIte Expr
i Expr
t Expr
e) = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$ite" forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr
i,Expr
t,Expr
e]
convert e :: Expr
e@EApp{} | (Expr
f, [Expr]
terms) <- Expr -> (Expr, [Expr])
splitEAppThroughECst Expr
e, EVar Symbol
fName <- Expr -> Expr
dropECst Expr
f
= Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Symbol -> Text
symbolText Symbol
fName)) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr]
terms
convert (EVar Symbol
s) = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Symbol -> Text
symbolText Symbol
s)) []
convert (PNot Expr
e) = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$not" [ Expr -> RuntimeTerm
convert Expr
e ]
convert (PAnd [Expr]
es) = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$and" forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr]
es
convert (POr [Expr]
es) = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$or" forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr]
es
convert (PAtom Brel
s Expr
l Expr
r) = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op forall a b. (a -> b) -> a -> b
$ Text
"$atom" Text -> Text -> Text
`TX.append` ([Char] -> Text
TX.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> [Char]
show) Brel
s) [Expr -> RuntimeTerm
convert Expr
l, Expr -> RuntimeTerm
convert Expr
r]
convert (EBin Bop
o Expr
l Expr
r) = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op forall a b. (a -> b) -> a -> b
$ Text
"$ebin" Text -> Text -> Text
`TX.append` ([Char] -> Text
TX.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> [Char]
show) Bop
o) [Expr -> RuntimeTerm
convert Expr
l, Expr -> RuntimeTerm
convert Expr
r]
convert (ECon Constant
c) = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op forall a b. (a -> b) -> a -> b
$ Text
"$econ" Text -> Text -> Text
`TX.append` ([Char] -> Text
TX.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> [Char]
show) Constant
c) []
convert (ESym (SL Text
tx)) = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op Text
tx) []
convert (ECst Expr
t Sort
_) = Expr -> RuntimeTerm
convert Expr
t
convert (PIff Expr
e0 Expr
e1) = Expr -> RuntimeTerm
convert (Brel -> Expr -> Expr -> Expr
PAtom Brel
Eq Expr
e0 Expr
e1)
convert (PImp Expr
e0 Expr
e1) = Expr -> RuntimeTerm
convert ([Expr] -> Expr
POr [Expr -> Expr
PNot Expr
e0, Expr
e1])
convert Expr
e = forall a. HasCallStack => [Char] -> a
error (forall a. Show a => a -> [Char]
show Expr
e)
passesTerminationCheck :: OCAlgebra oc a IO -> RewriteArgs -> oc -> IO Bool
passesTerminationCheck :: forall oc a. OCAlgebra oc a IO -> RewriteArgs -> oc -> IO Bool
passesTerminationCheck OCAlgebra oc a IO
aoc RewriteArgs
rwArgs oc
c =
case RewriteArgs -> RWTerminationOpts
rwTerminationOpts RewriteArgs
rwArgs of
RWTerminationOpts
RWTerminationCheckEnabled -> forall c a (m :: * -> *). OCAlgebra c a m -> c -> m Bool
isSat OCAlgebra oc a IO
aoc oc
c
RWTerminationOpts
RWTerminationCheckDisabled -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
getRewrite ::
OCAlgebra oc Expr IO
-> RewriteArgs
-> oc
-> SubExpr
-> AutoRewrite
-> MaybeT IO ((Expr, Expr), Expr, oc)
getRewrite :: forall oc.
OCAlgebra oc Expr IO
-> RewriteArgs
-> oc
-> SubExpr
-> AutoRewrite
-> MaybeT IO ((Expr, Expr), Expr, oc)
getRewrite OCAlgebra oc Expr IO
aoc RewriteArgs
rwArgs oc
c (Expr
subE, Expr -> Expr
toE) (AutoRewrite [SortedReft]
args Expr
lhs Expr
rhs) =
do
Subst
su <- forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
lhs Expr
subE
let subE' :: Expr
subE' = forall a. Subable a => Subst -> a -> a
subst Subst
su Expr
rhs
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Expr
subE forall a. Eq a => a -> a -> Bool
/= Expr
subE'
let expr' :: Expr
expr' = Expr -> Expr
toE Expr
subE'
eqn :: (Expr, Expr)
eqn = (forall a. Subable a => Subst -> a -> a
subst Subst
su Expr
lhs, Expr
subE')
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Subst -> (Symbol, Expr) -> MaybeT IO ()
checkSubst Subst
su) [(Symbol, Expr)]
exprs
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ case RewriteArgs -> RWTerminationOpts
rwTerminationOpts RewriteArgs
rwArgs of
RWTerminationOpts
RWTerminationCheckEnabled ->
let
c' :: oc
c' = forall c a (m :: * -> *). OCAlgebra c a m -> c -> a -> a -> c
refine OCAlgebra oc Expr IO
aoc oc
c Expr
subE Expr
subE'
in
((Expr, Expr)
eqn, Expr
expr', oc
c')
RWTerminationOpts
RWTerminationCheckDisabled -> ((Expr, Expr)
eqn, Expr
expr', oc
c)
where
check :: Expr -> MaybeT IO ()
check :: Expr -> MaybeT IO ()
check Expr
e = do
Bool
valid <- forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RewriteArgs -> Expr -> IO Bool
isRWValid RewriteArgs
rwArgs Expr
e
forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
valid
freeVars :: [Symbol]
freeVars = [Symbol
s | RR Sort
_ (Reft (Symbol
s, Expr
_)) <- [SortedReft]
args ]
exprs :: [(Symbol, Expr)]
exprs = [(Symbol
s, Expr
e) | RR Sort
_ (Reft (Symbol
s, Expr
e)) <- [SortedReft]
args ]
checkSubst :: Subst -> (Symbol, Expr) -> MaybeT IO ()
checkSubst Subst
su (Symbol
s, Expr
e) =
do
let su' :: Subst
su' = Subst -> Subst -> Subst
catSubst Subst
su forall a b. (a -> b) -> a -> b
$ [(Symbol, Expr)] -> Subst
mkSubst [(Symbol
"VV", forall a. Subable a => Subst -> a -> a
subst Subst
su (Symbol -> Expr
EVar Symbol
s))]
Expr -> MaybeT IO ()
check forall a b. (a -> b) -> a -> b
$ forall a. Subable a => Subst -> a -> a
subst (Subst -> Subst -> Subst
catSubst Subst
su Subst
su') Expr
e
subExprs :: Expr -> [SubExpr]
subExprs :: Expr -> [SubExpr]
subExprs Expr
e = (Expr
e,forall a. a -> a
id)forall a. a -> [a] -> [a]
:Expr -> [SubExpr]
subExprs' Expr
e
subExprs' :: Expr -> [SubExpr]
subExprs' :: Expr -> [SubExpr]
subExprs' (EIte Expr
c Expr
lhs Expr
rhs) = [SubExpr]
c''
where
c' :: [SubExpr]
c' = Expr -> [SubExpr]
subExprs Expr
c
c'' :: [SubExpr]
c'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr -> Expr
EIte (Expr -> Expr
f Expr
e') Expr
lhs Expr
rhs)) [SubExpr]
c'
subExprs' (EBin Bop
op Expr
lhs Expr
rhs) = [SubExpr]
lhs'' forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
where
lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
lhs'' :: [SubExpr]
lhs'' :: [SubExpr]
lhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Bop -> Expr -> Expr -> Expr
EBin Bop
op (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
rhs'' :: [SubExpr]
rhs'' :: [SubExpr]
rhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Bop -> Expr -> Expr -> Expr
EBin Bop
op Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'
subExprs' (PImp Expr
lhs Expr
rhs) = [SubExpr]
lhs'' forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
where
lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
lhs'' :: [SubExpr]
lhs'' :: [SubExpr]
lhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PImp (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
rhs'' :: [SubExpr]
rhs'' :: [SubExpr]
rhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PImp Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'
subExprs' (PIff Expr
lhs Expr
rhs) = [SubExpr]
lhs'' forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
where
lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
lhs'' :: [SubExpr]
lhs'' :: [SubExpr]
lhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PIff (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
rhs'' :: [SubExpr]
rhs'' :: [SubExpr]
rhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PIff Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'
subExprs' (PAtom Brel
op Expr
lhs Expr
rhs) = [SubExpr]
lhs'' forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
where
lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
lhs'' :: [SubExpr]
lhs'' :: [SubExpr]
lhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Brel -> Expr -> Expr -> Expr
PAtom Brel
op (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
rhs'' :: [SubExpr]
rhs'' :: [SubExpr]
rhs'' = forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Brel -> Expr -> Expr -> Expr
PAtom Brel
op Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'
subExprs' e :: Expr
e@EApp{} =
if Expr
f forall a. Eq a => a -> a -> Bool
== Symbol -> Expr
EVar Symbol
"Language.Haskell.Liquid.ProofCombinators.===" Bool -> Bool -> Bool
||
Expr
f forall a. Eq a => a -> a -> Bool
== Symbol -> Expr
EVar Symbol
"Language.Haskell.Liquid.ProofCombinators.==." Bool -> Bool -> Bool
||
Expr
f forall a. Eq a => a -> a -> Bool
== Symbol -> Expr
EVar Symbol
"Language.Haskell.Liquid.ProofCombinators.?"
then []
else forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int, Expr) -> [SubExpr]
replace [(Int, Expr)]
indexedArgs
where
(Expr
f, [Expr]
es) = Expr -> (Expr, [Expr])
splitEApp Expr
e
indexedArgs :: [(Int, Expr)]
indexedArgs = forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [Expr]
es
replace :: (Int, Expr) -> [SubExpr]
replace (Int
i, Expr
arg) = do
(Expr
subArg, Expr -> Expr
toArg) <- Expr -> [SubExpr]
subExprs Expr
arg
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
subArg, \Expr
subArg' -> Expr -> [Expr] -> Expr
eApps Expr
f forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
i [Expr]
es forall a. [a] -> [a] -> [a]
++ Expr -> Expr
toArg Expr
subArg' forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [a]
drop (Int
iforall a. Num a => a -> a -> a
+Int
1) [Expr]
es)
subExprs' (ECst Expr
e Sort
t) =
[ (Expr
e', \Expr
subE -> Expr -> Sort -> Expr
ECst (Expr -> Expr
toE Expr
subE) Sort
t) | (Expr
e', Expr -> Expr
toE) <- Expr -> [SubExpr]
subExprs' Expr
e ]
subExprs' (PAnd [Expr]
es) = [ (Expr
e, [Expr] -> Expr
PAnd forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> [Expr]
f) | (Expr
e, Expr -> [Expr]
f) <- [Expr] -> [(Expr, Expr -> [Expr])]
subs [Expr]
es ]
subExprs' (POr [Expr]
es) = [ (Expr
e, [Expr] -> Expr
POr forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> [Expr]
f) | (Expr
e, Expr -> [Expr]
f) <- [Expr] -> [(Expr, Expr -> [Expr])]
subs [Expr]
es ]
subExprs' Expr
_ = []
subs :: [Expr] -> [(Expr, Expr -> [Expr])]
subs :: [Expr] -> [(Expr, Expr -> [Expr])]
subs [] = []
subs [Expr
x] = [ (Expr
s, \Expr
e -> [Expr -> Expr
f Expr
e]) | (Expr
s, Expr -> Expr
f) <- Expr -> [SubExpr]
subExprs Expr
x ]
subs (Expr
x:[Expr]
xs) = [ (Expr
s, \Expr
e -> Expr -> Expr
f Expr
e forall a. a -> [a] -> [a]
: [Expr]
xs) | (Expr
s, Expr -> Expr
f) <- Expr -> [SubExpr]
subExprs Expr
x ]
forall a. [a] -> [a] -> [a]
++
[ (Expr
s, \Expr
e -> Expr
x forall a. a -> [a] -> [a]
: Expr -> [Expr]
f Expr
e) | (Expr
s, Expr -> [Expr]
f) <- [Expr] -> [(Expr, Expr -> [Expr])]
subs [Expr]
xs ]
unifyAll :: [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll :: [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
_ [] [] = forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su forall k v. HashMap k v
M.empty)
unifyAll [Symbol]
freeVars (Expr
template:[Expr]
xs) (Expr
seen:[Expr]
ys) =
do
rs :: Subst
rs@(Su HashMap Symbol Expr
s1) <- [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
template Expr
seen
let xs' :: [Expr]
xs' = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Subable a => Subst -> a -> a
subst Subst
rs) [Expr]
xs
let ys' :: [Expr]
ys' = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Subable a => Subst -> a -> a
subst Subst
rs) [Expr]
ys
(Su HashMap Symbol Expr
s2) <- [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll ([Symbol]
freeVars forall a. Eq a => [a] -> [a] -> [a]
L.\\ forall k v. HashMap k v -> [k]
M.keys HashMap Symbol Expr
s1) [Expr]
xs' [Expr]
ys'
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ HashMap Symbol Expr -> Subst
Su (forall k v.
(Eq k, Hashable k) =>
HashMap k v -> HashMap k v -> HashMap k v
M.union HashMap Symbol Expr
s1 HashMap Symbol Expr
s2)
unifyAll [Symbol]
_ [Expr]
_ [Expr]
_ = forall a. HasCallStack => a
undefined
unify :: [Symbol] -> Expr -> Expr -> Maybe Subst
unify :: [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
_ Expr
template Expr
seenExpr | Expr
template forall a. Eq a => a -> a -> Bool
== Expr
seenExpr = forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su forall k v. HashMap k v
M.empty)
unify [Symbol]
freeVars Expr
template Expr
seenExpr = case (Expr -> Expr
dropECst Expr
template, Expr
seenExpr) of
(EVar Symbol
rwVar, Expr
_) | Symbol
rwVar forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Symbol]
freeVars ->
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ HashMap Symbol Expr -> Subst
Su (forall k v. Hashable k => k -> v -> HashMap k v
M.singleton Symbol
rwVar Expr
seenExpr)
(Expr
template', Expr
_) -> case (Expr
template', Expr -> Expr
dropECst Expr
seenExpr) of
(EVar Symbol
lhs, EVar Symbol
rhs) | Symbol -> [Char]
removeModName Symbol
lhs forall a. Eq a => a -> a -> Bool
== Symbol -> [Char]
removeModName Symbol
rhs ->
forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su forall k v. HashMap k v
M.empty)
where
removeModName :: Symbol -> [Char]
removeModName Symbol
ts = [Char] -> ShowS
go [Char]
"" (Symbol -> [Char]
symbolString Symbol
ts) where
go :: [Char] -> ShowS
go [Char]
buf [] = [Char]
buf
go [Char]
_ (Char
'.':[Char]
rest) = [Char] -> ShowS
go [] [Char]
rest
go [Char]
buf (Char
x:[Char]
xs) = [Char] -> ShowS
go ([Char]
buf forall a. [a] -> [a] -> [a]
++ [Char
x]) [Char]
xs
(EApp Expr
templateF Expr
templateBody, EApp Expr
seenF Expr
seenBody) ->
[Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
(ENeg Expr
rw, ENeg Expr
seen) ->
[Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
(EBin Bop
op Expr
rwLeft Expr
rwRight, EBin Bop
op' Expr
seenLeft Expr
seenRight) | Bop
op forall a. Eq a => a -> a -> Bool
== Bop
op' ->
[Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
rwLeft, Expr
rwRight] [Expr
seenLeft, Expr
seenRight]
(EIte Expr
cond Expr
rwLeft Expr
rwRight, EIte Expr
seenCond Expr
seenLeft Expr
seenRight) ->
[Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
cond, Expr
rwLeft, Expr
rwRight] [Expr
seenCond, Expr
seenLeft, Expr
seenRight]
(ECst Expr
rw Sort
_, Expr
seen) ->
[Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
(ETApp Expr
rw Sort
_, ETApp Expr
seen Sort
_) ->
[Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
(ETAbs Expr
rw Symbol
_, ETAbs Expr
seen Symbol
_) ->
[Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
(PAnd [Expr]
rw, PAnd [Expr]
seen ) ->
[Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr]
rw [Expr]
seen
(POr [Expr]
rw, POr [Expr]
seen ) ->
[Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr]
rw [Expr]
seen
(PNot Expr
rw, PNot Expr
seen) ->
[Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
(PImp Expr
templateF Expr
templateBody, PImp Expr
seenF Expr
seenBody) ->
[Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
(PIff Expr
templateF Expr
templateBody, PIff Expr
seenF Expr
seenBody) ->
[Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
(PAtom Brel
rel Expr
templateF Expr
templateBody, PAtom Brel
rel' Expr
seenF Expr
seenBody) | Brel
rel forall a. Eq a => a -> a -> Bool
== Brel
rel' ->
[Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
(PAll [(Symbol, Sort)]
_ Expr
rw, PAll [(Symbol, Sort)]
_ Expr
seen) ->
[Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
(PExist [(Symbol, Sort)]
_ Expr
rw, PExist [(Symbol, Sort)]
_ Expr
seen) ->
[Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
(PGrad KVar
_ Subst
_ GradInfo
_ Expr
rw, PGrad KVar
_ Subst
_ GradInfo
_ Expr
seen) ->
[Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
(ECoerc Sort
_ Sort
_ Expr
rw, ECoerc Sort
_ Sort
_ Expr
seen) ->
[Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
(Expr, Expr)
_ -> forall a. Maybe a
Nothing