{-# LANGUAGE RecordWildCards #-}

module PostgREST.AppState
  ( AppState
  , getConfig
  , getDbStructure
  , getIsWorkerOn
  , getJsonDbS
  , getMainThreadId
  , getPgVersion
  , getPool
  , getTime
  , getRetryNextIn
  , init
  , initWithPool
  , logWithZTime
  , putConfig
  , putDbStructure
  , putIsWorkerOn
  , putJsonDbS
  , putPgVersion
  , putRetryNextIn
  , releasePool
  , signalListener
  , waitListener
  ) where

import qualified Hasql.Pool as SQL

import Control.AutoUpdate (defaultUpdateSettings, mkAutoUpdate,
                           updateAction)
import Data.IORef         (IORef, atomicWriteIORef, newIORef,
                           readIORef)
import Data.Time          (ZonedTime, defaultTimeLocale, formatTime,
                           getZonedTime)
import Data.Time.Clock    (UTCTime, getCurrentTime)

import PostgREST.Config           (AppConfig (..))
import PostgREST.Config.PgVersion (PgVersion (..), minimumPgVersion)
import PostgREST.DbStructure      (DbStructure)

import Protolude


data AppState = AppState
  { AppState -> Pool
statePool         :: SQL.Pool -- | Connection pool, either a 'Connection' or a 'ConnectionError'
  , AppState -> IORef PgVersion
statePgVersion    :: IORef PgVersion
  -- | No schema cache at the start. Will be filled in by the connectionWorker
  , AppState -> IORef (Maybe DbStructure)
stateDbStructure  :: IORef (Maybe DbStructure)
  -- | Cached DbStructure in json
  , AppState -> IORef ByteString
stateJsonDbS      :: IORef ByteString
  -- | Helper ref to make sure just one connectionWorker can run at a time
  , AppState -> IORef Bool
stateIsWorkerOn   :: IORef Bool
  -- | Binary semaphore used to sync the listener(NOTIFY reload) with the connectionWorker.
  , AppState -> MVar ()
stateListener     :: MVar ()
  -- | Config that can change at runtime
  , AppState -> IORef AppConfig
stateConf         :: IORef AppConfig
  -- | Time used for verifying JWT expiration
  , AppState -> IO UTCTime
stateGetTime      :: IO UTCTime
  -- | Time with time zone used for worker logs
  , AppState -> IO ZonedTime
stateGetZTime     :: IO ZonedTime
  -- | Used for killing the main thread in case a subthread fails
  , AppState -> ThreadId
stateMainThreadId :: ThreadId
  -- | Keeps track of when the next retry for connecting to database is scheduled
  , AppState -> IORef Int
stateRetryNextIn  :: IORef Int
  }

init :: AppConfig -> IO AppState
init :: AppConfig -> IO AppState
init AppConfig
conf = do
  Pool
newPool <- AppConfig -> IO Pool
initPool AppConfig
conf
  Pool -> AppConfig -> IO AppState
initWithPool Pool
newPool AppConfig
conf

initWithPool :: SQL.Pool -> AppConfig -> IO AppState
initWithPool :: Pool -> AppConfig -> IO AppState
initWithPool Pool
newPool AppConfig
conf =
  Pool
-> IORef PgVersion
-> IORef (Maybe DbStructure)
-> IORef ByteString
-> IORef Bool
-> MVar ()
-> IORef AppConfig
-> IO UTCTime
-> IO ZonedTime
-> ThreadId
-> IORef Int
-> AppState
AppState Pool
newPool
    (IORef PgVersion
 -> IORef (Maybe DbStructure)
 -> IORef ByteString
 -> IORef Bool
 -> MVar ()
 -> IORef AppConfig
 -> IO UTCTime
 -> IO ZonedTime
 -> ThreadId
 -> IORef Int
 -> AppState)
-> IO (IORef PgVersion)
-> IO
     (IORef (Maybe DbStructure)
      -> IORef ByteString
      -> IORef Bool
      -> MVar ()
      -> IORef AppConfig
      -> IO UTCTime
      -> IO ZonedTime
      -> ThreadId
      -> IORef Int
      -> AppState)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PgVersion -> IO (IORef PgVersion)
forall a. a -> IO (IORef a)
newIORef PgVersion
minimumPgVersion -- assume we're in a supported version when starting, this will be corrected on a later step
    IO
  (IORef (Maybe DbStructure)
   -> IORef ByteString
   -> IORef Bool
   -> MVar ()
   -> IORef AppConfig
   -> IO UTCTime
   -> IO ZonedTime
   -> ThreadId
   -> IORef Int
   -> AppState)
-> IO (IORef (Maybe DbStructure))
-> IO
     (IORef ByteString
      -> IORef Bool
      -> MVar ()
      -> IORef AppConfig
      -> IO UTCTime
      -> IO ZonedTime
      -> ThreadId
      -> IORef Int
      -> AppState)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe DbStructure -> IO (IORef (Maybe DbStructure))
forall a. a -> IO (IORef a)
newIORef Maybe DbStructure
forall a. Maybe a
Nothing
    IO
  (IORef ByteString
   -> IORef Bool
   -> MVar ()
   -> IORef AppConfig
   -> IO UTCTime
   -> IO ZonedTime
   -> ThreadId
   -> IORef Int
   -> AppState)
-> IO (IORef ByteString)
-> IO
     (IORef Bool
      -> MVar ()
      -> IORef AppConfig
      -> IO UTCTime
      -> IO ZonedTime
      -> ThreadId
      -> IORef Int
      -> AppState)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
forall a. Monoid a => a
mempty
    IO
  (IORef Bool
   -> MVar ()
   -> IORef AppConfig
   -> IO UTCTime
   -> IO ZonedTime
   -> ThreadId
   -> IORef Int
   -> AppState)
-> IO (IORef Bool)
-> IO
     (MVar ()
      -> IORef AppConfig
      -> IO UTCTime
      -> IO ZonedTime
      -> ThreadId
      -> IORef Int
      -> AppState)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
False
    IO
  (MVar ()
   -> IORef AppConfig
   -> IO UTCTime
   -> IO ZonedTime
   -> ThreadId
   -> IORef Int
   -> AppState)
-> IO (MVar ())
-> IO
     (IORef AppConfig
      -> IO UTCTime -> IO ZonedTime -> ThreadId -> IORef Int -> AppState)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
    IO
  (IORef AppConfig
   -> IO UTCTime -> IO ZonedTime -> ThreadId -> IORef Int -> AppState)
-> IO (IORef AppConfig)
-> IO
     (IO UTCTime -> IO ZonedTime -> ThreadId -> IORef Int -> AppState)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> AppConfig -> IO (IORef AppConfig)
forall a. a -> IO (IORef a)
newIORef AppConfig
conf
    IO
  (IO UTCTime -> IO ZonedTime -> ThreadId -> IORef Int -> AppState)
-> IO (IO UTCTime)
-> IO (IO ZonedTime -> ThreadId -> IORef Int -> AppState)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> UpdateSettings UTCTime -> IO (IO UTCTime)
forall a. UpdateSettings a -> IO (IO a)
mkAutoUpdate UpdateSettings ()
defaultUpdateSettings { updateAction :: IO UTCTime
updateAction = IO UTCTime
getCurrentTime }
    IO (IO ZonedTime -> ThreadId -> IORef Int -> AppState)
-> IO (IO ZonedTime) -> IO (ThreadId -> IORef Int -> AppState)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> UpdateSettings ZonedTime -> IO (IO ZonedTime)
forall a. UpdateSettings a -> IO (IO a)
mkAutoUpdate UpdateSettings ()
defaultUpdateSettings { updateAction :: IO ZonedTime
updateAction = IO ZonedTime
getZonedTime }
    IO (ThreadId -> IORef Int -> AppState)
-> IO ThreadId -> IO (IORef Int -> AppState)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO ThreadId
myThreadId
    IO (IORef Int -> AppState) -> IO (IORef Int) -> IO AppState
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
0

initPool :: AppConfig -> IO SQL.Pool
initPool :: AppConfig -> IO Pool
initPool AppConfig{Bool
Int
[(Text, Text)]
[ByteString]
[Text]
JSPath
Maybe Integer
Maybe FilePath
Maybe ByteString
Maybe Text
Maybe StringOrURI
Maybe JWKSet
Maybe QualifiedIdentifier
Text
FileMode
NonEmpty Text
NominalDiffTime
OpenAPIMode
LogLevel
configServerUnixSocketMode :: AppConfig -> FileMode
configServerUnixSocket :: AppConfig -> Maybe FilePath
configServerPort :: AppConfig -> Int
configServerHost :: AppConfig -> Text
configRawMediaTypes :: AppConfig -> [ByteString]
configOpenApiServerProxyUri :: AppConfig -> Maybe Text
configOpenApiMode :: AppConfig -> OpenAPIMode
configLogLevel :: AppConfig -> LogLevel
configJwtSecretIsBase64 :: AppConfig -> Bool
configJwtSecret :: AppConfig -> Maybe ByteString
configJwtRoleClaimKey :: AppConfig -> JSPath
configJwtAudience :: AppConfig -> Maybe StringOrURI
configJWKS :: AppConfig -> Maybe JWKSet
configFilePath :: AppConfig -> Maybe FilePath
configDbUseLegacyGucs :: AppConfig -> Bool
configDbUri :: AppConfig -> Text
configDbTxRollbackAll :: AppConfig -> Bool
configDbTxAllowOverride :: AppConfig -> Bool
configDbConfig :: AppConfig -> Bool
configDbSchemas :: AppConfig -> NonEmpty Text
configDbRootSpec :: AppConfig -> Maybe QualifiedIdentifier
configDbPreparedStatements :: AppConfig -> Bool
configDbPreRequest :: AppConfig -> Maybe QualifiedIdentifier
configDbPoolTimeout :: AppConfig -> NominalDiffTime
configDbPoolSize :: AppConfig -> Int
configDbMaxRows :: AppConfig -> Maybe Integer
configDbExtraSearchPath :: AppConfig -> [Text]
configDbChannelEnabled :: AppConfig -> Bool
configDbChannel :: AppConfig -> Text
configDbAnonRole :: AppConfig -> Text
configAppSettings :: AppConfig -> [(Text, Text)]
configServerUnixSocketMode :: FileMode
configServerUnixSocket :: Maybe FilePath
configServerPort :: Int
configServerHost :: Text
configRawMediaTypes :: [ByteString]
configOpenApiServerProxyUri :: Maybe Text
configOpenApiMode :: OpenAPIMode
configLogLevel :: LogLevel
configJwtSecretIsBase64 :: Bool
configJwtSecret :: Maybe ByteString
configJwtRoleClaimKey :: JSPath
configJwtAudience :: Maybe StringOrURI
configJWKS :: Maybe JWKSet
configFilePath :: Maybe FilePath
configDbUseLegacyGucs :: Bool
configDbUri :: Text
configDbTxRollbackAll :: Bool
configDbTxAllowOverride :: Bool
configDbConfig :: Bool
configDbSchemas :: NonEmpty Text
configDbRootSpec :: Maybe QualifiedIdentifier
configDbPreparedStatements :: Bool
configDbPreRequest :: Maybe QualifiedIdentifier
configDbPoolTimeout :: NominalDiffTime
configDbPoolSize :: Int
configDbMaxRows :: Maybe Integer
configDbExtraSearchPath :: [Text]
configDbChannelEnabled :: Bool
configDbChannel :: Text
configDbAnonRole :: Text
configAppSettings :: [(Text, Text)]
..} =
  Settings -> IO Pool
SQL.acquire (Int
configDbPoolSize, NominalDiffTime
configDbPoolTimeout, Text -> ByteString
forall a. ConvertText a Text => a -> ByteString
toUtf8 Text
configDbUri)

getPool :: AppState -> SQL.Pool
getPool :: AppState -> Pool
getPool = AppState -> Pool
statePool

releasePool :: AppState -> IO ()
releasePool :: AppState -> IO ()
releasePool AppState{IO UTCTime
IO ZonedTime
ThreadId
IORef Bool
IORef Int
IORef (Maybe DbStructure)
IORef ByteString
IORef PgVersion
IORef AppConfig
MVar ()
Pool
stateRetryNextIn :: IORef Int
stateMainThreadId :: ThreadId
stateGetZTime :: IO ZonedTime
stateGetTime :: IO UTCTime
stateConf :: IORef AppConfig
stateListener :: MVar ()
stateIsWorkerOn :: IORef Bool
stateJsonDbS :: IORef ByteString
stateDbStructure :: IORef (Maybe DbStructure)
statePgVersion :: IORef PgVersion
statePool :: Pool
stateRetryNextIn :: AppState -> IORef Int
stateMainThreadId :: AppState -> ThreadId
stateGetZTime :: AppState -> IO ZonedTime
stateGetTime :: AppState -> IO UTCTime
stateConf :: AppState -> IORef AppConfig
stateListener :: AppState -> MVar ()
stateIsWorkerOn :: AppState -> IORef Bool
stateJsonDbS :: AppState -> IORef ByteString
stateDbStructure :: AppState -> IORef (Maybe DbStructure)
statePgVersion :: AppState -> IORef PgVersion
statePool :: AppState -> Pool
..} = Pool -> IO ()
SQL.release Pool
statePool IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ThreadId -> AsyncException -> IO ()
forall (m :: * -> *) e.
(MonadIO m, Exception e) =>
ThreadId -> e -> m ()
throwTo ThreadId
stateMainThreadId AsyncException
UserInterrupt

getPgVersion :: AppState -> IO PgVersion
getPgVersion :: AppState -> IO PgVersion
getPgVersion = IORef PgVersion -> IO PgVersion
forall a. IORef a -> IO a
readIORef (IORef PgVersion -> IO PgVersion)
-> (AppState -> IORef PgVersion) -> AppState -> IO PgVersion
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AppState -> IORef PgVersion
statePgVersion

putPgVersion :: AppState -> PgVersion -> IO ()
putPgVersion :: AppState -> PgVersion -> IO ()
putPgVersion = IORef PgVersion -> PgVersion -> IO ()
forall a. IORef a -> a -> IO ()
atomicWriteIORef (IORef PgVersion -> PgVersion -> IO ())
-> (AppState -> IORef PgVersion) -> AppState -> PgVersion -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AppState -> IORef PgVersion
statePgVersion

getDbStructure :: AppState -> IO (Maybe DbStructure)
getDbStructure :: AppState -> IO (Maybe DbStructure)
getDbStructure = IORef (Maybe DbStructure) -> IO (Maybe DbStructure)
forall a. IORef a -> IO a
readIORef (IORef (Maybe DbStructure) -> IO (Maybe DbStructure))
-> (AppState -> IORef (Maybe DbStructure))
-> AppState
-> IO (Maybe DbStructure)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AppState -> IORef (Maybe DbStructure)
stateDbStructure

putDbStructure :: AppState -> DbStructure -> IO ()
putDbStructure :: AppState -> DbStructure -> IO ()
putDbStructure AppState
appState DbStructure
structure =
  IORef (Maybe DbStructure) -> Maybe DbStructure -> IO ()
forall a. IORef a -> a -> IO ()
atomicWriteIORef (AppState -> IORef (Maybe DbStructure)
stateDbStructure AppState
appState) (Maybe DbStructure -> IO ()) -> Maybe DbStructure -> IO ()
forall a b. (a -> b) -> a -> b
$ DbStructure -> Maybe DbStructure
forall a. a -> Maybe a
Just DbStructure
structure

getJsonDbS :: AppState -> IO ByteString
getJsonDbS :: AppState -> IO ByteString
getJsonDbS = IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
readIORef (IORef ByteString -> IO ByteString)
-> (AppState -> IORef ByteString) -> AppState -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AppState -> IORef ByteString
stateJsonDbS

putJsonDbS :: AppState -> ByteString -> IO ()
putJsonDbS :: AppState -> ByteString -> IO ()
putJsonDbS AppState
appState = IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
atomicWriteIORef (AppState -> IORef ByteString
stateJsonDbS AppState
appState)

getIsWorkerOn :: AppState -> IO Bool
getIsWorkerOn :: AppState -> IO Bool
getIsWorkerOn = IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef (IORef Bool -> IO Bool)
-> (AppState -> IORef Bool) -> AppState -> IO Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AppState -> IORef Bool
stateIsWorkerOn

putIsWorkerOn :: AppState -> Bool -> IO ()
putIsWorkerOn :: AppState -> Bool -> IO ()
putIsWorkerOn = IORef Bool -> Bool -> IO ()
forall a. IORef a -> a -> IO ()
atomicWriteIORef (IORef Bool -> Bool -> IO ())
-> (AppState -> IORef Bool) -> AppState -> Bool -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AppState -> IORef Bool
stateIsWorkerOn

getRetryNextIn :: AppState -> IO Int
getRetryNextIn :: AppState -> IO Int
getRetryNextIn = IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef (IORef Int -> IO Int)
-> (AppState -> IORef Int) -> AppState -> IO Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AppState -> IORef Int
stateRetryNextIn

putRetryNextIn :: AppState -> Int -> IO ()
putRetryNextIn :: AppState -> Int -> IO ()
putRetryNextIn = IORef Int -> Int -> IO ()
forall a. IORef a -> a -> IO ()
atomicWriteIORef (IORef Int -> Int -> IO ())
-> (AppState -> IORef Int) -> AppState -> Int -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AppState -> IORef Int
stateRetryNextIn

getConfig :: AppState -> IO AppConfig
getConfig :: AppState -> IO AppConfig
getConfig = IORef AppConfig -> IO AppConfig
forall a. IORef a -> IO a
readIORef (IORef AppConfig -> IO AppConfig)
-> (AppState -> IORef AppConfig) -> AppState -> IO AppConfig
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AppState -> IORef AppConfig
stateConf

putConfig :: AppState -> AppConfig -> IO ()
putConfig :: AppState -> AppConfig -> IO ()
putConfig = IORef AppConfig -> AppConfig -> IO ()
forall a. IORef a -> a -> IO ()
atomicWriteIORef (IORef AppConfig -> AppConfig -> IO ())
-> (AppState -> IORef AppConfig) -> AppState -> AppConfig -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AppState -> IORef AppConfig
stateConf

getTime :: AppState -> IO UTCTime
getTime :: AppState -> IO UTCTime
getTime = AppState -> IO UTCTime
stateGetTime

-- | Log to stderr with local time
logWithZTime :: AppState -> Text -> IO ()
logWithZTime :: AppState -> Text -> IO ()
logWithZTime AppState
appState Text
txt = do
  ZonedTime
zTime <- AppState -> IO ZonedTime
stateGetZTime AppState
appState
  Handle -> Text -> IO ()
forall a (m :: * -> *). (Print a, MonadIO m) => Handle -> a -> m ()
hPutStrLn Handle
stderr (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath -> Text
forall a b. ConvertText a b => a -> b
toS (TimeLocale -> FilePath -> ZonedTime -> FilePath
forall t. FormatTime t => TimeLocale -> FilePath -> t -> FilePath
formatTime TimeLocale
defaultTimeLocale FilePath
"%d/%b/%Y:%T %z: " ZonedTime
zTime) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
txt

getMainThreadId :: AppState -> ThreadId
getMainThreadId :: AppState -> ThreadId
getMainThreadId = AppState -> ThreadId
stateMainThreadId

-- | As this IO action uses `takeMVar` internally, it will only return once
-- `stateListener` has been set using `signalListener`. This is currently used
-- to syncronize workers.
waitListener :: AppState -> IO ()
waitListener :: AppState -> IO ()
waitListener = MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar (MVar () -> IO ()) -> (AppState -> MVar ()) -> AppState -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AppState -> MVar ()
stateListener

-- tryPutMVar doesn't lock the thread. It should always succeed since
-- the connectionWorker is the only mvar producer.
signalListener :: AppState -> IO ()
signalListener :: AppState -> IO ()
signalListener AppState
appState = IO Bool -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Bool -> IO ()) -> IO Bool -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar () -> () -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar (AppState -> MVar ()
stateListener AppState
appState) ()