module Hookup
(
Connection,
connect,
connectWithSocket,
close,
recv,
recvLine,
send,
putBuf,
ConnectionParams(..),
SocksParams(..),
TlsParams(..),
defaultFamily,
defaultTlsParams,
ConnectionFailure(..),
CommandReply(..)
, getPeerCertificate
, getPeerCertFingerprintSha1
, getPeerCertFingerprintSha256
, getPeerCertFingerprintSha512
, getPeerPubkeyFingerprintSha1
, getPeerPubkeyFingerprintSha256
, getPeerPubkeyFingerprintSha512
) where
import Control.Concurrent
import Control.Exception
import Control.Monad
import System.IO.Error (isDoesNotExistError, ioeGetErrorString)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import Data.Foldable
import Data.List (intercalate)
import Network.Socket (Socket, AddrInfo, PortNumber, HostName, Family)
import qualified Network.Socket as Socket
import qualified Network.Socket.ByteString as SocketB
import OpenSSL.Session (SSL, SSLContext)
import qualified OpenSSL as SSL
import qualified OpenSSL.Session as SSL
import OpenSSL.X509.SystemStore
import OpenSSL.X509 (X509)
import qualified OpenSSL.X509 as X509
import qualified OpenSSL.PEM as PEM
import qualified OpenSSL.EVP.Digest as Digest
import Data.Attoparsec.ByteString (Parser)
import qualified Data.Attoparsec.ByteString as Parser
import Hookup.OpenSSL (installVerification, getPubKeyDer)
import Hookup.Socks5
data ConnectionParams = ConnectionParams
{ cpFamily :: Family
, cpHost :: HostName
, cpPort :: PortNumber
, cpSocks :: Maybe SocksParams
, cpTls :: Maybe TlsParams
}
data SocksParams = SocksParams
{ spHost :: HostName
, spPort :: PortNumber
}
data TlsParams = TlsParams
{ tpClientCertificate :: Maybe FilePath
, tpClientPrivateKey :: Maybe FilePath
, tpServerCertificate :: Maybe FilePath
, tpCipherSuite :: String
, tpInsecure :: Bool
}
data ConnectionFailure
= HostnameResolutionFailure HostName String
| ConnectionFailure [IOError]
| LineTooLong
| LineTruncated
| SocksError CommandReply
| SocksAuthenticationError
| SocksProtocolError
| SocksBadDomainName
deriving Show
instance Exception ConnectionFailure where
displayException LineTruncated = "connection closed while reading line"
displayException LineTooLong = "line length exceeded maximum"
displayException (ConnectionFailure xs) =
"connection attempt failed due to: " ++
intercalate ", " (map displayException xs)
displayException (HostnameResolutionFailure h s) =
"hostname resolution failed (" ++ h ++ "): " ++ s
displayException SocksAuthenticationError =
"SOCKS authentication method rejected"
displayException SocksProtocolError =
"SOCKS server protocol error"
displayException SocksBadDomainName =
"SOCKS domain name length limit exceeded"
displayException (SocksError reply) =
"SOCKS command rejected: " ++
case reply of
Succeeded -> "succeeded"
GeneralFailure -> "general SOCKS server failure"
NotAllowed -> "connection not allowed by ruleset"
NetUnreachable -> "network unreachable"
HostUnreachable -> "host unreachable"
ConnectionRefused -> "connection refused"
TTLExpired -> "TTL expired"
CmdNotSupported -> "command not supported"
AddrNotSupported -> "address type not supported"
CommandReply n -> "unknown reply " ++ show n
defaultFamily :: Socket.Family
defaultFamily = Socket.AF_UNSPEC
defaultTlsParams :: TlsParams
defaultTlsParams = TlsParams
{ tpClientCertificate = Nothing
, tpClientPrivateKey = Nothing
, tpServerCertificate = Nothing
, tpCipherSuite = "HIGH"
, tpInsecure = False
}
openSocket :: ConnectionParams -> IO Socket
openSocket params =
case cpSocks params of
Nothing -> openSocket' (cpFamily params) (cpHost params) (cpPort params)
Just sp ->
do sock <- openSocket' (cpFamily params) (spHost sp) (spPort sp)
(sock <$ socksConnect sock (cpHost params) (cpPort params))
`onException` Socket.close sock
netParse :: Show a => Socket -> Parser a -> IO a
netParse sock parser =
do
result <- Parser.parseWith
(SocketB.recv sock 1)
parser
B.empty
case result of
Parser.Done i x | B.null i -> return x
_ -> throwIO SocksProtocolError
socksConnect :: Socket -> HostName -> PortNumber -> IO ()
socksConnect sock host port =
do SocketB.sendAll sock $
buildClientHello ClientHello
{ cHelloMethods = [AuthNoAuthenticationRequired] }
validateHello =<< netParse sock parseServerHello
let dnBytes = B8.pack host
unless (B.length dnBytes < 256)
(throwIO SocksBadDomainName)
SocketB.sendAll sock $
buildRequest Request
{ reqCommand = Connect
, reqAddress = Address (DomainName dnBytes) port
}
validateResponse =<< netParse sock parseResponse
validateHello :: ServerHello -> IO ()
validateHello hello =
unless (sHelloMethod hello == AuthNoAuthenticationRequired)
(throwIO SocksAuthenticationError)
validateResponse :: Response -> IO ()
validateResponse response =
unless (rspReply response == Succeeded )
(throwIO (SocksError (rspReply response)))
openSocket' :: Family -> HostName -> PortNumber -> IO Socket
openSocket' family h p =
do let hints = Socket.defaultHints
{ Socket.addrFamily = family
, Socket.addrSocketType = Socket.Stream
, Socket.addrFlags = [Socket.AI_ADDRCONFIG
,Socket.AI_NUMERICSERV]
}
res <- try (Socket.getAddrInfo (Just hints) (Just h) (Just (show p)))
case res of
Right ais -> attemptConnections [] ais
Left ioe
| isDoesNotExistError ioe ->
throwIO (HostnameResolutionFailure h (ioeGetErrorString ioe))
| otherwise -> throwIO ioe
attemptConnections ::
[IOError] ->
[Socket.AddrInfo] ->
IO Socket
attemptConnections exs [] = throwIO (ConnectionFailure exs)
attemptConnections exs (ai:ais) =
do res <- try (connectToAddrInfo ai)
case res of
Left ex -> attemptConnections (ex:exs) ais
Right s -> return s
connectToAddrInfo :: AddrInfo -> IO Socket
connectToAddrInfo info
= bracketOnError (socket' info) Socket.close
$ \s -> s <$ Socket.connect s (Socket.addrAddress info)
socket' :: AddrInfo -> IO Socket
socket' ai =
Socket.socket
(Socket.addrFamily ai)
(Socket.addrSocketType ai)
(Socket.addrProtocol ai)
data NetworkHandle = SSL SSL | Socket Socket
openNetworkHandle ::
ConnectionParams ->
IO Socket ->
IO NetworkHandle
openNetworkHandle params mkSocket =
case cpTls params of
Nothing -> Socket <$> mkSocket
Just tls -> SSL <$> startTls tls (cpHost params) mkSocket
closeNetworkHandle :: NetworkHandle -> IO ()
closeNetworkHandle (Socket s) = Socket.close s
closeNetworkHandle (SSL s) =
do SSL.shutdown s SSL.Unidirectional
traverse_ Socket.close (SSL.sslSocket s)
networkSend :: NetworkHandle -> ByteString -> IO ()
networkSend (Socket s) = SocketB.sendAll s
networkSend (SSL s) = SSL.write s
networkRecv :: NetworkHandle -> Int -> IO ByteString
networkRecv (Socket s) = SocketB.recv s
networkRecv (SSL s) = SSL.read s
data Connection = Connection (MVar ByteString) NetworkHandle
connect ::
ConnectionParams ->
IO Connection
connect params =
do h <- openNetworkHandle params (openSocket params)
b <- newMVar B.empty
return (Connection b h)
connectWithSocket ::
ConnectionParams ->
Socket ->
IO Connection
connectWithSocket params sock =
do h <- openNetworkHandle params (return sock)
b <- newMVar B.empty
return (Connection b h)
close ::
Connection ->
IO ()
close (Connection _ h) = closeNetworkHandle h
recv ::
Connection ->
Int ->
IO ByteString
recv (Connection buf h) n =
do bufChunk <- swapMVar buf B.empty
if B.null bufChunk
then networkRecv h n
else return bufChunk
recvLine ::
Connection ->
Int ->
IO (Maybe ByteString)
recvLine (Connection buf h) n =
modifyMVar buf $ \bs ->
go (B.length bs) bs []
where
go bsn bs bss =
case B8.elemIndex '\n' bs of
Just i -> return (B.tail b,
Just (cleanEnd (B.concat (reverse (a:bss)))))
where
(a,b) = B.splitAt i bs
Nothing ->
do when (bsn >= n) (throwIO LineTooLong)
more <- networkRecv h n
if B.null more
then if bsn == 0 then return (B.empty, Nothing)
else throwIO LineTruncated
else go (bsn + B.length more) more (bs:bss)
putBuf ::
Connection ->
ByteString ->
IO ()
putBuf (Connection buf h) bs =
modifyMVar_ buf (\old -> return $! B.append bs old)
cleanEnd :: ByteString -> ByteString
cleanEnd bs
| B.null bs || B8.last bs /= '\r' = bs
| otherwise = B.init bs
send ::
Connection ->
ByteString ->
IO ()
send (Connection _ h) = networkSend h
startTls ::
TlsParams ->
String ->
IO Socket ->
IO SSL
startTls tp hostname mkSocket = SSL.withOpenSSL $
do ctx <- SSL.context
SSL.contextSetCiphers ctx (tpCipherSuite tp)
installVerification ctx hostname
SSL.contextSetVerificationMode ctx (verificationMode (tpInsecure tp))
SSL.contextAddOption ctx SSL.SSL_OP_ALL
SSL.contextRemoveOption ctx SSL.SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS
setupCaCertificates ctx (tpServerCertificate tp)
traverse_ (setupCertificate ctx) (tpClientCertificate tp)
traverse_ (setupPrivateKey ctx) (tpClientPrivateKey tp)
ssl <- SSL.connection ctx =<< mkSocket
SSL.setTlsextHostName ssl hostname
SSL.connect ssl
return ssl
setupCaCertificates :: SSLContext -> Maybe FilePath -> IO ()
setupCaCertificates ctx mbPath =
case mbPath of
Nothing -> contextLoadSystemCerts ctx
Just path -> SSL.contextSetCAFile ctx path
setupCertificate :: SSLContext -> FilePath -> IO ()
setupCertificate ctx path
= SSL.contextSetCertificate ctx
=<< PEM.readX509
=<< readFile path
setupPrivateKey :: SSLContext -> FilePath -> IO ()
setupPrivateKey ctx path =
do str <- readFile path
key <- PEM.readPrivateKey str PEM.PwNone
SSL.contextSetPrivateKey ctx key
verificationMode :: Bool -> SSL.VerificationMode
verificationMode insecure
| insecure = SSL.VerifyNone
| otherwise = SSL.VerifyPeer
{ SSL.vpFailIfNoPeerCert = True
, SSL.vpClientOnce = True
, SSL.vpCallback = Nothing
}
getPeerCertificate :: Connection -> IO (Maybe X509.X509)
getPeerCertificate (Connection _ h) =
case h of
Socket{} -> return Nothing
SSL ssl -> SSL.getPeerCertificate ssl
getPeerCertFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha1 = getPeerCertFingerprint "sha1"
getPeerCertFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha256 = getPeerCertFingerprint "sha256"
getPeerCertFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha512 = getPeerCertFingerprint "sha512"
getPeerCertFingerprint :: String -> Connection -> IO (Maybe ByteString)
getPeerCertFingerprint name h =
do mb <- getPeerCertificate h
case mb of
Nothing -> return Nothing
Just x509 ->
do der <- X509.writeDerX509 x509
mbdigest <- Digest.getDigestByName name
case mbdigest of
Nothing -> return Nothing
Just digest -> return $! Just $! Digest.digestLBS digest der
getPeerPubkeyFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha1 = getPeerPubkeyFingerprint "sha1"
getPeerPubkeyFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha256 = getPeerPubkeyFingerprint "sha256"
getPeerPubkeyFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha512 = getPeerPubkeyFingerprint "sha512"
getPeerPubkeyFingerprint :: String -> Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprint name h =
do mb <- getPeerCertificate h
case mb of
Nothing -> return Nothing
Just x509 ->
do der <- getPubKeyDer x509
mbdigest <- Digest.getDigestByName name
case mbdigest of
Nothing -> return Nothing
Just digest -> return $! Just $! Digest.digestBS digest der