-- SPDX-License-Identifier: Apache-2.0
--
-- Copyright (C) 2023 Bin Jin. All Rights Reserved.

module Network.HProx.Util
  ( parseHostPort
  , parseHostPortWithDefault
  , randomPadding
  , randomPaddingLength
  , responseKnownLength
  ) where

import Control.Monad          (replicateM)
import Data.ByteString        qualified as BS
import Data.ByteString.Char8  qualified as BS8
import Data.ByteString.Lazy   qualified as LBS
import Data.Maybe             (fromMaybe)
import System.Random          (uniformR)
import System.Random.Stateful
    (applyAtomicGen, globalStdGen, runStateGen, uniformRM)

import Network.HTTP.Types (ResponseHeaders, Status)
import Network.Wai

parseHostPort :: BS.ByteString -> Maybe (BS.ByteString, Int)
parseHostPort :: ByteString -> Maybe (ByteString, Int)
parseHostPort ByteString
hostPort = do
    Int
lastColon <- Char -> ByteString -> Maybe Int
BS8.elemIndexEnd Char
':' ByteString
hostPort
    Int
port <- ByteString -> Maybe (Int, ByteString)
BS8.readInt (Int -> ByteString -> ByteString
BS.drop (Int
lastColonforall a. Num a => a -> a -> a
+Int
1) ByteString
hostPort) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {a}. (Ord a, Num a) => (a, ByteString) -> Maybe a
checkPort
    forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> ByteString -> ByteString
BS.take Int
lastColon ByteString
hostPort, Int
port)
  where
    checkPort :: (a, ByteString) -> Maybe a
checkPort (a
p, ByteString
bs)
        | ByteString -> Bool
BS.null ByteString
bs Bool -> Bool -> Bool
&& a
1 forall a. Ord a => a -> a -> Bool
<= a
p Bool -> Bool -> Bool
&& a
p forall a. Ord a => a -> a -> Bool
<= a
65535 = forall a. a -> Maybe a
Just a
p
        | Bool
otherwise                          = forall a. Maybe a
Nothing

parseHostPortWithDefault :: Int -> BS.ByteString -> (BS.ByteString, Int)
parseHostPortWithDefault :: Int -> ByteString -> (ByteString, Int)
parseHostPortWithDefault Int
defaultPort ByteString
hostPort =
    forall a. a -> Maybe a -> a
fromMaybe (ByteString
hostPort, Int
defaultPort) forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe (ByteString, Int)
parseHostPort ByteString
hostPort

randomPadding :: IO BS.ByteString
randomPadding :: IO ByteString
randomPadding = forall (m :: * -> *) g a.
MonadIO m =>
(g -> (a, g)) -> AtomicGenM g -> m a
applyAtomicGen forall {g}. RandomGen g => g -> (ByteString, g)
generate AtomicGenM StdGen
globalStdGen
  where
    nonHuffman :: [Char]
nonHuffman = [Char]
"!#$()+<>?@[]^`{}"
    countNonHuffman :: Int
countNonHuffman = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Char]
nonHuffman

    generate :: g -> (ByteString, g)
generate g
g0 = forall g a.
RandomGen g =>
g -> (StateGenM g -> State g a) -> (a, g)
runStateGen g
g0 forall a b. (a -> b) -> a -> b
$ \StateGenM g
gen -> do
        Int
len <- forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Int
32, Int
63) StateGenM g
gen
        [Char]
prefix <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
24 forall a b. (a -> b) -> a -> b
$ do
            Int
idx <- forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Int
0, Int
countNonHuffman forall a. Num a => a -> a -> a
- Int
1) StateGenM g
gen
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Char]
nonHuffman forall a. [a] -> Int -> a
!! Int
idx
        forall (m :: * -> *) a. Monad m => a -> m a
return ([Char] -> ByteString
BS8.pack ([Char]
prefix forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate (Int
len forall a. Num a => a -> a -> a
- Int
24) Char
'~'))

randomPaddingLength :: IO Int
randomPaddingLength :: IO Int
randomPaddingLength = forall (m :: * -> *) g a.
MonadIO m =>
(g -> (a, g)) -> AtomicGenM g -> m a
applyAtomicGen (forall g a. (RandomGen g, UniformRange a) => (a, a) -> g -> (a, g)
uniformR (Int
1, Int
255)) AtomicGenM StdGen
globalStdGen

responseKnownLength :: Status -> ResponseHeaders -> LBS.ByteString -> Response
responseKnownLength :: Status -> ResponseHeaders -> ByteString -> Response
responseKnownLength Status
status ResponseHeaders
headers ByteString
bs = Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status (ResponseHeaders
headers forall a. [a] -> [a] -> [a]
++ [(HeaderName
"Content-Length", [Char] -> ByteString
BS8.pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show (ByteString -> Int64
LBS.length ByteString
bs))]) ByteString
bs