{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -funbox-strict-fields #-}

{-|
Module      : Database.MySQL.Protocol.Auth
Description : MySQL Auth Packets
Copyright   : (c) Winterland, 2016
License     : BSD
Maintainer  : drkoster@qq.com
Stability   : experimental
Portability : PORTABLE

Auth related packet.

-}

module Database.MySQL.Protocol.Auth where

import           Control.Monad
import           Data.Binary
import           Data.Binary.Get
import           Data.Binary.Parser
import           Data.Binary.Put
import qualified Data.ByteString                as B
import           Data.ByteString.Char8          as BC
import           Data.Bits
import           Database.MySQL.Protocol.Packet

--------------------------------------------------------------------------------
-- Authentications

#define CLIENT_LONG_PASSWORD                  0x00000001
#define CLIENT_FOUND_ROWS                     0x00000002
#define CLIENT_LONG_FLAG                      0x00000004
#define CLIENT_CONNECT_WITH_DB                0x00000008
#define CLIENT_NO_SCHEMA                      0x00000010
#define CLIENT_COMPRESS                       0x00000020
#define CLIENT_ODBC                           0x00000040
#define CLIENT_LOCAL_FILES                    0x00000080
#define CLIENT_IGNORE_SPACE                   0x00000100
#define CLIENT_PROTOCOL_41                    0x00000200
#define CLIENT_INTERACTIVE                    0x00000400
#define CLIENT_SSL                            0x00000800
#define CLIENT_IGNORE_SIGPIPE                 0x00001000
#define CLIENT_TRANSACTIONS                   0x00002000
#define CLIENT_RESERVED                       0x00004000
#define CLIENT_SECURE_CONNECTION              0x00008000
#define CLIENT_MULTI_STATEMENTS               0x00010000
#define CLIENT_MULTI_RESULTS                  0x00020000
#define CLIENT_PS_MULTI_RESULTS               0x00040000
#define CLIENT_PLUGIN_AUTH                    0x00080000
#define CLIENT_CONNECT_ATTRS                  0x00100000
#define CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA 0x00200000

data Greeting = Greeting
    { Greeting -> Word8
greetingProtocol :: !Word8
    , Greeting -> ByteString
greetingVersion  :: !B.ByteString
    , Greeting -> Word32
greetingConnId   :: !Word32
    , Greeting -> ByteString
greetingSalt1    :: !B.ByteString
    , Greeting -> Word32
greetingCaps     :: !Word32
    , Greeting -> Word8
greetingCharset  :: !Word8
    , Greeting -> Word16
greetingStatus   :: !Word16
    , Greeting -> ByteString
greetingSalt2    :: !B.ByteString
    , Greeting -> ByteString
greetingAuthPlugin :: !B.ByteString
    } deriving (Int -> Greeting -> ShowS
[Greeting] -> ShowS
Greeting -> String
(Int -> Greeting -> ShowS)
-> (Greeting -> String) -> ([Greeting] -> ShowS) -> Show Greeting
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Greeting -> ShowS
showsPrec :: Int -> Greeting -> ShowS
$cshow :: Greeting -> String
show :: Greeting -> String
$cshowList :: [Greeting] -> ShowS
showList :: [Greeting] -> ShowS
Show, Greeting -> Greeting -> Bool
(Greeting -> Greeting -> Bool)
-> (Greeting -> Greeting -> Bool) -> Eq Greeting
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Greeting -> Greeting -> Bool
== :: Greeting -> Greeting -> Bool
$c/= :: Greeting -> Greeting -> Bool
/= :: Greeting -> Greeting -> Bool
Eq)

putGreeting :: Greeting -> Put
putGreeting :: Greeting -> Put
putGreeting (Greeting Word8
pv ByteString
sv Word32
cid ByteString
salt1 Word32
cap Word8
charset Word16
st ByteString
salt2 ByteString
authPlugin) = do
    Word8 -> Put
putWord8 Word8
pv
    ByteString -> Put
putByteString ByteString
sv
    Word8 -> Put
putWord8 Word8
0x00
    Word32 -> Put
putWord32le Word32
cid
    ByteString -> Put
putByteString ByteString
salt1
    let capL :: Word16
capL = Word32 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
cap Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. Word16
0xFF
        capH :: Word16
capH = Word32 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
cap Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. Word16
0xFF
    Word16 -> Put
putWord16le Word16
capL
    Word8 -> Put
putWord8 Word8
charset
    Word16 -> Put
putWord16le Word16
st
    Word16 -> Put
putWord16le Word16
capH
    Word8 -> Put
putWord8 (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
salt2)
    Int -> Put -> Put
forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ Int
10 (Word8 -> Put
putWord8 Word8
0x00)
    Bool -> Put -> Put
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word32
cap Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. CLIENT_SECURE_CONNECTION /= 0)
        (ByteString -> Put
putByteString ByteString
salt2)
    Bool -> Put -> Put
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word32
cap Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. CLIENT_PLUGIN_AUTH /= 0)
        (ByteString -> Put
putByteString ByteString
authPlugin)

getGreeting :: Get Greeting
getGreeting :: Get Greeting
getGreeting = do
    Word8
pv <- Get Word8
getWord8
    ByteString
sv <- Get ByteString
getByteStringNul
    Word32
cid <- Get Word32
getWord32le
    ByteString
salt1 <- Int -> Get ByteString
getByteString Int
8
    Int -> Get ()
skipN Int
1  -- 0x00
    Word16
capL <- Get Word16
getWord16le
    Word8
charset <- Get Word8
getWord8
    Word16
status <- Get Word16
getWord16le
    Word16
capH <- Get Word16
getWord16le
    let cap :: Word32
cap = Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
capH Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
16 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
capL
    Word8
_authPluginLen <- Get Word8
getWord8   -- this will issue an unused warning, see the notes below
    Int -> Get ()
skipN Int
10 -- 10 * 0x00
    ByteString
salt2 <- if (Word32
cap Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. CLIENT_SECURE_CONNECTION) == 0
        then ByteString -> Get ByteString
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
B.empty
        else Get ByteString
getByteStringNul   -- This is different with the MySQL document here
                                -- The doc said we should expect a MAX(13, length of auth-plugin-data - 8)
                                -- length bytes, but doing so stop us from login
                                -- anyway 'getByteStringNul' works perfectly here.

    ByteString
authPlugin <- if (Word32
cap Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. CLIENT_PLUGIN_AUTH) == 0
        then ByteString -> Get ByteString
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
B.empty
        else Get ByteString
getByteStringNul

    Greeting -> Get Greeting
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return (Word8
-> ByteString
-> Word32
-> ByteString
-> Word32
-> Word8
-> Word16
-> ByteString
-> ByteString
-> Greeting
Greeting Word8
pv ByteString
sv Word32
cid ByteString
salt1 Word32
cap Word8
charset Word16
status ByteString
salt2 ByteString
authPlugin)

instance Binary Greeting where
    get :: Get Greeting
get = Get Greeting
getGreeting
    put :: Greeting -> Put
put = Greeting -> Put
putGreeting

data Auth = Auth
    { Auth -> Word32
authCaps      :: !Word32
    , Auth -> Word32
authMaxPacket :: !Word32
    , Auth -> Word8
authCharset   :: !Word8
    , Auth -> ByteString
authName      :: !ByteString
    , Auth -> ByteString
authPassword  :: !ByteString
    , Auth -> ByteString
authSchema    :: !ByteString
    } deriving (Int -> Auth -> ShowS
[Auth] -> ShowS
Auth -> String
(Int -> Auth -> ShowS)
-> (Auth -> String) -> ([Auth] -> ShowS) -> Show Auth
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Auth -> ShowS
showsPrec :: Int -> Auth -> ShowS
$cshow :: Auth -> String
show :: Auth -> String
$cshowList :: [Auth] -> ShowS
showList :: [Auth] -> ShowS
Show, Auth -> Auth -> Bool
(Auth -> Auth -> Bool) -> (Auth -> Auth -> Bool) -> Eq Auth
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Auth -> Auth -> Bool
== :: Auth -> Auth -> Bool
$c/= :: Auth -> Auth -> Bool
/= :: Auth -> Auth -> Bool
Eq)

getAuth :: Get Auth
getAuth :: Get Auth
getAuth = do
    Word32
a <- Get Word32
getWord32le
    Word32
m <- Get Word32
getWord32le
    Word8
c <- Get Word8
getWord8
    Int -> Get ()
skipN Int
23
    ByteString
n <- Get ByteString
getByteStringNul
    Auth -> Get Auth
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return (Auth -> Get Auth) -> Auth -> Get Auth
forall a b. (a -> b) -> a -> b
$ Word32
-> Word32
-> Word8
-> ByteString
-> ByteString
-> ByteString
-> Auth
Auth Word32
a Word32
m Word8
c ByteString
n ByteString
B.empty ByteString
B.empty

putAuth :: Auth -> Put
putAuth :: Auth -> Put
putAuth (Auth Word32
cap Word32
m Word8
c ByteString
n ByteString
p ByteString
s) = do
    Word32 -> Put
putWord32le Word32
cap
    Word32 -> Put
putWord32le Word32
m
    Word8 -> Put
putWord8 Word8
c
    Int -> Put -> Put
forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ Int
23 (Word8 -> Put
putWord8 Word8
0x00)
    ByteString -> Put
putByteString ByteString
n Put -> Put -> Put
forall a b. PutM a -> PutM b -> PutM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Word8 -> Put
putWord8 Word8
0x00
    Word8 -> Put
putWord8 (Word8 -> Put) -> Word8 -> Put
forall a b. (a -> b) -> a -> b
$ Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
p)
    ByteString -> Put
putByteString ByteString
p
    ByteString -> Put
putByteString ByteString
s
    Word8 -> Put
putWord8 Word8
0x00

instance Binary Auth where
    get :: Get Auth
get = Get Auth
getAuth
    put :: Auth -> Put
put = Auth -> Put
putAuth

data SSLRequest = SSLRequest
    { SSLRequest -> Word32
sslReqCaps      :: !Word32
    , SSLRequest -> Word32
sslReqMaxPacket :: !Word32
    , SSLRequest -> Word8
sslReqCharset   :: !Word8
    } deriving (Int -> SSLRequest -> ShowS
[SSLRequest] -> ShowS
SSLRequest -> String
(Int -> SSLRequest -> ShowS)
-> (SSLRequest -> String)
-> ([SSLRequest] -> ShowS)
-> Show SSLRequest
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SSLRequest -> ShowS
showsPrec :: Int -> SSLRequest -> ShowS
$cshow :: SSLRequest -> String
show :: SSLRequest -> String
$cshowList :: [SSLRequest] -> ShowS
showList :: [SSLRequest] -> ShowS
Show, SSLRequest -> SSLRequest -> Bool
(SSLRequest -> SSLRequest -> Bool)
-> (SSLRequest -> SSLRequest -> Bool) -> Eq SSLRequest
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SSLRequest -> SSLRequest -> Bool
== :: SSLRequest -> SSLRequest -> Bool
$c/= :: SSLRequest -> SSLRequest -> Bool
/= :: SSLRequest -> SSLRequest -> Bool
Eq)

getSSLRequest :: Get SSLRequest
getSSLRequest :: Get SSLRequest
getSSLRequest = Word32 -> Word32 -> Word8 -> SSLRequest
SSLRequest (Word32 -> Word32 -> Word8 -> SSLRequest)
-> Get Word32 -> Get (Word32 -> Word8 -> SSLRequest)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word32
getWord32le Get (Word32 -> Word8 -> SSLRequest)
-> Get Word32 -> Get (Word8 -> SSLRequest)
forall a b. Get (a -> b) -> Get a -> Get b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get Word32
getWord32le Get (Word8 -> SSLRequest) -> Get Word8 -> Get SSLRequest
forall a b. Get (a -> b) -> Get a -> Get b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get Word8
getWord8 Get SSLRequest -> Get () -> Get SSLRequest
forall a b. Get a -> Get b -> Get a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Int -> Get ()
skipN Int
23

putSSLRequest :: SSLRequest -> Put
putSSLRequest :: SSLRequest -> Put
putSSLRequest (SSLRequest Word32
cap Word32
m Word8
c) = do
    Word32 -> Put
putWord32le Word32
cap
    Word32 -> Put
putWord32le Word32
m
    Word8 -> Put
putWord8 Word8
c
    Int -> Put -> Put
forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ Int
23 (Word8 -> Put
putWord8 Word8
0x00)

instance Binary SSLRequest where
    get :: Get SSLRequest
get = Get SSLRequest
getSSLRequest
    put :: SSLRequest -> Put
put = SSLRequest -> Put
putSSLRequest

--------------------------------------------------------------------------------
-- default Capability Flags

clientCap :: Word32
clientCap :: Word32
clientCap =  CLIENT_LONG_PASSWORD
                Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. CLIENT_LONG_FLAG
                Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. CLIENT_CONNECT_WITH_DB
                Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. CLIENT_IGNORE_SPACE
                Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. CLIENT_PROTOCOL_41
                Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. CLIENT_TRANSACTIONS
                Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. CLIENT_MULTI_STATEMENTS
                Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. CLIENT_MULTI_RESULTS
                Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. CLIENT_SECURE_CONNECTION

clientMaxPacketSize :: Word32
clientMaxPacketSize :: Word32
clientMaxPacketSize = Word32
0x00ffffff :: Word32


supportTLS :: Word32 -> Bool
supportTLS :: Word32 -> Bool
supportTLS Word32
x = (Word32
x Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
CLIENT_SSL) Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word32
0

sslRequest :: Word8 -> SSLRequest
sslRequest :: Word8 -> SSLRequest
sslRequest Word8
charset = Word32 -> Word32 -> Word8 -> SSLRequest
SSLRequest (Word32
clientCap Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. Word32
CLIENT_SSL) Word32
clientMaxPacketSize Word8
charset