| 1 | #include "duckdb/common/pair.hpp" |
| 2 | #include "duckdb/common/string_util.hpp" |
| 3 | #include "duckdb/common/types/chunk_collection.hpp" |
| 4 | #include "duckdb/common/types/data_chunk.hpp" |
| 5 | #include "duckdb/common/vector_operations/binary_executor.hpp" |
| 6 | #include "duckdb/function/scalar/nested_functions.hpp" |
| 7 | #include "duckdb/function/scalar/string_functions.hpp" |
| 8 | #include "duckdb/parser/expression/bound_expression.hpp" |
| 9 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
| 10 | #include "duckdb/storage/statistics/list_stats.hpp" |
| 11 | |
| 12 | namespace duckdb { |
| 13 | |
| 14 | template <class T, bool HEAP_REF = false, bool VALIDITY_ONLY = false> |
| 15 | void (idx_t count, UnifiedVectorFormat &list_data, UnifiedVectorFormat &offsets_data, |
| 16 | Vector &child_vector, idx_t list_size, Vector &result) { |
| 17 | UnifiedVectorFormat child_format; |
| 18 | child_vector.ToUnifiedFormat(count: list_size, data&: child_format); |
| 19 | |
| 20 | T *result_data; |
| 21 | |
| 22 | result.SetVectorType(VectorType::FLAT_VECTOR); |
| 23 | if (!VALIDITY_ONLY) { |
| 24 | result_data = FlatVector::GetData<T>(result); |
| 25 | } |
| 26 | auto &result_mask = FlatVector::Validity(vector&: result); |
| 27 | |
| 28 | // heap-ref once |
| 29 | if (HEAP_REF) { |
| 30 | StringVector::AddHeapReference(vector&: result, other&: child_vector); |
| 31 | } |
| 32 | |
| 33 | // this is lifted from ExecuteGenericLoop because we can't push the list child data into this otherwise |
| 34 | // should have gone with GetValue perhaps |
| 35 | auto child_data = UnifiedVectorFormat::GetData<T>(child_format); |
| 36 | for (idx_t i = 0; i < count; i++) { |
| 37 | auto list_index = list_data.sel->get_index(idx: i); |
| 38 | auto offsets_index = offsets_data.sel->get_index(idx: i); |
| 39 | if (!list_data.validity.RowIsValid(row_idx: list_index)) { |
| 40 | result_mask.SetInvalid(i); |
| 41 | continue; |
| 42 | } |
| 43 | if (!offsets_data.validity.RowIsValid(row_idx: offsets_index)) { |
| 44 | result_mask.SetInvalid(i); |
| 45 | continue; |
| 46 | } |
| 47 | auto list_entry = (UnifiedVectorFormat::GetData<list_entry_t>(format: list_data))[list_index]; |
| 48 | auto offsets_entry = (UnifiedVectorFormat::GetData<int64_t>(format: offsets_data))[offsets_index]; |
| 49 | |
| 50 | // 1-based indexing |
| 51 | if (offsets_entry == 0) { |
| 52 | result_mask.SetInvalid(i); |
| 53 | continue; |
| 54 | } |
| 55 | offsets_entry = (offsets_entry > 0) ? offsets_entry - 1 : offsets_entry; |
| 56 | |
| 57 | idx_t child_offset; |
| 58 | if (offsets_entry < 0) { |
| 59 | if (offsets_entry < -int64_t(list_entry.length)) { |
| 60 | result_mask.SetInvalid(i); |
| 61 | continue; |
| 62 | } |
| 63 | child_offset = list_entry.offset + list_entry.length + offsets_entry; |
| 64 | } else { |
| 65 | if ((idx_t)offsets_entry >= list_entry.length) { |
| 66 | result_mask.SetInvalid(i); |
| 67 | continue; |
| 68 | } |
| 69 | child_offset = list_entry.offset + offsets_entry; |
| 70 | } |
| 71 | auto child_index = child_format.sel->get_index(idx: child_offset); |
| 72 | if (child_format.validity.RowIsValid(row_idx: child_index)) { |
| 73 | if (!VALIDITY_ONLY) { |
| 74 | result_data[i] = child_data[child_index]; |
| 75 | } |
| 76 | } else { |
| 77 | result_mask.SetInvalid(i); |
| 78 | } |
| 79 | } |
| 80 | if (count == 1) { |
| 81 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
| 82 | } |
| 83 | } |
| 84 | static void (const idx_t count, UnifiedVectorFormat &list, UnifiedVectorFormat &offsets, |
| 85 | Vector &child_vector, idx_t list_size, Vector &result) { |
| 86 | D_ASSERT(child_vector.GetType() == result.GetType()); |
| 87 | switch (result.GetType().InternalType()) { |
| 88 | case PhysicalType::BOOL: |
| 89 | case PhysicalType::INT8: |
| 90 | ListExtractTemplate<int8_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 91 | break; |
| 92 | case PhysicalType::INT16: |
| 93 | ListExtractTemplate<int16_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 94 | break; |
| 95 | case PhysicalType::INT32: |
| 96 | ListExtractTemplate<int32_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 97 | break; |
| 98 | case PhysicalType::INT64: |
| 99 | ListExtractTemplate<int64_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 100 | break; |
| 101 | case PhysicalType::INT128: |
| 102 | ListExtractTemplate<hugeint_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 103 | break; |
| 104 | case PhysicalType::UINT8: |
| 105 | ListExtractTemplate<uint8_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 106 | break; |
| 107 | case PhysicalType::UINT16: |
| 108 | ListExtractTemplate<uint16_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 109 | break; |
| 110 | case PhysicalType::UINT32: |
| 111 | ListExtractTemplate<uint32_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 112 | break; |
| 113 | case PhysicalType::UINT64: |
| 114 | ListExtractTemplate<uint64_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 115 | break; |
| 116 | case PhysicalType::FLOAT: |
| 117 | ListExtractTemplate<float>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 118 | break; |
| 119 | case PhysicalType::DOUBLE: |
| 120 | ListExtractTemplate<double>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 121 | break; |
| 122 | case PhysicalType::VARCHAR: |
| 123 | ListExtractTemplate<string_t, true>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 124 | break; |
| 125 | case PhysicalType::INTERVAL: |
| 126 | ListExtractTemplate<interval_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 127 | break; |
| 128 | case PhysicalType::STRUCT: { |
| 129 | auto &entries = StructVector::GetEntries(vector&: child_vector); |
| 130 | auto &result_entries = StructVector::GetEntries(vector&: result); |
| 131 | D_ASSERT(entries.size() == result_entries.size()); |
| 132 | // extract the child entries of the struct |
| 133 | for (idx_t i = 0; i < entries.size(); i++) { |
| 134 | ExecuteListExtractInternal(count, list, offsets, child_vector&: *entries[i], list_size, result&: *result_entries[i]); |
| 135 | } |
| 136 | // extract the validity mask |
| 137 | ListExtractTemplate<bool, false, true>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 138 | break; |
| 139 | } |
| 140 | case PhysicalType::LIST: { |
| 141 | // nested list: we have to reference the child |
| 142 | auto &child_child_list = ListVector::GetEntry(vector&: child_vector); |
| 143 | |
| 144 | ListVector::GetEntry(vector&: result).Reference(other&: child_child_list); |
| 145 | ListVector::SetListSize(vec&: result, size: ListVector::GetListSize(vector: child_vector)); |
| 146 | ListExtractTemplate<list_entry_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
| 147 | break; |
| 148 | } |
| 149 | default: |
| 150 | throw NotImplementedException("Unimplemented type for LIST_EXTRACT" ); |
| 151 | } |
| 152 | } |
| 153 | |
| 154 | static void (Vector &result, Vector &list, Vector &offsets, const idx_t count) { |
| 155 | D_ASSERT(list.GetType().id() == LogicalTypeId::LIST); |
| 156 | UnifiedVectorFormat list_data; |
| 157 | UnifiedVectorFormat offsets_data; |
| 158 | |
| 159 | list.ToUnifiedFormat(count, data&: list_data); |
| 160 | offsets.ToUnifiedFormat(count, data&: offsets_data); |
| 161 | ExecuteListExtractInternal(count, list&: list_data, offsets&: offsets_data, child_vector&: ListVector::GetEntry(vector&: list), |
| 162 | list_size: ListVector::GetListSize(vector: list), result); |
| 163 | result.Verify(count); |
| 164 | } |
| 165 | |
| 166 | static void (Vector &result, Vector &input_vector, Vector &subscript_vector, const idx_t count) { |
| 167 | BinaryExecutor::Execute<string_t, int64_t, string_t>( |
| 168 | left&: input_vector, right&: subscript_vector, result, count, fun: [&](string_t input_string, int64_t subscript) { |
| 169 | return SubstringFun::SubstringUnicode(result, input: input_string, offset: subscript, length: 1); |
| 170 | }); |
| 171 | } |
| 172 | |
| 173 | static void (DataChunk &args, ExpressionState &state, Vector &result) { |
| 174 | D_ASSERT(args.ColumnCount() == 2); |
| 175 | auto count = args.size(); |
| 176 | |
| 177 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
| 178 | for (idx_t i = 0; i < args.ColumnCount(); i++) { |
| 179 | if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { |
| 180 | result.SetVectorType(VectorType::FLAT_VECTOR); |
| 181 | } |
| 182 | } |
| 183 | |
| 184 | Vector &base = args.data[0]; |
| 185 | Vector &subscript = args.data[1]; |
| 186 | |
| 187 | switch (base.GetType().id()) { |
| 188 | case LogicalTypeId::LIST: |
| 189 | ExecuteListExtract(result, list&: base, offsets&: subscript, count); |
| 190 | break; |
| 191 | case LogicalTypeId::VARCHAR: |
| 192 | ExecuteStringExtract(result, input_vector&: base, subscript_vector&: subscript, count); |
| 193 | break; |
| 194 | case LogicalTypeId::SQLNULL: |
| 195 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
| 196 | ConstantVector::SetNull(vector&: result, is_null: true); |
| 197 | break; |
| 198 | default: |
| 199 | throw NotImplementedException("Specifier type not implemented" ); |
| 200 | } |
| 201 | } |
| 202 | |
| 203 | static unique_ptr<FunctionData> (ClientContext &context, ScalarFunction &bound_function, |
| 204 | vector<unique_ptr<Expression>> &arguments) { |
| 205 | D_ASSERT(bound_function.arguments.size() == 2); |
| 206 | D_ASSERT(LogicalTypeId::LIST == arguments[0]->return_type.id()); |
| 207 | // list extract returns the child type of the list as return type |
| 208 | bound_function.return_type = ListType::GetChildType(type: arguments[0]->return_type); |
| 209 | return make_uniq<VariableReturnBindData>(args&: bound_function.return_type); |
| 210 | } |
| 211 | |
| 212 | static unique_ptr<BaseStatistics> (ClientContext &context, FunctionStatisticsInput &input) { |
| 213 | auto &child_stats = input.child_stats; |
| 214 | auto &list_child_stats = ListStats::GetChildStats(stats&: child_stats[0]); |
| 215 | auto child_copy = list_child_stats.Copy(); |
| 216 | // list_extract always pushes a NULL, since if the offset is out of range for a list it inserts a null |
| 217 | child_copy.Set(StatsInfo::CAN_HAVE_NULL_VALUES); |
| 218 | return child_copy.ToUnique(); |
| 219 | } |
| 220 | |
| 221 | void ListExtractFun::(BuiltinFunctions &set) { |
| 222 | // the arguments and return types are actually set in the binder function |
| 223 | ScalarFunction lfun({LogicalType::LIST(child: LogicalType::ANY), LogicalType::BIGINT}, LogicalType::ANY, |
| 224 | ListExtractFunction, ListExtractBind, nullptr, ListExtractStats); |
| 225 | |
| 226 | ScalarFunction sfun({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, ListExtractFunction); |
| 227 | |
| 228 | ScalarFunctionSet ("list_extract" ); |
| 229 | list_extract.AddFunction(function: lfun); |
| 230 | list_extract.AddFunction(function: sfun); |
| 231 | set.AddFunction(set: list_extract); |
| 232 | |
| 233 | ScalarFunctionSet list_element("list_element" ); |
| 234 | list_element.AddFunction(function: lfun); |
| 235 | list_element.AddFunction(function: sfun); |
| 236 | set.AddFunction(set: list_element); |
| 237 | |
| 238 | ScalarFunctionSet ("array_extract" ); |
| 239 | array_extract.AddFunction(function: lfun); |
| 240 | array_extract.AddFunction(function: sfun); |
| 241 | array_extract.AddFunction(function: StructExtractFun::GetFunction()); |
| 242 | set.AddFunction(set: array_extract); |
| 243 | } |
| 244 | |
| 245 | } // namespace duckdb |
| 246 | |