module Network.OAuth2.Experiment.Pkce (
  mkPkceParam,
  CodeChallenge (..),
  CodeVerifier (..),
  CodeChallengeMethod (..),
  PkceRequestParam (..),
) where

import Control.Monad.IO.Class
import Crypto.Hash qualified as H
import Crypto.Random qualified as Crypto
import Data.Base64.Types qualified as B64
import Data.ByteArray qualified as ByteArray
import Data.ByteString qualified as BS
import Data.ByteString.Base64.URL qualified as B64
import Data.Text (Text)
import Data.Text.Encoding qualified as T
import Data.Word

newtype CodeChallenge = CodeChallenge {CodeChallenge -> Text
unCodeChallenge :: Text}

newtype CodeVerifier = CodeVerifier {CodeVerifier -> Text
unCodeVerifier :: Text}

data CodeChallengeMethod = S256
  deriving (Int -> CodeChallengeMethod -> ShowS
[CodeChallengeMethod] -> ShowS
CodeChallengeMethod -> String
(Int -> CodeChallengeMethod -> ShowS)
-> (CodeChallengeMethod -> String)
-> ([CodeChallengeMethod] -> ShowS)
-> Show CodeChallengeMethod
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CodeChallengeMethod -> ShowS
showsPrec :: Int -> CodeChallengeMethod -> ShowS
$cshow :: CodeChallengeMethod -> String
show :: CodeChallengeMethod -> String
$cshowList :: [CodeChallengeMethod] -> ShowS
showList :: [CodeChallengeMethod] -> ShowS
Show)

data PkceRequestParam = PkceRequestParam
  { PkceRequestParam -> CodeVerifier
codeVerifier :: CodeVerifier
  , PkceRequestParam -> CodeChallenge
codeChallenge :: CodeChallenge
  , PkceRequestParam -> CodeChallengeMethod
codeChallengeMethod :: CodeChallengeMethod
  -- ^ spec says optional but in practice it is S256
  -- https://datatracker.ietf.org/doc/html/rfc7636#section-4.3
  }

mkPkceParam :: MonadIO m => m PkceRequestParam
mkPkceParam :: forall (m :: * -> *). MonadIO m => m PkceRequestParam
mkPkceParam = do
  ByteString
codeV <- m ByteString
forall (m :: * -> *). MonadIO m => m ByteString
genCodeVerifier
  PkceRequestParam -> m PkceRequestParam
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    PkceRequestParam
      { codeVerifier :: CodeVerifier
codeVerifier = Text -> CodeVerifier
CodeVerifier (ByteString -> Text
T.decodeUtf8 ByteString
codeV)
      , codeChallenge :: CodeChallenge
codeChallenge = Text -> CodeChallenge
CodeChallenge (ByteString -> Text
encodeCodeVerifier ByteString
codeV)
      , codeChallengeMethod :: CodeChallengeMethod
codeChallengeMethod = CodeChallengeMethod
S256
      }

encodeCodeVerifier :: BS.ByteString -> Text
encodeCodeVerifier :: ByteString -> Text
encodeCodeVerifier = Base64 'UrlUnpadded Text -> Text
forall (k :: Alphabet) a. Base64 k a -> a
B64.extractBase64 (Base64 'UrlUnpadded Text -> Text)
-> (ByteString -> Base64 'UrlUnpadded Text) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Base64 'UrlUnpadded Text
B64.encodeBase64Unpadded (ByteString -> Base64 'UrlUnpadded Text)
-> (ByteString -> ByteString)
-> ByteString
-> Base64 'UrlUnpadded Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
BS.pack ([Word8] -> ByteString)
-> (ByteString -> [Word8]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Digest SHA256 -> [Word8]
forall a. ByteArrayAccess a => a -> [Word8]
ByteArray.unpack (Digest SHA256 -> [Word8])
-> (ByteString -> Digest SHA256) -> ByteString -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Digest SHA256
hashSHA256

genCodeVerifier :: MonadIO m => m BS.ByteString
genCodeVerifier :: forall (m :: * -> *). MonadIO m => m ByteString
genCodeVerifier = IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> IO ByteString
getBytesInternal ByteString
BS.empty

cvMaxLen :: Int
cvMaxLen :: Int
cvMaxLen = Int
128

-- The default 'getRandomBytes' generates bytes out of unreverved characters scope.
-- code-verifier = 43*128unreserved
--   unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
--   ALPHA = %x41-5A / %x61-7A
--   DIGIT = %x30-39
getBytesInternal :: BS.ByteString -> IO BS.ByteString
getBytesInternal :: ByteString -> IO ByteString
getBytesInternal ByteString
ba
  | ByteString -> Int
BS.length ByteString
ba Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
cvMaxLen = ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> ByteString -> ByteString
BS.take Int
cvMaxLen ByteString
ba)
  | Bool
otherwise = do
      ByteString
bs <- Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
Crypto.getRandomBytes Int
cvMaxLen
      let bsUnreserved :: ByteString
bsUnreserved = ByteString
ba ByteString -> ByteString -> ByteString
`BS.append` (Word8 -> Bool) -> ByteString -> ByteString
BS.filter Word8 -> Bool
isUnreversed ByteString
bs
      ByteString -> IO ByteString
getBytesInternal ByteString
bsUnreserved

hashSHA256 :: BS.ByteString -> H.Digest H.SHA256
hashSHA256 :: ByteString -> Digest SHA256
hashSHA256 = ByteString -> Digest SHA256
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
H.hash

isUnreversed :: Word8 -> Bool
isUnreversed :: Word8 -> Bool
isUnreversed Word8
w = Word8
w Word8 -> ByteString -> Bool
`BS.elem` ByteString
unreverseBS

{-
a-z: 97-122
A-Z: 65-90
-: 45
.: 46
_: 95
~: 126
-}
unreverseBS :: BS.ByteString
unreverseBS :: ByteString
unreverseBS = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Word8
97 .. Word8
122] [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [Word8
65 .. Word8
90] [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [Word8
45, Word8
46, Word8
95, Word8
126]