-----------------------------------------------------------------------------
-- |
-- Module: Text.XML.LibXML.Enumerator
-- Copyright: 2010 John Millikin
-- License: MIT
--
-- Maintainer: jmillikin@gmail.com
-- Portability: portable
--
-----------------------------------------------------------------------------
module Text.XML.LibXML.Enumerator
	( parseBytesIO
	, parseBytesST
	) where

import           Control.Monad (unless)
import qualified Data.ByteString as B
import qualified Data.Enumerator as E
import           Data.Enumerator ((>>==))
import qualified Data.Text as T
import qualified Data.XML.Types as X
import qualified Text.XML.LibXML.SAX as SAX

import           Control.Exception (ErrorCall(..))
import           Control.Monad.IO.Class (MonadIO, liftIO)
import           Control.Monad.Trans.Class (lift)
import           Control.Monad.ST (ST)
import qualified Data.STRef as ST
import qualified Data.IORef as IO

setCallbacks :: Monad m => Bool -> SAX.Parser m -> (X.Event -> m Bool) -> m ()
setCallbacks expandRefs p addEvent = do
	let set cb st = SAX.setCallback p cb st
	
	set SAX.parsedBeginDocument (addEvent X.EventBeginDocument)
	set SAX.parsedEndDocument (addEvent X.EventEndDocument)
	set SAX.parsedBeginElement ((addEvent .) . X.EventBeginElement)
	set SAX.parsedEndElement (addEvent . X.EventEndElement)
	set SAX.parsedCharacters (addEvent . X.EventContent . X.ContentText)
	set SAX.parsedCDATA (addEvent . X.EventCDATA)
	set SAX.parsedComment (addEvent . X.EventComment)
	set SAX.parsedInstruction (addEvent . X.EventInstruction)
	set SAX.parsedExternalSubset ((addEvent .) . X.EventBeginDoctype)
	
	unless expandRefs (set SAX.parsedReference (addEvent . X.EventContent . X.ContentEntity))

parseBytesIO :: MonadIO m
             => Bool -- ^ Whether to expand entity references
             -> Maybe T.Text -- ^ An optional filename or URI
             -> E.Enumeratee B.ByteString X.Event m b
parseBytesIO expandRefs name s = E.Iteratee $ do
	p <- liftIO (SAX.newParserIO name)
	
	-- error handling
	errRef <- liftIO (IO.newIORef Nothing)
	liftIO (SAX.setCallback p SAX.reportError $ \msg -> do
		IO.writeIORef errRef (Just msg)
		return False)
	
	-- event storage
	eventRef <- liftIO (IO.newIORef [])
	let addEvent e = do
		IO.modifyIORef eventRef (e:)
		return True
	liftIO (setCallbacks expandRefs p addEvent)
	
	let withEvents io = liftIO $ do
		IO.writeIORef eventRef []
		IO.writeIORef errRef Nothing
		void io
		events <- IO.readIORef eventRef
		err <- IO.readIORef errRef
		return (reverse events, err)
	
	let parseChunk bytes = withEvents (SAX.parseBytes p bytes)
	let complete = withEvents (SAX.parseComplete p)
	E.runIteratee $ eneeParser parseChunk complete s

parseBytesST :: Bool -- ^ Whether to expand entity references
             -> Maybe T.Text -- ^ An optional filename or URI
             -> E.Enumeratee B.ByteString X.Event (ST s) b
parseBytesST expandRefs name s = E.Iteratee $do
	p <- SAX.newParserST name
	
	-- error handling
	errRef <- ST.newSTRef Nothing
	SAX.setCallback p SAX.reportError $ \msg -> do
		ST.writeSTRef errRef (Just msg)
		return False
	
	-- event storage
	eventRef <- ST.newSTRef []
	let addEvent e = do
		ST.modifySTRef eventRef (e:)
		return True
	setCallbacks expandRefs p addEvent
	
	let withEvents st = do
		ST.writeSTRef eventRef []
		ST.writeSTRef errRef Nothing
		void st
		events <- ST.readSTRef eventRef
		err <- ST.readSTRef errRef
		return (reverse events, err)
	
	let parseChunk bytes = withEvents (SAX.parseBytes p bytes)
	let complete = withEvents (SAX.parseComplete p)
	E.runIteratee $ eneeParser parseChunk complete s

eneeParser :: Monad m
           => (a -> m ([X.Event], Maybe T.Text))
           -> m ([X.Event], Maybe T.Text)
           -> E.Enumeratee a X.Event m b
eneeParser parseChunk parseComplete = E.checkDone (E.continue . step) where
	step k E.EOF = checkEvents k E.EOF parseComplete (\k' -> E.yield (E.Continue k') E.EOF)
	step k (E.Chunks xs) = parseLoop k xs
	
	parseLoop k [] = E.continue (step k)
	parseLoop k (x:xs) = checkEvents k (E.Chunks xs) (parseChunk x) (\k' -> parseLoop k' xs)
	
	checkEvents k extra getEvents next = do
		(events, maybeErr) <- lift getEvents
		let checkError k' = case maybeErr of
			Nothing -> next k'
			Just err -> E.throwError (ErrorCall (T.unpack err))
		if null events
			then checkError k
			else k (E.Chunks events) >>== E.checkDoneEx extra checkError

void :: Functor m => m a -> m ()
void = fmap (return ())