| 1 | // Copyright (c) 2018 Google LLC. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | #include "source/opcode.h" |
| 16 | #include "source/val/instruction.h" |
| 17 | #include "source/val/validate.h" |
| 18 | #include "source/val/validation_state.h" |
| 19 | |
| 20 | namespace spvtools { |
| 21 | namespace val { |
| 22 | namespace { |
| 23 | |
| 24 | spv_result_t ValidateConstantBool(ValidationState_t& _, |
| 25 | const Instruction* inst) { |
| 26 | auto type = _.FindDef(inst->type_id()); |
| 27 | if (!type || type->opcode() != SpvOpTypeBool) { |
| 28 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 29 | << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> '" |
| 30 | << _.getIdName(inst->type_id()) << "' is not a boolean type." ; |
| 31 | } |
| 32 | |
| 33 | return SPV_SUCCESS; |
| 34 | } |
| 35 | |
| 36 | spv_result_t ValidateConstantComposite(ValidationState_t& _, |
| 37 | const Instruction* inst) { |
| 38 | std::string opcode_name = std::string("Op" ) + spvOpcodeString(inst->opcode()); |
| 39 | |
| 40 | const auto result_type = _.FindDef(inst->type_id()); |
| 41 | if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) { |
| 42 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 43 | << opcode_name << " Result Type <id> '" |
| 44 | << _.getIdName(inst->type_id()) << "' is not a composite type." ; |
| 45 | } |
| 46 | |
| 47 | const auto constituent_count = inst->words().size() - 3; |
| 48 | switch (result_type->opcode()) { |
| 49 | case SpvOpTypeVector: { |
| 50 | const auto component_count = result_type->GetOperandAs<uint32_t>(2); |
| 51 | if (component_count != constituent_count) { |
| 52 | // TODO: Output ID's on diagnostic |
| 53 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 54 | << opcode_name |
| 55 | << " Constituent <id> count does not match " |
| 56 | "Result Type <id> '" |
| 57 | << _.getIdName(result_type->id()) |
| 58 | << "'s vector component count." ; |
| 59 | } |
| 60 | const auto component_type = |
| 61 | _.FindDef(result_type->GetOperandAs<uint32_t>(1)); |
| 62 | if (!component_type) { |
| 63 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
| 64 | << "Component type is not defined." ; |
| 65 | } |
| 66 | for (size_t constituent_index = 2; |
| 67 | constituent_index < inst->operands().size(); constituent_index++) { |
| 68 | const auto constituent_id = |
| 69 | inst->GetOperandAs<uint32_t>(constituent_index); |
| 70 | const auto constituent = _.FindDef(constituent_id); |
| 71 | if (!constituent || |
| 72 | !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
| 73 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 74 | << opcode_name << " Constituent <id> '" |
| 75 | << _.getIdName(constituent_id) |
| 76 | << "' is not a constant or undef." ; |
| 77 | } |
| 78 | const auto constituent_result_type = _.FindDef(constituent->type_id()); |
| 79 | if (!constituent_result_type || |
| 80 | component_type->opcode() != constituent_result_type->opcode()) { |
| 81 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 82 | << opcode_name << " Constituent <id> '" |
| 83 | << _.getIdName(constituent_id) |
| 84 | << "'s type does not match Result Type <id> '" |
| 85 | << _.getIdName(result_type->id()) << "'s vector element type." ; |
| 86 | } |
| 87 | } |
| 88 | } break; |
| 89 | case SpvOpTypeMatrix: { |
| 90 | const auto column_count = result_type->GetOperandAs<uint32_t>(2); |
| 91 | if (column_count != constituent_count) { |
| 92 | // TODO: Output ID's on diagnostic |
| 93 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 94 | << opcode_name |
| 95 | << " Constituent <id> count does not match " |
| 96 | "Result Type <id> '" |
| 97 | << _.getIdName(result_type->id()) << "'s matrix column count." ; |
| 98 | } |
| 99 | |
| 100 | const auto column_type = _.FindDef(result_type->words()[2]); |
| 101 | if (!column_type) { |
| 102 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
| 103 | << "Column type is not defined." ; |
| 104 | } |
| 105 | const auto component_count = column_type->GetOperandAs<uint32_t>(2); |
| 106 | const auto component_type = |
| 107 | _.FindDef(column_type->GetOperandAs<uint32_t>(1)); |
| 108 | if (!component_type) { |
| 109 | return _.diag(SPV_ERROR_INVALID_ID, column_type) |
| 110 | << "Component type is not defined." ; |
| 111 | } |
| 112 | |
| 113 | for (size_t constituent_index = 2; |
| 114 | constituent_index < inst->operands().size(); constituent_index++) { |
| 115 | const auto constituent_id = |
| 116 | inst->GetOperandAs<uint32_t>(constituent_index); |
| 117 | const auto constituent = _.FindDef(constituent_id); |
| 118 | if (!constituent || |
| 119 | !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
| 120 | // The message says "... or undef" because the spec does not say |
| 121 | // undef is a constant. |
| 122 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 123 | << opcode_name << " Constituent <id> '" |
| 124 | << _.getIdName(constituent_id) |
| 125 | << "' is not a constant or undef." ; |
| 126 | } |
| 127 | const auto vector = _.FindDef(constituent->type_id()); |
| 128 | if (!vector) { |
| 129 | return _.diag(SPV_ERROR_INVALID_ID, constituent) |
| 130 | << "Result type is not defined." ; |
| 131 | } |
| 132 | if (column_type->opcode() != vector->opcode()) { |
| 133 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 134 | << opcode_name << " Constituent <id> '" |
| 135 | << _.getIdName(constituent_id) |
| 136 | << "' type does not match Result Type <id> '" |
| 137 | << _.getIdName(result_type->id()) << "'s matrix column type." ; |
| 138 | } |
| 139 | const auto vector_component_type = |
| 140 | _.FindDef(vector->GetOperandAs<uint32_t>(1)); |
| 141 | if (component_type->id() != vector_component_type->id()) { |
| 142 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 143 | << opcode_name << " Constituent <id> '" |
| 144 | << _.getIdName(constituent_id) |
| 145 | << "' component type does not match Result Type <id> '" |
| 146 | << _.getIdName(result_type->id()) |
| 147 | << "'s matrix column component type." ; |
| 148 | } |
| 149 | if (component_count != vector->words()[3]) { |
| 150 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 151 | << opcode_name << " Constituent <id> '" |
| 152 | << _.getIdName(constituent_id) |
| 153 | << "' vector component count does not match Result Type <id> '" |
| 154 | << _.getIdName(result_type->id()) |
| 155 | << "'s vector component count." ; |
| 156 | } |
| 157 | } |
| 158 | } break; |
| 159 | case SpvOpTypeArray: { |
| 160 | auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1)); |
| 161 | if (!element_type) { |
| 162 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
| 163 | << "Element type is not defined." ; |
| 164 | } |
| 165 | const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2)); |
| 166 | if (!length) { |
| 167 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
| 168 | << "Length is not defined." ; |
| 169 | } |
| 170 | bool is_int32; |
| 171 | bool is_const; |
| 172 | uint32_t value; |
| 173 | std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id()); |
| 174 | if (is_int32 && is_const && value != constituent_count) { |
| 175 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 176 | << opcode_name |
| 177 | << " Constituent count does not match " |
| 178 | "Result Type <id> '" |
| 179 | << _.getIdName(result_type->id()) << "'s array length." ; |
| 180 | } |
| 181 | for (size_t constituent_index = 2; |
| 182 | constituent_index < inst->operands().size(); constituent_index++) { |
| 183 | const auto constituent_id = |
| 184 | inst->GetOperandAs<uint32_t>(constituent_index); |
| 185 | const auto constituent = _.FindDef(constituent_id); |
| 186 | if (!constituent || |
| 187 | !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
| 188 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 189 | << opcode_name << " Constituent <id> '" |
| 190 | << _.getIdName(constituent_id) |
| 191 | << "' is not a constant or undef." ; |
| 192 | } |
| 193 | const auto constituent_type = _.FindDef(constituent->type_id()); |
| 194 | if (!constituent_type) { |
| 195 | return _.diag(SPV_ERROR_INVALID_ID, constituent) |
| 196 | << "Result type is not defined." ; |
| 197 | } |
| 198 | if (element_type->id() != constituent_type->id()) { |
| 199 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 200 | << opcode_name << " Constituent <id> '" |
| 201 | << _.getIdName(constituent_id) |
| 202 | << "'s type does not match Result Type <id> '" |
| 203 | << _.getIdName(result_type->id()) << "'s array element type." ; |
| 204 | } |
| 205 | } |
| 206 | } break; |
| 207 | case SpvOpTypeStruct: { |
| 208 | const auto member_count = result_type->words().size() - 2; |
| 209 | if (member_count != constituent_count) { |
| 210 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 211 | << opcode_name << " Constituent <id> '" |
| 212 | << _.getIdName(inst->type_id()) |
| 213 | << "' count does not match Result Type <id> '" |
| 214 | << _.getIdName(result_type->id()) << "'s struct member count." ; |
| 215 | } |
| 216 | for (uint32_t constituent_index = 2, member_index = 1; |
| 217 | constituent_index < inst->operands().size(); |
| 218 | constituent_index++, member_index++) { |
| 219 | const auto constituent_id = |
| 220 | inst->GetOperandAs<uint32_t>(constituent_index); |
| 221 | const auto constituent = _.FindDef(constituent_id); |
| 222 | if (!constituent || |
| 223 | !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
| 224 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 225 | << opcode_name << " Constituent <id> '" |
| 226 | << _.getIdName(constituent_id) |
| 227 | << "' is not a constant or undef." ; |
| 228 | } |
| 229 | const auto constituent_type = _.FindDef(constituent->type_id()); |
| 230 | if (!constituent_type) { |
| 231 | return _.diag(SPV_ERROR_INVALID_ID, constituent) |
| 232 | << "Result type is not defined." ; |
| 233 | } |
| 234 | |
| 235 | const auto member_type_id = |
| 236 | result_type->GetOperandAs<uint32_t>(member_index); |
| 237 | const auto member_type = _.FindDef(member_type_id); |
| 238 | if (!member_type || member_type->id() != constituent_type->id()) { |
| 239 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 240 | << opcode_name << " Constituent <id> '" |
| 241 | << _.getIdName(constituent_id) |
| 242 | << "' type does not match the Result Type <id> '" |
| 243 | << _.getIdName(result_type->id()) << "'s member type." ; |
| 244 | } |
| 245 | } |
| 246 | } break; |
| 247 | case SpvOpTypeCooperativeMatrixNV: { |
| 248 | if (1 != constituent_count) { |
| 249 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 250 | << opcode_name << " Constituent <id> '" |
| 251 | << _.getIdName(inst->type_id()) << "' count must be one." ; |
| 252 | } |
| 253 | const auto constituent_id = inst->GetOperandAs<uint32_t>(2); |
| 254 | const auto constituent = _.FindDef(constituent_id); |
| 255 | if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
| 256 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 257 | << opcode_name << " Constituent <id> '" |
| 258 | << _.getIdName(constituent_id) |
| 259 | << "' is not a constant or undef." ; |
| 260 | } |
| 261 | const auto constituent_type = _.FindDef(constituent->type_id()); |
| 262 | if (!constituent_type) { |
| 263 | return _.diag(SPV_ERROR_INVALID_ID, constituent) |
| 264 | << "Result type is not defined." ; |
| 265 | } |
| 266 | |
| 267 | const auto component_type_id = result_type->GetOperandAs<uint32_t>(1); |
| 268 | const auto component_type = _.FindDef(component_type_id); |
| 269 | if (!component_type || component_type->id() != constituent_type->id()) { |
| 270 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 271 | << opcode_name << " Constituent <id> '" |
| 272 | << _.getIdName(constituent_id) |
| 273 | << "' type does not match the Result Type <id> '" |
| 274 | << _.getIdName(result_type->id()) << "'s component type." ; |
| 275 | } |
| 276 | } break; |
| 277 | default: |
| 278 | break; |
| 279 | } |
| 280 | return SPV_SUCCESS; |
| 281 | } |
| 282 | |
| 283 | spv_result_t ValidateConstantSampler(ValidationState_t& _, |
| 284 | const Instruction* inst) { |
| 285 | const auto result_type = _.FindDef(inst->type_id()); |
| 286 | if (!result_type || result_type->opcode() != SpvOpTypeSampler) { |
| 287 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
| 288 | << "OpConstantSampler Result Type <id> '" |
| 289 | << _.getIdName(inst->type_id()) << "' is not a sampler type." ; |
| 290 | } |
| 291 | |
| 292 | return SPV_SUCCESS; |
| 293 | } |
| 294 | |
| 295 | // True if instruction defines a type that can have a null value, as defined by |
| 296 | // the SPIR-V spec. Tracks composite-type components through module to check |
| 297 | // nullability transitively. |
| 298 | bool IsTypeNullable(const std::vector<uint32_t>& instruction, |
| 299 | const ValidationState_t& _) { |
| 300 | uint16_t opcode; |
| 301 | uint16_t word_count; |
| 302 | spvOpcodeSplit(instruction[0], &word_count, &opcode); |
| 303 | switch (static_cast<SpvOp>(opcode)) { |
| 304 | case SpvOpTypeBool: |
| 305 | case SpvOpTypeInt: |
| 306 | case SpvOpTypeFloat: |
| 307 | case SpvOpTypeEvent: |
| 308 | case SpvOpTypeDeviceEvent: |
| 309 | case SpvOpTypeReserveId: |
| 310 | case SpvOpTypeQueue: |
| 311 | return true; |
| 312 | case SpvOpTypeArray: |
| 313 | case SpvOpTypeMatrix: |
| 314 | case SpvOpTypeCooperativeMatrixNV: |
| 315 | case SpvOpTypeVector: { |
| 316 | auto base_type = _.FindDef(instruction[2]); |
| 317 | return base_type && IsTypeNullable(base_type->words(), _); |
| 318 | } |
| 319 | case SpvOpTypeStruct: { |
| 320 | for (size_t elementIndex = 2; elementIndex < instruction.size(); |
| 321 | ++elementIndex) { |
| 322 | auto element = _.FindDef(instruction[elementIndex]); |
| 323 | if (!element || !IsTypeNullable(element->words(), _)) return false; |
| 324 | } |
| 325 | return true; |
| 326 | } |
| 327 | case SpvOpTypePointer: |
| 328 | if (instruction[2] == SpvStorageClassPhysicalStorageBuffer) { |
| 329 | return false; |
| 330 | } |
| 331 | return true; |
| 332 | default: |
| 333 | return false; |
| 334 | } |
| 335 | } |
| 336 | |
| 337 | spv_result_t ValidateConstantNull(ValidationState_t& _, |
| 338 | const Instruction* inst) { |
| 339 | const auto result_type = _.FindDef(inst->type_id()); |
| 340 | if (!result_type || !IsTypeNullable(result_type->words(), _)) { |
| 341 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 342 | << "OpConstantNull Result Type <id> '" |
| 343 | << _.getIdName(inst->type_id()) << "' cannot have a null value." ; |
| 344 | } |
| 345 | |
| 346 | return SPV_SUCCESS; |
| 347 | } |
| 348 | |
| 349 | // Validates that OpSpecConstant specializes to either int or float type. |
| 350 | spv_result_t ValidateSpecConstant(ValidationState_t& _, |
| 351 | const Instruction* inst) { |
| 352 | // Operand 0 is the <id> of the type that we're specializing to. |
| 353 | auto type_id = inst->GetOperandAs<const uint32_t>(0); |
| 354 | auto type_instruction = _.FindDef(type_id); |
| 355 | auto type_opcode = type_instruction->opcode(); |
| 356 | if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) { |
| 357 | return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant " |
| 358 | "must be an integer or " |
| 359 | "floating-point number." ; |
| 360 | } |
| 361 | return SPV_SUCCESS; |
| 362 | } |
| 363 | |
| 364 | spv_result_t ValidateSpecConstantOp(ValidationState_t& _, |
| 365 | const Instruction* inst) { |
| 366 | const auto op = inst->GetOperandAs<SpvOp>(2); |
| 367 | |
| 368 | // The binary parser already ensures that the op is valid for *some* |
| 369 | // environment. Here we check restrictions. |
| 370 | switch (op) { |
| 371 | case SpvOpQuantizeToF16: |
| 372 | if (!_.HasCapability(SpvCapabilityShader)) { |
| 373 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 374 | << "Specialization constant operation " << spvOpcodeString(op) |
| 375 | << " requires Shader capability" ; |
| 376 | } |
| 377 | break; |
| 378 | |
| 379 | case SpvOpUConvert: |
| 380 | if (!_.features().uconvert_spec_constant_op && |
| 381 | !_.HasCapability(SpvCapabilityKernel)) { |
| 382 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 383 | << "Prior to SPIR-V 1.4, specialization constant operation " |
| 384 | "UConvert requires Kernel capability or extension " |
| 385 | "SPV_AMD_gpu_shader_int16" ; |
| 386 | } |
| 387 | break; |
| 388 | |
| 389 | case SpvOpConvertFToS: |
| 390 | case SpvOpConvertSToF: |
| 391 | case SpvOpConvertFToU: |
| 392 | case SpvOpConvertUToF: |
| 393 | case SpvOpConvertPtrToU: |
| 394 | case SpvOpConvertUToPtr: |
| 395 | case SpvOpGenericCastToPtr: |
| 396 | case SpvOpPtrCastToGeneric: |
| 397 | case SpvOpBitcast: |
| 398 | case SpvOpFNegate: |
| 399 | case SpvOpFAdd: |
| 400 | case SpvOpFSub: |
| 401 | case SpvOpFMul: |
| 402 | case SpvOpFDiv: |
| 403 | case SpvOpFRem: |
| 404 | case SpvOpFMod: |
| 405 | case SpvOpAccessChain: |
| 406 | case SpvOpInBoundsAccessChain: |
| 407 | case SpvOpPtrAccessChain: |
| 408 | case SpvOpInBoundsPtrAccessChain: |
| 409 | if (!_.HasCapability(SpvCapabilityKernel)) { |
| 410 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 411 | << "Specialization constant operation " << spvOpcodeString(op) |
| 412 | << " requires Kernel capability" ; |
| 413 | } |
| 414 | break; |
| 415 | |
| 416 | default: |
| 417 | break; |
| 418 | } |
| 419 | |
| 420 | // TODO(dneto): Validate result type and arguments to the various operations. |
| 421 | return SPV_SUCCESS; |
| 422 | } |
| 423 | |
| 424 | } // namespace |
| 425 | |
| 426 | spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) { |
| 427 | switch (inst->opcode()) { |
| 428 | case SpvOpConstantTrue: |
| 429 | case SpvOpConstantFalse: |
| 430 | case SpvOpSpecConstantTrue: |
| 431 | case SpvOpSpecConstantFalse: |
| 432 | if (auto error = ValidateConstantBool(_, inst)) return error; |
| 433 | break; |
| 434 | case SpvOpConstantComposite: |
| 435 | case SpvOpSpecConstantComposite: |
| 436 | if (auto error = ValidateConstantComposite(_, inst)) return error; |
| 437 | break; |
| 438 | case SpvOpConstantSampler: |
| 439 | if (auto error = ValidateConstantSampler(_, inst)) return error; |
| 440 | break; |
| 441 | case SpvOpConstantNull: |
| 442 | if (auto error = ValidateConstantNull(_, inst)) return error; |
| 443 | break; |
| 444 | case SpvOpSpecConstant: |
| 445 | if (auto error = ValidateSpecConstant(_, inst)) return error; |
| 446 | break; |
| 447 | case SpvOpSpecConstantOp: |
| 448 | if (auto error = ValidateSpecConstantOp(_, inst)) return error; |
| 449 | break; |
| 450 | default: |
| 451 | break; |
| 452 | } |
| 453 | |
| 454 | // Generally disallow creating 8- or 16-bit constants unless the full |
| 455 | // capabilities are present. |
| 456 | if (spvOpcodeIsConstant(inst->opcode()) && |
| 457 | _.HasCapability(SpvCapabilityShader) && |
| 458 | !_.IsPointerType(inst->type_id()) && |
| 459 | _.ContainsLimitedUseIntOrFloatType(inst->type_id())) { |
| 460 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 461 | << "Cannot form constants of 8- or 16-bit types" ; |
| 462 | } |
| 463 | |
| 464 | return SPV_SUCCESS; |
| 465 | } |
| 466 | |
| 467 | } // namespace val |
| 468 | } // namespace spvtools |
| 469 | |