{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Numeric.Rounded.Hardware.Interval
( Interval(..)
, increasing
, maxI
, minI
, powInt
, null
, inf
, sup
, width
, widthUlp
, hull
, intersection
) where
import Control.DeepSeq (NFData (..))
import Control.Monad
import Control.Monad.ST
import qualified Data.Array.Base as A
import Data.Coerce
import Data.Ix
import Data.Primitive
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import GHC.Float (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import Numeric.Rounded.Hardware.Internal
import qualified Numeric.Rounded.Hardware.Interval.Class as C
import qualified Numeric.Rounded.Hardware.Interval.NonEmpty as NE
import Prelude hiding (null)
data Interval a
= I !(Rounded 'TowardNegInf a) !(Rounded 'TowardInf a)
| Empty
deriving (Show,Generic)
instance NFData a => NFData (Interval a)
increasing :: (forall r. Rounding r => Rounded r a -> Rounded r a) -> Interval a -> Interval a
increasing f (I a b) = I (f a) (f b)
increasing _ Empty = Empty
{-# INLINE increasing #-}
instance (Num a, RoundedRing a) => Num (Interval a) where
(+) = liftBinaryNE (+)
(-) = liftBinaryNE (-)
negate = liftUnaryNE negate
(*) = liftBinaryNE (*)
abs = liftUnaryNE abs
signum = liftUnaryNE signum
fromInteger x = case intervalFromInteger x of
(y, y') -> I y y'
{-# INLINE (+) #-}
{-# INLINE (-) #-}
{-# INLINE negate #-}
{-# INLINE (*) #-}
{-# INLINE abs #-}
{-# INLINE signum #-}
{-# INLINE fromInteger #-}
instance (Num a, RoundedFractional a) => Fractional (Interval a) where
recip = liftUnaryNE recip
(/) = liftBinaryNE (/)
fromRational x = case intervalFromRational x of
(y, y') -> I y y'
{-# INLINE recip #-}
{-# INLINE (/) #-}
{-# INLINE fromRational #-}
maxI :: Ord a => Interval a -> Interval a -> Interval a
maxI (I a a') (I b b') = I (max a b) (max a' b')
maxI _ _ = Empty
{-# SPECIALIZE maxI :: Interval Float -> Interval Float -> Interval Float #-}
{-# SPECIALIZE maxI :: Interval Double -> Interval Double -> Interval Double #-}
minI :: Ord a => Interval a -> Interval a -> Interval a
minI (I a a') (I b b') = I (min a b) (min a' b')
minI _ _ = Empty
{-# SPECIALIZE minI :: Interval Float -> Interval Float -> Interval Float #-}
{-# SPECIALIZE minI :: Interval Double -> Interval Double -> Interval Double #-}
powInt :: (Ord a, Num a, RoundedRing a) => Interval a -> Int -> Interval a
powInt (I a a') n | odd n || 0 <= a = I (a^n) (a'^n)
| a' <= 0 = I ((coerce (abs a'))^n) ((coerce (abs a))^n)
| otherwise = I 0 (max ((coerce (abs a))^n) (a'^n))
powInt Empty _ = Empty
{-# SPECIALIZE powInt :: Interval Float -> Int -> Interval Float #-}
{-# SPECIALIZE powInt :: Interval Double -> Int -> Interval Double #-}
null :: Interval a -> Bool
null Empty = True
null _ = False
inf :: Interval a -> Rounded 'TowardNegInf a
inf (I x _) = x
inf _ = error "empty interval"
sup :: Interval a -> Rounded 'TowardInf a
sup (I _ y) = y
sup _ = error "empty interval"
width :: (Num a, RoundedRing a) => Interval a -> Rounded 'TowardInf a
width (I x y) = y - coerce x
width Empty = 0
widthUlp :: (RealFloat a) => Interval a -> Maybe Integer
widthUlp (I x y) = distanceUlp (getRounded x) (getRounded y)
widthUlp Empty = Just 0
hull :: RoundedRing a => Interval a -> Interval a -> Interval a
hull (I x y) (I x' y') = I (min x x') (max y y')
hull Empty v = v
hull u Empty = u
intersection :: RoundedRing a => Interval a -> Interval a -> Interval a
intersection (I x y) (I x' y') | getRounded x'' <= getRounded y'' = I x'' y''
where x'' = max x x'
y'' = min y y'
intersection _ _ = Empty
liftUnaryNE :: (NE.Interval a -> NE.Interval a) -> Interval a -> Interval a
liftUnaryNE f (I x x') = case f (NE.I x x') of
NE.I y y' -> I y y'
liftUnaryNE _f Empty = Empty
{-# INLINE [1] liftUnaryNE #-}
liftBinaryNE :: (NE.Interval a -> NE.Interval a -> NE.Interval a) -> Interval a -> Interval a -> Interval a
liftBinaryNE f (I x x') (I y y') = case f (NE.I x x') (NE.I y y') of
NE.I z z' -> I z z'
liftBinaryNE _f _ _ = Empty
{-# INLINE [1] liftBinaryNE #-}
instance (Num a, RoundedFractional a, RoundedSqrt a, Eq a, RealFloat a, RealFloatConstants a) => Floating (Interval a) where
pi = I pi_down pi_up
exp = liftUnaryNE exp
log = liftUnaryNE log
sqrt = liftUnaryNE sqrt
(**) = liftBinaryNE (**)
logBase = liftBinaryNE logBase
sin = liftUnaryNE sin
cos = liftUnaryNE cos
tan = liftUnaryNE tan
asin = liftUnaryNE asin
acos = liftUnaryNE acos
atan = liftUnaryNE atan
sinh = liftUnaryNE sinh
cosh = liftUnaryNE cosh
tanh = liftUnaryNE tanh
asinh = liftUnaryNE asinh
acosh = liftUnaryNE acosh
atanh = liftUnaryNE atanh
log1p = liftUnaryNE log1p
expm1 = liftUnaryNE expm1
log1pexp = liftUnaryNE log1pexp
log1mexp = liftUnaryNE log1mexp
{-# INLINE exp #-}
{-# INLINE log #-}
{-# INLINE sqrt #-}
{-# INLINE (**) #-}
{-# INLINE logBase #-}
{-# INLINE sin #-}
{-# INLINE cos #-}
{-# INLINE tan #-}
{-# INLINE asin #-}
{-# INLINE acos #-}
{-# INLINE atan #-}
{-# INLINE sinh #-}
{-# INLINE cosh #-}
{-# INLINE tanh #-}
{-# INLINE asinh #-}
{-# INLINE acosh #-}
{-# INLINE atanh #-}
{-# INLINE log1p #-}
{-# INLINE expm1 #-}
{-# INLINE log1pexp #-}
{-# INLINE log1mexp #-}
instance (Num a, RoundedRing a, RealFloat a) => C.IsInterval (Interval a) where
type EndPoint (Interval a) = a
makeInterval = I
width = width
withEndPoints f (I x y) = f x y
withEndPoints _ Empty = Empty
hull = hull
intersection = intersection
maybeIntersection x y = case intersection x y of
Empty -> Nothing
z -> Just z
equalAsSet (I x y) (I x' y') = x == x' && y == y'
equalAsSet Empty Empty = True
equalAsSet _ _ = False
subset (I x y) (I x' y') = x' <= x && y <= y'
subset Empty _ = True
subset I{} Empty = False
weaklyLess (I x y) (I x' y') = x <= x' && y <= y'
weaklyLess Empty Empty = True
weaklyLess _ _ = False
precedes (I _ y) (I x' _) = getRounded y <= getRounded x'
precedes _ _ = True
interior (I x y) (I x' y') = getRounded x' <# getRounded x && getRounded y <# getRounded y'
where s <# t = s < t || (s == t && isInfinite s)
interior Empty _ = True
interior I{} Empty = False
strictLess (I x y) (I x' y') = getRounded x <# getRounded x' && getRounded y <# getRounded y'
where s <# t = s < t || (s == t && isInfinite s)
strictLess Empty Empty = True
strictLess _ _ = False
strictPrecedes (I _ y) (I x' _) = getRounded y < getRounded x'
strictPrecedes _ _ = True
disjoint (I x y) (I x' y') = getRounded y < getRounded x' || getRounded y' < getRounded x
disjoint _ _ = True
newtype instance VUM.MVector s (Interval a) = MV_Interval (VUM.MVector s (a, a))
newtype instance VU.Vector (Interval a) = V_Interval (VU.Vector (a, a))
intervalToPair :: Fractional a => Interval a -> (a, a)
intervalToPair (I (Rounded x) (Rounded y)) = (x, y)
intervalToPair Empty = (1/0, -1/0)
{-# INLINE intervalToPair #-}
pairToInterval :: Ord a => (a, a) -> Interval a
pairToInterval (x, y) | y < x = Empty
| otherwise = I (Rounded x) (Rounded y)
{-# INLINE pairToInterval #-}
instance (VU.Unbox a, Ord a, Fractional a) => VGM.MVector VUM.MVector (Interval a) where
basicLength (MV_Interval mv) = VGM.basicLength mv
basicUnsafeSlice i l (MV_Interval mv) = MV_Interval (VGM.basicUnsafeSlice i l mv)
basicOverlaps (MV_Interval mv) (MV_Interval mv') = VGM.basicOverlaps mv mv'
basicUnsafeNew l = MV_Interval <$> VGM.basicUnsafeNew l
basicInitialize (MV_Interval mv) = VGM.basicInitialize mv
basicUnsafeReplicate i x = MV_Interval <$> VGM.basicUnsafeReplicate i (intervalToPair x)
basicUnsafeRead (MV_Interval mv) i = pairToInterval <$> VGM.basicUnsafeRead mv i
basicUnsafeWrite (MV_Interval mv) i x = VGM.basicUnsafeWrite mv i (intervalToPair x)
basicClear (MV_Interval mv) = VGM.basicClear mv
basicSet (MV_Interval mv) x = VGM.basicSet mv (intervalToPair x)
basicUnsafeCopy (MV_Interval mv) (MV_Interval mv') = VGM.basicUnsafeCopy mv mv'
basicUnsafeMove (MV_Interval mv) (MV_Interval mv') = VGM.basicUnsafeMove mv mv'
basicUnsafeGrow (MV_Interval mv) n = MV_Interval <$> VGM.basicUnsafeGrow mv n
{-# INLINE basicLength #-}
{-# INLINE basicUnsafeSlice #-}
{-# INLINE basicOverlaps #-}
{-# INLINE basicUnsafeNew #-}
{-# INLINE basicInitialize #-}
{-# INLINE basicUnsafeReplicate #-}
{-# INLINE basicUnsafeRead #-}
{-# INLINE basicUnsafeWrite #-}
{-# INLINE basicClear #-}
{-# INLINE basicSet #-}
{-# INLINE basicUnsafeCopy #-}
{-# INLINE basicUnsafeMove #-}
{-# INLINE basicUnsafeGrow #-}
instance (VU.Unbox a, Ord a, Fractional a) => VG.Vector VU.Vector (Interval a) where
basicUnsafeFreeze (MV_Interval mv) = V_Interval <$> VG.basicUnsafeFreeze mv
basicUnsafeThaw (V_Interval v) = MV_Interval <$> VG.basicUnsafeThaw v
basicLength (V_Interval v) = VG.basicLength v
basicUnsafeSlice i l (V_Interval v) = V_Interval (VG.basicUnsafeSlice i l v)
basicUnsafeIndexM (V_Interval v) i = pairToInterval <$> VG.basicUnsafeIndexM v i
basicUnsafeCopy (MV_Interval mv) (V_Interval v) = VG.basicUnsafeCopy mv v
elemseq (V_Interval _) x y = x `seq` y
{-# INLINE basicUnsafeFreeze #-}
{-# INLINE basicUnsafeThaw #-}
{-# INLINE basicLength #-}
{-# INLINE basicUnsafeSlice #-}
{-# INLINE basicUnsafeIndexM #-}
{-# INLINE basicUnsafeCopy #-}
{-# INLINE elemseq #-}
instance (VU.Unbox a, Ord a, Fractional a) => VU.Unbox (Interval a)
instance (Prim a, Ord a, Fractional a) => A.MArray (A.STUArray s) (Interval a) (ST s) where
getBounds (A.STUArray l u _ _) = return (l, u)
getNumElements (A.STUArray _ _ n _) = return n
unsafeNewArray_ = A.newArray_
newArray_ bounds@(l,u) = do
let n = rangeSize bounds
arr@(MutableByteArray arr_) <- newByteArray (2 * sizeOf (undefined :: a) * n)
setByteArray arr 0 (2 * n) (0 :: a)
return (A.STUArray l u n arr_)
unsafeRead (A.STUArray _ _ _ byteArr) i = do
x <- readByteArray (MutableByteArray byteArr) (2 * i)
y <- readByteArray (MutableByteArray byteArr) (2 * i + 1)
return (pairToInterval (x, y))
unsafeWrite (A.STUArray _ _ _ byteArr) i e = do
let (x, y) = intervalToPair e
writeByteArray (MutableByteArray byteArr) (2 * i) x
writeByteArray (MutableByteArray byteArr) (2 * i + 1) y
instance (Prim a, Ord a, Fractional a) => A.IArray A.UArray (Interval a) where
bounds (A.UArray l u _ _) = (l,u)
numElements (A.UArray _ _ n _) = n
unsafeArray bounds el = runST $ do
marr <- A.newArray_ bounds
forM_ el $ \(i,e) -> A.unsafeWrite marr i e
A.unsafeFreezeSTUArray marr
unsafeAt (A.UArray _ _ _ byteArr) i =
let x = indexByteArray (ByteArray byteArr) (2 * i)
y = indexByteArray (ByteArray byteArr) (2 * i + 1)
in pairToInterval (x, y)