{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module What4.Utils.Arithmetic
(
isPow2
, lg
, lgCeil
, nextMultiple
, nextPow2Multiple
, tryIntSqrt
, tryRationalSqrt
, roundAway
, ctz
, clz
, rotateLeft
, rotateRight
) where
import Control.Exception (assert)
import Data.Bits (Bits(..))
import Data.Ratio
import Data.Parameterized.NatRepr
isPow2 :: (Bits a, Num a) => a -> Bool
isPow2 :: a -> Bool
isPow2 a
x = a
x a -> a -> a
forall a. Bits a => a -> a -> a
.&. (a
xa -> a -> a
forall a. Num a => a -> a -> a
-a
1) a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0
lg :: (Bits a, Num a, Ord a) => a -> Int
lg :: a -> Int
lg a
i0 | a
i0 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
0 = Int -> a -> Int
forall t t. (Num t, Num t, Bits t) => t -> t -> t
go Int
0 (a
i0 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
| Bool
otherwise = [Char] -> Int
forall a. HasCallStack => [Char] -> a
error [Char]
"lg given number that is not positive."
where go :: t -> t -> t
go t
r t
0 = t
r
go t
r t
n = t -> t -> t
go (t
rt -> t -> t
forall a. Num a => a -> a -> a
+t
1) (t
n t -> Int -> t
forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
lgCeil :: (Bits a, Num a, Ord a) => a -> Int
lgCeil :: a -> Int
lgCeil a
0 = Int
0
lgCeil a
1 = Int
0
lgCeil a
i | a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
1 = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ a -> Int
forall a. (Bits a, Num a, Ord a) => a -> Int
lg (a
ia -> a -> a
forall a. Num a => a -> a -> a
-a
1)
| Bool
otherwise = [Char] -> Int
forall a. HasCallStack => [Char] -> a
error [Char]
"lgCeil given number that is not positive."
ctz :: NatRepr w -> Integer -> Integer
ctz :: NatRepr w -> Integer -> Integer
ctz NatRepr w
w Integer
x = Integer -> Integer
go Integer
0
where
go :: Integer -> Integer
go !Integer
i
| Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Natural -> Integer
forall a. Integral a => a -> Integer
toInteger (NatRepr w -> Natural
forall (n :: Nat). NatRepr n -> Natural
natValue NatRepr w
w) Bool -> Bool -> Bool
&& Integer -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Integer
x (Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
i) Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False = Integer -> Integer
go (Integer
iInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1)
| Bool
otherwise = Integer
i
clz :: NatRepr w -> Integer -> Integer
clz :: NatRepr w -> Integer -> Integer
clz NatRepr w
w Integer
x = Integer -> Integer
go Integer
0
where
go :: Integer -> Integer
go !Integer
i
| Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Natural -> Integer
forall a. Integral a => a -> Integer
toInteger (NatRepr w -> Natural
forall (n :: Nat). NatRepr n -> Natural
natValue NatRepr w
w) Bool -> Bool -> Bool
&& Integer -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Integer
x (NatRepr w -> Int
forall (n :: Nat). NatRepr n -> Int
widthVal NatRepr w
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False = Integer -> Integer
go (Integer
iInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1)
| Bool
otherwise = Integer
i
rotateRight ::
NatRepr w ->
Integer ->
Integer ->
Integer
rotateRight :: NatRepr w -> Integer -> Integer -> Integer
rotateRight NatRepr w
w Integer
x Integer
n = Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
xor (Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
shiftR Integer
x' Int
n') (NatRepr w -> Integer -> Integer
forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w (Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
shiftL Integer
x' (NatRepr w -> Int
forall (n :: Nat). NatRepr n -> Int
widthVal NatRepr w
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n')))
where
x' :: Integer
x' = NatRepr w -> Integer -> Integer
forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w Integer
x
n' :: Int
n' = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` NatRepr w -> Integer
forall (n :: Nat). NatRepr n -> Integer
intValue NatRepr w
w)
rotateLeft ::
NatRepr w ->
Integer ->
Integer ->
Integer
rotateLeft :: NatRepr w -> Integer -> Integer -> Integer
rotateLeft NatRepr w
w Integer
x Integer
n = Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
xor (Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
shiftR Integer
x' (NatRepr w -> Int
forall (n :: Nat). NatRepr n -> Int
widthVal NatRepr w
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n')) (NatRepr w -> Integer -> Integer
forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w (Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
shiftL Integer
x' Int
n'))
where
x' :: Integer
x' = NatRepr w -> Integer -> Integer
forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w Integer
x
n' :: Int
n' = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` NatRepr w -> Integer
forall (n :: Nat). NatRepr n -> Integer
intValue NatRepr w
w)
nextMultiple :: Integral a => a -> a -> a
nextMultiple :: a -> a -> a
nextMultiple a
x a
y = ((a
y a -> a -> a
forall a. Num a => a -> a -> a
+ a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
1) a -> a -> a
forall a. Integral a => a -> a -> a
`div` a
x) a -> a -> a
forall a. Num a => a -> a -> a
* a
x
nextPow2Multiple :: (Bits a, Integral a) => a -> Int -> a
nextPow2Multiple :: a -> Int -> a
nextPow2Multiple a
x Int
n | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = ((a
xa -> a -> a
forall a. Num a => a -> a -> a
+a
2a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^Int
n a -> a -> a
forall a. Num a => a -> a -> a
-a
1) a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
n) a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
n
| Bool
otherwise = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"nextPow2Multiple given negative value."
tryIntSqrt :: Integer -> Maybe Integer
tryIntSqrt :: Integer -> Maybe Integer
tryIntSqrt Integer
0 = Integer -> Maybe Integer
forall (m :: Type -> Type) a. Monad m => a -> m a
return Integer
0
tryIntSqrt Integer
1 = Integer -> Maybe Integer
forall (m :: Type -> Type) a. Monad m => a -> m a
return Integer
1
tryIntSqrt Integer
2 = Maybe Integer
forall a. Maybe a
Nothing
tryIntSqrt Integer
3 = Maybe Integer
forall a. Maybe a
Nothing
tryIntSqrt Integer
n = Bool -> Maybe Integer -> Maybe Integer
forall a. HasCallStack => Bool -> a -> a
assert (Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
4) (Maybe Integer -> Maybe Integer) -> Maybe Integer -> Maybe Integer
forall a b. (a -> b) -> a -> b
$ Integer -> Maybe Integer
go (Integer
n Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
where go :: Integer -> Maybe Integer
go Integer
x | Integer
x2 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
n = Maybe Integer
forall a. Maybe a
Nothing
| Integer
x2 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
n = Integer -> Maybe Integer
forall (m :: Type -> Type) a. Monad m => a -> m a
return Integer
x'
| Bool
True = Integer -> Maybe Integer
go Integer
x'
where
x' :: Integer
x' = (Integer
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
x) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
2
x2 :: Integer
x2 = Integer
x' Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
x'
tryRationalSqrt :: Rational -> Maybe Rational
tryRationalSqrt :: Rational -> Maybe Rational
tryRationalSqrt Rational
r = do
Integer -> Integer -> Rational
forall a. Integral a => a -> a -> Ratio a
(%) (Integer -> Integer -> Rational)
-> Maybe Integer -> Maybe (Integer -> Rational)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer -> Maybe Integer
tryIntSqrt (Rational -> Integer
forall a. Ratio a -> a
numerator Rational
r)
Maybe (Integer -> Rational) -> Maybe Integer -> Maybe Rational
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Integer -> Maybe Integer
tryIntSqrt (Rational -> Integer
forall a. Ratio a -> a
denominator Rational
r)
roundAway :: (RealFrac a) => a -> Integer
roundAway :: a -> Integer
roundAway a
r = a -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
truncate (a
r a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Num a => a -> a
signum a
r a -> a -> a
forall a. Num a => a -> a -> a
* a
0.5)