{-# LANGUAGE ScopedTypeVariables #-}

{- |
Copyright : Flipstone Technology Partners 2023
License   : MIT
Stability : Stable

@since 1.0.0.0
-}
module Orville.PostgreSQL.Internal.MigrationLock
  ( MigrationLockId
  , defaultLockId
  , nextLockId
  , withMigrationLock
  , MigrationLockError
  )
where

import Control.Concurrent (threadDelay)
import Control.Exception (Exception, throwIO)
import qualified Control.Monad as Monad
import qualified Control.Monad.IO.Class as MIO
import Data.Int (Int32)

import qualified Orville.PostgreSQL.Execution as Exec
import qualified Orville.PostgreSQL.Internal.Bracket as Bracket
import qualified Orville.PostgreSQL.Marshall as Marshall
import qualified Orville.PostgreSQL.Monad as Monad
import qualified Orville.PostgreSQL.Raw.RawSql as RawSql
import qualified Orville.PostgreSQL.Raw.SqlValue as SqlValue

{- |
Identifies a PostgreSQL advisory lock to to be aquired by the application. Use
'defaultLockId' to obtain the default value and 'nextLockId' to create custom
values if you need them.

@since 1.0.0.0
-}
data MigrationLockId = MigrationLockId
  { MigrationLockId -> Int32
i_lockKey1 :: Int32
  , MigrationLockId -> Int32
i_lockKey2 :: Int32
  }

{- |
The lock id that Orville uses by default to ensure that just one copy of the
application is attempting to run migrations at a time.

@since 1.0.0.0
-}
defaultLockId :: MigrationLockId
defaultLockId :: MigrationLockId
defaultLockId =
  MigrationLockId
    { i_lockKey1 :: Int32
i_lockKey1 = Int32
orvilleLockScope
    , i_lockKey2 :: Int32
i_lockKey2 = Int32
7995632
    }

{- |
Increments the id of the given 'MigrationLockId', creating a new distinct lock
id. You can use this to create your own custom 'MigrationLockId' values as
necessary if you need to control migration runs in a custom manner.

@since 1.0.0.0
-}
nextLockId :: MigrationLockId -> MigrationLockId
nextLockId :: MigrationLockId -> MigrationLockId
nextLockId MigrationLockId
lockId =
  MigrationLockId
lockId
    { i_lockKey2 :: Int32
i_lockKey2 = Int32
1 Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
+ MigrationLockId -> Int32
i_lockKey2 MigrationLockId
lockId
    }

orvilleLockScope :: Int32
orvilleLockScope :: Int32
orvilleLockScope = Int32
17772

{- |
  Executes an Orville action with a PostgreSQL advisory lock held that
  indicates to other Orville processes that a database migration is being done
  and no others should be performed concurrently.

@since 1.0.0.0
-}
withMigrationLock ::
  Monad.MonadOrville m =>
  MigrationLockId ->
  m a ->
  m a
withMigrationLock :: forall (m :: * -> *) a.
MonadOrville m =>
MigrationLockId -> m a -> m a
withMigrationLock MigrationLockId
lockId m a
action =
  m a -> m a
forall (m :: * -> *) a. MonadOrville m => m a -> m a
Monad.withConnection_ (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$
    m () -> (() -> BracketResult -> m ()) -> (() -> m a) -> m a
forall (m :: * -> *) a c b.
(MonadIO m, MonadOrvilleControl m) =>
m a -> (a -> BracketResult -> m c) -> (a -> m b) -> m b
Bracket.bracketWithResult
      (MigrationLockId -> m ()
forall (m :: * -> *). MonadOrville m => MigrationLockId -> m ()
accquireTransactionLock MigrationLockId
lockId)
      (\() BracketResult
_bracketResult -> MigrationLockId -> m ()
forall (m :: * -> *). MonadOrville m => MigrationLockId -> m ()
releaseTransactionLock MigrationLockId
lockId)
      (\() -> m a
action)

accquireTransactionLock ::
  forall m.
  Monad.MonadOrville m =>
  MigrationLockId ->
  m ()
accquireTransactionLock :: forall (m :: * -> *). MonadOrville m => MigrationLockId -> m ()
accquireTransactionLock MigrationLockId
lockId =
  let
    go :: Int -> m ()
    go :: Int -> m ()
go Int
attempts = do
      Bool
locked <- m Bool
attemptLockAcquisition
      if Bool
locked
        then () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        else do
          IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
MIO.liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
Monad.when (Int
attempts Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
25) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
              MigrationLockError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (MigrationLockError -> IO ()) -> MigrationLockError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> MigrationLockError
MigrationLockError
                  String
"Giving up after 25 attempts to aquire the migration lock."
            Int -> IO ()
threadDelay Int
10000

          Int -> m ()
go (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ Int
attempts Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

    attemptLockAcquisition :: m Bool
attemptLockAcquisition = do
      [Bool]
tryLockResults <-
        QueryType -> RawSql -> AnnotatedSqlMarshaller Bool Bool -> m [Bool]
forall (m :: * -> *) sql writeEntity readEntity.
(MonadOrville m, SqlExpression sql) =>
QueryType
-> sql
-> AnnotatedSqlMarshaller writeEntity readEntity
-> m [readEntity]
Exec.executeAndDecode QueryType
Exec.OtherQuery (MigrationLockId -> RawSql
tryLockExpr MigrationLockId
lockId) AnnotatedSqlMarshaller Bool Bool
lockedMarshaller

      case [Bool]
tryLockResults of
        [Bool
locked] ->
          Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
locked
        [Bool]
rows ->
          IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
MIO.liftIO (IO Bool -> m Bool) -> (String -> IO Bool) -> String -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MigrationLockError -> IO Bool
forall e a. Exception e => e -> IO a
throwIO (MigrationLockError -> IO Bool)
-> (String -> MigrationLockError) -> String -> IO Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> MigrationLockError
MigrationLockError (String -> m Bool) -> String -> m Bool
forall a b. (a -> b) -> a -> b
$
            String
"Expected exactly one row from attempt to acquire migration lock, but got " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([Bool] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Bool]
rows)
  in
    Int -> m ()
go Int
0

releaseTransactionLock :: Monad.MonadOrville m => MigrationLockId -> m ()
releaseTransactionLock :: forall (m :: * -> *). MonadOrville m => MigrationLockId -> m ()
releaseTransactionLock =
  QueryType -> RawSql -> m ()
forall (m :: * -> *) sql.
(MonadOrville m, SqlExpression sql) =>
QueryType -> sql -> m ()
Exec.executeVoid QueryType
Exec.OtherQuery (RawSql -> m ())
-> (MigrationLockId -> RawSql) -> MigrationLockId -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MigrationLockId -> RawSql
releaseLockExpr

lockedMarshaller :: Marshall.AnnotatedSqlMarshaller Bool Bool
lockedMarshaller :: AnnotatedSqlMarshaller Bool Bool
lockedMarshaller =
  SqlMarshaller Bool Bool -> AnnotatedSqlMarshaller Bool Bool
forall writeEntity readEntity.
SqlMarshaller writeEntity readEntity
-> AnnotatedSqlMarshaller writeEntity readEntity
Marshall.annotateSqlMarshallerEmptyAnnotation (SqlMarshaller Bool Bool -> AnnotatedSqlMarshaller Bool Bool)
-> SqlMarshaller Bool Bool -> AnnotatedSqlMarshaller Bool Bool
forall a b. (a -> b) -> a -> b
$
    (Bool -> Bool)
-> FieldDefinition NotNull Bool -> SqlMarshaller Bool Bool
forall writeEntity fieldValue nullability.
(writeEntity -> fieldValue)
-> FieldDefinition nullability fieldValue
-> SqlMarshaller writeEntity fieldValue
Marshall.marshallField Bool -> Bool
forall a. a -> a
id (String -> FieldDefinition NotNull Bool
Marshall.booleanField String
"locked")

tryLockExpr :: MigrationLockId -> RawSql.RawSql
tryLockExpr :: MigrationLockId -> RawSql
tryLockExpr MigrationLockId
lockId =
  String -> RawSql
RawSql.fromString String
"SELECT pg_try_advisory_lock"
    RawSql -> RawSql -> RawSql
forall a. Semigroup a => a -> a -> a
<> RawSql
RawSql.leftParen
    RawSql -> RawSql -> RawSql
forall a. Semigroup a => a -> a -> a
<> SqlValue -> RawSql
RawSql.parameter (Int32 -> SqlValue
SqlValue.fromInt32 (MigrationLockId -> Int32
i_lockKey1 MigrationLockId
lockId))
    RawSql -> RawSql -> RawSql
forall a. Semigroup a => a -> a -> a
<> RawSql
RawSql.comma
    RawSql -> RawSql -> RawSql
forall a. Semigroup a => a -> a -> a
<> SqlValue -> RawSql
RawSql.parameter (Int32 -> SqlValue
SqlValue.fromInt32 (MigrationLockId -> Int32
i_lockKey2 MigrationLockId
lockId))
    RawSql -> RawSql -> RawSql
forall a. Semigroup a => a -> a -> a
<> RawSql
RawSql.rightParen
    RawSql -> RawSql -> RawSql
forall a. Semigroup a => a -> a -> a
<> String -> RawSql
RawSql.fromString String
" as locked"

releaseLockExpr :: MigrationLockId -> RawSql.RawSql
releaseLockExpr :: MigrationLockId -> RawSql
releaseLockExpr MigrationLockId
lockId =
  String -> RawSql
RawSql.fromString String
"SELECT pg_advisory_unlock"
    RawSql -> RawSql -> RawSql
forall a. Semigroup a => a -> a -> a
<> RawSql
RawSql.leftParen
    RawSql -> RawSql -> RawSql
forall a. Semigroup a => a -> a -> a
<> SqlValue -> RawSql
RawSql.parameter (Int32 -> SqlValue
SqlValue.fromInt32 (MigrationLockId -> Int32
i_lockKey1 MigrationLockId
lockId))
    RawSql -> RawSql -> RawSql
forall a. Semigroup a => a -> a -> a
<> RawSql
RawSql.comma
    RawSql -> RawSql -> RawSql
forall a. Semigroup a => a -> a -> a
<> SqlValue -> RawSql
RawSql.parameter (Int32 -> SqlValue
SqlValue.fromInt32 (MigrationLockId -> Int32
i_lockKey2 MigrationLockId
lockId))
    RawSql -> RawSql -> RawSql
forall a. Semigroup a => a -> a -> a
<> RawSql
RawSql.rightParen

{- |
  Raised if 'withMigrationLock' cannot acquire the migration lock in a
  timely manner.

@since 1.0.0.0
-}
newtype MigrationLockError
  = MigrationLockError String
  deriving (Int -> MigrationLockError -> String -> String
[MigrationLockError] -> String -> String
MigrationLockError -> String
(Int -> MigrationLockError -> String -> String)
-> (MigrationLockError -> String)
-> ([MigrationLockError] -> String -> String)
-> Show MigrationLockError
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
$cshowsPrec :: Int -> MigrationLockError -> String -> String
showsPrec :: Int -> MigrationLockError -> String -> String
$cshow :: MigrationLockError -> String
show :: MigrationLockError -> String
$cshowList :: [MigrationLockError] -> String -> String
showList :: [MigrationLockError] -> String -> String
Show)

instance Exception MigrationLockError