//===----------------------------------------------------------------------===// // DuckDB // // duckdb/function/aggregate_function.hpp // // //===----------------------------------------------------------------------===// #pragma once #include "duckdb/function/aggregate_state.hpp" #include "duckdb/planner/bound_result_modifier.hpp" #include "duckdb/planner/expression.hpp" #include "duckdb/common/vector_operations/aggregate_executor.hpp" namespace duckdb { //! The type used for sizing hashed aggregate function states typedef idx_t (*aggregate_size_t)(); //! The type used for initializing hashed aggregate function states typedef void (*aggregate_initialize_t)(data_ptr_t state); //! The type used for updating hashed aggregate functions typedef void (*aggregate_update_t)(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &state, idx_t count); //! The type used for combining hashed aggregate states typedef void (*aggregate_combine_t)(Vector &state, Vector &combined, AggregateInputData &aggr_input_data, idx_t count); //! The type used for finalizing hashed aggregate function payloads typedef void (*aggregate_finalize_t)(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset); //! The type used for propagating statistics in aggregate functions (optional) typedef unique_ptr (*aggregate_statistics_t)(ClientContext &context, BoundAggregateExpression &expr, AggregateStatisticsInput &input); //! Binds the scalar function and creates the function data typedef unique_ptr (*bind_aggregate_function_t)(ClientContext &context, AggregateFunction &function, vector> &arguments); //! The type used for the aggregate destructor method. NOTE: this method is used in destructors and MAY NOT throw. typedef void (*aggregate_destructor_t)(Vector &state, AggregateInputData &aggr_input_data, idx_t count); //! The type used for updating simple (non-grouped) aggregate functions typedef void (*aggregate_simple_update_t)(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, idx_t count); //! The type used for updating complex windowed aggregate functions (optional) typedef std::pair FrameBounds; typedef void (*aggregate_window_t)(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev, Vector &result, idx_t rid, idx_t bias); typedef void (*aggregate_serialize_t)(FieldWriter &writer, const FunctionData *bind_data, const AggregateFunction &function); typedef unique_ptr (*aggregate_deserialize_t)(PlanDeserializationState &context, FieldReader &reader, AggregateFunction &function); class AggregateFunction : public BaseScalarFunction { public: AggregateFunction(const string &name, const vector &arguments, const LogicalType &return_type, aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update, aggregate_combine_t combine, aggregate_finalize_t finalize, FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING, aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr, aggregate_window_t window = nullptr, aggregate_serialize_t serialize = nullptr, aggregate_deserialize_t deserialize = nullptr) : BaseScalarFunction(name, arguments, return_type, FunctionSideEffects::NO_SIDE_EFFECTS, LogicalType(LogicalTypeId::INVALID), null_handling), state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize), simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics), serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { } AggregateFunction(const string &name, const vector &arguments, const LogicalType &return_type, aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update, aggregate_combine_t combine, aggregate_finalize_t finalize, aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr, aggregate_window_t window = nullptr, aggregate_serialize_t serialize = nullptr, aggregate_deserialize_t deserialize = nullptr) : BaseScalarFunction(name, arguments, return_type, FunctionSideEffects::NO_SIDE_EFFECTS, LogicalType(LogicalTypeId::INVALID)), state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize), simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics), serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { } AggregateFunction(const vector &arguments, const LogicalType &return_type, aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update, aggregate_combine_t combine, aggregate_finalize_t finalize, FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING, aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr, aggregate_window_t window = nullptr, aggregate_serialize_t serialize = nullptr, aggregate_deserialize_t deserialize = nullptr) : AggregateFunction(string(), arguments, return_type, state_size, initialize, update, combine, finalize, null_handling, simple_update, bind, destructor, statistics, window, serialize, deserialize) { } AggregateFunction(const vector &arguments, const LogicalType &return_type, aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update, aggregate_combine_t combine, aggregate_finalize_t finalize, aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr, aggregate_window_t window = nullptr, aggregate_serialize_t serialize = nullptr, aggregate_deserialize_t deserialize = nullptr) : AggregateFunction(string(), arguments, return_type, state_size, initialize, update, combine, finalize, FunctionNullHandling::DEFAULT_NULL_HANDLING, simple_update, bind, destructor, statistics, window, serialize, deserialize) { } //! The hashed aggregate state sizing function aggregate_size_t state_size; //! The hashed aggregate state initialization function aggregate_initialize_t initialize; //! The hashed aggregate update state function aggregate_update_t update; //! The hashed aggregate combine states function aggregate_combine_t combine; //! The hashed aggregate finalization function aggregate_finalize_t finalize; //! The simple aggregate update function (may be null) aggregate_simple_update_t simple_update; //! The windowed aggregate frame update function (may be null) aggregate_window_t window; //! The bind function (may be null) bind_aggregate_function_t bind; //! The destructor method (may be null) aggregate_destructor_t destructor; //! The statistics propagation function (may be null) aggregate_statistics_t statistics; aggregate_serialize_t serialize; aggregate_deserialize_t deserialize; //! Whether or not the aggregate is order dependent AggregateOrderDependent order_dependent; bool operator==(const AggregateFunction &rhs) const { return state_size == rhs.state_size && initialize == rhs.initialize && update == rhs.update && combine == rhs.combine && finalize == rhs.finalize && window == rhs.window; } bool operator!=(const AggregateFunction &rhs) const { return !(*this == rhs); } public: template static AggregateFunction NullaryAggregate(LogicalType return_type) { return AggregateFunction( {}, return_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, AggregateFunction::NullaryScatterUpdate, AggregateFunction::StateCombine, AggregateFunction::StateFinalize, AggregateFunction::NullaryUpdate); } template static AggregateFunction UnaryAggregate(const LogicalType &input_type, LogicalType return_type, FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING) { return AggregateFunction( {input_type}, return_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, AggregateFunction::StateFinalize, null_handling, AggregateFunction::UnaryUpdate); } template static AggregateFunction UnaryAggregateDestructor(LogicalType input_type, LogicalType return_type) { auto aggregate = UnaryAggregate(input_type, return_type); aggregate.destructor = AggregateFunction::StateDestroy; return aggregate; } template static AggregateFunction BinaryAggregate(const LogicalType &a_type, const LogicalType &b_type, LogicalType return_type) { return AggregateFunction({a_type, b_type}, return_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, AggregateFunction::BinaryScatterUpdate, AggregateFunction::StateCombine, AggregateFunction::StateFinalize, AggregateFunction::BinaryUpdate); } public: template static idx_t StateSize() { return sizeof(STATE); } template static void StateInitialize(data_ptr_t state) { OP::Initialize(*reinterpret_cast(state)); } template static void NullaryScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, idx_t count) { D_ASSERT(input_count == 0); AggregateExecutor::NullaryScatter(states, aggr_input_data, count); } template static void NullaryUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, idx_t count) { D_ASSERT(input_count == 0); AggregateExecutor::NullaryUpdate(state, aggr_input_data, count); } template static void UnaryScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, idx_t count) { D_ASSERT(input_count == 1); AggregateExecutor::UnaryScatter(inputs[0], states, aggr_input_data, count); } template static void UnaryUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, idx_t count) { D_ASSERT(input_count == 1); AggregateExecutor::UnaryUpdate(inputs[0], aggr_input_data, state, count); } template static void UnaryWindow(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev, Vector &result, idx_t rid, idx_t bias) { D_ASSERT(input_count == 1); AggregateExecutor::UnaryWindow(inputs[0], filter_mask, aggr_input_data, state, frame, prev, result, rid, bias); } template static void BinaryScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, idx_t count) { D_ASSERT(input_count == 2); AggregateExecutor::BinaryScatter(aggr_input_data, inputs[0], inputs[1], states, count); } template static void BinaryUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, idx_t count) { D_ASSERT(input_count == 2); AggregateExecutor::BinaryUpdate(aggr_input_data, inputs[0], inputs[1], state, count); } template static void StateCombine(Vector &source, Vector &target, AggregateInputData &aggr_input_data, idx_t count) { AggregateExecutor::Combine(source, target, aggr_input_data, count); } template static void StateFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { AggregateExecutor::Finalize(states, aggr_input_data, result, count, offset); } template static void StateVoidFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { AggregateExecutor::VoidFinalize(states, aggr_input_data, result, count, offset); } template static void StateDestroy(Vector &states, AggregateInputData &aggr_input_data, idx_t count) { AggregateExecutor::Destroy(states, aggr_input_data, count); } }; } // namespace duckdb