module Network.WebSockets.Extensions.PermessageDeflate
( defaultPermessageDeflate
, PermessageDeflate(..)
, negotiateDeflate
) where
import Control.Applicative ((<$>))
import Control.Exception (throwIO)
import Control.Monad (foldM)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Lazy.Char8 as BL8
import qualified Data.ByteString.Lazy.Internal as BL
import Data.Monoid
import qualified Data.Streaming.Zlib as Zlib
import Network.WebSockets.Extensions
import Network.WebSockets.Extensions.Description
import Network.WebSockets.Http
import Network.WebSockets.Types
import Text.Read (readMaybe)
import Prelude
data PermessageDeflate = PermessageDeflate
{ serverNoContextTakeover :: Bool
, clientNoContextTakeover :: Bool
, serverMaxWindowBits :: Int
, clientMaxWindowBits :: Int
, pdCompressionLevel :: Int
} deriving (Eq, Show)
defaultPermessageDeflate :: PermessageDeflate
defaultPermessageDeflate = PermessageDeflate False False 15 15 8
toExtensionDescription :: PermessageDeflate -> ExtensionDescription
toExtensionDescription PermessageDeflate {..} = ExtensionDescription
{ extName = "permessage-deflate"
, extParams =
[("server_no_context_takeover", Nothing) | serverNoContextTakeover] ++
[("client_no_context_takeover", Nothing) | clientNoContextTakeover] ++
[("server_max_window_bits", param serverMaxWindowBits) | serverMaxWindowBits /= 15] ++
[("client_max_window_bits", param clientMaxWindowBits) | clientMaxWindowBits /= 15]
}
where
param = Just . B8.pack . show
toHeaders :: PermessageDeflate -> Headers
toHeaders pmd =
[ ( "Sec-WebSocket-Extensions"
, encodeExtensionDescriptions [toExtensionDescription pmd]
)
]
negotiateDeflate :: Maybe PermessageDeflate -> NegotiateExtension
negotiateDeflate pmd0 exts0 = do
(headers, pmd1) <- negotiateDeflateOpts exts0 pmd0
return Extension
{ extHeaders = headers
, extParse = \parseRaw -> do
inflate <- makeMessageInflater pmd1
return $ do
msg <- parseRaw
case msg of
Nothing -> return Nothing
Just m -> fmap Just (inflate m)
, extWrite = \writeRaw -> do
deflate <- makeMessageDeflater pmd1
return $ \msgs ->
mapM deflate msgs >>= writeRaw
}
where
negotiateDeflateOpts
:: ExtensionDescriptions
-> Maybe PermessageDeflate
-> Either String (Headers, Maybe PermessageDeflate)
negotiateDeflateOpts (ext : _) (Just x)
| extName ext == "x-webkit-deflate-frame" = Right
([("Sec-WebSocket-Extensions", "x-webkit-deflate-frame")], Just x)
negotiateDeflateOpts (ext : _) (Just x)
| extName ext == "permessage-deflate" = do
x' <- foldM setParam x (extParams ext)
Right (toHeaders x', Just x')
negotiateDeflateOpts (_ : exts) (Just x) =
negotiateDeflateOpts exts (Just x)
negotiateDeflateOpts _ _ = Right ([], Nothing)
setParam
:: PermessageDeflate -> ExtensionParam -> Either String PermessageDeflate
setParam pmd ("server_no_context_takeover", _) =
Right pmd {serverNoContextTakeover = True}
setParam pmd ("client_no_context_takeover", _) =
Right pmd {clientNoContextTakeover = True}
setParam pmd ("server_max_window_bits", Nothing) =
Right pmd {serverMaxWindowBits = 15}
setParam pmd ("server_max_window_bits", Just param) = do
w <- parseWindow param
Right pmd {serverMaxWindowBits = w}
setParam pmd ("client_max_window_bits", Nothing) = do
Right pmd {clientMaxWindowBits = 15}
setParam pmd ("client_max_window_bits", Just param) = do
w <- parseWindow param
Right pmd {clientMaxWindowBits = w}
setParam pmd (_, _) = Right pmd
parseWindow :: B.ByteString -> Either String Int
parseWindow bs8 = case readMaybe (B8.unpack bs8) of
Just w
| w >= 8 && w <= 15 -> Right w
| otherwise -> Left $ "Window out of bounds: " ++ show w
Nothing -> Left $ "Can't parse window: " ++ show bs8
fixWindowBits :: Int -> Int
fixWindowBits n
| n < 9 = 9
| n > 15 = 15
| otherwise = n
appTailL :: BL.ByteString
appTailL = BL.pack [0x00,0x00,0xff,0xff]
maybeStrip :: BL.ByteString -> BL.ByteString
maybeStrip x | appTailL `BL.isSuffixOf` x = BL.take (BL.length x 4) x
maybeStrip x = x
rejectExtensions :: Message -> IO Message
rejectExtensions (DataMessage rsv1 rsv2 rsv3 _) | rsv1 || rsv2 || rsv3 =
throwIO $ CloseRequest 1002 "Protocol Error"
rejectExtensions x = return x
makeMessageDeflater
:: Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageDeflater Nothing = return rejectExtensions
makeMessageDeflater (Just pmd)
| serverNoContextTakeover pmd = do
return $ \msg -> do
ptr <- initDeflate pmd
deflateMessageWith (deflateBody ptr) msg
| otherwise = do
ptr <- initDeflate pmd
return $ \msg ->
deflateMessageWith (deflateBody ptr) msg
where
initDeflate :: PermessageDeflate -> IO Zlib.Deflate
initDeflate PermessageDeflate {..} =
Zlib.initDeflate
pdCompressionLevel
(Zlib.WindowBits ( (fixWindowBits serverMaxWindowBits)))
deflateMessageWith
:: (BL.ByteString -> IO BL.ByteString)
-> Message -> IO Message
deflateMessageWith deflater (DataMessage False False False (Text x _)) = do
x' <- deflater x
return (DataMessage True False False (Text x' Nothing))
deflateMessageWith deflater (DataMessage False False False (Binary x)) = do
x' <- deflater x
return (DataMessage True False False (Binary x'))
deflateMessageWith _ x = return x
deflateBody :: Zlib.Deflate -> BL.ByteString -> IO BL.ByteString
deflateBody ptr = fmap maybeStrip . go . BL.toChunks
where
go [] = dePopper (Zlib.flushDeflate ptr)
go (c : cs) = do
bl <- Zlib.feedDeflate ptr c >>= dePopper
(bl <>) <$> go cs
dePopper :: Zlib.Popper -> IO BL.ByteString
dePopper p = p >>= \case
Zlib.PRDone -> return BL.empty
Zlib.PRNext c -> BL.chunk c <$> dePopper p
Zlib.PRError x -> throwIO $ CloseRequest 1002 (BL8.pack (show x))
makeMessageInflater :: Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageInflater Nothing = return rejectExtensions
makeMessageInflater (Just pmd)
| clientNoContextTakeover pmd =
return $ \msg -> do
ptr <- initInflate pmd
inflateMessageWith (inflateBody ptr) msg
| otherwise = do
ptr <- initInflate pmd
return $ \msg ->
inflateMessageWith (inflateBody ptr) msg
where
initInflate :: PermessageDeflate -> IO Zlib.Inflate
initInflate PermessageDeflate {..} =
Zlib.initInflate
(Zlib.WindowBits ( (fixWindowBits clientMaxWindowBits)))
inflateMessageWith
:: (BL.ByteString -> IO BL.ByteString)
-> Message -> IO Message
inflateMessageWith inflater (DataMessage True a b (Text x _)) = do
x' <- inflater x
return (DataMessage False a b (Text x' Nothing))
inflateMessageWith inflater (DataMessage True a b (Binary x)) = do
x' <- inflater x
return (DataMessage False a b (Binary x'))
inflateMessageWith _ x = return x
inflateBody :: Zlib.Inflate -> BL.ByteString -> IO BL.ByteString
inflateBody ptr =
go . BL.toChunks . (<> appTailL)
where
go [] = BL.fromStrict <$> Zlib.flushInflate ptr
go (c : cs) = do
bl <- Zlib.feedInflate ptr c >>= dePopper
(bl <>) <$> go cs