| 1 | #include <DataTypes/DataTypeString.h> |
| 2 | #include <Columns/ColumnString.h> |
| 3 | #include <Columns/ColumnFixedString.h> |
| 4 | #include <Columns/ColumnConst.h> |
| 5 | #include <Functions/FunctionFactory.h> |
| 6 | #include <Functions/FunctionHelpers.h> |
| 7 | #include <Functions/IFunctionImpl.h> |
| 8 | #include <Functions/GatherUtils/GatherUtils.h> |
| 9 | #include <Functions/GatherUtils/Sources.h> |
| 10 | #include <Functions/GatherUtils/Sinks.h> |
| 11 | #include <Functions/GatherUtils/Slices.h> |
| 12 | #include <Functions/GatherUtils/Algorithms.h> |
| 13 | #include <IO/WriteHelpers.h> |
| 14 | |
| 15 | |
| 16 | namespace DB |
| 17 | { |
| 18 | |
| 19 | using namespace GatherUtils; |
| 20 | |
| 21 | namespace ErrorCodes |
| 22 | { |
| 23 | extern const int ILLEGAL_COLUMN; |
| 24 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
| 25 | extern const int ARGUMENT_OUT_OF_BOUND; |
| 26 | extern const int ZERO_ARRAY_OR_TUPLE_INDEX; |
| 27 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
| 28 | } |
| 29 | |
| 30 | |
| 31 | /// If 'is_utf8' - measure offset and length in code points instead of bytes. |
| 32 | /// UTF8 variant is not available for FixedString arguments. |
| 33 | template <bool is_utf8> |
| 34 | class FunctionSubstring : public IFunction |
| 35 | { |
| 36 | public: |
| 37 | static constexpr auto name = is_utf8 ? "substringUTF8" : "substring" ; |
| 38 | static FunctionPtr create(const Context &) |
| 39 | { |
| 40 | return std::make_shared<FunctionSubstring>(); |
| 41 | } |
| 42 | |
| 43 | String getName() const override |
| 44 | { |
| 45 | return name; |
| 46 | } |
| 47 | |
| 48 | bool isVariadic() const override { return true; } |
| 49 | size_t getNumberOfArguments() const override { return 0; } |
| 50 | |
| 51 | bool useDefaultImplementationForConstants() const override { return true; } |
| 52 | |
| 53 | DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override |
| 54 | { |
| 55 | size_t number_of_arguments = arguments.size(); |
| 56 | |
| 57 | if (number_of_arguments < 2 || number_of_arguments > 3) |
| 58 | throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " |
| 59 | + toString(number_of_arguments) + ", should be 2 or 3" , |
| 60 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
| 61 | |
| 62 | if ((is_utf8 && !isString(arguments[0])) || !isStringOrFixedString(arguments[0])) |
| 63 | throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 64 | |
| 65 | if (!isNativeNumber(arguments[1])) |
| 66 | throw Exception("Illegal type " + arguments[1]->getName() |
| 67 | + " of second argument of function " |
| 68 | + getName(), |
| 69 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 70 | |
| 71 | if (number_of_arguments == 3 && !isNativeNumber(arguments[2])) |
| 72 | throw Exception("Illegal type " + arguments[2]->getName() |
| 73 | + " of second argument of function " |
| 74 | + getName(), |
| 75 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 76 | |
| 77 | return std::make_shared<DataTypeString>(); |
| 78 | } |
| 79 | |
| 80 | template <typename Source> |
| 81 | void executeForSource(const ColumnPtr & column_start, const ColumnPtr & column_length, |
| 82 | const ColumnConst * column_start_const, const ColumnConst * column_length_const, |
| 83 | Int64 start_value, Int64 length_value, Block & block, size_t result, Source && source, |
| 84 | size_t input_rows_count) |
| 85 | { |
| 86 | auto col_res = ColumnString::create(); |
| 87 | |
| 88 | if (!column_length) |
| 89 | { |
| 90 | if (column_start_const) |
| 91 | { |
| 92 | if (start_value > 0) |
| 93 | sliceFromLeftConstantOffsetUnbounded(source, StringSink(*col_res, input_rows_count), start_value - 1); |
| 94 | else if (start_value < 0) |
| 95 | sliceFromRightConstantOffsetUnbounded(source, StringSink(*col_res, input_rows_count), -start_value); |
| 96 | else |
| 97 | throw Exception("Indices in strings are 1-based" , ErrorCodes::ZERO_ARRAY_OR_TUPLE_INDEX); |
| 98 | } |
| 99 | else |
| 100 | sliceDynamicOffsetUnbounded(source, StringSink(*col_res, input_rows_count), *column_start); |
| 101 | } |
| 102 | else |
| 103 | { |
| 104 | if (column_start_const && column_length_const) |
| 105 | { |
| 106 | if (start_value > 0) |
| 107 | sliceFromLeftConstantOffsetBounded(source, StringSink(*col_res, input_rows_count), start_value - 1, length_value); |
| 108 | else if (start_value < 0) |
| 109 | sliceFromRightConstantOffsetBounded(source, StringSink(*col_res, input_rows_count), -start_value, length_value); |
| 110 | else |
| 111 | throw Exception("Indices in strings are 1-based" , ErrorCodes::ZERO_ARRAY_OR_TUPLE_INDEX); |
| 112 | } |
| 113 | else |
| 114 | sliceDynamicOffsetBounded(source, StringSink(*col_res, input_rows_count), *column_start, *column_length); |
| 115 | } |
| 116 | |
| 117 | block.getByPosition(result).column = std::move(col_res); |
| 118 | } |
| 119 | |
| 120 | void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override |
| 121 | { |
| 122 | size_t number_of_arguments = arguments.size(); |
| 123 | |
| 124 | ColumnPtr column_string = block.getByPosition(arguments[0]).column; |
| 125 | ColumnPtr column_start = block.getByPosition(arguments[1]).column; |
| 126 | ColumnPtr column_length; |
| 127 | |
| 128 | if (number_of_arguments == 3) |
| 129 | column_length = block.getByPosition(arguments[2]).column; |
| 130 | |
| 131 | const ColumnConst * column_start_const = checkAndGetColumn<ColumnConst>(column_start.get()); |
| 132 | const ColumnConst * column_length_const = nullptr; |
| 133 | |
| 134 | if (number_of_arguments == 3) |
| 135 | column_length_const = checkAndGetColumn<ColumnConst>(column_length.get()); |
| 136 | |
| 137 | Int64 start_value = 0; |
| 138 | Int64 length_value = 0; |
| 139 | |
| 140 | if (column_start_const) |
| 141 | start_value = column_start_const->getInt(0); |
| 142 | if (column_length_const) |
| 143 | length_value = column_length_const->getInt(0); |
| 144 | |
| 145 | if constexpr (is_utf8) |
| 146 | { |
| 147 | if (const ColumnString * col = checkAndGetColumn<ColumnString>(column_string.get())) |
| 148 | executeForSource(column_start, column_length, column_start_const, column_length_const, start_value, |
| 149 | length_value, block, result, UTF8StringSource(*col), input_rows_count); |
| 150 | else if (const ColumnConst * col_const = checkAndGetColumnConst<ColumnString>(column_string.get())) |
| 151 | executeForSource(column_start, column_length, column_start_const, column_length_const, start_value, |
| 152 | length_value, block, result, ConstSource<UTF8StringSource>(*col_const), input_rows_count); |
| 153 | else |
| 154 | throw Exception( |
| 155 | "Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of first argument of function " + getName(), |
| 156 | ErrorCodes::ILLEGAL_COLUMN); |
| 157 | } |
| 158 | else |
| 159 | { |
| 160 | if (const ColumnString * col = checkAndGetColumn<ColumnString>(column_string.get())) |
| 161 | executeForSource(column_start, column_length, column_start_const, column_length_const, start_value, |
| 162 | length_value, block, result, StringSource(*col), input_rows_count); |
| 163 | else if (const ColumnFixedString * col_fixed = checkAndGetColumn<ColumnFixedString>(column_string.get())) |
| 164 | executeForSource(column_start, column_length, column_start_const, column_length_const, start_value, |
| 165 | length_value, block, result, FixedStringSource(*col_fixed), input_rows_count); |
| 166 | else if (const ColumnConst * col_const = checkAndGetColumnConst<ColumnString>(column_string.get())) |
| 167 | executeForSource(column_start, column_length, column_start_const, column_length_const, start_value, |
| 168 | length_value, block, result, ConstSource<StringSource>(*col_const), input_rows_count); |
| 169 | else if (const ColumnConst * col_const_fixed = checkAndGetColumnConst<ColumnFixedString>(column_string.get())) |
| 170 | executeForSource(column_start, column_length, column_start_const, column_length_const, start_value, |
| 171 | length_value, block, result, ConstSource<FixedStringSource>(*col_const_fixed), input_rows_count); |
| 172 | else |
| 173 | throw Exception( |
| 174 | "Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of first argument of function " + getName(), |
| 175 | ErrorCodes::ILLEGAL_COLUMN); |
| 176 | } |
| 177 | } |
| 178 | }; |
| 179 | |
| 180 | void registerFunctionSubstring(FunctionFactory & factory) |
| 181 | { |
| 182 | factory.registerFunction<FunctionSubstring<false>>(FunctionFactory::CaseInsensitive); |
| 183 | factory.registerAlias("substr" , "substring" , FunctionFactory::CaseInsensitive); |
| 184 | factory.registerAlias("mid" , "substring" , FunctionFactory::CaseInsensitive); /// from MySQL dialect |
| 185 | |
| 186 | factory.registerFunction<FunctionSubstring<true>>(FunctionFactory::CaseSensitive); |
| 187 | } |
| 188 | |
| 189 | } |
| 190 | |