-- SPDX-License-Identifier: Apache-2.0
--
-- Copyright (C) 2023 Bin Jin. All Rights Reserved.

module Network.HProx.Naive
  ( PaddingType(..)
  , addPaddingConduit
  , parseRequestForPadding
  , prepareResponseForPadding
  , removePaddingConduit
  ) where

import Control.Monad             (replicateM, unless)
import Control.Monad.IO.Class    (liftIO)
import Data.Binary.Builder       qualified as BB
import Data.ByteString           qualified as BS
import Data.ByteString.Char8     qualified as BS8
import Data.ByteString.Lazy      qualified as LBS
import Data.Conduit.Binary       qualified as CB
import Data.Maybe                (mapMaybe)
import Network.HTTP.Types.Header qualified as HT
import System.Random             (uniformR)
import System.Random.Stateful    (applyAtomicGen, globalStdGen, runStateGen, uniformRM)

import Data.Conduit
import Network.Wai

randomPadding :: IO BS8.ByteString
randomPadding :: IO ByteString
randomPadding = (StdGen -> (ByteString, StdGen))
-> AtomicGenM StdGen -> IO ByteString
forall (m :: * -> *) g a.
MonadIO m =>
(g -> (a, g)) -> AtomicGenM g -> m a
applyAtomicGen StdGen -> (ByteString, StdGen)
forall {g}. RandomGen g => g -> (ByteString, g)
generate AtomicGenM StdGen
globalStdGen
  where
    nonHuffman :: [Char]
nonHuffman = [Char]
"!#$()+<>?@[]^`{}"
    countNonHuffman :: Int
countNonHuffman = [Char] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Char]
nonHuffman

    generate :: g -> (ByteString, g)
generate g
g0 = g -> (StateGenM g -> State g ByteString) -> (ByteString, g)
forall g a.
RandomGen g =>
g -> (StateGenM g -> State g a) -> (a, g)
runStateGen g
g0 ((StateGenM g -> State g ByteString) -> (ByteString, g))
-> (StateGenM g -> State g ByteString) -> (ByteString, g)
forall a b. (a -> b) -> a -> b
$ \StateGenM g
gen -> do
        Int
len <- (Int, Int) -> StateGenM g -> StateT g Identity Int
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *). StatefulGen g m => (Int, Int) -> g -> m Int
uniformRM (Int
32, Int
63) StateGenM g
gen
        [Char]
prefix <- Int -> StateT g Identity Char -> StateT g Identity [Char]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
24 (StateT g Identity Char -> StateT g Identity [Char])
-> StateT g Identity Char -> StateT g Identity [Char]
forall a b. (a -> b) -> a -> b
$ do
            Int
idx <- (Int, Int) -> StateGenM g -> StateT g Identity Int
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *). StatefulGen g m => (Int, Int) -> g -> m Int
uniformRM (Int
0, Int
countNonHuffman Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) StateGenM g
gen
            Char -> StateT g Identity Char
forall a. a -> StateT g Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Char -> StateT g Identity Char) -> Char -> StateT g Identity Char
forall a b. (a -> b) -> a -> b
$ [Char]
nonHuffman [Char] -> Int -> Char
forall a. HasCallStack => [a] -> Int -> a
!! Int
idx
        ByteString -> State g ByteString
forall a. a -> StateT g Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Char] -> ByteString
BS8.pack ([Char]
prefix [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> Char -> [Char]
forall a. Int -> a -> [a]
replicate (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
24) Char
'~'))

randInt :: Int -> Int -> IO Int
randInt :: Int -> Int -> IO Int
randInt Int
minv Int
maxv = (StdGen -> (Int, StdGen)) -> AtomicGenM StdGen -> IO Int
forall (m :: * -> *) g a.
MonadIO m =>
(g -> (a, g)) -> AtomicGenM g -> m a
applyAtomicGen ((Int, Int) -> StdGen -> (Int, StdGen)
forall g a. (RandomGen g, UniformRange a) => (a, a) -> g -> (a, g)
uniformR (Int
minv, Int
maxv)) AtomicGenM StdGen
globalStdGen

-- https://github.com/klzgrad/naiveproxy/blob/master/src/net/tools/naive/naive_protocol.h#L30C12-L30C23
data PaddingType = NoPadding
                 | Variant1
  deriving (Int -> PaddingType -> [Char] -> [Char]
[PaddingType] -> [Char] -> [Char]
PaddingType -> [Char]
(Int -> PaddingType -> [Char] -> [Char])
-> (PaddingType -> [Char])
-> ([PaddingType] -> [Char] -> [Char])
-> Show PaddingType
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> PaddingType -> [Char] -> [Char]
showsPrec :: Int -> PaddingType -> [Char] -> [Char]
$cshow :: PaddingType -> [Char]
show :: PaddingType -> [Char]
$cshowList :: [PaddingType] -> [Char] -> [Char]
showList :: [PaddingType] -> [Char] -> [Char]
Show, PaddingType -> PaddingType -> Bool
(PaddingType -> PaddingType -> Bool)
-> (PaddingType -> PaddingType -> Bool) -> Eq PaddingType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PaddingType -> PaddingType -> Bool
== :: PaddingType -> PaddingType -> Bool
$c/= :: PaddingType -> PaddingType -> Bool
/= :: PaddingType -> PaddingType -> Bool
Eq, Eq PaddingType
Eq PaddingType =>
(PaddingType -> PaddingType -> Ordering)
-> (PaddingType -> PaddingType -> Bool)
-> (PaddingType -> PaddingType -> Bool)
-> (PaddingType -> PaddingType -> Bool)
-> (PaddingType -> PaddingType -> Bool)
-> (PaddingType -> PaddingType -> PaddingType)
-> (PaddingType -> PaddingType -> PaddingType)
-> Ord PaddingType
PaddingType -> PaddingType -> Bool
PaddingType -> PaddingType -> Ordering
PaddingType -> PaddingType -> PaddingType
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: PaddingType -> PaddingType -> Ordering
compare :: PaddingType -> PaddingType -> Ordering
$c< :: PaddingType -> PaddingType -> Bool
< :: PaddingType -> PaddingType -> Bool
$c<= :: PaddingType -> PaddingType -> Bool
<= :: PaddingType -> PaddingType -> Bool
$c> :: PaddingType -> PaddingType -> Bool
> :: PaddingType -> PaddingType -> Bool
$c>= :: PaddingType -> PaddingType -> Bool
>= :: PaddingType -> PaddingType -> Bool
$cmax :: PaddingType -> PaddingType -> PaddingType
max :: PaddingType -> PaddingType -> PaddingType
$cmin :: PaddingType -> PaddingType -> PaddingType
min :: PaddingType -> PaddingType -> PaddingType
Ord)

parsePaddingType :: BS8.ByteString -> Maybe PaddingType
parsePaddingType :: ByteString -> Maybe PaddingType
parsePaddingType ByteString
"0" = PaddingType -> Maybe PaddingType
forall a. a -> Maybe a
Just PaddingType
NoPadding
parsePaddingType ByteString
"1" = PaddingType -> Maybe PaddingType
forall a. a -> Maybe a
Just PaddingType
Variant1
parsePaddingType ByteString
_   = Maybe PaddingType
forall a. Maybe a
Nothing

showPaddingType :: PaddingType -> BS8.ByteString
showPaddingType :: PaddingType -> ByteString
showPaddingType PaddingType
NoPadding = ByteString
"0"
showPaddingType PaddingType
Variant1  = ByteString
"1"

legacyPaddingHeader :: HT.HeaderName
legacyPaddingHeader :: HeaderName
legacyPaddingHeader = HeaderName
"Padding"

paddingTypeRequestHeader :: HT.HeaderName
paddingTypeRequestHeader :: HeaderName
paddingTypeRequestHeader = HeaderName
"Padding-Type-Request"

paddingTypeReplyHeader :: HT.HeaderName
paddingTypeReplyHeader :: HeaderName
paddingTypeReplyHeader = HeaderName
"Padding-Type-Reply"

type PaddingConduit = ConduitT BS.ByteString BS.ByteString IO ()

noPaddingConduit :: PaddingConduit
noPaddingConduit :: PaddingConduit
noPaddingConduit = (ByteString -> PaddingConduit) -> PaddingConduit
forall (m :: * -> *) i o r.
Monad m =>
(i -> ConduitT i o m r) -> ConduitT i o m ()
awaitForever ByteString -> PaddingConduit
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield

addPaddingConduit :: PaddingType -> PaddingConduit
addPaddingConduit :: PaddingType -> PaddingConduit
addPaddingConduit PaddingType
NoPadding = PaddingConduit
noPaddingConduit
addPaddingConduit PaddingType
Variant1  = Int -> PaddingConduit
addPaddingVariant1 Int
countPaddingsVariant1

removePaddingConduit :: PaddingType -> PaddingConduit
removePaddingConduit :: PaddingType -> PaddingConduit
removePaddingConduit PaddingType
NoPadding = PaddingConduit
noPaddingConduit
removePaddingConduit PaddingType
Variant1  = Int -> PaddingConduit
removePaddingVariant1 Int
countPaddingsVariant1

parseRequestForPadding :: Request -> Maybe PaddingType
parseRequestForPadding :: Request -> Maybe PaddingType
parseRequestForPadding Request
req
    | Just ByteString
paddingTypesStr <- HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
paddingTypeRequestHeader (Request -> [(HeaderName, ByteString)]
requestHeaders Request
req) =
        let paddings :: [PaddingType]
paddings = (ByteString -> Maybe PaddingType) -> [ByteString] -> [PaddingType]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ByteString -> Maybe PaddingType
parsePaddingType ([ByteString] -> [PaddingType]) -> [ByteString] -> [PaddingType]
forall a b. (a -> b) -> a -> b
$ Char -> ByteString -> [ByteString]
BS8.split Char
',' ByteString
paddingTypesStr
        in if [PaddingType] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PaddingType]
paddings then Maybe PaddingType
forall a. Maybe a
Nothing else PaddingType -> Maybe PaddingType
forall a. a -> Maybe a
Just ([PaddingType] -> PaddingType
forall a. HasCallStack => [a] -> a
head [PaddingType]
paddings)
    | Just ByteString
_ <- HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
legacyPaddingHeader (Request -> [(HeaderName, ByteString)]
requestHeaders Request
req) = PaddingType -> Maybe PaddingType
forall a. a -> Maybe a
Just PaddingType
Variant1
    | Bool
otherwise                                                 = Maybe PaddingType
forall a. Maybe a
Nothing

prepareResponseForPadding :: Maybe PaddingType -> IO [HT.Header]
prepareResponseForPadding :: Maybe PaddingType -> IO [(HeaderName, ByteString)]
prepareResponseForPadding Maybe PaddingType
Nothing = [(HeaderName, ByteString)] -> IO [(HeaderName, ByteString)]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
prepareResponseForPadding (Just PaddingType
paddingType) = do
    ByteString
rndPadding <- IO ByteString
randomPadding
    [(HeaderName, ByteString)] -> IO [(HeaderName, ByteString)]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [(HeaderName
legacyPaddingHeader, ByteString
rndPadding), (HeaderName
paddingTypeReplyHeader, PaddingType -> ByteString
showPaddingType PaddingType
paddingType)]

-- see: https://github.com/klzgrad/naiveproxy/blob/master/src/net/tools/naive/naive_protocol.h#L34
countPaddingsVariant1 :: Int
countPaddingsVariant1 :: Int
countPaddingsVariant1 = Int
8

addPaddingVariant1 :: Int -> PaddingConduit
addPaddingVariant1 :: Int -> PaddingConduit
addPaddingVariant1 Int
0 = PaddingConduit
noPaddingConduit
addPaddingVariant1 Int
n = do
    Maybe ByteString
mbs <- ConduitT ByteString ByteString IO (Maybe ByteString)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
    case Maybe ByteString
mbs of
        Maybe ByteString
Nothing -> () -> PaddingConduit
forall a. a -> ConduitT ByteString ByteString IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Just ByteString
bs | ByteString -> Bool
BS.null ByteString
bs -> () -> PaddingConduit
forall a. a -> ConduitT ByteString ByteString IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Just ByteString
bs -> do
            let remaining :: Int
remaining = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (ByteString -> Int
BS.length ByteString
bs) (Int
65535 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
255)
            Int
toConsume <- if Int
remaining Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
400 Bool -> Bool -> Bool
&& Int
remaining Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1024
                         then IO Int -> ConduitT ByteString ByteString IO Int
forall a. IO a -> ConduitT ByteString ByteString IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> ConduitT ByteString ByteString IO Int)
-> IO Int -> ConduitT ByteString ByteString IO Int
forall a b. (a -> b) -> a -> b
$ Int -> Int -> IO Int
randInt Int
200 Int
300
                         else Int -> ConduitT ByteString ByteString IO Int
forall a. a -> ConduitT ByteString ByteString IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
remaining
            let (ByteString
bs0, ByteString
bs1) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
toConsume ByteString
bs
            Bool -> PaddingConduit -> PaddingConduit
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
bs1) (PaddingConduit -> PaddingConduit)
-> PaddingConduit -> PaddingConduit
forall a b. (a -> b) -> a -> b
$ ByteString -> PaddingConduit
forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover ByteString
bs1
            let len :: Int
len = ByteString -> Int
BS.length ByteString
bs0
                minPaddingLen :: Int
minPaddingLen = if Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
100 then Int
255 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len else Int
1
            Int
paddingLen <- IO Int -> ConduitT ByteString ByteString IO Int
forall a. IO a -> ConduitT ByteString ByteString IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> ConduitT ByteString ByteString IO Int)
-> IO Int -> ConduitT ByteString ByteString IO Int
forall a b. (a -> b) -> a -> b
$ Int -> Int -> IO Int
randInt Int
minPaddingLen Int
255
            let header :: Builder
header = [Builder] -> Builder
forall a. Monoid a => [a] -> a
mconcat ((Int -> Builder) -> [Int] -> [Builder]
forall a b. (a -> b) -> [a] -> [b]
map (Word8 -> Builder
BB.singleton(Word8 -> Builder) -> (Int -> Word8) -> Int -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral) [Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
256, Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
256, Int
paddingLen])
                body :: Builder
body   = ByteString -> Builder
BB.fromByteString ByteString
bs0
                tailer :: Builder
tailer = ByteString -> Builder
BB.fromByteString (Int -> Word8 -> ByteString
BS.replicate Int
paddingLen Word8
0)
            ByteString -> PaddingConduit
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield (ByteString -> PaddingConduit) -> ByteString -> PaddingConduit
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
LBS.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
BB.toLazyByteString (Builder
header Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
body Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
tailer)
            Int -> PaddingConduit
addPaddingVariant1 (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

removePaddingVariant1 :: Int -> PaddingConduit
removePaddingVariant1 :: Int -> PaddingConduit
removePaddingVariant1 Int
0 = PaddingConduit
noPaddingConduit
removePaddingVariant1 Int
n = do
    ByteString
header <- Int -> ConduitT ByteString ByteString IO ByteString
forall (m :: * -> *) o.
Monad m =>
Int -> ConduitT ByteString o m ByteString
CB.take Int
3
    case ByteString -> [Word8]
LBS.unpack ByteString
header of
        [Word8
b0, Word8
b1, Word8
b2] -> do
            let len :: Int64
len = Word8 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b0 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
256 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Word8 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b1
                paddingLen :: Int64
paddingLen = Word8 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b2
            ByteString
bs <- Int -> ConduitT ByteString ByteString IO ByteString
forall (m :: * -> *) o.
Monad m =>
Int -> ConduitT ByteString o m ByteString
CB.take (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
len Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
paddingLen))
            if ByteString -> Int64
LBS.length ByteString
bs Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
/= Int64
len Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
paddingLen
                then () -> PaddingConduit
forall a. a -> ConduitT ByteString ByteString IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                else ByteString -> PaddingConduit
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield (ByteString -> ByteString
LBS.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int64 -> ByteString -> ByteString
LBS.take Int64
len ByteString
bs) PaddingConduit -> PaddingConduit -> PaddingConduit
forall a b.
ConduitT ByteString ByteString IO a
-> ConduitT ByteString ByteString IO b
-> ConduitT ByteString ByteString IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> PaddingConduit
removePaddingVariant1 (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        [Word8]
_otherwise   -> () -> PaddingConduit
forall a. a -> ConduitT ByteString ByteString IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()