--------------------------------------------------------------------------------
{-# LANGUAGE BangPatterns      #-}
{-# LANGUAGE OverloadedStrings #-}
module Network.WebSockets.Hybi13
    ( headerVersions
    , finishRequest
    , finishResponse
    , encodeMessage
    , encodeMessages
    , decodeMessages
    , createRequest

      -- Internal (used for testing)
    , encodeFrame
    , parseFrame
    ) where


--------------------------------------------------------------------------------
import qualified Data.ByteString.Builder               as B
import           Control.Applicative                   (pure, (<$>))
import           Control.Arrow                         (first)
import           Control.Exception                     (throwIO)
import           Control.Monad                         (forM, liftM, unless,
                                                        when)
import           Data.Binary.Get                       (Get, getInt64be,
                                                        getLazyByteString,
                                                        getWord16be, getWord8)
import           Data.Binary.Put                       (putWord16be, runPut)
import           Data.Bits                             ((.&.), (.|.))
import           Data.ByteString                       (ByteString)
import qualified Data.ByteString.Base64                as B64
import           Data.ByteString.Char8                 ()
import qualified Data.ByteString.Lazy                  as BL
import           Data.Digest.Pure.SHA                  (bytestringDigest, sha1)
import           Data.IORef
import           Data.Monoid                           (mappend, mconcat,
                                                        mempty)
import           Data.Tuple                            (swap)
import           System.Entropy                        as R
import           System.Random                         (RandomGen, newStdGen)


--------------------------------------------------------------------------------
import           Network.WebSockets.Connection.Options
import           Network.WebSockets.Http
import           Network.WebSockets.Hybi13.Demultiplex
import           Network.WebSockets.Hybi13.Mask
import           Network.WebSockets.Stream             (Stream)
import qualified Network.WebSockets.Stream             as Stream
import           Network.WebSockets.Types


--------------------------------------------------------------------------------
headerVersions :: [ByteString]
headerVersions = ["13"]


--------------------------------------------------------------------------------
finishRequest :: RequestHead
              -> Headers
              -> Either HandshakeException Response
finishRequest reqHttp headers = do
    !key <- getRequestHeader reqHttp "Sec-WebSocket-Key"
    let !hash    = hashKey key
        !encoded = B64.encode hash
    return $ response101 (("Sec-WebSocket-Accept", encoded):headers) ""


--------------------------------------------------------------------------------
finishResponse :: RequestHead
               -> ResponseHead
               -> Either HandshakeException Response
finishResponse request response = do
    -- Response message should be one of
    --
    -- - WebSocket Protocol Handshake
    -- - Switching Protocols
    --
    -- But we don't check it for now
    when (responseCode response /= 101) $ Left $
        MalformedResponse response "Wrong response status or message."

    key          <- getRequestHeader  request  "Sec-WebSocket-Key"
    responseHash <- getResponseHeader response "Sec-WebSocket-Accept"
    let challengeHash = B64.encode $ hashKey key
    when (responseHash /= challengeHash) $ Left $
        MalformedResponse response "Challenge and response hashes do not match."

    return $ Response response ""


--------------------------------------------------------------------------------
encodeMessage :: RandomGen g => ConnectionType -> g -> Message -> (g, B.Builder)
encodeMessage conType gen msg = (gen', builder)
  where
    mkFrame      = Frame True False False False
    (mask, gen') = case conType of
        ServerConnection -> (Nothing, gen)
        ClientConnection -> first Just (randomMask gen)
    builder      = encodeFrame mask $ case msg of
        (ControlMessage (Close code pl)) -> mkFrame CloseFrame $
            runPut (putWord16be code) `mappend` pl
        (ControlMessage (Ping pl))               -> mkFrame PingFrame   pl
        (ControlMessage (Pong pl))               -> mkFrame PongFrame   pl
        (DataMessage rsv1 rsv2 rsv3 (Text pl _)) -> Frame True rsv1 rsv2 rsv3 TextFrame   pl
        (DataMessage rsv1 rsv2 rsv3 (Binary pl)) -> Frame True rsv1 rsv2 rsv3 BinaryFrame pl


--------------------------------------------------------------------------------
encodeMessages
    :: ConnectionType
    -> Stream
    -> IO ([Message] -> IO ())
encodeMessages conType stream = do
    genRef <- newIORef =<< newStdGen
    return $ \msgs -> do
        builders <- forM msgs $ \msg ->
          atomicModifyIORef' genRef $ \s -> encodeMessage conType s msg
        Stream.write stream (B.toLazyByteString $ mconcat builders)


--------------------------------------------------------------------------------
encodeFrame :: Maybe Mask -> Frame -> B.Builder
encodeFrame mask f = B.word8 byte0 `mappend`
    B.word8 byte1 `mappend` len `mappend` maskbytes `mappend`
    B.lazyByteString (maskPayload mask payload)
  where

    byte0  = fin .|. rsv1 .|. rsv2 .|. rsv3 .|. opcode
    fin    = if frameFin f  then 0x80 else 0x00
    rsv1   = if frameRsv1 f then 0x40 else 0x00
    rsv2   = if frameRsv2 f then 0x20 else 0x00
    rsv3   = if frameRsv3 f then 0x10 else 0x00
    payload = case frameType f of
        ContinuationFrame -> framePayload f
        TextFrame         -> framePayload f
        BinaryFrame       -> framePayload f
        CloseFrame        -> BL.take 125 $ framePayload f
        PingFrame         -> BL.take 125 $ framePayload f
        PongFrame         -> BL.take 125 $ framePayload f
    opcode = case frameType f of
        ContinuationFrame -> 0x00
        TextFrame         -> 0x01
        BinaryFrame       -> 0x02
        CloseFrame        -> 0x08
        PingFrame         -> 0x09
        PongFrame         -> 0x0a
    (maskflag, maskbytes) = case mask of
        Nothing -> (0x00, mempty)
        Just m  -> (0x80, encodeMask m)

    byte1 = maskflag .|. lenflag
    len'  = BL.length payload
    (lenflag, len)
        | len' < 126     = (fromIntegral len', mempty)
        | len' < 0x10000 = (126, B.word16BE (fromIntegral len'))
        | otherwise      = (127, B.word64BE (fromIntegral len'))


--------------------------------------------------------------------------------
decodeMessages
    :: SizeLimit
    -> SizeLimit
    -> Stream
    -> IO (IO (Maybe Message))
decodeMessages frameLimit messageLimit stream = do
    dmRef <- newIORef emptyDemultiplexState
    return $ go dmRef
  where
    go dmRef = do
        mbFrame <- Stream.parseBin stream (parseFrame frameLimit)
        case mbFrame of
            Nothing    -> return Nothing
            Just frame -> do
                demultiplexResult <- atomicModifyIORef' dmRef $
                    \s -> swap $ demultiplex messageLimit s frame
                case demultiplexResult of
                    DemultiplexError err    -> throwIO err
                    DemultiplexContinue     -> go dmRef
                    DemultiplexSuccess  msg -> return (Just msg)


--------------------------------------------------------------------------------
-- | Parse a frame
parseFrame :: SizeLimit -> Get Frame
parseFrame frameSizeLimit = do
    byte0 <- getWord8
    let fin    = byte0 .&. 0x80 == 0x80
        rsv1   = byte0 .&. 0x40 == 0x40
        rsv2   = byte0 .&. 0x20 == 0x20
        rsv3   = byte0 .&. 0x10 == 0x10
        opcode = byte0 .&. 0x0f

    byte1 <- getWord8
    let mask = byte1 .&. 0x80 == 0x80
        lenflag = byte1 .&. 0x7f

    len <- case lenflag of
        126 -> fromIntegral <$> getWord16be
        127 -> getInt64be
        _   -> return (fromIntegral lenflag)

    -- Check size against limit.
    unless (atMostSizeLimit len frameSizeLimit) $
        fail $ "Frame of size " ++ show len ++ " exceeded limit"

    ft <- case opcode of
        0x00 -> return ContinuationFrame
        0x01 -> return TextFrame
        0x02 -> return BinaryFrame
        0x08 -> enforceControlFrameRestrictions len fin >> return CloseFrame
        0x09 -> enforceControlFrameRestrictions len fin >> return PingFrame
        0x0a -> enforceControlFrameRestrictions len fin >> return PongFrame
        _    -> fail $ "Unknown opcode: " ++ show opcode

    masker <- maskPayload <$> if mask then Just <$> parseMask else pure Nothing

    chunks <- getLazyByteString len

    return $ Frame fin rsv1 rsv2 rsv3 ft (masker chunks)

    where
        enforceControlFrameRestrictions len fin
          | not fin   = fail "Control Frames must not be fragmented!"
          | len > 125 = fail "Control Frames must not carry payload > 125 bytes!"
          | otherwise = pure ()

--------------------------------------------------------------------------------
hashKey :: ByteString -> ByteString
hashKey key = unlazy $ bytestringDigest $ sha1 $ lazy $ key `mappend` guid
  where
    guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
    lazy = BL.fromChunks . return
    unlazy = mconcat . BL.toChunks


--------------------------------------------------------------------------------
createRequest :: ByteString
              -> ByteString
              -> Bool
              -> Headers
              -> IO RequestHead
createRequest hostname path secure customHeaders = do
    key <- B64.encode `liftM`  getEntropy 16
    return $ RequestHead path (headers key ++ customHeaders) secure
  where
    headers key =
        [ ("Host"                   , hostname     )
        , ("Connection"             , "Upgrade"    )
        , ("Upgrade"                , "websocket"  )
        , ("Sec-WebSocket-Key"      , key          )
        , ("Sec-WebSocket-Version"  , versionNumber)
        ]

    versionNumber = head headerVersions