{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.Data.Complex (
Complex(..),
real,
imag,
mkPolar,
cis,
polar,
magnitude,
phase,
conjugate,
) where
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Classes as A
import Data.Array.Accelerate.Language
import Data.Array.Accelerate.Lift
import Data.Array.Accelerate.Product
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Type
import Prelude ( ($), undefined, fromInteger )
import Data.Complex ( Complex(..) )
import qualified Data.Complex as C
import qualified Prelude as P
type instance EltRepr (Complex a) = EltRepr (a, a)
instance Elt a => Elt (Complex a) where
eltType _ = eltType (undefined :: (a,a))
toElt p = let (a, b) = toElt p in a :+ b
fromElt (a :+ b) = fromElt (a, b)
instance cst a => IsProduct cst (Complex a) where
type ProdRepr (Complex a) = ProdRepr (a, a)
fromProd cst (x :+ y) = fromProd cst (x, y)
toProd cst p = let (x, y) = toProd cst p in (x :+ y)
prod cst _ = prod cst (undefined :: (a, a))
instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Complex a) where
type Plain (Complex a) = Complex (Plain a)
lift (x1 :+ x2) = Exp $ Tuple (NilTup `SnocTup` lift x1 `SnocTup` lift x2)
instance Elt a => Unlift Exp (Complex (Exp a)) where
unlift e
= let x = Exp $ SuccTupIdx ZeroTupIdx `Prj` e
y = Exp $ ZeroTupIdx `Prj` e
in
x :+ y
instance A.Eq a => A.Eq (Complex a) where
x == y = let r1 :+ c1 = unlift x
r2 :+ c2 = unlift y
in r1 == r2 && c1 == c2
x /= y = let r1 :+ c1 = unlift x
r2 :+ c2 = unlift y
in r1 /= r2 || c1 /= c2
instance A.RealFloat a => P.Num (Exp (Complex a)) where
(+) = lift2 ((+) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a))
(-) = lift2 ((-) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a))
(*) = lift2 ((*) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a))
negate = lift1 (negate :: Complex (Exp a) -> Complex (Exp a))
signum = lift1 (signum :: Complex (Exp a) -> Complex (Exp a))
abs = lift1 (abs :: Complex (Exp a) -> Complex (Exp a))
fromInteger n = lift ((fromInteger n :: Exp a) :+ 0)
instance A.RealFloat a => P.Fractional (Exp (Complex a)) where
c / c'
= let x :+ y = unlift c
x' :+ y' = unlift c' :: Complex (Exp a)
den = x' P.^ (2 :: Int) + y' P.^ (2 :: Int)
re = (x * x' + y * y') / den
im = (y * x' - x * y') / den
in
lift (re :+ im)
fromRational x
= lift ((fromRational x :: Exp a) :+ 0)
instance A.RealFloat a => P.Floating (Exp (Complex a)) where
sqrt z
= let
x :+ y = unlift z
v' = abs y / (u'*2)
u' = sqrt ((magnitude z + abs x) / 2)
(u, v) = unlift $ cond (x < 0) (lift (v',u')) (lift (u',v'))
in
cond (x == 0 && y == 0)
0
(lift (u :+ (cond (y < 0) (-v) v)))
pi = lift (pi :: Complex (Exp a))
log z = lift (log (magnitude z) :+ phase z)
exp = lift1 (exp :: Complex (Exp a) -> Complex (Exp a))
sin = lift1 (sin :: Complex (Exp a) -> Complex (Exp a))
cos = lift1 (cos :: Complex (Exp a) -> Complex (Exp a))
tan = lift1 (tan :: Complex (Exp a) -> Complex (Exp a))
sinh = lift1 (sinh :: Complex (Exp a) -> Complex (Exp a))
cosh = lift1 (cosh :: Complex (Exp a) -> Complex (Exp a))
tanh = lift1 (tanh :: Complex (Exp a) -> Complex (Exp a))
asin = lift1 (asin :: Complex (Exp a) -> Complex (Exp a))
acos = lift1 (acos :: Complex (Exp a) -> Complex (Exp a))
atan = lift1 (atan :: Complex (Exp a) -> Complex (Exp a))
asinh = lift1 (asinh :: Complex (Exp a) -> Complex (Exp a))
acosh = lift1 (acosh :: Complex (Exp a) -> Complex (Exp a))
atanh = lift1 (atanh :: Complex (Exp a) -> Complex (Exp a))
instance (A.FromIntegral a b, A.Num b) => A.FromIntegral a (Complex b) where
fromIntegral x = lift (fromIntegral x :+ (0 :: Exp b))
magnitude :: RealFloat a => Exp (Complex a) -> Exp a
magnitude c =
let r :+ i = unlift c
in sqrt (r*r + i*i)
phase :: RealFloat a => Exp (Complex a) -> Exp a
phase c =
let x :+ y = unlift c
in atan2 y x
polar :: RealFloat a => Exp (Complex a) -> Exp (a,a)
polar z = lift (magnitude z, phase z)
#if __GLASGOW_HASKELL__ <= 708
mkPolar :: forall a. RealFloat a => Exp a -> Exp a -> Exp (Complex a)
#else
mkPolar :: forall a. Floating a => Exp a -> Exp a -> Exp (Complex a)
#endif
mkPolar = lift2 (C.mkPolar :: Exp a -> Exp a -> Complex (Exp a))
#if __GLASGOW_HASKELL__ <= 708
cis :: forall a. RealFloat a => Exp a -> Exp (Complex a)
#else
cis :: forall a. Floating a => Exp a -> Exp (Complex a)
#endif
cis = lift1 (C.cis :: Exp a -> Complex (Exp a))
real :: Elt a => Exp (Complex a) -> Exp a
real c =
let r :+ _ = unlift c
in r
imag :: Elt a => Exp (Complex a) -> Exp a
imag c =
let _ :+ i = unlift c
in i
conjugate :: Num a => Exp (Complex a) -> Exp (Complex a)
conjugate z = lift $ real z :+ (- imag z)