{-# LANGUAGE Rank2Types #-}
-- | Parsers for use with 'ByteStream's.
module Bio.Streaming.Parse
    ( Parser
    , ParseError(..)
    , EofException(..)
    , parse
    , parseIO
    , parseLog
    , parseM
    , abortParse
    , isFinished
    , drop
    , dropLine
    , getByte
    , getString
    , getWord32
    , getWord64
    , isolate
    , atto
    ) where

import Bio.Prelude                       hiding ( drop )
import Bio.Streaming.Bytes                      ( ByteStream )

import qualified Bio.Streaming.Bytes            as S
import qualified Data.Attoparsec.ByteString     as A
import qualified Data.ByteString                as B
import qualified Streaming.Prelude              as Q

newtype Parser r m a = P {
    runP :: forall x .
            (a -> ByteStream m r -> m x)                -- successful parse
         -> (r -> m x)                                  -- end of input stream
         -> (SomeException -> ByteStream m r -> m x)    -- exception and remaining input
         -> ByteStream m r -> m x }                     -- input, result

instance Functor (Parser r m) where
    fmap f p = P $ \sk -> runP p (sk . f)

instance Applicative (Parser r m) where
    pure a = P $ \sk _rk _ek -> sk a
    a <*> b = P $ \sk rk ek -> runP a (\f -> runP b (\x -> sk (f x)) rk ek) rk ek

instance Monad (Parser r m) where
    return = pure
    m >>= k = P $ \sk rk ek -> runP m (\a -> runP (k a) sk rk ek) rk ek

instance MonadIO m => MonadIO (Parser r m) where
    liftIO m = P $ \sk _rk _ek s -> liftIO m >>= \a -> sk a s

instance MonadTrans (Parser r) where
    lift m = P $ \sk _rk _ek s -> m >>= \a -> sk a s

instance MonadThrow (Parser r m) where
    throwM e = P $ \_sk _rk ek -> ek (toException e)

modify :: (ByteStream m r -> ByteStream m r) -> Parser r m ()
modify f = P $ \sk _rk _ek -> sk () . f

parse :: Monad m => (Int64 -> Parser r m a) -> ByteStream m r
      -> m (Either (SomeException, ByteStream m r) (Either r (a, ByteStream m r)))
parse p = go
  where
    go    (S.Empty     r)             = return $ Right $ Left r
    go    (S.Go        k)             = k >>= go
    go ck@(S.Chunk c o s) | B.null  c = go s
                          | otherwise = runP (p o) (\a t -> return . Right $ Right (a,t))
                                                   (return . Right . Left)
                                                   (curry $ return . Left)
                                                   ck

parseIO :: MonadIO m => (Int64 -> Parser r m a) -> ByteStream m r -> m (Either r (a, ByteStream m r))
parseIO p = parse p >=> either (liftIO . throwM . fst) return

parseLog :: MonadLog m => Level -> (Int64 -> Parser r m a) -> ByteStream m r -> m (Either r (a, ByteStream m r))
parseLog lv p = parse p >=> either throw_it pure
  where throw_it (ex,rest) = logMsg lv ex >> Left <$> S.effects rest

parseM :: MonadThrow m => (Int64 -> Parser r m a) -> ByteStream m r -> m (Either r (a, ByteStream m r))
parseM p = parse p >=> either (throwM . fst) return

abortParse :: Monad m => Parser r m a
abortParse = P $ \_sk rk _ek -> S.effects >=> rk

liftFun :: Monad m => (ByteStream m r -> m (a, ByteStream m r)) -> Parser r m a
liftFun f = P $ \sk _rk _ek -> f >=> uncurry sk

isFinished :: Monad m => Parser r m Bool
isFinished = liftFun go
  where
    go    (S.Empty     r)             = return (True, S.Empty r)
    go    (S.Go        k)             = k >>= go
    go ck@(S.Chunk c _ s) | B.null  c = go s
                          | otherwise = return (False, ck)

drop :: Monad m => Int -> Parser r m ()
drop l = modify $ S.drop (fromIntegral l)

dropLine :: Monad m => Parser r m ()
dropLine = modify $ S.drop 1 . S.dropWhile (/= 10)

getByte :: Monad m => Parser r m Word8
getByte = P $ \sk _rk ek -> S.nextByte >=> either (ek (toException EofException) . pure) (uncurry sk)

getString :: Monad m => Int -> Parser r m B.ByteString
getString l = liftFun $ liftM Q.lazily . S.splitAt' l

getWord32 :: Monad m => Parser r m Word32
getWord32 = liftM (fst . B.foldl (\(a,i) w -> (a + shiftL (fromIntegral w) i, i + 8)) (0,0)) (getString 4)

getWord64 :: Monad m => Parser r m Word64
getWord64 = liftM (fst . B.foldl (\(a,i) w -> (a + shiftL (fromIntegral w) i, i + 8)) (0,0)) (getString 8)

isolate :: Monad m => Int -> Parser (ByteStream m r) m a -> Parser r m a
isolate l p = P $ \sk rk ek -> runP p (\a -> S.effects >=> sk a)
                                      (S.effects >=> rk)
                                      (\e rest -> ek e (join rest)) .
                               S.splitAt (fromIntegral l)


data EofException = EofException deriving (Show, Typeable)
instance Exception EofException where displayException _ = "end-of-file"

data ParseError = ParseError {errorContexts :: [String], errorMessage :: String} deriving (Show, Typeable)
instance Exception ParseError where
    displayException (ParseError ctx msg)
        = "Parse error at " ++ intercalate ", " ctx ++ ": " ++ msg

atto :: Monad m => A.Parser a -> Parser r m a
atto = go . A.parse
  where
    go k = P $ \sk rk ek ->
        S.nextChunk >=> \case
            Left r -> case k B.empty of
                      A.Fail _ err dsc -> ek (toException (ParseError err dsc)) (pure r)
                      A.Partial _      -> ek (toException EofException) (pure r)
                      A.Done rest v    -> sk v (S.consChunk rest (pure r))
            Right (c,s')
                | B.null c -> runP (go k) sk rk ek s'
                | otherwise -> case k c of
                      A.Fail _ err dsc -> ek (toException (ParseError err dsc)) s'
                      A.Partial k'     -> runP (go k') sk rk ek s'
                      A.Done rest v    -> sk v (S.consChunk rest s')