module Network.Wai.Middleware.RealIp (
realIp,
realIpHeader,
realIpTrusted,
defaultTrusted,
ipInRange,
) where
import qualified Data.ByteString.Char8 as B8 (split, unpack)
import qualified Data.IP as IP
import Data.Maybe (fromMaybe, listToMaybe, mapMaybe)
import Network.HTTP.Types (HeaderName, RequestHeaders)
import Network.Wai (Middleware, remoteHost, requestHeaders)
import Text.Read (readMaybe)
realIp :: Middleware
realIp :: Middleware
realIp = HeaderName -> Middleware
realIpHeader HeaderName
"X-Forwarded-For"
realIpHeader :: HeaderName -> Middleware
HeaderName
header =
HeaderName -> (IP -> Bool) -> Middleware
realIpTrusted HeaderName
header ((IP -> Bool) -> Middleware) -> (IP -> Bool) -> Middleware
forall a b. (a -> b) -> a -> b
$ \IP
ip -> (IPRange -> Bool) -> [IPRange] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (IP -> IPRange -> Bool
ipInRange IP
ip) [IPRange]
defaultTrusted
realIpTrusted :: HeaderName -> (IP.IP -> Bool) -> Middleware
realIpTrusted :: HeaderName -> (IP -> Bool) -> Middleware
realIpTrusted HeaderName
header IP -> Bool
isTrusted Application
app Request
req = Application
app Request
req'
where
req' :: Request
req' = Request -> Maybe Request -> Request
forall a. a -> Maybe a -> a
fromMaybe Request
req (Maybe Request -> Request) -> Maybe Request -> Request
forall a b. (a -> b) -> a -> b
$ do
(IP
ip, PortNumber
port) <- SockAddr -> Maybe (IP, PortNumber)
IP.fromSockAddr (Request -> SockAddr
remoteHost Request
req)
IP
ip' <-
if IP -> Bool
isTrusted IP
ip
then RequestHeaders -> HeaderName -> (IP -> Bool) -> Maybe IP
findRealIp (Request -> RequestHeaders
requestHeaders Request
req) HeaderName
header IP -> Bool
isTrusted
else Maybe IP
forall a. Maybe a
Nothing
Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> Maybe Request) -> Request -> Maybe Request
forall a b. (a -> b) -> a -> b
$ Request
req{remoteHost = IP.toSockAddr (ip', port)}
defaultTrusted :: [IP.IPRange]
defaultTrusted :: [IPRange]
defaultTrusted =
[ IPRange
"127.0.0.0/8"
, IPRange
"10.0.0.0/8"
, IPRange
"172.16.0.0/12"
, IPRange
"192.168.0.0/16"
, IPRange
"::1/128"
, IPRange
"fc00::/7"
]
ipInRange :: IP.IP -> IP.IPRange -> Bool
ipInRange :: IP -> IPRange -> Bool
ipInRange (IP.IPv4 IPv4
ip) (IP.IPv4Range AddrRange IPv4
r) = IPv4
ip IPv4 -> AddrRange IPv4 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv4
r
ipInRange (IP.IPv6 IPv6
ip) (IP.IPv6Range AddrRange IPv6
r) = IPv6
ip IPv6 -> AddrRange IPv6 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv6
r
ipInRange (IP.IPv4 IPv4
ip) (IP.IPv6Range AddrRange IPv6
r) = IPv4 -> IPv6
IP.ipv4ToIPv6 IPv4
ip IPv6 -> AddrRange IPv6 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv6
r
ipInRange IP
_ IPRange
_ = Bool
False
findRealIp :: RequestHeaders -> HeaderName -> (IP.IP -> Bool) -> Maybe IP.IP
findRealIp :: RequestHeaders -> HeaderName -> (IP -> Bool) -> Maybe IP
findRealIp RequestHeaders
reqHeaders HeaderName
header IP -> Bool
isTrusted =
case ([IP]
nonTrusted, [IP]
ips) of
([], [IP]
xs) -> [IP] -> Maybe IP
forall a. [a] -> Maybe a
listToMaybe [IP]
xs
([IP]
xs, [IP]
_) -> [IP] -> Maybe IP
forall a. [a] -> Maybe a
listToMaybe ([IP] -> Maybe IP) -> [IP] -> Maybe IP
forall a b. (a -> b) -> a -> b
$ [IP] -> [IP]
forall a. [a] -> [a]
reverse [IP]
xs
where
headerVals :: [ByteString]
headerVals = [ByteString
v | (HeaderName
k, ByteString
v) <- RequestHeaders
reqHeaders, HeaderName
k HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
header]
ips :: [IP]
ips = (ByteString -> [IP]) -> [ByteString] -> [IP]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((ByteString -> Maybe IP) -> [ByteString] -> [IP]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (String -> Maybe IP
forall a. Read a => String -> Maybe a
readMaybe (String -> Maybe IP)
-> (ByteString -> String) -> ByteString -> Maybe IP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
B8.unpack) ([ByteString] -> [IP])
-> (ByteString -> [ByteString]) -> ByteString -> [IP]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ByteString -> [ByteString]
B8.split Char
',') [ByteString]
headerVals
nonTrusted :: [IP]
nonTrusted = (IP -> Bool) -> [IP] -> [IP]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (IP -> Bool) -> IP -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IP -> Bool
isTrusted) [IP]
ips