--------------------------------------------------------------------------------
-- | Demultiplexing of frames into messages
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings  #-}
module Network.WebSockets.Hybi13.Demultiplex
    ( FrameType (..)
    , Frame (..)
    , DemultiplexState
    , emptyDemultiplexState
    , DemultiplexResult (..)
    , demultiplex
    ) where


--------------------------------------------------------------------------------
import           Data.ByteString.Builder               (Builder)
import qualified Data.ByteString.Builder               as B
import           Control.Exception                     (Exception)
import           Data.Binary.Get                       (getWord16be, runGet)
import qualified Data.ByteString.Lazy                  as BL
import           Data.Int                              (Int64)
import           Data.Monoid                           (mappend)
import           Data.Typeable                         (Typeable)
import           Network.WebSockets.Connection.Options
import           Network.WebSockets.Types


--------------------------------------------------------------------------------
-- | A low-level representation of a WebSocket packet
data Frame = Frame
    { Frame -> Bool
frameFin     :: !Bool
    , Frame -> Bool
frameRsv1    :: !Bool
    , Frame -> Bool
frameRsv2    :: !Bool
    , Frame -> Bool
frameRsv3    :: !Bool
    , Frame -> FrameType
frameType    :: !FrameType
    , Frame -> ByteString
framePayload :: !BL.ByteString
    } deriving (Frame -> Frame -> Bool
(Frame -> Frame -> Bool) -> (Frame -> Frame -> Bool) -> Eq Frame
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Frame -> Frame -> Bool
$c/= :: Frame -> Frame -> Bool
== :: Frame -> Frame -> Bool
$c== :: Frame -> Frame -> Bool
Eq, Int -> Frame -> ShowS
[Frame] -> ShowS
Frame -> String
(Int -> Frame -> ShowS)
-> (Frame -> String) -> ([Frame] -> ShowS) -> Show Frame
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Frame] -> ShowS
$cshowList :: [Frame] -> ShowS
show :: Frame -> String
$cshow :: Frame -> String
showsPrec :: Int -> Frame -> ShowS
$cshowsPrec :: Int -> Frame -> ShowS
Show)


--------------------------------------------------------------------------------
-- | The type of a frame. Not all types are allowed for all protocols.
data FrameType
    = ContinuationFrame
    | TextFrame
    | BinaryFrame
    | CloseFrame
    | PingFrame
    | PongFrame
    deriving (FrameType -> FrameType -> Bool
(FrameType -> FrameType -> Bool)
-> (FrameType -> FrameType -> Bool) -> Eq FrameType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FrameType -> FrameType -> Bool
$c/= :: FrameType -> FrameType -> Bool
== :: FrameType -> FrameType -> Bool
$c== :: FrameType -> FrameType -> Bool
Eq, Int -> FrameType -> ShowS
[FrameType] -> ShowS
FrameType -> String
(Int -> FrameType -> ShowS)
-> (FrameType -> String)
-> ([FrameType] -> ShowS)
-> Show FrameType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FrameType] -> ShowS
$cshowList :: [FrameType] -> ShowS
show :: FrameType -> String
$cshow :: FrameType -> String
showsPrec :: Int -> FrameType -> ShowS
$cshowsPrec :: Int -> FrameType -> ShowS
Show)


--------------------------------------------------------------------------------
-- | Thrown if the client sends invalid multiplexed data
data DemultiplexException = DemultiplexException
    deriving (Int -> DemultiplexException -> ShowS
[DemultiplexException] -> ShowS
DemultiplexException -> String
(Int -> DemultiplexException -> ShowS)
-> (DemultiplexException -> String)
-> ([DemultiplexException] -> ShowS)
-> Show DemultiplexException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DemultiplexException] -> ShowS
$cshowList :: [DemultiplexException] -> ShowS
show :: DemultiplexException -> String
$cshow :: DemultiplexException -> String
showsPrec :: Int -> DemultiplexException -> ShowS
$cshowsPrec :: Int -> DemultiplexException -> ShowS
Show, Typeable)


--------------------------------------------------------------------------------
instance Exception DemultiplexException


--------------------------------------------------------------------------------
-- | Internal state used by the demultiplexer
data DemultiplexState
    = EmptyDemultiplexState
    | DemultiplexState !Int64 !Builder !(Builder -> Message)


--------------------------------------------------------------------------------
emptyDemultiplexState :: DemultiplexState
emptyDemultiplexState :: DemultiplexState
emptyDemultiplexState = DemultiplexState
EmptyDemultiplexState


--------------------------------------------------------------------------------
-- | Result of demultiplexing
data DemultiplexResult
    = DemultiplexSuccess  Message
    | DemultiplexError    ConnectionException
    | DemultiplexContinue


--------------------------------------------------------------------------------
demultiplex :: SizeLimit
            -> DemultiplexState
            -> Frame
            -> (DemultiplexResult, DemultiplexState)

demultiplex :: SizeLimit
-> DemultiplexState
-> Frame
-> (DemultiplexResult, DemultiplexState)
demultiplex SizeLimit
_ DemultiplexState
state (Frame Bool
True Bool
False Bool
False Bool
False FrameType
PingFrame ByteString
pl)
    | ByteString -> Int64
BL.length ByteString
pl Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
> Int64
125 =
        (ConnectionException -> DemultiplexResult
DemultiplexError (ConnectionException -> DemultiplexResult)
-> ConnectionException -> DemultiplexResult
forall a b. (a -> b) -> a -> b
$ Word16 -> ByteString -> ConnectionException
CloseRequest Word16
1002 ByteString
"Protocol Error", DemultiplexState
emptyDemultiplexState)
    | Bool
otherwise =
        (Message -> DemultiplexResult
DemultiplexSuccess (Message -> DemultiplexResult) -> Message -> DemultiplexResult
forall a b. (a -> b) -> a -> b
$ ControlMessage -> Message
ControlMessage (ByteString -> ControlMessage
Ping ByteString
pl), DemultiplexState
state)

demultiplex SizeLimit
_ DemultiplexState
state (Frame Bool
True Bool
False Bool
False Bool
False FrameType
PongFrame ByteString
pl) =
    (Message -> DemultiplexResult
DemultiplexSuccess (ControlMessage -> Message
ControlMessage (ByteString -> ControlMessage
Pong ByteString
pl)), DemultiplexState
state)

demultiplex SizeLimit
_ DemultiplexState
_ (Frame Bool
True Bool
False Bool
False Bool
False FrameType
CloseFrame ByteString
pl) =
    (Message -> DemultiplexResult
DemultiplexSuccess (ControlMessage -> Message
ControlMessage ((Word16 -> ByteString -> ControlMessage)
-> (Word16, ByteString) -> ControlMessage
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Word16 -> ByteString -> ControlMessage
Close (Word16, ByteString)
parsedClose)), DemultiplexState
emptyDemultiplexState)
  where
    -- The Close frame MAY contain a body (the "Application data" portion of the
    -- frame) that indicates a reason for closing, such as an endpoint shutting
    -- down, an endpoint having received a frame too large, or an endpoint
    -- having received a frame that does not conform to the format expected by
    -- the endpoint. If there is a body, the first two bytes of the body MUST
    -- be a 2-byte unsigned integer (in network byte order) representing a
    -- status code with value /code/ defined in Section 7.4.
    parsedClose :: (Word16, ByteString)
parsedClose
       | ByteString -> Int64
BL.length ByteString
pl Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
2 = case Get Word16 -> ByteString -> Word16
forall a. Get a -> ByteString -> a
runGet Get Word16
getWord16be ByteString
pl of
              Word16
a | Word16
a Word16 -> Word16 -> Bool
forall a. Ord a => a -> a -> Bool
< Word16
1000 Bool -> Bool -> Bool
|| Word16
a Word16 -> [Word16] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Word16
1004,Word16
1005,Word16
1006
                                       ,Word16
1014,Word16
1015,Word16
1016
                                       ,Word16
1100,Word16
2000,Word16
2999
                                       ,Word16
5000,Word16
65535] -> (Word16
1002, ByteString
BL.empty)
              Word16
a -> (Word16
a, Int64 -> ByteString -> ByteString
BL.drop Int64
2 ByteString
pl)
       | ByteString -> Int64
BL.length ByteString
pl Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
1 = (Word16
1002, ByteString
BL.empty)
       | Bool
otherwise         = (Word16
1000, ByteString
BL.empty)

demultiplex SizeLimit
sizeLimit DemultiplexState
EmptyDemultiplexState (Frame Bool
fin Bool
rsv1 Bool
rsv2 Bool
rsv3 FrameType
tp ByteString
pl) = case FrameType
tp of
    FrameType
_ | Bool -> Bool
not (Int64 -> SizeLimit -> Bool
atMostSizeLimit Int64
size SizeLimit
sizeLimit) ->
        ( ConnectionException -> DemultiplexResult
DemultiplexError (ConnectionException -> DemultiplexResult)
-> ConnectionException -> DemultiplexResult
forall a b. (a -> b) -> a -> b
$ String -> ConnectionException
ParseException (String -> ConnectionException) -> String -> ConnectionException
forall a b. (a -> b) -> a -> b
$
            String
"Message of size " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int64 -> String
forall a. Show a => a -> String
show Int64
size String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" exceeded limit"
        , DemultiplexState
emptyDemultiplexState
        )

    FrameType
TextFrame
        | Bool
fin       ->
            (Message -> DemultiplexResult
DemultiplexSuccess (ByteString -> Message
text ByteString
pl), DemultiplexState
emptyDemultiplexState)
        | Bool
otherwise ->
            (DemultiplexResult
DemultiplexContinue, Int64 -> Builder -> (Builder -> Message) -> DemultiplexState
DemultiplexState Int64
size Builder
plb (ByteString -> Message
text (ByteString -> Message)
-> (Builder -> ByteString) -> Builder -> Message
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
B.toLazyByteString))


    FrameType
BinaryFrame
        | Bool
fin       -> (Message -> DemultiplexResult
DemultiplexSuccess (ByteString -> Message
binary ByteString
pl), DemultiplexState
emptyDemultiplexState)
        | Bool
otherwise -> (DemultiplexResult
DemultiplexContinue, Int64 -> Builder -> (Builder -> Message) -> DemultiplexState
DemultiplexState Int64
size Builder
plb (ByteString -> Message
binary (ByteString -> Message)
-> (Builder -> ByteString) -> Builder -> Message
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
B.toLazyByteString))

    FrameType
_ -> (ConnectionException -> DemultiplexResult
DemultiplexError (ConnectionException -> DemultiplexResult)
-> ConnectionException -> DemultiplexResult
forall a b. (a -> b) -> a -> b
$ Word16 -> ByteString -> ConnectionException
CloseRequest Word16
1002 ByteString
"Protocol Error", DemultiplexState
emptyDemultiplexState)

  where
    size :: Int64
size     = ByteString -> Int64
BL.length ByteString
pl
    plb :: Builder
plb      = ByteString -> Builder
B.lazyByteString ByteString
pl
    text :: ByteString -> Message
text   ByteString
x = Bool -> Bool -> Bool -> DataMessage -> Message
DataMessage Bool
rsv1 Bool
rsv2 Bool
rsv3 (ByteString -> Maybe Text -> DataMessage
Text ByteString
x Maybe Text
forall a. Maybe a
Nothing)
    binary :: ByteString -> Message
binary ByteString
x = Bool -> Bool -> Bool -> DataMessage -> Message
DataMessage Bool
rsv1 Bool
rsv2 Bool
rsv3 (ByteString -> DataMessage
Binary ByteString
x)

demultiplex SizeLimit
sizeLimit (DemultiplexState Int64
size0 Builder
b Builder -> Message
f) (Frame Bool
fin Bool
False Bool
False Bool
False FrameType
ContinuationFrame ByteString
pl)
    | Bool -> Bool
not (Int64 -> SizeLimit -> Bool
atMostSizeLimit Int64
size1 SizeLimit
sizeLimit) =
        ( ConnectionException -> DemultiplexResult
DemultiplexError (ConnectionException -> DemultiplexResult)
-> ConnectionException -> DemultiplexResult
forall a b. (a -> b) -> a -> b
$ String -> ConnectionException
ParseException (String -> ConnectionException) -> String -> ConnectionException
forall a b. (a -> b) -> a -> b
$
            String
"Message of size " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int64 -> String
forall a. Show a => a -> String
show Int64
size1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" exceeded limit"
        , DemultiplexState
emptyDemultiplexState
        )
    | Bool
fin         = (Message -> DemultiplexResult
DemultiplexSuccess (Builder -> Message
f Builder
b'), DemultiplexState
emptyDemultiplexState)
    | Bool
otherwise   = (DemultiplexResult
DemultiplexContinue, Int64 -> Builder -> (Builder -> Message) -> DemultiplexState
DemultiplexState Int64
size1 Builder
b' Builder -> Message
f)
  where
    size1 :: Int64
size1 = Int64
size0 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ ByteString -> Int64
BL.length ByteString
pl
    b' :: Builder
b'    = Builder
b Builder -> Builder -> Builder
forall a. Monoid a => a -> a -> a
`mappend` Builder
plb
    plb :: Builder
plb   = ByteString -> Builder
B.lazyByteString ByteString
pl

demultiplex SizeLimit
_ DemultiplexState
_ Frame
_ =
    (ConnectionException -> DemultiplexResult
DemultiplexError (Word16 -> ByteString -> ConnectionException
CloseRequest Word16
1002 ByteString
"Protocol Error"), DemultiplexState
emptyDemultiplexState)