{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} #if __GLASGOW_HASKELL__ >= 806 {-# LANGUAGE UndecidableInstances #-} #endif -- | -- Module : Data.Array.Accelerate.Data.Semigroup -- Copyright : [2018..2020] The Accelerate Team -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- -- Semigroup instances for Accelerate -- -- @since 1.2.0.0 -- module Data.Array.Accelerate.Data.Semigroup ( Semigroup(..), Min(..), pattern Min_, Max(..), pattern Max_, ) where import Data.Array.Accelerate.Classes.Bounded import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Lift import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt import Data.Function import Data.Monoid ( Monoid(..) ) import Data.Semigroup import qualified Prelude as P pattern Min_ :: Elt a => Exp a -> Exp (Min a) pattern Min_ x = Pattern x {-# COMPLETE Min_ #-} instance Elt a => Elt (Min a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Min a) where type Plain (Min a) = Min (Plain a) lift (Min a) = Min_ (lift a) instance Elt a => Unlift Exp (Min (Exp a)) where unlift (Min_ a) = Min a instance Bounded a => P.Bounded (Exp (Min a)) where minBound = lift $ Min (minBound :: Exp a) maxBound = lift $ Min (maxBound :: Exp a) instance Num a => P.Num (Exp (Min a)) where (+) = lift2 ((+) :: Min (Exp a) -> Min (Exp a) -> Min (Exp a)) (-) = lift2 ((-) :: Min (Exp a) -> Min (Exp a) -> Min (Exp a)) (*) = lift2 ((*) :: Min (Exp a) -> Min (Exp a) -> Min (Exp a)) negate = lift1 (negate :: Min (Exp a) -> Min (Exp a)) signum = lift1 (signum :: Min (Exp a) -> Min (Exp a)) abs = lift1 (signum :: Min (Exp a) -> Min (Exp a)) fromInteger x = lift (P.fromInteger x :: Min (Exp a)) instance Eq a => Eq (Min a) where (==) = lift2 ((==) `on` getMin) (/=) = lift2 ((/=) `on` getMin) instance Ord a => Ord (Min a) where (<) = lift2 ((<) `on` getMin) (>) = lift2 ((>) `on` getMin) (<=) = lift2 ((<=) `on` getMin) (>=) = lift2 ((>=) `on` getMin) min x y = lift . Min $ lift2 (min `on` getMin) x y max x y = lift . Min $ lift2 (max `on` getMin) x y instance Ord a => Semigroup (Exp (Min a)) where x <> y = lift . Min $ lift2 (min `on` getMin) x y stimes = stimesIdempotent instance (Ord a, Bounded a) => Monoid (Exp (Min a)) where mempty = maxBound mappend = (<>) pattern Max_ :: Elt a => Exp a -> Exp (Max a) pattern Max_ x = Pattern x {-# COMPLETE Max_ #-} instance Elt a => Elt (Max a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Max a) where type Plain (Max a) = Max (Plain a) lift (Max a) = Max_ (lift a) instance Elt a => Unlift Exp (Max (Exp a)) where unlift (Max_ a) = Max a instance Bounded a => P.Bounded (Exp (Max a)) where minBound = Max_ minBound maxBound = Max_ maxBound instance Num a => P.Num (Exp (Max a)) where (+) = lift2 ((+) :: Max (Exp a) -> Max (Exp a) -> Max (Exp a)) (-) = lift2 ((-) :: Max (Exp a) -> Max (Exp a) -> Max (Exp a)) (*) = lift2 ((*) :: Max (Exp a) -> Max (Exp a) -> Max (Exp a)) negate = lift1 (negate :: Max (Exp a) -> Max (Exp a)) signum = lift1 (signum :: Max (Exp a) -> Max (Exp a)) abs = lift1 (signum :: Max (Exp a) -> Max (Exp a)) fromInteger x = lift (P.fromInteger x :: Max (Exp a)) instance Eq a => Eq (Max a) where (==) = lift2 ((==) `on` getMax) (/=) = lift2 ((/=) `on` getMax) instance Ord a => Ord (Max a) where (<) = lift2 ((<) `on` getMax) (>) = lift2 ((>) `on` getMax) (<=) = lift2 ((<=) `on` getMax) (>=) = lift2 ((>=) `on` getMax) min x y = Max_ $ lift2 (min `on` getMax) x y max x y = Max_ $ lift2 (max `on` getMax) x y instance Ord a => Semigroup (Exp (Max a)) where x <> y = Max_ $ lift2 (max `on` getMax) x y stimes = stimesIdempotent instance (Ord a, Bounded a) => Monoid (Exp (Max a)) where mempty = minBound mappend = (<>) -- Instances for unit and tuples -- ----------------------------- instance Semigroup (Exp ()) where _ <> _ = constant () sconcat _ = constant () stimes _ _ = constant () instance (Elt a, Elt b, Semigroup (Exp a), Semigroup (Exp b)) => Semigroup (Exp (a,b)) where (<>) = lift2 ((<>) :: (Exp a, Exp b) -> (Exp a, Exp b) -> (Exp a, Exp b)) stimes n (unlift -> (a,b) :: (Exp a, Exp b)) = lift (stimes n a, stimes n b) instance (Elt a, Elt b, Elt c, Semigroup (Exp a), Semigroup (Exp b), Semigroup (Exp c)) => Semigroup (Exp (a,b,c)) where (<>) = lift2 ((<>) :: (Exp a, Exp b, Exp c) -> (Exp a, Exp b, Exp c) -> (Exp a, Exp b, Exp c)) stimes n (unlift -> (a,b,c) :: (Exp a, Exp b, Exp c)) = lift (stimes n a, stimes n b, stimes n c) instance (Elt a, Elt b, Elt c, Elt d, Semigroup (Exp a), Semigroup (Exp b), Semigroup (Exp c), Semigroup (Exp d)) => Semigroup (Exp (a,b,c,d)) where (<>) = lift2 ((<>) :: (Exp a, Exp b, Exp c, Exp d) -> (Exp a, Exp b, Exp c, Exp d) -> (Exp a, Exp b, Exp c, Exp d)) stimes n (unlift -> (a,b,c,d) :: (Exp a, Exp b, Exp c, Exp d)) = lift (stimes n a, stimes n b, stimes n c, stimes n d) instance (Elt a, Elt b, Elt c, Elt d, Elt e, Semigroup (Exp a), Semigroup (Exp b), Semigroup (Exp c), Semigroup (Exp d), Semigroup (Exp e)) => Semigroup (Exp (a,b,c,d,e)) where (<>) = lift2 ((<>) :: (Exp a, Exp b, Exp c, Exp d, Exp e) -> (Exp a, Exp b, Exp c, Exp d, Exp e) -> (Exp a, Exp b, Exp c, Exp d, Exp e)) stimes n (unlift -> (a,b,c,d,e) :: (Exp a, Exp b, Exp c, Exp d, Exp e)) = lift (stimes n a, stimes n b, stimes n c, stimes n d, stimes n e)