{-
  Copyright (c) Meta Platforms, Inc. and affiliates.
  All rights reserved.

  This source code is licensed under the BSD-style license found in the
  LICENSE file in the root directory of this source tree.
-}

module TestChannel
  ( TestChannel(..), Req(..)
  , runServer
  , runTestServer
  ) where

import Control.Concurrent
import Control.Exception (SomeException)
import Control.Monad
import Data.ByteString (ByteString)
import Data.Proxy

import Thrift.Channel
import Thrift.Monad (getRpcPriority, newCounter)
import Thrift.Processor
import Thrift.Protocol

newtype TestChannel s = TestChannel (MVar Req)

data Req = Req ByteString RecvCallback

instance ClientChannel TestChannel where
  sendRequest (TestChannel reqBuf) Request{..} sendCob recvCob =
    case getRpcPriority reqOptions of
      Nothing             -> send ()
      Just NormalPriority -> send ()
      _ -> sendCob $ Just $ ChannelException "non-Normal priority"
    where
      send () = sendCob Nothing >> putMVar reqBuf (Req reqMsg recvCob)

  sendOnewayRequest (TestChannel reqBuf) Request{..} sendCob = do
    putMVar reqBuf $ Req reqMsg (\_ -> return ())
    sendCob Nothing

runServer
  :: (Processor c, Protocol p)
  => Proxy p
  -> TestChannel s
  -> (forall r . c r -> IO r)
  -> (forall r . c r -> Either SomeException r -> Header)
  -> IO ()
runServer p ch handler postProcess = do
  counter <- newCounter
  runTestServer ch $ \bytes -> do
    seqNum <- counter
    process p seqNum handler postProcess bytes

runTestServer
  :: TestChannel s
  -> (ByteString -> IO (ByteString, a, Header))
  -> IO ()
runTestServer (TestChannel req) handler = forever $ do
  Req bytes callback <- takeMVar req
  (handled, _, headers) <- handler bytes
  callback $ Right $ Response handled headers
