{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{- |
Module: Control.Concurrent.Actor
Description: A basic actor model in Haskell
Copyright: (c) Samuel Schlesinger 2020
License: MIT
Maintainer: sgschlesinger@gmail.com
Stability: experimental
Portability: POSIX, Windows
-}
module Control.Concurrent.Actor
( ActionT
, Actor
, send
, addAfterEffect
, threadId
, actFinally
, act
, receiveSTM
, receive
, hoistActionT
, link
, linkSTM
, LinkKill(..)
, self
, murder
) where

-- This list was generated by compiling with the -ddump-minimal-imports flag.
import Control.Concurrent
    ( forkFinally, myThreadId, throwTo, ThreadId )
import Control.Monad.IO.Class ( MonadIO(..) )
import Control.Monad.Trans ( MonadTrans(..) )
import Control.Monad.Reader
    ( MonadReader(local, ask), ReaderT(ReaderT) )
import Control.Monad.State.Class ( MonadState )
import Control.Monad.Reader.Class ()
import Control.Monad.Writer.Class ( MonadWriter )
import Control.Monad.RWS.Class ( MonadRWS )
import Control.Monad.Error.Class ( MonadError )
import Control.Monad.Cont.Class ( MonadCont )
import Control.Concurrent.STM
    ( STM, atomically, newTVar, readTVar, modifyTVar )
import Control.Exception ( SomeException, Exception )
import Data.Functor.Contravariant ( Contravariant(contramap) )
import Data.Queue ( dequeue, enqueue, newQueue, Queue )

-- | A type that contains the actions that 'Actor's will do.
newtype ActionT message m a = ActionT
  { runActionT
    :: ActorContext message
    -> m a
  }

deriving via ReaderT (ActorContext message) m instance Functor m => Functor (ActionT message m)
deriving via ReaderT (ActorContext message) m instance Applicative m => Applicative (ActionT message m)
deriving via ReaderT (ActorContext message) m instance Monad m => Monad (ActionT message m)
deriving via ReaderT (ActorContext message) m instance MonadIO m => MonadIO (ActionT message m)
deriving via ReaderT (ActorContext message) instance MonadTrans (ActionT message)
deriving via ReaderT (ActorContext message) m instance MonadError e m => MonadError e (ActionT message m)
deriving via ReaderT (ActorContext message) m instance MonadWriter w m => MonadWriter w (ActionT message m)
deriving via ReaderT (ActorContext message) m instance MonadState s m => MonadState s (ActionT message m)
deriving via ReaderT (ActorContext message) m instance MonadCont m => MonadCont (ActionT message m)

instance MonadReader r m => MonadReader r (ActionT message m) where
  ask = ActionT (const ask)
  local f (ActionT ma) = ActionT (fmap (local f) ma)

instance (MonadWriter w m, MonadReader r m, MonadState s m) => MonadRWS r w s (ActionT message m)

data ActorContext message = ActorContext
  { messageQueue :: Queue message
  , actorHandle :: Actor message
  }

-- | A handle to do things to actors, like sending them messages, fiddling
-- with their threads, or adding an effect that will occur after they've
-- finished executing.
data Actor message = Actor
  { addAfterEffect' :: (Maybe SomeException -> IO ()) -> STM ()
  , threadId' :: ThreadId
  , send' :: message -> STM ()
  }

addAfterEffect :: Actor message -> (Maybe SomeException -> IO ()) -> STM ()
addAfterEffect = addAfterEffect'

threadId :: Actor message -> ThreadId
threadId = threadId'

send :: Actor message -> message -> STM ()
send = send'

instance Eq (Actor message) where
  Actor _ x _ == Actor _ y _ = x == y

instance Show (Actor message) where
  show Actor{threadId'} = show threadId'

instance Contravariant Actor where
  contramap f (Actor addAfterEffect' threadId' ((. f) -> send')) = Actor{..}

-- | Perform some 'ActionT' in a thread, with some cleanup afterwards.
actFinally :: (Either SomeException a -> IO ()) -> ActionT message IO a -> IO (Actor message)
actFinally errorHandler (ActionT actionT) = do
  onErrorTVar <- atomically $ newTVar errorHandler
  messageQueue <- atomically newQueue
  let addAfterEffect' afterEffect = modifyTVar onErrorTVar (\f x -> f x <* afterEffect (leftToMaybe x))
  let send' = enqueue messageQueue
  threadId' <- forkFinally (do { threadId' <- myThreadId; actionT (ActorContext messageQueue Actor{..}) }) (\result -> atomically (readTVar onErrorTVar) >>= ($ result))
  pure $ Actor {..}
  where
    leftToMaybe (Left x) = Just x
    leftToMaybe _ = Nothing

-- | Perform some 'ActionT' in a thread.
act :: ActionT message IO a -> IO (Actor message)
act = actFinally (const (pure ()))

-- | Receive a message and do some 'ActionT' with it.
receive :: MonadIO m => (message -> ActionT message m a) -> ActionT message m a
receive f = ActionT \ctx -> do
  message <- liftIO $ atomically $ dequeue (messageQueue ctx)
  runActionT (f message) ctx

-- | Receive a message and, in the same transaction, produce some result.
receiveSTM :: MonadIO m => (message -> STM a) -> ActionT message m a
receiveSTM f = ActionT \ctx -> liftIO (atomically (dequeue (messageQueue ctx) >>= f))

-- | Use a natural transformation to transform an 'ActionT' on one base
-- monad to another.
hoistActionT :: (forall x. m x -> n x) -> ActionT message m a -> ActionT message n a
hoistActionT f (ActionT actionT) = ActionT (fmap f actionT)

-- | The exception thrown when an actor we've 'link'ed with has died.
data LinkKill = LinkKill ThreadId
  deriving Show

instance Exception LinkKill

-- | Link the lifetime of the given actor to this one. If the given actor
-- dies, it will throw a 'LinkKill' exception to us with its 'ThreadId'
-- attached to it.
link :: MonadIO m => Actor message -> ActionT message' m ()
link you = do
  me <- self
  liftIO . atomically $ linkSTM me you

-- | Links the lifetime of the first actor to the second. If the second
-- actor's thread dies, it will throw a 'LinkKill' exception to the first
-- with its 'ThreadId' attached to it.
linkSTM :: Actor message -> Actor message' -> STM ()
linkSTM alice bob = do
  addAfterEffect bob (const $ throwTo (threadId alice) (LinkKill (threadId bob)))

-- | Returns the 'ThreadId' of the actor executing this action. This is
-- possibly more efficient than 'liftIO'ing 'myThreadId', but more than
-- that, it gives us the ability to grab it in arbitrary 'Applicative'
-- contexts, rather than only in 'MonadIO' ones.
self :: Applicative m => ActionT message m (Actor message)
self = ActionT \(ActorContext{actorHandle}) -> pure actorHandle

-- | The exception thrown when we 'murder' an 'Actor'.
data MurderKill = MurderKill ThreadId
  deriving Show

instance Exception MurderKill

-- | Throws a 'MurderKill' exception to the given 'Actor'.
murder :: MonadIO m => Actor message -> m ()
murder Actor{threadId'} = liftIO $ do { tid <- myThreadId; throwTo threadId' (MurderKill tid) }