module Network.BufferedSocket
    ( BufferedSocket()
    , makeBuffered
    , socketPort
    ) where

import           Control.Monad                  ( unless )
import           Control.Monad.IO.Class         ( MonadIO(..), liftIO )
import qualified Data.ByteString                as BS ( ByteString, append, empty, length, null, splitAt )
import qualified Data.ByteString.Lazy           as LBS ( ByteString )
import           Data.IORef                     ( IORef, atomicModifyIORef', newIORef, writeIORef )
import qualified Network.Socket                 as S ( PortNumber, Socket, close, socketPort )
import qualified Network.Socket.ByteString      as NBS ( recv )
import qualified Network.Socket.ByteString.Lazy as NBL ( sendAll )
import           Util.BufferedIOx

--------------------------------------------------------------------------------
newtype BufferedSocket = BufferedSocket (S.Socket, IORef BS.ByteString)

instance BufferedIOx BufferedSocket where
    readBuffered :: BufferedSocket -> Int -> m ByteString
readBuffered BufferedSocket
a = IO ByteString -> m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString)
-> (Int -> IO ByteString) -> Int -> m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BufferedSocket -> Int -> IO ByteString
socketRecv BufferedSocket
a
    unreadBuffered :: BufferedSocket -> ByteString -> m ()
unreadBuffered BufferedSocket
a = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (ByteString -> IO ()) -> ByteString -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BufferedSocket -> ByteString -> IO ()
pushback BufferedSocket
a
    writeBuffered :: BufferedSocket -> ByteString -> m ()
writeBuffered BufferedSocket
a = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (ByteString -> IO ()) -> ByteString -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BufferedSocket -> ByteString -> IO ()
socketSend BufferedSocket
a
    closeBuffered :: BufferedSocket -> m ()
closeBuffered = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ())
-> (BufferedSocket -> IO ()) -> BufferedSocket -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BufferedSocket -> IO ()
socketClose

makeBuffered :: S.Socket -> IO BufferedSocket
makeBuffered :: Socket -> IO BufferedSocket
makeBuffered Socket
sock = do
    IORef ByteString
bufIO <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
BS.empty
    BufferedSocket -> IO BufferedSocket
forall (m :: * -> *) a. Monad m => a -> m a
return (BufferedSocket -> IO BufferedSocket)
-> BufferedSocket -> IO BufferedSocket
forall a b. (a -> b) -> a -> b
$ (Socket, IORef ByteString) -> BufferedSocket
BufferedSocket (Socket
sock, IORef ByteString
bufIO)

socketPort :: BufferedSocket -> IO S.PortNumber
socketPort :: BufferedSocket -> IO PortNumber
socketPort (BufferedSocket (Socket
sock, IORef ByteString
_)) =
    Socket -> IO PortNumber
S.socketPort Socket
sock

socketRecv :: BufferedSocket -> Int -> IO BS.ByteString
socketRecv :: BufferedSocket -> Int -> IO ByteString
socketRecv (BufferedSocket (Socket
sock, IORef ByteString
bufIO)) Int
len
    | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = [Char] -> IO ByteString
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO ByteString) -> [Char] -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [Char]
"Bad length: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
len
    | Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
BS.empty
    | Bool
otherwise = do
          IORef ByteString
-> (ByteString -> (ByteString, Maybe ByteString))
-> IO (Maybe ByteString)
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef ByteString
bufIO
                             (\ByteString
buf -> if ByteString -> Bool
BS.null ByteString
buf
                                      then (ByteString
BS.empty, Maybe ByteString
forall a. Maybe a
Nothing)
                                      else let bufLen :: Int
bufLen = ByteString -> Int
BS.length ByteString
buf
                                           in
                                               if Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
bufLen
                                               then (ByteString
BS.empty, (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
buf))
                                               else let (ByteString
buf0, ByteString
buf1) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
len ByteString
buf
                                                    in
                                                        (ByteString
buf1, ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
buf0)) IO (Maybe ByteString)
-> (Maybe ByteString -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
              IO ByteString
-> (ByteString -> IO ByteString)
-> Maybe ByteString
-> IO ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Socket -> Int -> IO ByteString
NBS.recv Socket
sock Int
len) (ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return)

pushback :: BufferedSocket -> BS.ByteString -> IO ()
pushback :: BufferedSocket -> ByteString -> IO ()
pushback (BufferedSocket (Socket
_, IORef ByteString
bufIO)) ByteString
bytes = do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
bytes) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        IORef ByteString -> (ByteString -> (ByteString, ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef ByteString
bufIO (\ByteString
buf -> (ByteString
buf ByteString -> ByteString -> ByteString
`BS.append` ByteString
bytes, ()))

socketSend :: BufferedSocket -> LBS.ByteString -> IO ()
socketSend :: BufferedSocket -> ByteString -> IO ()
socketSend (BufferedSocket (Socket
sock, IORef ByteString
_)) ByteString
bl = do
    Socket -> ByteString -> IO ()
NBL.sendAll Socket
sock ByteString
bl

socketClose :: BufferedSocket -> IO ()
socketClose :: BufferedSocket -> IO ()
socketClose (BufferedSocket (Socket
sock, IORef ByteString
bufIO)) = do
    IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
bufIO ByteString
BS.empty
    Socket -> IO ()
S.close Socket
sock