| 1 | #include "duckdb/common/sort/partition_state.hpp" |
| 2 | |
| 3 | #include "duckdb/common/types/column/column_data_consumer.hpp" |
| 4 | #include "duckdb/common/row_operations/row_operations.hpp" |
| 5 | #include "duckdb/main/config.hpp" |
| 6 | #include "duckdb/parallel/event.hpp" |
| 7 | |
| 8 | #include <numeric> |
| 9 | |
| 10 | namespace duckdb { |
| 11 | |
| 12 | PartitionGlobalHashGroup::PartitionGlobalHashGroup(BufferManager &buffer_manager, const Orders &partitions, |
| 13 | const Orders &orders, const Types &payload_types, bool external) |
| 14 | : count(0) { |
| 15 | |
| 16 | RowLayout payload_layout; |
| 17 | payload_layout.Initialize(types: payload_types); |
| 18 | global_sort = make_uniq<GlobalSortState>(args&: buffer_manager, args: orders, args&: payload_layout); |
| 19 | global_sort->external = external; |
| 20 | |
| 21 | // Set up a comparator for the partition subset |
| 22 | partition_layout = global_sort->sort_layout.GetPrefixComparisonLayout(num_prefix_cols: partitions.size()); |
| 23 | } |
| 24 | |
| 25 | int PartitionGlobalHashGroup::ComparePartitions(const SBIterator &left, const SBIterator &right) const { |
| 26 | int part_cmp = 0; |
| 27 | if (partition_layout.all_constant) { |
| 28 | part_cmp = FastMemcmp(str1: left.entry_ptr, str2: right.entry_ptr, size: partition_layout.comparison_size); |
| 29 | } else { |
| 30 | part_cmp = Comparators::CompareTuple(left: left.scan, right: right.scan, l_ptr: left.entry_ptr, r_ptr: right.entry_ptr, sort_layout: partition_layout, |
| 31 | external_sort: left.external); |
| 32 | } |
| 33 | return part_cmp; |
| 34 | } |
| 35 | |
| 36 | void PartitionGlobalHashGroup::ComputeMasks(ValidityMask &partition_mask, ValidityMask &order_mask) { |
| 37 | D_ASSERT(count > 0); |
| 38 | |
| 39 | SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN); |
| 40 | SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN); |
| 41 | |
| 42 | partition_mask.SetValidUnsafe(0); |
| 43 | order_mask.SetValidUnsafe(0); |
| 44 | for (++curr; curr.GetIndex() < count; ++curr) { |
| 45 | // Compare the partition subset first because if that differs, then so does the full ordering |
| 46 | const auto part_cmp = ComparePartitions(left: prev, right: curr); |
| 47 | ; |
| 48 | |
| 49 | if (part_cmp) { |
| 50 | partition_mask.SetValidUnsafe(curr.GetIndex()); |
| 51 | order_mask.SetValidUnsafe(curr.GetIndex()); |
| 52 | } else if (prev.Compare(other: curr)) { |
| 53 | order_mask.SetValidUnsafe(curr.GetIndex()); |
| 54 | } |
| 55 | ++prev; |
| 56 | } |
| 57 | } |
| 58 | |
| 59 | void PartitionGlobalSinkState::GenerateOrderings(Orders &partitions, Orders &orders, |
| 60 | const vector<unique_ptr<Expression>> &partition_bys, |
| 61 | const Orders &order_bys, |
| 62 | const vector<unique_ptr<BaseStatistics>> &partition_stats) { |
| 63 | |
| 64 | // we sort by both 1) partition by expression list and 2) order by expressions |
| 65 | const auto partition_cols = partition_bys.size(); |
| 66 | for (idx_t prt_idx = 0; prt_idx < partition_cols; prt_idx++) { |
| 67 | auto &pexpr = partition_bys[prt_idx]; |
| 68 | |
| 69 | if (partition_stats.empty() || !partition_stats[prt_idx]) { |
| 70 | orders.emplace_back(args: OrderType::ASCENDING, args: OrderByNullType::NULLS_FIRST, args: pexpr->Copy(), args: nullptr); |
| 71 | } else { |
| 72 | orders.emplace_back(args: OrderType::ASCENDING, args: OrderByNullType::NULLS_FIRST, args: pexpr->Copy(), |
| 73 | args: partition_stats[prt_idx]->ToUnique()); |
| 74 | } |
| 75 | partitions.emplace_back(args: orders.back().Copy()); |
| 76 | } |
| 77 | |
| 78 | for (const auto &order : order_bys) { |
| 79 | orders.emplace_back(args: order.Copy()); |
| 80 | } |
| 81 | } |
| 82 | |
| 83 | PartitionGlobalSinkState::PartitionGlobalSinkState(ClientContext &context, |
| 84 | const vector<unique_ptr<Expression>> &partition_bys, |
| 85 | const vector<BoundOrderByNode> &order_bys, |
| 86 | const Types &payload_types, |
| 87 | const vector<unique_ptr<BaseStatistics>> &partition_stats, |
| 88 | idx_t estimated_cardinality) |
| 89 | : context(context), buffer_manager(BufferManager::GetBufferManager(context)), allocator(Allocator::Get(context)), |
| 90 | payload_types(payload_types), memory_per_thread(0), count(0) { |
| 91 | |
| 92 | GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); |
| 93 | |
| 94 | memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); |
| 95 | external = ClientConfig::GetConfig(context).force_external; |
| 96 | |
| 97 | if (!orders.empty()) { |
| 98 | grouping_types = payload_types; |
| 99 | grouping_types.push_back(x: LogicalType::HASH); |
| 100 | |
| 101 | ResizeGroupingData(cardinality: estimated_cardinality); |
| 102 | } |
| 103 | } |
| 104 | |
| 105 | void PartitionGlobalSinkState::ResizeGroupingData(idx_t cardinality) { |
| 106 | // Have we started to combine? Then just live with it. |
| 107 | if (grouping_data && !grouping_data->GetPartitions().empty()) { |
| 108 | return; |
| 109 | } |
| 110 | // Is the average partition size too large? |
| 111 | const idx_t partition_size = STANDARD_ROW_GROUPS_SIZE; |
| 112 | const auto bits = grouping_data ? grouping_data->GetRadixBits() : 0; |
| 113 | auto new_bits = bits ? bits : 4; |
| 114 | while (new_bits < 10 && (cardinality / RadixPartitioning::NumberOfPartitions(radix_bits: new_bits)) > partition_size) { |
| 115 | ++new_bits; |
| 116 | } |
| 117 | |
| 118 | // Repartition the grouping data |
| 119 | if (new_bits != bits) { |
| 120 | const auto hash_col_idx = payload_types.size(); |
| 121 | grouping_data = make_uniq<RadixPartitionedColumnData>(args&: context, args&: grouping_types, args&: new_bits, args: hash_col_idx); |
| 122 | } |
| 123 | } |
| 124 | |
| 125 | void PartitionGlobalSinkState::SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { |
| 126 | // We are done if the local_partition is right sized. |
| 127 | auto &local_radix = local_partition->Cast<RadixPartitionedColumnData>(); |
| 128 | if (local_radix.GetRadixBits() == grouping_data->GetRadixBits()) { |
| 129 | return; |
| 130 | } |
| 131 | |
| 132 | // If the local partition is now too small, flush it and reallocate |
| 133 | auto new_partition = grouping_data->CreateShared(); |
| 134 | auto new_append = make_uniq<PartitionedColumnDataAppendState>(); |
| 135 | new_partition->InitializeAppendState(state&: *new_append); |
| 136 | |
| 137 | local_partition->FlushAppendState(state&: *local_append); |
| 138 | auto &local_groups = local_partition->GetPartitions(); |
| 139 | for (auto &local_group : local_groups) { |
| 140 | ColumnDataScanState scanner; |
| 141 | local_group->InitializeScan(state&: scanner); |
| 142 | |
| 143 | DataChunk scan_chunk; |
| 144 | local_group->InitializeScanChunk(chunk&: scan_chunk); |
| 145 | for (scan_chunk.Reset(); local_group->Scan(state&: scanner, result&: scan_chunk); scan_chunk.Reset()) { |
| 146 | new_partition->Append(state&: *new_append, input&: scan_chunk); |
| 147 | } |
| 148 | } |
| 149 | |
| 150 | // The append state has stale pointers to the old local partition, so nuke it from orbit. |
| 151 | new_partition->FlushAppendState(state&: *new_append); |
| 152 | |
| 153 | local_partition = std::move(new_partition); |
| 154 | local_append = make_uniq<PartitionedColumnDataAppendState>(); |
| 155 | local_partition->InitializeAppendState(state&: *local_append); |
| 156 | } |
| 157 | |
| 158 | void PartitionGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { |
| 159 | // Make sure grouping_data doesn't change under us. |
| 160 | lock_guard<mutex> guard(lock); |
| 161 | |
| 162 | if (!local_partition) { |
| 163 | local_partition = grouping_data->CreateShared(); |
| 164 | local_append = make_uniq<PartitionedColumnDataAppendState>(); |
| 165 | local_partition->InitializeAppendState(state&: *local_append); |
| 166 | return; |
| 167 | } |
| 168 | |
| 169 | // Grow the groups if they are too big |
| 170 | ResizeGroupingData(cardinality: count); |
| 171 | |
| 172 | // Sync local partition to have the same bit count |
| 173 | SyncLocalPartition(local_partition, local_append); |
| 174 | } |
| 175 | |
| 176 | void PartitionGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { |
| 177 | if (!local_partition) { |
| 178 | return; |
| 179 | } |
| 180 | local_partition->FlushAppendState(state&: *local_append); |
| 181 | |
| 182 | // Make sure grouping_data doesn't change under us. |
| 183 | // Combine has an internal mutex, so this is single-threaded anyway. |
| 184 | lock_guard<mutex> guard(lock); |
| 185 | SyncLocalPartition(local_partition, local_append); |
| 186 | grouping_data->Combine(other&: *local_partition); |
| 187 | } |
| 188 | |
| 189 | void PartitionGlobalSinkState::BuildSortState(ColumnDataCollection &group_data, PartitionGlobalHashGroup &hash_group) { |
| 190 | auto &global_sort = *hash_group.global_sort; |
| 191 | |
| 192 | // Set up the sort expression computation. |
| 193 | vector<LogicalType> sort_types; |
| 194 | ExpressionExecutor executor(context); |
| 195 | for (auto &order : orders) { |
| 196 | auto &oexpr = order.expression; |
| 197 | sort_types.emplace_back(args&: oexpr->return_type); |
| 198 | executor.AddExpression(expr: *oexpr); |
| 199 | } |
| 200 | DataChunk sort_chunk; |
| 201 | sort_chunk.Initialize(allocator, types: sort_types); |
| 202 | |
| 203 | // Copy the data from the group into the sort code. |
| 204 | LocalSortState local_sort; |
| 205 | local_sort.Initialize(global_sort_state&: global_sort, buffer_manager_p&: global_sort.buffer_manager); |
| 206 | |
| 207 | // Strip hash column |
| 208 | DataChunk payload_chunk; |
| 209 | payload_chunk.Initialize(allocator, types: payload_types); |
| 210 | |
| 211 | vector<column_t> column_ids; |
| 212 | column_ids.reserve(n: payload_types.size()); |
| 213 | for (column_t i = 0; i < payload_types.size(); ++i) { |
| 214 | column_ids.emplace_back(args&: i); |
| 215 | } |
| 216 | ColumnDataConsumer scanner(group_data, column_ids); |
| 217 | ColumnDataConsumerScanState chunk_state; |
| 218 | chunk_state.current_chunk_state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY; |
| 219 | scanner.InitializeScan(); |
| 220 | for (auto chunk_idx = scanner.ChunkCount(); chunk_idx-- > 0;) { |
| 221 | if (!scanner.AssignChunk(state&: chunk_state)) { |
| 222 | break; |
| 223 | } |
| 224 | scanner.ScanChunk(state&: chunk_state, chunk&: payload_chunk); |
| 225 | |
| 226 | sort_chunk.Reset(); |
| 227 | executor.Execute(input&: payload_chunk, result&: sort_chunk); |
| 228 | |
| 229 | local_sort.SinkChunk(sort&: sort_chunk, payload&: payload_chunk); |
| 230 | if (local_sort.SizeInBytes() > memory_per_thread) { |
| 231 | local_sort.Sort(global_sort_state&: global_sort, reorder_heap: true); |
| 232 | } |
| 233 | scanner.FinishChunk(state&: chunk_state); |
| 234 | } |
| 235 | |
| 236 | global_sort.AddLocalState(local_sort_state&: local_sort); |
| 237 | |
| 238 | hash_group.count += group_data.Count(); |
| 239 | } |
| 240 | |
| 241 | // Per-thread sink state |
| 242 | PartitionLocalSinkState::PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) |
| 243 | : gstate(gstate_p), allocator(Allocator::Get(context)), executor(context) { |
| 244 | |
| 245 | vector<LogicalType> group_types; |
| 246 | for (idx_t prt_idx = 0; prt_idx < gstate.partitions.size(); prt_idx++) { |
| 247 | auto &pexpr = *gstate.partitions[prt_idx].expression.get(); |
| 248 | group_types.push_back(x: pexpr.return_type); |
| 249 | executor.AddExpression(expr: pexpr); |
| 250 | } |
| 251 | sort_cols = gstate.orders.size() + group_types.size(); |
| 252 | |
| 253 | if (sort_cols) { |
| 254 | if (!group_types.empty()) { |
| 255 | // OVER(PARTITION BY...) |
| 256 | group_chunk.Initialize(allocator, types: group_types); |
| 257 | } |
| 258 | // OVER(...) |
| 259 | auto payload_types = gstate.payload_types; |
| 260 | payload_types.emplace_back(args: LogicalType::HASH); |
| 261 | payload_chunk.Initialize(allocator, types: payload_types); |
| 262 | } else { |
| 263 | // OVER() |
| 264 | payload_layout.Initialize(types: gstate.payload_types); |
| 265 | } |
| 266 | } |
| 267 | |
| 268 | void PartitionLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { |
| 269 | const auto count = input_chunk.size(); |
| 270 | if (group_chunk.ColumnCount() > 0) { |
| 271 | // OVER(PARTITION BY...) (hash grouping) |
| 272 | group_chunk.Reset(); |
| 273 | executor.Execute(input&: input_chunk, result&: group_chunk); |
| 274 | VectorOperations::Hash(input&: group_chunk.data[0], hashes&: hash_vector, count); |
| 275 | for (idx_t prt_idx = 1; prt_idx < group_chunk.ColumnCount(); ++prt_idx) { |
| 276 | VectorOperations::CombineHash(hashes&: hash_vector, input&: group_chunk.data[prt_idx], count); |
| 277 | } |
| 278 | } else { |
| 279 | // OVER(...) (sorting) |
| 280 | // Single partition => single hash value |
| 281 | hash_vector.SetVectorType(VectorType::CONSTANT_VECTOR); |
| 282 | auto hashes = ConstantVector::GetData<hash_t>(vector&: hash_vector); |
| 283 | hashes[0] = 0; |
| 284 | } |
| 285 | } |
| 286 | |
| 287 | void PartitionLocalSinkState::Sink(DataChunk &input_chunk) { |
| 288 | gstate.count += input_chunk.size(); |
| 289 | |
| 290 | // OVER() |
| 291 | if (sort_cols == 0) { |
| 292 | // No sorts, so build paged row chunks |
| 293 | if (!rows) { |
| 294 | const auto entry_size = payload_layout.GetRowWidth(); |
| 295 | const auto capacity = MaxValue<idx_t>(STANDARD_VECTOR_SIZE, b: (Storage::BLOCK_SIZE / entry_size) + 1); |
| 296 | rows = make_uniq<RowDataCollection>(args&: gstate.buffer_manager, args: capacity, args: entry_size); |
| 297 | strings = make_uniq<RowDataCollection>(args&: gstate.buffer_manager, args: (idx_t)Storage::BLOCK_SIZE, args: 1, args: true); |
| 298 | } |
| 299 | const auto row_count = input_chunk.size(); |
| 300 | const auto row_sel = FlatVector::IncrementalSelectionVector(); |
| 301 | Vector addresses(LogicalType::POINTER); |
| 302 | auto key_locations = FlatVector::GetData<data_ptr_t>(vector&: addresses); |
| 303 | const auto prev_rows_blocks = rows->blocks.size(); |
| 304 | auto handles = rows->Build(added_count: row_count, key_locations, entry_sizes: nullptr, sel: row_sel); |
| 305 | auto input_data = input_chunk.ToUnifiedFormat(); |
| 306 | RowOperations::Scatter(columns&: input_chunk, col_data: input_data.get(), layout: payload_layout, rows&: addresses, string_heap&: *strings, sel: *row_sel, count: row_count); |
| 307 | // Mark that row blocks contain pointers (heap blocks are pinned) |
| 308 | if (!payload_layout.AllConstant()) { |
| 309 | D_ASSERT(strings->keep_pinned); |
| 310 | for (size_t i = prev_rows_blocks; i < rows->blocks.size(); ++i) { |
| 311 | rows->blocks[i]->block->SetSwizzling("PartitionLocalSinkState::Sink" ); |
| 312 | } |
| 313 | } |
| 314 | return; |
| 315 | } |
| 316 | |
| 317 | // OVER(...) |
| 318 | payload_chunk.Reset(); |
| 319 | auto &hash_vector = payload_chunk.data.back(); |
| 320 | Hash(input_chunk, hash_vector); |
| 321 | for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); ++col_idx) { |
| 322 | payload_chunk.data[col_idx].Reference(other&: input_chunk.data[col_idx]); |
| 323 | } |
| 324 | payload_chunk.SetCardinality(input_chunk); |
| 325 | |
| 326 | gstate.UpdateLocalPartition(local_partition, local_append); |
| 327 | local_partition->Append(state&: *local_append, input&: payload_chunk); |
| 328 | } |
| 329 | |
| 330 | void PartitionLocalSinkState::Combine() { |
| 331 | // OVER() |
| 332 | if (sort_cols == 0) { |
| 333 | // Only one partition again, so need a global lock. |
| 334 | lock_guard<mutex> glock(gstate.lock); |
| 335 | if (gstate.rows) { |
| 336 | if (rows) { |
| 337 | gstate.rows->Merge(other&: *rows); |
| 338 | gstate.strings->Merge(other&: *strings); |
| 339 | rows.reset(); |
| 340 | strings.reset(); |
| 341 | } |
| 342 | } else { |
| 343 | gstate.rows = std::move(rows); |
| 344 | gstate.strings = std::move(strings); |
| 345 | } |
| 346 | return; |
| 347 | } |
| 348 | |
| 349 | // OVER(...) |
| 350 | gstate.CombineLocalPartition(local_partition, local_append); |
| 351 | } |
| 352 | |
| 353 | PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data, |
| 354 | hash_t hash_bin) |
| 355 | : sink(sink), group_data(std::move(group_data)), stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), |
| 356 | tasks_completed(0) { |
| 357 | |
| 358 | const auto group_idx = sink.hash_groups.size(); |
| 359 | auto new_group = make_uniq<PartitionGlobalHashGroup>(args&: sink.buffer_manager, args&: sink.partitions, args&: sink.orders, |
| 360 | args: sink.payload_types, args&: sink.external); |
| 361 | sink.hash_groups.emplace_back(args: std::move(new_group)); |
| 362 | |
| 363 | hash_group = sink.hash_groups[group_idx].get(); |
| 364 | global_sort = sink.hash_groups[group_idx]->global_sort.get(); |
| 365 | |
| 366 | sink.bin_groups[hash_bin] = group_idx; |
| 367 | } |
| 368 | |
| 369 | void PartitionLocalMergeState::Prepare() { |
| 370 | auto &global_sort = *merge_state->global_sort; |
| 371 | merge_state->sink.BuildSortState(group_data&: *merge_state->group_data, hash_group&: *merge_state->hash_group); |
| 372 | merge_state->group_data.reset(); |
| 373 | |
| 374 | global_sort.PrepareMergePhase(); |
| 375 | } |
| 376 | |
| 377 | void PartitionLocalMergeState::Merge() { |
| 378 | auto &global_sort = *merge_state->global_sort; |
| 379 | MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); |
| 380 | merge_sorter.PerformInMergeRound(); |
| 381 | } |
| 382 | |
| 383 | void PartitionLocalMergeState::ExecuteTask() { |
| 384 | switch (stage) { |
| 385 | case PartitionSortStage::PREPARE: |
| 386 | Prepare(); |
| 387 | break; |
| 388 | case PartitionSortStage::MERGE: |
| 389 | Merge(); |
| 390 | break; |
| 391 | default: |
| 392 | throw InternalException("Unexpected PartitionGlobalMergeState in ExecuteTask!" ); |
| 393 | } |
| 394 | |
| 395 | merge_state->CompleteTask(); |
| 396 | finished = true; |
| 397 | } |
| 398 | |
| 399 | bool PartitionGlobalMergeState::AssignTask(PartitionLocalMergeState &local_state) { |
| 400 | lock_guard<mutex> guard(lock); |
| 401 | |
| 402 | if (tasks_assigned >= total_tasks) { |
| 403 | return false; |
| 404 | } |
| 405 | |
| 406 | local_state.merge_state = this; |
| 407 | local_state.stage = stage; |
| 408 | local_state.finished = false; |
| 409 | tasks_assigned++; |
| 410 | |
| 411 | return true; |
| 412 | } |
| 413 | |
| 414 | void PartitionGlobalMergeState::CompleteTask() { |
| 415 | lock_guard<mutex> guard(lock); |
| 416 | |
| 417 | ++tasks_completed; |
| 418 | } |
| 419 | |
| 420 | bool PartitionGlobalMergeState::TryPrepareNextStage() { |
| 421 | lock_guard<mutex> guard(lock); |
| 422 | |
| 423 | if (tasks_completed < total_tasks) { |
| 424 | return false; |
| 425 | } |
| 426 | |
| 427 | tasks_assigned = tasks_completed = 0; |
| 428 | |
| 429 | switch (stage) { |
| 430 | case PartitionSortStage::INIT: |
| 431 | total_tasks = 1; |
| 432 | stage = PartitionSortStage::PREPARE; |
| 433 | return true; |
| 434 | |
| 435 | case PartitionSortStage::PREPARE: |
| 436 | total_tasks = global_sort->sorted_blocks.size() / 2; |
| 437 | if (!total_tasks) { |
| 438 | break; |
| 439 | } |
| 440 | stage = PartitionSortStage::MERGE; |
| 441 | global_sort->InitializeMergeRound(); |
| 442 | return true; |
| 443 | |
| 444 | case PartitionSortStage::MERGE: |
| 445 | global_sort->CompleteMergeRound(keep_radix_data: true); |
| 446 | total_tasks = global_sort->sorted_blocks.size() / 2; |
| 447 | if (!total_tasks) { |
| 448 | break; |
| 449 | } |
| 450 | global_sort->InitializeMergeRound(); |
| 451 | return true; |
| 452 | |
| 453 | case PartitionSortStage::SORTED: |
| 454 | break; |
| 455 | } |
| 456 | |
| 457 | stage = PartitionSortStage::SORTED; |
| 458 | |
| 459 | return false; |
| 460 | } |
| 461 | |
| 462 | PartitionGlobalMergeStates::PartitionGlobalMergeStates(PartitionGlobalSinkState &sink) { |
| 463 | // Schedule all the sorts for maximum thread utilisation |
| 464 | auto &partitions = sink.grouping_data->GetPartitions(); |
| 465 | sink.bin_groups.resize(new_size: partitions.size(), x: partitions.size()); |
| 466 | for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { |
| 467 | auto &group_data = partitions[hash_bin]; |
| 468 | // Prepare for merge sort phase |
| 469 | if (group_data->Count()) { |
| 470 | auto state = make_uniq<PartitionGlobalMergeState>(args&: sink, args: std::move(group_data), args&: hash_bin); |
| 471 | states.emplace_back(args: std::move(state)); |
| 472 | } |
| 473 | } |
| 474 | } |
| 475 | |
| 476 | class PartitionMergeTask : public ExecutorTask { |
| 477 | public: |
| 478 | PartitionMergeTask(shared_ptr<Event> event_p, ClientContext &context_p, PartitionGlobalMergeStates &hash_groups_p) |
| 479 | : ExecutorTask(context_p), event(std::move(event_p)), hash_groups(hash_groups_p) { |
| 480 | } |
| 481 | |
| 482 | TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; |
| 483 | |
| 484 | private: |
| 485 | shared_ptr<Event> event; |
| 486 | PartitionLocalMergeState local_state; |
| 487 | PartitionGlobalMergeStates &hash_groups; |
| 488 | }; |
| 489 | |
| 490 | TaskExecutionResult PartitionMergeTask::ExecuteTask(TaskExecutionMode mode) { |
| 491 | // Loop until all hash groups are done |
| 492 | size_t sorted = 0; |
| 493 | while (sorted < hash_groups.states.size()) { |
| 494 | // First check if there is an unfinished task for this thread |
| 495 | if (executor.HasError()) { |
| 496 | return TaskExecutionResult::TASK_ERROR; |
| 497 | } |
| 498 | if (!local_state.TaskFinished()) { |
| 499 | local_state.ExecuteTask(); |
| 500 | continue; |
| 501 | } |
| 502 | |
| 503 | // Thread is done with its assigned task, try to fetch new work |
| 504 | for (auto group = sorted; group < hash_groups.states.size(); ++group) { |
| 505 | auto &global_state = hash_groups.states[group]; |
| 506 | if (global_state->IsSorted()) { |
| 507 | // This hash group is done |
| 508 | // Update the high water mark of densely completed groups |
| 509 | if (sorted == group) { |
| 510 | ++sorted; |
| 511 | } |
| 512 | continue; |
| 513 | } |
| 514 | |
| 515 | // Try to assign work for this hash group to this thread |
| 516 | if (global_state->AssignTask(local_state)) { |
| 517 | // We assigned a task to this thread! |
| 518 | // Break out of this loop to re-enter the top-level loop and execute the task |
| 519 | break; |
| 520 | } |
| 521 | |
| 522 | // Hash group global state couldn't assign a task to this thread |
| 523 | // Try to prepare the next stage |
| 524 | if (!global_state->TryPrepareNextStage()) { |
| 525 | // This current hash group is not yet done |
| 526 | // But we were not able to assign a task for it to this thread |
| 527 | // See if the next hash group is better |
| 528 | continue; |
| 529 | } |
| 530 | |
| 531 | // We were able to prepare the next stage for this hash group! |
| 532 | // Try to assign a task once more |
| 533 | if (global_state->AssignTask(local_state)) { |
| 534 | // We assigned a task to this thread! |
| 535 | // Break out of this loop to re-enter the top-level loop and execute the task |
| 536 | break; |
| 537 | } |
| 538 | |
| 539 | // We were able to prepare the next merge round, |
| 540 | // but we were not able to assign a task for it to this thread |
| 541 | // The tasks were assigned to other threads while this thread waited for the lock |
| 542 | // Go to the next iteration to see if another hash group has a task |
| 543 | } |
| 544 | } |
| 545 | |
| 546 | event->FinishTask(); |
| 547 | return TaskExecutionResult::TASK_FINISHED; |
| 548 | } |
| 549 | |
| 550 | void PartitionMergeEvent::Schedule() { |
| 551 | auto &context = pipeline->GetClientContext(); |
| 552 | |
| 553 | // Schedule tasks equal to the number of threads, which will each merge multiple partitions |
| 554 | auto &ts = TaskScheduler::GetScheduler(context); |
| 555 | idx_t num_threads = ts.NumberOfThreads(); |
| 556 | |
| 557 | vector<shared_ptr<Task>> merge_tasks; |
| 558 | for (idx_t tnum = 0; tnum < num_threads; tnum++) { |
| 559 | merge_tasks.emplace_back(args: make_uniq<PartitionMergeTask>(args: shared_from_this(), args&: context, args&: merge_states)); |
| 560 | } |
| 561 | SetTasks(std::move(merge_tasks)); |
| 562 | } |
| 563 | |
| 564 | } // namespace duckdb |
| 565 | |