{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}

-- | Testing with an in-memory sqlite database using persistent-sqlite
--
-- For a fully worked example, see sydtest-yesod/blog-example.
module Test.Syd.Persistent.Sqlite
  ( persistSqliteSpec,
    withConnectionPool,
    connectionPoolSetupFunc,
    connectionPoolSetupFunc',
    runSqlPool,
    runSqliteTest,
  )
where

import Control.Monad.Logger
import Control.Monad.Reader
import Database.Persist.Sql
import Database.Persist.Sqlite
import Test.Syd

-- | Declare a test suite that uses a database connection.
--
-- This sets up the database connection around every test, so state is not preserved accross tests.
persistSqliteSpec :: Migration -> SpecWith ConnectionPool -> SpecWith a
persistSqliteSpec :: Migration -> SpecWith ConnectionPool -> SpecWith a
persistSqliteSpec Migration
migration = ((ConnectionPool -> IO ()) -> a -> IO ())
-> SpecWith ConnectionPool -> SpecWith a
forall newInner oldInner (outers :: [*]) result.
((newInner -> IO ()) -> oldInner -> IO ())
-> TestDefM outers newInner result
-> TestDefM outers oldInner result
aroundWith (((ConnectionPool -> IO ()) -> a -> IO ())
 -> SpecWith ConnectionPool -> SpecWith a)
-> ((ConnectionPool -> IO ()) -> a -> IO ())
-> SpecWith ConnectionPool
-> SpecWith a
forall a b. (a -> b) -> a -> b
$ \ConnectionPool -> IO ()
func a
_ -> Migration -> (ConnectionPool -> IO ()) -> IO ()
withConnectionPool Migration
migration ConnectionPool -> IO ()
func

-- | Set up a sqlite connection and migrate it to run the given function.
withConnectionPool :: Migration -> (ConnectionPool -> IO ()) -> IO ()
withConnectionPool :: Migration -> (ConnectionPool -> IO ()) -> IO ()
withConnectionPool = ((ConnectionPool -> IO ()) -> Migration -> IO ())
-> Migration -> (ConnectionPool -> IO ()) -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (((ConnectionPool -> IO ()) -> Migration -> IO ())
 -> Migration -> (ConnectionPool -> IO ()) -> IO ())
-> ((ConnectionPool -> IO ()) -> Migration -> IO ())
-> Migration
-> (ConnectionPool -> IO ())
-> IO ()
forall a b. (a -> b) -> a -> b
$ SetupFunc Migration ConnectionPool
-> forall r. (ConnectionPool -> IO r) -> Migration -> IO r
forall old new.
SetupFunc old new -> forall r. (new -> IO r) -> old -> IO r
unSetupFunc SetupFunc Migration ConnectionPool
connectionPoolSetupFunc'

-- | The 'SetupFunc' version of 'withConnectionPool'.
connectionPoolSetupFunc :: Migration -> SetupFunc () ConnectionPool
connectionPoolSetupFunc :: Migration -> SetupFunc () ConnectionPool
connectionPoolSetupFunc = SetupFunc Migration ConnectionPool
-> Migration -> SetupFunc () ConnectionPool
forall old new. SetupFunc old new -> old -> SetupFunc () new
unwrapSetupFunc SetupFunc Migration ConnectionPool
connectionPoolSetupFunc'

-- | A wrapped version of 'connectionPoolSetupFunc'
connectionPoolSetupFunc' :: SetupFunc Migration ConnectionPool
connectionPoolSetupFunc' :: SetupFunc Migration ConnectionPool
connectionPoolSetupFunc' = (forall r. (ConnectionPool -> IO r) -> Migration -> IO r)
-> SetupFunc Migration ConnectionPool
forall old new.
(forall r. (new -> IO r) -> old -> IO r) -> SetupFunc old new
SetupFunc ((forall r. (ConnectionPool -> IO r) -> Migration -> IO r)
 -> SetupFunc Migration ConnectionPool)
-> (forall r. (ConnectionPool -> IO r) -> Migration -> IO r)
-> SetupFunc Migration ConnectionPool
forall a b. (a -> b) -> a -> b
$ \ConnectionPool -> IO r
takeConnectionPool Migration
migration ->
  NoLoggingT IO r -> IO r
forall (m :: * -> *) a. NoLoggingT m a -> m a
runNoLoggingT (NoLoggingT IO r -> IO r) -> NoLoggingT IO r -> IO r
forall a b. (a -> b) -> a -> b
$
    Text
-> Int -> (ConnectionPool -> NoLoggingT IO r) -> NoLoggingT IO r
forall (m :: * -> *) a.
(MonadUnliftIO m, MonadLogger m) =>
Text -> Int -> (ConnectionPool -> m a) -> m a
withSqlitePool Text
":memory:" Int
1 ((ConnectionPool -> NoLoggingT IO r) -> NoLoggingT IO r)
-> (ConnectionPool -> NoLoggingT IO r) -> NoLoggingT IO r
forall a b. (a -> b) -> a -> b
$ \ConnectionPool
pool -> do
      ()
_ <- (ReaderT SqlBackend (NoLoggingT IO) ()
 -> ConnectionPool -> NoLoggingT IO ())
-> ConnectionPool
-> ReaderT SqlBackend (NoLoggingT IO) ()
-> NoLoggingT IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT SqlBackend (NoLoggingT IO) ()
-> ConnectionPool -> NoLoggingT IO ()
forall backend (m :: * -> *) a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a -> Pool backend -> m a
runSqlPool ConnectionPool
pool (ReaderT SqlBackend (NoLoggingT IO) () -> NoLoggingT IO ())
-> ReaderT SqlBackend (NoLoggingT IO) () -> NoLoggingT IO ()
forall a b. (a -> b) -> a -> b
$ Migration -> ReaderT SqlBackend (NoLoggingT IO) ()
forall (m :: * -> *).
MonadIO m =>
Migration -> ReaderT SqlBackend m ()
migrationRunner Migration
migration
      IO r -> NoLoggingT IO r
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO r -> NoLoggingT IO r) -> IO r -> NoLoggingT IO r
forall a b. (a -> b) -> a -> b
$ ConnectionPool -> IO r
takeConnectionPool ConnectionPool
pool

#if MIN_VERSION_persistent(2,10,2)
migrationRunner :: MonadIO m => Migration -> ReaderT SqlBackend m ()
migrationRunner :: Migration -> ReaderT SqlBackend m ()
migrationRunner = ReaderT SqlBackend m [Text] -> ReaderT SqlBackend m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ReaderT SqlBackend m [Text] -> ReaderT SqlBackend m ())
-> (Migration -> ReaderT SqlBackend m [Text])
-> Migration
-> ReaderT SqlBackend m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Migration -> ReaderT SqlBackend m [Text]
forall (m :: * -> *).
MonadIO m =>
Migration -> ReaderT SqlBackend m [Text]
runMigrationQuiet
#else
migrationRunner :: MonadIO m => Migration -> ReaderT SqlBackend m ()
migrationRunner = runMigration
#endif

-- | A flipped version of 'runSqlPool' to run your tests
runSqliteTest :: ConnectionPool -> SqlPersistT IO a -> IO a
runSqliteTest :: ConnectionPool -> SqlPersistT IO a -> IO a
runSqliteTest = (SqlPersistT IO a -> ConnectionPool -> IO a)
-> ConnectionPool -> SqlPersistT IO a -> IO a
forall a b c. (a -> b -> c) -> b -> a -> c
flip SqlPersistT IO a -> ConnectionPool -> IO a
forall backend (m :: * -> *) a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a -> Pool backend -> m a
runSqlPool