module Data.Binary.Get.Internal (
    
      Get
    , runCont
    , Decoder(..)
    , runGetIncremental
    , readN
    , readNWith
    
    , bytesRead
    , isolate
    
    , withInputChunks
    , Consume
    , failOnEOF
    , get
    , put
    , ensureN
    
    , remaining
    , getBytes
    , isEmpty
    , lookAhead
    , lookAheadM
    , lookAheadE
    , label
    
    , getByteString
    ) where
import Foreign
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import Control.Applicative
import Control.Monad
#if MIN_VERSION_base(4,9,0)
import qualified Control.Monad.Fail as Fail
#endif
import Data.Binary.Internal ( accursedUnutterablePerformIO )
data Decoder a = Fail !B.ByteString String
              
              
              | Partial (Maybe B.ByteString -> Decoder a)
              
              
              
              | Done !B.ByteString a
              
              
              | BytesRead  !Int64 (Int64 -> Decoder a)
              
              
              
              
newtype Get a = C { runCont :: forall r.
                               B.ByteString ->
                               Success a r ->
                               Decoder   r }
type Success a r = B.ByteString -> a -> Decoder r
instance Monad Get where
  return = pure
  (>>=) = bindG
#if MIN_VERSION_base(4,9,0)
  fail = Fail.fail
instance Fail.MonadFail Get where
#endif
  fail = failG
bindG :: Get a -> (a -> Get b) -> Get b
bindG (C c) f = C $ \i ks -> c i (\i' a -> (runCont (f a)) i' ks)
failG :: String -> Get a
failG str = C $ \i _ks -> Fail i str
apG :: Get (a -> b) -> Get a -> Get b
apG d e = do
  b <- d
  a <- e
  return (b a)
fmapG :: (a -> b) -> Get a -> Get b
fmapG f m = C $ \i ks -> runCont m i (\i' a -> ks i' (f a))
instance Applicative Get where
  pure = \x -> C $ \s ks -> ks s x
  
  (<*>) = apG
  
instance MonadPlus Get where
  mzero = empty
  mplus = (<|>)
instance Functor Get where
  fmap = fmapG
instance Functor Decoder where
  fmap f (Done s a) = Done s (f a)
  fmap f (Partial k) = Partial (fmap f . k)
  fmap _ (Fail s msg) = Fail s msg
  fmap f (BytesRead b k) = BytesRead b (fmap f . k)
instance (Show a) => Show (Decoder a) where
  show (Fail _ msg) = "Fail: " ++ msg
  show (Partial _) = "Partial _"
  show (Done _ a) = "Done: " ++ show a
  show (BytesRead _ _) = "BytesRead"
runGetIncremental :: Get a -> Decoder a
runGetIncremental g = noMeansNo $
  runCont g B.empty (\i a -> Done i a)
noMeansNo :: Decoder a -> Decoder a
noMeansNo r0 = go r0
  where
  go r =
    case r of
      Partial k -> Partial $ \ms ->
                    case ms of
                      Just _ -> go (k ms)
                      Nothing -> neverAgain (k ms)
      BytesRead n k -> BytesRead n (go . k)
      Done _ _ -> r
      Fail _ _ -> r
  neverAgain r =
    case r of
      Partial k -> neverAgain (k Nothing)
      BytesRead n k -> BytesRead n (neverAgain . k)
      Fail _ _ -> r
      Done _ _ -> r
prompt :: B.ByteString -> Decoder a -> (B.ByteString -> Decoder a) -> Decoder a
prompt inp kf ks = prompt' kf (\inp' -> ks (inp `B.append` inp'))
prompt' :: Decoder a -> (B.ByteString -> Decoder a) -> Decoder a
prompt' kf ks =
  let loop =
        Partial $ \sm ->
          case sm of
            Just s | B.null s -> loop
                   | otherwise -> ks s
            Nothing -> kf
  in loop
bytesRead :: Get Int64
bytesRead = C $ \inp k -> BytesRead (fromIntegral $ B.length inp) (k inp)
isolate :: Int   
        -> Get a 
        -> Get a
isolate n0 act
  | n0 < 0 = fail "isolate: negative size"
  | otherwise = go n0 (runCont act B.empty Done)
  where
  go !n (Done left x)
    | n == 0 && B.null left = return x
    | otherwise = do
        pushFront left
        let consumed = n0  n  B.length left
        fail $ "isolate: the decoder consumed " ++ show consumed ++ " bytes" ++
                 " which is less than the expected " ++ show n0 ++ " bytes"
  go 0 (Partial resume) = go 0 (resume Nothing)
  go n (Partial resume) = do
    inp <- C $ \inp k -> do
      let takeLimited str =
            let (inp', out) = B.splitAt n str
            in k out (Just inp')
      case not (B.null inp) of
        True -> takeLimited inp
        False -> prompt inp (k B.empty Nothing) takeLimited
    case inp of
      Nothing -> go n (resume Nothing)
      Just str -> go (n  B.length str) (resume (Just str))
  go _ (Fail bs err) = pushFront bs >> fail err
  go n (BytesRead r resume) =
    go n (resume $! fromIntegral n0  fromIntegral n  r)
type Consume s = s -> B.ByteString -> Either s (B.ByteString, B.ByteString)
withInputChunks :: s -> Consume s -> ([B.ByteString] -> b) -> ([B.ByteString] -> Get b) -> Get b
withInputChunks initS consume onSucc onFail = go initS []
  where
  go state acc = C $ \inp ks ->
    case consume state inp of
      Left state' -> do
        let acc' = inp : acc
        prompt'
          (runCont (onFail (reverse acc')) B.empty ks)
          (\str' -> runCont (go state' acc') str' ks)
      Right (want,rest) -> do
        ks rest (onSucc (reverse (want:acc)))
failOnEOF :: [B.ByteString] -> Get a
failOnEOF bs = C $ \_ _ -> Fail (B.concat bs) "not enough bytes"
isEmpty :: Get Bool
isEmpty = C $ \inp ks ->
    if B.null inp
      then prompt inp (ks inp True) (\inp' -> ks inp' False)
      else ks inp False
getBytes :: Int -> Get B.ByteString
getBytes = getByteString
instance Alternative Get where
  empty = C $ \inp _ks -> Fail inp "Data.Binary.Get(Alternative).empty"
  
  (<|>) f g = do
    (decoder, bs) <- runAndKeepTrack f
    case decoder of
      Done inp x -> C $ \_ ks -> ks inp x
      Fail _ _ -> pushBack bs >> g
      _ -> error "Binary: impossible"
  
  some p = (:) <$> p <*> many p
  
  many p = do
    v <- (Just <$> p) <|> pure Nothing
    case v of
      Nothing -> pure []
      Just x -> (:) x <$> many p
  
runAndKeepTrack :: Get a -> Get (Decoder a, [B.ByteString])
runAndKeepTrack g = C $ \inp ks ->
  let r0 = runCont g inp (\inp' a -> Done inp' a)
      go !acc r = case r of
                    Done inp' a -> ks inp (Done inp' a, reverse acc)
                    Partial k -> Partial $ \minp -> go (maybe acc (:acc) minp) (k minp)
                    Fail inp' s -> ks inp (Fail inp' s, reverse acc)
                    BytesRead unused k -> BytesRead unused (go acc . k)
  in go [] r0
pushBack :: [B.ByteString] -> Get ()
pushBack [] = C $ \ inp ks -> ks inp ()
pushBack bs = C $ \ inp ks -> ks (B.concat (inp : bs)) ()
pushFront :: B.ByteString -> Get ()
pushFront bs = C $ \ inp ks -> ks (B.append bs inp) ()
lookAhead :: Get a -> Get a
lookAhead g = do
  (decoder, bs) <- runAndKeepTrack g
  case decoder of
    Done _ a -> pushBack bs >> return a
    Fail inp s -> C $ \_ _ -> Fail inp s
    _ -> error "Binary: impossible"
lookAheadM :: Get (Maybe a) -> Get (Maybe a)
lookAheadM g = do
  let g' = maybe (Left ()) Right <$> g
  either (const Nothing) Just <$> lookAheadE g'
lookAheadE :: Get (Either a b) -> Get (Either a b)
lookAheadE g = do
  (decoder, bs) <- runAndKeepTrack g
  case decoder of
    Done _ (Left x) -> pushBack bs >> return (Left x)
    Done inp (Right x) -> C $ \_ ks -> ks inp (Right x)
    Fail inp s -> C $ \_ _ -> Fail inp s
    _ -> error "Binary: impossible"
label :: String -> Get a -> Get a
label msg decoder = C $ \inp ks ->
  let r0 = runCont decoder inp (\inp' a -> Done inp' a)
      go r = case r of
                 Done inp' a -> ks inp' a
                 Partial k -> Partial (go . k)
                 Fail inp' s -> Fail inp' (s ++ "\n" ++ msg)
                 BytesRead u k -> BytesRead u (go . k)
  in go r0
remaining :: Get Int64
remaining = C $ \ inp ks ->
  let loop acc = Partial $ \ minp ->
                  case minp of
                    Nothing -> let all_inp = B.concat (inp : (reverse acc))
                               in ks all_inp (fromIntegral $ B.length all_inp)
                    Just inp' -> loop (inp':acc)
  in loop []
getByteString :: Int -> Get B.ByteString
getByteString n | n > 0 = readN n (B.unsafeTake n)
                | otherwise = return B.empty
get :: Get B.ByteString
get = C $ \inp ks -> ks inp inp
put :: B.ByteString -> Get ()
put s = C $ \_inp ks -> ks s ()
readN :: Int -> (B.ByteString -> a) -> Get a
readN !n f = ensureN n >> unsafeReadN n f
ensureN :: Int -> Get ()
ensureN !n0 = C $ \inp ks -> do
  if B.length inp >= n0
    then ks inp ()
    else runCont (withInputChunks n0 enoughChunks onSucc onFail >>= put) inp ks
  where 
        
    enoughChunks n str
      | B.length str >= n = Right (str,B.empty)
      | otherwise = Left (n  B.length str)
    
    
    
    
    onSucc = B.concat . dropWhile B.null
    onFail bss = C $ \_ _ -> Fail (B.concat bss) "not enough bytes"
unsafeReadN :: Int -> (B.ByteString -> a) -> Get a
unsafeReadN !n f = C $ \inp ks -> do
  ks (B.unsafeDrop n inp) $! f inp 
readNWith :: Int -> (Ptr a -> IO a) -> Get a
readNWith n f = do
    
    
    
    readN n $ \s -> accursedUnutterablePerformIO $ B.unsafeUseAsCString s (f . castPtr)