module Database.SQLite3 (
open,
close,
exec,
execPrint,
execWithCallback,
ExecCallback,
prepare,
prepareUtf8,
step,
reset,
finalize,
clearBindings,
bindParameterCount,
bindParameterName,
columnCount,
columnName,
bindSQLData,
bind,
bindNamed,
bindInt,
bindInt64,
bindDouble,
bindText,
bindBlob,
bindNull,
column,
columns,
typedColumns,
columnType,
columnInt64,
columnDouble,
columnText,
columnBlob,
lastInsertRowId,
changes,
interrupt,
interruptibly,
Database,
Statement,
SQLData(..),
SQLError(..),
ColumnType(..),
StepResult(..),
Error(..),
ParamIndex(..),
ColumnIndex(..),
ColumnCount,
) where
import Database.SQLite3.Direct
( Database
, Statement
, ColumnType(..)
, StepResult(..)
, Error(..)
, ParamIndex(..)
, ColumnIndex(..)
, ColumnCount
, Utf8(..)
, clearBindings
, bindParameterCount
, columnCount
, columnType
, columnBlob
, columnInt64
, columnDouble
, lastInsertRowId
, changes
, interrupt
)
import qualified Database.SQLite3.Direct as Direct
import Prelude hiding (error)
import qualified Data.Text as T
import qualified Data.Text.IO as T
import Control.Applicative ((<$>))
import Control.Concurrent
import Control.Exception
import Control.Monad (when, zipWithM, zipWithM_)
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Data.Text.Encoding (encodeUtf8, decodeUtf8With)
import Data.Text.Encoding.Error (UnicodeException(..), lenientDecode)
import Data.Typeable
data SQLData
= SQLInteger !Int64
| SQLFloat !Double
| SQLText !Text
| SQLBlob !ByteString
| SQLNull
deriving (Eq, Show, Typeable)
data SQLError = SQLError
{ sqlError :: !Error
, sqlErrorDetails :: Text
, sqlErrorContext :: Text
}
deriving Typeable
instance Show SQLError where
show SQLError{ sqlError = code
, sqlErrorDetails = details
, sqlErrorContext = context
}
= T.unpack $ T.concat
[ "SQLite3 returned "
, T.pack $ show code
, " while attempting to perform "
, context
, ": "
, details
]
instance Exception SQLError
fromUtf8 :: String -> Utf8 -> IO Text
fromUtf8 desc utf8 = evaluate $ fromUtf8' desc utf8
fromUtf8' :: String -> Utf8 -> Text
fromUtf8' desc (Utf8 bs) =
decodeUtf8With (\_ c -> throw (DecodeError desc c)) bs
toUtf8 :: Text -> Utf8
toUtf8 = Utf8 . encodeUtf8
data DetailSource
= DetailDatabase Database
| DetailStatement Statement
| DetailMessage Utf8
renderDetailSource :: DetailSource -> IO Utf8
renderDetailSource src = case src of
DetailDatabase db ->
Direct.errmsg db
DetailStatement stmt -> do
db <- Direct.getStatementDatabase stmt
Direct.errmsg db
DetailMessage msg ->
return msg
throwSQLError :: DetailSource -> Text -> Error -> IO a
throwSQLError detailSource context error = do
Utf8 details <- renderDetailSource detailSource
throwIO SQLError
{ sqlError = error
, sqlErrorDetails = decodeUtf8With lenientDecode details
, sqlErrorContext = context
}
checkError :: DetailSource -> Text -> Either Error a -> IO a
checkError ds fn = either (throwSQLError ds fn) return
checkErrorMsg :: Text -> Either (Error, Utf8) a -> IO a
checkErrorMsg fn result = case result of
Left (err, msg) -> throwSQLError (DetailMessage msg) fn err
Right a -> return a
appendShow :: Show a => Text -> a -> Text
appendShow txt a = txt `T.append` (T.pack . show) a
open :: Text -> IO Database
open path =
Direct.open (toUtf8 path)
>>= checkErrorMsg ("open " `appendShow` path)
close :: Database -> IO ()
close db =
Direct.close db >>= checkError (DetailDatabase db) "close"
interruptibly :: Database -> IO a -> IO a
#if MIN_VERSION_base(4,3,0)
interruptibly db io
| rtsSupportsBoundThreads =
mask $ \restore -> do
mv <- newEmptyMVar
tid <- forkIO $ try' (restore io) >>= putMVar mv
let interruptAndWait =
uninterruptibleMask_ $ do
interrupt db
killThread tid
_ <- takeMVar mv
return ()
e <- takeMVar mv `onException` interruptAndWait
either throwIO return e
| otherwise = io
where
try' :: IO a -> IO (Either SomeException a)
try' = try
#else
interruptibly _db io = io
#endif
exec :: Database -> Text -> IO ()
exec db sql =
Direct.exec db (toUtf8 sql)
>>= checkErrorMsg ("exec " `appendShow` sql)
execPrint :: Database -> Text -> IO ()
execPrint !db !sql =
interruptibly db $
execWithCallback db sql $ \_count _colnames -> T.putStrLn . showValues
where
showValues = T.intercalate "|" . map (fromMaybe "")
execWithCallback :: Database -> Text -> ExecCallback -> IO ()
execWithCallback db sql cb =
Direct.execWithCallback db (toUtf8 sql) cb'
>>= checkErrorMsg ("execWithCallback " `appendShow` sql)
where
cb' count namesUtf8 =
let names = map fromUtf8'' namesUtf8
in cb count names . map (fmap fromUtf8'')
fromUtf8'' = fromUtf8' "Database.SQLite3.execWithCallback: Invalid UTF-8"
type ExecCallback
= ColumnCount
-> [Text]
-> [Maybe Text]
-> IO ()
prepare :: Database -> Text -> IO Statement
prepare db sql = prepareUtf8 db (toUtf8 sql)
prepareUtf8 :: Database -> Utf8 -> IO Statement
prepareUtf8 db sql = do
m <- Direct.prepare db sql
>>= checkError (DetailDatabase db) ("prepare " `appendShow` sql)
case m of
Nothing -> fail "Direct.SQLite3.prepare: empty query string"
Just stmt -> return stmt
step :: Statement -> IO StepResult
step statement =
Direct.step statement >>= checkError (DetailStatement statement) "step"
reset :: Statement -> IO ()
reset statement = do
_ <- Direct.reset statement
return ()
finalize :: Statement -> IO ()
finalize statement = do
_ <- Direct.finalize statement
return ()
bindParameterName :: Statement -> ParamIndex -> IO (Maybe Text)
bindParameterName stmt idx = do
m <- Direct.bindParameterName stmt idx
case m of
Nothing -> return Nothing
Just name -> Just <$> fromUtf8 desc name
where
desc = "Database.SQLite3.bindParameterName: Invalid UTF-8"
columnName :: Statement -> ColumnIndex -> IO (Maybe Text)
columnName stmt idx = do
m <- Direct.columnName stmt idx
case m of
Just name -> Just <$> fromUtf8 desc name
Nothing -> do
count <- Direct.columnCount stmt
if idx >= 0 && idx < count
then throwIO outOfMemory
else return Nothing
where
desc = "Database.SQLite3.columnName: Invalid UTF-8"
outOfMemory = SQLError
{ sqlError = ErrorNoMemory
, sqlErrorDetails = "out of memory (sqlite3_column_name returned NULL)"
, sqlErrorContext = "column name"
}
bindBlob :: Statement -> ParamIndex -> ByteString -> IO ()
bindBlob statement parameterIndex byteString =
Direct.bindBlob statement parameterIndex byteString
>>= checkError (DetailStatement statement) "bind blob"
bindDouble :: Statement -> ParamIndex -> Double -> IO ()
bindDouble statement parameterIndex datum =
Direct.bindDouble statement parameterIndex datum
>>= checkError (DetailStatement statement) "bind double"
bindInt :: Statement -> ParamIndex -> Int -> IO ()
bindInt statement parameterIndex datum =
Direct.bindInt64 statement
parameterIndex
(fromIntegral datum)
>>= checkError (DetailStatement statement) "bind int"
bindInt64 :: Statement -> ParamIndex -> Int64 -> IO ()
bindInt64 statement parameterIndex datum =
Direct.bindInt64 statement parameterIndex datum
>>= checkError (DetailStatement statement) "bind int64"
bindNull :: Statement -> ParamIndex -> IO ()
bindNull statement parameterIndex =
Direct.bindNull statement parameterIndex
>>= checkError (DetailStatement statement) "bind null"
bindText :: Statement -> ParamIndex -> Text -> IO ()
bindText statement parameterIndex text =
Direct.bindText statement parameterIndex (toUtf8 text)
>>= checkError (DetailStatement statement) "bind text"
bindSQLData :: Statement -> ParamIndex -> SQLData -> IO ()
bindSQLData statement idx datum =
case datum of
SQLInteger v -> bindInt64 statement idx v
SQLFloat v -> bindDouble statement idx v
SQLText v -> bindText statement idx v
SQLBlob v -> bindBlob statement idx v
SQLNull -> bindNull statement idx
bind :: Statement -> [SQLData] -> IO ()
bind statement sqlData = do
ParamIndex nParams <- bindParameterCount statement
when (nParams /= length sqlData) $
fail ("mismatched parameter count for bind. Prepared statement "++
"needs "++ show nParams ++ ", " ++ show (length sqlData) ++" given")
zipWithM_ (bindSQLData statement) [1..] sqlData
bindNamed :: Statement -> [(T.Text, SQLData)] -> IO ()
bindNamed statement params = do
ParamIndex nParams <- bindParameterCount statement
when (nParams /= length params) $
fail ("mismatched parameter count for bind. Prepared statement "++
"needs "++ show nParams ++ ", " ++ show (length params) ++" given")
mapM_ bindIdx params
where
bindIdx (name, val) = do
idx <- Direct.bindParameterIndex statement $ toUtf8 name
case idx of
Just i ->
bindSQLData statement i val
Nothing ->
fail ("unknown named parameter "++show name)
columnText :: Statement -> ColumnIndex -> IO Text
columnText statement columnIndex =
Direct.columnText statement columnIndex
>>= fromUtf8 "Database.SQLite3.columnText: Invalid UTF-8"
column :: Statement -> ColumnIndex -> IO SQLData
column statement idx = do
theType <- columnType statement idx
typedColumn theType statement idx
columns :: Statement -> IO [SQLData]
columns statement = do
count <- columnCount statement
mapM (column statement) [0..count1]
typedColumn :: ColumnType -> Statement -> ColumnIndex -> IO SQLData
typedColumn theType statement idx = case theType of
IntegerColumn -> SQLInteger <$> columnInt64 statement idx
FloatColumn -> SQLFloat <$> columnDouble statement idx
TextColumn -> SQLText <$> columnText statement idx
BlobColumn -> SQLBlob <$> columnBlob statement idx
NullColumn -> return SQLNull
typedColumns :: Statement -> [Maybe ColumnType] -> IO [SQLData]
typedColumns statement = zipWithM f [0..] where
f idx theType = case theType of
Nothing -> column statement idx
Just t -> typedColumn t statement idx