------------------------------------------------------------------------
-- |
-- Module           : What4.Utils.Arithmetic
-- Description      : Utility functions for computing arithmetic
-- Copyright        : (c) Galois, Inc 2015-2020
-- License          : BSD3
-- Maintainer       : Joe Hendrix <jhendrix@galois.com>
-- Stability        : provisional
------------------------------------------------------------------------
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module What4.Utils.Arithmetic
  ( -- * Arithmetic utilities
    isPow2
  , lg
  , lgCeil
  , nextMultiple
  , nextPow2Multiple
  , tryIntSqrt
  , tryRationalSqrt
  , roundAway
  , ctz
  , clz
  , rotateLeft
  , rotateRight
  ) where

import Control.Exception (assert)
import Data.Bits (Bits(..))
import Data.Ratio

import Data.Parameterized.NatRepr

-- | Returns true if number is a power of two.
isPow2 :: (Bits a, Num a) => a -> Bool
isPow2 :: forall a. (Bits a, Num a) => a -> Bool
isPow2 a
x = a
x forall a. Bits a => a -> a -> a
.&. (a
xforall a. Num a => a -> a -> a
-a
1) forall a. Eq a => a -> a -> Bool
== a
0

-- | Returns floor of log base 2.
lg :: (Bits a, Num a, Ord a) => a -> Int
lg :: forall a. (Bits a, Num a, Ord a) => a -> Int
lg a
i0 | a
i0 forall a. Ord a => a -> a -> Bool
> a
0 = forall {t} {t}. (Num t, Num t, Bits t) => t -> t -> t
go Int
0 (a
i0 forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
      | Bool
otherwise = forall a. (?callStack::CallStack) => [Char] -> a
error [Char]
"lg given number that is not positive."
  where go :: t -> t -> t
go t
r t
0 = t
r
        go t
r t
n = t -> t -> t
go (t
rforall a. Num a => a -> a -> a
+t
1) (t
n forall a. Bits a => a -> Int -> a
`shiftR` Int
1)

-- | Returns ceil of log base 2.
--   We define @lgCeil 0 = 0@
lgCeil :: (Bits a, Num a, Ord a) => a -> Int
lgCeil :: forall a. (Bits a, Num a, Ord a) => a -> Int
lgCeil a
0 = Int
0
lgCeil a
1 = Int
0
lgCeil a
i | a
i forall a. Ord a => a -> a -> Bool
> a
1 = Int
1 forall a. Num a => a -> a -> a
+ forall a. (Bits a, Num a, Ord a) => a -> Int
lg (a
iforall a. Num a => a -> a -> a
-a
1)
         | Bool
otherwise = forall a. (?callStack::CallStack) => [Char] -> a
error [Char]
"lgCeil given number that is not positive."

-- | Count trailing zeros
ctz :: NatRepr w -> Integer -> Integer
ctz :: forall (w :: Nat). NatRepr w -> Integer -> Integer
ctz NatRepr w
w Integer
x = Integer -> Integer
go Integer
0
 where
 go :: Integer -> Integer
go !Integer
i
   | Integer
i forall a. Ord a => a -> a -> Bool
< forall a. Integral a => a -> Integer
toInteger (forall (n :: Nat). NatRepr n -> Nat
natValue NatRepr w
w) Bool -> Bool -> Bool
&& forall a. Bits a => a -> Int -> Bool
testBit Integer
x (forall a. Num a => Integer -> a
fromInteger Integer
i) forall a. Eq a => a -> a -> Bool
== Bool
False = Integer -> Integer
go (Integer
iforall a. Num a => a -> a -> a
+Integer
1)
   | Bool
otherwise = Integer
i

-- | Count leading zeros
clz :: NatRepr w -> Integer -> Integer
clz :: forall (w :: Nat). NatRepr w -> Integer -> Integer
clz NatRepr w
w Integer
x = Integer -> Integer
go Integer
0
 where
 go :: Integer -> Integer
go !Integer
i
   | Integer
i forall a. Ord a => a -> a -> Bool
< forall a. Integral a => a -> Integer
toInteger (forall (n :: Nat). NatRepr n -> Nat
natValue NatRepr w
w) Bool -> Bool -> Bool
&& forall a. Bits a => a -> Int -> Bool
testBit Integer
x (forall (n :: Nat). NatRepr n -> Int
widthVal NatRepr w
w forall a. Num a => a -> a -> a
- forall a. Num a => Integer -> a
fromInteger Integer
i forall a. Num a => a -> a -> a
- Int
1) forall a. Eq a => a -> a -> Bool
== Bool
False = Integer -> Integer
go (Integer
iforall a. Num a => a -> a -> a
+Integer
1)
   | Bool
otherwise = Integer
i

rotateRight ::
  NatRepr w {- ^ width -} ->
  Integer {- ^ value to rotate -} ->
  Integer {- ^ amount to rotate -} ->
  Integer
rotateRight :: forall (w :: Nat). NatRepr w -> Integer -> Integer -> Integer
rotateRight NatRepr w
w Integer
x Integer
n = forall a. Bits a => a -> a -> a
xor (forall a. Bits a => a -> Int -> a
shiftR Integer
x' Int
n') (forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w (forall a. Bits a => a -> Int -> a
shiftL Integer
x' (forall (n :: Nat). NatRepr n -> Int
widthVal NatRepr w
w forall a. Num a => a -> a -> a
- Int
n')))
 where
 x' :: Integer
x' = forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w Integer
x
 n' :: Int
n' = forall a. Num a => Integer -> a
fromInteger (Integer
n forall a. Integral a => a -> a -> a
`rem` forall (n :: Nat). NatRepr n -> Integer
intValue NatRepr w
w)

rotateLeft ::
  NatRepr w {- ^ width -} ->
  Integer {- ^ value to rotate -} ->
  Integer {- ^ amount to rotate -} ->
  Integer
rotateLeft :: forall (w :: Nat). NatRepr w -> Integer -> Integer -> Integer
rotateLeft NatRepr w
w Integer
x Integer
n = forall a. Bits a => a -> a -> a
xor (forall a. Bits a => a -> Int -> a
shiftR Integer
x' (forall (n :: Nat). NatRepr n -> Int
widthVal NatRepr w
w forall a. Num a => a -> a -> a
- Int
n')) (forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w (forall a. Bits a => a -> Int -> a
shiftL Integer
x' Int
n'))
 where
 x' :: Integer
x' = forall (w :: Nat). NatRepr w -> Integer -> Integer
toUnsigned NatRepr w
w Integer
x
 n' :: Int
n' = forall a. Num a => Integer -> a
fromInteger (Integer
n forall a. Integral a => a -> a -> a
`rem` forall (n :: Nat). NatRepr n -> Integer
intValue NatRepr w
w)


-- | @nextMultiple x y@ computes the next multiple m of x s.t. m >= y.  E.g.,
-- nextMultiple 4 8 = 8 since 8 is a multiple of 8; nextMultiple 4 7 = 8;
-- nextMultiple 8 6 = 8.
nextMultiple :: Integral a => a -> a -> a
nextMultiple :: forall a. Integral a => a -> a -> a
nextMultiple a
x a
y = ((a
y forall a. Num a => a -> a -> a
+ a
x forall a. Num a => a -> a -> a
- a
1) forall a. Integral a => a -> a -> a
`div` a
x) forall a. Num a => a -> a -> a
* a
x

-- | @nextPow2Multiple x n@ returns the smallest multiple of @2^n@
-- not less than @x@.
nextPow2Multiple :: (Bits a, Integral a) => a -> Int -> a
nextPow2Multiple :: forall a. (Bits a, Integral a) => a -> Int -> a
nextPow2Multiple a
x Int
n | a
x forall a. Ord a => a -> a -> Bool
>= a
0 Bool -> Bool -> Bool
&& Int
n forall a. Ord a => a -> a -> Bool
>= Int
0 = ((a
xforall a. Num a => a -> a -> a
+a
2forall a b. (Num a, Integral b) => a -> b -> a
^Int
n forall a. Num a => a -> a -> a
-a
1) forall a. Bits a => a -> Int -> a
`shiftR` Int
n) forall a. Bits a => a -> Int -> a
`shiftL` Int
n
                     | Bool
otherwise = forall a. (?callStack::CallStack) => [Char] -> a
error [Char]
"nextPow2Multiple given negative value."

------------------------------------------------------------------------
-- Sqrt operators.

-- | This returns the sqrt of an integer if it is well-defined.
tryIntSqrt :: Integer -> Maybe Integer
tryIntSqrt :: Integer -> Maybe Integer
tryIntSqrt Integer
0 = forall (m :: Type -> Type) a. Monad m => a -> m a
return Integer
0
tryIntSqrt Integer
1 = forall (m :: Type -> Type) a. Monad m => a -> m a
return Integer
1
tryIntSqrt Integer
2 = forall a. Maybe a
Nothing
tryIntSqrt Integer
3 = forall a. Maybe a
Nothing
tryIntSqrt Integer
n = forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Integer
n forall a. Ord a => a -> a -> Bool
>= Integer
4) forall a b. (a -> b) -> a -> b
$ Integer -> Maybe Integer
go (Integer
n forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
  where go :: Integer -> Maybe Integer
go Integer
x | Integer
x2 forall a. Ord a => a -> a -> Bool
< Integer
n  = forall a. Maybe a
Nothing   -- Guess is below sqrt, so we quit.
             | Integer
x2 forall a. Eq a => a -> a -> Bool
== Integer
n = forall (m :: Type -> Type) a. Monad m => a -> m a
return Integer
x' -- We have found sqrt
             | Bool
True    = Integer -> Maybe Integer
go Integer
x'     -- Guess is still too large, so try again.
          where -- Next guess is floor(avg(x, n/x))
                x' :: Integer
x' = (Integer
x forall a. Num a => a -> a -> a
+ Integer
n forall a. Integral a => a -> a -> a
`div` Integer
x) forall a. Integral a => a -> a -> a
`div` Integer
2
                x2 :: Integer
x2 = Integer
x' forall a. Num a => a -> a -> a
* Integer
x'

-- | Return the rational sqrt of a
tryRationalSqrt :: Rational -> Maybe Rational
tryRationalSqrt :: Rational -> Maybe Rational
tryRationalSqrt Rational
r = do
  forall a. Integral a => a -> a -> Ratio a
(%) forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer -> Maybe Integer
tryIntSqrt (forall a. Ratio a -> a
numerator   Rational
r)
      forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Integer -> Maybe Integer
tryIntSqrt (forall a. Ratio a -> a
denominator Rational
r)

------------------------------------------------------------------------
-- Conversion

-- | Evaluate a real to an integer with rounding away from zero.
roundAway :: (RealFrac a) => a -> Integer
roundAway :: forall a. RealFrac a => a -> Integer
roundAway a
r = forall a b. (RealFrac a, Integral b) => a -> b
truncate (a
r forall a. Num a => a -> a -> a
+ forall a. Num a => a -> a
signum a
r forall a. Num a => a -> a -> a
* a
0.5)