{-# LANGUAGE GeneralizedNewtypeDeriving, DefaultSignatures, CPP, TypeFamilies #-}
module Database.Selda.Backend.Internal
( StmtID (..), BackendID (..)
, QueryRunner, SeldaBackend (..), SeldaConnection (..), SeldaStmt (..)
, MonadSelda (..), SeldaT (..), SeldaM
, SeldaError (..)
, Param (..), Lit (..), ColAttr (..), AutoIncType (..)
, SqlType (..), SqlValue (..), SqlTypeRep (..)
, PPConfig (..), defPPConfig
, TableInfo (..), ColumnInfo (..), tableInfo, fromColInfo
, isAutoPrimary, isPrimary, isUnique
, sqlDateTimeFormat, sqlDateFormat, sqlTimeFormat
, freshStmtId
, newConnection, allStmts
, runSeldaT, withBackend
) where
import Database.Selda.SQL (Param (..))
import Database.Selda.SqlType
import Database.Selda.Table hiding (colName, colType, colFKs)
import qualified Database.Selda.Table as Table (ColInfo (..))
import Database.Selda.SQL.Print.Config
import Database.Selda.Types (TableName, ColName)
import Control.Concurrent
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad.State
import Data.Dynamic
import qualified Data.IntMap as M
import Data.IORef
import Data.Text (Text)
import System.IO.Unsafe (unsafePerformIO)
#if !MIN_VERSION_base(4, 13, 0)
import Control.Monad.Fail (MonadFail)
#endif
data BackendID = SQLite | PostgreSQL | Other Text
deriving (Show, Eq, Ord)
data SeldaError
= DbError String
| SqlError String
deriving (Show, Eq, Typeable)
instance Exception SeldaError
newtype StmtID = StmtID Int
deriving (Show, Eq, Ord)
newtype ConnID = ConnID Int
deriving (Show, Eq, Ord)
{-# NOINLINE nextStmtId #-}
nextStmtId :: IORef Int
nextStmtId = unsafePerformIO $ newIORef 1
freshStmtId :: MonadIO m => m StmtID
freshStmtId = liftIO $ atomicModifyIORef' nextStmtId $ \n -> (n+1, StmtID n)
type QueryRunner a = Text -> [Param] -> IO a
data SeldaStmt = SeldaStmt
{
stmtHandle :: !Dynamic
, stmtText :: !Text
, stmtParams :: ![Either Int Param]
}
data SeldaConnection b = SeldaConnection
{
connBackend :: !(SeldaBackend b)
, connDbId :: Text
, connStmts :: !(IORef (M.IntMap SeldaStmt))
, connClosed :: !(IORef Bool)
, connLock :: !(MVar ())
}
newConnection :: MonadIO m => SeldaBackend b -> Text -> m (SeldaConnection b)
newConnection back dbid =
liftIO $ SeldaConnection back dbid <$> newIORef M.empty
<*> newIORef False
<*> newMVar ()
allStmts :: SeldaConnection b -> IO [(StmtID, Dynamic)]
allStmts = fmap (map (\(k, v) -> (StmtID k, stmtHandle v)) . M.toList)
. readIORef
. connStmts
data TableInfo = TableInfo
{
tableColumnInfos :: [ColumnInfo]
, tableUniqueGroups :: [[ColName]]
, tablePrimaryKey :: [ColName]
} deriving (Show, Eq)
data ColumnInfo = ColumnInfo
{
colName :: ColName
, colType :: Either Text SqlTypeRep
, colIsAutoPrimary :: Bool
, colIsNullable :: Bool
, colHasIndex :: Bool
, colFKs :: [(TableName, ColName)]
} deriving (Show, Eq)
fromColInfo :: Table.ColInfo -> ColumnInfo
fromColInfo ci = ColumnInfo
{ colName = Table.colName ci
, colType = Right $ Table.colType ci
, colIsAutoPrimary = any isAutoPrimary (Table.colAttrs ci)
, colIsNullable = Optional `elem` Table.colAttrs ci
, colHasIndex = not $ null [() | Indexed _ <- Table.colAttrs ci]
, colFKs = map fk (Table.colFKs ci)
}
where
fk (Table tbl _ _ _, col) = (tbl, col)
tableInfo :: Table a -> TableInfo
tableInfo t = TableInfo
{ tableColumnInfos = map fromColInfo (tableCols t)
, tableUniqueGroups = uniqueGroups
, tablePrimaryKey = pkGroups
}
where
uniqueGroups =
[ map (Table.colName . ((tableCols t) !!)) ixs
| (ixs, Unique) <- tableAttrs t
]
pkGroups = concat
[ map (Table.colName . ((tableCols t) !!)) ixs
| (ixs, Primary) <- tableAttrs t
]
data SeldaBackend b = SeldaBackend
{
runStmt :: Text -> [Param] -> IO (Int, [[SqlValue]])
, runStmtWithPK :: Text -> [Param] -> IO Int
, prepareStmt :: StmtID -> [SqlTypeRep] -> Text -> IO Dynamic
, runPrepared :: Dynamic -> [Param] -> IO (Int, [[SqlValue]])
, getTableInfo :: TableName -> IO TableInfo
, ppConfig :: PPConfig
, closeConnection :: SeldaConnection b -> IO ()
, backendId :: BackendID
, disableForeignKeys :: Bool -> IO ()
}
class MonadIO m => MonadSelda m where
{-# MINIMAL withConnection #-}
type Backend m
withConnection :: (SeldaConnection (Backend m) -> m a) -> m a
transact :: m a -> m a
transact = id
withBackend :: MonadSelda m => (SeldaBackend (Backend m) -> m a) -> m a
withBackend m = withConnection (m . connBackend)
newtype SeldaT b m a = S {unS :: StateT (SeldaConnection b) m a}
deriving ( Functor, Applicative, Monad, MonadIO
, MonadThrow, MonadCatch, MonadMask , MonadFail
)
instance (MonadIO m, MonadMask m) => MonadSelda (SeldaT b m) where
type Backend (SeldaT b m) = b
withConnection m = S get >>= m
type SeldaM b = SeldaT b IO
runSeldaT :: (MonadIO m, MonadMask m)
=> SeldaT b m a
-> SeldaConnection b
-> m a
runSeldaT m c =
bracket (liftIO $ takeMVar (connLock c))
(const $ liftIO $ putMVar (connLock c) ())
(const go)
where
go = do
closed <- liftIO $ readIORef (connClosed c)
when closed $ do
liftIO $ throwM $ DbError "runSeldaT called with a closed connection"
fst <$> runStateT (unS m) c