module Network.OAuth2.Experiment.Pkce (
mkPkceParam,
CodeChallenge (..),
CodeVerifier (..),
CodeChallengeMethod (..),
PkceRequestParam (..),
) where
import Control.Monad.IO.Class
import Crypto.Hash qualified as H
import Crypto.Random qualified as Crypto
import Data.Base64.Types qualified as B64
import Data.ByteArray qualified as ByteArray
import Data.ByteString qualified as BS
import Data.ByteString.Base64.URL qualified as B64
import Data.Text (Text)
import Data.Text.Encoding qualified as T
import Data.Word
newtype CodeChallenge = CodeChallenge {CodeChallenge -> Text
unCodeChallenge :: Text}
newtype CodeVerifier = CodeVerifier {CodeVerifier -> Text
unCodeVerifier :: Text}
data CodeChallengeMethod = S256
deriving (Int -> CodeChallengeMethod -> ShowS
[CodeChallengeMethod] -> ShowS
CodeChallengeMethod -> String
(Int -> CodeChallengeMethod -> ShowS)
-> (CodeChallengeMethod -> String)
-> ([CodeChallengeMethod] -> ShowS)
-> Show CodeChallengeMethod
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CodeChallengeMethod -> ShowS
showsPrec :: Int -> CodeChallengeMethod -> ShowS
$cshow :: CodeChallengeMethod -> String
show :: CodeChallengeMethod -> String
$cshowList :: [CodeChallengeMethod] -> ShowS
showList :: [CodeChallengeMethod] -> ShowS
Show)
data PkceRequestParam = PkceRequestParam
{ PkceRequestParam -> CodeVerifier
codeVerifier :: CodeVerifier
, PkceRequestParam -> CodeChallenge
codeChallenge :: CodeChallenge
, PkceRequestParam -> CodeChallengeMethod
codeChallengeMethod :: CodeChallengeMethod
}
mkPkceParam :: MonadIO m => m PkceRequestParam
mkPkceParam :: forall (m :: * -> *). MonadIO m => m PkceRequestParam
mkPkceParam = do
ByteString
codeV <- m ByteString
forall (m :: * -> *). MonadIO m => m ByteString
genCodeVerifier
PkceRequestParam -> m PkceRequestParam
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
PkceRequestParam
{ codeVerifier :: CodeVerifier
codeVerifier = Text -> CodeVerifier
CodeVerifier (ByteString -> Text
T.decodeUtf8 ByteString
codeV)
, codeChallenge :: CodeChallenge
codeChallenge = Text -> CodeChallenge
CodeChallenge (ByteString -> Text
encodeCodeVerifier ByteString
codeV)
, codeChallengeMethod :: CodeChallengeMethod
codeChallengeMethod = CodeChallengeMethod
S256
}
encodeCodeVerifier :: BS.ByteString -> Text
encodeCodeVerifier :: ByteString -> Text
encodeCodeVerifier = Base64 'UrlUnpadded Text -> Text
forall (k :: Alphabet) a. Base64 k a -> a
B64.extractBase64 (Base64 'UrlUnpadded Text -> Text)
-> (ByteString -> Base64 'UrlUnpadded Text) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Base64 'UrlUnpadded Text
B64.encodeBase64Unpadded (ByteString -> Base64 'UrlUnpadded Text)
-> (ByteString -> ByteString)
-> ByteString
-> Base64 'UrlUnpadded Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
BS.pack ([Word8] -> ByteString)
-> (ByteString -> [Word8]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Digest SHA256 -> [Word8]
forall a. ByteArrayAccess a => a -> [Word8]
ByteArray.unpack (Digest SHA256 -> [Word8])
-> (ByteString -> Digest SHA256) -> ByteString -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Digest SHA256
hashSHA256
genCodeVerifier :: MonadIO m => m BS.ByteString
genCodeVerifier :: forall (m :: * -> *). MonadIO m => m ByteString
genCodeVerifier = IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> IO ByteString
getBytesInternal ByteString
BS.empty
cvMaxLen :: Int
cvMaxLen :: Int
cvMaxLen = Int
128
getBytesInternal :: BS.ByteString -> IO BS.ByteString
getBytesInternal :: ByteString -> IO ByteString
getBytesInternal ByteString
ba
| ByteString -> Int
BS.length ByteString
ba Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
cvMaxLen = ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> ByteString -> ByteString
BS.take Int
cvMaxLen ByteString
ba)
| Bool
otherwise = do
ByteString
bs <- Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
Crypto.getRandomBytes Int
cvMaxLen
let bsUnreserved :: ByteString
bsUnreserved = ByteString
ba ByteString -> ByteString -> ByteString
`BS.append` (Word8 -> Bool) -> ByteString -> ByteString
BS.filter Word8 -> Bool
isUnreversed ByteString
bs
ByteString -> IO ByteString
getBytesInternal ByteString
bsUnreserved
hashSHA256 :: BS.ByteString -> H.Digest H.SHA256
hashSHA256 :: ByteString -> Digest SHA256
hashSHA256 = ByteString -> Digest SHA256
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
H.hash
isUnreversed :: Word8 -> Bool
isUnreversed :: Word8 -> Bool
isUnreversed Word8
w = Word8
w Word8 -> ByteString -> Bool
`BS.elem` ByteString
unreverseBS
unreverseBS :: BS.ByteString
unreverseBS :: ByteString
unreverseBS = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Word8
97 .. Word8
122] [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [Word8
65 .. Word8
90] [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [Word8
45, Word8
46, Word8
95, Word8
126]