{-# LANGUAGE CPP #-}
module Network.Wai.Middleware.ForceDomain where
import Data.ByteString (ByteString)
#if __GLASGOW_HASKELL__ < 804
import Data.Monoid ((<>))
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (mempty)
#endif
#endif
import Network.HTTP.Types (hLocation, methodGet, status301, status307)
import Network.Wai (Middleware, Request (..), responseBuilder)
import Network.Wai.Request (appearsSecure)
forceDomain :: (ByteString -> Maybe ByteString) -> Middleware
forceDomain :: (ByteString -> Maybe ByteString) -> Middleware
forceDomain ByteString -> Maybe ByteString
checkDomain Application
app Request
req Response -> IO ResponseReceived
sendResponse =
case Request -> Maybe ByteString
requestHeaderHost Request
req Maybe ByteString
-> (ByteString -> Maybe ByteString) -> Maybe ByteString
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Maybe ByteString
checkDomain of
Maybe ByteString
Nothing ->
Application
app Request
req Response -> IO ResponseReceived
sendResponse
Just ByteString
domain ->
Response -> IO ResponseReceived
sendResponse (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ ByteString -> Response
redirectResponse ByteString
domain
where
redirectResponse :: ByteString -> Response
redirectResponse ByteString
domain =
Status -> ResponseHeaders -> Builder -> Response
responseBuilder Status
status [(HeaderName
hLocation, ByteString -> ByteString
location ByteString
domain)] Builder
forall a. Monoid a => a
mempty
location :: ByteString -> ByteString
location ByteString
h =
let p :: ByteString
p = if Request -> Bool
appearsSecure Request
req then ByteString
"https://" else ByteString
"http://"
in ByteString
p ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
h ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawQueryString Request
req
status :: Status
status
| Request -> ByteString
requestMethod Request
req ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
methodGet = Status
status301
| Bool
otherwise = Status
status307