{-|
Module      : PostgREST.Cors
Description : Wai Middleware to set cors policy.
-}
module PostgREST.Cors (middleware) where

import qualified Data.ByteString.Char8       as BS
import qualified Data.CaseInsensitive        as CI
import qualified Network.Wai                 as Wai
import qualified Network.Wai.Middleware.Cors as Wai

import Data.List (lookup)

import Protolude

middleware :: Wai.Middleware
middleware :: Middleware
middleware = (Request -> Maybe CorsResourcePolicy) -> Middleware
Wai.cors Request -> Maybe CorsResourcePolicy
corsPolicy

-- | CORS policy to be used in by Wai Cors middleware
corsPolicy :: Wai.Request -> Maybe Wai.CorsResourcePolicy
corsPolicy :: Request -> Maybe CorsResourcePolicy
corsPolicy Request
req = case HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"origin" [(HeaderName, ByteString)]
headers of
  Just ByteString
origin ->
    CorsResourcePolicy -> Maybe CorsResourcePolicy
forall a. a -> Maybe a
Just CorsResourcePolicy :: Maybe ([ByteString], Bool)
-> [ByteString]
-> [HeaderName]
-> Maybe [HeaderName]
-> Maybe Int
-> Bool
-> Bool
-> Bool
-> CorsResourcePolicy
Wai.CorsResourcePolicy
    { corsOrigins :: Maybe ([ByteString], Bool)
Wai.corsOrigins = ([ByteString], Bool) -> Maybe ([ByteString], Bool)
forall a. a -> Maybe a
Just ([ByteString
origin], Bool
True)
    , corsMethods :: [ByteString]
Wai.corsMethods = [ByteString
"GET", ByteString
"POST", ByteString
"PATCH", ByteString
"PUT", ByteString
"DELETE", ByteString
"OPTIONS"]
    , corsRequestHeaders :: [HeaderName]
Wai.corsRequestHeaders = HeaderName
"Authorization" HeaderName -> [HeaderName] -> [HeaderName]
forall a. a -> [a] -> [a]
: [HeaderName]
accHeaders
    , corsExposedHeaders :: Maybe [HeaderName]
Wai.corsExposedHeaders = [HeaderName] -> Maybe [HeaderName]
forall a. a -> Maybe a
Just
      [ HeaderName
"Content-Encoding", HeaderName
"Content-Location", HeaderName
"Content-Range", HeaderName
"Content-Type"
      , HeaderName
"Date", HeaderName
"Location", HeaderName
"Server", HeaderName
"Transfer-Encoding", HeaderName
"Range-Unit"]
    , corsMaxAge :: Maybe Int
Wai.corsMaxAge = Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ Int
60Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
60Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
24
    , corsVaryOrigin :: Bool
Wai.corsVaryOrigin = Bool
False
    , corsRequireOrigin :: Bool
Wai.corsRequireOrigin = Bool
False
    , corsIgnoreFailures :: Bool
Wai.corsIgnoreFailures = Bool
True
    }
  Maybe ByteString
Nothing -> Maybe CorsResourcePolicy
forall a. Maybe a
Nothing
  where
    headers :: [(HeaderName, ByteString)]
headers = Request -> [(HeaderName, ByteString)]
Wai.requestHeaders Request
req
    accHeaders :: [HeaderName]
accHeaders = case HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"access-control-request-headers" [(HeaderName, ByteString)]
headers of
      Just ByteString
hdrs -> (ByteString -> HeaderName) -> [ByteString] -> [HeaderName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map (ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
CI.mk (ByteString -> HeaderName)
-> (ByteString -> ByteString) -> ByteString -> HeaderName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BS.strip) ([ByteString] -> [HeaderName]) -> [ByteString] -> [HeaderName]
forall a b. (a -> b) -> a -> b
$ Char -> ByteString -> [ByteString]
BS.split Char
',' ByteString
hdrs
       -- Impossible case, Middleware.Cors will not evaluate this when
       -- the Access-Control-Request-Headers header is not set.
      Maybe ByteString
Nothing   -> []