----------------------------------------------------------------------------- -- | -- Module : Network.Connection -- Copyright : Adam Langley -- License : BSD3-style (see LICENSE) -- -- Maintainer : Adam Langley -- Stability : experimental -- -- Helpful functions to deal with stream-like connections ----------------------------------------------------------------------------- module Network.Connection ( -- * Base connections BaseConnection(..) , baseConnectionFromSocket -- * Connection functions , Connection , new , newSTM , forkWriterThread , forkInConnection , close , write , writeAtLowWater , read , reada , pushBack ) where import Prelude hiding (foldl, read, catch) import Control.Concurrent import Control.Concurrent.STM import Control.Exception import Control.Monad import Data.Foldable (foldl) import qualified Data.ByteString as B import qualified Data.Sequence as Seq import Network.Socket hiding (send, sendTo, recv, recvFrom) import Network.Socket.ByteString -- | A BaseConnection abstracts a stream like connection. data BaseConnection = BaseConnection { -- | Read, at most, the given number of bytes from the connection and return -- a ByteString of the data. EOF is signaled by an exception and a zero -- length string is never a valid return value baseRead :: Int -> IO B.ByteString -- | Write the given ByteString to the connection. The write may write less -- than the requested number of bytes (but must always write at least one -- byte) , baseWrite :: B.ByteString -> IO Int -- | Close a connection , baseClose :: IO () } -- | Return a BaseConnection for the given socket. baseConnectionFromSocket :: Socket -> BaseConnection baseConnectionFromSocket sock = BaseConnection read write close where read = recv sock write = send sock close = sClose sock -- | A Connection uses the functions from a BaseConnection and wraps them a -- number of commonly needed behaviours. -- -- Firstly, a write queue is introduced so that writes can be non-blocking. -- -- Secondly, the Connection can manage a number of threads. Almost always -- there will be a writer thread which is taking items from the write queue -- and writing them to the BaseConnection. In addition, there can be zero or -- more other threads managed by the Connection. If a thread which is managed -- dies, by throwing an exception or otherwise, it will close the connection -- and all other managed threads will be killed. -- -- There is also the concept of pushing data back into the Connection. This -- is useful in a chain of reader functions where, for efficiency reasons, -- you would want to read large blocks at a time, but the data is -- self-deliminating so you would otherwise end up in a situation where you -- had read too much. See the pushBack function for details. data Connection = Connection { connbase :: BaseConnection , connoutq :: TVar (Seq.Seq B.ByteString) , connthreads :: TVar [ThreadId] , connpushback :: TVar (Seq.Seq B.ByteString) , conndeath :: IO () , conndead :: TVar Bool } updateTVar :: TVar a -> (a -> a) -> STM () updateTVar tvar f = do v <- readTVar tvar writeTVar tvar $ f v -- | Create a new Connection from a BaseConnection object new :: IO () -- ^ the action to run when the connection closes -> BaseConnection -- ^ the socket-like object to make a connection from -> IO Connection new deathaction baseconn = do conn <- atomically $ newSTM deathaction baseconn forkWriterThread conn return conn -- | This creates most of a Connection, purely in the STM monad. The Connection -- returned from this must be passed to forkWriterThread, otherwise nothing -- will ever get written. newSTM :: IO () -- ^ the action run when the connection closes -> BaseConnection -- ^ the socket-like object to make a connection from -> STM Connection newSTM deathaction baseconn = do dead <- newTVar False outq <- newTVar Seq.empty pushback <- newTVar Seq.empty threads <- newTVar [] return $ Connection baseconn outq threads pushback deathaction dead -- | If you created the Connection in the STM monad using newSTM, you need to -- call this on it in order to create the thread which processes the outgoing -- queue. forkWriterThread :: Connection -- ^ the connection to fork the writer thread for -> IO () forkWriterThread conn = do sync <- atomically $ newTVar False writer <- forkIO $ waitForReadySignal sync $ connectionThreadWrapper conn $ seqToSocket (connoutq conn) $ baseWrite $ connbase conn -- update the thread ids in the Connection and set the ready flag atomically (updateTVar (connthreads conn) ((:) writer) >> writeTVar sync True) -- | Run the given action, as if by forkIO, and manage the thread. If the given -- action completes or throws an exception, the connection will be closed and -- all other managed threads will be killed forkInConnection :: Connection -- ^ the connection to close on death -> IO () -- ^ the action to run -> IO () forkInConnection conn action = do sync <- atomically $ newTVar False thread <- forkIO $ waitForReadySignal sync $ connectionThreadWrapper conn action atomically (updateTVar (connthreads conn) ((:) thread) >> writeTVar sync True) -- | Wait for the given TVar to be true and then run the given action waitForReadySignal :: (TVar Bool) -> IO a -> IO a waitForReadySignal sync action = do atomically (do go <- readTVar sync if go == True then return () else retry) action killThreads :: Connection -> IO () killThreads conn = do isDead <- atomically $ do dead <- readTVar (conndead conn) when (not dead) $ writeTVar (conndead conn) True return dead when (not isDead) $ do t <- atomically (readTVar $ connthreads conn) me <- myThreadId mapM_ killThread $ filter ((/=) me) t baseClose $ connbase conn conndeath conn -- | Not all exceptions are safe to catch because of the way the GC works. If a -- thread is killed because it's waiting on a TVar which is now garbage (e.g. -- our writer thread when the Connection goes out of scope), all ForeignPtrs -- held by the thread are also garbage, /at the same time/. Thus we can end -- up holding invalid ForeignPtrs if we catch unsafe exceptions and try to -- cleanup. safeException :: Exception -> Maybe Exception safeException (AsyncException _) = Nothing safeException BlockedOnDeadMVar = Nothing safeException BlockedIndefinitely = Nothing safeException x = Just x -- | Wrap a connection thread so that, when the thread dies, it races to set -- the dead flag. If it does so, it closes the socket and kills the other -- threads connectionThreadWrapper :: Connection -> IO a -> IO a connectionThreadWrapper conn action = do handleJust safeException (\e -> killThreads conn >> throwIO e) action -- | Close a connection close :: Connection -> IO () close = killThreads -- | Enqueue a ByteString to a connection. This does not block. write :: Connection -> B.ByteString -> STM () write conn bs = do s <- readTVar $ connoutq conn writeTVar (connoutq conn) (bs Seq.<| s) -- | Block until the write queue has less than the given number of bytes in it -- then enqueue a new ByteString. writeAtLowWater :: Int -- ^ the max number of bytes in the queue before we enqueue anything -> Connection -- ^ the connection to write to -> B.ByteString -- ^ the data to enqueue -> STM () writeAtLowWater lw conn bs = do q <- readTVar $ connoutq conn let size = foldl (\sz bs -> sz + B.length bs) 0 q if size > lw then retry else writeTVar (connoutq conn) $ bs Seq.<| q -- | Read some number of bytes from a connection. The size is only a hint, -- the returned data may be shorter. A zero length read is EOF read :: Connection -> Int -> IO B.ByteString read conn sz = do pb <- atomically $ do pushback <- readTVar $ connpushback conn case Seq.viewl pushback of Seq.EmptyL -> return Nothing head Seq.:< rest -> if B.length head <= sz then do writeTVar (connpushback conn) rest return $ Just head else do let (left, right) = B.splitAt sz head writeTVar (connpushback conn) $ right Seq.<| rest return $ Just left case pb of Nothing -> (baseRead $ connbase conn) sz Just bs -> return bs -- | Read exactly a give number of bytes reada :: Connection -> Int -> IO B.ByteString reada conn n = do bytes <- read conn n when (B.null bytes) $ fail "EOF in reada" let remaining = n - B.length bytes if remaining == 0 then return bytes else reada conn remaining >>= return . B.append bytes -- | Unread some amount of data. It will be returned in the next call to read. -- -- The function pushes data to the front of the queue. Thus you need to push -- all the data base in one go, or the order of future reads will be wrong. -- -- This might seem like an error, but consider the case of two actions: -- the first reads 20 bytes and pushs back the last 10 of them. The second -- reads 5 bytes and pushs back the last 4. If we appended to the push back -- queue the second action would put those 4 bytes after the remaining 5 from -- the first action. pushBack :: Connection -> B.ByteString -> STM () pushBack conn bs | B.null bs = return () | otherwise = do pushback <- readTVar $ connpushback conn writeTVar (connpushback conn) $ bs Seq.<| pushback -- | Atomically take elements from the end of the given sequence and write them -- to the given socket. Throw an exception when the write fails seqToSocket :: TVar (Seq.Seq B.ByteString) -- ^ data is removed from the end -> (B.ByteString -> IO Int) -- ^ the write function -> IO () seqToSocket q write = do -- Atomically remove an element from the end of the sequence bs <- atomically (do q' <- readTVar q (bs, rest) <- case Seq.viewr q' of Seq.EmptyR -> retry rest Seq.:> head -> return (head, rest) writeTVar q rest return bs) -- Write the data to the socket writea write bs seqToSocket q write -- | Write a given number of bytes to a socket. This wraps a write function -- which may write less than the requested number of bytes so that the whole -- of the given ByteString is written out. writea :: (B.ByteString -> IO Int) -- ^ the write function -> B.ByteString -- ^ the data to write -> IO () writea write bytes | B.null bytes = return () | otherwise = do n <- write bytes if n == B.length bytes then return () else writea write $ B.drop n bytes