{-# LANGUAGE Safe, PatternGuards, MultiWayIf #-}
module Cryptol.TypeCheck.Solver.Numeric
( cryIsEqual, cryIsNotEqual, cryIsGeq
) where
import Control.Applicative(Alternative(..))
import Control.Monad (guard,mzero)
import qualified Control.Monad.Fail as Fail
import Data.List (sortBy)
import Cryptol.Utils.Patterns
import Cryptol.TypeCheck.PP
import Cryptol.TypeCheck.Type hiding (tMul)
import Cryptol.TypeCheck.TypePat
import Cryptol.TypeCheck.Solver.Types
import Cryptol.TypeCheck.Solver.InfNat
import Cryptol.TypeCheck.Solver.Numeric.Interval
import Cryptol.TypeCheck.SimpType as Simp
cryIsEqual :: Ctxt -> Type -> Type -> Solved
cryIsEqual ctxt t1 t2 =
matchDefault Unsolved $
(pBin PEqual (==) t1 t2)
<|> (aNat' t1 >>= tryEqK ctxt t2)
<|> (aNat' t2 >>= tryEqK ctxt t1)
<|> (aTVar t1 >>= tryEqVar t2)
<|> (aTVar t2 >>= tryEqVar t1)
<|> ( guard (t1 == t2) >> return (SolvedIf []))
<|> tryEqMin t1 t2
<|> tryEqMin t2 t1
<|> tryEqMins t1 t2
<|> tryEqMins t2 t1
<|> tryEqMulConst t1 t2
<|> tryEqAddInf ctxt t1 t2
<|> tryAddConst (=#=) t1 t2
<|> tryCancelVar ctxt (=#=) t1 t2
<|> tryLinearSolution t1 t2
<|> tryLinearSolution t2 t1
cryIsNotEqual :: Ctxt -> Type -> Type -> Solved
cryIsNotEqual _i t1 t2 = matchDefault Unsolved (pBin PNeq (/=) t1 t2)
cryIsGeq :: Ctxt -> Type -> Type -> Solved
cryIsGeq i t1 t2 =
matchDefault Unsolved $
(pBin PGeq (>=) t1 t2)
<|> (aNat' t1 >>= tryGeqKThan i t2)
<|> (aNat' t2 >>= tryGeqThanK i t1)
<|> (aTVar t2 >>= tryGeqThanVar i t1)
<|> tryGeqThanSub i t1 t2
<|> (geqByInterval i t1 t2)
<|> (guard (t1 == t2) >> return (SolvedIf []))
<|> tryAddConst (>==) t1 t2
<|> tryCancelVar i (>==) t1 t2
<|> tryMinIsGeq t1 t2
pBin :: PC -> (Nat' -> Nat' -> Bool) -> Type -> Type -> Match Solved
pBin tf p t1 t2 =
Unsolvable <$> anError KNum t1
<|> Unsolvable <$> anError KNum t2
<|> (do x <- aNat' t1
y <- aNat' t2
return $ if p x y
then SolvedIf []
else Unsolvable $ TCErrorMessage
$ "It is not the case that " ++
show (pp (TCon (PC tf) [ tNat' x, tNat' y ])))
tryGeqKThan :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqKThan _ _ Inf = return (SolvedIf [])
tryGeqKThan _ ty (Nat n) =
do (a,b) <- aMul ty
m <- aNat' a
return $ SolvedIf
$ case m of
Inf -> [ b =#= tZero ]
Nat 0 -> []
Nat k -> [ tNum (div n k) >== b ]
tryGeqThanK :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqThanK _ t Inf = return (SolvedIf [ t =#= tInf ])
tryGeqThanK _ t (Nat k) =
do (a,b) <- anAdd t
n <- aNat a
return $ SolvedIf $ if n >= k
then []
else [ b >== tNum (k - n) ]
tryGeqThanSub :: Ctxt -> Type -> Type -> Match Solved
tryGeqThanSub _ x y =
do (a,_) <- (|-|) y
guard (x == a)
return (SolvedIf [])
tryGeqThanVar :: Ctxt -> Type -> TVar -> Match Solved
tryGeqThanVar _ctxt ty x =
do (a,b) <- anAdd ty
let check y = do x' <- aTVar y
guard (x == x')
return (SolvedIf [])
check a <|> check b
geqByInterval :: Ctxt -> Type -> Type -> Match Solved
geqByInterval ctxt x y =
let ix = typeInterval (intervals ctxt) x
iy = typeInterval (intervals ctxt) y
in case (iLower ix, iUpper iy) of
(l,Just n) | l >= n -> return (SolvedIf [])
_ -> mzero
tryMinIsGeq :: Type -> Type -> Match Solved
tryMinIsGeq t1 t2 =
do (a,b) <- aMin t1
k1 <- aNat a
k2 <- aNat t2
return $ if k1 >= k2
then SolvedIf [ b >== t2 ]
else Unsolvable $ TCErrorMessage $
show k1 ++ " can't be greater than " ++ show k2
tryCancelVar :: Ctxt -> (Type -> Type -> Prop) -> Type -> Type -> Match Solved
tryCancelVar ctxt p t1 t2 =
let lhs = preproc t1
rhs = preproc t2
in case check [] [] lhs rhs of
Nothing -> Fail.fail "tryCancelVar"
Just x -> return x
where
check doneLHS doneRHS lhs@((a,mbA) : moreLHS) rhs@((b, mbB) : moreRHS) =
do x <- mbA
y <- mbB
case compare x y of
LT -> check (a : doneLHS) doneRHS moreLHS rhs
EQ -> return $ SolvedIf [ p (term (doneLHS ++ map fst moreLHS))
(term (doneRHS ++ map fst moreRHS)) ]
GT -> check doneLHS (b : doneRHS) lhs moreRHS
check _ _ _ _ = Nothing
term xs = case xs of
[] -> tNum (1::Int)
_ -> foldr1 tMul xs
preproc t = let fs = splitMul t []
in sortBy cmpFact (zip fs (map cancelVar fs))
splitMul t rest = case matchMaybe (aMul t) of
Just (a,b) -> splitMul a (splitMul b rest)
Nothing -> t : rest
cancelVar t = matchMaybe $ do x <- aTVar t
guard (iIsPosFin (tvarInterval (intervals ctxt) x))
return x
cmpFact (_,mbA) (_,mbB) =
case (mbA,mbB) of
(Just x, Just y) -> compare x y
(Just _, Nothing) -> LT
(Nothing, Just _) -> GT
_ -> EQ
tryEqMin :: Type -> Type -> Match Solved
tryEqMin x y =
do (a,b) <- aMin x
let check m1 m2 = do guard (m1 == y)
return $ SolvedIf [ m2 >== m1 ]
check a b <|> check b a
tryEqMins :: Type -> Type -> Match Solved
tryEqMins x y =
do (a, b) <- aMin y
let ys = splitMin a ++ splitMin b
let ys' = filter (not . isGt) ys
let y' = if null ys' then tInf else foldr1 Simp.tMin ys'
return $ if length ys' < length ys
then SolvedIf [x =#= y']
else Unsolved
where
splitMin :: Type -> [Type]
splitMin ty =
case matchMaybe (aMin ty) of
Just (t1, t2) -> splitMin t1 ++ splitMin t2
Nothing -> [ty]
isGt :: Type -> Bool
isGt t =
case matchMaybe (asAddK t) of
Just (k, t') -> k > 0 && t' == x
Nothing -> False
asAddK :: Type -> Match (Integer, Type)
asAddK t =
do (t1, t2) <- anAdd t
k <- aNat t1
return (k, t2)
tryEqVar :: Type -> TVar -> Match Solved
tryEqVar ty x =
(do (k,tv) <- matches ty (anAdd, aNat, aTVar)
guard (tv == x && k >= 1)
return $ SolvedIf [ TVar x =#= tInf ]
)
<|>
(do (l,r) <- aMin ty
let check this other =
do (k,x') <- matches this (anAdd, aNat', aTVar)
guard (x == x' && k >= Nat 1)
return $ SolvedIf [ TVar x =#= other ]
check l r <|> check r l
)
<|>
(do (k,(l,r)) <- matches ty (anAdd, aNat, aMin)
guard (k >= 1)
let check a b = do x' <- aTVar a
guard (x' == x)
return (SolvedIf [ TVar x =#= tAdd (tNum k) b ])
check l r <|> check r l
)
tryEqK :: Ctxt -> Type -> Nat' -> Match Solved
tryEqK ctxt ty lk =
do guard (lk == Inf)
(a,b) <- anAdd ty
let check x y = do guard (iIsFin (typeInterval (intervals ctxt) x))
return $ SolvedIf [ y =#= tInf ]
check a b <|> check b a
<|>
do (rk, b) <- matches ty (anAdd, aNat', __)
return $
case nSub lk rk of
Nothing -> Unsolvable
$ TCErrorMessage
$ "Adding " ++ showNat' rk ++ " will always exceed "
++ showNat' lk
Just r -> SolvedIf [ b =#= tNat' r ]
<|>
do (t,rk) <- matches ty ((|-|) , __, aNat')
return (SolvedIf [ t =#= tNat' (nAdd lk rk) ])
<|>
do (rk, b) <- matches ty (aMul, aNat', __)
return $
case (lk,rk) of
(Inf,Inf) -> SolvedIf [ b >== tOne ]
(Inf,Nat _) -> SolvedIf [ b =#= tInf ]
(Nat 0, Inf) -> SolvedIf [ b =#= tZero ]
(Nat k, Inf) -> Unsolvable
$ TCErrorMessage
$ show k ++ " != inf * anything"
(Nat lk', Nat rk')
| rk' == 0 -> SolvedIf [ tNat' lk =#= tZero ]
| (q,0) <- divMod lk' rk' -> SolvedIf [ b =#= tNum q ]
| otherwise ->
Unsolvable
$ TCErrorMessage
$ showNat' lk ++ " != " ++ showNat' rk ++ " * anything"
<|>
do (rk, b) <- matches ty ((|^|), aNat, __)
return $ case lk of
Inf | rk > 1 -> SolvedIf [ b =#= tInf ]
Nat n | Just (a,True) <- genLog n rk -> SolvedIf [ b =#= tNum a]
_ -> Unsolvable $ TCErrorMessage
$ show rk ++ " ^^ anything != " ++ showNat' lk
tryEqMulConst :: Type -> Type -> Match Solved
tryEqMulConst l r =
do (lc,ls) <- matchLinear l
(rc,rs) <- matchLinear r
let d = foldr1 gcd (lc : rc : map fst (ls ++ rs))
guard (d > 1)
return (SolvedIf [build d lc ls =#= build d rc rs])
where
build d k ts = foldr tAdd (cancel d k) (map (buildS d) ts)
buildS d (k,t) = tMul (cancel d k) t
cancel d x = tNum (div x d)
tryEqAddInf :: Ctxt -> Type -> Type -> Match Solved
tryEqAddInf ctxt l r = check l r <|> check r l
where
check x y =
do (x1,x2) <- anAdd x
aInf y
let x1Fin = iIsFin (typeInterval (intervals ctxt) x1)
let x2Fin = iIsFin (typeInterval (intervals ctxt) x2)
return $!
if | x1Fin ->
SolvedIf [ x2 =#= y ]
| x2Fin ->
SolvedIf [ x1 =#= y ]
| otherwise ->
Unsolved
tryAddConst :: (Type -> Type -> Prop) -> Type -> Type -> Match Solved
tryAddConst rel l r =
do (x1,x2) <- anAdd l
(y1,y2) <- anAdd r
k1 <- aNat x1
k2 <- aNat y1
if k1 > k2
then return (SolvedIf [ tAdd (tNum (k1 - k2)) x2 `rel` y2 ])
else return (SolvedIf [ x2 `rel` tAdd (tNum (k2 - k1)) y2 ])
tryLinearSolution :: Type -> Type -> Match Solved
tryLinearSolution s1 t =
do (a,xs) <- matchLinearUnifier t
guard (noFreeVariables s1)
let s2 = foldr1 Simp.tAdd xs
return (SolvedIf [ TVar a =#= (Simp.tSub s1 s2), s1 >== s2 ])
matchLinearUnifier :: Pat Type (TVar,[Type])
matchLinearUnifier = go []
where
go xs t =
do v <- aFreeTVar t
guard (not . null $ xs)
return (v, xs)
<|>
do (x, y) <- anAdd t
(do v <- aFreeTVar x
guard (noFreeVariables y)
return (v, reverse (y:xs))
<|>
do guard (noFreeVariables x)
go (x:xs) y)
matchLinear :: Pat Type (Integer, [(Integer,Type)])
matchLinear = go (0, [])
where
go (c,ts) t =
do n <- aNat t
return (n + c, ts)
<|>
do (x,y) <- aMul t
n <- aNat x
return (c, (n,y) : ts)
<|>
do (l,r) <- anAdd t
(c',ts') <- go (c,ts) l
go (c',ts') r
showNat' :: Nat' -> String
showNat' Inf = "inf"
showNat' (Nat n) = show n