-- author: Benjamin Surma <benjamin.surma@gmail.com>

{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE CPP #-}

module Test.Sandbox.HUnit (
    assertFailure
  , assertBool
  , assertEqual
  , assertString
  , assertException
  ) where

import Test.Sandbox

import Control.Exception.Lifted

import qualified Test.HUnit

#if MIN_VERSION_HUnit(1,5,0)
import Test.HUnit.Lang (HUnitFailure (..), formatFailureReason)
#else
import Test.HUnit.Lang (HUnitFailure (..))
#endif

-- | Unconditionally signals that a failure has occured.
assertFailure :: String     -- ^ A message that is displayed with the assertion failure
              -> Sandbox ()
assertFailure :: String -> Sandbox ()
assertFailure = Sandbox () -> Sandbox ()
wrap (Sandbox () -> Sandbox ())
-> (String -> Sandbox ()) -> String -> Sandbox ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> Sandbox ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Sandbox ()) -> (String -> IO ()) -> String -> Sandbox ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO ()
forall a. HasCallStack => String -> IO a
Test.HUnit.assertFailure

-- | Asserts that the specified condition holds.
assertBool :: String     -- ^ The message that is displayed if the assertion fails
           -> Bool       -- ^ The condition
           -> Sandbox ()
assertBool :: String -> Bool -> Sandbox ()
assertBool String
s Bool
b = Sandbox () -> Sandbox ()
wrap (Sandbox () -> Sandbox ()) -> Sandbox () -> Sandbox ()
forall a b. (a -> b) -> a -> b
$ IO () -> Sandbox ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (HasCallStack => String -> Bool -> IO ()
String -> Bool -> IO ()
Test.HUnit.assertBool String
s Bool
b)

-- | Asserts that the specified actual value is equal to the expected value.
assertEqual :: (Eq a, Show a)
            => String     -- ^ The message prefix
            -> a          -- ^ The expected value
            -> a          -- ^ The actual value
            -> Sandbox ()
assertEqual :: String -> a -> a -> Sandbox ()
assertEqual String
s a
a a
b = Sandbox () -> Sandbox ()
wrap (Sandbox () -> Sandbox ()) -> Sandbox () -> Sandbox ()
forall a b. (a -> b) -> a -> b
$ IO () -> Sandbox ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (String -> a -> a -> IO ()
forall a. (HasCallStack, Eq a, Show a) => String -> a -> a -> IO ()
Test.HUnit.assertEqual String
s a
a a
b)

-- | Signals an assertion failure if a non-empty message (i.e., a message other than "") is passed.
assertString :: String     -- ^ The message that is displayed with the assertion failure
             -> Sandbox ()
assertString :: String -> Sandbox ()
assertString String
s = Sandbox () -> Sandbox ()
wrap (Sandbox () -> Sandbox ()) -> Sandbox () -> Sandbox ()
forall a b. (a -> b) -> a -> b
$ IO () -> Sandbox ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (HasCallStack => String -> IO ()
String -> IO ()
Test.HUnit.assertString String
s)

-- | Signals an assertion failure if *no* exception is raised.
assertException :: String     -- ^ The message that is displayed with the assertion failure
                -> Sandbox a
                -> Sandbox ()
assertException :: String -> Sandbox a -> Sandbox ()
assertException String
s Sandbox a
a =
  String -> Bool -> Sandbox ()
assertBool String
s (Bool -> Sandbox ())
-> ExceptT String (ReaderT SandboxStateRef IO) Bool -> Sandbox ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Sandbox a
a Sandbox a
-> ExceptT String (ReaderT SandboxStateRef IO) Bool
-> ExceptT String (ReaderT SandboxStateRef IO) Bool
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> ExceptT String (ReaderT SandboxStateRef IO) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False) ExceptT String (ReaderT SandboxStateRef IO) Bool
-> (String -> ExceptT String (ReaderT SandboxStateRef IO) Bool)
-> ExceptT String (ReaderT SandboxStateRef IO) Bool
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` ExceptT String (ReaderT SandboxStateRef IO) Bool
-> String -> ExceptT String (ReaderT SandboxStateRef IO) Bool
forall a b. a -> b -> a
const (Bool -> ExceptT String (ReaderT SandboxStateRef IO) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True)

wrap :: Sandbox () -> Sandbox ()
#if MIN_VERSION_HUnit(1,5,0)
wrap :: Sandbox () -> Sandbox ()
wrap Sandbox ()
action = Sandbox ()
action Sandbox () -> (HUnitFailure -> Sandbox ()) -> Sandbox ()
forall (m :: * -> *) e a.
(MonadBaseControl IO m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (\ (HUnitFailure Maybe SrcLoc
_ FailureReason
e :: HUnitFailure) -> String -> Sandbox ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> Sandbox ()) -> String -> Sandbox ()
forall a b. (a -> b) -> a -> b
$ FailureReason -> String
formatFailureReason FailureReason
e)
#elif MIN_VERSION_HUnit(1,3,0)
wrap action = action `catch` (\ (HUnitFailure _ e :: HUnitFailure) -> throwError e)
#else
wrap action = action `catch` (\ (HUnitFailure e :: HUnitFailure) -> throwError e)
#endif