{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StrictData          #-}
{-# LANGUAGE TupleSections       #-}
{-# LANGUAGE TypeApplications    #-}
module Data.Avro.Internal.Container
where
import           Control.Monad                (when)
import qualified Data.Aeson                   as Aeson
import           Data.Avro.Codec              (Codec (..), Decompress)
import qualified Data.Avro.Codec              as Codec
import           Data.Avro.Encoding.ToAvro    (toAvro)
import           Data.Avro.Internal.EncodeRaw (encodeRaw)
import           Data.Avro.Schema.Schema      (Schema)
import qualified Data.Avro.Schema.Schema      as Schema
import           Data.Binary.Get              (Get)
import qualified Data.Binary.Get              as Get
import           Data.ByteString              (ByteString)
import           Data.ByteString.Builder      (Builder, lazyByteString, toLazyByteString)
import qualified Data.ByteString.Lazy         as BL
import qualified Data.ByteString.Lazy.Char8   as BLC
import           Data.Either                  (isRight)
import           Data.HashMap.Strict          (HashMap)
import qualified Data.HashMap.Strict          as HashMap
import           Data.Int                     (Int32, Int64)
import           Data.List                    (foldl', unfoldr)
import qualified Data.Map.Strict              as Map
import           Data.Text                    (Text)
import           System.Random.TF.Init        (initTFGen)
import           System.Random.TF.Instances   (randoms)
import qualified Data.Avro.Internal.Get as AGet
data ContainerHeader = ContainerHeader
  { syncBytes       :: BL.ByteString
  , decompress      :: forall a. Decompress a
  , containedSchema :: Schema
  }
nrSyncBytes :: Integral sb => sb
nrSyncBytes = 16
{-# INLINE nrSyncBytes #-}
newSyncBytes :: IO BL.ByteString
newSyncBytes = BL.pack . take nrSyncBytes . randoms <$> initTFGen
getContainerHeader :: Get ContainerHeader
getContainerHeader = do
  magic <- getFixed avroMagicSize
  when (BL.fromStrict magic /= avroMagicBytes)
        (fail "Invalid magic number at start of container.")
  metadata <- getMeta
  sync  <- BL.fromStrict <$> getFixed nrSyncBytes
  codec <- parseCodec (Map.lookup "avro.codec" metadata)
  schema <- case Map.lookup "avro.schema" metadata of
              Nothing -> fail "Invalid container object: no schema."
              Just s  -> case Aeson.eitherDecode' s of
                            Left e  -> fail ("Can not decode container schema: " <> e)
                            Right x -> return x
  return ContainerHeader  { syncBytes = sync
                          , decompress = Codec.codecDecompress codec
                          , containedSchema = schema
                          }
  where avroMagicSize :: Integral a => a
        avroMagicSize = 4
        avroMagicBytes :: BL.ByteString
        avroMagicBytes = BLC.pack "Obj" <> BL.pack [1]
        getFixed :: Int -> Get ByteString
        getFixed = Get.getByteString
        getMeta :: Get (Map.Map Text BL.ByteString)
        getMeta =
          let keyValue = (,) <$> AGet.getString <*> AGet.getBytesLazy
          in Map.fromList <$> AGet.decodeBlocks keyValue
decodeRawBlocks :: BL.ByteString -> Either String (Schema, [Either String (Int, BL.ByteString)])
decodeRawBlocks bs =
  case Get.runGetOrFail getContainerHeader bs of
    Left (bs', _, err) -> Left err
    Right (bs', _, ContainerHeader {..}) ->
      let blocks = allBlocks syncBytes decompress bs'
      in Right (containedSchema, blocks)
  where
    allBlocks sync decompress bytes =
      flip unfoldr (Just bytes) $ \acc -> case acc of
        Just rest -> next sync decompress rest
        Nothing   -> Nothing
    next syncBytes decompress bytes =
      case getNextBlock syncBytes decompress bytes of
        Right (Just (numObj, block, rest)) -> Just (Right (numObj, block), Just rest)
        Right Nothing                      -> Nothing
        Left err                           -> Just (Left err, Nothing)
getNextBlock :: BL.ByteString
             -> Decompress BL.ByteString
             -> BL.ByteString
             -> Either String (Maybe (Int, BL.ByteString, BL.ByteString))
getNextBlock sync decompress bs =
  if BL.null bs
    then Right Nothing
    else case Get.runGetOrFail (getRawBlock decompress) bs of
      Left (bs', _, err)             -> Left err
      Right (bs', _, (nrObj, bytes)) ->
        case checkMarker sync bs' of
          Left err   -> Left err
          Right rest -> Right $ Just (nrObj, bytes, rest)
  where
    getRawBlock :: Decompress BL.ByteString -> Get (Int, BL.ByteString)
    getRawBlock decompress = do
      nrObj    <- AGet.getLong >>= AGet.sFromIntegral
      nrBytes  <- AGet.getLong
      compressed <- Get.getLazyByteString nrBytes
      bytes <- case decompress compressed Get.getRemainingLazyByteString of
        Right x  -> pure x
        Left err -> fail err
      pure (nrObj, bytes)
    checkMarker :: BL.ByteString -> BL.ByteString -> Either String BL.ByteString
    checkMarker sync bs =
      case BL.splitAt nrSyncBytes bs of
        (marker, _) | marker /= sync -> Left "Invalid marker, does not match sync bytes."
        (_, rest)                    -> Right rest
extractContainerValuesBytes :: forall a schema.
     (Schema -> Either String schema)
  -> (schema -> Get a)
  -> BL.ByteString
  -> Either String (Schema, [Either String (a, BL.ByteString)])
extractContainerValuesBytes deconflict f =
  extractContainerValues deconflict readBytes
  where
    readBytes sch = do
      start <- Get.bytesRead
      (val, end) <- Get.lookAhead (f sch >>= (\v -> (v, ) <$> Get.bytesRead))
      res <- Get.getLazyByteString (end-start)
      pure (val, res)
extractContainerValues :: forall a schema.
     (Schema -> Either String schema)
  -> (schema -> Get a)
  -> BL.ByteString
  -> Either String (Schema, [Either String a])
extractContainerValues deconflict f bs = do
  (sch, blocks) <- decodeRawBlocks bs
  readSchema <- deconflict sch
  pure (sch, takeWhileInclusive isRight $ blocks >>= decodeBlock readSchema)
  where
    decodeBlock _ (Left err)               = undefined
    decodeBlock sch (Right (nrObj, bytes)) = snd $ consumeN (fromIntegral nrObj) (decodeValue sch) bytes
    decodeValue sch bytes =
      case Get.runGetOrFail (f sch) bytes of
        Left (bs', _, err)  -> (bs', Left err)
        Right (bs', _, res) -> (bs', Right res)
packContainerValues :: Codec -> Schema -> [[BL.ByteString]] -> IO BL.ByteString
packContainerValues codec sch values = do
  sync <- newSyncBytes
  pure $ packContainerValuesWithSync codec sch sync values
packContainerValuesWithSync :: Codec -> Schema -> BL.ByteString -> [[BL.ByteString]] -> BL.ByteString
packContainerValuesWithSync = packContainerValuesWithSync' (\_ a -> lazyByteString a)
{-# INLINABLE packContainerValuesWithSync #-}
packContainerValuesWithSync' ::
     (Schema -> a -> Builder)
  -> Codec
  -> Schema
  -> BL.ByteString
  -> [[a]]
  -> BL.ByteString
packContainerValuesWithSync' encode codec sch syncBytes values =
  toLazyByteString $ containerHeaderWithSync codec sch syncBytes <> foldMap putBlock values
  where
    putBlock ys =
      let nrObj = length ys
          nrBytes = BL.length theBytes
          theBytes = codecCompress codec $ toLazyByteString $ foldMap (encode sch) ys
      in encodeRaw @Int32 (fromIntegral nrObj) <>
         encodeRaw nrBytes <>
         lazyByteString theBytes <>
         lazyByteString syncBytes
packContainerBlocks :: Codec -> Schema -> [(Int, BL.ByteString)] -> IO BL.ByteString
packContainerBlocks codec sch blocks = do
  sync <- newSyncBytes
  pure $ packContainerBlocksWithSync codec sch sync blocks
packContainerBlocksWithSync :: Codec -> Schema -> BL.ByteString -> [(Int, BL.ByteString)] -> BL.ByteString
packContainerBlocksWithSync codec sch syncBytes blocks =
  toLazyByteString $
    containerHeaderWithSync codec sch syncBytes <>
    foldMap putBlock blocks
  where
    putBlock (nrObj, bytes) =
      let compressed = codecCompress codec bytes in
        encodeRaw @Int32 (fromIntegral nrObj) <>
        encodeRaw (BL.length compressed) <>
        lazyByteString compressed <>
        lazyByteString syncBytes
containerHeaderWithSync :: Codec -> Schema -> BL.ByteString -> Builder
containerHeaderWithSync codec sch syncBytes =
  lazyByteString avroMagicBytes
    <> toAvro (Schema.Map Schema.Bytes') headers
    <> lazyByteString syncBytes
  where
    avroMagicBytes :: BL.ByteString
    avroMagicBytes = "Obj" <> BL.pack [1]
    headers :: HashMap Text BL.ByteString
    headers =
      HashMap.fromList
        [
          ("avro.schema", Aeson.encode sch)
        , ("avro.codec", BL.fromStrict (codecName codec))
        ]
consumeN :: Int64 -> (a -> (a, b)) -> a -> (a, [b])
consumeN n f a =
  if n == 0
    then (a, [])
    else
      let (a', b) = f a
          (r, bs) = consumeN (n-1) f a'
      in (r, b:bs)
{-# INLINE consumeN #-}
parseCodec :: Monad m => Maybe BL.ByteString -> m Codec
parseCodec (Just "null")    = pure Codec.nullCodec
parseCodec (Just "deflate") = pure Codec.deflateCodec
parseCodec (Just x)         = error $ "Unrecognized codec: " <> BLC.unpack x
parseCodec Nothing          = pure Codec.nullCodec
takeWhileInclusive :: (a -> Bool) -> [a] -> [a]
takeWhileInclusive _ [] = []
takeWhileInclusive p (x:xs) =
  x : if p x then takeWhileInclusive p xs else []
{-# INLINE takeWhileInclusive #-}