module Database.MySQL.OpenSSL
( connect
, connectDetail
, module Data.OpenSSLSetting
) where
import Control.Exception (bracketOnError, throwIO)
import Control.Monad
import Data.IORef (newIORef)
import Database.MySQL.Connection hiding (connect, connectDetail)
import Database.MySQL.Protocol.Auth
import Database.MySQL.Protocol.Packet
import qualified Network.Socket as N
import qualified OpenSSL as SSL
import qualified OpenSSL.X509 as X509
import qualified OpenSSL.Session as Session
import qualified System.IO.Streams as Stream
import qualified System.IO.Streams.Binary as Binary
import qualified System.IO.Streams.OpenSSL as SSL
import qualified System.IO.Streams.TCP as TCP
import Data.OpenSSLSetting
connect :: ConnectInfo -> (Session.SSLContext, String) -> IO MySQLConn
connect c cp = fmap snd (connectDetail c cp)
connectDetail :: ConnectInfo -> (Session.SSLContext, String) -> IO (Greeting, MySQLConn)
connectDetail (ConnectInfo host port db user pass) (ctx, subname) =
bracketOnError (TCP.connectWithBufferSize host port bUFSIZE)
(\(_, _, sock) -> N.close sock) $ \ (is, os, sock) -> do
is' <- decodeInputStream is
os' <- Binary.encodeOutputStream os
p <- readPacket is'
greet <- decodeFromPacket p
if supportTLS (greetingCaps greet)
then SSL.withOpenSSL $ do
Stream.write (Just (encodeToPacket 1 sslRequest)) os'
bracketOnError (Session.connection ctx sock) SSL.close $ \ ssl -> do
Session.connect ssl
trusted <- Session.getVerifyResult ssl
cert <- Session.getPeerCertificate ssl
subnames <- maybe (return []) (`X509.getSubjectName` False) cert
let cnname = lookup "CN" subnames
verified = maybe False (== subname) cnname
unless (trusted && verified) (throwIO $ Session.ProtocolError "fail to verify certificate")
(sslIs, sslOs) <- SSL.sslToStreams ssl
sslIs' <- decodeInputStream sslIs
sslOs' <- Binary.encodeOutputStream sslOs
let auth = mkAuth db user pass greet
Stream.write (Just (encodeToPacket 2 auth)) sslOs'
q <- readPacket sslIs'
if isOK q
then do
consumed <- newIORef True
let conn = MySQLConn sslIs' sslOs' (SSL.close ssl) consumed
return (greet, conn)
else Stream.write Nothing sslOs' >> decodeFromPacket q >>= throwIO . ERRException
else error "Database.MySQL.OpenSSL: server doesn't support TLS connection"