module Data.Conduit.Cereal.Internal
( ConduitErrorHandler
, SinkErrorHandler
, SinkTerminationHandler
, mkConduitGet
, mkSinkGet
) where
import Control.Monad (forever, when)
import qualified Data.ByteString as BS
import qualified Data.Conduit as C
import Data.Serialize hiding (get, put)
type ConduitErrorHandler m o = String -> C.Conduit BS.ByteString m o
type SinkErrorHandler m r = String -> C.Consumer BS.ByteString m r
type SinkTerminationHandler m r = (BS.ByteString -> Result r) -> C.Consumer BS.ByteString m r
mkConduitGet :: Monad m
=> ConduitErrorHandler m o
-> Get o
-> C.Conduit BS.ByteString m o
mkConduitGet errorHandler get = consume True (runGetPartial get) [] BS.empty
where pull f b s
| BS.null s = C.await >>= maybe (when (not $ null b) (C.leftover $ BS.concat $ reverse b)) (pull f b)
| otherwise = consume False f b s
consume initial f b s = case f s of
Fail msg _ -> do
when (not $ null b) (C.leftover $ BS.concat $ reverse consumed)
errorHandler msg
Partial p -> pull p consumed BS.empty
Done a s' -> case initial of
True -> forever $ C.yield a
False -> C.yield a >> pull (runGetPartial get) [] s'
where consumed = s : b
mkSinkGet :: Monad m
=> SinkErrorHandler m r
-> SinkTerminationHandler m r
-> Get r
-> C.Consumer BS.ByteString m r
mkSinkGet errorHandler terminationHandler get = consume (runGetPartial get) [] BS.empty
where pull f b s
| BS.null s = C.await >>= \ x -> case x of
Nothing -> when (not $ null b) (C.leftover $ BS.concat $ reverse b) >> terminationHandler f
Just a -> pull f b a
| otherwise = consume f b s
consume f b s = case f s of
Fail msg _ -> do
when (not $ null b) (C.leftover $ BS.concat $ reverse consumed)
errorHandler msg
Partial p -> pull p consumed BS.empty
Done r s' -> when (not $ BS.null s') (C.leftover s') >> return r
where consumed = s : b