{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Classes.Ord -- Copyright : [2016..2020] The Accelerate Team -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.Classes.Ord ( Ord(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_, ) where import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Pattern.Ordering import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Type -- We must hide (==), as that operator is used for the literals 0, 1 and 2 in the pattern synonyms for Ordering. -- As RebindableSyntax is enabled, a literal pattern is compiled to a call to (==), meaning that the Prelude.(==) should be in scope as (==). import Data.Array.Accelerate.Classes.Eq hiding ( (==) ) import qualified Data.Array.Accelerate.Classes.Eq as A import Data.Char import Language.Haskell.TH hiding ( Exp ) import Language.Haskell.TH.Extra import Prelude ( ($), (>>=), Ordering(..), Num(..), Maybe(..), String, show, error, unlines, return, concat, map, mapM ) import Text.Printf import qualified Prelude as P infix 4 < infix 4 > infix 4 <= infix 4 >= -- | The 'Ord' class for totally ordered datatypes -- class Eq a => Ord a where {-# MINIMAL (<=) | compare #-} (<) :: Exp a -> Exp a -> Exp Bool (>) :: Exp a -> Exp a -> Exp Bool (<=) :: Exp a -> Exp a -> Exp Bool (>=) :: Exp a -> Exp a -> Exp Bool min :: Exp a -> Exp a -> Exp a max :: Exp a -> Exp a -> Exp a compare :: Exp a -> Exp a -> Exp Ordering x < y = if compare x y A.== constant LT then constant True else constant False x <= y = if compare x y A.== constant GT then constant False else constant True x > y = if compare x y A.== constant GT then constant True else constant False x >= y = if compare x y A.== constant LT then constant False else constant True min x y = if x <= y then x else y max x y = if x <= y then y else x compare x y = if x A.== y then constant EQ else if x <= y then constant LT else constant GT -- Local redefinition for use with RebindableSyntax (pulled forward from Prelude.hs) -- ifThenElse :: Elt a => Exp Bool -> Exp a -> Exp a -> Exp a ifThenElse (Exp c) (Exp x) (Exp y) = Exp $ SmartExp $ Cond (mkCoerce' c) x y instance Ord () where (<) _ _ = constant False (>) _ _ = constant False (>=) _ _ = constant True (<=) _ _ = constant True min _ _ = constant () max _ _ = constant () compare _ _ = constant EQ instance Ord Z where (<) _ _ = constant False (>) _ _ = constant False (<=) _ _ = constant True (>=) _ _ = constant True min _ _ = constant Z max _ _ = constant Z instance Ord sh => Ord (sh :. Int) where x <= y = indexHead x <= indexHead y && indexTail x <= indexTail y x >= y = indexHead x >= indexHead y && indexTail x >= indexTail y x < y = indexHead x < indexHead y && case matchTypeR (eltR @sh) (eltR @Z) of Just Refl -> constant True Nothing -> indexTail x < indexTail y x > y = indexHead x > indexHead y && case matchTypeR (eltR @sh) (eltR @Z) of Just Refl -> constant True Nothing -> indexTail x > indexTail y instance Eq Ordering where x == y = mkCoerce x A.== (mkCoerce y :: Exp TAG) x /= y = mkCoerce x A./= (mkCoerce y :: Exp TAG) instance Ord Ordering where x < y = mkCoerce x < (mkCoerce y :: Exp TAG) x > y = mkCoerce x > (mkCoerce y :: Exp TAG) x <= y = mkCoerce x <= (mkCoerce y :: Exp TAG) x >= y = mkCoerce x >= (mkCoerce y :: Exp TAG) min x y = mkCoerce $ min (mkCoerce x) (mkCoerce y :: Exp TAG) max x y = mkCoerce $ max (mkCoerce x) (mkCoerce y :: Exp TAG) -- Instances of 'Prelude.Ord' (mostly) don't make sense with the standard -- signatures as the return type is fixed to 'Bool'. This instance is provided -- to provide a useful error message. -- -- Note that 'min' and 'max' are implementable, so we do hook those into the -- accelerate instances defined here. This allows us to use operations such as -- 'Prelude.minimum' and 'Prelude.maximum'. -- instance Ord a => P.Ord (Exp a) where (<) = preludeError "Ord.(<)" "(<)" (<=) = preludeError "Ord.(<=)" "(<=)" (>) = preludeError "Ord.(>)" "(>)" (>=) = preludeError "Ord.(>=)" "(>=)" min = min max = max preludeError :: String -> String -> a preludeError x y = error $ unlines [ printf "Prelude.%s applied to EDSL types: use Data.Array.Accelerate.%s instead" x y , "" , "These Prelude.Ord instances are present only to fulfil superclass" , "constraints for subsequent classes in the standard Haskell numeric" , "hierarchy." ] $(runQ $ do let integralTypes :: [Name] integralTypes = [ ''Int , ''Int8 , ''Int16 , ''Int32 , ''Int64 , ''Word , ''Word8 , ''Word16 , ''Word32 , ''Word64 ] floatingTypes :: [Name] floatingTypes = [ ''Half , ''Float , ''Double ] nonNumTypes :: [Name] nonNumTypes = [ ''Char ] cTypes :: [Name] cTypes = [ ''CInt , ''CUInt , ''CLong , ''CULong , ''CLLong , ''CULLong , ''CShort , ''CUShort , ''CChar , ''CUChar , ''CSChar , ''CFloat , ''CDouble ] mkPrim :: Name -> Q [Dec] mkPrim t = [d| instance Ord $(conT t) where (<) = mkLt (>) = mkGt (<=) = mkLtEq (>=) = mkGtEq min = mkMin max = mkMax |] mkLt' :: [ExpQ] -> [ExpQ] -> ExpQ mkLt' [x] [y] = [| $x < $y |] mkLt' (x:xs) (y:ys) = [| $x < $y || ( $x A.== $y && $(mkLt' xs ys) ) |] mkLt' _ _ = error "mkLt'" mkGt' :: [ExpQ] -> [ExpQ] -> ExpQ mkGt' [x] [y] = [| $x > $y |] mkGt' (x:xs) (y:ys) = [| $x > $y || ( $x A.== $y && $(mkGt' xs ys) ) |] mkGt' _ _ = error "mkGt'" mkLtEq' :: [ExpQ] -> [ExpQ] -> ExpQ mkLtEq' [x] [y] = [| $x < $y |] mkLtEq' (x:xs) (y:ys) = [| $x < $y || ( $x A.== $y && $(mkLtEq' xs ys) ) |] mkLtEq' _ _ = error "mkLtEq'" mkGtEq' :: [ExpQ] -> [ExpQ] -> ExpQ mkGtEq' [x] [y] = [| $x > $y |] mkGtEq' (x:xs) (y:ys) = [| $x > $y || ( $x A.== $y && $(mkGtEq' xs ys) ) |] mkGtEq' _ _ = error "mkGtEq'" mkTup :: Int -> Q [Dec] mkTup n = let xs = [ mkName ('x':show i) | i <- [0 .. n-1] ] ys = [ mkName ('y':show i) | i <- [0 .. n-1] ] cst = tupT (map (\x -> [t| Ord $(varT x) |]) xs) res = tupT (map varT xs) pat vs = conP (mkName ('T':show n)) (map varP vs) in [d| instance $cst => Ord $res where $(pat xs) < $(pat ys) = $( mkLt' (map varE xs) (map varE ys) ) $(pat xs) > $(pat ys) = $( mkGt' (map varE xs) (map varE ys) ) $(pat xs) >= $(pat ys) = $( mkGtEq' (map varE xs) (map varE ys) ) $(pat xs) <= $(pat ys) = $( mkLtEq' (map varE xs) (map varE ys) ) |] is <- mapM mkPrim integralTypes fs <- mapM mkPrim floatingTypes ns <- mapM mkPrim nonNumTypes cs <- mapM mkPrim cTypes ts <- mapM mkTup [2..16] return $ concat (concat [is,fs,ns,cs,ts]) )