module Database.SQLite
( module Database.SQLite.Base
, module Database.SQLite.Types
, module Database.SQL.Types
, openConnection
, closeConnection
, execStatement
, execStatement_
, execParamStatement
, execParamStatement_
, insertRow
, defineTable
, defineTableOpt
, getLastRowID
, Row
, Value(..)
, addRegexpSupport
, RegexpHandler
, withPrim
, SQLiteHandle()
, newSQLiteHandle
) where
import Database.SQLite.Types
import Database.SQLite.Base
import Database.SQL.Types
import Foreign.Marshal
import Foreign.C
import Foreign.C.String (newCStringLen, peekCString)
import Foreign.Storable
import qualified Foreign.Concurrent as Conc
import Foreign.Ptr
import Foreign.ForeignPtr
import Data.List
import Data.Int
import Data.Char ( isDigit )
import Data.ByteString (ByteString, packCStringLen, useAsCStringLen)
import Data.ByteString.Unsafe (unsafePackCStringLen)
import Control.Monad ((<=<),when)
import qualified Codec.Binary.UTF8.String as UTF8
newtype SQLiteHandle = SQLiteHandle (ForeignPtr ())
addSQLiteHandleFinalizer :: SQLiteHandle -> IO () -> IO ()
addSQLiteHandleFinalizer (SQLiteHandle h) = Conc.addForeignPtrFinalizer h
newSQLiteHandle :: SQLite -> IO SQLiteHandle
newSQLiteHandle h@(SQLite p) = SQLiteHandle `fmap` Conc.newForeignPtr p close
where close = sqlite3_close h >> return ()
openConnection :: String -> IO SQLiteHandle
openConnection dbName =
alloca $ \ptr -> do
st <- withCString dbName $ \ c_dbName ->
sqlite3_open c_dbName ptr
case st of
0 -> do db <- peek ptr
newSQLiteHandle db
_ -> fail ("openDatabase: failed to open " ++ show st)
closeConnection :: SQLiteHandle -> IO ()
closeConnection (SQLiteHandle h) = finalizeForeignPtr h
withPrim :: SQLiteHandle -> (SQLite -> IO a) -> IO a
withPrim (SQLiteHandle h) f = withForeignPtr h (f . SQLite)
type Row a = [(ColumnName,a)]
defineTableOpt :: SQLiteHandle -> Bool -> SQLTable -> IO (Maybe String)
defineTableOpt h check tab = execStatement_ h (createTable tab)
where
opt = if check then " IF NOT EXISTS " else ""
createTable t =
"CREATE TABLE " ++ opt ++ toSQLString (tabName t) ++
tupled (map toCols (tabColumns t)) ++ ";"
toCols col =
toSQLString (colName col) ++ " " ++ showType (colType col) ++
' ':unwords (map showClause (colClauses col))
defineTable :: SQLiteHandle -> SQLTable -> IO (Maybe String)
defineTable h tab = defineTableOpt h False tab
insertRow :: SQLiteHandle -> TableName -> Row String -> IO (Maybe String)
insertRow h tab cs = do
let stmt = ("INSERT INTO " ++ tab ++
tupled (toVals fst) ++ " VALUES " ++
tupled (toVals (quote.snd)) ++ ";")
execStatement_ h stmt
where
toVals f = map (toVal f) cs
toVal f p = f p
quote "" = "''"
quote nm@(x:_)
| isDigit x = nm
| otherwise = '\'':toSQLString nm ++ "'"
getLastRowID :: SQLiteHandle -> IO Integer
getLastRowID h = withPrim h $ \ p -> do
v <- sqlite3_last_insert_rowid p
return (fromIntegral v)
data Value
= Double Double
| Int Int64
| Text String
| Blob ByteString
| Null
deriving Show
foreign import ccall "stdlib.h &free"
p_free :: FunPtr (Ptr a -> IO ())
bindValue :: SQLiteStmt -> String -> Value -> IO Status
bindValue stmt key value =
withCString (UTF8.encodeString key) $ \ckey ->
ensure (sqlite3_bind_parameter_index stmt ckey)
(> 0) (return sQLITE_OK) $ \ix ->
case value of
Text txt ->
do (cptr,len) <- newCStringLen (UTF8.encodeString txt)
res <- sqlite3_bind_text stmt ix cptr (fromIntegral len) p_free
when (res /= sQLITE_OK) (free cptr)
return res
Null -> sqlite3_bind_null stmt ix
Int x -> sqlite3_bind_int64 stmt ix x
Double x -> sqlite3_bind_double stmt ix x
Blob x -> useAsCStringLen x $ \ (ptr,bytes) ->
sqlite3_bind_blob stmt ix (castPtr ptr)
(fromIntegral bytes) nullFunPtr
to_error :: SQLite -> IO (Either String a)
to_error db = Left `fmap` (peekCString =<< sqlite3_errmsg db)
execParamStatement_ :: SQLiteHandle -> String -> [(String,Value)]
-> IO (Maybe String)
execParamStatement_ db q ps =
either Just (const Nothing) `fmap`
(execParamStatement db q ps :: IO (Either String [[Row ()]]))
execParamStatement :: SQLiteResult a => SQLiteHandle -> String
-> [(String,Value)] -> IO (Either String [[Row a]])
execParamStatement h query params = withPrim h $ \ db ->
alloca $ \stmt_ptr ->
alloca $ \pzTail ->
let encoded = UTF8.encodeString query in
withCString encoded $ \zSql -> do
poke pzTail zSql
prepare_loop db stmt_ptr pzTail
where
prepare_loop db stmt_ptr sqltxt_ptr = loop [] where
loop xs =
peek sqltxt_ptr >>= \ sqltxt ->
ensure_ (peek sqltxt) (/= 0) (eReturn (reverse xs)) $
ensure_ (sqlite3_prepare db sqltxt (1) stmt_ptr sqltxt_ptr)
(== sQLITE_OK) (to_error db) $
ensure (peek stmt_ptr)
(not . isNullStmt) (loop xs) $ \ stmt ->
then_finalize db (recv_rows db stmt) stmt `ebind` \ x ->
loop (x:xs)
recv_rows db stmt =
do mapM_ (uncurry $ bindValue stmt) params
col_num <- sqlite3_column_count stmt
let cols = [0..col_num1]
names <- mapM (peekCString <=< sqlite3_column_name stmt) cols
let decoded_names = map UTF8.decodeString names
get_rows db stmt cols decoded_names []
get_rows db stmt cols col_names rows = do
res <- sqlite3_step stmt
if res == sQLITE_ROW
then do
txts <- mapM (get_sqlite_val stmt) cols
let row = zip col_names txts
get_rows db stmt cols col_names (row:rows)
else if res == sQLITE_DONE
then eReturn (reverse rows)
else to_error db
then_finalize db m stmt = do
e <- m
sqlite3_finalize stmt
case e of
Left _ -> to_error db
Right r -> return (Right r)
execStatement :: SQLiteResult a
=> SQLiteHandle -> String -> IO (Either String [[Row a]])
execStatement db s = execParamStatement db s []
execStatement_ :: SQLiteHandle -> String -> IO (Maybe String)
execStatement_ h sqlStmt = withPrim h $ \ db ->
withCString (UTF8.encodeString sqlStmt) $ \ c_sqlStmt ->
sqlite3_exec db c_sqlStmt noCallback nullPtr nullPtr >>= \ st ->
if st == sQLITE_OK
then return Nothing
else fmap Just . peekCString =<< sqlite3_errmsg db
tupled :: [String] -> String
tupled xs = "(" ++ concat (intersperse ", " xs) ++ ")"
infixl 1 `ebind`
ebind :: Monad m => m (Either e a) -> (a -> m (Either e b)) -> m (Either e b)
m `ebind` f = do x <- m
case x of Left e -> return $ Left e
Right r -> f r
eReturn :: Monad m => a -> m (Either e a)
eReturn x = return $ Right x
ensure :: Monad m => m a -> (a -> Bool) -> m b -> (a -> m b) -> m b
ensure m p t f = m >>= \ x -> if p x then f x else t
ensure_ :: Monad m => m a -> (a -> Bool) -> m b -> m b -> m b
ensure_ m p t f = ensure m p t (const f)
class SQLiteResult a where
get_sqlite_val :: SQLiteStmt -> CInt -> IO a
instance SQLiteResult String where
get_sqlite_val = get_text_val
instance SQLiteResult () where
get_sqlite_val _ _ = return ()
instance SQLiteResult Value where
get_sqlite_val = get_val
get_text_val :: SQLiteStmt -> CInt -> IO String
get_text_val stmt n =
do ptr <- sqlite3_column_text stmt n
bytes <- sqlite3_column_bytes stmt n
str <- peekCStringLen (ptr,fromIntegral bytes)
return (UTF8.decodeString str)
get_val :: SQLiteStmt -> CInt -> IO Value
get_val stmt n =
do val <- sqlite3_column_value stmt n
typ <- sqlite3_value_type val
case () of
_ | typ == sQLITE_NULL -> return Null
| typ == sQLITE_INTEGER -> Int `fmap` sqlite3_value_int64 val
| typ == sQLITE_FLOAT -> Double `fmap` sqlite3_value_double val
| typ == sQLITE_TEXT ->
fmap Text . peekCStringLen =<< sqlite3_value_cstringlen val
| typ == sQLITE_BLOB ->
do SQLiteBLOB ptr <- sqlite3_value_blob val
bytes <- sqlite3_value_bytes val
str <- packCStringLen (castPtr ptr, fromIntegral bytes)
return $ Blob str
| otherwise -> error "get_val: unknown type"
type RegexpHandler = ByteString -> ByteString -> IO Bool
addRegexpSupport :: SQLiteHandle -> RegexpHandler -> IO ()
addRegexpSupport h f =
withCString "REGEXP" $ \ zFunctionName ->
do xFunc <- mkStepHandler $ regexp_callback f
withPrim h $ \ db ->
sqlite3_create_function db zFunctionName 2 sQLITE_UTF8 nullPtr
xFunc noCallback noCallback
addSQLiteHandleFinalizer h (freeCallback xFunc)
regexp_callback :: RegexpHandler -> StepHandler
regexp_callback f ctx argc argv =
if argc /= 2 then return_fail else
do arg0 <- sqlite3_value_cstringlen =<< peek argv
arg1 <- sqlite3_value_cstringlen =<< peekElemOff argv 1
if isNullCStringLen arg0 || isNullCStringLen arg1 then return_fail else
do regexp_str <- unsafePackCStringLen arg0
str <- unsafePackCStringLen arg1
res <- f regexp_str str
if res then return_success else return_fail
where
return_fail = sqlite3_result_int ctx 0
return_success = sqlite3_result_int ctx 1
isNullCStringLen :: CStringLen -> Bool
isNullCStringLen (p,_) = p == nullPtr
sqlite3_value_cstringlen :: SQLiteValue -> IO CStringLen
sqlite3_value_cstringlen v =
do str <- sqlite3_value_text v
len <- sqlite3_value_bytes v
return (str, fromIntegral len)