{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings  #-}
module Network.WebSockets.Hybi13.Demultiplex
    ( FrameType (..)
    , Frame (..)
    , DemultiplexState
    , emptyDemultiplexState
    , DemultiplexResult (..)
    , demultiplex
    ) where
import           Data.ByteString.Builder               (Builder)
import qualified Data.ByteString.Builder               as B
import           Control.Exception                     (Exception)
import           Data.Binary.Get                       (getWord16be, runGet)
import qualified Data.ByteString.Lazy                  as BL
import           Data.Int                              (Int64)
import           Data.Monoid                           (mappend)
import           Data.Typeable                         (Typeable)
import           Network.WebSockets.Connection.Options
import           Network.WebSockets.Types
data Frame = Frame
    { frameFin     :: !Bool
    , frameRsv1    :: !Bool
    , frameRsv2    :: !Bool
    , frameRsv3    :: !Bool
    , frameType    :: !FrameType
    , framePayload :: !BL.ByteString
    } deriving (Eq, Show)
data FrameType
    = ContinuationFrame
    | TextFrame
    | BinaryFrame
    | CloseFrame
    | PingFrame
    | PongFrame
    deriving (Eq, Show)
data DemultiplexException = DemultiplexException
    deriving (Show, Typeable)
instance Exception DemultiplexException
data DemultiplexState
    = EmptyDemultiplexState
    | DemultiplexState !Int64 !Builder !(Builder -> Message)
emptyDemultiplexState :: DemultiplexState
emptyDemultiplexState = EmptyDemultiplexState
data DemultiplexResult
    = DemultiplexSuccess  Message
    | DemultiplexError    ConnectionException
    | DemultiplexContinue
demultiplex :: SizeLimit
            -> DemultiplexState
            -> Frame
            -> (DemultiplexResult, DemultiplexState)
demultiplex _ state (Frame True False False False PingFrame pl)
    | BL.length pl > 125 =
        (DemultiplexError $ CloseRequest 1002 "Protocol Error", emptyDemultiplexState)
    | otherwise =
        (DemultiplexSuccess $ ControlMessage (Ping pl), state)
demultiplex _ state (Frame True False False False PongFrame pl) =
    (DemultiplexSuccess (ControlMessage (Pong pl)), state)
demultiplex _ _ (Frame True False False False CloseFrame pl) =
    (DemultiplexSuccess (ControlMessage (uncurry Close parsedClose)), emptyDemultiplexState)
  where
    
    
    
    
    
    
    
    parsedClose
       | BL.length pl >= 2 = case runGet getWord16be pl of
              a | a < 1000 || a `elem` [1004,1005,1006
                                       ,1014,1015,1016
                                       ,1100,2000,2999
                                       ,5000,65535] -> (1002, BL.empty)
              a -> (a, BL.drop 2 pl)
       | BL.length pl == 1 = (1002, BL.empty)
       | otherwise         = (1000, BL.empty)
demultiplex sizeLimit EmptyDemultiplexState (Frame fin rsv1 rsv2 rsv3 tp pl) = case tp of
    _ | not (atMostSizeLimit size sizeLimit) ->
        ( DemultiplexError $ ParseException $
            "Message of size " ++ show size ++ " exceeded limit"
        , emptyDemultiplexState
        )
    TextFrame
        | fin       ->
            (DemultiplexSuccess (text pl), emptyDemultiplexState)
        | otherwise ->
            (DemultiplexContinue, DemultiplexState size plb (text . B.toLazyByteString))
    BinaryFrame
        | fin       -> (DemultiplexSuccess (binary pl), emptyDemultiplexState)
        | otherwise -> (DemultiplexContinue, DemultiplexState size plb (binary . B.toLazyByteString))
    _ -> (DemultiplexError $ CloseRequest 1002 "Protocol Error", emptyDemultiplexState)
  where
    size     = BL.length pl
    plb      = B.lazyByteString pl
    text   x = DataMessage rsv1 rsv2 rsv3 (Text x Nothing)
    binary x = DataMessage rsv1 rsv2 rsv3 (Binary x)
demultiplex sizeLimit (DemultiplexState size0 b f) (Frame fin False False False ContinuationFrame pl)
    | not (atMostSizeLimit size1 sizeLimit) =
        ( DemultiplexError $ ParseException $
            "Message of size " ++ show size1 ++ " exceeded limit"
        , emptyDemultiplexState
        )
    | fin         = (DemultiplexSuccess (f b'), emptyDemultiplexState)
    | otherwise   = (DemultiplexContinue, DemultiplexState size1 b' f)
  where
    size1 = size0 + BL.length pl
    b'    = b `mappend` plb
    plb   = B.lazyByteString pl
demultiplex _ _ _ =
    (DemultiplexError (CloseRequest 1002 "Protocol Error"), emptyDemultiplexState)