| 1 | #include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" |
| 2 | #include "duckdb/planner/expression.hpp" |
| 3 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
| 4 | #include "duckdb/planner/expression/bound_reference_expression.hpp" |
| 5 | #include "duckdb/common/algorithm.hpp" |
| 6 | |
| 7 | namespace duckdb { |
| 8 | |
| 9 | //! Shared information about a collection of distinct aggregates |
| 10 | DistinctAggregateCollectionInfo::DistinctAggregateCollectionInfo(const vector<unique_ptr<Expression>> &aggregates, |
| 11 | vector<idx_t> indices) |
| 12 | : indices(std::move(indices)), aggregates(aggregates) { |
| 13 | table_count = CreateTableIndexMap(); |
| 14 | |
| 15 | const idx_t aggregate_count = aggregates.size(); |
| 16 | |
| 17 | total_child_count = 0; |
| 18 | for (idx_t i = 0; i < aggregate_count; i++) { |
| 19 | auto &aggregate = aggregates[i]->Cast<BoundAggregateExpression>(); |
| 20 | |
| 21 | if (!aggregate.IsDistinct()) { |
| 22 | continue; |
| 23 | } |
| 24 | total_child_count += aggregate.children.size(); |
| 25 | } |
| 26 | } |
| 27 | |
| 28 | //! Stateful data for the distinct aggregates |
| 29 | |
| 30 | DistinctAggregateState::DistinctAggregateState(const DistinctAggregateData &data, ClientContext &client) |
| 31 | : child_executor(client) { |
| 32 | |
| 33 | radix_states.resize(new_size: data.info.table_count); |
| 34 | distinct_output_chunks.resize(new_size: data.info.table_count); |
| 35 | |
| 36 | idx_t aggregate_count = data.info.aggregates.size(); |
| 37 | for (idx_t i = 0; i < aggregate_count; i++) { |
| 38 | auto &aggregate = data.info.aggregates[i]->Cast<BoundAggregateExpression>(); |
| 39 | |
| 40 | // Initialize the child executor and get the payload types for every aggregate |
| 41 | for (auto &child : aggregate.children) { |
| 42 | child_executor.AddExpression(expr: *child); |
| 43 | } |
| 44 | if (!aggregate.IsDistinct()) { |
| 45 | continue; |
| 46 | } |
| 47 | D_ASSERT(data.info.table_map.count(i)); |
| 48 | idx_t table_idx = data.info.table_map.at(k: i); |
| 49 | if (data.radix_tables[table_idx] == nullptr) { |
| 50 | //! This table is unused because the aggregate shares its data with another |
| 51 | continue; |
| 52 | } |
| 53 | |
| 54 | // Get the global sinkstate for the aggregate |
| 55 | auto &radix_table = *data.radix_tables[table_idx]; |
| 56 | radix_states[table_idx] = radix_table.GetGlobalSinkState(context&: client); |
| 57 | |
| 58 | // Fill the chunk_types (group_by + children) |
| 59 | vector<LogicalType> chunk_types; |
| 60 | for (auto &group_type : data.grouped_aggregate_data[table_idx]->group_types) { |
| 61 | chunk_types.push_back(x: group_type); |
| 62 | } |
| 63 | |
| 64 | // This is used in Finalize to get the data from the radix table |
| 65 | distinct_output_chunks[table_idx] = make_uniq<DataChunk>(); |
| 66 | distinct_output_chunks[table_idx]->Initialize(context&: client, types: chunk_types); |
| 67 | } |
| 68 | } |
| 69 | |
| 70 | //! Persistent + shared (read-only) data for the distinct aggregates |
| 71 | DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info) |
| 72 | : DistinctAggregateData(info, {}, nullptr) { |
| 73 | } |
| 74 | |
| 75 | DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info, const GroupingSet &groups, |
| 76 | const vector<unique_ptr<Expression>> *group_expressions) |
| 77 | : info(info) { |
| 78 | grouped_aggregate_data.resize(new_size: info.table_count); |
| 79 | radix_tables.resize(new_size: info.table_count); |
| 80 | grouping_sets.resize(new_size: info.table_count); |
| 81 | |
| 82 | for (auto &i : info.indices) { |
| 83 | auto &aggregate = info.aggregates[i]->Cast<BoundAggregateExpression>(); |
| 84 | |
| 85 | D_ASSERT(info.table_map.count(i)); |
| 86 | idx_t table_idx = info.table_map.at(k: i); |
| 87 | if (radix_tables[table_idx] != nullptr) { |
| 88 | //! This aggregate shares a table with another aggregate, and the table is already initialized |
| 89 | continue; |
| 90 | } |
| 91 | // The grouping set contains the indices of the chunk that correspond to the data vector |
| 92 | // that will be used to figure out in which bucket the payload should be put |
| 93 | auto &grouping_set = grouping_sets[table_idx]; |
| 94 | //! Populate the group with the children of the aggregate |
| 95 | for (auto &group : groups) { |
| 96 | grouping_set.insert(x: group); |
| 97 | } |
| 98 | idx_t group_by_size = group_expressions ? group_expressions->size() : 0; |
| 99 | for (idx_t set_idx = 0; set_idx < aggregate.children.size(); set_idx++) { |
| 100 | grouping_set.insert(x: set_idx + group_by_size); |
| 101 | } |
| 102 | // Create the hashtable for the aggregate |
| 103 | grouped_aggregate_data[table_idx] = make_uniq<GroupedAggregateData>(); |
| 104 | grouped_aggregate_data[table_idx]->InitializeDistinct(aggregate: info.aggregates[i], groups_p: group_expressions); |
| 105 | radix_tables[table_idx] = |
| 106 | make_uniq<RadixPartitionedHashTable>(args&: grouping_set, args&: *grouped_aggregate_data[table_idx]); |
| 107 | |
| 108 | // Fill the chunk_types (only contains the payload of the distinct aggregates) |
| 109 | vector<LogicalType> chunk_types; |
| 110 | for (auto &child_p : aggregate.children) { |
| 111 | chunk_types.push_back(x: child_p->return_type); |
| 112 | } |
| 113 | } |
| 114 | } |
| 115 | |
| 116 | using aggr_ref_t = reference<BoundAggregateExpression>; |
| 117 | |
| 118 | struct FindMatchingAggregate { |
| 119 | explicit FindMatchingAggregate(const aggr_ref_t &aggr) : aggr_r(aggr) { |
| 120 | } |
| 121 | bool operator()(const aggr_ref_t other_r) { |
| 122 | auto &other = other_r.get(); |
| 123 | auto &aggr = aggr_r.get(); |
| 124 | if (other.children.size() != aggr.children.size()) { |
| 125 | return false; |
| 126 | } |
| 127 | if (!Expression::Equals(left: aggr.filter, right: other.filter)) { |
| 128 | return false; |
| 129 | } |
| 130 | for (idx_t i = 0; i < aggr.children.size(); i++) { |
| 131 | auto &other_child = other.children[i]->Cast<BoundReferenceExpression>(); |
| 132 | auto &aggr_child = aggr.children[i]->Cast<BoundReferenceExpression>(); |
| 133 | if (other_child.index != aggr_child.index) { |
| 134 | return false; |
| 135 | } |
| 136 | } |
| 137 | return true; |
| 138 | } |
| 139 | const aggr_ref_t aggr_r; |
| 140 | }; |
| 141 | |
| 142 | idx_t DistinctAggregateCollectionInfo::CreateTableIndexMap() { |
| 143 | vector<aggr_ref_t> table_inputs; |
| 144 | |
| 145 | D_ASSERT(table_map.empty()); |
| 146 | for (auto &agg_idx : indices) { |
| 147 | D_ASSERT(agg_idx < aggregates.size()); |
| 148 | auto &aggregate = aggregates[agg_idx]->Cast<BoundAggregateExpression>(); |
| 149 | |
| 150 | auto matching_inputs = |
| 151 | std::find_if(first: table_inputs.begin(), last: table_inputs.end(), pred: FindMatchingAggregate(std::ref(t&: aggregate))); |
| 152 | if (matching_inputs != table_inputs.end()) { |
| 153 | //! Assign the existing table to the aggregate |
| 154 | idx_t found_idx = std::distance(first: table_inputs.begin(), last: matching_inputs); |
| 155 | table_map[agg_idx] = found_idx; |
| 156 | continue; |
| 157 | } |
| 158 | //! Create a new table and assign its index to the aggregate |
| 159 | table_map[agg_idx] = table_inputs.size(); |
| 160 | table_inputs.push_back(x: std::ref(t&: aggregate)); |
| 161 | } |
| 162 | //! Every distinct aggregate needs to be assigned an index |
| 163 | D_ASSERT(table_map.size() == indices.size()); |
| 164 | //! There can not be more tables than there are distinct aggregates |
| 165 | D_ASSERT(table_inputs.size() <= indices.size()); |
| 166 | |
| 167 | return table_inputs.size(); |
| 168 | } |
| 169 | |
| 170 | bool DistinctAggregateCollectionInfo::AnyDistinct() const { |
| 171 | return !indices.empty(); |
| 172 | } |
| 173 | |
| 174 | const unsafe_vector<idx_t> &DistinctAggregateCollectionInfo::Indices() const { |
| 175 | return this->indices; |
| 176 | } |
| 177 | |
| 178 | static vector<idx_t> GetDistinctIndices(vector<unique_ptr<Expression>> &aggregates) { |
| 179 | vector<idx_t> distinct_indices; |
| 180 | for (idx_t i = 0; i < aggregates.size(); i++) { |
| 181 | auto &aggregate = aggregates[i]; |
| 182 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
| 183 | if (aggr.IsDistinct()) { |
| 184 | distinct_indices.push_back(x: i); |
| 185 | } |
| 186 | } |
| 187 | return distinct_indices; |
| 188 | } |
| 189 | |
| 190 | unique_ptr<DistinctAggregateCollectionInfo> |
| 191 | DistinctAggregateCollectionInfo::Create(vector<unique_ptr<Expression>> &aggregates) { |
| 192 | vector<idx_t> indices = GetDistinctIndices(aggregates); |
| 193 | if (indices.empty()) { |
| 194 | return nullptr; |
| 195 | } |
| 196 | return make_uniq<DistinctAggregateCollectionInfo>(args&: aggregates, args: std::move(indices)); |
| 197 | } |
| 198 | |
| 199 | bool DistinctAggregateData::IsDistinct(idx_t index) const { |
| 200 | bool is_distinct = !radix_tables.empty() && info.table_map.count(x: index); |
| 201 | #ifdef DEBUG |
| 202 | //! Make sure that if it is distinct, it's also in the indices |
| 203 | //! And if it's not distinct, that it's also not in the indices |
| 204 | bool found = false; |
| 205 | for (auto &idx : info.indices) { |
| 206 | if (idx == index) { |
| 207 | found = true; |
| 208 | break; |
| 209 | } |
| 210 | } |
| 211 | D_ASSERT(found == is_distinct); |
| 212 | #endif |
| 213 | return is_distinct; |
| 214 | } |
| 215 | |
| 216 | } // namespace duckdb |
| 217 | |