{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ViewPatterns #-}

-- | This entire module only serves to be backwards compatible with Test.Hspec.Wai.Matcher
--
-- This approach of asserting what the response looks like is obsolete because of the way sydtest does things.
-- You should use `shouldBe` instead.
module Test.Syd.Wai.Matcher where

import Control.Monad
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy as LB
import qualified Data.CaseInsensitive as CI
import Data.Char as Char (isPrint, isSpace)
import Data.Maybe
import Data.String
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import Network.HTTP.Types as HTTP

type Body = LB.ByteString

data ResponseMatcher = ResponseMatcher
  { ResponseMatcher -> Int
matchStatus :: Int,
    ResponseMatcher -> [MatchHeader]
matchHeaders :: [MatchHeader],
    ResponseMatcher -> MatchBody
matchBody :: MatchBody
  }

data MatchHeader = MatchHeader ([Header] -> Body -> Maybe String)

data MatchBody = MatchBody ([Header] -> Body -> Maybe String)

bodyEquals :: Body -> MatchBody
bodyEquals :: Body -> MatchBody
bodyEquals Body
body = ([Header] -> Body -> Maybe String) -> MatchBody
MatchBody (\[Header]
_ Body
actual -> Body -> Body -> Maybe String
bodyMatcher Body
actual Body
body)
  where
    bodyMatcher :: Body -> Body -> Maybe String
    bodyMatcher :: Body -> Body -> Maybe String
bodyMatcher (Body -> ByteString
LB.toStrict -> ByteString
actual) (Body -> ByteString
LB.toStrict -> ByteString
expected) = String -> String -> String -> String
actualExpected String
"body mismatch:" String
actual_ String
expected_ String -> Maybe () -> Maybe String
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString
actual ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
expected)
      where
        (String
actual_, String
expected_) = case (ByteString -> Maybe String
safeToString ByteString
actual, ByteString -> Maybe String
safeToString ByteString
expected) of
          (Just String
x, Just String
y) -> (String
x, String
y)
          (Maybe String, Maybe String)
_ -> (ByteString -> String
forall a. Show a => a -> String
show ByteString
actual, ByteString -> String
forall a. Show a => a -> String
show ByteString
expected)

matchAny :: MatchBody
matchAny :: MatchBody
matchAny = ([Header] -> Body -> Maybe String) -> MatchBody
MatchBody (\[Header]
_ Body
_ -> Maybe String
forall a. Maybe a
Nothing)

instance IsString MatchBody where
  fromString :: String -> MatchBody
fromString = Body -> MatchBody
bodyEquals (Body -> MatchBody) -> (String -> Body) -> String -> MatchBody
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Body
LB.fromStrict (ByteString -> Body) -> (String -> ByteString) -> String -> Body
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
TE.encodeUtf8 (Text -> ByteString) -> (String -> Text) -> String -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack

instance IsString ResponseMatcher where
  fromString :: String -> ResponseMatcher
fromString = Int -> [MatchHeader] -> MatchBody -> ResponseMatcher
ResponseMatcher Int
200 [] (MatchBody -> ResponseMatcher)
-> (String -> MatchBody) -> String -> ResponseMatcher
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> MatchBody
forall a. IsString a => String -> a
fromString

instance Num ResponseMatcher where
  fromInteger :: Integer -> ResponseMatcher
fromInteger Integer
n = Int -> [MatchHeader] -> MatchBody -> ResponseMatcher
ResponseMatcher (Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
n) [] MatchBody
matchAny
  + :: ResponseMatcher -> ResponseMatcher -> ResponseMatcher
(+) = String -> ResponseMatcher -> ResponseMatcher -> ResponseMatcher
forall a. HasCallStack => String -> a
error String
"ResponseMatcher does not support (+)"
  (-) = String -> ResponseMatcher -> ResponseMatcher -> ResponseMatcher
forall a. HasCallStack => String -> a
error String
"ResponseMatcher does not support (-)"
  * :: ResponseMatcher -> ResponseMatcher -> ResponseMatcher
(*) = String -> ResponseMatcher -> ResponseMatcher -> ResponseMatcher
forall a. HasCallStack => String -> a
error String
"ResponseMatcher does not support (*)"
  abs :: ResponseMatcher -> ResponseMatcher
abs = String -> ResponseMatcher -> ResponseMatcher
forall a. HasCallStack => String -> a
error String
"ResponseMatcher does not support `abs`"
  signum :: ResponseMatcher -> ResponseMatcher
signum = String -> ResponseMatcher -> ResponseMatcher
forall a. HasCallStack => String -> a
error String
"ResponseMatcher does not support `signum`"

(<:>) :: HeaderName -> ByteString -> MatchHeader
HeaderName
name <:> :: HeaderName -> ByteString -> MatchHeader
<:> ByteString
value = ([Header] -> Body -> Maybe String) -> MatchHeader
MatchHeader (([Header] -> Body -> Maybe String) -> MatchHeader)
-> ([Header] -> Body -> Maybe String) -> MatchHeader
forall a b. (a -> b) -> a -> b
$ \[Header]
headers Body
_body ->
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Header
header Header -> [Header] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Header]
headers)
    Maybe () -> Maybe String -> Maybe String
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String)
-> ([String] -> String) -> [String] -> Maybe String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unlines)
      [ String
"missing header:",
        Header -> String
formatHeader Header
header
      ]
  where
    header :: Header
header = (HeaderName
name, ByteString
value)

actualExpected :: String -> String -> String -> String
actualExpected :: String -> String -> String -> String
actualExpected String
message String
actual String
expected =
  [String] -> String
unlines
    [ String
message,
      String
"  expected: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
expected,
      String
"  but got:  " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
actual
    ]

formatHeader :: Header -> String
formatHeader :: Header -> String
formatHeader header :: Header
header@(HeaderName
name, ByteString
value) = String
"  " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> Maybe String -> String
forall a. a -> Maybe a -> a
fromMaybe (Header -> String
forall a. Show a => a -> String
show Header
header) (ByteString -> Maybe String
safeToString (ByteString -> Maybe String) -> ByteString -> Maybe String
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
B8.concat [HeaderName -> ByteString
forall s. CI s -> s
CI.original HeaderName
name, ByteString
": ", ByteString
value])

safeToString :: ByteString -> Maybe String
safeToString :: ByteString -> Maybe String
safeToString ByteString
bs = do
  String
str <- (UnicodeException -> Maybe String)
-> (Text -> Maybe String)
-> Either UnicodeException Text
-> Maybe String
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe String -> UnicodeException -> Maybe String
forall a b. a -> b -> a
const Maybe String
forall a. Maybe a
Nothing) (String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String)
-> (Text -> String) -> Text -> Maybe String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
T.unpack) (ByteString -> Either UnicodeException Text
TE.decodeUtf8' ByteString
bs)
  let isSafe :: Bool
isSafe = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ case String
str of
        [] -> Bool
True
        String
_ -> Char -> Bool
Char.isSpace (String -> Char
forall a. [a] -> a
last String
str) Bool -> Bool -> Bool
|| Bool -> Bool
not ((Char -> Bool) -> String -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Char -> Bool
Char.isPrint String
str)
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
isSafe Maybe () -> Maybe String -> Maybe String
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> Maybe String
forall (m :: * -> *) a. Monad m => a -> m a
return String
str