-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.ProofTools.AddHorn
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Example of invariant generation for a simple addition algorithm:
--
-- @
--    z = x
--    i = 0
--    assume y > 0
--
--    while (i < y)
--       z = z + 1
--       i = i + 1
--
--   assert z == x + y
-- @
--
-- We use the Horn solver to calculate the invariant and then show that it
-- indeed is a sufficient invariant to establish correctness.
-----------------------------------------------------------------------------

{-# LANGUAGE DataKinds #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.ProofTools.AddHorn where

import Data.SBV

-- $setup
-- >>> -- For doctest purposes only:
-- >>> import Data.SBV

-- | Helper type synonym for the invariant.
type Inv = (SInteger, SInteger, SInteger, SInteger) -> SBool

-- | Helper type synonym for verification conditions.
type VC = Forall "x" Integer -> Forall "y" Integer -> Forall "z" Integer -> Forall "i" Integer -> SBool

-- | Helper for turning an invariant predicate to a boolean.
quantify :: Inv -> VC
quantify :: Inv -> VC
quantify Inv
f = \(Forall SInteger
x) (Forall SInteger
y) (Forall SInteger
z) (Forall SInteger
i) -> Inv
f (SInteger
x, SInteger
y, SInteger
z, SInteger
i)

-- | First verification condition: Before the loop starts, invariant must hold:
--
-- \(z = x \land i = 0 \land y > 0 \Rightarrow inv (x, y, z, i)\)
vc1 :: Inv -> VC
vc1 :: Inv -> VC
vc1 Inv
inv = Inv -> VC
quantify (Inv -> VC) -> Inv -> VC
forall a b. (a -> b) -> a -> b
$ \(SInteger
x, SInteger
y, SInteger
z, SInteger
i) -> SInteger
z SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
x SBool -> SBool -> SBool
.&& SInteger
i SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0 SBool -> SBool -> SBool
.&& SInteger
y SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.> SInteger
0 SBool -> SBool -> SBool
.=> Inv
inv (SInteger
x, SInteger
y, SInteger
z, SInteger
i)

-- | Second verification condition: If the loop body executes, invariant must still hold at the end:
--
-- \(inv (x, y, z, i) \land i < y \Rightarrow inv (x, y, z+1, i+1)\)
vc2 :: Inv -> VC
vc2 :: Inv -> VC
vc2 Inv
inv = Inv -> VC
quantify (Inv -> VC) -> Inv -> VC
forall a b. (a -> b) -> a -> b
$ \(SInteger
x, SInteger
y, SInteger
z, SInteger
i) -> Inv
inv (SInteger
x, SInteger
y, SInteger
z, SInteger
i) SBool -> SBool -> SBool
.&& SInteger
i SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.< SInteger
y SBool -> SBool -> SBool
.=> Inv
inv (SInteger
x, SInteger
y, SInteger
zSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+SInteger
1, SInteger
iSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+SInteger
1)

-- | Third verification condition: Once the loop exits, invariant and the negation of the loop condition
-- must establish the final assertion:
--
-- \(inv (x, y, z, i) \land i \geq y \Rightarrow z == x + y\)
vc3 :: Inv -> VC
vc3 :: Inv -> VC
vc3 Inv
inv = Inv -> VC
quantify (Inv -> VC) -> Inv -> VC
forall a b. (a -> b) -> a -> b
$ \(SInteger
x, SInteger
y, SInteger
z, SInteger
i) -> Inv
inv (SInteger
x, SInteger
y, SInteger
z, SInteger
i) SBool -> SBool -> SBool
.&& SInteger
i SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
y SBool -> SBool -> SBool
.=> SInteger
z SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger
y

-- | Synthesize the invariant. We use an uninterpreted function for the SMT solver to synthesize. We get:
--
-- >>> synthesize
-- Satisfiable. Model:
--   invariant :: (Integer, Integer, Integer, Integer) -> Bool
--   invariant (x, y, z, i) = x + (-z) + i > (-1) && x + (-z) + i < 1 && x + y + (-z) > (-1)
--
-- This is a bit hard to read, but you can convince yourself it is equivalent to @x + i .== z .&& x + y .>= z@:
--
-- >>> let f (x, y, z, i) = x + (-z) + i .> (-1) .&& x + (-z) + i .< 1 .&& x + y + (-z) .> (-1)
-- >>> let g (x, y, z, i) = x + i .== z .&& x + y .>= z
-- >>> f === (g :: Inv)
-- Q.E.D.
synthesize :: IO SatResult
synthesize :: IO SatResult
synthesize = ConstraintSet -> IO SatResult
forall a. Satisfiable a => a -> IO SatResult
sat ConstraintSet
vcs
  where invariant :: Inv
        invariant :: Inv
invariant = String -> [String] -> Inv
forall a. SMTDefinable a => String -> [String] -> a
uninterpretWithArgs String
"invariant" [String
"x", String
"y", String
"z", String
"i"]

        vcs :: ConstraintSet
        vcs :: ConstraintSet
vcs = do Logic -> ConstraintSet
forall (m :: * -> *). SolverContext m => Logic -> m ()
setLogic (Logic -> ConstraintSet) -> Logic -> ConstraintSet
forall a b. (a -> b) -> a -> b
$ String -> Logic
CustomLogic String
"HORN"
                 VC -> ConstraintSet
forall a. QuantifiedBool a => a -> ConstraintSet
forall (m :: * -> *) a.
(SolverContext m, QuantifiedBool a) =>
a -> m ()
constrain (VC -> ConstraintSet) -> VC -> ConstraintSet
forall a b. (a -> b) -> a -> b
$ Inv -> VC
vc1 Inv
invariant
                 VC -> ConstraintSet
forall a. QuantifiedBool a => a -> ConstraintSet
forall (m :: * -> *) a.
(SolverContext m, QuantifiedBool a) =>
a -> m ()
constrain (VC -> ConstraintSet) -> VC -> ConstraintSet
forall a b. (a -> b) -> a -> b
$ Inv -> VC
vc2 Inv
invariant
                 VC -> ConstraintSet
forall a. QuantifiedBool a => a -> ConstraintSet
forall (m :: * -> *) a.
(SolverContext m, QuantifiedBool a) =>
a -> m ()
constrain (VC -> ConstraintSet) -> VC -> ConstraintSet
forall a b. (a -> b) -> a -> b
$ Inv -> VC
vc3 Inv
invariant

-- | Verify that the synthesized function does indeed work. To do so, we simply prove that the invariant found satisfies all the vcs:
--
-- >>> verify
-- Q.E.D.
verify :: IO ThmResult
verify :: IO ThmResult
verify = SBool -> IO ThmResult
forall a. Provable a => a -> IO ThmResult
prove SBool
vcs
  where invariant :: Inv
        invariant :: Inv
invariant (SInteger
x, SInteger
y, SInteger
z, SInteger
i) = SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ (-SInteger
z) SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger
i SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.> (-SInteger
1) SBool -> SBool -> SBool
.&& SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ (-SInteger
z) SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger
i SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.< SInteger
1 SBool -> SBool -> SBool
.&& SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger
y SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ (-SInteger
z) SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.> (-SInteger
1)

        vcs :: SBool
        vcs :: SBool
vcs =   VC -> SBool
forall a. QuantifiedBool a => a -> SBool
quantifiedBool (Inv -> VC
vc1 Inv
invariant)
            SBool -> SBool -> SBool
.&& VC -> SBool
forall a. QuantifiedBool a => a -> SBool
quantifiedBool (Inv -> VC
vc3 Inv
invariant)
            SBool -> SBool -> SBool
.&& VC -> SBool
forall a. QuantifiedBool a => a -> SBool
quantifiedBool (Inv -> VC
vc3 Inv
invariant)

{- HLint ignore quantify "Redundant lambda" -}