//===----------------------------------------------------------------------===// // DuckDB // // duckdb/function/udf_function.hpp // // //===----------------------------------------------------------------------===// #pragma once #include "duckdb/function/scalar_function.hpp" #include "duckdb/function/aggregate_function.hpp" namespace duckdb { struct UDFWrapper { public: template inline static scalar_function_t CreateScalarFunction(const string &name, TR (*udf_func)(Args...)) { const std::size_t num_template_argc = sizeof...(Args); switch (num_template_argc) { case 1: return CreateUnaryFunction(name, udf_func); case 2: return CreateBinaryFunction(name, udf_func); case 3: return CreateTernaryFunction(name, udf_func); default: // LCOV_EXCL_START throw std::runtime_error("UDF function only supported until ternary!"); } // LCOV_EXCL_STOP } template inline static scalar_function_t CreateScalarFunction(const string &name, vector args, LogicalType ret_type, TR (*udf_func)(Args...)) { if (!TypesMatch(ret_type)) { // LCOV_EXCL_START throw std::runtime_error("Return type doesn't match with the first template type."); } // LCOV_EXCL_STOP const std::size_t num_template_types = sizeof...(Args); if (num_template_types != args.size()) { // LCOV_EXCL_START throw std::runtime_error( "The number of templated types should be the same quantity of the LogicalType arguments."); } // LCOV_EXCL_STOP switch (num_template_types) { case 1: return CreateUnaryFunction(name, args, ret_type, udf_func); case 2: return CreateBinaryFunction(name, args, ret_type, udf_func); case 3: return CreateTernaryFunction(name, args, ret_type, udf_func); default: // LCOV_EXCL_START throw std::runtime_error("UDF function only supported until ternary!"); } // LCOV_EXCL_STOP } template inline static void RegisterFunction(const string &name, scalar_function_t udf_function, ClientContext &context, LogicalType varargs = LogicalType(LogicalTypeId::INVALID)) { vector arguments; GetArgumentTypesRecursive(arguments); LogicalType ret_type = GetArgumentType(); RegisterFunction(name, arguments, ret_type, udf_function, context, varargs); } static void RegisterFunction(string name, vector args, LogicalType ret_type, scalar_function_t udf_function, ClientContext &context, LogicalType varargs = LogicalType(LogicalTypeId::INVALID)); //--------------------------------- Aggregate UDFs ------------------------------------// template inline static AggregateFunction CreateAggregateFunction(const string &name) { return CreateUnaryAggregateFunction(name); } template inline static AggregateFunction CreateAggregateFunction(const string &name) { return CreateBinaryAggregateFunction(name); } template inline static AggregateFunction CreateAggregateFunction(const string &name, LogicalType ret_type, LogicalType input_type) { if (!TypesMatch(ret_type)) { // LCOV_EXCL_START throw std::runtime_error("The return argument don't match!"); } // LCOV_EXCL_STOP if (!TypesMatch(input_type)) { // LCOV_EXCL_START throw std::runtime_error("The input argument don't match!"); } // LCOV_EXCL_STOP return CreateUnaryAggregateFunction(name, ret_type, input_type); } template inline static AggregateFunction CreateAggregateFunction(const string &name, LogicalType ret_type, LogicalType input_typeA, LogicalType input_typeB) { if (!TypesMatch(ret_type)) { // LCOV_EXCL_START throw std::runtime_error("The return argument don't match!"); } if (!TypesMatch(input_typeA)) { throw std::runtime_error("The first input argument don't match!"); } if (!TypesMatch(input_typeB)) { throw std::runtime_error("The second input argument don't match!"); } // LCOV_EXCL_STOP return CreateBinaryAggregateFunction(name, ret_type, input_typeA, input_typeB); } //! A generic CreateAggregateFunction ---------------------------------------------------------------------------// inline static AggregateFunction CreateAggregateFunction(string name, vector arguments, LogicalType return_type, aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update, aggregate_combine_t combine, aggregate_finalize_t finalize, aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr) { AggregateFunction aggr_function(std::move(name), std::move(arguments), std::move(return_type), state_size, initialize, update, combine, finalize, simple_update, bind, destructor); aggr_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; return aggr_function; } static void RegisterAggrFunction(AggregateFunction aggr_function, ClientContext &context, LogicalType varargs = LogicalType(LogicalTypeId::INVALID)); private: //-------------------------------- Templated functions --------------------------------// struct UnaryUDFExecutor { template static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { typedef RESULT_TYPE (*unary_function_t)(INPUT_TYPE); auto udf = (unary_function_t)dataptr; return udf(input); } }; template inline static scalar_function_t CreateUnaryFunction(const string &name, TR (*udf_func)(TA)) { scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { UnaryExecutor::GenericExecute(input.data[0], result, input.size(), (void *)udf_func); }; return udf_function; } template inline static scalar_function_t CreateBinaryFunction(const string &name, TR (*udf_func)(TA, TB)) { scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { BinaryExecutor::Execute(input.data[0], input.data[1], result, input.size(), udf_func); }; return udf_function; } template inline static scalar_function_t CreateTernaryFunction(const string &name, TR (*udf_func)(TA, TB, TC)) { scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { TernaryExecutor::Execute(input.data[0], input.data[1], input.data[2], result, input.size(), udf_func); }; return udf_function; } template inline static scalar_function_t CreateUnaryFunction(const string &name, TR (*udf_func)(Args...)) { // LCOV_EXCL_START throw std::runtime_error("Incorrect number of arguments for unary function"); } // LCOV_EXCL_STOP template inline static scalar_function_t CreateBinaryFunction(const string &name, TR (*udf_func)(Args...)) { // LCOV_EXCL_START throw std::runtime_error("Incorrect number of arguments for binary function"); } // LCOV_EXCL_STOP template inline static scalar_function_t CreateTernaryFunction(const string &name, TR (*udf_func)(Args...)) { // LCOV_EXCL_START throw std::runtime_error("Incorrect number of arguments for ternary function"); } // LCOV_EXCL_STOP template inline static LogicalType GetArgumentType() { if (std::is_same()) { return LogicalType(LogicalTypeId::BOOLEAN); } else if (std::is_same()) { return LogicalType(LogicalTypeId::TINYINT); } else if (std::is_same()) { return LogicalType(LogicalTypeId::SMALLINT); } else if (std::is_same()) { return LogicalType(LogicalTypeId::INTEGER); } else if (std::is_same()) { return LogicalType(LogicalTypeId::BIGINT); } else if (std::is_same()) { return LogicalType(LogicalTypeId::FLOAT); } else if (std::is_same()) { return LogicalType(LogicalTypeId::DOUBLE); } else if (std::is_same()) { return LogicalType(LogicalTypeId::VARCHAR); } else { // LCOV_EXCL_START throw std::runtime_error("Unrecognized type!"); } // LCOV_EXCL_STOP } template inline static void GetArgumentTypesRecursive(vector &arguments) { arguments.push_back(GetArgumentType()); GetArgumentTypesRecursive(arguments); } template inline static void GetArgumentTypesRecursive(vector &arguments) { arguments.push_back(GetArgumentType()); } private: //-------------------------------- Argumented functions --------------------------------// template inline static scalar_function_t CreateUnaryFunction(const string &name, vector args, LogicalType ret_type, TR (*udf_func)(Args...)) { // LCOV_EXCL_START throw std::runtime_error("Incorrect number of arguments for unary function"); } // LCOV_EXCL_STOP template inline static scalar_function_t CreateUnaryFunction(const string &name, vector args, LogicalType ret_type, TR (*udf_func)(TA)) { if (args.size() != 1) { // LCOV_EXCL_START throw std::runtime_error("The number of LogicalType arguments (\"args\") should be 1!"); } if (!TypesMatch(args[0])) { throw std::runtime_error("The first arguments don't match!"); } // LCOV_EXCL_STOP scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { UnaryExecutor::GenericExecute(input.data[0], result, input.size(), (void *)udf_func); }; return udf_function; } template inline static scalar_function_t CreateBinaryFunction(const string &name, vector args, LogicalType ret_type, TR (*udf_func)(Args...)) { // LCOV_EXCL_START throw std::runtime_error("Incorrect number of arguments for binary function"); } // LCOV_EXCL_STOP template inline static scalar_function_t CreateBinaryFunction(const string &name, vector args, LogicalType ret_type, TR (*udf_func)(TA, TB)) { if (args.size() != 2) { // LCOV_EXCL_START throw std::runtime_error("The number of LogicalType arguments (\"args\") should be 2!"); } if (!TypesMatch(args[0])) { throw std::runtime_error("The first arguments don't match!"); } if (!TypesMatch(args[1])) { throw std::runtime_error("The second arguments don't match!"); } // LCOV_EXCL_STOP scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) { BinaryExecutor::Execute(input.data[0], input.data[1], result, input.size(), udf_func); }; return udf_function; } template inline static scalar_function_t CreateTernaryFunction(const string &name, vector args, LogicalType ret_type, TR (*udf_func)(Args...)) { // LCOV_EXCL_START throw std::runtime_error("Incorrect number of arguments for ternary function"); } // LCOV_EXCL_STOP template inline static scalar_function_t CreateTernaryFunction(const string &name, vector args, LogicalType ret_type, TR (*udf_func)(TA, TB, TC)) { if (args.size() != 3) { // LCOV_EXCL_START throw std::runtime_error("The number of LogicalType arguments (\"args\") should be 3!"); } if (!TypesMatch(args[0])) { throw std::runtime_error("The first arguments don't match!"); } if (!TypesMatch(args[1])) { throw std::runtime_error("The second arguments don't match!"); } if (!TypesMatch(args[2])) { throw std::runtime_error("The second arguments don't match!"); } // LCOV_EXCL_STOP scalar_function_t udf_function = [=](DataChunk &input, ExpressionState &state, Vector &result) -> void { TernaryExecutor::Execute(input.data[0], input.data[1], input.data[2], result, input.size(), udf_func); }; return udf_function; } template inline static bool TypesMatch(const LogicalType &sql_type) { switch (sql_type.id()) { case LogicalTypeId::BOOLEAN: return std::is_same(); case LogicalTypeId::TINYINT: return std::is_same(); case LogicalTypeId::SMALLINT: return std::is_same(); case LogicalTypeId::INTEGER: return std::is_same(); case LogicalTypeId::BIGINT: return std::is_same(); case LogicalTypeId::DATE: return std::is_same(); case LogicalTypeId::TIME: case LogicalTypeId::TIME_TZ: return std::is_same(); case LogicalTypeId::TIMESTAMP: case LogicalTypeId::TIMESTAMP_MS: case LogicalTypeId::TIMESTAMP_NS: case LogicalTypeId::TIMESTAMP_SEC: case LogicalTypeId::TIMESTAMP_TZ: return std::is_same(); case LogicalTypeId::FLOAT: return std::is_same(); case LogicalTypeId::DOUBLE: return std::is_same(); case LogicalTypeId::VARCHAR: case LogicalTypeId::CHAR: case LogicalTypeId::BLOB: return std::is_same(); default: // LCOV_EXCL_START throw std::runtime_error("Type is not supported!"); } // LCOV_EXCL_STOP } private: //-------------------------------- Aggregate functions --------------------------------// template inline static AggregateFunction CreateUnaryAggregateFunction(const string &name) { LogicalType return_type = GetArgumentType(); LogicalType input_type = GetArgumentType(); return CreateUnaryAggregateFunction(name, return_type, input_type); } template inline static AggregateFunction CreateUnaryAggregateFunction(const string &name, LogicalType ret_type, LogicalType input_type) { AggregateFunction aggr_function = AggregateFunction::UnaryAggregate(input_type, ret_type); aggr_function.name = name; return aggr_function; } template inline static AggregateFunction CreateBinaryAggregateFunction(const string &name) { LogicalType return_type = GetArgumentType(); LogicalType input_typeA = GetArgumentType(); LogicalType input_typeB = GetArgumentType(); return CreateBinaryAggregateFunction(name, return_type, input_typeA, input_typeB); } template inline static AggregateFunction CreateBinaryAggregateFunction(const string &name, LogicalType ret_type, LogicalType input_typeA, LogicalType input_typeB) { AggregateFunction aggr_function = AggregateFunction::BinaryAggregate(input_typeA, input_typeB, ret_type); aggr_function.name = name; return aggr_function; } }; // end UDFWrapper } // namespace duckdb