{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Numeric.Rounded.Hardware.Backend.ViaRational where
import Control.DeepSeq (NFData (..))
import Control.Exception (assert)
import Data.Coerce
import Data.Functor.Product
import Data.Ratio
import Data.Tagged
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import Foreign.Storable (Storable)
import GHC.Generics (Generic)
import Numeric.Rounded.Hardware.Internal.Class
import Numeric.Rounded.Hardware.Internal.Constants
import Numeric.Rounded.Hardware.Internal.Conversion
import Numeric.Rounded.Hardware.Internal.FloatUtil (nextDown, nextUp)
newtype ViaRational a = ViaRational a
deriving (Eq,Ord,Show,Generic,Num,Storable)
instance NFData a => NFData (ViaRational a)
instance (RealFloat a, Num a, RealFloatConstants a) => RoundedRing (ViaRational a) where
roundedAdd r (ViaRational x) (ViaRational y)
| isNaN x || isNaN y || isInfinite x || isInfinite y = ViaRational (x + y)
| x == 0 && y == 0 = ViaRational $ if isNegativeZero x == isNegativeZero y
then x
else roundedZero
| otherwise = case toRational x + toRational y of
0 -> ViaRational roundedZero
z -> roundedFromRational r z
where roundedZero = case r of
ToNearest -> 0
TowardNegInf -> -0
TowardInf -> 0
TowardZero -> 0
roundedSub r (ViaRational x) (ViaRational y)
| isNaN x || isNaN y || isInfinite x || isInfinite y = ViaRational (x - y)
| x == 0 && y == 0 = ViaRational $ if isNegativeZero x /= isNegativeZero y
then x
else roundedZero
| otherwise = case toRational x - toRational y of
0 -> ViaRational roundedZero
z -> roundedFromRational r z
where roundedZero = case r of
ToNearest -> 0
TowardNegInf -> -0
TowardInf -> 0
TowardZero -> 0
roundedMul r (ViaRational x) (ViaRational y)
| isNaN x || isNaN y || isInfinite x || isInfinite y || isNegativeZero x || isNegativeZero y = ViaRational (x * y)
| otherwise = roundedFromRational r (toRational x * toRational y)
roundedFusedMultiplyAdd r (ViaRational x) (ViaRational y) (ViaRational z)
| isNaN x || isNaN y || isNaN z || isInfinite x || isInfinite y || isInfinite z = ViaRational (x * y + z)
| otherwise = case toRational x * toRational y + toRational z of
0 -> if z == 0 && isNegativeZero (x * y) == isNegativeZero z
then ViaRational z
else ViaRational roundedZero
w -> roundedFromRational r w
where roundedZero = case r of
ToNearest -> 0
TowardNegInf -> -0
TowardInf -> 0
TowardZero -> 0
roundedFromInteger r x = ViaRational (fromInt r x)
intervalFromInteger x = case fromIntF x :: Product (Rounded 'TowardNegInf) (Rounded 'TowardInf) a of
Pair a b -> (ViaRational <$> a, ViaRational <$> b)
backendNameT = Tagged "via Rational"
{-# INLINE roundedFromInteger #-}
{-# INLINE intervalFromInteger #-}
{-# SPECIALIZE instance RoundedRing (ViaRational Float) #-}
{-# SPECIALIZE instance RoundedRing (ViaRational Double) #-}
instance (RealFloat a, Num a, RealFloatConstants a) => RoundedFractional (ViaRational a) where
roundedDiv r (ViaRational x) (ViaRational y)
| isNaN x || isNaN y || isInfinite x || isInfinite y || x == 0 || y == 0 = ViaRational (x / y)
| otherwise = roundedFromRational r (toRational x / toRational y)
roundedFromRational r x = ViaRational $ fromRatio r (numerator x) (denominator x)
roundedFromRealFloat r x | isNaN x = ViaRational (0/0)
| isInfinite x = ViaRational (if x > 0 then 1/0 else -1/0)
| isNegativeZero x = ViaRational (-0)
| otherwise = roundedFromRational r (toRational x)
intervalFromRational x = case fromRatioF (numerator x) (denominator x) :: Product (Rounded 'TowardNegInf) (Rounded 'TowardInf) a of
Pair a b -> (ViaRational <$> a, ViaRational <$> b)
{-# INLINE roundedFromRational #-}
{-# INLINE intervalFromRational #-}
{-# SPECIALIZE instance RoundedFractional (ViaRational Float) #-}
{-# SPECIALIZE instance RoundedFractional (ViaRational Double) #-}
instance (RealFloat a, RealFloatConstants a) => RoundedSqrt (ViaRational a) where
roundedSqrt r (ViaRational x)
| r /= ToNearest && x >= 0 = ViaRational $
case compare ((toRational y) ^ (2 :: Int)) (toRational x) of
LT | r == TowardInf -> let z = nextUp y
in assert (toRational x < (toRational z) ^ (2 :: Int)) z
| otherwise -> y
EQ -> y
GT | r == TowardInf -> y
| otherwise -> let z = nextDown y
in assert ((toRational z) ^ (2 :: Int) < toRational x) z
| otherwise = ViaRational y
where y = sqrt x
instance (RealFloat a, RealFloatConstants a, Storable a) => RoundedRing_Vector VS.Vector (ViaRational a)
instance (RealFloat a, RealFloatConstants a, Storable a) => RoundedFractional_Vector VS.Vector (ViaRational a)
instance (RealFloat a, RealFloatConstants a, Storable a) => RoundedSqrt_Vector VS.Vector (ViaRational a)
instance (RealFloat a, RealFloatConstants a, VU.Unbox a) => RoundedRing_Vector VU.Vector (ViaRational a)
instance (RealFloat a, RealFloatConstants a, VU.Unbox a) => RoundedFractional_Vector VU.Vector (ViaRational a)
instance (RealFloat a, RealFloatConstants a, VU.Unbox a) => RoundedSqrt_Vector VU.Vector (ViaRational a)
newtype instance VUM.MVector s (ViaRational a) = MV_ViaRational (VUM.MVector s a)
newtype instance VU.Vector (ViaRational a) = V_ViaRational (VU.Vector a)
instance VU.Unbox a => VGM.MVector VUM.MVector (ViaRational a) where
basicLength (MV_ViaRational mv) = VGM.basicLength mv
basicUnsafeSlice i l (MV_ViaRational mv) = MV_ViaRational (VGM.basicUnsafeSlice i l mv)
basicOverlaps (MV_ViaRational mv) (MV_ViaRational mv') = VGM.basicOverlaps mv mv'
basicUnsafeNew l = MV_ViaRational <$> VGM.basicUnsafeNew l
basicInitialize (MV_ViaRational mv) = VGM.basicInitialize mv
basicUnsafeReplicate i x = MV_ViaRational <$> VGM.basicUnsafeReplicate i (coerce x)
basicUnsafeRead (MV_ViaRational mv) i = coerce <$> VGM.basicUnsafeRead mv i
basicUnsafeWrite (MV_ViaRational mv) i x = VGM.basicUnsafeWrite mv i (coerce x)
basicClear (MV_ViaRational mv) = VGM.basicClear mv
basicSet (MV_ViaRational mv) x = VGM.basicSet mv (coerce x)
basicUnsafeCopy (MV_ViaRational mv) (MV_ViaRational mv') = VGM.basicUnsafeCopy mv mv'
basicUnsafeMove (MV_ViaRational mv) (MV_ViaRational mv') = VGM.basicUnsafeMove mv mv'
basicUnsafeGrow (MV_ViaRational mv) n = MV_ViaRational <$> VGM.basicUnsafeGrow mv n
instance VU.Unbox a => VG.Vector VU.Vector (ViaRational a) where
basicUnsafeFreeze (MV_ViaRational mv) = V_ViaRational <$> VG.basicUnsafeFreeze mv
basicUnsafeThaw (V_ViaRational v) = MV_ViaRational <$> VG.basicUnsafeThaw v
basicLength (V_ViaRational v) = VG.basicLength v
basicUnsafeSlice i l (V_ViaRational v) = V_ViaRational (VG.basicUnsafeSlice i l v)
basicUnsafeIndexM (V_ViaRational v) i = coerce <$> VG.basicUnsafeIndexM v i
basicUnsafeCopy (MV_ViaRational mv) (V_ViaRational v) = VG.basicUnsafeCopy mv v
elemseq (V_ViaRational v) x y = VG.elemseq v (coerce x) y
instance VU.Unbox a => VU.Unbox (ViaRational a)