module Database.Persist.GenericSql.Internal
( Connection (..)
, Statement (..)
, withSqlConn
, withSqlPool
, RowPopper
, mkColumns
, Column (..)
, UniqueDef
, refName
, tableColumns
, rawFieldName
, rawTableName
, RawName (..)
) where
import qualified Data.Map as Map
import Data.IORef
import Control.Monad.IO.Class
import Database.Persist.Pool
import Database.Persist.Base
import Data.Maybe (fromJust)
import Control.Arrow
import Control.Monad.Invert (MonadInvertIO, bracket)
type RowPopper m = m (Maybe [PersistValue])
data Connection = Connection
{ prepare :: String -> IO Statement
, insertSql :: RawName -> [RawName] -> Either String (String, String)
, stmtMap :: IORef (Map.Map String Statement)
, close :: IO ()
, migrateSql :: forall v. PersistEntity v
=> (String -> IO Statement) -> v
-> IO (Either [String] [(Bool, String)])
, begin :: (String -> IO Statement) -> IO ()
, commit :: (String -> IO Statement) -> IO ()
, rollback :: (String -> IO Statement) -> IO ()
, escapeName :: RawName -> String
, noLimit :: String
}
data Statement = Statement
{ finalize :: IO ()
, reset :: IO ()
, execute :: [PersistValue] -> IO ()
, withStmt :: forall a m. MonadInvertIO m
=> [PersistValue] -> (RowPopper m -> m a) -> m a
}
withSqlPool :: MonadInvertIO m
=> IO Connection -> Int -> (Pool Connection -> m a) -> m a
withSqlPool mkConn = createPool mkConn close'
withSqlConn :: MonadInvertIO m => IO Connection -> (Connection -> m a) -> m a
withSqlConn open = bracket (liftIO open) (liftIO . close')
close' :: Connection -> IO ()
close' conn = do
readIORef (stmtMap conn) >>= mapM_ finalize . Map.elems
close conn
mkColumns :: PersistEntity val => val -> ([Column], [UniqueDef])
mkColumns val =
(cols, uniqs)
where
colNameMap = map ((\(x, _, _) -> x) &&& rawFieldName) $ entityColumns t
uniqs = map (RawName *** map (fromJust . flip lookup colNameMap))
$ entityUniques t
cols = zipWith go (tableColumns t) $ toPersistFields $ halfDefined `asTypeOf` val
t = entityDef val
tn = rawTableName t
go (name, t', as) p =
Column name ("null" `elem` as) (sqlType p) (def as) (ref name t' as)
def [] = Nothing
def (('d':'e':'f':'a':'u':'l':'t':'=':d):_) = Just d
def (_:rest) = def rest
ref c t' [] =
let l = length t'
(f, b) = splitAt (l 2) t'
in if b == "Id"
then Just (RawName f, refName tn c)
else Nothing
ref _ _ ("noreference":_) = Nothing
ref c _ (('r':'e':'f':'e':'r':'e':'n':'c':'e':'=':x):_) =
Just (RawName x, refName tn c)
ref c x (_:y) = ref c x y
refName :: RawName -> RawName -> RawName
refName (RawName table) (RawName column) =
RawName $ table ++ '_' : column ++ "_fkey"
data Column = Column
{ cName :: RawName
, cNull :: Bool
, cType :: SqlType
, cDefault :: Maybe String
, cReference :: (Maybe (RawName, RawName))
}
getSqlValue :: [String] -> Maybe String
getSqlValue (('s':'q':'l':'=':x):_) = Just x
getSqlValue (_:x) = getSqlValue x
getSqlValue [] = Nothing
tableColumns :: EntityDef -> [(RawName, String, [String])]
tableColumns = map (\a@(_, y, z) -> (rawFieldName a, y, z)) . entityColumns
type UniqueDef = (RawName, [RawName])
rawFieldName :: (String, String, [String]) -> RawName
rawFieldName (n, _, as) = RawName $
case getSqlValue as of
Just x -> x
Nothing -> n
rawTableName :: EntityDef -> RawName
rawTableName t = RawName $
case getSqlValue $ entityAttribs t of
Nothing -> entityName t
Just x -> x
newtype RawName = RawName { unRawName :: String }
deriving (Eq, Ord)