{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE FlexibleContexts #-} -- A number of standard aggregation functions. module Spark.Core.Internal.AggregationFunctions( -- Standard library collect, collect', count, count', sumCol, sumCol', -- Developer functions AggTry, UniversalAggregator(..), applyUAOUnsafe, applyUntypedUniAgg3 ) where import Data.Aeson(Value(Null)) import qualified Data.Text as T import qualified Data.Vector as V import Spark.Core.Internal.DatasetStructures import Spark.Core.Internal.ColumnStructures import Spark.Core.Internal.ColumnFunctions(colType, untypedCol) import Spark.Core.Internal.DatasetFunctions import Spark.Core.Internal.RowGenerics(ToSQL) import Spark.Core.Internal.LocalDataFunctions() import Spark.Core.Internal.FunctionsInternals import Spark.Core.Internal.OpStructures import Spark.Core.Internal.TypesStructures import Spark.Core.Internal.Utilities import Spark.Core.Internal.TypesFunctions(arrayType') import Spark.Core.StructuresInternal(emptyFieldPath) import Spark.Core.Types import Spark.Core.Try {-| The sum of all the elements in a column. If the data type is too small to represent the sum, the value being returned is undefined. -} sumCol :: forall ref a. (Num a, SQLTypeable a, ToSQL a) => Column ref a -> LocalData a sumCol = applyUAOUnsafe _sumAgg' sumCol' :: DynColumn -> LocalFrame sumCol' = applyUntypedUniAgg3 _sumAgg' {-| The number of elements in a column. -} -- TODO use Long for the return data type. count :: forall a. Dataset a -> LocalData Int count ds = applyUAOUnsafe _countAgg' (asCol ds) count' :: DataFrame -> LocalFrame count' df = applyUntypedUniAgg3 _countAgg' (asCol' df) {-| Collects all the elements of a column into a list. NOTE: This list is sorted in the canonical ordering of the data type: however the data may be stored by Spark, the result will always be in the same order. This is a departure from Spark, which does not guarantee an ordering on the returned data. -} collect :: forall ref a. (SQLTypeable a) => Column ref a -> LocalData [a] collect = applyUAOUnsafe _collectAgg' {-| See the documentation of collect. -} collect' :: DynColumn -> LocalFrame collect' = applyUntypedUniAgg3 _collectAgg' type AggTry a = Either T.Text a {-| This is the universal aggregator: the invariant aggregator and some extra laws to combine multiple outputs. It is useful for combining the results over multiple passes. A real implementation in Spark has also an inner pass. -} data UniversalAggregator a buff = UniversalAggregator { uaMergeType :: SQLType buff, -- The result is partioning invariant uaInitialOuter :: Dataset a -> LocalData buff, -- This operation is associative and commutative -- The logical parents of the final observable have to be the 2 inputs uaMergeBuffer :: LocalData buff -> LocalData buff -> LocalData buff } -- TODO(kps) check the coming type for non-summable types _sumAgg' :: DataType -> AggTry UniversalAggregatorOp _sumAgg' _ = pure UniversalAggregatorOp { -- TODO(kps) switch to BigInt uaoMergeType = StrictType IntType, uaoInitialOuter = InnerAggOp $ AggFunction "SUM" (V.singleton emptyFieldPath), uaoMergeBuffer = ColumnSemiGroupLaw "SUM" } _countAgg' :: DataType -> AggTry UniversalAggregatorOp -- Counting will always succeed. _countAgg' _ = pure UniversalAggregatorOp { -- TODO(kps) switch to BigInt uaoMergeType = StrictType IntType, uaoInitialOuter = InnerAggOp $ AggFunction "COUNT" (V.singleton emptyFieldPath), uaoMergeBuffer = ColumnSemiGroupLaw "SUM" } _collectAgg' :: DataType -> AggTry UniversalAggregatorOp -- Counting will always succeed. _collectAgg' dt = let ldt = arrayType' dt soMerge = StandardOperator { soName = "org.spark.Collect", soOutputType = ldt, soExtra = Null } soMono = StandardOperator { soName = "org.spark.CatSorted", soOutputType = ldt, soExtra = Null } in pure UniversalAggregatorOp { -- TODO(kps) switch to BigInt uaoMergeType = ldt, uaoInitialOuter = OpaqueAggTransform soMerge, uaoMergeBuffer = OpaqueSemiGroupLaw soMono } applyUntypedUniAgg3 :: (DataType -> AggTry UniversalAggregatorOp) -> DynColumn -> LocalFrame applyUntypedUniAgg3 f dc = do c <- dc let uaot = f . unSQLType . colType $ c uao <- tryEither uaot let no = NodeAggregatorReduction uao let ds = pack1 c return $ emptyLocalData no (SQLType (uaoMergeType uao)) `parents` [untyped ds] applyUAOUnsafe :: forall a b ref. (SQLTypeable b, HasCallStack) => (DataType -> AggTry UniversalAggregatorOp) -> Column ref a -> LocalData b applyUAOUnsafe f c = let lf = applyUntypedUniAgg3 f (untypedCol c) in forceRight (asObservable lf) -- _guardType :: DataType -> (UntypedDataset -> UntypedLocalData) -> (UntypedDataset -> LocalFrame) -- _guardType dt f ds = -- if unSQLType (nodeType ds) == dt -- then -- pure $ f ds -- else -- tryError $ sformat ("Expected type "%sh%" but got type "%sh) dt (nodeType ds)