module System.IO.Streams.OpenSSL
(
connect
, withConnection
, accept
, withOpenSSL
, sslToStreams
, closeSSL
) where
import qualified Control.Exception as E
import Control.Monad (unless, void)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as S
import Network.Socket (HostName, PortNumber, Socket)
import qualified Network.Socket as N
import OpenSSL (withOpenSSL)
import OpenSSL.Session (SSL, SSLContext)
import qualified OpenSSL.Session as SSL
import qualified OpenSSL.X509 as X509
import System.IO.Streams (InputStream, OutputStream)
import qualified System.IO.Streams as Streams
import qualified System.IO.Streams.TCP as TCP
bUFSIZ :: Int
bUFSIZ = 32752
sslToStreams :: SSL
-> IO (InputStream ByteString, OutputStream ByteString)
sslToStreams ssl = do
is <- Streams.makeInputStream input
os <- Streams.makeOutputStream output
return (is, os)
where
input = do
s <- SSL.read ssl bUFSIZ
return $! if S.null s then Nothing else Just s
`E.onException` return Nothing
output Nothing = return ()
output (Just s) = SSL.write ssl s
closeSSL :: SSL.SSL -> IO ()
closeSSL ssl = do
SSL.shutdown ssl SSL.Unidirectional
maybe (return ()) N.close (SSL.sslSocket ssl)
connect :: SSLContext
-> Maybe String
-> HostName
-> PortNumber
-> IO (InputStream ByteString, OutputStream ByteString, SSL)
connect ctx subname host port = do
sock <- TCP.connectSocket host port
E.bracketOnError (SSL.connection ctx sock) closeSSL $ \ ssl -> do
SSL.connect ssl
trusted <- SSL.getVerifyResult ssl
cert <- SSL.getPeerCertificate ssl
subnames <- maybe (return []) (`X509.getSubjectName` False) cert
let cnname = lookup "CN" subnames
verified = case subname of
Just subname' -> maybe False (== subname') cnname
Nothing -> maybe False (matchDomain host) cnname
unless (trusted && verified) (E.throwIO $ SSL.ProtocolError "fail to verify certificate")
(is, os) <- sslToStreams ssl
return (is, os, ssl)
where
matchDomain :: String -> String -> Bool
matchDomain n1 n2 =
let n1' = reverse (splitDot n1)
n2' = reverse (splitDot n2)
cmp src target = src == "*" || src == target
in and (zipWith cmp n1' n2')
splitDot :: String -> [String]
splitDot "" = [""]
splitDot x =
let (y, z) = break (== '.') x in
y : (if z == "" then [] else splitDot $ drop 1 z)
withConnection :: SSLContext
-> Maybe String
-> HostName
-> PortNumber
-> (InputStream ByteString -> OutputStream ByteString -> SSL -> IO a)
-> IO a
withConnection ctx subname host port action =
E.bracket (connect ctx subname host port) cleanup go
where
go (is, os, ssl) = action is os ssl
cleanup (_, os, ssl) = E.mask_ $
eatException $! Streams.write Nothing os >> closeSSL ssl
eatException m = void m `E.catch` (\(_::E.SomeException) -> return ())
accept :: SSL.SSLContext
-> Socket
-> IO (InputStream ByteString, OutputStream ByteString, SSL.SSL, N.SockAddr)
accept ctx sock = do
(sock', sockAddr) <- N.accept sock
E.bracketOnError (SSL.connection ctx sock') closeSSL $ \ ssl -> do
SSL.accept ssl
trusted <- SSL.getVerifyResult ssl
unless trusted (E.throwIO $ SSL.ProtocolError "fail to verify certificate")
(is, os) <- sslToStreams ssl
return (is, os, ssl, sockAddr)