{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UnboxedTuples #-}
{-# OPTIONS_HADDOCK not-home #-}
module Numeric.Rounded.Internal where
import Control.Exception (bracket, bracket_, throwIO, ArithException(Overflow))
import Data.Bits (shiftL, testBit)
import Data.Coerce (coerce)
import Data.Int (Int32)
import Data.Proxy (Proxy(..))
import Data.Ratio ((%))
import Foreign (with, alloca, allocaBytes, peek, sizeOf, nullPtr)
import Foreign.C (CInt(..), CIntMax(..))
import Foreign.C.String (peekCString)
import Numeric.LongDouble (LongDouble)
import System.IO.Unsafe (unsafePerformIO)
import GHC.Prim
( ByteArray#
, sizeofByteArray#
, copyByteArrayToAddr#
, newByteArray#
, copyAddrToByteArray#
, unsafeFreezeByteArray#
)
import GHC.Types (IO(..))
import GHC.Exts (Ptr(..), Int(..))
#if MIN_VERSION_base(4,9,0)
import Numeric (Floating(..))
#endif
import Numeric (readSigned, readFloat)
import Numeric.GMP.Utils (withInInteger, withOutInteger, withOutInteger_, withInRational)
import Numeric.GMP.Types (MPLimb)
import Numeric.MPFR.Types
import Numeric.MPFR.Raw
import Numeric.Rounded.Precision
import Numeric.Rounded.Rounding
type role Rounded phantom nominal
data Rounded (r :: RoundingMode) p = Rounded
{ roundedPrec :: !MPFRPrec
, roundedSign :: !MPFRSign
, roundedExp :: !MPFRExp
, roundedLimbs :: !ByteArray#
}
toDouble :: (Rounding r, Precision p) => Rounded r p -> Double
toDouble x = unsafePerformIO $ in_ x $ \xfr -> mpfr_get_d xfr (rnd x)
{-# RULES "realToFrac/toDouble" forall (x :: (Rounding r, Precision p) => Rounded r p) . realToFrac x = toDouble x #-}
toLongDouble :: (Rounding r, Precision p) => Rounded r p -> LongDouble
toLongDouble x = unsafePerformIO $ in_ x $ \xfr -> with 0 $ \yfr -> with 0 $ \ffr -> wrapped_mpfr_get_ld yfr xfr (rnd x) ffr >> peek yfr
{-# RULES "realToFrac/toLongDouble" forall (x :: (Rounding r, Precision p) => Rounded r p) . realToFrac x = toLongDouble x #-}
precRound :: (Rounding r, Precision p1, Precision p2) => Rounded r p1 -> Rounded r p2
precRound x = unsafePerformIO $ do
(Just y, _) <- in_ x $ \xfr -> out_ $ \yfr ->
mpfr_set yfr xfr (rnd x)
return y
toString :: (Rounding r, Precision p) => Rounded r p -> String
toString x = unsafePerformIO $ do
(s, e) <- in_ x $ \xfr -> with 0 $ \eptr -> do
s <- bracket (mpfr_get_str nullPtr eptr 10 0 xfr (fromIntegral (fromEnum TowardNearest))) mpfr_free_str peekCString
e <- peek eptr
return (s, fromIntegral e)
return $ case () of
_ | isNaN x -> "NaN"
| isInfinite x && sgn' == GT -> "Infinity"
| isInfinite x -> "-Infinity"
| isNegativeZero x -> "-0.0"
| sgn' == EQ -> "0.0"
| e < 0 ||
e >= threshold -> sign ++ take 1 digits ++ "." ++
dropTrailingZeroes (take (n - 1) (drop 1 digits0)) ++
"e" ++ show (e - 1)
| e == 0 -> sign ++ "0." ++
dropTrailingZeroes digits
| e < threshold -> sign ++ take e digits0 ++ "." ++
dropTrailingZeroes (take (n - e) (drop e digits0))
where
sgn' = sgn x
sign = case sgn' of
GT -> ""
EQ -> ""
LT -> "-"
threshold = 8
n = length digits
digits = case take 1 s of
"-" -> drop 1 s
_ -> s
digits0 = digits ++ repeat '0'
dropTrailingZeroes a = case dropWhile ('0' ==) (reverse a) of
"" -> "0"
b -> reverse b
instance (Rounding r, Precision p) => Show (Rounded r p) where
showsPrec p x = showParen (p >= 7 && take 1 s == "-") (s ++)
where s = toString x
instance (Rounding r, Precision p) => Read (Rounded r p) where
readsPrec _ = readSigned readFloat
unary
:: (Rounding r, Precision p1, Precision p2)
=> Unary -> Rounded r p1 -> Rounded r p2
unary f a = unsafePerformIO $ do
(Just c, _) <- in_ a $ \afr ->
out_ $ \cfr ->
f cfr afr (rnd a)
return c
unary' :: Rounding r => Unary -> Rounded r p -> Rounded r p
unary' f a = unsafePerformIO $ do
(Just c, _) <- in_ a $ \afr ->
out_' (roundedPrec a) $ \cfr ->
f cfr afr (rnd a)
return c
unary'' :: Unary -> Rounded r p -> Rounded r p
unary'' f a = unsafePerformIO $ do
(Just c, _) <- in_ a $ \afr ->
out_' (roundedPrec a) $ \cfr ->
f cfr afr (fromIntegral (fromEnum TowardNearest))
return c
abs' :: Rounded r p -> Rounded r p
abs' = unary'' mpfr_abs
negate' :: Rounded r p -> Rounded r p
negate' = unary'' mpfr_neg
(.-.), (.+.), (.*.) :: Rounding r => Rounded r p -> Rounded r p -> Rounded r p
(.-.) = binary' mpfr_sub
(.+.) = binary' mpfr_add
(.*.) = binary' mpfr_mul
infixl 6 .+., .-.
infixl 7 .*.
abs_, negate_, log_, exp_, sqrt_,
sin_, cos_, tan_, asin_, acos_, atan_,
sinh_, cosh_, tanh_, asinh_, acosh_, atanh_,
log1p_, expm1_
:: (Rounding r, Precision p1, Precision p2)
=> Rounded r p1 -> Rounded r p2
abs_ = unary mpfr_abs
negate_ = unary mpfr_neg
log_ = unary mpfr_log
exp_ = unary mpfr_exp
sqrt_ = unary mpfr_sqrt
sin_ = unary mpfr_sin
cos_ = unary mpfr_cos
tan_ = unary mpfr_tan
asin_ = unary mpfr_asin
acos_ = unary mpfr_acos
atan_ = unary mpfr_atan
sinh_ = unary mpfr_sinh
cosh_ = unary mpfr_cosh
tanh_ = unary mpfr_tanh
asinh_ = unary mpfr_asinh
acosh_ = unary mpfr_acosh
atanh_ = unary mpfr_atanh
log1p_ = unary mpfr_log1p
expm1_ = unary mpfr_expm1
binary
:: (Rounding r, Precision p1, Precision p2, Precision p3)
=> Binary -> Rounded r p1 -> Rounded r p2 -> Rounded r p3
binary f a b = unsafePerformIO $ do
(Just c, _) <- in_ a $ \afr ->
in_ b $ \bfr ->
out_ $ \cfr ->
f cfr afr bfr (rnd a)
return c
min_, max_, (!+!), (!-!), (!*!), (!/!), atan2_
:: (Rounding r, Precision p1, Precision p2, Precision p3)
=> Rounded r p1 -> Rounded r p2 -> Rounded r p3
min_ = binary mpfr_min
max_ = binary mpfr_max
(!+!) = binary mpfr_add
(!-!) = binary mpfr_sub
(!*!) = binary mpfr_mul
(!/!) = binary mpfr_div
atan2_ = binary mpfr_atan2
infixl 6 !+!, !-!
infixl 7 !*!, !/!
binary' :: Rounding r => Binary -> Rounded r p -> Rounded r p -> Rounded r p
binary' f a b = unsafePerformIO $ do
(Just c, _) <- in_ a $ \afr ->
in_ b $ \bfr ->
out_' (roundedPrec a) $ \cfr ->
f cfr afr bfr (rnd a)
return c
cmp' :: Comparison -> Rounded r p1 -> Rounded r p2 -> CInt
cmp' f a b = unsafePerformIO $
in_ a $ \afr ->
in_ b $ \bfr -> do
f afr bfr
cmp :: Comparison -> Rounded r p1 -> Rounded r p2 -> Bool
cmp f a b = cmp' f a b /= 0
(!==!), (!/=!), (!<=!), (!>=!), (!<!), (!>!)
:: (Precision p1, Precision p2)
=> Rounded r p1 -> Rounded r p2 -> Bool
(!==!) = cmp mpfr_equal_p
(!/=!) = cmp mpfr_lessgreater_p
(!<=!) = cmp mpfr_lessequal_p
(!>=!) = cmp mpfr_greaterequal_p
(!<!) = cmp mpfr_less_p
(!>!) = cmp mpfr_greater_p
infix 4 !==!, !/=!, !<=!, !>=!, !<!, !>!
compare_ :: (Precision p1, Precision p2) => Rounded r p1 -> Rounded r p2 -> Ordering
compare_ a b = compare (cmp' mpfr_cmp a b) 0
instance Eq (Rounded r p) where
(==) = cmp mpfr_equal_p
(/=) = cmp mpfr_lessgreater_p
instance Rounding r => Ord (Rounded r p) where
compare a b = compare (cmp' mpfr_cmp a b) 0
(<=) = cmp mpfr_lessequal_p
(>=) = cmp mpfr_greaterequal_p
(<) = cmp mpfr_less_p
(>) = cmp mpfr_greater_p
min = binary' mpfr_min
max = binary' mpfr_max
sgn :: (Rounding r, Precision p) => Rounded r p -> Ordering
sgn x = compare (unsafePerformIO $ in_ x mpfr_sgn) 0
instance (Rounding r, Precision p) => Num (Rounded r p) where
(+) = (.+.)
(-) = (.-.)
(*) = (.*.)
negate = negate'
fromInteger j = r where
r = unsafePerformIO $ do
if toInteger (minBound :: CIntMax) <= j && j <= toInteger (maxBound :: CIntMax)
then do
(Just x, _) <- out_ $ \jfr -> mpfr_set_sj jfr (fromInteger j :: CIntMax) (rnd r)
return x
else do
(Just x, _) <- withInInteger j $ \jz -> out_ $ \jfr -> mpfr_set_z jfr jz (rnd r)
return x
abs = abs'
signum x = case sgn x of
LT -> -1
EQ -> 0
GT -> 1
instance (Rounding r, Precision p) => Fractional (Rounded r p) where
fromRational q = r where
r = unsafePerformIO $ do
(Just x, _) <- withInRational q $ \qq -> out_ $ \qfr -> mpfr_set_q qfr qq (rnd r)
return x
(/) = (!/!)
fromInt :: (Rounding r, Precision p) => Int -> Rounded r p
fromInt i = r
where
r = unsafePerformIO $ do
(Just x, _) <- out_ $ \xfr -> mpfr_set_sj xfr (fromIntegral i) (rnd r)
return x
fromDouble :: (Rounding r, Precision p) => Double -> Rounded r p
fromDouble d = r
where
r = unsafePerformIO $ do
(Just x, _) <- out_ $ \xfr -> mpfr_set_d xfr d (rnd r)
return x
fromLongDouble :: (Rounding r, Precision p) => LongDouble -> Rounded r p
fromLongDouble d = r
where
r = unsafePerformIO $ do
(Just x, _) <- out_ $ \xfr -> with d $ \dp -> wrapped_mpfr_set_ld xfr dp (rnd r)
return x
inplace :: (Ptr MPFR -> IO ()) -> Rounded r p -> Rounded r p
inplace f y = unsafePerformIO $ do
(Just x, _) <- out_' (roundedPrec y) $ \xfr -> in_ y $ \yfr -> do
_ <- mpfr_set xfr yfr (fromIntegral (fromEnum TowardNearest))
f xfr
return x
succUlp, predUlp :: Rounded r p -> Rounded r p
succUlp = inplace mpfr_nextabove
predUlp = inplace mpfr_nextbelow
constant :: (Rounding r, Precision p) => Constant -> Rounded r p
constant k = r where
r = unsafePerformIO $ do
(Just x, _) <- out_ $ \xfr -> k xfr (rnd r)
return x
instance (Rounding r, Precision p) => Floating (Rounded r p) where
pi = kPi
exp = exp_
sqrt = sqrt_
log = log_
sin = sin_
tan = tan_
cos = cos_
asin = asin_
atan = atan_
acos = acos_
sinh = sinh_
tanh = tanh_
cosh = cosh_
asinh = asinh_
atanh = atanh_
acosh = acosh_
#if MIN_VERSION_base(4,9,0)
log1p = log1p_
expm1 = expm1_
#endif
toRational' :: Precision p => Rounded r p -> Rational
toRational' r
| e > 0 = fromIntegral (s `shiftL` e)
| otherwise = s % (1 `shiftL` negate e)
where (s, e) = decodeFloat' r
instance (Rounding r, Precision p) => Real (Rounded r p) where
toRational = toRational'
modf :: (Rounding r, Precision p) => Rounded r p -> (Rounded r p, Rounded r p)
modf x = unsafePerformIO $ do
(Just y, (Just z, _)) <- in_ x $ \xfr ->
out_ $ \yfr ->
out_ $ \zfr ->
mpfr_modf yfr zfr xfr (rnd x)
return (y, z)
toInteger' :: (Rounding r, Precision p) => Rounded r p -> Integer
toInteger' x = unsafePerformIO $
withOutInteger_ $ \yz ->
in_ x $ \xfr ->
with 0 $ \flagsptr -> do
e <- wrapped_mpfr_get_z yz xfr (rnd x) flagsptr
flags <- peek flagsptr
case testBit flags erangeBit of
False -> return e
True -> throwIO Overflow
instance (Rounding r, Precision p) => RealFrac (Rounded r p) where
properFraction r = (fromInteger (toInteger' i), f) where (i, f) = modf r
truncate = roundFunc truncate_
round = roundFunc round_
ceiling = roundFunc ceiling_
floor = roundFunc floor_
roundFunc :: (Integral i, Precision p) => (Rounded TowardNearest p -> Rounded TowardNearest p) -> Rounded r p -> i
roundFunc f = fromInteger . toInteger' . f . coerce
unary_ :: (Precision p1, Precision p2) => (Ptr MPFR -> Ptr MPFR -> IO CInt) -> Rounded r p1 -> Rounded r p2
unary_ f x = unsafePerformIO $ do
Just y <- withInRounded x $ \xp -> withOutRounded_ $ \yp -> f yp xp
return y
truncate_, ceiling_, floor_, round_ :: (Precision p1, Precision p2) => Rounded r p1 -> Rounded r p2
truncate_ = unary_ mpfr_trunc
ceiling_ = unary_ mpfr_ceil
floor_ = unary_ mpfr_floor
round_ = unary_ (\yp xp -> mpfr_rint yp xp (fromIntegral (fromEnum TowardNearest)))
tst :: (Precision p) => Test -> Rounded r p -> Bool
tst f x = unsafePerformIO $ in_ x $ \xfr -> do
t <- f xfr
return (t /= 0)
decodeFloat' :: Rounded r p -> (Integer, Int)
decodeFloat' x = case (unsafePerformIO $ do
in_ x $ \xfr -> withOutInteger $ \xz -> with 0 $ \flagsptr -> do
e <- wrapped_mpfr_get_z_2exp xz xfr flagsptr
flags <- peek flagsptr
case testBit flags erangeBit of
False -> return (fromIntegral e)
True -> throwIO Overflow) of
(0, _) -> (0, 0)
me -> me
encodeFloat' :: (Rounding r, Precision p) => Integer -> Int -> Rounded r p
encodeFloat' j e = r where
r = unsafePerformIO $ do
(Just x, _) <- withInInteger j $ \jz -> out_ $ \xfr -> mpfr_set_z_2exp xfr jz (fromIntegral e) (rnd r)
return x
instance (Rounding r, Precision p) => RealFloat (Rounded r p) where
floatRadix _ = 2
floatDigits = self where
self _ = p
p = precision (0 `asTypeIn` self)
asTypeIn :: a -> (a -> b) -> a
asTypeIn = const
floatRange _ = (MPFR_EMIN_DEFAULT, MPFR_EMAX_DEFAULT)
decodeFloat = decodeFloat'
encodeFloat = encodeFloat'
isNaN = tst mpfr_nan_p
isInfinite = tst mpfr_inf_p
isDenormalized _ = False
isNegativeZero r = tst mpfr_zero_p r && tst mpfr_signbit r
isIEEE _ = True
atan2 = atan2_
kPi :: (Rounding r, Precision p) => Rounded r p
kPi = constant mpfr_const_pi
kLog2 :: (Rounding r, Precision p) => Rounded r p
kLog2 = constant mpfr_const_log2
kEuler :: (Rounding r, Precision p) => Rounded r p
kEuler = constant mpfr_const_euler
kCatalan :: (Rounding r, Precision p) => Rounded r p
kCatalan = constant mpfr_const_catalan
in_' :: Rounded r p -> (MPFR -> IO a) -> IO a
in_' (Rounded p s e l) f = withByteArray l $ \ptr _bytes -> f MPFR
{ mpfrPrec = p
, mpfrSign = s
, mpfrExp = e
, mpfrD = ptr
}
in_ :: Rounded r p -> (Ptr MPFR -> IO a) -> IO a
in_ x f = in_' x $ \y -> with y f
out_' :: MPFRPrec -> (Ptr MPFR -> IO a) -> IO (Maybe (Rounded r p), a)
out_' p f = allocaBytes (precBytes p) $ \d -> with
MPFR{ mpfrPrec = p, mpfrSign = 0, mpfrExp = 0, mpfrD = d } $ \ptr -> do
a <- f ptr
MPFR{ mpfrPrec = p', mpfrSign = s', mpfrExp = e', mpfrD = d' } <- peek ptr
if p /= p' then return (Nothing, a) else
asByteArray d' (precBytes p') $ \l' -> return (Just (Rounded p' s' e' l'), a)
out_ :: Precision p => (Ptr MPFR -> IO a) -> IO (Maybe (Rounded r p), a)
out_ f = r where
r = out_' prec f
prec = fromIntegral (precision (t r))
t :: IO (Maybe t, a) -> t
t _ = undefined
withInRounded :: Rounded r p -> (Ptr MPFR -> IO a) -> IO a
withInRounded = in_
withOutRounded :: Precision p => (Ptr MPFR -> IO a) -> IO (Maybe (Rounded r p), a)
withOutRounded f = r where
r = alloca $ \ptr -> bracket_ (mpfr_init2 ptr prec) (mpfr_clear ptr) $ do
a <- f ptr
MPFR{ mpfrPrec = prec', mpfrSign = s, mpfrExp = e, mpfrD = d } <- peek ptr
if prec /= prec'
then return (Nothing, a)
else asByteArray d (precBytes prec) $ \l ->
return (Just (Rounded prec s e l), a)
prec = fromIntegral (precision (t r))
t :: IO (Maybe b, a) -> b
t _ = undefined
withOutRounded_ :: Precision p => (Ptr MPFR -> IO a) -> IO (Maybe (Rounded r p))
withOutRounded_ = fmap fst . withOutRounded
unsafeWithOutRounded :: Precision p => (Ptr MPFR -> IO a) -> IO (Maybe (Rounded r p), a)
unsafeWithOutRounded = out_
unsafeWithOutRounded_ :: Precision p => (Ptr MPFR -> IO a) -> IO (Maybe (Rounded r p))
unsafeWithOutRounded_ = fmap fst . out_
withInOutRounded :: Precision p => Rounded r p -> (Ptr MPFR -> IO a) -> IO (Maybe (Rounded r p), a)
withInOutRounded i f =
withOutRounded $ \ofr ->
in_ i $ \ifr -> do
_ <- mpfr_set ofr ifr (fromIntegral (fromEnum TowardNearest))
f ofr
withInOutRounded_ :: Precision p => Rounded r p -> (Ptr MPFR -> IO a) -> IO (Maybe (Rounded r p))
withInOutRounded_ x = fmap fst . withInOutRounded x
peekRounded :: Rounding r => Ptr MPFR -> (forall (p :: *) . Precision p => Rounded r p -> IO a) -> IO a
peekRounded ptr f = do
MPFR{ mpfrPrec = p', mpfrSign = s', mpfrExp = e', mpfrD = d' } <- peek ptr
asByteArray d' (precBytes p') $ \l' -> reifyPrecision (fromIntegral p') (wrap f (Rounded p' s' e' l'))
where
wrap :: forall (p :: *) (r :: RoundingMode) (a :: *) . (Rounding r, Precision p) => (forall (q :: *) . Precision q => Rounded r q -> IO a) -> Rounded r p -> Proxy p -> IO a
wrap g r = \_proxy -> g r
precBytes :: MPFRPrec -> Int
precBytes prec = bytesPerLimb * ((fromIntegral prec + bitsPerLimb1) `div` bitsPerLimb)
bytesPerLimb :: Int
bytesPerLimb = sizeOf (undefined :: MPLimb)
bitsPerLimb :: Int
bitsPerLimb = bytesPerLimb * 8
bitsPerLimb1 :: Int
bitsPerLimb1 = bitsPerLimb - 1
erangeBit :: Int
erangeBit = 5
rnd :: Rounding r => Rounded r p -> MPFRRnd
rnd = fromIntegral . fromEnum . rounding . proxyRounding
proxyRounding :: Rounded r p -> Proxy r
proxyRounding _ = Proxy
withByteArray :: ByteArray# -> (Ptr a -> Int -> IO r) -> IO r
withByteArray ba# f = do
let bytes = I# (sizeofByteArray# ba#)
allocaBytes bytes $ \ptr@(Ptr addr#) -> do
IO (\s -> (# copyByteArrayToAddr# ba# 0# addr# (sizeofByteArray# ba#) s, () #))
f ptr bytes
asByteArray :: Ptr a -> Int -> (ByteArray# -> IO r) -> IO r
asByteArray (Ptr addr#) (I# bytes#) f = do
IO $ \s# -> case newByteArray# bytes# s# of
(# s'#, mba# #) ->
case unsafeFreezeByteArray# mba# (copyAddrToByteArray# addr# mba# 0# bytes# s'#) of
(# s''#, ba# #) -> case f ba# of IO r -> r s''#