{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE PatternSynonyms       #-}
{-# LANGUAGE OverloadedStrings     #-}

-- | Provides instances to be able to use combinators from
-- "Servant.API.NamedArgs" with "Servant.Client", returning functions
-- using named parameters from "Named"
module Servant.Server.NamedArgs where

import Named ((:!), (:?), arg, argF, Name(..), (!), NamedF(..))
import Named.Internal (pattern Arg)
import Data.Functor.Identity (Identity)
import Servant.API ( (:>), SBoolI, FromHttpApiData, toQueryParam, toUrlPiece
                   , parseQueryParam, parseHeader, If, SBool(..), sbool)
import Servant.API.Modifiers (FoldRequired, FoldLenient)
import Servant.API.NamedArgs ( foldRequiredNamedArgument, NamedCapture', NamedFlag
                             , NamedParam, NamedParams, RequiredNamedArgument
                             , NamedCaptureAll, RequestNamedArgument, NamedHeader'
                             , unfoldRequestNamedArgument, NamedBody')
import Servant.API.ContentTypes (AllCTUnrender(..))
import Data.Either (partitionEithers)
import Data.Maybe (mapMaybe, fromMaybe)
import Servant.Server (HasServer(..), errBody, err400, err415)
import Servant.Server.Internal ( passToServer, addParameterCheck, withRequest, delayedFailFatal
                               , Router'(..), addCapture, delayedFail, DelayedIO, addHeaderCheck
                               , addBodyCheck)
import Web.HttpApiData (parseUrlPieceMaybe, parseUrlPieces)
import Data.String.Conversions (cs)
import Network.HTTP.Types (parseQueryText, hContentType)
import Network.Wai (rawQueryString, Request, requestHeaders, lazyRequestBody)
import Data.Text (Text)
import qualified Data.Text as T
import Data.String (IsString(..))
import Control.Monad (join)
import GHC.TypeLits (KnownSymbol, symbolVal)
import Data.Proxy (Proxy(..))
import Control.Monad.IO.Class (liftIO)
import qualified Data.ByteString.Lazy as BL

-- | 'NamedFlag's are converted to required named arguments
instance (KnownSymbol name, HasServer api context)
    => HasServer (NamedFlag name :> api) context where

  type ServerT (NamedFlag name :> api) m
      = (name :! Bool) -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s

  route Proxy context subserver =
      route (Proxy @api) context (passToServer subserver (Arg . param))
    where
      queryText = parseQueryText . rawQueryString
      param r = case lookup paramName (queryText r) of
          Just Nothing  -> True  -- param is there, with no value
          Just (Just v) -> examine v -- param with a value
          Nothing       -> False -- param not in the query string
      paramName = cs $ symbolVal (Proxy @name)
      examine v
        | v == "true" || v == "1" || v == "" = True
        | otherwise = False

-- | 'NamedCapture''s are converted to required named arguments
instance (KnownSymbol name, FromHttpApiData a, HasServer api context)
      => HasServer (NamedCapture' mods name a :> api) context where

  type ServerT (NamedCapture' mods name a :> api) m =
     (name :! a) -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s

  route Proxy context d =
    CaptureRouter $
        route (Proxy @api)
              context
              (addCapture d $ \ txt -> case parseUrlPieceMaybe txt of
                 Nothing -> delayedFail err400
                 Just v  -> pure $ Arg v
              )

-- | 'NamedCaptureAll's are converted to required named arguments, taking a list
instance (KnownSymbol name, FromHttpApiData a, HasServer api context)
      => HasServer (NamedCaptureAll name a :> api) context where

  type ServerT (NamedCaptureAll name a :> api) m =
    (name :! [a]) -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s

  route Proxy context d =
      CaptureAllRouter $
          route (Proxy @api)
                context
                (addCapture d $ \ txts -> case parseUrlPieces txts of
                   Left _  -> delayedFail err400
                   Right v -> pure $ Arg v
                )

-- | 'NamedParams's are converted to required named arguments, taking a list
instance (KnownSymbol name, FromHttpApiData a, HasServer api context)
    => HasServer (NamedParams name a :> api) context where

  type ServerT (NamedParams name a :> api) m =
    (name :! [a]) -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s

  route Proxy context subserver = route (Proxy @api) context $
      subserver `addParameterCheck` withRequest paramsCheck
    where
      paramname = cs $ symbolVal (Proxy @name)
      paramsCheck req =
          case partitionEithers $ fmap parseQueryParam params of
              ([], parsed) -> pure $ Arg parsed
              (errs, _)    -> delayedFailFatal err400
                  { errBody = cs $ "Error parsing query parameter(s) "
                                   <> paramname <> " failed: "
                                   <> T.intercalate ", " errs
                  }
        where
          params :: [Text]
          params = mapMaybe snd
                 . filter (looksLikeParam . fst)
                 . parseQueryText
                 . rawQueryString
                 $ req

          looksLikeParam name = name == paramname || name == (paramname <> "[]")

-- | 'NamedHeader''s are converted to required or optional named arguments
-- depending on the 'Servant.API.Modifiers.Required' and
-- 'Servant.API.Modifiers.Optional' modifiers, of type a or 'Either' 'Text'
-- a depending on the 'Servant.API.Modifiers.Strict' and
-- 'Servant.API.Modifiers.Lenient' modifiers
instance ( KnownSymbol name
         , FromHttpApiData a
         , HasServer api context
         , SBoolI (FoldRequired mods)
         , SBoolI (FoldLenient mods)
         ) => HasServer (NamedHeader' mods name a :> api) context where

  type ServerT (NamedHeader' mods name a :> api) m =
    RequestNamedArgument mods name a -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s

  route Proxy context subserver = route (Proxy @api) context $
      subserver `addHeaderCheck` withRequest headerCheck
    where
      headerName :: IsString n => n
      headerName = fromString $ symbolVal (Proxy @name)

      headerCheck :: Request -> DelayedIO (RequestNamedArgument mods name a)
      headerCheck req =
          unfoldRequestNamedArgument @mods @name errReq errSt mev
        where
          mev :: Maybe (Either T.Text a)
          mev = fmap parseHeader $ lookup headerName (requestHeaders req)

          errReq = delayedFailFatal err400
            { errBody = "Header " <> headerName <> " is required"
            }

          errSt e = delayedFailFatal err400
              { errBody = cs $ "Error parsing header "
                               <> headerName
                               <> " failed: " <> e
              }

-- | 'NamedParam's are converted to required or optional named arguments
-- depending on the 'Servant.API.Modifiers.Required' and
-- 'Servant.API.Modifiers.Optional' modifiers, of type a or 'Either' 'Text'
-- a depending on the 'Servant.API.Modifiers.Strict' and
-- 'Servant.API.Modifiers.Lenient' modifiers
instance ( KnownSymbol name
         , FromHttpApiData a
         , HasServer api context
         , SBoolI (FoldRequired mods)
         , SBoolI (FoldLenient mods)
         ) => HasServer (NamedParam mods name a :> api) context where

  type ServerT (NamedParam mods name a :> api) m =
    RequestNamedArgument mods name a -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s

  route Proxy context subserver = route (Proxy @api) context delayed
    where
      queryText = parseQueryText . rawQueryString
      paramName = cs $ symbolVal (Proxy @name)

      errReq = delayedFailFatal err400
        { errBody = cs $ "Query parameter " <> paramName <> " is required"
        }
      errSt e = delayedFailFatal err400
        { errBody = cs $ "Error parsing query parameter "
                           <> paramName
                           <> " failed: " <> e
        }

      mev :: Request -> Maybe (Either T.Text a)
      mev req = fmap parseQueryParam $ join $ lookup paramName $ queryText req

      parseParam :: Request -> DelayedIO (RequestNamedArgument mods name a)
      parseParam req
          = unfoldRequestNamedArgument @mods @name errReq errSt (mev req)

      delayed = addParameterCheck subserver . withRequest $ \req -> parseParam req

instance ( KnownSymbol name, AllCTUnrender list a, HasServer api context
         , SBoolI (FoldLenient mods)
         ) => HasServer (NamedBody' mods name list a :> api) context where

  type ServerT (NamedBody' mods name list a :> api) m =
    (name :! (If (FoldLenient mods) (Either String a) a)) -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s

  route _ context subserver =
    route (Proxy @api) context $
      addBodyCheck subserver ctCheck bodyCheck
    where
      ctCheck = withRequest $ \req -> do
        let contentTypeH = fromMaybe "application/octet-stream"
                            $ lookup hContentType  $ requestHeaders req
        case canHandleCTypeH (Proxy @list) (cs contentTypeH) :: Maybe (BL.ByteString -> Either String a) of
          Nothing -> delayedFail err415
          Just f -> pure f

      bodyCheck f = withRequest $ \req -> do
        mrqbody <- f <$> liftIO (lazyRequestBody req)
        case sbool :: SBool (FoldLenient mods) of
          STrue -> pure . Arg $ mrqbody
          SFalse -> case mrqbody of
            Left e  -> delayedFailFatal err400 { errBody = cs e }
            Right v -> pure . Arg $ v