{-# LANGUAGE GeneralizedNewtypeDeriving, DefaultSignatures, CPP, TypeFamilies #-}
module Database.Selda.Backend.Internal
( StmtID (..), BackendID (..)
, QueryRunner, SeldaBackend (..), SeldaConnection (..), SeldaStmt (..)
, MonadSelda (..), SeldaT (..), SeldaM
, SeldaError (..)
, Param (..), Lit (..), ColAttr (..), AutoIncType (..)
, SqlType (..), SqlValue (..), SqlTypeRep (..)
, PPConfig (..), defPPConfig
, TableInfo (..), ColumnInfo (..), tableInfo, fromColInfo
, isAutoPrimary, isPrimary, isUnique
, sqlDateTimeFormat, sqlDateFormat, sqlTimeFormat
, freshStmtId
, newConnection, allStmts
, runSeldaT, withBackend
) where
import Data.List (nub)
import Database.Selda.SQL (Param (..))
import Database.Selda.SqlType
import Database.Selda.Table hiding (colName, colType, colFKs)
import qualified Database.Selda.Table as Table (ColInfo (..))
import Database.Selda.SQL.Print.Config
import Database.Selda.Types (TableName, ColName)
import Control.Concurrent
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad.State
import Data.Dynamic
import qualified Data.IntMap as M
import Data.IORef
import Data.Text (Text)
import System.IO.Unsafe (unsafePerformIO)
#if !MIN_VERSION_base(4, 13, 0)
import Control.Monad.Fail (MonadFail)
#endif
data BackendID = SQLite | PostgreSQL | Other Text
deriving (Show, Eq, Ord)
data SeldaError
= DbError String
| SqlError String
| UnsafeError String
deriving (Show, Eq, Typeable)
instance Exception SeldaError
newtype StmtID = StmtID Int
deriving (Show, Eq, Ord)
newtype ConnID = ConnID Int
deriving (Show, Eq, Ord)
{-# NOINLINE nextStmtId #-}
nextStmtId :: IORef Int
nextStmtId = unsafePerformIO $ newIORef 1
freshStmtId :: MonadIO m => m StmtID
freshStmtId = liftIO $ atomicModifyIORef' nextStmtId $ \n -> (n+1, StmtID n)
type QueryRunner a = Text -> [Param] -> IO a
data SeldaStmt = SeldaStmt
{
stmtHandle :: !Dynamic
, stmtText :: !Text
, stmtParams :: ![Either Int Param]
}
data SeldaConnection b = SeldaConnection
{
connBackend :: !(SeldaBackend b)
, connDbId :: Text
, connStmts :: !(IORef (M.IntMap SeldaStmt))
, connClosed :: !(IORef Bool)
, connLock :: !(MVar ())
}
newConnection :: MonadIO m => SeldaBackend b -> Text -> m (SeldaConnection b)
newConnection back dbid =
liftIO $ SeldaConnection back dbid <$> newIORef M.empty
<*> newIORef False
<*> newMVar ()
allStmts :: SeldaConnection b -> IO [(StmtID, Dynamic)]
allStmts = fmap (map (\(k, v) -> (StmtID k, stmtHandle v)) . M.toList)
. readIORef
. connStmts
data TableInfo = TableInfo
{
tableColumnInfos :: [ColumnInfo]
, tableUniqueGroups :: [[ColName]]
, tablePrimaryKey :: [ColName]
} deriving (Show, Eq)
data ColumnInfo = ColumnInfo
{
colName :: ColName
, colType :: Either Text SqlTypeRep
, colIsAutoPrimary :: Bool
, colIsNullable :: Bool
, colHasIndex :: Bool
, colFKs :: [(TableName, ColName)]
} deriving (Show, Eq)
fromColInfo :: Table.ColInfo -> ColumnInfo
fromColInfo ci = ColumnInfo
{ colName = Table.colName ci
, colType = Right $ Table.colType ci
, colIsAutoPrimary = any isAutoPrimary (Table.colAttrs ci)
, colIsNullable = Optional `elem` Table.colAttrs ci
, colHasIndex = not $ null [() | Indexed _ <- Table.colAttrs ci]
, colFKs = map fk (Table.colFKs ci)
}
where
fk (Table tbl _ _ _, col) = (tbl, col)
tableInfo :: Table a -> TableInfo
tableInfo t = TableInfo
{ tableColumnInfos = map fromColInfo (tableCols t)
, tableUniqueGroups = uniqueGroups
, tablePrimaryKey = pkGroup
}
where
uniqueGroups =
[ map (Table.colName . ((tableCols t) !!)) ixs
| (ixs, Unique) <- tableAttrs t
]
pkGroup = nub $ concat
[ concat
[ map (Table.colName . ((tableCols t) !!)) ixs
| (ixs, attr) <- tableAttrs t
, isPrimary attr
]
, [ Table.colName col
| col <- tableCols t
, attr <- Table.colAttrs col
, isPrimary attr
]
]
data SeldaBackend b = SeldaBackend
{
runStmt :: Text -> [Param] -> IO (Int, [[SqlValue]])
, runStmtWithPK :: Text -> [Param] -> IO Int
, prepareStmt :: StmtID -> [SqlTypeRep] -> Text -> IO Dynamic
, runPrepared :: Dynamic -> [Param] -> IO (Int, [[SqlValue]])
, getTableInfo :: TableName -> IO TableInfo
, ppConfig :: PPConfig
, closeConnection :: SeldaConnection b -> IO ()
, backendId :: BackendID
, disableForeignKeys :: Bool -> IO ()
}
class MonadIO m => MonadSelda m where
{-# MINIMAL withConnection #-}
type Backend m
withConnection :: (SeldaConnection (Backend m) -> m a) -> m a
transact :: m a -> m a
transact = id
withBackend :: MonadSelda m => (SeldaBackend (Backend m) -> m a) -> m a
withBackend m = withConnection (m . connBackend)
newtype SeldaT b m a = S {unS :: StateT (SeldaConnection b) m a}
deriving ( Functor, Applicative, Monad, MonadIO
, MonadThrow, MonadCatch, MonadMask , MonadFail
)
instance (MonadIO m, MonadMask m) => MonadSelda (SeldaT b m) where
type Backend (SeldaT b m) = b
withConnection m = S get >>= m
instance MonadTrans (SeldaT b) where
lift = S . lift
type SeldaM b = SeldaT b IO
runSeldaT :: (MonadIO m, MonadMask m)
=> SeldaT b m a
-> SeldaConnection b
-> m a
runSeldaT m c =
bracket (liftIO $ takeMVar (connLock c))
(const $ liftIO $ putMVar (connLock c) ())
(const go)
where
go = do
closed <- liftIO $ readIORef (connClosed c)
when closed $ do
liftIO $ throwM $ DbError "runSeldaT called with a closed connection"
fst <$> runStateT (unS m) c