{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}

-- | Simple functions to run TCP clients and servers.
module Network.Run.TCP (
    runTCPClient
  , runTCPServer
  ) where

import Control.Concurrent (forkFinally)
import qualified Control.Exception as E
import Control.Monad (forever, void)
import Network.Socket
#if defined(mingw32_HOST_OS)
import Control.Concurrent.MVar
import Control.Concurrent
import qualified Control.Exception
#endif

import Network.Run.Core

-- | Running a TCP client with a connected socket.
runTCPClient :: HostName -> ServiceName -> (Socket -> IO a) -> IO a
runTCPClient :: forall a. HostName -> HostName -> (Socket -> IO a) -> IO a
runTCPClient HostName
host HostName
port Socket -> IO a
client = forall a. IO a -> IO a
withSocketsDo forall a b. (a -> b) -> a -> b
$ do
    AddrInfo
addr <- SocketType -> Maybe HostName -> HostName -> Bool -> IO AddrInfo
resolve SocketType
Stream (forall a. a -> Maybe a
Just HostName
host) HostName
port Bool
False
#if MIN_VERSION_network(3,1,1)
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (AddrInfo -> IO Socket
open AddrInfo
addr) (\Socket
sock -> Socket -> Int -> IO ()
gracefulClose Socket
sock Int
5000) Socket -> IO a
client
#else
    E.bracket (open addr) close client
#endif
  where
    open :: AddrInfo -> IO Socket
open AddrInfo
addr = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (AddrInfo -> IO Socket
openSocket AddrInfo
addr) Socket -> IO ()
close forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
        Socket -> SockAddr -> IO ()
connect Socket
sock forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
addrAddress AddrInfo
addr
        forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock

-- | Running a TCP server with an accepted socket and its peer name.
runTCPServer :: Maybe HostName -> ServiceName -> (Socket -> IO a) -> IO a
runTCPServer :: forall a. Maybe HostName -> HostName -> (Socket -> IO a) -> IO a
runTCPServer Maybe HostName
mhost HostName
port Socket -> IO a
server = forall a. IO a -> IO a
withSocketsDo forall a b. (a -> b) -> a -> b
$ do
    AddrInfo
addr <- SocketType -> Maybe HostName -> HostName -> Bool -> IO AddrInfo
resolve SocketType
Stream Maybe HostName
mhost HostName
port Bool
True
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (AddrInfo -> IO Socket
open AddrInfo
addr) Socket -> IO ()
close forall {b}. Socket -> IO b
loop
  where
    open :: AddrInfo -> IO Socket
open AddrInfo
addr = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (AddrInfo -> IO Socket
openServerSocket AddrInfo
addr) Socket -> IO ()
close forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
        Socket -> Int -> IO ()
listen Socket
sock Int
1024
        forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
    loop :: Socket -> IO b
loop Socket
sock = forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (forall a. IO a -> IO a
windowsThreadBlockHack (Socket -> IO (Socket, SockAddr)
accept Socket
sock)) (Socket -> IO ()
close forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
        \(Socket
conn, SockAddr
_peer) ->
#if MIN_VERSION_network(3,1,1)
          forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally (Socket -> IO a
server Socket
conn) (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ Socket -> Int -> IO ()
gracefulClose Socket
conn Int
5000)
#else
          void $ forkFinally (server conn) (const $ close conn)
#endif


#if defined(mingw32_HOST_OS)
windowsThreadBlockHack :: IO a -> IO a
windowsThreadBlockHack act = do
    var <- newEmptyMVar :: IO (MVar (Either Control.Exception.SomeException a))
    void . forkIO $ Control.Exception.try act >>= putMVar var
    res <- takeMVar var
    case res of
      Left  e -> Control.Exception.throwIO e
      Right r -> return r
#else
windowsThreadBlockHack :: IO a -> IO a
windowsThreadBlockHack :: forall a. IO a -> IO a
windowsThreadBlockHack = forall a. a -> a
id
#endif