{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings         #-}
{-# LANGUAGE PatternGuards             #-}
{-# LANGUAGE ScopedTypeVariables       #-}

{-# OPTIONS_GHC -Wno-name-shadowing    #-}

module Language.Fixpoint.Solver.Rewrite
  ( getRewrite
  , subExprs
  , unify
  , ordConstraints
  , convert
  , passesTerminationCheck
  , RewriteArgs(..)
  , RWTerminationOpts(..)
  , SubExpr
  , TermOrigin(..)
  , OCType
  , RESTOrdering(..)
  ) where

import           Control.Monad (guard)
import           Control.Monad.Trans.Maybe
import           Data.Hashable
import qualified Data.HashMap.Strict  as M
import qualified Data.List            as L
import qualified Data.Text as TX
import           GHC.IO.Handle.Types (Handle)
import           GHC.Generics
import           Text.PrettyPrint (text)
import           Language.Fixpoint.Types.Config (RESTOrdering(..))
import           Language.Fixpoint.Types hiding (simplify)
import           Language.REST
import           Language.REST.KBO (kbo)
import           Language.REST.LPO (lpo)
import           Language.REST.OCAlgebra as OC
import           Language.REST.OCToAbstract (lift)
import           Language.REST.Op
import           Language.REST.SMT (SMTExpr)
import           Language.REST.WQOConstraints.ADT (ConstraintsADT, adtOC)
import qualified Language.REST.RuntimeTerm as RT

-- | @(e, f)@ asserts that @e@ is a subexpression of @f e@
type SubExpr = (Expr, Expr -> Expr)

data TermOrigin = PLE | RW deriving (Int -> TermOrigin -> ShowS
[TermOrigin] -> ShowS
TermOrigin -> [Char]
(Int -> TermOrigin -> ShowS)
-> (TermOrigin -> [Char])
-> ([TermOrigin] -> ShowS)
-> Show TermOrigin
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TermOrigin -> ShowS
showsPrec :: Int -> TermOrigin -> ShowS
$cshow :: TermOrigin -> [Char]
show :: TermOrigin -> [Char]
$cshowList :: [TermOrigin] -> ShowS
showList :: [TermOrigin] -> ShowS
Show, TermOrigin -> TermOrigin -> Bool
(TermOrigin -> TermOrigin -> Bool)
-> (TermOrigin -> TermOrigin -> Bool) -> Eq TermOrigin
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TermOrigin -> TermOrigin -> Bool
== :: TermOrigin -> TermOrigin -> Bool
$c/= :: TermOrigin -> TermOrigin -> Bool
/= :: TermOrigin -> TermOrigin -> Bool
Eq)

instance PPrint TermOrigin where
  pprintTidy :: Tidy -> TermOrigin -> Doc
pprintTidy Tidy
_ = [Char] -> Doc
text ([Char] -> Doc) -> (TermOrigin -> [Char]) -> TermOrigin -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermOrigin -> [Char]
forall a. Show a => a -> [Char]
show


data RWTerminationOpts =
    RWTerminationCheckEnabled
  | RWTerminationCheckDisabled

data RewriteArgs = RWArgs
 { RewriteArgs -> Expr -> IO Bool
isRWValid          :: Expr -> IO Bool
 , RewriteArgs -> RWTerminationOpts
rwTerminationOpts  :: RWTerminationOpts
 }

-- Monomorphize ordering constraints so we don't litter PLE with type variables
-- Also helps since GHC doesn't support impredicate polymorphism (yet)
data OCType =
    RPO (ConstraintsADT Op)
  | LPO (ConstraintsADT Op)
  | KBO (SMTExpr Bool)
  | Fuel Int
  deriving (OCType -> OCType -> Bool
(OCType -> OCType -> Bool)
-> (OCType -> OCType -> Bool) -> Eq OCType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: OCType -> OCType -> Bool
== :: OCType -> OCType -> Bool
$c/= :: OCType -> OCType -> Bool
/= :: OCType -> OCType -> Bool
Eq, Int -> OCType -> ShowS
[OCType] -> ShowS
OCType -> [Char]
(Int -> OCType -> ShowS)
-> (OCType -> [Char]) -> ([OCType] -> ShowS) -> Show OCType
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> OCType -> ShowS
showsPrec :: Int -> OCType -> ShowS
$cshow :: OCType -> [Char]
show :: OCType -> [Char]
$cshowList :: [OCType] -> ShowS
showList :: [OCType] -> ShowS
Show, (forall x. OCType -> Rep OCType x)
-> (forall x. Rep OCType x -> OCType) -> Generic OCType
forall x. Rep OCType x -> OCType
forall x. OCType -> Rep OCType x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. OCType -> Rep OCType x
from :: forall x. OCType -> Rep OCType x
$cto :: forall x. Rep OCType x -> OCType
to :: forall x. Rep OCType x -> OCType
Generic, Eq OCType
Eq OCType =>
(Int -> OCType -> Int) -> (OCType -> Int) -> Hashable OCType
Int -> OCType -> Int
OCType -> Int
forall a. Eq a => (Int -> a -> Int) -> (a -> Int) -> Hashable a
$chashWithSalt :: Int -> OCType -> Int
hashWithSalt :: Int -> OCType -> Int
$chash :: OCType -> Int
hash :: OCType -> Int
Hashable)

ordConstraints :: RESTOrdering -> (Handle, Handle) -> OCAlgebra OCType RT.RuntimeTerm IO
ordConstraints :: RESTOrdering -> (Handle, Handle) -> OCAlgebra OCType RuntimeTerm IO
ordConstraints RESTOrdering
RESTRPO      (Handle, Handle)
solver = (ConstraintsADT Op -> OCType)
-> (OCType -> ConstraintsADT Op)
-> OCAlgebra (ConstraintsADT Op) RuntimeTerm IO
-> OCAlgebra OCType RuntimeTerm IO
forall c d a (m :: * -> *).
(c -> d) -> (d -> c) -> OCAlgebra c a m -> OCAlgebra d a m
bimapConstraints ConstraintsADT Op -> OCType
RPO OCType -> ConstraintsADT Op
asRPO ((Handle, Handle) -> OCAlgebra (ConstraintsADT Op) RuntimeTerm IO
adtRPO (Handle, Handle)
solver)
  where
    asRPO :: OCType -> ConstraintsADT Op
asRPO (RPO ConstraintsADT Op
t) = ConstraintsADT Op
t
    asRPO OCType
_       = ConstraintsADT Op
forall a. HasCallStack => a
undefined

ordConstraints RESTOrdering
RESTKBO      (Handle, Handle)
solver = (SMTExpr Bool -> OCType)
-> (OCType -> SMTExpr Bool)
-> OCAlgebra (SMTExpr Bool) RuntimeTerm IO
-> OCAlgebra OCType RuntimeTerm IO
forall c d a (m :: * -> *).
(c -> d) -> (d -> c) -> OCAlgebra c a m -> OCAlgebra d a m
bimapConstraints SMTExpr Bool -> OCType
KBO OCType -> SMTExpr Bool
asKBO ((Handle, Handle) -> OCAlgebra (SMTExpr Bool) RuntimeTerm IO
kbo (Handle, Handle)
solver)
  where
    asKBO :: OCType -> SMTExpr Bool
asKBO (KBO SMTExpr Bool
t) = SMTExpr Bool
t
    asKBO OCType
_       = SMTExpr Bool
forall a. HasCallStack => a
undefined

ordConstraints RESTOrdering
RESTLPO      (Handle, Handle)
solver = (ConstraintsADT Op -> OCType)
-> (OCType -> ConstraintsADT Op)
-> OCAlgebra (ConstraintsADT Op) RuntimeTerm IO
-> OCAlgebra OCType RuntimeTerm IO
forall c d a (m :: * -> *).
(c -> d) -> (d -> c) -> OCAlgebra c a m -> OCAlgebra d a m
bimapConstraints ConstraintsADT Op -> OCType
LPO OCType -> ConstraintsADT Op
asLPO (WQOConstraints ConstraintsADT IO
-> ConstraintGen ConstraintsADT Op RuntimeTerm Identity
-> OCAlgebra (ConstraintsADT Op) RuntimeTerm IO
forall (impl :: * -> *) base lifted (m :: * -> *).
(ToSMTVar base Int, Ord base, Eq base, Hashable base, Show lifted,
 Show base, Show (impl base)) =>
WQOConstraints impl m
-> ConstraintGen impl base lifted Identity
-> OCAlgebra (impl base) lifted m
lift ((Handle, Handle) -> WQOConstraints ConstraintsADT IO
adtOC (Handle, Handle)
solver) WQOConstraints ConstraintsADT m'
-> Relation
-> ConstraintsADT Op
-> RuntimeTerm
-> RuntimeTerm
-> Identity (ConstraintsADT Op)
forall (oc :: * -> *).
(Show (oc Op), Eq (oc Op), Hashable (oc Op)) =>
ConstraintGen oc Op RuntimeTerm Identity
ConstraintGen ConstraintsADT Op RuntimeTerm Identity
lpo)
  where
    asLPO :: OCType -> ConstraintsADT Op
asLPO (LPO ConstraintsADT Op
t) = ConstraintsADT Op
t
    asLPO OCType
_       = ConstraintsADT Op
forall a. HasCallStack => a
undefined

ordConstraints (RESTFuel Int
n) (Handle, Handle)
_      = (Int -> OCType)
-> (OCType -> Int)
-> OCAlgebra Int RuntimeTerm IO
-> OCAlgebra OCType RuntimeTerm IO
forall c d a (m :: * -> *).
(c -> d) -> (d -> c) -> OCAlgebra c a m -> OCAlgebra d a m
bimapConstraints Int -> OCType
Fuel OCType -> Int
asFuel (OCAlgebra Int RuntimeTerm IO -> OCAlgebra OCType RuntimeTerm IO)
-> OCAlgebra Int RuntimeTerm IO -> OCAlgebra OCType RuntimeTerm IO
forall a b. (a -> b) -> a -> b
$ Int -> OCAlgebra Int RuntimeTerm IO
forall (m :: * -> *) a. Monad m => Int -> OCAlgebra Int a m
fuelOC Int
n
  where
    asFuel :: OCType -> Int
asFuel (Fuel Int
n) = Int
n
    asFuel OCType
_        = Int
forall a. HasCallStack => a
undefined


convert :: Expr -> RT.RuntimeTerm
convert :: Expr -> RuntimeTerm
convert (EIte Expr
i Expr
t Expr
e)   = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$ite" ([RuntimeTerm] -> RuntimeTerm) -> [RuntimeTerm] -> RuntimeTerm
forall a b. (a -> b) -> a -> b
$ (Expr -> RuntimeTerm) -> [Expr] -> [RuntimeTerm]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr
i,Expr
t,Expr
e]
convert e :: Expr
e@EApp{}       | (Expr
f, [Expr]
terms) <- Expr -> (Expr, [Expr])
splitEAppThroughECst Expr
e, EVar Symbol
fName <- Expr -> Expr
dropECst Expr
f
                       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Symbol -> Text
symbolText Symbol
fName)) ([RuntimeTerm] -> RuntimeTerm) -> [RuntimeTerm] -> RuntimeTerm
forall a b. (a -> b) -> a -> b
$ (Expr -> RuntimeTerm) -> [Expr] -> [RuntimeTerm]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr]
terms
convert (EVar Symbol
s)       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Symbol -> Text
symbolText Symbol
s)) []
convert (PNot Expr
e)       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$not" [ Expr -> RuntimeTerm
convert Expr
e ]
convert (PAnd [Expr]
es)      = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$and" ([RuntimeTerm] -> RuntimeTerm) -> [RuntimeTerm] -> RuntimeTerm
forall a b. (a -> b) -> a -> b
$ (Expr -> RuntimeTerm) -> [Expr] -> [RuntimeTerm]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr]
es
convert (POr [Expr]
es)       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App Op
"$or" ([RuntimeTerm] -> RuntimeTerm) -> [RuntimeTerm] -> RuntimeTerm
forall a b. (a -> b) -> a -> b
$ (Expr -> RuntimeTerm) -> [Expr] -> [RuntimeTerm]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> RuntimeTerm
convert [Expr]
es
convert (PAtom Brel
s Expr
l Expr
r)  = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Text -> Op) -> Text -> Op
forall a b. (a -> b) -> a -> b
$ Text
"$atom" Text -> Text -> Text
`TX.append` ([Char] -> Text
TX.pack ([Char] -> Text) -> (Brel -> [Char]) -> Brel -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Brel -> [Char]
forall a. Show a => a -> [Char]
show) Brel
s) [Expr -> RuntimeTerm
convert Expr
l, Expr -> RuntimeTerm
convert Expr
r]
convert (EBin Bop
o Expr
l Expr
r)   = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Text -> Op) -> Text -> Op
forall a b. (a -> b) -> a -> b
$ Text
"$ebin" Text -> Text -> Text
`TX.append` ([Char] -> Text
TX.pack ([Char] -> Text) -> (Bop -> [Char]) -> Bop -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bop -> [Char]
forall a. Show a => a -> [Char]
show) Bop
o) [Expr -> RuntimeTerm
convert Expr
l, Expr -> RuntimeTerm
convert Expr
r]
convert (ECon Constant
c)       = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op (Text -> Op) -> Text -> Op
forall a b. (a -> b) -> a -> b
$ Text
"$econ" Text -> Text -> Text
`TX.append` ([Char] -> Text
TX.pack ([Char] -> Text) -> (Constant -> [Char]) -> Constant -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constant -> [Char]
forall a. Show a => a -> [Char]
show) Constant
c) []
convert (ESym (SL Text
tx)) = Op -> [RuntimeTerm] -> RuntimeTerm
RT.App (Text -> Op
Op Text
tx) []
convert (ECst Expr
t Sort
_)     = Expr -> RuntimeTerm
convert Expr
t
convert (PIff Expr
e0 Expr
e1)   = Expr -> RuntimeTerm
convert (Brel -> Expr -> Expr -> Expr
PAtom Brel
Eq Expr
e0 Expr
e1)
convert (PImp Expr
e0 Expr
e1)   = Expr -> RuntimeTerm
convert ([Expr] -> Expr
POr [Expr -> Expr
PNot Expr
e0, Expr
e1])
convert Expr
e              = [Char] -> RuntimeTerm
forall a. HasCallStack => [Char] -> a
error (Expr -> [Char]
forall a. Show a => a -> [Char]
show Expr
e)

passesTerminationCheck :: OCAlgebra oc a IO -> RewriteArgs -> oc -> IO Bool
passesTerminationCheck :: forall oc a. OCAlgebra oc a IO -> RewriteArgs -> oc -> IO Bool
passesTerminationCheck OCAlgebra oc a IO
aoc RewriteArgs
rwArgs oc
c =
  case RewriteArgs -> RWTerminationOpts
rwTerminationOpts RewriteArgs
rwArgs of
    RWTerminationOpts
RWTerminationCheckEnabled  -> OCAlgebra oc a IO -> oc -> IO Bool
forall c a (m :: * -> *). OCAlgebra c a m -> c -> m Bool
isSat OCAlgebra oc a IO
aoc oc
c
    RWTerminationOpts
RWTerminationCheckDisabled -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

-- | Yields the result of rewriting an expression with an autorewrite equation.
--
-- Yields nothing if:
--
--  * The result of the rewrite is identical to the original expression
--  * Any of the arguments of the autorewrite has a refinement type which is
--    not satisfied in the current context.
--
getRewrite ::
     OCAlgebra oc Expr IO
  -> RewriteArgs
  -> oc
  -> SubExpr
  -> AutoRewrite
  -> MaybeT IO ((Expr, Expr), Expr, oc)
getRewrite :: forall oc.
OCAlgebra oc Expr IO
-> RewriteArgs
-> oc
-> SubExpr
-> AutoRewrite
-> MaybeT IO ((Expr, Expr), Expr, oc)
getRewrite OCAlgebra oc Expr IO
aoc RewriteArgs
rwArgs oc
c (Expr
subE, Expr -> Expr
toE) (AutoRewrite [SortedReft]
args Expr
lhs Expr
rhs) =
  do
    Subst
su <- IO (Maybe Subst) -> MaybeT IO Subst
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (IO (Maybe Subst) -> MaybeT IO Subst)
-> IO (Maybe Subst) -> MaybeT IO Subst
forall a b. (a -> b) -> a -> b
$ Maybe Subst -> IO (Maybe Subst)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Subst -> IO (Maybe Subst))
-> Maybe Subst -> IO (Maybe Subst)
forall a b. (a -> b) -> a -> b
$ [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
lhs Expr
subE
    let subE' :: Expr
subE' = Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst Subst
su Expr
rhs
    Bool -> MaybeT IO ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT IO ()) -> Bool -> MaybeT IO ()
forall a b. (a -> b) -> a -> b
$ Expr
subE Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
/= Expr
subE'
    let expr' :: Expr
expr' = Expr -> Expr
toE Expr
subE'
        eqn :: (Expr, Expr)
eqn = (Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst Subst
su Expr
lhs, Expr
subE')
    ((Symbol, Expr) -> MaybeT IO ())
-> [(Symbol, Expr)] -> MaybeT IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Subst -> (Symbol, Expr) -> MaybeT IO ()
checkSubst Subst
su) [(Symbol, Expr)]
exprs
    ((Expr, Expr), Expr, oc) -> MaybeT IO ((Expr, Expr), Expr, oc)
forall a. a -> MaybeT IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (((Expr, Expr), Expr, oc) -> MaybeT IO ((Expr, Expr), Expr, oc))
-> ((Expr, Expr), Expr, oc) -> MaybeT IO ((Expr, Expr), Expr, oc)
forall a b. (a -> b) -> a -> b
$ case RewriteArgs -> RWTerminationOpts
rwTerminationOpts RewriteArgs
rwArgs of
      RWTerminationOpts
RWTerminationCheckEnabled ->
        let
          c' :: oc
c' = OCAlgebra oc Expr IO -> oc -> Expr -> Expr -> oc
forall c a (m :: * -> *). OCAlgebra c a m -> c -> a -> a -> c
refine OCAlgebra oc Expr IO
aoc oc
c Expr
subE Expr
subE'
        in
          ((Expr, Expr)
eqn, Expr
expr', oc
c')
      RWTerminationOpts
RWTerminationCheckDisabled -> ((Expr, Expr)
eqn, Expr
expr', oc
c)
  where
    check :: Expr -> MaybeT IO ()
    check :: Expr -> MaybeT IO ()
check Expr
e = do
      Bool
valid <- IO (Maybe Bool) -> MaybeT IO Bool
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (IO (Maybe Bool) -> MaybeT IO Bool)
-> IO (Maybe Bool) -> MaybeT IO Bool
forall a b. (a -> b) -> a -> b
$ Bool -> Maybe Bool
forall a. a -> Maybe a
Just (Bool -> Maybe Bool) -> IO Bool -> IO (Maybe Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RewriteArgs -> Expr -> IO Bool
isRWValid RewriteArgs
rwArgs Expr
e
      Bool -> MaybeT IO ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
valid

    freeVars :: [Symbol]
freeVars = [Symbol
s | RR Sort
_ (Reft (Symbol
s, Expr
_)) <- [SortedReft]
args ]
    exprs :: [(Symbol, Expr)]
exprs    = [(Symbol
s, Expr
e) | RR Sort
_ (Reft (Symbol
s, Expr
e)) <- [SortedReft]
args ]

    checkSubst :: Subst -> (Symbol, Expr) -> MaybeT IO ()
checkSubst Subst
su (Symbol
s, Expr
e) =
      do
        let su' :: Subst
su' = Subst -> Subst -> Subst
catSubst Subst
su (Subst -> Subst) -> Subst -> Subst
forall a b. (a -> b) -> a -> b
$ [(Symbol, Expr)] -> Subst
mkSubst [(Symbol
"VV", Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst Subst
su (Symbol -> Expr
EVar Symbol
s))]
        -- liftIO $ printf "Substitute %s in %s\n" (show su') (show e)
        Expr -> MaybeT IO ()
check (Expr -> MaybeT IO ()) -> Expr -> MaybeT IO ()
forall a b. (a -> b) -> a -> b
$ Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst (Subst -> Subst -> Subst
catSubst Subst
su Subst
su') Expr
e


subExprs :: Expr -> [SubExpr]
subExprs :: Expr -> [SubExpr]
subExprs Expr
e = (Expr
e,Expr -> Expr
forall a. a -> a
id)SubExpr -> [SubExpr] -> [SubExpr]
forall a. a -> [a] -> [a]
:Expr -> [SubExpr]
subExprs' Expr
e

subExprs' :: Expr -> [SubExpr]
subExprs' :: Expr -> [SubExpr]
subExprs' (EIte Expr
c Expr
lhs Expr
rhs)  = [SubExpr]
c''
  where
    c' :: [SubExpr]
c' = Expr -> [SubExpr]
subExprs Expr
c
    c'' :: [SubExpr]
c'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr -> Expr
EIte (Expr -> Expr
f Expr
e') Expr
lhs Expr
rhs)) [SubExpr]
c'

subExprs' (EBin Bop
op Expr
lhs Expr
rhs) = [SubExpr]
lhs'' [SubExpr] -> [SubExpr] -> [SubExpr]
forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Bop -> Expr -> Expr -> Expr
EBin Bop
op (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Bop -> Expr -> Expr -> Expr
EBin Bop
op Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'

subExprs' (PImp Expr
lhs Expr
rhs) = [SubExpr]
lhs'' [SubExpr] -> [SubExpr] -> [SubExpr]
forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PImp (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PImp Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'

subExprs' (PIff Expr
lhs Expr
rhs) = [SubExpr]
lhs'' [SubExpr] -> [SubExpr] -> [SubExpr]
forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PIff (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Expr -> Expr -> Expr
PIff Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'

subExprs' (PAtom Brel
op Expr
lhs Expr
rhs) = [SubExpr]
lhs'' [SubExpr] -> [SubExpr] -> [SubExpr]
forall a. [a] -> [a] -> [a]
++ [SubExpr]
rhs''
  where
    lhs' :: [SubExpr]
lhs' = Expr -> [SubExpr]
subExprs Expr
lhs
    rhs' :: [SubExpr]
rhs' = Expr -> [SubExpr]
subExprs Expr
rhs
    lhs'' :: [SubExpr]
    lhs'' :: [SubExpr]
lhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Brel -> Expr -> Expr -> Expr
PAtom Brel
op (Expr -> Expr
f Expr
e') Expr
rhs)) [SubExpr]
lhs'
    rhs'' :: [SubExpr]
    rhs'' :: [SubExpr]
rhs'' = (SubExpr -> SubExpr) -> [SubExpr] -> [SubExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Expr
e, Expr -> Expr
f) -> (Expr
e, \Expr
e' -> Brel -> Expr -> Expr -> Expr
PAtom Brel
op Expr
lhs (Expr -> Expr
f Expr
e'))) [SubExpr]
rhs'

subExprs' e :: Expr
e@EApp{} =
  if Expr
f Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol -> Expr
EVar Symbol
"Language.Haskell.Liquid.ProofCombinators.===" Bool -> Bool -> Bool
||
     Expr
f Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol -> Expr
EVar Symbol
"Language.Haskell.Liquid.ProofCombinators.==." Bool -> Bool -> Bool
||
     Expr
f Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol -> Expr
EVar Symbol
"Language.Haskell.Liquid.ProofCombinators.?"
  then []
  else ((Int, Expr) -> [SubExpr]) -> [(Int, Expr)] -> [SubExpr]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int, Expr) -> [SubExpr]
replace [(Int, Expr)]
indexedArgs
    where
      (Expr
f, [Expr]
es)          = Expr -> (Expr, [Expr])
splitEApp Expr
e
      indexedArgs :: [(Int, Expr)]
indexedArgs      = [Int] -> [Expr] -> [(Int, Expr)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [Expr]
es
      replace :: (Int, Expr) -> [SubExpr]
replace (Int
i, Expr
arg) = do
        (Expr
subArg, Expr -> Expr
toArg) <- Expr -> [SubExpr]
subExprs Expr
arg
        SubExpr -> [SubExpr]
forall a. a -> [a]
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
subArg, \Expr
subArg' -> Expr -> [Expr] -> Expr
eApps Expr
f ([Expr] -> Expr) -> [Expr] -> Expr
forall a b. (a -> b) -> a -> b
$ Int -> [Expr] -> [Expr]
forall a. Int -> [a] -> [a]
take Int
i [Expr]
es [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ Expr -> Expr
toArg Expr
subArg' Expr -> [Expr] -> [Expr]
forall a. a -> [a] -> [a]
: Int -> [Expr] -> [Expr]
forall a. Int -> [a] -> [a]
drop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Expr]
es)

subExprs' (ECst Expr
e Sort
t) =
    [ (Expr
e', \Expr
subE -> Expr -> Sort -> Expr
ECst (Expr -> Expr
toE Expr
subE) Sort
t) | (Expr
e', Expr -> Expr
toE) <- Expr -> [SubExpr]
subExprs' Expr
e ]

subExprs' (PAnd [Expr]
es) = [ (Expr
e, [Expr] -> Expr
PAnd ([Expr] -> Expr) -> (Expr -> [Expr]) -> Expr -> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> [Expr]
f) | (Expr
e, Expr -> [Expr]
f) <- [Expr] -> [(Expr, Expr -> [Expr])]
subs [Expr]
es ]

subExprs' (POr [Expr]
es) = [ (Expr
e, [Expr] -> Expr
POr ([Expr] -> Expr) -> (Expr -> [Expr]) -> Expr -> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> [Expr]
f) | (Expr
e, Expr -> [Expr]
f) <- [Expr] -> [(Expr, Expr -> [Expr])]
subs [Expr]
es ]

subExprs' Expr
_ = []

-- | Computes the subexpressions of a list of expressions.
-- Each subexpression comes with a function that rebuilds the
-- context in which the subexpression occurs.
--
-- > and [ es == f e | (e, f) <- subs es ]
--
subs :: [Expr] -> [(Expr, Expr -> [Expr])]
subs :: [Expr] -> [(Expr, Expr -> [Expr])]
subs [] = []
subs [Expr
x] = [ (Expr
s, \Expr
e -> [Expr -> Expr
f Expr
e]) | (Expr
s, Expr -> Expr
f) <- Expr -> [SubExpr]
subExprs Expr
x ]
subs (Expr
x:[Expr]
xs) = [ (Expr
s, \Expr
e -> Expr -> Expr
f Expr
e Expr -> [Expr] -> [Expr]
forall a. a -> [a] -> [a]
: [Expr]
xs) | (Expr
s, Expr -> Expr
f) <- Expr -> [SubExpr]
subExprs Expr
x ]
              [(Expr, Expr -> [Expr])]
-> [(Expr, Expr -> [Expr])] -> [(Expr, Expr -> [Expr])]
forall a. [a] -> [a] -> [a]
++
              [ (Expr
s, \Expr
e -> Expr
x Expr -> [Expr] -> [Expr]
forall a. a -> [a] -> [a]
: Expr -> [Expr]
f Expr
e) | (Expr
s, Expr -> [Expr]
f) <- [Expr] -> [(Expr, Expr -> [Expr])]
subs [Expr]
xs ]


unifyAll :: [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll :: [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
_ []     []               = Subst -> Maybe Subst
forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su HashMap Symbol Expr
forall k v. HashMap k v
M.empty)
unifyAll [Symbol]
freeVars (Expr
template:[Expr]
xs) (Expr
seen:[Expr]
ys) =
  do
    rs :: Subst
rs@(Su HashMap Symbol Expr
s1) <- [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
template Expr
seen
    let xs' :: [Expr]
xs' = (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst Subst
rs) [Expr]
xs
    let ys' :: [Expr]
ys' = (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Expr -> Expr
forall a. Subable a => Subst -> a -> a
subst Subst
rs) [Expr]
ys
    (Su HashMap Symbol Expr
s2) <- [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll ([Symbol]
freeVars [Symbol] -> [Symbol] -> [Symbol]
forall a. Eq a => [a] -> [a] -> [a]
L.\\ HashMap Symbol Expr -> [Symbol]
forall k v. HashMap k v -> [k]
M.keys HashMap Symbol Expr
s1) [Expr]
xs' [Expr]
ys'
    Subst -> Maybe Subst
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Maybe Subst) -> Subst -> Maybe Subst
forall a b. (a -> b) -> a -> b
$ HashMap Symbol Expr -> Subst
Su (HashMap Symbol Expr -> HashMap Symbol Expr -> HashMap Symbol Expr
forall k v. Eq k => HashMap k v -> HashMap k v -> HashMap k v
M.union HashMap Symbol Expr
s1 HashMap Symbol Expr
s2)
unifyAll [Symbol]
_ [Expr]
_ [Expr]
_ = Maybe Subst
forall a. HasCallStack => a
undefined

-- | @unify vs template e = Just su@ yields a substitution @su@
-- such that subst su template == e
--
-- Moreover, @su@ is constraint to only substitute variables in @vs@.
--
-- Yields @Nothing@ if no substitution exists.
--
unify :: [Symbol] -> Expr -> Expr -> Maybe Subst
unify :: [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
_ Expr
template Expr
seenExpr | Expr
template Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
seenExpr = Subst -> Maybe Subst
forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su HashMap Symbol Expr
forall k v. HashMap k v
M.empty)
unify [Symbol]
freeVars Expr
template Expr
seenExpr = case (Expr -> Expr
dropECst Expr
template, Expr
seenExpr) of
  -- preserve seen casts if possible
  (EVar Symbol
rwVar, Expr
_) | Symbol
rwVar Symbol -> [Symbol] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Symbol]
freeVars ->
    Subst -> Maybe Subst
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Maybe Subst) -> Subst -> Maybe Subst
forall a b. (a -> b) -> a -> b
$ HashMap Symbol Expr -> Subst
Su (Symbol -> Expr -> HashMap Symbol Expr
forall k v. Hashable k => k -> v -> HashMap k v
M.singleton Symbol
rwVar Expr
seenExpr)
  -- otherwise discard the seen casts
  (Expr
template', Expr
_) -> case (Expr
template', Expr -> Expr
dropECst Expr
seenExpr) of
    (EVar Symbol
lhs, EVar Symbol
rhs) | Symbol -> [Char]
removeModName Symbol
lhs [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol -> [Char]
removeModName Symbol
rhs ->
                           Subst -> Maybe Subst
forall a. a -> Maybe a
Just (HashMap Symbol Expr -> Subst
Su HashMap Symbol Expr
forall k v. HashMap k v
M.empty)
      where
        removeModName :: Symbol -> [Char]
removeModName Symbol
ts = [Char] -> ShowS
go [Char]
"" (Symbol -> [Char]
symbolString Symbol
ts) where
          go :: [Char] -> ShowS
go [Char]
buf []         = [Char]
buf
          go [Char]
_   (Char
'.':[Char]
rest) = [Char] -> ShowS
go [] [Char]
rest
          go [Char]
buf (Char
x:[Char]
xs)     = [Char] -> ShowS
go ([Char]
buf [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char
x]) [Char]
xs
    (EApp Expr
templateF Expr
templateBody, EApp Expr
seenF Expr
seenBody) ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
    (ENeg Expr
rw, ENeg Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (EBin Bop
op Expr
rwLeft Expr
rwRight, EBin Bop
op' Expr
seenLeft Expr
seenRight) | Bop
op Bop -> Bop -> Bool
forall a. Eq a => a -> a -> Bool
== Bop
op' ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
rwLeft, Expr
rwRight] [Expr
seenLeft, Expr
seenRight]
    (EIte Expr
cond Expr
rwLeft Expr
rwRight, EIte Expr
seenCond Expr
seenLeft Expr
seenRight) ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
cond, Expr
rwLeft, Expr
rwRight] [Expr
seenCond, Expr
seenLeft, Expr
seenRight]
    (ECst Expr
rw Sort
_, Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (ETApp Expr
rw Sort
_, ETApp Expr
seen Sort
_) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (ETAbs Expr
rw Symbol
_, ETAbs Expr
seen Symbol
_) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (PAnd [Expr]
rw, PAnd [Expr]
seen ) ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr]
rw [Expr]
seen
    (POr [Expr]
rw, POr [Expr]
seen ) ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr]
rw [Expr]
seen
    (PNot Expr
rw, PNot Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (PImp Expr
templateF Expr
templateBody, PImp Expr
seenF Expr
seenBody) ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
    (PIff Expr
templateF Expr
templateBody, PIff Expr
seenF Expr
seenBody) ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
    (PAtom Brel
rel Expr
templateF Expr
templateBody, PAtom Brel
rel' Expr
seenF Expr
seenBody) | Brel
rel Brel -> Brel -> Bool
forall a. Eq a => a -> a -> Bool
== Brel
rel' ->
      [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll [Symbol]
freeVars [Expr
templateF, Expr
templateBody] [Expr
seenF, Expr
seenBody]
    (PAll [(Symbol, Sort)]
_ Expr
rw, PAll [(Symbol, Sort)]
_ Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (PExist [(Symbol, Sort)]
_ Expr
rw, PExist [(Symbol, Sort)]
_ Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (PGrad KVar
_ Subst
_ GradInfo
_ Expr
rw, PGrad KVar
_ Subst
_ GradInfo
_ Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (ECoerc Sort
_ Sort
_ Expr
rw, ECoerc Sort
_ Sort
_ Expr
seen) ->
      [Symbol] -> Expr -> Expr -> Maybe Subst
unify [Symbol]
freeVars Expr
rw Expr
seen
    (Expr, Expr)
_ -> Maybe Subst
forall a. Maybe a
Nothing