-- SPDX-License-Identifier: Apache-2.0
--
-- Copyright (C) 2023 Bin Jin. All Rights Reserved.

module Network.HProx.Util
  ( Password(..)
  , PasswordSalted(..)
  , hashPasswordWithRandomSalt
  , parseHostPort
  , parseHostPortWithDefault
  , passwordReader
  , passwordWriter
  , responseKnownLength
  , splitBy
  , verifyPassword
  ) where

import Data.ByteString       qualified as BS
import Data.ByteString.Char8 qualified as BS8
import Data.ByteString.Lazy  qualified as LBS
import Data.List.NonEmpty    (NonEmpty(..), (<|))
import Data.List.NonEmpty    qualified as NE
import Data.Maybe            (fromMaybe)

import Network.HTTP.Types (ResponseHeaders, Status)
import Network.Wai

import Crypto.Error           (CryptoFailable(..))
import Crypto.KDF.Argon2      qualified as Argon2
import Crypto.Random          (MonadRandom(getRandomBytes))
import Data.ByteString.Base64 qualified as Base64

data Password = PlainText !BS.ByteString
              | Salted !BS.ByteString !BS.ByteString
    deriving (Int -> Password -> ShowS
[Password] -> ShowS
Password -> [Char]
(Int -> Password -> ShowS)
-> (Password -> [Char]) -> ([Password] -> ShowS) -> Show Password
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Password -> ShowS
showsPrec :: Int -> Password -> ShowS
$cshow :: Password -> [Char]
show :: Password -> [Char]
$cshowList :: [Password] -> ShowS
showList :: [Password] -> ShowS
Show, Password -> Password -> Bool
(Password -> Password -> Bool)
-> (Password -> Password -> Bool) -> Eq Password
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Password -> Password -> Bool
== :: Password -> Password -> Bool
$c/= :: Password -> Password -> Bool
/= :: Password -> Password -> Bool
Eq)

data PasswordSalted = PasswordSalted !BS.ByteString !BS.ByteString
    deriving (Int -> PasswordSalted -> ShowS
[PasswordSalted] -> ShowS
PasswordSalted -> [Char]
(Int -> PasswordSalted -> ShowS)
-> (PasswordSalted -> [Char])
-> ([PasswordSalted] -> ShowS)
-> Show PasswordSalted
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PasswordSalted -> ShowS
showsPrec :: Int -> PasswordSalted -> ShowS
$cshow :: PasswordSalted -> [Char]
show :: PasswordSalted -> [Char]
$cshowList :: [PasswordSalted] -> ShowS
showList :: [PasswordSalted] -> ShowS
Show, PasswordSalted -> PasswordSalted -> Bool
(PasswordSalted -> PasswordSalted -> Bool)
-> (PasswordSalted -> PasswordSalted -> Bool) -> Eq PasswordSalted
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PasswordSalted -> PasswordSalted -> Bool
== :: PasswordSalted -> PasswordSalted -> Bool
$c/= :: PasswordSalted -> PasswordSalted -> Bool
/= :: PasswordSalted -> PasswordSalted -> Bool
Eq)

splitBy :: Eq a => a -> [a] -> NonEmpty [a]
splitBy :: forall a. Eq a => a -> [a] -> NonEmpty [a]
splitBy a
_ [] = [a] -> NonEmpty [a]
forall a. a -> NonEmpty a
NE.singleton []
splitBy a
c (a
x:[a]
xs)
  | a
c a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x    = [] [a] -> NonEmpty [a] -> NonEmpty [a]
forall a. a -> NonEmpty a -> NonEmpty a
<| a -> [a] -> NonEmpty [a]
forall a. Eq a => a -> [a] -> NonEmpty [a]
splitBy a
c [a]
xs
  | Bool
otherwise = let [a]
y :| [[a]]
ys = a -> [a] -> NonEmpty [a]
forall a. Eq a => a -> [a] -> NonEmpty [a]
splitBy a
c [a]
xs in (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
y) [a] -> [[a]] -> NonEmpty [a]
forall a. a -> [a] -> NonEmpty a
:| [[a]]
ys

passwordReader :: BS.ByteString -> Maybe (BS.ByteString, Password)
passwordReader :: ByteString -> Maybe (ByteString, Password)
passwordReader ByteString
line = case Char -> ByteString -> [ByteString]
BS8.split Char
':' ByteString
line of
    [ByteString
user, ByteString
pass]         -> (ByteString, Password) -> Maybe (ByteString, Password)
forall a. a -> Maybe a
Just (ByteString
user, ByteString -> Password
PlainText ByteString
pass)
    [ByteString
user, ByteString
salt, ByteString
hashed] -> case (ByteString -> Either [Char] ByteString
Base64.decode ByteString
salt, ByteString -> Either [Char] ByteString
Base64.decode ByteString
hashed) of
                                (Right ByteString
salt', Right ByteString
hashed') -> (ByteString, Password) -> Maybe (ByteString, Password)
forall a. a -> Maybe a
Just (ByteString
user, ByteString -> ByteString -> Password
Salted ByteString
salt' ByteString
hashed')
                                (Either [Char] ByteString, Either [Char] ByteString)
_otherwise                   -> Maybe (ByteString, Password)
forall a. Maybe a
Nothing
    [ByteString]
_otherwise           -> Maybe (ByteString, Password)
forall a. Maybe a
Nothing

passwordWriter :: BS.ByteString -> PasswordSalted -> BS.ByteString
passwordWriter :: ByteString -> PasswordSalted -> ByteString
passwordWriter ByteString
user (PasswordSalted ByteString
salt ByteString
hash) =
    [ByteString] -> ByteString
BS.concat [ByteString
user , ByteString
":" , ByteString -> ByteString
Base64.encode ByteString
salt , ByteString
":" , ByteString -> ByteString
Base64.encode ByteString
hash]

hashPasswordWithRandomSalt :: Password -> IO PasswordSalted
hashPasswordWithRandomSalt :: Password -> IO PasswordSalted
hashPasswordWithRandomSalt (PlainText ByteString
pass) = do
    ByteString
salt <- Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
24
    case Options
-> ByteString -> ByteString -> Int -> CryptoFailable ByteString
forall password salt out.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray out) =>
Options -> password -> salt -> Int -> CryptoFailable out
Argon2.hash Options
Argon2.defaultOptions ByteString
pass ByteString
salt Int
48 of
        CryptoFailed CryptoError
err -> [Char] -> IO PasswordSalted
forall a. HasCallStack => [Char] -> a
error ([Char]
"unable to hash password with salt: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ CryptoError -> [Char]
forall a. Show a => a -> [Char]
show CryptoError
err)
        CryptoPassed ByteString
h   -> PasswordSalted -> IO PasswordSalted
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteString -> PasswordSalted
PasswordSalted ByteString
salt ByteString
h)
hashPasswordWithRandomSalt (Salted ByteString
salt ByteString
h) = PasswordSalted -> IO PasswordSalted
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteString -> PasswordSalted
PasswordSalted ByteString
salt ByteString
h)

verifyPassword :: PasswordSalted -> BS8.ByteString -> Bool
verifyPassword :: PasswordSalted -> ByteString -> Bool
verifyPassword (PasswordSalted ByteString
salt ByteString
hashed) ByteString
pass =
    case Options
-> ByteString -> ByteString -> Int -> CryptoFailable ByteString
forall password salt out.
(ByteArrayAccess password, ByteArrayAccess salt, ByteArray out) =>
Options -> password -> salt -> Int -> CryptoFailable out
Argon2.hash Options
Argon2.defaultOptions ByteString
pass ByteString
salt Int
48 of
        CryptoFailed CryptoError
_ -> Bool
False
        CryptoPassed ByteString
h -> ByteString
h ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
hashed

parseHostPort :: BS.ByteString -> Maybe (BS.ByteString, Int)
parseHostPort :: ByteString -> Maybe (ByteString, Int)
parseHostPort ByteString
hostPort = do
    Int
lastColon <- Char -> ByteString -> Maybe Int
BS8.elemIndexEnd Char
':' ByteString
hostPort
    Int
port <- ByteString -> Maybe (Int, ByteString)
BS8.readInt (Int -> ByteString -> ByteString
BS.drop (Int
lastColonInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) ByteString
hostPort) Maybe (Int, ByteString)
-> ((Int, ByteString) -> Maybe Int) -> Maybe Int
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Int, ByteString) -> Maybe Int
forall {a}. (Ord a, Num a) => (a, ByteString) -> Maybe a
checkPort
    (ByteString, Int) -> Maybe (ByteString, Int)
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> ByteString -> ByteString
BS.take Int
lastColon ByteString
hostPort, Int
port)
  where
    checkPort :: (a, ByteString) -> Maybe a
checkPort (a
p, ByteString
bs)
        | ByteString -> Bool
BS.null ByteString
bs Bool -> Bool -> Bool
&& a
1 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
p Bool -> Bool -> Bool
&& a
p a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
65535 = a -> Maybe a
forall a. a -> Maybe a
Just a
p
        | Bool
otherwise                          = Maybe a
forall a. Maybe a
Nothing

parseHostPortWithDefault :: Int -> BS.ByteString -> (BS.ByteString, Int)
parseHostPortWithDefault :: Int -> ByteString -> (ByteString, Int)
parseHostPortWithDefault Int
defaultPort ByteString
hostPort =
    (ByteString, Int) -> Maybe (ByteString, Int) -> (ByteString, Int)
forall a. a -> Maybe a -> a
fromMaybe (ByteString
hostPort, Int
defaultPort) (Maybe (ByteString, Int) -> (ByteString, Int))
-> Maybe (ByteString, Int) -> (ByteString, Int)
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe (ByteString, Int)
parseHostPort ByteString
hostPort

responseKnownLength :: Status -> ResponseHeaders -> LBS.ByteString -> Response
responseKnownLength :: Status -> ResponseHeaders -> ByteString -> Response
responseKnownLength Status
status ResponseHeaders
headers ByteString
bs = Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status (ResponseHeaders
headers ResponseHeaders -> ResponseHeaders -> ResponseHeaders
forall a. [a] -> [a] -> [a]
++ [(HeaderName
"Content-Length", [Char] -> ByteString
BS8.pack ([Char] -> ByteString) -> [Char] -> ByteString
forall a b. (a -> b) -> a -> b
$ Int64 -> [Char]
forall a. Show a => a -> [Char]
show (ByteString -> Int64
LBS.length ByteString
bs))]) ByteString
bs