{-# LANGUAGE OverloadedStrings #-}
module Snap.Util.CORS
(
applyCORS
, CORSOptions(..)
, defaultOptions
, OriginList(..)
, OriginSet, mkOriginSet, origins
, HashableURI(..), HashableMethod (..)
) where
import Control.Applicative
import Control.Monad (join, when)
import Data.CaseInsensitive (CI)
import Data.Hashable (Hashable(..))
import Data.Maybe (fromMaybe)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Network.URI (URI (..), URIAuth (..), parseURI)
import qualified Data.Attoparsec.ByteString.Char8 as Attoparsec
import qualified Data.ByteString.Char8 as S
import qualified Data.CaseInsensitive as CI
import qualified Data.HashSet as HashSet
import qualified Data.Text as Text
import qualified Snap.Core as Snap
import Snap.Internal.Parsing (pTokens)
newtype OriginSet = OriginSet { origins :: HashSet.HashSet HashableURI }
data OriginList
= Everywhere
| Nowhere
| Origins OriginSet
data CORSOptions m = CORSOptions
{ corsAllowOrigin :: m OriginList
, corsAllowCredentials :: m Bool
, corsExposeHeaders :: m (HashSet.HashSet (CI S.ByteString))
, corsAllowedMethods :: m (HashSet.HashSet HashableMethod)
, corsAllowedHeaders :: HashSet.HashSet S.ByteString -> m (HashSet.HashSet S.ByteString)
}
defaultOptions :: Monad m => CORSOptions m
defaultOptions = CORSOptions
{ corsAllowOrigin = return Everywhere
, corsAllowCredentials = return True
, corsExposeHeaders = return HashSet.empty
, corsAllowedMethods = return $! defaultAllowedMethods
, corsAllowedHeaders = return
}
defaultAllowedMethods :: HashSet.HashSet HashableMethod
defaultAllowedMethods = HashSet.fromList $ map HashableMethod
[ Snap.GET, Snap.POST, Snap.PUT, Snap.DELETE, Snap.HEAD ]
applyCORS :: Snap.MonadSnap m => CORSOptions m -> m () -> m ()
applyCORS options m =
(join . fmap decodeOrigin <$> getHeader "Origin") >>= maybe m corsRequestFrom
where
corsRequestFrom origin = do
originList <- corsAllowOrigin options
if origin `inOriginList` originList
then Snap.method Snap.OPTIONS (preflightRequestFrom origin)
<|> handleRequestFrom origin
else m
preflightRequestFrom origin = do
maybeMethod <- fmap (parseMethod . S.unpack) <$>
getHeader "Access-Control-Request-Method"
case maybeMethod of
Nothing -> m
Just method -> do
allowedMethods <- corsAllowedMethods options
if method `HashSet.member` allowedMethods
then do
maybeHeaders <-
fromMaybe (Just HashSet.empty) . fmap splitHeaders
<$> getHeader "Access-Control-Request-Headers"
case maybeHeaders of
Nothing -> m
Just headers -> do
allowedHeaders <- corsAllowedHeaders options headers
if not $ HashSet.null $
headers `HashSet.difference` allowedHeaders
then m
else do
addAccessControlAllowOrigin origin
addAccessControlAllowCredentials
commaSepHeader
"Access-Control-Allow-Headers"
id (HashSet.toList allowedHeaders)
commaSepHeader
"Access-Control-Allow-Methods"
(S.pack . show) (HashSet.toList allowedMethods)
else m
handleRequestFrom origin = do
addAccessControlAllowOrigin origin
addAccessControlAllowCredentials
exposeHeaders <- corsExposeHeaders options
when (not $ HashSet.null exposeHeaders) $
commaSepHeader
"Access-Control-Expose-Headers"
CI.original (HashSet.toList exposeHeaders)
m
addAccessControlAllowOrigin origin =
addHeader "Access-Control-Allow-Origin"
(encodeUtf8 $ Text.pack $ show origin)
addAccessControlAllowCredentials = do
allowCredentials <- corsAllowCredentials options
when (allowCredentials) $
addHeader "Access-Control-Allow-Credentials" "true"
decodeOrigin :: S.ByteString -> Maybe URI
decodeOrigin = fmap simplifyURI . parseURI . Text.unpack . decodeUtf8
addHeader k v = Snap.modifyResponse (Snap.addHeader k v)
commaSepHeader k f vs =
case vs of
[] -> return ()
_ -> addHeader k $ S.intercalate ", " (map f vs)
getHeader = Snap.getsRequest . Snap.getHeader
splitHeaders = either (const Nothing) (Just . HashSet.fromList) .
Attoparsec.parseOnly pTokens
mkOriginSet :: [URI] -> OriginSet
mkOriginSet = OriginSet . HashSet.fromList .
map (HashableURI . simplifyURI)
simplifyURI :: URI -> URI
simplifyURI uri = uri { uriAuthority =
fmap simplifyURIAuth (uriAuthority uri)
, uriPath = ""
, uriQuery = ""
, uriFragment = ""
}
where simplifyURIAuth auth = auth { uriUserInfo = "" }
parseMethod :: String -> HashableMethod
parseMethod "GET" = HashableMethod Snap.GET
parseMethod "POST" = HashableMethod Snap.POST
parseMethod "HEAD" = HashableMethod Snap.HEAD
parseMethod "PUT" = HashableMethod Snap.PUT
parseMethod "DELETE" = HashableMethod Snap.DELETE
parseMethod "TRACE" = HashableMethod Snap.TRACE
parseMethod "OPTIONS" = HashableMethod Snap.OPTIONS
parseMethod "CONNECT" = HashableMethod Snap.CONNECT
parseMethod "PATCH" = HashableMethod Snap.PATCH
parseMethod s = HashableMethod $ Snap.Method (S.pack s)
newtype HashableURI = HashableURI URI
deriving (Eq)
instance Show HashableURI where
show (HashableURI u) = show u
instance Hashable HashableURI where
hashWithSalt s (HashableURI (URI scheme authority path query fragment)) =
s `hashWithSalt`
scheme `hashWithSalt`
fmap hashAuthority authority `hashWithSalt`
path `hashWithSalt`
query `hashWithSalt`
fragment
where
hashAuthority (URIAuth userInfo regName port) =
s `hashWithSalt`
userInfo `hashWithSalt`
regName `hashWithSalt`
port
inOriginList :: URI -> OriginList -> Bool
_ `inOriginList` Nowhere = False
_ `inOriginList` Everywhere = True
origin `inOriginList` (Origins (OriginSet xs)) =
HashableURI origin `HashSet.member` xs
newtype HashableMethod = HashableMethod Snap.Method
deriving (Eq)
instance Hashable HashableMethod where
hashWithSalt s (HashableMethod Snap.GET) = s `hashWithSalt` (0 :: Int)
hashWithSalt s (HashableMethod Snap.HEAD) = s `hashWithSalt` (1 :: Int)
hashWithSalt s (HashableMethod Snap.POST) = s `hashWithSalt` (2 :: Int)
hashWithSalt s (HashableMethod Snap.PUT) = s `hashWithSalt` (3 :: Int)
hashWithSalt s (HashableMethod Snap.DELETE) = s `hashWithSalt` (4 :: Int)
hashWithSalt s (HashableMethod Snap.TRACE) = s `hashWithSalt` (5 :: Int)
hashWithSalt s (HashableMethod Snap.OPTIONS) = s `hashWithSalt` (6 :: Int)
hashWithSalt s (HashableMethod Snap.CONNECT) = s `hashWithSalt` (7 :: Int)
hashWithSalt s (HashableMethod Snap.PATCH) = s `hashWithSalt` (8 :: Int)
hashWithSalt s (HashableMethod (Snap.Method m)) =
s `hashWithSalt` (9 :: Int) `hashWithSalt` m
instance Show HashableMethod where
show (HashableMethod m) = show m