-- | 
-- = k8s-wrapper
--
-- The k8s Wrapper is a module designed to provide an interface for running
-- applications in the Kubernetes system. This wrapper spawns the server on an
-- internal protocol, providing endpoints for startup, liveness, and
-- readiness checks, as well as optional metrics support.
--
-- There are some restrictions to be aware of when using this module.
-- First, the **server must be running in the main thread** in order to exit properly.
-- If this guarantee is not met, application and k8s will still function, but rollout
-- update procedures may take much longer.
-- 
-- Second, the user's application must be able to tear down upon receiving
-- `AsyncCancelled` or `ThreadKilled` signals. While it's acceptable to implement
-- graceful teardowns, these should be time-bound. In general, applications that use
-- Warp handles this automatically.
--
-- To use the k8s Wrapper, include the following configuration snippet in your pod:
--
-- == __pod.yaml__
-- @
-- ...
-- spec:
--   metadata:
--     annotations:
--       prometheus.io/port: "9121"
--       prometheus.io/scrape: "true"
--       prometheus.io\/path: "\/_metrics"
--   containers:
--     - lifecycle:
--         preStop:
--           httpGet:
--            path: /stop
--            port: ${config.port}
--           # Period when after which the pod will be terminated
--           # even if the stop hook has not returned.
--           terminationGracePeriodSeconds: 30
--        # When the service is considered started
--        # if the startup probe will not return success in
--        #  `initialDealySeconds + periodSeconds * failureThreshold` seconds
--        # the service will be restarted
--        startupProbe:
--          httpGet:
--            path: /ready
--            port: ${config.port}
--          failureThreshold: 12
--          initialDelaySeconds: 1
--          periodSeconds: 5
--          successThreshold: 1
--          timeoutSeconds: 2 
--       # When the service is considered alive, if it's not alive it will be
--       # restarted according to it's policy 
--       # initialDealySeconds + periodSeconds * failureThreshold
--       livenessProbe:
--          httpGet:
--            path: /health
--            port: ${config.health}
--          failureThreshold: 2
--          initialDelaySeconds: 1
--          periodSeconds: 10
--          successThreshold: 1
--          timeoutSeconds: 2
--        readinessProbe:
--          httpGet:
--            path: /ready
--            port:${config.port}
--          failureThreshold: 2
--          initialDelaySeconds: 1
--          periodSeconds: 10
--          successThreshold: 1
--          timeoutSeconds: 2
-- @
--
module Network.K8s.Application
  ( withK8sEndpoint
  , Config(..)
  , defConfig
    -- * Checks
    -- $k8s-checks
  , K8sChecks(..)
  ) where

import Control.Concurrent (killThread, threadDelay, forkIO)
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Exception (finally, AsyncException, throwIO, fromException)
import Control.Monad
import Data.Foldable
import Network.HTTP.Types
import Network.Wai as Wai
import Network.Wai.Handler.Warp qualified as Warp
import Network.Wai.Middleware.Prometheus as Prometheus

-- | Server configuration.
data Config = Config
  { Config -> Int
port :: Int  -- ^ Port where control interface is statred.
  , Config -> Int
maxTearDownPeriodSeconds :: Int -- ^ How much time to wait before forceful teardown.
  } deriving (Int -> Config -> ShowS
[Config] -> ShowS
Config -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Config] -> ShowS
$cshowList :: [Config] -> ShowS
show :: Config -> String
$cshow :: Config -> String
showsPrec :: Int -> Config -> ShowS
$cshowsPrec :: Int -> Config -> ShowS
Show)

-- | Default configuration variables.
defConfig :: Config
defConfig :: Config
defConfig = Int -> Int -> Config
Config Int
10120 Int
30

-- $k8s-checks
-- 
-- There are two types of health checks that can be used:
--
--  1. __Deep check__ - this verifies not only basic startup but also that the services
--     the application communicates with are up and running. It is more precise, but it increases
--     the risk of marking services as unavailable and causing cascading errors. Therefore,
--     it is recommended to use deep checks for the `startupCheck`, but use shallow checks for the `livenessCheck`.
--     As for the `readinessCheck`, it is up to the user to decide which type of check to use.
--  2. __Shallow check__ - this provides only the basic tests to ensure the application is running.
--
-- The suggested approach for implementing the health checks is as follows:
-- 
--  1. The `startupCheck` should return @True@ after the server is started, configs were checked,
--     and communication with other services was established. Once it has returned @True@, the
--     function should always return @True@.
--  2. The `readinessCheck` should return @True@ after the startup check has passed and after
--     population of the caches, preparing of the internal structures, etc. The function may
--     switch back to @False@ if the structures need repopulation and the application
--     can't serve the users. It's important to ensure that all the services in the cluster will
--     not switch to not ready state at the same time to avoid cascading failures.
--  3. The `livenessCheck` performs a shallow check of the service and returns the state accordingly.
--     Generally, the `livenessCheck` should return @False@ only in the case where the server needs to be restarted.

-- | Callbacks that the wrapper can use in order to understand the state of
-- the application and react accordingly.
data K8sChecks = K8sChecks
  { K8sChecks -> IO Bool
runReadynessCheck :: IO Bool -- ^ Checks that application can receive requests
  , K8sChecks -> IO Bool
runLivenessCheck :: IO Bool  -- ^ Checks that application running (should not be restarted)
  }

-- | Application state.
data ApplicationState
  = ApplicationStarting (Async ())
  | ApplicationRunning
  | ApplicationTeardownConfirm (TVar Bool)
  | ApplicationTearingDown
  

-- | Wrap a server that allows controlling business logic server.
-- The server can be communicated by the k8s services and can monitor the application livecycle.
--
-- In the application run the following logic:
--
-- @
--    k8s_wrapper_server                  user code
--                                   +----------------+
--        ready=false                | initialization |             
--        started=false              |                |
--        alive=false                +----------------+
--                                          |
--        started=true  <-------------------+
--                                          |
--                                   +---------------+
--                                   | start server  |
--                                   +---------------+
--                                          |
--       ready? ---> check user thread, run callback
--       alive? ---> check user thread, run callback
-- @
--             
-- 
-- The server wrapper also provides additional logic:
--   1. When checking liveness, the code checks if the thread running the server is still
--      alive and returns @False@ if it is not, regardless of what liveness check returns.
--   2. When the `stop` action is called, the server starts to return @False@ in the readiness check.
--      Once it is asked by the server, it sends an Exception to the client code.
--      This ensures that no new requests will be sent to the server.
--
-- In case of an asynchronous exception, we expect that we want to terminate the program.
-- Thus, we want to ensure a similar set of actions as a call to the @/stop@ hook:
--
--   1.  Put the application in the tearing down state.
--   2.  Start replying with @ready=False@ replies.
--   3.  Once we replied with @ready=false@ at least once (or after a timeout), we trigger server stop.
--      In this place, we expect the server to stop accepting new connections and exit once all
--      current connections will be processed. This is the responsibility of the function provided by the user,
--      but servers like Warp already handle that properly.
--
-- In case of an exception in the initialization function, it will be rethrown, and the function will exit with the same exception.
-- The k8s endpoint will be torn down.
--
-- In case if the user code exists with an exception, it will be rethrown. Otherwise, the code exits properly.
withK8sEndpoint
  :: Config      -- ^ Static configuration of the endpoint.
  -> K8sChecks   -- ^ K8s hooks
  -> IO a        -- ^ Initialization procedure
  -> (a -> IO b) -- ^ User supplied logic, see requirements for the server to work properly.
  -> IO ()
withK8sEndpoint :: forall a b. Config -> K8sChecks -> IO a -> (a -> IO b) -> IO ()
withK8sEndpoint Config{Int
maxTearDownPeriodSeconds :: Int
port :: Int
maxTearDownPeriodSeconds :: Config -> Int
port :: Config -> Int
..} K8sChecks
k8s IO a
startup a -> IO b
action = do
  Async a
startup_handle <- forall a. IO a -> IO (Async a)
async IO a
startup
  TVar ApplicationState
state_box <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. a -> STM (TVar a)
newTVar forall a b. (a -> b) -> a -> b
$ Async () -> ApplicationState
ApplicationStarting (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. a -> b -> a
const ()) Async a
startup_handle)
  -- We start the server in the background, this is done synchronously
  -- with running initialization procedure.
  forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (do
       a
x <- forall a. Async a -> IO a
wait Async a
startup_handle
       forall a. STM a -> IO a
atomically (TVar ApplicationState -> STM ()
switchToRunning TVar ApplicationState
state_box)
       a -> IO b
action a
x) forall a b. (a -> b) -> a -> b
$ \Async b
server -> do
    Async ()
k8s_server <- forall a. IO a -> IO (Async a)
async forall a b. (a -> b) -> a -> b
$ forall void.
Int
-> Int -> K8sChecks -> TVar ApplicationState -> Async void -> IO ()
runK8sServiceEndpoint Int
port Int
maxTearDownPeriodSeconds K8sChecks
k8s TVar ApplicationState
state_box Async b
server
    (do Either SomeException ()
result <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum
          [ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. Async a -> STM (Either SomeException a)
waitCatchSTM Async b
server
          , forall a. TVar a -> STM a
readTVar TVar ApplicationState
state_box forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
              ApplicationState
ApplicationTearingDown -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right ()
              ApplicationState
_ -> forall a. STM a
retry
          ]
        case Either SomeException ()
result of
          Left SomeException
se
            | Just (AsyncCancelled
_ :: AsyncCancelled) <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
se -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            | Just (AsyncException
_ :: AsyncException) <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
se -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            | Bool
otherwise -> forall e a. Exception e => e -> IO a
throwIO SomeException
se
          Right{} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        ) forall a b. IO a -> IO b -> IO a
`finally`
            (let half_interval :: Int
half_interval = Int
maxTearDownPeriodSeconds forall a. Num a => a -> a -> a
* Int
1_000_000 forall a. Integral a => a -> a -> a
`div` Int
2
             in forall a b. IO a -> IO b -> IO ()
race_ (Int -> IO ()
threadDelay Int
half_interval) (forall a. Async a -> IO ()
cancel Async b
server))
          forall a b. IO a -> IO b -> IO a
`finally`
            (IO () -> IO ThreadId
forkIO (forall a. Async a -> IO ()
cancel Async ()
k8s_server))

-- | Run server with k8s endpoint.
--
-- This endpoint provides k8s hooks:
--   1. start
--   2. liveness
--   3. readyness
--   4. pre_stop handler
-- 
-- All other endpoints returns 404.
runK8sServiceEndpoint
  :: Int -- ^ Port
  -> Int -- ^ Shutdown interval
  -> K8sChecks -- ^ K8s hooks
  -> TVar ApplicationState -- ^ Projection of the application state
  -> Async void -- ^ Handle of the running user code
  -> IO ()
runK8sServiceEndpoint :: forall void.
Int
-> Int -> K8sChecks -> TVar ApplicationState -> Async void -> IO ()
runK8sServiceEndpoint Int
port Int
teardown_time_seconds K8sChecks{IO Bool
runLivenessCheck :: IO Bool
runReadynessCheck :: IO Bool
runLivenessCheck :: K8sChecks -> IO Bool
runReadynessCheck :: K8sChecks -> IO Bool
..} TVar ApplicationState
state_box Async void
server = Int -> Application -> IO ()
Warp.run Int
port forall a b. (a -> b) -> a -> b
$ \Request
req Response -> IO ResponseReceived
resp -> do
  case Request -> [Text]
Wai.pathInfo Request
req of
    [Text
"started"] -> do
      forall a. TVar a -> IO a
readTVarIO TVar ApplicationState
state_box forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        ApplicationStarting{} ->
          Response -> IO ResponseReceived
resp forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status400 [(HeaderName
hContentType, ByteString
"text/plain")] ByteString
"starting"
        ApplicationRunning{} ->
          Response -> IO ResponseReceived
resp forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status200 [(HeaderName
hContentType, ByteString
"text/plain")] ByteString
"ok"
        ApplicationState
_ -> 
          Response -> IO ResponseReceived
resp forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status200 [(HeaderName
hContentType, ByteString
"text/plain")] ByteString
"tearing down"
    [Text
"ready"] -> do
      forall a. TVar a -> IO a
readTVarIO TVar ApplicationState
state_box forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        ApplicationStarting{} -> 
          Response -> IO ResponseReceived
resp forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status400 [(HeaderName
hContentType, ByteString
"text/plain")] ByteString
"starting"
        ApplicationRunning{} -> do
          Bool
isReady   <- IO Bool
runReadynessCheck
          if Bool
isReady
          then Response -> IO ResponseReceived
resp forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status200 [(HeaderName
hContentType, ByteString
"text/plain")] ByteString
"running"
          else Response -> IO ResponseReceived
resp forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status400 [(HeaderName
hContentType, ByteString
"text/plain")] ByteString
"not running"
        ApplicationTeardownConfirm TVar Bool
confirmed -> do
          forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
confirmed Bool
True
          Response -> IO ResponseReceived
resp forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status400 [(HeaderName
hContentType, ByteString
"text/plain")] ByteString
"tearing down"
        ApplicationTearingDown{} ->
          Response -> IO ResponseReceived
resp forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status400 [(HeaderName
hContentType, ByteString
"text/plain")] ByteString
"tearing down"
    [Text
"health"] -> do
      forall a. TVar a -> IO a
readTVarIO TVar ApplicationState
state_box forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        ApplicationStarting{} -> 
          Response -> IO ResponseReceived
resp forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status400 [(HeaderName
hContentType, ByteString
"text/plain")] ByteString
"starting"
        ApplicationState
_ -> do
          Bool
isAlive <- IO Bool
runLivenessCheck 
          -- TODO: is thread ok?
          if Bool
isAlive
          then Response -> IO ResponseReceived
resp forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status200 [(HeaderName
hContentType, ByteString
"text/plain")] ByteString
"running"
          else Response -> IO ResponseReceived
resp forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status400 [(HeaderName
hContentType, ByteString
"text/plain")] ByteString
"unhealthy"
    [Text
"stop"]  -> do
      Async ()
_ <- forall a. IO a -> IO (Async a)
async forall a b. (a -> b) -> a -> b
$ do
        TVar Bool
d <- Int -> IO (TVar Bool)
registerDelay (Int
teardown_time_seconds forall a. Num a => a -> a -> a
* Int
1_000_000)
        forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$ forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ TVar Bool -> TVar ApplicationState -> STM (IO ())
switchToTeardown TVar Bool
d TVar ApplicationState
state_box
        -- this is interruptible, and waits until user thread will receive an
        -- exception, but does not wait for the tearing down.
        ThreadId -> IO ()
killThread forall a b. (a -> b) -> a -> b
$ forall a. Async a -> ThreadId
asyncThreadId Async void
server
      Response -> IO ResponseReceived
resp forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status200 [(HeaderName
hContentType, ByteString
"text/plain")] ByteString
"tearing down"
    [Text
"_metrics"] -> Application
metricsApp Request
req Response -> IO ResponseReceived
resp
    -- If it's any other path then we simply return 404
    [Text]
_ -> Response -> IO ResponseReceived
resp forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS Status
status404 [(HeaderName
hContentType, ByteString
"text/plain")] ByteString
"Not found"
    
-- | Switches the application to the tearing down state.
-- 
-- In case if the application was running it injects the confirmation
-- variable that is switched when the k8s system has seen that the
-- application is not ready.
--
-- Returns an action that once finished tells that the application
-- can be teared down. The action does not wait longer that max
-- timeout period
switchToTeardown :: TVar Bool -> TVar ApplicationState -> STM (IO ())
switchToTeardown :: TVar Bool -> TVar ApplicationState -> STM (IO ())
switchToTeardown TVar Bool
timeout TVar ApplicationState
state = forall a. TVar a -> STM a
readTVar TVar ApplicationState
state forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    ApplicationStarting Async ()
init_thread -> do
      forall a. TVar a -> a -> STM ()
writeTVar TVar ApplicationState
state ApplicationState
ApplicationTearingDown
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Async a -> IO ()
cancel Async ()
init_thread 
    ApplicationRunning{} -> do
      TVar Bool
confirmed <- forall a. a -> STM (TVar a)
newTVar Bool
False
      forall a. TVar a -> a -> STM ()
writeTVar TVar ApplicationState
state forall a b. (a -> b) -> a -> b
$ TVar Bool -> ApplicationState
ApplicationTeardownConfirm TVar Bool
confirmed
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
        forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum
         [ forall a. TVar a -> STM a
readTVar TVar Bool
confirmed forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Bool -> STM ()
check
         , forall a. TVar a -> STM a
readTVar TVar Bool
timeout forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Bool -> STM ()
check
         ]
        forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> a -> STM ()
writeTVar TVar ApplicationState
state ApplicationState
ApplicationTearingDown
    ApplicationTeardownConfirm TVar Bool
confirmed -> do
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
        forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum
          [ forall a. TVar a -> STM a
readTVar TVar Bool
confirmed forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Bool -> STM ()
check
          , forall a. TVar a -> STM a
readTVar TVar Bool
timeout forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Bool -> STM ()
check
          ]
        forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> a -> STM ()
writeTVar TVar ApplicationState
state ApplicationState
ApplicationTearingDown
    ApplicationState
ApplicationTearingDown -> 
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()


-- | Switch the state to the running in case if it was started.
-- Otherwise this is a noop.
switchToRunning :: TVar ApplicationState -> STM ()
switchToRunning :: TVar ApplicationState -> STM ()
switchToRunning TVar ApplicationState
state = forall a. TVar a -> STM a
readTVar TVar ApplicationState
state forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  ApplicationStarting{} -> do
    forall a. TVar a -> a -> STM ()
writeTVar TVar ApplicationState
state ApplicationState
ApplicationRunning
  ApplicationRunning{} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  ApplicationTeardownConfirm{} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  ApplicationState
ApplicationTearingDown -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()