{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE UndecidableInstances  #-}

module Metro.UDPServer
  ( UDPServer
  , udpServer
  , newClient
  ) where

import           Control.Monad             (void)
import           Data.ByteString           (empty)
import           Data.Hashable
import           Metro.Class               (GetPacketId, RecvPacket,
                                            Servable (..), Transport,
                                            TransportConfig)
import           Metro.Conn
import           Metro.IOHashMap           (IOHashMap, newIOHashMap)
import qualified Metro.IOHashMap           as HM (delete, insert, lookup)
import           Metro.Node                (NodeEnv1)
import           Metro.Server              (ServerT, getServ, handleConn,
                                            serverEnv)
import           Metro.Session             (SessionT)
import           Metro.Socket              (bindTo, getDatagramAddr)
import           Metro.TP.BS               (BSHandle, bsTransportConfig,
                                            closeBSHandle, feed, newBSHandle)
import           Metro.TP.UDPSocket        (UDPSocket, udpSocket_)
import           Network.Socket            (SockAddr, Socket, addrAddress)
import qualified Network.Socket            as Socket (close)
import           Network.Socket.ByteString (recvFrom, sendAllTo)
import           System.Log.Logger         (errorM)
import           UnliftIO

data UDPServer = UDPServer Socket (IOHashMap String BSHandle)

instance Servable UDPServer where
  data ServerConfig UDPServer = UDPConfig String
  type SID UDPServer = SockAddr
  type STP UDPServer = UDPSocket
  newServer :: ServerConfig UDPServer -> m UDPServer
newServer (UDPConfig hostPort) = do
    Socket
sock <- IO Socket -> m Socket
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Socket -> m Socket) -> IO Socket -> m Socket
forall a b. (a -> b) -> a -> b
$ String -> IO Socket
bindTo String
hostPort
    Socket -> IOHashMap String BSHandle -> UDPServer
UDPServer Socket
sock (IOHashMap String BSHandle -> UDPServer)
-> m (IOHashMap String BSHandle) -> m UDPServer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (IOHashMap String BSHandle)
forall (m :: * -> *) a b. MonadIO m => m (IOHashMap a b)
newIOHashMap
  servOnce :: UDPServer
-> (Maybe (SID UDPServer, TransportConfig (STP UDPServer)) -> m ())
-> m ()
servOnce us :: UDPServer
us@(UDPServer serv :: Socket
serv handleList :: IOHashMap String BSHandle
handleList) done :: Maybe (SID UDPServer, TransportConfig (STP UDPServer)) -> m ()
done = do
    (bs :: ByteString
bs, addr :: SockAddr
addr) <- IO (ByteString, SockAddr) -> m (ByteString, SockAddr)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (ByteString, SockAddr) -> m (ByteString, SockAddr))
-> IO (ByteString, SockAddr) -> m (ByteString, SockAddr)
forall a b. (a -> b) -> a -> b
$ Socket -> Int -> IO (ByteString, SockAddr)
recvFrom Socket
serv 4194304

    Maybe BSHandle
bsHandle <- IOHashMap String BSHandle -> String -> m (Maybe BSHandle)
forall a (m :: * -> *) b.
(Eq a, Hashable a, MonadIO m) =>
IOHashMap a b -> a -> m (Maybe b)
HM.lookup IOHashMap String BSHandle
handleList (String -> m (Maybe BSHandle)) -> String -> m (Maybe BSHandle)
forall a b. (a -> b) -> a -> b
$ SockAddr -> String
forall a. Show a => a -> String
show SockAddr
addr
    case Maybe BSHandle
bsHandle of
      Just h :: BSHandle
h  -> BSHandle -> ByteString -> m ()
forall (m :: * -> *). MonadIO m => BSHandle -> ByteString -> m ()
feed BSHandle
h ByteString
bs
      Nothing ->
        m (Async ()) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (Async ()) -> m ()) -> (m () -> m (Async ())) -> m () -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m () -> m (Async ())
forall (m :: * -> *) a. MonadUnliftIO m => m a -> m (Async a)
async (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          BSHandle
h <- ByteString -> m BSHandle
forall (m :: * -> *). MonadIO m => ByteString -> m BSHandle
newBSHandle ByteString
bs
          TransportConfig UDPSocket
config <- UDPServer -> SockAddr -> BSHandle -> m (TransportConfig UDPSocket)
forall (m :: * -> *).
MonadIO m =>
UDPServer -> SockAddr -> BSHandle -> m (TransportConfig UDPSocket)
newTransportConfig UDPServer
us SockAddr
addr BSHandle
h
          Maybe (SID UDPServer, TransportConfig (STP UDPServer)) -> m ()
done (Maybe (SID UDPServer, TransportConfig (STP UDPServer)) -> m ())
-> Maybe (SID UDPServer, TransportConfig (STP UDPServer)) -> m ()
forall a b. (a -> b) -> a -> b
$ (SockAddr, TransportConfig UDPSocket)
-> Maybe (SockAddr, TransportConfig UDPSocket)
forall a. a -> Maybe a
Just (SockAddr
addr, TransportConfig UDPSocket
config)
          BSHandle -> m ()
forall (m :: * -> *). MonadIO m => BSHandle -> m ()
closeBSHandle BSHandle
h

  onConnEnter :: UDPServer -> SID UDPServer -> m ()
onConnEnter _ _ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  onConnLeave :: UDPServer -> SID UDPServer -> m ()
onConnLeave (UDPServer _ handleList :: IOHashMap String BSHandle
handleList) addr :: SID UDPServer
addr = IOHashMap String BSHandle -> String -> m ()
forall a (m :: * -> *) b.
(Eq a, Hashable a, MonadIO m) =>
IOHashMap a b -> a -> m ()
HM.delete IOHashMap String BSHandle
handleList (SockAddr -> String
forall a. Show a => a -> String
show SID UDPServer
SockAddr
addr)
  servClose :: UDPServer -> m ()
servClose (UDPServer serv :: Socket
serv _) = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Socket -> IO ()
Socket.close Socket
serv

udpServer :: String -> ServerConfig UDPServer
udpServer :: String -> ServerConfig UDPServer
udpServer = String -> ServerConfig UDPServer
UDPConfig

newTransportConfig
  :: (MonadIO m)
  => UDPServer
  -> SockAddr
  -> BSHandle
  -> m (TransportConfig UDPSocket)
newTransportConfig :: UDPServer -> SockAddr -> BSHandle -> m (TransportConfig UDPSocket)
newTransportConfig (UDPServer sock :: Socket
sock handleList :: IOHashMap String BSHandle
handleList) addr :: SockAddr
addr h :: BSHandle
h = do
  IOHashMap String BSHandle -> String -> BSHandle -> m ()
forall a (m :: * -> *) b.
(Eq a, Hashable a, MonadIO m) =>
IOHashMap a b -> a -> b -> m ()
HM.insert IOHashMap String BSHandle
handleList (SockAddr -> String
forall a. Show a => a -> String
show SockAddr
addr) BSHandle
h
  TransportConfig UDPSocket -> m (TransportConfig UDPSocket)
forall (m :: * -> *) a. Monad m => a -> m a
return (TransportConfig UDPSocket -> m (TransportConfig UDPSocket))
-> TransportConfig UDPSocket -> m (TransportConfig UDPSocket)
forall a b. (a -> b) -> a -> b
$ TransportConfig BSTransport -> TransportConfig UDPSocket
udpSocket_ (TransportConfig BSTransport -> TransportConfig UDPSocket)
-> TransportConfig BSTransport -> TransportConfig UDPSocket
forall a b. (a -> b) -> a -> b
$ BSHandle -> (ByteString -> IO ()) -> TransportConfig BSTransport
bsTransportConfig BSHandle
h ((ByteString -> IO ()) -> TransportConfig BSTransport)
-> (ByteString -> IO ()) -> TransportConfig BSTransport
forall a b. (a -> b) -> a -> b
$ (ByteString -> SockAddr -> IO ())
-> SockAddr -> ByteString -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Socket -> ByteString -> SockAddr -> IO ()
sendAllTo Socket
sock) SockAddr
addr

newClient
  :: (MonadUnliftIO m, Transport tp, Show nid, Eq nid, Hashable nid, Eq k, Hashable k, GetPacketId k rpkt, RecvPacket rpkt)
  => (TransportConfig UDPSocket -> TransportConfig tp)
  -> String
  -> nid
  -> u
  -> (rpkt -> m Bool)
  -> SessionT u nid k rpkt tp m ()
  -> ServerT UDPServer u nid k rpkt tp m (Maybe (NodeEnv1 u nid k rpkt tp))
newClient :: (TransportConfig UDPSocket -> TransportConfig tp)
-> String
-> nid
-> u
-> (rpkt -> m Bool)
-> SessionT u nid k rpkt tp m ()
-> ServerT
     UDPServer u nid k rpkt tp m (Maybe (NodeEnv1 u nid k rpkt tp))
newClient mk :: TransportConfig UDPSocket -> TransportConfig tp
mk hostPort :: String
hostPort nid :: nid
nid uEnv :: u
uEnv preprocess :: rpkt -> m Bool
preprocess sess :: SessionT u nid k rpkt tp m ()
sess = do
  Maybe AddrInfo
addr <- IO (Maybe AddrInfo)
-> ServerT UDPServer u nid k rpkt tp m (Maybe AddrInfo)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe AddrInfo)
 -> ServerT UDPServer u nid k rpkt tp m (Maybe AddrInfo))
-> IO (Maybe AddrInfo)
-> ServerT UDPServer u nid k rpkt tp m (Maybe AddrInfo)
forall a b. (a -> b) -> a -> b
$ String -> IO (Maybe AddrInfo)
getDatagramAddr String
hostPort
  case Maybe AddrInfo
addr of
    Nothing -> do
      IO () -> ServerT UDPServer u nid k rpkt tp m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ServerT UDPServer u nid k rpkt tp m ())
-> IO () -> ServerT UDPServer u nid k rpkt tp m ()
forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
errorM "Metro.UDP" (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ "Connect UDP Server " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
hostPort String -> String -> String
forall a. [a] -> [a] -> [a]
++ " failed"
      Maybe (NodeEnv1 u nid k rpkt tp)
-> ServerT
     UDPServer u nid k rpkt tp m (Maybe (NodeEnv1 u nid k rpkt tp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (NodeEnv1 u nid k rpkt tp)
forall a. Maybe a
Nothing
    Just addr0 :: AddrInfo
addr0 -> do
      UDPServer
us <- ServerEnv UDPServer u nid k rpkt tp -> UDPServer
forall serv u nid k rpkt tp. ServerEnv serv u nid k rpkt tp -> serv
getServ (ServerEnv UDPServer u nid k rpkt tp -> UDPServer)
-> ServerT
     UDPServer u nid k rpkt tp m (ServerEnv UDPServer u nid k rpkt tp)
-> ServerT UDPServer u nid k rpkt tp m UDPServer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ServerT
  UDPServer u nid k rpkt tp m (ServerEnv UDPServer u nid k rpkt tp)
forall (m :: * -> *) serv u nid k rpkt tp.
Monad m =>
ServerT serv u nid k rpkt tp m (ServerEnv serv u nid k rpkt tp)
serverEnv
      BSHandle
h <- ByteString -> ServerT UDPServer u nid k rpkt tp m BSHandle
forall (m :: * -> *). MonadIO m => ByteString -> m BSHandle
newBSHandle ByteString
empty
      TransportConfig tp
config <- TransportConfig UDPSocket -> TransportConfig tp
mk (TransportConfig UDPSocket -> TransportConfig tp)
-> ServerT UDPServer u nid k rpkt tp m (TransportConfig UDPSocket)
-> ServerT UDPServer u nid k rpkt tp m (TransportConfig tp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UDPServer
-> SockAddr
-> BSHandle
-> ServerT UDPServer u nid k rpkt tp m (TransportConfig UDPSocket)
forall (m :: * -> *).
MonadIO m =>
UDPServer -> SockAddr -> BSHandle -> m (TransportConfig UDPSocket)
newTransportConfig UDPServer
us (AddrInfo -> SockAddr
addrAddress AddrInfo
addr0) BSHandle
h
      ConnEnv tp
connEnv <- TransportConfig tp
-> ServerT UDPServer u nid k rpkt tp m (ConnEnv tp)
forall (m :: * -> *) tp.
(MonadIO m, Transport tp) =>
TransportConfig tp -> m (ConnEnv tp)
initConnEnv TransportConfig tp
config
      NodeEnv1 u nid k rpkt tp -> Maybe (NodeEnv1 u nid k rpkt tp)
forall a. a -> Maybe a
Just (NodeEnv1 u nid k rpkt tp -> Maybe (NodeEnv1 u nid k rpkt tp))
-> ((NodeEnv1 u nid k rpkt tp, Async ())
    -> NodeEnv1 u nid k rpkt tp)
-> (NodeEnv1 u nid k rpkt tp, Async ())
-> Maybe (NodeEnv1 u nid k rpkt tp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NodeEnv1 u nid k rpkt tp, Async ()) -> NodeEnv1 u nid k rpkt tp
forall a b. (a, b) -> a
fst ((NodeEnv1 u nid k rpkt tp, Async ())
 -> Maybe (NodeEnv1 u nid k rpkt tp))
-> ServerT
     UDPServer u nid k rpkt tp m (NodeEnv1 u nid k rpkt tp, Async ())
-> ServerT
     UDPServer u nid k rpkt tp m (Maybe (NodeEnv1 u nid k rpkt tp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> SID UDPServer
-> ConnEnv tp
-> nid
-> u
-> (rpkt -> m Bool)
-> SessionT u nid k rpkt tp m ()
-> ServerT
     UDPServer u nid k rpkt tp m (NodeEnv1 u nid k rpkt tp, Async ())
forall (m :: * -> *) tp nid k rpkt serv u.
(MonadUnliftIO m, Transport tp, Show nid, Eq nid, Hashable nid,
 Eq k, Hashable k, GetPacketId k rpkt, RecvPacket rpkt,
 Servable serv) =>
String
-> SID serv
-> ConnEnv tp
-> nid
-> u
-> (rpkt -> m Bool)
-> SessionT u nid k rpkt tp m ()
-> ServerT
     serv u nid k rpkt tp m (NodeEnv1 u nid k rpkt tp, Async ())
handleConn "Server" (AddrInfo -> SockAddr
addrAddress AddrInfo
addr0) ConnEnv tp
connEnv nid
nid u
uEnv rpkt -> m Bool
preprocess SessionT u nid k rpkt tp m ()
sess