{-|
Module      : Database.MySQL.Connection
Description : TLS support for mysql-haskell via @tls@ package.
Copyright   : (c) Winterland, 2016
License     : BSD
Maintainer  : drkoster@qq.com
Stability   : experimental
Portability : PORTABLE

This module provides secure MySQL connection using 'tls' package, please make sure your certificate is v3 extension enabled.

-}

module Database.MySQL.TLS (
      connect
    , connectDetail
    , module Data.TLSSetting
    ) where

import           Control.Exception              (bracketOnError, throwIO)
import qualified Data.Binary                    as Binary
import qualified Data.Binary.Put                as Binary
import qualified Data.Connection                as Conn
import           Data.IORef                     (newIORef)
import           Data.TLSSetting
import           Database.MySQL.Connection      hiding (connect, connectDetail)
import           Database.MySQL.Protocol.Auth
import           Database.MySQL.Protocol.Packet
import qualified Network.TLS                    as TLS
import qualified System.IO.Streams.TCP          as TCP
import qualified Data.Connection                as TCP
import qualified System.IO.Streams.TLS          as TLS

--------------------------------------------------------------------------------

-- | Provide a 'TLS.ClientParams' and a subject name to establish a TLS connection.
--
connect :: ConnectInfo -> (TLS.ClientParams, String) -> IO MySQLConn
connect :: ConnectInfo -> (ClientParams, String) -> IO MySQLConn
connect ConnectInfo
c (ClientParams, String)
cp = ((Greeting, MySQLConn) -> MySQLConn)
-> IO (Greeting, MySQLConn) -> IO MySQLConn
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Greeting, MySQLConn) -> MySQLConn
forall a b. (a, b) -> b
snd (ConnectInfo -> (ClientParams, String) -> IO (Greeting, MySQLConn)
connectDetail ConnectInfo
c (ClientParams, String)
cp)

connectDetail :: ConnectInfo -> (TLS.ClientParams, String) -> IO (Greeting, MySQLConn)
connectDetail :: ConnectInfo -> (ClientParams, String) -> IO (Greeting, MySQLConn)
connectDetail (ConnectInfo String
host PortNumber
port ByteString
db ByteString
user ByteString
pass Word8
charset) (ClientParams
cparams, String
subName) =
    IO TCPConnection
-> (TCPConnection -> IO ())
-> (TCPConnection -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (String -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize String
host PortNumber
port Int
bUFSIZE)
       (TCPConnection -> IO ()
forall a. Connection a -> IO ()
TCP.close) ((TCPConnection -> IO (Greeting, MySQLConn))
 -> IO (Greeting, MySQLConn))
-> (TCPConnection -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b. (a -> b) -> a -> b
$ \ TCPConnection
c -> do
            let is :: InputStream ByteString
is = TCPConnection -> InputStream ByteString
forall a. Connection a -> InputStream ByteString
TCP.source TCPConnection
c
            InputStream Packet
is' <- InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
is
            Packet
p <- InputStream Packet -> IO Packet
readPacket InputStream Packet
is'
            Greeting
greet <- Packet -> IO Greeting
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
p
            if Word32 -> Bool
supportTLS (Greeting -> Word32
greetingCaps Greeting
greet)
            then do
                let cparams' :: ClientParams
cparams' = ClientParams
cparams {
                            TLS.clientUseServerNameIndication = False
                        ,   TLS.clientServerIdentification = (subName, "")
                        }
                let (Socket
sock, SockAddr
sockAddr) = TCPConnection -> (Socket, SockAddr)
forall a. Connection a -> a
Conn.connExtraInfo TCPConnection
c
                TCPConnection -> Packet -> IO ()
forall {p} {a}. Binary p => Connection a -> p -> IO ()
write TCPConnection
c (Word8 -> SSLRequest -> Packet
forall a. Binary a => Word8 -> a -> Packet
encodeToPacket Word8
1 (SSLRequest -> Packet) -> SSLRequest -> Packet
forall a b. (a -> b) -> a -> b
$ Word8 -> SSLRequest
sslRequest Word8
charset)
                IO Context
-> (Context -> IO ())
-> (Context -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (Socket -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Socket
sock ClientParams
cparams')
                               ( \ Context
ctx -> Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TCPConnection -> IO ()
forall a. Connection a -> IO ()
TCP.close TCPConnection
c ) ((Context -> IO (Greeting, MySQLConn)) -> IO (Greeting, MySQLConn))
-> (Context -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b. (a -> b) -> a -> b
$ \ Context
ctx -> do
                    Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
                    TLSConnection
tc <- (Context, SockAddr) -> IO TLSConnection
TLS.tLsToConnection (Context
ctx, SockAddr
sockAddr)
                    let tlsIs :: InputStream ByteString
tlsIs = TLSConnection -> InputStream ByteString
forall a. Connection a -> InputStream ByteString
TCP.source TLSConnection
tc
                    InputStream Packet
tlsIs' <- InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
tlsIs
                    let auth :: Auth
auth = ByteString -> ByteString -> ByteString -> Word8 -> Greeting -> Auth
mkAuth ByteString
db ByteString
user ByteString
pass Word8
charset Greeting
greet
                    TLSConnection -> Packet -> IO ()
forall {p} {a}. Binary p => Connection a -> p -> IO ()
write TLSConnection
tc (Word8 -> Auth -> Packet
forall a. Binary a => Word8 -> a -> Packet
encodeToPacket Word8
2 Auth
auth)
                    Packet
q <- InputStream Packet -> IO Packet
readPacket InputStream Packet
tlsIs'
                    if Packet -> Bool
isOK Packet
q
                    then do
                        IORef Bool
consumed <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
True
                        let conn :: MySQLConn
conn = InputStream Packet
-> (Packet -> IO ()) -> IO () -> IORef Bool -> MySQLConn
MySQLConn InputStream Packet
tlsIs' (TLSConnection -> Packet -> IO ()
forall {p} {a}. Binary p => Connection a -> p -> IO ()
write TLSConnection
tc) (TLSConnection -> IO ()
forall a. Connection a -> IO ()
TCP.close TLSConnection
tc) IORef Bool
consumed
                        (Greeting, MySQLConn) -> IO (Greeting, MySQLConn)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Greeting
greet, MySQLConn
conn)
                    else TCPConnection -> IO ()
forall a. Connection a -> IO ()
TCP.close TCPConnection
c IO () -> IO ERR -> IO ERR
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Packet -> IO ERR
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
q IO ERR
-> (ERR -> IO (Greeting, MySQLConn)) -> IO (Greeting, MySQLConn)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ERRException -> IO (Greeting, MySQLConn)
forall e a. Exception e => e -> IO a
throwIO (ERRException -> IO (Greeting, MySQLConn))
-> (ERR -> ERRException) -> ERR -> IO (Greeting, MySQLConn)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ERR -> ERRException
ERRException
            else String -> IO (Greeting, MySQLConn)
forall a. HasCallStack => String -> a
error String
"Database.MySQL.TLS: server doesn't support TLS connection"
  where
    connectWithBufferSize :: String -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize String
h PortNumber
p Int
bs = String -> PortNumber -> IO (Socket, SockAddr)
TCP.connectSocket String
h PortNumber
p IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO TCPConnection) -> IO TCPConnection
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> (Socket, SockAddr) -> IO TCPConnection
TCP.socketToConnection Int
bs
    write :: Connection a -> p -> IO ()
write Connection a
c p
a = Connection a -> ByteString -> IO ()
forall a. Connection a -> ByteString -> IO ()
TCP.send Connection a
c (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Put -> ByteString
Binary.runPut (Put -> ByteString) -> (p -> Put) -> p -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. p -> Put
forall t. Binary t => t -> Put
Binary.put (p -> ByteString) -> p -> ByteString
forall a b. (a -> b) -> a -> b
$ p
a