| 1 | #include "duckdb/execution/operator/persistent/physical_update.hpp" |
| 2 | |
| 3 | #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" |
| 4 | #include "duckdb/common/types/column/column_data_collection.hpp" |
| 5 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
| 6 | #include "duckdb/execution/expression_executor.hpp" |
| 7 | #include "duckdb/main/client_context.hpp" |
| 8 | #include "duckdb/parallel/thread_context.hpp" |
| 9 | #include "duckdb/planner/expression/bound_reference_expression.hpp" |
| 10 | #include "duckdb/storage/data_table.hpp" |
| 11 | |
| 12 | namespace duckdb { |
| 13 | |
| 14 | PhysicalUpdate::PhysicalUpdate(vector<LogicalType> types, TableCatalogEntry &tableref, DataTable &table, |
| 15 | vector<PhysicalIndex> columns, vector<unique_ptr<Expression>> expressions, |
| 16 | vector<unique_ptr<Expression>> bound_defaults, idx_t estimated_cardinality, |
| 17 | bool return_chunk) |
| 18 | : PhysicalOperator(PhysicalOperatorType::UPDATE, std::move(types), estimated_cardinality), tableref(tableref), |
| 19 | table(table), columns(std::move(columns)), expressions(std::move(expressions)), |
| 20 | bound_defaults(std::move(bound_defaults)), return_chunk(return_chunk) { |
| 21 | } |
| 22 | |
| 23 | //===--------------------------------------------------------------------===// |
| 24 | // Sink |
| 25 | //===--------------------------------------------------------------------===// |
| 26 | class UpdateGlobalState : public GlobalSinkState { |
| 27 | public: |
| 28 | explicit UpdateGlobalState(ClientContext &context, const vector<LogicalType> &return_types) |
| 29 | : updated_count(0), return_collection(context, return_types) { |
| 30 | } |
| 31 | |
| 32 | mutex lock; |
| 33 | idx_t updated_count; |
| 34 | unordered_set<row_t> updated_columns; |
| 35 | ColumnDataCollection return_collection; |
| 36 | }; |
| 37 | |
| 38 | class UpdateLocalState : public LocalSinkState { |
| 39 | public: |
| 40 | UpdateLocalState(ClientContext &context, const vector<unique_ptr<Expression>> &expressions, |
| 41 | const vector<LogicalType> &table_types, const vector<unique_ptr<Expression>> &bound_defaults) |
| 42 | : default_executor(context, bound_defaults) { |
| 43 | // initialize the update chunk |
| 44 | auto &allocator = Allocator::Get(context); |
| 45 | vector<LogicalType> update_types; |
| 46 | update_types.reserve(n: expressions.size()); |
| 47 | for (auto &expr : expressions) { |
| 48 | update_types.push_back(x: expr->return_type); |
| 49 | } |
| 50 | update_chunk.Initialize(allocator, types: update_types); |
| 51 | // initialize the mock chunk |
| 52 | mock_chunk.Initialize(allocator, types: table_types); |
| 53 | } |
| 54 | |
| 55 | DataChunk update_chunk; |
| 56 | DataChunk mock_chunk; |
| 57 | ExpressionExecutor default_executor; |
| 58 | }; |
| 59 | |
| 60 | SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { |
| 61 | auto &gstate = input.global_state.Cast<UpdateGlobalState>(); |
| 62 | auto &lstate = input.local_state.Cast<UpdateLocalState>(); |
| 63 | |
| 64 | DataChunk &update_chunk = lstate.update_chunk; |
| 65 | DataChunk &mock_chunk = lstate.mock_chunk; |
| 66 | |
| 67 | chunk.Flatten(); |
| 68 | lstate.default_executor.SetChunk(chunk); |
| 69 | |
| 70 | // update data in the base table |
| 71 | // the row ids are given to us as the last column of the child chunk |
| 72 | auto &row_ids = chunk.data[chunk.ColumnCount() - 1]; |
| 73 | update_chunk.Reset(); |
| 74 | update_chunk.SetCardinality(chunk); |
| 75 | |
| 76 | for (idx_t i = 0; i < expressions.size(); i++) { |
| 77 | if (expressions[i]->type == ExpressionType::VALUE_DEFAULT) { |
| 78 | // default expression, set to the default value of the column |
| 79 | lstate.default_executor.ExecuteExpression(expr_idx: columns[i].index, result&: update_chunk.data[i]); |
| 80 | } else { |
| 81 | D_ASSERT(expressions[i]->type == ExpressionType::BOUND_REF); |
| 82 | // index into child chunk |
| 83 | auto &binding = expressions[i]->Cast<BoundReferenceExpression>(); |
| 84 | update_chunk.data[i].Reference(other&: chunk.data[binding.index]); |
| 85 | } |
| 86 | } |
| 87 | |
| 88 | lock_guard<mutex> glock(gstate.lock); |
| 89 | if (update_is_del_and_insert) { |
| 90 | // index update or update on complex type, perform a delete and an append instead |
| 91 | |
| 92 | // figure out which rows have not yet been deleted in this update |
| 93 | // this is required since we might see the same row_id multiple times |
| 94 | // in the case of an UPDATE query that e.g. has joins |
| 95 | auto row_id_data = FlatVector::GetData<row_t>(vector&: row_ids); |
| 96 | SelectionVector sel(STANDARD_VECTOR_SIZE); |
| 97 | idx_t update_count = 0; |
| 98 | for (idx_t i = 0; i < update_chunk.size(); i++) { |
| 99 | auto row_id = row_id_data[i]; |
| 100 | if (gstate.updated_columns.find(x: row_id) == gstate.updated_columns.end()) { |
| 101 | gstate.updated_columns.insert(x: row_id); |
| 102 | sel.set_index(idx: update_count++, loc: i); |
| 103 | } |
| 104 | } |
| 105 | if (update_count != update_chunk.size()) { |
| 106 | // we need to slice here |
| 107 | update_chunk.Slice(sel_vector: sel, count: update_count); |
| 108 | } |
| 109 | table.Delete(table&: tableref, context&: context.client, row_ids, count: update_chunk.size()); |
| 110 | // for the append we need to arrange the columns in a specific manner (namely the "standard table order") |
| 111 | mock_chunk.SetCardinality(update_chunk); |
| 112 | for (idx_t i = 0; i < columns.size(); i++) { |
| 113 | mock_chunk.data[columns[i].index].Reference(other&: update_chunk.data[i]); |
| 114 | } |
| 115 | table.LocalAppend(table&: tableref, context&: context.client, chunk&: mock_chunk); |
| 116 | } else { |
| 117 | if (return_chunk) { |
| 118 | mock_chunk.SetCardinality(update_chunk); |
| 119 | for (idx_t i = 0; i < columns.size(); i++) { |
| 120 | mock_chunk.data[columns[i].index].Reference(other&: update_chunk.data[i]); |
| 121 | } |
| 122 | } |
| 123 | table.Update(table&: tableref, context&: context.client, row_ids, column_ids: columns, data&: update_chunk); |
| 124 | } |
| 125 | |
| 126 | if (return_chunk) { |
| 127 | gstate.return_collection.Append(new_chunk&: mock_chunk); |
| 128 | } |
| 129 | |
| 130 | gstate.updated_count += chunk.size(); |
| 131 | |
| 132 | return SinkResultType::NEED_MORE_INPUT; |
| 133 | } |
| 134 | |
| 135 | unique_ptr<GlobalSinkState> PhysicalUpdate::GetGlobalSinkState(ClientContext &context) const { |
| 136 | return make_uniq<UpdateGlobalState>(args&: context, args: GetTypes()); |
| 137 | } |
| 138 | |
| 139 | unique_ptr<LocalSinkState> PhysicalUpdate::GetLocalSinkState(ExecutionContext &context) const { |
| 140 | return make_uniq<UpdateLocalState>(args&: context.client, args: expressions, args: table.GetTypes(), args: bound_defaults); |
| 141 | } |
| 142 | |
| 143 | void PhysicalUpdate::Combine(ExecutionContext &context, GlobalSinkState &gstate, LocalSinkState &lstate) const { |
| 144 | auto &state = lstate.Cast<UpdateLocalState>(); |
| 145 | auto &client_profiler = QueryProfiler::Get(context&: context.client); |
| 146 | context.thread.profiler.Flush(phys_op: *this, expression_executor&: state.default_executor, name: "default_executor" , id: 1); |
| 147 | client_profiler.Flush(profiler&: context.thread.profiler); |
| 148 | } |
| 149 | |
| 150 | //===--------------------------------------------------------------------===// |
| 151 | // Source |
| 152 | //===--------------------------------------------------------------------===// |
| 153 | class UpdateSourceState : public GlobalSourceState { |
| 154 | public: |
| 155 | explicit UpdateSourceState(const PhysicalUpdate &op) { |
| 156 | if (op.return_chunk) { |
| 157 | D_ASSERT(op.sink_state); |
| 158 | auto &g = op.sink_state->Cast<UpdateGlobalState>(); |
| 159 | g.return_collection.InitializeScan(state&: scan_state); |
| 160 | } |
| 161 | } |
| 162 | |
| 163 | ColumnDataScanState scan_state; |
| 164 | }; |
| 165 | |
| 166 | unique_ptr<GlobalSourceState> PhysicalUpdate::GetGlobalSourceState(ClientContext &context) const { |
| 167 | return make_uniq<UpdateSourceState>(args: *this); |
| 168 | } |
| 169 | |
| 170 | SourceResultType PhysicalUpdate::GetData(ExecutionContext &context, DataChunk &chunk, |
| 171 | OperatorSourceInput &input) const { |
| 172 | auto &state = input.global_state.Cast<UpdateSourceState>(); |
| 173 | auto &g = sink_state->Cast<UpdateGlobalState>(); |
| 174 | if (!return_chunk) { |
| 175 | chunk.SetCardinality(1); |
| 176 | chunk.SetValue(col_idx: 0, index: 0, val: Value::BIGINT(value: g.updated_count)); |
| 177 | return SourceResultType::FINISHED; |
| 178 | } |
| 179 | |
| 180 | g.return_collection.Scan(state&: state.scan_state, result&: chunk); |
| 181 | |
| 182 | return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; |
| 183 | } |
| 184 | |
| 185 | } // namespace duckdb |
| 186 | |