{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# OPTIONS_HADDOCK not-home #-}
module Polysemy.Internal.Union
( Union (..)
, Weaving (..)
, Member
, MemberWithError
, weave
, hoist
, inj
, injWeaving
, weaken
, decomp
, prj
, extract
, absurdU
, decompCoerce
, SNat (..)
, Nat (..)
) where
import Control.Monad
import Data.Functor.Compose
import Data.Functor.Identity
import Data.Kind
import Data.Type.Equality
import Polysemy.Internal.Kind
#ifndef NO_ERROR_MESSAGES
import Polysemy.Internal.CustomErrors
#endif
data Union (r :: EffectRow) (m :: Type -> Type) a where
Union
::
SNat n
-> Weaving (IndexOf r n) m a
-> Union r m a
instance Functor (Union r m) where
fmap f (Union w t) = Union w $ fmap f t
{-# INLINE fmap #-}
data Weaving e m a where
Weaving
:: Functor f
=> { weaveEffect :: e m a
, weaveState :: f ()
, weaveDistrib :: forall x. f (m x) -> n (f x)
, weaveResult :: f a -> b
, weaveInspect :: forall x. f x -> Maybe x
}
-> Weaving e n b
instance Functor (Weaving e m) where
fmap f (Weaving e s d f' v) = Weaving e s d (f . f') v
{-# INLINE fmap #-}
weave
:: (Functor s, Functor m, Functor n)
=> s ()
-> (∀ x. s (m x) -> n (s x))
-> (∀ x. s x -> Maybe x)
-> Union r m a
-> Union r n (s a)
weave s' d v' (Union w (Weaving e s nt f v)) = Union w $
Weaving e (Compose $ s <$ s')
(fmap Compose . d . fmap nt . getCompose)
(fmap f . getCompose)
(v <=< v' . getCompose)
{-# INLINE weave #-}
hoist
:: ( Functor m
, Functor n
)
=> (∀ x. m x -> n x)
-> Union r m a
-> Union r n a
hoist f' (Union w (Weaving e s nt f v)) = Union w $ Weaving e s (f' . nt) f v
{-# INLINE hoist #-}
type Member e r = MemberNoError e r
type MemberWithError e r =
( MemberNoError e r
#ifndef NO_ERROR_MESSAGES
, WhenStuck (IndexOf r (Found r e)) (AmbiguousSend r e)
#endif
)
type MemberNoError e r =
( Find r e
, e ~ IndexOf r (Found r e)
)
data Nat = Z | S Nat
data SNat :: Nat -> Type where
SZ :: SNat 'Z
SS :: SNat n -> SNat ('S n)
instance TestEquality SNat where
testEquality SZ SZ = Just Refl
testEquality (SS _) SZ = Nothing
testEquality SZ (SS _) = Nothing
testEquality (SS n) (SS m) =
case testEquality n m of
Nothing -> Nothing
Just Refl -> Just Refl
{-# INLINE testEquality #-}
type family IndexOf (ts :: [k]) (n :: Nat) :: k where
IndexOf (k ': ks) 'Z = k
IndexOf (k ': ks) ('S n) = IndexOf ks n
type family Found (ts :: [k]) (t :: k) :: Nat where
#ifndef NO_ERROR_MESSAGES
Found '[] t = UnhandledEffect t
#endif
Found (t ': ts) t = 'Z
Found (u ': ts) t = 'S (Found ts t)
class Find (r :: [k]) (t :: k) where
finder :: SNat (Found r t)
instance {-# OVERLAPPING #-} Find (t ': z) t where
finder = SZ
{-# INLINE finder #-}
instance ( Find z t
, Found (_1 ': z) t ~ 'S (Found z t)
) => Find (_1 ': z) t where
finder = SS $ finder @_ @z @t
{-# INLINE finder #-}
decomp :: Union (e ': r) m a -> Either (Union r m a) (Weaving e m a)
decomp (Union p a) =
case p of
SZ -> Right a
SS n -> Left $ Union n a
{-# INLINE decomp #-}
extract :: Union '[e] m a -> Weaving e m a
extract (Union SZ a) = a
extract _ = error "impossible"
{-# INLINE extract #-}
absurdU :: Union '[] m a -> b
absurdU = absurdU
weaken :: forall e r m a. Union r m a -> Union (e ': r) m a
weaken (Union n a) = Union (SS n) a
{-# INLINE weaken #-}
inj :: forall e r m a. (Functor m , Member e r) => e m a -> Union r m a
inj e = injWeaving $
Weaving e (Identity ())
(fmap Identity . runIdentity)
runIdentity
(Just . runIdentity)
{-# INLINE inj #-}
injWeaving :: forall e r m a. Member e r => Weaving e m a -> Union r m a
injWeaving = Union (finder @_ @r @e)
{-# INLINE injWeaving #-}
prj :: forall e r m a
. ( Member e r
)
=> Union r m a
-> Maybe (Weaving e m a)
prj (Union sn a) =
case testEquality sn (finder @_ @r @e) of
Nothing -> Nothing
Just Refl -> Just a
{-# INLINE prj #-}
decompCoerce
:: Union (e ': r) m a
-> Either (Union (f ': r) m a) (Weaving e m a)
decompCoerce (Union p a) =
case p of
SZ -> Right a
SS n -> Left (Union (SS n) a)
{-# INLINE decompCoerce #-}