module Database.MySQL.TLS (
connect
, connectDetail
, module Data.TLSSetting
) where
import Control.Exception (bracketOnError, throwIO)
import qualified Data.Binary as Binary
import qualified Data.Binary.Put as Binary
import qualified Data.Connection as Conn
import Data.IORef (newIORef)
import Data.TLSSetting
import Database.MySQL.Connection hiding (connect, connectDetail)
import Database.MySQL.Protocol.Auth
import Database.MySQL.Protocol.Packet
import qualified Network.TLS as TLS
import qualified System.IO.Streams.TCP as TCP
import qualified Data.Connection as TCP
import qualified System.IO.Streams.TLS as TLS
connect :: ConnectInfo -> (TLS.ClientParams, String) -> IO MySQLConn
connect :: ConnectInfo -> (ClientParams, String) -> IO MySQLConn
connect ConnectInfo
c (ClientParams, String)
cp = ((Greeting, MySQLConn) -> MySQLConn)
-> IO (Greeting, MySQLConn) -> IO MySQLConn
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Greeting, MySQLConn) -> MySQLConn
forall a b. (a, b) -> b
snd (ConnectInfo -> (ClientParams, String) -> IO (Greeting, MySQLConn)
connectDetail ConnectInfo
c (ClientParams, String)
cp)
connectDetail :: ConnectInfo -> (TLS.ClientParams, String) -> IO (Greeting, MySQLConn)
connectDetail :: ConnectInfo -> (ClientParams, String) -> IO (Greeting, MySQLConn)
connectDetail (ConnectInfo String
host PortNumber
port ByteString
db ByteString
user ByteString
pass Word8
charset) (ClientParams
cparams, String
subName) =
IO TCPConnection
-> (TCPConnection -> IO ())
-> (TCPConnection -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (String -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize String
host PortNumber
port Int
bUFSIZE)
(TCPConnection -> IO ()
forall a. Connection a -> IO ()
TCP.close) ((TCPConnection -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn))
-> (TCPConnection -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b. (a -> b) -> a -> b
$ \ TCPConnection
c -> do
let is :: InputStream ByteString
is = TCPConnection -> InputStream ByteString
forall a. Connection a -> InputStream ByteString
TCP.source TCPConnection
c
InputStream Packet
is' <- InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
is
Packet
p <- InputStream Packet -> IO Packet
readPacket InputStream Packet
is'
Greeting
greet <- Packet -> IO Greeting
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
p
if Word32 -> Bool
supportTLS (Greeting -> Word32
greetingCaps Greeting
greet)
then do
let cparams' :: ClientParams
cparams' = ClientParams
cparams {
TLS.clientUseServerNameIndication = False
, TLS.clientServerIdentification = (subName, "")
}
let (Socket
sock, SockAddr
sockAddr) = TCPConnection -> (Socket, SockAddr)
forall a. Connection a -> a
Conn.connExtraInfo TCPConnection
c
TCPConnection -> Packet -> IO ()
forall {p} {a}. Binary p => Connection a -> p -> IO ()
write TCPConnection
c (Word8 -> SSLRequest -> Packet
forall a. Binary a => Word8 -> a -> Packet
encodeToPacket Word8
1 (SSLRequest -> Packet) -> SSLRequest -> Packet
forall a b. (a -> b) -> a -> b
$ Word8 -> SSLRequest
sslRequest Word8
charset)
IO Context
-> (Context -> IO ())
-> (Context -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (Socket -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Socket
sock ClientParams
cparams')
( \ Context
ctx -> Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TCPConnection -> IO ()
forall a. Connection a -> IO ()
TCP.close TCPConnection
c ) ((Context -> IO (Greeting, MySQLConn)) -> IO (Greeting, MySQLConn))
-> (Context -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b. (a -> b) -> a -> b
$ \ Context
ctx -> do
Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
TLSConnection
tc <- (Context, SockAddr) -> IO TLSConnection
TLS.tLsToConnection (Context
ctx, SockAddr
sockAddr)
let tlsIs :: InputStream ByteString
tlsIs = TLSConnection -> InputStream ByteString
forall a. Connection a -> InputStream ByteString
TCP.source TLSConnection
tc
InputStream Packet
tlsIs' <- InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
tlsIs
let auth :: Auth
auth = ByteString -> ByteString -> ByteString -> Word8 -> Greeting -> Auth
mkAuth ByteString
db ByteString
user ByteString
pass Word8
charset Greeting
greet
TLSConnection -> Packet -> IO ()
forall {p} {a}. Binary p => Connection a -> p -> IO ()
write TLSConnection
tc (Word8 -> Auth -> Packet
forall a. Binary a => Word8 -> a -> Packet
encodeToPacket Word8
2 Auth
auth)
Packet
q <- InputStream Packet -> IO Packet
readPacket InputStream Packet
tlsIs'
if Packet -> Bool
isOK Packet
q
then do
IORef Bool
consumed <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
True
let conn :: MySQLConn
conn = InputStream Packet
-> (Packet -> IO ()) -> IO () -> IORef Bool -> MySQLConn
MySQLConn InputStream Packet
tlsIs' (TLSConnection -> Packet -> IO ()
forall {p} {a}. Binary p => Connection a -> p -> IO ()
write TLSConnection
tc) (TLSConnection -> IO ()
forall a. Connection a -> IO ()
TCP.close TLSConnection
tc) IORef Bool
consumed
(Greeting, MySQLConn) -> IO (Greeting, MySQLConn)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Greeting
greet, MySQLConn
conn)
else TCPConnection -> IO ()
forall a. Connection a -> IO ()
TCP.close TCPConnection
c IO () -> IO ERR -> IO ERR
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Packet -> IO ERR
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
q IO ERR
-> (ERR -> IO (Greeting, MySQLConn)) -> IO (Greeting, MySQLConn)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ERRException -> IO (Greeting, MySQLConn)
forall e a. Exception e => e -> IO a
throwIO (ERRException -> IO (Greeting, MySQLConn))
-> (ERR -> ERRException) -> ERR -> IO (Greeting, MySQLConn)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ERR -> ERRException
ERRException
else String -> IO (Greeting, MySQLConn)
forall a. HasCallStack => String -> a
error String
"Database.MySQL.TLS: server doesn't support TLS connection"
where
connectWithBufferSize :: String -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize String
h PortNumber
p Int
bs = String -> PortNumber -> IO (Socket, SockAddr)
TCP.connectSocket String
h PortNumber
p IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO TCPConnection) -> IO TCPConnection
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> (Socket, SockAddr) -> IO TCPConnection
TCP.socketToConnection Int
bs
write :: Connection a -> p -> IO ()
write Connection a
c p
a = Connection a -> ByteString -> IO ()
forall a. Connection a -> ByteString -> IO ()
TCP.send Connection a
c (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Put -> ByteString
Binary.runPut (Put -> ByteString) -> (p -> Put) -> p -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. p -> Put
forall t. Binary t => t -> Put
Binary.put (p -> ByteString) -> p -> ByteString
forall a b. (a -> b) -> a -> b
$ p
a