{-# language DerivingVia #-}
{-# language OverloadedStrings #-}
{-# language ScopedTypeVariables #-}

module Database.Sqlite.Easy.Internal where

import Database.SQLite3
import Data.String (IsString)
import Data.Text (Text)
import Control.Exception
import Data.Typeable
import Data.Pool

-- * Connection

-- | A SQLite3 connection string
newtype ConnectionString
  = ConnectionString
    { ConnectionString -> Text
unConnectionString :: Text
    }
  deriving String -> ConnectionString
forall a. (String -> a) -> IsString a
fromString :: String -> ConnectionString
$cfromString :: String -> ConnectionString
IsString via Text
  deriving Int -> ConnectionString -> ShowS
[ConnectionString] -> ShowS
ConnectionString -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionString] -> ShowS
$cshowList :: [ConnectionString] -> ShowS
show :: ConnectionString -> String
$cshow :: ConnectionString -> String
showsPrec :: Int -> ConnectionString -> ShowS
$cshowsPrec :: Int -> ConnectionString -> ShowS
Show

-- | Create a pool of a sqlite3 db with a specific connection string.
createSqlitePool :: ConnectionString -> IO (Pool Database)
createSqlitePool :: ConnectionString -> IO (Pool Database)
createSqlitePool (ConnectionString Text
connStr) =
  forall a. PoolConfig a -> IO (Pool a)
newPool PoolConfig
    { createResource :: IO Database
createResource = Text -> IO Database
open Text
connStr
    , freeResource :: Database -> IO ()
freeResource = Database -> IO ()
close
    , poolCacheTTL :: Double
poolCacheTTL = Double
180
    , poolMaxResources :: Int
poolMaxResources = Int
50
    }

-- | Open a database, run some stuff, close the database.
withDb :: ConnectionString -> (Database -> IO a) -> IO a
withDb :: forall a. ConnectionString -> (Database -> IO a) -> IO a
withDb (ConnectionString Text
connStr) = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Text -> IO Database
open Text
connStr) Database -> IO ()
close

-- * Execution

-- | A SQL statement
newtype SQL
  = SQL
    { SQL -> Text
unSQL :: Text
    }
  deriving (NonEmpty SQL -> SQL
SQL -> SQL -> SQL
forall b. Integral b => b -> SQL -> SQL
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
stimes :: forall b. Integral b => b -> SQL -> SQL
$cstimes :: forall b. Integral b => b -> SQL -> SQL
sconcat :: NonEmpty SQL -> SQL
$csconcat :: NonEmpty SQL -> SQL
<> :: SQL -> SQL -> SQL
$c<> :: SQL -> SQL -> SQL
Semigroup, String -> SQL
forall a. (String -> a) -> IsString a
fromString :: String -> SQL
$cfromString :: String -> SQL
IsString) via Text
  deriving Int -> SQL -> ShowS
[SQL] -> ShowS
SQL -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SQL] -> ShowS
$cshowList :: [SQL] -> ShowS
show :: SQL -> String
$cshow :: SQL -> String
showsPrec :: Int -> SQL -> ShowS
$cshowsPrec :: Int -> SQL -> ShowS
Show

-- | Run a SQL statement on a database and fetch the results.
run :: SQL -> Database -> IO [[SQLData]]
run :: SQL -> Database -> IO [[SQLData]]
run (SQL Text
stmt) Database
db = Database -> Text -> IO Statement
prepare Database
db Text
stmt forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Statement -> IO [[SQLData]]
fetchAll

-- | Run a SQL statement with certain parameters on a database and fetch the results.
runWith :: SQL -> [SQLData] -> Database -> IO [[SQLData]]
runWith :: SQL -> [SQLData] -> Database -> IO [[SQLData]]
runWith (SQL Text
stmt) [SQLData]
params Database
db = do
  Statement
preparedStmt <- Database -> Text -> IO Statement
prepare Database
db Text
stmt
  Statement -> [SQLData] -> IO ()
bind Statement
preparedStmt [SQLData]
params
  Statement -> IO [[SQLData]]
fetchAll Statement
preparedStmt

-- | Run a statement and fetch all of the data.
fetchAll :: Statement -> IO [[SQLData]]
fetchAll :: Statement -> IO [[SQLData]]
fetchAll Statement
stmt = do
  StepResult
res <- Statement -> IO StepResult
step Statement
stmt
  case StepResult
res of
    StepResult
Row -> do
      [SQLData]
row <- Statement -> IO [SQLData]
columns Statement
stmt
      [[SQLData]]
rows <- Statement -> IO [[SQLData]]
fetchAll Statement
stmt
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SQLData]
row forall a. a -> [a] -> [a]
: [[SQLData]]
rows)
    StepResult
Done -> do
      Statement -> IO ()
finalize Statement
stmt
      forall (f :: * -> *) a. Applicative f => a -> f a
pure []

-- * Transaction

-- | Run operations as a transaction.
--   If the action throws an error, the transaction is rolled back.
--
--   __Note__: Transactions do not nest.
--
--   For more information, visit: <https://www.sqlite.org/lang_transaction.html>
asTransaction :: Typeable a => Database -> IO a -> IO a
asTransaction :: forall a. Typeable a => Database -> IO a -> IO a
asTransaction Database
db IO a
action = do
  [] <- SQL -> Database -> IO [[SQLData]]
run SQL
"BEGIN" Database
db
  forall a. IO a -> [Handler a] -> IO a
catches
    (IO a
action forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* SQL -> Database -> IO [[SQLData]]
run SQL
"COMMIT" Database
db)
    [ forall a e. Exception e => (e -> IO a) -> Handler a
Handler forall a b. (a -> b) -> a -> b
$ \(CancelTransaction a
a) -> SQL -> Database -> IO [[SQLData]]
run SQL
"ROLLBACK" Database
db forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
    , forall a e. Exception e => (e -> IO a) -> Handler a
Handler forall a b. (a -> b) -> a -> b
$ \(SomeException
ex :: SomeException) -> SQL -> Database -> IO [[SQLData]]
run SQL
"ROLLBACK" Database
db forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall e a. Exception e => e -> IO a
throwIO SomeException
ex
    ]

asTransaction' :: Database -> IO a -> IO a
asTransaction' :: forall a. Database -> IO a -> IO a
asTransaction' Database
db IO a
action = do
  [] <- SQL -> Database -> IO [[SQLData]]
run SQL
"BEGIN" Database
db
  forall a. IO a -> [Handler a] -> IO a
catches
    (IO a
action forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* SQL -> Database -> IO [[SQLData]]
run SQL
"COMMIT" Database
db)
    [ forall a e. Exception e => (e -> IO a) -> Handler a
Handler forall a b. (a -> b) -> a -> b
$ \(SomeException
ex :: SomeException) -> SQL -> Database -> IO [[SQLData]]
run SQL
"ROLLBACK" Database
db forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall e a. Exception e => e -> IO a
throwIO SomeException
ex
    ]

-- | Cancel a transaction by supplying the return value.
--   To be used inside transactions.
cancelTransaction :: Typeable a => a -> IO a
cancelTransaction :: forall a. Typeable a => a -> IO a
cancelTransaction = forall e a. Exception e => e -> IO a
throwIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> CancelTransaction a
CancelTransaction

data CancelTransaction a
  = CancelTransaction a

instance Show (CancelTransaction a) where
  show :: CancelTransaction a -> String
show CancelTransaction{} = String
"CancelTransaction"

instance (Typeable a) => Exception (CancelTransaction a)