{-# LANGUAGE ConstraintKinds     #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE PatternSynonyms     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module      : Data.Array.Accelerate.Classes.Eq
-- Copyright   : [2016..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Classes.Eq (

  Bool(..), pattern True_, pattern False_,
  Eq(..),
  (&&), (&&!),
  (||), (||!),
  not,

) where

import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Pattern
import Data.Array.Accelerate.Pattern.Bool
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Shape
import Data.Array.Accelerate.Type

import Data.Bool                                                    ( Bool(..) )
import Data.Char                                                    ( Char )
import Text.Printf
import Prelude                                                      ( ($), String, Num(..), show, error, return, concat, map, zipWith, foldr1, mapM )
import Language.Haskell.TH                                          hiding ( Exp )
import Language.Haskell.TH.Extra
import qualified Prelude                                            as P


infix 4 ==
infix 4 /=

-- | Conjunction: True if both arguments are true. This is a short-circuit
-- operator, so the second argument will be evaluated only if the first is true.
--
infixr 3 &&
(&&) :: Exp Bool -> Exp Bool -> Exp Bool
&& :: Exp Bool -> Exp Bool -> Exp Bool
(&&) (Exp SmartExp (EltR Bool)
x) (Exp SmartExp (EltR Bool)
y) =
  PreSmartExp SmartAcc SmartExp (EltR Bool) -> Exp Bool
forall t. PreSmartExp SmartAcc SmartExp (EltR t) -> Exp t
mkExp (PreSmartExp SmartAcc SmartExp (EltR Bool) -> Exp Bool)
-> PreSmartExp SmartAcc SmartExp (EltR Bool) -> Exp Bool
forall a b. (a -> b) -> a -> b
$ PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (SmartExp Word8
-> SmartExp Word8
-> SmartExp Word8
-> PreSmartExp SmartAcc SmartExp Word8
forall (exp :: * -> *) t (acc :: * -> *).
exp Word8 -> exp t -> exp t -> PreSmartExp acc exp t
Cond (PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8)
-> PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall a b. (a -> b) -> a -> b
$ PairIdx (Word8, ()) Word8
-> SmartExp (Word8, ()) -> PreSmartExp SmartAcc SmartExp Word8
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx (Word8, ()) Word8
forall a b. PairIdx (a, b) a
PairIdxLeft SmartExp (Word8, ())
SmartExp (EltR Bool)
x)
                         (PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8)
-> PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall a b. (a -> b) -> a -> b
$ PairIdx (Word8, ()) Word8
-> SmartExp (Word8, ()) -> PreSmartExp SmartAcc SmartExp Word8
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx (Word8, ()) Word8
forall a b. PairIdx (a, b) a
PairIdxLeft SmartExp (Word8, ())
SmartExp (EltR Bool)
y)
                         (PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8)
-> PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall a b. (a -> b) -> a -> b
$ ScalarType Word8 -> Word8 -> PreSmartExp SmartAcc SmartExp Word8
forall t (acc :: * -> *) (exp :: * -> *).
ScalarType t -> t -> PreSmartExp acc exp t
Const ScalarType Word8
scalarTypeWord8 Word8
0))
          SmartExp Word8
-> SmartExp () -> PreSmartExp SmartAcc SmartExp (Word8, ())
forall (exp :: * -> *) t1 t2 (acc :: * -> *).
exp t1 -> exp t2 -> PreSmartExp acc exp (t1, t2)
`Pair` PreSmartExp SmartAcc SmartExp () -> SmartExp ()
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp PreSmartExp SmartAcc SmartExp ()
forall (acc :: * -> *) (exp :: * -> *). PreSmartExp acc exp ()
Nil

-- | Conjunction: True if both arguments are true. This is a strict version of
-- '(&&)': it will always evaluate both arguments, even when the first is false.
--
-- @since 1.3.0.0
--
infixr 3 &&!
(&&!) :: Exp Bool -> Exp Bool -> Exp Bool
&&! :: Exp Bool -> Exp Bool -> Exp Bool
(&&!) = Exp Bool -> Exp Bool -> Exp Bool
mkLAnd

-- | Disjunction: True if either argument is true. This is a short-circuit
-- operator, so the second argument will be evaluated only if the first is
-- false.
--
infixr 2 ||
(||) :: Exp Bool -> Exp Bool -> Exp Bool
|| :: Exp Bool -> Exp Bool -> Exp Bool
(||) (Exp SmartExp (EltR Bool)
x) (Exp SmartExp (EltR Bool)
y) =
  PreSmartExp SmartAcc SmartExp (EltR Bool) -> Exp Bool
forall t. PreSmartExp SmartAcc SmartExp (EltR t) -> Exp t
mkExp (PreSmartExp SmartAcc SmartExp (EltR Bool) -> Exp Bool)
-> PreSmartExp SmartAcc SmartExp (EltR Bool) -> Exp Bool
forall a b. (a -> b) -> a -> b
$ PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (SmartExp Word8
-> SmartExp Word8
-> SmartExp Word8
-> PreSmartExp SmartAcc SmartExp Word8
forall (exp :: * -> *) t (acc :: * -> *).
exp Word8 -> exp t -> exp t -> PreSmartExp acc exp t
Cond (PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8)
-> PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall a b. (a -> b) -> a -> b
$ PairIdx (Word8, ()) Word8
-> SmartExp (Word8, ()) -> PreSmartExp SmartAcc SmartExp Word8
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx (Word8, ()) Word8
forall a b. PairIdx (a, b) a
PairIdxLeft SmartExp (Word8, ())
SmartExp (EltR Bool)
x)
                         (PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8)
-> PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall a b. (a -> b) -> a -> b
$ ScalarType Word8 -> Word8 -> PreSmartExp SmartAcc SmartExp Word8
forall t (acc :: * -> *) (exp :: * -> *).
ScalarType t -> t -> PreSmartExp acc exp t
Const ScalarType Word8
scalarTypeWord8 Word8
1)
                         (PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8)
-> PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall a b. (a -> b) -> a -> b
$ PairIdx (Word8, ()) Word8
-> SmartExp (Word8, ()) -> PreSmartExp SmartAcc SmartExp Word8
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx (Word8, ()) Word8
forall a b. PairIdx (a, b) a
PairIdxLeft SmartExp (Word8, ())
SmartExp (EltR Bool)
y))
          SmartExp Word8
-> SmartExp () -> PreSmartExp SmartAcc SmartExp (Word8, ())
forall (exp :: * -> *) t1 t2 (acc :: * -> *).
exp t1 -> exp t2 -> PreSmartExp acc exp (t1, t2)
`Pair` PreSmartExp SmartAcc SmartExp () -> SmartExp ()
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp PreSmartExp SmartAcc SmartExp ()
forall (acc :: * -> *) (exp :: * -> *). PreSmartExp acc exp ()
Nil


-- | Disjunction: True if either argument is true. This is a strict version of
-- '(||)': it will always evaluate both arguments, even when the first is true.
--
-- @since 1.3.0.0
--
infixr 2 ||!
(||!) :: Exp Bool -> Exp Bool -> Exp Bool
||! :: Exp Bool -> Exp Bool -> Exp Bool
(||!) = Exp Bool -> Exp Bool -> Exp Bool
mkLOr

-- | Logical negation
--
not :: Exp Bool -> Exp Bool
not :: Exp Bool -> Exp Bool
not = Exp Bool -> Exp Bool
mkLNot


-- | The 'Eq' class defines equality '==' and inequality '/=' for scalar
-- Accelerate expressions.
--
-- For convenience, we include 'Elt' as a superclass.
--
class Elt a => Eq a where
  (==) :: Exp a -> Exp a -> Exp Bool
  (/=) :: Exp a -> Exp a -> Exp Bool
  {-# MINIMAL (==) | (/=) #-}
  Exp a
x == Exp a
y = Exp Bool -> Exp Bool
mkLNot (Exp a
x Exp a -> Exp a -> Exp Bool
forall a. Eq a => Exp a -> Exp a -> Exp Bool
/= Exp a
y)
  Exp a
x /= Exp a
y = Exp Bool -> Exp Bool
mkLNot (Exp a
x Exp a -> Exp a -> Exp Bool
forall a. Eq a => Exp a -> Exp a -> Exp Bool
== Exp a
y)


instance Eq () where
  Exp ()
_ == :: Exp () -> Exp () -> Exp Bool
== Exp ()
_ = Exp Bool
HasCallStack => Exp Bool
True_
  Exp ()
_ /= :: Exp () -> Exp () -> Exp Bool
/= Exp ()
_ = Exp Bool
HasCallStack => Exp Bool
False_

instance Eq Z where
  Exp Z
_ == :: Exp Z -> Exp Z -> Exp Bool
== Exp Z
_ = Exp Bool
HasCallStack => Exp Bool
True_
  Exp Z
_ /= :: Exp Z -> Exp Z -> Exp Bool
/= Exp Z
_ = Exp Bool
HasCallStack => Exp Bool
False_

instance Eq sh => Eq (sh :. Int) where
  Exp (sh :. Int)
x == :: Exp (sh :. Int) -> Exp (sh :. Int) -> Exp Bool
== Exp (sh :. Int)
y = Exp (sh :. Int) -> Exp Int
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp a
indexHead Exp (sh :. Int)
x Exp Int -> Exp Int -> Exp Bool
forall a. Eq a => Exp a -> Exp a -> Exp Bool
== Exp (sh :. Int) -> Exp Int
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp a
indexHead Exp (sh :. Int)
y Exp Bool -> Exp Bool -> Exp Bool
&& Exp (sh :. Int) -> Exp sh
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp sh
indexTail Exp (sh :. Int)
x Exp sh -> Exp sh -> Exp Bool
forall a. Eq a => Exp a -> Exp a -> Exp Bool
== Exp (sh :. Int) -> Exp sh
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp sh
indexTail Exp (sh :. Int)
y
  Exp (sh :. Int)
x /= :: Exp (sh :. Int) -> Exp (sh :. Int) -> Exp Bool
/= Exp (sh :. Int)
y = Exp (sh :. Int) -> Exp Int
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp a
indexHead Exp (sh :. Int)
x Exp Int -> Exp Int -> Exp Bool
forall a. Eq a => Exp a -> Exp a -> Exp Bool
/= Exp (sh :. Int) -> Exp Int
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp a
indexHead Exp (sh :. Int)
y Exp Bool -> Exp Bool -> Exp Bool
|| Exp (sh :. Int) -> Exp sh
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp sh
indexTail Exp (sh :. Int)
x Exp sh -> Exp sh -> Exp Bool
forall a. Eq a => Exp a -> Exp a -> Exp Bool
/= Exp (sh :. Int) -> Exp sh
forall sh a. (Elt sh, Elt a) => Exp (sh :. a) -> Exp sh
indexTail Exp (sh :. Int)
y

instance Eq Bool where
  Exp Bool
x == :: Exp Bool -> Exp Bool -> Exp Bool
== Exp Bool
y = Exp Bool -> Exp Word8
forall a b. Coerce (EltR a) (EltR b) => Exp a -> Exp b
mkCoerce Exp Bool
x Exp Word8 -> Exp Word8 -> Exp Bool
forall a. Eq a => Exp a -> Exp a -> Exp Bool
== (Exp Bool -> Exp Word8
forall a b. Coerce (EltR a) (EltR b) => Exp a -> Exp b
mkCoerce Exp Bool
y :: Exp PrimBool)
  Exp Bool
x /= :: Exp Bool -> Exp Bool -> Exp Bool
/= Exp Bool
y = Exp Bool -> Exp Word8
forall a b. Coerce (EltR a) (EltR b) => Exp a -> Exp b
mkCoerce Exp Bool
x Exp Word8 -> Exp Word8 -> Exp Bool
forall a. Eq a => Exp a -> Exp a -> Exp Bool
/= (Exp Bool -> Exp Word8
forall a b. Coerce (EltR a) (EltR b) => Exp a -> Exp b
mkCoerce Exp Bool
y :: Exp PrimBool)

-- Instances of 'Prelude.Eq' don't make sense with the standard signatures as
-- the return type is fixed to 'Bool'. This instance is provided to provide
-- a useful error message.
--
instance P.Eq (Exp a) where
  == :: Exp a -> Exp a -> Bool
(==) = String -> String -> Exp a -> Exp a -> Bool
forall a. String -> String -> a
preludeError String
"Eq.(==)" String
"(==)"
  /= :: Exp a -> Exp a -> Bool
(/=) = String -> String -> Exp a -> Exp a -> Bool
forall a. String -> String -> a
preludeError String
"Eq.(/=)" String
"(/=)"

preludeError :: String -> String -> a
preludeError :: String -> String -> a
preludeError String
x String
y = String -> a
forall a. HasCallStack => String -> a
error (String -> String -> String -> String
forall r. PrintfType r => String -> r
printf String
"Prelude.%s applied to EDSL types: use Data.Array.Accelerate.%s instead" String
x String
y)

$(runQ $ do
    let
        integralTypes :: [Name]
        integralTypes =
          [ ''Int
          , ''Int8
          , ''Int16
          , ''Int32
          , ''Int64
          , ''Word
          , ''Word8
          , ''Word16
          , ''Word32
          , ''Word64
          ]

        floatingTypes :: [Name]
        floatingTypes =
          [ ''Half
          , ''Float
          , ''Double
          ]

        nonNumTypes :: [Name]
        nonNumTypes =
          [ ''Char
          ]

        cTypes :: [Name]
        cTypes =
          [ ''CInt
          , ''CUInt
          , ''CLong
          , ''CULong
          , ''CLLong
          , ''CULLong
          , ''CShort
          , ''CUShort
          , ''CChar
          , ''CUChar
          , ''CSChar
          , ''CFloat
          , ''CDouble
          ]

        mkPrim :: Name -> Q [Dec]
        mkPrim t =
          [d| instance Eq $(conT t) where
                (==) = mkEq
                (/=) = mkNEq
            |]

        mkTup :: Int -> Q [Dec]
        mkTup n =
          let
              xs      = [ mkName ('x':show i) | i <- [0 .. n-1] ]
              ys      = [ mkName ('y':show i) | i <- [0 .. n-1] ]
              cst     = tupT (map (\x -> [t| Eq $(varT x) |]) xs)
              res     = tupT (map varT xs)
              pat vs  = conP (mkName ('T':show n)) (map varP vs)
          in
          [d| instance ($cst) => Eq $res where
                $(pat xs) == $(pat ys) = $(foldr1 (\vs v -> [| $vs && $v |]) (zipWith (\x y -> [| $x == $y |]) (map varE xs) (map varE ys)))
                $(pat xs) /= $(pat ys) = $(foldr1 (\vs v -> [| $vs || $v |]) (zipWith (\x y -> [| $x /= $y |]) (map varE xs) (map varE ys)))
            |]

    is <- mapM mkPrim integralTypes
    fs <- mapM mkPrim floatingTypes
    ns <- mapM mkPrim nonNumTypes
    cs <- mapM mkPrim cTypes
    ts <- mapM mkTup [2..16]
    return $ concat (concat [is,fs,ns,cs,ts])
 )