module Network.OAuth2.Experiment.Pkce (
mkPkceParam,
CodeChallenge (..),
CodeVerifier (..),
CodeChallengeMethod (..),
PkceRequestParam (..),
) where
import Control.Monad.IO.Class
import Crypto.Hash qualified as H
import Crypto.Random qualified as Crypto
import Data.ByteArray qualified as ByteArray
import Data.ByteString qualified as BS
import Data.ByteString.Base64.URL qualified as B64
import Data.Text (Text)
import Data.Text.Encoding qualified as T
import Data.Word
newtype CodeChallenge = CodeChallenge {CodeChallenge -> Text
unCodeChallenge :: Text}
newtype CodeVerifier = CodeVerifier {CodeVerifier -> Text
unCodeVerifier :: Text}
data CodeChallengeMethod = S256
deriving (Int -> CodeChallengeMethod -> ShowS
[CodeChallengeMethod] -> ShowS
CodeChallengeMethod -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CodeChallengeMethod] -> ShowS
$cshowList :: [CodeChallengeMethod] -> ShowS
show :: CodeChallengeMethod -> String
$cshow :: CodeChallengeMethod -> String
showsPrec :: Int -> CodeChallengeMethod -> ShowS
$cshowsPrec :: Int -> CodeChallengeMethod -> ShowS
Show)
data PkceRequestParam = PkceRequestParam
{ PkceRequestParam -> CodeVerifier
codeVerifier :: CodeVerifier
, PkceRequestParam -> CodeChallenge
codeChallenge :: CodeChallenge
, PkceRequestParam -> CodeChallengeMethod
codeChallengeMethod :: CodeChallengeMethod
}
mkPkceParam :: MonadIO m => m PkceRequestParam
mkPkceParam :: forall (m :: * -> *). MonadIO m => m PkceRequestParam
mkPkceParam = do
ByteString
codeV <- forall (m :: * -> *). MonadIO m => m ByteString
genCodeVerifier
forall (f :: * -> *) a. Applicative f => a -> f a
pure
PkceRequestParam
{ codeVerifier :: CodeVerifier
codeVerifier = Text -> CodeVerifier
CodeVerifier (ByteString -> Text
T.decodeUtf8 ByteString
codeV)
, codeChallenge :: CodeChallenge
codeChallenge = Text -> CodeChallenge
CodeChallenge (ByteString -> Text
encodeCodeVerifier ByteString
codeV)
, codeChallengeMethod :: CodeChallengeMethod
codeChallengeMethod = CodeChallengeMethod
S256
}
encodeCodeVerifier :: BS.ByteString -> Text
encodeCodeVerifier :: ByteString -> Text
encodeCodeVerifier = ByteString -> Text
B64.encodeBase64Unpadded forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
BS.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ByteArrayAccess a => a -> [Word8]
ByteArray.unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Digest SHA256
hashSHA256
genCodeVerifier :: MonadIO m => m BS.ByteString
genCodeVerifier :: forall (m :: * -> *). MonadIO m => m ByteString
genCodeVerifier = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ ByteString -> IO ByteString
getBytesInternal ByteString
BS.empty
cvMaxLen :: Int
cvMaxLen :: Int
cvMaxLen = Int
128
getBytesInternal :: BS.ByteString -> IO BS.ByteString
getBytesInternal :: ByteString -> IO ByteString
getBytesInternal ByteString
ba
| ByteString -> Int
BS.length ByteString
ba forall a. Ord a => a -> a -> Bool
>= Int
cvMaxLen = forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> ByteString -> ByteString
BS.take Int
cvMaxLen ByteString
ba)
| Bool
otherwise = do
ByteString
bs <- forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
Crypto.getRandomBytes Int
cvMaxLen
let bsUnreserved :: ByteString
bsUnreserved = ByteString
ba ByteString -> ByteString -> ByteString
`BS.append` (Word8 -> Bool) -> ByteString -> ByteString
BS.filter Word8 -> Bool
isUnreversed ByteString
bs
ByteString -> IO ByteString
getBytesInternal ByteString
bsUnreserved
hashSHA256 :: BS.ByteString -> H.Digest H.SHA256
hashSHA256 :: ByteString -> Digest SHA256
hashSHA256 = forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
H.hash
isUnreversed :: Word8 -> Bool
isUnreversed :: Word8 -> Bool
isUnreversed Word8
w = Word8
w Word8 -> ByteString -> Bool
`BS.elem` ByteString
unreverseBS
unreverseBS :: BS.ByteString
unreverseBS :: ByteString
unreverseBS = [Word8] -> ByteString
BS.pack forall a b. (a -> b) -> a -> b
$ [Word8
97 .. Word8
122] forall a. [a] -> [a] -> [a]
++ [Word8
65 .. Word8
90] forall a. [a] -> [a] -> [a]
++ [Word8
45, Word8
46, Word8
95, Word8
126]