{-# LANGUAGE DeriveAnyClass     #-}
{-# LANGUAGE DeriveGeneric      #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StrictData         #-}

module Types where


import           ClassyPrelude
import           Data.Maybe
import           System.IO (stdin, stdout)
import           Data.ByteString (hGetSome, hPutStr)

import           Data.CaseInsensitive  ( CI )
import qualified Data.Streaming.Network        as N
import qualified Network.Connection            as NC
import           Network.Socket                (HostName, PortNumber)
import qualified Network.Socket                as N hiding (recv, recvFrom, send, sendTo)
import qualified Network.WebSockets.Connection as WS
import                  System.IO.Unsafe (unsafeDupablePerformIO)


instance Hashable PortNumber where
  hashWithSalt :: Int -> PortNumber -> Int
hashWithSalt Int
s PortNumber
p      = forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (forall a. Enum a => a -> Int
fromEnum PortNumber
p)
  
deriving instance Generic N.SockAddr
deriving instance Hashable N.SockAddr


{-# NOINLINE defaultRecvBufferSize #-}   
defaultRecvBufferSize ::  Int
defaultRecvBufferSize :: Int
defaultRecvBufferSize = forall a. IO a -> a
unsafeDupablePerformIO forall a b. (a -> b) -> a -> b
$
  forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (Family -> SocketType -> CInt -> IO Socket
N.socket Family
N.AF_INET SocketType
N.Stream CInt
0) Socket -> IO ()
N.close (\Socket
sock -> Socket -> SocketOption -> IO Int
N.getSocketOption Socket
sock SocketOption
N.RecvBuffer)

sO_MARK :: N.SocketOption
sO_MARK :: SocketOption
sO_MARK = CInt -> CInt -> SocketOption
N.SockOpt CInt
1 CInt
36 -- https://elixir.bootlin.com/linux/latest/source/arch/alpha/include/uapi/asm/socket.h#L64

{-# NOINLINE sO_MARK_Value #-}
sO_MARK_Value :: IORef Int
sO_MARK_Value :: IORef Int
sO_MARK_Value = forall a. IO a -> a
unsafeDupablePerformIO forall a b. (a -> b) -> a -> b
$ (forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef Int
0)

data Protocol = UDP | TCP | STDIO | SOCKS5 deriving (Int -> Protocol -> ShowS
[Protocol] -> ShowS
Protocol -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Protocol] -> ShowS
$cshowList :: [Protocol] -> ShowS
show :: Protocol -> String
$cshow :: Protocol -> String
showsPrec :: Int -> Protocol -> ShowS
$cshowsPrec :: Int -> Protocol -> ShowS
Show, ReadPrec [Protocol]
ReadPrec Protocol
Int -> ReadS Protocol
ReadS [Protocol]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Protocol]
$creadListPrec :: ReadPrec [Protocol]
readPrec :: ReadPrec Protocol
$creadPrec :: ReadPrec Protocol
readList :: ReadS [Protocol]
$creadList :: ReadS [Protocol]
readsPrec :: Int -> ReadS Protocol
$creadsPrec :: Int -> ReadS Protocol
Read, Protocol -> Protocol -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Protocol -> Protocol -> Bool
$c/= :: Protocol -> Protocol -> Bool
== :: Protocol -> Protocol -> Bool
$c== :: Protocol -> Protocol -> Bool
Eq)

data StdioAppData = StdioAppData

data UdpAppData = UdpAppData
  { UdpAppData -> SockAddr
appAddr  :: N.SockAddr
  , UdpAppData -> MVar ByteString
appSem   :: MVar ByteString
  , UdpAppData -> IO ByteString
appRead  :: IO ByteString
  , UdpAppData -> ByteString -> IO ()
appWrite :: ByteString -> IO ()
  }

instance N.HasReadWrite UdpAppData where
  readLens :: forall (f :: * -> *).
Functor f =>
(IO ByteString -> f (IO ByteString)) -> UdpAppData -> f UdpAppData
readLens IO ByteString -> f (IO ByteString)
f UdpAppData
appData =  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\IO ByteString
getData -> UdpAppData
appData { appRead :: IO ByteString
appRead = IO ByteString
getData})  (IO ByteString -> f (IO ByteString)
f forall a b. (a -> b) -> a -> b
$ UdpAppData -> IO ByteString
appRead UdpAppData
appData)
  writeLens :: forall (f :: * -> *).
Functor f =>
((ByteString -> IO ()) -> f (ByteString -> IO ()))
-> UdpAppData -> f UdpAppData
writeLens (ByteString -> IO ()) -> f (ByteString -> IO ())
f UdpAppData
appData = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ByteString -> IO ()
writeData -> UdpAppData
appData { appWrite :: ByteString -> IO ()
appWrite = ByteString -> IO ()
writeData}) ((ByteString -> IO ()) -> f (ByteString -> IO ())
f forall a b. (a -> b) -> a -> b
$ UdpAppData -> ByteString -> IO ()
appWrite UdpAppData
appData)

data ProxySettings = ProxySettings
  { ProxySettings -> String
host        :: HostName
  , ProxySettings -> PortNumber
port        :: PortNumber
  , ProxySettings -> Maybe (ByteString, ByteString)
credentials :: Maybe (ByteString, ByteString)
  } deriving (Int -> ProxySettings -> ShowS
[ProxySettings] -> ShowS
ProxySettings -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ProxySettings] -> ShowS
$cshowList :: [ProxySettings] -> ShowS
show :: ProxySettings -> String
$cshow :: ProxySettings -> String
showsPrec :: Int -> ProxySettings -> ShowS
$cshowsPrec :: Int -> ProxySettings -> ShowS
Show)

data TunnelSettings = TunnelSettings
  { TunnelSettings -> Maybe ProxySettings
proxySetting  :: Maybe ProxySettings
  , TunnelSettings -> String
localBind     :: HostName
  , TunnelSettings -> PortNumber
localPort     :: PortNumber
  , TunnelSettings -> String
serverHost    :: HostName
  , TunnelSettings -> PortNumber
serverPort    :: PortNumber
  , TunnelSettings -> String
destHost      :: HostName
  , TunnelSettings -> PortNumber
destPort      :: PortNumber
  , TunnelSettings -> Protocol
protocol      :: Protocol
  , TunnelSettings -> Bool
useTls        :: Bool
  , TunnelSettings -> Bool
useSocks      :: Bool
  , TunnelSettings -> String
upgradePrefix :: String
  , TunnelSettings -> ByteString
upgradeCredentials
                  :: ByteString
  , TunnelSettings -> ByteString
tlsSNI        :: ByteString
  , TunnelSettings -> Bool
tlsVerifyCertificate :: Bool
  , TunnelSettings -> ByteString
hostHeader    :: ByteString
  , TunnelSettings -> Int
udpTimeout    :: Int
  , TunnelSettings -> Int
websocketPingFrequencySec :: Int
  , TunnelSettings -> [(CI ByteString, ByteString)]
customHeaders :: [(CI ByteString, ByteString)]
  }

instance Show TunnelSettings where
  show :: TunnelSettings -> String
show TunnelSettings{Bool
Int
String
[(CI ByteString, ByteString)]
Maybe ProxySettings
ByteString
PortNumber
Protocol
customHeaders :: [(CI ByteString, ByteString)]
websocketPingFrequencySec :: Int
udpTimeout :: Int
hostHeader :: ByteString
tlsVerifyCertificate :: Bool
tlsSNI :: ByteString
upgradeCredentials :: ByteString
upgradePrefix :: String
useSocks :: Bool
useTls :: Bool
protocol :: Protocol
destPort :: PortNumber
destHost :: String
serverPort :: PortNumber
serverHost :: String
localPort :: PortNumber
localBind :: String
proxySetting :: Maybe ProxySettings
customHeaders :: TunnelSettings -> [(CI ByteString, ByteString)]
websocketPingFrequencySec :: TunnelSettings -> Int
udpTimeout :: TunnelSettings -> Int
hostHeader :: TunnelSettings -> ByteString
tlsVerifyCertificate :: TunnelSettings -> Bool
tlsSNI :: TunnelSettings -> ByteString
upgradeCredentials :: TunnelSettings -> ByteString
upgradePrefix :: TunnelSettings -> String
useSocks :: TunnelSettings -> Bool
useTls :: TunnelSettings -> Bool
protocol :: TunnelSettings -> Protocol
destPort :: TunnelSettings -> PortNumber
destHost :: TunnelSettings -> String
serverPort :: TunnelSettings -> PortNumber
serverHost :: TunnelSettings -> String
localPort :: TunnelSettings -> PortNumber
localBind :: TunnelSettings -> String
proxySetting :: TunnelSettings -> Maybe ProxySettings
..} =  String
localBind forall a. Semigroup a => a -> a -> a
<> String
":" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show PortNumber
localPort
                             forall a. Semigroup a => a -> a -> a
<> (if forall a. Maybe a -> Bool
isNothing Maybe ProxySettings
proxySetting
                                 then forall a. Monoid a => a
mempty
                                 else String
" <==PROXY==> " forall a. Semigroup a => a -> a -> a
<> ProxySettings -> String
host (forall a. HasCallStack => Maybe a -> a
fromJust Maybe ProxySettings
proxySetting) forall a. Semigroup a => a -> a -> a
<> String
":" forall a. Semigroup a => a -> a -> a
<> (forall a. Show a => a -> String
show forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ProxySettings -> PortNumber
port forall a b. (a -> b) -> a -> b
$ forall a. HasCallStack => Maybe a -> a
fromJust Maybe ProxySettings
proxySetting)
                                )
                             forall a. Semigroup a => a -> a -> a
<> String
" <==" forall a. Semigroup a => a -> a -> a
<> (if Bool
useTls then String
"WSS" else String
"WS") forall a. Semigroup a => a -> a -> a
<> String
"==> "
                             forall a. Semigroup a => a -> a -> a
<> String
serverHost forall a. Semigroup a => a -> a -> a
<> String
":" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show PortNumber
serverPort
                             forall a. Semigroup a => a -> a -> a
<> String
" <==" forall a. Semigroup a => a -> a -> a
<>  forall a. Show a => a -> String
show (if Protocol
protocol forall a. Eq a => a -> a -> Bool
== Protocol
SOCKS5 then Protocol
TCP else Protocol
protocol) forall a. Semigroup a => a -> a -> a
<> String
"==> " forall a. Semigroup a => a -> a -> a
<> String
destHost forall a. Semigroup a => a -> a -> a
<> String
":" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show PortNumber
destPort


data Connection = Connection
  { Connection -> IO (Maybe ByteString)
read          :: IO (Maybe ByteString)
  , Connection -> ByteString -> IO ()
write         :: ByteString -> IO ()
  , Connection -> IO ()
close         :: IO ()
  , Connection -> Maybe Socket
rawConnection :: Maybe N.Socket
  }

class ToConnection a where
  toConnection :: a -> Connection

instance ToConnection StdioAppData where
  toConnection :: StdioAppData -> Connection
toConnection StdioAppData
conn = Connection { read :: IO (Maybe ByteString)
read = forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Handle -> Int -> IO ByteString
hGetSome Handle
stdin Int
512
                                 , write :: ByteString -> IO ()
write = Handle -> ByteString -> IO ()
hPutStr Handle
stdout
                                 , close :: IO ()
close = forall (m :: * -> *) a. Monad m => a -> m a
return ()
                                 , rawConnection :: Maybe Socket
rawConnection = forall a. Maybe a
Nothing
                                 }

instance ToConnection WS.Connection where
  toConnection :: Connection -> Connection
toConnection Connection
conn = Connection { read :: IO (Maybe ByteString)
read = forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. WebSocketsData a => Connection -> IO a
WS.receiveData Connection
conn
                                 , write :: ByteString -> IO ()
write = forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendBinaryData Connection
conn
                                 , close :: IO ()
close = forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendClose Connection
conn (forall a. Monoid a => a
mempty :: LByteString)
                                 , rawConnection :: Maybe Socket
rawConnection = forall a. Maybe a
Nothing
                                 }

instance ToConnection N.AppData where
  toConnection :: AppData -> Connection
toConnection AppData
conn = Connection { read :: IO (Maybe ByteString)
read = forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. HasReadWrite a => a -> IO ByteString
N.appRead AppData
conn
                                 , write :: ByteString -> IO ()
write = forall a. HasReadWrite a => a -> ByteString -> IO ()
N.appWrite AppData
conn
                                 , close :: IO ()
close = AppData -> IO ()
N.appCloseConnection AppData
conn
                                 , rawConnection :: Maybe Socket
rawConnection = forall a. Maybe a
Nothing
                                 }

instance ToConnection UdpAppData where
  toConnection :: UdpAppData -> Connection
toConnection UdpAppData
conn = Connection { read :: IO (Maybe ByteString)
read = forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UdpAppData -> IO ByteString
appRead UdpAppData
conn
                                 , write :: ByteString -> IO ()
write = UdpAppData -> ByteString -> IO ()
appWrite UdpAppData
conn
                                 , close :: IO ()
close = forall (m :: * -> *) a. Monad m => a -> m a
return ()
                                 , rawConnection :: Maybe Socket
rawConnection = forall a. Maybe a
Nothing
                                 }

instance ToConnection NC.Connection where
  toConnection :: Connection -> Connection
toConnection Connection
conn = Connection { read :: IO (Maybe ByteString)
read = forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO ByteString
NC.connectionGetChunk Connection
conn
                                 , write :: ByteString -> IO ()
write = Connection -> ByteString -> IO ()
NC.connectionPut Connection
conn
                                 , close :: IO ()
close = Connection -> IO ()
NC.connectionClose Connection
conn
                                 , rawConnection :: Maybe Socket
rawConnection = forall a. Maybe a
Nothing
                                 }

data Error = ProxyConnectionError String
           | ProxyForwardError String
           | LocalServerError String
           | TunnelError String
           | WebsocketError String
           | TlsError String
           | Other String
           deriving (Int -> Error -> ShowS
[Error] -> ShowS
Error -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Error] -> ShowS
$cshowList :: [Error] -> ShowS
show :: Error -> String
$cshow :: Error -> String
showsPrec :: Int -> Error -> ShowS
$cshowsPrec :: Int -> Error -> ShowS
Show)