{-# LANGUAGE Safe, PatternGuards, MultiWayIf #-} module Cryptol.TypeCheck.Solver.Numeric ( cryIsEqual, cryIsNotEqual, cryIsGeq ) where import Control.Applicative(Alternative(..)) import Control.Monad (guard,mzero) import Data.List (sortBy) import Cryptol.Utils.Patterns import Cryptol.TypeCheck.PP import Cryptol.TypeCheck.Type hiding (tMul) import Cryptol.TypeCheck.TypePat import Cryptol.TypeCheck.Solver.Types import Cryptol.TypeCheck.Solver.InfNat import Cryptol.TypeCheck.Solver.Numeric.Interval import Cryptol.TypeCheck.SimpType as Simp {- Convention for comments: K1, K2 ... Concrete constants s1, s2, t1, t2 ... Arbitrary type expressions a, b, c ... Type variables -} -- | Try to solve @t1 = t2@ cryIsEqual :: Ctxt -> Type -> Type -> Solved cryIsEqual ctxt t1 t2 = matchDefault Unsolved $ (pBin PEqual (==) t1 t2) <|> (aNat' t1 >>= tryEqK ctxt t2) <|> (aNat' t2 >>= tryEqK ctxt t1) <|> (aTVar t1 >>= tryEqVar t2) <|> (aTVar t2 >>= tryEqVar t1) <|> ( guard (t1 == t2) >> return (SolvedIf [])) <|> tryEqMin t1 t2 <|> tryEqMin t2 t1 <|> tryEqMins t1 t2 <|> tryEqMins t2 t1 <|> tryEqMulConst t1 t2 <|> tryEqAddInf ctxt t1 t2 <|> tryAddConst (=#=) t1 t2 <|> tryCancelVar ctxt (=#=) t1 t2 <|> tryLinearSolution t1 t2 <|> tryLinearSolution t2 t1 -- | Try to solve @t1 /= t2@ cryIsNotEqual :: Ctxt -> Type -> Type -> Solved cryIsNotEqual _i t1 t2 = matchDefault Unsolved (pBin PNeq (/=) t1 t2) -- | Try to solve @t1 >= t2@ cryIsGeq :: Ctxt -> Type -> Type -> Solved cryIsGeq i t1 t2 = matchDefault Unsolved $ (pBin PGeq (>=) t1 t2) <|> (aNat' t1 >>= tryGeqKThan i t2) <|> (aNat' t2 >>= tryGeqThanK i t1) <|> (aTVar t2 >>= tryGeqThanVar i t1) <|> tryGeqThanSub i t1 t2 <|> (geqByInterval i t1 t2) <|> (guard (t1 == t2) >> return (SolvedIf [])) <|> tryAddConst (>==) t1 t2 <|> tryCancelVar i (>==) t1 t2 <|> tryMinIsGeq t1 t2 -- XXX: k >= width e -- XXX: width e >= k -- XXX: max t 10 >= 2 --> True -- XXX: max t 2 >= 10 --> a >= 10 -- | Try to solve something by evaluation. pBin :: PC -> (Nat' -> Nat' -> Bool) -> Type -> Type -> Match Solved pBin tf p t1 t2 = Unsolvable <$> anError KNum t1 <|> Unsolvable <$> anError KNum t2 <|> (do x <- aNat' t1 y <- aNat' t2 return $ if p x y then SolvedIf [] else Unsolvable $ TCErrorMessage $ "Unsolvable constraint: " ++ show (pp (TCon (PC tf) [ tNat' x, tNat' y ]))) -------------------------------------------------------------------------------- -- GEQ -- | Try to solve @K >= t@ tryGeqKThan :: Ctxt -> Type -> Nat' -> Match Solved tryGeqKThan _ _ Inf = return (SolvedIf []) tryGeqKThan _ ty (Nat n) = -- K1 >= K2 * t do (a,b) <- aMul ty m <- aNat' a return $ SolvedIf $ case m of Inf -> [ b =#= tZero ] Nat 0 -> [] Nat k -> [ tNum (div n k) >== b ] -- | Try to solve @t >= K@ tryGeqThanK :: Ctxt -> Type -> Nat' -> Match Solved tryGeqThanK _ t Inf = return (SolvedIf [ t =#= tInf ]) tryGeqThanK _ t (Nat k) = -- K1 + t >= K2 do (a,b) <- anAdd t n <- aNat a return $ SolvedIf $ if n >= k then [] else [ b >== tNum (k - n) ] -- XXX: K1 ^^ n >= K2 tryGeqThanSub :: Ctxt -> Type -> Type -> Match Solved tryGeqThanSub _ x y = -- t1 >= t1 - t2 do (a,_) <- (|-|) y guard (x == a) return (SolvedIf []) tryGeqThanVar :: Ctxt -> Type -> TVar -> Match Solved tryGeqThanVar _ctxt ty x = -- (t + a) >= a do (a,b) <- anAdd ty let check y = do x' <- aTVar y guard (x == x') return (SolvedIf []) check a <|> check b -- | Try to prove GEQ by considering the known intervals for the given types. geqByInterval :: Ctxt -> Type -> Type -> Match Solved geqByInterval ctxt x y = let ix = typeInterval ctxt x iy = typeInterval ctxt y in case (iLower ix, iUpper iy) of (l,Just n) | l >= n -> return (SolvedIf []) _ -> mzero -- min K1 t >= K2 ~~> t >= K2, if K1 >= K2; Err otherwise tryMinIsGeq :: Type -> Type -> Match Solved tryMinIsGeq t1 t2 = do (a,b) <- aMin t1 k1 <- aNat a k2 <- aNat t2 return $ if k1 >= k2 then SolvedIf [ b >== t2 ] else Unsolvable $ TCErrorMessage $ show k1 ++ " can't be greater than " ++ show k2 -------------------------------------------------------------------------------- -- | Cancel finite positive variables from both sides. -- @(fin a, a >= 1) => a * t1 == a * t2 ~~~> t1 == t2@ -- @(fin a, a >= 1) => a * t1 >= a * t2 ~~~> t1 >= t2@ tryCancelVar :: Ctxt -> (Type -> Type -> Prop) -> Type -> Type -> Match Solved tryCancelVar ctxt p t1 t2 = let lhs = preproc t1 rhs = preproc t2 in case check [] [] lhs rhs of Nothing -> fail "" Just x -> return x where check doneLHS doneRHS lhs@((a,mbA) : moreLHS) rhs@((b, mbB) : moreRHS) = do x <- mbA y <- mbB case compare x y of LT -> check (a : doneLHS) doneRHS moreLHS rhs EQ -> return $ SolvedIf [ p (term (doneLHS ++ map fst moreLHS)) (term (doneRHS ++ map fst moreRHS)) ] GT -> check doneLHS (b : doneRHS) lhs moreRHS check _ _ _ _ = Nothing term xs = case xs of [] -> tNum (1::Int) _ -> foldr1 tMul xs preproc t = let fs = splitMul t [] in sortBy cmpFact (zip fs (map cancelVar fs)) splitMul t rest = case matchMaybe (aMul t) of Just (a,b) -> splitMul a (splitMul b rest) Nothing -> t : rest cancelVar t = matchMaybe $ do x <- aTVar t guard (iIsPosFin (tvarInterval ctxt x)) return x -- cancellable variables go first, sorted alphabetically cmpFact (_,mbA) (_,mbB) = case (mbA,mbB) of (Just x, Just y) -> compare x y (Just _, Nothing) -> LT (Nothing, Just _) -> GT _ -> EQ -- min t1 t2 = t1 ~> t1 <= t2 tryEqMin :: Type -> Type -> Match Solved tryEqMin x y = do (a,b) <- aMin x let check m1 m2 = do guard (m1 == y) return $ SolvedIf [ m2 >== m1 ] check a b <|> check b a -- t1 == min (K + t1) t2 ~~> t1 == t2, if K >= 1 -- (also if (K + t1) is one term in a multi-way min) tryEqMins :: Type -> Type -> Match Solved tryEqMins x y = do (a, b) <- aMin y let ys = splitMin a ++ splitMin b let ys' = filter (not . isGt) ys let y' = if null ys' then tInf else foldr1 Simp.tMin ys' return $ if length ys' < length ys then SolvedIf [x =#= y'] else Unsolved where splitMin :: Type -> [Type] splitMin ty = case matchMaybe (aMin ty) of Just (t1, t2) -> splitMin t1 ++ splitMin t2 Nothing -> [ty] isGt :: Type -> Bool isGt t = case matchMaybe (asAddK t) of Just (k, t') -> k > 0 && t' == x Nothing -> False asAddK :: Type -> Match (Integer, Type) asAddK t = do (t1, t2) <- anAdd t k <- aNat t1 return (k, t2) tryEqVar :: Type -> TVar -> Match Solved tryEqVar ty x = -- a = K + a --> x = inf (do (k,tv) <- matches ty (anAdd, aNat, aTVar) guard (tv == x && k >= 1) return $ SolvedIf [ TVar x =#= tInf ] ) <|> -- a = min (K + a) t --> a = t (do (l,r) <- aMin ty let check this other = do (k,x') <- matches this (anAdd, aNat', aTVar) guard (x == x' && k >= Nat 1) return $ SolvedIf [ TVar x =#= other ] check l r <|> check r l ) <|> -- a = K + min t a (do (k,(l,r)) <- matches ty (anAdd, aNat, aMin) guard (k >= 1) let check a b = do x' <- aTVar a guard (x' == x) return (SolvedIf [ TVar x =#= tAdd (tNum k) b ]) check l r <|> check r l ) -- e.g., 10 = t tryEqK :: Ctxt -> Type -> Nat' -> Match Solved tryEqK ctxt ty lk = -- (t1 + t2 = inf, fin t1) ~~~> t2 = inf do guard (lk == Inf) (a,b) <- anAdd ty let check x y = do guard (iIsFin (typeInterval ctxt x)) return $ SolvedIf [ y =#= tInf ] check a b <|> check b a <|> -- (K1 + t = K2, K2 >= K1) ~~~> t = (K2 - K1) do (rk, b) <- matches ty (anAdd, aNat', __) return $ case nSub lk rk of -- NOTE: (Inf - Inf) shouldn't be possible Nothing -> Unsolvable $ TCErrorMessage $ "Adding " ++ showNat' rk ++ " will always exceed " ++ showNat' lk Just r -> SolvedIf [ b =#= tNat' r ] <|> -- (lk = t - rk) ~~> t = lk + rk do (t,rk) <- matches ty ((|-|) , __, aNat') return (SolvedIf [ t =#= tNat' (nAdd lk rk) ]) <|> do (rk, b) <- matches ty (aMul, aNat', __) return $ case (lk,rk) of -- Inf * t = Inf ~~~> t >= 1 (Inf,Inf) -> SolvedIf [ b >== tOne ] -- K * t = Inf ~~~> t = Inf (Inf,Nat _) -> SolvedIf [ b =#= tInf ] -- Inf * t = 0 ~~~> t = 0 (Nat 0, Inf) -> SolvedIf [ b =#= tZero ] -- Inf * t = K ~~~> ERR (K /= 0) (Nat k, Inf) -> Unsolvable $ TCErrorMessage $ show k ++ " != inf * anything" (Nat lk', Nat rk') -- 0 * t = K2 ~~> K2 = 0 | rk' == 0 -> SolvedIf [ tNat' lk =#= tZero ] -- shouldn't happen, as `0 * t = t` should have been simplified -- K1 * t = K2 ~~> t = K2/K1 | (q,0) <- divMod lk' rk' -> SolvedIf [ b =#= tNum q ] | otherwise -> Unsolvable $ TCErrorMessage $ showNat' lk ++ " != " ++ showNat' rk ++ " * anything" <|> -- K1 == K2 ^^ t ~~> t = logBase K2 K1 do (rk, b) <- matches ty ((|^|), aNat, __) return $ case lk of Inf | rk > 1 -> SolvedIf [ b =#= tInf ] Nat n | Just (a,True) <- genLog n rk -> SolvedIf [ b =#= tNum a] _ -> Unsolvable $ TCErrorMessage $ show rk ++ " ^^ anything != " ++ showNat' lk -- XXX: Min, Max, etx -- 2 = min (10,y) --> y = 2 -- 2 = min (2,y) --> y >= 2 -- 10 = min (2,y) --> impossible -- | K1 * t1 + K2 * t2 + ... = K3 * t3 + K4 * t4 + ... tryEqMulConst :: Type -> Type -> Match Solved tryEqMulConst l r = do (lc,ls) <- matchLinear l (rc,rs) <- matchLinear r let d = foldr1 gcd (lc : rc : map fst (ls ++ rs)) guard (d > 1) return (SolvedIf [build d lc ls =#= build d rc rs]) where build d k ts = foldr tAdd (cancel d k) (map (buildS d) ts) buildS d (k,t) = tMul (cancel d k) t cancel d x = tNum (div x d) -- | @(t1 + t2 = Inf, fin t1) ~~> t2 = Inf@ tryEqAddInf :: Ctxt -> Type -> Type -> Match Solved tryEqAddInf ctxt l r = check l r <|> check r l where -- check for x = a + b /\ x = inf check x y = do (x1,x2) <- anAdd x aInf y let x1Fin = iIsFin (typeInterval ctxt x1) let x2Fin = iIsFin (typeInterval ctxt x2) return $! if | x1Fin -> SolvedIf [ x2 =#= y ] | x2Fin -> SolvedIf [ x1 =#= y ] | otherwise -> Unsolved -- | Check for addition of constants to both sides of a relation. -- @((K1 + K2) + t1) `R` (K1 + t2) ~~> (K2 + t1) `R` t2@ -- -- This relies on the fact that constants are floated left during -- simplification. tryAddConst :: (Type -> Type -> Prop) -> Type -> Type -> Match Solved tryAddConst rel l r = do (x1,x2) <- anAdd l (y1,y2) <- anAdd r k1 <- aNat x1 k2 <- aNat y1 if k1 > k2 then return (SolvedIf [ tAdd (tNum (k1 - k2)) x2 `rel` y2 ]) else return (SolvedIf [ x2 `rel` tAdd (tNum (k2 - k1)) y2 ]) -- | Check for situations where a unification variable is involved in -- a sum of terms not containing additional unification variables, -- and replace it with a solution and an inequality. -- @s1 = ?a + s2 ~~> (?a = s1 - s2, s1 >= s2)@ tryLinearSolution :: Type -> Type -> Match Solved tryLinearSolution s1 t = do (a,xs) <- matchLinearUnifier t guard (noFreeVariables s1) -- NB: matchLinearUnifier only matches if xs is nonempty let s2 = foldr1 Simp.tAdd xs return (SolvedIf [ TVar a =#= (Simp.tSub s1 s2), s1 >== s2 ]) -- | Match a sum of the form @(s1 + ... + ?a + ... sn)@ where -- @s1@ through @sn@ do not contain any free variables. -- -- Note: a successful match should only occur if @s1 ... sn@ is -- not empty. matchLinearUnifier :: Pat Type (TVar,[Type]) matchLinearUnifier = go [] where go xs t = -- Case where a free variable occurs at the end of a sequence of additions. -- NB: match fails if @xs@ is empty do v <- aFreeTVar t guard (not . null $ xs) return (v, xs) <|> -- Next symbol is an addition do (x, y) <- anAdd t -- Case where a free variable occurs in the middle of an expression (do v <- aFreeTVar x guard (noFreeVariables y) return (v, reverse (y:xs)) <|> -- Non-free-variable recursive case do guard (noFreeVariables x) go (x:xs) y) -- | Is this a sum of products, where the products have constant coefficients? matchLinear :: Pat Type (Integer, [(Integer,Type)]) matchLinear = go (0, []) where go (c,ts) t = do n <- aNat t return (n + c, ts) <|> do (x,y) <- aMul t n <- aNat x return (c, (n,y) : ts) <|> do (l,r) <- anAdd t (c',ts') <- go (c,ts) l go (c',ts') r showNat' :: Nat' -> String showNat' Inf = "inf" showNat' (Nat n) = show n