{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
module Data.BitVector.Sized.Overflow
( Overflow(..)
, UnsignedOverflow(..)
, SignedOverflow(..)
, ofUnsigned
, ofSigned
, ofResult
, shlOf
, addOf
, subOf
, mulOf
, squotOf
, sremOf
, sdivOf
, smodOf
) where
import qualified Data.Bits as B
import Numeric.Natural
import GHC.TypeLits
import Data.Parameterized ( NatRepr )
import qualified Data.Parameterized.NatRepr as P
import Data.BitVector.Sized.Internal ( BV(..)
, mkBV'
, asUnsigned
, asSigned
, shiftAmount
)
data UnsignedOverflow = UnsignedOverflow
| NoUnsignedOverflow
deriving (Show, Eq)
instance Semigroup UnsignedOverflow where
NoUnsignedOverflow <> NoUnsignedOverflow = NoUnsignedOverflow
_ <> _ = UnsignedOverflow
instance Monoid UnsignedOverflow where
mempty = NoUnsignedOverflow
data SignedOverflow = SignedOverflow
| NoSignedOverflow
deriving (Show, Eq)
instance Semigroup SignedOverflow where
NoSignedOverflow <> NoSignedOverflow = NoSignedOverflow
_ <> _ = SignedOverflow
instance Monoid SignedOverflow where
mempty = NoSignedOverflow
data Overflow a =
Overflow UnsignedOverflow SignedOverflow a
deriving (Show, Eq)
ofUnsigned :: Overflow a -> Bool
ofUnsigned (Overflow UnsignedOverflow _ _) = True
ofUnsigned _ = False
ofSigned :: Overflow a -> Bool
ofSigned (Overflow _ SignedOverflow _) = True
ofSigned _ = False
ofResult :: Overflow a -> a
ofResult (Overflow _ _ res) = res
instance Foldable Overflow where
foldMap f (Overflow _ _ a) = f a
instance Traversable Overflow where
sequenceA (Overflow uof sof a) = Overflow uof sof <$> a
instance Functor Overflow where
fmap f (Overflow uof sof a) = Overflow uof sof (f a)
instance Applicative Overflow where
pure a = Overflow mempty mempty a
Overflow uof sof f <*> Overflow uof' sof' a =
Overflow (uof <> uof') (sof <> sof') (f a)
instance Monad Overflow where
Overflow uof sof a >>= k =
let Overflow uof' sof' b = k a
in Overflow (uof <> uof') (sof <> sof') b
getUof :: NatRepr w -> Integer -> UnsignedOverflow
getUof w x = if x < P.minUnsigned w || x > P.maxUnsigned w
then UnsignedOverflow
else NoUnsignedOverflow
getSof :: NatRepr w -> Integer -> SignedOverflow
getSof w x = case P.isZeroOrGT1 w of
Left P.Refl -> NoSignedOverflow
Right P.LeqProof ->
if x < P.minSigned w || x > P.maxSigned w
then SignedOverflow
else NoSignedOverflow
liftBinary :: (1 <= w) => (Integer -> Integer -> Integer)
-> NatRepr w
-> BV w -> BV w -> Overflow (BV w)
liftBinary op w xv yv =
let ux = asUnsigned xv
uy = asUnsigned yv
sx = asSigned w xv
sy = asSigned w yv
ures = ux `op` uy
sres = sx `op` sy
uof = getUof w ures
sof = getSof w sres
in Overflow uof sof (mkBV' w ures)
addOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
addOf = liftBinary (+)
subOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
subOf = liftBinary (-)
mulOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
mulOf = liftBinary (*)
shlOf :: (1 <= w) => NatRepr w -> BV w -> Natural -> Overflow (BV w)
shlOf w xv shf =
let ux = asUnsigned xv
sx = asSigned w xv
ures = ux `B.shiftL` shiftAmount w shf
sres = sx `B.shiftL` shiftAmount w shf
uof = getUof w ures
sof = getSof w sres
in Overflow uof sof (mkBV' w ures)
squotOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
squotOf w bv1 bv2 = Overflow NoUnsignedOverflow sof (mkBV' w (x `quot` y))
where x = asSigned w bv1
y = asSigned w bv2
sof = if (x == P.minSigned w && y == -1)
then SignedOverflow
else NoSignedOverflow
sremOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
sremOf w bv1 bv2 = Overflow NoUnsignedOverflow sof (mkBV' w (x `rem` y))
where x = asSigned w bv1
y = asSigned w bv2
sof = if (x == P.minSigned w && y == -1)
then SignedOverflow
else NoSignedOverflow
sdivOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
sdivOf w bv1 bv2 = Overflow NoUnsignedOverflow sof (mkBV' w (x `div` y))
where x = asSigned w bv1
y = asSigned w bv2
sof = if (x == P.minSigned w && y == -1)
then SignedOverflow
else NoSignedOverflow
smodOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
smodOf w bv1 bv2 = Overflow NoUnsignedOverflow sof (mkBV' w (x `mod` y))
where x = asSigned w bv1
y = asSigned w bv2
sof = if (x == P.minSigned w && y == -1)
then SignedOverflow
else NoSignedOverflow