{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | Simple tools for establishing and using insecure WebSockets connections on top of
-- TCP (i.e, @ws:\/\/@).
--
-- See the
-- [network-simple-wss](https://hackage.haskell.org/package/network-simple-wss)
-- package for Secure WebSockets (i.e, @wss:\/\/@) support.
--
-- Notice that, currently, this is package offers tools that are mostly
-- intreresting from a client's point of view. Server side support will come
-- later.
module Network.Simple.WS
 ( -- * Sending and receiving
   W.Connection
 , recv
 , send
 , close
   -- * Client side
 , connect
 , connectSOCKS5
   -- * Low level
 , clientConnectionFromStream
 , streamFromSocket
 ) where


import Control.Concurrent.Async (Async)
import qualified Control.Concurrent.Async as Async
import Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Control.Exception.Safe as Ex
import Data.Bifunctor (first)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy as BL
import qualified Data.CaseInsensitive as CI
import Data.Foldable (traverse_)
import Data.Word
import GHC.IO.Exception as IO

import qualified Network.Simple.TCP as T
import qualified Network.WebSockets as W
import qualified Network.WebSockets.Connection as W (pingThread)
import qualified Network.WebSockets.Stream as W (Stream, makeStream, close)

--------------------------------------------------------------------------------

-- | Connect to the specified WebSockets server.
connect
  :: (MonadIO m, Ex.MonadMask m)
  => T.HostName
  -- ^ WebSockets server host name (e.g., @\"www.example.com\"@ or IP
  -- address).
  -> T.ServiceName
  -- ^ WebSockets server port (e.g., @\"443\"@ or @\"www\"@).
  -> B.ByteString
  -- ^ WebSockets resource (e.g., @\"/foo\/qux?bar=wat&baz\"@).
  --
  -- Leading @\'\/\'@ is optional.
  -> [(B.ByteString, B.ByteString)]
  -- ^ Extra HTTP Headers
  -- (e.g., @[(\"Authorization\", \"Basic dXNlcjpwYXNzd29yZA==\")]@).
  -> ((W.Connection, T.SockAddr) -> m r)
  -- ^ Computation to run after establishing a WebSockets to the remote
  -- server. Takes the WebSockets 'W.Connection' and remote end address.
  -> m r
connect :: forall (m :: * -> *) r.
(MonadIO m, MonadMask m) =>
String
-> String
-> ByteString
-> [(ByteString, ByteString)]
-> ((Connection, SockAddr) -> m r)
-> m r
connect String
hn String
sn ByteString
res [(ByteString, ByteString)]
hds (Connection, SockAddr) -> m r
act = do
  forall (m :: * -> *) r.
(MonadIO m, MonadMask m) =>
String -> String -> ((Socket, SockAddr) -> m r) -> m r
T.connect String
hn String
sn forall a b. (a -> b) -> a -> b
$ \(Socket
sock, SockAddr
saddr) -> do
    forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
Ex.bracket (forall (m :: * -> *). MonadIO m => Socket -> m Stream
streamFromSocket Socket
sock) (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stream -> IO ()
W.close) forall a b. (a -> b) -> a -> b
$ \Stream
stream -> do
      Connection
conn <- forall (m :: * -> *).
MonadIO m =>
Stream
-> String
-> String
-> ByteString
-> [(ByteString, ByteString)]
-> m Connection
clientConnectionFromStream Stream
stream String
hn String
sn ByteString
res [(ByteString, ByteString)]
hds
      forall (m :: * -> *) a b.
(MonadMask m, MonadIO m) =>
IO a -> (Async a -> m b) -> m b
withAsync (Connection -> Int -> IO () -> IO ()
W.pingThread Connection
conn Int
30 (forall (f :: * -> *) a. Applicative f => a -> f a
pure ())) forall a b. (a -> b) -> a -> b
$ \Async ()
_ ->
        (Connection, SockAddr) -> m r
act (Connection
conn, SockAddr
saddr)

-- | Like 'connect', but connects to the destination server through a SOCKS5
-- proxy.
connectSOCKS5
  :: (MonadIO m, Ex.MonadMask m)
  => T.HostName
  -- ^ SOCKS5 proxy server hostname or IP address.
  -> T.ServiceName
  -- ^ SOCKS5 proxy server service port name or number.
  -> T.HostName
  -- ^ Destination WebSockets server hostname or IP address. We connect to this
  -- host /through/ the SOCKS5 proxy specified in the previous arguments.
  --
  -- Note that if hostname resolution on this 'T.HostName' is necessary, it
  -- will happen on the proxy side for security reasons, not locally.
  -> T.ServiceName
  -- ^ Destination WebSockets server port (e.g., @\"443\"@ or @\"www\"@).
  -> B.ByteString
  -- ^ WebSockets resource (e.g., @\"/foo\/qux?bar=wat&baz\"@).
  --
  -- Leading @\'\/\'@ is optional.
  -> [(B.ByteString, B.ByteString)]
  -- ^ Extra HTTP Headers
  -- (e.g., @[(\"Authorization\", \"Basic dXNlcjpwYXNzd29yZA==\")]@).
  -> ((W.Connection, T.SockAddr, T.SockAddr) -> m r)
  -- ^ Computation taking a 'W.Connection' for communicating with the
  -- destination WebSockets server through the SOCKS5 server, the address
  -- of that SOCKS5 server, and the address of the destination WebSockets
  -- server, in that order.
 -> m r
connectSOCKS5 :: forall (m :: * -> *) r.
(MonadIO m, MonadMask m) =>
String
-> String
-> String
-> String
-> ByteString
-> [(ByteString, ByteString)]
-> ((Connection, SockAddr, SockAddr) -> m r)
-> m r
connectSOCKS5 String
phn String
psn String
dhn String
dsn ByteString
res [(ByteString, ByteString)]
hds (Connection, SockAddr, SockAddr) -> m r
act = do
  forall (m :: * -> *) r.
(MonadIO m, MonadMask m) =>
String
-> String
-> String
-> String
-> ((Socket, SockAddr, SockAddr) -> m r)
-> m r
T.connectSOCKS5 String
phn String
psn String
dhn String
dsn forall a b. (a -> b) -> a -> b
$ \(Socket
sock, SockAddr
pa, SockAddr
da) -> do
    forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
Ex.bracket (forall (m :: * -> *). MonadIO m => Socket -> m Stream
streamFromSocket Socket
sock) (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stream -> IO ()
W.close) forall a b. (a -> b) -> a -> b
$ \Stream
stream -> do
      Connection
conn <- forall (m :: * -> *).
MonadIO m =>
Stream
-> String
-> String
-> ByteString
-> [(ByteString, ByteString)]
-> m Connection
clientConnectionFromStream Stream
stream String
dhn String
dsn ByteString
res [(ByteString, ByteString)]
hds
      forall (m :: * -> *) a b.
(MonadMask m, MonadIO m) =>
IO a -> (Async a -> m b) -> m b
withAsync (Connection -> Int -> IO () -> IO ()
W.pingThread Connection
conn Int
30 (forall (f :: * -> *) a. Applicative f => a -> f a
pure ())) forall a b. (a -> b) -> a -> b
$ \Async ()
_ ->
        (Connection, SockAddr, SockAddr) -> m r
act (Connection
conn, SockAddr
pa, SockAddr
da)

-- | Obtain a 'W.Connection' to the specified URI over the given 'W.Stream',
-- connected to either a WebSockets server, or a Secure WebSockets server.
clientConnectionFromStream
  :: MonadIO m
  => W.Stream
  -- ^ Stream on which to establish the WebSockets connection.
  -> T.HostName
  -- ^ WebSockets server host name (e.g., @\"www.example.com\"@ or IP address).
  -> T.ServiceName
  -- ^ WebSockets server port (e.g., @\"443\"@ or @\"www\"@).
  -> B.ByteString
  -- ^ WebSockets resource (e.g., @\"/foo\/qux?bar=wat&baz\"@).
  --
  -- Leading @\'\/\'@ is optional.
  -> [(B.ByteString, B.ByteString)]
  -- ^ Extra HTTP Headers
  -- (e.g., @[(\"Authorization\", \"Basic dXNlcjpwYXNzd29yZA==\")]@).
  -> m W.Connection
  -- ^ Established WebSockets connection
clientConnectionFromStream :: forall (m :: * -> *).
MonadIO m =>
Stream
-> String
-> String
-> ByteString
-> [(ByteString, ByteString)]
-> m Connection
clientConnectionFromStream Stream
stream String
hn String
sn ByteString
res [(ByteString, ByteString)]
hds = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
  let String
res' :: String = Char
'/' forall a. a -> [a] -> [a]
: forall a. (a -> Bool) -> [a] -> [a]
dropWhile (forall a. Eq a => a -> a -> Bool
==Char
'/') (ByteString -> String
B8.unpack ByteString
res)
      Headers
hds' :: W.Headers = forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall s. FoldCase s => s -> CI s
CI.mk) [(ByteString, ByteString)]
hds
      String
hnsn :: String = String
hn forall a. [a] -> [a] -> [a]
++ String
":" forall a. [a] -> [a] -> [a]
++ String
sn
      ConnectionOptions
wopts :: W.ConnectionOptions = ConnectionOptions
W.defaultConnectionOptions
        { connectionStrictUnicode :: Bool
W.connectionStrictUnicode =
            Bool
False -- Slows stuff down. And see 'recv'.
        , connectionCompressionOptions :: CompressionOptions
W.connectionCompressionOptions =
            PermessageDeflate -> CompressionOptions
W.PermessageDeflateCompression
              PermessageDeflate
W.defaultPermessageDeflate }
  Stream
-> String
-> String
-> ConnectionOptions
-> Headers
-> IO Connection
W.newClientConnection Stream
stream String
hnsn String
res' ConnectionOptions
wopts Headers
hds'

-- | Obtain a 'W.Stream' implemented using the network 'T.Socket'. You can
-- use the
-- [network-simple](https://hackage.haskell.org/package/network-simple)
-- library to get one of those.
streamFromSocket :: MonadIO m => T.Socket -> m W.Stream
streamFromSocket :: forall (m :: * -> *). MonadIO m => Socket -> m Stream
streamFromSocket Socket
sock = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
  IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
W.makeStream (forall (m :: * -> *).
MonadIO m =>
Socket -> Int -> m (Maybe ByteString)
T.recv Socket
sock Int
4096) (forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (forall (m :: * -> *). MonadIO m => Socket -> ByteString -> m ()
T.sendLazy Socket
sock))

-- | Receive a single full WebSockets message from the remote end as a lazy
-- 'BL.ByteString' (potentially 'BL.empty').
--
-- Throws 'IO.IOException' if there is an unexpected 'W.Connection' error.
--
-- If the remote end requested the 'W.Connection' to be closed, then 'Left'
-- will be returned instead, with a close code and reason description.
--
-- * See https://datatracker.ietf.org/doc/html/rfc6455#section-7.4 for details
-- about the close codes.
--
-- * Do not use 'recv' after receiving a close request.
--
-- * If you receive a close request after after having sent a close request
-- yourself (see 'close'), then the WebSocket 'W.Connection' is
-- considered closed and you can proceed to close the underlying transport.
--
-- * If you didn't send a close request before, then you may continue to use
-- 'send', but you are expected to perform 'close' as soon as possible in order
-- to indicate a graceful closing of the connection.

-- Note: The WebSockets protocol supports the silly idea of sending text
-- rather than bytes. We don't support that. If necessary, you can find support
-- for this in the `websockets` library.
recv :: MonadIO m
     => W.Connection
     -> m (Either (Word16, BL.ByteString) BL.ByteString)
recv :: forall (m :: * -> *).
MonadIO m =>
Connection -> m (Either (Word16, ByteString) ByteString)
recv Connection
conn = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
Ex.try (Connection -> IO DataMessage
W.receiveDataMessage Connection
conn) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Right (W.Binary !ByteString
bl) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right ByteString
bl
  Right (W.Text !ByteString
bl Maybe Text
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right ByteString
bl
  Left (W.CloseRequest !Word16
w !ByteString
bl) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left (Word16
w, ByteString
bl)
  Left ConnectionException
W.ConnectionClosed ->
    forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
Ex.throw forall a b. (a -> b) -> a -> b
$ String -> IOErrorType -> String -> IOError
ioe String
"recv" IOErrorType
IO.ResourceVanished String
"Connection closed"
  Left (W.ParseException String
s) ->
    forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
Ex.throw forall a b. (a -> b) -> a -> b
$ String -> IOErrorType -> String -> IOError
ioe String
"recv" IOErrorType
IO.ProtocolError (String
"WebSocket parsing error: " forall a. Semigroup a => a -> a -> a
<> String
s)
  Left (W.UnicodeException String
s) ->
    forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
Ex.throw forall a b. (a -> b) -> a -> b
$ String -> IOErrorType -> String -> IOError
ioe String
"recv" IOErrorType
IO.ProtocolError (String
"WebSocket UTF-8 error: " forall a. Semigroup a => a -> a -> a
<> String
s)

-- | Send a lazy 'BL.ByteString' (potentially 'BL.empty') to the remote end as
-- a single WebSockets message, in potentially multiple frames.
--
-- If there is an issue with the 'W.Connection', an exception originating from
-- the underlying 'W.Stream' will be thrown.

-- Note: The WebSockets protocol supports the silly idea of sending text rather
-- than bytes. We don't support that. If necessary, users can
-- find support for this in the `websockets` library.
send :: MonadIO m => W.Connection -> BL.ByteString -> m ()
send :: forall (m :: * -> *). MonadIO m => Connection -> ByteString -> m ()
send Connection
conn = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> DataMessage -> IO ()
W.sendDataMessage Connection
conn forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> DataMessage
W.Binary

-- | Send a close request to the remote end.
--
-- After sending this request you should not use 'send' anymore, but you
-- should still continue to call 'recv' to process any pending incomming
-- messages. As soon as 'recv' returns 'Left', you can consider the WebSocket
-- 'W.Connection' closed and can proceed to close the underlying transport.
--
-- If there is an issue with the 'W.Connection', an exception originating from
-- the underlying 'W.Stream' will be thrown.
close :: MonadIO m
      => W.Connection
      -> Word16        -- ^ Close code.
      -> BL.ByteString -- ^ Reason for closing.
      -> m ()
close :: forall (m :: * -> *).
MonadIO m =>
Connection -> Word16 -> ByteString -> m ()
close Connection
conn Word16
w ByteString
bl = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. WebSocketsData a => Connection -> Word16 -> a -> IO ()
W.sendCloseCode Connection
conn Word16
w ByteString
bl

-- | Like 'Async.async', but generalized to 'Ex.MonadMask' and 'MonadIO'.
withAsync
  :: (Ex.MonadMask m, MonadIO m)
  => IO a
  -> (Async a -> m b)
  -> m b
withAsync :: forall (m :: * -> *) a b.
(MonadMask m, MonadIO m) =>
IO a -> (Async a -> m b) -> m b
withAsync IO a
io = forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
Ex.bracket
  (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. ((forall b. IO b -> IO b) -> IO a) -> IO (Async a)
Async.asyncWithUnmask (\forall b. IO b -> IO b
u -> forall b. IO b -> IO b
u IO a
io))
  (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Async a -> IO ()
Async.uninterruptibleCancel)

-- | Construct a 'IO.IOError' relevant to this module.
ioe :: String  -- ^ Location
    -> IO.IOErrorType
    -> String  -- ^ Description
    -> IO.IOError
ioe :: String -> IOErrorType -> String -> IOError
ioe String
l IOErrorType
t String
s = IO.IOError
  { ioe_type :: IOErrorType
IO.ioe_type = IOErrorType
t
  , ioe_location :: String
IO.ioe_location = String
"Network.Simple.WS." forall a. Semigroup a => a -> a -> a
<> String
l
  , ioe_description :: String
IO.ioe_description = String
s
  , ioe_errno :: Maybe CInt
IO.ioe_errno = forall a. Maybe a
Nothing
  , ioe_handle :: Maybe Handle
IO.ioe_handle = forall a. Maybe a
Nothing
  , ioe_filename :: Maybe String
IO.ioe_filename = forall a. Maybe a
Nothing
  }