{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module What4.Utils.Arithmetic
(
isPow2
, lg
, lgCeil
, nextMultiple
, nextPow2Multiple
, tryIntSqrt
, tryRationalSqrt
, roundAway
, ctz
, clz
, rotateLeft
, rotateRight
) where
import Control.Exception (assert)
import Data.Bits (Bits(..))
import Data.Ratio
import Data.Parameterized.NatRepr
isPow2 :: (Bits a, Num a) => a -> Bool
isPow2 x = x .&. (x-1) == 0
lg :: (Bits a, Num a, Ord a) => a -> Int
lg i0 | i0 > 0 = go 0 (i0 `shiftR` 1)
| otherwise = error "lg given number that is not positive."
where go r 0 = r
go r n = go (r+1) (n `shiftR` 1)
lgCeil :: (Bits a, Num a, Ord a) => a -> Int
lgCeil 0 = 0
lgCeil 1 = 0
lgCeil i | i > 1 = 1 + lg (i-1)
| otherwise = error "lgCeil given number that is not positive."
ctz :: NatRepr w -> Integer -> Integer
ctz w x = go 0
where
go !i
| i < toInteger (natValue w) && testBit x (fromInteger i) == False = go (i+1)
| otherwise = i
clz :: NatRepr w -> Integer -> Integer
clz w x = go 0
where
go !i
| i < toInteger (natValue w) && testBit x (widthVal w - fromInteger i - 1) == False = go (i+1)
| otherwise = i
rotateRight ::
NatRepr w ->
Integer ->
Integer ->
Integer
rotateRight w x n = xor (shiftR x' n') (toUnsigned w (shiftL x' (widthVal w - n')))
where
x' = toUnsigned w x
n' = fromInteger (n `rem` intValue w)
rotateLeft ::
NatRepr w ->
Integer ->
Integer ->
Integer
rotateLeft w x n = xor (shiftR x' (widthVal w - n')) (toUnsigned w (shiftL x' n'))
where
x' = toUnsigned w x
n' = fromInteger (n `rem` intValue w)
nextMultiple :: Integral a => a -> a -> a
nextMultiple x y = ((y + x - 1) `div` x) * x
nextPow2Multiple :: (Bits a, Integral a) => a -> Int -> a
nextPow2Multiple x n | x >= 0 && n >= 0 = ((x+2^n -1) `shiftR` n) `shiftL` n
| otherwise = error "nextPow2Multiple given negative value."
tryIntSqrt :: Integer -> Maybe Integer
tryIntSqrt 0 = return 0
tryIntSqrt 1 = return 1
tryIntSqrt 2 = Nothing
tryIntSqrt 3 = Nothing
tryIntSqrt n = assert (n >= 4) $ go (n `shiftR` 1)
where go x | x2 < n = Nothing
| x2 == n = return x'
| True = go x'
where
x' = (x + n `div` x) `div` 2
x2 = x' * x'
tryRationalSqrt :: Rational -> Maybe Rational
tryRationalSqrt r = do
(%) <$> tryIntSqrt (numerator r)
<*> tryIntSqrt (denominator r)
roundAway :: (RealFrac a) => a -> Integer
roundAway r = truncate (r + signum r * 0.5)