{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.Classes.RealFrac (
RealFrac(..),
div', mod', divMod',
) where
import Data.Array.Accelerate.Language ( (^), cond, even )
import Data.Array.Accelerate.Lift ( unlift )
import Data.Array.Accelerate.Pattern
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Classes.Eq
import Data.Array.Accelerate.Classes.Ord
import Data.Array.Accelerate.Classes.Floating
import Data.Array.Accelerate.Classes.Fractional
import Data.Array.Accelerate.Classes.FromIntegral
import Data.Array.Accelerate.Classes.Integral
import Data.Array.Accelerate.Classes.Num
import Data.Array.Accelerate.Classes.ToFloating
import {-# SOURCE #-} Data.Array.Accelerate.Classes.RealFloat
import Data.Maybe
import Text.Printf
import Prelude ( ($), String, error, unlines, otherwise )
import qualified Prelude as P
div' :: (RealFrac a, FromIntegral Int64 b, Integral b) => Exp a -> Exp a -> Exp b
div' n d = floor (n / d)
mod' :: (Floating a, RealFrac a, ToFloating Int64 a) => Exp a -> Exp a -> Exp a
mod' n d = n - (toFloating f) * d
where
f :: Exp Int64
f = div' n d
divMod'
:: (Floating a, RealFrac a, Integral b, FromIntegral Int64 b, ToFloating b a)
=> Exp a
-> Exp a
-> (Exp b, Exp a)
divMod' n d = (f, n - (toFloating f) * d)
where
f = div' n d
class (Ord a, Fractional a) => RealFrac a where
properFraction :: (Integral b, FromIntegral Int64 b) => Exp a -> (Exp b, Exp a)
truncate :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b
truncate = defaultTruncate
round :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b
round = defaultRound
ceiling :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b
ceiling = defaultCeiling
floor :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b
floor = defaultFloor
instance RealFrac Half where
properFraction = defaultProperFraction
instance RealFrac Float where
properFraction = defaultProperFraction
instance RealFrac Double where
properFraction = defaultProperFraction
instance RealFrac CFloat where
properFraction = defaultProperFraction
truncate = defaultTruncate
round = defaultRound
ceiling = defaultCeiling
floor = defaultFloor
instance RealFrac CDouble where
properFraction = defaultProperFraction
truncate = defaultTruncate
round = defaultRound
ceiling = defaultCeiling
floor = defaultFloor
defaultProperFraction
:: (RealFloat a, FromIntegral Int64 b, Integral b)
=> Exp a
-> (Exp b, Exp a)
defaultProperFraction x
= unlift
$ cond (n >= 0)
(T2 (fromIntegral m * (2 ^ n)) 0.0)
(T2 (fromIntegral q) (encodeFloat r n))
where
(m, n) = decodeFloat x
(q, r) = quotRem m (2 ^ (negate n))
defaultTruncate :: forall a b. (RealFrac a, Integral b, FromIntegral Int64 b) => Exp a -> Exp b
defaultTruncate x
| Just IsFloatingDict <- isFloating @a
, Just IsIntegralDict <- isIntegral @b
= mkTruncate x
| otherwise
= let (n, _) = properFraction x in n
defaultCeiling :: forall a b. (RealFrac a, Integral b, FromIntegral Int64 b) => Exp a -> Exp b
defaultCeiling x
| Just IsFloatingDict <- isFloating @a
, Just IsIntegralDict <- isIntegral @b
= mkCeiling x
| otherwise
= let (n, r) = properFraction x in cond (r > 0) (n+1) n
defaultFloor :: forall a b. (RealFrac a, Integral b, FromIntegral Int64 b) => Exp a -> Exp b
defaultFloor x
| Just IsFloatingDict <- isFloating @a
, Just IsIntegralDict <- isIntegral @b
= mkFloor x
| otherwise
= let (n, r) = properFraction x in cond (r < 0) (n-1) n
defaultRound :: forall a b. (RealFrac a, Integral b, FromIntegral Int64 b) => Exp a -> Exp b
defaultRound x
| Just IsFloatingDict <- isFloating @a
, Just IsIntegralDict <- isIntegral @b
= mkRound x
| otherwise
= let (n, r) = properFraction x
m = cond (r < 0.0) (n-1) (n+1)
half_down = abs r - 0.5
p = compare half_down 0.0
in
cond (constant LT == p) n $
cond (constant EQ == p) (cond (even n) n m) $
m
data IsFloatingDict a where
IsFloatingDict :: IsFloating a => IsFloatingDict a
data IsIntegralDict a where
IsIntegralDict :: IsIntegral a => IsIntegralDict a
isFloating :: forall a. Elt a => Maybe (IsFloatingDict (EltR a))
isFloating
| TupRsingle t <- eltR @a
, SingleScalarType s <- t
, NumSingleType n <- s
, FloatingNumType f <- n
= case f of
TypeHalf{} -> Just IsFloatingDict
TypeFloat{} -> Just IsFloatingDict
TypeDouble{} -> Just IsFloatingDict
| otherwise
= Nothing
isIntegral :: forall a. Elt a => Maybe (IsIntegralDict (EltR a))
isIntegral
| TupRsingle t <- eltR @a
, SingleScalarType s <- t
, NumSingleType n <- s
, IntegralNumType i <- n
= case i of
TypeInt{} -> Just IsIntegralDict
TypeInt8{} -> Just IsIntegralDict
TypeInt16{} -> Just IsIntegralDict
TypeInt32{} -> Just IsIntegralDict
TypeInt64{} -> Just IsIntegralDict
TypeWord{} -> Just IsIntegralDict
TypeWord8{} -> Just IsIntegralDict
TypeWord16{} -> Just IsIntegralDict
TypeWord32{} -> Just IsIntegralDict
TypeWord64{} -> Just IsIntegralDict
| otherwise
= Nothing
instance RealFrac a => P.RealFrac (Exp a) where
properFraction = preludeError "properFraction"
truncate = preludeError "truncate"
round = preludeError "round"
ceiling = preludeError "ceiling"
floor = preludeError "floor"
preludeError :: String -> a
preludeError x
= error
$ unlines [ printf "Prelude.%s applied to EDSL types: use Data.Array.Accelerate.%s instead" x x
, ""
, "These Prelude.RealFrac instances are present only to fulfil superclass"
, "constraints for subsequent classes in the standard Haskell numeric hierarchy."
]