{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE UndecidableInstances #-}

-- This Source Code Form is subject to the terms of the Mozilla Public
-- License, v. 2.0. If a copy of the MPL was not distributed with this
-- file, You can obtain one at https://mozilla.org/MPL/2.0/.

-- The code before modification is MIT licensed; (c) 2023 Casper Bach Poulsen and Cas van der Rest.

{- |
Copyright   :  (c) 2023 Yamada Ryo
               (c) 2023 Casper Bach Poulsen and Cas van der Rest
License     :  MPL-2.0 (see the file LICENSE)
Maintainer  :  ymdfield@outlook.jp
Stability   :  experimental
Portability :  portable

An implementation of an open union for first-order effects using recursively nested binary sums.
-}
module Data.Free.Sum (module Data.Free.Sum, pattern L1, pattern R1) where

import Control.Effect.Class (Instruction, NopI, type (~>))
import Data.Free.Union (HasMembership, Union, absurdUnion, comp, decomp, inject, project)
import GHC.Generics (type (:+:) (L1, R1))

infixr 6 +

-- | A type synonym for disambiguation to the sum on the higher-order side.
type (+) = (:+:)

caseF :: (f a -> r) -> (g a -> r) -> (f + g) a -> r
caseF :: forall {k} (f :: k -> *) (a :: k) r (g :: k -> *).
(f a -> r) -> (g a -> r) -> (+) f g a -> r
caseF f a -> r
f g a -> r
g = \case
    L1 f a
x -> f a -> r
f f a
x
    R1 g a
x -> g a -> r
g g a
x
{-# INLINE caseF #-}

absurdL :: (NopI + f) ~> f
absurdL :: forall (f :: * -> *). (NopI + f) ~> f
absurdL = forall {k} (f :: k -> *) (a :: k) r (g :: k -> *).
(f a -> r) -> (g a -> r) -> (+) f g a -> r
caseF \case {} forall a. a -> a
id
{-# INLINE absurdL #-}

absurdR :: (f + NopI) ~> f
absurdR :: forall (f :: * -> *). (f + NopI) ~> f
absurdR = forall {k} (f :: k -> *) (a :: k) r (g :: k -> *).
(f a -> r) -> (g a -> r) -> (+) f g a -> r
caseF forall a. a -> a
id \case {}
{-# INLINE absurdR #-}

swapSum :: (f + g) a -> (g + f) a
swapSum :: forall {k} (f :: k -> *) (g :: k -> *) (a :: k).
(+) f g a -> (+) g f a
swapSum = forall {k} (f :: k -> *) (a :: k) r (g :: k -> *).
(f a -> r) -> (g a -> r) -> (+) f g a -> r
caseF forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1
{-# INLINE swapSum #-}

type family Sum fs where
    Sum '[] = NopI
    Sum (f ': fs) = f :+: Sum fs

{- |
An implementation of an open union for first-order effects using recursively nested binary sums.
-}
newtype SumUnion fs a = SumUnion {forall (fs :: [* -> *]) a. SumUnion fs a -> Sum fs a
unSumUnion :: Sum fs a}

deriving newtype instance Functor (SumUnion '[])
deriving newtype instance (Functor f, Functor (Sum fs)) => Functor (SumUnion (f ': fs))

deriving newtype instance Foldable (SumUnion '[])
deriving newtype instance (Foldable f, Foldable (Sum fs)) => Foldable (SumUnion (f ': fs))

deriving stock instance Traversable (SumUnion '[])
deriving stock instance (Traversable f, Traversable (Sum fs)) => Traversable (SumUnion (f ': fs))

instance Union SumUnion where
    type HasMembership _ f fs = f < Sum fs

    inject :: forall (f :: * -> *) (fs :: [* -> *]).
HasMembership SumUnion f fs =>
f ~> SumUnion fs
inject f x
sig = forall (fs :: [* -> *]) a. Sum fs a -> SumUnion fs a
SumUnion forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (g :: * -> *). (f < g) => f ~> g
inj f x
sig
    project :: forall (f :: * -> *) (fs :: [* -> *]) a.
HasMembership SumUnion f fs =>
SumUnion fs a -> Maybe (f a)
project (SumUnion Sum fs a
sig) = forall (f :: * -> *) (g :: * -> *) a. (f < g) => g a -> Maybe (f a)
proj Sum fs a
sig

    absurdUnion :: forall a x. SumUnion '[] a -> x
absurdUnion = \case {}

    comp :: forall (f :: * -> *) a (fs :: [* -> *]).
Either (f a) (SumUnion fs a) -> SumUnion (f : fs) a
comp =
        forall (fs :: [* -> *]) a. Sum fs a -> SumUnion fs a
SumUnion forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
            Left f a
x -> forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 f a
x
            Right (SumUnion Sum fs a
x) -> forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 Sum fs a
x

    decomp :: forall (f :: * -> *) (fs :: [* -> *]) a.
SumUnion (f : fs) a -> Either (f a) (SumUnion fs a)
decomp (SumUnion Sum (f : fs) a
sig) = case Sum (f : fs) a
sig of
        L1 f a
x -> forall a b. a -> Either a b
Left f a
x
        R1 Sum fs a
x -> forall a b. b -> Either a b
Right (forall (fs :: [* -> *]) a. Sum fs a -> SumUnion fs a
SumUnion Sum fs a
x)

    {-# INLINE inject #-}
    {-# INLINE project #-}
    {-# INLINE absurdUnion #-}

class isHead ~ f `IsHeadInsOf` g => SumMember isHead (f :: Instruction) g where
    injSum :: f a -> g a
    projSum :: g a -> Maybe (f a)

type family (f :: Instruction) `IsHeadInsOf` g where
    f `IsHeadInsOf` f + g = 'True
    _ `IsHeadInsOf` _ = 'False

type f < g = SumMember (IsHeadInsOf f g) f g

inj :: forall f g. f < g => f ~> g
inj :: forall (f :: * -> *) (g :: * -> *). (f < g) => f ~> g
inj = forall (isHead :: Bool) (f :: * -> *) (g :: * -> *) a.
SumMember isHead f g =>
f a -> g a
injSum @(IsHeadInsOf f g)

proj :: forall f g a. f < g => g a -> Maybe (f a)
proj :: forall (f :: * -> *) (g :: * -> *) a. (f < g) => g a -> Maybe (f a)
proj = forall (isHead :: Bool) (f :: * -> *) (g :: * -> *) a.
SumMember isHead f g =>
g a -> Maybe (f a)
projSum

instance SumMember 'True f (f + g) where
    injSum :: forall a. f a -> (+) f g a
injSum = forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1

    projSum :: forall a. (+) f g a -> Maybe (f a)
projSum = \case
        L1 f a
x -> forall a. a -> Maybe a
Just f a
x
        R1 g a
_ -> forall a. Maybe a
Nothing

    {-# INLINE injSum #-}
    {-# INLINE projSum #-}

instance (f `IsHeadInsOf` (g + h) ~ 'False, f < h) => SumMember 'False f (g + h) where
    injSum :: forall a. f a -> (+) g h a
injSum = forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) (g :: * -> *). (f < g) => f ~> g
inj
    projSum :: forall a. (+) g h a -> Maybe (f a)
projSum = \case
        L1 g a
_ -> forall a. Maybe a
Nothing
        R1 h a
x -> forall (isHead :: Bool) (f :: * -> *) (g :: * -> *) a.
SumMember isHead f g =>
g a -> Maybe (f a)
projSum h a
x

    {-# INLINE injSum #-}
    {-# INLINE projSum #-}