{-# Language BlockArguments, OverloadedStrings #-}
{-# Language BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# Language GADTs #-}
module What4.Utils.FloatHelpers where

import qualified Control.Exception as Ex
import Data.Ratio(numerator,denominator)
import Data.Hashable
import GHC.Generics (Generic)
import GHC.Stack

import LibBF

import What4.BaseTypes
import What4.Panic (panic)

-- | Rounding modes for IEEE-754 floating point operations.
data RoundingMode
  = RNE -- ^ Round to nearest even.
  | RNA -- ^ Round to nearest away.
  | RTP -- ^ Round toward plus Infinity.
  | RTN -- ^ Round toward minus Infinity.
  | RTZ -- ^ Round toward zero.
  deriving (RoundingMode -> RoundingMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RoundingMode -> RoundingMode -> Bool
$c/= :: RoundingMode -> RoundingMode -> Bool
== :: RoundingMode -> RoundingMode -> Bool
$c== :: RoundingMode -> RoundingMode -> Bool
Eq, forall x. Rep RoundingMode x -> RoundingMode
forall x. RoundingMode -> Rep RoundingMode x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep RoundingMode x -> RoundingMode
$cfrom :: forall x. RoundingMode -> Rep RoundingMode x
Generic, Eq RoundingMode
RoundingMode -> RoundingMode -> Bool
RoundingMode -> RoundingMode -> Ordering
RoundingMode -> RoundingMode -> RoundingMode
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: RoundingMode -> RoundingMode -> RoundingMode
$cmin :: RoundingMode -> RoundingMode -> RoundingMode
max :: RoundingMode -> RoundingMode -> RoundingMode
$cmax :: RoundingMode -> RoundingMode -> RoundingMode
>= :: RoundingMode -> RoundingMode -> Bool
$c>= :: RoundingMode -> RoundingMode -> Bool
> :: RoundingMode -> RoundingMode -> Bool
$c> :: RoundingMode -> RoundingMode -> Bool
<= :: RoundingMode -> RoundingMode -> Bool
$c<= :: RoundingMode -> RoundingMode -> Bool
< :: RoundingMode -> RoundingMode -> Bool
$c< :: RoundingMode -> RoundingMode -> Bool
compare :: RoundingMode -> RoundingMode -> Ordering
$ccompare :: RoundingMode -> RoundingMode -> Ordering
Ord, Int -> RoundingMode -> ShowS
[RoundingMode] -> ShowS
RoundingMode -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [RoundingMode] -> ShowS
$cshowList :: [RoundingMode] -> ShowS
show :: RoundingMode -> [Char]
$cshow :: RoundingMode -> [Char]
showsPrec :: Int -> RoundingMode -> ShowS
$cshowsPrec :: Int -> RoundingMode -> ShowS
Show, Int -> RoundingMode
RoundingMode -> Int
RoundingMode -> [RoundingMode]
RoundingMode -> RoundingMode
RoundingMode -> RoundingMode -> [RoundingMode]
RoundingMode -> RoundingMode -> RoundingMode -> [RoundingMode]
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: RoundingMode -> RoundingMode -> RoundingMode -> [RoundingMode]
$cenumFromThenTo :: RoundingMode -> RoundingMode -> RoundingMode -> [RoundingMode]
enumFromTo :: RoundingMode -> RoundingMode -> [RoundingMode]
$cenumFromTo :: RoundingMode -> RoundingMode -> [RoundingMode]
enumFromThen :: RoundingMode -> RoundingMode -> [RoundingMode]
$cenumFromThen :: RoundingMode -> RoundingMode -> [RoundingMode]
enumFrom :: RoundingMode -> [RoundingMode]
$cenumFrom :: RoundingMode -> [RoundingMode]
fromEnum :: RoundingMode -> Int
$cfromEnum :: RoundingMode -> Int
toEnum :: Int -> RoundingMode
$ctoEnum :: Int -> RoundingMode
pred :: RoundingMode -> RoundingMode
$cpred :: RoundingMode -> RoundingMode
succ :: RoundingMode -> RoundingMode
$csucc :: RoundingMode -> RoundingMode
Enum)

instance Hashable RoundingMode

bfStatus :: HasCallStack => (a, Status) -> a
bfStatus :: forall a. HasCallStack => (a, Status) -> a
bfStatus (a
_, Status
MemError)     = forall a e. Exception e => e -> a
Ex.throw AsyncException
Ex.HeapOverflow
bfStatus (a
x,Status
_)             = a
x

fppOpts :: FloatPrecisionRepr fpp -> RoundingMode -> BFOpts
fppOpts :: forall (fpp :: FloatPrecision).
FloatPrecisionRepr fpp -> RoundingMode -> BFOpts
fppOpts (FloatingPointPrecisionRepr NatRepr eb
eb NatRepr sb
sb) RoundingMode
r =
  Integer -> Integer -> RoundMode -> BFOpts
fpOpts (forall (n :: Natural). NatRepr n -> Integer
intValue NatRepr eb
eb) (forall (n :: Natural). NatRepr n -> Integer
intValue NatRepr sb
sb) (RoundingMode -> RoundMode
toRoundMode RoundingMode
r)

toRoundMode :: RoundingMode -> RoundMode
toRoundMode :: RoundingMode -> RoundMode
toRoundMode RoundingMode
RNE = RoundMode
NearEven
toRoundMode RoundingMode
RNA = RoundMode
NearAway
toRoundMode RoundingMode
RTP = RoundMode
ToPosInf
toRoundMode RoundingMode
RTN = RoundMode
ToNegInf
toRoundMode RoundingMode
RTZ = RoundMode
ToZero

-- | Make LibBF options for the given precision and rounding mode.
fpOpts :: Integer -> Integer -> RoundMode -> BFOpts
fpOpts :: Integer -> Integer -> RoundMode -> BFOpts
fpOpts Integer
e Integer
p RoundMode
r =
  case Maybe BFOpts
ok of
    Just BFOpts
opts -> BFOpts
opts
    Maybe BFOpts
Nothing   -> forall a. HasCallStack => [Char] -> [[Char]] -> a
panic [Char]
"floatOpts" [ [Char]
"Invalid Float size"
                                   , [Char]
"exponent: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Integer
e
                                   , [Char]
"precision: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Integer
p
                                   ]
  where
  ok :: Maybe BFOpts
ok = do BFOpts
eb <- forall {a} {a} {t} {a}.
(Integral a, Integral a, Num t) =>
(t -> a) -> a -> a -> Integer -> Maybe a
rng Int -> BFOpts
expBits Int
expBitsMin Int
expBitsMax Integer
e
          BFOpts
pb <- forall {a} {a} {t} {a}.
(Integral a, Integral a, Num t) =>
(t -> a) -> a -> a -> Integer -> Maybe a
rng Word -> BFOpts
precBits Int
precBitsMin Int
precBitsMax Integer
p
          forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (BFOpts
eb forall a. Semigroup a => a -> a -> a
<> BFOpts
pb forall a. Semigroup a => a -> a -> a
<> BFOpts
allowSubnormal forall a. Semigroup a => a -> a -> a
<> RoundMode -> BFOpts
rnd RoundMode
r)

  rng :: (t -> a) -> a -> a -> Integer -> Maybe a
rng t -> a
f a
a a
b Integer
x = if forall a. Integral a => a -> Integer
toInteger a
a forall a. Ord a => a -> a -> Bool
<= Integer
x Bool -> Bool -> Bool
&& Integer
x forall a. Ord a => a -> a -> Bool
<= forall a. Integral a => a -> Integer
toInteger a
b
                  then forall a. a -> Maybe a
Just (t -> a
f (forall a. Num a => Integer -> a
fromInteger Integer
x))
                  else forall a. Maybe a
Nothing


-- | Make a floating point number from an integer, using the given rounding mode
floatFromInteger :: BFOpts -> Integer -> BigFloat
floatFromInteger :: BFOpts -> Integer -> BigFloat
floatFromInteger BFOpts
opts Integer
i = forall a. HasCallStack => (a, Status) -> a
bfStatus (BFOpts -> BigFloat -> (BigFloat, Status)
bfRoundFloat BFOpts
opts (Integer -> BigFloat
bfFromInteger Integer
i))

-- | Make a floating point number from a rational, using the given rounding mode
floatFromRational :: BFOpts -> Rational -> BigFloat
floatFromRational :: BFOpts -> Rational -> BigFloat
floatFromRational BFOpts
opts Rational
rat = forall a. HasCallStack => (a, Status) -> a
bfStatus
    if Integer
den forall a. Eq a => a -> a -> Bool
== Integer
1 then BFOpts -> BigFloat -> (BigFloat, Status)
bfRoundFloat BFOpts
opts BigFloat
num
                else BFOpts -> BigFloat -> BigFloat -> (BigFloat, Status)
bfDiv BFOpts
opts BigFloat
num (Integer -> BigFloat
bfFromInteger Integer
den)
  where

  num :: BigFloat
num   = Integer -> BigFloat
bfFromInteger (forall a. Ratio a -> a
numerator Rational
rat)
  den :: Integer
den   = forall a. Ratio a -> a
denominator Rational
rat


-- | Convert a floating point number to a rational, if possible.
floatToRational :: BigFloat -> Maybe Rational
floatToRational :: BigFloat -> Maybe Rational
floatToRational BigFloat
bf =
  case BigFloat -> BFRep
bfToRep BigFloat
bf of
    BFRep
BFNaN -> forall a. Maybe a
Nothing
    BFRep Sign
s BFNum
num ->
      case BFNum
num of
        BFNum
Inf  -> forall a. Maybe a
Nothing
        BFNum
Zero -> forall a. a -> Maybe a
Just Rational
0
        Num Integer
i Int64
ev -> forall a. a -> Maybe a
Just case Sign
s of
                           Sign
Pos -> Rational
ab
                           Sign
Neg -> forall a. Num a => a -> a
negate Rational
ab
          where ab :: Rational
ab = forall a. Num a => Integer -> a
fromInteger Integer
i forall a. Num a => a -> a -> a
* (Rational
2 forall a b. (Fractional a, Integral b) => a -> b -> a
^^ Int64
ev)

-- | Convert a floating point number to an integer, if possible.
floatToInteger :: RoundingMode -> BigFloat -> Maybe Integer
floatToInteger :: RoundingMode -> BigFloat -> Maybe Integer
floatToInteger RoundingMode
r BigFloat
fp =
  do Rational
rat <- BigFloat -> Maybe Rational
floatToRational BigFloat
fp
     forall (f :: Type -> Type) a. Applicative f => a -> f a
pure case RoundingMode
r of
            RoundingMode
RNE -> forall a b. (RealFrac a, Integral b) => a -> b
round Rational
rat
            RoundingMode
RNA -> if Rational
rat forall a. Ord a => a -> a -> Bool
> Rational
0 then forall a b. (RealFrac a, Integral b) => a -> b
ceiling Rational
rat else forall a b. (RealFrac a, Integral b) => a -> b
floor Rational
rat
            RoundingMode
RTP -> forall a b. (RealFrac a, Integral b) => a -> b
ceiling Rational
rat
            RoundingMode
RTN -> forall a b. (RealFrac a, Integral b) => a -> b
floor Rational
rat
            RoundingMode
RTZ -> forall a b. (RealFrac a, Integral b) => a -> b
truncate Rational
rat

floatRoundToInt :: HasCallStack =>
  FloatPrecisionRepr fpp -> RoundingMode -> BigFloat -> BigFloat
floatRoundToInt :: forall (fpp :: FloatPrecision).
HasCallStack =>
FloatPrecisionRepr fpp -> RoundingMode -> BigFloat -> BigFloat
floatRoundToInt FloatPrecisionRepr fpp
fpp RoundingMode
r BigFloat
bf =
  forall a. HasCallStack => (a, Status) -> a
bfStatus (BFOpts -> BigFloat -> (BigFloat, Status)
bfRoundFloat (forall (fpp :: FloatPrecision).
FloatPrecisionRepr fpp -> RoundingMode -> BFOpts
fppOpts FloatPrecisionRepr fpp
fpp RoundingMode
r) (forall a. HasCallStack => (a, Status) -> a
bfStatus (RoundMode -> BigFloat -> (BigFloat, Status)
bfRoundInt (RoundingMode -> RoundMode
toRoundMode RoundingMode
r) BigFloat
bf)))