1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Deserializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/Location.h" 22 #include "mlir/Support/LogicalResult.h" 23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/Sequence.h" 26 #include "llvm/ADT/SmallVector.h" 27 #include "llvm/ADT/StringExtras.h" 28 #include "llvm/ADT/bit.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/SaveAndRestore.h" 31 #include "llvm/Support/raw_ostream.h" 32 33 using namespace mlir; 34 35 #define DEBUG_TYPE "spirv-deserialization" 36 37 //===----------------------------------------------------------------------===// 38 // Utility Functions 39 //===----------------------------------------------------------------------===// 40 41 /// Returns true if the given `block` is a function entry block. 42 static inline bool isFnEntryBlock(Block *block) { 43 return block->isEntryBlock() && 44 isa_and_nonnull<spirv::FuncOp>(block->getParentOp()); 45 } 46 47 //===----------------------------------------------------------------------===// 48 // Deserializer Method Definitions 49 //===----------------------------------------------------------------------===// 50 51 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary, 52 MLIRContext *context) 53 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), 54 module(createModuleOp()), opBuilder(module->getRegion()) 55 #ifndef NDEBUG 56 , 57 logger(llvm::dbgs()) 58 #endif 59 { 60 } 61 62 LogicalResult spirv::Deserializer::deserialize() { 63 LLVM_DEBUG({ 64 logger.resetIndent(); 65 logger.startLine() 66 << "//+++---------- start deserialization ----------+++//\n"; 67 }); 68 69 if (failed(processHeader())) 70 return failure(); 71 72 spirv::Opcode opcode = spirv::Opcode::OpNop; 73 ArrayRef<uint32_t> operands; 74 auto binarySize = binary.size(); 75 while (curOffset < binarySize) { 76 // Slice the next instruction out and populate `opcode` and `operands`. 77 // Internally this also updates `curOffset`. 78 if (failed(sliceInstruction(opcode, operands))) 79 return failure(); 80 81 if (failed(processInstruction(opcode, operands))) 82 return failure(); 83 } 84 85 assert(curOffset == binarySize && 86 "deserializer should never index beyond the binary end"); 87 88 for (auto &deferred : deferredInstructions) { 89 if (failed(processInstruction(deferred.first, deferred.second, false))) { 90 return failure(); 91 } 92 } 93 94 attachVCETriple(); 95 96 LLVM_DEBUG(logger.startLine() 97 << "//+++-------- completed deserialization --------+++//\n"); 98 return success(); 99 } 100 101 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() { 102 return std::move(module); 103 } 104 105 //===----------------------------------------------------------------------===// 106 // Module structure 107 //===----------------------------------------------------------------------===// 108 109 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() { 110 OpBuilder builder(context); 111 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); 112 spirv::ModuleOp::build(builder, state); 113 return cast<spirv::ModuleOp>(Operation::create(state)); 114 } 115 116 LogicalResult spirv::Deserializer::processHeader() { 117 if (binary.size() < spirv::kHeaderWordCount) 118 return emitError(unknownLoc, 119 "SPIR-V binary module must have a 5-word header"); 120 121 if (binary[0] != spirv::kMagicNumber) 122 return emitError(unknownLoc, "incorrect magic number"); 123 124 // Version number bytes: 0 | major number | minor number | 0 125 uint32_t majorVersion = (binary[1] << 8) >> 24; 126 uint32_t minorVersion = (binary[1] << 16) >> 24; 127 if (majorVersion == 1) { 128 switch (minorVersion) { 129 #define MIN_VERSION_CASE(v) \ 130 case v: \ 131 version = spirv::Version::V_1_##v; \ 132 break 133 134 MIN_VERSION_CASE(0); 135 MIN_VERSION_CASE(1); 136 MIN_VERSION_CASE(2); 137 MIN_VERSION_CASE(3); 138 MIN_VERSION_CASE(4); 139 MIN_VERSION_CASE(5); 140 #undef MIN_VERSION_CASE 141 default: 142 return emitError(unknownLoc, "unsupported SPIR-V minor version: ") 143 << minorVersion; 144 } 145 } else { 146 return emitError(unknownLoc, "unsupported SPIR-V major version: ") 147 << majorVersion; 148 } 149 150 // TODO: generator number, bound, schema 151 curOffset = spirv::kHeaderWordCount; 152 return success(); 153 } 154 155 LogicalResult 156 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) { 157 if (operands.size() != 1) 158 return emitError(unknownLoc, "OpMemoryModel must have one parameter"); 159 160 auto cap = spirv::symbolizeCapability(operands[0]); 161 if (!cap) 162 return emitError(unknownLoc, "unknown capability: ") << operands[0]; 163 164 capabilities.insert(*cap); 165 return success(); 166 } 167 168 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) { 169 if (words.empty()) { 170 return emitError( 171 unknownLoc, 172 "OpExtension must have a literal string for the extension name"); 173 } 174 175 unsigned wordIndex = 0; 176 StringRef extName = decodeStringLiteral(words, wordIndex); 177 if (wordIndex != words.size()) 178 return emitError(unknownLoc, 179 "unexpected trailing words in OpExtension instruction"); 180 auto ext = spirv::symbolizeExtension(extName); 181 if (!ext) 182 return emitError(unknownLoc, "unknown extension: ") << extName; 183 184 extensions.insert(*ext); 185 return success(); 186 } 187 188 LogicalResult 189 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) { 190 if (words.size() < 2) { 191 return emitError(unknownLoc, 192 "OpExtInstImport must have a result <id> and a literal " 193 "string for the extended instruction set name"); 194 } 195 196 unsigned wordIndex = 1; 197 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex); 198 if (wordIndex != words.size()) { 199 return emitError(unknownLoc, 200 "unexpected trailing words in OpExtInstImport"); 201 } 202 return success(); 203 } 204 205 void spirv::Deserializer::attachVCETriple() { 206 (*module)->setAttr( 207 spirv::ModuleOp::getVCETripleAttrName(), 208 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(), 209 extensions.getArrayRef(), context)); 210 } 211 212 LogicalResult 213 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) { 214 if (operands.size() != 2) 215 return emitError(unknownLoc, "OpMemoryModel must have two operands"); 216 217 (*module)->setAttr( 218 "addressing_model", 219 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front()))); 220 (*module)->setAttr( 221 "memory_model", 222 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back()))); 223 224 return success(); 225 } 226 227 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) { 228 // TODO: This function should also be auto-generated. For now, since only a 229 // few decorations are processed/handled in a meaningful manner, going with a 230 // manual implementation. 231 if (words.size() < 2) { 232 return emitError( 233 unknownLoc, "OpDecorate must have at least result <id> and Decoration"); 234 } 235 auto decorationName = 236 stringifyDecoration(static_cast<spirv::Decoration>(words[1])); 237 if (decorationName.empty()) { 238 return emitError(unknownLoc, "invalid Decoration code : ") << words[1]; 239 } 240 auto attrName = llvm::convertToSnakeFromCamelCase(decorationName); 241 auto symbol = opBuilder.getStringAttr(attrName); 242 switch (static_cast<spirv::Decoration>(words[1])) { 243 case spirv::Decoration::DescriptorSet: 244 case spirv::Decoration::Binding: 245 if (words.size() != 3) { 246 return emitError(unknownLoc, "OpDecorate with ") 247 << decorationName << " needs a single integer literal"; 248 } 249 decorations[words[0]].set( 250 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); 251 break; 252 case spirv::Decoration::BuiltIn: 253 if (words.size() != 3) { 254 return emitError(unknownLoc, "OpDecorate with ") 255 << decorationName << " needs a single integer literal"; 256 } 257 decorations[words[0]].set( 258 symbol, opBuilder.getStringAttr( 259 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2])))); 260 break; 261 case spirv::Decoration::ArrayStride: 262 if (words.size() != 3) { 263 return emitError(unknownLoc, "OpDecorate with ") 264 << decorationName << " needs a single integer literal"; 265 } 266 typeDecorations[words[0]] = words[2]; 267 break; 268 case spirv::Decoration::Aliased: 269 case spirv::Decoration::Block: 270 case spirv::Decoration::BufferBlock: 271 case spirv::Decoration::Flat: 272 case spirv::Decoration::NonReadable: 273 case spirv::Decoration::NonWritable: 274 case spirv::Decoration::NoPerspective: 275 case spirv::Decoration::Restrict: 276 case spirv::Decoration::RelaxedPrecision: 277 if (words.size() != 2) { 278 return emitError(unknownLoc, "OpDecoration with ") 279 << decorationName << "needs a single target <id>"; 280 } 281 // Block decoration does not affect spv.struct type, but is still stored for 282 // verification. 283 // TODO: Update StructType to contain this information since 284 // it is needed for many validation rules. 285 decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); 286 break; 287 case spirv::Decoration::Location: 288 case spirv::Decoration::SpecId: 289 if (words.size() != 3) { 290 return emitError(unknownLoc, "OpDecoration with ") 291 << decorationName << "needs a single integer literal"; 292 } 293 decorations[words[0]].set( 294 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); 295 break; 296 default: 297 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; 298 } 299 return success(); 300 } 301 302 LogicalResult 303 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) { 304 // The binary layout of OpMemberDecorate is different comparing to OpDecorate 305 if (words.size() < 3) { 306 return emitError(unknownLoc, 307 "OpMemberDecorate must have at least 3 operands"); 308 } 309 310 auto decoration = static_cast<spirv::Decoration>(words[2]); 311 if (decoration == spirv::Decoration::Offset && words.size() != 4) { 312 return emitError(unknownLoc, 313 " missing offset specification in OpMemberDecorate with " 314 "Offset decoration"); 315 } 316 ArrayRef<uint32_t> decorationOperands; 317 if (words.size() > 3) { 318 decorationOperands = words.slice(3); 319 } 320 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands; 321 return success(); 322 } 323 324 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) { 325 if (words.size() < 3) { 326 return emitError(unknownLoc, "OpMemberName must have at least 3 operands"); 327 } 328 unsigned wordIndex = 2; 329 auto name = decodeStringLiteral(words, wordIndex); 330 if (wordIndex != words.size()) { 331 return emitError(unknownLoc, 332 "unexpected trailing words in OpMemberName instruction"); 333 } 334 memberNameMap[words[0]][words[1]] = name; 335 return success(); 336 } 337 338 LogicalResult 339 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) { 340 if (curFunction) { 341 return emitError(unknownLoc, "found function inside function"); 342 } 343 344 // Get the result type 345 if (operands.size() != 4) { 346 return emitError(unknownLoc, "OpFunction must have 4 parameters"); 347 } 348 Type resultType = getType(operands[0]); 349 if (!resultType) { 350 return emitError(unknownLoc, "undefined result type from <id> ") 351 << operands[0]; 352 } 353 354 uint32_t fnID = operands[1]; 355 if (funcMap.count(fnID)) { 356 return emitError(unknownLoc, "duplicate function definition/declaration"); 357 } 358 359 auto fnControl = spirv::symbolizeFunctionControl(operands[2]); 360 if (!fnControl) { 361 return emitError(unknownLoc, "unknown Function Control: ") << operands[2]; 362 } 363 364 Type fnType = getType(operands[3]); 365 if (!fnType || !fnType.isa<FunctionType>()) { 366 return emitError(unknownLoc, "unknown function type from <id> ") 367 << operands[3]; 368 } 369 auto functionType = fnType.cast<FunctionType>(); 370 371 if ((isVoidType(resultType) && functionType.getNumResults() != 0) || 372 (functionType.getNumResults() == 1 && 373 functionType.getResult(0) != resultType)) { 374 return emitError(unknownLoc, "mismatch in function type ") 375 << functionType << " and return type " << resultType << " specified"; 376 } 377 378 std::string fnName = getFunctionSymbol(fnID); 379 auto funcOp = opBuilder.create<spirv::FuncOp>( 380 unknownLoc, fnName, functionType, fnControl.getValue()); 381 curFunction = funcMap[fnID] = funcOp; 382 auto *entryBlock = funcOp.addEntryBlock(); 383 LLVM_DEBUG({ 384 logger.startLine() 385 << "//===-------------------------------------------===//\n"; 386 logger.startLine() << "[fn] name: " << fnName << "\n"; 387 logger.startLine() << "[fn] type: " << fnType << "\n"; 388 logger.startLine() << "[fn] ID: " << fnID << "\n"; 389 logger.startLine() << "[fn] entry block: " << entryBlock << "\n"; 390 logger.indent(); 391 }); 392 393 // Parse the op argument instructions 394 if (functionType.getNumInputs()) { 395 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { 396 auto argType = functionType.getInput(i); 397 spirv::Opcode opcode = spirv::Opcode::OpNop; 398 ArrayRef<uint32_t> operands; 399 if (failed(sliceInstruction(opcode, operands, 400 spirv::Opcode::OpFunctionParameter))) { 401 return failure(); 402 } 403 if (opcode != spirv::Opcode::OpFunctionParameter) { 404 return emitError( 405 unknownLoc, 406 "missing OpFunctionParameter instruction for argument ") 407 << i; 408 } 409 if (operands.size() != 2) { 410 return emitError( 411 unknownLoc, 412 "expected result type and result <id> for OpFunctionParameter"); 413 } 414 auto argDefinedType = getType(operands[0]); 415 if (!argDefinedType || argDefinedType != argType) { 416 return emitError(unknownLoc, 417 "mismatch in argument type between function type " 418 "definition ") 419 << functionType << " and argument type definition " 420 << argDefinedType << " at argument " << i; 421 } 422 if (getValue(operands[1])) { 423 return emitError(unknownLoc, "duplicate definition of result <id> ") 424 << operands[1]; 425 } 426 auto argValue = funcOp.getArgument(i); 427 valueMap[operands[1]] = argValue; 428 } 429 } 430 431 // RAII guard to reset the insertion point to the module's region after 432 // deserializing the body of this function. 433 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); 434 435 spirv::Opcode opcode = spirv::Opcode::OpNop; 436 ArrayRef<uint32_t> instOperands; 437 438 // Special handling for the entry block. We need to make sure it starts with 439 // an OpLabel instruction. The entry block takes the same parameters as the 440 // function. All other blocks do not take any parameter. We have already 441 // created the entry block, here we need to register it to the correct label 442 // <id>. 443 if (failed(sliceInstruction(opcode, instOperands, 444 spirv::Opcode::OpFunctionEnd))) { 445 return failure(); 446 } 447 if (opcode == spirv::Opcode::OpFunctionEnd) { 448 return processFunctionEnd(instOperands); 449 } 450 if (opcode != spirv::Opcode::OpLabel) { 451 return emitError(unknownLoc, "a basic block must start with OpLabel"); 452 } 453 if (instOperands.size() != 1) { 454 return emitError(unknownLoc, "OpLabel should only have result <id>"); 455 } 456 blockMap[instOperands[0]] = entryBlock; 457 if (failed(processLabel(instOperands))) { 458 return failure(); 459 } 460 461 // Then process all the other instructions in the function until we hit 462 // OpFunctionEnd. 463 while (succeeded(sliceInstruction(opcode, instOperands, 464 spirv::Opcode::OpFunctionEnd)) && 465 opcode != spirv::Opcode::OpFunctionEnd) { 466 if (failed(processInstruction(opcode, instOperands))) { 467 return failure(); 468 } 469 } 470 if (opcode != spirv::Opcode::OpFunctionEnd) { 471 return failure(); 472 } 473 474 return processFunctionEnd(instOperands); 475 } 476 477 LogicalResult 478 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) { 479 // Process OpFunctionEnd. 480 if (!operands.empty()) { 481 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd"); 482 } 483 484 // Wire up block arguments from OpPhi instructions. 485 // Put all structured control flow in spv.mlir.selection/spv.mlir.loop ops. 486 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) { 487 return failure(); 488 } 489 490 curBlock = nullptr; 491 curFunction = llvm::None; 492 493 LLVM_DEBUG({ 494 logger.unindent(); 495 logger.startLine() 496 << "//===-------------------------------------------===//\n"; 497 }); 498 return success(); 499 } 500 501 Optional<std::pair<Attribute, Type>> 502 spirv::Deserializer::getConstant(uint32_t id) { 503 auto constIt = constantMap.find(id); 504 if (constIt == constantMap.end()) 505 return llvm::None; 506 return constIt->getSecond(); 507 } 508 509 Optional<spirv::SpecConstOperationMaterializationInfo> 510 spirv::Deserializer::getSpecConstantOperation(uint32_t id) { 511 auto constIt = specConstOperationMap.find(id); 512 if (constIt == specConstOperationMap.end()) 513 return llvm::None; 514 return constIt->getSecond(); 515 } 516 517 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) { 518 auto funcName = nameMap.lookup(id).str(); 519 if (funcName.empty()) { 520 funcName = "spirv_fn_" + std::to_string(id); 521 } 522 return funcName; 523 } 524 525 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) { 526 auto constName = nameMap.lookup(id).str(); 527 if (constName.empty()) { 528 constName = "spirv_spec_const_" + std::to_string(id); 529 } 530 return constName; 531 } 532 533 spirv::SpecConstantOp 534 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID, 535 Attribute defaultValue) { 536 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); 537 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, 538 defaultValue); 539 if (decorations.count(resultID)) { 540 for (auto attr : decorations[resultID].getAttrs()) 541 op->setAttr(attr.getName(), attr.getValue()); 542 } 543 specConstMap[resultID] = op; 544 return op; 545 } 546 547 LogicalResult 548 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) { 549 unsigned wordIndex = 0; 550 if (operands.size() < 3) { 551 return emitError( 552 unknownLoc, 553 "OpVariable needs at least 3 operands, type, <id> and storage class"); 554 } 555 556 // Result Type. 557 auto type = getType(operands[wordIndex]); 558 if (!type) { 559 return emitError(unknownLoc, "unknown result type <id> : ") 560 << operands[wordIndex]; 561 } 562 auto ptrType = type.dyn_cast<spirv::PointerType>(); 563 if (!ptrType) { 564 return emitError(unknownLoc, 565 "expected a result type <id> to be a spv.ptr, found : ") 566 << type; 567 } 568 wordIndex++; 569 570 // Result <id>. 571 auto variableID = operands[wordIndex]; 572 auto variableName = nameMap.lookup(variableID).str(); 573 if (variableName.empty()) { 574 variableName = "spirv_var_" + std::to_string(variableID); 575 } 576 wordIndex++; 577 578 // Storage class. 579 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]); 580 if (ptrType.getStorageClass() != storageClass) { 581 return emitError(unknownLoc, "mismatch in storage class of pointer type ") 582 << type << " and that specified in OpVariable instruction : " 583 << stringifyStorageClass(storageClass); 584 } 585 wordIndex++; 586 587 // Initializer. 588 FlatSymbolRefAttr initializer = nullptr; 589 if (wordIndex < operands.size()) { 590 auto initializerOp = getGlobalVariable(operands[wordIndex]); 591 if (!initializerOp) { 592 return emitError(unknownLoc, "unknown <id> ") 593 << operands[wordIndex] << "used as initializer"; 594 } 595 wordIndex++; 596 initializer = SymbolRefAttr::get(initializerOp.getOperation()); 597 } 598 if (wordIndex != operands.size()) { 599 return emitError(unknownLoc, 600 "found more operands than expected when deserializing " 601 "OpVariable instruction, only ") 602 << wordIndex << " of " << operands.size() << " processed"; 603 } 604 auto loc = createFileLineColLoc(opBuilder); 605 auto varOp = opBuilder.create<spirv::GlobalVariableOp>( 606 loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName), 607 initializer); 608 609 // Decorations. 610 if (decorations.count(variableID)) { 611 for (auto attr : decorations[variableID].getAttrs()) 612 varOp->setAttr(attr.getName(), attr.getValue()); 613 } 614 globalVariableMap[variableID] = varOp; 615 return success(); 616 } 617 618 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) { 619 auto constInfo = getConstant(id); 620 if (!constInfo) { 621 return nullptr; 622 } 623 return constInfo->first.dyn_cast<IntegerAttr>(); 624 } 625 626 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) { 627 if (operands.size() < 2) { 628 return emitError(unknownLoc, "OpName needs at least 2 operands"); 629 } 630 if (!nameMap.lookup(operands[0]).empty()) { 631 return emitError(unknownLoc, "duplicate name found for result <id> ") 632 << operands[0]; 633 } 634 unsigned wordIndex = 1; 635 StringRef name = decodeStringLiteral(operands, wordIndex); 636 if (wordIndex != operands.size()) { 637 return emitError(unknownLoc, 638 "unexpected trailing words in OpName instruction"); 639 } 640 nameMap[operands[0]] = name; 641 return success(); 642 } 643 644 //===----------------------------------------------------------------------===// 645 // Type 646 //===----------------------------------------------------------------------===// 647 648 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, 649 ArrayRef<uint32_t> operands) { 650 if (operands.empty()) { 651 return emitError(unknownLoc, "type instruction with opcode ") 652 << spirv::stringifyOpcode(opcode) << " needs at least one <id>"; 653 } 654 655 /// TODO: Types might be forward declared in some instructions and need to be 656 /// handled appropriately. 657 if (typeMap.count(operands[0])) { 658 return emitError(unknownLoc, "duplicate definition for result <id> ") 659 << operands[0]; 660 } 661 662 switch (opcode) { 663 case spirv::Opcode::OpTypeVoid: 664 if (operands.size() != 1) 665 return emitError(unknownLoc, "OpTypeVoid must have no parameters"); 666 typeMap[operands[0]] = opBuilder.getNoneType(); 667 break; 668 case spirv::Opcode::OpTypeBool: 669 if (operands.size() != 1) 670 return emitError(unknownLoc, "OpTypeBool must have no parameters"); 671 typeMap[operands[0]] = opBuilder.getI1Type(); 672 break; 673 case spirv::Opcode::OpTypeInt: { 674 if (operands.size() != 3) 675 return emitError( 676 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters"); 677 678 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 679 // to preserve or validate. 680 // 0 indicates unsigned, or no signedness semantics 681 // 1 indicates signed semantics." 682 // 683 // So we cannot differentiate signless and unsigned integers; always use 684 // signless semantics for such cases. 685 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed 686 : IntegerType::SignednessSemantics::Signless; 687 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign); 688 } break; 689 case spirv::Opcode::OpTypeFloat: { 690 if (operands.size() != 2) 691 return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); 692 693 Type floatTy; 694 switch (operands[1]) { 695 case 16: 696 floatTy = opBuilder.getF16Type(); 697 break; 698 case 32: 699 floatTy = opBuilder.getF32Type(); 700 break; 701 case 64: 702 floatTy = opBuilder.getF64Type(); 703 break; 704 default: 705 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ") 706 << operands[1]; 707 } 708 typeMap[operands[0]] = floatTy; 709 } break; 710 case spirv::Opcode::OpTypeVector: { 711 if (operands.size() != 3) { 712 return emitError( 713 unknownLoc, 714 "OpTypeVector must have element type and count parameters"); 715 } 716 Type elementTy = getType(operands[1]); 717 if (!elementTy) { 718 return emitError(unknownLoc, "OpTypeVector references undefined <id> ") 719 << operands[1]; 720 } 721 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy); 722 } break; 723 case spirv::Opcode::OpTypePointer: { 724 return processOpTypePointer(operands); 725 } break; 726 case spirv::Opcode::OpTypeArray: 727 return processArrayType(operands); 728 case spirv::Opcode::OpTypeCooperativeMatrixNV: 729 return processCooperativeMatrixType(operands); 730 case spirv::Opcode::OpTypeFunction: 731 return processFunctionType(operands); 732 case spirv::Opcode::OpTypeImage: 733 return processImageType(operands); 734 case spirv::Opcode::OpTypeSampledImage: 735 return processSampledImageType(operands); 736 case spirv::Opcode::OpTypeRuntimeArray: 737 return processRuntimeArrayType(operands); 738 case spirv::Opcode::OpTypeStruct: 739 return processStructType(operands); 740 case spirv::Opcode::OpTypeMatrix: 741 return processMatrixType(operands); 742 default: 743 return emitError(unknownLoc, "unhandled type instruction"); 744 } 745 return success(); 746 } 747 748 LogicalResult 749 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) { 750 if (operands.size() != 3) 751 return emitError(unknownLoc, "OpTypePointer must have two parameters"); 752 753 auto pointeeType = getType(operands[2]); 754 if (!pointeeType) 755 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ") 756 << operands[2]; 757 758 uint32_t typePointerID = operands[0]; 759 auto storageClass = static_cast<spirv::StorageClass>(operands[1]); 760 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass); 761 762 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos); 763 deferredStructIt != std::end(deferredStructTypesInfos);) { 764 for (auto *unresolvedMemberIt = 765 std::begin(deferredStructIt->unresolvedMemberTypes); 766 unresolvedMemberIt != 767 std::end(deferredStructIt->unresolvedMemberTypes);) { 768 if (unresolvedMemberIt->first == typePointerID) { 769 // The newly constructed pointer type can resolve one of the 770 // deferred struct type members; update the memberTypes list and 771 // clean the unresolvedMemberTypes list accordingly. 772 deferredStructIt->memberTypes[unresolvedMemberIt->second] = 773 typeMap[typePointerID]; 774 unresolvedMemberIt = 775 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt); 776 } else { 777 ++unresolvedMemberIt; 778 } 779 } 780 781 if (deferredStructIt->unresolvedMemberTypes.empty()) { 782 // All deferred struct type members are now resolved, set the struct body. 783 auto structType = deferredStructIt->deferredStructType; 784 785 assert(structType && "expected a spirv::StructType"); 786 assert(structType.isIdentified() && "expected an indentified struct"); 787 788 if (failed(structType.trySetBody( 789 deferredStructIt->memberTypes, deferredStructIt->offsetInfo, 790 deferredStructIt->memberDecorationsInfo))) 791 return failure(); 792 793 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt); 794 } else { 795 ++deferredStructIt; 796 } 797 } 798 799 return success(); 800 } 801 802 LogicalResult 803 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) { 804 if (operands.size() != 3) { 805 return emitError(unknownLoc, 806 "OpTypeArray must have element type and count parameters"); 807 } 808 809 Type elementTy = getType(operands[1]); 810 if (!elementTy) { 811 return emitError(unknownLoc, "OpTypeArray references undefined <id> ") 812 << operands[1]; 813 } 814 815 unsigned count = 0; 816 // TODO: The count can also come frome a specialization constant. 817 auto countInfo = getConstant(operands[2]); 818 if (!countInfo) { 819 return emitError(unknownLoc, "OpTypeArray count <id> ") 820 << operands[2] << "can only come from normal constant right now"; 821 } 822 823 if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) { 824 count = intVal.getValue().getZExtValue(); 825 } else { 826 return emitError(unknownLoc, "OpTypeArray count must come from a " 827 "scalar integer constant instruction"); 828 } 829 830 typeMap[operands[0]] = spirv::ArrayType::get( 831 elementTy, count, typeDecorations.lookup(operands[0])); 832 return success(); 833 } 834 835 LogicalResult 836 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) { 837 assert(!operands.empty() && "No operands for processing function type"); 838 if (operands.size() == 1) { 839 return emitError(unknownLoc, "missing return type for OpTypeFunction"); 840 } 841 auto returnType = getType(operands[1]); 842 if (!returnType) { 843 return emitError(unknownLoc, "unknown return type in OpTypeFunction"); 844 } 845 SmallVector<Type, 1> argTypes; 846 for (size_t i = 2, e = operands.size(); i < e; ++i) { 847 auto ty = getType(operands[i]); 848 if (!ty) { 849 return emitError(unknownLoc, "unknown argument type in OpTypeFunction"); 850 } 851 argTypes.push_back(ty); 852 } 853 ArrayRef<Type> returnTypes; 854 if (!isVoidType(returnType)) { 855 returnTypes = llvm::makeArrayRef(returnType); 856 } 857 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes); 858 return success(); 859 } 860 861 LogicalResult 862 spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) { 863 if (operands.size() != 5) { 864 return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element " 865 "type and row x column parameters"); 866 } 867 868 Type elementTy = getType(operands[1]); 869 if (!elementTy) { 870 return emitError(unknownLoc, 871 "OpTypeCooperativeMatrix references undefined <id> ") 872 << operands[1]; 873 } 874 875 auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); 876 if (!scope) { 877 return emitError(unknownLoc, 878 "OpTypeCooperativeMatrix references undefined scope <id> ") 879 << operands[2]; 880 } 881 882 unsigned rows = getConstantInt(operands[3]).getInt(); 883 unsigned columns = getConstantInt(operands[4]).getInt(); 884 885 typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get( 886 elementTy, scope.getValue(), rows, columns); 887 return success(); 888 } 889 890 LogicalResult 891 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) { 892 if (operands.size() != 2) { 893 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands"); 894 } 895 Type memberType = getType(operands[1]); 896 if (!memberType) { 897 return emitError(unknownLoc, 898 "OpTypeRuntimeArray references undefined <id> ") 899 << operands[1]; 900 } 901 typeMap[operands[0]] = spirv::RuntimeArrayType::get( 902 memberType, typeDecorations.lookup(operands[0])); 903 return success(); 904 } 905 906 LogicalResult 907 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) { 908 // TODO: Find a way to handle identified structs when debug info is stripped. 909 910 if (operands.empty()) { 911 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>"); 912 } 913 914 if (operands.size() == 1) { 915 // Handle empty struct. 916 typeMap[operands[0]] = 917 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str()); 918 return success(); 919 } 920 921 // First element is operand ID, second element is member index in the struct. 922 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes; 923 SmallVector<Type, 4> memberTypes; 924 925 for (auto op : llvm::drop_begin(operands, 1)) { 926 Type memberType = getType(op); 927 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0); 928 929 if (!memberType && !typeForwardPtr) 930 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ") 931 << op; 932 933 if (!memberType) 934 unresolvedMemberTypes.emplace_back(op, memberTypes.size()); 935 936 memberTypes.push_back(memberType); 937 } 938 939 SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo; 940 SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo; 941 if (memberDecorationMap.count(operands[0])) { 942 auto &allMemberDecorations = memberDecorationMap[operands[0]]; 943 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) { 944 if (allMemberDecorations.count(memberIndex)) { 945 for (auto &memberDecoration : allMemberDecorations[memberIndex]) { 946 // Check for offset. 947 if (memberDecoration.first == spirv::Decoration::Offset) { 948 // If offset info is empty, resize to the number of members; 949 if (offsetInfo.empty()) { 950 offsetInfo.resize(memberTypes.size()); 951 } 952 offsetInfo[memberIndex] = memberDecoration.second[0]; 953 } else { 954 if (!memberDecoration.second.empty()) { 955 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1, 956 memberDecoration.first, 957 memberDecoration.second[0]); 958 } else { 959 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0, 960 memberDecoration.first, 0); 961 } 962 } 963 } 964 } 965 } 966 } 967 968 uint32_t structID = operands[0]; 969 std::string structIdentifier = nameMap.lookup(structID).str(); 970 971 if (structIdentifier.empty()) { 972 assert(unresolvedMemberTypes.empty() && 973 "didn't expect unresolved member types"); 974 typeMap[structID] = 975 spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); 976 } else { 977 auto structTy = spirv::StructType::getIdentified(context, structIdentifier); 978 typeMap[structID] = structTy; 979 980 if (!unresolvedMemberTypes.empty()) 981 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes, 982 memberTypes, offsetInfo, 983 memberDecorationsInfo}); 984 else if (failed(structTy.trySetBody(memberTypes, offsetInfo, 985 memberDecorationsInfo))) 986 return failure(); 987 } 988 989 // TODO: Update StructType to have member name as attribute as 990 // well. 991 return success(); 992 } 993 994 LogicalResult 995 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) { 996 if (operands.size() != 3) { 997 // Three operands are needed: result_id, column_type, and column_count 998 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands" 999 " (result_id, column_type, and column_count)"); 1000 } 1001 // Matrix columns must be of vector type 1002 Type elementTy = getType(operands[1]); 1003 if (!elementTy) { 1004 return emitError(unknownLoc, 1005 "OpTypeMatrix references undefined column type.") 1006 << operands[1]; 1007 } 1008 1009 uint32_t colsCount = operands[2]; 1010 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount); 1011 return success(); 1012 } 1013 1014 LogicalResult 1015 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) { 1016 if (operands.size() != 2) 1017 return emitError(unknownLoc, 1018 "OpTypeForwardPointer instruction must have two operands"); 1019 1020 typeForwardPointerIDs.insert(operands[0]); 1021 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer 1022 // instruction that defines the actual type. 1023 1024 return success(); 1025 } 1026 1027 LogicalResult 1028 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) { 1029 // TODO: Add support for Access Qualifier. 1030 if (operands.size() != 8) 1031 return emitError( 1032 unknownLoc, 1033 "OpTypeImage with non-eight operands are not supported yet"); 1034 1035 Type elementTy = getType(operands[1]); 1036 if (!elementTy) 1037 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ") 1038 << operands[1]; 1039 1040 auto dim = spirv::symbolizeDim(operands[2]); 1041 if (!dim) 1042 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ") 1043 << operands[2]; 1044 1045 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]); 1046 if (!depthInfo) 1047 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ") 1048 << operands[3]; 1049 1050 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]); 1051 if (!arrayedInfo) 1052 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ") 1053 << operands[4]; 1054 1055 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]); 1056 if (!samplingInfo) 1057 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5]; 1058 1059 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]); 1060 if (!samplerUseInfo) 1061 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ") 1062 << operands[6]; 1063 1064 auto format = spirv::symbolizeImageFormat(operands[7]); 1065 if (!format) 1066 return emitError(unknownLoc, "unknown Format for OpTypeImage: ") 1067 << operands[7]; 1068 1069 typeMap[operands[0]] = spirv::ImageType::get( 1070 elementTy, dim.getValue(), depthInfo.getValue(), arrayedInfo.getValue(), 1071 samplingInfo.getValue(), samplerUseInfo.getValue(), format.getValue()); 1072 return success(); 1073 } 1074 1075 LogicalResult 1076 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) { 1077 if (operands.size() != 2) 1078 return emitError(unknownLoc, "OpTypeSampledImage must have two operands"); 1079 1080 Type elementTy = getType(operands[1]); 1081 if (!elementTy) 1082 return emitError(unknownLoc, 1083 "OpTypeSampledImage references undefined <id>: ") 1084 << operands[1]; 1085 1086 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy); 1087 return success(); 1088 } 1089 1090 //===----------------------------------------------------------------------===// 1091 // Constant 1092 //===----------------------------------------------------------------------===// 1093 1094 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands, 1095 bool isSpec) { 1096 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant"; 1097 1098 if (operands.size() < 2) { 1099 return emitError(unknownLoc) 1100 << opname << " must have type <id> and result <id>"; 1101 } 1102 if (operands.size() < 3) { 1103 return emitError(unknownLoc) 1104 << opname << " must have at least 1 more parameter"; 1105 } 1106 1107 Type resultType = getType(operands[0]); 1108 if (!resultType) { 1109 return emitError(unknownLoc, "undefined result type from <id> ") 1110 << operands[0]; 1111 } 1112 1113 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult { 1114 if (bitwidth == 64) { 1115 if (operands.size() == 4) { 1116 return success(); 1117 } 1118 return emitError(unknownLoc) 1119 << opname << " should have 2 parameters for 64-bit values"; 1120 } 1121 if (bitwidth <= 32) { 1122 if (operands.size() == 3) { 1123 return success(); 1124 } 1125 1126 return emitError(unknownLoc) 1127 << opname 1128 << " should have 1 parameter for values with no more than 32 bits"; 1129 } 1130 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ") 1131 << bitwidth; 1132 }; 1133 1134 auto resultID = operands[1]; 1135 1136 if (auto intType = resultType.dyn_cast<IntegerType>()) { 1137 auto bitwidth = intType.getWidth(); 1138 if (failed(checkOperandSizeForBitwidth(bitwidth))) { 1139 return failure(); 1140 } 1141 1142 APInt value; 1143 if (bitwidth == 64) { 1144 // 64-bit integers are represented with two SPIR-V words. According to 1145 // SPIR-V spec: "When the type’s bit width is larger than one word, the 1146 // literal’s low-order words appear first." 1147 struct DoubleWord { 1148 uint32_t word1; 1149 uint32_t word2; 1150 } words = {operands[2], operands[3]}; 1151 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true); 1152 } else if (bitwidth <= 32) { 1153 value = APInt(bitwidth, operands[2], /*isSigned=*/true); 1154 } 1155 1156 auto attr = opBuilder.getIntegerAttr(intType, value); 1157 1158 if (isSpec) { 1159 createSpecConstant(unknownLoc, resultID, attr); 1160 } else { 1161 // For normal constants, we just record the attribute (and its type) for 1162 // later materialization at use sites. 1163 constantMap.try_emplace(resultID, attr, intType); 1164 } 1165 1166 return success(); 1167 } 1168 1169 if (auto floatType = resultType.dyn_cast<FloatType>()) { 1170 auto bitwidth = floatType.getWidth(); 1171 if (failed(checkOperandSizeForBitwidth(bitwidth))) { 1172 return failure(); 1173 } 1174 1175 APFloat value(0.f); 1176 if (floatType.isF64()) { 1177 // Double values are represented with two SPIR-V words. According to 1178 // SPIR-V spec: "When the type’s bit width is larger than one word, the 1179 // literal’s low-order words appear first." 1180 struct DoubleWord { 1181 uint32_t word1; 1182 uint32_t word2; 1183 } words = {operands[2], operands[3]}; 1184 value = APFloat(llvm::bit_cast<double>(words)); 1185 } else if (floatType.isF32()) { 1186 value = APFloat(llvm::bit_cast<float>(operands[2])); 1187 } else if (floatType.isF16()) { 1188 APInt data(16, operands[2]); 1189 value = APFloat(APFloat::IEEEhalf(), data); 1190 } 1191 1192 auto attr = opBuilder.getFloatAttr(floatType, value); 1193 if (isSpec) { 1194 createSpecConstant(unknownLoc, resultID, attr); 1195 } else { 1196 // For normal constants, we just record the attribute (and its type) for 1197 // later materialization at use sites. 1198 constantMap.try_emplace(resultID, attr, floatType); 1199 } 1200 1201 return success(); 1202 } 1203 1204 return emitError(unknownLoc, "OpConstant can only generate values of " 1205 "scalar integer or floating-point type"); 1206 } 1207 1208 LogicalResult spirv::Deserializer::processConstantBool( 1209 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) { 1210 if (operands.size() != 2) { 1211 return emitError(unknownLoc, "Op") 1212 << (isSpec ? "Spec" : "") << "Constant" 1213 << (isTrue ? "True" : "False") 1214 << " must have type <id> and result <id>"; 1215 } 1216 1217 auto attr = opBuilder.getBoolAttr(isTrue); 1218 auto resultID = operands[1]; 1219 if (isSpec) { 1220 createSpecConstant(unknownLoc, resultID, attr); 1221 } else { 1222 // For normal constants, we just record the attribute (and its type) for 1223 // later materialization at use sites. 1224 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type()); 1225 } 1226 1227 return success(); 1228 } 1229 1230 LogicalResult 1231 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) { 1232 if (operands.size() < 2) { 1233 return emitError(unknownLoc, 1234 "OpConstantComposite must have type <id> and result <id>"); 1235 } 1236 if (operands.size() < 3) { 1237 return emitError(unknownLoc, 1238 "OpConstantComposite must have at least 1 parameter"); 1239 } 1240 1241 Type resultType = getType(operands[0]); 1242 if (!resultType) { 1243 return emitError(unknownLoc, "undefined result type from <id> ") 1244 << operands[0]; 1245 } 1246 1247 SmallVector<Attribute, 4> elements; 1248 elements.reserve(operands.size() - 2); 1249 for (unsigned i = 2, e = operands.size(); i < e; ++i) { 1250 auto elementInfo = getConstant(operands[i]); 1251 if (!elementInfo) { 1252 return emitError(unknownLoc, "OpConstantComposite component <id> ") 1253 << operands[i] << " must come from a normal constant"; 1254 } 1255 elements.push_back(elementInfo->first); 1256 } 1257 1258 auto resultID = operands[1]; 1259 if (auto vectorType = resultType.dyn_cast<VectorType>()) { 1260 auto attr = DenseElementsAttr::get(vectorType, elements); 1261 // For normal constants, we just record the attribute (and its type) for 1262 // later materialization at use sites. 1263 constantMap.try_emplace(resultID, attr, resultType); 1264 } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) { 1265 auto attr = opBuilder.getArrayAttr(elements); 1266 constantMap.try_emplace(resultID, attr, resultType); 1267 } else { 1268 return emitError(unknownLoc, "unsupported OpConstantComposite type: ") 1269 << resultType; 1270 } 1271 1272 return success(); 1273 } 1274 1275 LogicalResult 1276 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) { 1277 if (operands.size() < 2) { 1278 return emitError(unknownLoc, 1279 "OpConstantComposite must have type <id> and result <id>"); 1280 } 1281 if (operands.size() < 3) { 1282 return emitError(unknownLoc, 1283 "OpConstantComposite must have at least 1 parameter"); 1284 } 1285 1286 Type resultType = getType(operands[0]); 1287 if (!resultType) { 1288 return emitError(unknownLoc, "undefined result type from <id> ") 1289 << operands[0]; 1290 } 1291 1292 auto resultID = operands[1]; 1293 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); 1294 1295 SmallVector<Attribute, 4> elements; 1296 elements.reserve(operands.size() - 2); 1297 for (unsigned i = 2, e = operands.size(); i < e; ++i) { 1298 auto elementInfo = getSpecConstant(operands[i]); 1299 elements.push_back(SymbolRefAttr::get(elementInfo)); 1300 } 1301 1302 auto op = opBuilder.create<spirv::SpecConstantCompositeOp>( 1303 unknownLoc, TypeAttr::get(resultType), symName, 1304 opBuilder.getArrayAttr(elements)); 1305 specConstCompositeMap[resultID] = op; 1306 1307 return success(); 1308 } 1309 1310 LogicalResult 1311 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) { 1312 if (operands.size() < 3) 1313 return emitError(unknownLoc, "OpConstantOperation must have type <id>, " 1314 "result <id>, and operand opcode"); 1315 1316 uint32_t resultTypeID = operands[0]; 1317 1318 if (!getType(resultTypeID)) 1319 return emitError(unknownLoc, "undefined result type from <id> ") 1320 << resultTypeID; 1321 1322 uint32_t resultID = operands[1]; 1323 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]); 1324 auto emplaceResult = specConstOperationMap.try_emplace( 1325 resultID, 1326 SpecConstOperationMaterializationInfo{ 1327 enclosedOpcode, resultTypeID, 1328 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}}); 1329 1330 if (!emplaceResult.second) 1331 return emitError(unknownLoc, "value with <id>: ") 1332 << resultID << " is probably defined before."; 1333 1334 return success(); 1335 } 1336 1337 Value spirv::Deserializer::materializeSpecConstantOperation( 1338 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, 1339 ArrayRef<uint32_t> enclosedOpOperands) { 1340 1341 Type resultType = getType(resultTypeID); 1342 1343 // Instructions wrapped by OpSpecConstantOp need an ID for their 1344 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V 1345 // dialect wrapped op. For that purpose, a new value map is created and "fake" 1346 // ID in that map is assigned to the result of the enclosed instruction. Note 1347 // that there is no need to update this fake ID since we only need to 1348 // reference the created Value for the enclosed op from the spv::YieldOp 1349 // created later in this method (both of which are the only values in their 1350 // region: the SpecConstantOperation's region). If we encounter another 1351 // SpecConstantOperation in the module, we simply re-use the fake ID since the 1352 // previous Value assigned to it isn't visible in the current scope anyway. 1353 DenseMap<uint32_t, Value> newValueMap; 1354 llvm::SaveAndRestore<DenseMap<uint32_t, Value>> valueMapGuard(valueMap, 1355 newValueMap); 1356 constexpr uint32_t fakeID = static_cast<uint32_t>(-3); 1357 1358 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands; 1359 enclosedOpResultTypeAndOperands.push_back(resultTypeID); 1360 enclosedOpResultTypeAndOperands.push_back(fakeID); 1361 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(), 1362 enclosedOpOperands.end()); 1363 1364 // Process enclosed instruction before creating the enclosing 1365 // specConstantOperation (and its region). This way, references to constants, 1366 // global variables, and spec constants will be materialized outside the new 1367 // op's region. For more info, see Deserializer::getValue's implementation. 1368 if (failed( 1369 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands))) 1370 return Value(); 1371 1372 // Since the enclosed op is emitted in the current block, split it in a 1373 // separate new block. 1374 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back()); 1375 1376 auto loc = createFileLineColLoc(opBuilder); 1377 auto specConstOperationOp = 1378 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType); 1379 1380 Region &body = specConstOperationOp.body(); 1381 // Move the new block into SpecConstantOperation's body. 1382 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(), 1383 Region::iterator(enclosedBlock)); 1384 Block &block = body.back(); 1385 1386 // RAII guard to reset the insertion point to the module's region after 1387 // deserializing the body of the specConstantOperation. 1388 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); 1389 opBuilder.setInsertionPointToEnd(&block); 1390 1391 opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0)); 1392 return specConstOperationOp.getResult(); 1393 } 1394 1395 LogicalResult 1396 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) { 1397 if (operands.size() != 2) { 1398 return emitError(unknownLoc, 1399 "OpConstantNull must have type <id> and result <id>"); 1400 } 1401 1402 Type resultType = getType(operands[0]); 1403 if (!resultType) { 1404 return emitError(unknownLoc, "undefined result type from <id> ") 1405 << operands[0]; 1406 } 1407 1408 auto resultID = operands[1]; 1409 if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) { 1410 auto attr = opBuilder.getZeroAttr(resultType); 1411 // For normal constants, we just record the attribute (and its type) for 1412 // later materialization at use sites. 1413 constantMap.try_emplace(resultID, attr, resultType); 1414 return success(); 1415 } 1416 1417 return emitError(unknownLoc, "unsupported OpConstantNull type: ") 1418 << resultType; 1419 } 1420 1421 //===----------------------------------------------------------------------===// 1422 // Control flow 1423 //===----------------------------------------------------------------------===// 1424 1425 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) { 1426 if (auto *block = getBlock(id)) { 1427 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id 1428 << " @ " << block << "\n"); 1429 return block; 1430 } 1431 1432 // We don't know where this block will be placed finally (in a 1433 // spv.mlir.selection or spv.mlir.loop or function). Create it into the 1434 // function for now and sort out the proper place later. 1435 auto *block = curFunction->addBlock(); 1436 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id 1437 << " @ " << block << "\n"); 1438 return blockMap[id] = block; 1439 } 1440 1441 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) { 1442 if (!curBlock) { 1443 return emitError(unknownLoc, "OpBranch must appear inside a block"); 1444 } 1445 1446 if (operands.size() != 1) { 1447 return emitError(unknownLoc, "OpBranch must take exactly one target label"); 1448 } 1449 1450 auto *target = getOrCreateBlock(operands[0]); 1451 auto loc = createFileLineColLoc(opBuilder); 1452 // The preceding instruction for the OpBranch instruction could be an 1453 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have 1454 // the same OpLine information. 1455 opBuilder.create<spirv::BranchOp>(loc, target); 1456 1457 clearDebugLine(); 1458 return success(); 1459 } 1460 1461 LogicalResult 1462 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) { 1463 if (!curBlock) { 1464 return emitError(unknownLoc, 1465 "OpBranchConditional must appear inside a block"); 1466 } 1467 1468 if (operands.size() != 3 && operands.size() != 5) { 1469 return emitError(unknownLoc, 1470 "OpBranchConditional must have condition, true label, " 1471 "false label, and optionally two branch weights"); 1472 } 1473 1474 auto condition = getValue(operands[0]); 1475 auto *trueBlock = getOrCreateBlock(operands[1]); 1476 auto *falseBlock = getOrCreateBlock(operands[2]); 1477 1478 Optional<std::pair<uint32_t, uint32_t>> weights; 1479 if (operands.size() == 5) { 1480 weights = std::make_pair(operands[3], operands[4]); 1481 } 1482 // The preceding instruction for the OpBranchConditional instruction could be 1483 // an OpSelectionMerge instruction, in this case they will have the same 1484 // OpLine information. 1485 auto loc = createFileLineColLoc(opBuilder); 1486 opBuilder.create<spirv::BranchConditionalOp>( 1487 loc, condition, trueBlock, 1488 /*trueArguments=*/ArrayRef<Value>(), falseBlock, 1489 /*falseArguments=*/ArrayRef<Value>(), weights); 1490 1491 clearDebugLine(); 1492 return success(); 1493 } 1494 1495 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) { 1496 if (!curFunction) { 1497 return emitError(unknownLoc, "OpLabel must appear inside a function"); 1498 } 1499 1500 if (operands.size() != 1) { 1501 return emitError(unknownLoc, "OpLabel should only have result <id>"); 1502 } 1503 1504 auto labelID = operands[0]; 1505 // We may have forward declared this block. 1506 auto *block = getOrCreateBlock(labelID); 1507 LLVM_DEBUG(logger.startLine() 1508 << "[block] populating block " << block << "\n"); 1509 // If we have seen this block, make sure it was just a forward declaration. 1510 assert(block->empty() && "re-deserialize the same block!"); 1511 1512 opBuilder.setInsertionPointToStart(block); 1513 blockMap[labelID] = curBlock = block; 1514 1515 return success(); 1516 } 1517 1518 LogicalResult 1519 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) { 1520 if (!curBlock) { 1521 return emitError(unknownLoc, "OpSelectionMerge must appear in a block"); 1522 } 1523 1524 if (operands.size() < 2) { 1525 return emitError( 1526 unknownLoc, 1527 "OpSelectionMerge must specify merge target and selection control"); 1528 } 1529 1530 auto *mergeBlock = getOrCreateBlock(operands[0]); 1531 auto loc = createFileLineColLoc(opBuilder); 1532 auto selectionControl = operands[1]; 1533 1534 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock) 1535 .second) { 1536 return emitError( 1537 unknownLoc, 1538 "a block cannot have more than one OpSelectionMerge instruction"); 1539 } 1540 1541 return success(); 1542 } 1543 1544 LogicalResult 1545 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) { 1546 if (!curBlock) { 1547 return emitError(unknownLoc, "OpLoopMerge must appear in a block"); 1548 } 1549 1550 if (operands.size() < 3) { 1551 return emitError(unknownLoc, "OpLoopMerge must specify merge target, " 1552 "continue target and loop control"); 1553 } 1554 1555 auto *mergeBlock = getOrCreateBlock(operands[0]); 1556 auto *continueBlock = getOrCreateBlock(operands[1]); 1557 auto loc = createFileLineColLoc(opBuilder); 1558 uint32_t loopControl = operands[2]; 1559 1560 if (!blockMergeInfo 1561 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock) 1562 .second) { 1563 return emitError( 1564 unknownLoc, 1565 "a block cannot have more than one OpLoopMerge instruction"); 1566 } 1567 1568 return success(); 1569 } 1570 1571 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) { 1572 if (!curBlock) { 1573 return emitError(unknownLoc, "OpPhi must appear in a block"); 1574 } 1575 1576 if (operands.size() < 4) { 1577 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, " 1578 "and variable-parent pairs"); 1579 } 1580 1581 // Create a block argument for this OpPhi instruction. 1582 Type blockArgType = getType(operands[0]); 1583 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc); 1584 valueMap[operands[1]] = blockArg; 1585 LLVM_DEBUG(logger.startLine() 1586 << "[phi] created block argument " << blockArg 1587 << " id = " << operands[1] << " of type " << blockArgType << "\n"); 1588 1589 // For each (value, predecessor) pair, insert the value to the predecessor's 1590 // blockPhiInfo entry so later we can fix the block argument there. 1591 for (unsigned i = 2, e = operands.size(); i < e; i += 2) { 1592 uint32_t value = operands[i]; 1593 Block *predecessor = getOrCreateBlock(operands[i + 1]); 1594 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock}; 1595 blockPhiInfo[predecessorTargetPair].push_back(value); 1596 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor 1597 << " with arg id = " << value << "\n"); 1598 } 1599 1600 return success(); 1601 } 1602 1603 namespace { 1604 /// A class for putting all blocks in a structured selection/loop in a 1605 /// spv.mlir.selection/spv.mlir.loop op. 1606 class ControlFlowStructurizer { 1607 public: 1608 #ifndef NDEBUG 1609 ControlFlowStructurizer(Location loc, uint32_t control, 1610 spirv::BlockMergeInfoMap &mergeInfo, Block *header, 1611 Block *merge, Block *cont, 1612 llvm::ScopedPrinter &logger) 1613 : location(loc), control(control), blockMergeInfo(mergeInfo), 1614 headerBlock(header), mergeBlock(merge), continueBlock(cont), 1615 logger(logger) {} 1616 #else 1617 ControlFlowStructurizer(Location loc, uint32_t control, 1618 spirv::BlockMergeInfoMap &mergeInfo, Block *header, 1619 Block *merge, Block *cont) 1620 : location(loc), control(control), blockMergeInfo(mergeInfo), 1621 headerBlock(header), mergeBlock(merge), continueBlock(cont) {} 1622 #endif 1623 1624 /// Structurizes the loop at the given `headerBlock`. 1625 /// 1626 /// This method will create an spv.mlir.loop op in the `mergeBlock` and move 1627 /// all blocks in the structured loop into the spv.mlir.loop's region. All 1628 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This 1629 /// method will also update `mergeInfo` by remapping all blocks inside to the 1630 /// newly cloned ones inside structured control flow op's regions. 1631 LogicalResult structurize(); 1632 1633 private: 1634 /// Creates a new spv.mlir.selection op at the beginning of the `mergeBlock`. 1635 spirv::SelectionOp createSelectionOp(uint32_t selectionControl); 1636 1637 /// Creates a new spv.mlir.loop op at the beginning of the `mergeBlock`. 1638 spirv::LoopOp createLoopOp(uint32_t loopControl); 1639 1640 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`. 1641 void collectBlocksInConstruct(); 1642 1643 Location location; 1644 uint32_t control; 1645 1646 spirv::BlockMergeInfoMap &blockMergeInfo; 1647 1648 Block *headerBlock; 1649 Block *mergeBlock; 1650 Block *continueBlock; // nullptr for spv.mlir.selection 1651 1652 SetVector<Block *> constructBlocks; 1653 1654 #ifndef NDEBUG 1655 /// A logger used to emit information during the deserialzation process. 1656 llvm::ScopedPrinter &logger; 1657 #endif 1658 }; 1659 } // namespace 1660 1661 spirv::SelectionOp 1662 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) { 1663 // Create a builder and set the insertion point to the beginning of the 1664 // merge block so that the newly created SelectionOp will be inserted there. 1665 OpBuilder builder(&mergeBlock->front()); 1666 1667 auto control = static_cast<spirv::SelectionControl>(selectionControl); 1668 auto selectionOp = builder.create<spirv::SelectionOp>(location, control); 1669 selectionOp.addMergeBlock(); 1670 1671 return selectionOp; 1672 } 1673 1674 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) { 1675 // Create a builder and set the insertion point to the beginning of the 1676 // merge block so that the newly created LoopOp will be inserted there. 1677 OpBuilder builder(&mergeBlock->front()); 1678 1679 auto control = static_cast<spirv::LoopControl>(loopControl); 1680 auto loopOp = builder.create<spirv::LoopOp>(location, control); 1681 loopOp.addEntryAndMergeBlock(); 1682 1683 return loopOp; 1684 } 1685 1686 void ControlFlowStructurizer::collectBlocksInConstruct() { 1687 assert(constructBlocks.empty() && "expected empty constructBlocks"); 1688 1689 // Put the header block in the work list first. 1690 constructBlocks.insert(headerBlock); 1691 1692 // For each item in the work list, add its successors excluding the merge 1693 // block. 1694 for (unsigned i = 0; i < constructBlocks.size(); ++i) { 1695 for (auto *successor : constructBlocks[i]->getSuccessors()) 1696 if (successor != mergeBlock) 1697 constructBlocks.insert(successor); 1698 } 1699 } 1700 1701 LogicalResult ControlFlowStructurizer::structurize() { 1702 Operation *op = nullptr; 1703 bool isLoop = continueBlock != nullptr; 1704 if (isLoop) { 1705 if (auto loopOp = createLoopOp(control)) 1706 op = loopOp.getOperation(); 1707 } else { 1708 if (auto selectionOp = createSelectionOp(control)) 1709 op = selectionOp.getOperation(); 1710 } 1711 if (!op) 1712 return failure(); 1713 Region &body = op->getRegion(0); 1714 1715 BlockAndValueMapping mapper; 1716 // All references to the old merge block should be directed to the 1717 // selection/loop merge block in the SelectionOp/LoopOp's region. 1718 mapper.map(mergeBlock, &body.back()); 1719 1720 collectBlocksInConstruct(); 1721 1722 // We've identified all blocks belonging to the selection/loop's region. Now 1723 // need to "move" them into the selection/loop. Instead of really moving the 1724 // blocks, in the following we copy them and remap all values and branches. 1725 // This is because: 1726 // * Inserting a block into a region requires the block not in any region 1727 // before. But selections/loops can nest so we can create selection/loop ops 1728 // in a nested manner, which means some blocks may already be in a 1729 // selection/loop region when to be moved again. 1730 // * It's much trickier to fix up the branches into and out of the loop's 1731 // region: we need to treat not-moved blocks and moved blocks differently: 1732 // Not-moved blocks jumping to the loop header block need to jump to the 1733 // merge point containing the new loop op but not the loop continue block's 1734 // back edge. Moved blocks jumping out of the loop need to jump to the 1735 // merge block inside the loop region but not other not-moved blocks. 1736 // We cannot use replaceAllUsesWith clearly and it's harder to follow the 1737 // logic. 1738 1739 // Create a corresponding block in the SelectionOp/LoopOp's region for each 1740 // block in this loop construct. 1741 OpBuilder builder(body); 1742 for (auto *block : constructBlocks) { 1743 // Create a block and insert it before the selection/loop merge block in the 1744 // SelectionOp/LoopOp's region. 1745 auto *newBlock = builder.createBlock(&body.back()); 1746 mapper.map(block, newBlock); 1747 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock 1748 << " from block " << block << "\n"); 1749 if (!isFnEntryBlock(block)) { 1750 for (BlockArgument blockArg : block->getArguments()) { 1751 auto newArg = 1752 newBlock->addArgument(blockArg.getType(), blockArg.getLoc()); 1753 mapper.map(blockArg, newArg); 1754 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument " 1755 << blockArg << " to " << newArg << "\n"); 1756 } 1757 } else { 1758 LLVM_DEBUG(logger.startLine() 1759 << "[cf] block " << block << " is a function entry block\n"); 1760 } 1761 1762 for (auto &op : *block) 1763 newBlock->push_back(op.clone(mapper)); 1764 } 1765 1766 // Go through all ops and remap the operands. 1767 auto remapOperands = [&](Operation *op) { 1768 for (auto &operand : op->getOpOperands()) 1769 if (Value mappedOp = mapper.lookupOrNull(operand.get())) 1770 operand.set(mappedOp); 1771 for (auto &succOp : op->getBlockOperands()) 1772 if (Block *mappedOp = mapper.lookupOrNull(succOp.get())) 1773 succOp.set(mappedOp); 1774 }; 1775 for (auto &block : body) 1776 block.walk(remapOperands); 1777 1778 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to 1779 // the selection/loop construct into its region. Next we need to fix the 1780 // connections between this new SelectionOp/LoopOp with existing blocks. 1781 1782 // All existing incoming branches should go to the merge block, where the 1783 // SelectionOp/LoopOp resides right now. 1784 headerBlock->replaceAllUsesWith(mergeBlock); 1785 1786 LLVM_DEBUG({ 1787 logger.startLine() << "[cf] after cloning and fixing references:\n"; 1788 headerBlock->getParentOp()->print(logger.getOStream()); 1789 logger.startLine() << "\n"; 1790 }); 1791 1792 if (isLoop) { 1793 if (!mergeBlock->args_empty()) { 1794 return mergeBlock->getParentOp()->emitError( 1795 "OpPhi in loop merge block unsupported"); 1796 } 1797 1798 // The loop header block may have block arguments. Since now we place the 1799 // loop op inside the old merge block, we need to make sure the old merge 1800 // block has the same block argument list. 1801 for (BlockArgument blockArg : headerBlock->getArguments()) 1802 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc()); 1803 1804 // If the loop header block has block arguments, make sure the spv.Branch op 1805 // matches. 1806 SmallVector<Value, 4> blockArgs; 1807 if (!headerBlock->args_empty()) 1808 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; 1809 1810 // The loop entry block should have a unconditional branch jumping to the 1811 // loop header block. 1812 builder.setInsertionPointToEnd(&body.front()); 1813 builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock), 1814 ArrayRef<Value>(blockArgs)); 1815 } 1816 1817 // All the blocks cloned into the SelectionOp/LoopOp's region can now be 1818 // cleaned up. 1819 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n"); 1820 // First we need to drop all operands' references inside all blocks. This is 1821 // needed because we can have blocks referencing SSA values from one another. 1822 for (auto *block : constructBlocks) 1823 block->dropAllReferences(); 1824 1825 // Check that whether some op in the to-be-erased blocks still has uses. Those 1826 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's 1827 // region. We cannot handle such cases given that once a value is sinked into 1828 // the SelectionOp/LoopOp's region, there is no escape for it: 1829 // SelectionOp/LooOp does not support yield values right now. 1830 for (auto *block : constructBlocks) { 1831 for (Operation &op : *block) 1832 if (!op.use_empty()) 1833 return op.emitOpError( 1834 "failed control flow structurization: it has uses outside of the " 1835 "enclosing selection/loop construct"); 1836 } 1837 1838 // Then erase all old blocks. 1839 for (auto *block : constructBlocks) { 1840 // We've cloned all blocks belonging to this construct into the structured 1841 // control flow op's region. Among these blocks, some may compose another 1842 // selection/loop. If so, they will be recorded within blockMergeInfo. 1843 // We need to update the pointers there to the newly remapped ones so we can 1844 // continue structurizing them later. 1845 // TODO: The asserts in the following assumes input SPIR-V blob forms 1846 // correctly nested selection/loop constructs. We should relax this and 1847 // support error cases better. 1848 auto it = blockMergeInfo.find(block); 1849 if (it != blockMergeInfo.end()) { 1850 // Use the original location for nested selection/loop ops. 1851 Location loc = it->second.loc; 1852 1853 Block *newHeader = mapper.lookupOrNull(block); 1854 if (!newHeader) 1855 return emitError(loc, "failed control flow structurization: nested " 1856 "loop header block should be remapped!"); 1857 1858 Block *newContinue = it->second.continueBlock; 1859 if (newContinue) { 1860 newContinue = mapper.lookupOrNull(newContinue); 1861 if (!newContinue) 1862 return emitError(loc, "failed control flow structurization: nested " 1863 "loop continue block should be remapped!"); 1864 } 1865 1866 Block *newMerge = it->second.mergeBlock; 1867 if (Block *mappedTo = mapper.lookupOrNull(newMerge)) 1868 newMerge = mappedTo; 1869 1870 // The iterator should be erased before adding a new entry into 1871 // blockMergeInfo to avoid iterator invalidation. 1872 blockMergeInfo.erase(it); 1873 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge, 1874 newContinue); 1875 } 1876 1877 // The structured selection/loop's entry block does not have arguments. 1878 // If the function's header block is also part of the structured control 1879 // flow, we cannot just simply erase it because it may contain arguments 1880 // matching the function signature and used by the cloned blocks. 1881 if (isFnEntryBlock(block)) { 1882 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block 1883 << " to only contain a spv.Branch op\n"); 1884 // Still keep the function entry block for the potential block arguments, 1885 // but replace all ops inside with a branch to the merge block. 1886 block->clear(); 1887 builder.setInsertionPointToEnd(block); 1888 builder.create<spirv::BranchOp>(location, mergeBlock); 1889 } else { 1890 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n"); 1891 block->erase(); 1892 } 1893 } 1894 1895 LLVM_DEBUG(logger.startLine() 1896 << "[cf] after structurizing construct with header block " 1897 << headerBlock << ":\n" 1898 << *op << "\n"); 1899 1900 return success(); 1901 } 1902 1903 LogicalResult spirv::Deserializer::wireUpBlockArgument() { 1904 LLVM_DEBUG({ 1905 logger.startLine() 1906 << "//----- [phi] start wiring up block arguments -----//\n"; 1907 logger.indent(); 1908 }); 1909 1910 OpBuilder::InsertionGuard guard(opBuilder); 1911 1912 for (const auto &info : blockPhiInfo) { 1913 Block *block = info.first.first; 1914 Block *target = info.first.second; 1915 const BlockPhiInfo &phiInfo = info.second; 1916 LLVM_DEBUG({ 1917 logger.startLine() << "[phi] block " << block << "\n"; 1918 logger.startLine() << "[phi] before creating block argument:\n"; 1919 block->getParentOp()->print(logger.getOStream()); 1920 logger.startLine() << "\n"; 1921 }); 1922 1923 // Set insertion point to before this block's terminator early because we 1924 // may materialize ops via getValue() call. 1925 auto *op = block->getTerminator(); 1926 opBuilder.setInsertionPoint(op); 1927 1928 SmallVector<Value, 4> blockArgs; 1929 blockArgs.reserve(phiInfo.size()); 1930 for (uint32_t valueId : phiInfo) { 1931 if (Value value = getValue(valueId)) { 1932 blockArgs.push_back(value); 1933 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value 1934 << " id = " << valueId << "\n"); 1935 } else { 1936 return emitError(unknownLoc, "OpPhi references undefined value!"); 1937 } 1938 } 1939 1940 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) { 1941 // Replace the previous branch op with a new one with block arguments. 1942 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(), 1943 blockArgs); 1944 branchOp.erase(); 1945 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) { 1946 assert((branchCondOp.getTrueBlock() == target || 1947 branchCondOp.getFalseBlock() == target) && 1948 "expected target to be either the true or false target"); 1949 if (target == branchCondOp.trueTarget()) 1950 opBuilder.create<spirv::BranchConditionalOp>( 1951 branchCondOp.getLoc(), branchCondOp.condition(), blockArgs, 1952 branchCondOp.getFalseBlockArguments(), 1953 branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(), 1954 branchCondOp.falseTarget()); 1955 else 1956 opBuilder.create<spirv::BranchConditionalOp>( 1957 branchCondOp.getLoc(), branchCondOp.condition(), 1958 branchCondOp.getTrueBlockArguments(), blockArgs, 1959 branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(), 1960 branchCondOp.getFalseBlock()); 1961 1962 branchCondOp.erase(); 1963 } else { 1964 return emitError(unknownLoc, "unimplemented terminator for Phi creation"); 1965 } 1966 1967 LLVM_DEBUG({ 1968 logger.startLine() << "[phi] after creating block argument:\n"; 1969 block->getParentOp()->print(logger.getOStream()); 1970 logger.startLine() << "\n"; 1971 }); 1972 } 1973 blockPhiInfo.clear(); 1974 1975 LLVM_DEBUG({ 1976 logger.unindent(); 1977 logger.startLine() 1978 << "//--- [phi] completed wiring up block arguments ---//\n"; 1979 }); 1980 return success(); 1981 } 1982 1983 LogicalResult spirv::Deserializer::structurizeControlFlow() { 1984 LLVM_DEBUG({ 1985 logger.startLine() 1986 << "//----- [cf] start structurizing control flow -----//\n"; 1987 logger.indent(); 1988 }); 1989 1990 while (!blockMergeInfo.empty()) { 1991 Block *headerBlock = blockMergeInfo.begin()->first; 1992 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; 1993 1994 LLVM_DEBUG({ 1995 logger.startLine() << "[cf] header block " << headerBlock << ":\n"; 1996 headerBlock->print(logger.getOStream()); 1997 logger.startLine() << "\n"; 1998 }); 1999 2000 auto *mergeBlock = mergeInfo.mergeBlock; 2001 assert(mergeBlock && "merge block cannot be nullptr"); 2002 if (!mergeBlock->args_empty()) 2003 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented"); 2004 LLVM_DEBUG({ 2005 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n"; 2006 mergeBlock->print(logger.getOStream()); 2007 logger.startLine() << "\n"; 2008 }); 2009 2010 auto *continueBlock = mergeInfo.continueBlock; 2011 LLVM_DEBUG(if (continueBlock) { 2012 logger.startLine() << "[cf] continue block " << continueBlock << ":\n"; 2013 continueBlock->print(logger.getOStream()); 2014 logger.startLine() << "\n"; 2015 }); 2016 // Erase this case before calling into structurizer, who will update 2017 // blockMergeInfo. 2018 blockMergeInfo.erase(blockMergeInfo.begin()); 2019 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control, 2020 blockMergeInfo, headerBlock, 2021 mergeBlock, continueBlock 2022 #ifndef NDEBUG 2023 , 2024 logger 2025 #endif 2026 ); 2027 if (failed(structurizer.structurize())) 2028 return failure(); 2029 } 2030 2031 LLVM_DEBUG({ 2032 logger.unindent(); 2033 logger.startLine() 2034 << "//--- [cf] completed structurizing control flow ---//\n"; 2035 }); 2036 return success(); 2037 } 2038 2039 //===----------------------------------------------------------------------===// 2040 // Debug 2041 //===----------------------------------------------------------------------===// 2042 2043 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) { 2044 if (!debugLine) 2045 return unknownLoc; 2046 2047 auto fileName = debugInfoMap.lookup(debugLine->fileID).str(); 2048 if (fileName.empty()) 2049 fileName = "<unknown>"; 2050 return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line, 2051 debugLine->column); 2052 } 2053 2054 LogicalResult 2055 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) { 2056 // According to SPIR-V spec: 2057 // "This location information applies to the instructions physically 2058 // following this instruction, up to the first occurrence of any of the 2059 // following: the next end of block, the next OpLine instruction, or the next 2060 // OpNoLine instruction." 2061 if (operands.size() != 3) 2062 return emitError(unknownLoc, "OpLine must have 3 operands"); 2063 debugLine = DebugLine{operands[0], operands[1], operands[2]}; 2064 return success(); 2065 } 2066 2067 void spirv::Deserializer::clearDebugLine() { debugLine = llvm::None; } 2068 2069 LogicalResult 2070 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) { 2071 if (operands.size() < 2) 2072 return emitError(unknownLoc, "OpString needs at least 2 operands"); 2073 2074 if (!debugInfoMap.lookup(operands[0]).empty()) 2075 return emitError(unknownLoc, 2076 "duplicate debug string found for result <id> ") 2077 << operands[0]; 2078 2079 unsigned wordIndex = 1; 2080 StringRef debugString = decodeStringLiteral(operands, wordIndex); 2081 if (wordIndex != operands.size()) 2082 return emitError(unknownLoc, 2083 "unexpected trailing words in OpString instruction"); 2084 2085 debugInfoMap[operands[0]] = debugString; 2086 return success(); 2087 } 2088