{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Numeric.Rounded.Hardware.Backend.ViaRational where
import           Control.DeepSeq (NFData (..))
import           Control.Exception (assert)
import           Data.Coerce
import           Data.Functor.Product
import           Data.Ratio
import           Data.Tagged
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import           Foreign.Storable (Storable)
import           GHC.Generics (Generic)
import           Numeric.Rounded.Hardware.Internal.Class
import           Numeric.Rounded.Hardware.Internal.Constants
import           Numeric.Rounded.Hardware.Internal.Conversion
import           Numeric.Rounded.Hardware.Internal.FloatUtil (nextDown, nextUp)

newtype ViaRational a = ViaRational a
  deriving (Eq,Ord,Show,Generic,Num,Storable)

instance NFData a => NFData (ViaRational a)

instance (RealFloat a, Num a, RealFloatConstants a) => RoundedRing (ViaRational a) where
  roundedAdd r (ViaRational x) (ViaRational y)
    | isNaN x || isNaN y || isInfinite x || isInfinite y = ViaRational (x + y)
    | x == 0 && y == 0 = ViaRational $ if isNegativeZero x == isNegativeZero y
                                       then x
                                       else roundedZero
    | otherwise = case toRational x + toRational y of
                    0 -> ViaRational roundedZero
                    z -> roundedFromRational r z
    where roundedZero = case r of
            ToNearest    ->  0
            TowardNegInf -> -0
            TowardInf    ->  0
            TowardZero   ->  0
  roundedSub r (ViaRational x) (ViaRational y)
    | isNaN x || isNaN y || isInfinite x || isInfinite y = ViaRational (x - y)
    | x == 0 && y == 0 = ViaRational $ if isNegativeZero x /= isNegativeZero y
                                       then x
                                       else roundedZero
    | otherwise = case toRational x - toRational y of
                    0 -> ViaRational roundedZero
                    z -> roundedFromRational r z
    where roundedZero = case r of
            ToNearest    ->  0
            TowardNegInf -> -0
            TowardInf    ->  0
            TowardZero   ->  0
  roundedMul r (ViaRational x) (ViaRational y)
    | isNaN x || isNaN y || isInfinite x || isInfinite y || isNegativeZero x || isNegativeZero y = ViaRational (x * y)
    | otherwise = roundedFromRational r (toRational x * toRational y)
  roundedFusedMultiplyAdd r (ViaRational x) (ViaRational y) (ViaRational z)
    | isNaN x || isNaN y || isNaN z || isInfinite x || isInfinite y || isInfinite z = ViaRational (x * y + z)
    | otherwise = case toRational x * toRational y + toRational z of
                    0 -> if z == 0 && isNegativeZero (x * y) == isNegativeZero z
                         then ViaRational z
                         else ViaRational roundedZero
                    w -> roundedFromRational r w
      where roundedZero = case r of
              ToNearest    ->  0
              TowardNegInf -> -0
              TowardInf    ->  0
              TowardZero   ->  0
  roundedFromInteger r x = ViaRational (fromInt r x)
  intervalFromInteger x = case fromIntF x :: Product (Rounded 'TowardNegInf) (Rounded 'TowardInf) a of
    Pair a b -> (ViaRational <$> a, ViaRational <$> b)
  backendNameT = Tagged "via Rational"
  {-# INLINE roundedFromInteger #-}
  {-# INLINE intervalFromInteger #-}
  {-# SPECIALIZE instance RoundedRing (ViaRational Float) #-}
  {-# SPECIALIZE instance RoundedRing (ViaRational Double) #-}

instance (RealFloat a, Num a, RealFloatConstants a) => RoundedFractional (ViaRational a) where
  roundedDiv r (ViaRational x) (ViaRational y)
    | isNaN x || isNaN y || isInfinite x || isInfinite y || x == 0 || y == 0 = ViaRational (x / y)
    | otherwise = roundedFromRational r (toRational x / toRational y)
  roundedFromRational r x = ViaRational $ fromRatio r (numerator x) (denominator x)
  roundedFromRealFloat r x | isNaN x = ViaRational (0/0)
                           | isInfinite x = ViaRational (if x > 0 then 1/0 else -1/0)
                           | isNegativeZero x = ViaRational (-0)
                           | otherwise = roundedFromRational r (toRational x)
  intervalFromRational x = case fromRatioF (numerator x) (denominator x) :: Product (Rounded 'TowardNegInf) (Rounded 'TowardInf) a of
    Pair a b -> (ViaRational <$> a, ViaRational <$> b)
  {-# INLINE roundedFromRational #-}
  {-# INLINE intervalFromRational #-}
  {-# SPECIALIZE instance RoundedFractional (ViaRational Float) #-}
  {-# SPECIALIZE instance RoundedFractional (ViaRational Double) #-}

instance (RealFloat a, RealFloatConstants a) => RoundedSqrt (ViaRational a) where
  roundedSqrt r (ViaRational x)
    | r /= ToNearest && x >= 0 = ViaRational $
      case compare ((toRational y) ^ (2 :: Int)) (toRational x) of
        LT | r == TowardInf -> let z = nextUp y
                               in assert (toRational x < (toRational z) ^ (2 :: Int)) z
           | otherwise -> y
        EQ -> y
        GT | r == TowardInf -> y
           | otherwise -> let z = nextDown y
                          in assert ((toRational z) ^ (2 :: Int) < toRational x) z
    | otherwise = ViaRational y
    where y = sqrt x

instance (RealFloat a, RealFloatConstants a, Storable a) => RoundedRing_Vector VS.Vector (ViaRational a)
instance (RealFloat a, RealFloatConstants a, Storable a) => RoundedFractional_Vector VS.Vector (ViaRational a)
instance (RealFloat a, RealFloatConstants a, Storable a) => RoundedSqrt_Vector VS.Vector (ViaRational a)
instance (RealFloat a, RealFloatConstants a, VU.Unbox a) => RoundedRing_Vector VU.Vector (ViaRational a)
instance (RealFloat a, RealFloatConstants a, VU.Unbox a) => RoundedFractional_Vector VU.Vector (ViaRational a)
instance (RealFloat a, RealFloatConstants a, VU.Unbox a) => RoundedSqrt_Vector VU.Vector (ViaRational a)

--
-- instance for Data.Vector.Unboxed.Unbox
--

newtype instance VUM.MVector s (ViaRational a) = MV_ViaRational (VUM.MVector s a)
newtype instance VU.Vector (ViaRational a) = V_ViaRational (VU.Vector a)

instance VU.Unbox a => VGM.MVector VUM.MVector (ViaRational a) where
  basicLength (MV_ViaRational mv) = VGM.basicLength mv
  basicUnsafeSlice i l (MV_ViaRational mv) = MV_ViaRational (VGM.basicUnsafeSlice i l mv)
  basicOverlaps (MV_ViaRational mv) (MV_ViaRational mv') = VGM.basicOverlaps mv mv'
  basicUnsafeNew l = MV_ViaRational <$> VGM.basicUnsafeNew l
  basicInitialize (MV_ViaRational mv) = VGM.basicInitialize mv
  basicUnsafeReplicate i x = MV_ViaRational <$> VGM.basicUnsafeReplicate i (coerce x)
  basicUnsafeRead (MV_ViaRational mv) i = coerce <$> VGM.basicUnsafeRead mv i
  basicUnsafeWrite (MV_ViaRational mv) i x = VGM.basicUnsafeWrite mv i (coerce x)
  basicClear (MV_ViaRational mv) = VGM.basicClear mv
  basicSet (MV_ViaRational mv) x = VGM.basicSet mv (coerce x)
  basicUnsafeCopy (MV_ViaRational mv) (MV_ViaRational mv') = VGM.basicUnsafeCopy mv mv'
  basicUnsafeMove (MV_ViaRational mv) (MV_ViaRational mv') = VGM.basicUnsafeMove mv mv'
  basicUnsafeGrow (MV_ViaRational mv) n = MV_ViaRational <$> VGM.basicUnsafeGrow mv n

instance VU.Unbox a => VG.Vector VU.Vector (ViaRational a) where
  basicUnsafeFreeze (MV_ViaRational mv) = V_ViaRational <$> VG.basicUnsafeFreeze mv
  basicUnsafeThaw (V_ViaRational v) = MV_ViaRational <$> VG.basicUnsafeThaw v
  basicLength (V_ViaRational v) = VG.basicLength v
  basicUnsafeSlice i l (V_ViaRational v) = V_ViaRational (VG.basicUnsafeSlice i l v)
  basicUnsafeIndexM (V_ViaRational v) i = coerce <$> VG.basicUnsafeIndexM v i
  basicUnsafeCopy (MV_ViaRational mv) (V_ViaRational v) = VG.basicUnsafeCopy mv v
  elemseq (V_ViaRational v) x y = VG.elemseq v (coerce x) y

instance VU.Unbox a => VU.Unbox (ViaRational a)