//===----------------------------------------------------------------------===// // DuckDB // // duckdb/common/vector_operations/unary_executor.hpp // // //===----------------------------------------------------------------------===// #pragma once #include "duckdb/common/exception.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include namespace duckdb { struct UnaryOperatorWrapper { template static inline RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { return OP::template Operation(input); } }; struct UnaryLambdaWrapper { template static inline RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { auto fun = (FUNC *)dataptr; return (*fun)(input); } }; struct GenericUnaryWrapper { template static inline RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { return OP::template Operation(input, mask, idx, dataptr); } }; struct UnaryLambdaWrapperWithNulls { template static inline RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { auto fun = (FUNC *)dataptr; return (*fun)(input, mask, idx); } }; template struct UnaryStringOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { auto vector = (Vector *)dataptr; return OP::template Operation(input, *vector); } }; struct UnaryExecutor { private: template static inline void ExecuteLoop(const INPUT_TYPE *__restrict ldata, RESULT_TYPE *__restrict result_data, idx_t count, const SelectionVector *__restrict sel_vector, ValidityMask &mask, ValidityMask &result_mask, void *dataptr, bool adds_nulls) { #ifdef DEBUG // ldata may point to a compressed dictionary buffer which can be smaller than ldata + count idx_t max_index = 0; for (idx_t i = 0; i < count; i++) { auto idx = sel_vector->get_index(i); max_index = MaxValue(max_index, idx); } ASSERT_RESTRICT(ldata, ldata + max_index, result_data, result_data + count); #endif if (!mask.AllValid()) { result_mask.EnsureWritable(); for (idx_t i = 0; i < count; i++) { auto idx = sel_vector->get_index(i); if (mask.RowIsValidUnsafe(idx)) { result_data[i] = OPWRAPPER::template Operation(ldata[idx], result_mask, i, dataptr); } else { result_mask.SetInvalid(i); } } } else { if (adds_nulls) { result_mask.EnsureWritable(); } for (idx_t i = 0; i < count; i++) { auto idx = sel_vector->get_index(i); result_data[i] = OPWRAPPER::template Operation(ldata[idx], result_mask, i, dataptr); } } } template static inline void ExecuteFlat(const INPUT_TYPE *__restrict ldata, RESULT_TYPE *__restrict result_data, idx_t count, ValidityMask &mask, ValidityMask &result_mask, void *dataptr, bool adds_nulls) { ASSERT_RESTRICT(ldata, ldata + count, result_data, result_data + count); if (!mask.AllValid()) { if (!adds_nulls) { result_mask.Initialize(mask); } else { result_mask.Copy(mask, count); } idx_t base_idx = 0; auto entry_count = ValidityMask::EntryCount(count); for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { auto validity_entry = mask.GetValidityEntry(entry_idx); idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); if (ValidityMask::AllValid(validity_entry)) { // all valid: perform operation for (; base_idx < next; base_idx++) { result_data[base_idx] = OPWRAPPER::template Operation( ldata[base_idx], result_mask, base_idx, dataptr); } } else if (ValidityMask::NoneValid(validity_entry)) { // nothing valid: skip all base_idx = next; continue; } else { // partially valid: need to check individual elements for validity idx_t start = base_idx; for (; base_idx < next; base_idx++) { if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { D_ASSERT(mask.RowIsValid(base_idx)); result_data[base_idx] = OPWRAPPER::template Operation( ldata[base_idx], result_mask, base_idx, dataptr); } } } } } else { if (adds_nulls) { result_mask.EnsureWritable(); } for (idx_t i = 0; i < count; i++) { result_data[i] = OPWRAPPER::template Operation(ldata[i], result_mask, i, dataptr); } } } template static inline void ExecuteStandard(Vector &input, Vector &result, idx_t count, void *dataptr, bool adds_nulls) { switch (input.GetVectorType()) { case VectorType::CONSTANT_VECTOR: { result.SetVectorType(VectorType::CONSTANT_VECTOR); auto result_data = ConstantVector::GetData(result); auto ldata = ConstantVector::GetData(input); if (ConstantVector::IsNull(input)) { ConstantVector::SetNull(result, true); } else { ConstantVector::SetNull(result, false); *result_data = OPWRAPPER::template Operation( *ldata, ConstantVector::Validity(result), 0, dataptr); } break; } case VectorType::FLAT_VECTOR: { result.SetVectorType(VectorType::FLAT_VECTOR); auto result_data = FlatVector::GetData(result); auto ldata = FlatVector::GetData(input); ExecuteFlat(ldata, result_data, count, FlatVector::Validity(input), FlatVector::Validity(result), dataptr, adds_nulls); break; } default: { UnifiedVectorFormat vdata; input.ToUnifiedFormat(count, vdata); result.SetVectorType(VectorType::FLAT_VECTOR); auto result_data = FlatVector::GetData(result); auto ldata = UnifiedVectorFormat::GetData(vdata); ExecuteLoop(ldata, result_data, count, vdata.sel, vdata.validity, FlatVector::Validity(result), dataptr, adds_nulls); break; } } } public: template static void Execute(Vector &input, Vector &result, idx_t count) { ExecuteStandard(input, result, count, nullptr, false); } template > static void Execute(Vector &input, Vector &result, idx_t count, FUNC fun) { ExecuteStandard(input, result, count, (void *)&fun, false); } template static void GenericExecute(Vector &input, Vector &result, idx_t count, void *dataptr, bool adds_nulls = false) { ExecuteStandard(input, result, count, dataptr, adds_nulls); } template > static void ExecuteWithNulls(Vector &input, Vector &result, idx_t count, FUNC fun) { ExecuteStandard(input, result, count, (void *)&fun, true); } template static void ExecuteString(Vector &input, Vector &result, idx_t count) { UnaryExecutor::GenericExecute>(input, result, count, (void *)&result); } }; } // namespace duckdb