{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Hedgehog.Extras.Stock.IO.Network.Socket
  ( doesSocketExist
  , isPortOpen
  , canConnect
  , listenOn
  , allocateRandomPorts
  ) where

import           Control.Exception (IOException, handle)
import           Control.Monad
import           Data.Bool
import           Data.Function
import           Data.Functor
import           Data.Int
import           Data.Maybe
import           Network.Socket (Family (AF_INET), SockAddr (..), Socket, SocketType (Stream))
import           Prelude (fromIntegral)
import           System.IO (FilePath, IO)
import           Text.Show (show)

import qualified Network.Socket as IO
import qualified System.Directory as IO
import qualified UnliftIO.Exception as IO

-- | Check if a TCP port is open
isPortOpen :: Int -> IO Bool
isPortOpen :: Int -> IO Bool
isPortOpen Int
port = do
  [AddrInfo]
socketAddressInfos <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
IO.getAddrInfo Maybe AddrInfo
forall a. Maybe a
Nothing (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
"127.0.0.1") (HostName -> Maybe HostName
forall a. a -> Maybe a
Just (Int -> HostName
forall a. Show a => a -> HostName
show Int
port))
  case [AddrInfo]
socketAddressInfos of
    AddrInfo
socketAddressInfo:[AddrInfo]
_ ->
      (IOException -> IO Bool) -> IO Bool -> IO Bool
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool)
-> (IOException -> Bool) -> IOException -> IO Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const @Bool @IOException Bool
False) (IO Bool -> IO Bool) -> IO Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$
        SockAddr -> IO ()
canConnect (AddrInfo -> SockAddr
IO.addrAddress AddrInfo
socketAddressInfo) IO () -> Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Bool
True
    [] -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

-- | Check if it is possible to connect to a socket address
canConnect :: SockAddr -> IO ()
canConnect :: SockAddr -> IO ()
canConnect SockAddr
sockAddr = IO Socket -> (Socket -> IO ()) -> (Socket -> IO ()) -> IO ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
IO.bracket (Family -> SocketType -> ProtocolNumber -> IO Socket
IO.socket Family
AF_INET SocketType
Stream ProtocolNumber
6) Socket -> IO ()
IO.close' ((Socket -> IO ()) -> IO ()) -> (Socket -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
  Socket -> SockAddr -> IO ()
IO.connect Socket
sock SockAddr
sockAddr

-- | Open a socket at the specified port for listening
listenOn :: Int -> IO Socket
listenOn :: Int -> IO Socket
listenOn Int
n = do
  Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
IO.socket Family
AF_INET SocketType
Stream ProtocolNumber
0
  AddrInfo
sockAddrInfo:[AddrInfo]
_ <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
IO.getAddrInfo Maybe AddrInfo
forall a. Maybe a
Nothing (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
"127.0.0.1") (HostName -> Maybe HostName
forall a. a -> Maybe a
Just (Int -> HostName
forall a. Show a => a -> HostName
show Int
n))
  Socket -> SocketOption -> Int -> IO ()
IO.setSocketOption Socket
sock SocketOption
IO.ReuseAddr Int
1
  Socket -> SockAddr -> IO ()
IO.bind Socket
sock (AddrInfo -> SockAddr
IO.addrAddress AddrInfo
sockAddrInfo)
  Socket -> Int -> IO ()
IO.listen Socket
sock Int
2
  Socket -> IO Socket
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock

doesSocketExist :: FilePath -> IO Bool
doesSocketExist :: HostName -> IO Bool
doesSocketExist = HostName -> IO Bool
IO.doesFileExist
{-# INLINE doesSocketExist #-}

-- | Allocate the specified number of random ports
allocateRandomPorts :: Int -> IO [Int]
allocateRandomPorts :: Int -> IO [Int]
allocateRandomPorts Int
n = do
  let hints :: AddrInfo
hints = AddrInfo
IO.defaultHints
        { IO.addrFlags = [IO.AI_PASSIVE]
        , IO.addrSocketType = IO.Stream
        }

  -- Create n sockets with randomly bound ports, grab the port numbers and close those ports
  AddrInfo
addr:[AddrInfo]
_ <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
IO.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
"127.0.0.1") (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
"0")
  [Socket]
socks <- [Int] -> (Int -> IO Socket) -> IO [Socket]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
1..Int
n] ((Int -> IO Socket) -> IO [Socket])
-> (Int -> IO Socket) -> IO [Socket]
forall a b. (a -> b) -> a -> b
$ \Int
_ -> Family -> SocketType -> ProtocolNumber -> IO Socket
IO.socket (AddrInfo -> Family
IO.addrFamily AddrInfo
addr) (AddrInfo -> SocketType
IO.addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
IO.addrProtocol AddrInfo
addr)
  [Socket] -> (Socket -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Socket]
socks ((Socket -> IO ()) -> IO ()) -> (Socket -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> Socket -> SockAddr -> IO ()
IO.bind Socket
sock (AddrInfo -> SockAddr
IO.addrAddress AddrInfo
addr)
  [PortNumber]
ports <- [Socket] -> (Socket -> IO PortNumber) -> IO [PortNumber]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Socket]
socks Socket -> IO PortNumber
IO.socketPort
  [Socket] -> (Socket -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Socket]
socks Socket -> IO ()
IO.close

  [Int] -> IO [Int]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int] -> IO [Int]) -> [Int] -> IO [Int]
forall a b. (a -> b) -> a -> b
$ (PortNumber -> Int) -> [PortNumber] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PortNumber -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral [PortNumber]
ports