```------------------------------------------------------------------------
-- |
-- Module           : What4.Utils.Arithmetic
-- Description      : Utility functions for computing arithmetic
-- Copyright        : (c) Galois, Inc 2015-2020
-- License          : BSD3
-- Maintainer       : Joe Hendrix <jhendrix@galois.com>
-- Stability        : provisional
------------------------------------------------------------------------
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module What4.Utils.Arithmetic
( -- * Arithmetic utilities
isPow2
, lg
, lgCeil
, nextMultiple
, nextPow2Multiple
, tryIntSqrt
, tryRationalSqrt
, roundAway
, ctz
, clz
, rotateLeft
, rotateRight
) where

import Control.Exception (assert)
import Data.Bits (Bits(..))
import Data.Ratio

import Data.Parameterized.NatRepr

-- | Returns true if number is a power of two.
isPow2 :: (Bits a, Num a) => a -> Bool
isPow2 x = x .&. (x-1) == 0

-- | Returns floor of log base 2.
lg :: (Bits a, Num a, Ord a) => a -> Int
lg i0 | i0 > 0 = go 0 (i0 `shiftR` 1)
| otherwise = error "lg given number that is not positive."
where go r 0 = r
go r n = go (r+1) (n `shiftR` 1)

-- | Returns ceil of log base 2.
--   We define @lgCeil 0 = 0@
lgCeil :: (Bits a, Num a, Ord a) => a -> Int
lgCeil 0 = 0
lgCeil 1 = 0
lgCeil i | i > 1 = 1 + lg (i-1)
| otherwise = error "lgCeil given number that is not positive."

-- | Count trailing zeros
ctz :: NatRepr w -> Integer -> Integer
ctz w x = go 0
where
go !i
| i < toInteger (natValue w) && testBit x (fromInteger i) == False = go (i+1)
| otherwise = i

-- | Count leading zeros
clz :: NatRepr w -> Integer -> Integer
clz w x = go 0
where
go !i
| i < toInteger (natValue w) && testBit x (widthVal w - fromInteger i - 1) == False = go (i+1)
| otherwise = i

rotateRight ::
NatRepr w {- ^ width -} ->
Integer {- ^ value to rotate -} ->
Integer {- ^ amount to rotate -} ->
Integer
rotateRight w x n = xor (shiftR x' n') (toUnsigned w (shiftL x' (widthVal w - n')))
where
x' = toUnsigned w x
n' = fromInteger (n `rem` intValue w)

rotateLeft ::
NatRepr w {- ^ width -} ->
Integer {- ^ value to rotate -} ->
Integer {- ^ amount to rotate -} ->
Integer
rotateLeft w x n = xor (shiftR x' (widthVal w - n')) (toUnsigned w (shiftL x' n'))
where
x' = toUnsigned w x
n' = fromInteger (n `rem` intValue w)

-- | @nextMultiple x y@ computes the next multiple m of x s.t. m >= y.  E.g.,
-- nextMultiple 4 8 = 8 since 8 is a multiple of 8; nextMultiple 4 7 = 8;
-- nextMultiple 8 6 = 8.
nextMultiple :: Integral a => a -> a -> a
nextMultiple x y = ((y + x - 1) `div` x) * x

-- | @nextPow2Multiple x n@ returns the smallest multiple of @2^n@
-- not less than @x@.
nextPow2Multiple :: (Bits a, Integral a) => a -> Int -> a
nextPow2Multiple x n | x >= 0 && n >= 0 = ((x+2^n -1) `shiftR` n) `shiftL` n
| otherwise = error "nextPow2Multiple given negative value."

------------------------------------------------------------------------
-- Sqrt operators.

-- | This returns the sqrt of an integer if it is well-defined.
tryIntSqrt :: Integer -> Maybe Integer
tryIntSqrt 0 = return 0
tryIntSqrt 1 = return 1
tryIntSqrt 2 = Nothing
tryIntSqrt 3 = Nothing
tryIntSqrt n = assert (n >= 4) \$ go (n `shiftR` 1)
where go x | x2 < n  = Nothing   -- Guess is below sqrt, so we quit.
| x2 == n = return x' -- We have found sqrt
| True    = go x'     -- Guess is still too large, so try again.
where -- Next guess is floor(avg(x, n/x))
x' = (x + n `div` x) `div` 2
x2 = x' * x'

-- | Return the rational sqrt of a
tryRationalSqrt :: Rational -> Maybe Rational
tryRationalSqrt r = do
(%) <\$> tryIntSqrt (numerator   r)
<*> tryIntSqrt (denominator r)

------------------------------------------------------------------------
-- Conversion

-- | Evaluate a real to an integer with rounding away from zero.
roundAway :: (RealFrac a) => a -> Integer
roundAway r = truncate (r + signum r * 0.5)
```