{-# Language BlockArguments #-}
module Hookup
(
Connection,
connect,
connectWithSocket,
close,
upgradeTls,
recv,
recvLine,
send,
putBuf,
ConnectionParams(..),
SocksParams(..),
TlsParams(..),
PEM.PemPasswordSupply(..),
defaultTlsParams,
ConnectionFailure(..),
CommandReply(..)
, getClientCertificate
, getPeerCertificate
, getPeerCertFingerprintSha1
, getPeerCertFingerprintSha256
, getPeerCertFingerprintSha512
, getPeerPubkeyFingerprintSha1
, getPeerPubkeyFingerprintSha256
, getPeerPubkeyFingerprintSha512
) where
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Concurrent
import Control.Exception
import Control.Monad
import System.IO.Error (isDoesNotExistError, ioeGetErrorString)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import Data.Foldable
import Data.List (intercalate, partition)
import Data.Maybe (fromMaybe)
import Foreign.C.String (withCStringLen)
import Foreign.Ptr (nullPtr)
import Network.Socket (AddrInfo, HostName, PortNumber, SockAddr, Socket, Family)
import qualified Network.Socket as Socket
import qualified Network.Socket.ByteString as SocketB
import OpenSSL.Session (SSL, SSLContext)
import qualified OpenSSL as SSL
import qualified OpenSSL.Session as SSL
import OpenSSL.X509.SystemStore
import OpenSSL.X509 (X509)
import qualified OpenSSL.X509 as X509
import qualified OpenSSL.PEM as PEM
import qualified OpenSSL.EVP.Digest as Digest
import Data.Attoparsec.ByteString (Parser)
import qualified Data.Attoparsec.ByteString as Parser
import Hookup.Concurrent (concurrentAttempts)
import Hookup.OpenSSL
import Hookup.Socks5
data ConnectionParams = ConnectionParams
{ cpHost :: HostName
, cpPort :: PortNumber
, cpSocks :: Maybe SocksParams
, cpTls :: Maybe TlsParams
, cpBind :: Maybe HostName
}
data SocksParams = SocksParams
{ spHost :: HostName
, spPort :: PortNumber
}
data TlsParams = TlsParams
{ tpClientCertificate :: Maybe FilePath
, tpClientPrivateKey :: Maybe FilePath
, tpClientPrivateKeyPassword :: Maybe ByteString
, tpServerCertificate :: Maybe FilePath
, tpCipherSuite :: String
, tpInsecure :: Bool
}
data ConnectionFailure
= HostnameResolutionFailure HostName String
| ConnectionFailure [IOError]
| LineTooLong
| LineTruncated
| SocksError CommandReply
| SocksAuthenticationError
| SocksProtocolError
| SocksBadDomainName
deriving Show
instance Exception ConnectionFailure where
displayException LineTruncated = "connection closed while reading line"
displayException LineTooLong = "line length exceeded maximum"
displayException (ConnectionFailure xs) =
"connection attempt failed due to: " ++
intercalate ", " (map displayException xs)
displayException (HostnameResolutionFailure h s) =
"hostname resolution failed (" ++ h ++ "): " ++ s
displayException SocksAuthenticationError =
"SOCKS authentication method rejected"
displayException SocksProtocolError =
"SOCKS server protocol error"
displayException SocksBadDomainName =
"SOCKS domain name length limit exceeded"
displayException (SocksError reply) =
"SOCKS command rejected: " ++
case reply of
Succeeded -> "succeeded"
GeneralFailure -> "general SOCKS server failure"
NotAllowed -> "connection not allowed by ruleset"
NetUnreachable -> "network unreachable"
HostUnreachable -> "host unreachable"
ConnectionRefused -> "connection refused"
TTLExpired -> "TTL expired"
CmdNotSupported -> "command not supported"
AddrNotSupported -> "address type not supported"
CommandReply n -> "unknown reply " ++ show n
defaultTlsParams :: TlsParams
defaultTlsParams = TlsParams
{ tpClientCertificate = Nothing
, tpClientPrivateKey = Nothing
, tpClientPrivateKeyPassword = Nothing
, tpServerCertificate = Nothing
, tpCipherSuite = "HIGH"
, tpInsecure = False
}
openSocket :: ConnectionParams -> IO Socket
openSocket params =
case cpSocks params of
Nothing -> openSocket' (cpHost params) (cpPort params) (cpBind params)
Just sp ->
do sock <- openSocket' (spHost sp) (spPort sp) (cpBind params)
(sock <$ socksConnect sock (cpHost params) (cpPort params))
`onException` Socket.close sock
netParse :: Show a => Socket -> Parser a -> IO a
netParse sock parser =
do
result <- Parser.parseWith
(SocketB.recv sock 1)
parser
B.empty
case result of
Parser.Done i x | B.null i -> return x
_ -> throwIO SocksProtocolError
socksConnect :: Socket -> HostName -> PortNumber -> IO ()
socksConnect sock host port =
do SocketB.sendAll sock $
buildClientHello ClientHello
{ cHelloMethods = [AuthNoAuthenticationRequired] }
validateHello =<< netParse sock parseServerHello
let dnBytes = B8.pack host
unless (B.length dnBytes < 256)
(throwIO SocksBadDomainName)
SocketB.sendAll sock $
buildRequest Request
{ reqCommand = Connect
, reqAddress = Address (DomainName dnBytes) port
}
validateResponse =<< netParse sock parseResponse
validateHello :: ServerHello -> IO ()
validateHello hello =
unless (sHelloMethod hello == AuthNoAuthenticationRequired)
(throwIO SocksAuthenticationError)
validateResponse :: Response -> IO ()
validateResponse response =
unless (rspReply response == Succeeded )
(throwIO (SocksError (rspReply response)))
openSocket' ::
HostName ->
PortNumber ->
Maybe HostName ->
IO Socket
openSocket' h p mbBind =
do mbSrc <- traverse (resolve Nothing) mbBind
dst <- resolve (Just p) h
let pairs = interleaveAddressFamilies (matchBindAddrs mbSrc dst)
when (null pairs)
(throwIO (HostnameResolutionFailure h "No source/destination address family match"))
res <- concurrentAttempts connAttemptDelay Socket.close (uncurry connectToAddrInfo <$> pairs)
case res of
Left es -> throwIO (ConnectionFailure [ioe | e <- es, Just ioe <- [fromException e]])
Right s -> pure s
hints :: AddrInfo
hints = Socket.defaultHints
{ Socket.addrSocketType = Socket.Stream
, Socket.addrFlags = [Socket.AI_NUMERICSERV]
}
resolve :: Maybe PortNumber -> HostName -> IO [AddrInfo]
resolve mbPort host =
do res <- try (Socket.getAddrInfo (Just hints) (Just host) (show<$>mbPort))
case res of
Right ais -> return ais
Left ioe
| isDoesNotExistError ioe ->
throwIO (HostnameResolutionFailure host (ioeGetErrorString ioe))
| otherwise -> throwIO ioe
matchBindAddrs :: Maybe [AddrInfo] -> [AddrInfo] -> [(Maybe SockAddr, AddrInfo)]
matchBindAddrs Nothing dst = [ (Nothing, x) | x <- dst ]
matchBindAddrs (Just src) dst =
[ (Just (Socket.addrAddress s), d)
| d <- dst
, let ss = [s | s <- src, Socket.addrFamily d == Socket.addrFamily s]
, s <- take 1 ss ]
connAttemptDelay :: Int
connAttemptDelay = 150 * 1000
interleaveAddressFamilies :: [(Maybe SockAddr, AddrInfo)] -> [(Maybe SockAddr, AddrInfo)]
interleaveAddressFamilies xs = interleave sixes others
where
(sixes, others) = partition is6 xs
is6 x = Socket.AF_INET6 == Socket.addrFamily (snd x)
interleave (x:xs) (y:ys) = x : y : interleave xs ys
interleave [] ys = ys
interleave xs [] = xs
connectToAddrInfo :: Maybe SockAddr -> AddrInfo -> IO Socket
connectToAddrInfo mbSrc info
= bracketOnError (socket' info) Socket.close $ \s ->
do traverse_ (bind' s) mbSrc
Socket.connect s (Socket.addrAddress info)
pure s
bind' :: Socket -> SockAddr -> IO ()
bind' _ (Socket.SockAddrInet _ 0) = pure ()
bind' _ (Socket.SockAddrInet6 _ _ (0,0,0,0) _) = pure ()
bind' s a = Socket.bind s a
socket' :: AddrInfo -> IO Socket
socket' ai =
Socket.socket
(Socket.addrFamily ai)
(Socket.addrSocketType ai)
(Socket.addrProtocol ai)
data NetworkHandle = SSL (Maybe X509) SSL | Socket Socket
openNetworkHandle ::
ConnectionParams ->
IO Socket ->
IO NetworkHandle
openNetworkHandle params mkSocket =
case cpTls params of
Nothing -> Socket <$> mkSocket
Just tls ->
do (clientCert, ssl) <- startTls tls (cpHost params) mkSocket
pure (SSL clientCert ssl)
closeNetworkHandle :: NetworkHandle -> IO ()
closeNetworkHandle (Socket s) = Socket.close s
closeNetworkHandle (SSL _ s) =
do SSL.shutdown s SSL.Unidirectional
traverse_ Socket.close (SSL.sslSocket s)
networkSend :: NetworkHandle -> ByteString -> IO ()
networkSend (Socket s) = SocketB.sendAll s
networkSend (SSL _ s) = SSL.write s
networkRecv :: NetworkHandle -> Int -> IO ByteString
networkRecv (Socket s) = SocketB.recv s
networkRecv (SSL _ s) = SSL.read s
data Connection =
Connection
{-# UNPACK #-} !(MVar ByteString)
{-# UNPACK #-} !(MVar NetworkHandle)
connect ::
ConnectionParams ->
IO Connection
connect params =
do h <- openNetworkHandle params (openSocket params)
Connection <$> newMVar B.empty <*> newMVar h
connectWithSocket ::
ConnectionParams ->
Socket ->
IO Connection
connectWithSocket params sock =
do h <- openNetworkHandle params (return sock)
Connection <$> newMVar B.empty <*> newMVar h
close ::
Connection ->
IO ()
close (Connection _ m) = withMVar m $ \h -> closeNetworkHandle h
recv ::
Connection ->
Int ->
IO ByteString
recv (Connection bufVar hVar) n =
modifyMVar bufVar $ \bufChunk ->
do if B.null bufChunk
then do h <- readMVar hVar
bs <- networkRecv h n
return (B.empty, bs)
else return (B.empty, bufChunk)
recvLine ::
Connection ->
Int ->
IO (Maybe ByteString)
recvLine (Connection bufVar hVar) n =
modifyMVar bufVar $ \bs ->
do h <- readMVar hVar
go h (B.length bs) bs []
where
go h bsn bs bss =
case B8.elemIndex '\n' bs of
Just i -> return (B.tail b,
Just (cleanEnd (B.concat (reverse (a:bss)))))
where
(a,b) = B.splitAt i bs
Nothing ->
do when (bsn >= n) (throwIO LineTooLong)
more <- networkRecv h n
if B.null more
then if bsn == 0 then return (B.empty, Nothing)
else throwIO LineTruncated
else go h (bsn + B.length more) more (bs:bss)
putBuf ::
Connection ->
ByteString ->
IO ()
putBuf (Connection bufVar _) bs =
modifyMVar_ bufVar (\old -> return $! B.append bs old)
cleanEnd :: ByteString -> ByteString
cleanEnd bs
| B.null bs || B8.last bs /= '\r' = bs
| otherwise = B.init bs
send ::
Connection ->
ByteString ->
IO ()
send (Connection _ hVar) bs =
do h <- readMVar hVar
networkSend h bs
upgradeTls ::
TlsParams ->
String ->
Connection ->
IO ()
upgradeTls tp hostname (Connection bufVar hVar) =
modifyMVar_ bufVar $ \buf ->
modifyMVar hVar $ \h ->
case h of
SSL{} -> return (h, buf)
Socket s ->
do (cert, ssl) <- startTls tp hostname (pure s)
return (SSL cert ssl, B.empty)
startTls ::
TlsParams ->
String ->
IO Socket ->
IO (Maybe X509, SSL)
startTls tp hostname mkSocket = SSL.withOpenSSL $
do ctx <- SSL.context
SSL.contextSetCiphers ctx (tpCipherSuite tp)
installVerification ctx hostname
SSL.contextSetVerificationMode ctx (verificationMode (tpInsecure tp))
SSL.contextAddOption ctx SSL.SSL_OP_ALL
SSL.contextRemoveOption ctx SSL.SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS
setupCaCertificates ctx (tpServerCertificate tp)
clientCert <- traverse (setupCertificate ctx) (tpClientCertificate tp)
for_ (tpClientPrivateKey tp) $ \path ->
withDefaultPassword ctx (tpClientPrivateKeyPassword tp) $
SSL.contextSetPrivateKeyFile ctx path
ssl <- SSL.connection ctx =<< mkSocket
SSL.setTlsextHostName ssl hostname
SSL.connect ssl
return (clientCert, ssl)
setupCaCertificates :: SSLContext -> Maybe FilePath -> IO ()
setupCaCertificates ctx mbPath =
case mbPath of
Nothing -> contextLoadSystemCerts ctx
Just path -> withDefaultPassword ctx Nothing (SSL.contextSetCAFile ctx path)
setupCertificate :: SSLContext -> FilePath -> IO X509
setupCertificate ctx path =
do x509 <- PEM.readX509 =<< readFile path
SSL.contextSetCertificate ctx x509
pure x509
verificationMode :: Bool -> SSL.VerificationMode
verificationMode insecure
| insecure = SSL.VerifyNone
| otherwise = SSL.VerifyPeer
{ SSL.vpFailIfNoPeerCert = True
, SSL.vpClientOnce = True
, SSL.vpCallback = Nothing
}
getPeerCertificate :: Connection -> IO (Maybe X509.X509)
getPeerCertificate (Connection _ hVar) =
withMVar hVar $ \h ->
case h of
Socket{} -> return Nothing
SSL _ ssl -> SSL.getPeerCertificate ssl
getClientCertificate :: Connection -> IO (Maybe X509.X509)
getClientCertificate (Connection _ hVar) =
do h <- readMVar hVar
return $ case h of
Socket{} -> Nothing
SSL c _ -> c
getPeerCertFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha1 = getPeerCertFingerprint "sha1"
getPeerCertFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha256 = getPeerCertFingerprint "sha256"
getPeerCertFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha512 = getPeerCertFingerprint "sha512"
getPeerCertFingerprint :: String -> Connection -> IO (Maybe ByteString)
getPeerCertFingerprint name h =
do mb <- getPeerCertificate h
case mb of
Nothing -> return Nothing
Just x509 ->
do der <- X509.writeDerX509 x509
mbdigest <- Digest.getDigestByName name
case mbdigest of
Nothing -> return Nothing
Just digest -> return $! Just $! Digest.digestLBS digest der
getPeerPubkeyFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha1 = getPeerPubkeyFingerprint "sha1"
getPeerPubkeyFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha256 = getPeerPubkeyFingerprint "sha256"
getPeerPubkeyFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha512 = getPeerPubkeyFingerprint "sha512"
getPeerPubkeyFingerprint :: String -> Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprint name h =
do mb <- getPeerCertificate h
case mb of
Nothing -> return Nothing
Just x509 ->
do der <- getPubKeyDer x509
mbdigest <- Digest.getDigestByName name
case mbdigest of
Nothing -> return Nothing
Just digest -> return $! Just $! Digest.digestBS digest der