{-# LANGUAGE TypeFamilies, PatternGuards, RecordWildCards, ScopedTypeVariables #-}
module Twee.Proof(
Proof, Derivation(..), Axiom(..),
certify, equation, derivation,
lemma, axiom, symm, trans, cong, congPath,
simplify, usedLemmas, usedAxioms, usedLemmasAndSubsts, usedAxiomsAndSubsts,
Config(..), defaultConfig, Presentation(..),
ProvedGoal(..), provedGoal, checkProvedGoal,
pPrintPresentation, present, describeEquation) where
import Twee.Base hiding (invisible)
import Twee.Equation
import Twee.Utils
import Control.Monad
import Data.Maybe
import Data.List
import Data.Ord
import qualified Data.Set as Set
import qualified Data.Map.Strict as Map
data Proof f =
Proof {
equation :: !(Equation f),
derivation :: !(Derivation f) }
deriving Show
data Derivation f =
UseLemma {-# UNPACK #-} !(Proof f) !(Subst f)
| UseAxiom {-# UNPACK #-} !(Axiom f) !(Subst f)
| Refl !(Term f)
| Symm !(Derivation f)
| Trans !(Derivation f) !(Derivation f)
| Cong {-# UNPACK #-} !(Fun f) ![Derivation f]
deriving (Eq, Show)
data Axiom f =
Axiom {
axiom_number :: {-# UNPACK #-} !Int,
axiom_name :: !String,
axiom_eqn :: !(Equation f) }
deriving (Eq, Ord, Show)
{-# INLINEABLE certify #-}
certify :: PrettyTerm f => Derivation f -> Proof f
certify p =
{-# SCC certify #-}
case check p of
Nothing -> error ("Invalid proof created!\n" ++ prettyShow p)
Just eqn -> Proof eqn p
where
check (UseLemma proof sub) =
return (subst sub (equation proof))
check (UseAxiom Axiom{..} sub) =
return (subst sub axiom_eqn)
check (Refl t) =
return (t :=: t)
check (Symm p) = do
t :=: u <- check p
return (u :=: t)
check (Trans p q) = do
t :=: u1 <- check p
u2 :=: v <- check q
guard (u1 == u2)
return (t :=: v)
check (Cong f ps) = do
eqns <- mapM check ps
return
(build (app f (map eqn_lhs eqns)) :=:
build (app f (map eqn_rhs eqns)))
instance Eq (Proof f) where
x == y = compare x y == EQ
instance Ord (Proof f) where
compare = comparing equation
instance Symbolic (Derivation f) where
type ConstantOf (Derivation f) = f
termsDL (UseLemma _ sub) = termsDL sub
termsDL (UseAxiom _ sub) = termsDL sub
termsDL (Refl t) = termsDL t
termsDL (Symm p) = termsDL p
termsDL (Trans p q) = termsDL p `mplus` termsDL q
termsDL (Cong _ ps) = termsDL ps
subst_ sub (UseLemma lemma s) = UseLemma lemma (subst_ sub s)
subst_ sub (UseAxiom axiom s) = UseAxiom axiom (subst_ sub s)
subst_ sub (Refl t) = Refl (subst_ sub t)
subst_ sub (Symm p) = symm (subst_ sub p)
subst_ sub (Trans p q) = trans (subst_ sub p) (subst_ sub q)
subst_ sub (Cong f ps) = cong f (subst_ sub ps)
instance Function f => Pretty (Proof f) where
pPrint = pPrintLemma defaultConfig (prettyShow . axiom_number) (prettyShow . equation)
instance PrettyTerm f => Pretty (Derivation f) where
pPrint (UseLemma lemma sub) =
text "subst" <#> pPrintTuple [text "lemma" <#> pPrint (equation lemma), pPrint sub]
pPrint (UseAxiom axiom sub) =
text "subst" <#> pPrintTuple [pPrint axiom, pPrint sub]
pPrint (Refl t) =
text "refl" <#> pPrintTuple [pPrint t]
pPrint (Symm p) =
text "symm" <#> pPrintTuple [pPrint p]
pPrint (Trans p q) =
text "trans" <#> pPrintTuple [pPrint p, pPrint q]
pPrint (Cong f ps) =
text "cong" <#> pPrintTuple (pPrint f:map pPrint ps)
instance PrettyTerm f => Pretty (Axiom f) where
pPrint Axiom{..} =
text "axiom" <#>
pPrintTuple [pPrint axiom_number, text axiom_name, pPrint axiom_eqn]
simplify :: Minimal f => (Proof f -> Maybe (Derivation f)) -> Derivation f -> Derivation f
simplify lem p = simp p
where
simp p@(UseLemma q sub) =
case lem q of
Nothing -> p
Just r ->
let
dead = usort (vars r) \\ substDomain sub
in simp (subst sub (erase dead r))
simp (Symm p) = symm (simp p)
simp (Trans p q) = trans (simp p) (simp q)
simp (Cong f ps) = cong f (map simp ps)
simp p = p
lemma :: Proof f -> Subst f -> Derivation f
lemma p sub = UseLemma p sub
axiom :: Axiom f -> Derivation f
axiom ax@Axiom{..} =
UseAxiom ax $
fromJust $
listToSubst [(x, build (var x)) | x <- vars axiom_eqn]
symm :: Derivation f -> Derivation f
symm (Refl t) = Refl t
symm (Symm p) = p
symm (Trans p q) = trans (symm q) (symm p)
symm (Cong f ps) = cong f (map symm ps)
symm p = Symm p
trans :: Derivation f -> Derivation f -> Derivation f
trans Refl{} p = p
trans p Refl{} = p
trans (Trans p q) r =
Trans p (trans q r)
trans (Cong f ps) (Cong g qs) | f == g =
transCong f ps qs
trans (Cong f ps) (Trans (Cong g qs) r) | f == g =
trans (transCong f ps qs) r
trans p q = Trans p q
transCong :: Fun f -> [Derivation f] -> [Derivation f] -> Derivation f
transCong f ps qs =
cong f (zipWith trans ps qs)
cong :: Fun f -> [Derivation f] -> Derivation f
cong f ps =
case sequence (map unRefl ps) of
Nothing -> Cong f ps
Just ts -> Refl (build (app f ts))
where
unRefl (Refl t) = Just t
unRefl _ = Nothing
usedLemmas :: Derivation f -> [Proof f]
usedLemmas p = map fst (usedLemmasAndSubsts p)
usedLemmasAndSubsts :: Derivation f -> [(Proof f, Subst f)]
usedLemmasAndSubsts p = lem p []
where
lem (UseLemma p sub) = ((p, sub):)
lem (Symm p) = lem p
lem (Trans p q) = lem p . lem q
lem (Cong _ ps) = foldr (.) id (map lem ps)
lem _ = id
usedAxioms :: Derivation f -> [Axiom f]
usedAxioms p = map fst (usedAxiomsAndSubsts p)
usedAxiomsAndSubsts :: Derivation f -> [(Axiom f, Subst f)]
usedAxiomsAndSubsts p = ax p []
where
ax (UseAxiom axiom sub) = ((axiom, sub):)
ax (Symm p) = ax p
ax (Trans p q) = ax p . ax q
ax (Cong _ ps) = foldr (.) id (map ax ps)
ax _ = id
congPath :: [Int] -> Term f -> Derivation f -> Derivation f
congPath [] _ p = p
congPath (n:ns) (App f t) p | n <= length ts =
cong f $
map Refl (take n ts) ++
[congPath ns (ts !! n) p] ++
map Refl (drop (n+1) ts)
where
ts = unpack t
congPath _ _ _ = error "bad path"
data Config =
Config {
cfg_all_lemmas :: !Bool,
cfg_no_lemmas :: !Bool,
cfg_show_instances :: !Bool }
defaultConfig :: Config
defaultConfig =
Config {
cfg_all_lemmas = False,
cfg_no_lemmas = False,
cfg_show_instances = False }
data Presentation f =
Presentation {
pres_axioms :: [Axiom f],
pres_lemmas :: [Proof f],
pres_goals :: [ProvedGoal f] }
deriving Show
data ProvedGoal f =
ProvedGoal {
pg_number :: Int,
pg_name :: String,
pg_proof :: Proof f,
pg_goal_hint :: Equation f,
pg_witness_hint :: Subst f }
deriving Show
provedGoal :: Int -> String -> Proof f -> ProvedGoal f
provedGoal number name proof =
ProvedGoal {
pg_number = number,
pg_name = name,
pg_proof = proof,
pg_goal_hint = equation proof,
pg_witness_hint = emptySubst }
checkProvedGoal :: Function f => ProvedGoal f -> ProvedGoal f
checkProvedGoal pg@ProvedGoal{..}
| subst pg_witness_hint pg_goal_hint == equation pg_proof =
pg
| otherwise =
error $ show $
text "Invalid ProvedGoal!" $$
text "Claims to prove" <+> pPrint pg_goal_hint $$
text "with witness" <+> pPrint pg_witness_hint <#> text "," $$
text "but actually proves" <+> pPrint (equation pg_proof)
instance Function f => Pretty (Presentation f) where
pPrint = pPrintPresentation defaultConfig
present :: Function f => Config -> [ProvedGoal f] -> Presentation f
present config goals =
presentWithGoals config goals
(snd (used Set.empty (concatMap (usedLemmas . derivation . pg_proof) goals)))
where
used lems [] = (lems, [])
used lems (x:xs)
| x `Set.member` lems = used lems xs
| otherwise =
let (lems1, ys) = used (Set.insert x lems) (usedLemmas (derivation x))
(lems2, zs) = used lems1 xs
in (lems2, ys ++ [x] ++ zs)
presentWithGoals ::
Function f =>
Config -> [ProvedGoal f] -> [Proof f] -> Presentation f
presentWithGoals config@Config{..} goals lemmas
| Map.null inlinings =
let
axioms = usort $
concatMap (usedAxioms . derivation . pg_proof) goals ++
concatMap (usedAxioms . derivation) lemmas
in
Presentation axioms
(map flattenProof lemmas)
[ decodeGoal (goal { pg_proof = flattenProof pg_proof })
| goal@ProvedGoal{..} <- goals ]
| otherwise =
let
inline lemma = Map.lookup lemma inlinings
goals' =
[ decodeGoal (goal { pg_proof = certify $ simplify inline (derivation pg_proof) })
| goal@ProvedGoal{..} <- goals ]
lemmas' =
[ certify $ simplify inline (derivation lemma)
| lemma <- lemmas, not (lemma `Map.member` inlinings) ]
in
presentWithGoals config goals' lemmas'
where
inlinings =
Map.fromList
[ (lemma, p)
| lemma <- lemmas, Just p <- [tryInline lemma]]
tryInline p
| shouldInline p = Just (derivation p)
tryInline p
| Just (m, q) <- Map.lookup (canonicalise (t :=: u)) equations, m < n =
Just (subsume p (derivation q))
| Just (m, q) <- Map.lookup (canonicalise (u :=: t)) equations, m < n =
Just (subsume p (Symm (derivation q)))
where
t :=: u = equation p
Just (n, _) = Map.lookup (canonicalise (equation p)) equations
tryInline _ = Nothing
shouldInline p =
cfg_no_lemmas ||
oneStep (derivation p) ||
(not cfg_all_lemmas &&
(isJust (decodeEquality (eqn_lhs (equation p))) ||
isJust (decodeEquality (eqn_rhs (equation p))) ||
Map.lookup p uses == Just 1))
subsume p q =
subst sub q
where
t :=: u = equation p
t' :=: u' = equation (certify q)
Just sub = matchList (buildList [t', u']) (buildList [t, u])
equations =
Map.fromList
[ (canonicalise (equation p), (i, p))
| (i, p) <- zip [0..] lemmas]
uses =
Map.fromListWith (+)
[ (p, 1)
| p <-
concatMap usedLemmas
(map (derivation . pg_proof) goals ++
map derivation lemmas) ]
oneStep Trans{} = False
oneStep _ = True
invisible :: Function f => Equation f -> Bool
invisible (t :=: u) = show (pPrint t) == show (pPrint u)
pPrintLemma :: Function f => Config -> (Axiom f -> String) -> (Proof f -> String) -> Proof f -> Doc
pPrintLemma Config{..} axiomNum lemmaNum p =
ppTerm (eqn_lhs (equation q)) $$ pp (derivation q)
where
q = flattenProof p
pp (Trans p q) = pp p $$ pp q
pp p | invisible (equation (certify p)) = pPrintEmpty
pp p =
(text "= { by" <+>
ppStep
(nub (map (show . ppLemma) (usedLemmasAndSubsts p)) ++
nub (map (show . ppAxiom) (usedAxiomsAndSubsts p))) <+>
text "}" $$
ppTerm (eqn_rhs (equation (certify p))))
ppTerm t = text " " <#> pPrint t
ppStep [] = text "reflexivity"
ppStep [x] = text x
ppStep xs =
hcat (punctuate (text ", ") (map text (init xs))) <+>
text "and" <+>
text (last xs)
ppLemma (p, sub) =
text "lemma" <+> text (lemmaNum p) <#> showSubst sub
ppAxiom (axiom@Axiom{..}, sub) =
text "axiom" <+> text (axiomNum axiom) <+> parens (text axiom_name) <#> showSubst sub
showSubst sub
| cfg_show_instances && not (null (substToList sub)) =
text " with " <#>
fsep (punctuate comma
[ pPrint x <+> text "->" <+> pPrint t
| (x, t) <- substToList sub ])
| otherwise = pPrintEmpty
flattenProof :: Function f => Proof f -> Proof f
flattenProof =
certify . flat . simplify (const Nothing) . derivation
where
flat (Trans p q) = trans (flat p) (flat q)
flat p@(Cong f ps) =
foldr trans (reflAfter p)
[ Cong f $
map reflAfter (take i ps) ++
[p] ++
map reflBefore (drop (i+1) ps)
| (i, q) <- zip [0..] qs,
p <- steps q ]
where
qs = map flat ps
flat p = p
reflBefore p = Refl (eqn_lhs (equation (certify p)))
reflAfter p = Refl (eqn_rhs (equation (certify p)))
steps Refl{} = []
steps (Trans p q) = steps p ++ steps q
steps p = [p]
trans (Trans p q) r = trans p (trans q r)
trans Refl{} p = p
trans p Refl{} = p
trans p q =
case strip q of
Nothing -> Trans p q
Just q' -> trans p q'
strip p
| t == u = Just (Refl t)
| otherwise = strip' t p
where
t :=: u = equation (certify p)
strip' t (Trans _ q)
| eqn_lhs (equation (certify q)) == t = Just q
| otherwise = strip' t q
strip' _ _ = Nothing
derivSteps :: Function f => Derivation f -> [Derivation f]
derivSteps = steps . derivation . flattenProof . certify
where
steps Refl{} = []
steps (Trans p q) = steps p ++ steps q
steps p = [p]
pPrintPresentation :: forall f. Function f => Config -> Presentation f -> Doc
pPrintPresentation config (Presentation axioms lemmas goals) =
vcat $ intersperse (text "") $
vcat [ describeEquation "Axiom" (axiomNum axiom) (Just name) eqn
| axiom@(Axiom _ name eqn) <- axioms,
not (invisible eqn) ]:
[ pp "Lemma" (lemmaNum p) Nothing (equation p) emptySubst p
| p <- lemmas,
not (invisible (equation p)) ] ++
[ pp "Goal" (show num) (Just pg_name) pg_goal_hint pg_witness_hint pg_proof
| (num, ProvedGoal{..}) <- zip [1..] goals ]
where
pp kind n mname eqn witness p =
describeEquation kind n mname eqn $$
ppWitness witness $$
text "Proof:" $$
pPrintLemma config axiomNum lemmaNum p
axiomNums = Map.fromList (zip axioms [1..])
lemmaNums = Map.fromList (zip lemmas [length axioms+1..])
axiomNum x = show (fromJust (Map.lookup x axiomNums))
lemmaNum x = show (fromJust (Map.lookup x lemmaNums))
ppWitness sub
| sub == emptySubst = pPrintEmpty
| otherwise =
vcat [
text "The goal is true when:",
nest 2 $ vcat
[ pPrint x <+> text "=" <+> pPrint t
| (x, t) <- substToList sub ],
if minimal `elem` funs sub then
text "where" <+> doubleQuotes (pPrint (minimal :: Fun f)) <+>
text "stands for an arbitrary term of your choice."
else pPrintEmpty,
text ""]
describeEquation ::
PrettyTerm f =>
String -> String -> Maybe String -> Equation f -> Doc
describeEquation kind num mname eqn =
text kind <+> text num <#>
(case mname of
Nothing -> text ""
Just name -> text (" (" ++ name ++ ")")) <#>
text ":" <+> pPrint eqn <#> text "."
decodeEquality :: Function f => Term f -> Maybe (Equation f)
decodeEquality (App equals (Cons t (Cons u Empty)))
| isEquals equals = Just (t :=: u)
decodeEquality _ = Nothing
decodeGoal :: Function f => ProvedGoal f -> ProvedGoal f
decodeGoal pg =
case maybeDecodeGoal pg of
Nothing -> pg
Just (name, witness, goal, deriv) ->
checkProvedGoal $
pg {
pg_name = name,
pg_proof = certify deriv,
pg_goal_hint = goal,
pg_witness_hint = witness }
maybeDecodeGoal :: forall f. Function f =>
ProvedGoal f -> Maybe (String, Subst f, Equation f, Derivation f)
maybeDecodeGoal ProvedGoal{..}
| isFalseTerm u = extract (derivSteps deriv)
| isFalseTerm t = extract (derivSteps (symm deriv))
| otherwise = Nothing
where
isFalseTerm, isTrueTerm :: Term f -> Bool
isFalseTerm (App false _) = isFalse false
isFalseTerm _ = False
isTrueTerm (App true _) = isTrue true
isTrueTerm _ = False
t :=: u = equation pg_proof
deriv = derivation pg_proof
decodeReflexivity :: Derivation f -> Maybe (Term f)
decodeReflexivity (Symm (UseAxiom Axiom{..} sub)) = do
guard (isTrueTerm (eqn_rhs axiom_eqn))
(t :=: u) <- decodeEquality (eqn_lhs axiom_eqn)
guard (t == u)
return (subst sub t)
decodeReflexivity _ = Nothing
decodeConjecture :: Derivation f -> Maybe (String, Equation f, Subst f)
decodeConjecture (UseAxiom Axiom{..} sub) = do
guard (isFalseTerm (eqn_rhs axiom_eqn))
eqn <- decodeEquality (eqn_lhs axiom_eqn)
return (axiom_name, eqn, sub)
decodeConjecture _ = Nothing
extract (p:ps) = do
t <- decodeReflexivity p
cont (Refl t) (Refl t) ps
extract [] = Nothing
cont p1 p2 (p:ps)
| Just t <- decodeReflexivity p =
cont (Refl t) (Refl t) ps
| Just (name, eqn, sub) <- decodeConjecture p =
return (name, sub, eqn, symm p1 `trans` p2)
| Cong eq [p1', p2'] <- p, isEquals eq =
cont (p1 `trans` p1') (p2 `trans` p2') ps
cont _ _ _ = Nothing