{-# LANGUAGE CPP #-}
{-# LANGUAGE NumericUnderscores #-}
module Numeric.Rounded.Hardware.Internal.FloatUtil
( nextUp
, nextDown
, nextTowardZero
, minPositive_ieee
, maxFinite_ieee
, distanceUlp
, fusedMultiplyAdd
) where
import Data.Bits
import Data.Ratio
import GHC.Float (castDoubleToWord64, castFloatToWord32,
castWord32ToFloat, castWord64ToDouble)
minPositive_ieee :: RealFloat a => a
minPositive_ieee = let d = floatDigits x
(expMin,_expMax) = floatRange x
x = encodeFloat 1 (expMin - d)
in x
{-# SPECIALIZE minPositive_ieee :: Double #-}
{-# SPECIALIZE minPositive_ieee :: Float #-}
maxFinite_ieee :: RealFloat a => a
maxFinite_ieee = let d = floatDigits x
(_expMin,expMax) = floatRange x
r = floatRadix x
x = encodeFloat (r ^! d - 1) (expMax - d)
in x
{-# SPECIALIZE maxFinite_ieee :: Double #-}
{-# SPECIALIZE maxFinite_ieee :: Float #-}
infixr 8 ^!
(^!) :: Integer -> Int -> Integer
(^!) = (^)
{-# INLINE [2] (^!) #-}
{-# RULES
"2^!" forall y. 2 ^! y = staticIf (y >= 0) (1 `shiftL` y) (2 ^ y)
#-}
staticIf :: Bool -> a -> a -> a
staticIf _ _ x = x
{-# INLINE [0] staticIf #-}
{-# RULES
"staticIf/True" forall x y. staticIf True x y = x
"staticIf/False" forall x y. staticIf False x y = y
#-}
nextUp :: RealFloat a => a -> a
nextUp x | not (isIEEE x) = error "non-IEEE numbers are not supported"
| floatRadix x /= 2 = error "non-binary types are not supported"
| isNaN x || (isInfinite x && x > 0) = x
| x >= 0 = nextUp_ieee_positive x
| otherwise = - nextDown_ieee_positive (- x)
{-# INLINE [1] nextUp #-}
nextDown :: RealFloat a => a -> a
nextDown x | not (isIEEE x) = error "non-IEEE numbers are not supported"
| floatRadix x /= 2 = error "non-binary types are not supported"
| isNaN x || (isInfinite x && x < 0) = x
| x >= 0 = nextDown_ieee_positive x
| otherwise = - nextUp_ieee_positive (- x)
{-# INLINE [1] nextDown #-}
nextTowardZero :: RealFloat a => a -> a
nextTowardZero x | not (isIEEE x) = error "non-IEEE numbers are not supported"
| floatRadix x /= 2 = error "non-binary types are not supported "
| isNaN x || x == 0 = x
| x >= 0 = nextDown_ieee_positive x
| otherwise = - nextDown_ieee_positive (- x)
{-# INLINE [1] nextTowardZero #-}
nextUp_ieee_positive :: RealFloat a => a -> a
nextUp_ieee_positive x
| isNaN x || x < 0 = error "nextUp_ieee_positive"
| isInfinite x = x
| x == 0 = encodeFloat 1 (expMin - d)
| otherwise = let m :: Integer
e :: Int
(m,e) = decodeFloat x
in if expMin - d <= e
then encodeFloat (m + 1) e
else let m' = m `shiftR` (expMin - d - e)
in encodeFloat (m' + 1) (expMin - d)
where
d, expMin :: Int
d = floatDigits x
(expMin,_expMax) = floatRange x
{-# INLINE nextUp_ieee_positive #-}
nextDown_ieee_positive :: RealFloat a => a -> a
nextDown_ieee_positive x
| isNaN x || x < 0 = error "nextDown_ieee_positive"
| isInfinite x = encodeFloat ((1 `unsafeShiftL` d) - 1) (expMax - d)
| x == 0 = encodeFloat (-1) (expMin - d)
| otherwise = let m :: Integer
e :: Int
(m,e) = decodeFloat x
in if expMin - d <= e
then
let m1 = m - 1
in if m .&. m1 == 0
then encodeFloat (2 * m - 1) (e - 1)
else encodeFloat m1 e
else
let m' = m `shiftR` (expMin - d - e)
in encodeFloat (m' - 1) (expMin - d)
where
d, expMin :: Int
d = floatDigits x
(expMin,expMax) = floatRange x
{-# INLINE nextDown_ieee_positive #-}
{-# RULES
"nextUp/Float" [~1] nextUp = nextUpFloat
"nextUp/Double" [~1] nextUp = nextUpDouble
"nextDown/Float" [~1] nextDown = nextDownFloat
"nextDown/Double" [~1] nextDown = nextDownDouble
"nextTowardZero/Float" [~1] nextTowardZero = nextTowardZeroFloat
"nextTowardZero/Double" [~1] nextTowardZero = nextTowardZeroDouble
#-}
nextUpFloat :: Float -> Float
nextUpFloat x
| not (isIEEE x) || floatRadix x /= 2 || d /= 24 || expMin /= -125 || expMax /= 128 = error "rounded-hw assumes Float is IEEE binary32"
| isNaN x = x
| isNegativeZero x = encodeFloat 1 (expMin - d)
| x < 0 = castWord32ToFloat (castFloatToWord32 x - 1)
| otherwise = case castFloatToWord32 x of
0x7f80_0000 -> x
w -> castWord32ToFloat (w + 1)
where
d, expMin :: Int
d = floatDigits x
(expMin,expMax) = floatRange x
nextUpDouble :: Double -> Double
nextUpDouble x
| not (isIEEE x) || floatRadix x /= 2 || d /= 53 || expMin /= -1021 || expMax /= 1024 = error "rounded-hw assumes Double is IEEE binary64"
| otherwise = case castDoubleToWord64 x of
w | w .&. 0x7ff0_0000_0000_0000 == 0x7ff0_0000_0000_0000
, w /= 0xfff0_0000_0000_0000 -> x
0x8000_0000_0000_0000 -> encodeFloat 1 (expMin - d)
w | testBit w 63 -> castWord64ToDouble (w - 1)
| otherwise -> castWord64ToDouble (w + 1)
where
d, expMin :: Int
d = floatDigits x
(expMin,expMax) = floatRange x
nextDownFloat :: Float -> Float
nextDownFloat x
| not (isIEEE x) || floatRadix x /= 2 || d /= 24 || expMin /= -125 || expMax /= 128 = error "rounded-hw assumes Float is IEEE binary32"
| isNaN x || (isInfinite x && x < 0) = x
| isNegativeZero x || x < 0 = castWord32ToFloat (castFloatToWord32 x + 1)
| x == 0 = encodeFloat (-1) (expMin - d)
| otherwise = castWord32ToFloat (castFloatToWord32 x - 1)
where
d, expMin :: Int
d = floatDigits x
(expMin,expMax) = floatRange x
nextDownDouble :: Double -> Double
nextDownDouble x
| not (isIEEE x) || floatRadix x /= 2 || d /= 53 || expMin /= -1021 || expMax /= 1024 = error "rounded-hw assumes Double is IEEE binary64"
| otherwise = case castDoubleToWord64 x of
w | w .&. 0x7ff0_0000_0000_0000 == 0x7ff0_0000_0000_0000
, w /= 0x7ff0_0000_0000_0000 -> x
0x0000_0000_0000_0000 -> encodeFloat (-1) (expMin - d)
w | testBit w 63 -> castWord64ToDouble (w + 1)
| otherwise -> castWord64ToDouble (w - 1)
where
d, expMin :: Int
d = floatDigits x
(expMin,expMax) = floatRange x
nextTowardZeroFloat :: Float -> Float
nextTowardZeroFloat x
| not (isIEEE x) || floatRadix x /= 2 || d /= 24 || expMin /= -125 || expMax /= 128 = error "rounded-hw assumes Float is IEEE binary32"
| isNaN x || x == 0 = x
| otherwise = castWord32ToFloat (castFloatToWord32 x - 1)
where
d, expMin :: Int
d = floatDigits x
(expMin,expMax) = floatRange x
nextTowardZeroDouble :: Double -> Double
nextTowardZeroDouble x
| not (isIEEE x) || floatRadix x /= 2 || d /= 53 || expMin /= -1021 || expMax /= 1024 = error "rounded-hw assumes Double is IEEE binary64"
| otherwise = case castDoubleToWord64 x of
w | w .&. 0x7ff0_0000_0000_0000 == 0x7ff0_0000_0000_0000
, w .&. 0x000f_ffff_ffff_ffff /= 0 -> x
0x8000_0000_0000_0000 -> x
0x0000_0000_0000_0000 -> x
w -> castWord64ToDouble (w - 1)
where
d, expMin :: Int
d = floatDigits x
(expMin,expMax) = floatRange x
fusedMultiplyAdd :: RealFloat a => a -> a -> a -> a
fusedMultiplyAdd x y z
| isNaN x || isNaN y || isNaN z || isInfinite x || isInfinite y || isInfinite z = x * y + z
| otherwise = case toRational x * toRational y + toRational z of
0 | isNegativeZero (x * y + z) -> -0
r -> fromRational r
{-# NOINLINE [1] fusedMultiplyAdd #-}
#ifdef USE_FFI
foreign import ccall unsafe "fmaf"
fusedMultiplyAddFloat :: Float -> Float -> Float -> Float
foreign import ccall unsafe "fma"
fusedMultiplyAddDouble :: Double -> Double -> Double -> Double
{-# RULES
"fusedMultiplyAdd/Float" fusedMultiplyAdd = fusedMultiplyAddFloat
"fusedMultiplyAdd/Double" fusedMultiplyAdd = fusedMultiplyAddDouble
#-}
#endif
distanceUlp :: RealFloat a => a -> a -> Maybe Integer
distanceUlp x y
| isInfinite x || isInfinite y || isNaN x || isNaN y = Nothing
| otherwise = let m = min (abs x) (abs y)
m' = nextUp m
v = (toRational y - toRational x) / toRational (m' - m)
in if denominator v == 1
then Just (abs (numerator v))
else error "distanceUlp"