{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}

{- |
Copyright : Flipstone Technology Partners 2023
License   : MIT
Stability : Stable

@since 1.0.0.0
-}
module Orville.PostgreSQL.Internal.MonadOrville
  ( MonadOrville
  , MonadOrvilleControl (liftWithConnection, liftCatch, liftMask)
  , withConnection
  , withConnection_
  , withConnectedState
  )
where

import Control.Exception (Exception)
import Control.Monad.IO.Class (MonadIO)
import Control.Monad.Trans.Reader (ReaderT (ReaderT), mapReaderT, runReaderT)

import Orville.PostgreSQL.Internal.OrvilleState
  ( ConnectedState (ConnectedState, connectedConnection, connectedTransaction)
  , ConnectionState (Connected, NotConnected)
  , OrvilleState
  , connectState
  , orvilleConnectionPool
  , orvilleConnectionState
  )
import Orville.PostgreSQL.Monad.HasOrvilleState (HasOrvilleState (askOrvilleState, localOrvilleState))
import Orville.PostgreSQL.Raw.Connection (Connection, withPoolConnection)

{- |
  'MonadOrville' is the typeclass that most Orville operations require to
  do anything that connects to the database. 'MonadOrville' itself is empty,
  but it lists all the required typeclasses as superclass constraints so that
  it can be used instead of listing all the constraints on every function.

  If you want to be able to run Orville operations directly in your own
  application's Monad stack, a good starting place is to add

  @
    instance MonadOrville MyApplicationMonad
  @

  to your module and then let the compiler tell you what instances you
  are missing from the superclasses.

@since 1.0.0.0
-}
class
  ( HasOrvilleState m
  , MonadOrvilleControl m
  , MonadIO m
  ) =>
  MonadOrville m

{- |
  'MonadOrvilleControl' presents the interface that Orville will use to lift
  low-level IO operations that cannot be lifted via
  'Control.Monad.IO.Class.liftIO' (i.e. those where the IO parameter is
  contravariant rather than covariant).

  For application monads built using only 'ReaderT' and 'IO', this can be
  trivially implemented (or derived), using the 'ReaderT' instance that is
  provided here. If your monad stack is sufficiently complicated, you may
  need to use the @unliftio@ package as a stepping stone to implementing
  'MonadOrvilleControl'. If your monad uses features that @unliftio@ cannot
  support (e.g. the State monad or continuations), then you may need to
  use @monad-control@ instead.

  See 'Orville.PostgreSQL.UnliftIO' for functions that can be used as the
  implementation of the methods below for monads that implement
  'Control.Monad.IO.Unlift.MonadUnliftIO'.

@since 1.0.0.0
-}
class MonadOrvilleControl m where
  -- |
  --     Orville will use this function to lift the acquisition of connections
  --     from the resource pool into the application monad.
  --
  -- @since 1.0.0.0
  liftWithConnection ::
    (forall a. (Connection -> IO a) -> IO a) -> (Connection -> m b) -> m b

  -- |
  --     Orville will use this function to lift exception catches into the
  --     application monad.
  --
  -- @since 1.0.0.0
  liftCatch ::
    Exception e =>
    (forall a. IO a -> (e -> IO a) -> IO a) ->
    m b ->
    (e -> m b) ->
    m b

  -- |
  --     Orville will use this function to lift 'Control.Exception.mask' calls
  --     into the application monad to guarantee resource cleanup is executed
  --     even when asynchronous exceptions are thrown.
  --
  -- @since 1.0.0.0
  liftMask ::
    (forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b) ->
    ((forall a. m a -> m a) -> m c) ->
    m c

instance MonadOrvilleControl IO where
  liftWithConnection :: forall b.
(forall a. (Connection -> IO a) -> IO a)
-> (Connection -> IO b) -> IO b
liftWithConnection forall a. (Connection -> IO a) -> IO a
ioWithConn =
    (Connection -> IO b) -> IO b
forall a. (Connection -> IO a) -> IO a
ioWithConn

  liftCatch :: forall e b.
Exception e =>
(forall a. IO a -> (e -> IO a) -> IO a)
-> IO b -> (e -> IO b) -> IO b
liftCatch forall a. IO a -> (e -> IO a) -> IO a
ioCatch =
    IO b -> (e -> IO b) -> IO b
forall a. IO a -> (e -> IO a) -> IO a
ioCatch

  liftMask :: forall c.
(forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b)
-> ((forall a. IO a -> IO a) -> IO c) -> IO c
liftMask forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
ioMask =
    ((forall a. IO a -> IO a) -> IO c) -> IO c
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
ioMask

instance MonadOrvilleControl m => MonadOrvilleControl (ReaderT state m) where
  liftWithConnection :: forall b.
(forall a. (Connection -> IO a) -> IO a)
-> (Connection -> ReaderT state m b) -> ReaderT state m b
liftWithConnection forall a. (Connection -> IO a) -> IO a
ioWithConn Connection -> ReaderT state m b
action = do
    (state -> m b) -> ReaderT state m b
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((state -> m b) -> ReaderT state m b)
-> (state -> m b) -> ReaderT state m b
forall a b. (a -> b) -> a -> b
$ \state
env ->
      (forall a. (Connection -> IO a) -> IO a)
-> (Connection -> m b) -> m b
forall b.
(forall a. (Connection -> IO a) -> IO a)
-> (Connection -> m b) -> m b
forall (m :: * -> *) b.
MonadOrvilleControl m =>
(forall a. (Connection -> IO a) -> IO a)
-> (Connection -> m b) -> m b
liftWithConnection (Connection -> IO a) -> IO a
forall a. (Connection -> IO a) -> IO a
ioWithConn ((ReaderT state m b -> state -> m b)
-> state -> ReaderT state m b -> m b
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT state m b -> state -> m b
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT state
env (ReaderT state m b -> m b)
-> (Connection -> ReaderT state m b) -> Connection -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> ReaderT state m b
action)

  liftCatch :: forall e b.
Exception e =>
(forall a. IO a -> (e -> IO a) -> IO a)
-> ReaderT state m b
-> (e -> ReaderT state m b)
-> ReaderT state m b
liftCatch forall a. IO a -> (e -> IO a) -> IO a
ioCatch ReaderT state m b
action e -> ReaderT state m b
handler =
    (state -> m b) -> ReaderT state m b
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((state -> m b) -> ReaderT state m b)
-> (state -> m b) -> ReaderT state m b
forall a b. (a -> b) -> a -> b
$ \state
env ->
      (forall a. IO a -> (e -> IO a) -> IO a) -> m b -> (e -> m b) -> m b
forall e b.
Exception e =>
(forall a. IO a -> (e -> IO a) -> IO a) -> m b -> (e -> m b) -> m b
forall (m :: * -> *) e b.
(MonadOrvilleControl m, Exception e) =>
(forall a. IO a -> (e -> IO a) -> IO a) -> m b -> (e -> m b) -> m b
liftCatch
        IO a -> (e -> IO a) -> IO a
forall a. IO a -> (e -> IO a) -> IO a
ioCatch
        (ReaderT state m b -> state -> m b
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT state m b
action state
env)
        (\e
e -> ReaderT state m b -> state -> m b
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (e -> ReaderT state m b
handler e
e) state
env)

  liftMask :: forall c.
(forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b)
-> ((forall a. ReaderT state m a -> ReaderT state m a)
    -> ReaderT state m c)
-> ReaderT state m c
liftMask forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
ioMask (forall a. ReaderT state m a -> ReaderT state m a)
-> ReaderT state m c
action =
    (state -> m c) -> ReaderT state m c
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((state -> m c) -> ReaderT state m c)
-> (state -> m c) -> ReaderT state m c
forall a b. (a -> b) -> a -> b
$ \state
env ->
      (forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b)
-> ((forall {a}. m a -> m a) -> m c) -> m c
forall c.
(forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b)
-> ((forall {a}. m a -> m a) -> m c) -> m c
forall (m :: * -> *) c.
MonadOrvilleControl m =>
(forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b)
-> ((forall a. m a -> m a) -> m c) -> m c
liftMask ((forall a. IO a -> IO a) -> IO b) -> IO b
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
ioMask (((forall {a}. m a -> m a) -> m c) -> m c)
-> ((forall {a}. m a -> m a) -> m c) -> m c
forall a b. (a -> b) -> a -> b
$ \forall {a}. m a -> m a
restore ->
        ReaderT state m c -> state -> m c
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((forall a. ReaderT state m a -> ReaderT state m a)
-> ReaderT state m c
action ((m a -> m a) -> ReaderT state m a -> ReaderT state m a
forall (m :: * -> *) a (n :: * -> *) b r.
(m a -> n b) -> ReaderT r m a -> ReaderT r n b
mapReaderT m a -> m a
forall {a}. m a -> m a
restore)) state
env

instance (MonadOrvilleControl m, MonadIO m) => MonadOrville (ReaderT OrvilleState m)

{- |
  'withConnection' should be used to receive a 'Connection' handle for
  executing queries against the database from within an application monad using
  Orville.  For the "outermost" call of 'withConnection', a connection will be
  acquired from the resource pool. Additional calls to 'withConnection' that
  happen inside the 'm a' that uses the connection will return the same
  'Connection'. When the 'm a' finishes, the connection will be returned to the
  pool. If 'm a' throws an exception, the pool's exception handling will take
  effect, generally destroying the connection in case it was the source of the
  error.

@since 1.0.0.0
-}
withConnection :: MonadOrville m => (Connection -> m a) -> m a
withConnection :: forall (m :: * -> *) a.
MonadOrville m =>
(Connection -> m a) -> m a
withConnection Connection -> m a
connectedAction = do
  (ConnectedState -> m a) -> m a
forall (m :: * -> *) a.
MonadOrville m =>
(ConnectedState -> m a) -> m a
withConnectedState (Connection -> m a
connectedAction (Connection -> m a)
-> (ConnectedState -> Connection) -> ConnectedState -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnectedState -> Connection
connectedConnection)

{- |
  'withConnection_' is a convenience version of 'withConnection' for those that
  don't need the actual connection handle. You might want to use this function
  even without using the handle because it ensures that all the Orville
  operations performed by the action passed to it occur on the same connection.
  Orville uses connection pooling, so unless you use either 'withConnection' or
  'Orville.PostgreSQL.withTransaction', each database operation may be
  performed on a different connection.

@since 1.0.0.0
-}
withConnection_ :: MonadOrville m => m a -> m a
withConnection_ :: forall (m :: * -> *) a. MonadOrville m => m a -> m a
withConnection_ =
  (Connection -> m a) -> m a
forall (m :: * -> *) a.
MonadOrville m =>
(Connection -> m a) -> m a
withConnection ((Connection -> m a) -> m a)
-> (m a -> Connection -> m a) -> m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> Connection -> m a
forall a b. a -> b -> a
const

{- |
  INTERNAL: This in an internal version of 'withConnection' that gives access to
  the entire 'ConnectedState' value to allow for transaction management.

@since 1.0.0.0
-}
withConnectedState :: MonadOrville m => (ConnectedState -> m a) -> m a
withConnectedState :: forall (m :: * -> *) a.
MonadOrville m =>
(ConnectedState -> m a) -> m a
withConnectedState ConnectedState -> m a
connectedAction = do
  OrvilleState
state <- m OrvilleState
forall (m :: * -> *). HasOrvilleState m => m OrvilleState
askOrvilleState

  case OrvilleState -> ConnectionState
orvilleConnectionState OrvilleState
state of
    Connected ConnectedState
connectedState ->
      ConnectedState -> m a
connectedAction ConnectedState
connectedState
    ConnectionState
NotConnected ->
      let
        pool :: ConnectionPool
pool = OrvilleState -> ConnectionPool
orvilleConnectionPool OrvilleState
state
      in
        (forall a. (Connection -> IO a) -> IO a)
-> (Connection -> m a) -> m a
forall b.
(forall a. (Connection -> IO a) -> IO a)
-> (Connection -> m b) -> m b
forall (m :: * -> *) b.
MonadOrvilleControl m =>
(forall a. (Connection -> IO a) -> IO a)
-> (Connection -> m b) -> m b
liftWithConnection (ConnectionPool -> (Connection -> IO a) -> IO a
forall a. ConnectionPool -> (Connection -> IO a) -> IO a
withPoolConnection ConnectionPool
pool) ((Connection -> m a) -> m a) -> (Connection -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \Connection
conn ->
          let
            connectedState :: ConnectedState
connectedState =
              ConnectedState
                { connectedConnection :: Connection
connectedConnection = Connection
conn
                , connectedTransaction :: Maybe TransactionState
connectedTransaction = Maybe TransactionState
forall a. Maybe a
Nothing
                }
          in
            (OrvilleState -> OrvilleState) -> m a -> m a
forall a. (OrvilleState -> OrvilleState) -> m a -> m a
forall (m :: * -> *) a.
HasOrvilleState m =>
(OrvilleState -> OrvilleState) -> m a -> m a
localOrvilleState (ConnectedState -> OrvilleState -> OrvilleState
connectState ConnectedState
connectedState) (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$
              ConnectedState -> m a
connectedAction ConnectedState
connectedState