{-
  Copyright (c) Meta Platforms, Inc. and affiliates.
  All rights reserved.

  This source code is licensed under the BSD-style license found in the
  LICENSE file in the root directory of this source tree.
-}

module ServerTest (main) where

import Control.Exception hiding (DivideByZero)
import Control.Monad
import Control.Monad.Trans.Class
import Data.Either

import Facebook.Init
-- import Network (testServerHost)
import Test.HUnit
import TestRunner

import Thrift.Api
import Thrift.Monad
import Thrift.Protocol.ApplicationException.Types
import Thrift.Protocol.Id
import Thrift.Channel.HTTP
import Thrift.Server.HTTP

import Math.Adder.Client
import Math.Calculator.Client
import Math.Types
import Echoer.Echoer.Client
import EchoHandler

withTestServer :: ServerOptions -> (Int -> IO a) -> IO a
withTestServer serverOptions action = do
  st <- initEchoerState
  withBackgroundServer (echoHandler st) serverOptions $
    \Server{..} -> action serverPort

mkHTTPConfig :: Int -> ProtocolId -> HTTPConfig t
mkHTTPConfig port protId =
  HTTPConfig
    { httpHost = "localhost" --testServerHost
    , httpPort = port
    , httpProtocolId = protId
    , httpResponseTimeout = Nothing
    }

mkServerTest
  :: String
  -> String
  -> ProtocolId
  -> Thrift Echoer ()
  -> Test
mkServerTest pname label protId action =
  TestLabel (pname ++ " " ++ label) $ TestCase $
    withTestServer defaultOptions $ \port -> do
      let httpConf = mkHTTPConfig port protId
      withHTTPChannel httpConf action

-- Calculator function
addTest :: String -> ProtocolId -> Test
addTest pname protId = mkServerTest pname "add test" protId $ do
  res <- add 5 2
  lift $ assertEqual "5 + 2 = 7" 7 res

-- Calculator function
divideTest :: String -> ProtocolId -> Test
divideTest pname protId = mkServerTest pname "divide test" protId $ do
  res <- divide 9 3
  lift $ assertEqual "9 / 3 = 3" 3 res

divideExceptionTest :: String -> ProtocolId -> Test
divideExceptionTest pname protId =
  mkServerTest pname "divide exception" protId $
  (void . lift . evaluate =<< divide 1 0)
    `catchThrift` \DivideByZero -> return ()

-- Calculator function
multiTest :: String -> ProtocolId -> Test
multiTest pname protId = mkServerTest pname "multiple requests" protId $ do
  put 100

  r1 <- add 2 2
  lift $ assertEqual "2 + 2 = 4" 4 r1

  r2 <- divide 64 16
  lift $ assertEqual "64 / 16 = 4" 4 r2

  r3 <- get
  lift $ assertEqual "put = get" 100 r3

  r4 <- divide 100 10
  lift $ assertEqual "100 / 10 = 10" 10 r4

unimplementedTest :: String -> ProtocolId -> Test
unimplementedTest pname protId =
  mkServerTest pname "unimplemented test" protId $
    unimplemented `catchThrift` \ApplicationException{} -> return ()

-- Echo function
echoTest :: String -> ProtocolId -> Test
echoTest pname protId = mkServerTest pname "echo" protId $ do
  res <- echo val
  lift $ assertEqual "echo echoed" val res
  where
    val = "AAAAAAAAAA_DO_NOT_DELETE"

portAlreadyBoundTest :: String -> ProtocolId -> Test
portAlreadyBoundTest pname protId =
  TestLabel (pname ++ " portAlreadyBoundTest") $ TestCase $ do
    (result :: Either SomeException ()) <- try $
      withTestServer serverOptions $ const $ do
        withHTTPChannel httpConfig $
          lift $ withTestServer serverOptions $ const $ do
            withHTTPChannel httpConfig $
              return ()
    assertBool "should fail" (isLeft result)
  where
    port :: Int
    port = 9999
    serverOptions :: ServerOptions
    serverOptions = defaultOptions
      { desiredPort = Just port
      }
    httpConfig :: HTTPConfig t
    httpConfig = mkHTTPConfig port protId

tests :: String -> ProtocolId -> [Test]
tests pname protId = map (\f -> f pname protId)
  [ addTest
  , divideTest
  , divideExceptionTest
  , multiTest
  , unimplementedTest
  , echoTest
  , portAlreadyBoundTest
  ]

main :: IO ()
main = withFacebookUnitTest $
  testRunner $ TestList $
    tests "compact" compactProtocolId ++
    tests "binary" binaryProtocolId