module Network.Wai.Middleware.Cors
( CorsResourcePolicy(..)
, cors
, isSimple
, simpleResponseHeaders
, simpleHeaders
, simpleContentTypes
, simpleMethods
) where
import Control.Applicative
import Control.Error
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Resource
import qualified Data.Attoparsec as AttoParsec
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy.Char8 as LB8
import qualified Data.CaseInsensitive as CI
import qualified Data.CharSet as CS
import Data.List (intersect, (\\), union)
import Data.Monoid.Unicode
import Data.String
import qualified Network.HTTP.Types as HTTP
import qualified Network.Wai as WAI
import Prelude.Unicode
import qualified Text.Parser.Char as P
import qualified Text.Parser.Combinators as P
#if MIN_VERSION_wai(2,0,0)
type ReqMonad = IO
#else
type ReqMonad = ResourceT IO
#endif
type Origin = B8.ByteString
data CorsResourcePolicy = CorsResourcePolicy
{
corsOrigins ∷ !(Maybe ([Origin], Bool))
, corsMethods ∷ ![HTTP.Method]
, corsRequestHeaders ∷ ![HTTP.HeaderName]
, corsExposedHeaders ∷ !(Maybe [HTTP.HeaderName])
, corsMaxAge ∷ !(Maybe Int)
, corsVaryOrigin ∷ !Bool
, corsVerboseResponse ∷ !Bool
}
deriving (Show,Read,Eq,Ord)
cors
∷ (WAI.Request → Maybe CorsResourcePolicy)
→ WAI.Middleware
cors policyPattern app r
| Just policy ← policyPattern r = case hdrOrigin of
Nothing → return $ corsFailure (corsVerboseResponse policy) "Origin header is missing"
Just origin → runEitherT (applyCorsPolicy policy origin) >>= \case
Left e → return $ corsFailure (corsVerboseResponse policy) (B8.pack e)
Right response → return response
| otherwise = app r
where
hdrOrigin = lookup "origin" (WAI.requestHeaders r)
applyCorsPolicy
∷ CorsResourcePolicy
→ Origin
→ EitherT String ReqMonad WAI.Response
applyCorsPolicy policy origin = do
respOrigin ← case corsOrigins policy of
Nothing → return Nothing
Just (originList, withCreds) → if origin `elem` originList
then return $ Just (origin, withCreds)
else left $ "Unsupported origin: " ⊕ B8.unpack origin
let ch = commonCorsHeaders respOrigin (corsVaryOrigin policy)
case WAI.requestMethod r of
"OPTIONS" → do
headers ← (⊕) <$> pure ch <*> preflightHeaders policy
return $ WAI.responseLBS HTTP.ok200 headers ""
_ → lift $ app r >>= addHeaders (ch ⊕ respCorsHeaders policy)
preflightHeaders ∷ Monad μ ⇒ CorsResourcePolicy → EitherT String μ HTTP.ResponseHeaders
preflightHeaders policy = concat <$> sequence
[ hdrReqMethod policy
, hdrRequestHeader policy
, hdrMaxAge policy
]
hdrMaxAge ∷ Monad μ ⇒ CorsResourcePolicy → EitherT String μ HTTP.ResponseHeaders
hdrMaxAge policy = case corsMaxAge policy of
Nothing → return []
Just secs → return [("Access-Control-Max-Age", sshow secs)]
hdrReqMethod ∷ Monad μ ⇒ CorsResourcePolicy → EitherT String μ HTTP.ResponseHeaders
hdrReqMethod policy = case lookup "Access-Control-Request-Method" (WAI.requestHeaders r) of
Nothing → left "Access-Control-Request-Method header is missing in CORS preflight request"
Just x → if x `elem` supportedMethods
then return [("Access-Control-Allow-Methods", hdrL supportedMethods)]
else left
$ "Method requested in Access-Control-Request-Method of CORS request is not supported; requested: "
⊕ B8.unpack x
⊕ "; supported are "
⊕ B8.unpack (hdrL supportedMethods)
⊕ "."
where
supportedMethods = corsMethods policy `union` simpleMethods
hdrRequestHeader ∷ Monad μ ⇒ CorsResourcePolicy → EitherT String μ HTTP.ResponseHeaders
hdrRequestHeader policy = case lookup "Access-Control-Request-Headers" (WAI.requestHeaders r) of
Nothing → return []
Just hdrsBytes → do
hdrs ← hoistEither $ AttoParsec.parseOnly httpHeaderNameListParser hdrsBytes
if hdrs `isSubsetOf` supportedHeaders
then return [("Access-Control-Allow-Headers", hdrLI supportedHeaders)]
else left
$ "HTTP header requested in Access-Control-Request-Headers of CORS request is not supported; requested: "
⊕ B8.unpack (hdrLI hdrs)
⊕ "; supported are "
⊕ B8.unpack (hdrLI supportedHeaders)
⊕ "."
where
supportedHeaders = corsRequestHeaders policy `union` simpleHeadersWithoutContentType
simpleHeadersWithoutContentType = simpleHeaders \\ ["content-type"]
commonCorsHeaders ∷ Maybe (Origin, Bool) → Bool → HTTP.ResponseHeaders
commonCorsHeaders Nothing True = [("Access-Control-Allow-Origin", "*"), ("Vary", "Origin")]
commonCorsHeaders Nothing False = [("Access-Control-Allow-Origin", "*")]
commonCorsHeaders (Just (o, False)) _ = [("Access-Control-Allow-Origin", o)]
commonCorsHeaders (Just (o, True)) _ = [("Access-Control-Allow-Origin", o), ("Access-Control-Allow-Credentials", "true")]
respCorsHeaders ∷ CorsResourcePolicy → HTTP.ResponseHeaders
respCorsHeaders policy = catMaybes
[ fmap (\x → ("Access-Control-Expose-Headers", hdrLI x)) (corsExposedHeaders policy)
]
simpleResponseHeaders ∷ [HTTP.HeaderName]
simpleResponseHeaders =
[ "Cache-Control"
, "Content-Language"
, "Content-Type"
, "Expires"
, "Last-Modified"
, "Pragma"
]
simpleHeaders ∷ [HTTP.HeaderName]
simpleHeaders =
[ "Accept"
, "Accept-Language"
, "Content-Language"
, "Content-Type"
]
simpleContentTypes ∷ [CI.CI B8.ByteString]
simpleContentTypes =
[ "application/x-www-form-urlencoded"
, "multipart/form-data"
, "text/plain"
]
simpleMethods ∷ [HTTP.Method]
simpleMethods =
[ "GET"
, "HEAD"
, "POST"
]
isSimple ∷ HTTP.Method → HTTP.RequestHeaders → Bool
isSimple method headers
= method `elem` simpleMethods
∧ map fst headers `isSubsetOf` simpleHeaders
∧ case (method, lookup "content-type" headers) of
("POST", Just x) → CI.mk x `elem` simpleContentTypes
_ → True
httpHeaderNameCharSet ∷ CS.CharSet
httpHeaderNameCharSet = CS.range (toEnum 33) (toEnum 126) CS.\\ CS.fromList "()<>@,;:\\\"/[]?={}"
httpHeaderNameParser ∷ P.CharParsing μ ⇒ μ HTTP.HeaderName
httpHeaderNameParser = fromString <$> P.some (P.oneOfSet httpHeaderNameCharSet) P.<?> "HTTP Header Name"
httpHeaderNameListParser ∷ P.CharParsing μ ⇒ μ [HTTP.HeaderName]
httpHeaderNameListParser = P.spaces *> P.sepBy1 (httpHeaderNameParser <* P.spaces) (P.char ',') <* P.spaces
sshow ∷ (IsString α, Show β) ⇒ β → α
sshow = fromString ∘ show
isSubsetOf ∷ Eq α ⇒ [α] → [α] → Bool
isSubsetOf l1 l2 = intersect l1 l2 ≡ l1
addHeaders ∷ HTTP.ResponseHeaders → WAI.Response → ReqMonad WAI.Response
addHeaders hdrs res = do
#if MIN_VERSION_wai(2,0,0)
let (st, headers, src) = WAI.responseToSource res
WAI.responseSource st (headers ⊕ hdrs) <$> src return
#else
let (st, headers, src) = WAI.responseSource res
return $ WAI.ResponseSource st (headers ⊕ hdrs) src
#endif
hdrLI ∷ [HTTP.HeaderName] → B8.ByteString
hdrLI l = B8.intercalate ", " (map CI.original l)
hdrL ∷ [B8.ByteString] → B8.ByteString
hdrL l = B8.intercalate ", " l
corsFailure
∷ Bool
→ B8.ByteString
→ WAI.Response
corsFailure True msg = WAI.responseLBS HTTP.status400 [("Content-Type", "text/html; charset-utf-8")] (LB8.fromStrict msg)
corsFailure False _ = WAI.responseLBS HTTP.ok200 [] ""