{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Data.SBV.Tools.Overflow (
ArithOverflow(..), CheckedArithmetic(..)
, sFromIntegralO, sFromIntegralChecked
) where
import Data.SBV.Core.Data
import Data.SBV.Core.Symbolic
import Data.SBV.Core.Model
import Data.SBV.Core.Operations
import Data.SBV.Utils.Boolean
import GHC.Stack
import Data.Int
import Data.Word
class ArithOverflow a where
bvAddO :: a -> a -> (SBool, SBool)
bvSubO :: a -> a -> (SBool, SBool)
bvMulO :: a -> a -> (SBool, SBool)
bvMulOFast :: a -> a -> (SBool, SBool)
bvDivO :: a -> a -> (SBool, SBool)
bvNegO :: a -> (SBool, SBool)
instance ArithOverflow SWord8 where {bvAddO = l2 bvAddO; bvSubO = l2 bvSubO; bvMulO = l2 bvMulO; bvMulOFast = l2 bvMulOFast; bvDivO = l2 bvDivO; bvNegO = l1 bvNegO}
instance ArithOverflow SWord16 where {bvAddO = l2 bvAddO; bvSubO = l2 bvSubO; bvMulO = l2 bvMulO; bvMulOFast = l2 bvMulOFast; bvDivO = l2 bvDivO; bvNegO = l1 bvNegO}
instance ArithOverflow SWord32 where {bvAddO = l2 bvAddO; bvSubO = l2 bvSubO; bvMulO = l2 bvMulO; bvMulOFast = l2 bvMulOFast; bvDivO = l2 bvDivO; bvNegO = l1 bvNegO}
instance ArithOverflow SWord64 where {bvAddO = l2 bvAddO; bvSubO = l2 bvSubO; bvMulO = l2 bvMulO; bvMulOFast = l2 bvMulOFast; bvDivO = l2 bvDivO; bvNegO = l1 bvNegO}
instance ArithOverflow SInt8 where {bvAddO = l2 bvAddO; bvSubO = l2 bvSubO; bvMulO = l2 bvMulO; bvMulOFast = l2 bvMulOFast; bvDivO = l2 bvDivO; bvNegO = l1 bvNegO}
instance ArithOverflow SInt16 where {bvAddO = l2 bvAddO; bvSubO = l2 bvSubO; bvMulO = l2 bvMulO; bvMulOFast = l2 bvMulOFast; bvDivO = l2 bvDivO; bvNegO = l1 bvNegO}
instance ArithOverflow SInt32 where {bvAddO = l2 bvAddO; bvSubO = l2 bvSubO; bvMulO = l2 bvMulO; bvMulOFast = l2 bvMulOFast; bvDivO = l2 bvDivO; bvNegO = l1 bvNegO}
instance ArithOverflow SInt64 where {bvAddO = l2 bvAddO; bvSubO = l2 bvSubO; bvMulO = l2 bvMulO; bvMulOFast = l2 bvMulOFast; bvDivO = l2 bvDivO; bvNegO = l1 bvNegO}
instance ArithOverflow SVal where
bvAddO = signPick2 bvuaddo bvsaddo
bvSubO = signPick2 bvusubo bvssubo
bvMulO = signPick2 bvumulo bvsmulo
bvMulOFast = signPick2 bvumuloFast bvsmuloFast
bvDivO = signPick2 bvudivo bvsdivo
bvNegO = signPick1 bvunego bvsnego
class (ArithOverflow (SBV a), Num a, SymWord a) => CheckedArithmetic a where
(+!) :: (?loc :: CallStack) => SBV a -> SBV a -> SBV a
(-!) :: (?loc :: CallStack) => SBV a -> SBV a -> SBV a
(*!) :: (?loc :: CallStack) => SBV a -> SBV a -> SBV a
(/!) :: (?loc :: CallStack) => SBV a -> SBV a -> SBV a
negateChecked :: (?loc :: CallStack) => SBV a -> SBV a
infixl 6 +!, -!
infixl 7 *!, /!
instance CheckedArithmetic Word8 where
(+!) = checkOp2 ?loc "addition" (+) bvAddO
(-!) = checkOp2 ?loc "subtraction" (-) bvSubO
(*!) = checkOp2 ?loc "multiplication" (*) bvMulO
(/!) = checkOp2 ?loc "division" sDiv bvDivO
negateChecked = checkOp1 ?loc "unary negation" negate bvNegO
instance CheckedArithmetic Word16 where
(+!) = checkOp2 ?loc "addition" (+) bvAddO
(-!) = checkOp2 ?loc "subtraction" (-) bvSubO
(*!) = checkOp2 ?loc "multiplication" (*) bvMulO
(/!) = checkOp2 ?loc "division" sDiv bvDivO
negateChecked = checkOp1 ?loc "unary negation" negate bvNegO
instance CheckedArithmetic Word32 where
(+!) = checkOp2 ?loc "addition" (+) bvAddO
(-!) = checkOp2 ?loc "subtraction" (-) bvSubO
(*!) = checkOp2 ?loc "multiplication" (*) bvMulO
(/!) = checkOp2 ?loc "division" sDiv bvDivO
negateChecked = checkOp1 ?loc "unary negation" negate bvNegO
instance CheckedArithmetic Word64 where
(+!) = checkOp2 ?loc "addition" (+) bvAddO
(-!) = checkOp2 ?loc "subtraction" (-) bvSubO
(*!) = checkOp2 ?loc "multiplication" (*) bvMulO
(/!) = checkOp2 ?loc "division" sDiv bvDivO
negateChecked = checkOp1 ?loc "unary negation" negate bvNegO
instance CheckedArithmetic Int8 where
(+!) = checkOp2 ?loc "addition" (+) bvAddO
(-!) = checkOp2 ?loc "subtraction" (-) bvSubO
(*!) = checkOp2 ?loc "multiplication" (*) bvMulO
(/!) = checkOp2 ?loc "division" sDiv bvDivO
negateChecked = checkOp1 ?loc "unary negation" negate bvNegO
instance CheckedArithmetic Int16 where
(+!) = checkOp2 ?loc "addition" (+) bvAddO
(-!) = checkOp2 ?loc "subtraction" (-) bvSubO
(*!) = checkOp2 ?loc "multiplication" (*) bvMulO
(/!) = checkOp2 ?loc "division" sDiv bvDivO
negateChecked = checkOp1 ?loc "unary negation" negate bvNegO
instance CheckedArithmetic Int32 where
(+!) = checkOp2 ?loc "addition" (+) bvAddO
(-!) = checkOp2 ?loc "subtraction" (-) bvSubO
(*!) = checkOp2 ?loc "multiplication" (*) bvMulO
(/!) = checkOp2 ?loc "division" sDiv bvDivO
negateChecked = checkOp1 ?loc "unary negation" negate bvNegO
instance CheckedArithmetic Int64 where
(+!) = checkOp2 ?loc "addition" (+) bvAddO
(-!) = checkOp2 ?loc "subtraction" (-) bvSubO
(*!) = checkOp2 ?loc "multiplication" (*) bvMulO
(/!) = checkOp2 ?loc "division" sDiv bvDivO
negateChecked = checkOp1 ?loc "unary negation" negate bvNegO
zx :: Int -> SVal -> SVal
zx n a
| n < sa = error $ "Data.SBV: Unexpected zero extension: from: " ++ show (intSizeOf a) ++ " to: " ++ show n
| True = p `svJoin` a
where sa = intSizeOf a
s = hasSign a
p = svInteger (KBounded s (n - sa)) 0
sx :: Int -> SVal -> SVal
sx n a
| n < sa = error $ "Data.SBV: Unexpected sign extension: from: " ++ show (intSizeOf a) ++ " to: " ++ show n
| True = p `svJoin` a
where sa = intSizeOf a
mk = svInteger $ KBounded (hasSign a) (n - sa)
p = svIte (pos a) (mk 0) (mk (-1))
signBit :: SVal -> SVal
signBit x = x `svTestBit` (intSizeOf x - 1)
neg :: SVal -> SVal
neg x = signBit x `svEqual` svTrue
pos :: SVal -> SVal
pos x = signBit x `svEqual` svFalse
sameSign :: SVal -> SVal -> SVal
sameSign x y = (pos x `svAnd` pos y) `svOr` (neg x `svAnd` neg y)
diffSign :: SVal -> SVal -> SVal
diffSign x y = svNot (sameSign x y)
svAll :: [SVal] -> SVal
svAll = foldr svAnd svTrue
allZero :: Int -> Int -> SBV a -> SVal
allZero m n (SBV x)
| m >= sz || n < 0 || m < n
= error $ "Data.SBV.Tools.Overflow.allZero: Received unexpected parameters: " ++ show (m, n, sz)
| True
= svAll [svTestBit x i `svEqual` svFalse | i <- [m, m-1 .. n]]
where sz = intSizeOf x
allOne :: Int -> Int -> SBV a -> SVal
allOne m n (SBV x)
| m >= sz || n < 0 || m < n
= error $ "Data.SBV.Tools.Overflow.allOne: Received unexpected parameters: " ++ show (m, n, sz)
| True
= svAll [svTestBit x i `svEqual` svTrue | i <- [m, m-1 .. n]]
where sz = intSizeOf x
bvuaddo :: Int -> SVal -> SVal -> (SVal, SVal)
bvuaddo n x y = (underflow, overflow)
where underflow = svFalse
n' = n+1
overflow = neg $ zx n' x `svPlus` zx n' y
bvsaddo :: Int -> SVal -> SVal -> (SVal, SVal)
bvsaddo _n x y = (underflow, overflow)
where underflow = svAll [neg x, neg y, pos (x `svPlus` y)]
overflow = svAll [pos x, pos y, neg (x `svPlus` y)]
bvusubo :: Int -> SVal -> SVal -> (SVal, SVal)
bvusubo _n x y = (underflow, overflow)
where underflow = y `svGreaterThan` x
overflow = svFalse
bvssubo :: Int -> SVal -> SVal -> (SVal, SVal)
bvssubo _n x y = (underflow, overflow)
where underflow = svAll [neg x, pos y, pos (x `svMinus` y)]
overflow = svAll [pos x, neg y, neg (x `svMinus` y)]
bvumulo :: Int -> SVal -> SVal -> (SVal, SVal)
bvumulo 0 _ _ = (svFalse, svFalse)
bvumulo n x y = (underflow, overflow)
where underflow = svFalse
n1 = n+1
overflow1 = signBit $ zx n1 x `svTimes` zx n1 y
overflow2 = go 1 svFalse svFalse
where go i ovf v
| i >= n = v
| True = go (i+1) ovf' v'
where ovf' = ovf `svOr` (x `svTestBit` (n-i))
tmp = ovf' `svAnd` (y `svTestBit` i)
v' = tmp `svOr` v
overflow = overflow1 `svOr` overflow2
bvsmulo :: Int -> SVal -> SVal -> (SVal, SVal)
bvsmulo 0 _ _ = (svFalse, svFalse)
bvsmulo n x y = (underflow, overflow)
where underflow = diffSign x y `svAnd` overflowPossible
overflow = sameSign x y `svAnd` overflowPossible
n1 = n+1
overflow1 = (xy1 `svTestBit` n) `svXOr` (xy1 `svTestBit` (n-1))
where xy1 = sx n1 x `svTimes` sx n1 y
overflow2 = go 1 svFalse svFalse
where sY = signBit y
sX = signBit x
go i v a_acc
| i + 1 >= n = v
| True = go (i+1) v' a_acc'
where b = sY `svXOr` (y `svTestBit` i)
a = sX `svXOr` (x `svTestBit` (n-1-i))
a_acc' = a `svOr` a_acc
tmp = a_acc' `svAnd` b
v' = tmp `svOr` v
overflowPossible = overflow1 `svOr` overflow2
known :: SVal -> Bool
known (SVal _ (Left _)) = True
known _ = False
bvumuloFast :: Int -> SVal -> SVal -> (SVal, SVal)
bvumuloFast n x y
| known x && known y
= bvumulo n x y
| True
= (underflow, overflow)
where underflow = fst $ bvumulo n x y
overflow = svMkOverflow Overflow_UMul_OVFL x y
bvsmuloFast :: Int -> SVal -> SVal -> (SVal, SVal)
bvsmuloFast n x y
| known x && known y
= bvsmulo n x y
| True
= (underflow, overflow)
where underflow = svMkOverflow Overflow_SMul_UDFL x y
overflow = svMkOverflow Overflow_SMul_OVFL x y
bvudivo :: Int -> SVal -> SVal -> (SVal, SVal)
bvudivo _ _ _ = (underflow, overflow)
where underflow = svFalse
overflow = svFalse
bvsdivo :: Int -> SVal -> SVal -> (SVal, SVal)
bvsdivo n x y = (underflow, overflow)
where underflow = svFalse
ones = svInteger (KBounded True n) (-1)
topSet = svInteger (KBounded True n) (2^(n-1))
overflow = svAll [x `svEqual` topSet, y `svEqual` ones]
bvunego :: Int -> SVal -> (SVal, SVal)
bvunego _ _ = (underflow, overflow)
where underflow = svFalse
overflow = svFalse
bvsnego :: Int -> SVal -> (SVal, SVal)
bvsnego n x = (underflow, overflow)
where underflow = svFalse
topSet = svInteger (KBounded True n) (2^(n-1))
overflow = x `svEqual` topSet
sFromIntegralO :: forall a b. (Integral a, HasKind a, Num a, SymWord a, HasKind b, Num b, SymWord b) => SBV a -> (SBV b, (SBool, SBool))
sFromIntegralO x = case (kindOf x, kindOf (undefined :: b)) of
(KBounded False n, KBounded False m) -> (res, u2u n m)
(KBounded False n, KBounded True m) -> (res, u2s n m)
(KBounded True n, KBounded False m) -> (res, s2u n m)
(KBounded True n, KBounded True m) -> (res, s2s n m)
(KUnbounded, KBounded s m) -> (res, checkBounds s m)
(KBounded{}, KUnbounded) -> (res, (false, false))
(KUnbounded, KUnbounded) -> (res, (false, false))
(kFrom, kTo) -> error $ "sFromIntegralO: Expected bounded-BV types, received: " ++ show (kFrom, kTo)
where res :: SBV b
res = sFromIntegral x
checkBounds :: Bool -> Int -> (SBool, SBool)
checkBounds signed sz = (ix .< literal lb, ix .> literal ub)
where ix :: SInteger
ix = sFromIntegral x
s :: Integer
s = fromIntegral sz
ub :: Integer
ub | signed = 2^(s - 1) - 1
| True = 2^s - 1
lb :: Integer
lb | signed = -ub-1
| True = 0
u2u :: Int -> Int -> (SBool, SBool)
u2u n m = (underflow, overflow)
where underflow = false
overflow
| n <= m = false
| True = SBV $ svNot $ allZero (n-1) m x
u2s :: Int -> Int -> (SBool, SBool)
u2s n m = (underflow, overflow)
where underflow = false
overflow
| m > n = false
| True = SBV $ svNot $ allZero (n-1) (m-1) x
s2u :: Int -> Int -> (SBool, SBool)
s2u n m = (underflow, overflow)
where underflow = SBV $ (unSBV x `svTestBit` (n-1)) `svEqual` svTrue
overflow
| m >= n - 1 = false
| True = SBV $ svAll [(unSBV x `svTestBit` (n-1)) `svEqual` svFalse, svNot $ allZero (n-1) m x]
s2s :: Int -> Int -> (SBool, SBool)
s2s n m = (underflow, overflow)
where underflow
| m > n = false
| True = SBV $ svAll [(unSBV x `svTestBit` (n-1)) `svEqual` svTrue, svNot $ allOne (n-1) (m-1) x]
overflow
| m > n = false
| True = SBV $ svAll [(unSBV x `svTestBit` (n-1)) `svEqual` svFalse, svNot $ allZero (n-1) (m-1) x]
sFromIntegralChecked :: forall a b. (?loc :: CallStack, Integral a, HasKind a, HasKind b, Num a, SymWord a, HasKind b, Num b, SymWord b) => SBV a -> SBV b
sFromIntegralChecked x = sAssert (Just ?loc) (msg "underflows") (bnot u)
$ sAssert (Just ?loc) (msg "overflows") (bnot o)
r
where kFrom = show $ kindOf x
kTo = show $ kindOf (undefined :: b)
msg c = "Casting from " ++ kFrom ++ " to " ++ kTo ++ " " ++ c
(r, (u, o)) = sFromIntegralO x
l2 :: (SVal -> SVal -> (SBool, SBool)) -> SBV a -> SBV a -> (SBool, SBool)
l2 f (SBV a) (SBV b) = f a b
l1 :: (SVal -> (SBool, SBool)) -> SBV a -> (SBool, SBool)
l1 f (SBV a) = f a
signPick2 :: (Int -> SVal -> SVal -> (SVal, SVal)) -> (Int -> SVal -> SVal -> (SVal, SVal)) -> (SVal -> SVal -> (SBool, SBool))
signPick2 fu fs a b
| hasSign a = let (u, o) = fs n a b in (SBV u, SBV o)
| True = let (u, o) = fu n a b in (SBV u, SBV o)
where n = intSizeOf a
signPick1 :: (Int -> SVal -> (SVal, SVal)) -> (Int -> SVal -> (SVal, SVal)) -> (SVal -> (SBool, SBool))
signPick1 fu fs a
| hasSign a = let (u, o) = fs n a in (SBV u, SBV o)
| True = let (u, o) = fu n a in (SBV u, SBV o)
where n = intSizeOf a
checkOp1 :: HasKind a => CallStack -> String -> (a -> SBV b) -> (a -> (SBool, SBool)) -> a -> SBV b
checkOp1 loc w op cop a = sAssert (Just loc) (msg "underflows") (bnot u)
$ sAssert (Just loc) (msg "overflows") (bnot o)
$ op a
where k = show $ kindOf a
msg c = k ++ " " ++ w ++ " " ++ c
(u, o) = cop a
checkOp2 :: HasKind a => CallStack -> String -> (a -> b -> SBV c) -> (a -> b -> (SBool, SBool)) -> a -> b -> SBV c
checkOp2 loc w op cop a b = sAssert (Just loc) (msg "underflows") (bnot u)
$ sAssert (Just loc) (msg "overflows") (bnot o)
$ a `op` b
where k = show $ kindOf a
msg c = k ++ " " ++ w ++ " " ++ c
(u, o) = a `cop` b