{-# LANGUAGE CPP                                                           #-}
{-# LANGUAGE DataKinds                                                     #-}
{-# LANGUAGE FlexibleContexts                                              #-}
{-# LANGUAGE FlexibleInstances                                             #-}
{-# LANGUAGE GeneralizedNewtypeDeriving                                    #-}
{-# LANGUAGE FunctionalDependencies                                        #-}
{-# LANGUAGE MultiParamTypeClasses                                         #-}
{-# LANGUAGE OverloadedStrings                                             #-}
{-# LANGUAGE ScopedTypeVariables                                           #-}
{-# LANGUAGE TupleSections                                                 #-}
{-# LANGUAGE TypeFamilies                                                  #-}
{-# LANGUAGE TypeSynonymInstances                                          #-}
module Test.Hspec.Snap (
  
    snap
  , modifySite
  , modifySite'
  , afterEval
  , beforeEval
  
  , TestResponse(..)
  , RespCode(..)
  , SnapHspecM
  
  , Factory(..)
  
  , delete
  , get
  , get'
  , post
  , postJson
  , put
  , put'
  , params
  
  , restrictResponse
  
  , recordSession
  , HasSession(..)
  , sessionShouldContain
  , sessionShouldNotContain
  
  , eval
  
  , shouldChange
  , shouldEqual
  , shouldNotEqual
  , shouldBeTrue
  , shouldNotBeTrue
  
  , should200
  , shouldNot200
  , should404
  , shouldNot404
  , should300
  , shouldNot300
  , should300To
  , shouldNot300To
  , shouldHaveSelector
  , shouldNotHaveSelector
  , shouldHaveText
  , shouldNotHaveText
  
  , FormExpectations(..)
  , form
  
  , SnapHspecState(..)
  , setResult
  , runRequest
  , runHandlerSafe
  , evalHandlerSafe
  ) where
import           Control.Applicative     ((<$>))
import           Control.Concurrent.MVar (MVar, newEmptyMVar, newMVar, putMVar,
                                          readMVar, takeMVar)
import           Control.Exception       (SomeException, catch)
import           Control.Monad           (void)
import           Control.Monad.State     (StateT (..), runStateT)
import qualified Control.Monad.State     as S (get, put)
import           Control.Monad.Trans     (liftIO)
import           Data.Aeson              (ToJSON, encode)
import           Data.ByteString         (ByteString)
import           Data.ByteString.Lazy    (fromStrict, toStrict)
import qualified Data.ByteString.Lazy    as LBS (ByteString)
import qualified Data.Map                as M
import           Data.Maybe              (fromMaybe)
import           Data.Text               (Text)
import qualified Data.Text               as T
import qualified Data.Text.Encoding      as T
import           Snap.Core               (Response (..), getHeader)
import qualified Snap.Core               as Snap
import           Snap.Snaplet            (Handler, Snaplet, SnapletInit,
                                          SnapletLens, with)
import           Snap.Snaplet.Session    (SessionManager, commitSession,
                                          sessionToList, setInSession)
import           Snap.Snaplet.Test       (InitializerState, closeSnaplet,
                                          evalHandler', getSnaplet, runHandler')
import           Snap.Test               (RequestBuilder, getResponseBody)
import qualified Snap.Test               as Test
import           Test.Hspec
import           Test.Hspec.Core.Spec
import qualified Text.Digestive          as DF
import qualified Text.HandsomeSoup       as HS
import qualified Text.XML.HXT.Core       as HXT
newtype RespCode = RespCode Int deriving (Show, Read, Eq, Num, Ord)
data TestResponse = Html RespCode Text
                  | Json RespCode LBS.ByteString
                  | NotFound
                  | Redirect RespCode Text
                  | Other RespCode
                  | Empty
                  deriving (Show, Eq)
type SnapHspecM b = StateT (SnapHspecState b) IO
data SnapHspecState b = SnapHspecState
#if MIN_VERSION_hspec(2,5,0)
                                       ResultStatus
#else
                                       Result
#endif
                                       (Handler b b ())
                                       (Snaplet b)
                                       (InitializerState b)
                                       (MVar [(Text, Text)])
                                       (Handler b b ())
                                       (Handler b b ())
instance Example (SnapHspecM b ()) where
  type Arg (SnapHspecM b ()) = SnapHspecState b
  evaluateExample s _ cb _ =
    do mv <- newEmptyMVar
       cb $ \st -> do ((),SnapHspecState r' _ _ _ _ _ _) <- runStateT s st
                      putMVar mv r'
#if MIN_VERSION_hspec(2,5,0)
       rs <- takeMVar mv
       return $ Result "" rs
#else
       takeMVar mv
#endif
class Factory b a d | a -> b, a -> d, d -> a where
  fields :: d
  save :: d -> SnapHspecM b a
  create :: (d -> d) -> SnapHspecM b a
  create transform = save $ transform fields
  reload :: a -> SnapHspecM b a
  reload = return
snap :: Handler b b () -> SnapletInit b b -> SpecWith (SnapHspecState b) -> Spec
snap site app spec = do
  snapinit <- runIO $ getSnaplet (Just "test") app
  mv <- runIO (newMVar [])
  case snapinit of
    Left err -> error $ show err
    Right (snaplet, initstate) ->
      afterAll (const $ closeSnaplet initstate) $
        before (return (SnapHspecState Success site snaplet initstate mv (return ()) (return ()))) spec
modifySite :: (Handler b b () -> Handler b b ())
           -> SpecWith (SnapHspecState b)
           -> SpecWith (SnapHspecState b)
modifySite f = beforeWith (\(SnapHspecState r site snaplet initst sess bef aft) ->
                             return (SnapHspecState r (f site) snaplet initst sess bef aft))
modifySite' :: (Handler b b () -> Handler b b ())
            -> SnapHspecM b a
            -> SnapHspecM b a
modifySite' f a = do (SnapHspecState r site s i sess bef aft) <- S.get
                     S.put (SnapHspecState r (f site) s i sess bef aft)
                     a
afterEval :: Handler b b () -> SpecWith (SnapHspecState b) -> SpecWith (SnapHspecState b)
afterEval h = after (\(SnapHspecState _r _site s i _ _ _) ->
                       do res <- evalHandlerSafe h s i
                          case res of
                            Right _ -> return ()
                            Left msg -> liftIO $ print msg)
beforeEval :: Handler b b () -> SpecWith (SnapHspecState b) -> SpecWith (SnapHspecState b)
beforeEval h = beforeWith (\state@(SnapHspecState _r _site s i _ _ _) -> do void $ evalHandlerSafe h s i
                                                                            return state)
class HasSession b where
  getSessionLens :: SnapletLens b SessionManager
recordSession :: HasSession b => SnapHspecM b a -> SnapHspecM b a
recordSession a =
  do (SnapHspecState r site s i mv bef aft) <- S.get
     S.put (SnapHspecState r site s i mv
                             (do ps <- liftIO $ readMVar mv
                                 with getSessionLens $ mapM_ (uncurry setInSession) ps
                                 with getSessionLens commitSession)
                             (do ps' <- with getSessionLens sessionToList
                                 void . liftIO $ takeMVar mv
                                 liftIO $ putMVar mv ps'))
     res <- a
     (SnapHspecState r' _ _ _ _ _ _) <- S.get
     void . liftIO $ takeMVar mv
     liftIO $ putMVar mv []
     S.put (SnapHspecState r' site s i mv bef aft)
     return res
sessContents :: SnapHspecM b Text
sessContents = do
  (SnapHspecState _ _ _ _ mv _ _) <- S.get
  ps <- liftIO $ readMVar mv
  return $ T.concat (map (uncurry T.append) ps)
sessionShouldContain :: Text -> SnapHspecM b ()
sessionShouldContain t =
  do contents <- sessContents
     if t `T.isInfixOf` contents
       then setResult Success
       else setResult (Failure Nothing $ Reason $ "Session did not contain: " ++ T.unpack t
                                    ++ "\n\nSession was:\n" ++ T.unpack contents)
sessionShouldNotContain :: Text -> SnapHspecM b ()
sessionShouldNotContain t =
  do contents <- sessContents
     if t `T.isInfixOf` contents
       then setResult (Failure Nothing $ Reason $ "Session should not have contained: " ++ T.unpack t
                                    ++ "\n\nSession was:\n" ++ T.unpack contents)
       else setResult Success
delete :: Text -> SnapHspecM b TestResponse
delete path = runRequest (Test.delete (T.encodeUtf8 path) M.empty)
get :: Text -> SnapHspecM b TestResponse
get path = get' path M.empty
get' :: Text -> Snap.Params -> SnapHspecM b TestResponse
get' path ps = runRequest (Test.get (T.encodeUtf8 path) ps)
params :: [(ByteString, ByteString)] 
       -> Snap.Params
params = M.fromList . map (\x -> (fst x, [snd x]))
post :: Text -> Snap.Params -> SnapHspecM b TestResponse
post path ps = runRequest (Test.postUrlEncoded (T.encodeUtf8 path) ps)
postJson :: ToJSON tj => Text -> tj -> SnapHspecM b TestResponse
postJson path json = runRequest $ Test.postRaw (T.encodeUtf8 path)
                                               "application/json"
                                               (toStrict $ encode json)
put :: Text -> Snap.Params -> SnapHspecM b TestResponse
put path params' = put' path "application/x-www-form-urlencoded" params'
put' :: Text -> Text -> Snap.Params -> SnapHspecM b TestResponse
put' path mime params' = runRequest $ do
  Test.put (T.encodeUtf8 path) (T.encodeUtf8 mime) ""
  Test.setQueryString params'
restrictResponse :: Text -> TestResponse -> TestResponse
restrictResponse selector (Html code body) =
  case HXT.runLA (HXT.xshow $ HXT.hread HXT.>>> HS.css (T.unpack selector)) (T.unpack body) of
    [] -> Html code ""
    matches -> Html code (T.concat (map T.pack matches))
restrictResponse _ r = r
eval :: Handler b b a -> SnapHspecM b a
eval act = do (SnapHspecState _ _site app is _mv bef aft) <- S.get
              liftIO $ either (error . T.unpack) id <$> evalHandlerSafe (do bef
                                                                            r <- act
                                                                            aft
                                                                            return r) app is
#if MIN_VERSION_hspec(2,5,0)
setResult :: ResultStatus -> SnapHspecM b ()
#else
setResult :: Result -> SnapHspecM b ()
#endif
setResult r = do (SnapHspecState r' s a i sess bef aft) <- S.get
                 case r' of
                   Success -> S.put (SnapHspecState r s a i sess bef aft)
                   _ -> return ()
shouldChange :: (Show a, Eq a)
             => (a -> a)
             -> Handler b b a
             -> SnapHspecM b c
             -> SnapHspecM b ()
shouldChange f v act = do before' <- eval v
                          void act
                          after' <- eval v
                          shouldEqual (f before') after'
shouldEqual :: (Show a, Eq a)
            => a
            -> a
            -> SnapHspecM b ()
shouldEqual a b = if a == b
                      then setResult Success
                      else setResult (Failure Nothing $ Reason ("Should have held: " ++ show a ++ " == " ++ show b))
shouldNotEqual :: (Show a, Eq a)
               => a
               -> a
               -> SnapHspecM b ()
shouldNotEqual a b = if a == b
                         then setResult (Failure Nothing $ Reason ("Should not have held: " ++ show a ++ " == " ++ show b))
                         else setResult Success
shouldBeTrue :: Bool
             -> SnapHspecM b ()
shouldBeTrue True = setResult Success
shouldBeTrue False = setResult (Failure Nothing $ Reason "Value should have been True.")
shouldNotBeTrue :: Bool
                 -> SnapHspecM b ()
shouldNotBeTrue False = setResult Success
shouldNotBeTrue True = setResult (Failure Nothing $ Reason "Value should have been True.")
should200 :: TestResponse -> SnapHspecM b ()
should200 (Html _ _)   = setResult Success
should200 (Json 200 _) = setResult Success
should200 (Other 200)  = setResult Success
should200 r = setResult (Failure Nothing $ Reason (show r))
shouldNot200 :: TestResponse -> SnapHspecM b ()
shouldNot200 (Html _ _) = setResult (Failure Nothing $ Reason "Got Html back.")
shouldNot200 (Other 200) = setResult (Failure Nothing $ Reason "Got Other with 200 back.")
shouldNot200 _ = setResult Success
should404 :: TestResponse -> SnapHspecM b ()
should404 NotFound = setResult Success
should404 r = setResult (Failure Nothing $ Reason (show r))
shouldNot404 :: TestResponse -> SnapHspecM b ()
shouldNot404 NotFound = setResult (Failure Nothing $ Reason "Got NotFound back.")
shouldNot404 _ = setResult Success
should300 :: TestResponse -> SnapHspecM b ()
should300 (Redirect _ _) = setResult Success
should300 r = setResult (Failure Nothing $ Reason (show r))
shouldNot300 :: TestResponse -> SnapHspecM b ()
shouldNot300 (Redirect _ _) = setResult (Failure Nothing $ Reason "Got Redirect back.")
shouldNot300 _ = setResult Success
should300To :: Text -> TestResponse -> SnapHspecM b ()
should300To pth (Redirect _ to) | pth `T.isPrefixOf` to = setResult Success
should300To _ r = setResult (Failure Nothing $ Reason (show r))
shouldNot300To :: Text -> TestResponse -> SnapHspecM b ()
shouldNot300To pth (Redirect _ to) | pth `T.isPrefixOf` to = setResult (Failure Nothing $ Reason "Got Redirect back.")
shouldNot300To _ _ = setResult Success
shouldHaveSelector :: Text -> TestResponse -> SnapHspecM b ()
shouldHaveSelector selector r@(Html _ body) =
  setResult $ if haveSelector' selector r
                then Success
                else Failure Nothing $ Reason msg
  where msg = T.unpack $ T.concat ["Html should have contained selector: ", selector, "\n\n", body]
shouldHaveSelector match _ = setResult (Failure Nothing $ Reason (T.unpack $ T.concat ["Non-HTML body should have contained css selector: ", match]))
shouldNotHaveSelector :: Text -> TestResponse -> SnapHspecM b ()
shouldNotHaveSelector selector r@(Html _ body) =
  setResult $ if haveSelector' selector r
                then Failure Nothing $ Reason msg
                else Success
  where msg = T.unpack $ T.concat ["Html should not have contained selector: ", selector, "\n\n", body]
shouldNotHaveSelector _ _ = setResult Success
haveSelector' :: Text -> TestResponse -> Bool
haveSelector' selector (Html _ body) =
  case HXT.runLA (HXT.hread HXT.>>> HS.css (T.unpack selector)) (T.unpack body)  of
    [] -> False
    _ -> True
haveSelector' _ _ = False
shouldHaveText :: Text -> TestResponse -> SnapHspecM b ()
shouldHaveText match (Html _ body) =
  if T.isInfixOf match body
  then setResult Success
  else setResult (Failure Nothing $ Reason $ T.unpack $ T.concat [body, "' contains '", match, "'."])
shouldHaveText match _ = setResult (Failure Nothing $ Reason (T.unpack $ T.concat ["Body contains: ", match]))
shouldNotHaveText :: Text -> TestResponse -> SnapHspecM b ()
shouldNotHaveText match (Html _ body) =
  if T.isInfixOf match body
  then setResult (Failure Nothing $ Reason $ T.unpack $ T.concat [body, "' contains '", match, "'."])
  else setResult Success
shouldNotHaveText _ _ = setResult Success
data FormExpectations a = Value a           
                        | Predicate (a -> Bool)
                        | ErrorPaths [Text] 
form :: (Eq a, Show a)
     => FormExpectations a           
                                     
     -> DF.Form Text (Handler b b) a 
     -> M.Map Text Text                
     -> SnapHspecM b ()
form expected theForm theParams =
  do r <- eval $ DF.postForm "form" theForm (const $ return lookupParam)
     case expected of
       Value a -> shouldEqual (snd r) (Just a)
       Predicate f ->
         case snd r of
           Nothing -> setResult (Failure Nothing $ Reason $ T.unpack $
                                 T.append "Expected form to validate. Resulted in errors: "
                                          (T.pack (show $ DF.viewErrors $ fst r)))
           Just v -> if f v
                       then setResult Success
                       else setResult (Failure Nothing $ Reason $ T.unpack $
                                       T.append "Expected predicate to pass on value: "
                                                (T.pack (show v)))
       ErrorPaths expectedPaths ->
         do let viewErrorPaths = map (DF.fromPath . fst) $ DF.viewErrors $ fst r
            if all (`elem` viewErrorPaths) expectedPaths
               then if length viewErrorPaths == length expectedPaths
                       then setResult Success
                       else setResult (Failure Nothing $ Reason $ "Number of errors did not match test. Got:\n\n "
                                            ++ show viewErrorPaths
                                            ++ "\n\nBut expected:\n\n"
                                            ++ show expectedPaths)
               else setResult (Failure Nothing $ Reason $ "Did not have all errors specified. Got:\n\n"
                                    ++ show viewErrorPaths
                                    ++ "\n\nBut expected:\n\n"
                                    ++ show expectedPaths)
  where lookupParam pth = case M.lookup (DF.fromPath pth) fixedParams of
                            Nothing -> return []
                            Just v -> return [DF.TextInput v]
        fixedParams = M.mapKeys (T.append "form.") theParams
runRequest :: RequestBuilder IO () -> SnapHspecM b TestResponse
runRequest req = do
  (SnapHspecState _ site app is _ bef aft) <- S.get
  res <- liftIO $ runHandlerSafe req (bef >> site >> aft) app is
  case res of
    Left err ->
      error $ T.unpack err
    Right response -> let respCode = respStatus response in
      case respCode of
        404 -> return NotFound
        200 ->
          liftIO $ parse200 response
        _ -> if respCode >= 300 && respCode < 400
                then do let url = fromMaybe "" $ getHeader "Location" response
                        return (Redirect respCode (T.decodeUtf8 url))
                else return (Other respCode)
respStatus :: Response -> RespCode
respStatus = RespCode . rspStatus
parse200 :: Response -> IO TestResponse
parse200 resp =
    let body        = getResponseBody resp
        contentType = getHeader "content-type" resp in
    case contentType of
      Just "application/json" -> Json 200 . fromStrict <$> body
      _                       -> Html 200 . T.decodeUtf8 <$> body
runHandlerSafe :: RequestBuilder IO ()
               -> Handler b b v
               -> Snaplet b
               -> InitializerState b
               -> IO (Either Text Response)
runHandlerSafe req site s is =
  catch (runHandler' s is req site) (\(e::SomeException) -> return $ Left (T.pack $ show e))
evalHandlerSafe :: Handler b b v
                -> Snaplet b
                -> InitializerState b
                -> IO (Either Text v)
evalHandlerSafe act s is =
  catch (evalHandler' s is (Test.get "" M.empty) act) (\(e::SomeException) -> return $ Left (T.pack $ show e))
{-# ANN put ("HLint: ignore Eta reduce"::String)                            #-}