{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE OverloadedStrings #-}
module Servant.Server.NamedArgs where
import Named ((:!), (:?), arg, argF, Name(..), (!), NamedF(..))
import Named.Internal (pattern Arg)
import Data.Functor.Identity (Identity)
import Servant.API ( (:>), SBoolI, FromHttpApiData, toQueryParam, toUrlPiece
, parseQueryParam, parseHeader )
import Servant.API.Modifiers (FoldRequired, FoldLenient)
import Servant.API.NamedArgs ( foldRequiredNamedArgument, NamedCapture', NamedFlag
, NamedParam, NamedParams, RequiredNamedArgument
, NamedCaptureAll, RequestNamedArgument, NamedHeader'
, unfoldRequestNamedArgument)
import Data.Either (partitionEithers)
import Data.Maybe (mapMaybe)
import Servant.Server (HasServer(..), errBody, err400)
import Servant.Server.Internal ( passToServer, addParameterCheck, withRequest, delayedFailFatal
, Router'(..), addCapture, delayedFail, DelayedIO, addHeaderCheck)
import Web.HttpApiData (parseUrlPieceMaybe, parseUrlPieces)
import Data.String.Conversions (cs)
import Network.HTTP.Types (parseQueryText)
import Network.Wai (rawQueryString, Request, requestHeaders)
import Data.Text (Text)
import qualified Data.Text as T
import Data.String (IsString(..))
import Control.Monad (join)
import GHC.TypeLits (KnownSymbol, symbolVal)
import Data.Proxy (Proxy(..))
instance (KnownSymbol name, HasServer api context)
=> HasServer (NamedFlag name :> api) context where
type ServerT (NamedFlag name :> api) m
= (name :! Bool) -> ServerT api m
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s
route Proxy context subserver =
route (Proxy @api) context (passToServer subserver (Arg . param))
where
queryText = parseQueryText . rawQueryString
param r = case lookup paramName (queryText r) of
Just Nothing -> True
Just (Just v) -> examine v
Nothing -> False
paramName = cs $ symbolVal (Proxy @name)
examine v
| v == "true" || v == "1" || v == "" = True
| otherwise = False
instance (KnownSymbol name, FromHttpApiData a, HasServer api context)
=> HasServer (NamedCapture' mods name a :> api) context where
type ServerT (NamedCapture' mods name a :> api) m =
(name :! a) -> ServerT api m
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s
route Proxy context d =
CaptureRouter $
route (Proxy @api)
context
(addCapture d $ \ txt -> case parseUrlPieceMaybe txt of
Nothing -> delayedFail err400
Just v -> pure $ Arg v
)
instance (KnownSymbol name, FromHttpApiData a, HasServer api context)
=> HasServer (NamedCaptureAll name a :> api) context where
type ServerT (NamedCaptureAll name a :> api) m =
(name :! [a]) -> ServerT api m
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s
route Proxy context d =
CaptureAllRouter $
route (Proxy @api)
context
(addCapture d $ \ txts -> case parseUrlPieces txts of
Left _ -> delayedFail err400
Right v -> pure $ Arg v
)
instance (KnownSymbol name, FromHttpApiData a, HasServer api context)
=> HasServer (NamedParams name a :> api) context where
type ServerT (NamedParams name a :> api) m =
(name :! [a]) -> ServerT api m
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s
route Proxy context subserver = route (Proxy @api) context $
subserver `addParameterCheck` withRequest paramsCheck
where
paramname = cs $ symbolVal (Proxy @name)
paramsCheck req =
case partitionEithers $ fmap parseQueryParam params of
([], parsed) -> pure $ Arg parsed
(errs, _) -> delayedFailFatal err400
{ errBody = cs $ "Error parsing query parameter(s) "
<> paramname <> " failed: "
<> T.intercalate ", " errs
}
where
params :: [Text]
params = mapMaybe snd
. filter (looksLikeParam . fst)
. parseQueryText
. rawQueryString
$ req
looksLikeParam name = name == paramname || name == (paramname <> "[]")
instance ( KnownSymbol name
, FromHttpApiData a
, HasServer api context
, SBoolI (FoldRequired mods)
, SBoolI (FoldLenient mods)
) => HasServer (NamedHeader' mods name a :> api) context where
type ServerT (NamedHeader' mods name a :> api) m =
RequestNamedArgument mods name a -> ServerT api m
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s
route Proxy context subserver = route (Proxy @api) context $
subserver `addHeaderCheck` withRequest headerCheck
where
headerName :: IsString n => n
headerName = fromString $ symbolVal (Proxy @name)
headerCheck :: Request -> DelayedIO (RequestNamedArgument mods name a)
headerCheck req =
unfoldRequestNamedArgument @mods @name errReq errSt mev
where
mev :: Maybe (Either T.Text a)
mev = fmap parseHeader $ lookup headerName (requestHeaders req)
errReq = delayedFailFatal err400
{ errBody = "Header " <> headerName <> " is required"
}
errSt e = delayedFailFatal err400
{ errBody = cs $ "Error parsing header "
<> headerName
<> " failed: " <> e
}
instance ( KnownSymbol name
, FromHttpApiData a
, HasServer api context
, SBoolI (FoldRequired mods)
, SBoolI (FoldLenient mods)
) => HasServer (NamedParam mods name a :> api) context where
type ServerT (NamedParam mods name a :> api) m =
RequestNamedArgument mods name a -> ServerT api m
hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy @api) pc nt . s
route Proxy context subserver = route (Proxy @api) context delayed
where
queryText = parseQueryText . rawQueryString
paramName = cs $ symbolVal (Proxy @name)
errReq = delayedFailFatal err400
{ errBody = cs $ "Query parameter " <> paramName <> " is required"
}
errSt e = delayedFailFatal err400
{ errBody = cs $ "Error parsing query parameter "
<> paramName
<> " failed: " <> e
}
mev :: Request -> Maybe (Either T.Text a)
mev req = fmap parseQueryParam $ join $ lookup paramName $ queryText req
parseParam :: Request -> DelayedIO (RequestNamedArgument mods name a)
parseParam req
= unfoldRequestNamedArgument @mods @name errReq errSt (mev req)
delayed = addParameterCheck subserver . withRequest $ \req -> parseParam req