1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Deserializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/Location.h" 22 #include "mlir/Support/LogicalResult.h" 23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/Sequence.h" 26 #include "llvm/ADT/SmallVector.h" 27 #include "llvm/ADT/StringExtras.h" 28 #include "llvm/ADT/bit.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/SaveAndRestore.h" 31 #include "llvm/Support/raw_ostream.h" 32 33 using namespace mlir; 34 35 #define DEBUG_TYPE "spirv-deserialization" 36 37 //===----------------------------------------------------------------------===// 38 // Utility Functions 39 //===----------------------------------------------------------------------===// 40 41 /// Returns true if the given `block` is a function entry block. 42 static inline bool isFnEntryBlock(Block *block) { 43 return block->isEntryBlock() && 44 isa_and_nonnull<spirv::FuncOp>(block->getParentOp()); 45 } 46 47 //===----------------------------------------------------------------------===// 48 // Deserializer Method Definitions 49 //===----------------------------------------------------------------------===// 50 51 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary, 52 MLIRContext *context) 53 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), 54 module(createModuleOp()), opBuilder(module->getRegion()) {} 55 56 LogicalResult spirv::Deserializer::deserialize() { 57 LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n"); 58 59 if (failed(processHeader())) 60 return failure(); 61 62 spirv::Opcode opcode = spirv::Opcode::OpNop; 63 ArrayRef<uint32_t> operands; 64 auto binarySize = binary.size(); 65 while (curOffset < binarySize) { 66 // Slice the next instruction out and populate `opcode` and `operands`. 67 // Internally this also updates `curOffset`. 68 if (failed(sliceInstruction(opcode, operands))) 69 return failure(); 70 71 if (failed(processInstruction(opcode, operands))) 72 return failure(); 73 } 74 75 assert(curOffset == binarySize && 76 "deserializer should never index beyond the binary end"); 77 78 for (auto &deferred : deferredInstructions) { 79 if (failed(processInstruction(deferred.first, deferred.second, false))) { 80 return failure(); 81 } 82 } 83 84 attachVCETriple(); 85 86 LLVM_DEBUG(llvm::dbgs() << "+++ completed deserialization +++\n"); 87 return success(); 88 } 89 90 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() { 91 return std::move(module); 92 } 93 94 //===----------------------------------------------------------------------===// 95 // Module structure 96 //===----------------------------------------------------------------------===// 97 98 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() { 99 OpBuilder builder(context); 100 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); 101 spirv::ModuleOp::build(builder, state); 102 return cast<spirv::ModuleOp>(Operation::create(state)); 103 } 104 105 LogicalResult spirv::Deserializer::processHeader() { 106 if (binary.size() < spirv::kHeaderWordCount) 107 return emitError(unknownLoc, 108 "SPIR-V binary module must have a 5-word header"); 109 110 if (binary[0] != spirv::kMagicNumber) 111 return emitError(unknownLoc, "incorrect magic number"); 112 113 // Version number bytes: 0 | major number | minor number | 0 114 uint32_t majorVersion = (binary[1] << 8) >> 24; 115 uint32_t minorVersion = (binary[1] << 16) >> 24; 116 if (majorVersion == 1) { 117 switch (minorVersion) { 118 #define MIN_VERSION_CASE(v) \ 119 case v: \ 120 version = spirv::Version::V_1_##v; \ 121 break 122 123 MIN_VERSION_CASE(0); 124 MIN_VERSION_CASE(1); 125 MIN_VERSION_CASE(2); 126 MIN_VERSION_CASE(3); 127 MIN_VERSION_CASE(4); 128 MIN_VERSION_CASE(5); 129 #undef MIN_VERSION_CASE 130 default: 131 return emitError(unknownLoc, "unsupported SPIR-V minor version: ") 132 << minorVersion; 133 } 134 } else { 135 return emitError(unknownLoc, "unsupported SPIR-V major version: ") 136 << majorVersion; 137 } 138 139 // TODO: generator number, bound, schema 140 curOffset = spirv::kHeaderWordCount; 141 return success(); 142 } 143 144 LogicalResult 145 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) { 146 if (operands.size() != 1) 147 return emitError(unknownLoc, "OpMemoryModel must have one parameter"); 148 149 auto cap = spirv::symbolizeCapability(operands[0]); 150 if (!cap) 151 return emitError(unknownLoc, "unknown capability: ") << operands[0]; 152 153 capabilities.insert(*cap); 154 return success(); 155 } 156 157 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) { 158 if (words.empty()) { 159 return emitError( 160 unknownLoc, 161 "OpExtension must have a literal string for the extension name"); 162 } 163 164 unsigned wordIndex = 0; 165 StringRef extName = decodeStringLiteral(words, wordIndex); 166 if (wordIndex != words.size()) 167 return emitError(unknownLoc, 168 "unexpected trailing words in OpExtension instruction"); 169 auto ext = spirv::symbolizeExtension(extName); 170 if (!ext) 171 return emitError(unknownLoc, "unknown extension: ") << extName; 172 173 extensions.insert(*ext); 174 return success(); 175 } 176 177 LogicalResult 178 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) { 179 if (words.size() < 2) { 180 return emitError(unknownLoc, 181 "OpExtInstImport must have a result <id> and a literal " 182 "string for the extended instruction set name"); 183 } 184 185 unsigned wordIndex = 1; 186 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex); 187 if (wordIndex != words.size()) { 188 return emitError(unknownLoc, 189 "unexpected trailing words in OpExtInstImport"); 190 } 191 return success(); 192 } 193 194 void spirv::Deserializer::attachVCETriple() { 195 (*module)->setAttr( 196 spirv::ModuleOp::getVCETripleAttrName(), 197 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(), 198 extensions.getArrayRef(), context)); 199 } 200 201 LogicalResult 202 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) { 203 if (operands.size() != 2) 204 return emitError(unknownLoc, "OpMemoryModel must have two operands"); 205 206 (*module)->setAttr( 207 "addressing_model", 208 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front()))); 209 (*module)->setAttr( 210 "memory_model", 211 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back()))); 212 213 return success(); 214 } 215 216 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) { 217 // TODO: This function should also be auto-generated. For now, since only a 218 // few decorations are processed/handled in a meaningful manner, going with a 219 // manual implementation. 220 if (words.size() < 2) { 221 return emitError( 222 unknownLoc, "OpDecorate must have at least result <id> and Decoration"); 223 } 224 auto decorationName = 225 stringifyDecoration(static_cast<spirv::Decoration>(words[1])); 226 if (decorationName.empty()) { 227 return emitError(unknownLoc, "invalid Decoration code : ") << words[1]; 228 } 229 auto attrName = llvm::convertToSnakeFromCamelCase(decorationName); 230 auto symbol = opBuilder.getIdentifier(attrName); 231 switch (static_cast<spirv::Decoration>(words[1])) { 232 case spirv::Decoration::DescriptorSet: 233 case spirv::Decoration::Binding: 234 if (words.size() != 3) { 235 return emitError(unknownLoc, "OpDecorate with ") 236 << decorationName << " needs a single integer literal"; 237 } 238 decorations[words[0]].set( 239 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); 240 break; 241 case spirv::Decoration::BuiltIn: 242 if (words.size() != 3) { 243 return emitError(unknownLoc, "OpDecorate with ") 244 << decorationName << " needs a single integer literal"; 245 } 246 decorations[words[0]].set( 247 symbol, opBuilder.getStringAttr( 248 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2])))); 249 break; 250 case spirv::Decoration::ArrayStride: 251 if (words.size() != 3) { 252 return emitError(unknownLoc, "OpDecorate with ") 253 << decorationName << " needs a single integer literal"; 254 } 255 typeDecorations[words[0]] = words[2]; 256 break; 257 case spirv::Decoration::Aliased: 258 case spirv::Decoration::Block: 259 case spirv::Decoration::BufferBlock: 260 case spirv::Decoration::Flat: 261 case spirv::Decoration::NonReadable: 262 case spirv::Decoration::NonWritable: 263 case spirv::Decoration::NoPerspective: 264 case spirv::Decoration::Restrict: 265 case spirv::Decoration::RelaxedPrecision: 266 if (words.size() != 2) { 267 return emitError(unknownLoc, "OpDecoration with ") 268 << decorationName << "needs a single target <id>"; 269 } 270 // Block decoration does not affect spv.struct type, but is still stored for 271 // verification. 272 // TODO: Update StructType to contain this information since 273 // it is needed for many validation rules. 274 decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); 275 break; 276 case spirv::Decoration::Location: 277 case spirv::Decoration::SpecId: 278 if (words.size() != 3) { 279 return emitError(unknownLoc, "OpDecoration with ") 280 << decorationName << "needs a single integer literal"; 281 } 282 decorations[words[0]].set( 283 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); 284 break; 285 default: 286 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; 287 } 288 return success(); 289 } 290 291 LogicalResult 292 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) { 293 // The binary layout of OpMemberDecorate is different comparing to OpDecorate 294 if (words.size() < 3) { 295 return emitError(unknownLoc, 296 "OpMemberDecorate must have at least 3 operands"); 297 } 298 299 auto decoration = static_cast<spirv::Decoration>(words[2]); 300 if (decoration == spirv::Decoration::Offset && words.size() != 4) { 301 return emitError(unknownLoc, 302 " missing offset specification in OpMemberDecorate with " 303 "Offset decoration"); 304 } 305 ArrayRef<uint32_t> decorationOperands; 306 if (words.size() > 3) { 307 decorationOperands = words.slice(3); 308 } 309 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands; 310 return success(); 311 } 312 313 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) { 314 if (words.size() < 3) { 315 return emitError(unknownLoc, "OpMemberName must have at least 3 operands"); 316 } 317 unsigned wordIndex = 2; 318 auto name = decodeStringLiteral(words, wordIndex); 319 if (wordIndex != words.size()) { 320 return emitError(unknownLoc, 321 "unexpected trailing words in OpMemberName instruction"); 322 } 323 memberNameMap[words[0]][words[1]] = name; 324 return success(); 325 } 326 327 LogicalResult 328 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) { 329 if (curFunction) { 330 return emitError(unknownLoc, "found function inside function"); 331 } 332 333 // Get the result type 334 if (operands.size() != 4) { 335 return emitError(unknownLoc, "OpFunction must have 4 parameters"); 336 } 337 Type resultType = getType(operands[0]); 338 if (!resultType) { 339 return emitError(unknownLoc, "undefined result type from <id> ") 340 << operands[0]; 341 } 342 343 if (funcMap.count(operands[1])) { 344 return emitError(unknownLoc, "duplicate function definition/declaration"); 345 } 346 347 auto fnControl = spirv::symbolizeFunctionControl(operands[2]); 348 if (!fnControl) { 349 return emitError(unknownLoc, "unknown Function Control: ") << operands[2]; 350 } 351 352 Type fnType = getType(operands[3]); 353 if (!fnType || !fnType.isa<FunctionType>()) { 354 return emitError(unknownLoc, "unknown function type from <id> ") 355 << operands[3]; 356 } 357 auto functionType = fnType.cast<FunctionType>(); 358 359 if ((isVoidType(resultType) && functionType.getNumResults() != 0) || 360 (functionType.getNumResults() == 1 && 361 functionType.getResult(0) != resultType)) { 362 return emitError(unknownLoc, "mismatch in function type ") 363 << functionType << " and return type " << resultType << " specified"; 364 } 365 366 std::string fnName = getFunctionSymbol(operands[1]); 367 auto funcOp = opBuilder.create<spirv::FuncOp>( 368 unknownLoc, fnName, functionType, fnControl.getValue()); 369 curFunction = funcMap[operands[1]] = funcOp; 370 LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = " 371 << fnType << ", id = " << operands[1] << ") --\n"); 372 auto *entryBlock = funcOp.addEntryBlock(); 373 LLVM_DEBUG(llvm::dbgs() << "[block] created entry block " << entryBlock 374 << "\n"); 375 376 // Parse the op argument instructions 377 if (functionType.getNumInputs()) { 378 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { 379 auto argType = functionType.getInput(i); 380 spirv::Opcode opcode = spirv::Opcode::OpNop; 381 ArrayRef<uint32_t> operands; 382 if (failed(sliceInstruction(opcode, operands, 383 spirv::Opcode::OpFunctionParameter))) { 384 return failure(); 385 } 386 if (opcode != spirv::Opcode::OpFunctionParameter) { 387 return emitError( 388 unknownLoc, 389 "missing OpFunctionParameter instruction for argument ") 390 << i; 391 } 392 if (operands.size() != 2) { 393 return emitError( 394 unknownLoc, 395 "expected result type and result <id> for OpFunctionParameter"); 396 } 397 auto argDefinedType = getType(operands[0]); 398 if (!argDefinedType || argDefinedType != argType) { 399 return emitError(unknownLoc, 400 "mismatch in argument type between function type " 401 "definition ") 402 << functionType << " and argument type definition " 403 << argDefinedType << " at argument " << i; 404 } 405 if (getValue(operands[1])) { 406 return emitError(unknownLoc, "duplicate definition of result <id> '") 407 << operands[1]; 408 } 409 auto argValue = funcOp.getArgument(i); 410 valueMap[operands[1]] = argValue; 411 } 412 } 413 414 // RAII guard to reset the insertion point to the module's region after 415 // deserializing the body of this function. 416 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); 417 418 spirv::Opcode opcode = spirv::Opcode::OpNop; 419 ArrayRef<uint32_t> instOperands; 420 421 // Special handling for the entry block. We need to make sure it starts with 422 // an OpLabel instruction. The entry block takes the same parameters as the 423 // function. All other blocks do not take any parameter. We have already 424 // created the entry block, here we need to register it to the correct label 425 // <id>. 426 if (failed(sliceInstruction(opcode, instOperands, 427 spirv::Opcode::OpFunctionEnd))) { 428 return failure(); 429 } 430 if (opcode == spirv::Opcode::OpFunctionEnd) { 431 LLVM_DEBUG(llvm::dbgs() 432 << "-- completed function '" << fnName << "' (type = " << fnType 433 << ", id = " << operands[1] << ") --\n"); 434 return processFunctionEnd(instOperands); 435 } 436 if (opcode != spirv::Opcode::OpLabel) { 437 return emitError(unknownLoc, "a basic block must start with OpLabel"); 438 } 439 if (instOperands.size() != 1) { 440 return emitError(unknownLoc, "OpLabel should only have result <id>"); 441 } 442 blockMap[instOperands[0]] = entryBlock; 443 if (failed(processLabel(instOperands))) { 444 return failure(); 445 } 446 447 // Then process all the other instructions in the function until we hit 448 // OpFunctionEnd. 449 while (succeeded(sliceInstruction(opcode, instOperands, 450 spirv::Opcode::OpFunctionEnd)) && 451 opcode != spirv::Opcode::OpFunctionEnd) { 452 if (failed(processInstruction(opcode, instOperands))) { 453 return failure(); 454 } 455 } 456 if (opcode != spirv::Opcode::OpFunctionEnd) { 457 return failure(); 458 } 459 460 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = " 461 << fnType << ", id = " << operands[1] << ") --\n"); 462 return processFunctionEnd(instOperands); 463 } 464 465 LogicalResult 466 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) { 467 // Process OpFunctionEnd. 468 if (!operands.empty()) { 469 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd"); 470 } 471 472 // Wire up block arguments from OpPhi instructions. 473 // Put all structured control flow in spv.mlir.selection/spv.mlir.loop ops. 474 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) { 475 return failure(); 476 } 477 478 curBlock = nullptr; 479 curFunction = llvm::None; 480 481 return success(); 482 } 483 484 Optional<std::pair<Attribute, Type>> 485 spirv::Deserializer::getConstant(uint32_t id) { 486 auto constIt = constantMap.find(id); 487 if (constIt == constantMap.end()) 488 return llvm::None; 489 return constIt->getSecond(); 490 } 491 492 Optional<spirv::SpecConstOperationMaterializationInfo> 493 spirv::Deserializer::getSpecConstantOperation(uint32_t id) { 494 auto constIt = specConstOperationMap.find(id); 495 if (constIt == specConstOperationMap.end()) 496 return llvm::None; 497 return constIt->getSecond(); 498 } 499 500 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) { 501 auto funcName = nameMap.lookup(id).str(); 502 if (funcName.empty()) { 503 funcName = "spirv_fn_" + std::to_string(id); 504 } 505 return funcName; 506 } 507 508 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) { 509 auto constName = nameMap.lookup(id).str(); 510 if (constName.empty()) { 511 constName = "spirv_spec_const_" + std::to_string(id); 512 } 513 return constName; 514 } 515 516 spirv::SpecConstantOp 517 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID, 518 Attribute defaultValue) { 519 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); 520 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, 521 defaultValue); 522 if (decorations.count(resultID)) { 523 for (auto attr : decorations[resultID].getAttrs()) 524 op->setAttr(attr.first, attr.second); 525 } 526 specConstMap[resultID] = op; 527 return op; 528 } 529 530 LogicalResult 531 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) { 532 unsigned wordIndex = 0; 533 if (operands.size() < 3) { 534 return emitError( 535 unknownLoc, 536 "OpVariable needs at least 3 operands, type, <id> and storage class"); 537 } 538 539 // Result Type. 540 auto type = getType(operands[wordIndex]); 541 if (!type) { 542 return emitError(unknownLoc, "unknown result type <id> : ") 543 << operands[wordIndex]; 544 } 545 auto ptrType = type.dyn_cast<spirv::PointerType>(); 546 if (!ptrType) { 547 return emitError(unknownLoc, 548 "expected a result type <id> to be a spv.ptr, found : ") 549 << type; 550 } 551 wordIndex++; 552 553 // Result <id>. 554 auto variableID = operands[wordIndex]; 555 auto variableName = nameMap.lookup(variableID).str(); 556 if (variableName.empty()) { 557 variableName = "spirv_var_" + std::to_string(variableID); 558 } 559 wordIndex++; 560 561 // Storage class. 562 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]); 563 if (ptrType.getStorageClass() != storageClass) { 564 return emitError(unknownLoc, "mismatch in storage class of pointer type ") 565 << type << " and that specified in OpVariable instruction : " 566 << stringifyStorageClass(storageClass); 567 } 568 wordIndex++; 569 570 // Initializer. 571 FlatSymbolRefAttr initializer = nullptr; 572 if (wordIndex < operands.size()) { 573 auto initializerOp = getGlobalVariable(operands[wordIndex]); 574 if (!initializerOp) { 575 return emitError(unknownLoc, "unknown <id> ") 576 << operands[wordIndex] << "used as initializer"; 577 } 578 wordIndex++; 579 initializer = SymbolRefAttr::get(initializerOp.getOperation()); 580 } 581 if (wordIndex != operands.size()) { 582 return emitError(unknownLoc, 583 "found more operands than expected when deserializing " 584 "OpVariable instruction, only ") 585 << wordIndex << " of " << operands.size() << " processed"; 586 } 587 auto loc = createFileLineColLoc(opBuilder); 588 auto varOp = opBuilder.create<spirv::GlobalVariableOp>( 589 loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName), 590 initializer); 591 592 // Decorations. 593 if (decorations.count(variableID)) { 594 for (auto attr : decorations[variableID].getAttrs()) { 595 varOp->setAttr(attr.first, attr.second); 596 } 597 } 598 globalVariableMap[variableID] = varOp; 599 return success(); 600 } 601 602 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) { 603 auto constInfo = getConstant(id); 604 if (!constInfo) { 605 return nullptr; 606 } 607 return constInfo->first.dyn_cast<IntegerAttr>(); 608 } 609 610 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) { 611 if (operands.size() < 2) { 612 return emitError(unknownLoc, "OpName needs at least 2 operands"); 613 } 614 if (!nameMap.lookup(operands[0]).empty()) { 615 return emitError(unknownLoc, "duplicate name found for result <id> ") 616 << operands[0]; 617 } 618 unsigned wordIndex = 1; 619 StringRef name = decodeStringLiteral(operands, wordIndex); 620 if (wordIndex != operands.size()) { 621 return emitError(unknownLoc, 622 "unexpected trailing words in OpName instruction"); 623 } 624 nameMap[operands[0]] = name; 625 return success(); 626 } 627 628 //===----------------------------------------------------------------------===// 629 // Type 630 //===----------------------------------------------------------------------===// 631 632 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, 633 ArrayRef<uint32_t> operands) { 634 if (operands.empty()) { 635 return emitError(unknownLoc, "type instruction with opcode ") 636 << spirv::stringifyOpcode(opcode) << " needs at least one <id>"; 637 } 638 639 /// TODO: Types might be forward declared in some instructions and need to be 640 /// handled appropriately. 641 if (typeMap.count(operands[0])) { 642 return emitError(unknownLoc, "duplicate definition for result <id> ") 643 << operands[0]; 644 } 645 646 switch (opcode) { 647 case spirv::Opcode::OpTypeVoid: 648 if (operands.size() != 1) 649 return emitError(unknownLoc, "OpTypeVoid must have no parameters"); 650 typeMap[operands[0]] = opBuilder.getNoneType(); 651 break; 652 case spirv::Opcode::OpTypeBool: 653 if (operands.size() != 1) 654 return emitError(unknownLoc, "OpTypeBool must have no parameters"); 655 typeMap[operands[0]] = opBuilder.getI1Type(); 656 break; 657 case spirv::Opcode::OpTypeInt: { 658 if (operands.size() != 3) 659 return emitError( 660 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters"); 661 662 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 663 // to preserve or validate. 664 // 0 indicates unsigned, or no signedness semantics 665 // 1 indicates signed semantics." 666 // 667 // So we cannot differentiate signless and unsigned integers; always use 668 // signless semantics for such cases. 669 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed 670 : IntegerType::SignednessSemantics::Signless; 671 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign); 672 } break; 673 case spirv::Opcode::OpTypeFloat: { 674 if (operands.size() != 2) 675 return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); 676 677 Type floatTy; 678 switch (operands[1]) { 679 case 16: 680 floatTy = opBuilder.getF16Type(); 681 break; 682 case 32: 683 floatTy = opBuilder.getF32Type(); 684 break; 685 case 64: 686 floatTy = opBuilder.getF64Type(); 687 break; 688 default: 689 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ") 690 << operands[1]; 691 } 692 typeMap[operands[0]] = floatTy; 693 } break; 694 case spirv::Opcode::OpTypeVector: { 695 if (operands.size() != 3) { 696 return emitError( 697 unknownLoc, 698 "OpTypeVector must have element type and count parameters"); 699 } 700 Type elementTy = getType(operands[1]); 701 if (!elementTy) { 702 return emitError(unknownLoc, "OpTypeVector references undefined <id> ") 703 << operands[1]; 704 } 705 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy); 706 } break; 707 case spirv::Opcode::OpTypePointer: { 708 return processOpTypePointer(operands); 709 } break; 710 case spirv::Opcode::OpTypeArray: 711 return processArrayType(operands); 712 case spirv::Opcode::OpTypeCooperativeMatrixNV: 713 return processCooperativeMatrixType(operands); 714 case spirv::Opcode::OpTypeFunction: 715 return processFunctionType(operands); 716 case spirv::Opcode::OpTypeImage: 717 return processImageType(operands); 718 case spirv::Opcode::OpTypeSampledImage: 719 return processSampledImageType(operands); 720 case spirv::Opcode::OpTypeRuntimeArray: 721 return processRuntimeArrayType(operands); 722 case spirv::Opcode::OpTypeStruct: 723 return processStructType(operands); 724 case spirv::Opcode::OpTypeMatrix: 725 return processMatrixType(operands); 726 default: 727 return emitError(unknownLoc, "unhandled type instruction"); 728 } 729 return success(); 730 } 731 732 LogicalResult 733 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) { 734 if (operands.size() != 3) 735 return emitError(unknownLoc, "OpTypePointer must have two parameters"); 736 737 auto pointeeType = getType(operands[2]); 738 if (!pointeeType) 739 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ") 740 << operands[2]; 741 742 uint32_t typePointerID = operands[0]; 743 auto storageClass = static_cast<spirv::StorageClass>(operands[1]); 744 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass); 745 746 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos); 747 deferredStructIt != std::end(deferredStructTypesInfos);) { 748 for (auto *unresolvedMemberIt = 749 std::begin(deferredStructIt->unresolvedMemberTypes); 750 unresolvedMemberIt != 751 std::end(deferredStructIt->unresolvedMemberTypes);) { 752 if (unresolvedMemberIt->first == typePointerID) { 753 // The newly constructed pointer type can resolve one of the 754 // deferred struct type members; update the memberTypes list and 755 // clean the unresolvedMemberTypes list accordingly. 756 deferredStructIt->memberTypes[unresolvedMemberIt->second] = 757 typeMap[typePointerID]; 758 unresolvedMemberIt = 759 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt); 760 } else { 761 ++unresolvedMemberIt; 762 } 763 } 764 765 if (deferredStructIt->unresolvedMemberTypes.empty()) { 766 // All deferred struct type members are now resolved, set the struct body. 767 auto structType = deferredStructIt->deferredStructType; 768 769 assert(structType && "expected a spirv::StructType"); 770 assert(structType.isIdentified() && "expected an indentified struct"); 771 772 if (failed(structType.trySetBody( 773 deferredStructIt->memberTypes, deferredStructIt->offsetInfo, 774 deferredStructIt->memberDecorationsInfo))) 775 return failure(); 776 777 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt); 778 } else { 779 ++deferredStructIt; 780 } 781 } 782 783 return success(); 784 } 785 786 LogicalResult 787 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) { 788 if (operands.size() != 3) { 789 return emitError(unknownLoc, 790 "OpTypeArray must have element type and count parameters"); 791 } 792 793 Type elementTy = getType(operands[1]); 794 if (!elementTy) { 795 return emitError(unknownLoc, "OpTypeArray references undefined <id> ") 796 << operands[1]; 797 } 798 799 unsigned count = 0; 800 // TODO: The count can also come frome a specialization constant. 801 auto countInfo = getConstant(operands[2]); 802 if (!countInfo) { 803 return emitError(unknownLoc, "OpTypeArray count <id> ") 804 << operands[2] << "can only come from normal constant right now"; 805 } 806 807 if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) { 808 count = intVal.getValue().getZExtValue(); 809 } else { 810 return emitError(unknownLoc, "OpTypeArray count must come from a " 811 "scalar integer constant instruction"); 812 } 813 814 typeMap[operands[0]] = spirv::ArrayType::get( 815 elementTy, count, typeDecorations.lookup(operands[0])); 816 return success(); 817 } 818 819 LogicalResult 820 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) { 821 assert(!operands.empty() && "No operands for processing function type"); 822 if (operands.size() == 1) { 823 return emitError(unknownLoc, "missing return type for OpTypeFunction"); 824 } 825 auto returnType = getType(operands[1]); 826 if (!returnType) { 827 return emitError(unknownLoc, "unknown return type in OpTypeFunction"); 828 } 829 SmallVector<Type, 1> argTypes; 830 for (size_t i = 2, e = operands.size(); i < e; ++i) { 831 auto ty = getType(operands[i]); 832 if (!ty) { 833 return emitError(unknownLoc, "unknown argument type in OpTypeFunction"); 834 } 835 argTypes.push_back(ty); 836 } 837 ArrayRef<Type> returnTypes; 838 if (!isVoidType(returnType)) { 839 returnTypes = llvm::makeArrayRef(returnType); 840 } 841 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes); 842 return success(); 843 } 844 845 LogicalResult 846 spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) { 847 if (operands.size() != 5) { 848 return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element " 849 "type and row x column parameters"); 850 } 851 852 Type elementTy = getType(operands[1]); 853 if (!elementTy) { 854 return emitError(unknownLoc, 855 "OpTypeCooperativeMatrix references undefined <id> ") 856 << operands[1]; 857 } 858 859 auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); 860 if (!scope) { 861 return emitError(unknownLoc, 862 "OpTypeCooperativeMatrix references undefined scope <id> ") 863 << operands[2]; 864 } 865 866 unsigned rows = getConstantInt(operands[3]).getInt(); 867 unsigned columns = getConstantInt(operands[4]).getInt(); 868 869 typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get( 870 elementTy, scope.getValue(), rows, columns); 871 return success(); 872 } 873 874 LogicalResult 875 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) { 876 if (operands.size() != 2) { 877 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands"); 878 } 879 Type memberType = getType(operands[1]); 880 if (!memberType) { 881 return emitError(unknownLoc, 882 "OpTypeRuntimeArray references undefined <id> ") 883 << operands[1]; 884 } 885 typeMap[operands[0]] = spirv::RuntimeArrayType::get( 886 memberType, typeDecorations.lookup(operands[0])); 887 return success(); 888 } 889 890 LogicalResult 891 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) { 892 // TODO: Find a way to handle identified structs when debug info is stripped. 893 894 if (operands.empty()) { 895 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>"); 896 } 897 898 if (operands.size() == 1) { 899 // Handle empty struct. 900 typeMap[operands[0]] = 901 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str()); 902 return success(); 903 } 904 905 // First element is operand ID, second element is member index in the struct. 906 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes; 907 SmallVector<Type, 4> memberTypes; 908 909 for (auto op : llvm::drop_begin(operands, 1)) { 910 Type memberType = getType(op); 911 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0); 912 913 if (!memberType && !typeForwardPtr) 914 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ") 915 << op; 916 917 if (!memberType) 918 unresolvedMemberTypes.emplace_back(op, memberTypes.size()); 919 920 memberTypes.push_back(memberType); 921 } 922 923 SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo; 924 SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo; 925 if (memberDecorationMap.count(operands[0])) { 926 auto &allMemberDecorations = memberDecorationMap[operands[0]]; 927 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) { 928 if (allMemberDecorations.count(memberIndex)) { 929 for (auto &memberDecoration : allMemberDecorations[memberIndex]) { 930 // Check for offset. 931 if (memberDecoration.first == spirv::Decoration::Offset) { 932 // If offset info is empty, resize to the number of members; 933 if (offsetInfo.empty()) { 934 offsetInfo.resize(memberTypes.size()); 935 } 936 offsetInfo[memberIndex] = memberDecoration.second[0]; 937 } else { 938 if (!memberDecoration.second.empty()) { 939 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1, 940 memberDecoration.first, 941 memberDecoration.second[0]); 942 } else { 943 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0, 944 memberDecoration.first, 0); 945 } 946 } 947 } 948 } 949 } 950 } 951 952 uint32_t structID = operands[0]; 953 std::string structIdentifier = nameMap.lookup(structID).str(); 954 955 if (structIdentifier.empty()) { 956 assert(unresolvedMemberTypes.empty() && 957 "didn't expect unresolved member types"); 958 typeMap[structID] = 959 spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); 960 } else { 961 auto structTy = spirv::StructType::getIdentified(context, structIdentifier); 962 typeMap[structID] = structTy; 963 964 if (!unresolvedMemberTypes.empty()) 965 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes, 966 memberTypes, offsetInfo, 967 memberDecorationsInfo}); 968 else if (failed(structTy.trySetBody(memberTypes, offsetInfo, 969 memberDecorationsInfo))) 970 return failure(); 971 } 972 973 // TODO: Update StructType to have member name as attribute as 974 // well. 975 return success(); 976 } 977 978 LogicalResult 979 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) { 980 if (operands.size() != 3) { 981 // Three operands are needed: result_id, column_type, and column_count 982 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands" 983 " (result_id, column_type, and column_count)"); 984 } 985 // Matrix columns must be of vector type 986 Type elementTy = getType(operands[1]); 987 if (!elementTy) { 988 return emitError(unknownLoc, 989 "OpTypeMatrix references undefined column type.") 990 << operands[1]; 991 } 992 993 uint32_t colsCount = operands[2]; 994 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount); 995 return success(); 996 } 997 998 LogicalResult 999 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) { 1000 if (operands.size() != 2) 1001 return emitError(unknownLoc, 1002 "OpTypeForwardPointer instruction must have two operands"); 1003 1004 typeForwardPointerIDs.insert(operands[0]); 1005 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer 1006 // instruction that defines the actual type. 1007 1008 return success(); 1009 } 1010 1011 LogicalResult 1012 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) { 1013 // TODO: Add support for Access Qualifier. 1014 if (operands.size() != 8) 1015 return emitError( 1016 unknownLoc, 1017 "OpTypeImage with non-eight operands are not supported yet"); 1018 1019 Type elementTy = getType(operands[1]); 1020 if (!elementTy) 1021 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ") 1022 << operands[1]; 1023 1024 auto dim = spirv::symbolizeDim(operands[2]); 1025 if (!dim) 1026 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ") 1027 << operands[2]; 1028 1029 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]); 1030 if (!depthInfo) 1031 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ") 1032 << operands[3]; 1033 1034 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]); 1035 if (!arrayedInfo) 1036 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ") 1037 << operands[4]; 1038 1039 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]); 1040 if (!samplingInfo) 1041 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5]; 1042 1043 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]); 1044 if (!samplerUseInfo) 1045 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ") 1046 << operands[6]; 1047 1048 auto format = spirv::symbolizeImageFormat(operands[7]); 1049 if (!format) 1050 return emitError(unknownLoc, "unknown Format for OpTypeImage: ") 1051 << operands[7]; 1052 1053 typeMap[operands[0]] = spirv::ImageType::get( 1054 elementTy, dim.getValue(), depthInfo.getValue(), arrayedInfo.getValue(), 1055 samplingInfo.getValue(), samplerUseInfo.getValue(), format.getValue()); 1056 return success(); 1057 } 1058 1059 LogicalResult 1060 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) { 1061 if (operands.size() != 2) 1062 return emitError(unknownLoc, "OpTypeSampledImage must have two operands"); 1063 1064 Type elementTy = getType(operands[1]); 1065 if (!elementTy) 1066 return emitError(unknownLoc, 1067 "OpTypeSampledImage references undefined <id>: ") 1068 << operands[1]; 1069 1070 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy); 1071 return success(); 1072 } 1073 1074 //===----------------------------------------------------------------------===// 1075 // Constant 1076 //===----------------------------------------------------------------------===// 1077 1078 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands, 1079 bool isSpec) { 1080 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant"; 1081 1082 if (operands.size() < 2) { 1083 return emitError(unknownLoc) 1084 << opname << " must have type <id> and result <id>"; 1085 } 1086 if (operands.size() < 3) { 1087 return emitError(unknownLoc) 1088 << opname << " must have at least 1 more parameter"; 1089 } 1090 1091 Type resultType = getType(operands[0]); 1092 if (!resultType) { 1093 return emitError(unknownLoc, "undefined result type from <id> ") 1094 << operands[0]; 1095 } 1096 1097 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult { 1098 if (bitwidth == 64) { 1099 if (operands.size() == 4) { 1100 return success(); 1101 } 1102 return emitError(unknownLoc) 1103 << opname << " should have 2 parameters for 64-bit values"; 1104 } 1105 if (bitwidth <= 32) { 1106 if (operands.size() == 3) { 1107 return success(); 1108 } 1109 1110 return emitError(unknownLoc) 1111 << opname 1112 << " should have 1 parameter for values with no more than 32 bits"; 1113 } 1114 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ") 1115 << bitwidth; 1116 }; 1117 1118 auto resultID = operands[1]; 1119 1120 if (auto intType = resultType.dyn_cast<IntegerType>()) { 1121 auto bitwidth = intType.getWidth(); 1122 if (failed(checkOperandSizeForBitwidth(bitwidth))) { 1123 return failure(); 1124 } 1125 1126 APInt value; 1127 if (bitwidth == 64) { 1128 // 64-bit integers are represented with two SPIR-V words. According to 1129 // SPIR-V spec: "When the type’s bit width is larger than one word, the 1130 // literal’s low-order words appear first." 1131 struct DoubleWord { 1132 uint32_t word1; 1133 uint32_t word2; 1134 } words = {operands[2], operands[3]}; 1135 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true); 1136 } else if (bitwidth <= 32) { 1137 value = APInt(bitwidth, operands[2], /*isSigned=*/true); 1138 } 1139 1140 auto attr = opBuilder.getIntegerAttr(intType, value); 1141 1142 if (isSpec) { 1143 createSpecConstant(unknownLoc, resultID, attr); 1144 } else { 1145 // For normal constants, we just record the attribute (and its type) for 1146 // later materialization at use sites. 1147 constantMap.try_emplace(resultID, attr, intType); 1148 } 1149 1150 return success(); 1151 } 1152 1153 if (auto floatType = resultType.dyn_cast<FloatType>()) { 1154 auto bitwidth = floatType.getWidth(); 1155 if (failed(checkOperandSizeForBitwidth(bitwidth))) { 1156 return failure(); 1157 } 1158 1159 APFloat value(0.f); 1160 if (floatType.isF64()) { 1161 // Double values are represented with two SPIR-V words. According to 1162 // SPIR-V spec: "When the type’s bit width is larger than one word, the 1163 // literal’s low-order words appear first." 1164 struct DoubleWord { 1165 uint32_t word1; 1166 uint32_t word2; 1167 } words = {operands[2], operands[3]}; 1168 value = APFloat(llvm::bit_cast<double>(words)); 1169 } else if (floatType.isF32()) { 1170 value = APFloat(llvm::bit_cast<float>(operands[2])); 1171 } else if (floatType.isF16()) { 1172 APInt data(16, operands[2]); 1173 value = APFloat(APFloat::IEEEhalf(), data); 1174 } 1175 1176 auto attr = opBuilder.getFloatAttr(floatType, value); 1177 if (isSpec) { 1178 createSpecConstant(unknownLoc, resultID, attr); 1179 } else { 1180 // For normal constants, we just record the attribute (and its type) for 1181 // later materialization at use sites. 1182 constantMap.try_emplace(resultID, attr, floatType); 1183 } 1184 1185 return success(); 1186 } 1187 1188 return emitError(unknownLoc, "OpConstant can only generate values of " 1189 "scalar integer or floating-point type"); 1190 } 1191 1192 LogicalResult spirv::Deserializer::processConstantBool( 1193 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) { 1194 if (operands.size() != 2) { 1195 return emitError(unknownLoc, "Op") 1196 << (isSpec ? "Spec" : "") << "Constant" 1197 << (isTrue ? "True" : "False") 1198 << " must have type <id> and result <id>"; 1199 } 1200 1201 auto attr = opBuilder.getBoolAttr(isTrue); 1202 auto resultID = operands[1]; 1203 if (isSpec) { 1204 createSpecConstant(unknownLoc, resultID, attr); 1205 } else { 1206 // For normal constants, we just record the attribute (and its type) for 1207 // later materialization at use sites. 1208 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type()); 1209 } 1210 1211 return success(); 1212 } 1213 1214 LogicalResult 1215 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) { 1216 if (operands.size() < 2) { 1217 return emitError(unknownLoc, 1218 "OpConstantComposite must have type <id> and result <id>"); 1219 } 1220 if (operands.size() < 3) { 1221 return emitError(unknownLoc, 1222 "OpConstantComposite must have at least 1 parameter"); 1223 } 1224 1225 Type resultType = getType(operands[0]); 1226 if (!resultType) { 1227 return emitError(unknownLoc, "undefined result type from <id> ") 1228 << operands[0]; 1229 } 1230 1231 SmallVector<Attribute, 4> elements; 1232 elements.reserve(operands.size() - 2); 1233 for (unsigned i = 2, e = operands.size(); i < e; ++i) { 1234 auto elementInfo = getConstant(operands[i]); 1235 if (!elementInfo) { 1236 return emitError(unknownLoc, "OpConstantComposite component <id> ") 1237 << operands[i] << " must come from a normal constant"; 1238 } 1239 elements.push_back(elementInfo->first); 1240 } 1241 1242 auto resultID = operands[1]; 1243 if (auto vectorType = resultType.dyn_cast<VectorType>()) { 1244 auto attr = DenseElementsAttr::get(vectorType, elements); 1245 // For normal constants, we just record the attribute (and its type) for 1246 // later materialization at use sites. 1247 constantMap.try_emplace(resultID, attr, resultType); 1248 } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) { 1249 auto attr = opBuilder.getArrayAttr(elements); 1250 constantMap.try_emplace(resultID, attr, resultType); 1251 } else { 1252 return emitError(unknownLoc, "unsupported OpConstantComposite type: ") 1253 << resultType; 1254 } 1255 1256 return success(); 1257 } 1258 1259 LogicalResult 1260 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) { 1261 if (operands.size() < 2) { 1262 return emitError(unknownLoc, 1263 "OpConstantComposite must have type <id> and result <id>"); 1264 } 1265 if (operands.size() < 3) { 1266 return emitError(unknownLoc, 1267 "OpConstantComposite must have at least 1 parameter"); 1268 } 1269 1270 Type resultType = getType(operands[0]); 1271 if (!resultType) { 1272 return emitError(unknownLoc, "undefined result type from <id> ") 1273 << operands[0]; 1274 } 1275 1276 auto resultID = operands[1]; 1277 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); 1278 1279 SmallVector<Attribute, 4> elements; 1280 elements.reserve(operands.size() - 2); 1281 for (unsigned i = 2, e = operands.size(); i < e; ++i) { 1282 auto elementInfo = getSpecConstant(operands[i]); 1283 elements.push_back(SymbolRefAttr::get(elementInfo)); 1284 } 1285 1286 auto op = opBuilder.create<spirv::SpecConstantCompositeOp>( 1287 unknownLoc, TypeAttr::get(resultType), symName, 1288 opBuilder.getArrayAttr(elements)); 1289 specConstCompositeMap[resultID] = op; 1290 1291 return success(); 1292 } 1293 1294 LogicalResult 1295 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) { 1296 if (operands.size() < 3) 1297 return emitError(unknownLoc, "OpConstantOperation must have type <id>, " 1298 "result <id>, and operand opcode"); 1299 1300 uint32_t resultTypeID = operands[0]; 1301 1302 if (!getType(resultTypeID)) 1303 return emitError(unknownLoc, "undefined result type from <id> ") 1304 << resultTypeID; 1305 1306 uint32_t resultID = operands[1]; 1307 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]); 1308 auto emplaceResult = specConstOperationMap.try_emplace( 1309 resultID, 1310 SpecConstOperationMaterializationInfo{ 1311 enclosedOpcode, resultTypeID, 1312 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}}); 1313 1314 if (!emplaceResult.second) 1315 return emitError(unknownLoc, "value with <id>: ") 1316 << resultID << " is probably defined before."; 1317 1318 return success(); 1319 } 1320 1321 Value spirv::Deserializer::materializeSpecConstantOperation( 1322 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, 1323 ArrayRef<uint32_t> enclosedOpOperands) { 1324 1325 Type resultType = getType(resultTypeID); 1326 1327 // Instructions wrapped by OpSpecConstantOp need an ID for their 1328 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V 1329 // dialect wrapped op. For that purpose, a new value map is created and "fake" 1330 // ID in that map is assigned to the result of the enclosed instruction. Note 1331 // that there is no need to update this fake ID since we only need to 1332 // reference the created Value for the enclosed op from the spv::YieldOp 1333 // created later in this method (both of which are the only values in their 1334 // region: the SpecConstantOperation's region). If we encounter another 1335 // SpecConstantOperation in the module, we simply re-use the fake ID since the 1336 // previous Value assigned to it isn't visible in the current scope anyway. 1337 DenseMap<uint32_t, Value> newValueMap; 1338 llvm::SaveAndRestore<DenseMap<uint32_t, Value>> valueMapGuard(valueMap, 1339 newValueMap); 1340 constexpr uint32_t fakeID = static_cast<uint32_t>(-3); 1341 1342 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands; 1343 enclosedOpResultTypeAndOperands.push_back(resultTypeID); 1344 enclosedOpResultTypeAndOperands.push_back(fakeID); 1345 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(), 1346 enclosedOpOperands.end()); 1347 1348 // Process enclosed instruction before creating the enclosing 1349 // specConstantOperation (and its region). This way, references to constants, 1350 // global variables, and spec constants will be materialized outside the new 1351 // op's region. For more info, see Deserializer::getValue's implementation. 1352 if (failed( 1353 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands))) 1354 return Value(); 1355 1356 // Since the enclosed op is emitted in the current block, split it in a 1357 // separate new block. 1358 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back()); 1359 1360 auto loc = createFileLineColLoc(opBuilder); 1361 auto specConstOperationOp = 1362 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType); 1363 1364 Region &body = specConstOperationOp.body(); 1365 // Move the new block into SpecConstantOperation's body. 1366 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(), 1367 Region::iterator(enclosedBlock)); 1368 Block &block = body.back(); 1369 1370 // RAII guard to reset the insertion point to the module's region after 1371 // deserializing the body of the specConstantOperation. 1372 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); 1373 opBuilder.setInsertionPointToEnd(&block); 1374 1375 opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0)); 1376 return specConstOperationOp.getResult(); 1377 } 1378 1379 LogicalResult 1380 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) { 1381 if (operands.size() != 2) { 1382 return emitError(unknownLoc, 1383 "OpConstantNull must have type <id> and result <id>"); 1384 } 1385 1386 Type resultType = getType(operands[0]); 1387 if (!resultType) { 1388 return emitError(unknownLoc, "undefined result type from <id> ") 1389 << operands[0]; 1390 } 1391 1392 auto resultID = operands[1]; 1393 if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) { 1394 auto attr = opBuilder.getZeroAttr(resultType); 1395 // For normal constants, we just record the attribute (and its type) for 1396 // later materialization at use sites. 1397 constantMap.try_emplace(resultID, attr, resultType); 1398 return success(); 1399 } 1400 1401 return emitError(unknownLoc, "unsupported OpConstantNull type: ") 1402 << resultType; 1403 } 1404 1405 //===----------------------------------------------------------------------===// 1406 // Control flow 1407 //===----------------------------------------------------------------------===// 1408 1409 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) { 1410 if (auto *block = getBlock(id)) { 1411 LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id 1412 << " @ " << block << "\n"); 1413 return block; 1414 } 1415 1416 // We don't know where this block will be placed finally (in a 1417 // spv.mlir.selection or spv.mlir.loop or function). Create it into the 1418 // function for now and sort out the proper place later. 1419 auto *block = curFunction->addBlock(); 1420 LLVM_DEBUG(llvm::dbgs() << "[block] created block for id = " << id << " @ " 1421 << block << "\n"); 1422 return blockMap[id] = block; 1423 } 1424 1425 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) { 1426 if (!curBlock) { 1427 return emitError(unknownLoc, "OpBranch must appear inside a block"); 1428 } 1429 1430 if (operands.size() != 1) { 1431 return emitError(unknownLoc, "OpBranch must take exactly one target label"); 1432 } 1433 1434 auto *target = getOrCreateBlock(operands[0]); 1435 auto loc = createFileLineColLoc(opBuilder); 1436 // The preceding instruction for the OpBranch instruction could be an 1437 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have 1438 // the same OpLine information. 1439 opBuilder.create<spirv::BranchOp>(loc, target); 1440 1441 (void)clearDebugLine(); 1442 return success(); 1443 } 1444 1445 LogicalResult 1446 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) { 1447 if (!curBlock) { 1448 return emitError(unknownLoc, 1449 "OpBranchConditional must appear inside a block"); 1450 } 1451 1452 if (operands.size() != 3 && operands.size() != 5) { 1453 return emitError(unknownLoc, 1454 "OpBranchConditional must have condition, true label, " 1455 "false label, and optionally two branch weights"); 1456 } 1457 1458 auto condition = getValue(operands[0]); 1459 auto *trueBlock = getOrCreateBlock(operands[1]); 1460 auto *falseBlock = getOrCreateBlock(operands[2]); 1461 1462 Optional<std::pair<uint32_t, uint32_t>> weights; 1463 if (operands.size() == 5) { 1464 weights = std::make_pair(operands[3], operands[4]); 1465 } 1466 // The preceding instruction for the OpBranchConditional instruction could be 1467 // an OpSelectionMerge instruction, in this case they will have the same 1468 // OpLine information. 1469 auto loc = createFileLineColLoc(opBuilder); 1470 opBuilder.create<spirv::BranchConditionalOp>( 1471 loc, condition, trueBlock, 1472 /*trueArguments=*/ArrayRef<Value>(), falseBlock, 1473 /*falseArguments=*/ArrayRef<Value>(), weights); 1474 1475 (void)clearDebugLine(); 1476 return success(); 1477 } 1478 1479 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) { 1480 if (!curFunction) { 1481 return emitError(unknownLoc, "OpLabel must appear inside a function"); 1482 } 1483 1484 if (operands.size() != 1) { 1485 return emitError(unknownLoc, "OpLabel should only have result <id>"); 1486 } 1487 1488 auto labelID = operands[0]; 1489 // We may have forward declared this block. 1490 auto *block = getOrCreateBlock(labelID); 1491 LLVM_DEBUG(llvm::dbgs() << "[block] populating block " << block << "\n"); 1492 // If we have seen this block, make sure it was just a forward declaration. 1493 assert(block->empty() && "re-deserialize the same block!"); 1494 1495 opBuilder.setInsertionPointToStart(block); 1496 blockMap[labelID] = curBlock = block; 1497 1498 return success(); 1499 } 1500 1501 LogicalResult 1502 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) { 1503 if (!curBlock) { 1504 return emitError(unknownLoc, "OpSelectionMerge must appear in a block"); 1505 } 1506 1507 if (operands.size() < 2) { 1508 return emitError( 1509 unknownLoc, 1510 "OpSelectionMerge must specify merge target and selection control"); 1511 } 1512 1513 auto *mergeBlock = getOrCreateBlock(operands[0]); 1514 auto loc = createFileLineColLoc(opBuilder); 1515 auto selectionControl = operands[1]; 1516 1517 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock) 1518 .second) { 1519 return emitError( 1520 unknownLoc, 1521 "a block cannot have more than one OpSelectionMerge instruction"); 1522 } 1523 1524 return success(); 1525 } 1526 1527 LogicalResult 1528 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) { 1529 if (!curBlock) { 1530 return emitError(unknownLoc, "OpLoopMerge must appear in a block"); 1531 } 1532 1533 if (operands.size() < 3) { 1534 return emitError(unknownLoc, "OpLoopMerge must specify merge target, " 1535 "continue target and loop control"); 1536 } 1537 1538 auto *mergeBlock = getOrCreateBlock(operands[0]); 1539 auto *continueBlock = getOrCreateBlock(operands[1]); 1540 auto loc = createFileLineColLoc(opBuilder); 1541 uint32_t loopControl = operands[2]; 1542 1543 if (!blockMergeInfo 1544 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock) 1545 .second) { 1546 return emitError( 1547 unknownLoc, 1548 "a block cannot have more than one OpLoopMerge instruction"); 1549 } 1550 1551 return success(); 1552 } 1553 1554 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) { 1555 if (!curBlock) { 1556 return emitError(unknownLoc, "OpPhi must appear in a block"); 1557 } 1558 1559 if (operands.size() < 4) { 1560 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, " 1561 "and variable-parent pairs"); 1562 } 1563 1564 // Create a block argument for this OpPhi instruction. 1565 Type blockArgType = getType(operands[0]); 1566 BlockArgument blockArg = curBlock->addArgument(blockArgType); 1567 valueMap[operands[1]] = blockArg; 1568 LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg 1569 << " id = " << operands[1] << " of type " 1570 << blockArgType << '\n'); 1571 1572 // For each (value, predecessor) pair, insert the value to the predecessor's 1573 // blockPhiInfo entry so later we can fix the block argument there. 1574 for (unsigned i = 2, e = operands.size(); i < e; i += 2) { 1575 uint32_t value = operands[i]; 1576 Block *predecessor = getOrCreateBlock(operands[i + 1]); 1577 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock}; 1578 blockPhiInfo[predecessorTargetPair].push_back(value); 1579 LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor 1580 << " with arg id = " << value << '\n'); 1581 } 1582 1583 return success(); 1584 } 1585 1586 namespace { 1587 /// A class for putting all blocks in a structured selection/loop in a 1588 /// spv.mlir.selection/spv.mlir.loop op. 1589 class ControlFlowStructurizer { 1590 public: 1591 /// Structurizes the loop at the given `headerBlock`. 1592 /// 1593 /// This method will create an spv.mlir.loop op in the `mergeBlock` and move 1594 /// all blocks in the structured loop into the spv.mlir.loop's region. All 1595 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This 1596 /// method will also update `mergeInfo` by remapping all blocks inside to the 1597 /// newly cloned ones inside structured control flow op's regions. 1598 static LogicalResult structurize(Location loc, uint32_t control, 1599 spirv::BlockMergeInfoMap &mergeInfo, 1600 Block *headerBlock, Block *mergeBlock, 1601 Block *continueBlock) { 1602 return ControlFlowStructurizer(loc, control, mergeInfo, headerBlock, 1603 mergeBlock, continueBlock) 1604 .structurizeImpl(); 1605 } 1606 1607 private: 1608 ControlFlowStructurizer(Location loc, uint32_t control, 1609 spirv::BlockMergeInfoMap &mergeInfo, Block *header, 1610 Block *merge, Block *cont) 1611 : location(loc), control(control), blockMergeInfo(mergeInfo), 1612 headerBlock(header), mergeBlock(merge), continueBlock(cont) {} 1613 1614 /// Creates a new spv.mlir.selection op at the beginning of the `mergeBlock`. 1615 spirv::SelectionOp createSelectionOp(uint32_t selectionControl); 1616 1617 /// Creates a new spv.mlir.loop op at the beginning of the `mergeBlock`. 1618 spirv::LoopOp createLoopOp(uint32_t loopControl); 1619 1620 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`. 1621 void collectBlocksInConstruct(); 1622 1623 LogicalResult structurizeImpl(); 1624 1625 Location location; 1626 uint32_t control; 1627 1628 spirv::BlockMergeInfoMap &blockMergeInfo; 1629 1630 Block *headerBlock; 1631 Block *mergeBlock; 1632 Block *continueBlock; // nullptr for spv.mlir.selection 1633 1634 SetVector<Block *> constructBlocks; 1635 }; 1636 } // namespace 1637 1638 spirv::SelectionOp 1639 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) { 1640 // Create a builder and set the insertion point to the beginning of the 1641 // merge block so that the newly created SelectionOp will be inserted there. 1642 OpBuilder builder(&mergeBlock->front()); 1643 1644 auto control = static_cast<spirv::SelectionControl>(selectionControl); 1645 auto selectionOp = builder.create<spirv::SelectionOp>(location, control); 1646 selectionOp.addMergeBlock(); 1647 1648 return selectionOp; 1649 } 1650 1651 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) { 1652 // Create a builder and set the insertion point to the beginning of the 1653 // merge block so that the newly created LoopOp will be inserted there. 1654 OpBuilder builder(&mergeBlock->front()); 1655 1656 auto control = static_cast<spirv::LoopControl>(loopControl); 1657 auto loopOp = builder.create<spirv::LoopOp>(location, control); 1658 loopOp.addEntryAndMergeBlock(); 1659 1660 return loopOp; 1661 } 1662 1663 void ControlFlowStructurizer::collectBlocksInConstruct() { 1664 assert(constructBlocks.empty() && "expected empty constructBlocks"); 1665 1666 // Put the header block in the work list first. 1667 constructBlocks.insert(headerBlock); 1668 1669 // For each item in the work list, add its successors excluding the merge 1670 // block. 1671 for (unsigned i = 0; i < constructBlocks.size(); ++i) { 1672 for (auto *successor : constructBlocks[i]->getSuccessors()) 1673 if (successor != mergeBlock) 1674 constructBlocks.insert(successor); 1675 } 1676 } 1677 1678 LogicalResult ControlFlowStructurizer::structurizeImpl() { 1679 Operation *op = nullptr; 1680 bool isLoop = continueBlock != nullptr; 1681 if (isLoop) { 1682 if (auto loopOp = createLoopOp(control)) 1683 op = loopOp.getOperation(); 1684 } else { 1685 if (auto selectionOp = createSelectionOp(control)) 1686 op = selectionOp.getOperation(); 1687 } 1688 if (!op) 1689 return failure(); 1690 Region &body = op->getRegion(0); 1691 1692 BlockAndValueMapping mapper; 1693 // All references to the old merge block should be directed to the 1694 // selection/loop merge block in the SelectionOp/LoopOp's region. 1695 mapper.map(mergeBlock, &body.back()); 1696 1697 collectBlocksInConstruct(); 1698 1699 // We've identified all blocks belonging to the selection/loop's region. Now 1700 // need to "move" them into the selection/loop. Instead of really moving the 1701 // blocks, in the following we copy them and remap all values and branches. 1702 // This is because: 1703 // * Inserting a block into a region requires the block not in any region 1704 // before. But selections/loops can nest so we can create selection/loop ops 1705 // in a nested manner, which means some blocks may already be in a 1706 // selection/loop region when to be moved again. 1707 // * It's much trickier to fix up the branches into and out of the loop's 1708 // region: we need to treat not-moved blocks and moved blocks differently: 1709 // Not-moved blocks jumping to the loop header block need to jump to the 1710 // merge point containing the new loop op but not the loop continue block's 1711 // back edge. Moved blocks jumping out of the loop need to jump to the 1712 // merge block inside the loop region but not other not-moved blocks. 1713 // We cannot use replaceAllUsesWith clearly and it's harder to follow the 1714 // logic. 1715 1716 // Create a corresponding block in the SelectionOp/LoopOp's region for each 1717 // block in this loop construct. 1718 OpBuilder builder(body); 1719 for (auto *block : constructBlocks) { 1720 // Create a block and insert it before the selection/loop merge block in the 1721 // SelectionOp/LoopOp's region. 1722 auto *newBlock = builder.createBlock(&body.back()); 1723 mapper.map(block, newBlock); 1724 LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock 1725 << " from block " << block << "\n"); 1726 if (!isFnEntryBlock(block)) { 1727 for (BlockArgument blockArg : block->getArguments()) { 1728 auto newArg = newBlock->addArgument(blockArg.getType()); 1729 mapper.map(blockArg, newArg); 1730 LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg 1731 << " to " << newArg << '\n'); 1732 } 1733 } else { 1734 LLVM_DEBUG(llvm::dbgs() 1735 << "[cf] block " << block << " is a function entry block\n"); 1736 } 1737 for (auto &op : *block) 1738 newBlock->push_back(op.clone(mapper)); 1739 } 1740 1741 // Go through all ops and remap the operands. 1742 auto remapOperands = [&](Operation *op) { 1743 for (auto &operand : op->getOpOperands()) 1744 if (Value mappedOp = mapper.lookupOrNull(operand.get())) 1745 operand.set(mappedOp); 1746 for (auto &succOp : op->getBlockOperands()) 1747 if (Block *mappedOp = mapper.lookupOrNull(succOp.get())) 1748 succOp.set(mappedOp); 1749 }; 1750 for (auto &block : body) { 1751 block.walk(remapOperands); 1752 } 1753 1754 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to 1755 // the selection/loop construct into its region. Next we need to fix the 1756 // connections between this new SelectionOp/LoopOp with existing blocks. 1757 1758 // All existing incoming branches should go to the merge block, where the 1759 // SelectionOp/LoopOp resides right now. 1760 headerBlock->replaceAllUsesWith(mergeBlock); 1761 1762 if (isLoop) { 1763 // The loop selection/loop header block may have block arguments. Since now 1764 // we place the selection/loop op inside the old merge block, we need to 1765 // make sure the old merge block has the same block argument list. 1766 assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported"); 1767 for (BlockArgument blockArg : headerBlock->getArguments()) { 1768 mergeBlock->addArgument(blockArg.getType()); 1769 } 1770 1771 // If the loop header block has block arguments, make sure the spv.branch op 1772 // matches. 1773 SmallVector<Value, 4> blockArgs; 1774 if (!headerBlock->args_empty()) 1775 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; 1776 1777 // The loop entry block should have a unconditional branch jumping to the 1778 // loop header block. 1779 builder.setInsertionPointToEnd(&body.front()); 1780 builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock), 1781 ArrayRef<Value>(blockArgs)); 1782 } 1783 1784 // All the blocks cloned into the SelectionOp/LoopOp's region can now be 1785 // cleaned up. 1786 LLVM_DEBUG(llvm::dbgs() << "[cf] cleaning up blocks after clone\n"); 1787 // First we need to drop all operands' references inside all blocks. This is 1788 // needed because we can have blocks referencing SSA values from one another. 1789 for (auto *block : constructBlocks) 1790 block->dropAllReferences(); 1791 1792 // Then erase all old blocks. 1793 for (auto *block : constructBlocks) { 1794 // We've cloned all blocks belonging to this construct into the structured 1795 // control flow op's region. Among these blocks, some may compose another 1796 // selection/loop. If so, they will be recorded within blockMergeInfo. 1797 // We need to update the pointers there to the newly remapped ones so we can 1798 // continue structurizing them later. 1799 // TODO: The asserts in the following assumes input SPIR-V blob 1800 // forms correctly nested selection/loop constructs. We should relax this 1801 // and support error cases better. 1802 auto it = blockMergeInfo.find(block); 1803 if (it != blockMergeInfo.end()) { 1804 Block *newHeader = mapper.lookupOrNull(block); 1805 assert(newHeader && "nested loop header block should be remapped!"); 1806 1807 Block *newContinue = it->second.continueBlock; 1808 if (newContinue) { 1809 newContinue = mapper.lookupOrNull(newContinue); 1810 assert(newContinue && "nested loop continue block should be remapped!"); 1811 } 1812 1813 Block *newMerge = it->second.mergeBlock; 1814 if (Block *mappedTo = mapper.lookupOrNull(newMerge)) 1815 newMerge = mappedTo; 1816 1817 // Keep original location for nested selection/loop ops. 1818 Location loc = it->second.loc; 1819 // The iterator should be erased before adding a new entry into 1820 // blockMergeInfo to avoid iterator invalidation. 1821 blockMergeInfo.erase(it); 1822 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge, 1823 newContinue); 1824 } 1825 1826 // The structured selection/loop's entry block does not have arguments. 1827 // If the function's header block is also part of the structured control 1828 // flow, we cannot just simply erase it because it may contain arguments 1829 // matching the function signature and used by the cloned blocks. 1830 if (isFnEntryBlock(block)) { 1831 LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block 1832 << " to only contain a spv.Branch op\n"); 1833 // Still keep the function entry block for the potential block arguments, 1834 // but replace all ops inside with a branch to the merge block. 1835 block->clear(); 1836 builder.setInsertionPointToEnd(block); 1837 builder.create<spirv::BranchOp>(location, mergeBlock); 1838 } else { 1839 LLVM_DEBUG(llvm::dbgs() << "[cf] erasing block " << block << "\n"); 1840 block->erase(); 1841 } 1842 } 1843 1844 LLVM_DEBUG( 1845 llvm::dbgs() << "[cf] after structurizing construct with header block " 1846 << headerBlock << ":\n" 1847 << *op << '\n'); 1848 1849 return success(); 1850 } 1851 1852 LogicalResult spirv::Deserializer::wireUpBlockArgument() { 1853 LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n"); 1854 1855 OpBuilder::InsertionGuard guard(opBuilder); 1856 1857 for (const auto &info : blockPhiInfo) { 1858 Block *block = info.first.first; 1859 Block *target = info.first.second; 1860 const BlockPhiInfo &phiInfo = info.second; 1861 LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n"); 1862 LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n"); 1863 LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); 1864 LLVM_DEBUG(llvm::dbgs() << '\n'); 1865 1866 // Set insertion point to before this block's terminator early because we 1867 // may materialize ops via getValue() call. 1868 auto *op = block->getTerminator(); 1869 opBuilder.setInsertionPoint(op); 1870 1871 SmallVector<Value, 4> blockArgs; 1872 blockArgs.reserve(phiInfo.size()); 1873 for (uint32_t valueId : phiInfo) { 1874 if (Value value = getValue(valueId)) { 1875 blockArgs.push_back(value); 1876 LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value 1877 << " id = " << valueId << '\n'); 1878 } else { 1879 return emitError(unknownLoc, "OpPhi references undefined value!"); 1880 } 1881 } 1882 1883 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) { 1884 // Replace the previous branch op with a new one with block arguments. 1885 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(), 1886 blockArgs); 1887 branchOp.erase(); 1888 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) { 1889 assert((branchCondOp.getTrueBlock() == target || 1890 branchCondOp.getFalseBlock() == target) && 1891 "expected target to be either the true or false target"); 1892 if (target == branchCondOp.trueTarget()) 1893 opBuilder.create<spirv::BranchConditionalOp>( 1894 branchCondOp.getLoc(), branchCondOp.condition(), blockArgs, 1895 branchCondOp.getFalseBlockArguments(), 1896 branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(), 1897 branchCondOp.falseTarget()); 1898 else 1899 opBuilder.create<spirv::BranchConditionalOp>( 1900 branchCondOp.getLoc(), branchCondOp.condition(), 1901 branchCondOp.getTrueBlockArguments(), blockArgs, 1902 branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(), 1903 branchCondOp.getFalseBlock()); 1904 1905 branchCondOp.erase(); 1906 } else { 1907 return emitError(unknownLoc, "unimplemented terminator for Phi creation"); 1908 } 1909 1910 LLVM_DEBUG(llvm::dbgs() << "[phi] after creating block argument:\n"); 1911 LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); 1912 LLVM_DEBUG(llvm::dbgs() << '\n'); 1913 } 1914 blockPhiInfo.clear(); 1915 1916 LLVM_DEBUG(llvm::dbgs() << "[phi] completed wiring up block arguments\n"); 1917 return success(); 1918 } 1919 1920 LogicalResult spirv::Deserializer::structurizeControlFlow() { 1921 LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n"); 1922 1923 while (!blockMergeInfo.empty()) { 1924 Block *headerBlock = blockMergeInfo.begin()->first; 1925 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; 1926 1927 LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n"); 1928 LLVM_DEBUG(headerBlock->print(llvm::dbgs())); 1929 1930 auto *mergeBlock = mergeInfo.mergeBlock; 1931 assert(mergeBlock && "merge block cannot be nullptr"); 1932 if (!mergeBlock->args_empty()) 1933 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented"); 1934 LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n"); 1935 LLVM_DEBUG(mergeBlock->print(llvm::dbgs())); 1936 1937 auto *continueBlock = mergeInfo.continueBlock; 1938 if (continueBlock) { 1939 LLVM_DEBUG(llvm::dbgs() 1940 << "[cf] continue block " << continueBlock << ":\n"); 1941 LLVM_DEBUG(continueBlock->print(llvm::dbgs())); 1942 } 1943 // Erase this case before calling into structurizer, who will update 1944 // blockMergeInfo. 1945 blockMergeInfo.erase(blockMergeInfo.begin()); 1946 if (failed(ControlFlowStructurizer::structurize( 1947 mergeInfo.loc, mergeInfo.control, blockMergeInfo, headerBlock, 1948 mergeBlock, continueBlock))) 1949 return failure(); 1950 } 1951 1952 LLVM_DEBUG(llvm::dbgs() << "[cf] completed structurizing control flow\n"); 1953 return success(); 1954 } 1955 1956 //===----------------------------------------------------------------------===// 1957 // Debug 1958 //===----------------------------------------------------------------------===// 1959 1960 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) { 1961 if (!debugLine) 1962 return unknownLoc; 1963 1964 auto fileName = debugInfoMap.lookup(debugLine->fileID).str(); 1965 if (fileName.empty()) 1966 fileName = "<unknown>"; 1967 return FileLineColLoc::get(opBuilder.getIdentifier(fileName), debugLine->line, 1968 debugLine->col); 1969 } 1970 1971 LogicalResult 1972 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) { 1973 // According to SPIR-V spec: 1974 // "This location information applies to the instructions physically 1975 // following this instruction, up to the first occurrence of any of the 1976 // following: the next end of block, the next OpLine instruction, or the next 1977 // OpNoLine instruction." 1978 if (operands.size() != 3) 1979 return emitError(unknownLoc, "OpLine must have 3 operands"); 1980 debugLine = DebugLine(operands[0], operands[1], operands[2]); 1981 return success(); 1982 } 1983 1984 LogicalResult spirv::Deserializer::clearDebugLine() { 1985 debugLine = llvm::None; 1986 return success(); 1987 } 1988 1989 LogicalResult 1990 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) { 1991 if (operands.size() < 2) 1992 return emitError(unknownLoc, "OpString needs at least 2 operands"); 1993 1994 if (!debugInfoMap.lookup(operands[0]).empty()) 1995 return emitError(unknownLoc, 1996 "duplicate debug string found for result <id> ") 1997 << operands[0]; 1998 1999 unsigned wordIndex = 1; 2000 StringRef debugString = decodeStringLiteral(operands, wordIndex); 2001 if (wordIndex != operands.size()) 2002 return emitError(unknownLoc, 2003 "unexpected trailing words in OpString instruction"); 2004 2005 debugInfoMap[operands[0]] = debugString; 2006 return success(); 2007 } 2008