{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module      : Data.Binding.Hobbits.MonadBind
-- Copyright   : (c) 2020 Edwin Westbrook
--
-- License     : BSD3
--
-- Maintainer  : westbrook@galois.com
-- Stability   : experimental
-- Portability : GHC
--
-- This module defines monads that are compatible with the notion of
-- name-binding, where a monad is compatible with name-binding iff it can
-- intuitively run computations that are inside name-bindings. More formally, a
-- /binding monad/ is a monad with an operation 'mbM' that commutes name-binding
-- with the monadic operations, meaning:
--
-- > 'mbM' ('nuMulti' $ \ns -> 'return' a) == 'return' ('nuMulti' $ \ns -> a)
-- > 'mbM' ('nuMulti' $ \ns -> m >>= f)
-- >   == 'mbM' ('nuMulti' $ \ns -> m) >>= \mb_x ->
-- >      'mbM' (('nuMulti' $ \ns x -> f x) `'mbApply'` mb_x)

module Data.Binding.Hobbits.MonadBind (MonadBind(..), MonadStrongBind(..)) where

import Data.Binding.Hobbits.Closed
import Data.Binding.Hobbits.Liftable (mbLift)
import Data.Binding.Hobbits.Mb
import Data.Binding.Hobbits.NuMatching
import Data.Binding.Hobbits.QQ

import Control.Monad.Identity (Identity(..))
import Control.Monad.Reader (ReaderT(..))
import Control.Monad.State (StateT(..), get, lift, put, runStateT)

-- | The class of name-binding monads
class Monad m => MonadBind m where
  mbM :: NuMatching a => Mb ctx (m a) -> m (Mb ctx a)

-- | Bind a name inside a computation and return the name-binding whose body was
-- returned by the computation
nuM :: (MonadBind m, NuMatching b) => (Name a -> m b) -> m (Binding a b)
nuM :: (Name a -> m b) -> m (Binding a b)
nuM = Mb ('RNil ':> a) (m b) -> m (Binding a b)
forall (m :: * -> *) k a (ctx :: RList k).
(MonadBind m, NuMatching a) =>
Mb ctx (m a) -> m (Mb ctx a)
mbM (Mb ('RNil ':> a) (m b) -> m (Binding a b))
-> ((Name a -> m b) -> Mb ('RNil ':> a) (m b))
-> (Name a -> m b)
-> m (Binding a b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name a -> m b) -> Mb ('RNil ':> a) (m b)
forall k1 (a :: k1) b. (Name a -> b) -> Binding a b
nu

instance MonadBind Identity where
  mbM :: Mb ctx (Identity a) -> Identity (Mb ctx a)
mbM = Mb ctx a -> Identity (Mb ctx a)
forall a. a -> Identity a
Identity (Mb ctx a -> Identity (Mb ctx a))
-> (Mb ctx (Identity a) -> Mb ctx a)
-> Mb ctx (Identity a)
-> Identity (Mb ctx a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Identity a -> a) -> Mb ctx (Identity a) -> Mb ctx a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Identity a -> a
forall a. Identity a -> a
runIdentity

instance MonadBind Maybe where
  mbM :: Mb ctx (Maybe a) -> Maybe (Mb ctx a)
mbM Mb ctx (Maybe a)
[nuP| Just x |] = Mb ctx a -> Maybe (Mb ctx a)
forall (m :: * -> *) a. Monad m => a -> m a
return Mb ctx a
x
  mbM Mb ctx (Maybe a)
[nuP| Nothing |] = Maybe (Mb ctx a)
forall a. Maybe a
Nothing

instance MonadBind m => MonadBind (ReaderT r m) where
  mbM :: Mb ctx (ReaderT r m a) -> ReaderT r m (Mb ctx a)
mbM Mb ctx (ReaderT r m a)
mb = (r -> m (Mb ctx a)) -> ReaderT r m (Mb ctx a)
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m (Mb ctx a)) -> ReaderT r m (Mb ctx a))
-> (r -> m (Mb ctx a)) -> ReaderT r m (Mb ctx a)
forall a b. (a -> b) -> a -> b
$ \r
r -> Mb ctx (m a) -> m (Mb ctx a)
forall (m :: * -> *) k a (ctx :: RList k).
(MonadBind m, NuMatching a) =>
Mb ctx (m a) -> m (Mb ctx a)
mbM (Mb ctx (m a) -> m (Mb ctx a)) -> Mb ctx (m a) -> m (Mb ctx a)
forall a b. (a -> b) -> a -> b
$ (ReaderT r m a -> m a) -> Mb ctx (ReaderT r m a) -> Mb ctx (m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((ReaderT r m a -> r -> m a) -> r -> ReaderT r m a -> m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT r m a -> r -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT r
r) Mb ctx (ReaderT r m a)
mb

-- | A version of 'MonadBind' that does not require a 'NuMatching' instance on
-- the element type of the multi-binding in the monad
class MonadBind m => MonadStrongBind m where
  strongMbM :: Mb ctx (m a) -> m (Mb ctx a)

instance MonadStrongBind Identity where
  strongMbM :: Mb ctx (Identity a) -> Identity (Mb ctx a)
strongMbM = Mb ctx a -> Identity (Mb ctx a)
forall a. a -> Identity a
Identity (Mb ctx a -> Identity (Mb ctx a))
-> (Mb ctx (Identity a) -> Mb ctx a)
-> Mb ctx (Identity a)
-> Identity (Mb ctx a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Identity a -> a) -> Mb ctx (Identity a) -> Mb ctx a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Identity a -> a
forall a. Identity a -> a
runIdentity

instance MonadStrongBind m => MonadStrongBind (ReaderT r m) where
  strongMbM :: Mb ctx (ReaderT r m a) -> ReaderT r m (Mb ctx a)
strongMbM Mb ctx (ReaderT r m a)
mb = (r -> m (Mb ctx a)) -> ReaderT r m (Mb ctx a)
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m (Mb ctx a)) -> ReaderT r m (Mb ctx a))
-> (r -> m (Mb ctx a)) -> ReaderT r m (Mb ctx a)
forall a b. (a -> b) -> a -> b
$ \r
r -> Mb ctx (m a) -> m (Mb ctx a)
forall (m :: * -> *) k (ctx :: RList k) a.
MonadStrongBind m =>
Mb ctx (m a) -> m (Mb ctx a)
strongMbM (Mb ctx (m a) -> m (Mb ctx a)) -> Mb ctx (m a) -> m (Mb ctx a)
forall a b. (a -> b) -> a -> b
$ (ReaderT r m a -> m a) -> Mb ctx (ReaderT r m a) -> Mb ctx (m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((ReaderT r m a -> r -> m a) -> r -> ReaderT r m a -> m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT r m a -> r -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT r
r) Mb ctx (ReaderT r m a)
mb

-- | State types that can incorporate name-bindings
class NuMatching s => BindState s where
  bindState :: Mb ctx s -> s

instance BindState (Closed s) where
  bindState :: Mb ctx (Closed s) -> Closed s
bindState = Mb ctx (Closed s) -> Closed s
forall a k (ctx :: RList k). Liftable a => Mb ctx a -> a
mbLift

instance (MonadBind m, BindState s) => MonadBind (StateT s m) where
  mbM :: Mb ctx (StateT s m a) -> StateT s m (Mb ctx a)
mbM Mb ctx (StateT s m a)
mb_m = (s -> m (Mb ctx a, s)) -> StateT s m (Mb ctx a)
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (Mb ctx a, s)) -> StateT s m (Mb ctx a))
-> (s -> m (Mb ctx a, s)) -> StateT s m (Mb ctx a)
forall a b. (a -> b) -> a -> b
$ \s
s ->
    Mb ctx (m (a, s)) -> m (Mb ctx (a, s))
forall (m :: * -> *) k a (ctx :: RList k).
(MonadBind m, NuMatching a) =>
Mb ctx (m a) -> m (Mb ctx a)
mbM ((StateT s m a -> m (a, s))
-> Mb ctx (StateT s m a) -> Mb ctx (m (a, s))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\StateT s m a
m -> StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
m s
s) Mb ctx (StateT s m a)
mb_m) m (Mb ctx (a, s))
-> (Mb ctx (a, s) -> m (Mb ctx a, s)) -> m (Mb ctx a, s)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Mb ctx (a, s)
mb_as ->
    (Mb ctx a, s) -> m (Mb ctx a, s)
forall (m :: * -> *) a. Monad m => a -> m a
return (((a, s) -> a) -> Mb ctx (a, s) -> Mb ctx a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, s) -> a
forall a b. (a, b) -> a
fst Mb ctx (a, s)
mb_as, Mb ctx s -> s
forall s k (ctx :: RList k). BindState s => Mb ctx s -> s
bindState (((a, s) -> s) -> Mb ctx (a, s) -> Mb ctx s
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, s) -> s
forall a b. (a, b) -> b
snd Mb ctx (a, s)
mb_as))

instance (MonadStrongBind m, BindState s) => MonadStrongBind (StateT s m) where
  strongMbM :: Mb ctx (StateT s m a) -> StateT s m (Mb ctx a)
strongMbM Mb ctx (StateT s m a)
mb_m = (s -> m (Mb ctx a, s)) -> StateT s m (Mb ctx a)
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((s -> m (Mb ctx a, s)) -> StateT s m (Mb ctx a))
-> (s -> m (Mb ctx a, s)) -> StateT s m (Mb ctx a)
forall a b. (a -> b) -> a -> b
$ \s
s ->
    Mb ctx (m (a, s)) -> m (Mb ctx (a, s))
forall (m :: * -> *) k (ctx :: RList k) a.
MonadStrongBind m =>
Mb ctx (m a) -> m (Mb ctx a)
strongMbM ((StateT s m a -> m (a, s))
-> Mb ctx (StateT s m a) -> Mb ctx (m (a, s))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\StateT s m a
m -> StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
m s
s) Mb ctx (StateT s m a)
mb_m) m (Mb ctx (a, s))
-> (Mb ctx (a, s) -> m (Mb ctx a, s)) -> m (Mb ctx a, s)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Mb ctx (a, s)
mb_as ->
    (Mb ctx a, s) -> m (Mb ctx a, s)
forall (m :: * -> *) a. Monad m => a -> m a
return (((a, s) -> a) -> Mb ctx (a, s) -> Mb ctx a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, s) -> a
forall a b. (a, b) -> a
fst Mb ctx (a, s)
mb_as, Mb ctx s -> s
forall s k (ctx :: RList k). BindState s => Mb ctx s -> s
bindState (((a, s) -> s) -> Mb ctx (a, s) -> Mb ctx s
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, s) -> s
forall a b. (a, b) -> b
snd Mb ctx (a, s)
mb_as))


-- | A monad whose effects are closed
class Monad m => MonadClosed m where
  closedM :: Closed (m a) -> m (Closed a)

instance MonadClosed Identity where
  closedM :: Closed (Identity a) -> Identity (Closed a)
closedM = Closed a -> Identity (Closed a)
forall a. a -> Identity a
Identity (Closed a -> Identity (Closed a))
-> (Closed (Identity a) -> Closed a)
-> Closed (Identity a)
-> Identity (Closed a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Closed (Identity a -> a) -> Closed (Identity a) -> Closed a
forall a b. Closed (a -> b) -> Closed a -> Closed b
clApply $(mkClosed [| runIdentity |])

instance (MonadClosed m, Closable s) => MonadClosed (StateT s m) where
  closedM :: Closed (StateT s m a) -> StateT s m (Closed a)
closedM Closed (StateT s m a)
clm =
    do s
s <- StateT s m s
forall s (m :: * -> *). MonadState s m => m s
get
       Closed (a, s)
cl_a_s <- m (Closed (a, s)) -> StateT s m (Closed (a, s))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Closed (a, s)) -> StateT s m (Closed (a, s)))
-> m (Closed (a, s)) -> StateT s m (Closed (a, s))
forall a b. (a -> b) -> a -> b
$ Closed (m (a, s)) -> m (Closed (a, s))
forall (m :: * -> *) a.
MonadClosed m =>
Closed (m a) -> m (Closed a)
closedM ($(mkClosed [| runStateT |]) Closed (StateT s m a -> s -> m (a, s))
-> Closed (StateT s m a) -> Closed (s -> m (a, s))
forall a b. Closed (a -> b) -> Closed a -> Closed b
`clApply` Closed (StateT s m a)
clm
                                 Closed (s -> m (a, s)) -> Closed s -> Closed (m (a, s))
forall a b. Closed (a -> b) -> Closed a -> Closed b
`clApply` s -> Closed s
forall a. Closable a => a -> Closed a
toClosed s
s)
       s -> StateT s m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put ((a, s) -> s
forall a b. (a, b) -> b
snd ((a, s) -> s) -> (a, s) -> s
forall a b. (a -> b) -> a -> b
$ Closed (a, s) -> (a, s)
forall a. Closed a -> a
unClosed Closed (a, s)
cl_a_s)
       Closed a -> StateT s m (Closed a)
forall (m :: * -> *) a. Monad m => a -> m a
return ($(mkClosed [| fst |]) Closed ((a, s) -> a) -> Closed (a, s) -> Closed a
forall a b. Closed (a -> b) -> Closed a -> Closed b
`clApply` Closed (a, s)
cl_a_s)