module Network.HaskellNet.Auth
where

import Crypto.Hash.MD5
import qualified Codec.Binary.Base64.String as B64 (encode, decode)

import Data.Word
import Data.List
import Data.Bits
import Data.Array
import qualified Data.ByteString as B

type UserName = String
type Password = String

data AuthType = PLAIN
              | LOGIN
              | CRAM_MD5
                deriving AuthType -> AuthType -> Bool
(AuthType -> AuthType -> Bool)
-> (AuthType -> AuthType -> Bool) -> Eq AuthType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AuthType -> AuthType -> Bool
$c/= :: AuthType -> AuthType -> Bool
== :: AuthType -> AuthType -> Bool
$c== :: AuthType -> AuthType -> Bool
Eq

instance Show AuthType where
    showsPrec :: Int -> AuthType -> ShowS
showsPrec Int
d AuthType
at = Bool -> ShowS -> ShowS
showParen (Int
dInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
app_prec) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString (String -> ShowS) -> String -> ShowS
forall a b. (a -> b) -> a -> b
$ AuthType -> String
showMain AuthType
at
        where app_prec :: Int
app_prec = Int
10
              showMain :: AuthType -> String
showMain AuthType
PLAIN    = String
"PLAIN"
              showMain AuthType
LOGIN    = String
"LOGIN"
              showMain AuthType
CRAM_MD5 = String
"CRAM-MD5"

b64Encode :: String -> String
b64Encode :: ShowS
b64Encode = (Char -> Char) -> ShowS
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Char
forall a. Enum a => Int -> a
toEnum(Int -> Char) -> (Char -> Int) -> Char -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Char -> Int
forall a. Enum a => a -> Int
fromEnum)
          -- Hotfix for https://github.com/jtdaugherty/HaskellNet/issues/61
          ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ShowS
forall a. Eq a => a -> [a] -> [a]
delete Char
'\n'
          ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
B64.encode
          ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Char) -> ShowS
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Char
forall a. Enum a => Int -> a
toEnum(Int -> Char) -> (Char -> Int) -> Char -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Char -> Int
forall a. Enum a => a -> Int
fromEnum)

b64Decode :: String -> String
b64Decode :: ShowS
b64Decode = (Char -> Char) -> ShowS
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Char
forall a. Enum a => Int -> a
toEnum(Int -> Char) -> (Char -> Int) -> Char -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Char -> Int
forall a. Enum a => a -> Int
fromEnum) ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
B64.decode ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Char) -> ShowS
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Char
forall a. Enum a => Int -> a
toEnum(Int -> Char) -> (Char -> Int) -> Char -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Char -> Int
forall a. Enum a => a -> Int
fromEnum)

showOctet :: [Word8] -> String
showOctet :: [Word8] -> String
showOctet = (Word8 -> String) -> [Word8] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Word8 -> String
hexChars
    where hexChars :: Word8 -> String
hexChars Word8
c = [Array Word8 Char
arr Array Word8 Char -> Word8 -> Char
forall i e. Ix i => Array i e -> i -> e
! (Word8
c Word8 -> Word8 -> Word8
forall a. Integral a => a -> a -> a
`div` Word8
16), Array Word8 Char
arr Array Word8 Char -> Word8 -> Char
forall i e. Ix i => Array i e -> i -> e
! (Word8
c Word8 -> Word8 -> Word8
forall a. Integral a => a -> a -> a
`mod` Word8
16)]
          arr :: Array Word8 Char
arr = (Word8, Word8) -> String -> Array Word8 Char
forall i e. Ix i => (i, i) -> [e] -> Array i e
listArray (Word8
0, Word8
15) String
"0123456789abcdef"

hashMD5 :: [Word8] -> [Word8]
hashMD5 :: [Word8] -> [Word8]
hashMD5 = ByteString -> [Word8]
B.unpack (ByteString -> [Word8])
-> ([Word8] -> ByteString) -> [Word8] -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
hash (ByteString -> ByteString)
-> ([Word8] -> ByteString) -> [Word8] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
B.pack

hmacMD5 :: String -> String -> [Word8]
hmacMD5 :: String -> String -> [Word8]
hmacMD5 String
text String
key = [Word8] -> [Word8]
hashMD5 ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall a b. (a -> b) -> a -> b
$ [Word8]
okey [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ [Word8] -> [Word8]
hashMD5 ([Word8]
ikey [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ (Char -> Word8) -> String -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Word8
forall a. Enum a => Int -> a
toEnum(Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Char -> Int
forall a. Enum a => a -> Int
fromEnum) String
text)
    where koc :: [Word8]
koc = (Char -> Word8) -> String -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Word8
forall a. Enum a => Int -> a
toEnum(Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Char -> Int
forall a. Enum a => a -> Int
fromEnum) String
key
          key' :: [Word8]
key' = if [Word8] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word8]
koc Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
64
                 then [Word8] -> [Word8]
hashMD5 ([Word8] -> [Word8]) -> [Word8] -> [Word8]
forall a b. (a -> b) -> a -> b
$ [Word8]
koc [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ Int -> Word8 -> [Word8]
forall a. Int -> a -> [a]
replicate Int
48 Word8
0
                 else [Word8]
koc [Word8] -> [Word8] -> [Word8]
forall a. [a] -> [a] -> [a]
++ Int -> Word8 -> [Word8]
forall a. Int -> a -> [a]
replicate (Int
64Int -> Int -> Int
forall a. Num a => a -> a -> a
-[Word8] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word8]
koc) Word8
0
          ipad :: [Word8]
ipad = Int -> Word8 -> [Word8]
forall a. Int -> a -> [a]
replicate Int
64 Word8
0x36
          opad :: [Word8]
opad = Int -> Word8 -> [Word8]
forall a. Int -> a -> [a]
replicate Int
64 Word8
0x5c
          ikey :: [Word8]
ikey = (Word8 -> Word8 -> Word8) -> [Word8] -> [Word8] -> [Word8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor [Word8]
key' [Word8]
ipad
          okey :: [Word8]
okey = (Word8 -> Word8 -> Word8) -> [Word8] -> [Word8] -> [Word8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor [Word8]
key' [Word8]
opad

plain :: UserName -> Password -> String
plain :: String -> ShowS
plain String
user String
pass = ShowS
b64Encode ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"\0" [String
"", String
user, String
pass]

login :: UserName -> Password -> (String, String)
login :: String -> String -> (String, String)
login String
user String
pass = (ShowS
b64Encode String
user, ShowS
b64Encode String
pass)

cramMD5 :: String -> UserName -> Password -> String
cramMD5 :: String -> String -> ShowS
cramMD5 String
challenge String
user String
pass =
    ShowS
b64Encode (String
user String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Word8] -> String
showOctet (String -> String -> [Word8]
hmacMD5 String
challenge String
pass))

auth :: AuthType -> String -> UserName -> Password -> String
auth :: AuthType -> String -> String -> ShowS
auth AuthType
PLAIN    String
_ String
u String
p = String -> ShowS
plain String
u String
p
auth AuthType
LOGIN    String
_ String
u String
p = let (String
u', String
p') = String -> String -> (String, String)
login String
u String
p in [String] -> String
unwords [String
u', String
p']
auth AuthType
CRAM_MD5 String
c String
u String
p = String -> String -> ShowS
cramMD5 String
c String
u String
p