module Control.Subcategory.Bind
  (CBind(..), CMonad, creturn, (-<<)) where
import Control.Subcategory.Functor
import Control.Subcategory.Pointed

import           Control.Monad                   (join)
import qualified Control.Monad.ST.Lazy           as LST
import qualified Control.Monad.ST.Strict         as SST
import           Data.Coerce                     (coerce)
import           Data.Functor.Identity           (Identity)
import qualified Data.Functor.Product            as SOP
import           Data.Hashable                   (Hashable)
import qualified Data.HashMap.Strict             as HM
import qualified Data.HashSet                    as HS
import qualified Data.IntMap                     as IM
import qualified Data.IntSet                     as IS
import           Data.List.NonEmpty              (NonEmpty)
import qualified Data.Map                        as Map
import           Data.MonoTraversable
import qualified Data.Semigroup                  as Sem
import qualified Data.Sequence                   as Seq
import qualified Data.Set                        as Set
import qualified Data.Tree                       as Tree
import           GHC.Conc                        (STM)
import           Text.ParserCombinators.ReadP    (ReadP)
import           Text.ParserCombinators.ReadPrec (ReadPrec)

class CFunctor m => CBind m where
  (>>-) :: (Dom m a, Dom m b) => m a -> (a -> m b) -> m b
  default (>>-) :: (Dom m a, Dom m b, Dom m (m b)) => m a -> (a -> m b) -> m b
  m a
m >>- a -> m b
f = m (m b) -> m b
forall (m :: * -> *) a.
(CBind m, Dom m (m a), Dom m a) =>
m (m a) -> m a
cjoin ((a -> m b) -> m a -> m (m b)
forall (f :: * -> *) a b.
(CFunctor f, Dom f a, Dom f b) =>
(a -> b) -> f a -> f b
cmap a -> m b
f m a
m)
  cjoin :: (Dom m (m a), Dom m a) => m (m a) -> m a
  cjoin = (m (m a) -> (m a -> m a) -> m a
forall (m :: * -> *) a b.
(CBind m, Dom m a, Dom m b) =>
m a -> (a -> m b) -> m b
>>- m a -> m a
forall a. a -> a
id)

instance (Monad m) => CBind (WrapFunctor m) where
  (>>-) :: forall a b.
           WrapFunctor m a
        -> (a -> WrapFunctor m b) -> WrapFunctor m b
  >>- :: WrapFunctor m a -> (a -> WrapFunctor m b) -> WrapFunctor m b
(>>-) = (m a -> (a -> m b) -> m b)
-> WrapFunctor m a -> (a -> WrapFunctor m b) -> WrapFunctor m b
coerce @(m a -> (a -> m b) -> m b) m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  cjoin :: forall a. WrapFunctor m (WrapFunctor m a) -> WrapFunctor m a
  cjoin :: WrapFunctor m (WrapFunctor m a) -> WrapFunctor m a
cjoin (WrapFunctor m (WrapFunctor m a)
m) = m a -> WrapFunctor m a
forall (f :: * -> *) a. f a -> WrapFunctor f a
WrapFunctor (m a -> WrapFunctor m a) -> m a -> WrapFunctor m a
forall a b. (a -> b) -> a -> b
$ m (m a) -> m a
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join ((WrapFunctor m a -> m a) -> m (WrapFunctor m a) -> m (m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap WrapFunctor m a -> m a
coerce m (WrapFunctor m a)
m)

instance CBind [] where
  >>- :: [a] -> (a -> [b]) -> [b]
(>>-) = [a] -> (a -> [b]) -> [b]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  cjoin :: [[a]] -> [a]
cjoin  = [[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat

instance CBind IO where
  >>- :: IO a -> (a -> IO b) -> IO b
(>>-) = IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind STM where
  >>- :: STM a -> (a -> STM b) -> STM b
(>>-) = STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind (SST.ST s) where
  >>- :: ST s a -> (a -> ST s b) -> ST s b
(>>-) = ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind (LST.ST s) where
  >>- :: ST s a -> (a -> ST s b) -> ST s b
(>>-) = ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind Identity where
  >>- :: Identity a -> (a -> Identity b) -> Identity b
(>>-) = Identity a -> (a -> Identity b) -> Identity b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind (Either a) where
  >>- :: Either a a -> (a -> Either a b) -> Either a b
(>>-) = Either a a -> (a -> Either a b) -> Either a b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind Tree.Tree where
  >>- :: Tree a -> (a -> Tree b) -> Tree b
(>>-) = Tree a -> (a -> Tree b) -> Tree b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind Maybe where
  >>- :: Maybe a -> (a -> Maybe b) -> Maybe b
(>>-) = Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)

instance CBind IM.IntMap where
  IntMap a
m >>- :: IntMap a -> (a -> IntMap b) -> IntMap b
>>- a -> IntMap b
f = (Key -> a -> Maybe b) -> IntMap a -> IntMap b
forall a b. (Key -> a -> Maybe b) -> IntMap a -> IntMap b
IM.mapMaybeWithKey (\Key
k -> Key -> IntMap b -> Maybe b
forall a. Key -> IntMap a -> Maybe a
IM.lookup Key
k (IntMap b -> Maybe b) -> (a -> IntMap b) -> a -> Maybe b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> IntMap b
f) IntMap a
m

instance Ord k => CBind (Map.Map k) where
  Map k a
m >>- :: Map k a -> (a -> Map k b) -> Map k b
>>- a -> Map k b
f = (k -> a -> Maybe b) -> Map k a -> Map k b
forall k a b. (k -> a -> Maybe b) -> Map k a -> Map k b
Map.mapMaybeWithKey (\k
k -> k -> Map k b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
k (Map k b -> Maybe b) -> (a -> Map k b) -> a -> Maybe b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Map k b
f) Map k a
m

instance (Hashable k, Eq k) => CBind (HM.HashMap k) where
  HashMap k a
m >>- :: HashMap k a -> (a -> HashMap k b) -> HashMap k b
>>- a -> HashMap k b
f = (k -> a -> Maybe b) -> HashMap k a -> HashMap k b
forall k v1 v2.
(k -> v1 -> Maybe v2) -> HashMap k v1 -> HashMap k v2
HM.mapMaybeWithKey (\k
k -> k -> HashMap k b -> Maybe b
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HM.lookup k
k (HashMap k b -> Maybe b) -> (a -> HashMap k b) -> a -> Maybe b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> HashMap k b
f) HashMap k a
m

instance CBind Set.Set where
  >>- :: Set a -> (a -> Set b) -> Set b
(>>-) = ((a -> Set b) -> Set a -> Set b) -> Set a -> (a -> Set b) -> Set b
forall a b c. (a -> b -> c) -> b -> a -> c
flip (a -> Set b) -> Set a -> Set b
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
  {-# INLINE (>>-) #-}
  cjoin :: Set (Set a) -> Set a
cjoin = (Set a -> Set a) -> Set (Set a) -> Set a
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Set a -> Set a
forall a. a -> a
id
  {-# INLINE cjoin #-}

instance CBind (WrapMono IS.IntSet) where
  >>- :: WrapMono IntSet a -> (a -> WrapMono IntSet b) -> WrapMono IntSet b
(>>-) = (Coercible (WrapMono Any (Element Any)) Any =>
 WrapMono IntSet a -> (a -> WrapMono IntSet b) -> WrapMono IntSet b)
-> WrapMono IntSet a
-> (a -> WrapMono IntSet b)
-> WrapMono IntSet b
forall mono r.
(Coercible (WrapMono mono (Element mono)) mono => r) -> r
withMonoCoercible ((Coercible (WrapMono Any (Element Any)) Any =>
  WrapMono IntSet a -> (a -> WrapMono IntSet b) -> WrapMono IntSet b)
 -> WrapMono IntSet a
 -> (a -> WrapMono IntSet b)
 -> WrapMono IntSet b)
-> (Coercible (WrapMono Any (Element Any)) Any =>
    WrapMono IntSet a -> (a -> WrapMono IntSet b) -> WrapMono IntSet b)
-> WrapMono IntSet a
-> (a -> WrapMono IntSet b)
-> WrapMono IntSet b
forall a b. (a -> b) -> a -> b
$ ((Key -> WrapMono IntSet Key)
 -> WrapMono IntSet Key -> WrapMono IntSet Key)
-> WrapMono IntSet Key
-> (Key -> WrapMono IntSet Key)
-> WrapMono IntSet Key
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Key -> WrapMono IntSet Key)
-> WrapMono IntSet Key -> WrapMono IntSet Key
forall mono m.
(MonoFoldable mono, Monoid m) =>
(Element mono -> m) -> mono -> m
ofoldMap
  {-# INLINE (>>-) #-}

instance CBind NonEmpty where
  >>- :: NonEmpty a -> (a -> NonEmpty b) -> NonEmpty b
(>>-) = NonEmpty a -> (a -> NonEmpty b) -> NonEmpty b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  {-# INLINE (>>-) #-}

instance CBind Seq.Seq where
  >>- :: Seq a -> (a -> Seq b) -> Seq b
(>>-) = Seq a -> (a -> Seq b) -> Seq b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  {-# INLINE (>>-) #-}

instance CBind Sem.Option where
  >>- :: Option a -> (a -> Option b) -> Option b
(>>-) = Option a -> (a -> Option b) -> Option b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  {-# INLINE (>>-) #-}

instance CBind ((->) a) where
  >>- :: (a -> a) -> (a -> a -> b) -> a -> b
(>>-) = (a -> a) -> (a -> a -> b) -> a -> b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  {-# INLINE (>>-) #-}

instance CBind HS.HashSet where
  >>- :: HashSet a -> (a -> HashSet b) -> HashSet b
(>>-) = ((a -> HashSet b) -> HashSet a -> HashSet b)
-> HashSet a -> (a -> HashSet b) -> HashSet b
forall a b c. (a -> b -> c) -> b -> a -> c
flip (a -> HashSet b) -> HashSet a -> HashSet b
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
  {-# INLINE (>>-) #-}
  cjoin :: HashSet (HashSet a) -> HashSet a
cjoin = (HashSet a -> HashSet a) -> HashSet (HashSet a) -> HashSet a
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap HashSet a -> HashSet a
forall a. a -> a
id
  {-# INLINE cjoin #-}

instance CBind ReadP where
  >>- :: ReadP a -> (a -> ReadP b) -> ReadP b
(>>-) = ReadP a -> (a -> ReadP b) -> ReadP b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  {-# INLINE (>>-) #-}

instance CBind ReadPrec where
  >>- :: ReadPrec a -> (a -> ReadPrec b) -> ReadPrec b
(>>-) = ReadPrec a -> (a -> ReadPrec b) -> ReadPrec b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(>>=)
  {-# INLINE (>>-) #-}

instance Semigroup w => CBind ((,) w) where
  (w
m, a
a) >>- :: (w, a) -> (a -> (w, b)) -> (w, b)
>>- a -> (w, b)
f =
    let (w
w, b
b) = a -> (w, b)
f a
a
    in (w
m w -> w -> w
forall a. Semigroup a => a -> a -> a
<> w
w, b
b)
  {-# INLINE (>>-) #-}
  cjoin :: (w, (w, a)) -> (w, a)
cjoin (w
w, (w
m, a
a)) = (w
w w -> w -> w
forall a. Semigroup a => a -> a -> a
<> w
m, a
a)
  {-# INLINE cjoin #-}

infixl 1 >>-
infixr 1 -<<

(-<<) :: (Dom m b, Dom m a, CBind m) => (a -> m b) -> m a -> m b
-<< :: (a -> m b) -> m a -> m b
(-<<) = (m a -> (a -> m b) -> m b) -> (a -> m b) -> m a -> m b
forall a b c. (a -> b -> c) -> b -> a -> c
flip m a -> (a -> m b) -> m b
forall (m :: * -> *) a b.
(CBind m, Dom m a, Dom m b) =>
m a -> (a -> m b) -> m b
(>>-)
{-# INLINE (-<<) #-}

instance (CBind m, CBind n) => CBind (SOP.Product m n) where
  (SOP.Pair m a
a n a
b) >>- :: Product m n a -> (a -> Product m n b) -> Product m n b
>>- a -> Product m n b
f = m b -> n b -> Product m n b
forall k (f :: k -> *) (g :: k -> *) (a :: k).
f a -> g a -> Product f g a
SOP.Pair (m a
a m a -> (a -> m b) -> m b
forall (m :: * -> *) a b.
(CBind m, Dom m a, Dom m b) =>
m a -> (a -> m b) -> m b
>>- Product m n b -> m b
forall k (f :: k -> *) (g :: k -> *) (a :: k). Product f g a -> f a
fstP (Product m n b -> m b) -> (a -> Product m n b) -> a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Product m n b
f) (n a
b n a -> (a -> n b) -> n b
forall (m :: * -> *) a b.
(CBind m, Dom m a, Dom m b) =>
m a -> (a -> m b) -> m b
>>- Product m n b -> n b
forall k (f :: k -> *) (g :: k -> *) (a :: k). Product f g a -> g a
sndP (Product m n b -> n b) -> (a -> Product m n b) -> a -> n b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Product m n b
f)
    where
      fstP :: Product f g a -> f a
fstP (SOP.Pair f a
x g a
_) = f a
x
      sndP :: Product f g a -> g a
sndP (SOP.Pair f a
_ g a
y) = g a
y
  {-# INLINE (>>-) #-}

class    (CBind f, CPointed f) => CMonad f
instance (CBind f, CPointed f) => CMonad f

creturn :: (Dom m a, CMonad m) => a -> m a
creturn :: a -> m a
creturn = a -> m a
forall (f :: * -> *) a. (CPointed f, Dom f a) => a -> f a
cpure
{-# INLINE creturn #-}