--------------------------------------------------------------------------------

{-# LANGUAGE GADTs, FlexibleInstances #-}
{-# LANGUAGE Safe #-}

module Copilot.Theorem.Prover.SMTLib (SmtLib, interpret) where

import Copilot.Theorem.Prover.Backend (SmtFormat (..), SatResult (..))

import Copilot.Theorem.IL
import Copilot.Theorem.Misc.SExpr

import Text.Printf

--------------------------------------------------------------------------------

newtype SmtLib = SmtLib (SExpr String)

instance Show SmtLib where
  show :: SmtLib -> String
show (SmtLib SExpr String
s) = SExpr String -> String
forall a. Show a => a -> String
show SExpr String
s

smtTy :: Type -> String
smtTy :: Type -> String
smtTy Type
Bool    = String
"Bool"
smtTy Type
Real    = String
"Real"
smtTy Type
_       = String
"Int"

--------------------------------------------------------------------------------

instance SmtFormat SmtLib where
  push :: SmtLib
push = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$ String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node String
"push" [String -> SExpr String
forall a. a -> SExpr a
atom String
"1"]
  pop :: SmtLib
pop = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$ String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node String
"pop" [String -> SExpr String
forall a. a -> SExpr a
atom String
"1"]
  checkSat :: SmtLib
checkSat = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$ String -> SExpr String
forall a. a -> SExpr a
singleton String
"check-sat"
  setLogic :: String -> SmtLib
setLogic String
"" = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$ SExpr String
blank
  setLogic String
l = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$ String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node String
"set-logic" [String -> SExpr String
forall a. a -> SExpr a
atom String
l]
  declFun :: String -> Type -> [Type] -> SmtLib
declFun String
name Type
retTy [Type]
args = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$
    String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node String
"declare-fun" [String -> SExpr String
forall a. a -> SExpr a
atom String
name, ([SExpr String] -> SExpr String
forall a. [SExpr a] -> SExpr a
list ([SExpr String] -> SExpr String) -> [SExpr String] -> SExpr String
forall a b. (a -> b) -> a -> b
$ (Type -> SExpr String) -> [Type] -> [SExpr String]
forall a b. (a -> b) -> [a] -> [b]
map (String -> SExpr String
forall a. a -> SExpr a
atom (String -> SExpr String)
-> (Type -> String) -> Type -> SExpr String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> String
smtTy) [Type]
args), String -> SExpr String
forall a. a -> SExpr a
atom (Type -> String
smtTy Type
retTy)]
  assert :: Expr -> SmtLib
assert Expr
c = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$ String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node String
"assert" [Expr -> SExpr String
expr Expr
c]

interpret :: String -> Maybe SatResult
interpret :: String -> Maybe SatResult
interpret String
"sat"   = SatResult -> Maybe SatResult
forall a. a -> Maybe a
Just SatResult
Sat
interpret String
"unsat" = SatResult -> Maybe SatResult
forall a. a -> Maybe a
Just SatResult
Unsat
interpret String
_       = SatResult -> Maybe SatResult
forall a. a -> Maybe a
Just SatResult
Unknown

--------------------------------------------------------------------------------

expr :: Expr -> SExpr String

expr :: Expr -> SExpr String
expr (ConstB Bool
v) = String -> SExpr String
forall a. a -> SExpr a
atom (String -> SExpr String) -> String -> SExpr String
forall a b. (a -> b) -> a -> b
$ if Bool
v then String
"true" else String
"false"
expr (ConstI Type
_ Integer
v) = String -> SExpr String
forall a. a -> SExpr a
atom (String -> SExpr String) -> String -> SExpr String
forall a b. (a -> b) -> a -> b
$ Integer -> String
forall a. Show a => a -> String
show Integer
v
expr (ConstR Double
v) = String -> SExpr String
forall a. a -> SExpr a
atom (String -> SExpr String) -> String -> SExpr String
forall a b. (a -> b) -> a -> b
$ String -> Double -> String
forall r. PrintfType r => String -> r
printf String
"%f" Double
v

expr (Ite Type
_ Expr
cond Expr
e1 Expr
e2) = String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node String
"ite" [Expr -> SExpr String
expr Expr
cond, Expr -> SExpr String
expr Expr
e1, Expr -> SExpr String
expr Expr
e2]

expr (FunApp Type
_ String
funName [Expr]
args) = String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node String
funName ([SExpr String] -> SExpr String) -> [SExpr String] -> SExpr String
forall a b. (a -> b) -> a -> b
$ (Expr -> SExpr String) -> [Expr] -> [SExpr String]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> SExpr String
expr [Expr]
args

expr (Op1 Type
_ Op1
op Expr
e) =
  String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node String
smtOp [Expr -> SExpr String
expr Expr
e]
  where
    smtOp :: String
smtOp = case Op1
op of
      Op1
Not   -> String
"not"
      Op1
Neg   -> String
"-"
      Op1
Abs   -> String
"abs"
      Op1
Exp   -> String
"exp"
      Op1
Sqrt  -> String
"sqrt"
      Op1
Log   -> String
"log"
      Op1
Sin   -> String
"sin"
      Op1
Tan   -> String
"tan"
      Op1
Cos   -> String
"cos"
      Op1
Asin  -> String
"asin"
      Op1
Atan  -> String
"atan"
      Op1
Acos  -> String
"acos"
      Op1
Sinh  -> String
"sinh"
      Op1
Tanh  -> String
"tanh"
      Op1
Cosh  -> String
"cosh"
      Op1
Asinh -> String
"asinh"
      Op1
Atanh -> String
"atanh"
      Op1
Acosh -> String
"acosh"

expr (Op2 Type
_ Op2
op Expr
e1 Expr
e2) =
  String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node String
smtOp [Expr -> SExpr String
expr Expr
e1, Expr -> SExpr String
expr Expr
e2]
  where
    smtOp :: String
smtOp = case Op2
op of
      Op2
Eq   -> String
"="
      Op2
Le   -> String
"<="
      Op2
Lt   -> String
"<"
      Op2
Ge   -> String
">="
      Op2
Gt   -> String
">"
      Op2
And  -> String
"and"
      Op2
Or   -> String
"or"
      Op2
Add  -> String
"+"
      Op2
Sub  -> String
"-"
      Op2
Mul  -> String
"*"
      Op2
Mod  -> String
"mod"
      Op2
Fdiv -> String
"/"
      Op2
Pow  -> String
"^"

expr (SVal Type
_ String
f SeqIndex
ix) = String -> SExpr String
forall a. a -> SExpr a
atom (String -> SExpr String) -> String -> SExpr String
forall a b. (a -> b) -> a -> b
$ case SeqIndex
ix of
  Fixed Integer
i -> String
f String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i
  Var Integer
off -> String
f String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
off

--------------------------------------------------------------------------------