module DB.HSQL.MySQL.Functions where
import Foreign((.&.),peekByteOff,nullPtr,peekElemOff)
import Foreign.C(CInt(..),CString,peekCString)
import Control.Concurrent.MVar(MVar,newMVar,modifyMVar,readMVar)
import Control.Exception (throw)
import Control.Monad(when)
import Database.HSQL.Types(ColDef,Statement(..),Connection(..),SqlError(..))
import DB.HSQL.MySQL.Type(MYSQL,MYSQL_RES,MYSQL_FIELD,MYSQL_ROW,MYSQL_LENGTHS
,mkSqlType)
foreign import ccall "HsMySQL.h mysql_init"
mysql_init :: MYSQL -> IO MYSQL
foreign import ccall "HsMySQL.h mysql_real_connect"
mysql_real_connect :: MYSQL -> CString -> CString -> CString -> CString -> CInt -> CString -> CInt -> IO MYSQL
foreign import ccall "HsMySQL.h mysql_close"
mysql_close :: MYSQL -> IO ()
foreign import ccall "HsMySQL.h mysql_errno"
mysql_errno :: MYSQL -> IO CInt
foreign import ccall "HsMySQL.h mysql_error"
mysql_error :: MYSQL -> IO CString
foreign import ccall "HsMySQL.h mysql_query"
mysql_query :: MYSQL -> CString -> IO CInt
foreign import ccall "HsMySQL.h mysql_use_result"
mysql_use_result :: MYSQL -> IO MYSQL_RES
foreign import ccall "HsMySQL.h mysql_fetch_field"
mysql_fetch_field :: MYSQL_RES -> IO MYSQL_FIELD
foreign import ccall "HsMySQL.h mysql_free_result"
mysql_free_result :: MYSQL_RES -> IO ()
foreign import ccall "HsMySQL.h mysql_fetch_row"
mysql_fetch_row :: MYSQL_RES -> IO MYSQL_ROW
foreign import ccall "HsMySQL.h mysql_fetch_lengths"
mysql_fetch_lengths :: MYSQL_RES -> IO MYSQL_LENGTHS
foreign import ccall "HsMySQL.h mysql_list_tables"
mysql_list_tables :: MYSQL -> CString -> IO MYSQL_RES
foreign import ccall "HsMySQL.h mysql_list_fields"
mysql_list_fields :: MYSQL -> CString -> CString -> IO MYSQL_RES
foreign import ccall "HsMySQL.h mysql_next_result"
mysql_next_result :: MYSQL -> IO CInt
withStatement :: Connection -> MYSQL -> MYSQL_RES -> IO Statement
withStatement conn pMYSQL pRes = do
currRow <- newMVar (nullPtr, nullPtr)
refFalse <- newMVar False
if (pRes == nullPtr)
then do
errno <- mysql_errno pMYSQL
when (errno /= 0) (handleSqlError pMYSQL)
return Statement { stmtConn = conn
, stmtClose = return ()
, stmtFetch = fetch pRes currRow
, stmtGetCol = getColValue currRow
, stmtFields = []
, stmtClosed = refFalse }
else do
fieldDefs <- getFieldDefs pRes
return Statement { stmtConn = conn
, stmtClose = mysql_free_result pRes
, stmtFetch = fetch pRes currRow
, stmtGetCol = getColValue currRow
, stmtFields = fieldDefs
, stmtClosed = refFalse }
getColValue :: MVar (MYSQL_ROW, MYSQL_LENGTHS)
-> Int
-> ColDef
-> (ColDef -> CString -> Int -> IO a)
-> IO a
getColValue currRow colNumber fieldDef f = do
(row, lengths) <- readMVar currRow
pValue <- peekElemOff row colNumber
len <- fmap fromIntegral (peekElemOff lengths colNumber)
f fieldDef pValue len
getFieldDefs pRes = do
pField <- mysql_fetch_field pRes
if pField == nullPtr
then return []
else do
name <- ((\hsc_ptr -> peekByteOff hsc_ptr 0)) pField >>= peekCString
dataType <- ((\hsc_ptr -> peekByteOff hsc_ptr 76)) pField
columnSize <- ((\hsc_ptr -> peekByteOff hsc_ptr 28)) pField
flags <- ((\hsc_ptr -> peekByteOff hsc_ptr 64)) pField
decimalDigits <- ((\hsc_ptr -> peekByteOff hsc_ptr 68)) pField
let sqlType = mkSqlType dataType columnSize decimalDigits
defs <- getFieldDefs pRes
return ( (name,sqlType,((flags :: Int) .&. (1)) == 0)
: defs )
fetch :: MYSQL_RES
-> MVar (MYSQL_ROW, MYSQL_LENGTHS)
-> IO Bool
fetch pRes currRow
| pRes == nullPtr = return False
| otherwise = modifyMVar currRow $ \(pRow, pLengths) -> do
pRow <- mysql_fetch_row pRes
pLengths <- mysql_fetch_lengths pRes
return ((pRow, pLengths), pRow /= nullPtr)
mysqlDefaultConnectFlags:: CInt
mysqlDefaultConnectFlags = 65536
handleSqlError :: MYSQL -> IO a
handleSqlError pMYSQL = do
errno <- mysql_errno pMYSQL
errMsg <- mysql_error pMYSQL >>= peekCString
throw (SqlError "" (fromIntegral errno) errMsg)