-- | Compatibility layer for network package, including newtype 'PortID'
{-# LANGUAGE CPP, GeneralizedNewtypeDeriving, OverloadedStrings #-}

module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo, 
                                          lookupReplicaSetName, lookupSeedList) where


#if !MIN_VERSION_network(2, 9, 0)

import qualified Network as N
import System.IO (Handle)

#else

import Control.Exception (bracketOnError)
import Network.BSD as BSD
import qualified Network.Socket as N
import System.IO (Handle, IOMode(ReadWriteMode))

#endif

import Data.ByteString.Char8 (pack, unpack)
import Data.List (dropWhileEnd, lookup)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Network.DNS.Lookup (lookupSRV, lookupTXT)
import Network.DNS.Resolver (defaultResolvConf, makeResolvSeed, withResolver)
import Network.HTTP.Types.URI (parseQueryText)


-- | Wraps network's 'PortNumber'
-- Used to ease compatibility between older and newer network versions.
data PortID = PortNumber N.PortNumber
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
            | UnixSocket String
#endif
            deriving (PortID -> PortID -> Bool
(PortID -> PortID -> Bool)
-> (PortID -> PortID -> Bool) -> Eq PortID
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PortID -> PortID -> Bool
$c/= :: PortID -> PortID -> Bool
== :: PortID -> PortID -> Bool
$c== :: PortID -> PortID -> Bool
Eq, Eq PortID
Eq PortID
-> (PortID -> PortID -> Ordering)
-> (PortID -> PortID -> Bool)
-> (PortID -> PortID -> Bool)
-> (PortID -> PortID -> Bool)
-> (PortID -> PortID -> Bool)
-> (PortID -> PortID -> PortID)
-> (PortID -> PortID -> PortID)
-> Ord PortID
PortID -> PortID -> Bool
PortID -> PortID -> Ordering
PortID -> PortID -> PortID
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: PortID -> PortID -> PortID
$cmin :: PortID -> PortID -> PortID
max :: PortID -> PortID -> PortID
$cmax :: PortID -> PortID -> PortID
>= :: PortID -> PortID -> Bool
$c>= :: PortID -> PortID -> Bool
> :: PortID -> PortID -> Bool
$c> :: PortID -> PortID -> Bool
<= :: PortID -> PortID -> Bool
$c<= :: PortID -> PortID -> Bool
< :: PortID -> PortID -> Bool
$c< :: PortID -> PortID -> Bool
compare :: PortID -> PortID -> Ordering
$ccompare :: PortID -> PortID -> Ordering
$cp1Ord :: Eq PortID
Ord, Int -> PortID -> ShowS
[PortID] -> ShowS
PortID -> String
(Int -> PortID -> ShowS)
-> (PortID -> String) -> ([PortID] -> ShowS) -> Show PortID
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PortID] -> ShowS
$cshowList :: [PortID] -> ShowS
show :: PortID -> String
$cshow :: PortID -> String
showsPrec :: Int -> PortID -> ShowS
$cshowsPrec :: Int -> PortID -> ShowS
Show)


#if !MIN_VERSION_network(2, 9, 0)

-- Unwrap our newtype and use network's PortID and connectTo
connectTo :: N.HostName         -- Hostname
          -> PortID             -- Port Identifier
          -> IO Handle          -- Connected Socket
connectTo hostname (PortNumber port) = N.connectTo hostname (N.PortNumber port)

#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
connectTo _ (UnixSocket path) = N.connectTo "" (N.UnixSocket path)
#endif

#else

-- Copied implementation from network 2.8's 'connectTo', but using our 'PortID' newtype.
-- https://github.com/haskell/network/blob/e73f0b96c9da924fe83f3c73488f7e69f712755f/Network.hs#L120-L129
connectTo :: N.HostName         -- Hostname
          -> PortID             -- Port Identifier
          -> IO Handle          -- Connected Socket
connectTo :: String -> PortID -> IO Handle
connectTo String
hostname (PortNumber PortNumber
port) = do
    ProtocolNumber
proto <- String -> IO ProtocolNumber
BSD.getProtocolNumber String
"tcp"
    IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Handle) -> IO Handle
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
        (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
N.AF_INET SocketType
N.Stream ProtocolNumber
proto)
        (Socket -> IO ()
N.close)  -- only done if there's an error
        (\Socket
sock -> do
          HostEntry
he <- String -> IO HostEntry
BSD.getHostByName String
hostname
          Socket -> SockAddr -> IO ()
N.connect Socket
sock (PortNumber -> HostAddress -> SockAddr
N.SockAddrInet PortNumber
port (HostEntry -> HostAddress
hostAddress HostEntry
he))
          Socket -> IOMode -> IO Handle
N.socketToHandle Socket
sock IOMode
ReadWriteMode
        )

#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
connectTo String
_ (UnixSocket String
path) = do
    IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Handle) -> IO Handle
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
        (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
N.AF_UNIX SocketType
N.Stream ProtocolNumber
0)
        (Socket -> IO ()
N.close)
        (\Socket
sock -> do
          Socket -> SockAddr -> IO ()
N.connect Socket
sock (String -> SockAddr
N.SockAddrUnix String
path)
          Socket -> IOMode -> IO Handle
N.socketToHandle Socket
sock IOMode
ReadWriteMode
        )
#endif

#endif

-- * Host

data Host = Host N.HostName PortID  deriving (Int -> Host -> ShowS
[Host] -> ShowS
Host -> String
(Int -> Host -> ShowS)
-> (Host -> String) -> ([Host] -> ShowS) -> Show Host
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Host] -> ShowS
$cshowList :: [Host] -> ShowS
show :: Host -> String
$cshow :: Host -> String
showsPrec :: Int -> Host -> ShowS
$cshowsPrec :: Int -> Host -> ShowS
Show, Host -> Host -> Bool
(Host -> Host -> Bool) -> (Host -> Host -> Bool) -> Eq Host
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Host -> Host -> Bool
$c/= :: Host -> Host -> Bool
== :: Host -> Host -> Bool
$c== :: Host -> Host -> Bool
Eq, Eq Host
Eq Host
-> (Host -> Host -> Ordering)
-> (Host -> Host -> Bool)
-> (Host -> Host -> Bool)
-> (Host -> Host -> Bool)
-> (Host -> Host -> Bool)
-> (Host -> Host -> Host)
-> (Host -> Host -> Host)
-> Ord Host
Host -> Host -> Bool
Host -> Host -> Ordering
Host -> Host -> Host
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Host -> Host -> Host
$cmin :: Host -> Host -> Host
max :: Host -> Host -> Host
$cmax :: Host -> Host -> Host
>= :: Host -> Host -> Bool
$c>= :: Host -> Host -> Bool
> :: Host -> Host -> Bool
$c> :: Host -> Host -> Bool
<= :: Host -> Host -> Bool
$c<= :: Host -> Host -> Bool
< :: Host -> Host -> Bool
$c< :: Host -> Host -> Bool
compare :: Host -> Host -> Ordering
$ccompare :: Host -> Host -> Ordering
$cp1Ord :: Eq Host
Ord)

lookupReplicaSetName :: N.HostName -> IO (Maybe Text)
-- ^ Retrieves the replica set name from the TXT DNS record for the given hostname
lookupReplicaSetName :: String -> IO (Maybe Text)
lookupReplicaSetName String
hostname = do 
  ResolvSeed
rs <- ResolvConf -> IO ResolvSeed
makeResolvSeed ResolvConf
defaultResolvConf
  Either DNSError [ByteString]
res <- ResolvSeed
-> (Resolver -> IO (Either DNSError [ByteString]))
-> IO (Either DNSError [ByteString])
forall a. ResolvSeed -> (Resolver -> IO a) -> IO a
withResolver ResolvSeed
rs ((Resolver -> IO (Either DNSError [ByteString]))
 -> IO (Either DNSError [ByteString]))
-> (Resolver -> IO (Either DNSError [ByteString]))
-> IO (Either DNSError [ByteString])
forall a b. (a -> b) -> a -> b
$ \Resolver
resolver -> Resolver -> ByteString -> IO (Either DNSError [ByteString])
lookupTXT Resolver
resolver (String -> ByteString
pack String
hostname)
  case Either DNSError [ByteString]
res of 
    Left DNSError
_ -> Maybe Text -> IO (Maybe Text)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Text
forall a. Maybe a
Nothing 
    Right [] -> Maybe Text -> IO (Maybe Text)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Text
forall a. Maybe a
Nothing 
    Right (ByteString
x:[ByteString]
_) ->
      Maybe Text -> IO (Maybe Text)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Text -> IO (Maybe Text)) -> Maybe Text -> IO (Maybe Text)
forall a b. (a -> b) -> a -> b
$ Maybe Text -> Maybe (Maybe Text) -> Maybe Text
forall a. a -> Maybe a -> a
fromMaybe (Maybe Text
forall a. Maybe a
Nothing :: Maybe Text) (Text -> [(Text, Maybe Text)] -> Maybe (Maybe Text)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Text
"replicaSet" ([(Text, Maybe Text)] -> Maybe (Maybe Text))
-> [(Text, Maybe Text)] -> Maybe (Maybe Text)
forall a b. (a -> b) -> a -> b
$ ByteString -> [(Text, Maybe Text)]
parseQueryText ByteString
x)

lookupSeedList :: N.HostName -> IO [Host]
-- ^ Retrieves the replica set seed list from the SRV DNS record for the given hostname
lookupSeedList :: String -> IO [Host]
lookupSeedList String
hostname = do 
  ResolvSeed
rs <- ResolvConf -> IO ResolvSeed
makeResolvSeed ResolvConf
defaultResolvConf
  Either DNSError [(Word16, Word16, Word16, ByteString)]
res <- ResolvSeed
-> (Resolver
    -> IO (Either DNSError [(Word16, Word16, Word16, ByteString)]))
-> IO (Either DNSError [(Word16, Word16, Word16, ByteString)])
forall a. ResolvSeed -> (Resolver -> IO a) -> IO a
withResolver ResolvSeed
rs ((Resolver
  -> IO (Either DNSError [(Word16, Word16, Word16, ByteString)]))
 -> IO (Either DNSError [(Word16, Word16, Word16, ByteString)]))
-> (Resolver
    -> IO (Either DNSError [(Word16, Word16, Word16, ByteString)]))
-> IO (Either DNSError [(Word16, Word16, Word16, ByteString)])
forall a b. (a -> b) -> a -> b
$ \Resolver
resolver -> Resolver
-> ByteString
-> IO (Either DNSError [(Word16, Word16, Word16, ByteString)])
lookupSRV Resolver
resolver (ByteString
 -> IO (Either DNSError [(Word16, Word16, Word16, ByteString)]))
-> ByteString
-> IO (Either DNSError [(Word16, Word16, Word16, ByteString)])
forall a b. (a -> b) -> a -> b
$ String -> ByteString
pack (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ String
"_mongodb._tcp." String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
hostname
  case Either DNSError [(Word16, Word16, Word16, ByteString)]
res of 
    Left DNSError
_ -> [Host] -> IO [Host]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    Right [(Word16, Word16, Word16, ByteString)]
srv -> [Host] -> IO [Host]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Host] -> IO [Host]) -> [Host] -> IO [Host]
forall a b. (a -> b) -> a -> b
$ ((Word16, Word16, Word16, ByteString) -> Host)
-> [(Word16, Word16, Word16, ByteString)] -> [Host]
forall a b. (a -> b) -> [a] -> [b]
map (\(Word16
_, Word16
_, Word16
por, ByteString
tar) -> 
      let tar' :: String
tar' = (Char -> Bool) -> ShowS
forall a. (a -> Bool) -> [a] -> [a]
dropWhileEnd (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
==Char
'.') (ByteString -> String
unpack ByteString
tar) 
      in String -> PortID -> Host
Host String
tar' (PortNumber -> PortID
PortNumber (PortNumber -> PortID)
-> (Word16 -> PortNumber) -> Word16 -> PortID
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> PortID) -> Word16 -> PortID
forall a b. (a -> b) -> a -> b
$ Word16
por)) [(Word16, Word16, Word16, ByteString)]
srv