--------------------------------------------------------------------------------
-- | Masking of fragmes using a simple XOR algorithm
{-# LANGUAGE BangPatterns             #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE OverloadedStrings        #-}
{-# LANGUAGE ScopedTypeVariables      #-}
module Network.WebSockets.Hybi13.Mask
    ( Mask
    , parseMask
    , encodeMask
    , randomMask

    , maskPayload
    ) where


--------------------------------------------------------------------------------
import qualified Data.ByteString.Builder       as Builder
import qualified Data.ByteString.Builder.Extra as Builder
import           Data.Binary.Get               (Get, getWord32host)
import qualified Data.ByteString.Internal      as B
import qualified Data.ByteString.Lazy          as BL
import qualified Data.ByteString.Lazy.Internal as BL
import           Data.Word                     (Word32, Word8)
import           Foreign.C.Types               (CChar (..), CInt (..),
                                                CSize (..))
import           Foreign.ForeignPtr            (withForeignPtr)
import           Foreign.Ptr                   (Ptr, plusPtr)
import           System.Random                 (RandomGen, random)


--------------------------------------------------------------------------------
foreign import ccall unsafe "_hs_mask_chunk" c_mask_chunk
    :: Word32 -> CInt -> Ptr CChar -> CSize -> Ptr Word8 -> IO ()


--------------------------------------------------------------------------------
-- | A mask is sequence of 4 bytes.  We store this in a 'Word32' in the host's
-- native byte ordering.
newtype Mask = Mask {unMask :: Word32}


--------------------------------------------------------------------------------
-- | Parse a mask.
parseMask :: Get Mask
parseMask = fmap Mask getWord32host


--------------------------------------------------------------------------------
-- | Encode a mask
encodeMask :: Mask -> Builder.Builder
encodeMask = Builder.word32Host . unMask


--------------------------------------------------------------------------------
-- | Create a random mask
randomMask :: forall g. RandomGen g => g -> (Mask, g)
randomMask gen = (Mask int, gen')
  where
    (!int, !gen') = random gen :: (Word32, g)


--------------------------------------------------------------------------------
-- | Mask a lazy bytestring.  Uses 'c_mask_chunk' under the hood.
maskPayload :: Maybe Mask -> BL.ByteString -> BL.ByteString
maskPayload Nothing            = id
maskPayload (Just (Mask 0))    = id
maskPayload (Just (Mask mask)) = go 0
  where
    go _           BL.Empty                               = BL.Empty
    go !maskOffset (BL.Chunk (B.PS payload off len) rest) =
        BL.Chunk maskedChunk (go ((maskOffset + len) `rem` 4) rest)
      where
        maskedChunk =
            B.unsafeCreate len $ \dst ->
            withForeignPtr payload $ \src ->
                c_mask_chunk mask
                    (fromIntegral maskOffset)
                    (src `plusPtr` off)
                    (fromIntegral len)
                    dst