{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleContexts #-}

-- A number of standard aggregation functions.

module Spark.Core.Internal.AggregationFunctions where

import Data.Aeson(Value(Null))
import qualified Data.Text as T
import Formatting

import Spark.Core.Internal.DatasetStructures
import Spark.Core.Internal.ColumnStructures
import Spark.Core.Internal.DatasetFunctions
import Spark.Core.Internal.RowGenerics(ToSQL)
import Spark.Core.Internal.LocalDataFunctions()
import Spark.Core.Internal.FunctionsInternals
import Spark.Core.Internal.OpStructures
import Spark.Core.Internal.Utilities(failure, HasCallStack)
import Spark.Core.Internal.TypesStructures
import Spark.Core.Types

{-| The sum of all the elements in a column.

If the data type is too small to represent the sum, the value being returned is
undefined.
-}
colSum :: forall ref a. (Num a, SQLTypeable a, ToSQL a) =>
  Column ref a -> LocalData a
colSum = applyUniAgg (_sumAgg :: UniversalAggregator a a)

{-| The number of elements in a column.

-}
-- TODO use Long for the return data type.
count :: forall a. (SQLTypeable a) => Dataset a -> LocalData Int
count ds = applyUniAgg (_countAgg2 :: UniversalAggregator a Int) (asCol ds)

{-| Collects all the elements of a column into a list.

NOTE:
This list is sorted in the canonical ordering of the data type: however the
data may be stored by Spark, the result will always be in the same order.
This is a departure from Spark, which does not guarantee an ordering on
the returned data.
-}
collect :: forall ref a. (SQLTypeable a) => Column ref a -> LocalData [a]
collect = applyUniAgg (_collectAgg :: UniversalAggregator a [a])

{-|
This is the universal aggregator: the invariant aggregator and
some extra laws to combine multiple outputs.
It is useful for combining the results over multiple passes.
A real implementation in Spark has also an inner pass.
-}
data UniversalAggregator a buff = UniversalAggregator {
  -- The result is partioning invariant
  uaInitialOuter :: Dataset a -> LocalData buff,
  -- This operation is associative and commutative
  -- The logical parents of the final observable have to be the 2 inputs
  uaMergeBuffer :: LocalData buff -> LocalData buff -> LocalData buff
}

-- | (internal)
univAggToOp :: forall a buff. (SQLTypeable a, SQLTypeable buff) =>
  UniversalAggregator a buff -> UniversalAggregatorOp
univAggToOp = univAggToOpTyped (buildType :: SQLType a) (buildType :: SQLType buff)

-- | (internal)
univAggToOpTyped :: forall a buff.
  SQLType a ->
  SQLType buff ->
  UniversalAggregator a buff ->
  UniversalAggregatorOp
univAggToOpTyped sqlta sqltm ua =
  let
    mt = unSQLType sqltm
    outer = _unsafeExtractOp $ fun1ToOpTyped sqlta (uaInitialOuter ua)
    merge = _unsafeExtractOp $ fun2ToOpTyped sqltm sqltm (uaMergeBuffer ua)
  in UniversalAggregatorOp {
    uaoMergeType = mt,
    uaoInitialOuter = outer,
    uaoMergeBuffer = merge
  }

-- | (internal)
applyUniAgg :: UniversalAggregator a b -> Column ref a -> LocalData b
applyUniAgg ua c =
  let
    ds = pack1 c
    ld1 = uaInitialOuter ua ds
    -- TODO understand how to pass this info
    -- aggop = univAggToOpTyped (nodeType ds) (nodeType ld1) ua
    -- ld = emptyLocalData (NodeUniversalAggregator aggop) (nodeType ld1)
  in ld1


-- (internal)
simpleOp1Typed :: (IsLocality loca, IsLocality locb) =>
  SQLType b ->
  T.Text ->
  ComputeNode loca a -> ComputeNode locb b
simpleOp1Typed sqltb name =
  let so = StandardOperator {
             soName = name,
             soOutputType = unSQLType sqltb,
             soExtra = Null
           }
      no = NodeLocalOp so
  in nodeOpToFun1Typed sqltb no

-- (internal)
simpleOp1 :: forall a b loca locb. (IsLocality loca, IsLocality locb, SQLTypeable a, SQLTypeable b) =>
  T.Text ->
  ComputeNode loca a -> ComputeNode locb b
simpleOp1 = simpleOp1Typed (buildType :: SQLType b)

-- (internal)
simpleOp2 :: forall a1 a2 b loc1 loc2 locb. (SQLTypeable b, IsLocality loc1, IsLocality loc2, IsLocality locb) =>
  T.Text ->
  ComputeNode loc1 a1 -> ComputeNode loc2 a2 -> ComputeNode locb b
simpleOp2 = simpleOp2Typed (buildType :: SQLType b)

-- (internal)
simpleOp2Typed :: (IsLocality loc1, IsLocality loc2, IsLocality locb) =>
  SQLType b ->
  T.Text ->
  ComputeNode loc1 a1 -> ComputeNode loc2 a2 -> ComputeNode locb b
simpleOp2Typed sqltb name =
  let so = StandardOperator {
             soName = name,
             soOutputType = unSQLType sqltb,
             soExtra = Null
           }
      no = NodeLocalOp so
  in nodeOpToFun2Typed sqltb no

_unsafeExtractOp :: (HasCallStack) => NodeOp -> StandardOperator
_unsafeExtractOp (NodeLocalOp so) = so
_unsafeExtractOp (NodeOpaqueAggregator so) = so
_unsafeExtractOp (NodeDistributedOp so) = so
_unsafeExtractOp x = failure $ sformat ("Expected standard op, found "%shown) x

_countAgg2 :: (SQLTypeable a) => UniversalAggregator a Int
_countAgg2 = UniversalAggregator {
    uaInitialOuter = simpleOp1 "org.spark.Count",
    uaMergeBuffer = (+)
  }

_sumAgg :: forall a. (SQLTypeable a, Num a, ToSQL a) => UniversalAggregator a a
_sumAgg = UniversalAggregator {
    uaInitialOuter = simpleOp1 "org.spark.Sum",
    uaMergeBuffer = (+)
  }

_collectAgg :: forall a. SQLTypeable a => UniversalAggregator a [a]
_collectAgg =
  UniversalAggregator {
    uaInitialOuter = simpleOp1 "org.spark.Collect",
    uaMergeBuffer = simpleOp2 "org.spark.CatSorted"
  }