{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
module Traction.Control (
    Db
  , DbT (..)
  , DbError (..)
  , Tracer
  , renderDbError
  , DbPool (..)
  , DbPoolConfiguration (..)
  , defaultDbPoolConfiguration
  , MonadDb (..)
  , transaction
  , transactionT
  , runDb
  , runDbT
  , runDbWith
  , runDbWithT
  , runDbTracing
  , runDbTracingT
  , runDbTracingWith
  , runDbTracingWithT
  , newPool
  , newPoolWith
  , newRollbackPool
  , newRollbackPoolWith
  , withRollbackSingletonPool
  , withConnection
  , failWith
  , withTracing
  , trace
  , noTracing
  ) where

import           Control.Monad.Catch (Exception, MonadMask (..), MonadThrow, MonadCatch, Handler (..), handle, catches, bracket_, throwM)
import           Control.Monad.IO.Class (MonadIO (..))
import           Control.Monad.Morph (MFunctor (..), squash)
import           Control.Monad.Trans.Class (MonadTrans (..))
import           Control.Monad.Trans.Except (ExceptT(..))
import           Control.Monad.Trans.Reader (ReaderT (..), ask, asks, local)

import           Data.ByteString (ByteString)
import           Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Time as Time
import           Data.Typeable (Typeable)
import qualified Data.Pool as Pool

import qualified Database.PostgreSQL.Simple as Postgresql

import           System.IO (IO)

import           Traction.Prelude


newtype DbPool =
  DbPool {
      runDbPool :: forall a. (Postgresql.Connection -> EitherT DbError IO a) -> EitherT DbError IO a
    }

data TransactionContext =
    InTransaction Postgresql.Connection
  | NotInTransaction DbPool

data TractionSettings =
  TractionSettings {
      transactionContext :: TransactionContext
    , tracer :: Tracer
    }

type Tracer = Text -> IO ()

noTracing :: Tracer
noTracing = const (pure ())

type Db =
  DbT IO

newtype DbT m a =
  DbT {
      _runDb :: ReaderT TractionSettings (EitherT DbError m) a
    }  deriving (Functor, Applicative, Monad, MonadIO, MonadMask, MonadThrow, MonadCatch)

instance MFunctor DbT where
  hoist f = DbT . hoist (hoist f) . _runDb

instance MonadTrans DbT where
  lift = DbT . lift . lift

data DbError =
    DbSqlError Postgresql.Query Postgresql.SqlError
  | DbQueryError Postgresql.Query Postgresql.QueryError
  | DbFormatError Postgresql.Query Postgresql.FormatError
  | DbResultError Postgresql.Query Postgresql.ResultError
  | DbTooManyResults Postgresql.Query Int
  | DbNoResults Postgresql.Query
  | DbEncodingInvariant Postgresql.Query Text Text
    deriving (Show, Eq, Typeable)

renderDbError :: DbError -> Text
renderDbError e =
  Text.pack $ case e of
    DbSqlError q err ->
      mconcat ["SQL Error [", show err, "], for query: ", show q]
    DbQueryError q err ->
      mconcat ["Query Error [", show err, "], for query: ", show q]
    DbFormatError q err ->
      mconcat ["Format Error [", show err, "], for query: ", show q]
    DbResultError q err ->
      mconcat ["Result Error [", show err, "], for query: ", show q]
    DbTooManyResults q n ->
      mconcat ["Too many results [", show n , "], for query: ", show q]
    DbNoResults q ->
      mconcat ["Query generated no results, for query: ", show q]
    DbEncodingInvariant q field encoding ->
      mconcat ["Query could not decode results, expected to be able to decode [", Text.unpack field, "], to type, [", Text.unpack encoding, "], for query: ", show q]

class MonadIO m => MonadDb m where
  liftDb :: DbT IO a -> m a

instance MonadIO m => MonadDb (DbT m) where
  liftDb = hoist liftIO

instance MonadDb m => MonadDb (ExceptT e m) where
  liftDb = lift . liftDb

data WithTransaction =
    WithTransaction
  | WithoutTransaction

failWith :: DbError -> Db a
failWith =
  DbT . lift . left

runDb :: DbPool -> Db a -> EitherT DbError IO a
runDb pool db =
  runDbWith pool WithTransaction db

runDbT :: DbPool -> (DbError -> e) -> EitherT e Db a -> EitherT e IO a
runDbT pool handler db =
  runDbWithT pool WithTransaction handler db

runDbWith :: DbPool -> WithTransaction -> Db a -> EitherT DbError IO a
runDbWith pool tx db =
  runDbTracingWith pool noTracing tx db

runDbWithT :: DbPool -> WithTransaction -> (DbError -> e) -> EitherT e Db a -> EitherT e IO a
runDbWithT pool tx handler db =
  runDbTracingWithT pool noTracing tx handler db

runDbTracing :: DbPool -> Tracer -> Db a -> EitherT DbError IO a
runDbTracing pool tr db =
  runDbTracingWith pool tr WithTransaction db

runDbTracingT :: DbPool -> Tracer -> (DbError -> e) -> EitherT e Db a -> EitherT e IO a
runDbTracingT pool tr handler db =
  runDbTracingWithT pool tr WithTransaction handler db

runDbTracingWith :: DbPool -> Tracer -> WithTransaction -> Db a -> EitherT DbError IO a
runDbTracingWith pool tr tx db =
  runReaderT (_runDb $ case tx of
    WithTransaction ->
      transaction db
    WithoutTransaction ->
      db) $ TractionSettings (NotInTransaction pool) tr

runDbTracingWithT :: DbPool -> Tracer -> WithTransaction -> (DbError -> e) -> EitherT e Db a -> EitherT e IO a
runDbTracingWithT pool tr tx handler db =
  squash $ mapEitherT (firstEitherT handler . runDbTracingWith pool tr tx) db

transaction :: Db a -> Db a
transaction db =
  DbT $ ask >>= \cc -> lift $ case transactionContext cc of
    InTransaction _ ->
      runReaderT (_runDb db) cc
    NotInTransaction pool ->
      runDbPool pool $ \c ->
        runReaderT (_runDb db) $ TractionSettings (InTransaction c) noTracing

transactionT :: EitherT e Db a -> EitherT e Db a
transactionT =
  transactional runEitherT newEitherT

transactional :: (Monad m, Monad n) => (m a -> Db (n a)) -> (Db (n a) -> m a) -> m a -> m a
transactional sifter lifter db =
  lifter . DbT $ ask >>= \cc -> lift $ case transactionContext cc of
    InTransaction c ->
      runReaderT (_runDb $ sifter db) $ TractionSettings (InTransaction c) noTracing
    NotInTransaction pool ->
      runDbPool pool $ \c ->
        runReaderT (_runDb $ sifter db) $ TractionSettings (InTransaction c) noTracing

data DbPoolConfiguration =
  DbPoolConfiguration {
      dbPoolStripes :: Int
    , dbPoolKeepAliveSeconds :: Time.NominalDiffTime
    , dbPoolSize :: Int
    } deriving (Eq, Ord, Show)

defaultDbPoolConfiguration :: DbPoolConfiguration
defaultDbPoolConfiguration =
  DbPoolConfiguration {
      dbPoolStripes = 4
    , dbPoolKeepAliveSeconds = 20
    , dbPoolSize = 20
    }

data RollbackException =
    RollbackException DbError
    deriving (Eq, Show, Typeable)

instance Exception RollbackException

newPool :: ByteString -> IO DbPool
newPool connection =
  newPoolWith connection defaultDbPoolConfiguration (pure ())

newPoolWith :: ByteString -> DbPoolConfiguration -> Db () -> IO DbPool
newPoolWith connection configuration initializer = do
  pool <- Pool.createPool
    (Postgresql.connectPostgreSQL connection)
    Postgresql.close
    (dbPoolStripes configuration)
    (dbPoolKeepAliveSeconds configuration)
    (dbPoolSize configuration)
  pure $ DbPool $ \db ->
    newEitherT $ Pool.withResource pool $ \c ->
      handle (\(RollbackException e) -> pure $ Left e) $ Postgresql.withTransaction c $ do
        r <- runEitherT $ do
          runReaderT (_runDb initializer) $ TractionSettings (InTransaction c) noTracing
          db c
        case r of
          Left e -> do
            throwM $ RollbackException e
          Right _ ->
            pure r

newRollbackPool :: ByteString -> IO DbPool
newRollbackPool connection =
  newRollbackPoolWith connection defaultDbPoolConfiguration (pure ())

newRollbackPoolWith :: ByteString -> DbPoolConfiguration -> Db () -> IO DbPool
newRollbackPoolWith connection configuration initializer = do
  pool <- Pool.createPool
    (Postgresql.connectPostgreSQL connection)
    Postgresql.close
    (dbPoolStripes configuration)
    (dbPoolKeepAliveSeconds configuration)
    (dbPoolSize configuration)
  pure $ DbPool $ \db ->
    newEitherT $ Pool.withResource pool $ \c ->
      bracket_ (Postgresql.begin c) (Postgresql.rollback c)  $ do
        runEitherT $ do
          runReaderT (_runDb initializer) $ TractionSettings (InTransaction c) noTracing
          db c

withRollbackSingletonPool :: (MonadMask m, MonadIO m) => ByteString -> (DbPool -> m a) -> m a
withRollbackSingletonPool connection action = do
  c <- liftIO . Postgresql.connectPostgreSQL $ connection
  bracket_ (liftIO $ Postgresql.begin c) (liftIO $ Postgresql.rollback c) $
    action $ DbPool $ \db -> db c

withConnection :: Postgresql.Query -> (Postgresql.Connection -> IO a) -> Db a
withConnection query f =
  DbT $ ask >>= \cc -> case transactionContext cc of
    InTransaction c ->
      lift . safely query . f $ c
    NotInTransaction pool ->
      lift . runDbPool pool $ \c -> (safely query . f $ c)

withTracing :: Tracer -> DbT m () -> DbT m ()
withTracing f (DbT db) = DbT $ local (\x -> x { tracer = f }) db

trace :: (MonadDb m, Show a) => a -> m ()
trace a =
  liftDb . DbT $ asks tracer >>= \t ->
    liftIO . t . Text.pack $ show a

safely :: Postgresql.Query -> IO a -> EitherT DbError IO a
safely query action =
  newEitherT $ catches (Right <$> action) [
      Handler $ pure . Left . DbSqlError query
    , Handler $ pure . Left . DbQueryError query
    , Handler $ pure . Left . DbFormatError query
    , Handler $ pure . Left . DbResultError query
    ]