| 1 | #include "duckdb/common/field_writer.hpp" |
| 2 | #include "duckdb/planner/expression/bound_cast_expression.hpp" |
| 3 | #include "duckdb/planner/expression/bound_default_expression.hpp" |
| 4 | #include "duckdb/planner/expression/bound_parameter_expression.hpp" |
| 5 | #include "duckdb/function/cast_rules.hpp" |
| 6 | #include "duckdb/function/cast/cast_function_set.hpp" |
| 7 | #include "duckdb/main/config.hpp" |
| 8 | |
| 9 | namespace duckdb { |
| 10 | |
| 11 | BoundCastExpression::BoundCastExpression(unique_ptr<Expression> child_p, LogicalType target_type_p, |
| 12 | BoundCastInfo bound_cast_p, bool try_cast_p) |
| 13 | : Expression(ExpressionType::OPERATOR_CAST, ExpressionClass::BOUND_CAST, std::move(target_type_p)), |
| 14 | child(std::move(child_p)), try_cast(try_cast_p), bound_cast(std::move(bound_cast_p)) { |
| 15 | } |
| 16 | |
| 17 | unique_ptr<Expression> AddCastExpressionInternal(unique_ptr<Expression> expr, const LogicalType &target_type, |
| 18 | BoundCastInfo bound_cast, bool try_cast) { |
| 19 | if (expr->return_type == target_type) { |
| 20 | return expr; |
| 21 | } |
| 22 | auto &expr_type = expr->return_type; |
| 23 | if (target_type.id() == LogicalTypeId::LIST && expr_type.id() == LogicalTypeId::LIST) { |
| 24 | auto &target_list = ListType::GetChildType(type: target_type); |
| 25 | auto &expr_list = ListType::GetChildType(type: expr_type); |
| 26 | if (target_list.id() == LogicalTypeId::ANY || expr_list == target_list) { |
| 27 | return expr; |
| 28 | } |
| 29 | } |
| 30 | return make_uniq<BoundCastExpression>(args: std::move(expr), args: target_type, args: std::move(bound_cast), args&: try_cast); |
| 31 | } |
| 32 | |
| 33 | static BoundCastInfo BindCastFunction(ClientContext &context, const LogicalType &source, const LogicalType &target) { |
| 34 | auto &cast_functions = DBConfig::GetConfig(context).GetCastFunctions(); |
| 35 | GetCastFunctionInput input(context); |
| 36 | return cast_functions.GetCastFunction(source, target, input); |
| 37 | } |
| 38 | |
| 39 | unique_ptr<Expression> AddCastToTypeInternal(unique_ptr<Expression> expr, const LogicalType &target_type, |
| 40 | CastFunctionSet &cast_functions, GetCastFunctionInput &get_input, |
| 41 | bool try_cast) { |
| 42 | D_ASSERT(expr); |
| 43 | if (expr->expression_class == ExpressionClass::BOUND_PARAMETER) { |
| 44 | auto ¶meter = expr->Cast<BoundParameterExpression>(); |
| 45 | if (!target_type.IsValid()) { |
| 46 | // invalidate the parameter |
| 47 | parameter.parameter_data->return_type = LogicalType::INVALID; |
| 48 | parameter.return_type = target_type; |
| 49 | return expr; |
| 50 | } |
| 51 | if (parameter.parameter_data->return_type.id() == LogicalTypeId::INVALID) { |
| 52 | // we don't know the type of this parameter |
| 53 | parameter.return_type = target_type; |
| 54 | return expr; |
| 55 | } |
| 56 | if (parameter.parameter_data->return_type.id() == LogicalTypeId::UNKNOWN) { |
| 57 | // prepared statement parameter cast - but there is no type, convert the type |
| 58 | parameter.parameter_data->return_type = target_type; |
| 59 | parameter.return_type = target_type; |
| 60 | return expr; |
| 61 | } |
| 62 | // prepared statement parameter already has a type |
| 63 | if (parameter.parameter_data->return_type == target_type) { |
| 64 | // this type! we are done |
| 65 | parameter.return_type = parameter.parameter_data->return_type; |
| 66 | return expr; |
| 67 | } |
| 68 | // invalidate the type |
| 69 | parameter.parameter_data->return_type = LogicalType::INVALID; |
| 70 | parameter.return_type = target_type; |
| 71 | return expr; |
| 72 | } else if (expr->expression_class == ExpressionClass::BOUND_DEFAULT) { |
| 73 | D_ASSERT(target_type.IsValid()); |
| 74 | auto &def = expr->Cast<BoundDefaultExpression>(); |
| 75 | def.return_type = target_type; |
| 76 | } |
| 77 | if (!target_type.IsValid()) { |
| 78 | return expr; |
| 79 | } |
| 80 | |
| 81 | auto cast_function = cast_functions.GetCastFunction(source: expr->return_type, target: target_type, input&: get_input); |
| 82 | return AddCastExpressionInternal(expr: std::move(expr), target_type, bound_cast: std::move(cast_function), try_cast); |
| 83 | } |
| 84 | |
| 85 | unique_ptr<Expression> BoundCastExpression::AddDefaultCastToType(unique_ptr<Expression> expr, |
| 86 | const LogicalType &target_type, bool try_cast) { |
| 87 | CastFunctionSet default_set; |
| 88 | GetCastFunctionInput get_input; |
| 89 | return AddCastToTypeInternal(expr: std::move(expr), target_type, cast_functions&: default_set, get_input, try_cast); |
| 90 | } |
| 91 | |
| 92 | unique_ptr<Expression> BoundCastExpression::AddCastToType(ClientContext &context, unique_ptr<Expression> expr, |
| 93 | const LogicalType &target_type, bool try_cast) { |
| 94 | auto &cast_functions = DBConfig::GetConfig(context).GetCastFunctions(); |
| 95 | GetCastFunctionInput get_input(context); |
| 96 | return AddCastToTypeInternal(expr: std::move(expr), target_type, cast_functions, get_input, try_cast); |
| 97 | } |
| 98 | |
| 99 | bool BoundCastExpression::CastIsInvertible(const LogicalType &source_type, const LogicalType &target_type) { |
| 100 | D_ASSERT(source_type.IsValid() && target_type.IsValid()); |
| 101 | if (source_type.id() == LogicalTypeId::BOOLEAN || target_type.id() == LogicalTypeId::BOOLEAN) { |
| 102 | return false; |
| 103 | } |
| 104 | if (source_type.id() == LogicalTypeId::FLOAT || target_type.id() == LogicalTypeId::FLOAT) { |
| 105 | return false; |
| 106 | } |
| 107 | if (source_type.id() == LogicalTypeId::DOUBLE || target_type.id() == LogicalTypeId::DOUBLE) { |
| 108 | return false; |
| 109 | } |
| 110 | if (source_type.id() == LogicalTypeId::DECIMAL || target_type.id() == LogicalTypeId::DECIMAL) { |
| 111 | uint8_t source_width, target_width; |
| 112 | uint8_t source_scale, target_scale; |
| 113 | // cast to or from decimal |
| 114 | // cast is only invertible if the cast is strictly widening |
| 115 | if (!source_type.GetDecimalProperties(width&: source_width, scale&: source_scale)) { |
| 116 | return false; |
| 117 | } |
| 118 | if (!target_type.GetDecimalProperties(width&: target_width, scale&: target_scale)) { |
| 119 | return false; |
| 120 | } |
| 121 | if (target_scale < source_scale) { |
| 122 | return false; |
| 123 | } |
| 124 | return true; |
| 125 | } |
| 126 | if (source_type.id() == LogicalTypeId::TIMESTAMP || source_type.id() == LogicalTypeId::TIMESTAMP_TZ) { |
| 127 | switch (target_type.id()) { |
| 128 | case LogicalTypeId::DATE: |
| 129 | case LogicalTypeId::TIME: |
| 130 | case LogicalTypeId::TIME_TZ: |
| 131 | return false; |
| 132 | default: |
| 133 | break; |
| 134 | } |
| 135 | } |
| 136 | if (source_type.id() == LogicalTypeId::VARCHAR) { |
| 137 | switch (target_type.id()) { |
| 138 | case LogicalTypeId::TIME: |
| 139 | case LogicalTypeId::TIMESTAMP: |
| 140 | case LogicalTypeId::TIMESTAMP_NS: |
| 141 | case LogicalTypeId::TIMESTAMP_MS: |
| 142 | case LogicalTypeId::TIMESTAMP_SEC: |
| 143 | case LogicalTypeId::TIME_TZ: |
| 144 | case LogicalTypeId::TIMESTAMP_TZ: |
| 145 | return true; |
| 146 | default: |
| 147 | return false; |
| 148 | } |
| 149 | } |
| 150 | if (target_type.id() == LogicalTypeId::VARCHAR) { |
| 151 | switch (source_type.id()) { |
| 152 | case LogicalTypeId::DATE: |
| 153 | case LogicalTypeId::TIME: |
| 154 | case LogicalTypeId::TIMESTAMP: |
| 155 | case LogicalTypeId::TIMESTAMP_NS: |
| 156 | case LogicalTypeId::TIMESTAMP_MS: |
| 157 | case LogicalTypeId::TIMESTAMP_SEC: |
| 158 | case LogicalTypeId::TIME_TZ: |
| 159 | case LogicalTypeId::TIMESTAMP_TZ: |
| 160 | return true; |
| 161 | default: |
| 162 | return false; |
| 163 | } |
| 164 | } |
| 165 | return true; |
| 166 | } |
| 167 | |
| 168 | string BoundCastExpression::ToString() const { |
| 169 | return (try_cast ? "TRY_CAST(" : "CAST(" ) + child->GetName() + " AS " + return_type.ToString() + ")" ; |
| 170 | } |
| 171 | |
| 172 | bool BoundCastExpression::Equals(const BaseExpression &other_p) const { |
| 173 | if (!Expression::Equals(other: other_p)) { |
| 174 | return false; |
| 175 | } |
| 176 | auto &other = other_p.Cast<BoundCastExpression>(); |
| 177 | if (!Expression::Equals(left: *child, right: *other.child)) { |
| 178 | return false; |
| 179 | } |
| 180 | if (try_cast != other.try_cast) { |
| 181 | return false; |
| 182 | } |
| 183 | return true; |
| 184 | } |
| 185 | |
| 186 | unique_ptr<Expression> BoundCastExpression::Copy() { |
| 187 | auto copy = make_uniq<BoundCastExpression>(args: child->Copy(), args&: return_type, args: bound_cast.Copy(), args&: try_cast); |
| 188 | copy->CopyProperties(other&: *this); |
| 189 | return std::move(copy); |
| 190 | } |
| 191 | |
| 192 | void BoundCastExpression::Serialize(FieldWriter &writer) const { |
| 193 | writer.WriteSerializable(element: *child); |
| 194 | writer.WriteSerializable(element: return_type); |
| 195 | writer.WriteField(element: try_cast); |
| 196 | } |
| 197 | |
| 198 | unique_ptr<Expression> BoundCastExpression::Deserialize(ExpressionDeserializationState &state, FieldReader &reader) { |
| 199 | auto child = reader.ReadRequiredSerializable<Expression>(args&: state.gstate); |
| 200 | auto target_type = reader.ReadRequiredSerializable<LogicalType, LogicalType>(); |
| 201 | auto try_cast = reader.ReadRequired<bool>(); |
| 202 | auto cast_function = BindCastFunction(context&: state.gstate.context, source: child->return_type, target: target_type); |
| 203 | return make_uniq<BoundCastExpression>(args: std::move(child), args: std::move(target_type), args: std::move(cast_function), args&: try_cast); |
| 204 | } |
| 205 | |
| 206 | } // namespace duckdb |
| 207 | |