| 1 | #include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" |
| 2 | |
| 3 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
| 4 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
| 5 | #include "duckdb/execution/aggregate_hashtable.hpp" |
| 6 | #include "duckdb/main/client_context.hpp" |
| 7 | #include "duckdb/parallel/interrupt.hpp" |
| 8 | #include "duckdb/parallel/pipeline.hpp" |
| 9 | #include "duckdb/parallel/task_scheduler.hpp" |
| 10 | #include "duckdb/parallel/thread_context.hpp" |
| 11 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
| 12 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
| 13 | #include "duckdb/planner/expression/bound_reference_expression.hpp" |
| 14 | #include "duckdb/parallel/base_pipeline_event.hpp" |
| 15 | #include "duckdb/common/atomic.hpp" |
| 16 | #include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" |
| 17 | |
| 18 | namespace duckdb { |
| 19 | |
| 20 | HashAggregateGroupingData::HashAggregateGroupingData(GroupingSet &grouping_set_p, |
| 21 | const GroupedAggregateData &grouped_aggregate_data, |
| 22 | unique_ptr<DistinctAggregateCollectionInfo> &info) |
| 23 | : table_data(grouping_set_p, grouped_aggregate_data) { |
| 24 | if (info) { |
| 25 | distinct_data = make_uniq<DistinctAggregateData>(args&: *info, args&: grouping_set_p, args: &grouped_aggregate_data.groups); |
| 26 | } |
| 27 | } |
| 28 | |
| 29 | bool HashAggregateGroupingData::HasDistinct() const { |
| 30 | return distinct_data != nullptr; |
| 31 | } |
| 32 | |
| 33 | HashAggregateGroupingGlobalState::HashAggregateGroupingGlobalState(const HashAggregateGroupingData &data, |
| 34 | ClientContext &context) { |
| 35 | table_state = data.table_data.GetGlobalSinkState(context); |
| 36 | if (data.HasDistinct()) { |
| 37 | distinct_state = make_uniq<DistinctAggregateState>(args&: *data.distinct_data, args&: context); |
| 38 | } |
| 39 | } |
| 40 | |
| 41 | HashAggregateGroupingLocalState::HashAggregateGroupingLocalState(const PhysicalHashAggregate &op, |
| 42 | const HashAggregateGroupingData &data, |
| 43 | ExecutionContext &context) { |
| 44 | table_state = data.table_data.GetLocalSinkState(context); |
| 45 | if (!data.HasDistinct()) { |
| 46 | return; |
| 47 | } |
| 48 | auto &distinct_data = *data.distinct_data; |
| 49 | |
| 50 | auto &distinct_indices = op.distinct_collection_info->Indices(); |
| 51 | D_ASSERT(!distinct_indices.empty()); |
| 52 | |
| 53 | distinct_states.resize(new_size: op.distinct_collection_info->aggregates.size()); |
| 54 | auto &table_map = op.distinct_collection_info->table_map; |
| 55 | |
| 56 | for (auto &idx : distinct_indices) { |
| 57 | idx_t table_idx = table_map[idx]; |
| 58 | auto &radix_table = distinct_data.radix_tables[table_idx]; |
| 59 | if (radix_table == nullptr) { |
| 60 | // This aggregate has identical input as another aggregate, so no table is created for it |
| 61 | continue; |
| 62 | } |
| 63 | // Initialize the states of the radix tables used for the distinct aggregates |
| 64 | distinct_states[table_idx] = radix_table->GetLocalSinkState(context); |
| 65 | } |
| 66 | } |
| 67 | |
| 68 | static vector<LogicalType> CreateGroupChunkTypes(vector<unique_ptr<Expression>> &groups) { |
| 69 | set<idx_t> group_indices; |
| 70 | |
| 71 | if (groups.empty()) { |
| 72 | return {}; |
| 73 | } |
| 74 | |
| 75 | for (auto &group : groups) { |
| 76 | D_ASSERT(group->type == ExpressionType::BOUND_REF); |
| 77 | auto &bound_ref = group->Cast<BoundReferenceExpression>(); |
| 78 | group_indices.insert(x: bound_ref.index); |
| 79 | } |
| 80 | idx_t highest_index = *group_indices.rbegin(); |
| 81 | vector<LogicalType> types(highest_index + 1, LogicalType::SQLNULL); |
| 82 | for (auto &group : groups) { |
| 83 | auto &bound_ref = group->Cast<BoundReferenceExpression>(); |
| 84 | types[bound_ref.index] = bound_ref.return_type; |
| 85 | } |
| 86 | return types; |
| 87 | } |
| 88 | |
| 89 | bool PhysicalHashAggregate::CanSkipRegularSink() const { |
| 90 | if (!filter_indexes.empty()) { |
| 91 | // If we have filters, we can't skip the regular sink, because we might lose groups otherwise. |
| 92 | return false; |
| 93 | } |
| 94 | if (grouped_aggregate_data.aggregates.empty()) { |
| 95 | // When there are no aggregates, we have to add to the main ht right away |
| 96 | return false; |
| 97 | } |
| 98 | if (!non_distinct_filter.empty()) { |
| 99 | return false; |
| 100 | } |
| 101 | return true; |
| 102 | } |
| 103 | |
| 104 | PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector<LogicalType> types, |
| 105 | vector<unique_ptr<Expression>> expressions, idx_t estimated_cardinality) |
| 106 | : PhysicalHashAggregate(context, std::move(types), std::move(expressions), {}, estimated_cardinality) { |
| 107 | } |
| 108 | |
| 109 | PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector<LogicalType> types, |
| 110 | vector<unique_ptr<Expression>> expressions, |
| 111 | vector<unique_ptr<Expression>> groups_p, idx_t estimated_cardinality) |
| 112 | : PhysicalHashAggregate(context, std::move(types), std::move(expressions), std::move(groups_p), {}, {}, |
| 113 | estimated_cardinality) { |
| 114 | } |
| 115 | |
| 116 | PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector<LogicalType> types, |
| 117 | vector<unique_ptr<Expression>> expressions, |
| 118 | vector<unique_ptr<Expression>> groups_p, |
| 119 | vector<GroupingSet> grouping_sets_p, |
| 120 | vector<unsafe_vector<idx_t>> grouping_functions_p, |
| 121 | idx_t estimated_cardinality) |
| 122 | : PhysicalOperator(PhysicalOperatorType::HASH_GROUP_BY, std::move(types), estimated_cardinality), |
| 123 | grouping_sets(std::move(grouping_sets_p)) { |
| 124 | // get a list of all aggregates to be computed |
| 125 | const idx_t group_count = groups_p.size(); |
| 126 | if (grouping_sets.empty()) { |
| 127 | GroupingSet set; |
| 128 | for (idx_t i = 0; i < group_count; i++) { |
| 129 | set.insert(x: i); |
| 130 | } |
| 131 | grouping_sets.push_back(x: std::move(set)); |
| 132 | } |
| 133 | input_group_types = CreateGroupChunkTypes(groups&: groups_p); |
| 134 | |
| 135 | grouped_aggregate_data.InitializeGroupby(groups: std::move(groups_p), expressions: std::move(expressions), |
| 136 | grouping_functions: std::move(grouping_functions_p)); |
| 137 | |
| 138 | auto &aggregates = grouped_aggregate_data.aggregates; |
| 139 | // filter_indexes must be pre-built, not lazily instantiated in parallel... |
| 140 | // Because everything that lives in this class should be read-only at execution time |
| 141 | idx_t aggregate_input_idx = 0; |
| 142 | for (idx_t i = 0; i < aggregates.size(); i++) { |
| 143 | auto &aggregate = aggregates[i]; |
| 144 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
| 145 | aggregate_input_idx += aggr.children.size(); |
| 146 | if (aggr.aggr_type == AggregateType::DISTINCT) { |
| 147 | distinct_filter.push_back(x: i); |
| 148 | } else if (aggr.aggr_type == AggregateType::NON_DISTINCT) { |
| 149 | non_distinct_filter.push_back(x: i); |
| 150 | } else { // LCOV_EXCL_START |
| 151 | throw NotImplementedException("AggregateType not implemented in PhysicalHashAggregate" ); |
| 152 | } // LCOV_EXCL_STOP |
| 153 | } |
| 154 | |
| 155 | for (idx_t i = 0; i < aggregates.size(); i++) { |
| 156 | auto &aggregate = aggregates[i]; |
| 157 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
| 158 | if (aggr.filter) { |
| 159 | auto &bound_ref_expr = aggr.filter->Cast<BoundReferenceExpression>(); |
| 160 | if (!filter_indexes.count(x: aggr.filter.get())) { |
| 161 | // Replace the bound reference expression's index with the corresponding index of the payload chunk |
| 162 | filter_indexes[aggr.filter.get()] = bound_ref_expr.index; |
| 163 | bound_ref_expr.index = aggregate_input_idx; |
| 164 | } |
| 165 | aggregate_input_idx++; |
| 166 | } |
| 167 | } |
| 168 | |
| 169 | distinct_collection_info = DistinctAggregateCollectionInfo::Create(aggregates&: grouped_aggregate_data.aggregates); |
| 170 | |
| 171 | for (idx_t i = 0; i < grouping_sets.size(); i++) { |
| 172 | groupings.emplace_back(args&: grouping_sets[i], args&: grouped_aggregate_data, args&: distinct_collection_info); |
| 173 | } |
| 174 | } |
| 175 | |
| 176 | //===--------------------------------------------------------------------===// |
| 177 | // Sink |
| 178 | //===--------------------------------------------------------------------===// |
| 179 | class HashAggregateGlobalState : public GlobalSinkState { |
| 180 | public: |
| 181 | HashAggregateGlobalState(const PhysicalHashAggregate &op, ClientContext &context) { |
| 182 | grouping_states.reserve(n: op.groupings.size()); |
| 183 | for (idx_t i = 0; i < op.groupings.size(); i++) { |
| 184 | auto &grouping = op.groupings[i]; |
| 185 | grouping_states.emplace_back(args: grouping, args&: context); |
| 186 | } |
| 187 | vector<LogicalType> filter_types; |
| 188 | for (auto &aggr : op.grouped_aggregate_data.aggregates) { |
| 189 | auto &aggregate = aggr->Cast<BoundAggregateExpression>(); |
| 190 | for (auto &child : aggregate.children) { |
| 191 | payload_types.push_back(x: child->return_type); |
| 192 | } |
| 193 | if (aggregate.filter) { |
| 194 | filter_types.push_back(x: aggregate.filter->return_type); |
| 195 | } |
| 196 | } |
| 197 | payload_types.reserve(n: payload_types.size() + filter_types.size()); |
| 198 | payload_types.insert(position: payload_types.end(), first: filter_types.begin(), last: filter_types.end()); |
| 199 | } |
| 200 | |
| 201 | vector<HashAggregateGroupingGlobalState> grouping_states; |
| 202 | vector<LogicalType> payload_types; |
| 203 | //! Whether or not the aggregate is finished |
| 204 | bool finished = false; |
| 205 | }; |
| 206 | |
| 207 | class HashAggregateLocalState : public LocalSinkState { |
| 208 | public: |
| 209 | HashAggregateLocalState(const PhysicalHashAggregate &op, ExecutionContext &context) { |
| 210 | |
| 211 | auto &payload_types = op.grouped_aggregate_data.payload_types; |
| 212 | if (!payload_types.empty()) { |
| 213 | aggregate_input_chunk.InitializeEmpty(types: payload_types); |
| 214 | } |
| 215 | |
| 216 | grouping_states.reserve(n: op.groupings.size()); |
| 217 | for (auto &grouping : op.groupings) { |
| 218 | grouping_states.emplace_back(args: op, args: grouping, args&: context); |
| 219 | } |
| 220 | // The filter set is only needed here for the distinct aggregates |
| 221 | // the filtering of data for the regular aggregates is done within the hashtable |
| 222 | vector<AggregateObject> aggregate_objects; |
| 223 | for (auto &aggregate : op.grouped_aggregate_data.aggregates) { |
| 224 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
| 225 | aggregate_objects.emplace_back(args: &aggr); |
| 226 | } |
| 227 | |
| 228 | filter_set.Initialize(context&: context.client, aggregates: aggregate_objects, payload_types); |
| 229 | } |
| 230 | |
| 231 | DataChunk aggregate_input_chunk; |
| 232 | vector<HashAggregateGroupingLocalState> grouping_states; |
| 233 | AggregateFilterDataSet filter_set; |
| 234 | }; |
| 235 | |
| 236 | void PhysicalHashAggregate::SetMultiScan(GlobalSinkState &state) { |
| 237 | auto &gstate = state.Cast<HashAggregateGlobalState>(); |
| 238 | for (auto &grouping_state : gstate.grouping_states) { |
| 239 | auto &radix_state = grouping_state.table_state; |
| 240 | RadixPartitionedHashTable::SetMultiScan(*radix_state); |
| 241 | if (!grouping_state.distinct_state) { |
| 242 | continue; |
| 243 | } |
| 244 | } |
| 245 | } |
| 246 | |
| 247 | unique_ptr<GlobalSinkState> PhysicalHashAggregate::GetGlobalSinkState(ClientContext &context) const { |
| 248 | return make_uniq<HashAggregateGlobalState>(args: *this, args&: context); |
| 249 | } |
| 250 | |
| 251 | unique_ptr<LocalSinkState> PhysicalHashAggregate::GetLocalSinkState(ExecutionContext &context) const { |
| 252 | return make_uniq<HashAggregateLocalState>(args: *this, args&: context); |
| 253 | } |
| 254 | |
| 255 | void PhysicalHashAggregate::SinkDistinctGrouping(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input, |
| 256 | idx_t grouping_idx) const { |
| 257 | auto &sink = input.local_state.Cast<HashAggregateLocalState>(); |
| 258 | auto &global_sink = input.global_state.Cast<HashAggregateGlobalState>(); |
| 259 | |
| 260 | auto &grouping_gstate = global_sink.grouping_states[grouping_idx]; |
| 261 | auto &grouping_lstate = sink.grouping_states[grouping_idx]; |
| 262 | auto &distinct_info = *distinct_collection_info; |
| 263 | |
| 264 | auto &distinct_state = grouping_gstate.distinct_state; |
| 265 | auto &distinct_data = groupings[grouping_idx].distinct_data; |
| 266 | |
| 267 | DataChunk empty_chunk; |
| 268 | |
| 269 | // Create an empty filter for Sink, since we don't need to update any aggregate states here |
| 270 | unsafe_vector<idx_t> empty_filter; |
| 271 | |
| 272 | for (idx_t &idx : distinct_info.indices) { |
| 273 | auto &aggregate = grouped_aggregate_data.aggregates[idx]->Cast<BoundAggregateExpression>(); |
| 274 | |
| 275 | D_ASSERT(distinct_info.table_map.count(idx)); |
| 276 | idx_t table_idx = distinct_info.table_map[idx]; |
| 277 | if (!distinct_data->radix_tables[table_idx]) { |
| 278 | continue; |
| 279 | } |
| 280 | D_ASSERT(distinct_data->radix_tables[table_idx]); |
| 281 | auto &radix_table = *distinct_data->radix_tables[table_idx]; |
| 282 | auto &radix_global_sink = *distinct_state->radix_states[table_idx]; |
| 283 | auto &radix_local_sink = *grouping_lstate.distinct_states[table_idx]; |
| 284 | |
| 285 | InterruptState interrupt_state; |
| 286 | OperatorSinkInput sink_input {.global_state: radix_global_sink, .local_state: radix_local_sink, .interrupt_state: interrupt_state}; |
| 287 | |
| 288 | if (aggregate.filter) { |
| 289 | DataChunk filter_chunk; |
| 290 | auto &filtered_data = sink.filter_set.GetFilterData(aggr_idx: idx); |
| 291 | filter_chunk.InitializeEmpty(types: filtered_data.filtered_payload.GetTypes()); |
| 292 | |
| 293 | // Add the filter Vector (BOOL) |
| 294 | auto it = filter_indexes.find(x: aggregate.filter.get()); |
| 295 | D_ASSERT(it != filter_indexes.end()); |
| 296 | D_ASSERT(it->second < chunk.data.size()); |
| 297 | auto &filter_bound_ref = aggregate.filter->Cast<BoundReferenceExpression>(); |
| 298 | filter_chunk.data[filter_bound_ref.index].Reference(other&: chunk.data[it->second]); |
| 299 | filter_chunk.SetCardinality(chunk.size()); |
| 300 | |
| 301 | // We cant use the AggregateFilterData::ApplyFilter method, because the chunk we need to |
| 302 | // apply the filter to also has the groups, and the filtered_data.filtered_payload does not have those. |
| 303 | SelectionVector sel_vec(STANDARD_VECTOR_SIZE); |
| 304 | idx_t count = filtered_data.filter_executor.SelectExpression(input&: filter_chunk, sel&: sel_vec); |
| 305 | |
| 306 | if (count == 0) { |
| 307 | continue; |
| 308 | } |
| 309 | |
| 310 | // Because the 'input' chunk needs to be re-used after this, we need to create |
| 311 | // a duplicate of it, that we can apply the filter to |
| 312 | DataChunk filtered_input; |
| 313 | filtered_input.InitializeEmpty(types: chunk.GetTypes()); |
| 314 | |
| 315 | for (idx_t group_idx = 0; group_idx < grouped_aggregate_data.groups.size(); group_idx++) { |
| 316 | auto &group = grouped_aggregate_data.groups[group_idx]; |
| 317 | auto &bound_ref = group->Cast<BoundReferenceExpression>(); |
| 318 | filtered_input.data[bound_ref.index].Reference(other&: chunk.data[bound_ref.index]); |
| 319 | } |
| 320 | for (idx_t child_idx = 0; child_idx < aggregate.children.size(); child_idx++) { |
| 321 | auto &child = aggregate.children[child_idx]; |
| 322 | auto &bound_ref = child->Cast<BoundReferenceExpression>(); |
| 323 | |
| 324 | filtered_input.data[bound_ref.index].Reference(other&: chunk.data[bound_ref.index]); |
| 325 | } |
| 326 | filtered_input.Slice(sel_vector: sel_vec, count); |
| 327 | filtered_input.SetCardinality(count); |
| 328 | |
| 329 | radix_table.Sink(context, chunk&: filtered_input, input&: sink_input, aggregate_input_chunk&: empty_chunk, filter: empty_filter); |
| 330 | } else { |
| 331 | radix_table.Sink(context, chunk, input&: sink_input, aggregate_input_chunk&: empty_chunk, filter: empty_filter); |
| 332 | } |
| 333 | } |
| 334 | } |
| 335 | |
| 336 | void PhysicalHashAggregate::SinkDistinct(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { |
| 337 | for (idx_t i = 0; i < groupings.size(); i++) { |
| 338 | SinkDistinctGrouping(context, chunk, input, grouping_idx: i); |
| 339 | } |
| 340 | } |
| 341 | |
| 342 | SinkResultType PhysicalHashAggregate::Sink(ExecutionContext &context, DataChunk &chunk, |
| 343 | OperatorSinkInput &input) const { |
| 344 | auto &llstate = input.local_state.Cast<HashAggregateLocalState>(); |
| 345 | auto &gstate = input.global_state.Cast<HashAggregateGlobalState>(); |
| 346 | |
| 347 | if (distinct_collection_info) { |
| 348 | SinkDistinct(context, chunk, input); |
| 349 | } |
| 350 | |
| 351 | if (CanSkipRegularSink()) { |
| 352 | return SinkResultType::NEED_MORE_INPUT; |
| 353 | } |
| 354 | |
| 355 | DataChunk &aggregate_input_chunk = llstate.aggregate_input_chunk; |
| 356 | |
| 357 | auto &aggregates = grouped_aggregate_data.aggregates; |
| 358 | idx_t aggregate_input_idx = 0; |
| 359 | |
| 360 | // Populate the aggregate child vectors |
| 361 | for (auto &aggregate : aggregates) { |
| 362 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
| 363 | for (auto &child_expr : aggr.children) { |
| 364 | D_ASSERT(child_expr->type == ExpressionType::BOUND_REF); |
| 365 | auto &bound_ref_expr = child_expr->Cast<BoundReferenceExpression>(); |
| 366 | D_ASSERT(bound_ref_expr.index < chunk.data.size()); |
| 367 | aggregate_input_chunk.data[aggregate_input_idx++].Reference(other&: chunk.data[bound_ref_expr.index]); |
| 368 | } |
| 369 | } |
| 370 | // Populate the filter vectors |
| 371 | for (auto &aggregate : aggregates) { |
| 372 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
| 373 | if (aggr.filter) { |
| 374 | auto it = filter_indexes.find(x: aggr.filter.get()); |
| 375 | D_ASSERT(it != filter_indexes.end()); |
| 376 | D_ASSERT(it->second < chunk.data.size()); |
| 377 | aggregate_input_chunk.data[aggregate_input_idx++].Reference(other&: chunk.data[it->second]); |
| 378 | } |
| 379 | } |
| 380 | |
| 381 | aggregate_input_chunk.SetCardinality(chunk.size()); |
| 382 | aggregate_input_chunk.Verify(); |
| 383 | |
| 384 | // For every grouping set there is one radix_table |
| 385 | for (idx_t i = 0; i < groupings.size(); i++) { |
| 386 | auto &grouping_gstate = gstate.grouping_states[i]; |
| 387 | auto &grouping_lstate = llstate.grouping_states[i]; |
| 388 | InterruptState interrupt_state; |
| 389 | OperatorSinkInput sink_input {.global_state: *grouping_gstate.table_state, .local_state: *grouping_lstate.table_state, .interrupt_state: interrupt_state}; |
| 390 | |
| 391 | auto &grouping = groupings[i]; |
| 392 | auto &table = grouping.table_data; |
| 393 | table.Sink(context, chunk, input&: sink_input, aggregate_input_chunk, filter: non_distinct_filter); |
| 394 | } |
| 395 | |
| 396 | return SinkResultType::NEED_MORE_INPUT; |
| 397 | } |
| 398 | |
| 399 | void PhysicalHashAggregate::CombineDistinct(ExecutionContext &context, GlobalSinkState &state, |
| 400 | LocalSinkState &lstate) const { |
| 401 | auto &global_sink = state.Cast<HashAggregateGlobalState>(); |
| 402 | auto &sink = lstate.Cast<HashAggregateLocalState>(); |
| 403 | |
| 404 | if (!distinct_collection_info) { |
| 405 | return; |
| 406 | } |
| 407 | for (idx_t i = 0; i < groupings.size(); i++) { |
| 408 | auto &grouping_gstate = global_sink.grouping_states[i]; |
| 409 | auto &grouping_lstate = sink.grouping_states[i]; |
| 410 | |
| 411 | auto &distinct_data = groupings[i].distinct_data; |
| 412 | auto &distinct_state = grouping_gstate.distinct_state; |
| 413 | |
| 414 | const auto table_count = distinct_data->radix_tables.size(); |
| 415 | for (idx_t table_idx = 0; table_idx < table_count; table_idx++) { |
| 416 | if (!distinct_data->radix_tables[table_idx]) { |
| 417 | continue; |
| 418 | } |
| 419 | auto &radix_table = *distinct_data->radix_tables[table_idx]; |
| 420 | auto &radix_global_sink = *distinct_state->radix_states[table_idx]; |
| 421 | auto &radix_local_sink = *grouping_lstate.distinct_states[table_idx]; |
| 422 | |
| 423 | radix_table.Combine(context, state&: radix_global_sink, lstate&: radix_local_sink); |
| 424 | } |
| 425 | } |
| 426 | } |
| 427 | |
| 428 | void PhysicalHashAggregate::Combine(ExecutionContext &context, GlobalSinkState &state, LocalSinkState &lstate) const { |
| 429 | auto &gstate = state.Cast<HashAggregateGlobalState>(); |
| 430 | auto &llstate = lstate.Cast<HashAggregateLocalState>(); |
| 431 | |
| 432 | CombineDistinct(context, state, lstate); |
| 433 | |
| 434 | if (CanSkipRegularSink()) { |
| 435 | return; |
| 436 | } |
| 437 | for (idx_t i = 0; i < groupings.size(); i++) { |
| 438 | auto &grouping_gstate = gstate.grouping_states[i]; |
| 439 | auto &grouping_lstate = llstate.grouping_states[i]; |
| 440 | |
| 441 | auto &grouping = groupings[i]; |
| 442 | auto &table = grouping.table_data; |
| 443 | table.Combine(context, state&: *grouping_gstate.table_state, lstate&: *grouping_lstate.table_state); |
| 444 | } |
| 445 | } |
| 446 | |
| 447 | //! REGULAR FINALIZE EVENT |
| 448 | |
| 449 | class HashAggregateMergeEvent : public BasePipelineEvent { |
| 450 | public: |
| 451 | HashAggregateMergeEvent(const PhysicalHashAggregate &op_p, HashAggregateGlobalState &gstate_p, Pipeline *pipeline_p) |
| 452 | : BasePipelineEvent(*pipeline_p), op(op_p), gstate(gstate_p) { |
| 453 | } |
| 454 | |
| 455 | const PhysicalHashAggregate &op; |
| 456 | HashAggregateGlobalState &gstate; |
| 457 | |
| 458 | public: |
| 459 | void Schedule() override { |
| 460 | vector<shared_ptr<Task>> tasks; |
| 461 | for (idx_t i = 0; i < op.groupings.size(); i++) { |
| 462 | auto &grouping_gstate = gstate.grouping_states[i]; |
| 463 | |
| 464 | auto &grouping = op.groupings[i]; |
| 465 | auto &table = grouping.table_data; |
| 466 | table.ScheduleTasks(executor&: pipeline->executor, event: shared_from_this(), state&: *grouping_gstate.table_state, tasks); |
| 467 | } |
| 468 | D_ASSERT(!tasks.empty()); |
| 469 | SetTasks(std::move(tasks)); |
| 470 | } |
| 471 | }; |
| 472 | |
| 473 | //! REGULAR FINALIZE FROM DISTINCT FINALIZE |
| 474 | |
| 475 | class HashAggregateFinalizeTask : public ExecutorTask { |
| 476 | public: |
| 477 | HashAggregateFinalizeTask(Pipeline &pipeline, shared_ptr<Event> event_p, HashAggregateGlobalState &state_p, |
| 478 | ClientContext &context, const PhysicalHashAggregate &op) |
| 479 | : ExecutorTask(pipeline.executor), pipeline(pipeline), event(std::move(event_p)), gstate(state_p), |
| 480 | context(context), op(op) { |
| 481 | } |
| 482 | |
| 483 | TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { |
| 484 | op.FinalizeInternal(pipeline, event&: *event, context, gstate, check_distinct: false); |
| 485 | D_ASSERT(!gstate.finished); |
| 486 | gstate.finished = true; |
| 487 | event->FinishTask(); |
| 488 | return TaskExecutionResult::TASK_FINISHED; |
| 489 | } |
| 490 | |
| 491 | private: |
| 492 | Pipeline &pipeline; |
| 493 | shared_ptr<Event> event; |
| 494 | HashAggregateGlobalState &gstate; |
| 495 | ClientContext &context; |
| 496 | const PhysicalHashAggregate &op; |
| 497 | }; |
| 498 | |
| 499 | class HashAggregateFinalizeEvent : public BasePipelineEvent { |
| 500 | public: |
| 501 | HashAggregateFinalizeEvent(const PhysicalHashAggregate &op_p, HashAggregateGlobalState &gstate_p, |
| 502 | Pipeline *pipeline_p, ClientContext &context) |
| 503 | : BasePipelineEvent(*pipeline_p), op(op_p), gstate(gstate_p), context(context) { |
| 504 | } |
| 505 | |
| 506 | const PhysicalHashAggregate &op; |
| 507 | HashAggregateGlobalState &gstate; |
| 508 | ClientContext &context; |
| 509 | |
| 510 | public: |
| 511 | void Schedule() override { |
| 512 | vector<shared_ptr<Task>> tasks; |
| 513 | tasks.push_back(x: make_uniq<HashAggregateFinalizeTask>(args&: *pipeline, args: shared_from_this(), args&: gstate, args&: context, args: op)); |
| 514 | D_ASSERT(!tasks.empty()); |
| 515 | SetTasks(std::move(tasks)); |
| 516 | } |
| 517 | }; |
| 518 | |
| 519 | //! DISTINCT FINALIZE TASK |
| 520 | |
| 521 | class HashDistinctAggregateFinalizeTask : public ExecutorTask { |
| 522 | public: |
| 523 | HashDistinctAggregateFinalizeTask(Pipeline &pipeline, shared_ptr<Event> event_p, HashAggregateGlobalState &state_p, |
| 524 | ClientContext &context, const PhysicalHashAggregate &op, |
| 525 | vector<vector<unique_ptr<GlobalSourceState>>> &global_sources_p) |
| 526 | : ExecutorTask(pipeline.executor), pipeline(pipeline), event(std::move(event_p)), gstate(state_p), |
| 527 | context(context), op(op), global_sources(global_sources_p) { |
| 528 | } |
| 529 | |
| 530 | void AggregateDistinctGrouping(DistinctAggregateCollectionInfo &info, |
| 531 | const HashAggregateGroupingData &grouping_data, |
| 532 | HashAggregateGroupingGlobalState &grouping_state, idx_t grouping_idx) { |
| 533 | auto &aggregates = info.aggregates; |
| 534 | auto &data = *grouping_data.distinct_data; |
| 535 | auto &state = *grouping_state.distinct_state; |
| 536 | auto &table_state = *grouping_state.table_state; |
| 537 | |
| 538 | ThreadContext temp_thread_context(context); |
| 539 | ExecutionContext temp_exec_context(context, temp_thread_context, &pipeline); |
| 540 | |
| 541 | auto temp_local_state = grouping_data.table_data.GetLocalSinkState(context&: temp_exec_context); |
| 542 | |
| 543 | // Create a chunk that mimics the 'input' chunk in Sink, for storing the group vectors |
| 544 | DataChunk group_chunk; |
| 545 | if (!op.input_group_types.empty()) { |
| 546 | group_chunk.Initialize(context, types: op.input_group_types); |
| 547 | } |
| 548 | |
| 549 | auto &groups = op.grouped_aggregate_data.groups; |
| 550 | const idx_t group_by_size = groups.size(); |
| 551 | |
| 552 | DataChunk aggregate_input_chunk; |
| 553 | if (!gstate.payload_types.empty()) { |
| 554 | aggregate_input_chunk.Initialize(context, types: gstate.payload_types); |
| 555 | } |
| 556 | |
| 557 | idx_t payload_idx; |
| 558 | idx_t next_payload_idx = 0; |
| 559 | |
| 560 | for (idx_t i = 0; i < op.grouped_aggregate_data.aggregates.size(); i++) { |
| 561 | auto &aggregate = aggregates[i]->Cast<BoundAggregateExpression>(); |
| 562 | |
| 563 | // Forward the payload idx |
| 564 | payload_idx = next_payload_idx; |
| 565 | next_payload_idx = payload_idx + aggregate.children.size(); |
| 566 | |
| 567 | // If aggregate is not distinct, skip it |
| 568 | if (!data.IsDistinct(index: i)) { |
| 569 | continue; |
| 570 | } |
| 571 | D_ASSERT(data.info.table_map.count(i)); |
| 572 | auto table_idx = data.info.table_map.at(k: i); |
| 573 | auto &radix_table_p = data.radix_tables[table_idx]; |
| 574 | |
| 575 | // Create a duplicate of the output_chunk, because of multi-threading we cant alter the original |
| 576 | DataChunk output_chunk; |
| 577 | output_chunk.Initialize(context, types: state.distinct_output_chunks[table_idx]->GetTypes()); |
| 578 | |
| 579 | auto &global_source = global_sources[grouping_idx][i]; |
| 580 | auto local_source = radix_table_p->GetLocalSourceState(context&: temp_exec_context); |
| 581 | |
| 582 | // Fetch all the data from the aggregate ht, and Sink it into the main ht |
| 583 | while (true) { |
| 584 | output_chunk.Reset(); |
| 585 | group_chunk.Reset(); |
| 586 | aggregate_input_chunk.Reset(); |
| 587 | |
| 588 | InterruptState interrupt_state; |
| 589 | OperatorSourceInput source_input {.global_state: *global_source, .local_state: *local_source, .interrupt_state: interrupt_state}; |
| 590 | auto res = radix_table_p->GetData(context&: temp_exec_context, chunk&: output_chunk, sink_state&: *state.radix_states[table_idx], |
| 591 | input&: source_input); |
| 592 | |
| 593 | if (res == SourceResultType::FINISHED) { |
| 594 | D_ASSERT(output_chunk.size() == 0); |
| 595 | break; |
| 596 | } else if (res == SourceResultType::BLOCKED) { |
| 597 | throw InternalException( |
| 598 | "Unexpected interrupt from radix table GetData in HashDistinctAggregateFinalizeTask" ); |
| 599 | } |
| 600 | |
| 601 | auto &grouped_aggregate_data = *data.grouped_aggregate_data[table_idx]; |
| 602 | |
| 603 | for (idx_t group_idx = 0; group_idx < group_by_size; group_idx++) { |
| 604 | auto &group = grouped_aggregate_data.groups[group_idx]; |
| 605 | auto &bound_ref_expr = group->Cast<BoundReferenceExpression>(); |
| 606 | group_chunk.data[bound_ref_expr.index].Reference(other&: output_chunk.data[group_idx]); |
| 607 | } |
| 608 | group_chunk.SetCardinality(output_chunk); |
| 609 | |
| 610 | for (idx_t child_idx = 0; child_idx < grouped_aggregate_data.groups.size() - group_by_size; |
| 611 | child_idx++) { |
| 612 | aggregate_input_chunk.data[payload_idx + child_idx].Reference( |
| 613 | other&: output_chunk.data[group_by_size + child_idx]); |
| 614 | } |
| 615 | aggregate_input_chunk.SetCardinality(output_chunk); |
| 616 | |
| 617 | // Sink it into the main ht |
| 618 | OperatorSinkInput sink_input {.global_state: table_state, .local_state: *temp_local_state, .interrupt_state: interrupt_state}; |
| 619 | grouping_data.table_data.Sink(context&: temp_exec_context, chunk&: group_chunk, input&: sink_input, aggregate_input_chunk, filter: {i}); |
| 620 | } |
| 621 | } |
| 622 | grouping_data.table_data.Combine(context&: temp_exec_context, state&: table_state, lstate&: *temp_local_state); |
| 623 | } |
| 624 | |
| 625 | TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { |
| 626 | D_ASSERT(op.distinct_collection_info); |
| 627 | auto &info = *op.distinct_collection_info; |
| 628 | for (idx_t i = 0; i < op.groupings.size(); i++) { |
| 629 | auto &grouping = op.groupings[i]; |
| 630 | auto &grouping_state = gstate.grouping_states[i]; |
| 631 | AggregateDistinctGrouping(info, grouping_data: grouping, grouping_state, grouping_idx: i); |
| 632 | } |
| 633 | event->FinishTask(); |
| 634 | return TaskExecutionResult::TASK_FINISHED; |
| 635 | } |
| 636 | |
| 637 | private: |
| 638 | Pipeline &pipeline; |
| 639 | shared_ptr<Event> event; |
| 640 | HashAggregateGlobalState &gstate; |
| 641 | ClientContext &context; |
| 642 | const PhysicalHashAggregate &op; |
| 643 | vector<vector<unique_ptr<GlobalSourceState>>> &global_sources; |
| 644 | }; |
| 645 | |
| 646 | //! DISTINCT FINALIZE EVENT |
| 647 | |
| 648 | // TODO: Create tasks and run these in parallel instead of doing this all in Schedule, single threaded |
| 649 | class HashDistinctAggregateFinalizeEvent : public BasePipelineEvent { |
| 650 | public: |
| 651 | HashDistinctAggregateFinalizeEvent(const PhysicalHashAggregate &op_p, HashAggregateGlobalState &gstate_p, |
| 652 | Pipeline &pipeline_p, ClientContext &context) |
| 653 | : BasePipelineEvent(pipeline_p), op(op_p), gstate(gstate_p), context(context) { |
| 654 | } |
| 655 | const PhysicalHashAggregate &op; |
| 656 | HashAggregateGlobalState &gstate; |
| 657 | ClientContext &context; |
| 658 | //! The GlobalSourceStates for all the radix tables of the distinct aggregates |
| 659 | vector<vector<unique_ptr<GlobalSourceState>>> global_sources; |
| 660 | |
| 661 | public: |
| 662 | void Schedule() override { |
| 663 | global_sources = CreateGlobalSources(); |
| 664 | |
| 665 | vector<shared_ptr<Task>> tasks; |
| 666 | auto &scheduler = TaskScheduler::GetScheduler(context); |
| 667 | auto number_of_threads = scheduler.NumberOfThreads(); |
| 668 | tasks.reserve(n: number_of_threads); |
| 669 | for (int32_t i = 0; i < number_of_threads; i++) { |
| 670 | tasks.push_back(x: make_uniq<HashDistinctAggregateFinalizeTask>(args&: *pipeline, args: shared_from_this(), args&: gstate, args&: context, |
| 671 | args: op, args&: global_sources)); |
| 672 | } |
| 673 | D_ASSERT(!tasks.empty()); |
| 674 | SetTasks(std::move(tasks)); |
| 675 | } |
| 676 | |
| 677 | void FinishEvent() override { |
| 678 | //! Now that everything is added to the main ht, we can actually finalize |
| 679 | auto new_event = make_shared<HashAggregateFinalizeEvent>(args: op, args&: gstate, args: pipeline.get(), args&: context); |
| 680 | this->InsertEvent(replacement_event: std::move(new_event)); |
| 681 | } |
| 682 | |
| 683 | private: |
| 684 | vector<vector<unique_ptr<GlobalSourceState>>> CreateGlobalSources() { |
| 685 | vector<vector<unique_ptr<GlobalSourceState>>> grouping_sources; |
| 686 | grouping_sources.reserve(n: op.groupings.size()); |
| 687 | for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) { |
| 688 | auto &grouping = op.groupings[grouping_idx]; |
| 689 | auto &data = *grouping.distinct_data; |
| 690 | |
| 691 | vector<unique_ptr<GlobalSourceState>> aggregate_sources; |
| 692 | aggregate_sources.reserve(n: op.grouped_aggregate_data.aggregates.size()); |
| 693 | |
| 694 | for (idx_t i = 0; i < op.grouped_aggregate_data.aggregates.size(); i++) { |
| 695 | auto &aggregate = op.grouped_aggregate_data.aggregates[i]; |
| 696 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
| 697 | |
| 698 | if (!aggr.IsDistinct()) { |
| 699 | aggregate_sources.push_back(x: nullptr); |
| 700 | continue; |
| 701 | } |
| 702 | |
| 703 | D_ASSERT(data.info.table_map.count(i)); |
| 704 | auto table_idx = data.info.table_map.at(k: i); |
| 705 | auto &radix_table_p = data.radix_tables[table_idx]; |
| 706 | aggregate_sources.push_back(x: radix_table_p->GetGlobalSourceState(context)); |
| 707 | } |
| 708 | grouping_sources.push_back(x: std::move(aggregate_sources)); |
| 709 | } |
| 710 | return grouping_sources; |
| 711 | } |
| 712 | }; |
| 713 | |
| 714 | //! DISTINCT COMBINE EVENT |
| 715 | |
| 716 | class HashDistinctCombineFinalizeEvent : public BasePipelineEvent { |
| 717 | public: |
| 718 | HashDistinctCombineFinalizeEvent(const PhysicalHashAggregate &op_p, HashAggregateGlobalState &gstate_p, |
| 719 | Pipeline &pipeline_p, ClientContext &client) |
| 720 | : BasePipelineEvent(pipeline_p), op(op_p), gstate(gstate_p), client(client) { |
| 721 | } |
| 722 | |
| 723 | const PhysicalHashAggregate &op; |
| 724 | HashAggregateGlobalState &gstate; |
| 725 | ClientContext &client; |
| 726 | |
| 727 | public: |
| 728 | void Schedule() override { |
| 729 | vector<shared_ptr<Task>> tasks; |
| 730 | for (idx_t i = 0; i < op.groupings.size(); i++) { |
| 731 | auto &grouping = op.groupings[i]; |
| 732 | auto &distinct_data = *grouping.distinct_data; |
| 733 | auto &distinct_state = *gstate.grouping_states[i].distinct_state; |
| 734 | for (idx_t table_idx = 0; table_idx < distinct_data.radix_tables.size(); table_idx++) { |
| 735 | if (!distinct_data.radix_tables[table_idx]) { |
| 736 | continue; |
| 737 | } |
| 738 | distinct_data.radix_tables[table_idx]->ScheduleTasks(executor&: pipeline->executor, event: shared_from_this(), |
| 739 | state&: *distinct_state.radix_states[table_idx], tasks); |
| 740 | } |
| 741 | } |
| 742 | |
| 743 | D_ASSERT(!tasks.empty()); |
| 744 | SetTasks(std::move(tasks)); |
| 745 | } |
| 746 | |
| 747 | void FinishEvent() override { |
| 748 | //! Now that all tables are combined, it's time to do the distinct aggregations |
| 749 | auto new_event = make_shared<HashDistinctAggregateFinalizeEvent>(args: op, args&: gstate, args&: *pipeline, args&: client); |
| 750 | this->InsertEvent(replacement_event: std::move(new_event)); |
| 751 | } |
| 752 | }; |
| 753 | |
| 754 | //! FINALIZE |
| 755 | |
| 756 | SinkFinalizeType PhysicalHashAggregate::FinalizeDistinct(Pipeline &pipeline, Event &event, ClientContext &context, |
| 757 | GlobalSinkState &gstate_p) const { |
| 758 | auto &gstate = gstate_p.Cast<HashAggregateGlobalState>(); |
| 759 | D_ASSERT(distinct_collection_info); |
| 760 | |
| 761 | bool any_partitioned = false; |
| 762 | for (idx_t i = 0; i < groupings.size(); i++) { |
| 763 | auto &grouping = groupings[i]; |
| 764 | auto &distinct_data = *grouping.distinct_data; |
| 765 | auto &distinct_state = *gstate.grouping_states[i].distinct_state; |
| 766 | |
| 767 | for (idx_t table_idx = 0; table_idx < distinct_data.radix_tables.size(); table_idx++) { |
| 768 | if (!distinct_data.radix_tables[table_idx]) { |
| 769 | continue; |
| 770 | } |
| 771 | auto &radix_table = distinct_data.radix_tables[table_idx]; |
| 772 | auto &radix_state = *distinct_state.radix_states[table_idx]; |
| 773 | bool partitioned = radix_table->Finalize(context, gstate_p&: radix_state); |
| 774 | if (partitioned) { |
| 775 | any_partitioned = true; |
| 776 | } |
| 777 | } |
| 778 | } |
| 779 | if (any_partitioned) { |
| 780 | // If any of the groupings are partitioned then we first need to combine those, then aggregate |
| 781 | auto new_event = make_shared<HashDistinctCombineFinalizeEvent>(args: *this, args&: gstate, args&: pipeline, args&: context); |
| 782 | event.InsertEvent(replacement_event: std::move(new_event)); |
| 783 | } else { |
| 784 | // Hashtables aren't partitioned, they dont need to be joined first |
| 785 | // so we can already compute the aggregate |
| 786 | auto new_event = make_shared<HashDistinctAggregateFinalizeEvent>(args: *this, args&: gstate, args&: pipeline, args&: context); |
| 787 | event.InsertEvent(replacement_event: std::move(new_event)); |
| 788 | } |
| 789 | return SinkFinalizeType::READY; |
| 790 | } |
| 791 | |
| 792 | SinkFinalizeType PhysicalHashAggregate::FinalizeInternal(Pipeline &pipeline, Event &event, ClientContext &context, |
| 793 | GlobalSinkState &gstate_p, bool check_distinct) const { |
| 794 | auto &gstate = gstate_p.Cast<HashAggregateGlobalState>(); |
| 795 | |
| 796 | if (check_distinct && distinct_collection_info) { |
| 797 | // There are distinct aggregates |
| 798 | // If these are partitioned those need to be combined first |
| 799 | // Then we Finalize again, skipping this step |
| 800 | return FinalizeDistinct(pipeline, event, context, gstate_p); |
| 801 | } |
| 802 | |
| 803 | bool any_partitioned = false; |
| 804 | for (idx_t i = 0; i < groupings.size(); i++) { |
| 805 | auto &grouping = groupings[i]; |
| 806 | auto &grouping_gstate = gstate.grouping_states[i]; |
| 807 | |
| 808 | bool is_partitioned = grouping.table_data.Finalize(context, gstate_p&: *grouping_gstate.table_state); |
| 809 | if (is_partitioned) { |
| 810 | any_partitioned = true; |
| 811 | } |
| 812 | } |
| 813 | if (any_partitioned) { |
| 814 | auto new_event = make_shared<HashAggregateMergeEvent>(args: *this, args&: gstate, args: &pipeline); |
| 815 | event.InsertEvent(replacement_event: std::move(new_event)); |
| 816 | } |
| 817 | return SinkFinalizeType::READY; |
| 818 | } |
| 819 | |
| 820 | SinkFinalizeType PhysicalHashAggregate::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, |
| 821 | GlobalSinkState &gstate_p) const { |
| 822 | return FinalizeInternal(pipeline, event, context, gstate_p, check_distinct: true); |
| 823 | } |
| 824 | |
| 825 | //===--------------------------------------------------------------------===// |
| 826 | // Source |
| 827 | //===--------------------------------------------------------------------===// |
| 828 | class PhysicalHashAggregateGlobalSourceState : public GlobalSourceState { |
| 829 | public: |
| 830 | PhysicalHashAggregateGlobalSourceState(ClientContext &context, const PhysicalHashAggregate &op) |
| 831 | : op(op), state_index(0) { |
| 832 | for (auto &grouping : op.groupings) { |
| 833 | auto &rt = grouping.table_data; |
| 834 | radix_states.push_back(x: rt.GetGlobalSourceState(context)); |
| 835 | } |
| 836 | } |
| 837 | |
| 838 | const PhysicalHashAggregate &op; |
| 839 | mutex lock; |
| 840 | atomic<idx_t> state_index; |
| 841 | |
| 842 | vector<unique_ptr<GlobalSourceState>> radix_states; |
| 843 | |
| 844 | public: |
| 845 | idx_t MaxThreads() override { |
| 846 | // If there are no tables, we only need one thread. |
| 847 | if (op.groupings.empty()) { |
| 848 | return 1; |
| 849 | } |
| 850 | |
| 851 | auto &ht_state = op.sink_state->Cast<HashAggregateGlobalState>(); |
| 852 | idx_t count = 0; |
| 853 | for (size_t sidx = 0; sidx < op.groupings.size(); ++sidx) { |
| 854 | auto &grouping = op.groupings[sidx]; |
| 855 | auto &grouping_gstate = ht_state.grouping_states[sidx]; |
| 856 | count += grouping.table_data.Size(sink_state&: *grouping_gstate.table_state); |
| 857 | } |
| 858 | return MaxValue<idx_t>(a: 1, b: count / STANDARD_VECTOR_SIZE); |
| 859 | } |
| 860 | }; |
| 861 | |
| 862 | unique_ptr<GlobalSourceState> PhysicalHashAggregate::GetGlobalSourceState(ClientContext &context) const { |
| 863 | return make_uniq<PhysicalHashAggregateGlobalSourceState>(args&: context, args: *this); |
| 864 | } |
| 865 | |
| 866 | class PhysicalHashAggregateLocalSourceState : public LocalSourceState { |
| 867 | public: |
| 868 | explicit PhysicalHashAggregateLocalSourceState(ExecutionContext &context, const PhysicalHashAggregate &op) { |
| 869 | for (auto &grouping : op.groupings) { |
| 870 | auto &rt = grouping.table_data; |
| 871 | radix_states.push_back(x: rt.GetLocalSourceState(context)); |
| 872 | } |
| 873 | } |
| 874 | |
| 875 | vector<unique_ptr<LocalSourceState>> radix_states; |
| 876 | }; |
| 877 | |
| 878 | unique_ptr<LocalSourceState> PhysicalHashAggregate::GetLocalSourceState(ExecutionContext &context, |
| 879 | GlobalSourceState &gstate) const { |
| 880 | return make_uniq<PhysicalHashAggregateLocalSourceState>(args&: context, args: *this); |
| 881 | } |
| 882 | |
| 883 | SourceResultType PhysicalHashAggregate::GetData(ExecutionContext &context, DataChunk &chunk, |
| 884 | OperatorSourceInput &input) const { |
| 885 | auto &sink_gstate = sink_state->Cast<HashAggregateGlobalState>(); |
| 886 | auto &gstate = input.global_state.Cast<PhysicalHashAggregateGlobalSourceState>(); |
| 887 | auto &lstate = input.local_state.Cast<PhysicalHashAggregateLocalSourceState>(); |
| 888 | while (true) { |
| 889 | idx_t radix_idx = gstate.state_index; |
| 890 | if (radix_idx >= groupings.size()) { |
| 891 | break; |
| 892 | } |
| 893 | auto &grouping = groupings[radix_idx]; |
| 894 | auto &radix_table = grouping.table_data; |
| 895 | auto &grouping_gstate = sink_gstate.grouping_states[radix_idx]; |
| 896 | |
| 897 | InterruptState interrupt_state; |
| 898 | OperatorSourceInput source_input {.global_state: *gstate.radix_states[radix_idx], .local_state: *lstate.radix_states[radix_idx], |
| 899 | .interrupt_state: interrupt_state}; |
| 900 | auto res = radix_table.GetData(context, chunk, sink_state&: *grouping_gstate.table_state, input&: source_input); |
| 901 | if (chunk.size() != 0) { |
| 902 | return SourceResultType::HAVE_MORE_OUTPUT; |
| 903 | } else if (res == SourceResultType::BLOCKED) { |
| 904 | throw InternalException("Unexpectedly Blocked from radix_table" ); |
| 905 | } |
| 906 | |
| 907 | // move to the next table |
| 908 | lock_guard<mutex> l(gstate.lock); |
| 909 | radix_idx++; |
| 910 | if (radix_idx > gstate.state_index) { |
| 911 | // we have not yet worked on the table |
| 912 | // move the global index forwards |
| 913 | gstate.state_index = radix_idx; |
| 914 | } |
| 915 | } |
| 916 | |
| 917 | return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; |
| 918 | } |
| 919 | |
| 920 | string PhysicalHashAggregate::ParamsToString() const { |
| 921 | string result; |
| 922 | auto &groups = grouped_aggregate_data.groups; |
| 923 | auto &aggregates = grouped_aggregate_data.aggregates; |
| 924 | for (idx_t i = 0; i < groups.size(); i++) { |
| 925 | if (i > 0) { |
| 926 | result += "\n" ; |
| 927 | } |
| 928 | result += groups[i]->GetName(); |
| 929 | } |
| 930 | for (idx_t i = 0; i < aggregates.size(); i++) { |
| 931 | auto &aggregate = aggregates[i]->Cast<BoundAggregateExpression>(); |
| 932 | if (i > 0 || !groups.empty()) { |
| 933 | result += "\n" ; |
| 934 | } |
| 935 | result += aggregates[i]->GetName(); |
| 936 | if (aggregate.filter) { |
| 937 | result += " Filter: " + aggregate.filter->GetName(); |
| 938 | } |
| 939 | } |
| 940 | return result; |
| 941 | } |
| 942 | |
| 943 | } // namespace duckdb |
| 944 | |