{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}

module Parquet.Stream.Reader where

import qualified Conduit as C
import Control.Applicative (liftA3)
import Control.Monad.Except
import Control.Monad.Logger
import Control.Monad.Reader
import Data.Bifunctor (first)
import qualified Data.Binary.Get as BG
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Conduit.Binary as CB
import qualified Data.Conduit.Serialization.Binary as CB
import Data.Int (Int32, Int64)
import qualified Data.List.NonEmpty as NE
import qualified Data.Map as M
import qualified Data.Text as T
import Data.Traversable (for)
import Data.Word (Word32, Word8)
import Control.Lens
import qualified Pinch
import Safe.Exact (zipExactMay)

import Parquet.Decoder (BitWidth(..), decodeBPBE, decodeRLEBPHybrid)
import Parquet.Monad
import qualified Parquet.ThriftTypes as TT
import Parquet.Utils ((<??>))

data ColumnValue = ColumnValue
  { _cvRepetitionLevel :: Word32
  , _cvDefinitionLevel :: Word32
  , _cvMaxDefinitionLevel :: Word32
  , _cvValue :: Value
  } deriving (Eq, Show)

-- | TODO: This is so unoptimized that my eyes bleed.
replicateMSized :: (Monad m) => Int -> m (Int64, result) -> m (Int64, [result])
replicateMSized 0 _    = pure (0, [])
replicateMSized n comp = do
  (consumed     , result     ) <- comp
  (rest_consumed, rest_result) <- replicateMSized (n - 1) comp
  pure (consumed + rest_consumed, result : rest_result)

forSized :: (Monad m) => [a] -> (a -> m (Int64, result)) -> m (Int64, [result])
forSized = flip traverseSized

traverseSized
  :: (Monad m) => (a -> m (Int64, result)) -> [a] -> m (Int64, [result])
traverseSized _    []       = pure (0, [])
traverseSized comp (x : xs) = do
  (consumed     , result     ) <- comp x
  (rest_consumed, rest_result) <- traverseSized comp xs
  pure (consumed + rest_consumed, result : rest_result)

maxLevelToBitWidth :: Word8 -> BitWidth
maxLevelToBitWidth 0 = BitWidth 0
maxLevelToBitWidth max_level =
  BitWidth $ floor (logBase 2 (fromIntegral max_level) :: Double) + 1

dataPageReader
  :: forall m
   . (PR m, MonadReader PageCtx m)
  => TT.DataPageHeader
  -> Maybe [Value]
  -> C.ConduitT BS.ByteString ColumnValue m ()
dataPageReader header mb_dict = do
  let num_values         = header ^. TT.pinchField @"num_values"
  let def_level_encoding = header ^. TT.pinchField @"definition_level_encoding"
  let rep_level_encoding = header ^. TT.pinchField @"repetition_level_encoding"
  let encoding           = header ^. TT.pinchField @"encoding"
  (max_rep_level, max_def_level) <- calcMaxEncodingLevels
  (_rep_consumed, fill_level_default num_values -> rep_data) <-
    readRepetitionLevel
      rep_level_encoding
      (maxLevelToBitWidth max_rep_level)
      num_values
  (_def_consumed, fill_level_default num_values -> def_data) <-
    readDefinitionLevel
      def_level_encoding
      (maxLevelToBitWidth max_def_level)
      num_values
  level_data <- zip_level_data rep_data def_data
  read_page_content encoding level_data num_values (fromIntegral max_def_level)
  pure ()
 where
  find_from_dict
    :: forall  m0 . MonadError T.Text m0 => [Value] -> Word32 -> m0 Value
  find_from_dict dict (fromIntegral -> d_index) = case dict ^? ix d_index of
    Nothing ->
      throwError $ "A dictionary value couldn't be found in index " <> T.pack
        (show d_index)
    Just val -> pure val

  zip_level_data
    :: forall  m0 a b . MonadError T.Text m0 => [a] -> [b] -> m0 [(a, b)]
  zip_level_data rep_data def_data =
    zipExactMay rep_data def_data
      <??> (  "Repetition and Definition data sizes differ: page has "
           <> T.pack (show (length rep_data))
           <> " repetition values and "
           <> T.pack (show (length def_data))
           <> " definition values."
           )

  fill_level_default :: Int32 -> [Word32] -> [Word32]
  fill_level_default num_values = \case
    [] -> replicate (fromIntegral num_values) 1
    xs -> xs

  read_page_content
    :: TT.Encoding
    -> [(Word32, Word32)]
    -> Int32
    -> Word32
    -> C.ConduitT BS.ByteString ColumnValue m ()
  read_page_content encoding level_data num_values max_def_level =
    case (mb_dict, encoding) of
      (Nothing, TT.PLAIN _) -> do
        vals <- for level_data $ \(r, d) -> if
          | d == max_def_level -> do
            (_, val) <- decodeValue
            pure (ColumnValue r d max_def_level val)
          | otherwise -> pure $ ColumnValue r d max_def_level Null
        C.yieldMany vals
      (Just _, TT.PLAIN _) -> throwError
        "We shouldn't have PLAIN-encoded data pages with a dictionary."
      (Just dict, TT.PLAIN_DICTIONARY _) -> do
        !bit_width  <- CB.sinkGet BG.getWord8
        val_indexes <- CB.sinkGet
          $ decodeRLEBPHybrid (BitWidth bit_width) num_values
        vals <- construct_dict_values max_def_level dict level_data val_indexes
        C.yieldMany vals
      (Nothing, TT.PLAIN_DICTIONARY _) ->
        throwError
          "Data page has PLAIN_DICTIONARY encoding but we don't have a dictionary yet."
      other ->
        throwError
          $  "Don't know how to encode data pages with encoding: "
          <> T.pack (show other)

  -- | Given repetition and definition level data, a dictionary and a set of indexes,
  -- constructs values for this dictionary-encoded page.
  construct_dict_values
    :: forall m0
     . (MonadError T.Text m0)
    => Word32
    -> [Value]
    -> [(Word32, Word32)]
    -> [Word32]
    -> m0 [ColumnValue]
  construct_dict_values _ _ [] _ = pure []
  construct_dict_values _ _ _ [] =
    throwError
      "There are not enough level data for given amount of dictionary indexes."
  construct_dict_values max_def_level dict ((r, d) : lx) (v : vx)
    | d == max_def_level = do
      val <- find_from_dict dict v
      (ColumnValue r d max_def_level val :)
        <$> construct_dict_values max_def_level dict lx vx
    | otherwise = do
      (ColumnValue r d max_def_level Null :)
        <$> construct_dict_values max_def_level dict lx (v : vx)

data Value
  = ValueInt64 Int64
  | ValueByteString BS.ByteString
  | Null
  deriving (Show, Eq)

decodeValue
  :: (PR m, MonadReader PageCtx m)
  => C.ConduitT BS.ByteString o m (Int64, Value)
decodeValue = asks _pcColumnTy >>= \case
  (TT.BYTE_ARRAY _) -> do
    !len               <- CB.sinkGet BG.getWord32le
    (consumed, result) <- replicateMSized
      (fromIntegral len)
      (CB.sinkGet (sizedGet BG.getWord8))
    pure (consumed + 4, ValueByteString (BS.pack result))
  (TT.INT64 _) -> do
    (consumed, result) <- CB.sinkGet (sizedGet BG.getInt64le)
    pure (consumed, ValueInt64 result)
  ty ->
    throwError
      $  "Don't know how to decode value of type "
      <> T.pack (show ty)
      <> " yet."

dictPageReader
  :: (MonadReader PageCtx m, PR m)
  => TT.DictionaryPageHeader
  -> C.ConduitT BS.ByteString o m (Int64, [Value])
dictPageReader header = do
  let num_values = header ^. TT.pinchField @"num_values"
  let _encoding  = header ^. TT.pinchField @"encoding"
  let _is_sorted = header ^. TT.pinchField @"is_sorted"
  (consumed, vals) <- replicateMSized (fromIntegral num_values) decodeValue
  pure (consumed, vals)

data PageCtx =
  PageCtx
    { _pcSchema :: M.Map T.Text TT.SchemaElement
    , _pcPath :: NE.NonEmpty T.Text
    , _pcColumnTy :: TT.Type
    }
  deriving (Show, Eq)

getLastSchemaElement
  :: (MonadError T.Text m, MonadReader PageCtx m) => m TT.SchemaElement
getLastSchemaElement = do
  path   <- asks _pcPath
  schema <- asks _pcSchema
  M.lookup (NE.head (NE.reverse path)) schema
    <??> "Schema element could not be found"

readDefinitionLevel
  :: (PR m, MonadReader PageCtx m)
  => TT.Encoding
  -> BitWidth
  -> Int32
  -> C.ConduitT BS.ByteString a m (Int64, [Word32])
readDefinitionLevel _ (BitWidth 0) _ = pure (0, [])
readDefinitionLevel encoding bit_width num_values =
  getLastSchemaElement >>= getRepType >>= \case
    TT.OPTIONAL _ -> decodeLevel encoding bit_width num_values
    TT.REPEATED _ -> decodeLevel encoding bit_width num_values
    TT.REQUIRED _ -> pure (0, [])

readRepetitionLevel
  :: (C.MonadThrow m, MonadError T.Text m, MonadReader PageCtx m, MonadLogger m)
  => TT.Encoding
  -> BitWidth
  -> Int32
  -> C.ConduitT BS.ByteString a m (Int64, [Word32])
readRepetitionLevel encoding bit_width num_values = do
  path <- asks _pcPath
  if NE.length path > 1
    then decodeLevel encoding bit_width num_values
    else pure (0, [])

sizedGet :: BG.Get result -> BG.Get (Int64, result)
sizedGet g = do
  (before, result, after) <- liftA3 (,,) BG.bytesRead g BG.bytesRead
  pure (after - before, result)

decodeLevel
  :: (C.MonadThrow m, MonadError T.Text m, MonadLogger m, MonadReader PageCtx m)
  => TT.Encoding
  -> BitWidth
  -> Int32
  -> C.ConduitT BS.ByteString a m (Int64, [Word32])
decodeLevel _        (BitWidth 0) _                            = pure (0, [])
decodeLevel encoding bit_width (fromIntegral -> num_values) = case encoding of
  TT.RLE _ ->
    CB.sinkGet
      $  sizedGet
      $  BG.getWord32le
      *> decodeRLEBPHybrid bit_width num_values
  TT.BIT_PACKED _ ->
    CB.sinkGet
      $  sizedGet
      $  BG.getWord32le
      *> (take (fromIntegral num_values) <$> decodeBPBE bit_width)
  _ -> throwError "Only RLE and BIT_PACKED encodings are supported for levels"

-- | Algorithm:
-- https://blog.twitter.com/engineering/en_us/a/2013/dremel-made-simple-with-parquet.html
calcMaxEncodingLevels
  :: (MonadReader PageCtx m, MonadError T.Text m) => m (Word8, Word8)
calcMaxEncodingLevels = do
  schema      <- asks _pcSchema
  path        <- asks _pcPath
  filled_path <- for path $ \name ->
    M.lookup name schema <??> "Schema Element cannot be found: " <> name
  foldM
    (\(rep, def) e -> getRepType e >>= \case
      (TT.REQUIRED _) -> pure (rep, def)
      (TT.OPTIONAL _) -> pure (rep, def + 1)
      (TT.REPEATED _) -> pure (rep + 1, def + 1)
    )
    (0, 0)
    filled_path

getRepType
  :: MonadError T.Text m => TT.SchemaElement -> m TT.FieldRepetitionType
getRepType e =
  e
    ^.   TT.pinchField @"repetition_type"
    <??> "Repetition type could not be found for elem "
    <>   T.pack (show e)

validateCompression :: MonadError T.Text m => TT.ColumnMetaData -> m ()
validateCompression metadata =
  let compression = metadata ^. TT.pinchField @"codec"
  in
    case compression of
      TT.UNCOMPRESSED _ -> pure ()
      _ ->
        throwError "This library doesn't support compression algorithms yet."

readColumnChunk
  :: PR m
  => M.Map T.Text TT.SchemaElement
  -> TT.ColumnChunk
  -> C.ConduitT BS.ByteString ColumnValue m ()
readColumnChunk schema cc = do
  let mb_metadata = cc ^. TT.pinchField @"meta_data"
  metadata <- mb_metadata <??> "Metadata could not be found"
  validateCompression metadata
  let size      = metadata ^. TT.pinchField @"total_compressed_size"
  let column_ty = metadata ^. TT.pinchField @"type"
  let path      = metadata ^. TT.pinchField @"path_in_schema"
  ne_path <- NE.nonEmpty path <??> "Schema path cannot be empty"
  let page_ctx = PageCtx schema ne_path column_ty
  C.runReaderC page_ctx $ readPage size Nothing

readPage
  :: (MonadReader PageCtx m, PR m)
  => Int64
  -> Maybe [Value]
  -> C.ConduitT BS.ByteString ColumnValue m ()
readPage 0         _       = pure ()
readPage remaining mb_dict = do
  (page_header_size, page_header :: TT.PageHeader) <- decodeConduit remaining
  let
    page_content_size = page_header ^. TT.pinchField @"uncompressed_page_size"
  let
    validate_consumed_page_bytes consumed =
      unless (fromIntegral page_content_size == consumed)
        $ throwError "Reader did not consume the whole page!"
  let page_size = fromIntegral page_header_size + page_content_size
  case
      ( page_header ^. TT.pinchField @"dictionary_page_header"
      , page_header ^. TT.pinchField @"data_page_header"
      , mb_dict
      )
    of
      (Just dict_page_header, Nothing, Nothing) -> do
        (page_consumed, dict) <- dictPageReader dict_page_header
        validate_consumed_page_bytes page_consumed
        readPage (remaining - fromIntegral page_size) (Just dict)
      (Just _dict_page_header, Nothing, Just _dict) ->
        throwError "Found dictionary page while we already had a dictionary."
      (Nothing, Just dp_header, Nothing) -> do
        dataPageReader dp_header Nothing
      (Nothing, Just dp_header, Just dict) -> do
        dataPageReader dp_header (Just dict)
      (Nothing, Nothing, _) -> throwError
        "Page doesn't have any of the dictionary or data page header."
      (Just _, Just _, _) ->
        throwError "Page has both dictionary and data page headers."

failOnError :: Show err => IO (Either err b) -> IO b
failOnError v = v >>= \case
  Left  err -> fail $ show err
  Right val -> pure val

decodeConduit
  :: forall a size m o
   . (MonadError T.Text m, MonadIO m, Integral size, Pinch.Pinchable a)
  => size
  -> C.ConduitT BS.ByteString o m (Int, a)
decodeConduit (fromIntegral -> size) = do
  (left, val) <-
    liftEither
    .   first T.pack
    .   Pinch.decodeWithLeftovers Pinch.compactProtocol
    .   LBS.toStrict
    =<< CB.take size
  C.leftover left
  pure (size - BS.length left, val)