{- | This module contains the incremental decoder logic to process a continuous stream.

Here is an example stream layout:

> | EBML | SIZE | ELT | ELTA | ELT | ... | SEGMENT | USIZE | ELT | ELTB | ... |
> | CLUSTER | USIZE | ELT | ELTC   | ... | CLUSTER | USIZE | ELT | ELT  | ... |

There are two difficulties:

- The element are not aligned, a segment id can start at position 15.
- Unknown sized element, such as segment and cluster, need to use a look-ahead to ensure it is completed.

Here are the main scenarios:

- The initial buffers does not contains the begining of a media segment.
  In that case we need to accumulate the data to provide the complete initialization segments.

- The buffer contains multiple segments. In that case we need to find the last one, e.g. the most recent.

- The buffer ends on the middle of the cluster id, e.g. "...\x1f\x43".
  In that case we need to wait for the next buffer to confirm a new media segment exists.
  We also need to returns the end of the previous buffer, so that the media segment does start with "\x1f\x43...".
  This is somehow already managed by the 'Data.Binary.Get.runGetIncremental'.

Checkout the 'testIncrementalLookahead' case in the test/Spec.hs module that validates these scenarios.
-}
module Codec.EBML.Stream (StreamReader, newStreamReader, StreamFrame (..), feedReader) where

import Control.Monad (when)
import Data.Binary.Get qualified as Get
import Data.ByteString qualified as BS
import Data.Text (Text)
import Data.Text qualified as Text

import Codec.EBML.Decoder
import Codec.EBML.Element
import Codec.EBML.Matroska
import Codec.EBML.Schema
import Codec.EBML.WebM qualified as WebM

-- | A valid frame that can be served.
data StreamFrame = StreamFrame
    { StreamFrame -> ByteString
initialization :: BS.ByteString
    -- ^ The initialization segments, to be provided before the first media segment.
    , StreamFrame -> ByteString
media :: BS.ByteString
    -- ^ The begining of the last media segment found in the input buffer.
    }

-- | Create a stream reader with 'newStreamReader', and decode media segments with 'feedReader'.
data StreamReader = StreamReader
    { StreamReader -> Either (Int, [ByteString]) ByteString
header :: Either (Int, [BS.ByteString]) BS.ByteString
    -- ^ The stream initialization segments, either an accumulator (read bytes, list of buffer), either the full segments.
    , StreamReader -> Decoder ()
decoder :: Get.Decoder ()
    -- ^ The current decoder.
    }

streamSchema :: EBMLSchemas
streamSchema :: EBMLSchemas
streamSchema = [EBMLSchema] -> EBMLSchemas
compileSchemas [EBMLSchema]
schemaHeader

-- | Read elements until the first cluster eid.
getUntilNextCluster :: Get.Get [EBMLElement]
getUntilNextCluster :: Get [EBMLElement]
getUntilNextCluster =
    forall a. Get (Maybe a) -> Get (Maybe a)
Get.lookAheadM Get (Maybe EBMLElement)
getNonCluster forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just EBMLElement
elt -> do
            [EBMLElement]
elts <- Get [EBMLElement]
getUntilNextCluster
            forall (f :: * -> *) a. Applicative f => a -> f a
pure (EBMLElement
elt forall a. a -> [a] -> [a]
: [EBMLElement]
elts)
        Maybe EBMLElement
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
  where
    getNonCluster :: Get (Maybe EBMLElement)
getNonCluster = do
        EBMLID
eid <- Get EBMLID
getElementID
        if EBMLID
eid forall a. Eq a => a -> a -> Bool
== EBMLID
0x1F43B675
            then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
            else do
                EBMLElementHeader
elth <- EBMLID -> Maybe Word64 -> EBMLElementHeader
EBMLElementHeader EBMLID
eid forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get (Maybe Word64)
getMaybeDataSize
                forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EBMLSchemas -> EBMLElementHeader -> Get EBMLElement
getElementValue EBMLSchemas
streamSchema EBMLElementHeader
elth

-- | Read the initialization frame.
getInitialization :: Get.Get ()
getInitialization :: Get ()
getInitialization = do
    -- Read the EBML header element
    EBMLElement
elt <- EBMLSchemas -> Get EBMLElement
getElement EBMLSchemas
streamSchema
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EBMLElement
elt.header.eid forall a. Eq a => a -> a -> Bool
/= EBMLID
0x1A45DFA3) do
        forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Invalid magic: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show EBMLElement
elt.header

    -- Read the begining of the first segment, until the first cluster
    EBMLElementHeader
segmentHead <- Get EBMLElementHeader
getElementHeader
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EBMLElementHeader
segmentHead.eid forall a. Eq a => a -> a -> Bool
/= EBMLID
0x18538067) do
        forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Invalid segment: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show EBMLElementHeader
segmentHead
    [EBMLElement]
elts <- Get [EBMLElement]
getUntilNextCluster
    case [EBMLElement] -> Either Text WebMDocument
WebM.decodeSegment [EBMLElement]
elts of
        Right WebMDocument
_webmDocument -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Left Text
err -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail (Text -> String
Text.unpack Text
err)

-- | Read a cluster frame.
getCluster :: Get.Get ()
getCluster :: Get ()
getCluster = do
    EBMLElementHeader
clusterHead <- Get EBMLElementHeader
getElementHeader
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EBMLElementHeader
clusterHead.eid forall a. Eq a => a -> a -> Bool
/= EBMLID
0x1F43B675) do
        forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Invalid cluster: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show EBMLElementHeader
clusterHead
    [EBMLElement]
elts <- Get [EBMLElement]
getUntilNextCluster
    case [EBMLElement]
elts of
        (EBMLElement
elt : [EBMLElement]
_) | EBMLElement
elt.header.eid forall a. Eq a => a -> a -> Bool
== EBMLID
0xE7 -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        [EBMLElement]
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cluster first element is not a timestamp"

-- | Initialize a stream reader.
newStreamReader :: StreamReader
newStreamReader :: StreamReader
newStreamReader = Either (Int, [ByteString]) ByteString -> Decoder () -> StreamReader
StreamReader (forall a b. a -> Either a b
Left (Int
0, [])) (forall a. Get a -> Decoder a
Get.runGetIncremental Get ()
getInitialization)

-- | Feed data into a stream reader. Returns either an error, or maybe a new 'StreamFrame' and an updated StreamReader.
feedReader :: BS.ByteString -> StreamReader -> Either Text (Maybe StreamFrame, StreamReader)
feedReader :: ByteString
-> StreamReader -> Either Text (Maybe StreamFrame, StreamReader)
feedReader = Maybe StreamFrame
-> ByteString
-> StreamReader
-> Either Text (Maybe StreamFrame, StreamReader)
go forall a. Maybe a
Nothing
  where
    -- This is the end
    go :: Maybe StreamFrame
-> ByteString
-> StreamReader
-> Either Text (Maybe StreamFrame, StreamReader)
go Maybe StreamFrame
Nothing ByteString
"" StreamReader
_ = forall a b. a -> Either a b
Left Text
"empty buffer"
    -- Feed the decoder
    go Maybe StreamFrame
mFrame ByteString
bs StreamReader
sr =
        case forall a. Decoder a -> ByteString -> Decoder a
Get.pushChunk StreamReader
sr.decoder ByteString
bs of
            Get.Fail ByteString
_ ByteOffset
_ String
s -> forall a b. a -> Either a b
Left (String -> Text
Text.pack String
s)
            -- More data is needed.
            newDecoder :: Decoder ()
newDecoder@(Get.Partial Maybe ByteString -> Decoder ()
_) -> forall a b. b -> Either a b
Right (Maybe StreamFrame
mFrame, StreamReader
newSR)
              where
                -- Accumulate the buffer for the initialization segments if needed.
                newHeader :: Either (Int, [ByteString]) ByteString
newHeader = case StreamReader
sr.header of
                    Left (Int
consumed, [ByteString]
acc) -> forall a b. a -> Either a b
Left (Int
consumed forall a. Num a => a -> a -> a
+ ByteString -> Int
BS.length ByteString
bs, ByteString
bs forall a. a -> [a] -> [a]
: [ByteString]
acc)
                    Right ByteString
_ -> StreamReader
sr.header
                newSR :: StreamReader
newSR = Either (Int, [ByteString]) ByteString -> Decoder () -> StreamReader
StreamReader Either (Int, [ByteString]) ByteString
newHeader Decoder ()
newDecoder
            Get.Done ByteString
leftover ByteOffset
consumed ()
_ -> Maybe StreamFrame
-> ByteString
-> StreamReader
-> Either Text (Maybe StreamFrame, StreamReader)
go Maybe StreamFrame
newFrame ByteString
leftover StreamReader
newSR
              where
                -- The header is either the one already parsed, or the current complete decoded buffer.
                newHeader :: ByteString
newHeader = case StreamReader
sr.header of
                    Left (Int
prevConsumed, [ByteString]
acc) ->
                        let currentPos :: Int
currentPos = forall a b. (Integral a, Num b) => a -> b
fromIntegral ByteOffset
consumed forall a. Num a => a -> a -> a
- Int
prevConsumed
                         in forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse (Int -> ByteString -> ByteString
BS.take Int
currentPos ByteString
bs forall a. a -> [a] -> [a]
: [ByteString]
acc)
                    Right ByteString
header -> ByteString
header
                -- The new frame starts after what was decoded.
                newFrame :: Maybe StreamFrame
newFrame = forall a. a -> Maybe a
Just (ByteString -> ByteString -> StreamFrame
StreamFrame ByteString
newHeader ByteString
leftover)
                newSR :: StreamReader
newSR = Either (Int, [ByteString]) ByteString -> Decoder () -> StreamReader
StreamReader (forall a b. b -> Either a b
Right ByteString
newHeader) (forall a. Get a -> Decoder a
Get.runGetIncremental Get ()
getCluster)