```------------------------------------------------------------------------
-- |
-- Module           : What4.Utils.Arithmetic
-- Description      : Utility functions for computing arithmetic
-- Copyright        : (c) Galois, Inc 2015-2020
-- Maintainer       : Joe Hendrix <jhendrix@galois.com>
-- Stability        : provisional
------------------------------------------------------------------------
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module What4.Utils.Arithmetic
( -- * Arithmetic utilities
isPow2
, lg
, lgCeil
, nextMultiple
, nextPow2Multiple
, tryIntSqrt
, tryRationalSqrt
, roundAway
, ctz
, clz
, rotateLeft
, rotateRight
) where

import Control.Exception (assert)
import Data.Bits (Bits(..))
import Data.Ratio

import Data.Parameterized.NatRepr

-- | Returns true if number is a power of two.
isPow2 :: (Bits a, Num a) => a -> Bool
isPow2 :: a -> Bool
isPow2 a
x = a
x a -> a -> a
forall a. Bits a => a -> a -> a
.&. (a
xa -> a -> a
forall a. Num a => a -> a -> a
-a
1) a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0

-- | Returns floor of log base 2.
lg :: (Bits a, Num a, Ord a) => a -> Int
lg :: a -> Int
lg a
i0 | a
i0 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
0 = Int -> a -> Int
forall t t. (Num t, Num t, Bits t) => t -> t -> t
go Int
0 (a
i0 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
| Bool
otherwise = [Char] -> Int
forall a. HasCallStack => [Char] -> a
error [Char]
"lg given number that is not positive."
where go :: t -> t -> t
go t
r t
0 = t
r
go t
r t
n = t -> t -> t
go (t
rt -> t -> t
forall a. Num a => a -> a -> a
+t
1) (t
n t -> Int -> t
forall a. Bits a => a -> Int -> a
`shiftR` Int
1)

-- | Returns ceil of log base 2.
--   We define @lgCeil 0 = 0@
lgCeil :: (Bits a, Num a, Ord a) => a -> Int
lgCeil :: a -> Int
lgCeil a
0 = Int
0
lgCeil a
1 = Int
0
lgCeil a
i | a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
1 = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ a -> Int
forall a. (Bits a, Num a, Ord a) => a -> Int
lg (a
ia -> a -> a
forall a. Num a => a -> a -> a
-a
1)
| Bool
otherwise = [Char] -> Int
forall a. HasCallStack => [Char] -> a
error [Char]
"lgCeil given number that is not positive."

-- | Count trailing zeros
ctz :: NatRepr w -> Integer -> Integer
ctz :: NatRepr w -> Integer -> Integer
ctz NatRepr w
w Integer
x = Integer -> Integer
go Integer
0
where
go :: Integer -> Integer
go !Integer
i
| Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Natural -> Integer
forall a. Integral a => a -> Integer
toInteger (NatRepr w -> Natural
forall (n :: Nat). NatRepr n -> Natural
natValue NatRepr w
w) Bool -> Bool -> Bool
&& Integer -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Integer
x (Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
i) Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False = Integer -> Integer
go (Integer
iInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1)
| Bool
otherwise = Integer
i

clz :: NatRepr w -> Integer -> Integer
clz :: NatRepr w -> Integer -> Integer
clz NatRepr w
w Integer
x = Integer -> Integer
go Integer
0
where
go :: Integer -> Integer
go !Integer
i
| Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Natural -> Integer
forall a. Integral a => a -> Integer
toInteger (NatRepr w -> Natural
forall (n :: Nat). NatRepr n -> Natural
natValue NatRepr w
w) Bool -> Bool -> Bool
&& Integer -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Integer
x (NatRepr w -> Int
forall (n :: Nat). NatRepr n -> Int
widthVal NatRepr w
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False = Integer -> Integer
go (Integer
iInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1)
| Bool
otherwise = Integer
i

rotateRight ::
NatRepr w {- ^ width -} ->
Integer {- ^ value to rotate -} ->
Integer {- ^ amount to rotate -} ->
Integer
rotateRight :: NatRepr w -> Integer -> Integer -> Integer
rotateRight NatRepr w
w Integer
x Integer
n = Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
xor (Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
shiftR Integer
x' Int
n') (NatRepr w -> Integer -> Integer
forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w (Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
shiftL Integer
x' (NatRepr w -> Int
forall (n :: Nat). NatRepr n -> Int
widthVal NatRepr w
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n')))
where
x' :: Integer
x' = NatRepr w -> Integer -> Integer
forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w Integer
x
n' :: Int
n' = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` NatRepr w -> Integer
forall (n :: Nat). NatRepr n -> Integer
intValue NatRepr w
w)

rotateLeft ::
NatRepr w {- ^ width -} ->
Integer {- ^ value to rotate -} ->
Integer {- ^ amount to rotate -} ->
Integer
rotateLeft :: NatRepr w -> Integer -> Integer -> Integer
rotateLeft NatRepr w
w Integer
x Integer
n = Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
xor (Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
shiftR Integer
x' (NatRepr w -> Int
forall (n :: Nat). NatRepr n -> Int
widthVal NatRepr w
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n')) (NatRepr w -> Integer -> Integer
forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w (Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
shiftL Integer
x' Int
n'))
where
x' :: Integer
x' = NatRepr w -> Integer -> Integer
forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w Integer
x
n' :: Int
n' = Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` NatRepr w -> Integer
forall (n :: Nat). NatRepr n -> Integer
intValue NatRepr w
w)

-- | @nextMultiple x y@ computes the next multiple m of x s.t. m >= y.  E.g.,
-- nextMultiple 4 8 = 8 since 8 is a multiple of 8; nextMultiple 4 7 = 8;
-- nextMultiple 8 6 = 8.
nextMultiple :: Integral a => a -> a -> a
nextMultiple :: a -> a -> a
nextMultiple a
x a
y = ((a
y a -> a -> a
forall a. Num a => a -> a -> a
+ a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
1) a -> a -> a
forall a. Integral a => a -> a -> a
`div` a
x) a -> a -> a
forall a. Num a => a -> a -> a
* a
x

-- | @nextPow2Multiple x n@ returns the smallest multiple of @2^n@
-- not less than @x@.
nextPow2Multiple :: (Bits a, Integral a) => a -> Int -> a
nextPow2Multiple :: a -> Int -> a
nextPow2Multiple a
x Int
n | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = ((a
xa -> a -> a
forall a. Num a => a -> a -> a
+a
2a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^Int
n a -> a -> a
forall a. Num a => a -> a -> a
-a
1) a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
n) a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
n
| Bool
otherwise = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"nextPow2Multiple given negative value."

------------------------------------------------------------------------
-- Sqrt operators.

-- | This returns the sqrt of an integer if it is well-defined.
tryIntSqrt :: Integer -> Maybe Integer
tryIntSqrt :: Integer -> Maybe Integer
tryIntSqrt Integer
0 = Integer -> Maybe Integer
forall (m :: Type -> Type) a. Monad m => a -> m a
return Integer
0
tryIntSqrt Integer
1 = Integer -> Maybe Integer
forall (m :: Type -> Type) a. Monad m => a -> m a
return Integer
1
tryIntSqrt Integer
2 = Maybe Integer
forall a. Maybe a
Nothing
tryIntSqrt Integer
3 = Maybe Integer
forall a. Maybe a
Nothing
tryIntSqrt Integer
n = Bool -> Maybe Integer -> Maybe Integer
forall a. HasCallStack => Bool -> a -> a
assert (Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
4) (Maybe Integer -> Maybe Integer) -> Maybe Integer -> Maybe Integer
forall a b. (a -> b) -> a -> b
\$ Integer -> Maybe Integer
go (Integer
n Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
where go :: Integer -> Maybe Integer
go Integer
x | Integer
x2 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
n  = Maybe Integer
forall a. Maybe a
Nothing   -- Guess is below sqrt, so we quit.
| Integer
x2 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
n = Integer -> Maybe Integer
forall (m :: Type -> Type) a. Monad m => a -> m a
return Integer
x' -- We have found sqrt
| Bool
True    = Integer -> Maybe Integer
go Integer
x'     -- Guess is still too large, so try again.
where -- Next guess is floor(avg(x, n/x))
x' :: Integer
x' = (Integer
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
x) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
2
x2 :: Integer
x2 = Integer
x' Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
x'

-- | Return the rational sqrt of a
tryRationalSqrt :: Rational -> Maybe Rational
tryRationalSqrt :: Rational -> Maybe Rational
tryRationalSqrt Rational
r = do
Integer -> Integer -> Rational
forall a. Integral a => a -> a -> Ratio a
(%) (Integer -> Integer -> Rational)
-> Maybe Integer -> Maybe (Integer -> Rational)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<\$> Integer -> Maybe Integer
tryIntSqrt (Rational -> Integer
forall a. Ratio a -> a
numerator   Rational
r)
Maybe (Integer -> Rational) -> Maybe Integer -> Maybe Rational
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Integer -> Maybe Integer
tryIntSqrt (Rational -> Integer
forall a. Ratio a -> a
denominator Rational
r)

------------------------------------------------------------------------
-- Conversion

-- | Evaluate a real to an integer with rounding away from zero.
roundAway :: (RealFrac a) => a -> Integer
roundAway :: a -> Integer
roundAway a
r = a -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
truncate (a
r a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Num a => a -> a
signum a
r a -> a -> a
forall a. Num a => a -> a -> a
* a
0.5)
```