{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiWayIf #-}
module Network.SSH.TStreamingQueue where
import Control.Concurrent.STM.TChan
import Control.Concurrent.STM.TVar
import Control.Concurrent.STM.TMVar
import Control.Monad.STM
import Control.Applicative
import Data.Word
import qualified Data.ByteString as BS
import qualified Data.ByteString.Short as SBS
import Prelude hiding ( head
, tail
)
import qualified Network.SSH.Stream as S
import Network.SSH.Constants
data TStreamingQueue
= TStreamingQueue
{ qCapacity :: Word32
, qWindow :: TVar Word32
, qSize :: TVar Word32
, qEof :: TVar Bool
, qHead :: TMVar SBS.ShortByteString
, qTail :: TChan SBS.ShortByteString
}
newTStreamingQueue :: Word32 -> TVar Word32 -> STM TStreamingQueue
newTStreamingQueue c window =
TStreamingQueue c window <$> newTVar 0 <*> newTVar False <*> newEmptyTMVar <*> newTChan
capacity :: TStreamingQueue -> Word32
capacity = qCapacity
getSize :: TStreamingQueue -> STM Word32
getSize = readTVar . qSize
getFree :: TStreamingQueue -> STM Word32
getFree q = (capacity q -) <$> getSize q
getWindowSpace :: TStreamingQueue -> STM Word32
getWindowSpace = readTVar . qWindow
addWindowSpace :: TStreamingQueue -> Word32 -> STM ()
addWindowSpace q increment = do
wndw <- getWindowSpace q :: STM Word32
check $ (fromIntegral wndw + fromIntegral increment :: Word64) <= fromIntegral (maxBound :: Word32)
writeTVar (qWindow q) $! wndw + increment
askWindowSpaceAdjustRecommended :: TStreamingQueue -> STM Bool
askWindowSpaceAdjustRecommended q = do
size <- getSize q
wndw <- getWindowSpace q
let threshold = capacity q `div` 2
pure $ size <= threshold && wndw <= threshold
fillWindowSpace :: TStreamingQueue -> STM Word32
fillWindowSpace q = do
free <- getFree q
wndw <- getWindowSpace q
writeTVar (qWindow q) $! wndw + free
pure free
terminate :: TStreamingQueue -> STM ()
terminate q =
writeTVar (qEof q) True
enqueue :: TStreamingQueue -> BS.ByteString -> STM Word32
enqueue q bs
| BS.null bs = pure 0
| otherwise = do
eof <- readTVar (qEof q)
if eof then pure 0 else do
size <- getSize q
wndw <- getWindowSpace q
let free = capacity q - size
requested = fromIntegral (BS.length bs) :: Word32
available = min (min free wndw) maxBoundIntWord32 :: Word32
check $ available > 0
if | available >= requested -> do
writeTVar (qSize q) $! size + requested
writeTVar (qWindow q) $! wndw - requested
writeTChan (qTail q) $! SBS.toShort bs
pure requested
| otherwise -> do
writeTVar (qSize q) $! size + available
writeTVar (qWindow q) $! wndw - available
writeTChan (qTail q) $! SBS.toShort $ BS.take (fromIntegral available) bs
pure available
dequeue :: TStreamingQueue -> Word32 -> STM BS.ByteString
dequeue q maxBufSize = do
size <- getSize q
eof <- readTVar (qEof q)
check $ size > 0 || eof
if size == 0 && eof
then pure mempty
else SBS.fromShort . mconcat <$> f size requested
where
f s 0 = do
writeTVar (qSize q) $! s - requested
pure []
f s j = do
bs <- takeTMVar (qHead q) <|> readTChan (qTail q) <|> pure mempty
if | SBS.null bs -> do
writeTVar (qSize q) 0
pure []
| fromIntegral (SBS.length bs) <= j ->
(bs:) <$> f s (j - fromIntegral (SBS.length bs))
| otherwise -> do
writeTVar (qSize q) $! s - requested
putTMVar (qHead q) $! SBS.toShort $ BS.drop (fromIntegral j) $ SBS.fromShort bs
pure [ SBS.toShort $ BS.take (fromIntegral j) $ SBS.fromShort bs ]
requested = min maxBufSize maxBoundIntWord32
dequeueShort :: TStreamingQueue -> Word32 -> STM SBS.ShortByteString
dequeueShort q maxBufSize = do
size <- getSize q
eof <- readTVar (qEof q)
check $ size > 0 || eof
if size == 0 && eof
then pure mempty
else mconcat <$> f size requested
where
f s 0 = do
writeTVar (qSize q) $! s - requested
pure []
f s j = do
bs <- takeTMVar (qHead q) <|> readTChan (qTail q) <|> pure mempty
if | SBS.null bs -> do
writeTVar (qSize q) 0
pure []
| fromIntegral (SBS.length bs) <= j ->
(bs:) <$> f s (j - fromIntegral (SBS.length bs))
| otherwise -> do
writeTVar (qSize q) $! s - requested
putTMVar (qHead q) $! SBS.toShort $ BS.drop (fromIntegral j) $ SBS.fromShort bs
pure [ SBS.toShort $ BS.take (fromIntegral j) $ SBS.fromShort bs ]
requested = min maxBufSize maxBoundIntWord32
lookAhead :: TStreamingQueue -> Word32 -> STM BS.ByteString
lookAhead q maxBufSize = do
size <- getSize q
eof <- readTVar (qEof q)
check $ size > 0 || eof
if size == 0 && eof
then pure mempty
else do
bs <- readTMVar (qHead q) <|> peekTChan (qTail q)
pure $ BS.take (fromIntegral maxBufSize) (SBS.fromShort bs)
instance S.DuplexStream TStreamingQueue
instance S.OutputStream TStreamingQueue where
send q bs = fromIntegral <$> atomically (enqueue q bs)
instance S.InputStream TStreamingQueue where
peek q i = atomically $ lookAhead q $ fromIntegral $ min i maxBoundIntWord32
receive q i = atomically $ dequeue q $ fromIntegral $ min i maxBoundIntWord32