{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleContexts #-}

-- A number of standard aggregation functions.

module Spark.Core.Internal.Groups(
  GroupData,
  LogicalGroupData,
  -- Typed functions
  groupByKey,
  mapGroup,
  aggKey,
  groupAsDS
  -- Developer

) where

import qualified Data.Text as T
import qualified Data.Vector as V
import Formatting
import Debug.Trace(trace)

import Spark.Core.Internal.DatasetStructures
import Spark.Core.Internal.ColumnStructures
import Spark.Core.Internal.ColumnFunctions(untypedCol, colType, colOp, iUntypedColData, colOrigin, castTypeCol, dropColReference, genColOp)
import Spark.Core.Internal.DatasetFunctions
import Spark.Core.Internal.LocalDataFunctions()
import Spark.Core.Internal.FunctionsInternals
import Spark.Core.Internal.TypesFunctions(tupleType, structTypeFromFields)
import Spark.Core.Internal.OpStructures
import Spark.Core.Internal.Projections
import Spark.Core.Internal.TypesStructures
import Spark.Core.Internal.Utilities
import Spark.Core.Internal.RowStructures(Cell)
import Spark.Core.Try
import Spark.Core.StructuresInternal
import Spark.Core.Internal.CanRename

{-| A dataset that has been partitioned according to some given field.
-}
data GroupData key val = GroupData {
  -- The dataset of reference for this group
  _gdRef :: !UntypedDataset,
  -- The columns used to partition the data by keys.
  _gdKey :: !GroupColumn,
  -- The columns that contain the values.
  _gdValue :: !GroupColumn
}

type LogicalGroupData = Try UntypedGroupData

-- A column in a group, that can be used either for key or for values.
-- It is different from the column data, because it does not include
-- broadcast data.
data GroupColumn = GroupColumn {
  _gcType :: !DataType,
  _gcOp :: !ColOp,
  _gcRefName :: !(Maybe FieldName)
} deriving (Eq, Show)


{-| (developper)

A group data type with no typing information.
-}
type UntypedGroupData = GroupData Cell Cell

-- type GroupTry a = Either T.Text a

-- A useful type when chaining operations withing groups.
data PipedTrans =
    PipedError !T.Text
  | PipedDataset !UntypedDataset
  | PipedGroup !UntypedGroupData
  deriving (Show)


{-| Performs a logical group of data based on a key.
-}
groupByKey :: (HasCallStack) => Column ref key -> Column ref val -> GroupData key val
groupByKey keys vals = forceRight $ _castGroup (colType keys) (colType vals) =<< _groupByKey (iUntypedColData keys) (iUntypedColData vals)

{-| Transforms the values in a group.
-}
-- This only allows direct transforms, so it is probably valid in all cases.
mapGroup :: GroupData key val -> (forall ref. Column ref val -> Column ref val') -> GroupData key val'
mapGroup g f =
  let c = _valueCol g
      c' = f (_unsafeCastColData c)
      -- Assume for now that there is no broadcast.
      -- TODO: deal with broadcast eventually
      gVals = forceRight $ _groupCol c'
  in g {  _gdValue = gVals }

{-| The generalized value transform.

This generalizes mapGroup to allow more complex transforms involving joins,
groups, etc.
-}
-- TODO: this can fail
-- magGroupGen :: (forall ref. Column ref val -> Dataset val') -> GroupData key val -> GroupData key val'
-- magGroupGen _ _ = undefined

{-| Given a group and an aggregation function, aggregates the data.

Note: not all the reduction functions may be used in this case. The analyzer
will fail if the function is not universal.
-}
-- TODO: it should be a try, this can fail
aggKey :: (HasCallStack) => GroupData key val -> (forall ref. Column ref val -> LocalData val') -> Dataset (key, val')
aggKey gd f = trace "aggKey" $
  let ugd = _untypedGroup gd
      keyt = traceHint "aggKey: keyt: " $  mapGroupKeys gd colType
      valt = traceHint "aggKey: valt: " $  mapGroupValues gd colType
      -- We call the function twice: the first one to recover the type info,
      -- and the second time to perform the unrolling.
      -- TODO we should be able to do it in one pass instead.
      fOut = traceHint "aggKey: fOut: " $  f (mapGroupValues gd dropColReference)
      valt' = traceHint "aggKey: valt': " $ nodeType fOut
      t = traceHint "aggKey: t: " $ tupleType keyt valt'
      f' c = untypedLocalData . f <$> castTypeCol valt c
      tud = traceHint "aggKey: tud: " $ _aggKey ugd f'
      res = castType' t tud
  in forceRight res

{-| Creates a group by 'expanding' a value into a potentially large collection.

Note on performance: this function is optimized to work at any scale and may not
be the most efficient when the generated collections are small (a few elements).
-}
-- TODO: it should be a try, this can fail
-- expand :: Column ref key -> Column ref val -> (LocalData val -> Dataset val') -> GroupData key val'
-- expand = undefined

{-| Builds groups within groups.

This function allows groups to be constructed from each collections inside a
group.

This function is usually not used directly by the user, but rather as part of
more complex pipelines that may involve multiple levels of nesting.
-}
-- groupInGroup :: GroupData key val -> (forall ref. Column ref val -> GroupData key' val') -> GroupData (key', key) val'
-- groupInGroup _ _ = undefined

{-| Reduces a group in group into a single group.
-}
-- aggGroup :: GroupData (key, key') val -> (forall ref. LocalData key -> Column ref val -> LocalData val') -> GroupData key val
-- aggGroup _ _ = undefined

{-| Returns the collapsed representation of a grouped dataset, discarding group
information.
-}
groupAsDS :: forall key val. GroupData key val -> Dataset (key, val)
groupAsDS g = pack s where
  c1 = _unsafeCastColData (_keyCol g) :: Column UnknownReference key
  c2 = _unsafeCastColData (_valueCol g) :: Column UnknownReference val
  s = struct (c1, c2) :: Column UnknownReference (key, val)

mapGroupKeys :: GroupData key val -> (forall ref. Column ref key -> a) -> a
mapGroupKeys gd f =
  f (_unsafeCastColData (_keyCol gd))

mapGroupValues :: GroupData key val -> (forall ref. Column ref val -> a) -> a
mapGroupValues gd f =
  f (_unsafeCastColData (_valueCol gd))

-- ******** INSTANCES ***********


instance Show (GroupData key val) where
  show gd = T.unpack s where
    s = sformat ("GroupData[key="%sh%", val="%sh%"]") (_gdKey gd) (_gdValue gd)

-- ******** PRIVATE METHODS ********

_keyCol :: GroupData key val -> UntypedColumnData
_keyCol gd = ColumnData {
    _cOrigin = _gdRef gd,
    _cType = _gcType (_gdKey gd),
    _cOp = genColOp . _gcOp . _gdKey $ gd,
    _cReferingPath = _gcRefName . _gdKey $ gd
  }

_valueCol :: GroupData key val -> UntypedColumnData
_valueCol gd = ColumnData {
    _cOrigin = _gdRef gd,
    _cType = _gcType (_gdValue gd),
    _cOp = genColOp . _gcOp . _gdValue $ gd,
    _cReferingPath = _gcRefName . _gdValue $ gd
  }


_pError :: T.Text -> PipedTrans
_pError = PipedError

_unrollTransform :: PipedTrans -> NodeId -> UntypedNode -> PipedTrans
_unrollTransform start nid un | nodeId un == nid = start
_unrollTransform start nid un = case nodeParents un of
    [p] ->
      let pt' = _unrollTransform start nid p in _unrollStep pt' un
    _ ->
      _pError $ sformat (sh%": operations with multiple parents cannot be used in groups yet.") un

_unrollStep :: PipedTrans -> UntypedNode -> PipedTrans
_unrollStep pt un = traceHint ("_unrollStep: pt=" <> show' pt <> " un=" <> show' un <> " res=") $
  let op = nodeOp un
      dt = unSQLType (nodeType un) in case nodeParents un of
    [p] ->
      case (pt, op) of
        (PipedError e, _) -> PipedError e
        (PipedDataset ds, NodeStructuredTransform _) ->
          -- This is simply dointg a DS -> DS transform.
          -- TODO: this breaks the encapsulation of ComputeNode
          let ds' = updateNode un (\un' -> un' { _cnParents = V.singleton (untyped ds)})
          in PipedDataset ds'
        (PipedGroup g, NodeStructuredTransform co) ->
          _unrollGroupTrans g co
        (PipedGroup g, NodeAggregatorReduction uao) ->
          case uaoInitialOuter uao of
            OpaqueAggTransform x -> _pError $ sformat ("Cannot apply opaque transform in the context of an aggregation: "%sh) x
            InnerAggOp ao ->
              PipedDataset $ _applyAggOp dt ao g
        _ -> _pError $ sformat (sh%": Operation not supported with trans="%sh%" and parents="%sh) op pt p
    l -> _pError $ sformat (sh%": expected one parent but got "%sh) un l

-- dt: output type of the aggregation op
_applyAggOp :: (HasCallStack) => DataType -> AggOp -> UntypedGroupData -> UntypedDataset
_applyAggOp dt ao ugd = traceHint ("_applyAggOp dt=" <> show' dt <> " ao=" <> show' ao <> " ugd=" <> show' ugd <> " res=") $
  -- Reset the names to make sure there are no collision.
  let c1 = untypedCol (_keyCol ugd) @@ T.unpack "_1"
      c2 = untypedCol (_valueCol ugd) @@ T.unpack "_2"
      s = struct' [c1, c2]
      p = pack1 <$> s
      ds = forceRight p
      -- The structure of the result dataframe
      keyDt = unSQLType (colType (_keyCol ugd))
      st' = structTypeFromFields [(unsafeFieldName "key", keyDt), (unsafeFieldName "agg", dt)]
      -- The keys are different, so we know we this operation is legit:
      st = forceRight st'
      resDt = SQLType . StrictType . Struct $ st
      ds2 = emptyDataset (NodeGroupedReduction ao) resDt `parents` [untyped ds]
  in ds2

_unrollGroupTrans :: UntypedGroupData -> ColOp -> PipedTrans
_unrollGroupTrans ugd co =
  let gco = colOp (_valueCol ugd) in case colOpNoBroadcast gco of
    Left x -> _pError $ "_unrollGroupTrans (1): using unimplemented feature:" <> show' x
    Right co' -> case _combineColOp co' co of
      -- TODO: this is ugly, we are loosing the error structure.
      Left x -> _pError $ "_unrollGroupTrans (2): failure with " <> show' x
      Right co'' -> case _groupCol $ _transformCol co'' (_valueCol ugd) of
        Left x -> _pError $ "_unrollGroupTrans (3): failure with " <> show' x
        Right g -> PipedGroup $ ugd { _gdValue = g }

-- TODO: this should be moved to ColumnFunctions
_transformCol :: ColOp -> UntypedColumnData -> UntypedColumnData
-- TODO: at this point, it should be checked for correctness (the fields
-- being extracted should exist)
_transformCol co ucd = ucd { _cOp = genColOp co }

-- Takes a column operation and chain it with another column operation.
_combineColOp :: ColOp -> ColOp -> Try ColOp
_combineColOp _ (x @ (ColLit _ _)) = pure x
_combineColOp x (ColFunction fn v) =
  ColFunction fn <$> sequence (_combineColOp x <$> v)
_combineColOp x (ColExtraction fp) = _extractColOp x (V.toList (unFieldPath fp))
_combineColOp x (ColStruct v) =
  ColStruct <$> sequence (f <$> v) where
    f (TransformField n val) = TransformField n <$> _combineColOp x val

_extractColOp :: ColOp -> [FieldName] -> Try ColOp
_extractColOp x [] = pure x
_extractColOp (ColStruct s) (fn : t) =
  case V.find (\x -> tfName x == fn) s of
    Just (TransformField _ co) ->
      _extractColOp co t
    Nothing ->
      tryError $ sformat ("Expected to find field "%sh%" in structure "%sh) fn s
_extractColOp x y =
  tryError $ sformat ("Cannot perform extraction "%sh%" on column operation "%sh) y x

_aggKey :: UntypedGroupData -> (UntypedColumnData -> Try UntypedLocalData) -> Try UntypedDataset
_aggKey ugd f =
  let inputDt = unSQLType . colType . _valueCol $ ugd
      p = placeholder inputDt :: UntypedDataset
      startNid = nodeId p in do
  uld <- f (_unsafeCastColData (asCol p))
  case _unrollTransform (PipedGroup ugd) startNid (untyped uld) of
    PipedError t -> tryError t
    PipedGroup g ->
      -- This is a programming error
      tryError $ sformat ("Expected a dataframe at the output but got a group: "%sh) g
    PipedDataset ds -> pure ds

_unsafeCastColData :: Column ref a -> Column ref' a'
_unsafeCastColData c = c { _cType = _cType c }

{-| Checks that the group can be cast.
-}
_castGroup ::
  SQLType key -> SQLType val -> UntypedGroupData -> Try (GroupData key val)
_castGroup (SQLType keyType) (SQLType valType) ugd =
  let keyType' = unSQLType . colType . _keyCol $ ugd
      valType' = unSQLType . colType . _valueCol $ ugd in
  if keyType == keyType'
  then if valType == valType'
    then
      pure ugd { _gdRef = _gdRef ugd }
    else
      tryError $ sformat ("The value column (of type "%sh%") cannot be cast to type "%sh) valType' valType
  else
    tryError $ sformat ("The value column (of type "%sh%") cannot be cast to type "%sh) keyType' keyType

_untypedGroup :: GroupData key val -> UntypedGroupData
_untypedGroup gd = gd { _gdRef = _gdRef gd }

_groupByKey :: UntypedColumnData -> UntypedColumnData -> LogicalGroupData
_groupByKey keys vals =
  if nodeId (colOrigin keys) == nodeId (colOrigin vals)
  then
    -- Get the latest data (packed)
    -- TODO: put a scoping
    let s = struct (keys, vals) :: Column UnknownReference (Cell, Cell)
        ds = pack1 s
        keys' = ds // _1
        vals' = ds // _2
    in do
      gKeys <- _groupCol keys'
      gVals <- _groupCol vals'
      return GroupData {
                _gdRef = colOrigin keys',
                _gdKey = gKeys,
                _gdValue = gVals
              }
  else
    tryError $ sformat ("The columns have different origin: "%sh%" and "%sh) keys vals

_groupCol :: Column ref a -> Try GroupColumn
_groupCol c = do
  co <- colOpNoBroadcast (colOp c)
  return GroupColumn {
            _gcType = unSQLType $ colType c,
            _gcOp = co,
            _gcRefName = Nothing
           }