--------------------------------------------------------------------------------
-- | Lightweight abstraction over an input/output stream.
{-# LANGUAGE CPP #-}
module Network.WebSockets.Stream
    ( Stream
    , makeStream
    , makeSocketStream
    , makeEchoStream
    , parse
    , parseBin
    , write
    , close
    ) where

import           Control.Concurrent.MVar        (MVar, newEmptyMVar, newMVar,
                                                 putMVar, takeMVar, withMVar)
import           Control.Exception              (SomeException, SomeAsyncException, throwIO, catch, try, fromException)
import           Control.Monad                  (forM_)
import qualified Data.Attoparsec.ByteString     as Atto
import qualified Data.Binary.Get                as BIN
import qualified Data.ByteString                as B
import qualified Data.ByteString.Lazy           as BL
import           Data.IORef                     (IORef, atomicModifyIORef',
                                                 newIORef, readIORef,
                                                 writeIORef)
import qualified Network.Socket                 as S
import qualified Network.Socket.ByteString      as SB (recv)

#if !defined(mingw32_HOST_OS)
import qualified Network.Socket.ByteString.Lazy as SBL (sendAll)
#else
import qualified Network.Socket.ByteString      as SB (sendAll)
#endif
import           System.IO.Error                (isResourceVanishedError)

import           Network.WebSockets.Types


--------------------------------------------------------------------------------
-- | State of the stream
data StreamState
    = Closed !B.ByteString  -- Remainder
    | Open   !B.ByteString  -- Buffer


--------------------------------------------------------------------------------
-- | Lightweight abstraction over an input/output stream.
data Stream = Stream
    { Stream -> IO (Maybe ByteString)
streamIn    :: IO (Maybe B.ByteString)
    , Stream -> Maybe ByteString -> IO ()
streamOut   :: (Maybe BL.ByteString -> IO ())
    , Stream -> IORef StreamState
streamState :: !(IORef StreamState)
    }


--------------------------------------------------------------------------------
-- | Create a stream from a "receive" and "send" action. The following
-- properties apply:
--
-- - Regardless of the provided "receive" and "send" functions, reading and
--   writing from the stream will be thread-safe, i.e. this function will create
--   a receive and write lock to be used internally.
--
-- - Reading from or writing to a closed 'Stream' will always throw an
--   exception, even if the underlying "receive" and "send" functions do not
--   (we do the bookkeeping).
--
-- - Streams should always be closed.
makeStream
    :: IO (Maybe B.ByteString)         -- ^ Reading
    -> (Maybe BL.ByteString -> IO ())  -- ^ Writing
    -> IO Stream                       -- ^ Resulting stream
makeStream :: IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
makeStream IO (Maybe ByteString)
receive Maybe ByteString -> IO ()
send = do
    IORef StreamState
ref         <- forall a. a -> IO (IORef a)
newIORef (ByteString -> StreamState
Open ByteString
B.empty)
    MVar ()
receiveLock <- forall a. a -> IO (MVar a)
newMVar ()
    MVar ()
sendLock    <- forall a. a -> IO (MVar a)
newMVar ()
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ IO (Maybe ByteString)
-> (Maybe ByteString -> IO ()) -> IORef StreamState -> Stream
Stream (IORef StreamState -> MVar () -> IO (Maybe ByteString)
receive' IORef StreamState
ref MVar ()
receiveLock) (IORef StreamState -> MVar () -> Maybe ByteString -> IO ()
send' IORef StreamState
ref MVar ()
sendLock) IORef StreamState
ref
  where
    closeRef :: IORef StreamState -> IO ()
    closeRef :: IORef StreamState -> IO ()
closeRef IORef StreamState
ref = forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef StreamState
ref forall a b. (a -> b) -> a -> b
$ \StreamState
state -> case StreamState
state of
        Open   ByteString
buf -> (ByteString -> StreamState
Closed ByteString
buf, ())
        Closed ByteString
buf -> (ByteString -> StreamState
Closed ByteString
buf, ())

    -- Throw a 'ConnectionClosed' is the connection is not 'Open'.
    assertOpen :: IORef StreamState -> IO ()
    assertOpen :: IORef StreamState -> IO ()
assertOpen IORef StreamState
ref = do
        StreamState
state <- forall a. IORef a -> IO a
readIORef IORef StreamState
ref
        case StreamState
state of
            Closed ByteString
_ -> forall e a. Exception e => e -> IO a
throwIO ConnectionException
ConnectionClosed
            Open   ByteString
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()

    receive' :: IORef StreamState -> MVar () -> IO (Maybe B.ByteString)
    receive' :: IORef StreamState -> MVar () -> IO (Maybe ByteString)
receive' IORef StreamState
ref MVar ()
lock = forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar ()
lock forall a b. (a -> b) -> a -> b
$ \() -> do
        IORef StreamState -> IO ()
assertOpen IORef StreamState
ref
        Maybe ByteString
mbBs <- forall a b. IO a -> IO b -> IO a
onSyncException IO (Maybe ByteString)
receive (IORef StreamState -> IO ()
closeRef IORef StreamState
ref)
        case Maybe ByteString
mbBs of
            Maybe ByteString
Nothing -> IORef StreamState -> IO ()
closeRef IORef StreamState
ref forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
            Just ByteString
bs -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just ByteString
bs)

    send' :: IORef StreamState -> MVar () -> (Maybe BL.ByteString -> IO ())
    send' :: IORef StreamState -> MVar () -> Maybe ByteString -> IO ()
send' IORef StreamState
ref MVar ()
lock Maybe ByteString
mbBs = forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar ()
lock forall a b. (a -> b) -> a -> b
$ \() -> do
        case Maybe ByteString
mbBs of
            Maybe ByteString
Nothing -> IORef StreamState -> IO ()
closeRef IORef StreamState
ref
            Just ByteString
_  -> IORef StreamState -> IO ()
assertOpen IORef StreamState
ref
        forall a b. IO a -> IO b -> IO a
onSyncException (Maybe ByteString -> IO ()
send Maybe ByteString
mbBs) (IORef StreamState -> IO ()
closeRef IORef StreamState
ref)

    onSyncException :: IO a -> IO b -> IO a
    onSyncException :: forall a b. IO a -> IO b -> IO a
onSyncException IO a
io IO b
what =
        forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch IO a
io forall a b. (a -> b) -> a -> b
$ \SomeException
e -> do
            case forall e. Exception e => SomeException -> Maybe e
fromException (SomeException
e :: SomeException) :: Maybe SomeAsyncException of
                Just SomeAsyncException
_  -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                Maybe SomeAsyncException
Nothing -> IO b
what forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            forall e a. Exception e => e -> IO a
throwIO SomeException
e


--------------------------------------------------------------------------------
makeSocketStream :: S.Socket -> IO Stream
makeSocketStream :: Socket -> IO Stream
makeSocketStream Socket
socket = IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
makeStream IO (Maybe ByteString)
receive Maybe ByteString -> IO ()
send
  where
    receive :: IO (Maybe ByteString)
receive = do
        Either IOError ByteString
bs <- forall e a. Exception e => IO a -> IO (Either e a)
try forall a b. (a -> b) -> a -> b
$ Socket -> Int -> IO ByteString
SB.recv Socket
socket Int
8192
        case Either IOError ByteString
bs of
            -- If the resource vanished, the socket was closed
            Left IOError
e | IOError -> Bool
isResourceVanishedError IOError
e -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
                   | Bool
otherwise                 -> forall e a. Exception e => e -> IO a
throwIO IOError
e
            Right ByteString
bs' | ByteString -> Bool
B.null ByteString
bs'             -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
                      | Bool
otherwise              -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just ByteString
bs'

    send :: Maybe ByteString -> IO ()
send Maybe ByteString
Nothing   = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    send (Just ByteString
bs) = do
#if !defined(mingw32_HOST_OS)
        Socket -> ByteString -> IO ()
SBL.sendAll Socket
socket ByteString
bs
#else
        forM_ (BL.toChunks bs) (SB.sendAll socket)
#endif


--------------------------------------------------------------------------------
makeEchoStream :: IO Stream
makeEchoStream :: IO Stream
makeEchoStream = do
    MVar (Maybe ByteString)
mvar <- forall a. IO (MVar a)
newEmptyMVar
    IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
makeStream (forall a. MVar a -> IO a
takeMVar MVar (Maybe ByteString)
mvar) forall a b. (a -> b) -> a -> b
$ \Maybe ByteString
mbBs -> case Maybe ByteString
mbBs of
        Maybe ByteString
Nothing -> forall a. MVar a -> a -> IO ()
putMVar MVar (Maybe ByteString)
mvar forall a. Maybe a
Nothing
        Just ByteString
bs -> forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ByteString -> [ByteString]
BL.toChunks ByteString
bs) forall a b. (a -> b) -> a -> b
$ \ByteString
c -> forall a. MVar a -> a -> IO ()
putMVar MVar (Maybe ByteString)
mvar (forall a. a -> Maybe a
Just ByteString
c)


--------------------------------------------------------------------------------
parseBin :: Stream -> BIN.Get a -> IO (Maybe a)
parseBin :: forall a. Stream -> Get a -> IO (Maybe a)
parseBin Stream
stream Get a
parser = do
    StreamState
state <- forall a. IORef a -> IO a
readIORef (Stream -> IORef StreamState
streamState Stream
stream)
    case StreamState
state of
        Closed ByteString
remainder
            | ByteString -> Bool
B.null ByteString
remainder -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
            | Bool
otherwise        -> forall {a}. Decoder a -> Bool -> IO (Maybe a)
go (forall a. Get a -> Decoder a
BIN.runGetIncremental Get a
parser forall a. Decoder a -> ByteString -> Decoder a
`BIN.pushChunk` ByteString
remainder) Bool
True
        Open ByteString
buffer
            | ByteString -> Bool
B.null ByteString
buffer -> do
                Maybe ByteString
mbBs <- Stream -> IO (Maybe ByteString)
streamIn Stream
stream
                case Maybe ByteString
mbBs of
                    Maybe ByteString
Nothing -> do
                        forall a. IORef a -> a -> IO ()
writeIORef (Stream -> IORef StreamState
streamState Stream
stream) (ByteString -> StreamState
Closed ByteString
B.empty)
                        forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
                    Just ByteString
bs -> forall {a}. Decoder a -> Bool -> IO (Maybe a)
go (forall a. Get a -> Decoder a
BIN.runGetIncremental Get a
parser forall a. Decoder a -> ByteString -> Decoder a
`BIN.pushChunk` ByteString
bs) Bool
False
            | Bool
otherwise     -> forall {a}. Decoder a -> Bool -> IO (Maybe a)
go (forall a. Get a -> Decoder a
BIN.runGetIncremental Get a
parser forall a. Decoder a -> ByteString -> Decoder a
`BIN.pushChunk` ByteString
buffer) Bool
False
  where
    -- Buffer is empty when entering this function.
    go :: Decoder a -> Bool -> IO (Maybe a)
go (BIN.Done ByteString
remainder ByteOffset
_ a
x) Bool
closed = do
        forall a. IORef a -> a -> IO ()
writeIORef (Stream -> IORef StreamState
streamState Stream
stream) forall a b. (a -> b) -> a -> b
$
            if Bool
closed then ByteString -> StreamState
Closed ByteString
remainder else ByteString -> StreamState
Open ByteString
remainder
        forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just a
x)
    go (BIN.Partial Maybe ByteString -> Decoder a
f) Bool
closed
        | Bool
closed    = Decoder a -> Bool -> IO (Maybe a)
go (Maybe ByteString -> Decoder a
f forall a. Maybe a
Nothing) Bool
True
        | Bool
otherwise = do
            Maybe ByteString
mbBs <- Stream -> IO (Maybe ByteString)
streamIn Stream
stream
            case Maybe ByteString
mbBs of
                Maybe ByteString
Nothing -> Decoder a -> Bool -> IO (Maybe a)
go (Maybe ByteString -> Decoder a
f forall a. Maybe a
Nothing) Bool
True
                Just ByteString
bs -> Decoder a -> Bool -> IO (Maybe a)
go (Maybe ByteString -> Decoder a
f (forall a. a -> Maybe a
Just ByteString
bs)) Bool
False
    go (BIN.Fail ByteString
_ ByteOffset
_ String
err) Bool
_ = forall e a. Exception e => e -> IO a
throwIO (String -> ConnectionException
ParseException String
err)


parse :: Stream -> Atto.Parser a -> IO (Maybe a)
parse :: forall a. Stream -> Parser a -> IO (Maybe a)
parse Stream
stream Parser a
parser = do
    StreamState
state <- forall a. IORef a -> IO a
readIORef (Stream -> IORef StreamState
streamState Stream
stream)
    case StreamState
state of
        Closed ByteString
remainder
            | ByteString -> Bool
B.null ByteString
remainder -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
            | Bool
otherwise        -> forall {a}. IResult ByteString a -> Bool -> IO (Maybe a)
go (forall a. Parser a -> ByteString -> Result a
Atto.parse Parser a
parser ByteString
remainder) Bool
True
        Open ByteString
buffer
            | ByteString -> Bool
B.null ByteString
buffer -> do
                Maybe ByteString
mbBs <- Stream -> IO (Maybe ByteString)
streamIn Stream
stream
                case Maybe ByteString
mbBs of
                    Maybe ByteString
Nothing -> do
                        forall a. IORef a -> a -> IO ()
writeIORef (Stream -> IORef StreamState
streamState Stream
stream) (ByteString -> StreamState
Closed ByteString
B.empty)
                        forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
                    Just ByteString
bs -> forall {a}. IResult ByteString a -> Bool -> IO (Maybe a)
go (forall a. Parser a -> ByteString -> Result a
Atto.parse Parser a
parser ByteString
bs) Bool
False
            | Bool
otherwise     -> forall {a}. IResult ByteString a -> Bool -> IO (Maybe a)
go (forall a. Parser a -> ByteString -> Result a
Atto.parse Parser a
parser ByteString
buffer) Bool
False
  where
    -- Buffer is empty when entering this function.
    go :: IResult ByteString a -> Bool -> IO (Maybe a)
go (Atto.Done ByteString
remainder a
x) Bool
closed = do
        forall a. IORef a -> a -> IO ()
writeIORef (Stream -> IORef StreamState
streamState Stream
stream) forall a b. (a -> b) -> a -> b
$
            if Bool
closed then ByteString -> StreamState
Closed ByteString
remainder else ByteString -> StreamState
Open ByteString
remainder
        forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just a
x)
    go (Atto.Partial ByteString -> IResult ByteString a
f) Bool
closed
        | Bool
closed    = IResult ByteString a -> Bool -> IO (Maybe a)
go (ByteString -> IResult ByteString a
f ByteString
B.empty) Bool
True
        | Bool
otherwise = do
            Maybe ByteString
mbBs <- Stream -> IO (Maybe ByteString)
streamIn Stream
stream
            case Maybe ByteString
mbBs of
                Maybe ByteString
Nothing -> IResult ByteString a -> Bool -> IO (Maybe a)
go (ByteString -> IResult ByteString a
f ByteString
B.empty) Bool
True
                Just ByteString
bs -> IResult ByteString a -> Bool -> IO (Maybe a)
go (ByteString -> IResult ByteString a
f ByteString
bs) Bool
False
    go (Atto.Fail ByteString
_ [String]
_ String
err) Bool
_ = forall e a. Exception e => e -> IO a
throwIO (String -> ConnectionException
ParseException String
err)


--------------------------------------------------------------------------------
write :: Stream -> BL.ByteString -> IO ()
write :: Stream -> ByteString -> IO ()
write Stream
stream = Stream -> Maybe ByteString -> IO ()
streamOut Stream
stream forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just


--------------------------------------------------------------------------------
close :: Stream -> IO ()
close :: Stream -> IO ()
close Stream
stream = Stream -> Maybe ByteString -> IO ()
streamOut Stream
stream forall a. Maybe a
Nothing