{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module What4.Utils.Arithmetic
(
isPow2
, lg
, lgCeil
, nextMultiple
, nextPow2Multiple
, tryIntSqrt
, tryRationalSqrt
, roundAway
, ctz
, clz
, rotateLeft
, rotateRight
) where
import Control.Exception (assert)
import Data.Bits (Bits(..))
import Data.Ratio
import Data.Parameterized.NatRepr
isPow2 :: (Bits a, Num a) => a -> Bool
isPow2 :: forall a. (Bits a, Num a) => a -> Bool
isPow2 a
x = a
x forall a. Bits a => a -> a -> a
.&. (a
xforall a. Num a => a -> a -> a
-a
1) forall a. Eq a => a -> a -> Bool
== a
0
lg :: (Bits a, Num a, Ord a) => a -> Int
lg :: forall a. (Bits a, Num a, Ord a) => a -> Int
lg a
i0 | a
i0 forall a. Ord a => a -> a -> Bool
> a
0 = forall {t} {t}. (Num t, Num t, Bits t) => t -> t -> t
go Int
0 (a
i0 forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
| Bool
otherwise = forall a. (?callStack::CallStack) => [Char] -> a
error [Char]
"lg given number that is not positive."
where go :: t -> t -> t
go t
r t
0 = t
r
go t
r t
n = t -> t -> t
go (t
rforall a. Num a => a -> a -> a
+t
1) (t
n forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
lgCeil :: (Bits a, Num a, Ord a) => a -> Int
lgCeil :: forall a. (Bits a, Num a, Ord a) => a -> Int
lgCeil a
0 = Int
0
lgCeil a
1 = Int
0
lgCeil a
i | a
i forall a. Ord a => a -> a -> Bool
> a
1 = Int
1 forall a. Num a => a -> a -> a
+ forall a. (Bits a, Num a, Ord a) => a -> Int
lg (a
iforall a. Num a => a -> a -> a
-a
1)
| Bool
otherwise = forall a. (?callStack::CallStack) => [Char] -> a
error [Char]
"lgCeil given number that is not positive."
ctz :: NatRepr w -> Integer -> Integer
ctz :: forall (w :: Nat). NatRepr w -> Integer -> Integer
ctz NatRepr w
w Integer
x = Integer -> Integer
go Integer
0
where
go :: Integer -> Integer
go !Integer
i
| Integer
i forall a. Ord a => a -> a -> Bool
< forall a. Integral a => a -> Integer
toInteger (forall (n :: Nat). NatRepr n -> Nat
natValue NatRepr w
w) Bool -> Bool -> Bool
&& forall a. Bits a => a -> Int -> Bool
testBit Integer
x (forall a. Num a => Integer -> a
fromInteger Integer
i) forall a. Eq a => a -> a -> Bool
== Bool
False = Integer -> Integer
go (Integer
iforall a. Num a => a -> a -> a
+Integer
1)
| Bool
otherwise = Integer
i
clz :: NatRepr w -> Integer -> Integer
clz :: forall (w :: Nat). NatRepr w -> Integer -> Integer
clz NatRepr w
w Integer
x = Integer -> Integer
go Integer
0
where
go :: Integer -> Integer
go !Integer
i
| Integer
i forall a. Ord a => a -> a -> Bool
< forall a. Integral a => a -> Integer
toInteger (forall (n :: Nat). NatRepr n -> Nat
natValue NatRepr w
w) Bool -> Bool -> Bool
&& forall a. Bits a => a -> Int -> Bool
testBit Integer
x (forall (n :: Nat). NatRepr n -> Int
widthVal NatRepr w
w forall a. Num a => a -> a -> a
- forall a. Num a => Integer -> a
fromInteger Integer
i forall a. Num a => a -> a -> a
- Int
1) forall a. Eq a => a -> a -> Bool
== Bool
False = Integer -> Integer
go (Integer
iforall a. Num a => a -> a -> a
+Integer
1)
| Bool
otherwise = Integer
i
rotateRight ::
NatRepr w ->
Integer ->
Integer ->
Integer
rotateRight :: forall (w :: Nat). NatRepr w -> Integer -> Integer -> Integer
rotateRight NatRepr w
w Integer
x Integer
n = forall a. Bits a => a -> a -> a
xor (forall a. Bits a => a -> Int -> a
shiftR Integer
x' Int
n') (forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w (forall a. Bits a => a -> Int -> a
shiftL Integer
x' (forall (n :: Nat). NatRepr n -> Int
widthVal NatRepr w
w forall a. Num a => a -> a -> a
- Int
n')))
where
x' :: Integer
x' = forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w Integer
x
n' :: Int
n' = forall a. Num a => Integer -> a
fromInteger (Integer
n forall a. Integral a => a -> a -> a
`rem` forall (n :: Nat). NatRepr n -> Integer
intValue NatRepr w
w)
rotateLeft ::
NatRepr w ->
Integer ->
Integer ->
Integer
rotateLeft :: forall (w :: Nat). NatRepr w -> Integer -> Integer -> Integer
rotateLeft NatRepr w
w Integer
x Integer
n = forall a. Bits a => a -> a -> a
xor (forall a. Bits a => a -> Int -> a
shiftR Integer
x' (forall (n :: Nat). NatRepr n -> Int
widthVal NatRepr w
w forall a. Num a => a -> a -> a
- Int
n')) (forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w (forall a. Bits a => a -> Int -> a
shiftL Integer
x' Int
n'))
where
x' :: Integer
x' = forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w Integer
x
n' :: Int
n' = forall a. Num a => Integer -> a
fromInteger (Integer
n forall a. Integral a => a -> a -> a
`rem` forall (n :: Nat). NatRepr n -> Integer
intValue NatRepr w
w)
nextMultiple :: Integral a => a -> a -> a
nextMultiple :: forall a. Integral a => a -> a -> a
nextMultiple a
x a
y = ((a
y forall a. Num a => a -> a -> a
+ a
x forall a. Num a => a -> a -> a
- a
1) forall a. Integral a => a -> a -> a
`div` a
x) forall a. Num a => a -> a -> a
* a
x
nextPow2Multiple :: (Bits a, Integral a) => a -> Int -> a
nextPow2Multiple :: forall a. (Bits a, Integral a) => a -> Int -> a
nextPow2Multiple a
x Int
n | a
x forall a. Ord a => a -> a -> Bool
>= a
0 Bool -> Bool -> Bool
&& Int
n forall a. Ord a => a -> a -> Bool
>= Int
0 = ((a
xforall a. Num a => a -> a -> a
+a
2forall a b. (Num a, Integral b) => a -> b -> a
^Int
n forall a. Num a => a -> a -> a
-a
1) forall a. Bits a => a -> Int -> a
`shiftR` Int
n) forall a. Bits a => a -> Int -> a
`shiftL` Int
n
| Bool
otherwise = forall a. (?callStack::CallStack) => [Char] -> a
error [Char]
"nextPow2Multiple given negative value."
tryIntSqrt :: Integer -> Maybe Integer
tryIntSqrt :: Integer -> Maybe Integer
tryIntSqrt Integer
0 = forall (m :: Type -> Type) a. Monad m => a -> m a
return Integer
0
tryIntSqrt Integer
1 = forall (m :: Type -> Type) a. Monad m => a -> m a
return Integer
1
tryIntSqrt Integer
2 = forall a. Maybe a
Nothing
tryIntSqrt Integer
3 = forall a. Maybe a
Nothing
tryIntSqrt Integer
n = forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Integer
n forall a. Ord a => a -> a -> Bool
>= Integer
4) forall a b. (a -> b) -> a -> b
$ Integer -> Maybe Integer
go (Integer
n forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
where go :: Integer -> Maybe Integer
go Integer
x | Integer
x2 forall a. Ord a => a -> a -> Bool
< Integer
n = forall a. Maybe a
Nothing
| Integer
x2 forall a. Eq a => a -> a -> Bool
== Integer
n = forall (m :: Type -> Type) a. Monad m => a -> m a
return Integer
x'
| Bool
True = Integer -> Maybe Integer
go Integer
x'
where
x' :: Integer
x' = (Integer
x forall a. Num a => a -> a -> a
+ Integer
n forall a. Integral a => a -> a -> a
`div` Integer
x) forall a. Integral a => a -> a -> a
`div` Integer
2
x2 :: Integer
x2 = Integer
x' forall a. Num a => a -> a -> a
* Integer
x'
tryRationalSqrt :: Rational -> Maybe Rational
tryRationalSqrt :: Rational -> Maybe Rational
tryRationalSqrt Rational
r = do
forall a. Integral a => a -> a -> Ratio a
(%) forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer -> Maybe Integer
tryIntSqrt (forall a. Ratio a -> a
numerator Rational
r)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Integer -> Maybe Integer
tryIntSqrt (forall a. Ratio a -> a
denominator Rational
r)
roundAway :: (RealFrac a) => a -> Integer
roundAway :: forall a. RealFrac a => a -> Integer
roundAway a
r = forall a b. (RealFrac a, Integral b) => a -> b
truncate (a
r forall a. Num a => a -> a -> a
+ forall a. Num a => a -> a
signum a
r forall a. Num a => a -> a -> a
* a
0.5)