module Network.HTTP.Conduit.Manager
( Manager
, ManagerSettings (..)
, ConnKey (..)
, newManager
, closeManager
, getConn
, ConnReuse (..)
, withManager
, ConnRelease
, ManagedConn (..)
, defaultCheckCerts
) where
import Prelude hiding (catch)
import Data.Monoid (mappend)
import System.IO (hClose, hFlush, IOMode(..))
import qualified Data.IORef as I
import qualified Data.Map as Map
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Lazy as L
import qualified Blaze.ByteString.Builder as Blaze
import Data.Text (Text)
import qualified Data.Text as T
import Control.Monad.Base (liftBase)
import Control.Exception.Lifted (mask)
import Control.Exception (mask_, SomeException, catch)
import Control.Monad.Trans.Resource
( ResourceT, runResourceT, ResourceIO, withIO
, register, release
, newRef, readRef', writeRef
, safeFromIOBase
)
import Control.Concurrent (forkIO, threadDelay)
import Data.Time (UTCTime, getCurrentTime, addUTCTime)
import Network (connectTo, PortID (PortNumber), HostName)
import Network.Socket (socketToHandle)
import Data.Certificate.X509 (X509, encodeCertificate)
import qualified Network.HTTP.Types as W
import Network.TLS.Extra (certificateVerifyChain, certificateVerifyDomain)
import Network.HTTP.Conduit.ConnInfo
import Network.HTTP.Conduit.Util (hGetSome)
import Network.HTTP.Conduit.Parser (parserHeadersFromByteString)
import Network.HTTP.Conduit.Request
import Network.Socks5 (SocksConf, socksConnectWith)
import Data.Default
import Data.Maybe (mapMaybe)
import System.IO (Handle)
data ManagerSettings = ManagerSettings
{ managerConnCount :: Int
, managerCheckCerts :: W.Ascii -> [X509] -> IO TLSCertificateUsage
}
type X509Encoded = L.ByteString
instance Default ManagerSettings where
def = ManagerSettings
{ managerConnCount = 10
, managerCheckCerts = defaultCheckCerts
}
defaultCheckCerts :: W.Ascii -> [X509] -> IO TLSCertificateUsage
defaultCheckCerts host' certs =
case certificateVerifyDomain (S8.unpack host') certs of
CertificateUsageAccept -> certificateVerifyChain certs
rejected -> return rejected
data Manager = Manager
{ mConns :: !(I.IORef (Maybe (Map.Map ConnKey (NonEmptyList ConnInfo))))
, mMaxConns :: !Int
, mCheckCerts :: W.Ascii -> [X509] -> IO TLSCertificateUsage
, mCertCache :: !(I.IORef (Map.Map W.Ascii (Map.Map X509Encoded UTCTime)))
}
data NonEmptyList a =
One !a !UTCTime |
Cons !a !Int !UTCTime !(NonEmptyList a)
data ConnKey = ConnKey !Text !Int !Bool
deriving (Eq, Show, Ord)
takeSocket :: Manager -> ConnKey -> IO (Maybe ConnInfo)
takeSocket man key =
I.atomicModifyIORef (mConns man) go
where
go Nothing = (Nothing, Nothing)
go (Just m) =
case Map.lookup key m of
Nothing -> (Just m, Nothing)
Just (One a _) -> (Just $ Map.delete key m, Just a)
Just (Cons a _ _ rest) -> (Just $ Map.insert key rest m, Just a)
putSocket :: Manager -> ConnKey -> ConnInfo -> IO ()
putSocket man key ci = do
now <- getCurrentTime
msock <- I.atomicModifyIORef (mConns man) (go now)
maybe (return ()) connClose msock
where
go _ Nothing = (Nothing, Just ci)
go now (Just m) =
case Map.lookup key m of
Nothing -> (Just $ Map.insert key (One ci now) m, Nothing)
Just l ->
let (l', mx) = addToList now (mMaxConns man) ci l
in (Just $ Map.insert key l' m, mx)
addToList :: UTCTime -> Int -> a -> NonEmptyList a -> (NonEmptyList a, Maybe a)
addToList _ i x l | i <= 1 = (l, Just x)
addToList now _ x l@One{} = (Cons x 2 now l, Nothing)
addToList now maxCount x l@(Cons _ currCount _ _)
| maxCount > currCount = (Cons x (currCount + 1) now l, Nothing)
| otherwise = (l, Just x)
newManager :: ManagerSettings -> IO Manager
newManager ms = do
mapRef <- I.newIORef (Just Map.empty)
certCache <- I.newIORef Map.empty
_ <- forkIO $ reap mapRef certCache
return $ Manager mapRef (managerConnCount ms) (managerCheckCerts ms) certCache
reap :: I.IORef (Maybe (Map.Map ConnKey (NonEmptyList ConnInfo)))
-> I.IORef (Map.Map W.Ascii (Map.Map X509Encoded UTCTime))
-> IO ()
reap mapRef certCacheRef =
mask_ loop
where
loop = do
threadDelay (5 * 1000 * 1000)
now <- getCurrentTime
let isNotStale time = 30 `addUTCTime` time >= now
mtoDestroy <- I.atomicModifyIORef mapRef (findStaleWrap isNotStale)
case mtoDestroy of
Nothing -> return ()
Just toDestroy -> do
mapM_ safeConnClose toDestroy
loop
I.atomicModifyIORef certCacheRef $ \x -> (flushStaleCerts now x, ())
findStaleWrap _ Nothing = (Nothing, Nothing)
findStaleWrap isNotStale (Just m) =
let (x, y) = findStale isNotStale m
in (Just x, Just y)
findStale isNotStale =
findStale' id id . Map.toList
where
findStale' destroy keep [] = (Map.fromList $ keep [], destroy [])
findStale' destroy keep ((connkey, nelist):rest) =
findStale' destroy' keep' rest
where
(notStale, stale) = span (isNotStale . fst) $ neToList nelist
destroy' = destroy . (map snd stale++)
keep' =
case neFromList notStale of
Nothing -> keep
Just x -> keep . ((connkey, x):)
flushStaleCerts now =
Map.fromList . mapMaybe flushStaleCerts' . Map.toList
where
flushStaleCerts' (host', inner) =
case mapMaybe flushStaleCerts'' $ Map.toList inner of
[] -> Nothing
pairs -> Just (host', Map.fromList $ take 10 pairs)
flushStaleCerts'' (certs, expires)
| expires > now = Just (certs, expires)
| otherwise = Nothing
neToList :: NonEmptyList a -> [(UTCTime, a)]
neToList (One a t) = [(t, a)]
neToList (Cons a _ t nelist) = (t, a) : neToList nelist
neFromList :: [(UTCTime, a)] -> Maybe (NonEmptyList a)
neFromList [] = Nothing
neFromList [(t, a)] = Just (One a t)
neFromList xs =
Just . snd . go $ xs
where
go [] = error "neFromList.go []"
go [(t, a)] = (2, One a t)
go ((t, a):rest) =
let (i, rest') = go rest
i' = i + 1
in i' `seq` (i', Cons a i t rest')
withManager :: ResourceIO m => (Manager -> ResourceT m a) -> m a
withManager f = runResourceT $ do
(_, manager) <- withIO (newManager def) closeManager
f manager
closeManager :: Manager -> IO ()
closeManager manager = mask_ $ do
m <- I.atomicModifyIORef (mConns manager) $ \x -> (Nothing, x)
mapM_ (nonEmptyMapM_ safeConnClose) $ maybe [] Map.elems m
safeConnClose :: ConnInfo -> IO ()
safeConnClose ci = connClose ci `catch` \(_::SomeException) -> return ()
nonEmptyMapM_ :: Monad m => (a -> m ()) -> NonEmptyList a -> m ()
nonEmptyMapM_ f (One x _) = f x
nonEmptyMapM_ f (Cons x _ _ l) = f x >> nonEmptyMapM_ f l
getSocketConn
:: ResourceIO m
=> Manager
-> String
-> Int
-> Maybe SocksConf
-> ResourceT m (ConnRelease m, ConnInfo, ManagedConn)
getSocketConn man host' port' socksProxy' =
getManagedConn man (ConnKey (T.pack host') port' False) $
getSocket host' port' socksProxy' >>= socketConn desc
where
desc = socketDesc host' port' "unsecured"
socketDesc :: String -> Int -> String -> String
socketDesc h p t = unwords [h, show p, t]
getSslConn :: ResourceIO m
=> ([X509] -> IO TLSCertificateUsage)
-> Manager
-> String
-> Int
-> Maybe SocksConf
-> ResourceT m (ConnRelease m, ConnInfo, ManagedConn)
getSslConn checkCert man host' port' socksProxy' =
getManagedConn man (ConnKey (T.pack host') port' True) $
(connectionTo host' (PortNumber $ fromIntegral port') socksProxy' >>= sslClientConn desc checkCert)
where
desc = socketDesc host' port' "secured"
getSslProxyConn
:: ResourceIO m
=> ([X509] -> IO TLSCertificateUsage)
-> S8.ByteString
-> Int
-> Manager
-> String
-> Int
-> Maybe SocksConf
-> ResourceT m (ConnRelease m, ConnInfo, ManagedConn)
getSslProxyConn checkCert thost tport man phost pport socksProxy' =
getManagedConn man (ConnKey (T.pack phost) pport True) $
doConnect >>= sslClientConn desc checkCert
where
desc = socketDesc phost pport "secured-proxy"
doConnect = do
h <- connectionTo phost (PortNumber $ fromIntegral pport) socksProxy'
L.hPutStr h $ Blaze.toLazyByteString connectRequest
hFlush h
r <- hGetSome h 2048
res <- parserHeadersFromByteString r
case res of
Right ((_, 200, _), _) -> return h
Right ((_, _, msg), _) -> hClose h >> proxyError (S8.unpack msg)
Left s -> hClose h >> proxyError s
connectRequest =
Blaze.fromByteString "CONNECT "
`mappend` Blaze.fromByteString thost
`mappend` Blaze.fromByteString (S8.pack (':' : show tport))
`mappend` Blaze.fromByteString " HTTP/1.1\r\n\r\n"
proxyError s =
error $ "Proxy failed to CONNECT to '"
++ S8.unpack thost ++ ":" ++ show tport ++ "' : " ++ s
data ManagedConn = Fresh | Reused
getManagedConn
:: ResourceIO m
=> Manager
-> ConnKey
-> IO ConnInfo
-> ResourceT m (ConnRelease m, ConnInfo, ManagedConn)
getManagedConn man key open = mask $ \restore -> do
mci <- liftBase $ takeSocket man key
(ci, isManaged) <-
case mci of
Nothing -> do
ci <- restore $ liftBase open
return (ci, Fresh)
Just ci -> return (ci, Reused)
toReuseRef <- newRef DontReuse
releaseKey <- register $ do
toReuse <- readRef' toReuseRef
case toReuse of
Reuse -> safeFromIOBase $ putSocket man key ci
DontReuse -> safeFromIOBase $ connClose ci
let connRelease x = do
writeRef toReuseRef x
release releaseKey
return (connRelease, ci, isManaged)
data ConnReuse = Reuse | DontReuse
type ConnRelease m = ConnReuse -> ResourceT m ()
getConn :: ResourceIO m
=> Request m
-> Manager
-> ResourceT m (ConnRelease m, ConnInfo, ManagedConn)
getConn req m =
go m connhost connport (socksProxy req)
where
h = host req
(useProxy, connhost, connport) =
case proxy req of
Just p -> (True, S8.unpack (proxyHost p), proxyPort p)
Nothing -> (False, S8.unpack h, port req)
go =
case (secure req, useProxy) of
(False, _) -> getSocketConn
(True, False) -> getSslConn $ checkCerts m h
(True, True) -> getSslProxyConn (checkCerts m h) h (port req)
checkCerts :: Manager -> W.Ascii -> [X509] -> IO TLSCertificateUsage
checkCerts man host' certs = do
#if DEBUG
putStrLn $ "checkCerts for host: " ++ show host'
#endif
cache <- I.readIORef $ mCertCache man
case Map.lookup host' cache >>= Map.lookup encoded of
Nothing -> do
#if DEBUG
putStrLn $ concat ["checkCerts ", show host', " no cached certs found"]
#endif
res <- mCheckCerts man host' certs
case res of
CertificateUsageAccept -> do
#if DEBUG
putStrLn $ concat ["checkCerts ", show host', " valid cert, adding to cache"]
#endif
now <- getCurrentTime
let expire = (60 * 60) `addUTCTime` now
I.atomicModifyIORef (mCertCache man) $ addValidCerts expire
_ -> return ()
return res
Just _ -> do
#if DEBUG
putStrLn $ concat ["checkCerts ", show host', " cert already cached"]
#endif
return CertificateUsageAccept
where
encoded = L.concat $ map encodeCertificate certs
addValidCerts expire cache =
(Map.insert host' inner cache, ())
where
inner =
case Map.lookup host' cache of
Nothing -> Map.singleton encoded expire
Just m -> Map.insert encoded expire m
connectionTo :: HostName -> PortID -> Maybe SocksConf -> IO Handle
connectionTo host' port' Nothing = connectTo host' port'
connectionTo host' port' (Just socksConf) =
socksConnectWith socksConf host' port' >>= flip socketToHandle ReadWriteMode