{-# LANGUAGE OverloadedStrings #-}
module Network.Wai.Auth.AppRoot
  ( smartAppRoot
  ) where

import           Data.ByteString          (ByteString)
import           Data.CaseInsensitive     (CI, mk)
import qualified Data.HashMap.Lazy        as HM
import qualified Data.Text                as T
import           Data.Text.Encoding       (decodeUtf8With)
import           Data.Text.Encoding.Error (lenientDecode)
import           Network.HTTP.Types       (Header)
import           Network.Wai              (Request, isSecure, requestHeaderHost,
                                           requestHeaders)


-- | Determine approot by:
--
-- * Respect the Host header and isSecure property, together with the following de facto standards: x-forwarded-protocol, x-forwarded-ssl, x-url-scheme, x-forwarded-proto, front-end-https. (Note: this list may be updated at will in the future without doc updates.)
--
-- Normally trusting headers in this way is insecure, however in the case of approot, the worst that can happen is that the client will get an incorrect URL. Note that this does not work for some situations, e.g.:
--
-- * Reverse proxies not setting one of the above mentioned headers
--
-- * Applications hosted somewhere besides the root of the domain name
--
-- * Reverse proxies that modify the host header
--
-- @since 0.1.0.0
smartAppRoot :: Request -> T.Text
smartAppRoot :: Request -> Text
smartAppRoot Request
req =
  let secure :: Bool
secure = Request -> Bool
isSecure Request
req Bool -> Bool -> Bool
|| (Header -> Bool) -> [Header] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Header -> Bool
isSecureHeader (Request -> [Header]
requestHeaders Request
req)
      host :: Text
host =
        Text -> (ByteString -> Text) -> Maybe ByteString -> Text
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Text
"localhost" (OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode) (Request -> Maybe ByteString
requestHeaderHost Request
req)
  in (if Bool
secure
        then Text
"https://"
        else Text
"http://") Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<>
     Text
host

-- |
--
-- See: http://stackoverflow.com/a/16042648/369198
httpsHeaders :: HM.HashMap (CI ByteString) (CI ByteString)
httpsHeaders :: HashMap (CI ByteString) (CI ByteString)
httpsHeaders =
  [(CI ByteString, CI ByteString)]
-> HashMap (CI ByteString) (CI ByteString)
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
HM.fromList
    [ (CI ByteString
"X-Forwarded-Protocol", CI ByteString
"https")
    , (CI ByteString
"X-Forwarded-Ssl", CI ByteString
"on")
    , (CI ByteString
"X-Url-Scheme", CI ByteString
"https")
    , (CI ByteString
"X-Forwarded-Proto", CI ByteString
"https")
    , (CI ByteString
"Front-End-Https", CI ByteString
"on")
    ]

isSecureHeader :: Header -> Bool
isSecureHeader :: Header -> Bool
isSecureHeader (CI ByteString
key, ByteString
value) =
  case CI ByteString
-> HashMap (CI ByteString) (CI ByteString) -> Maybe (CI ByteString)
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HM.lookup CI ByteString
key HashMap (CI ByteString) (CI ByteString)
httpsHeaders of
    Maybe (CI ByteString)
Nothing     -> Bool
False
    Just CI ByteString
value' -> CI ByteString
valueCI CI ByteString -> CI ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== CI ByteString
value'
  where
    valueCI :: CI ByteString
valueCI = ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
mk ByteString
value