{-# LANGUAGE TypeFamilies , DeriveDataTypeable , OverloadedStrings #-} module Database.HDBI.SQlite.Implementation ( -- * types SQliteConnection (..) , SQliteStatement (..) , SQState (..) -- * connecting , connectSqlite3 -- * auxiliary functions , encodeUTF8 , encodeLUTF8 , encodeSUTF8 , decodeUTF8 , decodeLUTF8 , fetchValue , bindParam , throwErrMsg , sqliteMsg , withConnectionUnlocked ) where import Blaze.ByteString.Builder (toByteString) import Blaze.ByteString.Builder.Char.Utf8 (fromText, fromLazyText, fromString) import Control.Applicative import Control.Concurrent.MVar import Control.Exception import Control.Monad (forM, forM_) import Data.Int import Data.Typeable import Database.HDBI.DriverUtils import Database.HDBI.Formaters import Database.HDBI.SqlValue import Database.HDBI.Types import qualified Data.ByteString.Lazy as BL import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Text.Lazy as TL import qualified Data.Text.Lazy.Encoding as TL import qualified Database.SQLite3.Direct as SD encodeUTF8 :: T.Text -> SD.Utf8 encodeUTF8 x = SD.Utf8 $ toByteString $ fromText x encodeLUTF8 :: TL.Text -> SD.Utf8 encodeLUTF8 x = SD.Utf8 $ toByteString $ fromLazyText x encodeSUTF8 :: String -> SD.Utf8 encodeSUTF8 x = SD.Utf8 $ toByteString $ fromString x decodeUTF8 :: SD.Utf8 -> T.Text decodeUTF8 (SD.Utf8 x) = T.decodeUtf8 x decodeLUTF8 :: SD.Utf8 -> TL.Text decodeLUTF8 (SD.Utf8 x) = TL.decodeUtf8 $ BL.fromChunks [x] -- | Connection to the database data SQliteConnection = SQliteConnection { scDatabase :: MVar (Maybe SD.Database) , scConnString :: T.Text , scStatements :: ChildList SQliteStatement -- ^ List of statements to finish before disconnect } deriving (Typeable) -- | Prepared statement data SQliteStatement = SQliteStatement { ssState :: MVar SQState , ssConnection :: SQliteConnection , ssQuery :: Query } deriving (Typeable) -- | Internal state of the statement. There is two similar constructors -- 'SQFetching' and 'SQExecuted' to simulate proper behaviour according to -- tests. data SQState = SQNew { sqStatement :: SD.Statement } | SQExecuted { sqStatement :: SD.Statement , sqResult :: SD.StepResult } | SQFetching { sqStatement :: SD.Statement , sqResult :: SD.StepResult } | SQFinished -- | Connect to SQlite3 database connectSqlite3 :: T.Text -- ^ Connection string -> IO SQliteConnection connectSqlite3 connstr = do res <- SD.open $ SD.Utf8 $ toByteString $ fromText connstr case res of Left (err, (SD.Utf8 errmsg)) -> throwIO $ SqlError (show err) $ T.unpack $ T.decodeUtf8 $ errmsg Right r -> SQliteConnection <$> newMVar (Just r) <*> return connstr <*> newChildList instance Connection SQliteConnection where type ConnStatement SQliteConnection = SQliteStatement disconnect conn = modifyMVar_ (scDatabase conn) $ \c -> case c of Nothing -> return Nothing Just (con) -> do closeAllChildren (scStatements conn) res <- SD.close con case res of Left err -> throwErrMsg con $ show err Right () -> return Nothing begin conn = run conn "begin" () commit conn = run conn "commit" () rollback conn = run conn "rollback" () inTransaction conn = withConnectionUnlocked conn $ \con -> do ac <- SD.getAutoCommit con return $ not ac -- we are not in autocommit, so we inside the -- transaction connStatus conn = do val <- readMVar $ scDatabase $ conn return $ case val of Nothing -> ConnDisconnected Just _ -> ConnOK prepare conn query = withConnectionUnlocked conn $ \con -> do res <- SD.prepare con $ SD.Utf8 $ toByteString $ fromLazyText $ unQuery query case res of Left err -> throwErrMsg con $ show err Right x -> case x of Nothing -> throwErrMsg con "" Just st -> do ret <- SQliteStatement <$> (newMVar $ SQNew st) <*> return conn <*> return query addChild (scStatements conn) ret return ret -- run: using default implementation runRaw conn (Query t) = withConnectionUnlocked conn $ \con -> do res <- SD.exec con $ encodeLUTF8 t case res of Left (err, errmsg) -> throwIO $ SqlError (show err) $ T.unpack $ decodeUTF8 errmsg Right () -> return () -- runMany: using default implementation clone conn = connectSqlite3 $ scConnString conn hdbiDriverName = const "sqlite3" dbTransactionSupport = const True instance Statement SQliteStatement where execute stmt vals = modifyMVar_ (ssState stmt) $ \state -> case state of SQNew st -> execute' st SQExecuted {} -> throwIO $ SqlDriverError $ sqliteMsg "Statement is already executed" SQFetching {} -> throwIO $ SqlDriverError $ sqliteMsg "Statement is already executed (fetching)" SQFinished -> throwIO $ SqlDriverError $ sqliteMsg "Statement is already finished" where execute' st = do withConnectionUnlocked (ssConnection stmt) $ \con -> forM_ (zip [1..] $ toRow vals) $ uncurry $ bindParam con st res <- getRight (ssConnection stmt) $ SD.step st -- if this is INSERT or UPDATE query we need -- step to execute it. return $ SQExecuted st res -- executeRaw: use default -- executeMany: use default implementation statementStatus stmt = do res <- readMVar $ ssState stmt return $ case res of SQNew {} -> StatementNew SQExecuted {} -> StatementExecuted SQFetching _ SD.Row -> StatementExecuted SQFetching _ SD.Done -> StatementFetched SQFinished -> StatementFinished finish stmt = modifyMVar_ (ssState stmt) $ \st -> case st of SQFinished -> return SQFinished x -> do getRight (ssConnection stmt) $ SD.finalize $ sqStatement x return SQFinished reset stmt = modifyMVar_ (ssState stmt) $ \st -> case st of SQFinished -> withConnectionUnlocked (ssConnection stmt) $ \con -> do res <- SD.prepare con $ encodeLUTF8 $ unQuery $ ssQuery stmt case res of Left err -> throwErrMsg con $ show err Right mst -> case mst of Nothing -> throwErrMsg con "" Just s -> return $ SQNew s x -> do let rst = sqStatement x getRight (ssConnection stmt) $ SD.reset rst SD.clearBindings rst return $ SQNew rst fetch stmt = modifyMVar (ssState stmt) $ \st -> case st of SQNew _ -> throwIO $ SqlDriverError $ sqliteMsg "Statement is not executed to fetch rows from" SQExecuted x SD.Row -> fetch' x SQExecuted x SD.Done -> return (SQFetching x SD.Done, Nothing) SQFetching x SD.Row -> fetch' x r@(SQFetching _ SD.Done) -> return (r, Nothing) SQFinished -> throwIO $ SqlDriverError $ sqliteMsg "Statement is already finished to fetch rows from" where fetch' ss = do cc <- SD.columnCount ss res <- forM [0..cc-1] $ \col -> fetchValue ss col sres <- getRight (ssConnection stmt) $ SD.step ss return (SQFetching ss sres, Just $ fromRow res) getColumnNames stmt = do st <- readMVar $ ssState stmt case st of SQNew {} -> throwIO $ SqlDriverError $ sqliteMsg "Statement is not executed to get column names" SQFinished -> throwIO $ SqlDriverError $ sqliteMsg "Statement is already finished to get column names" x -> do let sqst = sqStatement x cols <- SD.columnCount sqst forM [0..cols-1] $ \col -> do res <- SD.columnName sqst col case res of Nothing -> return "" Just val -> return $ decodeLUTF8 val getColumnsCount stmt = do st <- readMVar $ ssState stmt case st of SQNew {} -> throwIO $ SqlDriverError $ sqliteMsg "Statement is not executed to get columns count" SQFinished -> throwIO $ SqlDriverError $ sqliteMsg "Statement is already finished to get columns count" x -> do (SD.ColumnIndex idx) <- SD.columnCount $ sqStatement x return idx originalQuery stmt = ssQuery stmt -- | fetch value from particular column of current row in particular statement fetchValue :: SD.Statement -> SD.ColumnIndex -> IO SqlValue fetchValue ss col = do ct <- SD.columnType ss col case ct of SD.IntegerColumn -> SqlInteger . toInteger <$> SD.columnInt64 ss col SD.FloatColumn -> SqlDouble <$> SD.columnDouble ss col SD.TextColumn -> SqlText . decodeLUTF8 <$> SD.columnText ss col SD.BlobColumn -> SqlBlob <$> SD.columnBlob ss col SD.NullColumn -> return SqlNull -- | If action return (Left error) then get description from the database and -- throw error. Else return the value. getRight :: SQliteConnection -> IO (Either SD.Error a) -- ^ action to execute -> IO a getRight con act = do res <- act case res of Left err -> withConnectionUnlocked con $ \c -> throwErrMsg c $ show err Right a -> return a -- | bind SqlValue to the particular parameter of query of particular statement. bindParam :: SD.Database -> SD.Statement -> SD.ParamIndex -> SqlValue -> IO () bindParam con st idx val = do res <- bind val case res of Left err -> throwErrMsg con $ show err Right () -> return () where binds x = SD.bindText st idx $ encodeSUTF8 x bindShow :: (Show a) => a -> IO (Either SD.Error ()) bindShow x = binds $ show x bindi i = SD.bindInt64 st idx i downInt64 :: Integer -> Maybe Int64 downInt64 i = if i > imax || i < imin then Nothing else Just $ fromInteger i where imax = toInteger (maxBound :: Int64) imin = toInteger (minBound :: Int64) bind (SqlDecimal d) = bindShow d bind (SqlInteger i) = case downInt64 i of Nothing -> bindShow i Just i64 -> bindi i64 bind (SqlDouble d) = SD.bindDouble st idx d bind (SqlText t) = SD.bindText st idx $ encodeLUTF8 t bind (SqlBlob b) = SD.bindBlob st idx b bind (SqlBool b) = bindi $ if b then 1 else 0 bind (SqlBitField bf) = bindShow bf bind (SqlUUID u) = bindShow u bind (SqlUTCTime ut) = binds $ formatIsoUTCTime ut bind (SqlLocalDate d) = binds $ formatIsoDay d bind (SqlLocalTimeOfDay td) = binds $ formatIsoTimeOfDay td bind (SqlLocalTime lt) = binds $ formatIsoLocalTime lt bind SqlNull = SD.bindNull st idx -- | Get error description from the database and throw exception with it. throwErrMsg :: SD.Database -> String -> IO a throwErrMsg db err = do (SD.Utf8 errmsg) <- SD.errmsg db throwIO $ SqlError err $ T.unpack $ T.decodeUtf8 errmsg -- | prepend package name to the string for error reporting sqliteMsg :: String -> String sqliteMsg = ("hdbi-sqlite: " ++) -- | Get internal 'SD.Database' from the 'SQliteConnection' and execute and -- action with it. Or throw an error if connection is already closed. withConnectionUnlocked :: SQliteConnection -> (SD.Database -> IO a) -> IO a withConnectionUnlocked conn fun = do val <- readMVar $ scDatabase $ conn case val of Nothing -> throwIO $ SqlDriverError $ sqliteMsg $ "connection is closed" Just x -> fun x