{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.Classes.Ord (
Ord(..),
Ordering(..), pattern LT_, pattern EQ_, pattern GT_,
) where
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Pattern
import Data.Array.Accelerate.Pattern.Ordering
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Shape
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Classes.Eq hiding ( (==) )
import qualified Data.Array.Accelerate.Classes.Eq as A
import Data.Char
import Language.Haskell.TH hiding ( Exp )
import Language.Haskell.TH.Extra
import Prelude ( ($), (>>=), Ordering(..), Num(..), Maybe(..), String, show, error, unlines, return, concat, map, mapM )
import Text.Printf
import qualified Prelude as P
infix 4 <
infix 4 >
infix 4 <=
infix 4 >=
class Eq a => Ord a where
{-# MINIMAL (<=) | compare #-}
(<) :: Exp a -> Exp a -> Exp Bool
(>) :: Exp a -> Exp a -> Exp Bool
(<=) :: Exp a -> Exp a -> Exp Bool
(>=) :: Exp a -> Exp a -> Exp Bool
min :: Exp a -> Exp a -> Exp a
max :: Exp a -> Exp a -> Exp a
compare :: Exp a -> Exp a -> Exp Ordering
x < y = if compare x y A.== constant LT then constant True else constant False
x <= y = if compare x y A.== constant GT then constant False else constant True
x > y = if compare x y A.== constant GT then constant True else constant False
x >= y = if compare x y A.== constant LT then constant False else constant True
min x y = if x <= y then x else y
max x y = if x <= y then y else x
compare x y =
if x A.== y then constant EQ else
if x <= y then constant LT
else constant GT
ifThenElse :: Elt a => Exp Bool -> Exp a -> Exp a -> Exp a
ifThenElse (Exp c) (Exp x) (Exp y) = Exp $ SmartExp $ Cond (mkCoerce' c) x y
instance Ord () where
(<) _ _ = constant False
(>) _ _ = constant False
(>=) _ _ = constant True
(<=) _ _ = constant True
min _ _ = constant ()
max _ _ = constant ()
compare _ _ = constant EQ
instance Ord Z where
(<) _ _ = constant False
(>) _ _ = constant False
(<=) _ _ = constant True
(>=) _ _ = constant True
min _ _ = constant Z
max _ _ = constant Z
instance Ord sh => Ord (sh :. Int) where
x <= y = indexHead x <= indexHead y && indexTail x <= indexTail y
x >= y = indexHead x >= indexHead y && indexTail x >= indexTail y
x < y = indexHead x < indexHead y
&& case matchTypeR (eltR @sh) (eltR @Z) of
Just Refl -> constant True
Nothing -> indexTail x < indexTail y
x > y = indexHead x > indexHead y
&& case matchTypeR (eltR @sh) (eltR @Z) of
Just Refl -> constant True
Nothing -> indexTail x > indexTail y
instance Eq Ordering where
x == y = mkCoerce x A.== (mkCoerce y :: Exp TAG)
x /= y = mkCoerce x A./= (mkCoerce y :: Exp TAG)
instance Ord Ordering where
x < y = mkCoerce x < (mkCoerce y :: Exp TAG)
x > y = mkCoerce x > (mkCoerce y :: Exp TAG)
x <= y = mkCoerce x <= (mkCoerce y :: Exp TAG)
x >= y = mkCoerce x >= (mkCoerce y :: Exp TAG)
min x y = mkCoerce $ min (mkCoerce x) (mkCoerce y :: Exp TAG)
max x y = mkCoerce $ max (mkCoerce x) (mkCoerce y :: Exp TAG)
instance Ord a => P.Ord (Exp a) where
(<) = preludeError "Ord.(<)" "(<)"
(<=) = preludeError "Ord.(<=)" "(<=)"
(>) = preludeError "Ord.(>)" "(>)"
(>=) = preludeError "Ord.(>=)" "(>=)"
min = min
max = max
preludeError :: String -> String -> a
preludeError x y
= error
$ unlines [ printf "Prelude.%s applied to EDSL types: use Data.Array.Accelerate.%s instead" x y
, ""
, "These Prelude.Ord instances are present only to fulfil superclass"
, "constraints for subsequent classes in the standard Haskell numeric"
, "hierarchy."
]
$(runQ $ do
let
integralTypes :: [Name]
integralTypes =
[ ''Int
, ''Int8
, ''Int16
, ''Int32
, ''Int64
, ''Word
, ''Word8
, ''Word16
, ''Word32
, ''Word64
]
floatingTypes :: [Name]
floatingTypes =
[ ''Half
, ''Float
, ''Double
]
nonNumTypes :: [Name]
nonNumTypes =
[ ''Char
]
cTypes :: [Name]
cTypes =
[ ''CInt
, ''CUInt
, ''CLong
, ''CULong
, ''CLLong
, ''CULLong
, ''CShort
, ''CUShort
, ''CChar
, ''CUChar
, ''CSChar
, ''CFloat
, ''CDouble
]
mkPrim :: Name -> Q [Dec]
mkPrim t =
[d| instance Ord $(conT t) where
(<) = mkLt
(>) = mkGt
(<=) = mkLtEq
(>=) = mkGtEq
min = mkMin
max = mkMax
|]
mkLt' :: [ExpQ] -> [ExpQ] -> ExpQ
mkLt' [x] [y] = [| $x < $y |]
mkLt' (x:xs) (y:ys) = [| $x < $y || ( $x A.== $y && $(mkLt' xs ys) ) |]
mkLt' _ _ = error "mkLt'"
mkGt' :: [ExpQ] -> [ExpQ] -> ExpQ
mkGt' [x] [y] = [| $x > $y |]
mkGt' (x:xs) (y:ys) = [| $x > $y || ( $x A.== $y && $(mkGt' xs ys) ) |]
mkGt' _ _ = error "mkGt'"
mkLtEq' :: [ExpQ] -> [ExpQ] -> ExpQ
mkLtEq' [x] [y] = [| $x < $y |]
mkLtEq' (x:xs) (y:ys) = [| $x < $y || ( $x A.== $y && $(mkLtEq' xs ys) ) |]
mkLtEq' _ _ = error "mkLtEq'"
mkGtEq' :: [ExpQ] -> [ExpQ] -> ExpQ
mkGtEq' [x] [y] = [| $x > $y |]
mkGtEq' (x:xs) (y:ys) = [| $x > $y || ( $x A.== $y && $(mkGtEq' xs ys) ) |]
mkGtEq' _ _ = error "mkGtEq'"
mkTup :: Int -> Q [Dec]
mkTup n =
let
xs = [ mkName ('x':show i) | i <- [0 .. n-1] ]
ys = [ mkName ('y':show i) | i <- [0 .. n-1] ]
cst = tupT (map (\x -> [t| Ord $(varT x) |]) xs)
res = tupT (map varT xs)
pat vs = conP (mkName ('T':show n)) (map varP vs)
in
[d| instance $cst => Ord $res where
$(pat xs) < $(pat ys) = $( mkLt' (map varE xs) (map varE ys) )
$(pat xs) > $(pat ys) = $( mkGt' (map varE xs) (map varE ys) )
$(pat xs) >= $(pat ys) = $( mkGtEq' (map varE xs) (map varE ys) )
$(pat xs) <= $(pat ys) = $( mkLtEq' (map varE xs) (map varE ys) )
|]
is <- mapM mkPrim integralTypes
fs <- mapM mkPrim floatingTypes
ns <- mapM mkPrim nonNumTypes
cs <- mapM mkPrim cTypes
ts <- mapM mkTup [2..16]
return $ concat (concat [is,fs,ns,cs,ts])
)