{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{- |
Module: Control.Concurrent.Actor
Description: A basic actor model in Haskell
Copyright: (c) Samuel Schlesinger 2020
License: MIT
Maintainer: sgschlesinger@gmail.com
Stability: experimental
Portability: POSIX, Windows
-}
module Control.Concurrent.Actor
( ActionT
, Actor
, send
, addAfterEffect
, threadId
, livenessCheck
, withLivenessCheck
, Liveness(..)
, ActorDead(..)
, actFinally
, act
, receiveSTM
, receive
, hoistActionT
, link
, linkSTM
, LinkKill(..)
, self
, murder
, MurderKill(..)
) where

-- This list was generated by compiling with the -ddump-minimal-imports flag.
import Control.Concurrent
    ( forkFinally, myThreadId, throwTo, ThreadId )
import Control.Monad.IO.Class ( MonadIO(..) )
import Control.Monad.Trans ( MonadTrans(..) )
import Control.Monad.Reader
    ( MonadReader(local, ask), ReaderT(ReaderT) )
import Control.Monad.State.Class ( MonadState )
import Control.Monad.Reader.Class ()
import Control.Monad.Writer.Class ( MonadWriter )
import Control.Monad.RWS.Class ( MonadRWS )
import Control.Monad.Error.Class ( MonadError )
import Control.Monad.Cont.Class ( MonadCont )
import Control.Concurrent.STM
    ( STM, atomically, newTVar, newTVarIO, readTVar, modifyTVar, writeTVar, TVar, throwSTM )
import Control.Exception ( SomeException, Exception )
import Data.Functor.Contravariant ( Contravariant(contramap) )
import Data.Queue ( dequeue, enqueue, newQueue, Queue )

-- | A type that contains the actions that 'Actor's will do.
newtype ActionT message m a = ActionT
  { runActionT
    :: ActorContext message
    -> m a
  }

deriving via ReaderT (ActorContext message) m instance Functor m => Functor (ActionT message m)
deriving via ReaderT (ActorContext message) m instance Applicative m => Applicative (ActionT message m)
deriving via ReaderT (ActorContext message) m instance Monad m => Monad (ActionT message m)
deriving via ReaderT (ActorContext message) m instance MonadIO m => MonadIO (ActionT message m)
deriving via ReaderT (ActorContext message) instance MonadTrans (ActionT message)
deriving via ReaderT (ActorContext message) m instance MonadError e m => MonadError e (ActionT message m)
deriving via ReaderT (ActorContext message) m instance MonadWriter w m => MonadWriter w (ActionT message m)
deriving via ReaderT (ActorContext message) m instance MonadState s m => MonadState s (ActionT message m)
deriving via ReaderT (ActorContext message) m instance MonadCont m => MonadCont (ActionT message m)

instance MonadReader r m => MonadReader r (ActionT message m) where
  ask = ActionT (const ask)
  local f (ActionT ma) = ActionT (fmap (local f) ma)

instance (MonadWriter w m, MonadReader r m, MonadState s m) => MonadRWS r w s (ActionT message m)

data ActorContext message = ActorContext
  { messageQueue :: Queue message
  , actorHandle :: Actor message
  }

-- | A handle to do things to actors, like sending them messages, fiddling
-- with their threads, or adding an effect that will occur after they've
-- finished executing.
data Actor message = Actor
  { addAfterEffect' :: (Maybe SomeException -> IO ()) -> STM ()
  , threadId' :: ThreadId
  , send' :: message -> STM ()
  , status :: TVar (Maybe (Maybe SomeException))
  }

-- | The liveness state of a particular 'Actor'.
data Liveness = Alive | Completed | ThrewException SomeException
  deriving Show

-- | Checks the 'Liveness' of a particular 'Actor'
livenessCheck :: Actor message -> STM Liveness
livenessCheck actor = do
  readTVar (status actor) >>= \case
    Nothing -> pure Alive
    Just completion -> pure (maybe Completed ThrewException completion)

-- | The exception thrown when we try to 'send' or 'addAfterEffect' to an
-- 'Actor' which has died.
data ActorDead = ActorDead (Maybe SomeException)
  deriving Show

instance Exception ActorDead

-- | Allows us to wrap 'addAfterEffect', 'send', and any other custom
-- combinators in a liveness check. This causes contention on the
-- underlying 'TVar' that contains the status report of the 'Actor', and
-- thus should be avoided where possible. That being said, it is also
-- useful to avoid sending messages or add after effects to dead actors,
-- which will certainly be lost forever.
withLivenessCheck :: (Actor message -> x -> STM ()) -> Actor message -> x -> STM ()
withLivenessCheck f actorHandle x = readTVar (status actorHandle) >>= maybe (f actorHandle x) (throwSTM . ActorDead)

-- | Once the 'Actor' dies, all of the effects that have been added via
-- this function will run. This is how you can implement your own functions
-- like 'link' or 'linkSTM'.
addAfterEffect :: Actor message -> (Maybe SomeException -> IO ()) -> STM ()
addAfterEffect = addAfterEffect'

-- | Retrieve the 'ThreadId' associated with this 'Actor'.
threadId :: Actor message -> ThreadId
threadId = threadId'

-- | Send a message to this 'Actor'.
send :: Actor message -> message -> STM ()
send = send'

instance Eq (Actor message) where
  Actor _ x _ _ == Actor _ y _ _ = x == y

instance Show (Actor message) where
  show Actor{threadId'} = show threadId'

instance Contravariant Actor where
  contramap f (Actor addAfterEffect' threadId' ((. f) -> send') status) = Actor{..}

-- | Perform some 'ActionT' in a thread, with some cleanup afterwards.
actFinally :: (Either SomeException a -> IO ()) -> ActionT message IO a -> IO (Actor message)
actFinally errorHandler (ActionT actionT) = do
  onErrorTVar <- atomically $ newTVar errorHandler
  messageQueue <- atomically newQueue
  status <- newTVarIO Nothing
  let addAfterEffect' afterEffect = modifyTVar onErrorTVar (\f x -> f x <* afterEffect (leftToMaybe x))
  let send' = enqueue messageQueue
  threadId' <- forkFinally (do { threadId' <- myThreadId; actionT (ActorContext messageQueue Actor{..}) }) (\result -> atomically (do { writeTVar status (Just (leftToMaybe result)); readTVar onErrorTVar }) >>= ($ result))
  pure $ Actor {..}
  where
    leftToMaybe (Left x) = Just x
    leftToMaybe _ = Nothing

-- | Perform some 'ActionT' in a thread.
act :: ActionT message IO a -> IO (Actor message)
act = actFinally (const (pure ()))

-- | Receive a message and do some 'ActionT' with it.
receive :: MonadIO m => (message -> ActionT message m a) -> ActionT message m a
receive f = ActionT \ctx -> do
  message <- liftIO $ atomically $ dequeue (messageQueue ctx)
  runActionT (f message) ctx

-- | Receive a message and, in the same transaction, produce some result.
receiveSTM :: MonadIO m => (message -> STM a) -> ActionT message m a
receiveSTM f = ActionT \ctx -> liftIO (atomically (dequeue (messageQueue ctx) >>= f))

-- | Use a natural transformation to transform an 'ActionT' on one base
-- monad to another.
hoistActionT :: (forall x. m x -> n x) -> ActionT message m a -> ActionT message n a
hoistActionT f (ActionT actionT) = ActionT (fmap f actionT)

-- | The exception thrown when an actor we've 'link'ed with has died.
data LinkKill = LinkKill ThreadId
  deriving Show

instance Exception LinkKill

-- | Link the lifetime of the given actor to this one. If the given actor
-- dies, it will throw a 'LinkKill' exception to us with its 'ThreadId'
-- attached to it.
link :: MonadIO m => Actor message -> ActionT message' m ()
link you = do
  me <- self
  liftIO . atomically $ linkSTM me you

-- | Links the lifetime of the first actor to the second. If the second
-- actor's thread dies, it will throw a 'LinkKill' exception to the first
-- with its 'ThreadId' attached to it.
linkSTM :: Actor message -> Actor message' -> STM ()
linkSTM alice bob = do
  addAfterEffect bob (const $ throwTo (threadId alice) (LinkKill (threadId bob)))

-- | Returns the 'Actor' handle of the actor executing this action.
self :: Applicative m => ActionT message m (Actor message)
self = ActionT \(ActorContext{actorHandle}) -> pure actorHandle

-- | The exception thrown when we 'murder' an 'Actor'.
data MurderKill = MurderKill ThreadId
  deriving Show

instance Exception MurderKill

-- | Throws a 'MurderKill' exception to the given 'Actor'.
murder :: MonadIO m => Actor message -> m ()
murder Actor{threadId'} = liftIO $ do { tid <- myThreadId; throwTo threadId' (MurderKill tid) }