module Ribosome.Host.Listener where

import Conc (Lock, Restoration, interpretAtomic, interpretEventsChan, interpretLockReentrant, lock)
import Exon (exon)
import qualified Polysemy.Log as Log
import qualified Polysemy.Process as Process
import Polysemy.Process (Process)

import Ribosome.Host.Data.Request (RequestId (unRequestId), TrackedRequest (TrackedRequest))
import Ribosome.Host.Data.Response (Response, TrackedResponse (TrackedResponse), formatResponse)
import Ribosome.Host.Data.RpcError (RpcError)
import qualified Ribosome.Host.Data.RpcMessage as RpcMessage
import Ribosome.Host.Data.RpcMessage (RpcMessage, formatRpcMsg)
import qualified Ribosome.Host.Effect.Host as Host
import Ribosome.Host.Effect.Host (Host)
import qualified Ribosome.Host.Effect.Responses as Responses
import Ribosome.Host.Effect.Responses (Responses)
import Ribosome.Host.Text (ellipsize)

data ResponseLock =
  ResponseLock
  deriving stock (ResponseLock -> ResponseLock -> Bool
(ResponseLock -> ResponseLock -> Bool)
-> (ResponseLock -> ResponseLock -> Bool) -> Eq ResponseLock
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ResponseLock -> ResponseLock -> Bool
$c/= :: ResponseLock -> ResponseLock -> Bool
== :: ResponseLock -> ResponseLock -> Bool
$c== :: ResponseLock -> ResponseLock -> Bool
Eq, Int -> ResponseLock -> ShowS
[ResponseLock] -> ShowS
ResponseLock -> String
(Int -> ResponseLock -> ShowS)
-> (ResponseLock -> String)
-> ([ResponseLock] -> ShowS)
-> Show ResponseLock
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResponseLock] -> ShowS
$cshowList :: [ResponseLock] -> ShowS
show :: ResponseLock -> String
$cshow :: ResponseLock -> String
showsPrec :: Int -> ResponseLock -> ShowS
$cshowsPrec :: Int -> ResponseLock -> ShowS
Show)

data ResponseSent =
  ResponseSent
  deriving stock (ResponseSent -> ResponseSent -> Bool
(ResponseSent -> ResponseSent -> Bool)
-> (ResponseSent -> ResponseSent -> Bool) -> Eq ResponseSent
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ResponseSent -> ResponseSent -> Bool
$c/= :: ResponseSent -> ResponseSent -> Bool
== :: ResponseSent -> ResponseSent -> Bool
$c== :: ResponseSent -> ResponseSent -> Bool
Eq, Int -> ResponseSent -> ShowS
[ResponseSent] -> ShowS
ResponseSent -> String
(Int -> ResponseSent -> ShowS)
-> (ResponseSent -> String)
-> ([ResponseSent] -> ShowS)
-> Show ResponseSent
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResponseSent] -> ShowS
$cshowList :: [ResponseSent] -> ShowS
show :: ResponseSent -> String
$cshow :: ResponseSent -> String
showsPrec :: Int -> ResponseSent -> ShowS
$cshowsPrec :: Int -> ResponseSent -> ShowS
Show)

readyToSend ::
  Member (AtomicState RequestId) r =>
  RequestId ->
  Sem r Bool
readyToSend :: forall (r :: EffectRow).
Member (AtomicState RequestId) r =>
RequestId -> Sem r Bool
readyToSend RequestId
i =
  (RequestId -> Bool) -> Sem r Bool
forall s s' (r :: EffectRow).
Member (AtomicState s) r =>
(s -> s') -> Sem r s'
atomicGets \ RequestId
prev -> RequestId
prev RequestId -> RequestId -> Bool
forall a. Ord a => a -> a -> Bool
>= RequestId
i RequestId -> RequestId -> RequestId
forall a. Num a => a -> a -> a
- RequestId
1

-- |Send a response, increment the 'RequestId' tracking the latest sent response, and publish an event that unblocks all
-- waiting responses.
sendResponse ::
  Members [Process RpcMessage a, AtomicState RequestId, Events res ResponseSent, Log] r =>
  RequestId ->
  Response ->
  Sem r ()
sendResponse :: forall a res (r :: EffectRow).
Members
  '[Process RpcMessage a, AtomicState RequestId,
    Events res ResponseSent, Log]
  r =>
RequestId -> Response -> Sem r ()
sendResponse RequestId
i Response
response = do
  Text -> Sem r ()
forall (r :: EffectRow).
(HasCallStack, Member Log r) =>
Text -> Sem r ()
Log.trace [exon|send response: <#{show (unRequestId i)}> #{formatResponse response}|]
  RpcMessage -> Sem r ()
forall i o (r :: EffectRow).
Member (Process i o) r =>
i -> Sem r ()
Process.send (TrackedResponse -> RpcMessage
RpcMessage.Response (RequestId -> Response -> TrackedResponse
TrackedResponse RequestId
i Response
response))
  (RequestId -> RequestId) -> Sem r ()
forall s (r :: EffectRow).
Member (AtomicState s) r =>
(s -> s) -> Sem r ()
atomicModify' (RequestId -> RequestId -> RequestId
forall a. Ord a => a -> a -> a
max RequestId
i)
  ResponseSent -> Sem r ()
forall e resource (r :: EffectRow).
Member (Events resource e) r =>
e -> Sem r ()
publish ResponseSent
ResponseSent

-- |Check whether the last sent response has a 'RequestId' one smaller than the current response.
-- If true, send the response.
-- This is protected by a mutex to avoid deadlock.
-- Returns whether the response was sent for 'sendWhenReady' to decide whether to recurse.
sendIfReady ::
  Member (Events res ResponseSent) r =>
  Members [Tagged ResponseLock Lock, Process RpcMessage a, AtomicState RequestId, Log, Resource] r =>
  RequestId ->
  Response ->
  Sem r Bool
sendIfReady :: forall res (r :: EffectRow) a.
(Member (Events res ResponseSent) r,
 Members
   '[Tagged ResponseLock Lock, Process RpcMessage a,
     AtomicState RequestId, Log, Resource]
   r) =>
RequestId -> Response -> Sem r Bool
sendIfReady RequestId
i Response
response =
  Sem (Lock : r) Bool -> Sem r Bool
forall {k1} (k2 :: k1) (e :: (* -> *) -> * -> *) (r :: EffectRow)
       a.
Member (Tagged k2 e) r =>
Sem (e : r) a -> Sem r a
tag (Sem (Lock : r) Bool -> Sem r Bool)
-> Sem (Lock : r) Bool -> Sem r Bool
forall a b. (a -> b) -> a -> b
$ Sem (Lock : r) Bool -> Sem (Lock : r) Bool
forall (r :: EffectRow) a. Member Lock r => Sem r a -> Sem r a
lock do
    Sem (Lock : r) Bool
-> Sem (Lock : r) Bool
-> Sem (Lock : r) Bool
-> Sem (Lock : r) Bool
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM (RequestId -> Sem (Lock : r) Bool
forall (r :: EffectRow).
Member (AtomicState RequestId) r =>
RequestId -> Sem r Bool
readyToSend RequestId
i) (Bool
True Bool -> Sem (Lock : r) () -> Sem (Lock : r) Bool
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ RequestId -> Response -> Sem (Lock : r) ()
forall a res (r :: EffectRow).
Members
  '[Process RpcMessage a, AtomicState RequestId,
    Events res ResponseSent, Log]
  r =>
RequestId -> Response -> Sem r ()
sendResponse RequestId
i Response
response) (Bool -> Sem (Lock : r) Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False)

-- |Neovim doesn't permit responses to be sent out of order.
-- If multiple requests from Neovim have been sent concurrently (e.g. triggered from rpc calls themselves, since the
-- user can't achieve this through the UI due to it being single-threaded), and the first one runs longer than the rest,
-- the others have to wait for the first response to be sent.
-- Otherwise, Neovim will just terminate the client connection.
--
-- To ensure this, the last sent 'RequestId' is stored and compared to the current response's ID before sending.
-- If the last ID is not @i - 1@, this waits until all previous responses are sent.
-- A new attempt to respond is triggered via 'Events' in 'sendResponse'.
-- This function calls 'subscribe' before doing the initial ID comparison, to avoid the race condition in which the last
-- response is sent at the same time that the call to 'subscribe' is made after comparing the IDs unsuccessfully and the
-- 'ResponseSent' event is therefore missed, causing this to block indefinitely.
sendWhenReady ::
  Members [Events res ResponseSent, EventConsumer res ResponseSent] r =>
  Members [Tagged ResponseLock Lock, Process RpcMessage a, AtomicState RequestId, Log, Resource] r =>
  RequestId ->
  Response ->
  Sem r ()
sendWhenReady :: forall res (r :: EffectRow) a.
(Members
   '[Events res ResponseSent, EventConsumer res ResponseSent] r,
 Members
   '[Tagged ResponseLock Lock, Process RpcMessage a,
     AtomicState RequestId, Log, Resource]
   r) =>
RequestId -> Response -> Sem r ()
sendWhenReady RequestId
i Response
response =
  Sem (Consume ResponseSent : r) () -> Sem r ()
forall e resource (r :: EffectRow).
Member (Scoped (EventResource resource) (Consume e)) r =>
InterpreterFor (Consume e) r
subscribe Sem (Consume ResponseSent : r) ()
trySend
  where
    trySend :: Sem (Consume ResponseSent : r) ()
trySend =
      Sem (Consume ResponseSent : r) Bool
-> Sem (Consume ResponseSent : r) ()
-> Sem (Consume ResponseSent : r) ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM (RequestId -> Response -> Sem (Consume ResponseSent : r) Bool
forall res (r :: EffectRow) a.
(Member (Events res ResponseSent) r,
 Members
   '[Tagged ResponseLock Lock, Process RpcMessage a,
     AtomicState RequestId, Log, Resource]
   r) =>
RequestId -> Response -> Sem r Bool
sendIfReady RequestId
i Response
response) do
        ResponseSent
ResponseSent <- Sem (Consume ResponseSent : r) ResponseSent
forall e (r :: EffectRow). Member (Consume e) r => Sem r e
consume
        Sem (Consume ResponseSent : r) ()
trySend

dispatch ::
  Members [AtomicState RequestId, Tagged ResponseLock Lock, Events res ResponseSent, EventConsumer res ResponseSent] r =>
  Members [Host, Process RpcMessage a, Responses RequestId Response !! RpcError, Log, Resource, Async] r =>
  RpcMessage ->
  Sem r ()
dispatch :: forall res (r :: EffectRow) a.
(Members
   '[AtomicState RequestId, Tagged ResponseLock Lock,
     Events res ResponseSent, EventConsumer res ResponseSent]
   r,
 Members
   '[Host, Process RpcMessage a,
     Responses RequestId Response !! RpcError, Log, Resource, Async]
   r) =>
RpcMessage -> Sem r ()
dispatch = \case
  RpcMessage.Request (TrackedRequest RequestId
i Request
req) ->
    Sem r (Async (Maybe ())) -> Sem r ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Sem r () -> Sem r (Async (Maybe ()))
forall (r :: EffectRow) a.
Member Async r =>
Sem r a -> Sem r (Async (Maybe a))
async (RequestId -> Response -> Sem r ()
forall res (r :: EffectRow) a.
(Members
   '[Events res ResponseSent, EventConsumer res ResponseSent] r,
 Members
   '[Tagged ResponseLock Lock, Process RpcMessage a,
     AtomicState RequestId, Log, Resource]
   r) =>
RequestId -> Response -> Sem r ()
sendWhenReady RequestId
i (Response -> Sem r ()) -> Sem r Response -> Sem r ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Request -> Sem r Response
forall (r :: EffectRow). Member Host r => Request -> Sem r Response
Host.request Request
req))
  RpcMessage.Response (TrackedResponse RequestId
i Response
response) ->
    RequestId -> Response -> Sem (Responses RequestId Response : r) ()
forall k v (r :: EffectRow).
Member (Responses k v) r =>
k -> v -> Sem r ()
Responses.respond RequestId
i Response
response Sem (Responses RequestId Response : r) ()
-> (RpcError -> Sem r ()) -> Sem r ()
forall err (eff :: (* -> *) -> * -> *) (r :: EffectRow) a.
Member (Resumable err eff) r =>
Sem (eff : r) a -> (err -> Sem r a) -> Sem r a
!! \ RpcError
e -> Text -> Sem r ()
forall (r :: EffectRow).
(HasCallStack, Member Log r) =>
Text -> Sem r ()
Log.error (RpcError -> Text
forall b a. (Show a, IsString b) => a -> b
show RpcError
e)
  RpcMessage.Notification Request
req ->
    Sem r (Async (Maybe ())) -> Sem r ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Sem r () -> Sem r (Async (Maybe ()))
forall (r :: EffectRow) a.
Member Async r =>
Sem r a -> Sem r (Async (Maybe a))
async (Request -> Sem r ()
forall (r :: EffectRow). Member Host r => Request -> Sem r ()
Host.notification Request
req))

listener ::
  Members [Host, Process RpcMessage (Either Text RpcMessage)] r =>
  Members [Responses RequestId Response !! RpcError, Log, Resource, Mask Restoration, Race, Async, Embed IO] r =>
  Sem r ()
listener :: forall (r :: EffectRow).
(Members '[Host, Process RpcMessage (Either Text RpcMessage)] r,
 Members
   '[Responses RequestId Response !! RpcError, Log, Resource,
     Mask Restoration, Race, Async, Embed IO]
   r) =>
Sem r ()
listener =
  Sem (Lock : r) () -> Sem r ()
forall mres (r :: EffectRow).
Members '[Resource, Race, Mask mres, Embed IO] r =>
InterpreterFor Lock r
interpretLockReentrant (Sem (Lock : r) () -> Sem r ()) -> Sem (Lock : r) () -> Sem r ()
forall a b. (a -> b) -> a -> b
$ Sem (Tagged ResponseLock Lock : r) () -> Sem (Lock : r) ()
forall {k1} (k2 :: k1) (e :: (* -> *) -> * -> *) (r :: EffectRow)
       a.
Sem (Tagged k2 e : r) a -> Sem (e : r) a
untag (Sem (Tagged ResponseLock Lock : r) () -> Sem (Lock : r) ())
-> Sem (Tagged ResponseLock Lock : r) () -> Sem (Lock : r) ()
forall a b. (a -> b) -> a -> b
$ Sem
  (Append
     '[Events (OutChan ResponseSent) ResponseSent,
       PScoped () (EventChan ResponseSent) (Consume ResponseSent)]
     (Tagged ResponseLock Lock : r))
  ()
-> Sem (Tagged ResponseLock Lock : r) ()
forall e (r :: EffectRow).
Members '[Resource, Race, Async, Embed IO] r =>
InterpretersFor '[Events (OutChan e) e, ChanConsumer e] r
interpretEventsChan (Sem
   (Append
      '[Events (OutChan ResponseSent) ResponseSent,
        PScoped () (EventChan ResponseSent) (Consume ResponseSent)]
      (Tagged ResponseLock Lock : r))
   ()
 -> Sem (Tagged ResponseLock Lock : r) ())
-> Sem
     (Append
        '[Events (OutChan ResponseSent) ResponseSent,
          PScoped () (EventChan ResponseSent) (Consume ResponseSent)]
        (Tagged ResponseLock Lock : r))
     ()
-> Sem (Tagged ResponseLock Lock : r) ()
forall a b. (a -> b) -> a -> b
$ RequestId
-> InterpreterFor
     (AtomicState RequestId)
     (Events (OutChan ResponseSent) ResponseSent
        : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
        : Tagged ResponseLock Lock : r)
forall a (r :: EffectRow).
Member (Embed IO) r =>
a -> InterpreterFor (AtomicState a) r
interpretAtomic RequestId
0 (Sem
   (AtomicState RequestId
      : Events (OutChan ResponseSent) ResponseSent
      : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
      : Tagged ResponseLock Lock : r)
   ()
 -> Sem
      (Events (OutChan ResponseSent) ResponseSent
         : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
         : Tagged ResponseLock Lock : r)
      ())
-> Sem
     (AtomicState RequestId
        : Events (OutChan ResponseSent) ResponseSent
        : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
        : Tagged ResponseLock Lock : r)
     ()
-> Sem
     (Events (OutChan ResponseSent) ResponseSent
        : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
        : Tagged ResponseLock Lock : r)
     ()
forall a b. (a -> b) -> a -> b
$ Sem
  (AtomicState RequestId
     : Events (OutChan ResponseSent) ResponseSent
     : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
     : Tagged ResponseLock Lock : r)
  ()
-> Sem
     (AtomicState RequestId
        : Events (OutChan ResponseSent) ResponseSent
        : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
        : Tagged ResponseLock Lock : r)
     ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever do
    Sem
  (AtomicState RequestId
     : Events (OutChan ResponseSent) ResponseSent
     : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
     : Tagged ResponseLock Lock : r)
  (Either Text RpcMessage)
forall i o (r :: EffectRow). Member (Process i o) r => Sem r o
Process.recv Sem
  (AtomicState RequestId
     : Events (OutChan ResponseSent) ResponseSent
     : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
     : Tagged ResponseLock Lock : r)
  (Either Text RpcMessage)
-> (Either Text RpcMessage
    -> Sem
         (AtomicState RequestId
            : Events (OutChan ResponseSent) ResponseSent
            : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
            : Tagged ResponseLock Lock : r)
         ())
-> Sem
     (AtomicState RequestId
        : Events (OutChan ResponseSent) ResponseSent
        : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
        : Tagged ResponseLock Lock : r)
     ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Right RpcMessage
msg -> do
        Text
-> Sem
     (AtomicState RequestId
        : Events (OutChan ResponseSent) ResponseSent
        : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
        : Tagged ResponseLock Lock : r)
     ()
forall (r :: EffectRow).
(HasCallStack, Member Log r) =>
Text -> Sem r ()
Log.trace [exon|listen: #{ellipsize 500 (formatRpcMsg msg)}|]
        RpcMessage
-> Sem
     (AtomicState RequestId
        : Events (OutChan ResponseSent) ResponseSent
        : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
        : Tagged ResponseLock Lock : r)
     ()
forall res (r :: EffectRow) a.
(Members
   '[AtomicState RequestId, Tagged ResponseLock Lock,
     Events res ResponseSent, EventConsumer res ResponseSent]
   r,
 Members
   '[Host, Process RpcMessage a,
     Responses RequestId Response !! RpcError, Log, Resource, Async]
   r) =>
RpcMessage -> Sem r ()
dispatch RpcMessage
msg
      Left Text
err ->
        Text
-> Sem
     (AtomicState RequestId
        : Events (OutChan ResponseSent) ResponseSent
        : PScoped () (EventChan ResponseSent) (Consume ResponseSent)
        : Tagged ResponseLock Lock : r)
     ()
forall (r :: EffectRow).
(HasCallStack, Member Log r) =>
Text -> Sem r ()
Log.error [exon|listen error: #{err}|]