{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Hedgehog.Extras.Stock.IO.Network.Socket
  ( doesSocketExist
  , isPortOpen
  , canConnect
  , listenOn
  , allocateRandomPorts
  ) where

import           Control.Exception (IOException, handle)
import           Control.Monad
import           Data.Bool
import           Data.Function
import           Data.Functor
import           Data.Int
import           Data.Maybe
import           Network.Socket (Family (AF_INET), SockAddr (..), Socket, SocketType (Stream))
import           Prelude (fromIntegral)
import           System.IO (FilePath, IO)
import           Text.Show (show)

import qualified Network.Socket as IO
import qualified System.Directory as IO
import qualified UnliftIO.Exception as IO

-- | Check if a TCP port is open
isPortOpen :: Int -> IO Bool
isPortOpen :: Int -> IO Bool
isPortOpen Int
port = do
  [AddrInfo]
socketAddressInfos <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
IO.getAddrInfo forall a. Maybe a
Nothing (forall a. a -> Maybe a
Just HostName
"127.0.0.1") (forall a. a -> Maybe a
Just (forall a. Show a => a -> HostName
show Int
port))
  case [AddrInfo]
socketAddressInfos of
    AddrInfo
socketAddressInfo:[AddrInfo]
_ ->
      forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const @Bool @IOException Bool
False) forall a b. (a -> b) -> a -> b
$
        SockAddr -> IO ()
canConnect (AddrInfo -> SockAddr
IO.addrAddress AddrInfo
socketAddressInfo) forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Bool
True
    [] -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

-- | Check if it is possible to connect to a socket address
canConnect :: SockAddr -> IO ()
canConnect :: SockAddr -> IO ()
canConnect SockAddr
sockAddr = forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
IO.bracket (Family -> SocketType -> ProtocolNumber -> IO Socket
IO.socket Family
AF_INET SocketType
Stream ProtocolNumber
6) Socket -> IO ()
IO.close' forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
  Socket -> SockAddr -> IO ()
IO.connect Socket
sock SockAddr
sockAddr

-- | Open a socket at the specified port for listening
listenOn :: Int -> IO Socket
listenOn :: Int -> IO Socket
listenOn Int
n = do
  Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
IO.socket Family
AF_INET SocketType
Stream ProtocolNumber
0
  AddrInfo
sockAddrInfo:[AddrInfo]
_ <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
IO.getAddrInfo forall a. Maybe a
Nothing (forall a. a -> Maybe a
Just HostName
"127.0.0.1") (forall a. a -> Maybe a
Just (forall a. Show a => a -> HostName
show Int
n))
  Socket -> SocketOption -> Int -> IO ()
IO.setSocketOption Socket
sock SocketOption
IO.ReuseAddr Int
1
  Socket -> SockAddr -> IO ()
IO.bind Socket
sock (AddrInfo -> SockAddr
IO.addrAddress AddrInfo
sockAddrInfo)
  Socket -> Int -> IO ()
IO.listen Socket
sock Int
2
  forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock

doesSocketExist :: FilePath -> IO Bool
doesSocketExist :: HostName -> IO Bool
doesSocketExist = HostName -> IO Bool
IO.doesFileExist
{-# INLINE doesSocketExist #-}

-- | Allocate the specified number of random ports
allocateRandomPorts :: Int -> IO [Int]
allocateRandomPorts :: Int -> IO [Int]
allocateRandomPorts Int
n = do
  let hints :: AddrInfo
hints = AddrInfo
IO.defaultHints
        { addrFlags :: [AddrInfoFlag]
IO.addrFlags = [AddrInfoFlag
IO.AI_PASSIVE]
        , addrSocketType :: SocketType
IO.addrSocketType = SocketType
IO.Stream
        }

  -- Create n sockets with randomly bound ports, grab the port numbers and close those ports
  AddrInfo
addr:[AddrInfo]
_ <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
IO.getAddrInfo (forall a. a -> Maybe a
Just AddrInfo
hints) (forall a. a -> Maybe a
Just HostName
"127.0.0.1") (forall a. a -> Maybe a
Just HostName
"0")
  [Socket]
socks <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
1..Int
n] forall a b. (a -> b) -> a -> b
$ \Int
_ -> Family -> SocketType -> ProtocolNumber -> IO Socket
IO.socket (AddrInfo -> Family
IO.addrFamily AddrInfo
addr) (AddrInfo -> SocketType
IO.addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
IO.addrProtocol AddrInfo
addr)
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Socket]
socks forall a b. (a -> b) -> a -> b
$ \Socket
sock -> Socket -> SockAddr -> IO ()
IO.bind Socket
sock (AddrInfo -> SockAddr
IO.addrAddress AddrInfo
addr)
  [PortNumber]
ports <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Socket]
socks Socket -> IO PortNumber
IO.socketPort
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Socket]
socks Socket -> IO ()
IO.close

  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (Integral a, Num b) => a -> b
fromIntegral [PortNumber]
ports