{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}

-- | A version of 'GraphulaT' that @'remove'@s all @'insert'@ed data afterward
module Graphula.Idempotent
  ( GraphulaIdempotentT
  , runGraphulaIdempotentT
  ) where

import Prelude

import Control.Monad.IO.Unlift
import Control.Monad.Reader (MonadReader, ReaderT, ask, runReaderT)
import Control.Monad.Trans (MonadTrans, lift)
import Data.Foldable (for_)
import Data.IORef (IORef, modifyIORef', newIORef, readIORef)
import Database.Persist (Entity (..))
import Graphula.Class
import UnliftIO.Exception (SomeException, catch, mask, throwIO)

newtype GraphulaIdempotentT m a = GraphulaIdempotentT
  { runGraphulaIdempotentT' :: ReaderT (IORef (m ())) m a
  }
  deriving newtype
    ( Functor
    , Applicative
    , Monad
    , MonadIO
    , MonadReader (IORef (m ()))
    )

instance MonadUnliftIO m => MonadUnliftIO (GraphulaIdempotentT m) where
  {-# INLINE withRunInIO #-}
  withRunInIO inner = GraphulaIdempotentT $ withRunInIO $ \run ->
    inner $ run . runGraphulaIdempotentT'

instance MonadTrans GraphulaIdempotentT where
  lift = GraphulaIdempotentT . lift

instance
  (MonadIO m, MonadGraphulaFrontend m)
  => MonadGraphulaFrontend (GraphulaIdempotentT m)
  where
  insert mKey n = do
    finalizersRef <- ask
    mEnt <- lift $ insert mKey n
    for_ (entityKey <$> mEnt) $
      \key -> liftIO $ modifyIORef' finalizersRef (remove key >>)
    pure mEnt
  insertKeyed key n = do
    finalizersRef <- ask
    mEnt <- lift $ insertKeyed key n
    for_ (entityKey <$> mEnt) $
      \key' -> liftIO $ modifyIORef' finalizersRef (remove key' >>)
    pure mEnt
  remove = lift . remove

runGraphulaIdempotentT :: MonadUnliftIO m => GraphulaIdempotentT m a -> m a
runGraphulaIdempotentT action = mask $ \unmasked -> do
  finalizersRef <- liftIO . newIORef $ pure ()
  x <-
    unmasked $
      runReaderT (runGraphulaIdempotentT' action) finalizersRef
        `catch` rollbackRethrow finalizersRef
  rollback finalizersRef $ pure x
 where
  rollback :: MonadIO m => IORef (m a) -> m b -> m b
  rollback finalizersRef x = do
    finalizers <- liftIO $ readIORef finalizersRef
    finalizers >> x

  rollbackRethrow :: MonadIO m => IORef (m a) -> SomeException -> m b
  rollbackRethrow finalizersRef (e :: SomeException) =
    rollback finalizersRef (throwIO e)
