{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE StrictData            #-}
{-# LANGUAGE ViewPatterns          #-}

module HttpProxy () where



import           ClassyPrelude
import qualified Data.ByteString.Char8     as BC

import           Control.Monad.Except
import qualified Data.Conduit.Network.TLS  as N
import qualified Data.Streaming.Network    as N

import qualified Data.ByteString.Base64    as B64
import           Network.Socket            (HostName, PortNumber)
import qualified Network.Socket            as N hiding (recv, recvFrom, send,
                                                 sendTo)
import qualified Network.Socket.ByteString as N

import           Logger
import           Types


data HttpProxySettings = HttpProxySettings
  { HttpProxySettings -> HostName
proxyHost   :: HostName
  , HttpProxySettings -> PortNumber
proxyPort   :: PortNumber
  , HttpProxySettings -> Maybe (ByteString, ByteString)
credentials :: Maybe (ByteString, ByteString)
  } deriving (Int -> HttpProxySettings -> ShowS
[HttpProxySettings] -> ShowS
HttpProxySettings -> HostName
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
showList :: [HttpProxySettings] -> ShowS
$cshowList :: [HttpProxySettings] -> ShowS
show :: HttpProxySettings -> HostName
$cshow :: HttpProxySettings -> HostName
showsPrec :: Int -> HttpProxySettings -> ShowS
$cshowsPrec :: Int -> HttpProxySettings -> ShowS
Show)


httpProxyConnection :: MonadError Error m => HttpProxySettings -> (HostName, PortNumber) ->  (Connection -> IO (m a)) -> IO (m a)
httpProxyConnection :: forall (m :: * -> *) a.
MonadError Error m =>
HttpProxySettings
-> (HostName, PortNumber) -> (Connection -> IO (m a)) -> IO (m a)
httpProxyConnection HttpProxySettings{HostName
Maybe (ByteString, ByteString)
PortNumber
credentials :: Maybe (ByteString, ByteString)
proxyPort :: PortNumber
proxyHost :: HostName
credentials :: HttpProxySettings -> Maybe (ByteString, ByteString)
proxyPort :: HttpProxySettings -> PortNumber
proxyHost :: HttpProxySettings -> HostName
..} (HostName
host, PortNumber
port) Connection -> IO (m a)
app = forall {m :: * -> *} {m :: * -> *} {a}.
(MonadUnliftIO m, MonadError Error m) =>
m (m a) -> m (m a)
onError forall a b. (a -> b) -> a -> b
$ do
  HostName -> IO ()
debug forall a b. (a -> b) -> a -> b
$ HostName
"Opening tcp connection to proxy " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> HostName
show HostName
proxyHost forall a. Semigroup a => a -> a -> a
<> HostName
":" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> HostName
show PortNumber
proxyPort

  m a
ret <- forall a. ClientSettings -> (AppData -> IO a) -> IO a
N.runTCPClient (Int -> ByteString -> ClientSettings
N.clientSettingsTCP (forall a b. (Integral a, Num b) => a -> b
fromIntegral PortNumber
proxyPort) (forall a. IsString a => HostName -> a
fromString HostName
proxyHost)) forall a b. (a -> b) -> a -> b
$ \AppData
conn' -> do
    let conn :: Connection
conn = forall a. ToConnection a => a -> Connection
toConnection AppData
conn'
    ()
_ <- Connection -> IO ()
sendConnectRequest Connection
conn

    -- wait 10sec for a reply before giving up
    let _10sec :: Int
_10sec = Int
1000000 forall a. Num a => a -> a -> a
* Int
10
    Maybe ByteString
responseM <- forall (m :: * -> *) a.
MonadUnliftIO m =>
Int -> m a -> m (Maybe a)
timeout Int
_10sec forall a b. (a -> b) -> a -> b
$ ByteString -> Connection -> IO ByteString
readConnectResponse forall a. Monoid a => a
mempty Connection
conn

    case Maybe ByteString
responseM of
      Just (ByteString -> Bool
isAuthorized -> Bool
True) -> Connection -> IO (m a)
app Connection
conn
      Just ByteString
response               -> forall (m :: * -> *) a. Monad m => a -> m a
return forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ HostName -> Error
ProxyForwardError (ByteString -> HostName
BC.unpack ByteString
response)
      Maybe ByteString
Nothing                     -> forall (m :: * -> *) a. Monad m => a -> m a
return forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ HostName -> Error
ProxyForwardError (HostName
"No response from the proxy after "
                                                                              forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> HostName
show (Int
_10sec forall a. Integral a => a -> a -> a
`div` Int
1000000) forall a. Semigroup a => a -> a -> a
<> HostName
"sec" )

  HostName -> IO ()
debug forall a b. (a -> b) -> a -> b
$ HostName
"Closing tcp connection to proxy " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> HostName
show HostName
proxyHost forall a. Semigroup a => a -> a -> a
<> HostName
":" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> HostName
show PortNumber
proxyPort
  forall (m :: * -> *) a. Monad m => a -> m a
return m a
ret

  where
    credentialsToHeader :: (ByteString, ByteString) -> ByteString
    credentialsToHeader :: (ByteString, ByteString) -> ByteString
credentialsToHeader (ByteString
user, ByteString
password) = ByteString
"Proxy-Authorization: Basic " forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
B64.encode (ByteString
user forall a. Semigroup a => a -> a -> a
<> ByteString
":" forall a. Semigroup a => a -> a -> a
<> ByteString
password) forall a. Semigroup a => a -> a -> a
<> ByteString
"\r\n"

    sendConnectRequest :: Connection -> IO ()
    sendConnectRequest :: Connection -> IO ()
sendConnectRequest Connection
h = Connection -> ByteString -> IO ()
write Connection
h forall a b. (a -> b) -> a -> b
$ ByteString
"CONNECT " forall a. Semigroup a => a -> a -> a
<> forall a. IsString a => HostName -> a
fromString HostName
host forall a. Semigroup a => a -> a -> a
<> ByteString
":" forall a. Semigroup a => a -> a -> a
<> forall a. IsString a => HostName -> a
fromString (forall a. Show a => a -> HostName
show PortNumber
port) forall a. Semigroup a => a -> a -> a
<> ByteString
" HTTP/1.0\r\n"
                                  forall a. Semigroup a => a -> a -> a
<> ByteString
"Host: " forall a. Semigroup a => a -> a -> a
<> forall a. IsString a => HostName -> a
fromString HostName
host forall a. Semigroup a => a -> a -> a
<> ByteString
":" forall a. Semigroup a => a -> a -> a
<> forall a. IsString a => HostName -> a
fromString (forall a. Show a => a -> HostName
show PortNumber
port) forall a. Semigroup a => a -> a -> a
<> ByteString
"\r\n"
                                  forall a. Semigroup a => a -> a -> a
<> forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty (ByteString, ByteString) -> ByteString
credentialsToHeader Maybe (ByteString, ByteString)
credentials
                                  forall a. Semigroup a => a -> a -> a
<> ByteString
"\r\n"

    readConnectResponse :: ByteString -> Connection -> IO ByteString
    readConnectResponse :: ByteString -> Connection -> IO ByteString
readConnectResponse ByteString
buff Connection
conn = do
      Maybe ByteString
responseM <- Connection -> IO (Maybe ByteString)
read Connection
conn
      case Maybe ByteString
responseM of
        Maybe ByteString
Nothing       -> forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
buff
        Just ByteString
response -> if ByteString
"\r\n\r\n" forall seq.
(IsSequence seq, Eq (Element seq)) =>
seq -> seq -> Bool
`isInfixOf` ByteString
response
                          then forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString
buff forall a. Semigroup a => a -> a -> a
<> ByteString
response
                          else ByteString -> Connection -> IO ByteString
readConnectResponse (ByteString
buff forall a. Semigroup a => a -> a -> a
<> ByteString
response) Connection
conn

    isAuthorized :: ByteString -> Bool
    isAuthorized :: ByteString -> Bool
isAuthorized ByteString
response = ByteString
" 200 " forall seq.
(IsSequence seq, Eq (Element seq)) =>
seq -> seq -> Bool
`isInfixOf` ByteString
response

    onError :: m (m a) -> m (m a)
onError m (m a)
f = forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
catch m (m a)
f forall a b. (a -> b) -> a -> b
$ \(SomeException
e :: SomeException) -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
      if forall seq. IsSequence seq => Index seq -> seq -> seq
take Int
10 (forall a. Show a => a -> HostName
show SomeException
e) forall a. Eq a => a -> a -> Bool
== HostName
"user error"
        then forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ HostName -> Error
ProxyConnectionError (forall a. Show a => a -> HostName
show SomeException
e)
        else forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ HostName -> Error
ProxyConnectionError (HostName
"Unknown Error :: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> HostName
show SomeException
e)