module Database.MySQL.TLS (
connect
, connectDetail
, module Data.TLSSetting
) where
import Control.Exception (bracketOnError, throwIO)
import qualified Data.Binary as Binary
import qualified Data.Binary.Put as Binary
import qualified Data.Connection as Conn
import Data.IORef (newIORef)
import Data.TLSSetting
import Database.MySQL.Connection hiding (connect, connectDetail)
import Database.MySQL.Protocol.Auth
import Database.MySQL.Protocol.Packet
import qualified Network.TLS as TLS
import qualified System.IO.Streams.TCP as TCP
import qualified Data.Connection as TCP
import qualified System.IO.Streams.TLS as TLS
connect :: ConnectInfo -> (TLS.ClientParams, String) -> IO MySQLConn
connect :: ConnectInfo -> (ClientParams, String) -> IO MySQLConn
connect ConnectInfo
c (ClientParams, String)
cp = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd (ConnectInfo -> (ClientParams, String) -> IO (Greeting, MySQLConn)
connectDetail ConnectInfo
c (ClientParams, String)
cp)
connectDetail :: ConnectInfo -> (TLS.ClientParams, String) -> IO (Greeting, MySQLConn)
connectDetail :: ConnectInfo -> (ClientParams, String) -> IO (Greeting, MySQLConn)
connectDetail (ConnectInfo String
host PortNumber
port ByteString
db ByteString
user ByteString
pass Word8
charset) (ClientParams
cparams, String
subName) =
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (String -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize String
host PortNumber
port Int
bUFSIZE)
(forall a. Connection a -> IO ()
TCP.close) forall a b. (a -> b) -> a -> b
$ \ TCPConnection
c -> do
let is :: InputStream ByteString
is = forall a. Connection a -> InputStream ByteString
TCP.source TCPConnection
c
InputStream Packet
is' <- InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
is
Packet
p <- InputStream Packet -> IO Packet
readPacket InputStream Packet
is'
Greeting
greet <- forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
p
if Word32 -> Bool
supportTLS (Greeting -> Word32
greetingCaps Greeting
greet)
then do
let cparams' :: ClientParams
cparams' = ClientParams
cparams {
clientUseServerNameIndication :: Bool
TLS.clientUseServerNameIndication = Bool
False
, clientServerIdentification :: (String, ByteString)
TLS.clientServerIdentification = (String
subName, ByteString
"")
}
let (Socket
sock, SockAddr
sockAddr) = forall a. Connection a -> a
Conn.connExtraInfo TCPConnection
c
forall {p} {a}. Binary p => Connection a -> p -> IO ()
write TCPConnection
c (forall a. Binary a => Word8 -> a -> Packet
encodeToPacket Word8
1 forall a b. (a -> b) -> a -> b
$ Word8 -> SSLRequest
sslRequest Word8
charset)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Socket
sock ClientParams
cparams')
( \ Context
ctx -> forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall a. Connection a -> IO ()
TCP.close TCPConnection
c ) forall a b. (a -> b) -> a -> b
$ \ Context
ctx -> do
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
TLSConnection
tc <- (Context, SockAddr) -> IO TLSConnection
TLS.tLsToConnection (Context
ctx, SockAddr
sockAddr)
let tlsIs :: InputStream ByteString
tlsIs = forall a. Connection a -> InputStream ByteString
TCP.source TLSConnection
tc
InputStream Packet
tlsIs' <- InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
tlsIs
let auth :: Auth
auth = ByteString -> ByteString -> ByteString -> Word8 -> Greeting -> Auth
mkAuth ByteString
db ByteString
user ByteString
pass Word8
charset Greeting
greet
forall {p} {a}. Binary p => Connection a -> p -> IO ()
write TLSConnection
tc (forall a. Binary a => Word8 -> a -> Packet
encodeToPacket Word8
2 Auth
auth)
Packet
q <- InputStream Packet -> IO Packet
readPacket InputStream Packet
tlsIs'
if Packet -> Bool
isOK Packet
q
then do
IORef Bool
consumed <- forall a. a -> IO (IORef a)
newIORef Bool
True
let conn :: MySQLConn
conn = InputStream Packet
-> (Packet -> IO ()) -> IO () -> IORef Bool -> MySQLConn
MySQLConn InputStream Packet
tlsIs' (forall {p} {a}. Binary p => Connection a -> p -> IO ()
write TLSConnection
tc) (forall a. Connection a -> IO ()
TCP.close TLSConnection
tc) IORef Bool
consumed
forall (m :: * -> *) a. Monad m => a -> m a
return (Greeting
greet, MySQLConn
conn)
else forall a. Connection a -> IO ()
TCP.close TCPConnection
c forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
q forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall e a. Exception e => e -> IO a
throwIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. ERR -> ERRException
ERRException
else forall a. HasCallStack => String -> a
error String
"Database.MySQL.TLS: server doesn't support TLS connection"
where
connectWithBufferSize :: String -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize String
h PortNumber
p Int
bs = String -> PortNumber -> IO (Socket, SockAddr)
TCP.connectSocket String
h PortNumber
p forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> (Socket, SockAddr) -> IO TCPConnection
TCP.socketToConnection Int
bs
write :: Connection a -> p -> IO ()
write Connection a
c p
a = forall a. Connection a -> ByteString -> IO ()
TCP.send Connection a
c forall a b. (a -> b) -> a -> b
$ Put -> ByteString
Binary.runPut forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Binary t => t -> Put
Binary.put forall a b. (a -> b) -> a -> b
$ p
a