{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.Data.Either (
Either(..), pattern Left_, pattern Right_,
either, isLeft, isRight, fromLeft, fromRight, lefts, rights,
) where
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Language
import Data.Array.Accelerate.Lift
import Data.Array.Accelerate.Pattern.Either
import Data.Array.Accelerate.Prelude
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Array ( Array, Vector )
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Shape ( Shape, Slice, (:.) )
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Classes.Eq
import Data.Array.Accelerate.Classes.Ord
import Data.Array.Accelerate.Data.Functor
import Data.Array.Accelerate.Data.Monoid
import Data.Array.Accelerate.Data.Semigroup
import Data.Either ( Either(..) )
import Prelude ( (.), ($) )
isLeft :: (Elt a, Elt b) => Exp (Either a b) -> Exp Bool
isLeft = not . isRight
isRight :: (Elt a, Elt b) => Exp (Either a b) -> Exp Bool
isRight (Exp e) = Exp $ SmartExp $ (SmartExp $ Prj PairIdxLeft e) `Pair` SmartExp Nil
fromLeft :: (Elt a, Elt b) => Exp (Either a b) -> Exp a
fromLeft (Exp e) = Exp $ SmartExp $ Prj PairIdxRight $ SmartExp $ Prj PairIdxLeft $ SmartExp $ Prj PairIdxRight e
fromRight :: (Elt a, Elt b) => Exp (Either a b) -> Exp b
fromRight (Exp e) = Exp $ SmartExp $ Prj PairIdxRight $ SmartExp $ Prj PairIdxRight e
either :: (Elt a, Elt b, Elt c) => (Exp a -> Exp c) -> (Exp b -> Exp c) -> Exp (Either a b) -> Exp c
either f g = match \case
Left_ x -> f x
Right_ x -> g x
lefts :: (Shape sh, Slice sh, Elt a, Elt b)
=> Acc (Array (sh:.Int) (Either a b))
-> Acc (Vector a, Array sh Int)
lefts es = compact (map isLeft es) (map fromLeft es)
rights :: (Shape sh, Slice sh, Elt a, Elt b)
=> Acc (Array (sh:.Int) (Either a b))
-> Acc (Vector b, Array sh Int)
rights es = compact (map isRight es) (map fromRight es)
instance Elt a => Functor (Either a) where
fmap f = either Left_ (Right_ . f)
instance (Eq a, Eq b) => Eq (Either a b) where
(==) = match go
where
go (Left_ x) (Left_ y) = x == y
go (Right_ x) (Right_ y) = x == y
go _ _ = False_
instance (Ord a, Ord b) => Ord (Either a b) where
compare = match go
where
go (Left_ x) (Left_ y) = compare x y
go (Right_ x) (Right_ y) = compare x y
go Left_{} Right_{} = LT_
go Right_{} Left_{} = GT_
instance (Elt a, Elt b) => Semigroup (Exp (Either a b)) where
ex <> ey = isLeft ex ? ( ey, ex )
instance (Lift Exp a, Lift Exp b, Elt (Plain a), Elt (Plain b)) => Lift Exp (Either a b) where
type Plain (Either a b) = Either (Plain a) (Plain b)
lift (Left a) = Left_ (lift a)
lift (Right b) = Right_ (lift b)