--------------------------------------------------------------------------------

{-# LANGUAGE NamedFieldPuns, ViewPatterns, ExistentialQuantification, GADTs #-}
{-# LANGUAGE Safe #-}

module Copilot.Theorem.Prove
  ( Output  (..)
  , Status  (..)
  , Prover  (..)
  , PropId, PropRef (..)
  , Proof, UProof, ProofScheme (..)
  , Action (..)
  , Universal, Existential
  , check
  , prove
  , combine
  ) where

import qualified Copilot.Core as Core

import Data.List (intercalate)
import Control.Applicative (liftA2)
import Control.Monad.Writer

--------------------------------------------------------------------------------

data Output = Output Status [String]

data Status = Sat | Valid | Invalid | Unknown | Error

{- Each prover has to provide the following five functions.
   The most important is `askProver`, which takes 3 arguments :
   *  The prover descriptor
   *  A list of properties names which are assumptions
   *  A property name which has to be deduced from these assumptions
-}

data Prover = forall r . Prover
  { Prover -> String
proverName  :: String
  , ()
startProver :: Core.Spec -> IO r
  , ()
askProver   :: r -> [PropId] -> [PropId] -> IO Output
  , ()
closeProver :: r -> IO ()
  }

type PropId = String

data PropRef a where
  PropRef :: PropId -> PropRef a

data Universal
data Existential

type Proof a = ProofScheme a ()

type UProof = Writer [Action] ()

data ProofScheme a b where
  Proof :: Writer [Action] b -> ProofScheme a b

instance Functor (ProofScheme a) where
  fmap :: (a -> b) -> ProofScheme a a -> ProofScheme a b
fmap = (a -> b) -> ProofScheme a a -> ProofScheme a b
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM

instance Applicative (ProofScheme a) where
  pure :: a -> ProofScheme a a
pure = a -> ProofScheme a a
forall (m :: * -> *) a. Monad m => a -> m a
return
  <*> :: ProofScheme a (a -> b) -> ProofScheme a a -> ProofScheme a b
(<*>) = ProofScheme a (a -> b) -> ProofScheme a a -> ProofScheme a b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Monad (ProofScheme a) where
  (Proof Writer [Action] a
p) >>= :: ProofScheme a a -> (a -> ProofScheme a b) -> ProofScheme a b
>>= a -> ProofScheme a b
f = Writer [Action] b -> ProofScheme a b
forall b a. Writer [Action] b -> ProofScheme a b
Proof (Writer [Action] b -> ProofScheme a b)
-> Writer [Action] b -> ProofScheme a b
forall a b. (a -> b) -> a -> b
$ Writer [Action] a
p Writer [Action] a -> (a -> Writer [Action] b) -> Writer [Action] b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (\a
a -> case a -> ProofScheme a b
f a
a of Proof Writer [Action] b
p' -> Writer [Action] b
p')
  return :: a -> ProofScheme a a
return a
a = Writer [Action] a -> ProofScheme a a
forall b a. Writer [Action] b -> ProofScheme a b
Proof (a -> Writer [Action] a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a)

data Action where
  Check  :: Prover -> Action
  Assume :: PropId -> Action
  Admit  :: Action

--------------------------------------------------------------------------------

check :: Prover -> Proof a
check :: Prover -> Proof a
check Prover
prover = Writer [Action] () -> Proof a
forall b a. Writer [Action] b -> ProofScheme a b
Proof (Writer [Action] () -> Proof a) -> Writer [Action] () -> Proof a
forall a b. (a -> b) -> a -> b
$ [Action] -> Writer [Action] ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Prover -> Action
Check Prover
prover]

prove :: Core.Spec -> PropId -> UProof -> IO Bool
prove :: Spec -> String -> Writer [Action] () -> IO Bool
prove Spec
spec String
propId (Writer [Action] () -> [Action]
forall w a. Writer w a -> w
execWriter -> [Action]
actions) = do

    [String]
thms <- [String] -> [Action] -> IO [String]
processActions [] [Action]
actions
    String -> IO ()
putStr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Finished: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
propId String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": proof "
    if (String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem String
propId [String]
thms) then (String -> IO ()
putStrLn String
"checked successfully") else (String -> IO ()
putStrLn String
"failed")
    Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem String
propId [String]
thms

    where
      processActions :: [String] -> [Action] -> IO [String]
processActions [String]
context [] = [String] -> IO [String]
forall (m :: * -> *) a. Monad m => a -> m a
return [String]
context
      processActions [String]
context (Action
action:[Action]
nextActions) = case Action
action of
        Check (Prover { Spec -> IO r
startProver :: Spec -> IO r
startProver :: ()
startProver, r -> [String] -> [String] -> IO Output
askProver :: r -> [String] -> [String] -> IO Output
askProver :: ()
askProver, r -> IO ()
closeProver :: r -> IO ()
closeProver :: ()
closeProver }) -> do
          r
prover <- Spec -> IO r
startProver Spec
spec
          (Output Status
status [String]
infos) <- r -> [String] -> [String] -> IO Output
askProver r
prover [String]
context [String
propId]
          r -> IO ()
closeProver r
prover
          case Status
status of
            Status
Sat     -> do
              String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
propId String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": sat " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " [String]
infos String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
              [String] -> [Action] -> IO [String]
processActions (String
propId String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
context) [Action]
nextActions
            Status
Valid   -> do
              String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
propId String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": valid " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " [String]
infos String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
              [String] -> [Action] -> IO [String]
processActions (String
propId String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
context) [Action]
nextActions
            Status
Invalid -> do
              String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
propId String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": invalid " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " [String]
infos String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
              [String] -> [Action] -> IO [String]
processActions [String]
context [Action]
nextActions
            Status
Error   -> do
              String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
propId String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": error " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " [String]
infos String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
              [String] -> [Action] -> IO [String]
processActions [String]
context [Action]
nextActions
            Status
Unknown -> do
              String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
propId String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": unknown " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " [String]
infos String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
              [String] -> [Action] -> IO [String]
processActions [String]
context [Action]
nextActions

        Assume String
propId' -> do
          String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
propId' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": assumption"
          [String] -> [Action] -> IO [String]
processActions (String
propId' String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
context) [Action]
nextActions

        Action
Admit -> do
          String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
propId String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": admitted"
          [String] -> [Action] -> IO [String]
processActions (String
propId String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
context) [Action]
nextActions

combine :: Prover -> Prover -> Prover
combine :: Prover -> Prover -> Prover
combine
  (Prover { proverName :: Prover -> String
proverName  = String
proverNameL
          , startProver :: ()
startProver = Spec -> IO r
startProverL
          , askProver :: ()
askProver   = r -> [String] -> [String] -> IO Output
askProverL
          , closeProver :: ()
closeProver = r -> IO ()
closeProverL
          })

  (Prover { proverName :: Prover -> String
proverName  = String
proverNameR
          , startProver :: ()
startProver = Spec -> IO r
startProverR
          , askProver :: ()
askProver   = r -> [String] -> [String] -> IO Output
askProverR
          , closeProver :: ()
closeProver = r -> IO ()
closeProverR
          })

 = Prover :: forall r.
String
-> (Spec -> IO r)
-> (r -> [String] -> [String] -> IO Output)
-> (r -> IO ())
-> Prover
Prover
  { proverName :: String
proverName  = String
proverNameL String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
proverNameR
  , startProver :: Spec -> IO (r, r)
startProver = \Spec
spec -> do
      r
proverL <- Spec -> IO r
startProverL Spec
spec
      r
proverR <- Spec -> IO r
startProverR Spec
spec
      (r, r) -> IO (r, r)
forall (m :: * -> *) a. Monad m => a -> m a
return (r
proverL, r
proverR)

  , askProver :: (r, r) -> [String] -> [String] -> IO Output
askProver = \(r
stL, r
stR) [String]
assumptions [String]
toCheck ->
      (Output -> Output -> Output) -> IO Output -> IO Output -> IO Output
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (String -> String -> Output -> Output -> Output
combineOutputs String
proverNameL String
proverNameR)
        (r -> [String] -> [String] -> IO Output
askProverL r
stL [String]
assumptions [String]
toCheck)
        (r -> [String] -> [String] -> IO Output
askProverR r
stR [String]
assumptions [String]
toCheck)

  , closeProver :: (r, r) -> IO ()
closeProver = \(r
stL, r
stR) -> do
      r -> IO ()
closeProverL r
stL
      r -> IO ()
closeProverR r
stR
  }

combineOutputs :: [Char] -> [Char] -> Output -> Output -> Output
combineOutputs :: String -> String -> Output -> Output -> Output
combineOutputs String
nameL String
nameR (Output Status
stL [String]
msgL) (Output Status
stR [String]
msgR) =
  Status -> [String] -> Output
Output (Status -> Status -> Status
combineSt Status
stL Status
stR) [String]
infos

  where
    combineSt :: Status -> Status -> Status
combineSt Status
Error Status
_         = Status
Error
    combineSt  Status
_ Status
Error        = Status
Error

    combineSt Status
Valid Status
Invalid   = Status
Error
    combineSt Status
Invalid Status
Valid   = Status
Error

    combineSt Status
Invalid Status
_       = Status
Invalid
    combineSt Status
_ Status
Invalid       = Status
Invalid

    combineSt Status
Valid Status
_         = Status
Valid
    combineSt Status
_ Status
Valid         = Status
Valid

    combineSt Status
Sat Status
_           = Status
Sat
    combineSt Status
_ Status
Sat           = Status
Sat

    combineSt Status
Unknown Status
Unknown = Status
Unknown

    prefixMsg :: [String]
prefixMsg = case (Status
stL, Status
stR) of
      (Status
Valid, Status
Invalid) -> [String
"The two provers don't agree"]
      (Status, Status)
_ -> []

    decoName :: String -> String
decoName String
s = String
"<" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
">"

    infos :: [String]
infos =
      [String]
prefixMsg
      [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ [String -> String
decoName String
nameL]
      [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ [String]
msgL
      [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ [String -> String
decoName String
nameR]
      [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ [String]
msgR

--------------------------------------------------------------------------------