{-# LANGUAGE LambdaCase #-}

module Http.Header
  ( Header (..)
  , decodeMany
  , parser
  , parserSmallArray
  , builder
  , builderSmallArray
  ) where

import Data.Bytes (Bytes)
import Data.Bytes.Builder (Builder)
import Data.Bytes.Parser (Parser)
import Data.Bytes.Types (Bytes (Bytes))
import Data.Primitive (ByteArray (ByteArray), SmallArray, SmallMutableArray)
import Data.Text (Text)

import Data.Bytes qualified as Bytes
import Data.Bytes.Builder qualified as Builder
import Data.Bytes.Parser qualified as Parser
import Data.Bytes.Parser.Latin qualified as Latin
import Data.Bytes.Text.Utf8 qualified as Utf8
import Data.Primitive qualified as PM
import Data.Text.Array qualified
import Data.Text.Internal qualified as Text

{- | An HTTP header. This type does not enforce a restricted character
set. If, for example, the user creates a header whose key has a colon
character, the resulting request will be malformed.
-}
data Header = Header
  { Header -> Text
name :: {-# UNPACK #-} !Text
  , Header -> Text
value :: {-# UNPACK #-} !Text
  }
  deriving (Header -> Header -> Bool
(Header -> Header -> Bool)
-> (Header -> Header -> Bool) -> Eq Header
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Header -> Header -> Bool
== :: Header -> Header -> Bool
$c/= :: Header -> Header -> Bool
/= :: Header -> Header -> Bool
Eq, Int -> Header -> ShowS
[Header] -> ShowS
Header -> String
(Int -> Header -> ShowS)
-> (Header -> String) -> ([Header] -> ShowS) -> Show Header
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Header -> ShowS
showsPrec :: Int -> Header -> ShowS
$cshow :: Header -> String
show :: Header -> String
$cshowList :: [Header] -> ShowS
showList :: [Header] -> ShowS
Show)

uninitializedHeader :: Header
{-# NOINLINE uninitializedHeader #-}
uninitializedHeader :: Header
uninitializedHeader = String -> Header
forall a. String -> a
errorWithoutStackTrace String
"parserHeaders: uninitialized header"

{- | Parse headers. Expects two CRLF sequences in a row at the end.
Fails if leftovers are encountered.
-}
decodeMany :: Int -> Bytes -> Maybe (SmallArray Header)
decodeMany :: Int -> Bytes -> Maybe (SmallArray Header)
decodeMany !Int
n !Bytes
b = (forall s. Parser () s (SmallArray Header))
-> Bytes -> Maybe (SmallArray Header)
forall e a. (forall s. Parser e s a) -> Bytes -> Maybe a
Parser.parseBytesMaybe (Int -> Parser () s (SmallArray Header)
forall s. Int -> Parser () s (SmallArray Header)
parserSmallArray Int
n Parser () s (SmallArray Header)
-> Parser () s () -> Parser () s (SmallArray Header)
forall a b. Parser () s a -> Parser () s b -> Parser () s a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* () -> Parser () s ()
forall e s. e -> Parser e s ()
Parser.endOfInput ()) Bytes
b

-- Parse headers. Stops after encountering two CRLF sequences in
-- a row.
parserSmallArray ::
  Int -> -- maximum number of headers allowed, recommended 128
  Parser () s (SmallArray Header)
parserSmallArray :: forall s. Int -> Parser () s (SmallArray Header)
parserSmallArray !Int
n = do
  SmallMutableArray s Header
dst <- ST s (SmallMutableArray s Header)
-> Parser () s (SmallMutableArray s Header)
forall s a e. ST s a -> Parser e s a
Parser.effect (Int -> Header -> ST s (SmallMutableArray (PrimState (ST s)) Header)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (SmallMutableArray (PrimState m) a)
PM.newSmallArray Int
n Header
uninitializedHeader)
  Int
-> Int
-> SmallMutableArray s Header
-> Parser () s (SmallArray Header)
forall s.
Int
-> Int
-> SmallMutableArray s Header
-> Parser () s (SmallArray Header)
parserHeaderStep Int
0 Int
n SmallMutableArray s Header
dst

parserHeaderStep ::
  Int -> -- index
  Int -> -- remaining length
  SmallMutableArray s Header ->
  Parser () s (SmallArray Header)
parserHeaderStep :: forall s.
Int
-> Int
-> SmallMutableArray s Header
-> Parser () s (SmallArray Header)
parserHeaderStep !Int
ix !Int
n !SmallMutableArray s Header
dst =
  (Char -> Bool) -> Parser () s Bool
forall e s. (Char -> Bool) -> Parser e s Bool
Latin.trySatisfy (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\r') Parser () s Bool
-> (Bool -> Parser () s (SmallArray Header))
-> Parser () s (SmallArray Header)
forall a b. Parser () s a -> (a -> Parser () s b) -> Parser () s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
True -> do
      () -> Char -> Parser () s ()
forall e s. e -> Char -> Parser e s ()
Latin.char () Char
'\n'
      ST s (SmallArray Header) -> Parser () s (SmallArray Header)
forall s a e. ST s a -> Parser e s a
Parser.effect (ST s (SmallArray Header) -> Parser () s (SmallArray Header))
-> ST s (SmallArray Header) -> Parser () s (SmallArray Header)
forall a b. (a -> b) -> a -> b
$ do
        SmallMutableArray (PrimState (ST s)) Header -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m ()
PM.shrinkSmallMutableArray SmallMutableArray s Header
SmallMutableArray (PrimState (ST s)) Header
dst Int
ix
        SmallMutableArray (PrimState (ST s)) Header
-> ST s (SmallArray Header)
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> m (SmallArray a)
PM.unsafeFreezeSmallArray SmallMutableArray s Header
SmallMutableArray (PrimState (ST s)) Header
dst
    Bool
False ->
      if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
        then do
          Header
header <- Parser () s Header
forall s. Parser () s Header
parser
          ST s () -> Parser () s ()
forall s a e. ST s a -> Parser e s a
Parser.effect (SmallMutableArray (PrimState (ST s)) Header
-> Int -> Header -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
PM.writeSmallArray SmallMutableArray s Header
SmallMutableArray (PrimState (ST s)) Header
dst Int
ix Header
header)
          Int
-> Int
-> SmallMutableArray s Header
-> Parser () s (SmallArray Header)
forall s.
Int
-> Int
-> SmallMutableArray s Header
-> Parser () s (SmallArray Header)
parserHeaderStep (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) SmallMutableArray s Header
dst
        else () -> Parser () s (SmallArray Header)
forall e s a. e -> Parser e s a
Parser.fail ()

{- | Parse a single HTTP header including the trailing CRLF sequence.
From RFC 7230:

> header-field   = field-name ":" OWS field-value OWS
> field-name     = token
> field-value    = *( field-content / obs-fold )
> field-content  = field-vchar [ 1*( SP / HTAB ) field-vchar ]
> field-vchar    = VCHAR / obs-text
-}
parser :: Parser () s Header
parser :: forall s. Parser () s Header
parser = do
  -- Header name may contain: a-z, A-Z, 0-9, underscore, hyphen
  !Bytes
name <- (Word8 -> Bool) -> Parser () s Bytes
forall e s. (Word8 -> Bool) -> Parser e s Bytes
Parser.takeWhile ((Word8 -> Bool) -> Parser () s Bytes)
-> (Word8 -> Bool) -> Parser () s Bytes
forall a b. (a -> b) -> a -> b
$ \Word8
c ->
    (Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
0x41 Bool -> Bool -> Bool
&& Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
0x5A)
      Bool -> Bool -> Bool
|| (Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
0x61 Bool -> Bool -> Bool
&& Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
0x7A)
      Bool -> Bool -> Bool
|| (Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
0x30 Bool -> Bool -> Bool
&& Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
0x39)
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x2D
      Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x5F
  () -> Char -> Parser () s ()
forall e s. e -> Char -> Parser e s ()
Latin.char () Char
':'
  (Char -> Bool) -> Parser () s ()
forall e s. (Char -> Bool) -> Parser e s ()
Latin.skipWhile (\Char
c -> Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
' ' Bool -> Bool -> Bool
|| Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\t')
  -- Header name allows vchar, space, and tab.
  Bytes
value0 <- (Word8 -> Bool) -> Parser () s Bytes
forall e s. (Word8 -> Bool) -> Parser e s Bytes
Parser.takeWhile ((Word8 -> Bool) -> Parser () s Bytes)
-> (Word8 -> Bool) -> Parser () s Bytes
forall a b. (a -> b) -> a -> b
$ \Word8
c ->
    (Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
0x20 Bool -> Bool -> Bool
&& Word8
c Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
0x7e)
      Bool -> Bool -> Bool
|| (Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x09)
  () -> Char -> Char -> Parser () s ()
forall e s. e -> Char -> Char -> Parser e s ()
Latin.char2 () Char
'\r' Char
'\n'
  -- We only need to trim the end because the leading spaces and tab
  -- were already skipped.
  let !value :: Bytes
value = (Word8 -> Bool) -> Bytes -> Bytes
Bytes.dropWhileEnd (\Word8
c -> Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x20 Bool -> Bool -> Bool
|| Word8
c Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x09) Bytes
value0
  Header -> Parser () s Header
forall a. a -> Parser () s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Header {name :: Text
name = Bytes -> Text
unsafeBytesToText Bytes
name, value :: Text
value = Bytes -> Text
unsafeBytesToText Bytes
value}

unsafeBytesToText :: Bytes -> Text
{-# INLINE unsafeBytesToText #-}
unsafeBytesToText :: Bytes -> Text
unsafeBytesToText (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) =
  Array -> Int -> Int -> Text
Text.Text (ByteArray# -> Array
Data.Text.Array.ByteArray ByteArray#
arr) Int
off Int
len

-- | Encode a header. Includes the trailing CRLF sequence.
builder :: Header -> Builder
builder :: Header -> Builder
builder Header {Text
name :: Header -> Text
name :: Text
name, Text
value :: Header -> Text
value :: Text
value} =
  Bytes -> Builder
Builder.copy (Text -> Bytes
Utf8.fromText Text
name)
    Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Char -> Char -> Builder
Builder.ascii2 Char
':' Char
' '
    Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Bytes -> Builder
Builder.copy (Text -> Bytes
Utf8.fromText Text
value)
    Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Char -> Char -> Builder
Builder.ascii2 Char
'\r' Char
'\n'

builderSmallArray :: SmallArray Header -> Builder
builderSmallArray :: SmallArray Header -> Builder
builderSmallArray = (Header -> Builder) -> SmallArray Header -> Builder
forall m a. Monoid m => (a -> m) -> SmallArray a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Header -> Builder
builder