{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
module Language.Fixpoint.Solver.Rewrite
( getRewrite
, subExprs
, unify
, RewriteArgs(..)
, RWTerminationOpts(..)
, SubExpr
, TermOrigin(..)
) where
import Control.Monad.State
import Control.Monad.Trans.Maybe
import GHC.Generics
import Data.Hashable
import qualified Data.HashMap.Strict as M
import qualified Data.HashSet as S
import qualified Data.List as L
import qualified Data.Maybe as Mb
import Language.Fixpoint.Types hiding (simplify)
import qualified Data.Text as TX
type Op = Symbol
type OpOrdering = [Symbol]
data Term = Term Symbol [Term] deriving (Eq, Generic)
instance Hashable Term
termSym :: Term -> Symbol
termSym (Term s _) = s
instance Show Term where
show (Term op []) = TX.unpack $ symbolText op
show (Term op args) =
TX.unpack (symbolText op) ++ "(" ++ L.intercalate ", " (map show args) ++ ")"
data SCDir =
SCUp
| SCEq
| SCDown
deriving (Eq, Ord, Show, Generic)
instance Hashable SCDir
type SCPath = ((Op, Int), (Op, Int), [SCDir])
type SubExpr = (Expr, Expr -> Expr)
data SCEntry = SCEntry {
from :: (Op, Int)
, to :: (Op, Int)
, dir :: SCDir
} deriving (Eq, Ord, Show, Generic)
instance Hashable SCEntry
getDir :: OpOrdering -> Term -> Term -> SCDir
getDir o from to =
case (synGTE o from to, synGTE o to from) of
(True, True) -> SCEq
(True, False) -> SCDown
(False, _) -> SCUp
getSC :: OpOrdering -> Term -> Term -> S.HashSet SCEntry
getSC o (Term op ts) (Term op' us) =
S.fromList $ do
(i, from) <- zip [0..] ts
(j, to) <- zip [0..] us
return $ SCEntry (op, i) (op', j) (getDir o from to)
scp :: OpOrdering -> [Term] -> S.HashSet SCPath
scp _ [] = S.empty
scp _ [_] = S.empty
scp o [t1, t2] = S.fromList $ do
(SCEntry a b d) <- S.toList $ getSC o t1 t2
return (a, b, [d])
scp o (t1:t2:trms) = S.fromList $ do
(SCEntry a b' d) <- S.toList $ getSC o t1 t2
(a', b, ds) <- S.toList $ scp o (t2:trms)
guard $ b' == a'
return (a, b, d:ds)
synEQ :: OpOrdering -> Term -> Term -> Bool
synEQ o l r = synGTE o l r && synGTE o r l
opGT :: OpOrdering -> Op -> Op -> Bool
opGT ordering op1 op2 = case (L.elemIndex op1 ordering, L.elemIndex op2 ordering) of
(Just index1, Just index2) -> index1 < index2
(Just _, Nothing) -> True
_ -> False
removeSynEQs :: OpOrdering -> [Term] -> [Term] -> ([Term], [Term])
removeSynEQs _ [] ys = ([], ys)
removeSynEQs ordering (x:xs) ys
| Just yIndex <- L.findIndex (synEQ ordering x) ys
= removeSynEQs ordering xs $ take yIndex ys ++ drop (yIndex + 1) ys
| otherwise =
let
(xs', ys') = removeSynEQs ordering xs ys
in
(x:xs', ys')
synGTEM :: OpOrdering -> [Term] -> [Term] -> Bool
synGTEM ordering xs ys =
case removeSynEQs ordering xs ys of
(_ , []) -> True
(xs', ys') -> any (\x -> all (synGT ordering x) ys') xs'
synGT :: OpOrdering -> Term -> Term -> Bool
synGT o t1 t2 = synGTE o t1 t2 && not (synGTE o t2 t1)
synGTM :: OpOrdering -> [Term] -> [Term] -> Bool
synGTM o t1 t2 = synGTEM o t1 t2 && not (synGTEM o t2 t1)
synGTE :: OpOrdering -> Term -> Term -> Bool
synGTE ordering t1@(Term x tms) t2@(Term y tms') =
if opGT ordering x y then
synGTM ordering [t1] tms'
else if opGT ordering y x then
synGTEM ordering tms [t2]
else
synGTEM ordering tms tms'
subsequencesOfSize :: Int -> [a] -> [[a]]
subsequencesOfSize n xs = let l = length xs
in if n>l then [] else subsequencesBySize xs !! (l-n)
where
subsequencesBySize [] = [[[]]]
subsequencesBySize (x:xs) = let next = subsequencesBySize xs
in zipWith (++) ([]:next) (map (map (x:)) next ++ [[]])
data TermOrigin = PLE | RW OpOrdering deriving (Show, Eq)
data DivergeResult = Diverging | NotDiverging OpOrdering
fromRW :: TermOrigin -> Bool
fromRW (RW _) = True
fromRW PLE = False
getOrdering :: TermOrigin -> Maybe OpOrdering
getOrdering (RW o) = Just o
getOrdering PLE = Nothing
diverges :: Maybe Int -> [(Term, TermOrigin)] -> Term -> DivergeResult
diverges maxOrderingConstraints path term = go 0
where
path' = map fst path ++ [term]
go n | n > length syms'
|| n > Mb.fromMaybe (length syms') maxOrderingConstraints = Diverging
go n = case L.find (not . diverges') (orderings' n) of
Just ordering -> NotDiverging ordering
Nothing -> go (n + 1)
ops (Term o xs) = o:concatMap ops xs
syms' = L.nub $ concatMap ops path'
suggestedOrderings :: [OpOrdering]
suggestedOrderings =
reverse $ Mb.catMaybes $ map (getOrdering . snd) path
orderings' n =
suggestedOrderings ++ concatMap L.permutations ((subsequencesOfSize n) syms')
diverges' o = divergesFor o path term
divergesFor :: OpOrdering -> [(Term, TermOrigin)] -> Term -> Bool
divergesFor o path term = any diverges' terms'
where
terms = map fst path ++ [term]
lastRWIndex =
Mb.fromMaybe 0 (fmap fst $ L.find (fromRW . snd . snd) $ reverse $ zip [1..] path)
okTerms = take lastRWIndex terms
checkTerms = drop lastRWIndex terms
terms' = L.subsequences checkTerms ++ do
firstpart <- L.tails okTerms
secondpart <- L.inits checkTerms
return $ firstpart ++ secondpart
diverges' :: [Term] -> Bool
diverges' trms' =
if length trms' <= 1 || termSym (head trms') /= termSym (last trms') then
False
else
any ascending (scp o trms') && all (not . descending) (scp o trms')
descending :: SCPath -> Bool
descending (a, b, ds) = a == b && L.elem SCDown ds && L.notElem SCUp ds
ascending :: SCPath -> Bool
ascending (a, b, ds) = a == b && L.elem SCUp ds
data RWTerminationOpts =
RWTerminationCheckEnabled (Maybe Int)
| RWTerminationCheckDisabled
data RewriteArgs = RWArgs
{ isRWValid :: Expr -> IO Bool
, rwTerminationOpts :: RWTerminationOpts
}
getRewrite :: RewriteArgs -> [(Expr, TermOrigin)] -> SubExpr -> AutoRewrite -> MaybeT IO (Expr, TermOrigin)
getRewrite rwArgs path (subE, toE) (AutoRewrite args lhs rhs) =
do
su <- MaybeT $ return $ unify freeVars lhs subE
let subE' = subst su rhs
let expr' = toE subE'
guard $ all ( (/= expr') . fst) path
mapM_ (check . subst su) exprs
let termPath = map (\(t, o) -> (convert t, o)) path
case rwTerminationOpts rwArgs of
RWTerminationCheckEnabled maxConstraints ->
case diverges maxConstraints termPath (convert expr') of
NotDiverging opOrdering ->
return (expr', RW opOrdering)
Diverging ->
mzero
RWTerminationCheckDisabled -> return (expr', RW [])
where
convert (EIte i t e) = Term "$ite" $ map convert [i,t,e]
convert (EApp (EVar s) (EVar var))
| dcPrefix `isPrefixOfSym` s
= Term (symbol $ TX.concat [symbolText s, "$", symbolText var]) []
convert e@(EApp{}) | (EVar fName, terms) <- splitEApp e
= Term fName $ map convert terms
convert (EVar s) = Term s []
convert (PAnd es) = Term "$and" $ map convert es
convert (POr es) = Term "$or" $ map convert es
convert (PAtom s l r) = Term (symbol $ "$atom" ++ show s) [convert l, convert r]
convert (EBin o l r) = Term (symbol $ "$ebin" ++ show o) [convert l, convert r]
convert (ECon c) = Term (symbol $ "$econ" ++ show c) []
convert e = error (show e)
check :: Expr -> MaybeT IO ()
check e = do
valid <- MaybeT $ Just <$> isRWValid rwArgs e
guard valid
dcPrefix = "lqdc"
freeVars = [s | RR _ (Reft (s, _)) <- args ]
exprs = [e | RR _ (Reft (_, e)) <- args ]
subExprs :: Expr -> [SubExpr]
subExprs e = (e,id):subExprs' e
subExprs' :: Expr -> [SubExpr]
subExprs' (EIte c lhs rhs) = c'' ++ l'' ++ r''
where
c' = subExprs c
l' = subExprs lhs
r' = subExprs rhs
c'' = map (\(e, f) -> (e, \e' -> EIte (f e') lhs rhs)) c'
l'' = map (\(e, f) -> (e, \e' -> EIte c (f e') rhs)) l'
r'' = map (\(e, f) -> (e, \e' -> EIte c lhs (f e'))) r'
subExprs' (EBin op lhs rhs) = lhs'' ++ rhs''
where
lhs' = subExprs lhs
rhs' = subExprs rhs
lhs'' :: [SubExpr]
lhs'' = map (\(e, f) -> (e, \e' -> EBin op (f e') rhs)) lhs'
rhs'' :: [SubExpr]
rhs'' = map (\(e, f) -> (e, \e' -> EBin op lhs (f e'))) rhs'
subExprs' (PImp lhs rhs) = lhs'' ++ rhs''
where
lhs' = subExprs lhs
rhs' = subExprs rhs
lhs'' :: [SubExpr]
lhs'' = map (\(e, f) -> (e, \e' -> PImp (f e') rhs)) lhs'
rhs'' :: [SubExpr]
rhs'' = map (\(e, f) -> (e, \e' -> PImp lhs (f e'))) rhs'
subExprs' (PAtom op lhs rhs) = lhs'' ++ rhs''
where
lhs' = subExprs lhs
rhs' = subExprs rhs
lhs'' :: [SubExpr]
lhs'' = map (\(e, f) -> (e, \e' -> PAtom op (f e') rhs)) lhs'
rhs'' :: [SubExpr]
rhs'' = map (\(e, f) -> (e, \e' -> PAtom op lhs (f e'))) rhs'
subExprs' _ = []
unifyAll :: [Symbol] -> [Expr] -> [Expr] -> Maybe Subst
unifyAll _ [] [] = Just (Su M.empty)
unifyAll freeVars (template:xs) (seen:ys) =
do
rs@(Su s1) <- unify freeVars template seen
let xs' = map (subst rs) xs
let ys' = map (subst rs) ys
(Su s2) <- unifyAll (freeVars L.\\ M.keys s1) xs' ys'
return $ Su (M.union s1 s2)
unifyAll _ _ _ = undefined
unify :: [Symbol] -> Expr -> Expr -> Maybe Subst
unify _ template seenExpr | template == seenExpr = Just (Su M.empty)
unify freeVars template seenExpr = case (template, seenExpr) of
(EVar rwVar, _) | rwVar `elem` freeVars ->
return $ Su (M.singleton rwVar seenExpr)
(EApp templateF templateBody, EApp seenF seenBody) ->
unifyAll freeVars [templateF, templateBody] [seenF, seenBody]
(ENeg rw, ENeg seen) ->
unify freeVars rw seen
(EBin op rwLeft rwRight, EBin op' seenLeft seenRight) | op == op' ->
unifyAll freeVars [rwLeft, rwRight] [seenLeft, seenRight]
(EIte cond rwLeft rwRight, EIte seenCond seenLeft seenRight) ->
unifyAll freeVars [cond, rwLeft, rwRight] [seenCond, seenLeft, seenRight]
(ECst rw _, ECst seen _) ->
unify freeVars rw seen
(ETApp rw _, ETApp seen _) ->
unify freeVars rw seen
(ETAbs rw _, ETAbs seen _) ->
unify freeVars rw seen
(PAnd rw, PAnd seen ) ->
unifyAll freeVars rw seen
(POr rw, POr seen ) ->
unifyAll freeVars rw seen
(PNot rw, PNot seen) ->
unify freeVars rw seen
(PImp templateF templateBody, PImp seenF seenBody) ->
unifyAll freeVars [templateF, templateBody] [seenF, seenBody]
(PIff templateF templateBody, PIff seenF seenBody) ->
unifyAll freeVars [templateF, templateBody] [seenF, seenBody]
(PAtom rel templateF templateBody, PAtom rel' seenF seenBody) | rel == rel' ->
unifyAll freeVars [templateF, templateBody] [seenF, seenBody]
(PAll _ rw, PAll _ seen) ->
unify freeVars rw seen
(PExist _ rw, PExist _ seen) ->
unify freeVars rw seen
(PGrad _ _ _ rw, PGrad _ _ _ seen) ->
unify freeVars rw seen
(ECoerc _ _ rw, ECoerc _ _ seen) ->
unify freeVars rw seen
_ -> Nothing