{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
module Database.SQLite3 (
open,
close,
exec,
execPrint,
execWithCallback,
ExecCallback,
prepare,
prepareUtf8,
step,
stepNoCB,
reset,
finalize,
clearBindings,
bindParameterCount,
bindParameterName,
columnCount,
columnName,
bindSQLData,
bind,
bindNamed,
bindInt,
bindInt64,
bindDouble,
bindText,
bindBlob,
bindZeroBlob,
bindNull,
column,
columns,
typedColumns,
columnType,
columnInt64,
columnDouble,
columnText,
columnBlob,
lastInsertRowId,
changes,
createFunction,
createAggregate,
deleteFunction,
funcArgCount,
funcArgType,
funcArgInt64,
funcArgDouble,
funcArgText,
funcArgBlob,
funcResultSQLData,
funcResultInt64,
funcResultDouble,
funcResultText,
funcResultBlob,
funcResultZeroBlob,
funcResultNull,
getFuncContextDatabase,
createCollation,
deleteCollation,
interrupt,
interruptibly,
blobOpen,
blobClose,
blobReopen,
blobBytes,
blobRead,
blobReadBuf,
blobWrite,
backupInit,
backupFinish,
backupStep,
backupRemaining,
backupPagecount,
Database,
Statement,
SQLData(..),
SQLError(..),
ColumnType(..),
FuncContext,
FuncArgs,
Blob,
Backup,
StepResult(..),
BackupStepResult(..),
Error(..),
ParamIndex(..),
ColumnIndex(..),
ColumnCount,
ArgCount(..),
ArgIndex,
) where
import Database.SQLite3.Direct
( Database
, Statement
, ColumnType(..)
, StepResult(..)
, BackupStepResult(..)
, Error(..)
, ParamIndex(..)
, ColumnIndex(..)
, ColumnCount
, Utf8(..)
, FuncContext
, FuncArgs
, ArgCount(..)
, ArgIndex
, Blob
, Backup
, clearBindings
, bindParameterCount
, columnCount
, columnType
, columnBlob
, columnInt64
, columnDouble
, funcArgCount
, funcArgType
, funcArgInt64
, funcArgDouble
, funcArgBlob
, funcResultInt64
, funcResultDouble
, funcResultBlob
, funcResultZeroBlob
, funcResultNull
, getFuncContextDatabase
, lastInsertRowId
, changes
, interrupt
, blobBytes
, backupRemaining
, backupPagecount
)
import qualified Database.SQLite3.Direct as Direct
import Prelude hiding (error)
import qualified Data.Text as T
import qualified Data.Text.IO as T
import Control.Concurrent
import Control.Exception
import Control.Monad (when, zipWithM, zipWithM_)
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Data.Text.Encoding (encodeUtf8, decodeUtf8With)
import Data.Text.Encoding.Error (UnicodeException(..), lenientDecode)
import Data.Typeable
import Foreign.Ptr (Ptr)
data SQLData
= SQLInteger !Int64
| SQLFloat !Double
| SQLText !Text
| SQLBlob !ByteString
| SQLNull
deriving (Eq, Show, Typeable)
data SQLError = SQLError
{ sqlError :: !Error
, sqlErrorDetails :: Text
, sqlErrorContext :: Text
}
deriving (Eq, Typeable)
instance Show SQLError where
show SQLError{ sqlError = code
, sqlErrorDetails = details
, sqlErrorContext = context
}
= T.unpack $ T.concat
[ "SQLite3 returned "
, T.pack $ show code
, " while attempting to perform "
, context
, ": "
, details
]
instance Exception SQLError
fromUtf8 :: String -> Utf8 -> IO Text
fromUtf8 desc utf8 = evaluate $ fromUtf8' desc utf8
fromUtf8' :: String -> Utf8 -> Text
fromUtf8' desc (Utf8 bs) =
decodeUtf8With (\_ c -> throw (DecodeError desc c)) bs
toUtf8 :: Text -> Utf8
toUtf8 = Utf8 . encodeUtf8
data DetailSource
= DetailDatabase Database
| DetailStatement Statement
| DetailMessage Utf8
renderDetailSource :: DetailSource -> IO Utf8
renderDetailSource src = case src of
DetailDatabase db ->
Direct.errmsg db
DetailStatement stmt -> do
db <- Direct.getStatementDatabase stmt
Direct.errmsg db
DetailMessage msg ->
return msg
throwSQLError :: DetailSource -> Text -> Error -> IO a
throwSQLError detailSource context error = do
Utf8 details <- renderDetailSource detailSource
throwIO SQLError
{ sqlError = error
, sqlErrorDetails = decodeUtf8With lenientDecode details
, sqlErrorContext = context
}
checkError :: DetailSource -> Text -> Either Error a -> IO a
checkError ds fn = either (throwSQLError ds fn) return
checkErrorMsg :: Text -> Either (Error, Utf8) a -> IO a
checkErrorMsg fn result = case result of
Left (err, msg) -> throwSQLError (DetailMessage msg) fn err
Right a -> return a
appendShow :: Show a => Text -> a -> Text
appendShow txt a = txt `T.append` (T.pack . show) a
open :: Text -> IO Database
open path =
Direct.open (toUtf8 path)
>>= checkErrorMsg ("open " `appendShow` path)
close :: Database -> IO ()
close db =
Direct.close db >>= checkError (DetailDatabase db) "close"
interruptibly :: Database -> IO a -> IO a
#if MIN_VERSION_base(4,3,0)
interruptibly db io
| rtsSupportsBoundThreads =
mask $ \restore -> do
mv <- newEmptyMVar
tid <- forkIO $ try' (restore io) >>= putMVar mv
let interruptAndWait =
uninterruptibleMask_ $ do
interrupt db
killThread tid
_ <- takeMVar mv
return ()
e <- takeMVar mv `onException` interruptAndWait
either throwIO return e
| otherwise = io
where
try' :: IO a -> IO (Either SomeException a)
try' = try
#else
interruptibly _db io = io
#endif
exec :: Database -> Text -> IO ()
exec db sql =
Direct.exec db (toUtf8 sql)
>>= checkErrorMsg ("exec " `appendShow` sql)
execPrint :: Database -> Text -> IO ()
execPrint !db !sql =
interruptibly db $
execWithCallback db sql $ \_count _colnames -> T.putStrLn . showValues
where
showValues = T.intercalate "|" . map (fromMaybe "")
execWithCallback :: Database -> Text -> ExecCallback -> IO ()
execWithCallback db sql cb =
Direct.execWithCallback db (toUtf8 sql) cb'
>>= checkErrorMsg ("execWithCallback " `appendShow` sql)
where
cb' count namesUtf8 =
let names = map fromUtf8'' namesUtf8
{-# NOINLINE names #-}
in cb count names . map (fmap fromUtf8'')
fromUtf8'' = fromUtf8' "Database.SQLite3.execWithCallback: Invalid UTF-8"
type ExecCallback
= ColumnCount
-> [Text]
-> [Maybe Text]
-> IO ()
prepare :: Database -> Text -> IO Statement
prepare db sql = prepareUtf8 db (toUtf8 sql)
prepareUtf8 :: Database -> Utf8 -> IO Statement
prepareUtf8 db sql = do
m <- Direct.prepare db sql
>>= checkError (DetailDatabase db) ("prepare " `appendShow` sql)
case m of
Nothing -> fail "Direct.SQLite3.prepare: empty query string"
Just stmt -> return stmt
step :: Statement -> IO StepResult
step statement =
Direct.step statement >>= checkError (DetailStatement statement) "step"
stepNoCB :: Statement -> IO StepResult
stepNoCB statement =
Direct.stepNoCB statement >>= checkError (DetailStatement statement) "stepNoCB"
reset :: Statement -> IO ()
reset statement = do
_ <- Direct.reset statement
return ()
finalize :: Statement -> IO ()
finalize statement = do
_ <- Direct.finalize statement
return ()
bindParameterName :: Statement -> ParamIndex -> IO (Maybe Text)
bindParameterName stmt idx = do
m <- Direct.bindParameterName stmt idx
case m of
Nothing -> return Nothing
Just name -> Just <$> fromUtf8 desc name
where
desc = "Database.SQLite3.bindParameterName: Invalid UTF-8"
columnName :: Statement -> ColumnIndex -> IO (Maybe Text)
columnName stmt idx = do
m <- Direct.columnName stmt idx
case m of
Just name -> Just <$> fromUtf8 desc name
Nothing -> do
count <- Direct.columnCount stmt
if idx >= 0 && idx < count
then throwIO outOfMemory
else return Nothing
where
desc = "Database.SQLite3.columnName: Invalid UTF-8"
outOfMemory = SQLError
{ sqlError = ErrorNoMemory
, sqlErrorDetails = "out of memory (sqlite3_column_name returned NULL)"
, sqlErrorContext = "column name"
}
bindBlob :: Statement -> ParamIndex -> ByteString -> IO ()
bindBlob statement parameterIndex byteString =
Direct.bindBlob statement parameterIndex byteString
>>= checkError (DetailStatement statement) "bind blob"
bindZeroBlob :: Statement -> ParamIndex -> Int -> IO ()
bindZeroBlob statement parameterIndex len =
Direct.bindZeroBlob statement parameterIndex len
>>= checkError (DetailStatement statement) "bind zeroblob"
bindDouble :: Statement -> ParamIndex -> Double -> IO ()
bindDouble statement parameterIndex datum =
Direct.bindDouble statement parameterIndex datum
>>= checkError (DetailStatement statement) "bind double"
bindInt :: Statement -> ParamIndex -> Int -> IO ()
bindInt statement parameterIndex datum =
Direct.bindInt64 statement
parameterIndex
(fromIntegral datum)
>>= checkError (DetailStatement statement) "bind int"
bindInt64 :: Statement -> ParamIndex -> Int64 -> IO ()
bindInt64 statement parameterIndex datum =
Direct.bindInt64 statement parameterIndex datum
>>= checkError (DetailStatement statement) "bind int64"
bindNull :: Statement -> ParamIndex -> IO ()
bindNull statement parameterIndex =
Direct.bindNull statement parameterIndex
>>= checkError (DetailStatement statement) "bind null"
bindText :: Statement -> ParamIndex -> Text -> IO ()
bindText statement parameterIndex text =
Direct.bindText statement parameterIndex (toUtf8 text)
>>= checkError (DetailStatement statement) "bind text"
bindSQLData :: Statement -> ParamIndex -> SQLData -> IO ()
bindSQLData statement idx datum =
case datum of
SQLInteger v -> bindInt64 statement idx v
SQLFloat v -> bindDouble statement idx v
SQLText v -> bindText statement idx v
SQLBlob v -> bindBlob statement idx v
SQLNull -> bindNull statement idx
bind :: Statement -> [SQLData] -> IO ()
bind statement sqlData = do
ParamIndex nParams <- bindParameterCount statement
when (nParams /= length sqlData) $
fail ("mismatched parameter count for bind. Prepared statement "++
"needs "++ show nParams ++ ", " ++ show (length sqlData) ++" given")
zipWithM_ (bindSQLData statement) [1..] sqlData
bindNamed :: Statement -> [(T.Text, SQLData)] -> IO ()
bindNamed statement params = do
ParamIndex nParams <- bindParameterCount statement
when (nParams /= length params) $
fail ("mismatched parameter count for bind. Prepared statement "++
"needs "++ show nParams ++ ", " ++ show (length params) ++" given")
mapM_ bindIdx params
where
bindIdx (name, val) = do
idx <- Direct.bindParameterIndex statement $ toUtf8 name
case idx of
Just i ->
bindSQLData statement i val
Nothing ->
fail ("unknown named parameter "++show name)
columnText :: Statement -> ColumnIndex -> IO Text
columnText statement columnIndex =
Direct.columnText statement columnIndex
>>= fromUtf8 "Database.SQLite3.columnText: Invalid UTF-8"
column :: Statement -> ColumnIndex -> IO SQLData
column statement idx = do
theType <- columnType statement idx
typedColumn theType statement idx
columns :: Statement -> IO [SQLData]
columns statement = do
count <- columnCount statement
mapM (column statement) [0..count-1]
typedColumn :: ColumnType -> Statement -> ColumnIndex -> IO SQLData
typedColumn theType statement idx = case theType of
IntegerColumn -> SQLInteger <$> columnInt64 statement idx
FloatColumn -> SQLFloat <$> columnDouble statement idx
TextColumn -> SQLText <$> columnText statement idx
BlobColumn -> SQLBlob <$> columnBlob statement idx
NullColumn -> return SQLNull
typedColumns :: Statement -> [Maybe ColumnType] -> IO [SQLData]
typedColumns statement = zipWithM f [0..] where
f idx theType = case theType of
Nothing -> column statement idx
Just t -> typedColumn t statement idx
createFunction
:: Database
-> Text
-> Maybe ArgCount
-> Bool
-> (FuncContext -> FuncArgs -> IO ())
-> IO ()
createFunction db name nArgs isDet fun =
Direct.createFunction db (toUtf8 name) nArgs isDet fun
>>= checkError (DetailDatabase db) ("createFunction " `appendShow` name)
createAggregate
:: Database
-> Text
-> Maybe ArgCount
-> a
-> (FuncContext -> FuncArgs -> a -> IO a)
-> (FuncContext -> a -> IO ())
-> IO ()
createAggregate db name nArgs initSt xStep xFinal =
Direct.createAggregate db (toUtf8 name) nArgs initSt xStep xFinal
>>= checkError (DetailDatabase db) ("createAggregate " `appendShow` name)
deleteFunction :: Database -> Text -> Maybe ArgCount -> IO ()
deleteFunction db name nArgs =
Direct.deleteFunction db (toUtf8 name) nArgs
>>= checkError (DetailDatabase db) ("deleteFunction " `appendShow` name)
funcArgText :: FuncArgs -> ArgIndex -> IO Text
funcArgText args argIndex =
Direct.funcArgText args argIndex
>>= fromUtf8 "Database.SQLite3.funcArgText: Invalid UTF-8"
funcResultSQLData :: FuncContext -> SQLData -> IO ()
funcResultSQLData ctx datum =
case datum of
SQLInteger v -> funcResultInt64 ctx v
SQLFloat v -> funcResultDouble ctx v
SQLText v -> funcResultText ctx v
SQLBlob v -> funcResultBlob ctx v
SQLNull -> funcResultNull ctx
funcResultText :: FuncContext -> Text -> IO ()
funcResultText ctx value =
Direct.funcResultText ctx (toUtf8 value)
createCollation
:: Database
-> Text
-> (Text -> Text -> Ordering)
-> IO ()
createCollation db name cmp =
Direct.createCollation db (toUtf8 name) cmp'
>>= checkError (DetailDatabase db) ("createCollation " `appendShow` name)
where
cmp' (Utf8 s1) (Utf8 s2) = cmp (fromUtf8'' s1) (fromUtf8'' s2)
fromUtf8'' = decodeUtf8With lenientDecode
deleteCollation :: Database -> Text -> IO ()
deleteCollation db name =
Direct.deleteCollation db (toUtf8 name)
>>= checkError (DetailDatabase db) ("deleteCollation " `appendShow` name)
blobOpen
:: Database
-> Text
-> Text
-> Text
-> Int64
-> Bool
-> IO Blob
blobOpen db zDb zTable zColumn rowid rw =
Direct.blobOpen db (toUtf8 zDb) (toUtf8 zTable) (toUtf8 zColumn) rowid rw
>>= checkError (DetailDatabase db) "blobOpen"
blobClose :: Blob -> IO ()
blobClose blob@(Direct.Blob db _) =
Direct.blobClose blob
>>= checkError (DetailDatabase db) "blobClose"
blobReopen
:: Blob
-> Int64
-> IO ()
blobReopen blob@(Direct.Blob db _) rowid =
Direct.blobReopen blob rowid
>>= checkError (DetailDatabase db) "blobReopen"
blobRead
:: Blob
-> Int
-> Int
-> IO ByteString
blobRead blob@(Direct.Blob db _) len offset =
Direct.blobRead blob len offset
>>= checkError (DetailDatabase db) "blobRead"
blobReadBuf :: Blob -> Ptr a -> Int -> Int -> IO ()
blobReadBuf blob@(Direct.Blob db _) buf len offset =
Direct.blobReadBuf blob buf len offset
>>= checkError (DetailDatabase db) "blobReadBuf"
blobWrite
:: Blob
-> ByteString
-> Int
-> IO ()
blobWrite blob@(Direct.Blob db _) bs offset =
Direct.blobWrite blob bs offset
>>= checkError (DetailDatabase db) "blobWrite"
backupInit
:: Database
-> Text
-> Database
-> Text
-> IO Backup
backupInit dstDb dstName srcDb srcName =
Direct.backupInit dstDb (toUtf8 dstName) srcDb (toUtf8 srcName)
>>= checkError (DetailDatabase dstDb) "backupInit"
backupFinish :: Backup -> IO ()
backupFinish backup@(Direct.Backup dstDb _) =
Direct.backupFinish backup
>>= checkError (DetailDatabase dstDb) "backupFinish"
backupStep :: Backup -> Int -> IO BackupStepResult
backupStep backup pages =
Direct.backupStep backup pages
>>= checkError (DetailMessage "failed") "backupStep"