{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Grisette.Core.Data.Class.SafeSymShift
  ( SafeSymShift (..),
  )
where

import Control.Exception (ArithException (Overflow))
import Control.Monad.Error.Class (MonadError)
import Data.Bits (Bits (shiftL, shiftR), FiniteBits (finiteBitSize))
import Data.Int (Int16, Int32, Int64, Int8)
import Data.Word (Word16, Word32, Word64, Word8)
import GHC.TypeLits (KnownNat, type (<=))
import Grisette.Core.Control.Monad.Union (MonadUnion)
import Grisette.Core.Data.BV (IntN, WordN)
import Grisette.Core.Data.Class.LogicalOp
  ( LogicalOp ((.&&), (.||)),
  )
import Grisette.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Core.Data.Class.SOrd
  ( SOrd ((.<), (.>=)),
  )
import Grisette.Core.Data.Class.SimpleMergeable
  ( UnionLike,
    mrgIf,
  )
import Grisette.Core.Data.Class.SymShift (SymShift)
import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bits
  ( pevalShiftLeftTerm,
    pevalShiftRightTerm,
  )
import Grisette.IR.SymPrim.Data.SymPrim (SymIntN (SymIntN), SymWordN (SymWordN))
import Grisette.Lib.Control.Monad (mrgReturn)
import Grisette.Lib.Control.Monad.Except (mrgThrowError)

-- | Safe version for `shiftL` or `shiftR`.
--
-- The `safeSymShiftL` and `safeSymShiftR` and their primed versions are defined
-- for all non-negative shift amounts.
--
-- * Shifting by negative shift amounts is an error.
-- * The result is defined to be 0 when shifting left by more than or equal to
-- the  bit size of the number.
-- * The result is defined to be 0 when shifting right by more than or equal to
-- the bit size of the number and the number is unsigned or signed non-negative.
-- * The result is defined to be -1 when shifting right by more than or equal to
-- the bit size of the number and the number is signed negative.
--
-- The `safeSymStrictShiftL` and `safeSymStrictShiftR` and their primed versions
-- are defined for all non-negative shift amounts that is less than the bit
-- size. Shifting by more than or equal to the bit size is an error, otherwise
-- they are the same as the non-strict versions.
class (SymShift a) => SafeSymShift e a | a -> e where
  safeSymShiftL :: (MonadError e m, UnionLike m) => a -> a -> m a
  safeSymShiftL = (e -> e) -> a -> a -> m a
forall e a e' (m :: * -> *).
(SafeSymShift e a, MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
safeSymShiftL' e -> e
forall a. a -> a
id
  safeSymShiftR :: (MonadError e m, UnionLike m) => a -> a -> m a
  safeSymShiftR = (e -> e) -> a -> a -> m a
forall e a e' (m :: * -> *).
(SafeSymShift e a, MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
safeSymShiftR' e -> e
forall a. a -> a
id
  safeSymShiftL' ::
    (MonadError e' m, UnionLike m) => (e -> e') -> a -> a -> m a
  safeSymShiftR' ::
    (MonadError e' m, UnionLike m) => (e -> e') -> a -> a -> m a
  safeSymStrictShiftL :: (MonadError e m, UnionLike m) => a -> a -> m a
  safeSymStrictShiftL = (e -> e) -> a -> a -> m a
forall e a e' (m :: * -> *).
(SafeSymShift e a, MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
safeSymStrictShiftL' e -> e
forall a. a -> a
id
  safeSymStrictShiftR :: (MonadError e m, UnionLike m) => a -> a -> m a
  safeSymStrictShiftR = (e -> e) -> a -> a -> m a
forall e a e' (m :: * -> *).
(SafeSymShift e a, MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
safeSymStrictShiftR' e -> e
forall a. a -> a
id
  safeSymStrictShiftL' ::
    (MonadError e' m, UnionLike m) => (e -> e') -> a -> a -> m a
  safeSymStrictShiftR' ::
    (MonadError e' m, UnionLike m) => (e -> e') -> a -> a -> m a
  {-# MINIMAL
    safeSymShiftL',
    safeSymShiftR',
    safeSymStrictShiftL',
    safeSymStrictShiftR'
    #-}

-- | This function handles the case when the shift amount is out the range of
-- `Int` correctly.
safeSymShiftLConcreteNum ::
  (MonadError e m, MonadUnion m, Integral a, FiniteBits a, Mergeable a) =>
  e ->
  Bool ->
  a ->
  a ->
  m a
safeSymShiftLConcreteNum :: forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftLConcreteNum e
e Bool
_ a
_ a
s | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = e -> m a
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError e
e
safeSymShiftLConcreteNum e
e Bool
allowLargeShiftAmount a
a a
s
  | (a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s :: Integer) Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a) =
      if Bool
allowLargeShiftAmount then a -> m a
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn a
0 else e -> m a
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError e
e
safeSymShiftLConcreteNum e
_ Bool
_ a
a a
s = a -> m a
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
shiftL a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)

-- | This function handles the case when the shift amount is out the range of
-- `Int` correctly.
safeSymShiftRConcreteNum ::
  (MonadError e m, MonadUnion m, Integral a, FiniteBits a, Mergeable a) =>
  e ->
  Bool ->
  a ->
  a ->
  m a
safeSymShiftRConcreteNum :: forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftRConcreteNum e
e Bool
_ a
_ a
s | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = e -> m a
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError e
e
safeSymShiftRConcreteNum e
e Bool
allowLargeShiftAmount a
a a
s
  | (a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s :: Integer) Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a) =
      if Bool
allowLargeShiftAmount then a -> m a
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn a
0 else e -> m a
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError e
e
safeSymShiftRConcreteNum e
_ Bool
_ a
a a
s = a -> m a
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
shiftR a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)

#define SAFE_SYM_SHIFT_CONCRETE(T) \
  instance SafeSymShift ArithException T where \
    safeSymShiftL' f = safeSymShiftLConcreteNum (f Overflow) True; \
    safeSymShiftR' f = safeSymShiftRConcreteNum (f Overflow) True; \
    safeSymStrictShiftL' f = safeSymShiftLConcreteNum (f Overflow) False; \
    safeSymStrictShiftR' f = safeSymShiftRConcreteNum (f Overflow) False

#if 1
SAFE_SYM_SHIFT_CONCRETE(Word8)
SAFE_SYM_SHIFT_CONCRETE(Word16)
SAFE_SYM_SHIFT_CONCRETE(Word32)
SAFE_SYM_SHIFT_CONCRETE(Word64)
SAFE_SYM_SHIFT_CONCRETE(Word)
SAFE_SYM_SHIFT_CONCRETE(Int8)
SAFE_SYM_SHIFT_CONCRETE(Int16)
SAFE_SYM_SHIFT_CONCRETE(Int32)
SAFE_SYM_SHIFT_CONCRETE(Int64)
SAFE_SYM_SHIFT_CONCRETE(Int)
#endif

instance (KnownNat n, 1 <= n) => SafeSymShift ArithException (WordN n) where
  safeSymShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> WordN n -> WordN n -> m (WordN n)
safeSymShiftL' ArithException -> e'
f = e' -> Bool -> WordN n -> WordN n -> m (WordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftLConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
True
  safeSymShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> WordN n -> WordN n -> m (WordN n)
safeSymShiftR' ArithException -> e'
f = e' -> Bool -> WordN n -> WordN n -> m (WordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftRConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
True
  safeSymStrictShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> WordN n -> WordN n -> m (WordN n)
safeSymStrictShiftL' ArithException -> e'
f = e' -> Bool -> WordN n -> WordN n -> m (WordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftLConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
False
  safeSymStrictShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> WordN n -> WordN n -> m (WordN n)
safeSymStrictShiftR' ArithException -> e'
f = e' -> Bool -> WordN n -> WordN n -> m (WordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftRConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
False

instance (KnownNat n, 1 <= n) => SafeSymShift ArithException (IntN n) where
  safeSymShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> IntN n -> IntN n -> m (IntN n)
safeSymShiftL' ArithException -> e'
f = e' -> Bool -> IntN n -> IntN n -> m (IntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftLConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
True
  safeSymShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> IntN n -> IntN n -> m (IntN n)
safeSymShiftR' ArithException -> e'
f = e' -> Bool -> IntN n -> IntN n -> m (IntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftRConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
True
  safeSymStrictShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> IntN n -> IntN n -> m (IntN n)
safeSymStrictShiftL' ArithException -> e'
f = e' -> Bool -> IntN n -> IntN n -> m (IntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftLConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
False
  safeSymStrictShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> IntN n -> IntN n -> m (IntN n)
safeSymStrictShiftR' ArithException -> e'
f = e' -> Bool -> IntN n -> IntN n -> m (IntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftRConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
False

instance (KnownNat n, 1 <= n) => SafeSymShift ArithException (SymWordN n) where
  safeSymShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e')
-> SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymShiftL' ArithException -> e'
_ (SymWordN Term (WordN n)
a) (SymWordN Term (WordN n)
s) =
    SymWordN n -> m (SymWordN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftLeftTerm Term (WordN n)
a Term (WordN n)
s
  safeSymShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e')
-> SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymShiftR' ArithException -> e'
_ (SymWordN Term (WordN n)
a) (SymWordN Term (WordN n)
s) =
    SymWordN n -> m (SymWordN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftRightTerm Term (WordN n)
a Term (WordN n)
s
  safeSymStrictShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e')
-> SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymStrictShiftL' ArithException -> e'
f a :: SymWordN n
a@(SymWordN Term (WordN n)
ta) s :: SymWordN n
s@(SymWordN Term (WordN n)
ts) =
    SymBool -> m (SymWordN n) -> m (SymWordN n) -> m (SymWordN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymWordN n
s SymWordN n -> SymWordN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= Int -> SymWordN n
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SymWordN n -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize SymWordN n
a))
      (e' -> m (SymWordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymWordN n)) -> e' -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
      (SymWordN n -> m (SymWordN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftLeftTerm Term (WordN n)
ta Term (WordN n)
ts)
  safeSymStrictShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e')
-> SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymStrictShiftR' ArithException -> e'
f a :: SymWordN n
a@(SymWordN Term (WordN n)
ta) s :: SymWordN n
s@(SymWordN Term (WordN n)
ts) =
    SymBool -> m (SymWordN n) -> m (SymWordN n) -> m (SymWordN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymWordN n
s SymWordN n -> SymWordN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= Int -> SymWordN n
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SymWordN n -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize SymWordN n
a))
      (e' -> m (SymWordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymWordN n)) -> e' -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
      (SymWordN n -> m (SymWordN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftRightTerm Term (WordN n)
ta Term (WordN n)
ts)

instance (KnownNat n, 1 <= n) => SafeSymShift ArithException (SymIntN n) where
  safeSymShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymShiftL' ArithException -> e'
f (SymIntN Term (IntN n)
a) ss :: SymIntN n
ss@(SymIntN Term (IntN n)
s) =
    SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymIntN n
ss SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0)
      (e' -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymIntN n)) -> e' -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
      (SymIntN n -> m (SymIntN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftLeftTerm Term (IntN n)
a Term (IntN n)
s)
  safeSymShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymShiftR' ArithException -> e'
f (SymIntN Term (IntN n)
a) ss :: SymIntN n
ss@(SymIntN Term (IntN n)
s) =
    SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymIntN n
ss SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0)
      (e' -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymIntN n)) -> e' -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
      (SymIntN n -> m (SymIntN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftRightTerm Term (IntN n)
a Term (IntN n)
s)
  safeSymStrictShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymStrictShiftL' ArithException -> e'
f a :: SymIntN n
a@(SymIntN Term (IntN n)
ta) s :: SymIntN n
s@(SymIntN Term (IntN n)
ts) =
    SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymIntN n
s SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.|| (SymIntN n
bs SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& SymIntN n
s SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= SymIntN n
bs))
      (e' -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymIntN n)) -> e' -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
      (SymIntN n -> m (SymIntN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftLeftTerm Term (IntN n)
ta Term (IntN n)
ts)
    where
      bs :: SymIntN n
bs = Int -> SymIntN n
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SymIntN n -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize SymIntN n
a)
  safeSymStrictShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymStrictShiftR' ArithException -> e'
f a :: SymIntN n
a@(SymIntN Term (IntN n)
ta) s :: SymIntN n
s@(SymIntN Term (IntN n)
ts) =
    SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymIntN n
s SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.|| (SymIntN n
bs SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& SymIntN n
s SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= SymIntN n
bs))
      (e' -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymIntN n)) -> e' -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
      (SymIntN n -> m (SymIntN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftRightTerm Term (IntN n)
ta Term (IntN n)
ts)
    where
      bs :: SymIntN n
bs = Int -> SymIntN n
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SymIntN n -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize SymIntN n
a)