{-# LANGUAGE UndecidableInstances #-}

-- This Source Code Form is subject to the terms of the Mozilla Public
-- License, v. 2.0. If a copy of the MPL was not distributed with this
-- file, You can obtain one at https://mozilla.org/MPL/2.0/.

-- The code before modification is MIT licensed; (c) 2023 Casper Bach Poulsen and Cas van der Rest.

{- |
Copyright   :  (c) 2023 Yamada Ryo
               (c) 2023 Casper Bach Poulsen and Cas van der Rest
License     :  MPL-2.0 (see the file LICENSE)
Maintainer  :  ymdfield@outlook.jp
Stability   :  experimental
Portability :  portable

An implementation of an open union for higher-order effects using recursively nested binary sums.
-}
module Data.Hefty.Sum where

import Control.Effect.Class (NopS, Signature, type (~>))
import Control.Effect.Class.Machinery.HFunctor (HFunctor, caseH, (:+:) (Inl, Inr))
import Data.Hefty.Union (HasMembershipH, UnionH, absurdUnionH, compH, decompH, injectH, projectH)

absurdLH :: (NopS :+: h) f ~> h f
absurdLH :: forall (h :: (* -> *) -> * -> *) (f :: * -> *).
(:+:) NopS h f ~> h f
absurdLH = forall {k} (f :: (* -> *) -> k -> *) (a :: * -> *) (b :: k) c
       (g :: (* -> *) -> k -> *).
(f a b -> c) -> (g a b -> c) -> (:+:) f g a b -> c
caseH \case {} forall a. a -> a
id
{-# INLINE absurdLH #-}

absurdRH :: (h :+: NopS) f ~> h f
absurdRH :: forall (h :: (* -> *) -> * -> *) (f :: * -> *).
(:+:) h NopS f ~> h f
absurdRH = forall {k} (f :: (* -> *) -> k -> *) (a :: * -> *) (b :: k) c
       (g :: (* -> *) -> k -> *).
(f a b -> c) -> (g a b -> c) -> (:+:) f g a b -> c
caseH forall a. a -> a
id \case {}
{-# INLINE absurdRH #-}

swapSumH :: (h1 :+: h2) f a -> (h2 :+: h1) f a
swapSumH :: forall {k} (h1 :: (* -> *) -> k -> *) (h2 :: (* -> *) -> k -> *)
       (f :: * -> *) (a :: k).
(:+:) h1 h2 f a -> (:+:) h2 h1 f a
swapSumH = forall {k} (f :: (* -> *) -> k -> *) (a :: * -> *) (b :: k) c
       (g :: (* -> *) -> k -> *).
(f a b -> c) -> (g a b -> c) -> (:+:) f g a b -> c
caseH forall {k} (f :: (* -> *) -> k -> *) (g :: (* -> *) -> k -> *)
       (h :: * -> *) (e :: k).
g h e -> (:+:) f g h e
Inr forall {k} (f :: (* -> *) -> k -> *) (g :: (* -> *) -> k -> *)
       (h :: * -> *) (e :: k).
f h e -> (:+:) f g h e
Inl
{-# INLINE swapSumH #-}

type family SumH hs where
    SumH '[] = NopS
    SumH (h ': hs) = h :+: SumH hs

{- |
An implementation of an open union for higher-order effects using recursively nested binary sums.
-}
newtype SumUnionH hs f a = SumUnionH {forall (hs :: [(* -> *) -> * -> *]) (f :: * -> *) a.
SumUnionH hs f a -> SumH hs f a
unSumUnionH :: SumH hs f a}

deriving newtype instance Functor (SumUnionH '[] f)
deriving newtype instance Foldable (SumUnionH '[] f)
deriving stock instance Traversable (SumUnionH '[] f)

{- Lack of instances of 'Data.Comp.Multi.Ops.:+:'.
 - Should we create a pullreq on the compdata package side?
 -}
{-
deriving newtype instance
    (Functor (h f), Functor (SumH hs f)) =>
    Functor (SumUnionH (h ': hs) f)

deriving newtype instance
    (Foldable (h f), Foldable (SumH hs f)) =>
    Foldable (SumUnionH (h ': hs) f)

deriving stock instance
    (Traversable (h f), Traversable (SumH hs f)) =>
    Traversable (SumUnionH (h ': hs) f)
-}

deriving newtype instance HFunctor (SumH hs) => HFunctor (SumUnionH hs)

instance UnionH SumUnionH where
    type HasMembershipH _ h hs = h << SumH hs

    injectH :: forall (h :: (* -> *) -> * -> *) (hs :: [(* -> *) -> * -> *])
       (f :: * -> *).
HasMembershipH SumUnionH h hs =>
h f ~> SumUnionH hs f
injectH h f x
sig = forall (hs :: [(* -> *) -> * -> *]) (f :: * -> *) a.
SumH hs f a -> SumUnionH hs f a
SumUnionH forall a b. (a -> b) -> a -> b
$ forall (h1 :: (* -> *) -> * -> *) (h2 :: (* -> *) -> * -> *)
       (f :: * -> *).
(h1 << h2) =>
h1 f ~> h2 f
injH h f x
sig
    projectH :: forall (h :: (* -> *) -> * -> *) (hs :: [(* -> *) -> * -> *])
       (f :: * -> *) a.
HasMembershipH SumUnionH h hs =>
SumUnionH hs f a -> Maybe (h f a)
projectH (SumUnionH SumH hs f a
sig) = forall (h1 :: (* -> *) -> * -> *) (h2 :: (* -> *) -> * -> *)
       (f :: * -> *) a.
(h1 << h2) =>
h2 f a -> Maybe (h1 f a)
projH SumH hs f a
sig

    absurdUnionH :: forall (f :: * -> *) a x. SumUnionH '[] f a -> x
absurdUnionH = \case {}

    compH :: forall (h :: (* -> *) -> * -> *) (f :: * -> *) a
       (hs :: [(* -> *) -> * -> *]).
Either (h f a) (SumUnionH hs f a) -> SumUnionH (h : hs) f a
compH =
        forall (hs :: [(* -> *) -> * -> *]) (f :: * -> *) a.
SumH hs f a -> SumUnionH hs f a
SumUnionH forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
            Left h f a
x -> forall {k} (f :: (* -> *) -> k -> *) (g :: (* -> *) -> k -> *)
       (h :: * -> *) (e :: k).
f h e -> (:+:) f g h e
Inl h f a
x
            Right (SumUnionH SumH hs f a
x) -> forall {k} (f :: (* -> *) -> k -> *) (g :: (* -> *) -> k -> *)
       (h :: * -> *) (e :: k).
g h e -> (:+:) f g h e
Inr SumH hs f a
x

    decompH :: forall (h :: (* -> *) -> * -> *) (hs :: [(* -> *) -> * -> *])
       (f :: * -> *) a.
SumUnionH (h : hs) f a -> Either (h f a) (SumUnionH hs f a)
decompH (SumUnionH SumH (h : hs) f a
sig) = case SumH (h : hs) f a
sig of
        Inl h f a
x -> forall a b. a -> Either a b
Left h f a
x
        Inr SumH hs f a
x -> forall a b. b -> Either a b
Right (forall (hs :: [(* -> *) -> * -> *]) (f :: * -> *) a.
SumH hs f a -> SumUnionH hs f a
SumUnionH SumH hs f a
x)

    {-# INLINE injectH #-}
    {-# INLINE projectH #-}
    {-# INLINE absurdUnionH #-}

class isHead ~ h1 `IsHeadSigOf` h2 => SumMemberH isHead (h1 :: Signature) h2 where
    injSumH :: h1 f a -> h2 f a
    projSumH :: h2 f a -> Maybe (h1 f a)

type family (h1 :: Signature) `IsHeadSigOf` h2 where
    f `IsHeadSigOf` f :+: g = 'True
    _ `IsHeadSigOf` _ = 'False

type h1 << h2 = SumMemberH (IsHeadSigOf h1 h2) h1 h2

injH :: forall h1 h2 f. h1 << h2 => h1 f ~> h2 f
injH :: forall (h1 :: (* -> *) -> * -> *) (h2 :: (* -> *) -> * -> *)
       (f :: * -> *).
(h1 << h2) =>
h1 f ~> h2 f
injH = forall (isHead :: Bool) (h1 :: (* -> *) -> * -> *)
       (h2 :: (* -> *) -> * -> *) (f :: * -> *) a.
SumMemberH isHead h1 h2 =>
h1 f a -> h2 f a
injSumH @(IsHeadSigOf h1 h2)

projH :: forall h1 h2 f a. h1 << h2 => h2 f a -> Maybe (h1 f a)
projH :: forall (h1 :: (* -> *) -> * -> *) (h2 :: (* -> *) -> * -> *)
       (f :: * -> *) a.
(h1 << h2) =>
h2 f a -> Maybe (h1 f a)
projH = forall (isHead :: Bool) (h1 :: (* -> *) -> * -> *)
       (h2 :: (* -> *) -> * -> *) (f :: * -> *) a.
SumMemberH isHead h1 h2 =>
h2 f a -> Maybe (h1 f a)
projSumH

instance SumMemberH 'True f (f :+: g) where
    injSumH :: forall (f :: * -> *) a. f f a -> (:+:) f g f a
injSumH = forall {k} (f :: (* -> *) -> k -> *) (g :: (* -> *) -> k -> *)
       (h :: * -> *) (e :: k).
f h e -> (:+:) f g h e
Inl

    projSumH :: forall (f :: * -> *) a. (:+:) f g f a -> Maybe (f f a)
projSumH = \case
        Inl f f a
x -> forall a. a -> Maybe a
Just f f a
x
        Inr g f a
_ -> forall a. Maybe a
Nothing

    {-# INLINE injSumH #-}
    {-# INLINE projSumH #-}

instance (f `IsHeadSigOf` (g :+: h) ~ 'False, f << h) => SumMemberH 'False f (g :+: h) where
    injSumH :: forall (f :: * -> *) a. f f a -> (:+:) g h f a
injSumH = forall {k} (f :: (* -> *) -> k -> *) (g :: (* -> *) -> k -> *)
       (h :: * -> *) (e :: k).
g h e -> (:+:) f g h e
Inr forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (h1 :: (* -> *) -> * -> *) (h2 :: (* -> *) -> * -> *)
       (f :: * -> *).
(h1 << h2) =>
h1 f ~> h2 f
injH
    projSumH :: forall (f :: * -> *) a. (:+:) g h f a -> Maybe (f f a)
projSumH = \case
        Inl g f a
_ -> forall a. Maybe a
Nothing
        Inr h f a
x -> forall (isHead :: Bool) (h1 :: (* -> *) -> * -> *)
       (h2 :: (* -> *) -> * -> *) (f :: * -> *) a.
SumMemberH isHead h1 h2 =>
h2 f a -> Maybe (h1 f a)
projSumH h f a
x

    {-# INLINE injSumH #-}
    {-# INLINE projSumH #-}