{-# LANGUAGE DeriveGeneric             #-}
{-# LANGUAGE OverloadedStrings         #-}
{-# LANGUAGE PatternGuards             #-}
{-# LANGUAGE ScopedTypeVariables       #-}

module Language.Fixpoint.Solver.Rewrite
  ( getRewrite
  -- , getRewrite'
  , subExprs
  , unify
  , ordConstraints
  , convert
  , passesTerminationCheck
  , RewriteArgs(..)
  , RWTerminationOpts(..)
  , SubExpr
  , TermOrigin(..)
  ) where

import           Control.Monad.State
import           Control.Monad.Trans.Maybe
import qualified Data.HashMap.Strict  as M
import qualified Data.List            as L
import qualified Data.Text as TX
import           GHC.IO.Handle.Types (Handle)
import           Text.PrettyPrint (text)
import           Language.Fixpoint.Types hiding (simplify)
import           Language.REST
import           Language.REST.AbstractOC
import qualified Language.REST.RuntimeTerm as RT
import           Language.REST.Op
import           Language.REST.OrderingConstraints.ADT (ConstraintsADT)

type SubExpr = (Expr, Expr -> Expr)

data TermOrigin = PLE | RW deriving (Int -> TermOrigin -> ShowS
[TermOrigin] -> ShowS
TermOrigin -> String
(Int -> TermOrigin -> ShowS)
-> (TermOrigin -> String)
-> ([TermOrigin] -> ShowS)
-> Show TermOrigin
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TermOrigin] -> ShowS
$cshowList :: [TermOrigin] -> ShowS
show :: TermOrigin -> String
$cshow :: TermOrigin -> String
showsPrec :: Int -> TermOrigin -> ShowS
$cshowsPrec :: Int -> TermOrigin -> ShowS
Show, TermOrigin -> TermOrigin -> Bool
(TermOrigin -> TermOrigin -> Bool)
-> (TermOrigin -> TermOrigin -> Bool) -> Eq TermOrigin
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TermOrigin -> TermOrigin -> Bool
$c/= :: TermOrigin -> TermOrigin -> Bool
== :: TermOrigin -> TermOrigin -> Bool
$c== :: TermOrigin -> TermOrigin -> Bool
Eq)

instance PPrint TermOrigin where
  pprintTidy :: Tidy -> TermOrigin -> Doc
pprintTidy Tidy
_ = String -> Doc
text (String -> Doc) -> (TermOrigin -> String) -> TermOrigin -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermOrigin -> String
forall a. Show a => a -> String
show


data RWTerminationOpts =
    RWTerminationCheckEnabled
  | RWTerminationCheckDisabled

data RewriteArgs = RWArgs
 { RewriteArgs -> Expr -> IO Bool
isRWValid          :: Expr -> IO Bool
 , RewriteArgs -> RWTerminationOpts
rwTerminationOpts  :: RWTerminationOpts
 }

ordConstraints :: (Handle, Handle) -> AbstractOC (ConstraintsADT Op) Expr IO
ordConstraints :: (Handle, Handle) -> AbstractOC (ConstraintsADT Op) Expr IO
ordConstraints (Handle, Handle)
solver = (Expr -> RuntimeTerm)
-> AbstractOC (ConstraintsADT Op) RuntimeTerm IO
-> AbstractOC (ConstraintsADT Op) Expr IO
forall c a b (m :: * -> *).
(b -> a) -> AbstractOC c a m -> AbstractOC c b m
contramap Expr -> RuntimeTerm
convert ((Handle, Handle) -> AbstractOC (ConstraintsADT Op) RuntimeTerm IO
adtRPO (Handle, Handle)
solver)


convert :: Expr -> RT.RuntimeTerm
convert :: Expr -> RuntimeTerm
convert (EIte Expr
i Expr
t Expr
e)   = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$ite" ([RuntimeTerm] -> RuntimeTerm) -> [RuntimeTerm] -> RuntimeTerm
forall a b. (a -> b) -> a -> b
$ (Expr -> RuntimeTerm) -> [Expr] -> [RuntimeTerm]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr
i,Expr
t,Expr
e]
convert e :: Expr
e@(EApp{})     | (EVar Symbol
fName, [Expr]
terms) <- Expr -> (Expr, [Expr])
splitEApp Expr
e
                       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Symbol -> Text
symbolText Symbol
fName)) ([RuntimeTerm] -> RuntimeTerm) -> [RuntimeTerm] -> RuntimeTerm
forall a b. (a -> b) -> a -> b
$ (Expr -> RuntimeTerm) -> [Expr] -> [RuntimeTerm]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr]
terms
convert (EVar Symbol
s)       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Symbol -> Text
symbolText Symbol
s)) []
convert (PNot Expr
e)       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$not" [ Expr -> RuntimeTerm
convert Expr
e ]
convert (PAnd [Expr]
es)      = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$and" ([RuntimeTerm] -> RuntimeTerm) -> [RuntimeTerm] -> RuntimeTerm
forall a b. (a -> b) -> a -> b
$ (Expr -> RuntimeTerm) -> [Expr] -> [RuntimeTerm]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr]
es
convert (POr [Expr]
es)       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$or" ([RuntimeTerm] -> RuntimeTerm) -> [RuntimeTerm] -> RuntimeTerm
forall a b. (a -> b) -> a -> b
$ (Expr -> RuntimeTerm) -> [Expr] -> [RuntimeTerm]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr]
es
convert (PAtom Brel
s Expr
l Expr
r)  = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Text -> Op) -> Text -> Op
forall a b. (a -> b) -> a -> b
$ Text
"$atom" Text -> Text -> Text
`TX.append` (String -> Text
TX.pack (String -> Text) -> (Brel -> String) -> Brel -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Brel -> String
forall a. Show a => a -> String
show) Brel
s) [Expr -> RuntimeTerm
convert Expr
l, Expr -> RuntimeTerm
convert Expr
r]
convert (EBin Bop
o Expr
l Expr
r)   = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Text -> Op) -> Text -> Op
forall a b. (a -> b) -> a -> b
$ Text
"$ebin" Text -> Text -> Text
`TX.append` (String -> Text
TX.pack (String -> Text) -> (Bop -> String) -> Bop -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bop -> String
forall a. Show a => a -> String
show) Bop
o) [Expr -> RuntimeTerm
convert Expr
l, Expr -> RuntimeTerm
convert Expr
r]
convert (ECon Constant
c)       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Text -> Op) -> Text -> Op
forall a b. (a -> b) -> a -> b
$ Text
"$econ" Text -> Text -> Text
`TX.append` (String -> Text
TX.pack (String -> Text) -> (Constant -> String) -> Constant -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constant -> String
forall a. Show a => a -> String
show) Constant
c) []
convert (ESym (SL Text
tx)) = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op Text
tx) []
convert (ECst Expr
t Sort
_)     = Expr -> RuntimeTerm
convert Expr
t
convert Expr
e              = String -> RuntimeTerm
forall a. HasCallStack => String -> a
error (Expr -> String
forall a. Show a => a -> String
show Expr
e)

passesTerminationCheck :: AbstractOC oc a IO -> RewriteArgs -> oc -> IO Bool
passesTerminationCheck :: AbstractOC oc a IO -> RewriteArgs -> oc -> IO Bool
passesTerminationCheck AbstractOC oc a IO
aoc RewriteArgs
rwArgs oc
c =
  case RewriteArgs -> RWTerminationOpts
rwTerminationOpts RewriteArgs
rwArgs of
    RWTerminationOpts
RWTerminationCheckEnabled  -> AbstractOC oc a IO -> oc -> IO Bool
forall c a (m :: * -> *). AbstractOC c a m -> c -> m Bool
isSat AbstractOC oc a IO
aoc oc
c
    RWTerminationOpts
RWTerminationCheckDisabled -> Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

getRewrite ::
     AbstractOC oc Expr IO
  -> RewriteArgs
  -> oc
  -> SubExpr
  -> AutoRewrite
  -> MaybeT IO (Expr, oc)
getRewrite :: AbstractOC oc Expr IO
-> RewriteArgs
-> oc
-> SubExpr
-> AutoRewrite
-> MaybeT IO (Expr, oc)
getRewrite AbstractOC oc Expr IO
aoc RewriteArgs
rwArgs oc
c (Expr
subE, Expr -> Expr
toE) (AutoRewrite [SortedReft]
args Expr
lhs Expr
rhs) =
  do
    Subst
su <- IO (Maybe Subst) -> MaybeT IO Subst
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (IO (Maybe Subst) -> MaybeT IO Subst)
-> IO (Maybe Subst) -> MaybeT IO Subst
forall a b. (a -> b) -> a -> b
$ Maybe Subst -> IO (Maybe Subst)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Subst -> IO (Maybe Subst))
-> Maybe Subst -> IO (Maybe Subst)
forall a b. (a -> b) -> a -> b
$ [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
lhs Expr
subE
    let subE' :: Expr
subE' = Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst Subst
su Expr
rhs
    Bool -> MaybeT IO ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT IO ()) -> Bool -> MaybeT IO ()
forall a b. (a -> b) -> a -> b
$ Expr
subE Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
/= Expr
subE'
    let expr' :: Expr
expr' = Expr -> Expr
toE Expr
subE'
    ((Symbol, Expr) -> MaybeT IO ())
-> [(Symbol, Expr)] -> MaybeT IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Subst -> (Symbol, Expr) -> MaybeT IO ()
checkSubst Subst
su) [(Symbol, Expr)]
exprs
    (Expr, oc) -> MaybeT IO (Expr, oc)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Expr, oc) -> MaybeT IO (Expr, oc))
-> (Expr, oc) -> MaybeT IO (Expr, oc)
forall a b. (a -> b) -> a -> b
$ case RewriteArgs -> RWTerminationOpts
rwTerminationOpts RewriteArgs
rwArgs of
      RWTerminationOpts
RWTerminationCheckEnabled ->
        let
          c' :: oc
c' = AbstractOC oc Expr IO -> oc -> Expr -> Expr -> oc
forall c a (m :: * -> *). AbstractOC c a m -> c -> a -> a -> c
refine AbstractOC oc Expr IO
aoc oc
c Expr
subE Expr
subE'
        in
          (Expr
expr', oc
c')
      RWTerminationOpts
RWTerminationCheckDisabled -> (Expr
expr', oc
c)
  where
    check :: Expr -> MaybeT IO ()
    check :: Expr -> MaybeT IO ()
check Expr
e = do
      Bool
valid <- IO (Maybe Bool) -> MaybeT IO Bool
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (IO (Maybe Bool) -> MaybeT IO Bool)
-> IO (Maybe Bool) -> MaybeT IO Bool
forall a b. (a -> b) -> a -> b
$ Bool -> Maybe Bool
forall a. a -> Maybe a
Just (Bool -> Maybe Bool) -> IO Bool -> IO (Maybe Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RewriteArgs -> Expr -> IO Bool
isRWValid RewriteArgs
rwArgs Expr
e
      Bool -> MaybeT IO ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
valid

    freeVars :: [Symbol]
freeVars = [Symbol
s | RR Sort
_ (Reft (Symbol
s, Expr
_)) <- [SortedReft]
args ]
    exprs :: [(Symbol, Expr)]
exprs    = [(Symbol
s, Expr
e) | RR Sort
_ (Reft (Symbol
s, Expr
e)) <- [SortedReft]
args ]

    checkSubst :: Subst -> (Symbol, Expr) -> MaybeT IO ()
checkSubst Subst
su (Symbol
s, Expr
e) =
      do
        let su' :: Subst
su' = (Subst -> Subst -> Subst
catSubst Subst
su (Subst -> Subst) -> Subst -> Subst
forall a b. (a -> b) -> a -> b
$ [(Symbol, Expr)] -> Subst
mkSubst [(Symbol
"VV", Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst Subst
su (Symbol -> Expr
EVar Symbol
s))])
        -- liftIO $ printf "Substitute %s in %s\n" (show su') (show e)
        Expr -> MaybeT IO ()
check (Expr -> MaybeT IO ()) -> Expr -> MaybeT IO ()
forall a b. (a -> b) -> a -> b
$ Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst (Subst -> Subst -> Subst
catSubst Subst
su Subst
su') Expr
e


subExprs :: Expr -> [SubExpr]
subExprs :: Expr -> [SubExpr]
subExprs Expr
e = (Expr
e,Expr -> Expr
forall a. a -> a
id)SubExpr -> [SubExpr] -> [SubExpr]
forall a. a -> [a] -> [a]
:Expr -> [SubExpr]
subExprs' Expr
e

subExprs' :: Expr -> [SubExpr]
subExprs' :: Expr -> [SubExpr]
subExprs' (EIte Expr
c Expr
lhs Expr
rhs)  = [SubExpr]
c''
  where
    c' :: [SubExpr]
c' = Expr -> [SubExpr]
subExprs Expr
c
    c'' :: [SubExpr]
c'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr -> Expr
EIte (Expr -> Expr
f Expr
e') Expr
lhs Expr
rhs)) [SubExpr]
c'

subExprs' (EBin Bop
op Expr
lhs Expr
rhs) = [SubExpr]
lhs'' [SubExpr] -> [SubExpr] -> [SubExpr]
forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Bop -> Expr -> Expr -> Expr
EBin Bop
op (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Bop -> Expr -> Expr -> Expr
EBin Bop
op Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'
    
subExprs' (PImp Expr
lhs Expr
rhs) = [SubExpr]
lhs'' [SubExpr] -> [SubExpr] -> [SubExpr]
forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PImp (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PImp Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'
    
subExprs' (PAtom Brel
op Expr
lhs Expr
rhs) = [SubExpr]
lhs'' [SubExpr] -> [SubExpr] -> [SubExpr]
forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Brel -> Expr -> Expr -> Expr
PAtom Brel
op (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Brel -> Expr -> Expr -> Expr
PAtom Brel
op Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'

subExprs' e :: Expr
e@(EApp{}) =
  if (Expr
f Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol -> Expr
EVar Symbol
"Language.Haskell.Liquid.ProofCombinators.===" Bool -> Bool -> Bool
||
      Expr
f Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol -> Expr
EVar Symbol
"Language.Haskell.Liquid.ProofCombinators.==." Bool -> Bool -> Bool
||
      Expr
f Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol -> Expr
EVar Symbol
"Language.Haskell.Liquid.ProofCombinators.?")
  then []
  else ((Int, Expr) -> [SubExpr]) -> [(Int, Expr)] -> [SubExpr]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int, Expr) -> [SubExpr]
replace [(Int, Expr)]
indexedArgs
    where
      (Expr
f, [Expr]
es)          = Expr -> (Expr, [Expr])
splitEApp Expr
e
      indexedArgs :: [(Int, Expr)]
indexedArgs      = [Int] -> [Expr] -> [(Int, Expr)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [Expr]
es
      replace :: (Int, Expr) -> [SubExpr]
replace (Int
i, Expr
arg) = do
        (Expr
subArg, Expr -> Expr
toArg) <- Expr -> [SubExpr]
subExprs Expr
arg
        SubExpr -> [SubExpr]
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
subArg, \Expr
subArg' -> Expr -> [Expr] -> Expr
eApps Expr
f ([Expr] -> Expr) -> [Expr] -> Expr
forall a b. (a -> b) -> a -> b
$ (Int -> [Expr] -> [Expr]
forall a. Int -> [a] -> [a]
take Int
i [Expr]
es) [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ (Expr -> Expr
toArg Expr
subArg')Expr -> [Expr] -> [Expr]
forall a. a -> [a] -> [a]
:(Int -> [Expr] -> [Expr]
forall a. Int -> [a] -> [a]
drop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Expr]
es))

subExprs' Expr
_ = []

unifyAll :: [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll :: [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
_ []     []               = Subst -> Maybe Subst
forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su HashMap Symbol Expr
forall k v. HashMap k v
M.empty)
unifyAll [Symbol]
freeVars (Expr
template:[Expr]
xs) (Expr
seen:[Expr]
ys) =
  do
    rs :: Subst
rs@(Su HashMap Symbol Expr
s1) <- [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
template Expr
seen
    let xs' :: [Expr]
xs' = (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst Subst
rs) [Expr]
xs
    let ys' :: [Expr]
ys' = (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst Subst
rs) [Expr]
ys
    (Su HashMap Symbol Expr
s2) <- [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll ([Symbol]
freeVars [Symbol] -> [Symbol] -> [Symbol]
forall a. Eq a => [a] -> [a] -> [a]
L.\\ HashMap Symbol Expr -> [Symbol]
forall k v. HashMap k v -> [k]
M.keys HashMap Symbol Expr
s1) [Expr]
xs' [Expr]
ys'
    Subst -> Maybe Subst
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Maybe Subst) -> Subst -> Maybe Subst
forall a b. (a -> b) -> a -> b
$ HashMap Symbol Expr -> Subst
Su (HashMap Symbol Expr -> HashMap Symbol Expr -> HashMap Symbol Expr
forall k v.
(Eq k, Hashable k) =>
HashMap k v -> HashMap k v -> HashMap k v
M.union HashMap Symbol Expr
s1 HashMap Symbol Expr
s2)
unifyAll [Symbol]
_ [Expr]
_ [Expr]
_ = Maybe Subst
forall a. HasCallStack => a
undefined

unify :: [Symbol] -> Expr -> Expr -> Maybe Subst
unify :: [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
_ Expr
template Expr
seenExpr | Expr
template Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
seenExpr = Subst -> Maybe Subst
forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su HashMap Symbol Expr
forall k v. HashMap k v
M.empty)
unify [Symbol]
freeVars Expr
template Expr
seenExpr = case (Expr
template, Expr
seenExpr) of
  (EVar Symbol
rwVar, Expr
_) | Symbol
rwVar Symbol -> [Symbol] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Symbol]
freeVars ->
    Subst -> Maybe Subst
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Maybe Subst) -> Subst -> Maybe Subst
forall a b. (a -> b) -> a -> b
$ HashMap Symbol Expr -> Subst
Su (Symbol -> Expr -> HashMap Symbol Expr
forall k v. Hashable k => k -> v -> HashMap k v
M.singleton Symbol
rwVar Expr
seenExpr)
  (EVar Symbol
lhs, EVar Symbol
rhs) | Symbol -> String
removeModName Symbol
lhs String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol -> String
removeModName Symbol
rhs ->
                         Subst -> Maybe Subst
forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su HashMap Symbol Expr
forall k v. HashMap k v
M.empty)
    where
      removeModName :: Symbol -> String
removeModName Symbol
ts = String -> ShowS
go String
"" (Symbol -> String
symbolString Symbol
ts) where
        go :: String -> ShowS
go String
buf []         = String
buf
        go String
_   (Char
'.':String
rest) = String -> ShowS
go [] String
rest
        go String
buf (Char
x:String
xs)     = String -> ShowS
go (String
buf String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char
x]) String
xs
  (EApp Expr
templateF Expr
templateBody, EApp Expr
seenF Expr
seenBody) ->
    [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
  (ENeg Expr
rw, ENeg Expr
seen) ->
    [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
  (EBin Bop
op Expr
rwLeft Expr
rwRight, EBin Bop
op' Expr
seenLeft Expr
seenRight) | Bop
op Bop -> Bop -> Bool
forall a. Eq a => a -> a -> Bool
== Bop
op' ->
    [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
rwLeft, Expr
rwRight] [Expr
seenLeft, Expr
seenRight]
  (EIte Expr
cond Expr
rwLeft Expr
rwRight, EIte Expr
seenCond Expr
seenLeft Expr
seenRight) ->
    [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
cond, Expr
rwLeft, Expr
rwRight] [Expr
seenCond, Expr
seenLeft, Expr
seenRight]
  (ECst Expr
rw Sort
_, ECst Expr
seen Sort
_) ->
    [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
  (ETApp Expr
rw Sort
_, ETApp Expr
seen Sort
_) ->
    [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
  (ETAbs Expr
rw Symbol
_, ETAbs Expr
seen Symbol
_) ->
    [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
  (PAnd [Expr]
rw, PAnd [Expr]
seen ) ->
    [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr]
rw [Expr]
seen
  (POr [Expr]
rw, POr [Expr]
seen ) ->
    [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr]
rw [Expr]
seen
  (PNot Expr
rw, PNot Expr
seen) ->
    [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
  (PImp Expr
templateF Expr
templateBody, PImp Expr
seenF Expr
seenBody) ->
    [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
  (PIff Expr
templateF Expr
templateBody, PIff Expr
seenF Expr
seenBody) ->
    [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
  (PAtom Brel
rel Expr
templateF Expr
templateBody, PAtom Brel
rel' Expr
seenF Expr
seenBody) | Brel
rel Brel -> Brel -> Bool
forall a. Eq a => a -> a -> Bool
== Brel
rel' ->
    [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
  (PAll [(Symbol, Sort)]
_ Expr
rw, PAll [(Symbol, Sort)]
_ Expr
seen) ->
    [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
  (PExist [(Symbol, Sort)]
_ Expr
rw, PExist [(Symbol, Sort)]
_ Expr
seen) ->
    [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
  (PGrad KVar
_ Subst
_ GradInfo
_ Expr
rw, PGrad KVar
_ Subst
_ GradInfo
_ Expr
seen) ->
    [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
  (ECoerc Sort
_ Sort
_ Expr
rw, ECoerc Sort
_ Sort
_ Expr
seen) ->
    [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
  (Expr, Expr)
_ -> Maybe Subst
forall a. Maybe a
Nothing