module Cryptol.TypeCheck.Solver.CrySAT
( withScope, withSolver
, assumeProps, simplifyProps, getModel
, check
, Solver, logger, getIntervals
, DefinedProp(..)
, debugBlock
, DebugLog(..)
, knownDefined, numericRight
, minimizeContradictionSimpDef
) where
import qualified Cryptol.TypeCheck.AST as Cry
import Cryptol.TypeCheck.InferTypes(Goal(..), SolverConfig(..), Solved(..))
import qualified Cryptol.TypeCheck.Subst as Cry
import Cryptol.TypeCheck.Solver.Numeric.AST
import Cryptol.TypeCheck.Solver.Numeric.Fin
import Cryptol.TypeCheck.Solver.Numeric.ImportExport
import Cryptol.TypeCheck.Solver.Numeric.Interval
import Cryptol.TypeCheck.Solver.Numeric.Defined
import Cryptol.TypeCheck.Solver.Numeric.Simplify
import Cryptol.TypeCheck.Solver.Numeric.NonLin
import Cryptol.TypeCheck.Solver.Numeric.SMT
import Cryptol.Utils.PP
import Cryptol.Utils.Panic ( panic )
import MonadLib
import Data.Maybe ( fromMaybe )
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Foldable ( any, all )
import qualified Data.Set as Set
import Data.IORef ( IORef, newIORef, readIORef, modifyIORef',
atomicModifyIORef' )
import Prelude hiding (any,all)
import qualified SimpleSMT as SMT
newtype SimpProp = SimpProp { unSimpProp :: Prop }
simpProp :: Prop -> SimpProp
simpProp p = SimpProp (crySimplify p)
class HasProp a where getProp :: a -> Cry.Prop
instance HasProp Cry.Prop where getProp = id
instance HasProp Goal where getProp = goal
data DefinedProp a = DefinedProp
{ dpData :: a
, dpSimpProp :: SimpProp
, dpSimpExprProp :: Prop
}
knownDefined :: (a,Prop) -> DefinedProp a
knownDefined (a,p) = DefinedProp
{ dpData = a, dpSimpProp = simpProp p, dpSimpExprProp = p }
numericRight :: Goal -> Either Goal (Goal, Prop)
numericRight g = case exportProp (goal g) of
Just p -> Right (g, p)
Nothing -> Left g
simplifyProps :: Solver -> [DefinedProp Goal] -> IO [Goal]
simplifyProps s props =
debugBlock s "Simplifying properties" $
withScope s (go [] (eliminateSimpleGEQ props))
where
go survived [] = return survived
go survived (DefinedProp { dpSimpProp = SimpProp PTrue } : more) =
go survived more
go survived (p : more) =
case dpSimpProp p of
SimpProp PTrue -> go survived more
SimpProp p' ->
do mbProved <- withScope s $
do mapM_ (assert s) more
e <- getIntervals s
case e of
Left _ -> return Nothing
Right ints -> do b <- prove s p'
return (Just (ints,b))
case mbProved of
Just (_,True) -> go survived more
Just (ints,False) ->
debugLog s ("Using the fin solver:" ++ show (pp (goal (dpData p)))) >>
case cryIsFin ints (dpData p) of
Solved _ gs' ->
do debugLog s "solved"
let more' = [ knownDefined g | Right g <- map numericRight gs' ]
go survived (more' ++ more)
Unsolved ->
do debugLog s "unsolved"
assert s p
go (dpData p : survived) more
Unsolvable ->
do debugLog s "unsolvable"
go (dpData p:survived) more
Nothing -> go (dpData p:survived) more
eliminateSimpleGEQ :: [DefinedProp a] -> [DefinedProp a]
eliminateSimpleGEQ = go Map.empty []
where
go geqs other (g : rest) =
case dpSimpExprProp g of
K a :== K b
| a == b -> go geqs other rest
_ :>= K (Nat 0) ->
go geqs other rest
K (Nat k1) :>= K (Nat k2)
| k1 >= k2 -> go geqs other rest
Var v :>= K (Nat k2) ->
go (addUpperBound v (k2,g) geqs) other rest
_ -> go geqs (g:other) rest
go geqs other [] = [ g | (_,g) <- Map.elems geqs ] ++ other
addUpperBound var g = Map.insertWith cmp var g
where
cmp a b | fst a > fst b = a
| otherwise = b
assumeProps :: Solver -> [Cry.Prop] -> IO [SimpProp]
assumeProps s props =
do let ps = [ (p,p') | p <- props
, Just p' <- [exportProp p] ]
let defPs = [ (p,cryDefinedProp p') | (p,p') <- ps ]
let simpProps = map knownDefined (defPs ++ ps)
mapM_ (assert s) simpProps
return (map dpSimpProp simpProps)
minimizeContradictionSimpDef :: HasProp a => Solver -> [DefinedProp a] -> IO [a]
minimizeContradictionSimpDef s ps = start [] ps
where
start bad todo =
do res <- SMT.check (solver s)
case res of
SMT.Unsat -> return (map dpData bad)
_ -> do solPush s
go bad [] todo
go _ _ [] = panic "minimizeContradiction"
$ ("No contradiction" : map (show . ppProp . dpSimpExprProp) ps)
go bad prev (d : more) =
do assert s d
res <- SMT.check (solver s)
case res of
SMT.Unsat -> do solPop s
assert s d
start (d : bad) prev
_ -> go bad (d : prev) more
getModel :: Solver -> [Cry.Prop] -> IO (Maybe Cry.Subst)
getModel s props = withScope s $
do ps <- assumeProps s props
res <- SMT.check (solver s)
let vars = Set.toList $ Set.unions $ map (cryPropFVS . unSimpProp) ps
case res of
SMT.Sat ->
do vs <- getVals (solver s) vars
let su1 = fmap K vs
ps1 = [ fromMaybe p (apSubst su1 p) | SimpProp p <- ps ]
ok p = case crySimplify p of
PTrue -> True
_ -> False
su2 = Cry.listSubst
[ (x, numTy v) | (UserName x, v) <- Map.toList vs ]
return (guard (all ok ps1) >> return su2)
_ -> return Nothing
where
numTy Inf = Cry.tInf
numTy (Nat k) = Cry.tNum k
data Solver = Solver
{ solver :: SMT.Solver
, declared :: IORef VarInfo
, logger :: SMT.Logger
}
data VarInfo = VarInfo
{ curScope :: Scope
, otherScopes :: [Scope]
} deriving Show
data Scope = Scope
{ scopeNames :: [Name]
, scopeNonLinS :: NonLinS
, scopeIntervals :: Either Cry.TVar (Map.Map Cry.TVar Interval)
, scopeAsserted :: [Cry.Prop]
} deriving Show
scopeEmpty :: Scope
scopeEmpty = Scope { scopeNames = [], scopeNonLinS = initialNonLinS
, scopeIntervals = Right Map.empty, scopeAsserted = [] }
scopeElem :: Name -> Scope -> Bool
scopeElem x Scope { .. } = x `elem` scopeNames
scopeInsert :: Name -> Scope -> Scope
scopeInsert x Scope { .. } = Scope { scopeNames = x : scopeNames, .. }
scopeAssertNew :: Cry.Prop -> Scope -> Scope
scopeAssertNew prop Scope { .. } =
Scope { scopeIntervals = ints'
, scopeAsserted = props
, .. }
where
props = prop : scopeAsserted
ints' = case scopeIntervals of
Left tv -> Left tv
Right ints -> case computePropIntervals ints props of
NoChange -> scopeIntervals
NewIntervals is -> Right is
InvalidInterval tv -> Left tv
scopeAssertSimpProp :: SimpProp -> Scope -> ([SimpProp],Scope)
scopeAssertSimpProp (SimpProp p) Scope { .. } =
let (ps1,s1) = nonLinProp scopeNonLinS p
in (map SimpProp ps1, Scope { scopeNonLinS = s1, .. })
scopeAssert :: HasProp a => DefinedProp a -> Scope -> ([SimpProp],Scope)
scopeAssert DefinedProp { .. } s =
let (ps1,s1) = scopeAssertSimpProp dpSimpProp s
in (ps1,scopeAssertNew (getProp dpData) s1)
viEmpty :: VarInfo
viEmpty = VarInfo { curScope = scopeEmpty, otherScopes = [] }
viElem :: Name -> VarInfo -> Bool
viElem x VarInfo { .. } = any (x `scopeElem`) (curScope : otherScopes)
viInsert :: Name -> VarInfo -> VarInfo
viInsert x VarInfo { .. } = VarInfo { curScope = scopeInsert x curScope, .. }
viAssertSimpProp :: SimpProp -> VarInfo -> (VarInfo, [SimpProp])
viAssertSimpProp p VarInfo { .. } = ( VarInfo { curScope = s1, .. }, p1)
where (p1, s1) = scopeAssertSimpProp p curScope
viAssert :: HasProp a => DefinedProp a -> VarInfo -> (VarInfo, [SimpProp])
viAssert d VarInfo { .. } = (VarInfo { curScope = s1, .. },p1)
where (p1, s1) = scopeAssert d curScope
viPush :: VarInfo -> VarInfo
viPush VarInfo { .. } =
VarInfo { curScope = scopeEmpty { scopeNonLinS = scopeNonLinS curScope
, scopeAsserted = scopeAsserted curScope
, scopeIntervals = scopeIntervals curScope }
, otherScopes = curScope : otherScopes
}
viPop :: VarInfo -> VarInfo
viPop VarInfo { .. } = case otherScopes of
c : cs -> VarInfo { curScope = c, otherScopes = cs }
_ -> panic "viPop" ["no more scopes"]
viUnmarkedNames :: VarInfo -> [ Name ]
viUnmarkedNames VarInfo { .. } = concatMap scopeNames scopes
where scopes = curScope : otherScopes
getIntervals :: Solver -> IO (Either Cry.TVar (Map Cry.TVar Interval))
getIntervals Solver { .. } =
do vi <- readIORef declared
return (scopeIntervals (curScope vi))
getNLSubst :: Solver -> IO Subst
getNLSubst Solver { .. } =
do VarInfo { .. } <- readIORef declared
return $ nonLinSubst $ scopeNonLinS curScope
withSolver :: SolverConfig -> (Solver -> IO a) -> IO a
withSolver SolverConfig { .. } k =
do logger <- if solverVerbose > 0 then SMT.newLogger 0 else return quietLogger
let smtDbg = if solverVerbose > 1 then Just logger else Nothing
solver <- SMT.newSolver solverPath solverArgs smtDbg
_ <- SMT.setOptionMaybe solver ":global-decls" "false"
SMT.setLogic solver "QF_LIA"
declared <- newIORef viEmpty
a <- k Solver { .. }
_ <- SMT.stop solver
return a
where
quietLogger = SMT.Logger { SMT.logMessage = \_ -> return ()
, SMT.logLevel = return 0
, SMT.logSetLevel= \_ -> return ()
, SMT.logTab = return ()
, SMT.logUntab = return ()
}
solPush :: Solver -> IO ()
solPush Solver { .. } =
do SMT.push solver
SMT.logTab logger
modifyIORef' declared viPush
solPop :: Solver -> IO ()
solPop Solver { .. } =
do modifyIORef' declared viPop
SMT.logUntab logger
SMT.pop solver
withScope :: Solver -> IO a -> IO a
withScope s k =
do solPush s
a <- k
solPop s
return a
declareVar :: Solver -> Name -> IO ()
declareVar s@Solver { .. } a =
do done <- fmap (a `viElem`) (readIORef declared)
unless done $
do e <- SMT.declare solver (smtName a) SMT.tInt
let fin_a = smtFinName a
fin <- SMT.declare solver fin_a SMT.tBool
SMT.assert solver (SMT.geq e (SMT.int 0))
nlSu <- getNLSubst s
modifyIORef' declared (viInsert a)
case Map.lookup a nlSu of
Nothing -> return ()
Just e' ->
do let finDef = crySimplify (Fin e')
mapM_ (declareVar s) (Set.toList (cryPropFVS finDef))
SMT.assert solver $
SMT.eq fin (ifPropToSmtLib (desugarProp finDef))
assert :: HasProp a => Solver -> DefinedProp a -> IO ()
assert _ DefinedProp { dpSimpProp = SimpProp PTrue } = return ()
assert s@Solver { .. } def@DefinedProp { dpSimpProp = p } =
do debugLog s ("Assuming: " ++ show (ppProp (unSimpProp p)))
a <- getIntervals s
debugLog s ("Intervals before:" ++ show (either pp ppIntervals a))
ps1' <- atomicModifyIORef' declared (viAssert def)
b <- getIntervals s
debugLog s ("Intervals after:" ++ show (either pp ppIntervals b))
let ps1 = map unSimpProp ps1'
vs = Set.toList $ Set.unions $ map cryPropFVS ps1
mapM_ (declareVar s) vs
mapM_ (SMT.assert solver . ifPropToSmtLib . desugarProp) ps1
assertSimpProp :: Solver -> SimpProp -> IO ()
assertSimpProp _ (SimpProp PTrue) = return ()
assertSimpProp s@Solver { .. } p@(SimpProp p0) =
do debugLog s ("Assuming: " ++ show (ppProp p0))
ps1' <- atomicModifyIORef' declared (viAssertSimpProp p)
let ps1 = map unSimpProp ps1'
vs = Set.toList $ Set.unions $ map cryPropFVS ps1
mapM_ (declareVar s) vs
mapM_ (SMT.assert solver . ifPropToSmtLib . desugarProp) ps1
prove :: Solver -> Prop -> IO Bool
prove _ PTrue = return True
prove s@Solver { .. } p =
debugBlock s ("Proving: " ++ show (ppProp p)) $
withScope s $
do assertSimpProp s (simpProp (Not p))
res <- SMT.check solver
case res of
SMT.Unsat -> debugLog s "Proved" >> return True
SMT.Unknown -> debugLog s "Not proved" >> return False
SMT.Sat -> debugLog s "Not proved" >> return False
check :: Solver -> IO (Maybe (Subst, [Prop]))
check s@Solver { .. } =
do e <- getIntervals s
case e of
Left tv ->
do debugLog s ("Invalid interval: " ++ show (pp tv))
return Nothing
Right ints ->
do debugLog s ("Intervals:" ++ show (ppIntervals ints))
res <- SMT.check solver
case res of
SMT.Unsat ->
do debugLog s "Not satisfiable"
return Nothing
SMT.Unknown ->
do debugLog s "Unknown"
return (Just (Map.empty, []))
SMT.Sat ->
do debugLog s "Satisfiable"
(impMap,sideConds) <- debugBlock s "Computing improvements"
(getImpSubst s)
return (Just (impMap, sideConds))
getImpSubst :: Solver -> IO (Subst,[Prop])
getImpSubst s@Solver { .. } =
do names <- viUnmarkedNames `fmap` readIORef declared
m <- getVals solver names
(impSu,sideConditions) <- cryImproveModel solver logger m
nlSu <- getNLSubst s
let isNonLinName (SysName {}) = True
isNonLinName (UserName {}) = False
(nlFacts, vFacts) = Map.partitionWithKey (\k _ -> isNonLinName k) impSu
(vV, vNL) = Map.partition noNLVars vFacts
nlSu1 = fmap (doAppSubst vV) nlSu
(vNL_su,vNL_eqs) = Map.partitionWithKey goodDef
$ fmap (doAppSubst nlSu1) vNL
nlSu2 = fmap (doAppSubst vNL_su) nlSu1
nlLkp x = case Map.lookup x nlSu2 of
Just e -> e
Nothing -> panic "getImpSubst"
[ "Missing NL variable:", show x ]
allSides =
[ Var a :== e | (a,e) <- Map.toList vNL_eqs ] ++
[ nlLkp x :== doAppSubst nlSu2 e | (x,e) <- Map.toList nlFacts ] ++
[ doAppSubst nlSu2 si | si <- sideConditions ]
theImpSu = composeSubst vNL_su vV
debugBlock s "Improvments" $
do debugBlock s "substitution" $
mapM_ (debugLog s . dump) (Map.toList theImpSu)
debugBlock s "side-conditions" $ debugLog s allSides
return (theImpSu, allSides)
where
goodDef k e = not (k `Set.member` cryExprFVS e)
isNLVar (SysName _) = True
isNLVar _ = False
noNLVars e = all (not . isNLVar) (cryExprFVS e)
dump (x,e) = show (ppProp (Var x :== e))
debugBlock :: Solver -> String -> IO a -> IO a
debugBlock s@Solver { .. } name m =
do debugLog s name
SMT.logTab logger
a <- m
SMT.logUntab logger
return a
class DebugLog t where
debugLog :: Solver -> t -> IO ()
debugLogList :: Solver -> [t] -> IO ()
debugLogList s ts = case ts of
[] -> debugLog s "(none)"
_ -> mapM_ (debugLog s) ts
instance DebugLog Char where
debugLog s x = SMT.logMessage (logger s) (show x)
debugLogList s x = SMT.logMessage (logger s) x
instance DebugLog a => DebugLog [a] where
debugLog = debugLogList
instance DebugLog a => DebugLog (Maybe a) where
debugLog s x = case x of
Nothing -> debugLog s "(nothing)"
Just a -> debugLog s a
instance DebugLog Doc where
debugLog s x = debugLog s (show x)
instance DebugLog Cry.Type where
debugLog s x = debugLog s (pp x)
instance DebugLog Goal where
debugLog s x = debugLog s (goal x)
instance DebugLog Cry.Subst where
debugLog s x = debugLog s (pp x)
instance DebugLog Prop where
debugLog s x = debugLog s (ppProp x)