1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Deserializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Location.h" 23 #include "mlir/Support/LogicalResult.h" 24 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/ADT/Sequence.h" 27 #include "llvm/ADT/SmallVector.h" 28 #include "llvm/ADT/StringExtras.h" 29 #include "llvm/ADT/bit.h" 30 #include "llvm/Support/Debug.h" 31 #include "llvm/Support/SaveAndRestore.h" 32 #include "llvm/Support/raw_ostream.h" 33 34 using namespace mlir; 35 36 #define DEBUG_TYPE "spirv-deserialization" 37 38 //===----------------------------------------------------------------------===// 39 // Utility Functions 40 //===----------------------------------------------------------------------===// 41 42 /// Returns true if the given `block` is a function entry block. 43 static inline bool isFnEntryBlock(Block *block) { 44 return block->isEntryBlock() && 45 isa_and_nonnull<spirv::FuncOp>(block->getParentOp()); 46 } 47 48 //===----------------------------------------------------------------------===// 49 // Deserializer Method Definitions 50 //===----------------------------------------------------------------------===// 51 52 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary, 53 MLIRContext *context) 54 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), 55 module(createModuleOp()), opBuilder(module->body()) {} 56 57 LogicalResult spirv::Deserializer::deserialize() { 58 LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n"); 59 60 if (failed(processHeader())) 61 return failure(); 62 63 spirv::Opcode opcode = spirv::Opcode::OpNop; 64 ArrayRef<uint32_t> operands; 65 auto binarySize = binary.size(); 66 while (curOffset < binarySize) { 67 // Slice the next instruction out and populate `opcode` and `operands`. 68 // Internally this also updates `curOffset`. 69 if (failed(sliceInstruction(opcode, operands))) 70 return failure(); 71 72 if (failed(processInstruction(opcode, operands))) 73 return failure(); 74 } 75 76 assert(curOffset == binarySize && 77 "deserializer should never index beyond the binary end"); 78 79 for (auto &deferred : deferredInstructions) { 80 if (failed(processInstruction(deferred.first, deferred.second, false))) { 81 return failure(); 82 } 83 } 84 85 attachVCETriple(); 86 87 LLVM_DEBUG(llvm::dbgs() << "+++ completed deserialization +++\n"); 88 return success(); 89 } 90 91 spirv::OwningSPIRVModuleRef spirv::Deserializer::collect() { 92 return std::move(module); 93 } 94 95 //===----------------------------------------------------------------------===// 96 // Module structure 97 //===----------------------------------------------------------------------===// 98 99 spirv::OwningSPIRVModuleRef spirv::Deserializer::createModuleOp() { 100 OpBuilder builder(context); 101 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); 102 spirv::ModuleOp::build(builder, state); 103 return cast<spirv::ModuleOp>(Operation::create(state)); 104 } 105 106 LogicalResult spirv::Deserializer::processHeader() { 107 if (binary.size() < spirv::kHeaderWordCount) 108 return emitError(unknownLoc, 109 "SPIR-V binary module must have a 5-word header"); 110 111 if (binary[0] != spirv::kMagicNumber) 112 return emitError(unknownLoc, "incorrect magic number"); 113 114 // Version number bytes: 0 | major number | minor number | 0 115 uint32_t majorVersion = (binary[1] << 8) >> 24; 116 uint32_t minorVersion = (binary[1] << 16) >> 24; 117 if (majorVersion == 1) { 118 switch (minorVersion) { 119 #define MIN_VERSION_CASE(v) \ 120 case v: \ 121 version = spirv::Version::V_1_##v; \ 122 break 123 124 MIN_VERSION_CASE(0); 125 MIN_VERSION_CASE(1); 126 MIN_VERSION_CASE(2); 127 MIN_VERSION_CASE(3); 128 MIN_VERSION_CASE(4); 129 MIN_VERSION_CASE(5); 130 #undef MIN_VERSION_CASE 131 default: 132 return emitError(unknownLoc, "unsupported SPIR-V minor version: ") 133 << minorVersion; 134 } 135 } else { 136 return emitError(unknownLoc, "unsupported SPIR-V major version: ") 137 << majorVersion; 138 } 139 140 // TODO: generator number, bound, schema 141 curOffset = spirv::kHeaderWordCount; 142 return success(); 143 } 144 145 LogicalResult 146 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) { 147 if (operands.size() != 1) 148 return emitError(unknownLoc, "OpMemoryModel must have one parameter"); 149 150 auto cap = spirv::symbolizeCapability(operands[0]); 151 if (!cap) 152 return emitError(unknownLoc, "unknown capability: ") << operands[0]; 153 154 capabilities.insert(*cap); 155 return success(); 156 } 157 158 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) { 159 if (words.empty()) { 160 return emitError( 161 unknownLoc, 162 "OpExtension must have a literal string for the extension name"); 163 } 164 165 unsigned wordIndex = 0; 166 StringRef extName = decodeStringLiteral(words, wordIndex); 167 if (wordIndex != words.size()) 168 return emitError(unknownLoc, 169 "unexpected trailing words in OpExtension instruction"); 170 auto ext = spirv::symbolizeExtension(extName); 171 if (!ext) 172 return emitError(unknownLoc, "unknown extension: ") << extName; 173 174 extensions.insert(*ext); 175 return success(); 176 } 177 178 LogicalResult 179 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) { 180 if (words.size() < 2) { 181 return emitError(unknownLoc, 182 "OpExtInstImport must have a result <id> and a literal " 183 "string for the extended instruction set name"); 184 } 185 186 unsigned wordIndex = 1; 187 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex); 188 if (wordIndex != words.size()) { 189 return emitError(unknownLoc, 190 "unexpected trailing words in OpExtInstImport"); 191 } 192 return success(); 193 } 194 195 void spirv::Deserializer::attachVCETriple() { 196 (*module)->setAttr( 197 spirv::ModuleOp::getVCETripleAttrName(), 198 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(), 199 extensions.getArrayRef(), context)); 200 } 201 202 LogicalResult 203 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) { 204 if (operands.size() != 2) 205 return emitError(unknownLoc, "OpMemoryModel must have two operands"); 206 207 (*module)->setAttr( 208 "addressing_model", 209 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front()))); 210 (*module)->setAttr( 211 "memory_model", 212 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back()))); 213 214 return success(); 215 } 216 217 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) { 218 // TODO: This function should also be auto-generated. For now, since only a 219 // few decorations are processed/handled in a meaningful manner, going with a 220 // manual implementation. 221 if (words.size() < 2) { 222 return emitError( 223 unknownLoc, "OpDecorate must have at least result <id> and Decoration"); 224 } 225 auto decorationName = 226 stringifyDecoration(static_cast<spirv::Decoration>(words[1])); 227 if (decorationName.empty()) { 228 return emitError(unknownLoc, "invalid Decoration code : ") << words[1]; 229 } 230 auto attrName = llvm::convertToSnakeFromCamelCase(decorationName); 231 auto symbol = opBuilder.getIdentifier(attrName); 232 switch (static_cast<spirv::Decoration>(words[1])) { 233 case spirv::Decoration::DescriptorSet: 234 case spirv::Decoration::Binding: 235 if (words.size() != 3) { 236 return emitError(unknownLoc, "OpDecorate with ") 237 << decorationName << " needs a single integer literal"; 238 } 239 decorations[words[0]].set( 240 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); 241 break; 242 case spirv::Decoration::BuiltIn: 243 if (words.size() != 3) { 244 return emitError(unknownLoc, "OpDecorate with ") 245 << decorationName << " needs a single integer literal"; 246 } 247 decorations[words[0]].set( 248 symbol, opBuilder.getStringAttr( 249 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2])))); 250 break; 251 case spirv::Decoration::ArrayStride: 252 if (words.size() != 3) { 253 return emitError(unknownLoc, "OpDecorate with ") 254 << decorationName << " needs a single integer literal"; 255 } 256 typeDecorations[words[0]] = words[2]; 257 break; 258 case spirv::Decoration::Aliased: 259 case spirv::Decoration::Block: 260 case spirv::Decoration::BufferBlock: 261 case spirv::Decoration::Flat: 262 case spirv::Decoration::NonReadable: 263 case spirv::Decoration::NonWritable: 264 case spirv::Decoration::NoPerspective: 265 case spirv::Decoration::Restrict: 266 if (words.size() != 2) { 267 return emitError(unknownLoc, "OpDecoration with ") 268 << decorationName << "needs a single target <id>"; 269 } 270 // Block decoration does not affect spv.struct type, but is still stored for 271 // verification. 272 // TODO: Update StructType to contain this information since 273 // it is needed for many validation rules. 274 decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); 275 break; 276 case spirv::Decoration::Location: 277 case spirv::Decoration::SpecId: 278 if (words.size() != 3) { 279 return emitError(unknownLoc, "OpDecoration with ") 280 << decorationName << "needs a single integer literal"; 281 } 282 decorations[words[0]].set( 283 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); 284 break; 285 default: 286 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; 287 } 288 return success(); 289 } 290 291 LogicalResult 292 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) { 293 // The binary layout of OpMemberDecorate is different comparing to OpDecorate 294 if (words.size() < 3) { 295 return emitError(unknownLoc, 296 "OpMemberDecorate must have at least 3 operands"); 297 } 298 299 auto decoration = static_cast<spirv::Decoration>(words[2]); 300 if (decoration == spirv::Decoration::Offset && words.size() != 4) { 301 return emitError(unknownLoc, 302 " missing offset specification in OpMemberDecorate with " 303 "Offset decoration"); 304 } 305 ArrayRef<uint32_t> decorationOperands; 306 if (words.size() > 3) { 307 decorationOperands = words.slice(3); 308 } 309 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands; 310 return success(); 311 } 312 313 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) { 314 if (words.size() < 3) { 315 return emitError(unknownLoc, "OpMemberName must have at least 3 operands"); 316 } 317 unsigned wordIndex = 2; 318 auto name = decodeStringLiteral(words, wordIndex); 319 if (wordIndex != words.size()) { 320 return emitError(unknownLoc, 321 "unexpected trailing words in OpMemberName instruction"); 322 } 323 memberNameMap[words[0]][words[1]] = name; 324 return success(); 325 } 326 327 LogicalResult 328 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) { 329 if (curFunction) { 330 return emitError(unknownLoc, "found function inside function"); 331 } 332 333 // Get the result type 334 if (operands.size() != 4) { 335 return emitError(unknownLoc, "OpFunction must have 4 parameters"); 336 } 337 Type resultType = getType(operands[0]); 338 if (!resultType) { 339 return emitError(unknownLoc, "undefined result type from <id> ") 340 << operands[0]; 341 } 342 343 if (funcMap.count(operands[1])) { 344 return emitError(unknownLoc, "duplicate function definition/declaration"); 345 } 346 347 auto fnControl = spirv::symbolizeFunctionControl(operands[2]); 348 if (!fnControl) { 349 return emitError(unknownLoc, "unknown Function Control: ") << operands[2]; 350 } 351 352 Type fnType = getType(operands[3]); 353 if (!fnType || !fnType.isa<FunctionType>()) { 354 return emitError(unknownLoc, "unknown function type from <id> ") 355 << operands[3]; 356 } 357 auto functionType = fnType.cast<FunctionType>(); 358 359 if ((isVoidType(resultType) && functionType.getNumResults() != 0) || 360 (functionType.getNumResults() == 1 && 361 functionType.getResult(0) != resultType)) { 362 return emitError(unknownLoc, "mismatch in function type ") 363 << functionType << " and return type " << resultType << " specified"; 364 } 365 366 std::string fnName = getFunctionSymbol(operands[1]); 367 auto funcOp = opBuilder.create<spirv::FuncOp>( 368 unknownLoc, fnName, functionType, fnControl.getValue()); 369 curFunction = funcMap[operands[1]] = funcOp; 370 LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = " 371 << fnType << ", id = " << operands[1] << ") --\n"); 372 auto *entryBlock = funcOp.addEntryBlock(); 373 LLVM_DEBUG(llvm::dbgs() << "[block] created entry block " << entryBlock 374 << "\n"); 375 376 // Parse the op argument instructions 377 if (functionType.getNumInputs()) { 378 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { 379 auto argType = functionType.getInput(i); 380 spirv::Opcode opcode = spirv::Opcode::OpNop; 381 ArrayRef<uint32_t> operands; 382 if (failed(sliceInstruction(opcode, operands, 383 spirv::Opcode::OpFunctionParameter))) { 384 return failure(); 385 } 386 if (opcode != spirv::Opcode::OpFunctionParameter) { 387 return emitError( 388 unknownLoc, 389 "missing OpFunctionParameter instruction for argument ") 390 << i; 391 } 392 if (operands.size() != 2) { 393 return emitError( 394 unknownLoc, 395 "expected result type and result <id> for OpFunctionParameter"); 396 } 397 auto argDefinedType = getType(operands[0]); 398 if (!argDefinedType || argDefinedType != argType) { 399 return emitError(unknownLoc, 400 "mismatch in argument type between function type " 401 "definition ") 402 << functionType << " and argument type definition " 403 << argDefinedType << " at argument " << i; 404 } 405 if (getValue(operands[1])) { 406 return emitError(unknownLoc, "duplicate definition of result <id> '") 407 << operands[1]; 408 } 409 auto argValue = funcOp.getArgument(i); 410 valueMap[operands[1]] = argValue; 411 } 412 } 413 414 // RAII guard to reset the insertion point to the module's region after 415 // deserializing the body of this function. 416 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); 417 418 spirv::Opcode opcode = spirv::Opcode::OpNop; 419 ArrayRef<uint32_t> instOperands; 420 421 // Special handling for the entry block. We need to make sure it starts with 422 // an OpLabel instruction. The entry block takes the same parameters as the 423 // function. All other blocks do not take any parameter. We have already 424 // created the entry block, here we need to register it to the correct label 425 // <id>. 426 if (failed(sliceInstruction(opcode, instOperands, 427 spirv::Opcode::OpFunctionEnd))) { 428 return failure(); 429 } 430 if (opcode == spirv::Opcode::OpFunctionEnd) { 431 LLVM_DEBUG(llvm::dbgs() 432 << "-- completed function '" << fnName << "' (type = " << fnType 433 << ", id = " << operands[1] << ") --\n"); 434 return processFunctionEnd(instOperands); 435 } 436 if (opcode != spirv::Opcode::OpLabel) { 437 return emitError(unknownLoc, "a basic block must start with OpLabel"); 438 } 439 if (instOperands.size() != 1) { 440 return emitError(unknownLoc, "OpLabel should only have result <id>"); 441 } 442 blockMap[instOperands[0]] = entryBlock; 443 if (failed(processLabel(instOperands))) { 444 return failure(); 445 } 446 447 // Then process all the other instructions in the function until we hit 448 // OpFunctionEnd. 449 while (succeeded(sliceInstruction(opcode, instOperands, 450 spirv::Opcode::OpFunctionEnd)) && 451 opcode != spirv::Opcode::OpFunctionEnd) { 452 if (failed(processInstruction(opcode, instOperands))) { 453 return failure(); 454 } 455 } 456 if (opcode != spirv::Opcode::OpFunctionEnd) { 457 return failure(); 458 } 459 460 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = " 461 << fnType << ", id = " << operands[1] << ") --\n"); 462 return processFunctionEnd(instOperands); 463 } 464 465 LogicalResult 466 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) { 467 // Process OpFunctionEnd. 468 if (!operands.empty()) { 469 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd"); 470 } 471 472 // Wire up block arguments from OpPhi instructions. 473 // Put all structured control flow in spv.mlir.selection/spv.mlir.loop ops. 474 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) { 475 return failure(); 476 } 477 478 curBlock = nullptr; 479 curFunction = llvm::None; 480 481 return success(); 482 } 483 484 Optional<std::pair<Attribute, Type>> 485 spirv::Deserializer::getConstant(uint32_t id) { 486 auto constIt = constantMap.find(id); 487 if (constIt == constantMap.end()) 488 return llvm::None; 489 return constIt->getSecond(); 490 } 491 492 Optional<spirv::SpecConstOperationMaterializationInfo> 493 spirv::Deserializer::getSpecConstantOperation(uint32_t id) { 494 auto constIt = specConstOperationMap.find(id); 495 if (constIt == specConstOperationMap.end()) 496 return llvm::None; 497 return constIt->getSecond(); 498 } 499 500 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) { 501 auto funcName = nameMap.lookup(id).str(); 502 if (funcName.empty()) { 503 funcName = "spirv_fn_" + std::to_string(id); 504 } 505 return funcName; 506 } 507 508 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) { 509 auto constName = nameMap.lookup(id).str(); 510 if (constName.empty()) { 511 constName = "spirv_spec_const_" + std::to_string(id); 512 } 513 return constName; 514 } 515 516 spirv::SpecConstantOp 517 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID, 518 Attribute defaultValue) { 519 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); 520 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, 521 defaultValue); 522 if (decorations.count(resultID)) { 523 for (auto attr : decorations[resultID].getAttrs()) 524 op->setAttr(attr.first, attr.second); 525 } 526 specConstMap[resultID] = op; 527 return op; 528 } 529 530 LogicalResult 531 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) { 532 unsigned wordIndex = 0; 533 if (operands.size() < 3) { 534 return emitError( 535 unknownLoc, 536 "OpVariable needs at least 3 operands, type, <id> and storage class"); 537 } 538 539 // Result Type. 540 auto type = getType(operands[wordIndex]); 541 if (!type) { 542 return emitError(unknownLoc, "unknown result type <id> : ") 543 << operands[wordIndex]; 544 } 545 auto ptrType = type.dyn_cast<spirv::PointerType>(); 546 if (!ptrType) { 547 return emitError(unknownLoc, 548 "expected a result type <id> to be a spv.ptr, found : ") 549 << type; 550 } 551 wordIndex++; 552 553 // Result <id>. 554 auto variableID = operands[wordIndex]; 555 auto variableName = nameMap.lookup(variableID).str(); 556 if (variableName.empty()) { 557 variableName = "spirv_var_" + std::to_string(variableID); 558 } 559 wordIndex++; 560 561 // Storage class. 562 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]); 563 if (ptrType.getStorageClass() != storageClass) { 564 return emitError(unknownLoc, "mismatch in storage class of pointer type ") 565 << type << " and that specified in OpVariable instruction : " 566 << stringifyStorageClass(storageClass); 567 } 568 wordIndex++; 569 570 // Initializer. 571 FlatSymbolRefAttr initializer = nullptr; 572 if (wordIndex < operands.size()) { 573 auto initializerOp = getGlobalVariable(operands[wordIndex]); 574 if (!initializerOp) { 575 return emitError(unknownLoc, "unknown <id> ") 576 << operands[wordIndex] << "used as initializer"; 577 } 578 wordIndex++; 579 initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation()); 580 } 581 if (wordIndex != operands.size()) { 582 return emitError(unknownLoc, 583 "found more operands than expected when deserializing " 584 "OpVariable instruction, only ") 585 << wordIndex << " of " << operands.size() << " processed"; 586 } 587 auto loc = createFileLineColLoc(opBuilder); 588 auto varOp = opBuilder.create<spirv::GlobalVariableOp>( 589 loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName), 590 initializer); 591 592 // Decorations. 593 if (decorations.count(variableID)) { 594 for (auto attr : decorations[variableID].getAttrs()) { 595 varOp->setAttr(attr.first, attr.second); 596 } 597 } 598 globalVariableMap[variableID] = varOp; 599 return success(); 600 } 601 602 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) { 603 auto constInfo = getConstant(id); 604 if (!constInfo) { 605 return nullptr; 606 } 607 return constInfo->first.dyn_cast<IntegerAttr>(); 608 } 609 610 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) { 611 if (operands.size() < 2) { 612 return emitError(unknownLoc, "OpName needs at least 2 operands"); 613 } 614 if (!nameMap.lookup(operands[0]).empty()) { 615 return emitError(unknownLoc, "duplicate name found for result <id> ") 616 << operands[0]; 617 } 618 unsigned wordIndex = 1; 619 StringRef name = decodeStringLiteral(operands, wordIndex); 620 if (wordIndex != operands.size()) { 621 return emitError(unknownLoc, 622 "unexpected trailing words in OpName instruction"); 623 } 624 nameMap[operands[0]] = name; 625 return success(); 626 } 627 628 //===----------------------------------------------------------------------===// 629 // Type 630 //===----------------------------------------------------------------------===// 631 632 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, 633 ArrayRef<uint32_t> operands) { 634 if (operands.empty()) { 635 return emitError(unknownLoc, "type instruction with opcode ") 636 << spirv::stringifyOpcode(opcode) << " needs at least one <id>"; 637 } 638 639 /// TODO: Types might be forward declared in some instructions and need to be 640 /// handled appropriately. 641 if (typeMap.count(operands[0])) { 642 return emitError(unknownLoc, "duplicate definition for result <id> ") 643 << operands[0]; 644 } 645 646 switch (opcode) { 647 case spirv::Opcode::OpTypeVoid: 648 if (operands.size() != 1) 649 return emitError(unknownLoc, "OpTypeVoid must have no parameters"); 650 typeMap[operands[0]] = opBuilder.getNoneType(); 651 break; 652 case spirv::Opcode::OpTypeBool: 653 if (operands.size() != 1) 654 return emitError(unknownLoc, "OpTypeBool must have no parameters"); 655 typeMap[operands[0]] = opBuilder.getI1Type(); 656 break; 657 case spirv::Opcode::OpTypeInt: { 658 if (operands.size() != 3) 659 return emitError( 660 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters"); 661 662 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 663 // to preserve or validate. 664 // 0 indicates unsigned, or no signedness semantics 665 // 1 indicates signed semantics." 666 // 667 // So we cannot differentiate signless and unsigned integers; always use 668 // signless semantics for such cases. 669 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed 670 : IntegerType::SignednessSemantics::Signless; 671 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign); 672 } break; 673 case spirv::Opcode::OpTypeFloat: { 674 if (operands.size() != 2) 675 return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); 676 677 Type floatTy; 678 switch (operands[1]) { 679 case 16: 680 floatTy = opBuilder.getF16Type(); 681 break; 682 case 32: 683 floatTy = opBuilder.getF32Type(); 684 break; 685 case 64: 686 floatTy = opBuilder.getF64Type(); 687 break; 688 default: 689 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ") 690 << operands[1]; 691 } 692 typeMap[operands[0]] = floatTy; 693 } break; 694 case spirv::Opcode::OpTypeVector: { 695 if (operands.size() != 3) { 696 return emitError( 697 unknownLoc, 698 "OpTypeVector must have element type and count parameters"); 699 } 700 Type elementTy = getType(operands[1]); 701 if (!elementTy) { 702 return emitError(unknownLoc, "OpTypeVector references undefined <id> ") 703 << operands[1]; 704 } 705 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy); 706 } break; 707 case spirv::Opcode::OpTypePointer: { 708 return processOpTypePointer(operands); 709 } break; 710 case spirv::Opcode::OpTypeArray: 711 return processArrayType(operands); 712 case spirv::Opcode::OpTypeCooperativeMatrixNV: 713 return processCooperativeMatrixType(operands); 714 case spirv::Opcode::OpTypeFunction: 715 return processFunctionType(operands); 716 case spirv::Opcode::OpTypeImage: 717 return processImageType(operands); 718 case spirv::Opcode::OpTypeSampledImage: 719 return processSampledImageType(operands); 720 case spirv::Opcode::OpTypeRuntimeArray: 721 return processRuntimeArrayType(operands); 722 case spirv::Opcode::OpTypeStruct: 723 return processStructType(operands); 724 case spirv::Opcode::OpTypeMatrix: 725 return processMatrixType(operands); 726 default: 727 return emitError(unknownLoc, "unhandled type instruction"); 728 } 729 return success(); 730 } 731 732 LogicalResult 733 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) { 734 if (operands.size() != 3) 735 return emitError(unknownLoc, "OpTypePointer must have two parameters"); 736 737 auto pointeeType = getType(operands[2]); 738 if (!pointeeType) 739 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ") 740 << operands[2]; 741 742 uint32_t typePointerID = operands[0]; 743 auto storageClass = static_cast<spirv::StorageClass>(operands[1]); 744 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass); 745 746 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos); 747 deferredStructIt != std::end(deferredStructTypesInfos);) { 748 for (auto *unresolvedMemberIt = 749 std::begin(deferredStructIt->unresolvedMemberTypes); 750 unresolvedMemberIt != 751 std::end(deferredStructIt->unresolvedMemberTypes);) { 752 if (unresolvedMemberIt->first == typePointerID) { 753 // The newly constructed pointer type can resolve one of the 754 // deferred struct type members; update the memberTypes list and 755 // clean the unresolvedMemberTypes list accordingly. 756 deferredStructIt->memberTypes[unresolvedMemberIt->second] = 757 typeMap[typePointerID]; 758 unresolvedMemberIt = 759 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt); 760 } else { 761 ++unresolvedMemberIt; 762 } 763 } 764 765 if (deferredStructIt->unresolvedMemberTypes.empty()) { 766 // All deferred struct type members are now resolved, set the struct body. 767 auto structType = deferredStructIt->deferredStructType; 768 769 assert(structType && "expected a spirv::StructType"); 770 assert(structType.isIdentified() && "expected an indentified struct"); 771 772 if (failed(structType.trySetBody( 773 deferredStructIt->memberTypes, deferredStructIt->offsetInfo, 774 deferredStructIt->memberDecorationsInfo))) 775 return failure(); 776 777 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt); 778 } else { 779 ++deferredStructIt; 780 } 781 } 782 783 return success(); 784 } 785 786 LogicalResult 787 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) { 788 if (operands.size() != 3) { 789 return emitError(unknownLoc, 790 "OpTypeArray must have element type and count parameters"); 791 } 792 793 Type elementTy = getType(operands[1]); 794 if (!elementTy) { 795 return emitError(unknownLoc, "OpTypeArray references undefined <id> ") 796 << operands[1]; 797 } 798 799 unsigned count = 0; 800 // TODO: The count can also come frome a specialization constant. 801 auto countInfo = getConstant(operands[2]); 802 if (!countInfo) { 803 return emitError(unknownLoc, "OpTypeArray count <id> ") 804 << operands[2] << "can only come from normal constant right now"; 805 } 806 807 if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) { 808 count = intVal.getValue().getZExtValue(); 809 } else { 810 return emitError(unknownLoc, "OpTypeArray count must come from a " 811 "scalar integer constant instruction"); 812 } 813 814 typeMap[operands[0]] = spirv::ArrayType::get( 815 elementTy, count, typeDecorations.lookup(operands[0])); 816 return success(); 817 } 818 819 LogicalResult 820 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) { 821 assert(!operands.empty() && "No operands for processing function type"); 822 if (operands.size() == 1) { 823 return emitError(unknownLoc, "missing return type for OpTypeFunction"); 824 } 825 auto returnType = getType(operands[1]); 826 if (!returnType) { 827 return emitError(unknownLoc, "unknown return type in OpTypeFunction"); 828 } 829 SmallVector<Type, 1> argTypes; 830 for (size_t i = 2, e = operands.size(); i < e; ++i) { 831 auto ty = getType(operands[i]); 832 if (!ty) { 833 return emitError(unknownLoc, "unknown argument type in OpTypeFunction"); 834 } 835 argTypes.push_back(ty); 836 } 837 ArrayRef<Type> returnTypes; 838 if (!isVoidType(returnType)) { 839 returnTypes = llvm::makeArrayRef(returnType); 840 } 841 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes); 842 return success(); 843 } 844 845 LogicalResult 846 spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) { 847 if (operands.size() != 5) { 848 return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element " 849 "type and row x column parameters"); 850 } 851 852 Type elementTy = getType(operands[1]); 853 if (!elementTy) { 854 return emitError(unknownLoc, 855 "OpTypeCooperativeMatrix references undefined <id> ") 856 << operands[1]; 857 } 858 859 auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); 860 if (!scope) { 861 return emitError(unknownLoc, 862 "OpTypeCooperativeMatrix references undefined scope <id> ") 863 << operands[2]; 864 } 865 866 unsigned rows = getConstantInt(operands[3]).getInt(); 867 unsigned columns = getConstantInt(operands[4]).getInt(); 868 869 typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get( 870 elementTy, scope.getValue(), rows, columns); 871 return success(); 872 } 873 874 LogicalResult 875 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) { 876 if (operands.size() != 2) { 877 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands"); 878 } 879 Type memberType = getType(operands[1]); 880 if (!memberType) { 881 return emitError(unknownLoc, 882 "OpTypeRuntimeArray references undefined <id> ") 883 << operands[1]; 884 } 885 typeMap[operands[0]] = spirv::RuntimeArrayType::get( 886 memberType, typeDecorations.lookup(operands[0])); 887 return success(); 888 } 889 890 LogicalResult 891 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) { 892 // TODO: Find a way to handle identified structs when debug info is stripped. 893 894 if (operands.empty()) { 895 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>"); 896 } 897 898 if (operands.size() == 1) { 899 // Handle empty struct. 900 typeMap[operands[0]] = 901 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str()); 902 return success(); 903 } 904 905 // First element is operand ID, second element is member index in the struct. 906 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes; 907 SmallVector<Type, 4> memberTypes; 908 909 for (auto op : llvm::drop_begin(operands, 1)) { 910 Type memberType = getType(op); 911 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0); 912 913 if (!memberType && !typeForwardPtr) 914 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ") 915 << op; 916 917 if (!memberType) 918 unresolvedMemberTypes.emplace_back(op, memberTypes.size()); 919 920 memberTypes.push_back(memberType); 921 } 922 923 SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo; 924 SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo; 925 if (memberDecorationMap.count(operands[0])) { 926 auto &allMemberDecorations = memberDecorationMap[operands[0]]; 927 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) { 928 if (allMemberDecorations.count(memberIndex)) { 929 for (auto &memberDecoration : allMemberDecorations[memberIndex]) { 930 // Check for offset. 931 if (memberDecoration.first == spirv::Decoration::Offset) { 932 // If offset info is empty, resize to the number of members; 933 if (offsetInfo.empty()) { 934 offsetInfo.resize(memberTypes.size()); 935 } 936 offsetInfo[memberIndex] = memberDecoration.second[0]; 937 } else { 938 if (!memberDecoration.second.empty()) { 939 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1, 940 memberDecoration.first, 941 memberDecoration.second[0]); 942 } else { 943 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0, 944 memberDecoration.first, 0); 945 } 946 } 947 } 948 } 949 } 950 } 951 952 uint32_t structID = operands[0]; 953 std::string structIdentifier = nameMap.lookup(structID).str(); 954 955 if (structIdentifier.empty()) { 956 assert(unresolvedMemberTypes.empty() && 957 "didn't expect unresolved member types"); 958 typeMap[structID] = 959 spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); 960 } else { 961 auto structTy = spirv::StructType::getIdentified(context, structIdentifier); 962 typeMap[structID] = structTy; 963 964 if (!unresolvedMemberTypes.empty()) 965 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes, 966 memberTypes, offsetInfo, 967 memberDecorationsInfo}); 968 else if (failed(structTy.trySetBody(memberTypes, offsetInfo, 969 memberDecorationsInfo))) 970 return failure(); 971 } 972 973 // TODO: Update StructType to have member name as attribute as 974 // well. 975 return success(); 976 } 977 978 LogicalResult 979 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) { 980 if (operands.size() != 3) { 981 // Three operands are needed: result_id, column_type, and column_count 982 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands" 983 " (result_id, column_type, and column_count)"); 984 } 985 // Matrix columns must be of vector type 986 Type elementTy = getType(operands[1]); 987 if (!elementTy) { 988 return emitError(unknownLoc, 989 "OpTypeMatrix references undefined column type.") 990 << operands[1]; 991 } 992 993 uint32_t colsCount = operands[2]; 994 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount); 995 return success(); 996 } 997 998 LogicalResult 999 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) { 1000 if (operands.size() != 2) 1001 return emitError(unknownLoc, 1002 "OpTypeForwardPointer instruction must have two operands"); 1003 1004 typeForwardPointerIDs.insert(operands[0]); 1005 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer 1006 // instruction that defines the actual type. 1007 1008 return success(); 1009 } 1010 1011 LogicalResult 1012 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) { 1013 // TODO: Add support for Access Qualifier. 1014 if (operands.size() != 8) 1015 return emitError( 1016 unknownLoc, 1017 "OpTypeImage with non-eight operands are not supported yet"); 1018 1019 Type elementTy = getType(operands[1]); 1020 if (!elementTy) 1021 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ") 1022 << operands[1]; 1023 1024 auto dim = spirv::symbolizeDim(operands[2]); 1025 if (!dim) 1026 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ") 1027 << operands[2]; 1028 1029 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]); 1030 if (!depthInfo) 1031 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ") 1032 << operands[3]; 1033 1034 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]); 1035 if (!arrayedInfo) 1036 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ") 1037 << operands[4]; 1038 1039 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]); 1040 if (!samplingInfo) 1041 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5]; 1042 1043 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]); 1044 if (!samplerUseInfo) 1045 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ") 1046 << operands[6]; 1047 1048 auto format = spirv::symbolizeImageFormat(operands[7]); 1049 if (!format) 1050 return emitError(unknownLoc, "unknown Format for OpTypeImage: ") 1051 << operands[7]; 1052 1053 typeMap[operands[0]] = spirv::ImageType::get( 1054 elementTy, dim.getValue(), depthInfo.getValue(), arrayedInfo.getValue(), 1055 samplingInfo.getValue(), samplerUseInfo.getValue(), format.getValue()); 1056 return success(); 1057 } 1058 1059 LogicalResult 1060 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) { 1061 if (operands.size() != 2) 1062 return emitError(unknownLoc, "OpTypeSampledImage must have two operands"); 1063 1064 Type elementTy = getType(operands[1]); 1065 if (!elementTy) 1066 return emitError(unknownLoc, 1067 "OpTypeSampledImage references undefined <id>: ") 1068 << operands[1]; 1069 1070 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy); 1071 return success(); 1072 } 1073 1074 //===----------------------------------------------------------------------===// 1075 // Constant 1076 //===----------------------------------------------------------------------===// 1077 1078 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands, 1079 bool isSpec) { 1080 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant"; 1081 1082 if (operands.size() < 2) { 1083 return emitError(unknownLoc) 1084 << opname << " must have type <id> and result <id>"; 1085 } 1086 if (operands.size() < 3) { 1087 return emitError(unknownLoc) 1088 << opname << " must have at least 1 more parameter"; 1089 } 1090 1091 Type resultType = getType(operands[0]); 1092 if (!resultType) { 1093 return emitError(unknownLoc, "undefined result type from <id> ") 1094 << operands[0]; 1095 } 1096 1097 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult { 1098 if (bitwidth == 64) { 1099 if (operands.size() == 4) { 1100 return success(); 1101 } 1102 return emitError(unknownLoc) 1103 << opname << " should have 2 parameters for 64-bit values"; 1104 } 1105 if (bitwidth <= 32) { 1106 if (operands.size() == 3) { 1107 return success(); 1108 } 1109 1110 return emitError(unknownLoc) 1111 << opname 1112 << " should have 1 parameter for values with no more than 32 bits"; 1113 } 1114 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ") 1115 << bitwidth; 1116 }; 1117 1118 auto resultID = operands[1]; 1119 1120 if (auto intType = resultType.dyn_cast<IntegerType>()) { 1121 auto bitwidth = intType.getWidth(); 1122 if (failed(checkOperandSizeForBitwidth(bitwidth))) { 1123 return failure(); 1124 } 1125 1126 APInt value; 1127 if (bitwidth == 64) { 1128 // 64-bit integers are represented with two SPIR-V words. According to 1129 // SPIR-V spec: "When the type’s bit width is larger than one word, the 1130 // literal’s low-order words appear first." 1131 struct DoubleWord { 1132 uint32_t word1; 1133 uint32_t word2; 1134 } words = {operands[2], operands[3]}; 1135 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true); 1136 } else if (bitwidth <= 32) { 1137 value = APInt(bitwidth, operands[2], /*isSigned=*/true); 1138 } 1139 1140 auto attr = opBuilder.getIntegerAttr(intType, value); 1141 1142 if (isSpec) { 1143 createSpecConstant(unknownLoc, resultID, attr); 1144 } else { 1145 // For normal constants, we just record the attribute (and its type) for 1146 // later materialization at use sites. 1147 constantMap.try_emplace(resultID, attr, intType); 1148 } 1149 1150 return success(); 1151 } 1152 1153 if (auto floatType = resultType.dyn_cast<FloatType>()) { 1154 auto bitwidth = floatType.getWidth(); 1155 if (failed(checkOperandSizeForBitwidth(bitwidth))) { 1156 return failure(); 1157 } 1158 1159 APFloat value(0.f); 1160 if (floatType.isF64()) { 1161 // Double values are represented with two SPIR-V words. According to 1162 // SPIR-V spec: "When the type’s bit width is larger than one word, the 1163 // literal’s low-order words appear first." 1164 struct DoubleWord { 1165 uint32_t word1; 1166 uint32_t word2; 1167 } words = {operands[2], operands[3]}; 1168 value = APFloat(llvm::bit_cast<double>(words)); 1169 } else if (floatType.isF32()) { 1170 value = APFloat(llvm::bit_cast<float>(operands[2])); 1171 } else if (floatType.isF16()) { 1172 APInt data(16, operands[2]); 1173 value = APFloat(APFloat::IEEEhalf(), data); 1174 } 1175 1176 auto attr = opBuilder.getFloatAttr(floatType, value); 1177 if (isSpec) { 1178 createSpecConstant(unknownLoc, resultID, attr); 1179 } else { 1180 // For normal constants, we just record the attribute (and its type) for 1181 // later materialization at use sites. 1182 constantMap.try_emplace(resultID, attr, floatType); 1183 } 1184 1185 return success(); 1186 } 1187 1188 return emitError(unknownLoc, "OpConstant can only generate values of " 1189 "scalar integer or floating-point type"); 1190 } 1191 1192 LogicalResult spirv::Deserializer::processConstantBool( 1193 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) { 1194 if (operands.size() != 2) { 1195 return emitError(unknownLoc, "Op") 1196 << (isSpec ? "Spec" : "") << "Constant" 1197 << (isTrue ? "True" : "False") 1198 << " must have type <id> and result <id>"; 1199 } 1200 1201 auto attr = opBuilder.getBoolAttr(isTrue); 1202 auto resultID = operands[1]; 1203 if (isSpec) { 1204 createSpecConstant(unknownLoc, resultID, attr); 1205 } else { 1206 // For normal constants, we just record the attribute (and its type) for 1207 // later materialization at use sites. 1208 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type()); 1209 } 1210 1211 return success(); 1212 } 1213 1214 LogicalResult 1215 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) { 1216 if (operands.size() < 2) { 1217 return emitError(unknownLoc, 1218 "OpConstantComposite must have type <id> and result <id>"); 1219 } 1220 if (operands.size() < 3) { 1221 return emitError(unknownLoc, 1222 "OpConstantComposite must have at least 1 parameter"); 1223 } 1224 1225 Type resultType = getType(operands[0]); 1226 if (!resultType) { 1227 return emitError(unknownLoc, "undefined result type from <id> ") 1228 << operands[0]; 1229 } 1230 1231 SmallVector<Attribute, 4> elements; 1232 elements.reserve(operands.size() - 2); 1233 for (unsigned i = 2, e = operands.size(); i < e; ++i) { 1234 auto elementInfo = getConstant(operands[i]); 1235 if (!elementInfo) { 1236 return emitError(unknownLoc, "OpConstantComposite component <id> ") 1237 << operands[i] << " must come from a normal constant"; 1238 } 1239 elements.push_back(elementInfo->first); 1240 } 1241 1242 auto resultID = operands[1]; 1243 if (auto vectorType = resultType.dyn_cast<VectorType>()) { 1244 auto attr = DenseElementsAttr::get(vectorType, elements); 1245 // For normal constants, we just record the attribute (and its type) for 1246 // later materialization at use sites. 1247 constantMap.try_emplace(resultID, attr, resultType); 1248 } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) { 1249 auto attr = opBuilder.getArrayAttr(elements); 1250 constantMap.try_emplace(resultID, attr, resultType); 1251 } else { 1252 return emitError(unknownLoc, "unsupported OpConstantComposite type: ") 1253 << resultType; 1254 } 1255 1256 return success(); 1257 } 1258 1259 LogicalResult 1260 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) { 1261 if (operands.size() < 2) { 1262 return emitError(unknownLoc, 1263 "OpConstantComposite must have type <id> and result <id>"); 1264 } 1265 if (operands.size() < 3) { 1266 return emitError(unknownLoc, 1267 "OpConstantComposite must have at least 1 parameter"); 1268 } 1269 1270 Type resultType = getType(operands[0]); 1271 if (!resultType) { 1272 return emitError(unknownLoc, "undefined result type from <id> ") 1273 << operands[0]; 1274 } 1275 1276 auto resultID = operands[1]; 1277 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); 1278 1279 SmallVector<Attribute, 4> elements; 1280 elements.reserve(operands.size() - 2); 1281 for (unsigned i = 2, e = operands.size(); i < e; ++i) { 1282 auto elementInfo = getSpecConstant(operands[i]); 1283 elements.push_back(opBuilder.getSymbolRefAttr(elementInfo)); 1284 } 1285 1286 auto op = opBuilder.create<spirv::SpecConstantCompositeOp>( 1287 unknownLoc, TypeAttr::get(resultType), symName, 1288 opBuilder.getArrayAttr(elements)); 1289 specConstCompositeMap[resultID] = op; 1290 1291 return success(); 1292 } 1293 1294 LogicalResult 1295 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) { 1296 if (operands.size() < 3) 1297 return emitError(unknownLoc, "OpConstantOperation must have type <id>, " 1298 "result <id>, and operand opcode"); 1299 1300 uint32_t resultTypeID = operands[0]; 1301 1302 if (!getType(resultTypeID)) 1303 return emitError(unknownLoc, "undefined result type from <id> ") 1304 << resultTypeID; 1305 1306 uint32_t resultID = operands[1]; 1307 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]); 1308 auto emplaceResult = specConstOperationMap.try_emplace( 1309 resultID, 1310 SpecConstOperationMaterializationInfo{ 1311 enclosedOpcode, resultTypeID, 1312 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}}); 1313 1314 if (!emplaceResult.second) 1315 return emitError(unknownLoc, "value with <id>: ") 1316 << resultID << " is probably defined before."; 1317 1318 return success(); 1319 } 1320 1321 Value spirv::Deserializer::materializeSpecConstantOperation( 1322 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, 1323 ArrayRef<uint32_t> enclosedOpOperands) { 1324 1325 Type resultType = getType(resultTypeID); 1326 1327 // Instructions wrapped by OpSpecConstantOp need an ID for their 1328 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V 1329 // dialect wrapped op. For that purpose, a new value map is created and "fake" 1330 // ID in that map is assigned to the result of the enclosed instruction. Note 1331 // that there is no need to update this fake ID since we only need to 1332 // reference the created Value for the enclosed op from the spv::YieldOp 1333 // created later in this method (both of which are the only values in their 1334 // region: the SpecConstantOperation's region). If we encounter another 1335 // SpecConstantOperation in the module, we simply re-use the fake ID since the 1336 // previous Value assigned to it isn't visible in the current scope anyway. 1337 DenseMap<uint32_t, Value> newValueMap; 1338 llvm::SaveAndRestore<DenseMap<uint32_t, Value>> valueMapGuard(valueMap, 1339 newValueMap); 1340 constexpr uint32_t fakeID = static_cast<uint32_t>(-3); 1341 1342 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands; 1343 enclosedOpResultTypeAndOperands.push_back(resultTypeID); 1344 enclosedOpResultTypeAndOperands.push_back(fakeID); 1345 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(), 1346 enclosedOpOperands.end()); 1347 1348 // Process enclosed instruction before creating the enclosing 1349 // specConstantOperation (and its region). This way, references to constants, 1350 // global variables, and spec constants will be materialized outside the new 1351 // op's region. For more info, see Deserializer::getValue's implementation. 1352 if (failed( 1353 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands))) 1354 return Value(); 1355 1356 // Since the enclosed op is emitted in the current block, split it in a 1357 // separate new block. 1358 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back()); 1359 1360 auto loc = createFileLineColLoc(opBuilder); 1361 auto specConstOperationOp = 1362 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType); 1363 1364 Region &body = specConstOperationOp.body(); 1365 // Move the new block into SpecConstantOperation's body. 1366 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(), 1367 Region::iterator(enclosedBlock)); 1368 Block &block = body.back(); 1369 1370 // RAII guard to reset the insertion point to the module's region after 1371 // deserializing the body of the specConstantOperation. 1372 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); 1373 opBuilder.setInsertionPointToEnd(&block); 1374 1375 opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0)); 1376 return specConstOperationOp.getResult(); 1377 } 1378 1379 LogicalResult 1380 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) { 1381 if (operands.size() != 2) { 1382 return emitError(unknownLoc, 1383 "OpConstantNull must have type <id> and result <id>"); 1384 } 1385 1386 Type resultType = getType(operands[0]); 1387 if (!resultType) { 1388 return emitError(unknownLoc, "undefined result type from <id> ") 1389 << operands[0]; 1390 } 1391 1392 auto resultID = operands[1]; 1393 if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) { 1394 auto attr = opBuilder.getZeroAttr(resultType); 1395 // For normal constants, we just record the attribute (and its type) for 1396 // later materialization at use sites. 1397 constantMap.try_emplace(resultID, attr, resultType); 1398 return success(); 1399 } 1400 1401 return emitError(unknownLoc, "unsupported OpConstantNull type: ") 1402 << resultType; 1403 } 1404 1405 //===----------------------------------------------------------------------===// 1406 // Control flow 1407 //===----------------------------------------------------------------------===// 1408 1409 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) { 1410 if (auto *block = getBlock(id)) { 1411 LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id 1412 << " @ " << block << "\n"); 1413 return block; 1414 } 1415 1416 // We don't know where this block will be placed finally (in a 1417 // spv.mlir.selection or spv.mlir.loop or function). Create it into the 1418 // function for now and sort out the proper place later. 1419 auto *block = curFunction->addBlock(); 1420 LLVM_DEBUG(llvm::dbgs() << "[block] created block for id = " << id << " @ " 1421 << block << "\n"); 1422 return blockMap[id] = block; 1423 } 1424 1425 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) { 1426 if (!curBlock) { 1427 return emitError(unknownLoc, "OpBranch must appear inside a block"); 1428 } 1429 1430 if (operands.size() != 1) { 1431 return emitError(unknownLoc, "OpBranch must take exactly one target label"); 1432 } 1433 1434 auto *target = getOrCreateBlock(operands[0]); 1435 auto loc = createFileLineColLoc(opBuilder); 1436 // The preceding instruction for the OpBranch instruction could be an 1437 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have 1438 // the same OpLine information. 1439 opBuilder.create<spirv::BranchOp>(loc, target); 1440 1441 (void)clearDebugLine(); 1442 return success(); 1443 } 1444 1445 LogicalResult 1446 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) { 1447 if (!curBlock) { 1448 return emitError(unknownLoc, 1449 "OpBranchConditional must appear inside a block"); 1450 } 1451 1452 if (operands.size() != 3 && operands.size() != 5) { 1453 return emitError(unknownLoc, 1454 "OpBranchConditional must have condition, true label, " 1455 "false label, and optionally two branch weights"); 1456 } 1457 1458 auto condition = getValue(operands[0]); 1459 auto *trueBlock = getOrCreateBlock(operands[1]); 1460 auto *falseBlock = getOrCreateBlock(operands[2]); 1461 1462 Optional<std::pair<uint32_t, uint32_t>> weights; 1463 if (operands.size() == 5) { 1464 weights = std::make_pair(operands[3], operands[4]); 1465 } 1466 // The preceding instruction for the OpBranchConditional instruction could be 1467 // an OpSelectionMerge instruction, in this case they will have the same 1468 // OpLine information. 1469 auto loc = createFileLineColLoc(opBuilder); 1470 opBuilder.create<spirv::BranchConditionalOp>( 1471 loc, condition, trueBlock, 1472 /*trueArguments=*/ArrayRef<Value>(), falseBlock, 1473 /*falseArguments=*/ArrayRef<Value>(), weights); 1474 1475 (void)clearDebugLine(); 1476 return success(); 1477 } 1478 1479 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) { 1480 if (!curFunction) { 1481 return emitError(unknownLoc, "OpLabel must appear inside a function"); 1482 } 1483 1484 if (operands.size() != 1) { 1485 return emitError(unknownLoc, "OpLabel should only have result <id>"); 1486 } 1487 1488 auto labelID = operands[0]; 1489 // We may have forward declared this block. 1490 auto *block = getOrCreateBlock(labelID); 1491 LLVM_DEBUG(llvm::dbgs() << "[block] populating block " << block << "\n"); 1492 // If we have seen this block, make sure it was just a forward declaration. 1493 assert(block->empty() && "re-deserialize the same block!"); 1494 1495 opBuilder.setInsertionPointToStart(block); 1496 blockMap[labelID] = curBlock = block; 1497 1498 return success(); 1499 } 1500 1501 LogicalResult 1502 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) { 1503 if (!curBlock) { 1504 return emitError(unknownLoc, "OpSelectionMerge must appear in a block"); 1505 } 1506 1507 if (operands.size() < 2) { 1508 return emitError( 1509 unknownLoc, 1510 "OpSelectionMerge must specify merge target and selection control"); 1511 } 1512 1513 auto *mergeBlock = getOrCreateBlock(operands[0]); 1514 auto loc = createFileLineColLoc(opBuilder); 1515 auto selectionControl = operands[1]; 1516 1517 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock) 1518 .second) { 1519 return emitError( 1520 unknownLoc, 1521 "a block cannot have more than one OpSelectionMerge instruction"); 1522 } 1523 1524 return success(); 1525 } 1526 1527 LogicalResult 1528 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) { 1529 if (!curBlock) { 1530 return emitError(unknownLoc, "OpLoopMerge must appear in a block"); 1531 } 1532 1533 if (operands.size() < 3) { 1534 return emitError(unknownLoc, "OpLoopMerge must specify merge target, " 1535 "continue target and loop control"); 1536 } 1537 1538 auto *mergeBlock = getOrCreateBlock(operands[0]); 1539 auto *continueBlock = getOrCreateBlock(operands[1]); 1540 auto loc = createFileLineColLoc(opBuilder); 1541 uint32_t loopControl = operands[2]; 1542 1543 if (!blockMergeInfo 1544 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock) 1545 .second) { 1546 return emitError( 1547 unknownLoc, 1548 "a block cannot have more than one OpLoopMerge instruction"); 1549 } 1550 1551 return success(); 1552 } 1553 1554 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) { 1555 if (!curBlock) { 1556 return emitError(unknownLoc, "OpPhi must appear in a block"); 1557 } 1558 1559 if (operands.size() < 4) { 1560 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, " 1561 "and variable-parent pairs"); 1562 } 1563 1564 // Create a block argument for this OpPhi instruction. 1565 Type blockArgType = getType(operands[0]); 1566 BlockArgument blockArg = curBlock->addArgument(blockArgType); 1567 valueMap[operands[1]] = blockArg; 1568 LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg 1569 << " id = " << operands[1] << " of type " 1570 << blockArgType << '\n'); 1571 1572 // For each (value, predecessor) pair, insert the value to the predecessor's 1573 // blockPhiInfo entry so later we can fix the block argument there. 1574 for (unsigned i = 2, e = operands.size(); i < e; i += 2) { 1575 uint32_t value = operands[i]; 1576 Block *predecessor = getOrCreateBlock(operands[i + 1]); 1577 blockPhiInfo[predecessor].push_back(value); 1578 LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor 1579 << " with arg id = " << value << '\n'); 1580 } 1581 1582 return success(); 1583 } 1584 1585 namespace { 1586 /// A class for putting all blocks in a structured selection/loop in a 1587 /// spv.mlir.selection/spv.mlir.loop op. 1588 class ControlFlowStructurizer { 1589 public: 1590 /// Structurizes the loop at the given `headerBlock`. 1591 /// 1592 /// This method will create an spv.mlir.loop op in the `mergeBlock` and move 1593 /// all blocks in the structured loop into the spv.mlir.loop's region. All 1594 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This 1595 /// method will also update `mergeInfo` by remapping all blocks inside to the 1596 /// newly cloned ones inside structured control flow op's regions. 1597 static LogicalResult structurize(Location loc, uint32_t control, 1598 spirv::BlockMergeInfoMap &mergeInfo, 1599 Block *headerBlock, Block *mergeBlock, 1600 Block *continueBlock) { 1601 return ControlFlowStructurizer(loc, control, mergeInfo, headerBlock, 1602 mergeBlock, continueBlock) 1603 .structurizeImpl(); 1604 } 1605 1606 private: 1607 ControlFlowStructurizer(Location loc, uint32_t control, 1608 spirv::BlockMergeInfoMap &mergeInfo, Block *header, 1609 Block *merge, Block *cont) 1610 : location(loc), control(control), blockMergeInfo(mergeInfo), 1611 headerBlock(header), mergeBlock(merge), continueBlock(cont) {} 1612 1613 /// Creates a new spv.mlir.selection op at the beginning of the `mergeBlock`. 1614 spirv::SelectionOp createSelectionOp(uint32_t selectionControl); 1615 1616 /// Creates a new spv.mlir.loop op at the beginning of the `mergeBlock`. 1617 spirv::LoopOp createLoopOp(uint32_t loopControl); 1618 1619 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`. 1620 void collectBlocksInConstruct(); 1621 1622 LogicalResult structurizeImpl(); 1623 1624 Location location; 1625 uint32_t control; 1626 1627 spirv::BlockMergeInfoMap &blockMergeInfo; 1628 1629 Block *headerBlock; 1630 Block *mergeBlock; 1631 Block *continueBlock; // nullptr for spv.mlir.selection 1632 1633 llvm::SetVector<Block *> constructBlocks; 1634 }; 1635 } // namespace 1636 1637 spirv::SelectionOp 1638 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) { 1639 // Create a builder and set the insertion point to the beginning of the 1640 // merge block so that the newly created SelectionOp will be inserted there. 1641 OpBuilder builder(&mergeBlock->front()); 1642 1643 auto control = builder.getI32IntegerAttr(selectionControl); 1644 auto selectionOp = builder.create<spirv::SelectionOp>(location, control); 1645 selectionOp.addMergeBlock(); 1646 1647 return selectionOp; 1648 } 1649 1650 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) { 1651 // Create a builder and set the insertion point to the beginning of the 1652 // merge block so that the newly created LoopOp will be inserted there. 1653 OpBuilder builder(&mergeBlock->front()); 1654 1655 auto control = builder.getI32IntegerAttr(loopControl); 1656 auto loopOp = builder.create<spirv::LoopOp>(location, control); 1657 loopOp.addEntryAndMergeBlock(); 1658 1659 return loopOp; 1660 } 1661 1662 void ControlFlowStructurizer::collectBlocksInConstruct() { 1663 assert(constructBlocks.empty() && "expected empty constructBlocks"); 1664 1665 // Put the header block in the work list first. 1666 constructBlocks.insert(headerBlock); 1667 1668 // For each item in the work list, add its successors excluding the merge 1669 // block. 1670 for (unsigned i = 0; i < constructBlocks.size(); ++i) { 1671 for (auto *successor : constructBlocks[i]->getSuccessors()) 1672 if (successor != mergeBlock) 1673 constructBlocks.insert(successor); 1674 } 1675 } 1676 1677 LogicalResult ControlFlowStructurizer::structurizeImpl() { 1678 Operation *op = nullptr; 1679 bool isLoop = continueBlock != nullptr; 1680 if (isLoop) { 1681 if (auto loopOp = createLoopOp(control)) 1682 op = loopOp.getOperation(); 1683 } else { 1684 if (auto selectionOp = createSelectionOp(control)) 1685 op = selectionOp.getOperation(); 1686 } 1687 if (!op) 1688 return failure(); 1689 Region &body = op->getRegion(0); 1690 1691 BlockAndValueMapping mapper; 1692 // All references to the old merge block should be directed to the 1693 // selection/loop merge block in the SelectionOp/LoopOp's region. 1694 mapper.map(mergeBlock, &body.back()); 1695 1696 collectBlocksInConstruct(); 1697 1698 // We've identified all blocks belonging to the selection/loop's region. Now 1699 // need to "move" them into the selection/loop. Instead of really moving the 1700 // blocks, in the following we copy them and remap all values and branches. 1701 // This is because: 1702 // * Inserting a block into a region requires the block not in any region 1703 // before. But selections/loops can nest so we can create selection/loop ops 1704 // in a nested manner, which means some blocks may already be in a 1705 // selection/loop region when to be moved again. 1706 // * It's much trickier to fix up the branches into and out of the loop's 1707 // region: we need to treat not-moved blocks and moved blocks differently: 1708 // Not-moved blocks jumping to the loop header block need to jump to the 1709 // merge point containing the new loop op but not the loop continue block's 1710 // back edge. Moved blocks jumping out of the loop need to jump to the 1711 // merge block inside the loop region but not other not-moved blocks. 1712 // We cannot use replaceAllUsesWith clearly and it's harder to follow the 1713 // logic. 1714 1715 // Create a corresponding block in the SelectionOp/LoopOp's region for each 1716 // block in this loop construct. 1717 OpBuilder builder(body); 1718 for (auto *block : constructBlocks) { 1719 // Create a block and insert it before the selection/loop merge block in the 1720 // SelectionOp/LoopOp's region. 1721 auto *newBlock = builder.createBlock(&body.back()); 1722 mapper.map(block, newBlock); 1723 LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock 1724 << " from block " << block << "\n"); 1725 if (!isFnEntryBlock(block)) { 1726 for (BlockArgument blockArg : block->getArguments()) { 1727 auto newArg = newBlock->addArgument(blockArg.getType()); 1728 mapper.map(blockArg, newArg); 1729 LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg 1730 << " to " << newArg << '\n'); 1731 } 1732 } else { 1733 LLVM_DEBUG(llvm::dbgs() 1734 << "[cf] block " << block << " is a function entry block\n"); 1735 } 1736 for (auto &op : *block) 1737 newBlock->push_back(op.clone(mapper)); 1738 } 1739 1740 // Go through all ops and remap the operands. 1741 auto remapOperands = [&](Operation *op) { 1742 for (auto &operand : op->getOpOperands()) 1743 if (Value mappedOp = mapper.lookupOrNull(operand.get())) 1744 operand.set(mappedOp); 1745 for (auto &succOp : op->getBlockOperands()) 1746 if (Block *mappedOp = mapper.lookupOrNull(succOp.get())) 1747 succOp.set(mappedOp); 1748 }; 1749 for (auto &block : body) { 1750 block.walk(remapOperands); 1751 } 1752 1753 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to 1754 // the selection/loop construct into its region. Next we need to fix the 1755 // connections between this new SelectionOp/LoopOp with existing blocks. 1756 1757 // All existing incoming branches should go to the merge block, where the 1758 // SelectionOp/LoopOp resides right now. 1759 headerBlock->replaceAllUsesWith(mergeBlock); 1760 1761 if (isLoop) { 1762 // The loop selection/loop header block may have block arguments. Since now 1763 // we place the selection/loop op inside the old merge block, we need to 1764 // make sure the old merge block has the same block argument list. 1765 assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported"); 1766 for (BlockArgument blockArg : headerBlock->getArguments()) { 1767 mergeBlock->addArgument(blockArg.getType()); 1768 } 1769 1770 // If the loop header block has block arguments, make sure the spv.branch op 1771 // matches. 1772 SmallVector<Value, 4> blockArgs; 1773 if (!headerBlock->args_empty()) 1774 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; 1775 1776 // The loop entry block should have a unconditional branch jumping to the 1777 // loop header block. 1778 builder.setInsertionPointToEnd(&body.front()); 1779 builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock), 1780 ArrayRef<Value>(blockArgs)); 1781 } 1782 1783 // All the blocks cloned into the SelectionOp/LoopOp's region can now be 1784 // cleaned up. 1785 LLVM_DEBUG(llvm::dbgs() << "[cf] cleaning up blocks after clone\n"); 1786 // First we need to drop all operands' references inside all blocks. This is 1787 // needed because we can have blocks referencing SSA values from one another. 1788 for (auto *block : constructBlocks) 1789 block->dropAllReferences(); 1790 1791 // Then erase all old blocks. 1792 for (auto *block : constructBlocks) { 1793 // We've cloned all blocks belonging to this construct into the structured 1794 // control flow op's region. Among these blocks, some may compose another 1795 // selection/loop. If so, they will be recorded within blockMergeInfo. 1796 // We need to update the pointers there to the newly remapped ones so we can 1797 // continue structurizing them later. 1798 // TODO: The asserts in the following assumes input SPIR-V blob 1799 // forms correctly nested selection/loop constructs. We should relax this 1800 // and support error cases better. 1801 auto it = blockMergeInfo.find(block); 1802 if (it != blockMergeInfo.end()) { 1803 Block *newHeader = mapper.lookupOrNull(block); 1804 assert(newHeader && "nested loop header block should be remapped!"); 1805 1806 Block *newContinue = it->second.continueBlock; 1807 if (newContinue) { 1808 newContinue = mapper.lookupOrNull(newContinue); 1809 assert(newContinue && "nested loop continue block should be remapped!"); 1810 } 1811 1812 Block *newMerge = it->second.mergeBlock; 1813 if (Block *mappedTo = mapper.lookupOrNull(newMerge)) 1814 newMerge = mappedTo; 1815 1816 // Keep original location for nested selection/loop ops. 1817 Location loc = it->second.loc; 1818 // The iterator should be erased before adding a new entry into 1819 // blockMergeInfo to avoid iterator invalidation. 1820 blockMergeInfo.erase(it); 1821 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge, 1822 newContinue); 1823 } 1824 1825 // The structured selection/loop's entry block does not have arguments. 1826 // If the function's header block is also part of the structured control 1827 // flow, we cannot just simply erase it because it may contain arguments 1828 // matching the function signature and used by the cloned blocks. 1829 if (isFnEntryBlock(block)) { 1830 LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block 1831 << " to only contain a spv.Branch op\n"); 1832 // Still keep the function entry block for the potential block arguments, 1833 // but replace all ops inside with a branch to the merge block. 1834 block->clear(); 1835 builder.setInsertionPointToEnd(block); 1836 builder.create<spirv::BranchOp>(location, mergeBlock); 1837 } else { 1838 LLVM_DEBUG(llvm::dbgs() << "[cf] erasing block " << block << "\n"); 1839 block->erase(); 1840 } 1841 } 1842 1843 LLVM_DEBUG( 1844 llvm::dbgs() << "[cf] after structurizing construct with header block " 1845 << headerBlock << ":\n" 1846 << *op << '\n'); 1847 1848 return success(); 1849 } 1850 1851 LogicalResult spirv::Deserializer::wireUpBlockArgument() { 1852 LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n"); 1853 1854 OpBuilder::InsertionGuard guard(opBuilder); 1855 1856 for (const auto &info : blockPhiInfo) { 1857 Block *block = info.first; 1858 const BlockPhiInfo &phiInfo = info.second; 1859 LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n"); 1860 LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n"); 1861 LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); 1862 LLVM_DEBUG(llvm::dbgs() << '\n'); 1863 1864 // Set insertion point to before this block's terminator early because we 1865 // may materialize ops via getValue() call. 1866 auto *op = block->getTerminator(); 1867 opBuilder.setInsertionPoint(op); 1868 1869 SmallVector<Value, 4> blockArgs; 1870 blockArgs.reserve(phiInfo.size()); 1871 for (uint32_t valueId : phiInfo) { 1872 if (Value value = getValue(valueId)) { 1873 blockArgs.push_back(value); 1874 LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value 1875 << " id = " << valueId << '\n'); 1876 } else { 1877 return emitError(unknownLoc, "OpPhi references undefined value!"); 1878 } 1879 } 1880 1881 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) { 1882 // Replace the previous branch op with a new one with block arguments. 1883 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(), 1884 blockArgs); 1885 branchOp.erase(); 1886 } else { 1887 return emitError(unknownLoc, "unimplemented terminator for Phi creation"); 1888 } 1889 1890 LLVM_DEBUG(llvm::dbgs() << "[phi] after creating block argument:\n"); 1891 LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); 1892 LLVM_DEBUG(llvm::dbgs() << '\n'); 1893 } 1894 blockPhiInfo.clear(); 1895 1896 LLVM_DEBUG(llvm::dbgs() << "[phi] completed wiring up block arguments\n"); 1897 return success(); 1898 } 1899 1900 LogicalResult spirv::Deserializer::structurizeControlFlow() { 1901 LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n"); 1902 1903 while (!blockMergeInfo.empty()) { 1904 Block *headerBlock = blockMergeInfo.begin()->first; 1905 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; 1906 1907 LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n"); 1908 LLVM_DEBUG(headerBlock->print(llvm::dbgs())); 1909 1910 auto *mergeBlock = mergeInfo.mergeBlock; 1911 assert(mergeBlock && "merge block cannot be nullptr"); 1912 if (!mergeBlock->args_empty()) 1913 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented"); 1914 LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n"); 1915 LLVM_DEBUG(mergeBlock->print(llvm::dbgs())); 1916 1917 auto *continueBlock = mergeInfo.continueBlock; 1918 if (continueBlock) { 1919 LLVM_DEBUG(llvm::dbgs() 1920 << "[cf] continue block " << continueBlock << ":\n"); 1921 LLVM_DEBUG(continueBlock->print(llvm::dbgs())); 1922 } 1923 // Erase this case before calling into structurizer, who will update 1924 // blockMergeInfo. 1925 blockMergeInfo.erase(blockMergeInfo.begin()); 1926 if (failed(ControlFlowStructurizer::structurize( 1927 mergeInfo.loc, mergeInfo.control, blockMergeInfo, headerBlock, 1928 mergeBlock, continueBlock))) 1929 return failure(); 1930 } 1931 1932 LLVM_DEBUG(llvm::dbgs() << "[cf] completed structurizing control flow\n"); 1933 return success(); 1934 } 1935 1936 //===----------------------------------------------------------------------===// 1937 // Debug 1938 //===----------------------------------------------------------------------===// 1939 1940 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) { 1941 if (!debugLine) 1942 return unknownLoc; 1943 1944 auto fileName = debugInfoMap.lookup(debugLine->fileID).str(); 1945 if (fileName.empty()) 1946 fileName = "<unknown>"; 1947 return FileLineColLoc::get(opBuilder.getIdentifier(fileName), debugLine->line, 1948 debugLine->col); 1949 } 1950 1951 LogicalResult 1952 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) { 1953 // According to SPIR-V spec: 1954 // "This location information applies to the instructions physically 1955 // following this instruction, up to the first occurrence of any of the 1956 // following: the next end of block, the next OpLine instruction, or the next 1957 // OpNoLine instruction." 1958 if (operands.size() != 3) 1959 return emitError(unknownLoc, "OpLine must have 3 operands"); 1960 debugLine = DebugLine(operands[0], operands[1], operands[2]); 1961 return success(); 1962 } 1963 1964 LogicalResult spirv::Deserializer::clearDebugLine() { 1965 debugLine = llvm::None; 1966 return success(); 1967 } 1968 1969 LogicalResult 1970 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) { 1971 if (operands.size() < 2) 1972 return emitError(unknownLoc, "OpString needs at least 2 operands"); 1973 1974 if (!debugInfoMap.lookup(operands[0]).empty()) 1975 return emitError(unknownLoc, 1976 "duplicate debug string found for result <id> ") 1977 << operands[0]; 1978 1979 unsigned wordIndex = 1; 1980 StringRef debugString = decodeStringLiteral(operands, wordIndex); 1981 if (wordIndex != operands.size()) 1982 return emitError(unknownLoc, 1983 "unexpected trailing words in OpString instruction"); 1984 1985 debugInfoMap[operands[0]] = debugString; 1986 return success(); 1987 } 1988