{-# LANGUAGE Safe #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# Language FlexibleInstances #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Cryptol.TypeCheck.Solver.SMT
(
Solver
, withSolver
, isNumeric
, debugBlock
, debugLog
, proveImp
, checkUnsolvable
, tryGetModel
, shrinkModel
) where
import SimpleSMT (SExpr)
import qualified SimpleSMT as SMT
import Data.Map ( Map )
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.Maybe(catMaybes)
import Data.List(partition)
import Control.Exception
import Control.Monad(msum,zipWithM,void)
import Data.Char(isSpace)
import Text.Read(readMaybe)
import qualified System.IO.Strict as StrictIO
import System.FilePath((</>))
import System.Directory(doesFileExist)
import Cryptol.Prelude(cryptolTcContents)
import Cryptol.TypeCheck.Type
import Cryptol.TypeCheck.InferTypes
import Cryptol.TypeCheck.Solver.InfNat(Nat'(..))
import Cryptol.TypeCheck.TypePat hiding ((~>),(~~>))
import Cryptol.TypeCheck.Subst(Subst)
import Cryptol.Utils.Panic
import Cryptol.Utils.PP
data Solver = Solver
{ solver :: SMT.Solver
, logger :: SMT.Logger
}
withSolver :: SolverConfig -> (Solver -> IO a) -> IO a
withSolver SolverConfig{ .. } =
bracket
(do logger <- if solverVerbose > 0 then SMT.newLogger 0
else return quietLogger
let smtDbg = if solverVerbose > 1 then Just logger else Nothing
solver <- SMT.newSolver solverPath solverArgs smtDbg
_ <- SMT.setOptionMaybe solver ":global-decls" "false"
let sol = Solver { .. }
loadTcPrelude sol solverPreludePath
return sol)
(\s -> void $ SMT.stop (solver s))
where
quietLogger = SMT.Logger { SMT.logMessage = \_ -> return ()
, SMT.logLevel = return 0
, SMT.logSetLevel= \_ -> return ()
, SMT.logTab = return ()
, SMT.logUntab = return ()
}
loadTcPrelude :: Solver -> [FilePath] -> IO ()
loadTcPrelude s [] = loadString s cryptolTcContents
loadTcPrelude s (p : ps) =
do let file = p </> "CryptolTC.z3"
yes <- doesFileExist file
if yes then loadFile s file
else loadTcPrelude s ps
loadFile :: Solver -> FilePath -> IO ()
loadFile s file = loadString s =<< StrictIO.readFile file
loadString :: Solver -> String -> IO ()
loadString s str = go (dropComments str)
where
go txt
| all isSpace txt = return ()
| otherwise =
case SMT.readSExpr txt of
Just (e,rest) -> SMT.command (solver s) e >> go rest
Nothing -> panic "loadFile" [ "Failed to parse SMT file."
, txt
]
dropComments = unlines . map dropComment . lines
dropComment xs = case break (== ';') xs of
(as,_:_) -> as
_ -> xs
debugBlock :: Solver -> String -> IO a -> IO a
debugBlock s@Solver { .. } name m =
do debugLog s name
SMT.logTab logger
a <- m
SMT.logUntab logger
return a
class DebugLog t where
debugLog :: Solver -> t -> IO ()
debugLogList :: Solver -> [t] -> IO ()
debugLogList s ts = case ts of
[] -> debugLog s "(none)"
_ -> mapM_ (debugLog s) ts
instance DebugLog Char where
debugLog s x = SMT.logMessage (logger s) (show x)
debugLogList s x = SMT.logMessage (logger s) x
instance DebugLog a => DebugLog [a] where
debugLog = debugLogList
instance DebugLog a => DebugLog (Maybe a) where
debugLog s x = case x of
Nothing -> debugLog s "(nothing)"
Just a -> debugLog s a
instance DebugLog Doc where
debugLog s x = debugLog s (show x)
instance DebugLog Type where
debugLog s x = debugLog s (pp x)
instance DebugLog Goal where
debugLog s x = debugLog s (goal x)
instance DebugLog Subst where
debugLog s x = debugLog s (pp x)
proveImp :: Solver -> [Prop] -> [Goal] -> IO [Goal]
proveImp sol ps gs0 =
debugBlock sol "PROVE IMP" $
do let gs1 = concatMap flatGoal gs0
(gs,rest) = partition (isNumeric . goal) gs1
numAsmp = filter isNumeric (concatMap pSplitAnd ps)
vs = Set.toList (fvs (numAsmp, map goal gs))
tvs <- debugBlock sol "VARIABLES" $
do SMT.push (solver sol)
Map.fromList <$> zipWithM (declareVar sol) [ 0 .. ] vs
debugBlock sol "ASSUMPTIONS" $
mapM_ (assume sol tvs) numAsmp
gs' <- mapM (prove sol tvs) gs
SMT.pop (solver sol)
return (catMaybes gs' ++ rest)
checkUnsolvable :: Solver -> [Goal] -> IO Bool
checkUnsolvable sol gs0 =
debugBlock sol "CHECK UNSOLVABLE" $
do let ps = filter isNumeric
$ map goal
$ concatMap flatGoal gs0
vs = Set.toList (fvs ps)
tvs <- debugBlock sol "VARIABLES" $
do push sol
Map.fromList <$> zipWithM (declareVar sol) [ 0 .. ] vs
ans <- unsolvable sol tvs ps
pop sol
return ans
tryGetModel :: Solver -> [TVar] -> [Prop] -> IO (Maybe [(TVar,Nat')])
tryGetModel sol as ps =
debugBlock sol "TRY GET MODEL" $
do push sol
tvs <- Map.fromList <$> zipWithM (declareVar sol) [ 0 .. ] as
mapM_ (assume sol tvs) ps
sat <- SMT.check (solver sol)
su <- case sat of
SMT.Sat ->
case as of
[] -> return (Just [])
_ -> do res <- SMT.getExprs (solver sol) (Map.elems tvs)
let parse x = do e <- Map.lookup x tvs
t <- parseNum =<< lookup e res
return (x, t)
return (mapM parse as)
_ -> return Nothing
pop sol
return su
where
parseNum a
| SMT.Other s <- a
, SMT.List [con,val,isFin,isErr] <- s
, SMT.Atom "mk-infnat" <- con
, SMT.Atom "false" <- isErr
, SMT.Atom fin <- isFin
, SMT.Atom v <- val
, Just n <- readMaybe v
= Just (if fin == "false" then Inf else Nat n)
parseNum _ = Nothing
shrinkModel :: Solver -> [TVar] -> [Prop] -> [(TVar,Nat')] -> IO [(TVar,Nat')]
shrinkModel sol as ps0 mdl = go [] ps0 mdl
where
go done ps ((x,Nat k) : more) =
do k1 <- shrink1 ps x k
go ((x,Nat k1) : done) ((tNum k1 >== TVar x) : ps) more
go done ps ((x,i) : more) = go ((x,i) : done) ps more
go done _ [] = return done
shrink1 ps x k
| k == 0 = return 0
| otherwise =
do let k1 = div k 2
p1 = tNum k1 >== TVar x
mb <- tryGetModel sol as (p1 : ps)
case mb of
Nothing -> return k
Just newMdl ->
case lookup x newMdl of
Just (Nat k2) -> shrink1 ps x k2
_ -> panic "shrink" ["model is missing variable", show x]
push :: Solver -> IO ()
push sol = SMT.push (solver sol)
pop :: Solver -> IO ()
pop sol = SMT.pop (solver sol)
declareVar :: Solver -> Int -> TVar -> IO (TVar, SExpr)
declareVar s x v =
do let name = (if isFreeTV v then "fv" else "kv") ++ show x
e <- SMT.declare (solver s) name cryInfNat
SMT.assert (solver s) (SMT.fun "cryVar" [ e ])
return (v,e)
assume :: Solver -> TVars -> Prop -> IO ()
assume s tvs p = SMT.assert (solver s) (SMT.fun "cryAssume" [ toSMT tvs p ])
prove :: Solver -> TVars -> Goal -> IO (Maybe Goal)
prove sol tvs g =
debugBlock sol "PROVE" $
do let s = solver sol
push sol
SMT.assert s (SMT.fun "cryProve" [ toSMT tvs (goal g) ])
res <- SMT.check s
pop sol
case res of
SMT.Unsat -> return Nothing
_ -> return (Just g)
unsolvable :: Solver -> TVars -> [Prop] -> IO Bool
unsolvable sol tvs ps =
debugBlock sol "UNSOLVABLE" $
do SMT.push (solver sol)
mapM_ (assume sol tvs) ps
res <- SMT.check (solver sol)
SMT.pop (solver sol)
case res of
SMT.Unsat -> return True
_ -> return False
flatGoal :: Goal -> [Goal]
flatGoal g = [ g { goal = p } | p <- pSplitAnd (goal g) ]
isNumeric :: Prop -> Bool
isNumeric ty = matchDefault False $ msum [ is (|=|), is (|/=|), is (|>=|), is aFin ]
where
is f = f ty >> return True
type TVars = Map TVar SExpr
cryInfNat :: SExpr
cryInfNat = SMT.const "InfNat"
toSMT :: TVars -> Type -> SExpr
toSMT tvs ty = matchDefault (panic "toSMT" [ "Unexpected type", show ty ])
$ msum $ map (\f -> f tvs ty)
[ aInf ~> "cryInf"
, aNat ~> "cryNat"
, aFin ~> "cryFin"
, (|=|) ~> "cryEq"
, (|/=|) ~> "cryNeq"
, (|>=|) ~> "cryGeq"
, aAnd ~> "cryAnd"
, aTrue ~> "cryTrue"
, anAdd ~> "cryAdd"
, (|-|) ~> "crySub"
, aMul ~> "cryMul"
, (|^|) ~> "cryExp"
, (|/|) ~> "cryDiv"
, (|%|) ~> "cryMod"
, aMin ~> "cryMin"
, aMax ~> "cryMax"
, aWidth ~> "cryWidth"
, aCeilDiv ~> "cryCeilDiv"
, aCeilMod ~> "cryCeilMod"
, aLenFromThen ~> "cryLenFromThen"
, aLenFromThenTo ~> "cryLenFromThenTo"
, anError KNum ~> "cryErr"
, anError KProp ~> "cryErrProp"
, aTVar ~> "(unused)"
]
(~>) :: Mk a => (Type -> Match a) -> String -> TVars -> Type -> Match SExpr
(m ~> f) tvs t = m t >>= \a -> return (mk tvs f a)
class Mk t where
mk :: TVars -> String -> t -> SExpr
instance Mk () where
mk _ f _ = SMT.const f
instance Mk Integer where
mk _ f x = SMT.fun f [ SMT.int x ]
instance Mk TVar where
mk tvs _ x = tvs Map.! x
instance Mk Type where
mk tvs f x = SMT.fun f [toSMT tvs x]
instance Mk TCErrorMessage where
mk _ f _ = SMT.fun f []
instance Mk (Type,Type) where
mk tvs f (x,y) = SMT.fun f [ toSMT tvs x, toSMT tvs y]
instance Mk (Type,Type,Type) where
mk tvs f (x,y,z) = SMT.fun f [ toSMT tvs x, toSMT tvs y, toSMT tvs z ]