{-# Language BlockArguments #-}
{-# Language ImportQualifiedPost #-}
{-# Language LambdaCase #-}
{-# Language OverloadedStrings #-}
{-# Language RecordWildCards #-}
{-# Language ViewPatterns #-}
module Client.Authentication.Scram (
Phase1,
Phase2,
initiateScram,
addServerFirst,
addServerFinal,
ScramDigest(..),
mechanismName,
) where
import Control.Monad (guard)
import Data.Bits (xor)
import Data.ByteString (ByteString)
import Data.ByteString qualified as B
import Data.ByteString.Base64 qualified as B64
import Data.ByteString.Char8 qualified as B8
import Data.List (foldl1')
import Data.Text (Text)
import OpenSSL.EVP.Digest ( Digest, digestBS, hmacBS, getDigestByName)
import Irc.Commands (AuthenticatePayload (AuthenticatePayload))
import System.IO.Unsafe (unsafePerformIO)
data ScramDigest
= ScramDigestSha1
| ScramDigestSha2_256
| ScramDigestSha2_512
deriving Int -> ScramDigest -> ShowS
[ScramDigest] -> ShowS
ScramDigest -> String
(Int -> ScramDigest -> ShowS)
-> (ScramDigest -> String)
-> ([ScramDigest] -> ShowS)
-> Show ScramDigest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScramDigest] -> ShowS
$cshowList :: [ScramDigest] -> ShowS
show :: ScramDigest -> String
$cshow :: ScramDigest -> String
showsPrec :: Int -> ScramDigest -> ShowS
$cshowsPrec :: Int -> ScramDigest -> ShowS
Show
mechanismName :: ScramDigest -> Text
mechanismName :: ScramDigest -> Text
mechanismName ScramDigest
digest =
case ScramDigest
digest of
ScramDigest
ScramDigestSha1 -> Text
"SCRAM-SHA-1"
ScramDigest
ScramDigestSha2_256 -> Text
"SCRAM-SHA-256"
ScramDigest
ScramDigestSha2_512 -> Text
"SCRAM-SHA-512"
data Phase1 = Phase1
{ Phase1 -> ScramDigest
phase1Digest :: ScramDigest
, Phase1 -> ByteString
phase1Password :: ByteString
, Phase1 -> ByteString
phase1CbindInput :: ByteString
, Phase1 -> ByteString
phase1Nonce :: ByteString
, Phase1 -> ByteString
phase1ClientFirstBare :: ByteString
}
initiateScram ::
ScramDigest ->
ByteString ->
ByteString ->
ByteString ->
ByteString ->
(AuthenticatePayload, Phase1)
initiateScram :: ScramDigest
-> ByteString
-> ByteString
-> ByteString
-> ByteString
-> (AuthenticatePayload, Phase1)
initiateScram ScramDigest
digest ByteString
user ByteString
authzid ByteString
pass ByteString
nonce =
(ByteString -> AuthenticatePayload
AuthenticatePayload ByteString
clientFirstMessage, Phase1 :: ScramDigest
-> ByteString -> ByteString -> ByteString -> ByteString -> Phase1
Phase1
{ phase1Digest :: ScramDigest
phase1Digest = ScramDigest
digest
, phase1Password :: ByteString
phase1Password = ByteString
pass
, phase1CbindInput :: ByteString
phase1CbindInput = ByteString -> ByteString
B64.encode ByteString
gs2Header
, phase1Nonce :: ByteString
phase1Nonce = ByteString
nonce
, phase1ClientFirstBare :: ByteString
phase1ClientFirstBare = ByteString
clientFirstMessageBare
})
where
clientFirstMessage :: ByteString
clientFirstMessage = ByteString
gs2Header ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
clientFirstMessageBare
gs2Header :: ByteString
gs2Header = ByteString
"n," ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
encodeUsername ByteString
authzid ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
","
clientFirstMessageBare :: ByteString
clientFirstMessageBare = ByteString
"n=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
encodeUsername ByteString
user ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
",r=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
nonce
newtype Phase2 = Phase2
{ Phase2 -> ByteString
phase2ServerSignature :: ByteString
}
addServerFirst ::
Phase1 ->
ByteString ->
Maybe (AuthenticatePayload, Phase2)
addServerFirst :: Phase1 -> ByteString -> Maybe (AuthenticatePayload, Phase2)
addServerFirst Phase1{ByteString
ScramDigest
phase1ClientFirstBare :: ByteString
phase1Nonce :: ByteString
phase1CbindInput :: ByteString
phase1Password :: ByteString
phase1Digest :: ScramDigest
phase1ClientFirstBare :: Phase1 -> ByteString
phase1Nonce :: Phase1 -> ByteString
phase1CbindInput :: Phase1 -> ByteString
phase1Password :: Phase1 -> ByteString
phase1Digest :: Phase1 -> ScramDigest
..} ByteString
serverFirstMessage =
do
(ByteString
"r", ByteString
nonce) :
(ByteString
"s", ByteString -> Either String ByteString
B64.decode -> Right ByteString
salt) :
(ByteString
"i", ByteString -> Maybe (Int, ByteString)
B8.readInt -> Just (Int
iterations, ByteString
"")) :
[(ByteString, ByteString)]
_extensions
<- [(ByteString, ByteString)] -> Maybe [(ByteString, ByteString)]
forall a. a -> Maybe a
Just (ByteString -> [(ByteString, ByteString)]
parseMessage ByteString
serverFirstMessage)
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString -> ByteString -> Bool
B.isPrefixOf ByteString
phase1Nonce ByteString
nonce Bool -> Bool -> Bool
&& ByteString
phase1Nonce ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
nonce)
let clientFinalWithoutProof :: ByteString
clientFinalWithoutProof = ByteString
"c=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
phase1CbindInput ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
",r=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
nonce
let authMessage :: ByteString
authMessage =
ByteString
phase1ClientFirstBare ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"," ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>
ByteString
serverFirstMessage ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"," ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>
ByteString
clientFinalWithoutProof
let (ByteString
clientProof, ByteString
serverSignature) =
ScramDigest
-> ByteString
-> ByteString
-> Int
-> ByteString
-> (ByteString, ByteString)
crypto ScramDigest
phase1Digest ByteString
phase1Password ByteString
salt Int
iterations ByteString
authMessage
let proof :: ByteString
proof = ByteString
"p=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
B64.encode ByteString
clientProof
let clientFinalMessage :: ByteString
clientFinalMessage = ByteString
clientFinalWithoutProof ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"," ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
proof
let phase2 :: Phase2
phase2 = Phase2 :: ByteString -> Phase2
Phase2 { phase2ServerSignature :: ByteString
phase2ServerSignature = ByteString -> ByteString
B64.encode ByteString
serverSignature }
(AuthenticatePayload, Phase2)
-> Maybe (AuthenticatePayload, Phase2)
forall a. a -> Maybe a
Just (ByteString -> AuthenticatePayload
AuthenticatePayload ByteString
clientFinalMessage, Phase2
phase2)
addServerFinal ::
Phase2 ->
ByteString ->
Bool
addServerFinal :: Phase2 -> ByteString -> Bool
addServerFinal Phase2{ByteString
phase2ServerSignature :: ByteString
phase2ServerSignature :: Phase2 -> ByteString
..} ByteString
serverFinalMessage =
case ByteString -> [(ByteString, ByteString)]
parseMessage ByteString
serverFinalMessage of
(ByteString
"v", ByteString
sig) : [(ByteString, ByteString)]
_extensions -> ByteString
sig ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
phase2ServerSignature
[(ByteString, ByteString)]
_ -> Bool
False
int1 :: ByteString
int1 :: ByteString
int1 = [Word8] -> ByteString
B.pack [Word8
0,Word8
0,Word8
0,Word8
1]
xorBS :: ByteString -> ByteString -> ByteString
xorBS :: ByteString -> ByteString -> ByteString
xorBS ByteString
x ByteString
y = [Word8] -> ByteString
B.pack ((Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
x ByteString
y)
hi ::
Digest ->
ByteString ->
ByteString ->
Int ->
ByteString
hi :: Digest -> ByteString -> ByteString -> Int -> ByteString
hi Digest
digest ByteString
str ByteString
salt Int
n = (ByteString -> ByteString -> ByteString)
-> [ByteString] -> ByteString
forall a. (a -> a -> a) -> [a] -> a
foldl1' ByteString -> ByteString -> ByteString
xorBS (Int -> [ByteString] -> [ByteString]
forall a. Int -> [a] -> [a]
take Int
n [ByteString]
us)
where
u1 :: ByteString
u1 = Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
digest ByteString
str (ByteString
salt ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
int1)
us :: [ByteString]
us = (ByteString -> ByteString) -> ByteString -> [ByteString]
forall a. (a -> a) -> a -> [a]
iterate (Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
digest ByteString
str) ByteString
u1
parseMessage :: ByteString -> [(ByteString, ByteString)]
parseMessage :: ByteString -> [(ByteString, ByteString)]
parseMessage ByteString
msg =
[case (Char -> Bool) -> ByteString -> (ByteString, ByteString)
B8.break (Char
'='Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
==) ByteString
entry of
(ByteString
key, ByteString
value) -> (ByteString
key, Int -> ByteString -> ByteString
B.drop Int
1 ByteString
value)
| ByteString
entry <- Char -> ByteString -> [ByteString]
B8.split Char
',' ByteString
msg]
crypto ::
ScramDigest ->
ByteString ->
ByteString ->
Int ->
ByteString ->
(ByteString, ByteString)
crypto :: ScramDigest
-> ByteString
-> ByteString
-> Int
-> ByteString
-> (ByteString, ByteString)
crypto ScramDigest
digest ByteString
password ByteString
salt Int
iterations ByteString
authMessage =
(ByteString
clientProof, ByteString
serverSignature)
where
saltedPassword :: ByteString
saltedPassword = Digest -> ByteString -> ByteString -> Int -> ByteString
hi Digest
d ByteString
password ByteString
salt Int
iterations
clientKey :: ByteString
clientKey = Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
d ByteString
saltedPassword ByteString
"Client Key"
storedKey :: ByteString
storedKey = Digest -> ByteString -> ByteString
digestBS Digest
d ByteString
clientKey
clientSignature :: ByteString
clientSignature = Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
d ByteString
storedKey ByteString
authMessage
clientProof :: ByteString
clientProof = ByteString -> ByteString -> ByteString
xorBS ByteString
clientKey ByteString
clientSignature
serverKey :: ByteString
serverKey = Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
d ByteString
saltedPassword ByteString
"Server Key"
serverSignature :: ByteString
serverSignature = Digest -> ByteString -> ByteString -> ByteString
hmacBS Digest
d ByteString
serverKey ByteString
authMessage
digestName :: String
digestName =
case ScramDigest
digest of
ScramDigest
ScramDigestSha1 -> String
"SHA1"
ScramDigest
ScramDigestSha2_256 -> String
"SHA256"
ScramDigest
ScramDigestSha2_512 -> String
"SHA512"
Just Digest
d = IO (Maybe Digest) -> Maybe Digest
forall a. IO a -> a
unsafePerformIO (String -> IO (Maybe Digest)
getDigestByName String
digestName)
encodeUsername :: ByteString -> ByteString
encodeUsername :: ByteString -> ByteString
encodeUsername = (Char -> ByteString) -> ByteString -> ByteString
B8.concatMap \case
Char
',' -> ByteString
"=2C"
Char
'=' -> ByteString
"=3D"
Char
x -> Char -> ByteString
B8.singleton Char
x