module Serv.Wai.Type (
Server (..)
, ServerResult (..)
, returnServer
, notFound
, badRequest
, methodNotAllowed
, orElse
, mapServer
, serverApplication
, serverApplication'
, serverApplication''
, defaultRoutingErrorResponse
, Context (..)
, makeContext
, Contextual (..)
) where
import Control.Monad.Morph
import Control.Monad.State
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as Sl
import qualified Data.CaseInsensitive as CI
import Data.IORef
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (catMaybes)
import Data.Set (Set)
import Data.Singletons
import Data.Singletons.TypeLits
import Data.String
import Data.Text (Text)
import qualified Data.Text.Encoding as Text
import Network.HTTP.Kinder.Header (HeaderDecode, HeaderName,
Sing (SAllow), SomeHeaderName,
headerDecodeBS, headerEncodePair,
headerName, parseHeaderName)
import Network.HTTP.Kinder.Query (QueryDecode (..),
QueryKeyState (..))
import qualified Network.HTTP.Kinder.Status as St
import Network.HTTP.Kinder.Verb (Verb, parseVerb)
import Network.HTTP.Types.URI (queryToQueryText)
import Network.Wai
import Serv.Wai.Error (RoutingError)
import qualified Serv.Wai.Error as Error
newtype Server m = Server { runServer :: StateT Context m ServerResult }
returnServer :: Monad m => m ServerResult -> Server m
returnServer m = Server (lift m)
mapServer :: Monad m => (forall x . m x -> n x) -> Server m -> Server n
mapServer phi (Server act) = Server (hoist phi act)
orElse :: Monad m => Server m -> Server m -> Server m
orElse sa sb = Server $ do
(a, ctx) <- fork (runServer sa)
case a of
RoutingError e
| Error.ignorable e -> runServer sb
| otherwise -> restore ctx >> return a
_ -> restore ctx >> return a
notFound :: Monad m => Server m
notFound = Server (return (RoutingError Error.NotFound))
methodNotAllowed :: Monad m => Set Verb -> Server m
methodNotAllowed verbs =
Server (return (RoutingError (Error.MethodNotAllowed verbs)))
badRequest :: Monad m => Maybe String -> Server m
badRequest err = Server (return (RoutingError (Error.BadRequest err)))
serverApplication :: Server IO -> Application
serverApplication server = serverApplication' server (const id)
serverApplication' :: Server IO -> (Context -> Response -> Response) -> Application
serverApplication' server xform = do
serverApplication'' server $ \ctx res ->
case res of
RoutingError err -> xform ctx (defaultRoutingErrorResponse err)
WaiResponse resp -> xform ctx resp
_ -> error "Recieved 'Application' value in 'serverApplication'' impl"
serverApplication''
:: Server IO
-> (Context -> ServerResult -> Response)
-> Application
serverApplication'' server xform request respond = do
ctx0 <- makeContext request
(val, ctx1) <- runStateT (runServer server) ctx0
case val of
Application app -> app ctx1 (ctxRequest ctx1) respond
_ -> respond (xform ctx1 val)
defaultRoutingErrorResponse :: RoutingError -> Response
defaultRoutingErrorResponse err =
case err of
Error.NotFound ->
responseLBS (St.httpStatus St.SNotFound) [] ""
Error.BadRequest e -> do
let errString = fromString (maybe "" id e)
responseLBS (St.httpStatus St.SBadRequest) [] (fromString errString)
Error.UnsupportedMediaType ->
responseLBS (St.httpStatus St.SUnsupportedMediaType) [] ""
Error.MethodNotAllowed verbs -> do
responseLBS
(St.httpStatus St.SMethodNotAllowed)
(catMaybes [headerEncodePair SAllow verbs])
""
data ServerResult
= RoutingError RoutingError
| WaiResponse Response
| Application (Context -> Application)
class Contextual m where
fork :: m a -> m (a, Context)
restore :: Context -> m ()
getVerb :: m (Maybe Verb)
endOfPath :: m Bool
popSegment :: m (Maybe Text)
popAllSegments :: m [Text]
getHeader
:: forall a (n :: HeaderName)
. HeaderDecode n a => Sing n -> m (Either String a)
expectHeader
:: forall (n :: HeaderName)
. Sing n -> Text -> m Bool
getQuery :: QueryDecode s a => Sing s -> m (Either String a)
getHeaderRaw
:: forall m (n :: HeaderName)
. Monad m => Sing n -> StateT Context m (Maybe S.ByteString)
getHeaderRaw s = do
hdrs <- gets ctxHeaders
return $ Map.lookup (headerName s) hdrs
declareHeader
:: forall m (n :: HeaderName)
. Monad m => Sing n -> Maybe Text -> StateT Context m ()
declareHeader s val =
modify $ \ctx ->
ctx { ctxHeaderAccess =
Map.insert
(headerName s) val
(ctxHeaderAccess ctx) }
getQueryRaw
:: forall m (n :: Symbol)
. Monad m => Sing n -> StateT Context m (QueryKeyState Text)
getQueryRaw s = do
qs <- gets ctxQuery
let qKey = withKnownSymbol s (fromString (symbolVal s))
return $ case Map.lookup qKey qs of
Nothing -> QueryKeyAbsent
Just Nothing -> QueryKeyPresent
Just (Just val) -> QueryKeyValued val
declareQuery
:: forall m (n :: Symbol)
. Monad m => Sing n -> StateT Context m ()
declareQuery s = do
let qKey = withKnownSymbol s (fromString (symbolVal s))
modify $ \ctx ->
ctx { ctxQueryAccess = qKey : ctxQueryAccess ctx }
instance Monad m => Contextual (StateT Context m) where
fork m = StateT $ \ctx -> do
(a, newCtx) <- runStateT m ctx
return ((a, newCtx), ctx)
restore = put
getVerb = parseVerb <$> gets (requestMethod . ctxRequest)
endOfPath = do
path <- gets ctxPathZipper
case path of
(_, []) -> return True
_ -> return False
popSegment = do
state $ \ctx ->
case ctxPathZipper ctx of
(_past, []) -> (Nothing, ctx)
(past, seg:future) ->
(Just seg, ctx { ctxPathZipper = (seg:past, future) })
popAllSegments = do
state $ \ctx ->
case ctxPathZipper ctx of
(past, fut) ->
(fut, ctx { ctxPathZipper = (reverse fut ++ past, []) })
getHeader s = do
declareHeader s Nothing
mayVal <- getHeaderRaw s
return (headerDecodeBS s mayVal)
expectHeader s expected = do
declareHeader s (Just expected)
mayVal <- fmap (fmap Text.decodeUtf8) (getHeaderRaw s)
return (maybe False (== expected) mayVal)
getQuery s = do
declareQuery s
qks <- getQueryRaw s
return (queryDecode s qks)
data Context =
Context
{ ctxRequest :: Request
, ctxPathZipper :: ([Text], [Text])
, ctxHeaders :: Map SomeHeaderName S.ByteString
, ctxHeaderAccess :: Map SomeHeaderName (Maybe Text)
, ctxQuery :: Map Text (Maybe Text)
, ctxQueryAccess :: [Text]
, ctxBody :: S.ByteString
}
makeContext :: Request -> IO Context
makeContext theRequest = do
theBody <- strictRequestBody theRequest
ref <- newIORef (Sl.toStrict theBody)
let headerSet =
map (\(name, value) ->
(parseHeaderName (ciBsToText name), value))
(requestHeaders theRequest)
let querySet = queryToQueryText (queryString theRequest)
return Context { ctxRequest = theRequest { requestBody = readIORef ref }
, ctxPathZipper = ([], pathInfo theRequest)
, ctxHeaders = Map.fromList headerSet
, ctxQuery = Map.fromList querySet
, ctxHeaderAccess = Map.empty
, ctxQueryAccess = []
, ctxBody = Sl.toStrict theBody
}
ciBsToText :: CI.CI S.ByteString -> CI.CI Text
ciBsToText = CI.mk . Text.decodeUtf8 . CI.original