{-# LANGUAGE Safe, PatternGuards, MultiWayIf #-} module Cryptol.TypeCheck.Solver.Numeric ( cryIsEqual, cryIsNotEqual, cryIsGeq ) where import Control.Applicative(Alternative(..)) import Control.Monad (guard,mzero) import qualified Control.Monad.Fail as Fail import Data.List (sortBy) import Cryptol.Utils.Patterns import Cryptol.TypeCheck.PP import Cryptol.TypeCheck.Type hiding (tMul) import Cryptol.TypeCheck.TypePat import Cryptol.TypeCheck.Solver.Types import Cryptol.TypeCheck.Solver.InfNat import Cryptol.TypeCheck.Solver.Numeric.Interval import Cryptol.TypeCheck.SimpType as Simp {- Convention for comments: K1, K2 ... Concrete constants s1, s2, t1, t2 ... Arbitrary type expressions a, b, c ... Type variables -} -- | Try to solve @t1 = t2@ cryIsEqual :: Ctxt -> Type -> Type -> Solved cryIsEqual ctxt t1 t2 = matchDefault Unsolved $ (pBin PEqual (==) t1 t2) <|> (aNat' t1 >>= tryEqK ctxt t2) <|> (aNat' t2 >>= tryEqK ctxt t1) <|> (aTVar t1 >>= tryEqVar t2) <|> (aTVar t2 >>= tryEqVar t1) <|> ( guard (t1 == t2) >> return (SolvedIf [])) <|> tryEqMin t1 t2 <|> tryEqMin t2 t1 <|> tryEqMins t1 t2 <|> tryEqMins t2 t1 <|> tryEqMulConst t1 t2 <|> tryEqAddInf ctxt t1 t2 <|> tryAddConst (=#=) t1 t2 <|> tryCancelVar ctxt (=#=) t1 t2 <|> tryLinearSolution t1 t2 <|> tryLinearSolution t2 t1 -- | Try to solve @t1 /= t2@ cryIsNotEqual :: Ctxt -> Type -> Type -> Solved cryIsNotEqual _i t1 t2 = matchDefault Unsolved (pBin PNeq (/=) t1 t2) -- | Try to solve @t1 >= t2@ cryIsGeq :: Ctxt -> Type -> Type -> Solved cryIsGeq i t1 t2 = matchDefault Unsolved $ (pBin PGeq (>=) t1 t2) <|> (aNat' t1 >>= tryGeqKThan i t2) <|> (aNat' t2 >>= tryGeqThanK i t1) <|> (aTVar t2 >>= tryGeqThanVar i t1) <|> tryGeqThanSub i t1 t2 <|> (geqByInterval i t1 t2) <|> (guard (t1 == t2) >> return (SolvedIf [])) <|> tryAddConst (>==) t1 t2 <|> tryCancelVar i (>==) t1 t2 <|> tryMinIsGeq t1 t2 -- XXX: k >= width e -- XXX: width e >= k -- XXX: max t 10 >= 2 --> True -- XXX: max t 2 >= 10 --> a >= 10 -- | Try to solve something by evaluation. pBin :: PC -> (Nat' -> Nat' -> Bool) -> Type -> Type -> Match Solved pBin tf p t1 t2 = Unsolvable <$> anError KNum t1 <|> Unsolvable <$> anError KNum t2 <|> (do x <- aNat' t1 y <- aNat' t2 return $ if p x y then SolvedIf [] else Unsolvable $ TCErrorMessage $ "It is not the case that " ++ show (pp (TCon (PC tf) [ tNat' x, tNat' y ]))) -------------------------------------------------------------------------------- -- GEQ -- | Try to solve @K >= t@ tryGeqKThan :: Ctxt -> Type -> Nat' -> Match Solved tryGeqKThan _ _ Inf = return (SolvedIf []) tryGeqKThan _ ty (Nat n) = -- K1 >= K2 * t do (a,b) <- aMul ty m <- aNat' a return $ SolvedIf $ case m of Inf -> [ b =#= tZero ] Nat 0 -> [] Nat k -> [ tNum (div n k) >== b ] -- | Try to solve @t >= K@ tryGeqThanK :: Ctxt -> Type -> Nat' -> Match Solved tryGeqThanK _ t Inf = return (SolvedIf [ t =#= tInf ]) tryGeqThanK _ t (Nat k) = -- K1 + t >= K2 do (a,b) <- anAdd t n <- aNat a return $ SolvedIf $ if n >= k then [] else [ b >== tNum (k - n) ] -- XXX: K1 ^^ n >= K2 tryGeqThanSub :: Ctxt -> Type -> Type -> Match Solved tryGeqThanSub _ x y = -- t1 >= t1 - t2 do (a,_) <- (|-|) y guard (x == a) return (SolvedIf []) tryGeqThanVar :: Ctxt -> Type -> TVar -> Match Solved tryGeqThanVar _ctxt ty x = -- (t + a) >= a do (a,b) <- anAdd ty let check y = do x' <- aTVar y guard (x == x') return (SolvedIf []) check a <|> check b -- | Try to prove GEQ by considering the known intervals for the given types. geqByInterval :: Ctxt -> Type -> Type -> Match Solved geqByInterval ctxt x y = let ix = typeInterval (intervals ctxt) x iy = typeInterval (intervals ctxt) y in case (iLower ix, iUpper iy) of (l,Just n) | l >= n -> return (SolvedIf []) _ -> mzero -- min K1 t >= K2 ~~> t >= K2, if K1 >= K2; Err otherwise tryMinIsGeq :: Type -> Type -> Match Solved tryMinIsGeq t1 t2 = do (a,b) <- aMin t1 k1 <- aNat a k2 <- aNat t2 return $ if k1 >= k2 then SolvedIf [ b >== t2 ] else Unsolvable $ TCErrorMessage $ show k1 ++ " can't be greater than " ++ show k2 -------------------------------------------------------------------------------- -- | Cancel finite positive variables from both sides. -- @(fin a, a >= 1) => a * t1 == a * t2 ~~~> t1 == t2@ -- @(fin a, a >= 1) => a * t1 >= a * t2 ~~~> t1 >= t2@ tryCancelVar :: Ctxt -> (Type -> Type -> Prop) -> Type -> Type -> Match Solved tryCancelVar ctxt p t1 t2 = let lhs = preproc t1 rhs = preproc t2 in case check [] [] lhs rhs of Nothing -> Fail.fail "tryCancelVar" Just x -> return x where check doneLHS doneRHS lhs@((a,mbA) : moreLHS) rhs@((b, mbB) : moreRHS) = do x <- mbA y <- mbB case compare x y of LT -> check (a : doneLHS) doneRHS moreLHS rhs EQ -> return $ SolvedIf [ p (term (doneLHS ++ map fst moreLHS)) (term (doneRHS ++ map fst moreRHS)) ] GT -> check doneLHS (b : doneRHS) lhs moreRHS check _ _ _ _ = Nothing term xs = case xs of [] -> tNum (1::Int) _ -> foldr1 tMul xs preproc t = let fs = splitMul t [] in sortBy cmpFact (zip fs (map cancelVar fs)) splitMul t rest = case matchMaybe (aMul t) of Just (a,b) -> splitMul a (splitMul b rest) Nothing -> t : rest cancelVar t = matchMaybe $ do x <- aTVar t guard (iIsPosFin (tvarInterval (intervals ctxt) x)) return x -- cancellable variables go first, sorted alphabetically cmpFact (_,mbA) (_,mbB) = case (mbA,mbB) of (Just x, Just y) -> compare x y (Just _, Nothing) -> LT (Nothing, Just _) -> GT _ -> EQ -- min t1 t2 = t1 ~> t1 <= t2 tryEqMin :: Type -> Type -> Match Solved tryEqMin x y = do (a,b) <- aMin x let check m1 m2 = do guard (m1 == y) return $ SolvedIf [ m2 >== m1 ] check a b <|> check b a -- t1 == min (K + t1) t2 ~~> t1 == t2, if K >= 1 -- (also if (K + t1) is one term in a multi-way min) tryEqMins :: Type -> Type -> Match Solved tryEqMins x y = do (a, b) <- aMin y let ys = splitMin a ++ splitMin b let ys' = filter (not . isGt) ys let y' = if null ys' then tInf else foldr1 Simp.tMin ys' return $ if length ys' < length ys then SolvedIf [x =#= y'] else Unsolved where splitMin :: Type -> [Type] splitMin ty = case matchMaybe (aMin ty) of Just (t1, t2) -> splitMin t1 ++ splitMin t2 Nothing -> [ty] isGt :: Type -> Bool isGt t = case matchMaybe (asAddK t) of Just (k, t') -> k > 0 && t' == x Nothing -> False asAddK :: Type -> Match (Integer, Type) asAddK t = do (t1, t2) <- anAdd t k <- aNat t1 return (k, t2) tryEqVar :: Type -> TVar -> Match Solved tryEqVar ty x = -- a = K + a --> x = inf (do (k,tv) <- matches ty (anAdd, aNat, aTVar) guard (tv == x && k >= 1) return $ SolvedIf [ TVar x =#= tInf ] ) <|> -- a = min (K + a) t --> a = t (do (l,r) <- aMin ty let check this other = do (k,x') <- matches this (anAdd, aNat', aTVar) guard (x == x' && k >= Nat 1) return $ SolvedIf [ TVar x =#= other ] check l r <|> check r l ) <|> -- a = K + min t a (do (k,(l,r)) <- matches ty (anAdd, aNat, aMin) guard (k >= 1) let check a b = do x' <- aTVar a guard (x' == x) return (SolvedIf [ TVar x =#= tAdd (tNum k) b ]) check l r <|> check r l ) -- e.g., 10 = t tryEqK :: Ctxt -> Type -> Nat' -> Match Solved tryEqK ctxt ty lk = -- (t1 + t2 = inf, fin t1) ~~~> t2 = inf do guard (lk == Inf) (a,b) <- anAdd ty let check x y = do guard (iIsFin (typeInterval (intervals ctxt) x)) return $ SolvedIf [ y =#= tInf ] check a b <|> check b a <|> -- (K1 + t = K2, K2 >= K1) ~~~> t = (K2 - K1) do (rk, b) <- matches ty (anAdd, aNat', __) return $ case nSub lk rk of -- NOTE: (Inf - Inf) shouldn't be possible Nothing -> Unsolvable $ TCErrorMessage $ "Adding " ++ showNat' rk ++ " will always exceed " ++ showNat' lk Just r -> SolvedIf [ b =#= tNat' r ] <|> -- (lk = t - rk) ~~> t = lk + rk do (t,rk) <- matches ty ((|-|) , __, aNat') return (SolvedIf [ t =#= tNat' (nAdd lk rk) ]) <|> do (rk, b) <- matches ty (aMul, aNat', __) return $ case (lk,rk) of -- Inf * t = Inf ~~~> t >= 1 (Inf,Inf) -> SolvedIf [ b >== tOne ] -- K * t = Inf ~~~> t = Inf (Inf,Nat _) -> SolvedIf [ b =#= tInf ] -- Inf * t = 0 ~~~> t = 0 (Nat 0, Inf) -> SolvedIf [ b =#= tZero ] -- Inf * t = K ~~~> ERR (K /= 0) (Nat k, Inf) -> Unsolvable $ TCErrorMessage $ show k ++ " != inf * anything" (Nat lk', Nat rk') -- 0 * t = K2 ~~> K2 = 0 | rk' == 0 -> SolvedIf [ tNat' lk =#= tZero ] -- shouldn't happen, as `0 * t = t` should have been simplified -- K1 * t = K2 ~~> t = K2/K1 | (q,0) <- divMod lk' rk' -> SolvedIf [ b =#= tNum q ] | otherwise -> Unsolvable $ TCErrorMessage $ showNat' lk ++ " != " ++ showNat' rk ++ " * anything" <|> -- K1 == K2 ^^ t ~~> t = logBase K2 K1 do (rk, b) <- matches ty ((|^|), aNat, __) return $ case lk of Inf | rk > 1 -> SolvedIf [ b =#= tInf ] Nat n | Just (a,True) <- genLog n rk -> SolvedIf [ b =#= tNum a] _ -> Unsolvable $ TCErrorMessage $ show rk ++ " ^^ anything != " ++ showNat' lk -- XXX: Min, Max, etx -- 2 = min (10,y) --> y = 2 -- 2 = min (2,y) --> y >= 2 -- 10 = min (2,y) --> impossible -- | K1 * t1 + K2 * t2 + ... = K3 * t3 + K4 * t4 + ... tryEqMulConst :: Type -> Type -> Match Solved tryEqMulConst l r = do (lc,ls) <- matchLinear l (rc,rs) <- matchLinear r let d = foldr1 gcd (lc : rc : map fst (ls ++ rs)) guard (d > 1) return (SolvedIf [build d lc ls =#= build d rc rs]) where build d k ts = foldr tAdd (cancel d k) (map (buildS d) ts) buildS d (k,t) = tMul (cancel d k) t cancel d x = tNum (div x d) -- | @(t1 + t2 = Inf, fin t1) ~~> t2 = Inf@ tryEqAddInf :: Ctxt -> Type -> Type -> Match Solved tryEqAddInf ctxt l r = check l r <|> check r l where -- check for x = a + b /\ x = inf check x y = do (x1,x2) <- anAdd x aInf y let x1Fin = iIsFin (typeInterval (intervals ctxt) x1) let x2Fin = iIsFin (typeInterval (intervals ctxt) x2) return $! if | x1Fin -> SolvedIf [ x2 =#= y ] | x2Fin -> SolvedIf [ x1 =#= y ] | otherwise -> Unsolved -- | Check for addition of constants to both sides of a relation. -- @((K1 + K2) + t1) `R` (K1 + t2) ~~> (K2 + t1) `R` t2@ -- -- This relies on the fact that constants are floated left during -- simplification. tryAddConst :: (Type -> Type -> Prop) -> Type -> Type -> Match Solved tryAddConst rel l r = do (x1,x2) <- anAdd l (y1,y2) <- anAdd r k1 <- aNat x1 k2 <- aNat y1 if k1 > k2 then return (SolvedIf [ tAdd (tNum (k1 - k2)) x2 `rel` y2 ]) else return (SolvedIf [ x2 `rel` tAdd (tNum (k2 - k1)) y2 ]) -- | Check for situations where a unification variable is involved in -- a sum of terms not containing additional unification variables, -- and replace it with a solution and an inequality. -- @s1 = ?a + s2 ~~> (?a = s1 - s2, s1 >= s2)@ tryLinearSolution :: Type -> Type -> Match Solved tryLinearSolution s1 t = do (a,xs) <- matchLinearUnifier t guard (noFreeVariables s1) -- NB: matchLinearUnifier only matches if xs is nonempty let s2 = foldr1 Simp.tAdd xs return (SolvedIf [ TVar a =#= (Simp.tSub s1 s2), s1 >== s2 ]) -- | Match a sum of the form @(s1 + ... + ?a + ... sn)@ where -- @s1@ through @sn@ do not contain any free variables. -- -- Note: a successful match should only occur if @s1 ... sn@ is -- not empty. matchLinearUnifier :: Pat Type (TVar,[Type]) matchLinearUnifier = go [] where go xs t = -- Case where a free variable occurs at the end of a sequence of additions. -- NB: match fails if @xs@ is empty do v <- aFreeTVar t guard (not . null $ xs) return (v, xs) <|> -- Next symbol is an addition do (x, y) <- anAdd t -- Case where a free variable occurs in the middle of an expression (do v <- aFreeTVar x guard (noFreeVariables y) return (v, reverse (y:xs)) <|> -- Non-free-variable recursive case do guard (noFreeVariables x) go (x:xs) y) -- | Is this a sum of products, where the products have constant coefficients? matchLinear :: Pat Type (Integer, [(Integer,Type)]) matchLinear = go (0, []) where go (c,ts) t = do n <- aNat t return (n + c, ts) <|> do (x,y) <- aMul t n <- aNat x return (c, (n,y) : ts) <|> do (l,r) <- anAdd t (c',ts') <- go (c,ts) l go (c',ts') r showNat' :: Nat' -> String showNat' Inf = "inf" showNat' (Nat n) = show n