{-# LANGUAGE BlockArguments        #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternGuards         #-}
{-# LANGUAGE PatternSynonyms       #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TemplateHaskell       #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE ViewPatterns          #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module      : Data.Array.Accelerate.Data.Either
-- Copyright   : [2018..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- @since 1.2.0.0
--

module Data.Array.Accelerate.Data.Either (

  Either(..), pattern Left_, pattern Right_,
  either, isLeft, isRight, fromLeft, fromRight, lefts, rights,

) where

import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Language
import Data.Array.Accelerate.Lift
import Data.Array.Accelerate.Pattern.Either
import Data.Array.Accelerate.Prelude
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Array                            ( Array, Vector )
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Shape                            ( Shape, Slice, (:.) )
import Data.Array.Accelerate.Type

import Data.Array.Accelerate.Classes.Eq
import Data.Array.Accelerate.Classes.Ord

import Data.Array.Accelerate.Data.Functor
import Data.Array.Accelerate.Data.Monoid
import Data.Array.Accelerate.Data.Semigroup

import Data.Either                                                  ( Either(..) )
import Prelude                                                      ( (.), ($) )


-- | Return 'True' if the argument is a 'Left'-value
--
isLeft :: (Elt a, Elt b) => Exp (Either a b) -> Exp Bool
isLeft :: Exp (Either a b) -> Exp Bool
isLeft = Exp Bool -> Exp Bool
not (Exp Bool -> Exp Bool)
-> (Exp (Either a b) -> Exp Bool) -> Exp (Either a b) -> Exp Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp (Either a b) -> Exp Bool
forall a b. (Elt a, Elt b) => Exp (Either a b) -> Exp Bool
isRight

-- | Return 'True' if the argument is a 'Right'-value
--
isRight :: (Elt a, Elt b) => Exp (Either a b) -> Exp Bool
isRight :: Exp (Either a b) -> Exp Bool
isRight (Exp SmartExp (EltR (Either a b))
e) = SmartExp (EltR Bool) -> Exp Bool
forall t. SmartExp (EltR t) -> Exp t
Exp (SmartExp (EltR Bool) -> Exp Bool)
-> SmartExp (EltR Bool) -> Exp Bool
forall a b. (a -> b) -> a -> b
$ PreSmartExp SmartAcc SmartExp (TAG, ()) -> SmartExp (TAG, ())
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp (TAG, ()) -> SmartExp (TAG, ()))
-> PreSmartExp SmartAcc SmartExp (TAG, ()) -> SmartExp (TAG, ())
forall a b. (a -> b) -> a -> b
$ (PreSmartExp SmartAcc SmartExp TAG -> SmartExp TAG
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp TAG -> SmartExp TAG)
-> PreSmartExp SmartAcc SmartExp TAG -> SmartExp TAG
forall a b. (a -> b) -> a -> b
$ PairIdx (TAG, (((), EltR a), EltR b)) TAG
-> SmartExp (TAG, (((), EltR a), EltR b))
-> PreSmartExp SmartAcc SmartExp TAG
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx (TAG, (((), EltR a), EltR b)) TAG
forall a b. PairIdx (a, b) a
PairIdxLeft SmartExp (TAG, (((), EltR a), EltR b))
SmartExp (EltR (Either a b))
e) SmartExp TAG
-> SmartExp () -> PreSmartExp SmartAcc SmartExp (TAG, ())
forall (exp :: * -> *) t1 t2 (acc :: * -> *).
exp t1 -> exp t2 -> PreSmartExp acc exp (t1, t2)
`Pair` PreSmartExp SmartAcc SmartExp () -> SmartExp ()
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp PreSmartExp SmartAcc SmartExp ()
forall (acc :: * -> *) (exp :: * -> *). PreSmartExp acc exp ()
Nil
  -- TLM: This is a sneaky hack because we know that the tag bits for Right
  -- and True are identical.

-- | The 'fromLeft' function extracts the element out of the 'Left' constructor.
-- If the argument was actually 'Right', you will get an undefined value
-- instead.
--
fromLeft :: (Elt a, Elt b) => Exp (Either a b) -> Exp a
fromLeft :: Exp (Either a b) -> Exp a
fromLeft (Exp SmartExp (EltR (Either a b))
e) = SmartExp (EltR a) -> Exp a
forall t. SmartExp (EltR t) -> Exp t
Exp (SmartExp (EltR a) -> Exp a) -> SmartExp (EltR a) -> Exp a
forall a b. (a -> b) -> a -> b
$ PreSmartExp SmartAcc SmartExp (EltR a) -> SmartExp (EltR a)
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp (EltR a) -> SmartExp (EltR a))
-> PreSmartExp SmartAcc SmartExp (EltR a) -> SmartExp (EltR a)
forall a b. (a -> b) -> a -> b
$ PairIdx ((), EltR a) (EltR a)
-> SmartExp ((), EltR a) -> PreSmartExp SmartAcc SmartExp (EltR a)
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx ((), EltR a) (EltR a)
forall a b. PairIdx (a, b) b
PairIdxRight (SmartExp ((), EltR a) -> PreSmartExp SmartAcc SmartExp (EltR a))
-> SmartExp ((), EltR a) -> PreSmartExp SmartAcc SmartExp (EltR a)
forall a b. (a -> b) -> a -> b
$ PreSmartExp SmartAcc SmartExp ((), EltR a) -> SmartExp ((), EltR a)
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp ((), EltR a)
 -> SmartExp ((), EltR a))
-> PreSmartExp SmartAcc SmartExp ((), EltR a)
-> SmartExp ((), EltR a)
forall a b. (a -> b) -> a -> b
$ PairIdx (((), EltR a), EltR b) ((), EltR a)
-> SmartExp (((), EltR a), EltR b)
-> PreSmartExp SmartAcc SmartExp ((), EltR a)
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx (((), EltR a), EltR b) ((), EltR a)
forall a b. PairIdx (a, b) a
PairIdxLeft (SmartExp (((), EltR a), EltR b)
 -> PreSmartExp SmartAcc SmartExp ((), EltR a))
-> SmartExp (((), EltR a), EltR b)
-> PreSmartExp SmartAcc SmartExp ((), EltR a)
forall a b. (a -> b) -> a -> b
$ PreSmartExp SmartAcc SmartExp (((), EltR a), EltR b)
-> SmartExp (((), EltR a), EltR b)
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp (((), EltR a), EltR b)
 -> SmartExp (((), EltR a), EltR b))
-> PreSmartExp SmartAcc SmartExp (((), EltR a), EltR b)
-> SmartExp (((), EltR a), EltR b)
forall a b. (a -> b) -> a -> b
$ PairIdx (TAG, (((), EltR a), EltR b)) (((), EltR a), EltR b)
-> SmartExp (TAG, (((), EltR a), EltR b))
-> PreSmartExp SmartAcc SmartExp (((), EltR a), EltR b)
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx (TAG, (((), EltR a), EltR b)) (((), EltR a), EltR b)
forall a b. PairIdx (a, b) b
PairIdxRight SmartExp (TAG, (((), EltR a), EltR b))
SmartExp (EltR (Either a b))
e

-- | The 'fromRight' function extracts the element out of the 'Right'
-- constructor. If the argument was actually 'Left', you will get an undefined
-- value instead.
--
fromRight :: (Elt a, Elt b) => Exp (Either a b) -> Exp b
fromRight :: Exp (Either a b) -> Exp b
fromRight (Exp SmartExp (EltR (Either a b))
e) = SmartExp (EltR b) -> Exp b
forall t. SmartExp (EltR t) -> Exp t
Exp (SmartExp (EltR b) -> Exp b) -> SmartExp (EltR b) -> Exp b
forall a b. (a -> b) -> a -> b
$ PreSmartExp SmartAcc SmartExp (EltR b) -> SmartExp (EltR b)
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp (EltR b) -> SmartExp (EltR b))
-> PreSmartExp SmartAcc SmartExp (EltR b) -> SmartExp (EltR b)
forall a b. (a -> b) -> a -> b
$ PairIdx (((), EltR a), EltR b) (EltR b)
-> SmartExp (((), EltR a), EltR b)
-> PreSmartExp SmartAcc SmartExp (EltR b)
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx (((), EltR a), EltR b) (EltR b)
forall a b. PairIdx (a, b) b
PairIdxRight (SmartExp (((), EltR a), EltR b)
 -> PreSmartExp SmartAcc SmartExp (EltR b))
-> SmartExp (((), EltR a), EltR b)
-> PreSmartExp SmartAcc SmartExp (EltR b)
forall a b. (a -> b) -> a -> b
$ PreSmartExp SmartAcc SmartExp (((), EltR a), EltR b)
-> SmartExp (((), EltR a), EltR b)
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp (((), EltR a), EltR b)
 -> SmartExp (((), EltR a), EltR b))
-> PreSmartExp SmartAcc SmartExp (((), EltR a), EltR b)
-> SmartExp (((), EltR a), EltR b)
forall a b. (a -> b) -> a -> b
$ PairIdx (TAG, (((), EltR a), EltR b)) (((), EltR a), EltR b)
-> SmartExp (TAG, (((), EltR a), EltR b))
-> PreSmartExp SmartAcc SmartExp (((), EltR a), EltR b)
forall t1 t2 t (exp :: * -> *) (acc :: * -> *).
PairIdx (t1, t2) t -> exp (t1, t2) -> PreSmartExp acc exp t
Prj PairIdx (TAG, (((), EltR a), EltR b)) (((), EltR a), EltR b)
forall a b. PairIdx (a, b) b
PairIdxRight SmartExp (TAG, (((), EltR a), EltR b))
SmartExp (EltR (Either a b))
e

-- | The 'either' function performs case analysis on the 'Either' type. If the
-- value is @'Left' a@, apply the first function to @a@; if it is @'Right' b@,
-- apply the second function to @b@.
--
either :: (Elt a, Elt b, Elt c) => (Exp a -> Exp c) -> (Exp b -> Exp c) -> Exp (Either a b) -> Exp c
either :: (Exp a -> Exp c) -> (Exp b -> Exp c) -> Exp (Either a b) -> Exp c
either Exp a -> Exp c
f Exp b -> Exp c
g = (Exp (Either a b) -> Exp c) -> Exp (Either a b) -> Exp c
forall f. Matching f => f -> f
match \case
  Left_  Exp a
x -> Exp a -> Exp c
f Exp a
x
  Right_ Exp b
x -> Exp b -> Exp c
g Exp b
x

-- | Extract from the array of 'Either' all of the 'Left' elements, together
-- with a segment descriptor indicating how many elements along each dimension
-- were returned.
--
lefts :: (Shape sh, Slice sh, Elt a, Elt b)
      => Acc (Array (sh:.Int) (Either a b))
      -> Acc (Vector a, Array sh Int)
lefts :: Acc (Array (sh :. Int) (Either a b))
-> Acc (Vector a, Array sh Int)
lefts Acc (Array (sh :. Int) (Either a b))
es = Acc (Array (sh :. Int) Bool)
-> Acc (Array (sh :. Int) a) -> Acc (Vector a, Array sh Int)
forall sh e.
(Shape sh, Elt e) =>
Acc (Array (sh :. Int) Bool)
-> Acc (Array (sh :. Int) e) -> Acc (Vector e, Array sh Int)
compact ((Exp (Either a b) -> Exp Bool)
-> Acc (Array (sh :. Int) (Either a b))
-> Acc (Array (sh :. Int) Bool)
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
map Exp (Either a b) -> Exp Bool
forall a b. (Elt a, Elt b) => Exp (Either a b) -> Exp Bool
isLeft Acc (Array (sh :. Int) (Either a b))
es) ((Exp (Either a b) -> Exp a)
-> Acc (Array (sh :. Int) (Either a b))
-> Acc (Array (sh :. Int) a)
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
map Exp (Either a b) -> Exp a
forall a b. (Elt a, Elt b) => Exp (Either a b) -> Exp a
fromLeft Acc (Array (sh :. Int) (Either a b))
es)

-- | Extract from the array of 'Either' all of the 'Right' elements, together
-- with a segment descriptor indicating how many elements along each dimension
-- were returned.
--
rights :: (Shape sh, Slice sh, Elt a, Elt b)
       => Acc (Array (sh:.Int) (Either a b))
       -> Acc (Vector b, Array sh Int)
rights :: Acc (Array (sh :. Int) (Either a b))
-> Acc (Vector b, Array sh Int)
rights Acc (Array (sh :. Int) (Either a b))
es = Acc (Array (sh :. Int) Bool)
-> Acc (Array (sh :. Int) b) -> Acc (Vector b, Array sh Int)
forall sh e.
(Shape sh, Elt e) =>
Acc (Array (sh :. Int) Bool)
-> Acc (Array (sh :. Int) e) -> Acc (Vector e, Array sh Int)
compact ((Exp (Either a b) -> Exp Bool)
-> Acc (Array (sh :. Int) (Either a b))
-> Acc (Array (sh :. Int) Bool)
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
map Exp (Either a b) -> Exp Bool
forall a b. (Elt a, Elt b) => Exp (Either a b) -> Exp Bool
isRight Acc (Array (sh :. Int) (Either a b))
es) ((Exp (Either a b) -> Exp b)
-> Acc (Array (sh :. Int) (Either a b))
-> Acc (Array (sh :. Int) b)
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
map Exp (Either a b) -> Exp b
forall a b. (Elt a, Elt b) => Exp (Either a b) -> Exp b
fromRight Acc (Array (sh :. Int) (Either a b))
es)


instance Elt a => Functor (Either a) where
  fmap :: (Exp a -> Exp b) -> Exp (Either a a) -> Exp (Either a b)
fmap Exp a -> Exp b
f = (Exp a -> Exp (Either a b))
-> (Exp a -> Exp (Either a b))
-> Exp (Either a a)
-> Exp (Either a b)
forall a b c.
(Elt a, Elt b, Elt c) =>
(Exp a -> Exp c) -> (Exp b -> Exp c) -> Exp (Either a b) -> Exp c
either Exp a -> Exp (Either a b)
forall a b.
(HasCallStack, Elt a, Elt b) =>
Exp a -> Exp (Either a b)
Left_ (Exp b -> Exp (Either a b)
forall a b.
(HasCallStack, Elt a, Elt b) =>
Exp b -> Exp (Either a b)
Right_ (Exp b -> Exp (Either a b))
-> (Exp a -> Exp b) -> Exp a -> Exp (Either a b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp a -> Exp b
f)

instance (Eq a, Eq b) => Eq (Either a b) where
  == :: Exp (Either a b) -> Exp (Either a b) -> Exp Bool
(==) = (Exp (Either a b) -> Exp (Either a b) -> Exp Bool)
-> Exp (Either a b) -> Exp (Either a b) -> Exp Bool
forall f. Matching f => f -> f
match Exp (Either a b) -> Exp (Either a b) -> Exp Bool
forall a b.
(Eq a, Eq b) =>
Exp (Either a b) -> Exp (Either a b) -> Exp Bool
go
    where
      go :: Exp (Either a a) -> Exp (Either a a) -> Exp Bool
go (Left_ Exp a
x)  (Left_ Exp a
y)  = Exp a
x Exp a -> Exp a -> Exp Bool
forall a. Eq a => Exp a -> Exp a -> Exp Bool
== Exp a
y
      go (Right_ Exp a
x) (Right_ Exp a
y) = Exp a
x Exp a -> Exp a -> Exp Bool
forall a. Eq a => Exp a -> Exp a -> Exp Bool
== Exp a
y
      go Exp (Either a a)
_          Exp (Either a a)
_          = Exp Bool
HasCallStack => Exp Bool
False_

instance (Ord a, Ord b) => Ord (Either a b) where
  compare :: Exp (Either a b) -> Exp (Either a b) -> Exp Ordering
compare = (Exp (Either a b) -> Exp (Either a b) -> Exp Ordering)
-> Exp (Either a b) -> Exp (Either a b) -> Exp Ordering
forall f. Matching f => f -> f
match Exp (Either a b) -> Exp (Either a b) -> Exp Ordering
forall a b.
(Ord a, Ord b) =>
Exp (Either a b) -> Exp (Either a b) -> Exp Ordering
go
    where
      go :: Exp (Either a b) -> Exp (Either a b) -> Exp Ordering
go (Left_ Exp a
x)  (Left_ Exp a
y)  = Exp a -> Exp a -> Exp Ordering
forall a. Ord a => Exp a -> Exp a -> Exp Ordering
compare Exp a
x Exp a
y
      go (Right_ Exp b
x) (Right_ Exp b
y) = Exp b -> Exp b -> Exp Ordering
forall a. Ord a => Exp a -> Exp a -> Exp Ordering
compare Exp b
x Exp b
y
      go Left_{}    Right_{}   = Exp Ordering
HasCallStack => Exp Ordering
LT_
      go Right_{}   Left_{}    = Exp Ordering
HasCallStack => Exp Ordering
GT_

instance (Elt a, Elt b) => Semigroup (Exp (Either a b)) where
  Exp (Either a b)
ex <> :: Exp (Either a b) -> Exp (Either a b) -> Exp (Either a b)
<> Exp (Either a b)
ey = Exp (Either a b) -> Exp Bool
forall a b. (Elt a, Elt b) => Exp (Either a b) -> Exp Bool
isLeft Exp (Either a b)
ex Exp Bool
-> (Exp (Either a b), Exp (Either a b)) -> Exp (Either a b)
forall t. Elt t => Exp Bool -> (Exp t, Exp t) -> Exp t
? ( Exp (Either a b)
ey, Exp (Either a b)
ex )

instance (Lift Exp a, Lift Exp b, Elt (Plain a), Elt (Plain b)) => Lift Exp (Either a b) where
  type Plain (Either a b) = Either (Plain a) (Plain b)
  lift :: Either a b -> Exp (Plain (Either a b))
lift (Left a
a)  = Exp (Plain a) -> Exp (Either (Plain a) (Plain b))
forall a b.
(HasCallStack, Elt a, Elt b) =>
Exp a -> Exp (Either a b)
Left_ (a -> Exp (Plain a)
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift a
a)
  lift (Right b
b) = Exp (Plain b) -> Exp (Either (Plain a) (Plain b))
forall a b.
(HasCallStack, Elt a, Elt b) =>
Exp b -> Exp (Either a b)
Right_ (b -> Exp (Plain b)
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift b
b)