{- |
Module      :  Control.Monad.Trans.Indexed
Copyright   :  (C) 2024 Eitan Chatav
License     :  BSD 3-Clause License (see the file LICENSE)
Maintainer  :  Eitan Chatav <eitan.chatav@gmail.com>

Indexed monad transformers.
-}

module Control.Monad.Trans.Indexed
  ( IxMonadTrans (..)
  , Indexed (..)
  , (&)
  ) where

import Control.Category (Category (..))
import Control.Monad
import Control.Monad.Trans
import Data.Function ((&))
import Data.Kind
import Prelude hiding (id, (.))

{- |
An [Atkey indexed monad]
(https://bentnib.org/paramnotions-jfp.pdf)
is a `Functor` [enriched category]
(https://ncatlab.org/nlab/show/enriched+category).
An indexed monad transformer transforms a `Monad` into an indexed monad.
It is a monad and monad transformer when its source and target index
are the same, enabling use of standard @do@ notation in that case.
In the general case, qualified @Indexed.do@ notation can be used,
even if the source and target index are different.

>>> :set -XQualifiedDo
>>> import qualified Control.Monad.Trans.Indexed.Do as Indexed
-}
type IxMonadTrans
  :: (k -> k -> (Type -> Type) -> Type -> Type)
  -> Constraint
class
  ( forall i j m. Monad m => Functor (t i j m)
  , forall i j m. (i ~ j, Monad m) => Monad (t i j m)
  , forall i j. i ~ j => MonadTrans (t i j)
  ) => IxMonadTrans t where

  {-# MINIMAL joinIx | bindIx #-}

  {- |
  indexed analog of `<*>`

  prop> (<*>) = apIx
  -}
  apIx
    :: Monad m
    => t i j m (x -> y)
    -> t j k m x
    -> t i k m y
  apIx t i j m (x -> y)
tf t j k m x
tx = ((x -> y) -> t j k m y) -> t i j m (x -> y) -> t i k m y
forall k (t :: k -> k -> (* -> *) -> * -> *) (m :: * -> *) x
       (j :: k) (k :: k) y (i :: k).
(IxMonadTrans t, Monad m) =>
(x -> t j k m y) -> t i j m x -> t i k m y
forall (m :: * -> *) x (j :: k) (k :: k) y (i :: k).
Monad m =>
(x -> t j k m y) -> t i j m x -> t i k m y
bindIx ((x -> y) -> t j k m x -> t j k m y
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t j k m x
tx) t i j m (x -> y)
tf

  {- |
  indexed analog of `join`

  prop> join = joinIx
  prop> joinIx = bindIx id
  -}
  joinIx
    :: Monad m
    => t i j m (t j k m y)
    -> t i k m y
  joinIx = (t j k m y -> t j k m y) -> t i j m (t j k m y) -> t i k m y
forall k (t :: k -> k -> (* -> *) -> * -> *) (m :: * -> *) x
       (j :: k) (k :: k) y (i :: k).
(IxMonadTrans t, Monad m) =>
(x -> t j k m y) -> t i j m x -> t i k m y
forall (m :: * -> *) x (j :: k) (k :: k) y (i :: k).
Monad m =>
(x -> t j k m y) -> t i j m x -> t i k m y
bindIx t j k m y -> t j k m y
forall a. a -> a
forall {k} (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id

  {- |
  indexed analog of `=<<`

  prop> (=<<) = bindIx
  prop> bindIx f x = joinIx (f <$> x)
  prop> x & bindIx return = x
  prop> x & bindIx f & bindIx g = x & bindIx (f & andThenIx g)
  -}
  bindIx
    :: Monad m
    => (x -> t j k m y)
    -> t i j m x
    -> t i k m y
  bindIx x -> t j k m y
f t i j m x
t = t i j m (t j k m y) -> t i k m y
forall k (t :: k -> k -> (* -> *) -> * -> *) (m :: * -> *) (i :: k)
       (j :: k) (k :: k) y.
(IxMonadTrans t, Monad m) =>
t i j m (t j k m y) -> t i k m y
forall (m :: * -> *) (i :: k) (j :: k) (k :: k) y.
Monad m =>
t i j m (t j k m y) -> t i k m y
joinIx (x -> t j k m y
f (x -> t j k m y) -> t i j m x -> t i j m (t j k m y)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t i j m x
t)

  {- |
  indexed analog of flipped `>>`

  prop> (>>) = flip thenIx
  prop> return () & thenIx y = y
  -}
  thenIx
    :: Monad m
    => t j k m y
    -> t i j m x
    -> t i k m y
  thenIx t j k m y
ix2 t i j m x
ix1 = t i j m x
ix1 t i j m x -> (t i j m x -> t i k m y) -> t i k m y
forall a b. a -> (a -> b) -> b
& (x -> t j k m y) -> t i j m x -> t i k m y
forall k (t :: k -> k -> (* -> *) -> * -> *) (m :: * -> *) x
       (j :: k) (k :: k) y (i :: k).
(IxMonadTrans t, Monad m) =>
(x -> t j k m y) -> t i j m x -> t i k m y
forall (m :: * -> *) x (j :: k) (k :: k) y (i :: k).
Monad m =>
(x -> t j k m y) -> t i j m x -> t i k m y
bindIx (\ x
_ -> t j k m y
ix2)

  {- |
  indexed analog of `<=<`

  prop> (<=<) = andThenIx
  prop> andThenIx g f x = bindIx g (f x)
  prop> f & andThenIx return = f
  prop> return & andThenIx f = f
  prop> f & andThenIx g & andThenIx h = f & andThenIx (g & andThenIx h)
  -}
  andThenIx
    :: Monad m
    => (y -> t j k m z)
    -> (x -> t i j m y)
    -> x -> t i k m z
  andThenIx y -> t j k m z
g x -> t i j m y
f x
x = (y -> t j k m z) -> t i j m y -> t i k m z
forall k (t :: k -> k -> (* -> *) -> * -> *) (m :: * -> *) x
       (j :: k) (k :: k) y (i :: k).
(IxMonadTrans t, Monad m) =>
(x -> t j k m y) -> t i j m x -> t i k m y
forall (m :: * -> *) x (j :: k) (k :: k) y (i :: k).
Monad m =>
(x -> t j k m y) -> t i j m x -> t i k m y
bindIx y -> t j k m z
g (x -> t i j m y
f x
x)

{- |
`Indexed` reshuffles the type parameters of an `IxMonadTrans`,
exposing its `Category` instance.
-}
newtype Indexed t m r i j = Indexed {forall {k} {k} {k} {k} (t :: k -> k -> k -> k -> *) (m :: k)
       (r :: k) (i :: k) (j :: k).
Indexed t m r i j -> t i j m r
runIndexed :: t i j m r}
instance
  ( IxMonadTrans t
  , Monad m
  , Monoid r
  ) => Category (Indexed t m r) where
    id :: forall (a :: k). Indexed t m r a a
id = t a a m r -> Indexed t m r a a
forall {k} {k} {k} {k} (t :: k -> k -> k -> k -> *) (m :: k)
       (r :: k) (i :: k) (j :: k).
t i j m r -> Indexed t m r i j
Indexed (r -> t a a m r
forall a. a -> t a a m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure r
forall a. Monoid a => a
mempty)
    Indexed t b c m r
g . :: forall (b :: k) (c :: k) (a :: k).
Indexed t m r b c -> Indexed t m r a b -> Indexed t m r a c
. Indexed t a b m r
f = t a c m r -> Indexed t m r a c
forall {k} {k} {k} {k} (t :: k -> k -> k -> k -> *) (m :: k)
       (r :: k) (i :: k) (j :: k).
t i j m r -> Indexed t m r i j
Indexed (t a c m r -> Indexed t m r a c) -> t a c m r -> Indexed t m r a c
forall a b. (a -> b) -> a -> b
$ t a b m (r -> r) -> t b c m r -> t a c m r
forall k (t :: k -> k -> (* -> *) -> * -> *) (m :: * -> *) (i :: k)
       (j :: k) x y (k :: k).
(IxMonadTrans t, Monad m) =>
t i j m (x -> y) -> t j k m x -> t i k m y
forall (m :: * -> *) (i :: k) (j :: k) x y (k :: k).
Monad m =>
t i j m (x -> y) -> t j k m x -> t i k m y
apIx ((r -> r -> r) -> t a b m r -> t a b m (r -> r)
forall a b. (a -> b) -> t a b m a -> t a b m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap r -> r -> r
forall a. Semigroup a => a -> a -> a
(<>) t a b m r
f) t b c m r
g