{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}

module Data.Morpheus.Ext.SemigroupM
  ( SemigroupM (..),
    (<:>),
    concatTraverse,
    join,
  )
where

import qualified Data.HashMap.Lazy as HM
import Data.HashMap.Lazy (HashMap)
import Data.Morpheus.Error.NameCollision (NameCollision)
import Data.Morpheus.Ext.KeyOf (KeyOf (..))
import Data.Morpheus.Ext.Map
  ( fromListT,
    runResolutionT,
  )
import Data.Morpheus.Internal.Utils
  ( Collection (..),
    Elems (..),
    Failure,
    failOnDuplicates,
  )
import Data.Morpheus.Types.Internal.AST.Base
  ( Ref,
    ValidationErrors,
  )
import Relude
  ( ($),
    Applicative (..),
    Monad (..),
    Semigroup (..),
    Traversable (..),
  )

class SemigroupM (m :: * -> *) a where
  mergeM :: [Ref] -> a -> a -> m a

instance
  ( NameCollision a,
    Monad m,
    KeyOf k a,
    Failure ValidationErrors m
  ) =>
  SemigroupM m (HashMap k a)
  where
  mergeM :: [Ref] -> HashMap k a -> HashMap k a -> m (HashMap k a)
mergeM [Ref]
_ HashMap k a
x HashMap k a
y = ResolutionT k a (HashMap k a) m (HashMap k a)
-> ([(k, a)] -> HashMap k a)
-> (NonEmpty a -> m a)
-> m (HashMap k a)
forall k a coll (m :: * -> *) b.
ResolutionT k a coll m b
-> ([(k, a)] -> coll) -> (NonEmpty a -> m a) -> m b
runResolutionT ([(k, a)] -> ResolutionT k a (HashMap k a) m (HashMap k a)
forall (m :: * -> *) k a coll.
(Monad m, Eq k, Hashable k) =>
[(k, a)] -> ResolutionT k a coll m coll
fromListT ([(k, a)] -> ResolutionT k a (HashMap k a) m (HashMap k a))
-> [(k, a)] -> ResolutionT k a (HashMap k a) m (HashMap k a)
forall a b. (a -> b) -> a -> b
$ HashMap k a -> [(k, a)]
forall k v. HashMap k v -> [(k, v)]
HM.toList HashMap k a
x [(k, a)] -> [(k, a)] -> [(k, a)]
forall a. Semigroup a => a -> a -> a
<> HashMap k a -> [(k, a)]
forall k v. HashMap k v -> [(k, v)]
HM.toList HashMap k a
y) [(k, a)] -> HashMap k a
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
HM.fromList NonEmpty a -> m a
forall (m :: * -> *) a.
(Failure ValidationErrors m, NameCollision a) =>
NonEmpty a -> m a
failOnDuplicates

concatTraverse ::
  ( Monad m,
    Failure ValidationErrors m,
    Collection b cb,
    Elems a ca,
    SemigroupM m cb
  ) =>
  (a -> m cb) ->
  ca ->
  m cb
concatTraverse :: (a -> m cb) -> ca -> m cb
concatTraverse a -> m cb
f ca
smap =
  (a -> m cb) -> [a] -> m [cb]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> m cb
f (ca -> [a]
forall a coll. Elems a coll => coll -> [a]
elems ca
smap)
    m [cb] -> ([cb] -> m cb) -> m cb
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [cb] -> m cb
forall e a (m :: * -> *).
(Collection e a, Monad m, Failure ValidationErrors m,
 SemigroupM m a) =>
[a] -> m a
join

join ::
  ( Collection e a,
    Monad m,
    Failure ValidationErrors m,
    SemigroupM m a
  ) =>
  [a] ->
  m a
join :: [a] -> m a
join = a -> [a] -> m a
forall (f :: * -> *) t.
(Monad f, SemigroupM f t) =>
t -> [t] -> f t
__join a
forall a coll. Collection a coll => coll
empty
  where
    __join :: t -> [t] -> f t
__join t
acc [] = t -> f t
forall (f :: * -> *) a. Applicative f => a -> f a
pure t
acc
    __join t
acc (t
x : [t]
xs) = t
acc t -> t -> f t
forall (m :: * -> *) a. SemigroupM m a => a -> a -> m a
<:> t
x f t -> (t -> f t) -> f t
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (t -> [t] -> f t
`__join` [t]
xs)

(<:>) :: SemigroupM m a => a -> a -> m a
<:> :: a -> a -> m a
(<:>) = [Ref] -> a -> a -> m a
forall (m :: * -> *) a. SemigroupM m a => [Ref] -> a -> a -> m a
mergeM []