{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.Class.PlainUnion
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.Data.Class.PlainUnion
  ( PlainUnion (..),
    pattern Single,
    pattern If,
    simpleMerge,
    symIteMerge,
    (.#),
    onUnion,
    onUnion2,
    onUnion3,
    onUnion4,
  )
where

import Data.Bifunctor (Bifunctor (first))
import Data.Kind (Type)
import Grisette.Internal.Core.Data.Class.Function (Function ((#)))
import Grisette.Internal.Core.Data.Class.ITEOp (ITEOp (symIte))
import Grisette.Internal.Core.Data.Class.LogicalOp
  ( LogicalOp (symNot, (.&&)),
  )
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SimpleMergeable
  ( SimpleMergeable,
    UnionMergeable1,
    mrgIf,
  )
import Grisette.Internal.Core.Data.Class.Solvable (Solvable (con))
import Grisette.Internal.Core.Data.Class.TryMerge
  ( mrgSingle,
    tryMerge,
  )
import Grisette.Internal.SymPrim.SymBool (SymBool)

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim

-- | Plain union containers that can be projected back into single value or
-- if-guarded values.
class (Applicative u, UnionMergeable1 u) => PlainUnion (u :: Type -> Type) where
  -- | Pattern match to extract single values.
  --
  -- >>> singleView (return 1 :: UnionM Integer)
  -- Just 1
  -- >>> singleView (mrgIfPropagatedStrategy "a" (return 1) (return 2) :: UnionM Integer)
  -- Nothing
  singleView :: u a -> Maybe a

  -- | Pattern match to extract if values.
  --
  -- >>> ifView (return 1 :: UnionM Integer)
  -- Nothing
  -- >>> ifView (mrgIfPropagatedStrategy "a" (return 1) (return 2) :: UnionM Integer)
  -- Just (a,<1>,<2>)
  -- >>> ifView (mrgIf "a" (return 1) (return 2) :: UnionM Integer)
  -- Just (a,{1},{2})
  ifView :: u a -> Maybe (SymBool, u a, u a)

  -- | Convert the union to a guarded list.
  --
  -- >>> toGuardedList (mrgIf "a" (return 1) (mrgIf "b" (return 2) (return 3)) :: UnionM Integer)
  -- [(a,1),((&& b (! a)),2),((! (|| b a)),3)]
  toGuardedList :: u a -> [(SymBool, a)]
  toGuardedList u a
u =
    case (u a -> Maybe a
forall a. u a -> Maybe a
forall (u :: * -> *) a. PlainUnion u => u a -> Maybe a
singleView u a
u, u a -> Maybe (SymBool, u a, u a)
forall a. u a -> Maybe (SymBool, u a, u a)
forall (u :: * -> *) a.
PlainUnion u =>
u a -> Maybe (SymBool, u a, u a)
ifView u a
u) of
      (Just a
x, Maybe (SymBool, u a, u a)
_) -> [(Bool -> SymBool
forall c t. Solvable c t => c -> t
con Bool
True, a
x)]
      (Maybe a
_, Just (SymBool
c, u a
l, u a
r)) ->
        ((SymBool, a) -> (SymBool, a)) -> [(SymBool, a)] -> [(SymBool, a)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((SymBool -> SymBool) -> (SymBool, a) -> (SymBool, a)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& SymBool
c)) (u a -> [(SymBool, a)]
forall a. u a -> [(SymBool, a)]
forall (u :: * -> *) a. PlainUnion u => u a -> [(SymBool, a)]
toGuardedList u a
l)
          [(SymBool, a)] -> [(SymBool, a)] -> [(SymBool, a)]
forall a. [a] -> [a] -> [a]
++ ((SymBool, a) -> (SymBool, a)) -> [(SymBool, a)] -> [(SymBool, a)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((SymBool -> SymBool) -> (SymBool, a) -> (SymBool, a)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& SymBool -> SymBool
forall b. LogicalOp b => b -> b
symNot SymBool
c)) (u a -> [(SymBool, a)]
forall a. u a -> [(SymBool, a)]
forall (u :: * -> *) a. PlainUnion u => u a -> [(SymBool, a)]
toGuardedList u a
r)
      (Maybe a, Maybe (SymBool, u a, u a))
_ -> [Char] -> [(SymBool, a)]
forall a. HasCallStack => [Char] -> a
error [Char]
"Should not happen"

-- | Pattern match to extract single values with 'singleView'.
--
-- >>> case (return 1 :: UnionM Integer) of Single v -> v
-- 1
pattern Single :: (PlainUnion u, Mergeable a) => a -> u a
pattern $mSingle :: forall {r} {u :: * -> *} {a}.
(PlainUnion u, Mergeable a) =>
u a -> (a -> r) -> ((# #) -> r) -> r
$bSingle :: forall (u :: * -> *) a. (PlainUnion u, Mergeable a) => a -> u a
Single x <-
  (singleView -> Just x)
  where
    Single a
x = a -> u a
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle a
x

-- | Pattern match to extract guard values with 'ifView'
-- >>> case (mrgIfPropagatedStrategy "a" (return 1) (return 2) :: UnionM Integer) of If c t f -> (c,t,f)
-- (a,<1>,<2>)
pattern If :: (PlainUnion u, Mergeable a) => SymBool -> u a -> u a -> u a
pattern $mIf :: forall {r} {u :: * -> *} {a}.
(PlainUnion u, Mergeable a) =>
u a -> (SymBool -> u a -> u a -> r) -> ((# #) -> r) -> r
$bIf :: forall (u :: * -> *) a.
(PlainUnion u, Mergeable a) =>
SymBool -> u a -> u a -> u a
If c t f <-
  (ifView -> Just (c, t, f))
  where
    If SymBool
c u a
t u a
f = SymBool -> u a -> u a -> u a
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf SymBool
c u a
t u a
f

-- | Merge the simply mergeable values in a union, and extract the merged value.
--
-- In the following example, 'mrgIfPropagatedStrategy' will not merge the results, and
-- 'simpleMerge' will merge it and extract the single merged value.
--
-- >>> mrgIfPropagatedStrategy (ssym "a") (return $ ssym "b") (return $ ssym "c") :: UnionM SymBool
-- <If a b c>
-- >>> simpleMerge $ (mrgIfPropagatedStrategy (ssym "a") (return $ ssym "b") (return $ ssym "c") :: UnionM SymBool)
-- (ite a b c)
simpleMerge :: forall u a. (SimpleMergeable a, PlainUnion u) => u a -> a
simpleMerge :: forall (u :: * -> *) a.
(SimpleMergeable a, PlainUnion u) =>
u a -> a
simpleMerge u a
u = case u a -> u a
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge u a
u of
  Single a
x -> a
x
  u a
_ -> [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Should not happen"
{-# INLINE simpleMerge #-}

-- | Merge the mergeable values in a union, using `symIte`, and extract the
-- merged value.
--
-- The reason why we provide this class is that for some types, we only have
-- `ITEOp` (which may throw an error), and we don't have a `SimpleMergeable`
-- instance. In this case, we can use `symIteMerge` to merge the values.
symIteMerge :: (ITEOp a, Mergeable a, PlainUnion u) => u a -> a
symIteMerge :: forall a (u :: * -> *).
(ITEOp a, Mergeable a, PlainUnion u) =>
u a -> a
symIteMerge (Single a
x) = a
x
symIteMerge (If SymBool
cond u a
l u a
r) = SymBool -> a -> a -> a
forall v. ITEOp v => SymBool -> v -> v -> v
symIte SymBool
cond (u a -> a
forall a (u :: * -> *).
(ITEOp a, Mergeable a, PlainUnion u) =>
u a -> a
symIteMerge u a
l) (u a -> a
forall a (u :: * -> *).
(ITEOp a, Mergeable a, PlainUnion u) =>
u a -> a
symIteMerge u a
r)
symIteMerge u a
_ = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Should not happen"
{-# INLINE symIteMerge #-}

-- | Helper for applying functions on 'UnionLike' and 'SimpleMergeable'.
--
-- >>> let f :: Integer -> UnionM Integer = \x -> mrgIf (ssym "a") (mrgSingle $ x + 1) (mrgSingle $ x + 2)
-- >>> f .# (mrgIf (ssym "b" :: SymBool) (mrgSingle 0) (mrgSingle 2) :: UnionM Integer)
-- {If (&& b a) 1 (If b 2 (If a 3 4))}
(.#) ::
  (Function f a r, SimpleMergeable r, PlainUnion u) =>
  f ->
  u a ->
  r
.# :: forall f a r (u :: * -> *).
(Function f a r, SimpleMergeable r, PlainUnion u) =>
f -> u a -> r
(.#) f
f u a
u = u r -> r
forall (u :: * -> *) a.
(SimpleMergeable a, PlainUnion u) =>
u a -> a
simpleMerge (u r -> r) -> u r -> r
forall a b. (a -> b) -> a -> b
$ (a -> r) -> u a -> u r
forall a b. (a -> b) -> u a -> u b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (f
f f -> a -> r
forall f arg ret. Function f arg ret => f -> arg -> ret
#) u a
u
{-# INLINE (.#) #-}

infixl 9 .#

-- | Lift a function to work on union values.
--
-- >>> sumU = onUnion sum
-- >>> sumU (mrgIfPropagatedStrategy "cond" (return ["a"]) (return ["b","c"]) :: UnionM [SymInteger])
-- (ite cond a (+ b c))
onUnion ::
  forall u a r.
  (SimpleMergeable r, UnionMergeable1 u, PlainUnion u, Mergeable a) =>
  (a -> r) ->
  (u a -> r)
onUnion :: forall (u :: * -> *) a r.
(SimpleMergeable r, UnionMergeable1 u, PlainUnion u,
 Mergeable a) =>
(a -> r) -> u a -> r
onUnion a -> r
f = u r -> r
forall (u :: * -> *) a.
(SimpleMergeable a, PlainUnion u) =>
u a -> a
simpleMerge (u r -> r) -> (u a -> u r) -> u a -> r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> r) -> u a -> u r
forall a b. (a -> b) -> u a -> u b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> r
f (u a -> u r) -> (u a -> u a) -> u a -> u r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. u a -> u a
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge

-- | Lift a function to work on union values.
onUnion2 ::
  forall u a b r.
  ( SimpleMergeable r,
    UnionMergeable1 u,
    PlainUnion u,
    Mergeable a,
    Mergeable b
  ) =>
  (a -> b -> r) ->
  (u a -> u b -> r)
onUnion2 :: forall (u :: * -> *) a b r.
(SimpleMergeable r, UnionMergeable1 u, PlainUnion u, Mergeable a,
 Mergeable b) =>
(a -> b -> r) -> u a -> u b -> r
onUnion2 a -> b -> r
f u a
ua u b
ub = u r -> r
forall (u :: * -> *) a.
(SimpleMergeable a, PlainUnion u) =>
u a -> a
simpleMerge (u r -> r) -> u r -> r
forall a b. (a -> b) -> a -> b
$ a -> b -> r
f (a -> b -> r) -> u a -> u (b -> r)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> u a -> u a
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge u a
ua u (b -> r) -> u b -> u r
forall a b. u (a -> b) -> u a -> u b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u b -> u b
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge u b
ub

-- | Lift a function to work on union values.
onUnion3 ::
  forall u a b c r.
  ( SimpleMergeable r,
    UnionMergeable1 u,
    PlainUnion u,
    Mergeable a,
    Mergeable b,
    Mergeable c
  ) =>
  (a -> b -> c -> r) ->
  (u a -> u b -> u c -> r)
onUnion3 :: forall (u :: * -> *) a b c r.
(SimpleMergeable r, UnionMergeable1 u, PlainUnion u, Mergeable a,
 Mergeable b, Mergeable c) =>
(a -> b -> c -> r) -> u a -> u b -> u c -> r
onUnion3 a -> b -> c -> r
f u a
ua u b
ub u c
uc =
  u r -> r
forall (u :: * -> *) a.
(SimpleMergeable a, PlainUnion u) =>
u a -> a
simpleMerge (u r -> r) -> u r -> r
forall a b. (a -> b) -> a -> b
$ a -> b -> c -> r
f (a -> b -> c -> r) -> u a -> u (b -> c -> r)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> u a -> u a
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge u a
ua u (b -> c -> r) -> u b -> u (c -> r)
forall a b. u (a -> b) -> u a -> u b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u b -> u b
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge u b
ub u (c -> r) -> u c -> u r
forall a b. u (a -> b) -> u a -> u b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u c -> u c
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge u c
uc

-- | Lift a function to work on union values.
onUnion4 ::
  forall u a b c d r.
  ( SimpleMergeable r,
    UnionMergeable1 u,
    PlainUnion u,
    Mergeable a,
    Mergeable b,
    Mergeable c,
    Mergeable d
  ) =>
  (a -> b -> c -> d -> r) ->
  (u a -> u b -> u c -> u d -> r)
onUnion4 :: forall (u :: * -> *) a b c d r.
(SimpleMergeable r, UnionMergeable1 u, PlainUnion u, Mergeable a,
 Mergeable b, Mergeable c, Mergeable d) =>
(a -> b -> c -> d -> r) -> u a -> u b -> u c -> u d -> r
onUnion4 a -> b -> c -> d -> r
f u a
ua u b
ub u c
uc u d
ud =
  u r -> r
forall (u :: * -> *) a.
(SimpleMergeable a, PlainUnion u) =>
u a -> a
simpleMerge (u r -> r) -> u r -> r
forall a b. (a -> b) -> a -> b
$
    a -> b -> c -> d -> r
f (a -> b -> c -> d -> r) -> u a -> u (b -> c -> d -> r)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> u a -> u a
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge u a
ua u (b -> c -> d -> r) -> u b -> u (c -> d -> r)
forall a b. u (a -> b) -> u a -> u b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u b -> u b
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge u b
ub u (c -> d -> r) -> u c -> u (d -> r)
forall a b. u (a -> b) -> u a -> u b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u c -> u c
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge u c
uc u (d -> r) -> u d -> u r
forall a b. u (a -> b) -> u a -> u b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u d -> u d
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge u d
ud