1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Deserializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/Location.h" 22 #include "mlir/Support/LogicalResult.h" 23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/Sequence.h" 26 #include "llvm/ADT/SmallVector.h" 27 #include "llvm/ADT/StringExtras.h" 28 #include "llvm/ADT/bit.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/SaveAndRestore.h" 31 #include "llvm/Support/raw_ostream.h" 32 33 using namespace mlir; 34 35 #define DEBUG_TYPE "spirv-deserialization" 36 37 //===----------------------------------------------------------------------===// 38 // Utility Functions 39 //===----------------------------------------------------------------------===// 40 41 /// Returns true if the given `block` is a function entry block. 42 static inline bool isFnEntryBlock(Block *block) { 43 return block->isEntryBlock() && 44 isa_and_nonnull<spirv::FuncOp>(block->getParentOp()); 45 } 46 47 //===----------------------------------------------------------------------===// 48 // Deserializer Method Definitions 49 //===----------------------------------------------------------------------===// 50 51 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary, 52 MLIRContext *context) 53 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), 54 module(createModuleOp()), opBuilder(module->getRegion()) {} 55 56 LogicalResult spirv::Deserializer::deserialize() { 57 LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n"); 58 59 if (failed(processHeader())) 60 return failure(); 61 62 spirv::Opcode opcode = spirv::Opcode::OpNop; 63 ArrayRef<uint32_t> operands; 64 auto binarySize = binary.size(); 65 while (curOffset < binarySize) { 66 // Slice the next instruction out and populate `opcode` and `operands`. 67 // Internally this also updates `curOffset`. 68 if (failed(sliceInstruction(opcode, operands))) 69 return failure(); 70 71 if (failed(processInstruction(opcode, operands))) 72 return failure(); 73 } 74 75 assert(curOffset == binarySize && 76 "deserializer should never index beyond the binary end"); 77 78 for (auto &deferred : deferredInstructions) { 79 if (failed(processInstruction(deferred.first, deferred.second, false))) { 80 return failure(); 81 } 82 } 83 84 attachVCETriple(); 85 86 LLVM_DEBUG(llvm::dbgs() << "+++ completed deserialization +++\n"); 87 return success(); 88 } 89 90 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() { 91 return std::move(module); 92 } 93 94 //===----------------------------------------------------------------------===// 95 // Module structure 96 //===----------------------------------------------------------------------===// 97 98 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() { 99 OpBuilder builder(context); 100 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); 101 spirv::ModuleOp::build(builder, state); 102 return cast<spirv::ModuleOp>(Operation::create(state)); 103 } 104 105 LogicalResult spirv::Deserializer::processHeader() { 106 if (binary.size() < spirv::kHeaderWordCount) 107 return emitError(unknownLoc, 108 "SPIR-V binary module must have a 5-word header"); 109 110 if (binary[0] != spirv::kMagicNumber) 111 return emitError(unknownLoc, "incorrect magic number"); 112 113 // Version number bytes: 0 | major number | minor number | 0 114 uint32_t majorVersion = (binary[1] << 8) >> 24; 115 uint32_t minorVersion = (binary[1] << 16) >> 24; 116 if (majorVersion == 1) { 117 switch (minorVersion) { 118 #define MIN_VERSION_CASE(v) \ 119 case v: \ 120 version = spirv::Version::V_1_##v; \ 121 break 122 123 MIN_VERSION_CASE(0); 124 MIN_VERSION_CASE(1); 125 MIN_VERSION_CASE(2); 126 MIN_VERSION_CASE(3); 127 MIN_VERSION_CASE(4); 128 MIN_VERSION_CASE(5); 129 #undef MIN_VERSION_CASE 130 default: 131 return emitError(unknownLoc, "unsupported SPIR-V minor version: ") 132 << minorVersion; 133 } 134 } else { 135 return emitError(unknownLoc, "unsupported SPIR-V major version: ") 136 << majorVersion; 137 } 138 139 // TODO: generator number, bound, schema 140 curOffset = spirv::kHeaderWordCount; 141 return success(); 142 } 143 144 LogicalResult 145 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) { 146 if (operands.size() != 1) 147 return emitError(unknownLoc, "OpMemoryModel must have one parameter"); 148 149 auto cap = spirv::symbolizeCapability(operands[0]); 150 if (!cap) 151 return emitError(unknownLoc, "unknown capability: ") << operands[0]; 152 153 capabilities.insert(*cap); 154 return success(); 155 } 156 157 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) { 158 if (words.empty()) { 159 return emitError( 160 unknownLoc, 161 "OpExtension must have a literal string for the extension name"); 162 } 163 164 unsigned wordIndex = 0; 165 StringRef extName = decodeStringLiteral(words, wordIndex); 166 if (wordIndex != words.size()) 167 return emitError(unknownLoc, 168 "unexpected trailing words in OpExtension instruction"); 169 auto ext = spirv::symbolizeExtension(extName); 170 if (!ext) 171 return emitError(unknownLoc, "unknown extension: ") << extName; 172 173 extensions.insert(*ext); 174 return success(); 175 } 176 177 LogicalResult 178 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) { 179 if (words.size() < 2) { 180 return emitError(unknownLoc, 181 "OpExtInstImport must have a result <id> and a literal " 182 "string for the extended instruction set name"); 183 } 184 185 unsigned wordIndex = 1; 186 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex); 187 if (wordIndex != words.size()) { 188 return emitError(unknownLoc, 189 "unexpected trailing words in OpExtInstImport"); 190 } 191 return success(); 192 } 193 194 void spirv::Deserializer::attachVCETriple() { 195 (*module)->setAttr( 196 spirv::ModuleOp::getVCETripleAttrName(), 197 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(), 198 extensions.getArrayRef(), context)); 199 } 200 201 LogicalResult 202 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) { 203 if (operands.size() != 2) 204 return emitError(unknownLoc, "OpMemoryModel must have two operands"); 205 206 (*module)->setAttr( 207 "addressing_model", 208 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front()))); 209 (*module)->setAttr( 210 "memory_model", 211 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back()))); 212 213 return success(); 214 } 215 216 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) { 217 // TODO: This function should also be auto-generated. For now, since only a 218 // few decorations are processed/handled in a meaningful manner, going with a 219 // manual implementation. 220 if (words.size() < 2) { 221 return emitError( 222 unknownLoc, "OpDecorate must have at least result <id> and Decoration"); 223 } 224 auto decorationName = 225 stringifyDecoration(static_cast<spirv::Decoration>(words[1])); 226 if (decorationName.empty()) { 227 return emitError(unknownLoc, "invalid Decoration code : ") << words[1]; 228 } 229 auto attrName = llvm::convertToSnakeFromCamelCase(decorationName); 230 auto symbol = opBuilder.getStringAttr(attrName); 231 switch (static_cast<spirv::Decoration>(words[1])) { 232 case spirv::Decoration::DescriptorSet: 233 case spirv::Decoration::Binding: 234 if (words.size() != 3) { 235 return emitError(unknownLoc, "OpDecorate with ") 236 << decorationName << " needs a single integer literal"; 237 } 238 decorations[words[0]].set( 239 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); 240 break; 241 case spirv::Decoration::BuiltIn: 242 if (words.size() != 3) { 243 return emitError(unknownLoc, "OpDecorate with ") 244 << decorationName << " needs a single integer literal"; 245 } 246 decorations[words[0]].set( 247 symbol, opBuilder.getStringAttr( 248 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2])))); 249 break; 250 case spirv::Decoration::ArrayStride: 251 if (words.size() != 3) { 252 return emitError(unknownLoc, "OpDecorate with ") 253 << decorationName << " needs a single integer literal"; 254 } 255 typeDecorations[words[0]] = words[2]; 256 break; 257 case spirv::Decoration::Aliased: 258 case spirv::Decoration::Block: 259 case spirv::Decoration::BufferBlock: 260 case spirv::Decoration::Flat: 261 case spirv::Decoration::NonReadable: 262 case spirv::Decoration::NonWritable: 263 case spirv::Decoration::NoPerspective: 264 case spirv::Decoration::Restrict: 265 case spirv::Decoration::RelaxedPrecision: 266 if (words.size() != 2) { 267 return emitError(unknownLoc, "OpDecoration with ") 268 << decorationName << "needs a single target <id>"; 269 } 270 // Block decoration does not affect spv.struct type, but is still stored for 271 // verification. 272 // TODO: Update StructType to contain this information since 273 // it is needed for many validation rules. 274 decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); 275 break; 276 case spirv::Decoration::Location: 277 case spirv::Decoration::SpecId: 278 if (words.size() != 3) { 279 return emitError(unknownLoc, "OpDecoration with ") 280 << decorationName << "needs a single integer literal"; 281 } 282 decorations[words[0]].set( 283 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); 284 break; 285 default: 286 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; 287 } 288 return success(); 289 } 290 291 LogicalResult 292 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) { 293 // The binary layout of OpMemberDecorate is different comparing to OpDecorate 294 if (words.size() < 3) { 295 return emitError(unknownLoc, 296 "OpMemberDecorate must have at least 3 operands"); 297 } 298 299 auto decoration = static_cast<spirv::Decoration>(words[2]); 300 if (decoration == spirv::Decoration::Offset && words.size() != 4) { 301 return emitError(unknownLoc, 302 " missing offset specification in OpMemberDecorate with " 303 "Offset decoration"); 304 } 305 ArrayRef<uint32_t> decorationOperands; 306 if (words.size() > 3) { 307 decorationOperands = words.slice(3); 308 } 309 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands; 310 return success(); 311 } 312 313 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) { 314 if (words.size() < 3) { 315 return emitError(unknownLoc, "OpMemberName must have at least 3 operands"); 316 } 317 unsigned wordIndex = 2; 318 auto name = decodeStringLiteral(words, wordIndex); 319 if (wordIndex != words.size()) { 320 return emitError(unknownLoc, 321 "unexpected trailing words in OpMemberName instruction"); 322 } 323 memberNameMap[words[0]][words[1]] = name; 324 return success(); 325 } 326 327 LogicalResult 328 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) { 329 if (curFunction) { 330 return emitError(unknownLoc, "found function inside function"); 331 } 332 333 // Get the result type 334 if (operands.size() != 4) { 335 return emitError(unknownLoc, "OpFunction must have 4 parameters"); 336 } 337 Type resultType = getType(operands[0]); 338 if (!resultType) { 339 return emitError(unknownLoc, "undefined result type from <id> ") 340 << operands[0]; 341 } 342 343 if (funcMap.count(operands[1])) { 344 return emitError(unknownLoc, "duplicate function definition/declaration"); 345 } 346 347 auto fnControl = spirv::symbolizeFunctionControl(operands[2]); 348 if (!fnControl) { 349 return emitError(unknownLoc, "unknown Function Control: ") << operands[2]; 350 } 351 352 Type fnType = getType(operands[3]); 353 if (!fnType || !fnType.isa<FunctionType>()) { 354 return emitError(unknownLoc, "unknown function type from <id> ") 355 << operands[3]; 356 } 357 auto functionType = fnType.cast<FunctionType>(); 358 359 if ((isVoidType(resultType) && functionType.getNumResults() != 0) || 360 (functionType.getNumResults() == 1 && 361 functionType.getResult(0) != resultType)) { 362 return emitError(unknownLoc, "mismatch in function type ") 363 << functionType << " and return type " << resultType << " specified"; 364 } 365 366 std::string fnName = getFunctionSymbol(operands[1]); 367 auto funcOp = opBuilder.create<spirv::FuncOp>( 368 unknownLoc, fnName, functionType, fnControl.getValue()); 369 curFunction = funcMap[operands[1]] = funcOp; 370 LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = " 371 << fnType << ", id = " << operands[1] << ") --\n"); 372 auto *entryBlock = funcOp.addEntryBlock(); 373 LLVM_DEBUG(llvm::dbgs() << "[block] created entry block " << entryBlock 374 << "\n"); 375 376 // Parse the op argument instructions 377 if (functionType.getNumInputs()) { 378 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { 379 auto argType = functionType.getInput(i); 380 spirv::Opcode opcode = spirv::Opcode::OpNop; 381 ArrayRef<uint32_t> operands; 382 if (failed(sliceInstruction(opcode, operands, 383 spirv::Opcode::OpFunctionParameter))) { 384 return failure(); 385 } 386 if (opcode != spirv::Opcode::OpFunctionParameter) { 387 return emitError( 388 unknownLoc, 389 "missing OpFunctionParameter instruction for argument ") 390 << i; 391 } 392 if (operands.size() != 2) { 393 return emitError( 394 unknownLoc, 395 "expected result type and result <id> for OpFunctionParameter"); 396 } 397 auto argDefinedType = getType(operands[0]); 398 if (!argDefinedType || argDefinedType != argType) { 399 return emitError(unknownLoc, 400 "mismatch in argument type between function type " 401 "definition ") 402 << functionType << " and argument type definition " 403 << argDefinedType << " at argument " << i; 404 } 405 if (getValue(operands[1])) { 406 return emitError(unknownLoc, "duplicate definition of result <id> '") 407 << operands[1]; 408 } 409 auto argValue = funcOp.getArgument(i); 410 valueMap[operands[1]] = argValue; 411 } 412 } 413 414 // RAII guard to reset the insertion point to the module's region after 415 // deserializing the body of this function. 416 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); 417 418 spirv::Opcode opcode = spirv::Opcode::OpNop; 419 ArrayRef<uint32_t> instOperands; 420 421 // Special handling for the entry block. We need to make sure it starts with 422 // an OpLabel instruction. The entry block takes the same parameters as the 423 // function. All other blocks do not take any parameter. We have already 424 // created the entry block, here we need to register it to the correct label 425 // <id>. 426 if (failed(sliceInstruction(opcode, instOperands, 427 spirv::Opcode::OpFunctionEnd))) { 428 return failure(); 429 } 430 if (opcode == spirv::Opcode::OpFunctionEnd) { 431 LLVM_DEBUG(llvm::dbgs() 432 << "-- completed function '" << fnName << "' (type = " << fnType 433 << ", id = " << operands[1] << ") --\n"); 434 return processFunctionEnd(instOperands); 435 } 436 if (opcode != spirv::Opcode::OpLabel) { 437 return emitError(unknownLoc, "a basic block must start with OpLabel"); 438 } 439 if (instOperands.size() != 1) { 440 return emitError(unknownLoc, "OpLabel should only have result <id>"); 441 } 442 blockMap[instOperands[0]] = entryBlock; 443 if (failed(processLabel(instOperands))) { 444 return failure(); 445 } 446 447 // Then process all the other instructions in the function until we hit 448 // OpFunctionEnd. 449 while (succeeded(sliceInstruction(opcode, instOperands, 450 spirv::Opcode::OpFunctionEnd)) && 451 opcode != spirv::Opcode::OpFunctionEnd) { 452 if (failed(processInstruction(opcode, instOperands))) { 453 return failure(); 454 } 455 } 456 if (opcode != spirv::Opcode::OpFunctionEnd) { 457 return failure(); 458 } 459 460 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = " 461 << fnType << ", id = " << operands[1] << ") --\n"); 462 return processFunctionEnd(instOperands); 463 } 464 465 LogicalResult 466 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) { 467 // Process OpFunctionEnd. 468 if (!operands.empty()) { 469 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd"); 470 } 471 472 // Wire up block arguments from OpPhi instructions. 473 // Put all structured control flow in spv.mlir.selection/spv.mlir.loop ops. 474 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) { 475 return failure(); 476 } 477 478 curBlock = nullptr; 479 curFunction = llvm::None; 480 481 return success(); 482 } 483 484 Optional<std::pair<Attribute, Type>> 485 spirv::Deserializer::getConstant(uint32_t id) { 486 auto constIt = constantMap.find(id); 487 if (constIt == constantMap.end()) 488 return llvm::None; 489 return constIt->getSecond(); 490 } 491 492 Optional<spirv::SpecConstOperationMaterializationInfo> 493 spirv::Deserializer::getSpecConstantOperation(uint32_t id) { 494 auto constIt = specConstOperationMap.find(id); 495 if (constIt == specConstOperationMap.end()) 496 return llvm::None; 497 return constIt->getSecond(); 498 } 499 500 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) { 501 auto funcName = nameMap.lookup(id).str(); 502 if (funcName.empty()) { 503 funcName = "spirv_fn_" + std::to_string(id); 504 } 505 return funcName; 506 } 507 508 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) { 509 auto constName = nameMap.lookup(id).str(); 510 if (constName.empty()) { 511 constName = "spirv_spec_const_" + std::to_string(id); 512 } 513 return constName; 514 } 515 516 spirv::SpecConstantOp 517 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID, 518 Attribute defaultValue) { 519 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); 520 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, 521 defaultValue); 522 if (decorations.count(resultID)) { 523 for (auto attr : decorations[resultID].getAttrs()) 524 op->setAttr(attr.getName(), attr.getValue()); 525 } 526 specConstMap[resultID] = op; 527 return op; 528 } 529 530 LogicalResult 531 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) { 532 unsigned wordIndex = 0; 533 if (operands.size() < 3) { 534 return emitError( 535 unknownLoc, 536 "OpVariable needs at least 3 operands, type, <id> and storage class"); 537 } 538 539 // Result Type. 540 auto type = getType(operands[wordIndex]); 541 if (!type) { 542 return emitError(unknownLoc, "unknown result type <id> : ") 543 << operands[wordIndex]; 544 } 545 auto ptrType = type.dyn_cast<spirv::PointerType>(); 546 if (!ptrType) { 547 return emitError(unknownLoc, 548 "expected a result type <id> to be a spv.ptr, found : ") 549 << type; 550 } 551 wordIndex++; 552 553 // Result <id>. 554 auto variableID = operands[wordIndex]; 555 auto variableName = nameMap.lookup(variableID).str(); 556 if (variableName.empty()) { 557 variableName = "spirv_var_" + std::to_string(variableID); 558 } 559 wordIndex++; 560 561 // Storage class. 562 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]); 563 if (ptrType.getStorageClass() != storageClass) { 564 return emitError(unknownLoc, "mismatch in storage class of pointer type ") 565 << type << " and that specified in OpVariable instruction : " 566 << stringifyStorageClass(storageClass); 567 } 568 wordIndex++; 569 570 // Initializer. 571 FlatSymbolRefAttr initializer = nullptr; 572 if (wordIndex < operands.size()) { 573 auto initializerOp = getGlobalVariable(operands[wordIndex]); 574 if (!initializerOp) { 575 return emitError(unknownLoc, "unknown <id> ") 576 << operands[wordIndex] << "used as initializer"; 577 } 578 wordIndex++; 579 initializer = SymbolRefAttr::get(initializerOp.getOperation()); 580 } 581 if (wordIndex != operands.size()) { 582 return emitError(unknownLoc, 583 "found more operands than expected when deserializing " 584 "OpVariable instruction, only ") 585 << wordIndex << " of " << operands.size() << " processed"; 586 } 587 auto loc = createFileLineColLoc(opBuilder); 588 auto varOp = opBuilder.create<spirv::GlobalVariableOp>( 589 loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName), 590 initializer); 591 592 // Decorations. 593 if (decorations.count(variableID)) { 594 for (auto attr : decorations[variableID].getAttrs()) 595 varOp->setAttr(attr.getName(), attr.getValue()); 596 } 597 globalVariableMap[variableID] = varOp; 598 return success(); 599 } 600 601 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) { 602 auto constInfo = getConstant(id); 603 if (!constInfo) { 604 return nullptr; 605 } 606 return constInfo->first.dyn_cast<IntegerAttr>(); 607 } 608 609 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) { 610 if (operands.size() < 2) { 611 return emitError(unknownLoc, "OpName needs at least 2 operands"); 612 } 613 if (!nameMap.lookup(operands[0]).empty()) { 614 return emitError(unknownLoc, "duplicate name found for result <id> ") 615 << operands[0]; 616 } 617 unsigned wordIndex = 1; 618 StringRef name = decodeStringLiteral(operands, wordIndex); 619 if (wordIndex != operands.size()) { 620 return emitError(unknownLoc, 621 "unexpected trailing words in OpName instruction"); 622 } 623 nameMap[operands[0]] = name; 624 return success(); 625 } 626 627 //===----------------------------------------------------------------------===// 628 // Type 629 //===----------------------------------------------------------------------===// 630 631 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, 632 ArrayRef<uint32_t> operands) { 633 if (operands.empty()) { 634 return emitError(unknownLoc, "type instruction with opcode ") 635 << spirv::stringifyOpcode(opcode) << " needs at least one <id>"; 636 } 637 638 /// TODO: Types might be forward declared in some instructions and need to be 639 /// handled appropriately. 640 if (typeMap.count(operands[0])) { 641 return emitError(unknownLoc, "duplicate definition for result <id> ") 642 << operands[0]; 643 } 644 645 switch (opcode) { 646 case spirv::Opcode::OpTypeVoid: 647 if (operands.size() != 1) 648 return emitError(unknownLoc, "OpTypeVoid must have no parameters"); 649 typeMap[operands[0]] = opBuilder.getNoneType(); 650 break; 651 case spirv::Opcode::OpTypeBool: 652 if (operands.size() != 1) 653 return emitError(unknownLoc, "OpTypeBool must have no parameters"); 654 typeMap[operands[0]] = opBuilder.getI1Type(); 655 break; 656 case spirv::Opcode::OpTypeInt: { 657 if (operands.size() != 3) 658 return emitError( 659 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters"); 660 661 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 662 // to preserve or validate. 663 // 0 indicates unsigned, or no signedness semantics 664 // 1 indicates signed semantics." 665 // 666 // So we cannot differentiate signless and unsigned integers; always use 667 // signless semantics for such cases. 668 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed 669 : IntegerType::SignednessSemantics::Signless; 670 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign); 671 } break; 672 case spirv::Opcode::OpTypeFloat: { 673 if (operands.size() != 2) 674 return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); 675 676 Type floatTy; 677 switch (operands[1]) { 678 case 16: 679 floatTy = opBuilder.getF16Type(); 680 break; 681 case 32: 682 floatTy = opBuilder.getF32Type(); 683 break; 684 case 64: 685 floatTy = opBuilder.getF64Type(); 686 break; 687 default: 688 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ") 689 << operands[1]; 690 } 691 typeMap[operands[0]] = floatTy; 692 } break; 693 case spirv::Opcode::OpTypeVector: { 694 if (operands.size() != 3) { 695 return emitError( 696 unknownLoc, 697 "OpTypeVector must have element type and count parameters"); 698 } 699 Type elementTy = getType(operands[1]); 700 if (!elementTy) { 701 return emitError(unknownLoc, "OpTypeVector references undefined <id> ") 702 << operands[1]; 703 } 704 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy); 705 } break; 706 case spirv::Opcode::OpTypePointer: { 707 return processOpTypePointer(operands); 708 } break; 709 case spirv::Opcode::OpTypeArray: 710 return processArrayType(operands); 711 case spirv::Opcode::OpTypeCooperativeMatrixNV: 712 return processCooperativeMatrixType(operands); 713 case spirv::Opcode::OpTypeFunction: 714 return processFunctionType(operands); 715 case spirv::Opcode::OpTypeImage: 716 return processImageType(operands); 717 case spirv::Opcode::OpTypeSampledImage: 718 return processSampledImageType(operands); 719 case spirv::Opcode::OpTypeRuntimeArray: 720 return processRuntimeArrayType(operands); 721 case spirv::Opcode::OpTypeStruct: 722 return processStructType(operands); 723 case spirv::Opcode::OpTypeMatrix: 724 return processMatrixType(operands); 725 default: 726 return emitError(unknownLoc, "unhandled type instruction"); 727 } 728 return success(); 729 } 730 731 LogicalResult 732 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) { 733 if (operands.size() != 3) 734 return emitError(unknownLoc, "OpTypePointer must have two parameters"); 735 736 auto pointeeType = getType(operands[2]); 737 if (!pointeeType) 738 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ") 739 << operands[2]; 740 741 uint32_t typePointerID = operands[0]; 742 auto storageClass = static_cast<spirv::StorageClass>(operands[1]); 743 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass); 744 745 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos); 746 deferredStructIt != std::end(deferredStructTypesInfos);) { 747 for (auto *unresolvedMemberIt = 748 std::begin(deferredStructIt->unresolvedMemberTypes); 749 unresolvedMemberIt != 750 std::end(deferredStructIt->unresolvedMemberTypes);) { 751 if (unresolvedMemberIt->first == typePointerID) { 752 // The newly constructed pointer type can resolve one of the 753 // deferred struct type members; update the memberTypes list and 754 // clean the unresolvedMemberTypes list accordingly. 755 deferredStructIt->memberTypes[unresolvedMemberIt->second] = 756 typeMap[typePointerID]; 757 unresolvedMemberIt = 758 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt); 759 } else { 760 ++unresolvedMemberIt; 761 } 762 } 763 764 if (deferredStructIt->unresolvedMemberTypes.empty()) { 765 // All deferred struct type members are now resolved, set the struct body. 766 auto structType = deferredStructIt->deferredStructType; 767 768 assert(structType && "expected a spirv::StructType"); 769 assert(structType.isIdentified() && "expected an indentified struct"); 770 771 if (failed(structType.trySetBody( 772 deferredStructIt->memberTypes, deferredStructIt->offsetInfo, 773 deferredStructIt->memberDecorationsInfo))) 774 return failure(); 775 776 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt); 777 } else { 778 ++deferredStructIt; 779 } 780 } 781 782 return success(); 783 } 784 785 LogicalResult 786 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) { 787 if (operands.size() != 3) { 788 return emitError(unknownLoc, 789 "OpTypeArray must have element type and count parameters"); 790 } 791 792 Type elementTy = getType(operands[1]); 793 if (!elementTy) { 794 return emitError(unknownLoc, "OpTypeArray references undefined <id> ") 795 << operands[1]; 796 } 797 798 unsigned count = 0; 799 // TODO: The count can also come frome a specialization constant. 800 auto countInfo = getConstant(operands[2]); 801 if (!countInfo) { 802 return emitError(unknownLoc, "OpTypeArray count <id> ") 803 << operands[2] << "can only come from normal constant right now"; 804 } 805 806 if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) { 807 count = intVal.getValue().getZExtValue(); 808 } else { 809 return emitError(unknownLoc, "OpTypeArray count must come from a " 810 "scalar integer constant instruction"); 811 } 812 813 typeMap[operands[0]] = spirv::ArrayType::get( 814 elementTy, count, typeDecorations.lookup(operands[0])); 815 return success(); 816 } 817 818 LogicalResult 819 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) { 820 assert(!operands.empty() && "No operands for processing function type"); 821 if (operands.size() == 1) { 822 return emitError(unknownLoc, "missing return type for OpTypeFunction"); 823 } 824 auto returnType = getType(operands[1]); 825 if (!returnType) { 826 return emitError(unknownLoc, "unknown return type in OpTypeFunction"); 827 } 828 SmallVector<Type, 1> argTypes; 829 for (size_t i = 2, e = operands.size(); i < e; ++i) { 830 auto ty = getType(operands[i]); 831 if (!ty) { 832 return emitError(unknownLoc, "unknown argument type in OpTypeFunction"); 833 } 834 argTypes.push_back(ty); 835 } 836 ArrayRef<Type> returnTypes; 837 if (!isVoidType(returnType)) { 838 returnTypes = llvm::makeArrayRef(returnType); 839 } 840 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes); 841 return success(); 842 } 843 844 LogicalResult 845 spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) { 846 if (operands.size() != 5) { 847 return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element " 848 "type and row x column parameters"); 849 } 850 851 Type elementTy = getType(operands[1]); 852 if (!elementTy) { 853 return emitError(unknownLoc, 854 "OpTypeCooperativeMatrix references undefined <id> ") 855 << operands[1]; 856 } 857 858 auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); 859 if (!scope) { 860 return emitError(unknownLoc, 861 "OpTypeCooperativeMatrix references undefined scope <id> ") 862 << operands[2]; 863 } 864 865 unsigned rows = getConstantInt(operands[3]).getInt(); 866 unsigned columns = getConstantInt(operands[4]).getInt(); 867 868 typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get( 869 elementTy, scope.getValue(), rows, columns); 870 return success(); 871 } 872 873 LogicalResult 874 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) { 875 if (operands.size() != 2) { 876 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands"); 877 } 878 Type memberType = getType(operands[1]); 879 if (!memberType) { 880 return emitError(unknownLoc, 881 "OpTypeRuntimeArray references undefined <id> ") 882 << operands[1]; 883 } 884 typeMap[operands[0]] = spirv::RuntimeArrayType::get( 885 memberType, typeDecorations.lookup(operands[0])); 886 return success(); 887 } 888 889 LogicalResult 890 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) { 891 // TODO: Find a way to handle identified structs when debug info is stripped. 892 893 if (operands.empty()) { 894 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>"); 895 } 896 897 if (operands.size() == 1) { 898 // Handle empty struct. 899 typeMap[operands[0]] = 900 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str()); 901 return success(); 902 } 903 904 // First element is operand ID, second element is member index in the struct. 905 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes; 906 SmallVector<Type, 4> memberTypes; 907 908 for (auto op : llvm::drop_begin(operands, 1)) { 909 Type memberType = getType(op); 910 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0); 911 912 if (!memberType && !typeForwardPtr) 913 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ") 914 << op; 915 916 if (!memberType) 917 unresolvedMemberTypes.emplace_back(op, memberTypes.size()); 918 919 memberTypes.push_back(memberType); 920 } 921 922 SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo; 923 SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo; 924 if (memberDecorationMap.count(operands[0])) { 925 auto &allMemberDecorations = memberDecorationMap[operands[0]]; 926 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) { 927 if (allMemberDecorations.count(memberIndex)) { 928 for (auto &memberDecoration : allMemberDecorations[memberIndex]) { 929 // Check for offset. 930 if (memberDecoration.first == spirv::Decoration::Offset) { 931 // If offset info is empty, resize to the number of members; 932 if (offsetInfo.empty()) { 933 offsetInfo.resize(memberTypes.size()); 934 } 935 offsetInfo[memberIndex] = memberDecoration.second[0]; 936 } else { 937 if (!memberDecoration.second.empty()) { 938 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1, 939 memberDecoration.first, 940 memberDecoration.second[0]); 941 } else { 942 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0, 943 memberDecoration.first, 0); 944 } 945 } 946 } 947 } 948 } 949 } 950 951 uint32_t structID = operands[0]; 952 std::string structIdentifier = nameMap.lookup(structID).str(); 953 954 if (structIdentifier.empty()) { 955 assert(unresolvedMemberTypes.empty() && 956 "didn't expect unresolved member types"); 957 typeMap[structID] = 958 spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); 959 } else { 960 auto structTy = spirv::StructType::getIdentified(context, structIdentifier); 961 typeMap[structID] = structTy; 962 963 if (!unresolvedMemberTypes.empty()) 964 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes, 965 memberTypes, offsetInfo, 966 memberDecorationsInfo}); 967 else if (failed(structTy.trySetBody(memberTypes, offsetInfo, 968 memberDecorationsInfo))) 969 return failure(); 970 } 971 972 // TODO: Update StructType to have member name as attribute as 973 // well. 974 return success(); 975 } 976 977 LogicalResult 978 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) { 979 if (operands.size() != 3) { 980 // Three operands are needed: result_id, column_type, and column_count 981 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands" 982 " (result_id, column_type, and column_count)"); 983 } 984 // Matrix columns must be of vector type 985 Type elementTy = getType(operands[1]); 986 if (!elementTy) { 987 return emitError(unknownLoc, 988 "OpTypeMatrix references undefined column type.") 989 << operands[1]; 990 } 991 992 uint32_t colsCount = operands[2]; 993 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount); 994 return success(); 995 } 996 997 LogicalResult 998 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) { 999 if (operands.size() != 2) 1000 return emitError(unknownLoc, 1001 "OpTypeForwardPointer instruction must have two operands"); 1002 1003 typeForwardPointerIDs.insert(operands[0]); 1004 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer 1005 // instruction that defines the actual type. 1006 1007 return success(); 1008 } 1009 1010 LogicalResult 1011 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) { 1012 // TODO: Add support for Access Qualifier. 1013 if (operands.size() != 8) 1014 return emitError( 1015 unknownLoc, 1016 "OpTypeImage with non-eight operands are not supported yet"); 1017 1018 Type elementTy = getType(operands[1]); 1019 if (!elementTy) 1020 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ") 1021 << operands[1]; 1022 1023 auto dim = spirv::symbolizeDim(operands[2]); 1024 if (!dim) 1025 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ") 1026 << operands[2]; 1027 1028 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]); 1029 if (!depthInfo) 1030 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ") 1031 << operands[3]; 1032 1033 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]); 1034 if (!arrayedInfo) 1035 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ") 1036 << operands[4]; 1037 1038 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]); 1039 if (!samplingInfo) 1040 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5]; 1041 1042 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]); 1043 if (!samplerUseInfo) 1044 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ") 1045 << operands[6]; 1046 1047 auto format = spirv::symbolizeImageFormat(operands[7]); 1048 if (!format) 1049 return emitError(unknownLoc, "unknown Format for OpTypeImage: ") 1050 << operands[7]; 1051 1052 typeMap[operands[0]] = spirv::ImageType::get( 1053 elementTy, dim.getValue(), depthInfo.getValue(), arrayedInfo.getValue(), 1054 samplingInfo.getValue(), samplerUseInfo.getValue(), format.getValue()); 1055 return success(); 1056 } 1057 1058 LogicalResult 1059 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) { 1060 if (operands.size() != 2) 1061 return emitError(unknownLoc, "OpTypeSampledImage must have two operands"); 1062 1063 Type elementTy = getType(operands[1]); 1064 if (!elementTy) 1065 return emitError(unknownLoc, 1066 "OpTypeSampledImage references undefined <id>: ") 1067 << operands[1]; 1068 1069 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy); 1070 return success(); 1071 } 1072 1073 //===----------------------------------------------------------------------===// 1074 // Constant 1075 //===----------------------------------------------------------------------===// 1076 1077 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands, 1078 bool isSpec) { 1079 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant"; 1080 1081 if (operands.size() < 2) { 1082 return emitError(unknownLoc) 1083 << opname << " must have type <id> and result <id>"; 1084 } 1085 if (operands.size() < 3) { 1086 return emitError(unknownLoc) 1087 << opname << " must have at least 1 more parameter"; 1088 } 1089 1090 Type resultType = getType(operands[0]); 1091 if (!resultType) { 1092 return emitError(unknownLoc, "undefined result type from <id> ") 1093 << operands[0]; 1094 } 1095 1096 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult { 1097 if (bitwidth == 64) { 1098 if (operands.size() == 4) { 1099 return success(); 1100 } 1101 return emitError(unknownLoc) 1102 << opname << " should have 2 parameters for 64-bit values"; 1103 } 1104 if (bitwidth <= 32) { 1105 if (operands.size() == 3) { 1106 return success(); 1107 } 1108 1109 return emitError(unknownLoc) 1110 << opname 1111 << " should have 1 parameter for values with no more than 32 bits"; 1112 } 1113 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ") 1114 << bitwidth; 1115 }; 1116 1117 auto resultID = operands[1]; 1118 1119 if (auto intType = resultType.dyn_cast<IntegerType>()) { 1120 auto bitwidth = intType.getWidth(); 1121 if (failed(checkOperandSizeForBitwidth(bitwidth))) { 1122 return failure(); 1123 } 1124 1125 APInt value; 1126 if (bitwidth == 64) { 1127 // 64-bit integers are represented with two SPIR-V words. According to 1128 // SPIR-V spec: "When the type’s bit width is larger than one word, the 1129 // literal’s low-order words appear first." 1130 struct DoubleWord { 1131 uint32_t word1; 1132 uint32_t word2; 1133 } words = {operands[2], operands[3]}; 1134 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true); 1135 } else if (bitwidth <= 32) { 1136 value = APInt(bitwidth, operands[2], /*isSigned=*/true); 1137 } 1138 1139 auto attr = opBuilder.getIntegerAttr(intType, value); 1140 1141 if (isSpec) { 1142 createSpecConstant(unknownLoc, resultID, attr); 1143 } else { 1144 // For normal constants, we just record the attribute (and its type) for 1145 // later materialization at use sites. 1146 constantMap.try_emplace(resultID, attr, intType); 1147 } 1148 1149 return success(); 1150 } 1151 1152 if (auto floatType = resultType.dyn_cast<FloatType>()) { 1153 auto bitwidth = floatType.getWidth(); 1154 if (failed(checkOperandSizeForBitwidth(bitwidth))) { 1155 return failure(); 1156 } 1157 1158 APFloat value(0.f); 1159 if (floatType.isF64()) { 1160 // Double values are represented with two SPIR-V words. According to 1161 // SPIR-V spec: "When the type’s bit width is larger than one word, the 1162 // literal’s low-order words appear first." 1163 struct DoubleWord { 1164 uint32_t word1; 1165 uint32_t word2; 1166 } words = {operands[2], operands[3]}; 1167 value = APFloat(llvm::bit_cast<double>(words)); 1168 } else if (floatType.isF32()) { 1169 value = APFloat(llvm::bit_cast<float>(operands[2])); 1170 } else if (floatType.isF16()) { 1171 APInt data(16, operands[2]); 1172 value = APFloat(APFloat::IEEEhalf(), data); 1173 } 1174 1175 auto attr = opBuilder.getFloatAttr(floatType, value); 1176 if (isSpec) { 1177 createSpecConstant(unknownLoc, resultID, attr); 1178 } else { 1179 // For normal constants, we just record the attribute (and its type) for 1180 // later materialization at use sites. 1181 constantMap.try_emplace(resultID, attr, floatType); 1182 } 1183 1184 return success(); 1185 } 1186 1187 return emitError(unknownLoc, "OpConstant can only generate values of " 1188 "scalar integer or floating-point type"); 1189 } 1190 1191 LogicalResult spirv::Deserializer::processConstantBool( 1192 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) { 1193 if (operands.size() != 2) { 1194 return emitError(unknownLoc, "Op") 1195 << (isSpec ? "Spec" : "") << "Constant" 1196 << (isTrue ? "True" : "False") 1197 << " must have type <id> and result <id>"; 1198 } 1199 1200 auto attr = opBuilder.getBoolAttr(isTrue); 1201 auto resultID = operands[1]; 1202 if (isSpec) { 1203 createSpecConstant(unknownLoc, resultID, attr); 1204 } else { 1205 // For normal constants, we just record the attribute (and its type) for 1206 // later materialization at use sites. 1207 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type()); 1208 } 1209 1210 return success(); 1211 } 1212 1213 LogicalResult 1214 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) { 1215 if (operands.size() < 2) { 1216 return emitError(unknownLoc, 1217 "OpConstantComposite must have type <id> and result <id>"); 1218 } 1219 if (operands.size() < 3) { 1220 return emitError(unknownLoc, 1221 "OpConstantComposite must have at least 1 parameter"); 1222 } 1223 1224 Type resultType = getType(operands[0]); 1225 if (!resultType) { 1226 return emitError(unknownLoc, "undefined result type from <id> ") 1227 << operands[0]; 1228 } 1229 1230 SmallVector<Attribute, 4> elements; 1231 elements.reserve(operands.size() - 2); 1232 for (unsigned i = 2, e = operands.size(); i < e; ++i) { 1233 auto elementInfo = getConstant(operands[i]); 1234 if (!elementInfo) { 1235 return emitError(unknownLoc, "OpConstantComposite component <id> ") 1236 << operands[i] << " must come from a normal constant"; 1237 } 1238 elements.push_back(elementInfo->first); 1239 } 1240 1241 auto resultID = operands[1]; 1242 if (auto vectorType = resultType.dyn_cast<VectorType>()) { 1243 auto attr = DenseElementsAttr::get(vectorType, elements); 1244 // For normal constants, we just record the attribute (and its type) for 1245 // later materialization at use sites. 1246 constantMap.try_emplace(resultID, attr, resultType); 1247 } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) { 1248 auto attr = opBuilder.getArrayAttr(elements); 1249 constantMap.try_emplace(resultID, attr, resultType); 1250 } else { 1251 return emitError(unknownLoc, "unsupported OpConstantComposite type: ") 1252 << resultType; 1253 } 1254 1255 return success(); 1256 } 1257 1258 LogicalResult 1259 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) { 1260 if (operands.size() < 2) { 1261 return emitError(unknownLoc, 1262 "OpConstantComposite must have type <id> and result <id>"); 1263 } 1264 if (operands.size() < 3) { 1265 return emitError(unknownLoc, 1266 "OpConstantComposite must have at least 1 parameter"); 1267 } 1268 1269 Type resultType = getType(operands[0]); 1270 if (!resultType) { 1271 return emitError(unknownLoc, "undefined result type from <id> ") 1272 << operands[0]; 1273 } 1274 1275 auto resultID = operands[1]; 1276 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); 1277 1278 SmallVector<Attribute, 4> elements; 1279 elements.reserve(operands.size() - 2); 1280 for (unsigned i = 2, e = operands.size(); i < e; ++i) { 1281 auto elementInfo = getSpecConstant(operands[i]); 1282 elements.push_back(SymbolRefAttr::get(elementInfo)); 1283 } 1284 1285 auto op = opBuilder.create<spirv::SpecConstantCompositeOp>( 1286 unknownLoc, TypeAttr::get(resultType), symName, 1287 opBuilder.getArrayAttr(elements)); 1288 specConstCompositeMap[resultID] = op; 1289 1290 return success(); 1291 } 1292 1293 LogicalResult 1294 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) { 1295 if (operands.size() < 3) 1296 return emitError(unknownLoc, "OpConstantOperation must have type <id>, " 1297 "result <id>, and operand opcode"); 1298 1299 uint32_t resultTypeID = operands[0]; 1300 1301 if (!getType(resultTypeID)) 1302 return emitError(unknownLoc, "undefined result type from <id> ") 1303 << resultTypeID; 1304 1305 uint32_t resultID = operands[1]; 1306 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]); 1307 auto emplaceResult = specConstOperationMap.try_emplace( 1308 resultID, 1309 SpecConstOperationMaterializationInfo{ 1310 enclosedOpcode, resultTypeID, 1311 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}}); 1312 1313 if (!emplaceResult.second) 1314 return emitError(unknownLoc, "value with <id>: ") 1315 << resultID << " is probably defined before."; 1316 1317 return success(); 1318 } 1319 1320 Value spirv::Deserializer::materializeSpecConstantOperation( 1321 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, 1322 ArrayRef<uint32_t> enclosedOpOperands) { 1323 1324 Type resultType = getType(resultTypeID); 1325 1326 // Instructions wrapped by OpSpecConstantOp need an ID for their 1327 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V 1328 // dialect wrapped op. For that purpose, a new value map is created and "fake" 1329 // ID in that map is assigned to the result of the enclosed instruction. Note 1330 // that there is no need to update this fake ID since we only need to 1331 // reference the created Value for the enclosed op from the spv::YieldOp 1332 // created later in this method (both of which are the only values in their 1333 // region: the SpecConstantOperation's region). If we encounter another 1334 // SpecConstantOperation in the module, we simply re-use the fake ID since the 1335 // previous Value assigned to it isn't visible in the current scope anyway. 1336 DenseMap<uint32_t, Value> newValueMap; 1337 llvm::SaveAndRestore<DenseMap<uint32_t, Value>> valueMapGuard(valueMap, 1338 newValueMap); 1339 constexpr uint32_t fakeID = static_cast<uint32_t>(-3); 1340 1341 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands; 1342 enclosedOpResultTypeAndOperands.push_back(resultTypeID); 1343 enclosedOpResultTypeAndOperands.push_back(fakeID); 1344 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(), 1345 enclosedOpOperands.end()); 1346 1347 // Process enclosed instruction before creating the enclosing 1348 // specConstantOperation (and its region). This way, references to constants, 1349 // global variables, and spec constants will be materialized outside the new 1350 // op's region. For more info, see Deserializer::getValue's implementation. 1351 if (failed( 1352 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands))) 1353 return Value(); 1354 1355 // Since the enclosed op is emitted in the current block, split it in a 1356 // separate new block. 1357 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back()); 1358 1359 auto loc = createFileLineColLoc(opBuilder); 1360 auto specConstOperationOp = 1361 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType); 1362 1363 Region &body = specConstOperationOp.body(); 1364 // Move the new block into SpecConstantOperation's body. 1365 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(), 1366 Region::iterator(enclosedBlock)); 1367 Block &block = body.back(); 1368 1369 // RAII guard to reset the insertion point to the module's region after 1370 // deserializing the body of the specConstantOperation. 1371 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); 1372 opBuilder.setInsertionPointToEnd(&block); 1373 1374 opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0)); 1375 return specConstOperationOp.getResult(); 1376 } 1377 1378 LogicalResult 1379 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) { 1380 if (operands.size() != 2) { 1381 return emitError(unknownLoc, 1382 "OpConstantNull must have type <id> and result <id>"); 1383 } 1384 1385 Type resultType = getType(operands[0]); 1386 if (!resultType) { 1387 return emitError(unknownLoc, "undefined result type from <id> ") 1388 << operands[0]; 1389 } 1390 1391 auto resultID = operands[1]; 1392 if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) { 1393 auto attr = opBuilder.getZeroAttr(resultType); 1394 // For normal constants, we just record the attribute (and its type) for 1395 // later materialization at use sites. 1396 constantMap.try_emplace(resultID, attr, resultType); 1397 return success(); 1398 } 1399 1400 return emitError(unknownLoc, "unsupported OpConstantNull type: ") 1401 << resultType; 1402 } 1403 1404 //===----------------------------------------------------------------------===// 1405 // Control flow 1406 //===----------------------------------------------------------------------===// 1407 1408 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) { 1409 if (auto *block = getBlock(id)) { 1410 LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id 1411 << " @ " << block << "\n"); 1412 return block; 1413 } 1414 1415 // We don't know where this block will be placed finally (in a 1416 // spv.mlir.selection or spv.mlir.loop or function). Create it into the 1417 // function for now and sort out the proper place later. 1418 auto *block = curFunction->addBlock(); 1419 LLVM_DEBUG(llvm::dbgs() << "[block] created block for id = " << id << " @ " 1420 << block << "\n"); 1421 return blockMap[id] = block; 1422 } 1423 1424 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) { 1425 if (!curBlock) { 1426 return emitError(unknownLoc, "OpBranch must appear inside a block"); 1427 } 1428 1429 if (operands.size() != 1) { 1430 return emitError(unknownLoc, "OpBranch must take exactly one target label"); 1431 } 1432 1433 auto *target = getOrCreateBlock(operands[0]); 1434 auto loc = createFileLineColLoc(opBuilder); 1435 // The preceding instruction for the OpBranch instruction could be an 1436 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have 1437 // the same OpLine information. 1438 opBuilder.create<spirv::BranchOp>(loc, target); 1439 1440 (void)clearDebugLine(); 1441 return success(); 1442 } 1443 1444 LogicalResult 1445 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) { 1446 if (!curBlock) { 1447 return emitError(unknownLoc, 1448 "OpBranchConditional must appear inside a block"); 1449 } 1450 1451 if (operands.size() != 3 && operands.size() != 5) { 1452 return emitError(unknownLoc, 1453 "OpBranchConditional must have condition, true label, " 1454 "false label, and optionally two branch weights"); 1455 } 1456 1457 auto condition = getValue(operands[0]); 1458 auto *trueBlock = getOrCreateBlock(operands[1]); 1459 auto *falseBlock = getOrCreateBlock(operands[2]); 1460 1461 Optional<std::pair<uint32_t, uint32_t>> weights; 1462 if (operands.size() == 5) { 1463 weights = std::make_pair(operands[3], operands[4]); 1464 } 1465 // The preceding instruction for the OpBranchConditional instruction could be 1466 // an OpSelectionMerge instruction, in this case they will have the same 1467 // OpLine information. 1468 auto loc = createFileLineColLoc(opBuilder); 1469 opBuilder.create<spirv::BranchConditionalOp>( 1470 loc, condition, trueBlock, 1471 /*trueArguments=*/ArrayRef<Value>(), falseBlock, 1472 /*falseArguments=*/ArrayRef<Value>(), weights); 1473 1474 (void)clearDebugLine(); 1475 return success(); 1476 } 1477 1478 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) { 1479 if (!curFunction) { 1480 return emitError(unknownLoc, "OpLabel must appear inside a function"); 1481 } 1482 1483 if (operands.size() != 1) { 1484 return emitError(unknownLoc, "OpLabel should only have result <id>"); 1485 } 1486 1487 auto labelID = operands[0]; 1488 // We may have forward declared this block. 1489 auto *block = getOrCreateBlock(labelID); 1490 LLVM_DEBUG(llvm::dbgs() << "[block] populating block " << block << "\n"); 1491 // If we have seen this block, make sure it was just a forward declaration. 1492 assert(block->empty() && "re-deserialize the same block!"); 1493 1494 opBuilder.setInsertionPointToStart(block); 1495 blockMap[labelID] = curBlock = block; 1496 1497 return success(); 1498 } 1499 1500 LogicalResult 1501 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) { 1502 if (!curBlock) { 1503 return emitError(unknownLoc, "OpSelectionMerge must appear in a block"); 1504 } 1505 1506 if (operands.size() < 2) { 1507 return emitError( 1508 unknownLoc, 1509 "OpSelectionMerge must specify merge target and selection control"); 1510 } 1511 1512 auto *mergeBlock = getOrCreateBlock(operands[0]); 1513 auto loc = createFileLineColLoc(opBuilder); 1514 auto selectionControl = operands[1]; 1515 1516 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock) 1517 .second) { 1518 return emitError( 1519 unknownLoc, 1520 "a block cannot have more than one OpSelectionMerge instruction"); 1521 } 1522 1523 return success(); 1524 } 1525 1526 LogicalResult 1527 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) { 1528 if (!curBlock) { 1529 return emitError(unknownLoc, "OpLoopMerge must appear in a block"); 1530 } 1531 1532 if (operands.size() < 3) { 1533 return emitError(unknownLoc, "OpLoopMerge must specify merge target, " 1534 "continue target and loop control"); 1535 } 1536 1537 auto *mergeBlock = getOrCreateBlock(operands[0]); 1538 auto *continueBlock = getOrCreateBlock(operands[1]); 1539 auto loc = createFileLineColLoc(opBuilder); 1540 uint32_t loopControl = operands[2]; 1541 1542 if (!blockMergeInfo 1543 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock) 1544 .second) { 1545 return emitError( 1546 unknownLoc, 1547 "a block cannot have more than one OpLoopMerge instruction"); 1548 } 1549 1550 return success(); 1551 } 1552 1553 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) { 1554 if (!curBlock) { 1555 return emitError(unknownLoc, "OpPhi must appear in a block"); 1556 } 1557 1558 if (operands.size() < 4) { 1559 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, " 1560 "and variable-parent pairs"); 1561 } 1562 1563 // Create a block argument for this OpPhi instruction. 1564 Type blockArgType = getType(operands[0]); 1565 BlockArgument blockArg = curBlock->addArgument(blockArgType); 1566 valueMap[operands[1]] = blockArg; 1567 LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg 1568 << " id = " << operands[1] << " of type " 1569 << blockArgType << '\n'); 1570 1571 // For each (value, predecessor) pair, insert the value to the predecessor's 1572 // blockPhiInfo entry so later we can fix the block argument there. 1573 for (unsigned i = 2, e = operands.size(); i < e; i += 2) { 1574 uint32_t value = operands[i]; 1575 Block *predecessor = getOrCreateBlock(operands[i + 1]); 1576 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock}; 1577 blockPhiInfo[predecessorTargetPair].push_back(value); 1578 LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor 1579 << " with arg id = " << value << '\n'); 1580 } 1581 1582 return success(); 1583 } 1584 1585 namespace { 1586 /// A class for putting all blocks in a structured selection/loop in a 1587 /// spv.mlir.selection/spv.mlir.loop op. 1588 class ControlFlowStructurizer { 1589 public: 1590 /// Structurizes the loop at the given `headerBlock`. 1591 /// 1592 /// This method will create an spv.mlir.loop op in the `mergeBlock` and move 1593 /// all blocks in the structured loop into the spv.mlir.loop's region. All 1594 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This 1595 /// method will also update `mergeInfo` by remapping all blocks inside to the 1596 /// newly cloned ones inside structured control flow op's regions. 1597 static LogicalResult structurize(Location loc, uint32_t control, 1598 spirv::BlockMergeInfoMap &mergeInfo, 1599 Block *headerBlock, Block *mergeBlock, 1600 Block *continueBlock) { 1601 return ControlFlowStructurizer(loc, control, mergeInfo, headerBlock, 1602 mergeBlock, continueBlock) 1603 .structurizeImpl(); 1604 } 1605 1606 private: 1607 ControlFlowStructurizer(Location loc, uint32_t control, 1608 spirv::BlockMergeInfoMap &mergeInfo, Block *header, 1609 Block *merge, Block *cont) 1610 : location(loc), control(control), blockMergeInfo(mergeInfo), 1611 headerBlock(header), mergeBlock(merge), continueBlock(cont) {} 1612 1613 /// Creates a new spv.mlir.selection op at the beginning of the `mergeBlock`. 1614 spirv::SelectionOp createSelectionOp(uint32_t selectionControl); 1615 1616 /// Creates a new spv.mlir.loop op at the beginning of the `mergeBlock`. 1617 spirv::LoopOp createLoopOp(uint32_t loopControl); 1618 1619 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`. 1620 void collectBlocksInConstruct(); 1621 1622 LogicalResult structurizeImpl(); 1623 1624 Location location; 1625 uint32_t control; 1626 1627 spirv::BlockMergeInfoMap &blockMergeInfo; 1628 1629 Block *headerBlock; 1630 Block *mergeBlock; 1631 Block *continueBlock; // nullptr for spv.mlir.selection 1632 1633 SetVector<Block *> constructBlocks; 1634 }; 1635 } // namespace 1636 1637 spirv::SelectionOp 1638 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) { 1639 // Create a builder and set the insertion point to the beginning of the 1640 // merge block so that the newly created SelectionOp will be inserted there. 1641 OpBuilder builder(&mergeBlock->front()); 1642 1643 auto control = static_cast<spirv::SelectionControl>(selectionControl); 1644 auto selectionOp = builder.create<spirv::SelectionOp>(location, control); 1645 selectionOp.addMergeBlock(); 1646 1647 return selectionOp; 1648 } 1649 1650 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) { 1651 // Create a builder and set the insertion point to the beginning of the 1652 // merge block so that the newly created LoopOp will be inserted there. 1653 OpBuilder builder(&mergeBlock->front()); 1654 1655 auto control = static_cast<spirv::LoopControl>(loopControl); 1656 auto loopOp = builder.create<spirv::LoopOp>(location, control); 1657 loopOp.addEntryAndMergeBlock(); 1658 1659 return loopOp; 1660 } 1661 1662 void ControlFlowStructurizer::collectBlocksInConstruct() { 1663 assert(constructBlocks.empty() && "expected empty constructBlocks"); 1664 1665 // Put the header block in the work list first. 1666 constructBlocks.insert(headerBlock); 1667 1668 // For each item in the work list, add its successors excluding the merge 1669 // block. 1670 for (unsigned i = 0; i < constructBlocks.size(); ++i) { 1671 for (auto *successor : constructBlocks[i]->getSuccessors()) 1672 if (successor != mergeBlock) 1673 constructBlocks.insert(successor); 1674 } 1675 } 1676 1677 LogicalResult ControlFlowStructurizer::structurizeImpl() { 1678 Operation *op = nullptr; 1679 bool isLoop = continueBlock != nullptr; 1680 if (isLoop) { 1681 if (auto loopOp = createLoopOp(control)) 1682 op = loopOp.getOperation(); 1683 } else { 1684 if (auto selectionOp = createSelectionOp(control)) 1685 op = selectionOp.getOperation(); 1686 } 1687 if (!op) 1688 return failure(); 1689 Region &body = op->getRegion(0); 1690 1691 BlockAndValueMapping mapper; 1692 // All references to the old merge block should be directed to the 1693 // selection/loop merge block in the SelectionOp/LoopOp's region. 1694 mapper.map(mergeBlock, &body.back()); 1695 1696 collectBlocksInConstruct(); 1697 1698 // We've identified all blocks belonging to the selection/loop's region. Now 1699 // need to "move" them into the selection/loop. Instead of really moving the 1700 // blocks, in the following we copy them and remap all values and branches. 1701 // This is because: 1702 // * Inserting a block into a region requires the block not in any region 1703 // before. But selections/loops can nest so we can create selection/loop ops 1704 // in a nested manner, which means some blocks may already be in a 1705 // selection/loop region when to be moved again. 1706 // * It's much trickier to fix up the branches into and out of the loop's 1707 // region: we need to treat not-moved blocks and moved blocks differently: 1708 // Not-moved blocks jumping to the loop header block need to jump to the 1709 // merge point containing the new loop op but not the loop continue block's 1710 // back edge. Moved blocks jumping out of the loop need to jump to the 1711 // merge block inside the loop region but not other not-moved blocks. 1712 // We cannot use replaceAllUsesWith clearly and it's harder to follow the 1713 // logic. 1714 1715 // Create a corresponding block in the SelectionOp/LoopOp's region for each 1716 // block in this loop construct. 1717 OpBuilder builder(body); 1718 for (auto *block : constructBlocks) { 1719 // Create a block and insert it before the selection/loop merge block in the 1720 // SelectionOp/LoopOp's region. 1721 auto *newBlock = builder.createBlock(&body.back()); 1722 mapper.map(block, newBlock); 1723 LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock 1724 << " from block " << block << "\n"); 1725 if (!isFnEntryBlock(block)) { 1726 for (BlockArgument blockArg : block->getArguments()) { 1727 auto newArg = newBlock->addArgument(blockArg.getType()); 1728 mapper.map(blockArg, newArg); 1729 LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg 1730 << " to " << newArg << '\n'); 1731 } 1732 } else { 1733 LLVM_DEBUG(llvm::dbgs() 1734 << "[cf] block " << block << " is a function entry block\n"); 1735 } 1736 for (auto &op : *block) 1737 newBlock->push_back(op.clone(mapper)); 1738 } 1739 1740 // Go through all ops and remap the operands. 1741 auto remapOperands = [&](Operation *op) { 1742 for (auto &operand : op->getOpOperands()) 1743 if (Value mappedOp = mapper.lookupOrNull(operand.get())) 1744 operand.set(mappedOp); 1745 for (auto &succOp : op->getBlockOperands()) 1746 if (Block *mappedOp = mapper.lookupOrNull(succOp.get())) 1747 succOp.set(mappedOp); 1748 }; 1749 for (auto &block : body) { 1750 block.walk(remapOperands); 1751 } 1752 1753 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to 1754 // the selection/loop construct into its region. Next we need to fix the 1755 // connections between this new SelectionOp/LoopOp with existing blocks. 1756 1757 // All existing incoming branches should go to the merge block, where the 1758 // SelectionOp/LoopOp resides right now. 1759 headerBlock->replaceAllUsesWith(mergeBlock); 1760 1761 if (isLoop) { 1762 // The loop selection/loop header block may have block arguments. Since now 1763 // we place the selection/loop op inside the old merge block, we need to 1764 // make sure the old merge block has the same block argument list. 1765 assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported"); 1766 for (BlockArgument blockArg : headerBlock->getArguments()) { 1767 mergeBlock->addArgument(blockArg.getType()); 1768 } 1769 1770 // If the loop header block has block arguments, make sure the spv.branch op 1771 // matches. 1772 SmallVector<Value, 4> blockArgs; 1773 if (!headerBlock->args_empty()) 1774 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; 1775 1776 // The loop entry block should have a unconditional branch jumping to the 1777 // loop header block. 1778 builder.setInsertionPointToEnd(&body.front()); 1779 builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock), 1780 ArrayRef<Value>(blockArgs)); 1781 } 1782 1783 // All the blocks cloned into the SelectionOp/LoopOp's region can now be 1784 // cleaned up. 1785 LLVM_DEBUG(llvm::dbgs() << "[cf] cleaning up blocks after clone\n"); 1786 // First we need to drop all operands' references inside all blocks. This is 1787 // needed because we can have blocks referencing SSA values from one another. 1788 for (auto *block : constructBlocks) 1789 block->dropAllReferences(); 1790 1791 // Then erase all old blocks. 1792 for (auto *block : constructBlocks) { 1793 // We've cloned all blocks belonging to this construct into the structured 1794 // control flow op's region. Among these blocks, some may compose another 1795 // selection/loop. If so, they will be recorded within blockMergeInfo. 1796 // We need to update the pointers there to the newly remapped ones so we can 1797 // continue structurizing them later. 1798 // TODO: The asserts in the following assumes input SPIR-V blob 1799 // forms correctly nested selection/loop constructs. We should relax this 1800 // and support error cases better. 1801 auto it = blockMergeInfo.find(block); 1802 if (it != blockMergeInfo.end()) { 1803 Block *newHeader = mapper.lookupOrNull(block); 1804 assert(newHeader && "nested loop header block should be remapped!"); 1805 1806 Block *newContinue = it->second.continueBlock; 1807 if (newContinue) { 1808 newContinue = mapper.lookupOrNull(newContinue); 1809 assert(newContinue && "nested loop continue block should be remapped!"); 1810 } 1811 1812 Block *newMerge = it->second.mergeBlock; 1813 if (Block *mappedTo = mapper.lookupOrNull(newMerge)) 1814 newMerge = mappedTo; 1815 1816 // Keep original location for nested selection/loop ops. 1817 Location loc = it->second.loc; 1818 // The iterator should be erased before adding a new entry into 1819 // blockMergeInfo to avoid iterator invalidation. 1820 blockMergeInfo.erase(it); 1821 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge, 1822 newContinue); 1823 } 1824 1825 // The structured selection/loop's entry block does not have arguments. 1826 // If the function's header block is also part of the structured control 1827 // flow, we cannot just simply erase it because it may contain arguments 1828 // matching the function signature and used by the cloned blocks. 1829 if (isFnEntryBlock(block)) { 1830 LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block 1831 << " to only contain a spv.Branch op\n"); 1832 // Still keep the function entry block for the potential block arguments, 1833 // but replace all ops inside with a branch to the merge block. 1834 block->clear(); 1835 builder.setInsertionPointToEnd(block); 1836 builder.create<spirv::BranchOp>(location, mergeBlock); 1837 } else { 1838 LLVM_DEBUG(llvm::dbgs() << "[cf] erasing block " << block << "\n"); 1839 block->erase(); 1840 } 1841 } 1842 1843 LLVM_DEBUG( 1844 llvm::dbgs() << "[cf] after structurizing construct with header block " 1845 << headerBlock << ":\n" 1846 << *op << '\n'); 1847 1848 return success(); 1849 } 1850 1851 LogicalResult spirv::Deserializer::wireUpBlockArgument() { 1852 LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n"); 1853 1854 OpBuilder::InsertionGuard guard(opBuilder); 1855 1856 for (const auto &info : blockPhiInfo) { 1857 Block *block = info.first.first; 1858 Block *target = info.first.second; 1859 const BlockPhiInfo &phiInfo = info.second; 1860 LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n"); 1861 LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n"); 1862 LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); 1863 LLVM_DEBUG(llvm::dbgs() << '\n'); 1864 1865 // Set insertion point to before this block's terminator early because we 1866 // may materialize ops via getValue() call. 1867 auto *op = block->getTerminator(); 1868 opBuilder.setInsertionPoint(op); 1869 1870 SmallVector<Value, 4> blockArgs; 1871 blockArgs.reserve(phiInfo.size()); 1872 for (uint32_t valueId : phiInfo) { 1873 if (Value value = getValue(valueId)) { 1874 blockArgs.push_back(value); 1875 LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value 1876 << " id = " << valueId << '\n'); 1877 } else { 1878 return emitError(unknownLoc, "OpPhi references undefined value!"); 1879 } 1880 } 1881 1882 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) { 1883 // Replace the previous branch op with a new one with block arguments. 1884 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(), 1885 blockArgs); 1886 branchOp.erase(); 1887 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) { 1888 assert((branchCondOp.getTrueBlock() == target || 1889 branchCondOp.getFalseBlock() == target) && 1890 "expected target to be either the true or false target"); 1891 if (target == branchCondOp.trueTarget()) 1892 opBuilder.create<spirv::BranchConditionalOp>( 1893 branchCondOp.getLoc(), branchCondOp.condition(), blockArgs, 1894 branchCondOp.getFalseBlockArguments(), 1895 branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(), 1896 branchCondOp.falseTarget()); 1897 else 1898 opBuilder.create<spirv::BranchConditionalOp>( 1899 branchCondOp.getLoc(), branchCondOp.condition(), 1900 branchCondOp.getTrueBlockArguments(), blockArgs, 1901 branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(), 1902 branchCondOp.getFalseBlock()); 1903 1904 branchCondOp.erase(); 1905 } else { 1906 return emitError(unknownLoc, "unimplemented terminator for Phi creation"); 1907 } 1908 1909 LLVM_DEBUG(llvm::dbgs() << "[phi] after creating block argument:\n"); 1910 LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); 1911 LLVM_DEBUG(llvm::dbgs() << '\n'); 1912 } 1913 blockPhiInfo.clear(); 1914 1915 LLVM_DEBUG(llvm::dbgs() << "[phi] completed wiring up block arguments\n"); 1916 return success(); 1917 } 1918 1919 LogicalResult spirv::Deserializer::structurizeControlFlow() { 1920 LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n"); 1921 1922 while (!blockMergeInfo.empty()) { 1923 Block *headerBlock = blockMergeInfo.begin()->first; 1924 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; 1925 1926 LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n"); 1927 LLVM_DEBUG(headerBlock->print(llvm::dbgs())); 1928 1929 auto *mergeBlock = mergeInfo.mergeBlock; 1930 assert(mergeBlock && "merge block cannot be nullptr"); 1931 if (!mergeBlock->args_empty()) 1932 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented"); 1933 LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n"); 1934 LLVM_DEBUG(mergeBlock->print(llvm::dbgs())); 1935 1936 auto *continueBlock = mergeInfo.continueBlock; 1937 if (continueBlock) { 1938 LLVM_DEBUG(llvm::dbgs() 1939 << "[cf] continue block " << continueBlock << ":\n"); 1940 LLVM_DEBUG(continueBlock->print(llvm::dbgs())); 1941 } 1942 // Erase this case before calling into structurizer, who will update 1943 // blockMergeInfo. 1944 blockMergeInfo.erase(blockMergeInfo.begin()); 1945 if (failed(ControlFlowStructurizer::structurize( 1946 mergeInfo.loc, mergeInfo.control, blockMergeInfo, headerBlock, 1947 mergeBlock, continueBlock))) 1948 return failure(); 1949 } 1950 1951 LLVM_DEBUG(llvm::dbgs() << "[cf] completed structurizing control flow\n"); 1952 return success(); 1953 } 1954 1955 //===----------------------------------------------------------------------===// 1956 // Debug 1957 //===----------------------------------------------------------------------===// 1958 1959 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) { 1960 if (!debugLine) 1961 return unknownLoc; 1962 1963 auto fileName = debugInfoMap.lookup(debugLine->fileID).str(); 1964 if (fileName.empty()) 1965 fileName = "<unknown>"; 1966 return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line, 1967 debugLine->col); 1968 } 1969 1970 LogicalResult 1971 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) { 1972 // According to SPIR-V spec: 1973 // "This location information applies to the instructions physically 1974 // following this instruction, up to the first occurrence of any of the 1975 // following: the next end of block, the next OpLine instruction, or the next 1976 // OpNoLine instruction." 1977 if (operands.size() != 3) 1978 return emitError(unknownLoc, "OpLine must have 3 operands"); 1979 debugLine = DebugLine(operands[0], operands[1], operands[2]); 1980 return success(); 1981 } 1982 1983 LogicalResult spirv::Deserializer::clearDebugLine() { 1984 debugLine = llvm::None; 1985 return success(); 1986 } 1987 1988 LogicalResult 1989 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) { 1990 if (operands.size() < 2) 1991 return emitError(unknownLoc, "OpString needs at least 2 operands"); 1992 1993 if (!debugInfoMap.lookup(operands[0]).empty()) 1994 return emitError(unknownLoc, 1995 "duplicate debug string found for result <id> ") 1996 << operands[0]; 1997 1998 unsigned wordIndex = 1; 1999 StringRef debugString = decodeStringLiteral(operands, wordIndex); 2000 if (wordIndex != operands.size()) 2001 return emitError(unknownLoc, 2002 "unexpected trailing words in OpString instruction"); 2003 2004 debugInfoMap[operands[0]] = debugString; 2005 return success(); 2006 } 2007