{-# LANGUAGE GADTs              #-}
{-# LANGUAGE ImpredicativeTypes #-}

-- | This module defines `Choreo`, the monad for writing choreographies.
module Choreography.Choreo where

import Choreography.Location
import Choreography.Network
import Control.Monad.Freer
import Data.List
import Data.Proxy
import GHC.TypeLits

-- * The Choreo monad

-- | A constrained version of `unwrap` that only unwraps values located at a
-- specific location.
type Unwrap l = forall a. a @ l -> a

-- | Effect signature for the `Choreo` monad. @m@ is a monad that represents
-- local computations.
data ChoreoSig m a where
  Local :: (KnownSymbol l)
        => Proxy l
        -> (Unwrap l -> m a)
        -> ChoreoSig m (a @ l)

  Comm :: (Show a, Read a, KnownSymbol l, KnownSymbol l')
       => Proxy l
       -> a @ l
       -> Proxy l'
       -> ChoreoSig m (a @ l')

  Cond :: (Show a, Read a, KnownSymbol l)
       => Proxy l
       -> a @ l
       -> (a -> Choreo m b)
       -> ChoreoSig m b

-- | Monad for writing choreographies.
type Choreo m = Freer (ChoreoSig m)

-- | Run a `Choreo` monad directly.
runChoreo :: Monad m => Choreo m a -> m a
runChoreo :: forall (m :: * -> *) a. Monad m => Choreo m a -> m a
runChoreo = forall (m :: * -> *) (f :: * -> *) a.
Monad m =>
(forall a1. f a1 -> m a1) -> Freer f a -> m a
interpFreer forall (m :: * -> *) a. Monad m => ChoreoSig m a -> m a
handler
  where
    handler :: Monad m => ChoreoSig m a -> m a
    handler :: forall (m :: * -> *) a. Monad m => ChoreoSig m a -> m a
handler (Local Proxy l
_ Unwrap l -> m a
m)  = forall a (l :: Symbol). a -> a @ l
wrap forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Unwrap l -> m a
m forall a (l :: Symbol). (a @ l) -> a
unwrap
    handler (Comm Proxy l
_ a @ l
a Proxy l'
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ (forall a (l :: Symbol). a -> a @ l
wrap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (l :: Symbol). (a @ l) -> a
unwrap) a @ l
a
    handler (Cond Proxy l
_ a @ l
a a -> Choreo m a
c) = forall (m :: * -> *) a. Monad m => Choreo m a -> m a
runChoreo forall a b. (a -> b) -> a -> b
$ a -> Choreo m a
c (forall a (l :: Symbol). (a @ l) -> a
unwrap a @ l
a)

-- | Endpoint projection.
epp :: Choreo m a -> LocTm -> Network m a
epp :: forall (m :: * -> *) a. Choreo m a -> LocTm -> Network m a
epp Choreo m a
c LocTm
l' = forall (m :: * -> *) (f :: * -> *) a.
Monad m =>
(forall a1. f a1 -> m a1) -> Freer f a -> m a
interpFreer forall (m :: * -> *) a. ChoreoSig m a -> Network m a
handler Choreo m a
c
  where
    handler :: ChoreoSig m a -> Network m a
    handler :: forall (m :: * -> *) a. ChoreoSig m a -> Network m a
handler (Local Proxy l
l Unwrap l -> m a
m)
      | forall (l :: Symbol). KnownSymbol l => Proxy l -> LocTm
toLocTm Proxy l
l forall a. Eq a => a -> a -> Bool
== LocTm
l' = forall a (l :: Symbol). a -> a @ l
wrap forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. m a -> Network m a
run (Unwrap l -> m a
m forall a (l :: Symbol). (a @ l) -> a
unwrap)
      | Bool
otherwise       = forall (m :: * -> *) a. Monad m => a -> m a
return forall a (l :: Symbol). a @ l
Empty
    handler (Comm Proxy l
s a @ l
a Proxy l'
r)
      | forall (l :: Symbol). KnownSymbol l => Proxy l -> LocTm
toLocTm Proxy l
s forall a. Eq a => a -> a -> Bool
== LocTm
l' = forall a (m :: * -> *). Show a => a -> LocTm -> Network m ()
send (forall a (l :: Symbol). (a @ l) -> a
unwrap a @ l
a) (forall (l :: Symbol). KnownSymbol l => Proxy l -> LocTm
toLocTm Proxy l'
r) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return forall a (l :: Symbol). a @ l
Empty
      | forall (l :: Symbol). KnownSymbol l => Proxy l -> LocTm
toLocTm Proxy l'
r forall a. Eq a => a -> a -> Bool
== LocTm
l' = forall a (l :: Symbol). a -> a @ l
wrap forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *). Read a => LocTm -> Network m a
recv (forall (l :: Symbol). KnownSymbol l => Proxy l -> LocTm
toLocTm Proxy l
s)
      | Bool
otherwise       = forall (m :: * -> *) a. Monad m => a -> m a
return forall a (l :: Symbol). a @ l
Empty
    handler (Cond Proxy l
l a @ l
a a -> Choreo m a
c)
      | forall (l :: Symbol). KnownSymbol l => Proxy l -> LocTm
toLocTm Proxy l
l forall a. Eq a => a -> a -> Bool
== LocTm
l' = forall a (m :: * -> *). Show a => a -> Network m ()
broadcast (forall a (l :: Symbol). (a @ l) -> a
unwrap a @ l
a) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Choreo m a -> LocTm -> Network m a
epp (a -> Choreo m a
c (forall a (l :: Symbol). (a @ l) -> a
unwrap a @ l
a)) LocTm
l'
      | Bool
otherwise       = forall a (m :: * -> *). Read a => LocTm -> Network m a
recv (forall (l :: Symbol). KnownSymbol l => Proxy l -> LocTm
toLocTm Proxy l
l) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
x -> forall (m :: * -> *) a. Choreo m a -> LocTm -> Network m a
epp (a -> Choreo m a
c a
x) LocTm
l'

-- * Choreo operations

-- | Perform a local computation at a given location.
locally :: KnownSymbol l
        => Proxy l           -- ^ Location performing the local computation.
        -> (Unwrap l -> m a) -- ^ The local computation given a constrained
                             -- unwrap funciton.
        -> Choreo m (a @ l)
locally :: forall (l :: Symbol) (m :: * -> *) a.
KnownSymbol l =>
Proxy l -> (Unwrap l -> m a) -> Choreo m (a @ l)
locally Proxy l
l Unwrap l -> m a
m = forall (f :: * -> *) a. f a -> Freer f a
toFreer (forall (a :: Symbol) (m :: * -> *) l.
KnownSymbol a =>
Proxy a -> (Unwrap a -> m l) -> ChoreoSig m (l @ a)
Local Proxy l
l Unwrap l -> m a
m)

-- | Communication between a sender and a receiver.
(~>) :: (Show a, Read a, KnownSymbol l, KnownSymbol l')
     => (Proxy l, a @ l)  -- ^ A pair of a sender's location and a value located
                          -- at the sender
     -> Proxy l'          -- ^ A receiver's location.
     -> Choreo m (a @ l')
~> :: forall a (l :: Symbol) (l' :: Symbol) (m :: * -> *).
(Show a, Read a, KnownSymbol l, KnownSymbol l') =>
(Proxy l, a @ l) -> Proxy l' -> Choreo m (a @ l')
(~>) (Proxy l
l, a @ l
a) Proxy l'
l' = forall (f :: * -> *) a. f a -> Freer f a
toFreer (forall a (l :: Symbol) (l' :: Symbol) (m :: * -> *).
(Show a, Read a, KnownSymbol l, KnownSymbol l') =>
Proxy l -> (a @ l) -> Proxy l' -> ChoreoSig m (a @ l')
Comm Proxy l
l a @ l
a Proxy l'
l')

-- | Conditionally execute choreographies based on a located value.
cond :: (Show a, Read a, KnownSymbol l)
     => (Proxy l, a @ l)  -- ^ A pair of a location and a scrutinee located on
                          -- it.
     -> (a -> Choreo m b) -- ^ A function that describes the follow-up
                          -- choreographies based on the value of scrutinee.
     -> Choreo m b
cond :: forall a (l :: Symbol) (m :: * -> *) b.
(Show a, Read a, KnownSymbol l) =>
(Proxy l, a @ l) -> (a -> Choreo m b) -> Choreo m b
cond (Proxy l
l, a @ l
a) a -> Choreo m b
c = forall (f :: * -> *) a. f a -> Freer f a
toFreer (forall a (l :: Symbol) (m :: * -> *) b.
(Show a, Read a, KnownSymbol l) =>
Proxy l -> (a @ l) -> (a -> Choreo m b) -> ChoreoSig m b
Cond Proxy l
l a @ l
a a -> Choreo m b
c)

-- | A variant of `~>` that sends the result of a local computation.
(~~>) :: (Show a, Read a, KnownSymbol l, KnownSymbol l')
      => (Proxy l, Unwrap l -> m a) -- ^ A pair of a sender's location and a local
                                    -- computation.
      -> Proxy l'                   -- ^ A receiver's location.
      -> Choreo m (a @ l')
~~> :: forall a (l :: Symbol) (l' :: Symbol) (m :: * -> *).
(Show a, Read a, KnownSymbol l, KnownSymbol l') =>
(Proxy l, Unwrap l -> m a) -> Proxy l' -> Choreo m (a @ l')
(~~>) (Proxy l
l, Unwrap l -> m a
m) Proxy l'
l' = do
  a @ l
x <- Proxy l
l forall (l :: Symbol) (m :: * -> *) a.
KnownSymbol l =>
Proxy l -> (Unwrap l -> m a) -> Choreo m (a @ l)
`locally` Unwrap l -> m a
m
  (Proxy l
l, a @ l
x) forall a (l :: Symbol) (l' :: Symbol) (m :: * -> *).
(Show a, Read a, KnownSymbol l, KnownSymbol l') =>
(Proxy l, a @ l) -> Proxy l' -> Choreo m (a @ l')
~> Proxy l'
l'

-- | A variant of `cond` that conditonally executes choregraphies based on the
-- result of a local computation.
cond' :: (Show a, Read a, KnownSymbol l)
      => (Proxy l, Unwrap l -> m a) -- ^ A pair of a location and a local
                                    -- computation.
      -> (a -> Choreo m b)          -- ^ A function that describes the follow-up
                                    -- choreographies based on the result of the
                                    -- local computation.
      -> Choreo m b
cond' :: forall a (l :: Symbol) (m :: * -> *) b.
(Show a, Read a, KnownSymbol l) =>
(Proxy l, Unwrap l -> m a) -> (a -> Choreo m b) -> Choreo m b
cond' (Proxy l
l, Unwrap l -> m a
m) a -> Choreo m b
c = do
  a @ l
x <- Proxy l
l forall (l :: Symbol) (m :: * -> *) a.
KnownSymbol l =>
Proxy l -> (Unwrap l -> m a) -> Choreo m (a @ l)
`locally` Unwrap l -> m a
m
  forall a (l :: Symbol) (m :: * -> *) b.
(Show a, Read a, KnownSymbol l) =>
(Proxy l, a @ l) -> (a -> Choreo m b) -> Choreo m b
cond (Proxy l
l, a @ l
x) a -> Choreo m b
c