{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}

module Network.Xmpp.Sasl.Mechanisms.Scram
  where

import           Control.Applicative ((<$>))
import           Control.Monad
import           Control.Monad.Except
import           Control.Monad.State.Strict
import qualified Crypto.Classes          as Crypto
import qualified Crypto.HMAC             as Crypto
import qualified Crypto.Hash.CryptoAPI   as Crypto
import qualified Data.ByteString         as BS
import qualified Data.ByteString.Base64  as B64
import           Data.ByteString.Char8   as BS8 (unpack)
import           Data.List (foldl1', genericTake)
import qualified Data.Text               as Text
import qualified Data.Text.Encoding      as Text
import           Network.Xmpp.Sasl.Common
import           Network.Xmpp.Sasl.Types
import           Network.Xmpp.Types

-- | A nicer name for undefined, for use as a dummy token to determin
-- the hash function to use
hashToken :: (Crypto.Hash ctx hash) => hash
hashToken :: forall ctx hash. Hash ctx hash => hash
hashToken = forall a. HasCallStack => a
undefined

-- | Salted Challenge Response Authentication Mechanism (SCRAM) SASL
-- mechanism according to RFC 5802.
--
-- This implementation is independent and polymorphic in the used hash function.
scram :: (Crypto.Hash ctx hash)
      => hash            -- ^ Dummy argument to determine the hash to use; you
                         --   can safely pass undefined or a 'hashToken' to it
      -> Text.Text       -- ^ Authentication ID (user name)
      -> Maybe Text.Text -- ^ Authorization ID
      -> Text.Text       -- ^ Password
      -> ExceptT AuthFailure (StateT StreamState IO) ()
scram :: forall ctx hash.
Hash ctx hash =>
hash
-> Text
-> Maybe Text
-> Text
-> ExceptT AuthFailure (StateT StreamState IO) ()
scram hash
hToken Text
authcid Maybe Text
authzid Text
password = do
    (Text
ac, Maybe Text
az, Text
pw) <- Text
-> Maybe Text
-> Text
-> ExceptT
     AuthFailure (StateT StreamState IO) (Text, Maybe Text, Text)
prepCredentials Text
authcid Maybe Text
authzid Text
password
    Text
-> Maybe Text
-> Text
-> ExceptT AuthFailure (StateT StreamState IO) ()
scramhelper Text
ac Maybe Text
az Text
pw
  where
    scramhelper :: Text
-> Maybe Text
-> Text
-> ExceptT AuthFailure (StateT StreamState IO) ()
scramhelper Text
authcid' Maybe Text
authzid' Text
pwd = do
        ByteString
cnonce <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ByteString
makeNonce
        ()
_ <- Text
-> Maybe ByteString
-> ExceptT AuthFailure (StateT StreamState IO) ()
saslInit Text
"SCRAM-SHA-1" (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
cFirstMessage ByteString
cnonce)
        ByteString
sFirstMessage <- forall a. Maybe a -> ExceptT AuthFailure (StateT StreamState IO) a
saslFromJust forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExceptT AuthFailure (StateT StreamState IO) (Maybe ByteString)
pullChallenge
        Pairs
prs <- ByteString -> ExceptT AuthFailure (StateT StreamState IO) Pairs
toPairs ByteString
sFirstMessage
        (ByteString
nonce, ByteString
salt, Integer
ic) <- Pairs
-> ByteString
-> ExceptT
     AuthFailure
     (StateT StreamState IO)
     (ByteString, ByteString, Integer)
fromPairs Pairs
prs ByteString
cnonce
        let (ByteString
cfm, ByteString
v) = ByteString
-> ByteString
-> Integer
-> ByteString
-> ByteString
-> (ByteString, ByteString)
cFinalMessageAndVerifier ByteString
nonce ByteString
salt Integer
ic ByteString
sFirstMessage ByteString
cnonce
        ()
_ <- Maybe ByteString -> ExceptT AuthFailure (StateT StreamState IO) ()
respond forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just ByteString
cfm
        Pairs
finalPairs <- ByteString -> ExceptT AuthFailure (StateT StreamState IO) Pairs
toPairs forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. Maybe a -> ExceptT AuthFailure (StateT StreamState IO) a
saslFromJust forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExceptT AuthFailure (StateT StreamState IO) (Maybe ByteString)
pullFinalMessage
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"v" Pairs
finalPairs forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just ByteString
v) forall a b. (a -> b) -> a -> b
$ forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError AuthFailure
AuthOtherFailure -- TODO: Log
        forall (m :: * -> *) a. Monad m => a -> m a
return ()
      where
        -- We need to jump through some hoops to get a polymorphic solution
        encode :: Crypto.Hash ctx hash => hash -> hash -> BS.ByteString
        encode :: forall ctx hash. Hash ctx hash => hash -> hash -> ByteString
encode hash
_hashtoken = forall a. Serialize a => a -> ByteString
Crypto.encode

        hash :: BS.ByteString -> BS.ByteString
        hash :: ByteString -> ByteString
hash ByteString
str = forall ctx hash. Hash ctx hash => hash -> hash -> ByteString
encode hash
hToken forall a b. (a -> b) -> a -> b
$ forall ctx d. (Hash ctx d, Hash ctx d) => ByteString -> d
Crypto.hash' ByteString
str

        hmac :: BS.ByteString -> BS.ByteString -> BS.ByteString
        hmac :: ByteString -> ByteString -> ByteString
hmac ByteString
key ByteString
str = forall ctx hash. Hash ctx hash => hash -> hash -> ByteString
encode hash
hToken forall a b. (a -> b) -> a -> b
$ forall c d. Hash c d => MacKey c d -> ByteString -> d
Crypto.hmac' (forall c d. ByteString -> MacKey c d
Crypto.MacKey ByteString
key) ByteString
str

        authzid'' :: Maybe BS.ByteString
        authzid'' :: Maybe ByteString
authzid''              = (\Text
z -> ByteString
"a=" ByteString -> ByteString -> ByteString
+++ Text -> ByteString
Text.encodeUtf8 Text
z) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Text
authzid'

        gs2CbindFlag :: BS.ByteString
        gs2CbindFlag :: ByteString
gs2CbindFlag         = ByteString
"n" -- we don't support channel binding yet

        gs2Header :: BS.ByteString
        gs2Header :: ByteString
gs2Header            = [ByteString] -> ByteString
merge forall a b. (a -> b) -> a -> b
$ [ ByteString
gs2CbindFlag
                                       , forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString
"" forall a. a -> a
id Maybe ByteString
authzid''
                                       , ByteString
""
                                       ]
        -- cbindData :: BS.ByteString
        -- cbindData            = "" -- we don't support channel binding yet

        cFirstMessageBare :: BS.ByteString -> BS.ByteString
        cFirstMessageBare :: ByteString -> ByteString
cFirstMessageBare ByteString
cnonce = [ByteString] -> ByteString
merge [ ByteString
"n=" ByteString -> ByteString -> ByteString
+++ Text -> ByteString
Text.encodeUtf8 Text
authcid'
                                         , ByteString
"r=" ByteString -> ByteString -> ByteString
+++ ByteString
cnonce]
        cFirstMessage :: BS.ByteString -> BS.ByteString
        cFirstMessage :: ByteString -> ByteString
cFirstMessage ByteString
cnonce = ByteString
gs2Header ByteString -> ByteString -> ByteString
+++ ByteString -> ByteString
cFirstMessageBare ByteString
cnonce

        fromPairs :: Pairs
                  -> BS.ByteString
                  -> ExceptT AuthFailure (StateT StreamState IO) (BS.ByteString, BS.ByteString, Integer)
        fromPairs :: Pairs
-> ByteString
-> ExceptT
     AuthFailure
     (StateT StreamState IO)
     (ByteString, ByteString, Integer)
fromPairs Pairs
prs ByteString
cnonce | Just ByteString
nonce <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"r" Pairs
prs
                             , ByteString
cnonce ByteString -> ByteString -> Bool
`BS.isPrefixOf` ByteString
nonce
                             , Just ByteString
salt' <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"s" Pairs
prs
                             , Right ByteString
salt <- ByteString -> Either String ByteString
B64.decode ByteString
salt'
                             , Just ByteString
ic <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"i" Pairs
prs
                             , [(Integer
i,String
"")] <- forall a. Read a => ReadS a
reads forall a b. (a -> b) -> a -> b
$ ByteString -> String
BS8.unpack ByteString
ic
                             = forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
nonce, ByteString
salt, Integer
i)
        fromPairs Pairs
_ ByteString
_ = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ AuthFailure
AuthOtherFailure -- TODO: Log

        cFinalMessageAndVerifier :: BS.ByteString
                                 -> BS.ByteString
                                 -> Integer
                                 -> BS.ByteString
                                 -> BS.ByteString
                                 -> (BS.ByteString, BS.ByteString)
        cFinalMessageAndVerifier :: ByteString
-> ByteString
-> Integer
-> ByteString
-> ByteString
-> (ByteString, ByteString)
cFinalMessageAndVerifier ByteString
nonce ByteString
salt Integer
ic ByteString
sfm ByteString
cnonce
            =  ([ByteString] -> ByteString
merge [ ByteString
cFinalMessageWOProof
                      , ByteString
"p=" ByteString -> ByteString -> ByteString
+++ ByteString -> ByteString
B64.encode ByteString
clientProof
                      ]
              , ByteString -> ByteString
B64.encode ByteString
serverSignature
              )
          where
            cFinalMessageWOProof :: BS.ByteString
            cFinalMessageWOProof :: ByteString
cFinalMessageWOProof = [ByteString] -> ByteString
merge [ ByteString
"c=" ByteString -> ByteString -> ByteString
+++ ByteString -> ByteString
B64.encode ByteString
gs2Header
                                         , ByteString
"r=" ByteString -> ByteString -> ByteString
+++ ByteString
nonce]

            saltedPassword :: BS.ByteString
            saltedPassword :: ByteString
saltedPassword       = ByteString -> ByteString -> Integer -> ByteString
hi (Text -> ByteString
Text.encodeUtf8 Text
pwd) ByteString
salt Integer
ic

            clientKey :: BS.ByteString
            clientKey :: ByteString
clientKey            = ByteString -> ByteString -> ByteString
hmac ByteString
saltedPassword ByteString
"Client Key"

            storedKey :: BS.ByteString
            storedKey :: ByteString
storedKey            = ByteString -> ByteString
hash ByteString
clientKey

            authMessage :: BS.ByteString
            authMessage :: ByteString
authMessage          = [ByteString] -> ByteString
merge [ ByteString -> ByteString
cFirstMessageBare ByteString
cnonce
                                         , ByteString
sfm
                                         , ByteString
cFinalMessageWOProof
                                         ]

            clientSignature :: BS.ByteString
            clientSignature :: ByteString
clientSignature      = ByteString -> ByteString -> ByteString
hmac ByteString
storedKey ByteString
authMessage

            clientProof :: BS.ByteString
            clientProof :: ByteString
clientProof          = ByteString
clientKey ByteString -> ByteString -> ByteString
`xorBS` ByteString
clientSignature

            serverKey :: BS.ByteString
            serverKey :: ByteString
serverKey            = ByteString -> ByteString -> ByteString
hmac ByteString
saltedPassword ByteString
"Server Key"

            serverSignature :: BS.ByteString
            serverSignature :: ByteString
serverSignature      = ByteString -> ByteString -> ByteString
hmac ByteString
serverKey ByteString
authMessage

            -- helper
            hi :: BS.ByteString -> BS.ByteString -> Integer -> BS.ByteString
            hi :: ByteString -> ByteString -> Integer -> ByteString
hi ByteString
str ByteString
slt Integer
ic' = forall a. (a -> a -> a) -> [a] -> a
foldl1' ByteString -> ByteString -> ByteString
xorBS (forall i a. Integral i => i -> [a] -> [a]
genericTake Integer
ic' [ByteString]
us)
              where
                u1 :: ByteString
u1 = ByteString -> ByteString -> ByteString
hmac ByteString
str (ByteString
slt ByteString -> ByteString -> ByteString
+++ ([Word8] -> ByteString
BS.pack [Word8
0,Word8
0,Word8
0,Word8
1]))
                us :: [ByteString]
us = forall a. (a -> a) -> a -> [a]
iterate (ByteString -> ByteString -> ByteString
hmac ByteString
str) ByteString
u1

scramSha1 :: Username  -- ^ username
          -> Maybe AuthZID -- ^ authorization ID
          -> Password   -- ^ password
          -> SaslHandler
scramSha1 :: Text -> Maybe Text -> Text -> SaslHandler
scramSha1 Text
authcid Maybe Text
authzid Text
passwd =
    ( Text
"SCRAM-SHA-1"
    , do
          Either AuthFailure ()
r <- forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ forall ctx hash.
Hash ctx hash =>
hash
-> Text
-> Maybe Text
-> Text
-> ExceptT AuthFailure (StateT StreamState IO) ()
scram (forall ctx hash. Hash ctx hash => hash
hashToken :: Crypto.SHA1) Text
authcid Maybe Text
authzid Text
passwd
          case Either AuthFailure ()
r of
              Left (AuthStreamFailure XmppFailure
e) -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left XmppFailure
e
              Left AuthFailure
e -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just AuthFailure
e
              Right () -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a. Maybe a
Nothing
    )