{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
module Network.WebSockets.Hybi13
( headerVersions
, finishRequest
, finishResponse
, encodeMessage
, encodeMessages
, decodeMessages
, createRequest
, encodeFrame
, parseFrame
) where
import qualified Data.ByteString.Builder as B
import Control.Applicative (pure, (<$>))
import Control.Arrow (first)
import Control.Exception (throwIO)
import Control.Monad (forM, liftM, unless,
when)
import Data.Binary.Get (Get, getInt64be,
getLazyByteString,
getWord16be, getWord8)
import Data.Binary.Put (putWord16be, runPut)
import Data.Bits ((.&.), (.|.))
import Data.ByteString (ByteString)
import qualified Data.ByteString.Base64 as B64
import Data.ByteString.Char8 ()
import qualified Data.ByteString.Lazy as BL
import Data.Digest.Pure.SHA (bytestringDigest, sha1)
import Data.IORef
import Data.Monoid (mappend, mconcat,
mempty)
import Data.Tuple (swap)
import System.Entropy as R
import System.Random (RandomGen, newStdGen)
import Network.WebSockets.Connection.Options
import Network.WebSockets.Http
import Network.WebSockets.Hybi13.Demultiplex
import Network.WebSockets.Hybi13.Mask
import Network.WebSockets.Stream (Stream)
import qualified Network.WebSockets.Stream as Stream
import Network.WebSockets.Types
headerVersions :: [ByteString]
headerVersions = ["13"]
finishRequest :: RequestHead
-> Headers
-> Either HandshakeException Response
finishRequest reqHttp headers = do
!key <- getRequestHeader reqHttp "Sec-WebSocket-Key"
let !hash = hashKey key
!encoded = B64.encode hash
return $ response101 (("Sec-WebSocket-Accept", encoded):headers) ""
finishResponse :: RequestHead
-> ResponseHead
-> Either HandshakeException Response
finishResponse request response = do
when (responseCode response /= 101) $ Left $
MalformedResponse response "Wrong response status or message."
key <- getRequestHeader request "Sec-WebSocket-Key"
responseHash <- getResponseHeader response "Sec-WebSocket-Accept"
let challengeHash = B64.encode $ hashKey key
when (responseHash /= challengeHash) $ Left $
MalformedResponse response "Challenge and response hashes do not match."
return $ Response response ""
encodeMessage :: RandomGen g => ConnectionType -> g -> Message -> (g, B.Builder)
encodeMessage conType gen msg = (gen', builder)
where
mkFrame = Frame True False False False
(mask, gen') = case conType of
ServerConnection -> (Nothing, gen)
ClientConnection -> first Just (randomMask gen)
builder = encodeFrame mask $ case msg of
(ControlMessage (Close code pl)) -> mkFrame CloseFrame $
runPut (putWord16be code) `mappend` pl
(ControlMessage (Ping pl)) -> mkFrame PingFrame pl
(ControlMessage (Pong pl)) -> mkFrame PongFrame pl
(DataMessage rsv1 rsv2 rsv3 (Text pl _)) -> Frame True rsv1 rsv2 rsv3 TextFrame pl
(DataMessage rsv1 rsv2 rsv3 (Binary pl)) -> Frame True rsv1 rsv2 rsv3 BinaryFrame pl
encodeMessages
:: ConnectionType
-> Stream
-> IO ([Message] -> IO ())
encodeMessages conType stream = do
genRef <- newIORef =<< newStdGen
return $ \msgs -> do
builders <- forM msgs $ \msg ->
atomicModifyIORef' genRef $ \s -> encodeMessage conType s msg
Stream.write stream (B.toLazyByteString $ mconcat builders)
encodeFrame :: Maybe Mask -> Frame -> B.Builder
encodeFrame mask f = B.word8 byte0 `mappend`
B.word8 byte1 `mappend` len `mappend` maskbytes `mappend`
B.lazyByteString (maskPayload mask payload)
where
byte0 = fin .|. rsv1 .|. rsv2 .|. rsv3 .|. opcode
fin = if frameFin f then 0x80 else 0x00
rsv1 = if frameRsv1 f then 0x40 else 0x00
rsv2 = if frameRsv2 f then 0x20 else 0x00
rsv3 = if frameRsv3 f then 0x10 else 0x00
payload = case frameType f of
ContinuationFrame -> framePayload f
TextFrame -> framePayload f
BinaryFrame -> framePayload f
CloseFrame -> BL.take 125 $ framePayload f
PingFrame -> BL.take 125 $ framePayload f
PongFrame -> BL.take 125 $ framePayload f
opcode = case frameType f of
ContinuationFrame -> 0x00
TextFrame -> 0x01
BinaryFrame -> 0x02
CloseFrame -> 0x08
PingFrame -> 0x09
PongFrame -> 0x0a
(maskflag, maskbytes) = case mask of
Nothing -> (0x00, mempty)
Just m -> (0x80, encodeMask m)
byte1 = maskflag .|. lenflag
len' = BL.length payload
(lenflag, len)
| len' < 126 = (fromIntegral len', mempty)
| len' < 0x10000 = (126, B.word16BE (fromIntegral len'))
| otherwise = (127, B.word64BE (fromIntegral len'))
decodeMessages
:: SizeLimit
-> SizeLimit
-> Stream
-> IO (IO (Maybe Message))
decodeMessages frameLimit messageLimit stream = do
dmRef <- newIORef emptyDemultiplexState
return $ go dmRef
where
go dmRef = do
mbFrame <- Stream.parseBin stream (parseFrame frameLimit)
case mbFrame of
Nothing -> return Nothing
Just frame -> do
demultiplexResult <- atomicModifyIORef' dmRef $
\s -> swap $ demultiplex messageLimit s frame
case demultiplexResult of
DemultiplexError err -> throwIO err
DemultiplexContinue -> go dmRef
DemultiplexSuccess msg -> return (Just msg)
parseFrame :: SizeLimit -> Get Frame
parseFrame frameSizeLimit = do
byte0 <- getWord8
let fin = byte0 .&. 0x80 == 0x80
rsv1 = byte0 .&. 0x40 == 0x40
rsv2 = byte0 .&. 0x20 == 0x20
rsv3 = byte0 .&. 0x10 == 0x10
opcode = byte0 .&. 0x0f
byte1 <- getWord8
let mask = byte1 .&. 0x80 == 0x80
lenflag = byte1 .&. 0x7f
len <- case lenflag of
126 -> fromIntegral <$> getWord16be
127 -> getInt64be
_ -> return (fromIntegral lenflag)
unless (atMostSizeLimit len frameSizeLimit) $
fail $ "Frame of size " ++ show len ++ " exceeded limit"
ft <- case opcode of
0x00 -> return ContinuationFrame
0x01 -> return TextFrame
0x02 -> return BinaryFrame
0x08 -> enforceControlFrameRestrictions len fin >> return CloseFrame
0x09 -> enforceControlFrameRestrictions len fin >> return PingFrame
0x0a -> enforceControlFrameRestrictions len fin >> return PongFrame
_ -> fail $ "Unknown opcode: " ++ show opcode
masker <- maskPayload <$> if mask then Just <$> parseMask else pure Nothing
chunks <- getLazyByteString len
return $ Frame fin rsv1 rsv2 rsv3 ft (masker chunks)
where
enforceControlFrameRestrictions len fin
| not fin = fail "Control Frames must not be fragmented!"
| len > 125 = fail "Control Frames must not carry payload > 125 bytes!"
| otherwise = pure ()
hashKey :: ByteString -> ByteString
hashKey key = unlazy $ bytestringDigest $ sha1 $ lazy $ key `mappend` guid
where
guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
lazy = BL.fromChunks . return
unlazy = mconcat . BL.toChunks
createRequest :: ByteString
-> ByteString
-> Bool
-> Headers
-> IO RequestHead
createRequest hostname path secure customHeaders = do
key <- B64.encode `liftM` getEntropy 16
return $ RequestHead path (headers key ++ customHeaders) secure
where
headers key =
[ ("Host" , hostname )
, ("Connection" , "Upgrade" )
, ("Upgrade" , "websocket" )
, ("Sec-WebSocket-Key" , key )
, ("Sec-WebSocket-Version" , versionNumber)
]
versionNumber = head headerVersions