{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}

module Grisette.Internal.Core.Data.Class.SymShift
  ( SymShift (..),
    DefaultFiniteBitsSymShift (..),
  )
where

import Data.Bits (Bits (isSigned, shift, shiftR), FiniteBits (finiteBitSize))
import Data.Int (Int16, Int32, Int64, Int8)
import Data.Word (Word16, Word32, Word64, Word8)

-- | A class for shifting operations.
--
-- The `symShift` function shifts the value to the left if the shift amount is
-- positive, and to the right if the shift amount is negative. If shifting
-- beyond the bit width of the value, the result is the same as shifting with
-- the bit width.
--
-- The `symShiftNegated` function shifts the value to the right if the shift
-- amount is positive, and to the left if the shift amount is negative. This
-- function is introduced to handle the asymmetry of the range of values.
class (Bits a) => SymShift a where
  symShift :: a -> a -> a
  symShiftNegated :: a -> a -> a

instance SymShift Int where
  symShift :: Int -> Int -> Int
symShift Int
a Int
s
    | Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize Int
s = Int
0
    | Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= -Int -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize Int
s = if Int
a Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 then Int
0 else -Int
1
    | Bool
otherwise = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shift Int
a Int
s
  symShiftNegated :: Int -> Int -> Int
  symShiftNegated :: Int -> Int -> Int
symShiftNegated Int
a Int
s
    | Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= -Int -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize Int
a = Int
0
    | Bool
otherwise = Int -> Int -> Int
forall a. SymShift a => a -> a -> a
symShift Int
a (-Int
s)

newtype DefaultFiniteBitsSymShift a = DefaultFiniteBitsSymShift
  { forall a. DefaultFiniteBitsSymShift a -> a
unDefaultFiniteBitsSymShift :: a
  }
  deriving newtype (DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a -> Bool
(DefaultFiniteBitsSymShift a
 -> DefaultFiniteBitsSymShift a -> Bool)
-> (DefaultFiniteBitsSymShift a
    -> DefaultFiniteBitsSymShift a -> Bool)
-> Eq (DefaultFiniteBitsSymShift a)
forall a.
Eq a =>
DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a.
Eq a =>
DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a -> Bool
== :: DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a -> Bool
$c/= :: forall a.
Eq a =>
DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a -> Bool
/= :: DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a -> Bool
Eq, Eq (DefaultFiniteBitsSymShift a)
DefaultFiniteBitsSymShift a
Eq (DefaultFiniteBitsSymShift a) =>
(DefaultFiniteBitsSymShift a
 -> DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a
    -> DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a
    -> DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a
    -> Int -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a
    -> Int -> DefaultFiniteBitsSymShift a)
-> DefaultFiniteBitsSymShift a
-> (Int -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a
    -> Int -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a
    -> Int -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a
    -> Int -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a -> Int -> Bool)
-> (DefaultFiniteBitsSymShift a -> Maybe Int)
-> (DefaultFiniteBitsSymShift a -> Int)
-> (DefaultFiniteBitsSymShift a -> Bool)
-> (DefaultFiniteBitsSymShift a
    -> Int -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a
    -> Int -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a
    -> Int -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a
    -> Int -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a
    -> Int -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a
    -> Int -> DefaultFiniteBitsSymShift a)
-> (DefaultFiniteBitsSymShift a -> Int)
-> Bits (DefaultFiniteBitsSymShift a)
Int -> DefaultFiniteBitsSymShift a
DefaultFiniteBitsSymShift a -> Bool
DefaultFiniteBitsSymShift a -> Int
DefaultFiniteBitsSymShift a -> Maybe Int
DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
DefaultFiniteBitsSymShift a -> Int -> Bool
DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
DefaultFiniteBitsSymShift a
-> DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
forall a.
Eq a =>
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> a
-> (Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> Bool)
-> (a -> Maybe Int)
-> (a -> Int)
-> (a -> Bool)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int)
-> Bits a
forall a. Bits a => Eq (DefaultFiniteBitsSymShift a)
forall a. Bits a => DefaultFiniteBitsSymShift a
forall a. Bits a => Int -> DefaultFiniteBitsSymShift a
forall a. Bits a => DefaultFiniteBitsSymShift a -> Bool
forall a. Bits a => DefaultFiniteBitsSymShift a -> Int
forall a. Bits a => DefaultFiniteBitsSymShift a -> Maybe Int
forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
forall a. Bits a => DefaultFiniteBitsSymShift a -> Int -> Bool
forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
forall a.
Bits a =>
DefaultFiniteBitsSymShift a
-> DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
$c.&. :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a
-> DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
.&. :: DefaultFiniteBitsSymShift a
-> DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
$c.|. :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a
-> DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
.|. :: DefaultFiniteBitsSymShift a
-> DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
$cxor :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a
-> DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
xor :: DefaultFiniteBitsSymShift a
-> DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
$ccomplement :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
complement :: DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
$cshift :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
shift :: DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
$crotate :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
rotate :: DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
$czeroBits :: forall a. Bits a => DefaultFiniteBitsSymShift a
zeroBits :: DefaultFiniteBitsSymShift a
$cbit :: forall a. Bits a => Int -> DefaultFiniteBitsSymShift a
bit :: Int -> DefaultFiniteBitsSymShift a
$csetBit :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
setBit :: DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
$cclearBit :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
clearBit :: DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
$ccomplementBit :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
complementBit :: DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
$ctestBit :: forall a. Bits a => DefaultFiniteBitsSymShift a -> Int -> Bool
testBit :: DefaultFiniteBitsSymShift a -> Int -> Bool
$cbitSizeMaybe :: forall a. Bits a => DefaultFiniteBitsSymShift a -> Maybe Int
bitSizeMaybe :: DefaultFiniteBitsSymShift a -> Maybe Int
$cbitSize :: forall a. Bits a => DefaultFiniteBitsSymShift a -> Int
bitSize :: DefaultFiniteBitsSymShift a -> Int
$cisSigned :: forall a. Bits a => DefaultFiniteBitsSymShift a -> Bool
isSigned :: DefaultFiniteBitsSymShift a -> Bool
$cshiftL :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
shiftL :: DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
$cunsafeShiftL :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
unsafeShiftL :: DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
$cshiftR :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
shiftR :: DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
$cunsafeShiftR :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
unsafeShiftR :: DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
$crotateL :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
rotateL :: DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
$crotateR :: forall a.
Bits a =>
DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
rotateR :: DefaultFiniteBitsSymShift a -> Int -> DefaultFiniteBitsSymShift a
$cpopCount :: forall a. Bits a => DefaultFiniteBitsSymShift a -> Int
popCount :: DefaultFiniteBitsSymShift a -> Int
Bits)

instance
  (Integral a, FiniteBits a) =>
  SymShift (DefaultFiniteBitsSymShift a)
  where
  symShift :: DefaultFiniteBitsSymShift a
-> DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
symShift (DefaultFiniteBitsSymShift a
a) (DefaultFiniteBitsSymShift a
s)
    | a -> Bool
forall a. Bits a => a -> Bool
isSigned a
a = a -> DefaultFiniteBitsSymShift a
forall a. a -> DefaultFiniteBitsSymShift a
DefaultFiniteBitsSymShift (a -> DefaultFiniteBitsSymShift a)
-> a -> DefaultFiniteBitsSymShift a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
forall a. (Integral a, FiniteBits a) => a -> a -> a
symShiftSigned a
a a
s
    | Bool
otherwise = a -> DefaultFiniteBitsSymShift a
forall a. a -> DefaultFiniteBitsSymShift a
DefaultFiniteBitsSymShift (a -> DefaultFiniteBitsSymShift a)
-> a -> DefaultFiniteBitsSymShift a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
forall a. (Integral a, FiniteBits a) => a -> a -> a
symShiftUnsigned a
a a
s
    where
      symShiftUnsigned :: (Integral a, FiniteBits a) => a -> a -> a
      symShiftUnsigned :: forall a. (Integral a, FiniteBits a) => a -> a -> a
symShiftUnsigned a
a a
s
        | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a) = a
0
        | Bool
otherwise = a -> Int -> a
forall a. Bits a => a -> Int -> a
shift a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)
      {-# INLINE symShiftUnsigned #-}

      symShiftSigned :: (Integral a, FiniteBits a) => a -> a -> a
      symShiftSigned :: forall a. (Integral a, FiniteBits a) => a -> a -> a
symShiftSigned a
a a
s
        | a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = a -> Int -> a
forall a. Bits a => a -> Int -> a
shift a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)
        | a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 = a -> Int -> a
forall a. Bits a => a -> Int -> a
shift a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)
        | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a) = a
0
        | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (-a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a) = if a
a a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 then -a
1 else a
0
        | Bool
otherwise = a -> Int -> a
forall a. Bits a => a -> Int -> a
shift a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)
      {-# INLINE symShiftSigned #-}
  {-# INLINE symShift #-}
  symShiftNegated :: DefaultFiniteBitsSymShift a
-> DefaultFiniteBitsSymShift a -> DefaultFiniteBitsSymShift a
symShiftNegated (DefaultFiniteBitsSymShift a
a) (DefaultFiniteBitsSymShift a
s)
    | a -> Bool
forall a. Bits a => a -> Bool
isSigned a
a = a -> DefaultFiniteBitsSymShift a
forall a. a -> DefaultFiniteBitsSymShift a
DefaultFiniteBitsSymShift (a -> DefaultFiniteBitsSymShift a)
-> a -> DefaultFiniteBitsSymShift a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
forall a. (Integral a, FiniteBits a) => a -> a -> a
symShiftSigned a
a a
s
    | Bool
otherwise = a -> DefaultFiniteBitsSymShift a
forall a. a -> DefaultFiniteBitsSymShift a
DefaultFiniteBitsSymShift (a -> DefaultFiniteBitsSymShift a)
-> a -> DefaultFiniteBitsSymShift a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
forall a. (Integral a, FiniteBits a) => a -> a -> a
symShiftUnsigned a
a a
s
    where
      symShiftUnsigned :: (Integral a, FiniteBits a) => a -> a -> a
      symShiftUnsigned :: forall a. (Integral a, FiniteBits a) => a -> a -> a
symShiftUnsigned a
a a
s
        | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a) = a
0
        | Bool
otherwise = a -> Int -> a
forall a. Bits a => a -> Int -> a
shiftR a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)
      {-# INLINE symShiftUnsigned #-}

      symShiftSigned :: (Integral a, FiniteBits a) => a -> a -> a
      symShiftSigned :: forall a. (Integral a, FiniteBits a) => a -> a -> a
symShiftSigned a
a a
s
        | a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = a -> Int -> a
forall a. Bits a => a -> Int -> a
shift a
a (-a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)
        | a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 = a -> Int -> a
forall a. Bits a => a -> Int -> a
shift a
a (-a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)
        | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (-a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a) = a
0
        | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a) = if a
a a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 then -a
1 else a
0
        | Bool
otherwise = a -> Int -> a
forall a. Bits a => a -> Int -> a
shift a
a (-a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)
      {-# INLINE symShiftSigned #-}
  {-# INLINE symShiftNegated #-}

deriving via (DefaultFiniteBitsSymShift Int8) instance SymShift Int8

deriving via (DefaultFiniteBitsSymShift Int16) instance SymShift Int16

deriving via (DefaultFiniteBitsSymShift Int32) instance SymShift Int32

deriving via (DefaultFiniteBitsSymShift Int64) instance SymShift Int64

deriving via (DefaultFiniteBitsSymShift Word8) instance SymShift Word8

deriving via (DefaultFiniteBitsSymShift Word16) instance SymShift Word16

deriving via (DefaultFiniteBitsSymShift Word32) instance SymShift Word32

deriving via (DefaultFiniteBitsSymShift Word64) instance SymShift Word64

deriving via (DefaultFiniteBitsSymShift Word) instance SymShift Word