{-# LANGUAGE Safe, PatternGuards, MultiWayIf #-}
module Cryptol.TypeCheck.Solver.Numeric
( cryIsEqual, cryIsNotEqual, cryIsGeq
) where
import Control.Applicative(Alternative(..))
import Control.Monad (guard,mzero)
import Data.List (sortBy)
import Cryptol.Utils.Patterns
import Cryptol.TypeCheck.PP
import Cryptol.TypeCheck.Type hiding (tMul)
import Cryptol.TypeCheck.TypePat
import Cryptol.TypeCheck.Solver.Types
import Cryptol.TypeCheck.Solver.InfNat
import Cryptol.TypeCheck.Solver.Numeric.Interval
import Cryptol.TypeCheck.SimpType as Simp
cryIsEqual :: Ctxt -> Type -> Type -> Solved
cryIsEqual ctxt t1 t2 =
matchDefault Unsolved $
(pBin PEqual (==) t1 t2)
<|> (aNat' t1 >>= tryEqK ctxt t2)
<|> (aNat' t2 >>= tryEqK ctxt t1)
<|> (aTVar t1 >>= tryEqVar t2)
<|> (aTVar t2 >>= tryEqVar t1)
<|> ( guard (t1 == t2) >> return (SolvedIf []))
<|> tryEqMin t1 t2
<|> tryEqMin t2 t1
<|> tryEqMins t1 t2
<|> tryEqMins t2 t1
<|> tryEqMulConst t1 t2
<|> tryEqAddInf ctxt t1 t2
<|> tryAddConst (=#=) t1 t2
<|> tryCancelVar ctxt (=#=) t1 t2
<|> tryLinearSolution t1 t2
<|> tryLinearSolution t2 t1
cryIsNotEqual :: Ctxt -> Type -> Type -> Solved
cryIsNotEqual _i t1 t2 = matchDefault Unsolved (pBin PNeq (/=) t1 t2)
cryIsGeq :: Ctxt -> Type -> Type -> Solved
cryIsGeq i t1 t2 =
matchDefault Unsolved $
(pBin PGeq (>=) t1 t2)
<|> (aNat' t1 >>= tryGeqKThan i t2)
<|> (aNat' t2 >>= tryGeqThanK i t1)
<|> (aTVar t2 >>= tryGeqThanVar i t1)
<|> tryGeqThanSub i t1 t2
<|> (geqByInterval i t1 t2)
<|> (guard (t1 == t2) >> return (SolvedIf []))
<|> tryAddConst (>==) t1 t2
<|> tryCancelVar i (>==) t1 t2
<|> tryMinIsGeq t1 t2
pBin :: PC -> (Nat' -> Nat' -> Bool) -> Type -> Type -> Match Solved
pBin tf p t1 t2 =
Unsolvable <$> anError KNum t1
<|> Unsolvable <$> anError KNum t2
<|> (do x <- aNat' t1
y <- aNat' t2
return $ if p x y
then SolvedIf []
else Unsolvable $ TCErrorMessage
$ "Unsolvable constraint: " ++
show (pp (TCon (PC tf) [ tNat' x, tNat' y ])))
tryGeqKThan :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqKThan _ _ Inf = return (SolvedIf [])
tryGeqKThan _ ty (Nat n) =
do (a,b) <- aMul ty
m <- aNat' a
return $ SolvedIf
$ case m of
Inf -> [ b =#= tZero ]
Nat 0 -> []
Nat k -> [ tNum (div n k) >== b ]
tryGeqThanK :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqThanK _ t Inf = return (SolvedIf [ t =#= tInf ])
tryGeqThanK _ t (Nat k) =
do (a,b) <- anAdd t
n <- aNat a
return $ SolvedIf $ if n >= k
then []
else [ b >== tNum (k - n) ]
tryGeqThanSub :: Ctxt -> Type -> Type -> Match Solved
tryGeqThanSub _ x y =
do (a,_) <- (|-|) y
guard (x == a)
return (SolvedIf [])
tryGeqThanVar :: Ctxt -> Type -> TVar -> Match Solved
tryGeqThanVar _ctxt ty x =
do (a,b) <- anAdd ty
let check y = do x' <- aTVar y
guard (x == x')
return (SolvedIf [])
check a <|> check b
geqByInterval :: Ctxt -> Type -> Type -> Match Solved
geqByInterval ctxt x y =
let ix = typeInterval ctxt x
iy = typeInterval ctxt y
in case (iLower ix, iUpper iy) of
(l,Just n) | l >= n -> return (SolvedIf [])
_ -> mzero
tryMinIsGeq :: Type -> Type -> Match Solved
tryMinIsGeq t1 t2 =
do (a,b) <- aMin t1
k1 <- aNat a
k2 <- aNat t2
return $ if k1 >= k2
then SolvedIf [ b >== t2 ]
else Unsolvable $ TCErrorMessage $
show k1 ++ " can't be greater than " ++ show k2
tryCancelVar :: Ctxt -> (Type -> Type -> Prop) -> Type -> Type -> Match Solved
tryCancelVar ctxt p t1 t2 =
let lhs = preproc t1
rhs = preproc t2
in case check [] [] lhs rhs of
Nothing -> fail ""
Just x -> return x
where
check doneLHS doneRHS lhs@((a,mbA) : moreLHS) rhs@((b, mbB) : moreRHS) =
do x <- mbA
y <- mbB
case compare x y of
LT -> check (a : doneLHS) doneRHS moreLHS rhs
EQ -> return $ SolvedIf [ p (term (doneLHS ++ map fst moreLHS))
(term (doneRHS ++ map fst moreRHS)) ]
GT -> check doneLHS (b : doneRHS) lhs moreRHS
check _ _ _ _ = Nothing
term xs = case xs of
[] -> tNum (1::Int)
_ -> foldr1 tMul xs
preproc t = let fs = splitMul t []
in sortBy cmpFact (zip fs (map cancelVar fs))
splitMul t rest = case matchMaybe (aMul t) of
Just (a,b) -> splitMul a (splitMul b rest)
Nothing -> t : rest
cancelVar t = matchMaybe $ do x <- aTVar t
guard (iIsPosFin (tvarInterval ctxt x))
return x
cmpFact (_,mbA) (_,mbB) =
case (mbA,mbB) of
(Just x, Just y) -> compare x y
(Just _, Nothing) -> LT
(Nothing, Just _) -> GT
_ -> EQ
tryEqMin :: Type -> Type -> Match Solved
tryEqMin x y =
do (a,b) <- aMin x
let check m1 m2 = do guard (m1 == y)
return $ SolvedIf [ m2 >== m1 ]
check a b <|> check b a
tryEqMins :: Type -> Type -> Match Solved
tryEqMins x y =
do (a, b) <- aMin y
let ys = splitMin a ++ splitMin b
let ys' = filter (not . isGt) ys
let y' = if null ys' then tInf else foldr1 Simp.tMin ys'
return $ if length ys' < length ys
then SolvedIf [x =#= y']
else Unsolved
where
splitMin :: Type -> [Type]
splitMin ty =
case matchMaybe (aMin ty) of
Just (t1, t2) -> splitMin t1 ++ splitMin t2
Nothing -> [ty]
isGt :: Type -> Bool
isGt t =
case matchMaybe (asAddK t) of
Just (k, t') -> k > 0 && t' == x
Nothing -> False
asAddK :: Type -> Match (Integer, Type)
asAddK t =
do (t1, t2) <- anAdd t
k <- aNat t1
return (k, t2)
tryEqVar :: Type -> TVar -> Match Solved
tryEqVar ty x =
(do (k,tv) <- matches ty (anAdd, aNat, aTVar)
guard (tv == x && k >= 1)
return $ SolvedIf [ TVar x =#= tInf ]
)
<|>
(do (l,r) <- aMin ty
let check this other =
do (k,x') <- matches this (anAdd, aNat', aTVar)
guard (x == x' && k >= Nat 1)
return $ SolvedIf [ TVar x =#= other ]
check l r <|> check r l
)
<|>
(do (k,(l,r)) <- matches ty (anAdd, aNat, aMin)
guard (k >= 1)
let check a b = do x' <- aTVar a
guard (x' == x)
return (SolvedIf [ TVar x =#= tAdd (tNum k) b ])
check l r <|> check r l
)
tryEqK :: Ctxt -> Type -> Nat' -> Match Solved
tryEqK ctxt ty lk =
do guard (lk == Inf)
(a,b) <- anAdd ty
let check x y = do guard (iIsFin (typeInterval ctxt x))
return $ SolvedIf [ y =#= tInf ]
check a b <|> check b a
<|>
do (rk, b) <- matches ty (anAdd, aNat', __)
return $
case nSub lk rk of
Nothing -> Unsolvable
$ TCErrorMessage
$ "Adding " ++ showNat' rk ++ " will always exceed "
++ showNat' lk
Just r -> SolvedIf [ b =#= tNat' r ]
<|>
do (t,rk) <- matches ty ((|-|) , __, aNat')
return (SolvedIf [ t =#= tNat' (nAdd lk rk) ])
<|>
do (rk, b) <- matches ty (aMul, aNat', __)
return $
case (lk,rk) of
(Inf,Inf) -> SolvedIf [ b >== tOne ]
(Inf,Nat _) -> SolvedIf [ b =#= tInf ]
(Nat 0, Inf) -> SolvedIf [ b =#= tZero ]
(Nat k, Inf) -> Unsolvable
$ TCErrorMessage
$ show k ++ " != inf * anything"
(Nat lk', Nat rk')
| rk' == 0 -> SolvedIf [ tNat' lk =#= tZero ]
| (q,0) <- divMod lk' rk' -> SolvedIf [ b =#= tNum q ]
| otherwise ->
Unsolvable
$ TCErrorMessage
$ showNat' lk ++ " != " ++ showNat' rk ++ " * anything"
<|>
do (rk, b) <- matches ty ((|^|), aNat, __)
return $ case lk of
Inf | rk > 1 -> SolvedIf [ b =#= tInf ]
Nat n | Just (a,True) <- genLog n rk -> SolvedIf [ b =#= tNum a]
_ -> Unsolvable $ TCErrorMessage
$ show rk ++ " ^^ anything != " ++ showNat' lk
tryEqMulConst :: Type -> Type -> Match Solved
tryEqMulConst l r =
do (lc,ls) <- matchLinear l
(rc,rs) <- matchLinear r
let d = foldr1 gcd (lc : rc : map fst (ls ++ rs))
guard (d > 1)
return (SolvedIf [build d lc ls =#= build d rc rs])
where
build d k ts = foldr tAdd (cancel d k) (map (buildS d) ts)
buildS d (k,t) = tMul (cancel d k) t
cancel d x = tNum (div x d)
tryEqAddInf :: Ctxt -> Type -> Type -> Match Solved
tryEqAddInf ctxt l r = check l r <|> check r l
where
check x y =
do (x1,x2) <- anAdd x
aInf y
let x1Fin = iIsFin (typeInterval ctxt x1)
let x2Fin = iIsFin (typeInterval ctxt x2)
return $!
if | x1Fin ->
SolvedIf [ x2 =#= y ]
| x2Fin ->
SolvedIf [ x1 =#= y ]
| otherwise ->
Unsolved
tryAddConst :: (Type -> Type -> Prop) -> Type -> Type -> Match Solved
tryAddConst rel l r =
do (x1,x2) <- anAdd l
(y1,y2) <- anAdd r
k1 <- aNat x1
k2 <- aNat y1
if k1 > k2
then return (SolvedIf [ tAdd (tNum (k1 - k2)) x2 `rel` y2 ])
else return (SolvedIf [ x2 `rel` tAdd (tNum (k2 - k1)) y2 ])
tryLinearSolution :: Type -> Type -> Match Solved
tryLinearSolution s1 t =
do (a,xs) <- matchLinearUnifier t
guard (noFreeVariables s1)
let s2 = foldr1 Simp.tAdd xs
return (SolvedIf [ TVar a =#= (Simp.tSub s1 s2), s1 >== s2 ])
matchLinearUnifier :: Pat Type (TVar,[Type])
matchLinearUnifier = go []
where
go xs t =
do v <- aFreeTVar t
guard (not . null $ xs)
return (v, xs)
<|>
do (x, y) <- anAdd t
(do v <- aFreeTVar x
guard (noFreeVariables y)
return (v, reverse (y:xs))
<|>
do guard (noFreeVariables x)
go (x:xs) y)
matchLinear :: Pat Type (Integer, [(Integer,Type)])
matchLinear = go (0, [])
where
go (c,ts) t =
do n <- aNat t
return (n + c, ts)
<|>
do (x,y) <- aMul t
n <- aNat x
return (c, (n,y) : ts)
<|>
do (l,r) <- anAdd t
(c',ts') <- go (c,ts) l
go (c',ts') r
showNat' :: Nat' -> String
showNat' Inf = "inf"
showNat' (Nat n) = show n