{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TupleSections #-}
module System.IO.Streams.Network.HAProxy
(
behindHAProxy
, behindHAProxyWithLocalInfo
, decodeHAProxyHeaders
, ProxyInfo
, socketToProxyInfo
, makeProxyInfo
, getSourceAddr
, getDestAddr
, getFamily
, getSocketType
) where
import Control.Applicative ((<|>))
import Control.Monad (void, when)
import Data.Attoparsec.ByteString (anyWord8)
import Data.Attoparsec.ByteString.Char8 (Parser, char, decimal, skipWhile, string, take, takeWhile1)
import Data.Bits (unsafeShiftR, (.&.))
import qualified Data.ByteString as S8
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as S
import qualified Data.ByteString.Unsafe as S
import Data.Word (Word16, Word32, Word8)
import Foreign.C.Types (CUInt (..), CUShort (..))
import Foreign.Ptr (castPtr)
import Foreign.Storable (peek)
import qualified Network.Socket as N
import Prelude hiding (take)
import System.IO.Streams (InputStream, OutputStream)
import qualified System.IO.Streams as Streams
import qualified System.IO.Streams.Attoparsec as Streams
import System.IO.Streams.Network.Internal.Address (getSockAddr)
import System.IO.Unsafe (unsafePerformIO)
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>))
#endif
socketToProxyInfo :: N.Socket -> N.SockAddr -> IO ProxyInfo
socketToProxyInfo s sa = do
da <- N.getSocketName s
!sty <- getSockType
return $! makeProxyInfo sa da (addrFamily sa) sty
where
#if MIN_VERSION_network(2,7,0)
getSockType = do
c <- N.getSocketOption s N.Type
case c of
1 -> return N.Stream
2 -> return N.Datagram
_ -> error ("bad socket type: " ++ show c)
#else
getSockType = let (N.MkSocket _ _ sty _ _) = s in return sty
#endif
behindHAProxy :: N.Socket
-> N.SockAddr
-> (ProxyInfo
-> InputStream ByteString
-> OutputStream ByteString
-> IO a)
-> IO a
behindHAProxy socket sa m = do
pinfo <- socketToProxyInfo socket sa
sockets <- Streams.socketToStreams socket
behindHAProxyWithLocalInfo pinfo sockets m
behindHAProxyWithLocalInfo
:: ProxyInfo
-> (InputStream ByteString, OutputStream ByteString)
-> (ProxyInfo
-> InputStream ByteString
-> OutputStream ByteString
-> IO a)
-> IO a
behindHAProxyWithLocalInfo localProxyInfo (is, os) m = do
proxyInfo <- decodeHAProxyHeaders localProxyInfo is
m proxyInfo is os
decodeHAProxyHeaders :: ProxyInfo -> (InputStream ByteString) -> IO ProxyInfo
decodeHAProxyHeaders localProxyInfo is0 = do
is <- Streams.throwIfProducesMoreThan 536 is0
(!isOld, !mbOldInfo) <- Streams.parseFromStream
(((True,) <$> parseOldHaProxy)
<|> return (False, Nothing)) is
if isOld
then maybe (return localProxyInfo)
(\(srcAddr, srcPort, destAddr, destPort, f) -> do
(_, s) <- getSockAddr srcPort srcAddr
(_, d) <- getSockAddr destPort destAddr
return $! makeProxyInfo s d f $ getSocketType localProxyInfo)
mbOldInfo
else Streams.parseFromStream (parseNewHaProxy localProxyInfo) is
data ProxyInfo = ProxyInfo {
_sourceAddr :: N.SockAddr
, _destAddr :: N.SockAddr
, _family :: N.Family
, _sockType :: N.SocketType
} deriving (Show)
getFamily :: ProxyInfo -> N.Family
getFamily p = _family p
getSocketType :: ProxyInfo -> N.SocketType
getSocketType p = _sockType p
getSourceAddr :: ProxyInfo -> N.SockAddr
getSourceAddr p = _sourceAddr p
getDestAddr :: ProxyInfo -> N.SockAddr
getDestAddr p = _destAddr p
makeProxyInfo :: N.SockAddr
-> N.SockAddr
-> N.Family
-> N.SocketType
-> ProxyInfo
makeProxyInfo srcAddr destAddr f st = ProxyInfo srcAddr destAddr f st
parseFamily :: Parser (Maybe N.Family)
parseFamily = (string "TCP4" >> return (Just N.AF_INET))
<|> (string "TCP6" >> return (Just N.AF_INET6))
<|> (string "UNKNOWN" >> return Nothing)
parseOldHaProxy :: Parser (Maybe (ByteString, Int, ByteString, Int, N.Family))
parseOldHaProxy = do
string "PROXY "
gotFamily <- parseFamily
case gotFamily of
Nothing -> skipWhile (/= '\r') >> string "\r\n" >> return Nothing
(Just f) -> do
char ' '
srcAddress <- takeWhile1 (/= ' ')
char ' '
destAddress <- takeWhile1 (/= ' ')
char ' '
srcPort <- decimal
char ' '
destPort <- decimal
string "\r\n"
return $! Just $! (srcAddress, srcPort, destAddress, destPort, f)
protocolHeader :: ByteString
protocolHeader = S8.pack [ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D
, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A ]
{-# NOINLINE protocolHeader #-}
parseNewHaProxy :: ProxyInfo -> Parser ProxyInfo
parseNewHaProxy localProxyInfo = do
string protocolHeader
versionAndCommand <- anyWord8
let version = (versionAndCommand .&. 0xF0) `unsafeShiftR` 4
let command = (versionAndCommand .&. 0xF) :: Word8
when (version /= 0x2) $ fail $ "Invalid protocol version: " ++ show version
when (command > 1) $ fail $ "Invalid command: " ++ show command
protocolAndFamily <- anyWord8
let family = (protocolAndFamily .&. 0xF0) `unsafeShiftR` 4
let protocol = (protocolAndFamily .&. 0xF) :: Word8
socketType <- toSocketType protocol
addressLen <- ntohs <$> snarf16
case () of
!_ | command == 0x0 || family == 0x0 || protocol == 0x0
-> handleLocal addressLen
| family == 0x1 -> handleIPv4 addressLen socketType
| family == 0x2 -> handleIPv6 addressLen socketType
#ifndef WINDOWS
| family == 0x3 -> handleUnix addressLen socketType
#endif
| otherwise -> fail $ "Bad family " ++ show family
where
toSocketType 0 = return $! N.Stream
toSocketType 1 = return $! N.Stream
toSocketType 2 = return $! N.Datagram
toSocketType _ = fail "bad protocol"
handleLocal addressLen = do
when (addressLen > 500) $ fail $ "suspiciously long address "
++ show addressLen
void $ take (fromIntegral addressLen)
return localProxyInfo
handleIPv4 addressLen socketType = do
when (addressLen < 12) $ fail $ "bad address length "
++ show addressLen
++ " for IPv4"
let nskip = addressLen - 12
srcAddr <- snarf32
destAddr <- snarf32
srcPort <- ntohs <$> snarf16
destPort <- ntohs <$> snarf16
void $ take $ fromIntegral nskip
let sa = N.SockAddrInet (fromIntegral srcPort) srcAddr
let sb = N.SockAddrInet (fromIntegral destPort) destAddr
return $! makeProxyInfo sa sb (addrFamily sa) socketType
handleIPv6 addressLen socketType = do
let scopeId = 0
let flow = 0
when (addressLen < 36) $ fail $ "bad address length "
++ show addressLen
++ " for IPv6"
let nskip = addressLen - 36
s1 <- ntohl <$> snarf32
s2 <- ntohl <$> snarf32
s3 <- ntohl <$> snarf32
s4 <- ntohl <$> snarf32
d1 <- ntohl <$> snarf32
d2 <- ntohl <$> snarf32
d3 <- ntohl <$> snarf32
d4 <- ntohl <$> snarf32
sp <- ntohs <$> snarf16
dp <- ntohs <$> snarf16
void $ take $ fromIntegral nskip
let sa = N.SockAddrInet6 (fromIntegral sp) flow (s1, s2, s3, s4) scopeId
let sb = N.SockAddrInet6 (fromIntegral dp) flow (d1, d2, d3, d4) scopeId
return $! makeProxyInfo sa sb (addrFamily sa) socketType
#ifndef WINDOWS
handleUnix addressLen socketType = do
when (addressLen < 216) $ fail $ "bad address length "
++ show addressLen
++ " for unix"
addr1 <- take 108
addr2 <- take 108
void $ take $ fromIntegral $ addressLen - 216
let sa = N.SockAddrUnix (toUnixPath addr1)
let sb = N.SockAddrUnix (toUnixPath addr2)
return $! makeProxyInfo sa sb (addrFamily sa) socketType
toUnixPath = S.unpack . fst . S.break (=='\x00')
#endif
foreign import ccall unsafe "iostreams_ntohs" c_ntohs :: CUShort -> CUShort
foreign import ccall unsafe "iostreams_ntohl" c_ntohl :: CUInt -> CUInt
ntohs :: Word16 -> Word16
ntohs = fromIntegral . c_ntohs . fromIntegral
ntohl :: Word32 -> Word32
ntohl = fromIntegral . c_ntohl . fromIntegral
snarf32 :: Parser Word32
snarf32 = do
s <- take 4
return $! unsafePerformIO $! S.unsafeUseAsCString s $ peek . castPtr
snarf16 :: Parser Word16
snarf16 = do
s <- take 2
return $! unsafePerformIO $! S.unsafeUseAsCString s $ peek . castPtr
addrFamily :: N.SockAddr -> N.Family
addrFamily s = case s of
(N.SockAddrInet _ _) -> N.AF_INET
(N.SockAddrInet6 _ _ _ _) -> N.AF_INET6
#ifndef WINDOWS
(N.SockAddrUnix _ ) -> N.AF_UNIX
#endif
_ -> error "unknown family"