{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeApplications    #-}

{-|
Module      : Reflex.Backend.Socket
Description : Wrap a TCP socket to communicate through @Event t ByteString@
Copyright   : (c) 2018-2019, Commonwealth Scientific and Industrial Research Organisation
License     : BSD3
Maintainer  : dave.laing.80@gmail.com, jack.kelly@data61.csiro.au
Stability   : experimental
Portability : non-portable

Use 'socket' to wrap a network 'Socket' so that it sends out the
firings of an @'Event' t 'ByteString'@, and fires any data that it
receives on another @'Event' t 'ByteString'@.
-}

module Reflex.Backend.Socket
  ( socket

    -- * Socket configuration
  , SocketConfig(..)

    -- * Socket output events
  , Socket(..)

    -- * Lenses
    -- ** 'SocketConfig'
  , scInitSocket
  , scMaxRx
  , scSend
  , scClose

    -- ** 'Socket'
  , sReceive
  , sOpen
  , sClose
  , sError

    -- * Convenience re-exports
  , module Reflex.Backend.Socket.Accept
  , module Reflex.Backend.Socket.Connect
  , module Reflex.Backend.Socket.Error
  ) where

import           Control.Concurrent (forkIO)
import qualified Control.Concurrent.STM as STM
import           Control.Exception (IOException, try)
import           Control.Lens.TH (makeLenses)
import           Control.Monad.IO.Class (MonadIO(..))
import           Control.Monad.STM (atomically)
import           Data.Align (align)
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import           Data.Functor (($>), (<&>), void)
import           Data.These
import qualified Network.Socket as NS
import           Network.Socket.ByteString (sendAll, recv)
import           Reflex
import           Reflex.Backend.Socket.Accept
import           Reflex.Backend.Socket.Connect
import           Reflex.Backend.Socket.Error

-- | Holds the socket to wire into the FRP network, and events that
-- drive it.
data SocketConfig t = SocketConfig
  { SocketConfig t -> Socket
_scInitSocket :: NS.Socket
    -- ^ Socket to wrap.
  , SocketConfig t -> Int
_scMaxRx :: Int
    -- ^ Maximum number of bytes to read at a time.
  , SocketConfig t -> Event t ByteString
_scSend :: Event t ByteString
    -- ^ Data to send out on this socket.
  , SocketConfig t -> Event t ()
_scClose :: Event t ()
    -- ^ Ask to close the socket. The socket will stop trying to
    -- receive data (and the '_sReceive' event will stop firing), and
    -- the socket will be "drained": future events on '_scSend' will
    -- be ignored, and it will close after writing all pending data.
    -- If '_scSend' and '_scClose' fire in the same frame, the data
    -- will nevertheless be queued for sending.
  }

$(makeLenses ''SocketConfig)

-- | Events produced by an active socket.
data Socket t = Socket
  { Socket t -> Event t ByteString
_sReceive :: Event t ByteString
    -- ^ Data has arrived.
  , Socket t -> Event t ()
_sOpen :: Event t ()
    -- ^ The socket has opened, and its receive/send loops are running.
  , Socket t -> Event t ()
_sClose :: Event t ()
    -- ^ The socket has closed. This will fire exactly once when the
    -- socket closes for any reason, including if your '_scClose'
    -- event fires, the other end disconnects, or if the socket closes
    -- in response to a caught exception.
  , Socket t -> Event t IOException
_sError :: Event t IOException
    -- ^ An exception occurred. Treat the socket as closed after you
    -- see this. If the socket was open, you will see the '_sClose'
    -- event fire as well, but not necessarily in the same frame.
  }

$(makeLenses ''Socket)

data SocketState
  = Open
    -- ^ Data flows in both directions
  | Draining
    -- ^ We've been asked to close, but will transmit all pending data
    -- first (and not accept any more)
  | Closed
    -- ^ Hard close. Don't transmit pending data.

-- | Wire a socket into the FRP network. You will likely use this to
-- attach events to a socket that you just connected (from
-- 'Reflex.Backend.Socket.Connect.connect'), or a socket that you just
-- accepted (from the 'Reflex.Backend.Socket.Accept._aAcceptSocket'
-- event you got when you called
-- 'Reflex.Backend.Socket.Accept.accept').
socket
  :: forall t m.
     ( Reflex t
     , PerformEvent t m
     , PostBuild t m
     , TriggerEvent t m
     , MonadIO (Performable m)
     , MonadIO m
     )
  => SocketConfig t
  -> m (Socket t)
socket :: SocketConfig t -> m (Socket t)
socket (SocketConfig Socket
sock Int
maxRx Event t ByteString
eTx Event t ()
eClose) = do
  (Event t ByteString
eRx, ByteString -> IO ()
onRx) <- m (Event t ByteString, ByteString -> IO ())
forall t (m :: * -> *) a.
TriggerEvent t m =>
m (Event t a, a -> IO ())
newTriggerEvent
  (Event t ()
eOpen, () -> IO ()
onOpen) <- m (Event t (), () -> IO ())
forall t (m :: * -> *) a.
TriggerEvent t m =>
m (Event t a, a -> IO ())
newTriggerEvent
  (Event t ()
eClosed, () -> IO ()
onClosed) <- m (Event t (), () -> IO ())
forall t (m :: * -> *) a.
TriggerEvent t m =>
m (Event t a, a -> IO ())
newTriggerEvent
  (Event t IOException
eError, IOException -> IO ()
onError) <- m (Event t IOException, IOException -> IO ())
forall t (m :: * -> *) a.
TriggerEvent t m =>
m (Event t a, a -> IO ())
newTriggerEvent

  TQueue ByteString
payloadQueue <- IO (TQueue ByteString) -> m (TQueue ByteString)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO (TQueue ByteString)
forall a. IO (TQueue a)
STM.newTQueueIO
  TVar SocketState
state <- IO (TVar SocketState) -> m (TVar SocketState)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (TVar SocketState) -> m (TVar SocketState))
-> IO (TVar SocketState) -> m (TVar SocketState)
forall a b. (a -> b) -> a -> b
$ SocketState -> IO (TVar SocketState)
forall a. a -> IO (TVar a)
STM.newTVarIO SocketState
Open

  let
    start :: Performable m ()
start = IO () -> Performable m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Performable m ()) -> IO () -> Performable m ()
forall a b. (a -> b) -> a -> b
$ do
      IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO IO ()
txLoop
      IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO IO ()
rxLoop
      IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO IO ()
closeSentinel
      () -> IO ()
onOpen ()

      where
        txLoop :: IO ()
txLoop =
          let
            loop :: IO ()
loop = do
              Maybe ByteString
mBytes <- STM (Maybe ByteString) -> IO (Maybe ByteString)
forall a. STM a -> IO a
atomically (STM (Maybe ByteString) -> IO (Maybe ByteString))
-> STM (Maybe ByteString) -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$
                TVar SocketState -> STM SocketState
forall a. TVar a -> STM a
STM.readTVar TVar SocketState
state STM SocketState
-> (SocketState -> STM (Maybe ByteString))
-> STM (Maybe ByteString)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                  SocketState
Closed -> Maybe ByteString -> STM (Maybe ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ByteString
forall a. Maybe a
Nothing
                  SocketState
Draining -> TQueue ByteString -> STM (Maybe ByteString)
forall a. TQueue a -> STM (Maybe a)
STM.tryReadTQueue TQueue ByteString
payloadQueue
                  SocketState
Open -> TQueue ByteString -> STM (Maybe ByteString)
forall a. TQueue a -> STM (Maybe a)
STM.tryReadTQueue TQueue ByteString
payloadQueue
                    STM (Maybe ByteString)
-> (Maybe ByteString -> STM (Maybe ByteString))
-> STM (Maybe ByteString)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STM (Maybe ByteString)
-> (ByteString -> STM (Maybe ByteString))
-> Maybe ByteString
-> STM (Maybe ByteString)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe STM (Maybe ByteString)
forall a. STM a
STM.retry (Maybe ByteString -> STM (Maybe ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ByteString -> STM (Maybe ByteString))
-> (ByteString -> Maybe ByteString)
-> ByteString
-> STM (Maybe ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just)

              case Maybe ByteString
mBytes of
                Maybe ByteString
Nothing -> IO ()
shutdown
                Just ByteString
bs ->
                  IO () -> IO (Either IOException ())
forall e a. Exception e => IO a -> IO (Either e a)
try (Socket -> ByteString -> IO ()
sendAll Socket
sock ByteString
bs) IO (Either IOException ())
-> (Either IOException () -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                    Left IOException
exc -> IOException -> IO ()
onError IOException
exc IO () -> IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> IO ()
shutdown
                    Right () -> IO ()
loop
          in IO ()
loop

        rxLoop :: IO ()
rxLoop =
          let
            loop :: IO ()
loop = STM SocketState -> IO SocketState
forall a. STM a -> IO a
atomically (TVar SocketState -> STM SocketState
forall a. TVar a -> STM a
STM.readTVar TVar SocketState
state) IO SocketState -> (SocketState -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
              SocketState
Open -> IO ByteString -> IO (Either IOException ByteString)
forall e a. Exception e => IO a -> IO (Either e a)
try (Socket -> Int -> IO ByteString
recv Socket
sock Int
maxRx) IO (Either IOException ByteString)
-> (Either IOException ByteString -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                Left IOException
exc -> IOException -> IO ()
onError IOException
exc IO () -> IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> IO ()
shutdown
                Right ByteString
bs
                  | ByteString -> Bool
B.null ByteString
bs -> IO ()
shutdown
                  | Bool
otherwise -> ByteString -> IO ()
onRx ByteString
bs IO () -> IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> IO ()
loop
              SocketState
_ -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
          in IO ()
loop

        closeSentinel :: IO ()
closeSentinel = do
          STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar SocketState -> STM SocketState
forall a. TVar a -> STM a
STM.readTVar TVar SocketState
state STM SocketState -> (SocketState -> STM ()) -> STM ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            SocketState
Closed -> () -> STM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            SocketState
_ -> STM ()
forall a. STM a
STM.retry

          IO (Either IOException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either IOException ()) -> IO ())
-> (IO () -> IO (Either IOException ())) -> IO () -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
Exception IOException =>
IO a -> IO (Either IOException a)
forall e a. Exception e => IO a -> IO (Either e a)
try @IOException (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> IO ()
NS.close Socket
sock
          () -> IO ()
onClosed ()

        shutdown :: IO ()
shutdown = IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> (STM () -> IO ()) -> STM () -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar SocketState -> SocketState -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar TVar SocketState
state SocketState
Closed

  Event t ()
ePostBuild <- m (Event t ())
forall t (m :: * -> *). PostBuild t m => m (Event t ())
getPostBuild
  Event t (Performable m ()) -> m ()
forall t (m :: * -> *).
PerformEvent t m =>
Event t (Performable m ()) -> m ()
performEvent_ (Event t (Performable m ()) -> m ())
-> Event t (Performable m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ Event t ()
ePostBuild Event t () -> Performable m () -> Event t (Performable m ())
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Performable m ()
start

  -- If we see a tx and a close event in the same frame, we want to
  -- process the tx before the close, so it doesn't get lost.
  let
    eTxOrClose :: Event t (These ByteString ())
    eTxOrClose :: Event t (These ByteString ())
eTxOrClose = Event t ByteString -> Event t () -> Event t (These ByteString ())
forall (f :: * -> *) a b.
Semialign f =>
f a -> f b -> f (These a b)
align Event t ByteString
eTx Event t ()
eClose

    queueSend :: ByteString -> STM ()
queueSend ByteString
bs = TVar SocketState -> STM SocketState
forall a. TVar a -> STM a
STM.readTVar TVar SocketState
state STM SocketState -> (SocketState -> STM ()) -> STM ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      SocketState
Open -> TQueue ByteString -> ByteString -> STM ()
forall a. TQueue a -> a -> STM ()
STM.writeTQueue TQueue ByteString
payloadQueue ByteString
bs
      SocketState
_ -> () -> STM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    queueClose :: STM ()
queueClose = TVar SocketState -> (SocketState -> SocketState) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
STM.modifyTVar TVar SocketState
state ((SocketState -> SocketState) -> STM ())
-> (SocketState -> SocketState) -> STM ()
forall a b. (a -> b) -> a -> b
$ \case
      SocketState
Open -> SocketState
Draining
      SocketState
s -> SocketState
s

  Event t (Performable m ()) -> m ()
forall t (m :: * -> *).
PerformEvent t m =>
Event t (Performable m ()) -> m ()
performEvent_ (Event t (Performable m ()) -> m ())
-> Event t (Performable m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ Event t (These ByteString ())
eTxOrClose Event t (These ByteString ())
-> (These ByteString () -> Performable m ())
-> Event t (Performable m ())
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> IO () -> Performable m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Performable m ())
-> (These ByteString () -> IO ())
-> These ByteString ()
-> Performable m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ())
-> (These ByteString () -> STM ()) -> These ByteString () -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
    This ByteString
bs -> ByteString -> STM ()
queueSend ByteString
bs
    That () -> STM ()
queueClose
    These ByteString
bs () -> ByteString -> STM ()
queueSend ByteString
bs STM () -> STM () -> STM ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> STM ()
queueClose

  Socket t -> m (Socket t)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Socket t -> m (Socket t)) -> Socket t -> m (Socket t)
forall a b. (a -> b) -> a -> b
$ Event t ByteString
-> Event t () -> Event t () -> Event t IOException -> Socket t
forall t.
Event t ByteString
-> Event t () -> Event t () -> Event t IOException -> Socket t
Socket Event t ByteString
eRx Event t ()
eOpen Event t ()
eClosed Event t IOException
eError