module Network.HTTP.Semantics.ReadN (
    -- * Reading n bytes
    ReadN,
    defaultReadN,
)
where

import qualified Data.ByteString as B
import Data.IORef
import Network.Socket
import qualified Network.Socket.ByteString as N

-- | Reading n bytes.
type ReadN = Int -> IO B.ByteString

-- | Naive implementation for readN.
--
-- /NOTE/: This function is intended to be used by a single thread only.
-- (It is probably quite rare anyway to want concurrent reads from the /same/
-- network socket.)
defaultReadN :: Socket -> IORef (Maybe B.ByteString) -> ReadN
defaultReadN :: Socket -> IORef (Maybe ByteString) -> ReadN
defaultReadN Socket
_ IORef (Maybe ByteString)
_ Int
0 = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
B.empty
defaultReadN Socket
s IORef (Maybe ByteString)
ref Int
n = do
    Maybe ByteString
mbs <- IORef (Maybe ByteString) -> IO (Maybe ByteString)
forall a. IORef a -> IO a
readIORef IORef (Maybe ByteString)
ref
    IORef (Maybe ByteString) -> Maybe ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe ByteString)
ref Maybe ByteString
forall a. Maybe a
Nothing
    case Maybe ByteString
mbs of
        Maybe ByteString
Nothing -> do
            ByteString
bs <- Socket -> ReadN
N.recv Socket
s Int
n
            if ByteString -> Bool
B.null ByteString
bs
                then ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
B.empty
                else
                    if ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n
                        then ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
                        else ByteString -> IO ByteString
loop ByteString
bs
        Just ByteString
bs
            | ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n -> ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
            | ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n -> do
                let (ByteString
bs0, ByteString
bs1) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
n ByteString
bs
                IORef (Maybe ByteString) -> Maybe ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe ByteString)
ref (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
bs1)
                ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs0
            | Bool
otherwise -> ByteString -> IO ByteString
loop ByteString
bs
  where
    loop :: ByteString -> IO ByteString
loop ByteString
bs = do
        let n' :: Int
n' = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
B.length ByteString
bs
        ByteString
bs1 <- Socket -> ReadN
N.recv Socket
s Int
n'
        if ByteString -> Bool
B.null ByteString
bs1
            then ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
B.empty
            else do
                let bs2 :: ByteString
bs2 = ByteString
bs ByteString -> ByteString -> ByteString
`B.append` ByteString
bs1
                if ByteString -> Int
B.length ByteString
bs2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n
                    then ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs2
                    else ByteString -> IO ByteString
loop ByteString
bs2