--------------------------------------------------------------------------------
{-# LANGUAGE BangPatterns      #-}
{-# LANGUAGE OverloadedStrings #-}
module Network.WebSockets.Hybi13
    ( headerVersions
    , finishRequest
    , finishResponse
    , encodeMessage
    , encodeMessages
    , decodeMessages
    , createRequest

      -- Internal (used for testing)
    , encodeFrame
    , parseFrame
    ) where


--------------------------------------------------------------------------------
import qualified Data.ByteString.Builder               as B
import           Control.Applicative                   (pure, (<$>))
import           Control.Arrow                         (first)
import           Control.Exception                     (throwIO)
import           Control.Monad                         (forM, liftM, unless,
                                                        when)
import           Data.Binary.Get                       (Get, getInt64be,
                                                        getLazyByteString,
                                                        getWord16be, getWord8)
import           Data.Binary.Put                       (putWord16be, runPut)
import           Data.Bits                             ((.&.), (.|.))
import           Data.ByteString                       (ByteString)
import qualified Data.ByteString.Base64                as B64
import           Data.ByteString.Char8                 ()
import qualified Data.ByteString.Lazy                  as BL
import           Data.Digest.Pure.SHA                  (bytestringDigest, sha1)
import           Data.IORef
import           Data.Monoid                           (mappend, mconcat,
                                                        mempty)
import           Data.Tuple                            (swap)
import           System.Entropy                        as R
import           System.Random                         (RandomGen, newStdGen)


--------------------------------------------------------------------------------
import           Network.WebSockets.Connection.Options
import           Network.WebSockets.Http
import           Network.WebSockets.Hybi13.Demultiplex
import           Network.WebSockets.Hybi13.Mask
import           Network.WebSockets.Stream             (Stream)
import qualified Network.WebSockets.Stream             as Stream
import           Network.WebSockets.Types


--------------------------------------------------------------------------------
headerVersions :: [ByteString]
headerVersions :: [ByteString]
headerVersions = [ByteString
"13"]


--------------------------------------------------------------------------------
finishRequest :: RequestHead
              -> Headers
              -> Either HandshakeException Response
finishRequest :: RequestHead -> Headers -> Either HandshakeException Response
finishRequest RequestHead
reqHttp Headers
headers = do
    !ByteString
key <- RequestHead
-> CI ByteString -> Either HandshakeException ByteString
getRequestHeader RequestHead
reqHttp CI ByteString
"Sec-WebSocket-Key"
    let !hash :: ByteString
hash    = ByteString -> ByteString
hashKey ByteString
key
        !encoded :: ByteString
encoded = ByteString -> ByteString
B64.encode ByteString
hash
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Headers -> ByteString -> Response
response101 ((CI ByteString
"Sec-WebSocket-Accept", ByteString
encoded)forall a. a -> [a] -> [a]
:Headers
headers) ByteString
""


--------------------------------------------------------------------------------
finishResponse :: RequestHead
               -> ResponseHead
               -> Either HandshakeException Response
finishResponse :: RequestHead -> ResponseHead -> Either HandshakeException Response
finishResponse RequestHead
request ResponseHead
response = do
    -- Response message should be one of
    --
    -- - WebSocket Protocol Handshake
    -- - Switching Protocols
    --
    -- But we don't check it for now
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ResponseHead -> Int
responseCode ResponseHead
response forall a. Eq a => a -> a -> Bool
== Int
400) forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$
        RequestHead -> ResponseHead -> HandshakeException
RequestRejected RequestHead
request ResponseHead
response 
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ResponseHead -> Int
responseCode ResponseHead
response forall a. Eq a => a -> a -> Bool
/= Int
101) forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$
        ResponseHead -> [Char] -> HandshakeException
MalformedResponse ResponseHead
response [Char]
"Wrong response status or message."

    ByteString
key          <- RequestHead
-> CI ByteString -> Either HandshakeException ByteString
getRequestHeader  RequestHead
request  CI ByteString
"Sec-WebSocket-Key"
    ByteString
responseHash <- ResponseHead
-> CI ByteString -> Either HandshakeException ByteString
getResponseHeader ResponseHead
response CI ByteString
"Sec-WebSocket-Accept"
    let challengeHash :: ByteString
challengeHash = ByteString -> ByteString
B64.encode forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
hashKey ByteString
key
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString
responseHash forall a. Eq a => a -> a -> Bool
/= ByteString
challengeHash) forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$
        ResponseHead -> [Char] -> HandshakeException
MalformedResponse ResponseHead
response [Char]
"Challenge and response hashes do not match."

    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ResponseHead -> ByteString -> Response
Response ResponseHead
response ByteString
""


--------------------------------------------------------------------------------
encodeMessage :: RandomGen g => ConnectionType -> g -> Message -> (g, B.Builder)
encodeMessage :: forall g.
RandomGen g =>
ConnectionType -> g -> Message -> (g, Builder)
encodeMessage ConnectionType
conType g
gen Message
msg = (g
gen', Builder
builder)
  where
    mkFrame :: FrameType -> ByteString -> Frame
mkFrame      = Bool -> Bool -> Bool -> Bool -> FrameType -> ByteString -> Frame
Frame Bool
True Bool
False Bool
False Bool
False
    (Maybe Mask
mask, g
gen') = case ConnectionType
conType of
        ConnectionType
ServerConnection -> (forall a. Maybe a
Nothing, g
gen)
        ConnectionType
ClientConnection -> forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first forall a. a -> Maybe a
Just (forall g. RandomGen g => g -> (Mask, g)
randomMask g
gen)
    builder :: Builder
builder      = Maybe Mask -> Frame -> Builder
encodeFrame Maybe Mask
mask forall a b. (a -> b) -> a -> b
$ case Message
msg of
        (ControlMessage (Close Word16
code ByteString
pl)) -> FrameType -> ByteString -> Frame
mkFrame FrameType
CloseFrame forall a b. (a -> b) -> a -> b
$
            Put -> ByteString
runPut (Word16 -> Put
putWord16be Word16
code) forall a. Monoid a => a -> a -> a
`mappend` ByteString
pl
        (ControlMessage (Ping ByteString
pl))               -> FrameType -> ByteString -> Frame
mkFrame FrameType
PingFrame   ByteString
pl
        (ControlMessage (Pong ByteString
pl))               -> FrameType -> ByteString -> Frame
mkFrame FrameType
PongFrame   ByteString
pl
        (DataMessage Bool
rsv1 Bool
rsv2 Bool
rsv3 (Text ByteString
pl Maybe Text
_)) -> Bool -> Bool -> Bool -> Bool -> FrameType -> ByteString -> Frame
Frame Bool
True Bool
rsv1 Bool
rsv2 Bool
rsv3 FrameType
TextFrame   ByteString
pl
        (DataMessage Bool
rsv1 Bool
rsv2 Bool
rsv3 (Binary ByteString
pl)) -> Bool -> Bool -> Bool -> Bool -> FrameType -> ByteString -> Frame
Frame Bool
True Bool
rsv1 Bool
rsv2 Bool
rsv3 FrameType
BinaryFrame ByteString
pl


--------------------------------------------------------------------------------
encodeMessages
    :: ConnectionType
    -> Stream
    -> IO ([Message] -> IO ())
encodeMessages :: ConnectionType -> Stream -> IO ([Message] -> IO ())
encodeMessages ConnectionType
conType Stream
stream = do
    IORef StdGen
genRef <- forall a. a -> IO (IORef a)
newIORef forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). MonadIO m => m StdGen
newStdGen
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ \[Message]
msgs -> do
        [Builder]
builders <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Message]
msgs forall a b. (a -> b) -> a -> b
$ \Message
msg ->
          forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef StdGen
genRef forall a b. (a -> b) -> a -> b
$ \StdGen
s -> forall g.
RandomGen g =>
ConnectionType -> g -> Message -> (g, Builder)
encodeMessage ConnectionType
conType StdGen
s Message
msg
        Stream -> ByteString -> IO ()
Stream.write Stream
stream (Builder -> ByteString
B.toLazyByteString forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat [Builder]
builders)


--------------------------------------------------------------------------------
encodeFrame :: Maybe Mask -> Frame -> B.Builder
encodeFrame :: Maybe Mask -> Frame -> Builder
encodeFrame Maybe Mask
mask Frame
f = Word8 -> Builder
B.word8 Word8
byte0 forall a. Monoid a => a -> a -> a
`mappend`
    Word8 -> Builder
B.word8 Word8
byte1 forall a. Monoid a => a -> a -> a
`mappend` Builder
len forall a. Monoid a => a -> a -> a
`mappend` Builder
maskbytes forall a. Monoid a => a -> a -> a
`mappend`
    ByteString -> Builder
B.lazyByteString (Maybe Mask -> ByteString -> ByteString
maskPayload Maybe Mask
mask ByteString
payload)
  where

    byte0 :: Word8
byte0  = Word8
fin forall a. Bits a => a -> a -> a
.|. Word8
rsv1 forall a. Bits a => a -> a -> a
.|. Word8
rsv2 forall a. Bits a => a -> a -> a
.|. Word8
rsv3 forall a. Bits a => a -> a -> a
.|. Word8
opcode
    fin :: Word8
fin    = if Frame -> Bool
frameFin Frame
f  then Word8
0x80 else Word8
0x00
    rsv1 :: Word8
rsv1   = if Frame -> Bool
frameRsv1 Frame
f then Word8
0x40 else Word8
0x00
    rsv2 :: Word8
rsv2   = if Frame -> Bool
frameRsv2 Frame
f then Word8
0x20 else Word8
0x00
    rsv3 :: Word8
rsv3   = if Frame -> Bool
frameRsv3 Frame
f then Word8
0x10 else Word8
0x00
    payload :: ByteString
payload = case Frame -> FrameType
frameType Frame
f of
        FrameType
ContinuationFrame -> Frame -> ByteString
framePayload Frame
f
        FrameType
TextFrame         -> Frame -> ByteString
framePayload Frame
f
        FrameType
BinaryFrame       -> Frame -> ByteString
framePayload Frame
f
        FrameType
CloseFrame        -> Int64 -> ByteString -> ByteString
BL.take Int64
125 forall a b. (a -> b) -> a -> b
$ Frame -> ByteString
framePayload Frame
f
        FrameType
PingFrame         -> Int64 -> ByteString -> ByteString
BL.take Int64
125 forall a b. (a -> b) -> a -> b
$ Frame -> ByteString
framePayload Frame
f
        FrameType
PongFrame         -> Int64 -> ByteString -> ByteString
BL.take Int64
125 forall a b. (a -> b) -> a -> b
$ Frame -> ByteString
framePayload Frame
f
    opcode :: Word8
opcode = case Frame -> FrameType
frameType Frame
f of
        FrameType
ContinuationFrame -> Word8
0x00
        FrameType
TextFrame         -> Word8
0x01
        FrameType
BinaryFrame       -> Word8
0x02
        FrameType
CloseFrame        -> Word8
0x08
        FrameType
PingFrame         -> Word8
0x09
        FrameType
PongFrame         -> Word8
0x0a
    (Word8
maskflag, Builder
maskbytes) = case Maybe Mask
mask of
        Maybe Mask
Nothing -> (Word8
0x00, forall a. Monoid a => a
mempty)
        Just Mask
m  -> (Word8
0x80, Mask -> Builder
encodeMask Mask
m)

    byte1 :: Word8
byte1 = Word8
maskflag forall a. Bits a => a -> a -> a
.|. Word8
lenflag
    len' :: Int64
len'  = ByteString -> Int64
BL.length ByteString
payload
    (Word8
lenflag, Builder
len)
        | Int64
len' forall a. Ord a => a -> a -> Bool
< Int64
126     = (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
len', forall a. Monoid a => a
mempty)
        | Int64
len' forall a. Ord a => a -> a -> Bool
< Int64
0x10000 = (Word8
126, Word16 -> Builder
B.word16BE (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
len'))
        | Bool
otherwise      = (Word8
127, Word64 -> Builder
B.word64BE (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
len'))


--------------------------------------------------------------------------------
decodeMessages
    :: SizeLimit
    -> SizeLimit
    -> Stream
    -> IO (IO (Maybe Message))
decodeMessages :: SizeLimit -> SizeLimit -> Stream -> IO (IO (Maybe Message))
decodeMessages SizeLimit
frameLimit SizeLimit
messageLimit Stream
stream = do
    IORef DemultiplexState
dmRef <- forall a. a -> IO (IORef a)
newIORef DemultiplexState
emptyDemultiplexState
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ IORef DemultiplexState -> IO (Maybe Message)
go IORef DemultiplexState
dmRef
  where
    go :: IORef DemultiplexState -> IO (Maybe Message)
go IORef DemultiplexState
dmRef = do
        Maybe Frame
mbFrame <- forall a. Stream -> Get a -> IO (Maybe a)
Stream.parseBin Stream
stream (SizeLimit -> Get Frame
parseFrame SizeLimit
frameLimit)
        case Maybe Frame
mbFrame of
            Maybe Frame
Nothing    -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
            Just Frame
frame -> do
                DemultiplexResult
demultiplexResult <- forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef DemultiplexState
dmRef forall a b. (a -> b) -> a -> b
$
                    \DemultiplexState
s -> forall a b. (a, b) -> (b, a)
swap forall a b. (a -> b) -> a -> b
$ SizeLimit
-> DemultiplexState
-> Frame
-> (DemultiplexResult, DemultiplexState)
demultiplex SizeLimit
messageLimit DemultiplexState
s Frame
frame
                case DemultiplexResult
demultiplexResult of
                    DemultiplexError ConnectionException
err    -> forall e a. Exception e => e -> IO a
throwIO ConnectionException
err
                    DemultiplexResult
DemultiplexContinue     -> IORef DemultiplexState -> IO (Maybe Message)
go IORef DemultiplexState
dmRef
                    DemultiplexSuccess  Message
msg -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just Message
msg)


--------------------------------------------------------------------------------
-- | Parse a frame
parseFrame :: SizeLimit -> Get Frame
parseFrame :: SizeLimit -> Get Frame
parseFrame SizeLimit
frameSizeLimit = do
    Word8
byte0 <- Get Word8
getWord8
    let fin :: Bool
fin    = Word8
byte0 forall a. Bits a => a -> a -> a
.&. Word8
0x80 forall a. Eq a => a -> a -> Bool
== Word8
0x80
        rsv1 :: Bool
rsv1   = Word8
byte0 forall a. Bits a => a -> a -> a
.&. Word8
0x40 forall a. Eq a => a -> a -> Bool
== Word8
0x40
        rsv2 :: Bool
rsv2   = Word8
byte0 forall a. Bits a => a -> a -> a
.&. Word8
0x20 forall a. Eq a => a -> a -> Bool
== Word8
0x20
        rsv3 :: Bool
rsv3   = Word8
byte0 forall a. Bits a => a -> a -> a
.&. Word8
0x10 forall a. Eq a => a -> a -> Bool
== Word8
0x10
        opcode :: Word8
opcode = Word8
byte0 forall a. Bits a => a -> a -> a
.&. Word8
0x0f

    Word8
byte1 <- Get Word8
getWord8
    let mask :: Bool
mask = Word8
byte1 forall a. Bits a => a -> a -> a
.&. Word8
0x80 forall a. Eq a => a -> a -> Bool
== Word8
0x80
        lenflag :: Word8
lenflag = Word8
byte1 forall a. Bits a => a -> a -> a
.&. Word8
0x7f

    Int64
len <- case Word8
lenflag of
        Word8
126 -> forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word16
getWord16be
        Word8
127 -> Get Int64
getInt64be
        Word8
_   -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
lenflag)

    -- Check size against limit.
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int64 -> SizeLimit -> Bool
atMostSizeLimit Int64
len SizeLimit
frameSizeLimit) forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail forall a b. (a -> b) -> a -> b
$ [Char]
"Frame of size " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Int64
len forall a. [a] -> [a] -> [a]
++ [Char]
" exceeded limit"

    FrameType
ft <- case Word8
opcode of
        Word8
0x00 -> forall (m :: * -> *) a. Monad m => a -> m a
return FrameType
ContinuationFrame
        Word8
0x01 -> forall (m :: * -> *) a. Monad m => a -> m a
return FrameType
TextFrame
        Word8
0x02 -> forall (m :: * -> *) a. Monad m => a -> m a
return FrameType
BinaryFrame
        Word8
0x08 -> forall {m :: * -> *} {a}.
(MonadFail m, Ord a, Num a) =>
a -> Bool -> m ()
enforceControlFrameRestrictions Int64
len Bool
fin forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return FrameType
CloseFrame
        Word8
0x09 -> forall {m :: * -> *} {a}.
(MonadFail m, Ord a, Num a) =>
a -> Bool -> m ()
enforceControlFrameRestrictions Int64
len Bool
fin forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return FrameType
PingFrame
        Word8
0x0a -> forall {m :: * -> *} {a}.
(MonadFail m, Ord a, Num a) =>
a -> Bool -> m ()
enforceControlFrameRestrictions Int64
len Bool
fin forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return FrameType
PongFrame
        Word8
_    -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown opcode: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Word8
opcode

    ByteString -> ByteString
masker <- Maybe Mask -> ByteString -> ByteString
maskPayload forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> if Bool
mask then forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Mask
parseMask else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

    ByteString
chunks <- Int64 -> Get ByteString
getLazyByteString Int64
len

    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Bool -> Bool -> Bool -> Bool -> FrameType -> ByteString -> Frame
Frame Bool
fin Bool
rsv1 Bool
rsv2 Bool
rsv3 FrameType
ft (ByteString -> ByteString
masker ByteString
chunks)

    where
        enforceControlFrameRestrictions :: a -> Bool -> m ()
enforceControlFrameRestrictions a
len Bool
fin
          | Bool -> Bool
not Bool
fin   = forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Control Frames must not be fragmented!"
          | a
len forall a. Ord a => a -> a -> Bool
> a
125 = forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Control Frames must not carry payload > 125 bytes!"
          | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

--------------------------------------------------------------------------------
hashKey :: ByteString -> ByteString
hashKey :: ByteString -> ByteString
hashKey ByteString
key = ByteString -> ByteString
unlazy forall a b. (a -> b) -> a -> b
$ forall t. Digest t -> ByteString
bytestringDigest forall a b. (a -> b) -> a -> b
$ ByteString -> Digest SHA1State
sha1 forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
lazy forall a b. (a -> b) -> a -> b
$ ByteString
key forall a. Monoid a => a -> a -> a
`mappend` ByteString
guid
  where
    guid :: ByteString
guid = ByteString
"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
    lazy :: ByteString -> ByteString
lazy = [ByteString] -> ByteString
BL.fromChunks forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return
    unlazy :: ByteString -> ByteString
unlazy = forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
BL.toChunks


--------------------------------------------------------------------------------
createRequest :: ByteString
              -> ByteString
              -> Bool
              -> Headers
              -> IO RequestHead
createRequest :: ByteString -> ByteString -> Bool -> Headers -> IO RequestHead
createRequest ByteString
hostname ByteString
path Bool
secure Headers
customHeaders = do
    ByteString
key <- ByteString -> ByteString
B64.encode forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM`  Int -> IO ByteString
getEntropy Int
16
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString -> Headers -> Bool -> RequestHead
RequestHead ByteString
path (forall {a}. IsString a => ByteString -> [(a, ByteString)]
headers ByteString
key forall a. [a] -> [a] -> [a]
++ Headers
customHeaders) Bool
secure
  where
    headers :: ByteString -> [(a, ByteString)]
headers ByteString
key =
        [ (a
"Host"                   , ByteString
hostname     )
        , (a
"Connection"             , ByteString
"Upgrade"    )
        , (a
"Upgrade"                , ByteString
"websocket"  )
        , (a
"Sec-WebSocket-Key"      , ByteString
key          )
        , (a
"Sec-WebSocket-Version"  , ByteString
versionNumber)
        ]

    versionNumber :: ByteString
versionNumber = forall a. [a] -> a
head [ByteString]
headerVersions