{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}

-- |
-- Module      : Network.Wai.Middleware.EnforceHTTPS
-- Copyright   : (c) Marek Fajkus
-- License     : BSD3
--
-- Maintainer  : marek.faj@gmail.com
--
-- Wai Middleware for enforcing encrypted HTTPS connection safely.
--
-- This module is intended to be imported @qualified@
--
-- > import qualified Network.Wai.Middleware.EnforceHTTPS as EnforceHTTPS
--
-- = Example Usage
--
-- Following is the most typical config.
-- That is GCP, AWS and Heroku compatible setting
-- using @x-forwarded-proto@ header check and default configuration.
--
-- > {-# LANGUAGE OverloadedStrings #-}
-- >
-- > module Main where
-- >
-- > import           Network.HTTP.Types                  (status200)
-- > import           Network.Wai                         (Application, responseLBS)
-- > import           Network.Wai.Handler.Warp            (runEnv)
-- >
-- > import qualified Network.Wai.Middleware.EnforceHTTPS as EnforceHTTPS
-- >
-- > handler :: Application
-- > handler _ respond = respond $
-- >     responseLBS status200 [] "Hello from behind proxy"
-- >
-- > app :: Application
-- > app = EnforceHTTPS.withResolver EnforceHTTPS.xForwardedProto handler
-- >
-- > main :: IO ()
-- > main = runEnv 8080 app

module Network.Wai.Middleware.EnforceHTTPS
  (
 -- * Configuration and Initialization
    EnforceHTTPSConfig(..)
  , defaultConfig
  , def
  , withResolver
  , withConfig
 -- * Provided Resolvers
 -- | This module provides most common implementation
 -- of rrsolvers used by various cloud providers and
 -- reverse proxy implemetations.
  , HTTPSResolver
  , xForwardedProto
  , azure
  , forwarded
  , customProtoHeader
  ) where

import           Data.ByteString        (ByteString)
import           Data.Maybe             (fromMaybe)
import           Data.Monoid            ((<>))
import           Network.HTTP.Types     (Method, Status)
import           Network.Wai            (Application, Middleware, Request)

#if __GLASGOW_HASKELL__ < 710
import           Data.Monoid            (mappend, mempty)
#endif

import qualified Data.ByteString        as ByteString
import qualified Data.CaseInsensitive   as CaseInsensitive
import qualified Data.Text              as Text
import qualified Data.Text.Encoding     as Text
import qualified Network.HTTP.Forwarded as Forwarded
import qualified Network.HTTP.Types     as HTTP
import qualified Network.Wai            as Wai


-- | === Configuration
--
-- `EnforceHTTPSConfig` does export constructor
-- which should not collide with any other functions
-- and therefore can be exposed in import
--
-- > import Network.Wai.Middleware.EnforceHTTPS (EnforceHTTPSConfig(..))
--
-- __Default configuration is recommended__ but you're free
-- to override any default value if you need to.
--
-- Configuration of `httpsIsSecure` can be set using `withResolver`
-- function which is preferred way for overwriting default `Resolver`.
data EnforceHTTPSConfig = EnforceHTTPSConfig
    { EnforceHTTPSConfig -> HTTPSResolver
httpsIsSecure        :: !HTTPSResolver -- ^ Function to detect if reqest was done over secure protocol
    , EnforceHTTPSConfig -> ByteString -> ByteString
httpsHostRewrite     :: !(ByteString -> ByteString) -- ^ Rewrite rule for host (useful for redirecting between domains)
    , EnforceHTTPSConfig -> Int
httpsPort            :: !Int -- ^ Port of secure server
    , EnforceHTTPSConfig -> Bool
httpsIgnoreURL       :: !Bool -- ^ Ignore url (path, query) - redirect to just host
    , EnforceHTTPSConfig -> Bool
httpsTemporary       :: !Bool -- ^ Use termporary redirect
    , EnforceHTTPSConfig -> Bool
httpsSkipDefaultPort :: !Bool -- ^ Avoid sending explicit port if default (443) is specified
    , EnforceHTTPSConfig -> [ByteString]
httpsRedirectMethods :: ![Method] -- ^ Whitelist for methods that should be redirected
    , EnforceHTTPSConfig -> Status
httpsDisallowStatus  :: !Status -- ^ Status to retuned for disallowed methods
    }


-- | Default Configuration
-- Default resolver is proxy to 'Network.Wai.isSecure' function
--
-- * uses request @Host@ header information to resolve hostname
-- * standard HTTPS port @443@
-- * redirect includes path and url params
-- * uses permanent redirect (@301@)
-- * doesn't include @port@ in @Location@ header id port is @443@
-- * redirects @GET@ and @HEAD@ methods
-- * all /other/ methods are resolved with @405@ (Method not Allowed) and with appropriate @Allowed@ header
defaultConfig :: EnforceHTTPSConfig
defaultConfig :: EnforceHTTPSConfig
defaultConfig = EnforceHTTPSConfig :: HTTPSResolver
-> (ByteString -> ByteString)
-> Int
-> Bool
-> Bool
-> Bool
-> [ByteString]
-> Status
-> EnforceHTTPSConfig
EnforceHTTPSConfig
  { httpsIsSecure :: HTTPSResolver
httpsIsSecure        = HTTPSResolver
Wai.isSecure
  , httpsHostRewrite :: ByteString -> ByteString
httpsHostRewrite     = ByteString -> ByteString
forall a. a -> a
id
  , httpsPort :: Int
httpsPort            = Int
443
  , httpsIgnoreURL :: Bool
httpsIgnoreURL       = Bool
False
  , httpsTemporary :: Bool
httpsTemporary       = Bool
False
  , httpsSkipDefaultPort :: Bool
httpsSkipDefaultPort = Bool
True
  , httpsRedirectMethods :: [ByteString]
httpsRedirectMethods = [ ByteString
"GET", ByteString
"HEAD" ]
  , httpsDisallowStatus :: Status
httpsDisallowStatus  = Status
HTTP.methodNotAllowed405
  }
{-# INLINE defaultConfig #-}


-- | Construct `Middleware` for specific `EnforceHTTPSConfig`
withConfig :: EnforceHTTPSConfig -> Middleware
withConfig :: EnforceHTTPSConfig -> Middleware
withConfig conf :: EnforceHTTPSConfig
conf@EnforceHTTPSConfig { Bool
Int
[ByteString]
Status
ByteString -> ByteString
HTTPSResolver
httpsDisallowStatus :: Status
httpsRedirectMethods :: [ByteString]
httpsSkipDefaultPort :: Bool
httpsTemporary :: Bool
httpsIgnoreURL :: Bool
httpsPort :: Int
httpsHostRewrite :: ByteString -> ByteString
httpsIsSecure :: HTTPSResolver
httpsDisallowStatus :: EnforceHTTPSConfig -> Status
httpsRedirectMethods :: EnforceHTTPSConfig -> [ByteString]
httpsSkipDefaultPort :: EnforceHTTPSConfig -> Bool
httpsTemporary :: EnforceHTTPSConfig -> Bool
httpsIgnoreURL :: EnforceHTTPSConfig -> Bool
httpsPort :: EnforceHTTPSConfig -> Int
httpsHostRewrite :: EnforceHTTPSConfig -> ByteString -> ByteString
httpsIsSecure :: EnforceHTTPSConfig -> HTTPSResolver
.. } Application
app Request
req
  | HTTPSResolver
httpsIsSecure Request
req = Application
app Request
req
  | Bool
otherwise = EnforceHTTPSConfig -> Application
redirect EnforceHTTPSConfig
conf Request
req
{-# INLINE withConfig #-}


redirect :: EnforceHTTPSConfig -> Application
redirect :: EnforceHTTPSConfig -> Application
redirect EnforceHTTPSConfig { Bool
Int
[ByteString]
Status
ByteString -> ByteString
HTTPSResolver
httpsDisallowStatus :: Status
httpsRedirectMethods :: [ByteString]
httpsSkipDefaultPort :: Bool
httpsTemporary :: Bool
httpsIgnoreURL :: Bool
httpsPort :: Int
httpsHostRewrite :: ByteString -> ByteString
httpsIsSecure :: HTTPSResolver
httpsDisallowStatus :: EnforceHTTPSConfig -> Status
httpsRedirectMethods :: EnforceHTTPSConfig -> [ByteString]
httpsSkipDefaultPort :: EnforceHTTPSConfig -> Bool
httpsTemporary :: EnforceHTTPSConfig -> Bool
httpsIgnoreURL :: EnforceHTTPSConfig -> Bool
httpsPort :: EnforceHTTPSConfig -> Int
httpsHostRewrite :: EnforceHTTPSConfig -> ByteString -> ByteString
httpsIsSecure :: EnforceHTTPSConfig -> HTTPSResolver
.. } Request
req Response -> IO ResponseReceived
respond = Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$
  case Request -> Maybe ByteString
Wai.requestHeaderHost Request
req of
    -- A Host header field must be sent in all HTTP/1.1 request messages.
    -- A 400 (Bad Request) status code will be sent to any HTTP/1.1 request message
    -- that lacks a Host header field or contains more than one.
    -- source: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host
    Maybe ByteString
Nothing -> Status -> ResponseHeaders -> Builder -> Response
Wai.responseBuilder Status
HTTP.status400 [] Builder
forall a. Monoid a => a
mempty
    Just ByteString
h  -> Status -> ResponseHeaders -> Builder -> Response
Wai.responseBuilder Status
status (ByteString -> ResponseHeaders
headers (ByteString -> ResponseHeaders) -> ByteString -> ResponseHeaders
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
stripPort ByteString
h) Builder
forall a. Monoid a => a
mempty

  where
    ( Status
status, ByteString -> ResponseHeaders
headers ) =
      if ByteString
reqMethod ByteString -> [ByteString] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
httpsRedirectMethods then
        ( if Bool
httpsTemporary then
            Status
HTTP.status307
          else
            Status
HTTP.status301
        , (HeaderName, ByteString) -> ResponseHeaders
forall (m :: * -> *) a. Monad m => a -> m a
return ((HeaderName, ByteString) -> ResponseHeaders)
-> (ByteString -> (HeaderName, ByteString))
-> ByteString
-> ResponseHeaders
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> (HeaderName, ByteString)
redirectURL
        )

      else
        ( Status
httpsDisallowStatus
        , ResponseHeaders -> ByteString -> ResponseHeaders
forall a b. a -> b -> a
const (ResponseHeaders -> ByteString -> ResponseHeaders)
-> ResponseHeaders -> ByteString -> ResponseHeaders
forall a b. (a -> b) -> a -> b
$
          if Status
httpsDisallowStatus Status -> Status -> Bool
forall a. Eq a => a -> a -> Bool
== Status
HTTP.methodNotAllowed405 then
            [ (HeaderName
"Allow", ByteString -> [ByteString] -> ByteString
ByteString.intercalate ByteString
", " [ByteString]
httpsRedirectMethods) ]
          else
            []
        )

    redirectURL :: ByteString -> (HeaderName, ByteString)
redirectURL ByteString
h =
      ( HeaderName
HTTP.hLocation, ByteString
"https://" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
fullHost ByteString
h ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
path)

    path :: ByteString
path =
      if Bool
httpsIgnoreURL then
        ByteString
forall a. Monoid a => a
mempty
      else
        Request -> ByteString
Wai.rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
Wai.rawQueryString Request
req

    port :: ByteString
port =
      if Int
httpsPort Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
443 Bool -> Bool -> Bool
&& Bool
httpsSkipDefaultPort then
        ByteString
""
      else
        Text -> ByteString
Text.encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ (Text -> Text -> Text
forall a. Monoid a => a -> a -> a
mappend Text
":") (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ String -> Text
Text.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ Int -> String
forall a. Show a => a -> String
show Int
httpsPort

    stripPort :: ByteString -> ByteString
stripPort ByteString
h =
      (ByteString, ByteString) -> ByteString
forall a b. (a, b) -> a
fst ((ByteString, ByteString) -> ByteString)
-> (ByteString, ByteString) -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
ByteString.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
58) ByteString
h -- colon

    fullHost :: ByteString -> ByteString
fullHost ByteString
h = ByteString -> ByteString
httpsHostRewrite ByteString
h ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
port
    reqMethod :: ByteString
reqMethod = Request -> ByteString
Wai.requestMethod Request
req


-- | `Middleware` with /default/ configuration.
-- See 'defaultConfig' for more details.
def :: Middleware
def :: Middleware
def =
  EnforceHTTPSConfig -> Middleware
withConfig EnforceHTTPSConfig
defaultConfig
{-# INLINE def #-}


-- | Construct middleware with provided `Resolver`
-- See `Provided Resolvers` section for more information.
withResolver :: HTTPSResolver -> Middleware
withResolver :: HTTPSResolver -> Middleware
withResolver HTTPSResolver
resolver =
  EnforceHTTPSConfig -> Middleware
withConfig (EnforceHTTPSConfig -> Middleware)
-> EnforceHTTPSConfig -> Middleware
forall a b. (a -> b) -> a -> b
$ EnforceHTTPSConfig
defaultConfig { httpsIsSecure :: HTTPSResolver
httpsIsSecure = HTTPSResolver
resolver }
{-# INLINE withResolver #-}


-- | Resolvers are function used for testing
-- if Request is made over secure HTTPS protocol.
--
-- if `True` is returned from a `Resolver` function,
-- request is considered to be secure.
-- In case of `False` value, redirect logic is called.
type HTTPSResolver =
  Request -> Bool


-- | Resolver checking value of @x-forwarded-proto@ HTTP header.
-- This header is for instance used by Heroku or GCP Ingress
-- among many others.
--
-- Request is secure if value of header is `https`
-- otherwise request is considered not being secure.
xForwardedProto :: HTTPSResolver
xForwardedProto :: HTTPSResolver
xForwardedProto Request
req =
  Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"https") Maybe ByteString
maybeHederVal
  where
    maybeHederVal :: Maybe ByteString
maybeHederVal =
      HeaderName
"x-forwarded-proto" HeaderName -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` Request -> ResponseHeaders
Wai.requestHeaders Request
req
{-# INLINE xForwardedProto #-}


-- | Azure is proxying with additional
-- `x-arr-ssl` header if original protocol is HTTPS.
-- This resolver checks for the presence of this header.
azure :: HTTPSResolver
azure :: HTTPSResolver
azure Request
req =
  Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Bool -> ByteString -> Bool
forall a b. a -> b -> a
const Bool
True) Maybe ByteString
maybeHeader
  where
    maybeHeader :: Maybe ByteString
maybeHeader =
      HeaderName
"x-arr-ssl" HeaderName -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` Request -> ResponseHeaders
Wai.requestHeaders Request
req
{-# INLINE azure #-}


-- | Some reverse proxies (Kong) are using
-- values similar to @x-forwarded-proto@
-- but are using different headers.
-- This resolver allows you to specify name of header
-- which should be used for the check.
-- Like `xForwardedProto`, request is considered
-- as being secure if value of header is @https@.
customProtoHeader :: ByteString -> HTTPSResolver
customProtoHeader :: ByteString -> HTTPSResolver
customProtoHeader ByteString
header Request
req =
  Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"https") Maybe ByteString
maybeHederVal
  where
    maybeHederVal :: Maybe ByteString
maybeHederVal =
      ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
CaseInsensitive.mk ByteString
header HeaderName -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` Request -> ResponseHeaders
Wai.requestHeaders Request
req
{-# INLINE customProtoHeader #-}


-- | Forwarded HTTP header is relatively new standard
-- which should replaced all @x-*@ adhoc headers by
-- standardized one.
-- This resolver is using @proto=foo@ part of the header
-- and check for equality with @https@ value.
--
-- More information can be found on [MDN](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Forwarded)
-- Complete implementation of @Forwarded@ is located in
-- @Network.HTTP.Forwarded@ module
forwarded :: HTTPSResolver
forwarded :: HTTPSResolver
forwarded Request
req =
  Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False ByteString -> Bool
check Maybe ByteString
maybeHeader
  where
    check :: ByteString -> Bool
check ByteString
val =
      Bool -> (HeaderName -> Bool) -> Maybe HeaderName -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
"https") (Maybe HeaderName -> Bool) -> Maybe HeaderName -> Bool
forall a b. (a -> b) -> a -> b
$
      Forwarded -> Maybe HeaderName
Forwarded.forwardedProto (Forwarded -> Maybe HeaderName) -> Forwarded -> Maybe HeaderName
forall a b. (a -> b) -> a -> b
$ ByteString -> Forwarded
Forwarded.parseForwarded ByteString
val

    maybeHeader :: Maybe ByteString
maybeHeader =
      HeaderName
"forwarded" HeaderName -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` Request -> ResponseHeaders
Wai.requestHeaders Request
req
{-# INLINE forwarded #-}