module Cryptol.TypeCheck.Solver.FinOrd
( OrdFacts, AssertResult(..)
, noFacts, addFact
, isKnownLeq
, knownInterval
, ordFactsToGoals
, ordFactsToProps
, dumpDot
, dumpDoc
, IsFact(..)
) where
import Cryptol.TypeCheck.Solver.InfNat
import Cryptol.TypeCheck.Solver.Interval
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.InferTypes
import Cryptol.TypeCheck.TypeMap
import Cryptol.Parser.Position
import qualified Cryptol.Utils.Panic as P
import Cryptol.Utils.PP(Doc,pp,vcat,text,(<+>))
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Map (Map)
import qualified Data.Map as Map
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.Maybe(fromMaybe,maybeToList)
import Control.Monad(guard)
panic :: String -> [String] -> a
panic x y = P.panic ("Cryptol.TypeCheck.Solver.FinOrd." ++ x) y
class IsFact t where
factProp :: t -> Prop
factChangeProp :: t -> Prop -> t
factSource :: t -> EdgeSrc
instance IsFact Goal where
factProp = goal
factChangeProp g p = g { goal = p }
factSource g = FromGoal (goalSource g) (goalRange g)
instance IsFact Prop where
factProp = id
factChangeProp _ x = x
factSource _ = NoGoal
addFact :: IsFact t => t -> OrdFacts -> AssertResult
addFact g m =
case tNoUser (factProp g) of
TCon (PC PFin) [t] ->
let (x,m1) = nameTerm m t
in insertFin src x m1
TCon (PC PGeq) [t1,t2] ->
let (x,m1) = nameTerm m t1
(y,m2) = nameTerm m1 t2
in insertLeq src y x m2
_ -> OrdCannot
where
src = factSource g
data AssertResult
= OrdAdded OrdFacts
| OrdAlreadyKnown
| OrdImprove Type Type
| OrdImpossible
| OrdCannot
deriving Show
isKnownLeq :: OrdFacts -> Type -> Type -> Bool
isKnownLeq m t1 t2 =
let (x,m1) = nameTerm m t1
(y,m2) = nameTerm m1 t2
in isLeq m2 x y
knownInterval :: OrdFacts -> Type -> Interval
knownInterval m t =
fromMaybe anything $
do a <- numType t
return $
case (cvtLower (getLowerBound m a), cvtUpper (getUpperBound m a)) of
(x,(y,z)) -> Interval { lowerBound = x
, upperBound = y
, isFinite = z
}
where
cvtLower (Nat'' x) = Nat x
cvtLower FinNat'' = panic "knownInterval"
[ "`FinNat` used as a lower bound for:"
, show t
]
cvtLower Inf'' = Inf
cvtUpper (Nat'' x) = (Nat x, True)
cvtUpper FinNat'' = (Inf, True)
cvtUpper Inf'' = (Inf, False)
ordFactsToGoals :: OrdFacts -> [Goal]
ordFactsToGoals = ordFactsToList onlyGoals
where
onlyGoals (FromGoal c r) = Just $ \p -> Goal { goalSource = c
, goalRange = r
, goal = p }
onlyGoals NoGoal = Nothing
ordFactsToProps :: OrdFacts -> [Prop]
ordFactsToProps = ordFactsToList (\_ -> Just id)
ordFactsToList :: (EdgeSrc -> Maybe (Prop -> a)) -> OrdFacts -> [a]
ordFactsToList consider (OrdFacts m ts) = concatMap getGoals (Map.toList m)
where
getGoals (ty, es) =
do (lower,notNum) <-
case ty of
NTNat FinNat'' -> []
NTNat Inf'' -> []
NTNat (Nat'' n) -> guard (n > 0) >> [ (tNum n, False) ]
NTVar v -> [ (TVar v, True) ]
NTNamed x -> [ (getNamed x, True) ]
Edge { target = t, eSource = src } <- Set.toList (above es)
f <- maybeToList (consider src)
g <- case t of
NTNat FinNat'' -> guard notNum >> [ pFin lower ]
NTNat Inf'' -> []
NTNat (Nat'' n) -> guard notNum >> [ tNum n >== lower ]
NTVar x -> [ TVar x >== lower ]
NTNamed x -> [ getNamed x >== lower ]
return (f g)
getNamed x = case IntMap.lookup x (nameToType ts) of
Just t -> t
Nothing -> panic "ordFactsToList" [ "Missing name" ]
data OrdFacts = OrdFacts (Map NumType Edges) OrdTerms
deriving Show
data OrdTerms = OrdTerms
{ typeToName :: TypeMap Int
, nameToType :: IntMap Type
, nextId :: Int
} deriving Show
data Edges = Edges { above :: Set Edge, below :: Set Edge }
deriving Show
data Edge = Edge { target :: NumType, eSource :: EdgeSrc }
deriving Show
data EdgeSrc = FromGoal ConstraintSource Range
| NoGoal
deriving Show
instance Eq Edge where
x == y = target x == target y
instance Ord Edge where
compare x y = compare (target x) (target y)
data Nat'' = Nat'' Integer
| FinNat''
| Inf''
deriving (Eq,Ord,Show)
data NumType = NTNat Nat'' | NTVar TVar | NTNamed Int
deriving (Eq,Ord,Show)
nameTerm :: OrdFacts -> Type -> (NumType, OrdFacts)
nameTerm fs t | Just n <- numType t = (n, fs)
nameTerm fs@(OrdFacts xs ts) t =
case lookupTM t (typeToName ts) of
Just n -> (NTNamed n, fs)
Nothing ->
let name = nextId ts
ts1 = OrdTerms { nameToType = IntMap.insert name t (nameToType ts)
, typeToName = insertTM t name (typeToName ts)
, nextId = name + 1
}
in (NTNamed name, OrdFacts xs ts1)
zero :: NumType
zero = NTNat (Nat'' 0)
fin :: NumType
fin = NTNat FinNat''
inf :: NumType
inf = NTNat Inf''
numType :: Type -> Maybe NumType
numType ty =
case tNoUser ty of
TCon (TC (TCNum n)) _ -> Just $ NTNat $ Nat'' n
TCon (TC TCInf) _ -> Just $ NTNat $ Inf''
TVar x | kindOf x == KNum -> Just $ NTVar x
_ -> Nothing
fromNumType :: OrdTerms -> NumType -> Maybe Type
fromNumType _ (NTVar x) = Just (TVar x)
fromNumType _ (NTNat Inf'') = Just tInf
fromNumType _ (NTNat FinNat'') = Nothing
fromNumType _ (NTNat (Nat'' x)) = Just (tNum x)
fromNumType ts (NTNamed x) = IntMap.lookup x (nameToType ts)
isVar :: NumType -> Bool
isVar (NTVar _) = True
isVar (NTNamed _) = True
isVar (NTNat _) = False
noFacts :: OrdFacts
noFacts = snd $ insNode inf
$ snd $ insNode fin
$ snd $ insNode zero
$ OrdFacts Map.empty OrdTerms { typeToName = emptyTM
, nameToType = IntMap.empty
, nextId = 0
}
noEdges :: Edges
noEdges = Edges { above = Set.empty, below = Set.empty }
imm :: (Edges -> Set Edge) -> OrdFacts -> NumType -> Set Edge
imm dir (OrdFacts m _) t = maybe Set.empty dir (Map.lookup t m)
reachable :: OrdFacts -> NumType -> NumType -> Bool
reachable m smaller larger =
search Set.empty (Set.singleton Edge { target = smaller, eSource = NoGoal })
where
search visited todo
| Just (Edge { target = term }, rest) <- Set.minView todo =
if term == larger
then True
else let new = imm above m term
vis = Set.insert term visited
notDone = Set.filter (not . (`Set.member` vis) . target)
in search vis (notDone new `Set.union` notDone rest)
| otherwise = False
link :: EdgeSrc -> (NumType,Edges) -> (NumType,Edges)
-> OrdFacts -> (Edges,Edges,OrdFacts)
link src (lower, ledges) (upper, uedges) m0 =
let uus = Set.mapMonotonic target (above uedges)
lls = Set.mapMonotonic target (below ledges)
rm x = Set.filter (not . (`Set.member` x) . target)
newLedges = ledges { above = Set.insert Edge { target = upper
, eSource = src }
$ rm (Set.insert inf uus)
$ above ledges
}
newUedges = uedges { below = Set.insert Edge { target = lower
, eSource = src }
$ rm lls
$ below uedges
}
del x = Set.delete Edge { target = x, eSource = NoGoal }
adjust f t (OrdFacts m xs) = OrdFacts (Map.adjust f t m) xs
insert k x (OrdFacts m xs) = OrdFacts (Map.insert k x m) xs
adjAbove = adjust (\e -> e { above = del upper (above e) })
adjBelow = adjust (\e -> e { below = del lower (below e) })
fold f xs x = Set.fold f x xs
in ( newLedges
, newUedges
, insert lower newLedges
$ insert upper newUedges
$ fold adjAbove lls
$ fold adjBelow uus
m0
)
insNode :: NumType -> OrdFacts -> (Edges, OrdFacts)
insNode t model@(OrdFacts m0 xs) =
case Map.splitLookup t m0 of
(_, Just r, _) -> (r, model)
(left, Nothing, right) ->
let m1 = OrdFacts (Map.insert t noEdges m0) xs
in if isVar t
then
case Map.lookup zero m0 of
Just zes ->
let (_,es1,m2@(OrdFacts m2M _)) = link NoGoal (zero,zes) (t,noEdges) m1
in case Map.lookup inf m2M of
Just ies ->
let (es2,_,m3) = link NoGoal (t,es1) (inf,ies) m2
in (es2,m3)
Nothing -> panic "insNode"
[ "infinity is missing from the model"
, show m0
]
Nothing -> panic "insNode"
[ "0 is missing from the model"
, show m0
]
else
let ans2@(es2, m2) =
case toNum Map.findMax left of
Nothing -> (noEdges,m1)
Just l ->
let (_,x,y) = link NoGoal l (t,noEdges) m1
in (x,y)
in case toNum Map.findMin right of
Nothing -> ans2
Just u ->
let (x,_,y) = link NoGoal (t,es2) u m2
in (x,y)
where
toNum f x = do guard (not (Map.null x))
let it@(n,_) = f x
guard (not (isVar n))
return it
isLeq :: OrdFacts -> NumType -> NumType -> Bool
isLeq m t1 t2 = reachable m2 t1 t2
where (_,m1) = insNode t1 m
(_,m2) = insNode t2 m1
insertLeq :: EdgeSrc -> NumType -> NumType -> OrdFacts -> AssertResult
insertLeq _ (NTNat Inf'') (NTNat Inf'') _ = OrdAlreadyKnown
insertLeq _ (NTNat Inf'') (NTNat FinNat'') _ = OrdImpossible
insertLeq _ (NTNat Inf'') (NTNat (Nat'' _)) _ = OrdImpossible
insertLeq _ (NTNat FinNat'') (NTNat Inf'') _ = OrdAlreadyKnown
insertLeq _ (NTNat FinNat'') (NTNat FinNat'') _ = OrdAlreadyKnown
insertLeq _ (NTNat FinNat'') (NTNat (Nat'' _)) _= OrdImpossible
insertLeq _ (NTNat (Nat'' _)) (NTNat Inf'') _ = OrdAlreadyKnown
insertLeq _ (NTNat (Nat'' _)) (NTNat FinNat'') _ = OrdAlreadyKnown
insertLeq _ (NTNat (Nat'' x)) (NTNat (Nat'' y)) _
| x <= y = OrdAlreadyKnown
| otherwise = OrdImpossible
insertLeq src t1 t2 m0
| reachable m2 t2 t1 = case (fromNumType terms t1, fromNumType terms t2) of
(Just a, Just b) -> OrdImprove a b
_ -> OrdCannot
| otherwise =
if reachable m2 t1 t2
then OrdAlreadyKnown
else let (_,_,m3) = link src (t1,n1) (t2,n2) m2
in OrdAdded m3
where (_,m1) = insNode t1 m0
(n2,m2@(OrdFacts m2M terms)) = insNode t2 m1
Just n1 = Map.lookup t1 m2M
insertFin :: EdgeSrc -> NumType -> OrdFacts -> AssertResult
insertFin src t m = insertLeq src t (NTNat FinNat'') m
getLowerBound :: OrdFacts -> NumType -> Nat''
getLowerBound _ (NTNat n) = n
getLowerBound fs@(OrdFacts m _) t =
case Map.lookup t m of
Nothing -> Nat'' 0
Just es -> case map (getLowerBound fs . target) $ Set.toList $ below es of
[] -> Nat'' 0
xs -> maximum xs
getUpperBound :: OrdFacts -> NumType -> Nat''
getUpperBound _ (NTNat n) = n
getUpperBound fs@(OrdFacts m _) t =
case Map.lookup t m of
Nothing -> Inf''
Just es -> case map (getUpperBound fs . target) $ Set.toList $ above es of
[] -> Inf''
xs -> minimum xs
dumpDot :: Bool -> OrdFacts -> String
dumpDot isUp (OrdFacts m _) = "digraph {" ++ concatMap edges (Map.toList m) ++ "}"
where
edge x e = x ++ " -> " ++ node (target e) ++ "[color=\"blue\"];"
dir = if isUp then above else below
edges (x,es) = let n = node x
in n ++ ";" ++
concatMap (edge n) (Set.toList (dir es))
node (NTNat (Nat'' x)) = show (show x)
node (NTNat FinNat'') = show "fin"
node (NTNat Inf'') = show "inf"
node (NTVar (TVFree x _ _ _)) = show ("?v" ++ show x)
node (NTVar (TVBound x _)) = show ("v" ++ show x)
node (NTNamed x) = show ("<" ++ show x ++ ">")
dumpDoc :: OrdFacts -> Doc
dumpDoc = vcat . ordFactsToList mk
where
mk src = Just $ \x -> case src of
NoGoal -> text "[G]" <+> pp x
FromGoal {} -> text "[W]" <+> pp x