1 //===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the MLIR SPIR-V module to SPIR-V binary serialization. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Target/SPIRV/Serialization.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/RegionGraphTraits.h" 21 #include "mlir/Support/LogicalResult.h" 22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 23 #include "llvm/ADT/DepthFirstIterator.h" 24 #include "llvm/ADT/Sequence.h" 25 #include "llvm/ADT/SetVector.h" 26 #include "llvm/ADT/SmallPtrSet.h" 27 #include "llvm/ADT/SmallVector.h" 28 #include "llvm/ADT/StringExtras.h" 29 #include "llvm/ADT/TypeSwitch.h" 30 #include "llvm/ADT/bit.h" 31 #include "llvm/Support/Debug.h" 32 #include "llvm/Support/raw_ostream.h" 33 34 #define DEBUG_TYPE "spirv-serialization" 35 36 using namespace mlir; 37 38 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into 39 /// the given `binary` vector. 40 static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, 41 spirv::Opcode op, 42 ArrayRef<uint32_t> operands) { 43 uint32_t wordCount = 1 + operands.size(); 44 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 45 binary.append(operands.begin(), operands.end()); 46 return success(); 47 } 48 49 /// A pre-order depth-first visitor function for processing basic blocks. 50 /// 51 /// Visits the basic blocks starting from the given `headerBlock` in pre-order 52 /// depth-first manner and calls `blockHandler` on each block. Skips handling 53 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` 54 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s 55 /// successors. 56 /// 57 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order 58 /// of blocks in a function must satisfy the rule that blocks appear before 59 /// all blocks they dominate." This can be achieved by a pre-order CFG 60 /// traversal algorithm. To make the serialization output more logical and 61 /// readable to human, we perform depth-first CFG traversal and delay the 62 /// serialization of the merge block and the continue block, if exists, until 63 /// after all other blocks have been processed. 64 static LogicalResult 65 visitInPrettyBlockOrder(Block *headerBlock, 66 function_ref<LogicalResult(Block *)> blockHandler, 67 bool skipHeader = false, BlockRange skipBlocks = {}) { 68 llvm::df_iterator_default_set<Block *, 4> doneBlocks; 69 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end()); 70 71 for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) { 72 if (skipHeader && block == headerBlock) 73 continue; 74 if (failed(blockHandler(block))) 75 return failure(); 76 } 77 return success(); 78 } 79 80 /// Returns the merge block if the given `op` is a structured control flow op. 81 /// Otherwise returns nullptr. 82 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { 83 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) 84 return selectionOp.getMergeBlock(); 85 if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) 86 return loopOp.getMergeBlock(); 87 return nullptr; 88 } 89 90 /// Given a predecessor `block` for a block with arguments, returns the block 91 /// that should be used as the parent block for SPIR-V OpPhi instructions 92 /// corresponding to the block arguments. 93 static Block *getPhiIncomingBlock(Block *block) { 94 // If the predecessor block in question is the entry block for a spv.loop, 95 // we jump to this spv.loop from its enclosing block. 96 if (block->isEntryBlock()) { 97 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) { 98 // Then the incoming parent block for OpPhi should be the merge block of 99 // the structured control flow op before this loop. 100 Operation *op = loopOp.getOperation(); 101 while ((op = op->getPrevNode()) != nullptr) 102 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) 103 return incomingBlock; 104 // Or the enclosing block itself if no structured control flow ops 105 // exists before this loop. 106 return loopOp->getBlock(); 107 } 108 } 109 110 // Otherwise, we jump from the given predecessor block. Try to see if there is 111 // a structured control flow op inside it. 112 for (Operation &op : llvm::reverse(block->getOperations())) { 113 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) 114 return incomingBlock; 115 } 116 return block; 117 } 118 119 namespace { 120 121 /// A SPIR-V module serializer. 122 /// 123 /// A SPIR-V binary module is a single linear stream of instructions; each 124 /// instruction is composed of 32-bit words with the layout: 125 /// 126 /// | <word-count>|<opcode> | <operand> | <operand> | ... | 127 /// | <------ word -------> | <-- word --> | <-- word --> | ... | 128 /// 129 /// For the first word, the 16 high-order bits are the word count of the 130 /// instruction, the 16 low-order bits are the opcode enumerant. The 131 /// instructions then belong to different sections, which must be laid out in 132 /// the particular order as specified in "2.4 Logical Layout of a Module" of 133 /// the SPIR-V spec. 134 class Serializer { 135 public: 136 /// Creates a serializer for the given SPIR-V `module`. 137 explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false); 138 139 /// Serializes the remembered SPIR-V module. 140 LogicalResult serialize(); 141 142 /// Collects the final SPIR-V `binary`. 143 void collect(SmallVectorImpl<uint32_t> &binary); 144 145 #ifndef NDEBUG 146 /// (For debugging) prints each value and its corresponding result <id>. 147 void printValueIDMap(raw_ostream &os); 148 #endif 149 150 private: 151 // Note that there are two main categories of methods in this class: 152 // * process*() methods are meant to fully serialize a SPIR-V module entity 153 // (header, type, op, etc.). They update internal vectors containing 154 // different binary sections. They are not meant to be called except the 155 // top-level serialization loop. 156 // * prepare*() methods are meant to be helpers that prepare for serializing 157 // certain entity. They may or may not update internal vectors containing 158 // different binary sections. They are meant to be called among themselves 159 // or by other process*() methods for subtasks. 160 161 //===--------------------------------------------------------------------===// 162 // <id> 163 //===--------------------------------------------------------------------===// 164 165 // Note that it is illegal to use id <0> in SPIR-V binary module. Various 166 // methods in this class, if using SPIR-V word (uint32_t) as interface, 167 // check or return id <0> to indicate error in processing. 168 169 /// Consumes the next unused <id>. This method will never return 0. 170 uint32_t getNextID() { return nextID++; } 171 172 //===--------------------------------------------------------------------===// 173 // Module structure 174 //===--------------------------------------------------------------------===// 175 176 uint32_t getSpecConstID(StringRef constName) const { 177 return specConstIDMap.lookup(constName); 178 } 179 180 uint32_t getVariableID(StringRef varName) const { 181 return globalVarIDMap.lookup(varName); 182 } 183 184 uint32_t getFunctionID(StringRef fnName) const { 185 return funcIDMap.lookup(fnName); 186 } 187 188 /// Gets the <id> for the function with the given name. Assigns the next 189 /// available <id> if the function haven't been deserialized. 190 uint32_t getOrCreateFunctionID(StringRef fnName); 191 192 void processCapability(); 193 194 void processDebugInfo(); 195 196 void processExtension(); 197 198 void processMemoryModel(); 199 200 LogicalResult processConstantOp(spirv::ConstantOp op); 201 202 LogicalResult processSpecConstantOp(spirv::SpecConstantOp op); 203 204 LogicalResult 205 processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op); 206 207 LogicalResult 208 processSpecConstantOperationOp(spirv::SpecConstantOperationOp op); 209 210 /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA 211 /// value to use with other operations. The SPIR-V spec recommends that 212 /// OpUndef be generated at module level. The serialization generates an 213 /// OpUndef for each type needed at module level. 214 LogicalResult processUndefOp(spirv::UndefOp op); 215 216 /// Emit OpName for the given `resultID`. 217 LogicalResult processName(uint32_t resultID, StringRef name); 218 219 /// Processes a SPIR-V function op. 220 LogicalResult processFuncOp(spirv::FuncOp op); 221 222 LogicalResult processVariableOp(spirv::VariableOp op); 223 224 /// Process a SPIR-V GlobalVariableOp 225 LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp); 226 227 /// Process attributes that translate to decorations on the result <id> 228 LogicalResult processDecoration(Location loc, uint32_t resultID, 229 NamedAttribute attr); 230 231 template <typename DType> 232 LogicalResult processTypeDecoration(Location loc, DType type, 233 uint32_t resultId) { 234 return emitError(loc, "unhandled decoration for type:") << type; 235 } 236 237 /// Process member decoration 238 LogicalResult processMemberDecoration( 239 uint32_t structID, 240 const spirv::StructType::MemberDecorationInfo &memberDecorationInfo); 241 242 //===--------------------------------------------------------------------===// 243 // Types 244 //===--------------------------------------------------------------------===// 245 246 uint32_t getTypeID(Type type) const { return typeIDMap.lookup(type); } 247 248 Type getVoidType() { return mlirBuilder.getNoneType(); } 249 250 bool isVoidType(Type type) const { return type.isa<NoneType>(); } 251 252 /// Returns true if the given type is a pointer type to a struct in some 253 /// interface storage class. 254 bool isInterfaceStructPtrType(Type type) const; 255 256 /// Main dispatch method for serializing a type. The result <id> of the 257 /// serialized type will be returned as `typeID`. 258 LogicalResult processType(Location loc, Type type, uint32_t &typeID); 259 LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID, 260 llvm::SetVector<StringRef> &serializationCtx); 261 262 /// Method for preparing basic SPIR-V type serialization. Returns the type's 263 /// opcode and operands for the instruction via `typeEnum` and `operands`. 264 LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID, 265 spirv::Opcode &typeEnum, 266 SmallVectorImpl<uint32_t> &operands, 267 bool &deferSerialization, 268 llvm::SetVector<StringRef> &serializationCtx); 269 270 LogicalResult prepareFunctionType(Location loc, FunctionType type, 271 spirv::Opcode &typeEnum, 272 SmallVectorImpl<uint32_t> &operands); 273 274 //===--------------------------------------------------------------------===// 275 // Constant 276 //===--------------------------------------------------------------------===// 277 278 uint32_t getConstantID(Attribute value) const { 279 return constIDMap.lookup(value); 280 } 281 282 /// Main dispatch method for processing a constant with the given `constType` 283 /// and `valueAttr`. `constType` is needed here because we can interpret the 284 /// `valueAttr` as a different type than the type of `valueAttr` itself; for 285 /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType 286 /// constants. 287 uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr); 288 289 /// Prepares array attribute serialization. This method emits corresponding 290 /// OpConstant* and returns the result <id> associated with it. Returns 0 if 291 /// failed. 292 uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr); 293 294 /// Prepares bool/int/float DenseElementsAttr serialization. This method 295 /// iterates the DenseElementsAttr to construct the constant array, and 296 /// returns the result <id> associated with it. Returns 0 if failed. Note 297 /// that the size of `index` must match the rank. 298 /// TODO: Consider to enhance splat elements cases. For splat cases, 299 /// we don't need to loop over all elements, especially when the splat value 300 /// is zero. We can use OpConstantNull when the value is zero. 301 uint32_t prepareDenseElementsConstant(Location loc, Type constType, 302 DenseElementsAttr valueAttr, int dim, 303 MutableArrayRef<uint64_t> index); 304 305 /// Prepares scalar attribute serialization. This method emits corresponding 306 /// OpConstant* and returns the result <id> associated with it. Returns 0 if 307 /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is 308 /// true, then the constant will be serialized as a specialization constant. 309 uint32_t prepareConstantScalar(Location loc, Attribute valueAttr, 310 bool isSpec = false); 311 312 uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, 313 bool isSpec = false); 314 315 uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, 316 bool isSpec = false); 317 318 uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, 319 bool isSpec = false); 320 321 //===--------------------------------------------------------------------===// 322 // Control flow 323 //===--------------------------------------------------------------------===// 324 325 /// Returns the result <id> for the given block. 326 uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); } 327 328 /// Returns the result <id> for the given block. If no <id> has been assigned, 329 /// assigns the next available <id> 330 uint32_t getOrCreateBlockID(Block *block); 331 332 /// Processes the given `block` and emits SPIR-V instructions for all ops 333 /// inside. Does not emit OpLabel for this block if `omitLabel` is true. 334 /// `actionBeforeTerminator` is a callback that will be invoked before 335 /// handling the terminator op. It can be used to inject the Op*Merge 336 /// instruction if this is a SPIR-V selection/loop header block. 337 LogicalResult 338 processBlock(Block *block, bool omitLabel = false, 339 function_ref<void()> actionBeforeTerminator = nullptr); 340 341 /// Emits OpPhi instructions for the given block if it has block arguments. 342 LogicalResult emitPhiForBlockArguments(Block *block); 343 344 LogicalResult processSelectionOp(spirv::SelectionOp selectionOp); 345 346 LogicalResult processLoopOp(spirv::LoopOp loopOp); 347 348 LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp); 349 350 LogicalResult processBranchOp(spirv::BranchOp branchOp); 351 352 //===--------------------------------------------------------------------===// 353 // Operations 354 //===--------------------------------------------------------------------===// 355 356 LogicalResult encodeExtensionInstruction(Operation *op, 357 StringRef extensionSetName, 358 uint32_t opcode, 359 ArrayRef<uint32_t> operands); 360 361 uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); } 362 363 LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp); 364 365 LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp); 366 367 /// Main dispatch method for serializing an operation. 368 LogicalResult processOperation(Operation *op); 369 370 /// Serializes an operation `op` as core instruction with `opcode` if 371 /// `extInstSet` is empty. Otherwise serializes it as an extended instruction 372 /// with `opcode` from `extInstSet`. 373 /// This method is a generic one for dispatching any SPIR-V ops that has no 374 /// variadic operands and attributes in TableGen definitions. 375 LogicalResult processOpWithoutGrammarAttr(Operation *op, StringRef extInstSet, 376 uint32_t opcode); 377 378 /// Dispatches to the serialization function for an operation in SPIR-V 379 /// dialect that is a mirror of an instruction in the SPIR-V spec. This is 380 /// auto-generated from ODS. Dispatch is handled for all operations in SPIR-V 381 /// dialect that have hasOpcode == 1. 382 LogicalResult dispatchToAutogenSerialization(Operation *op); 383 384 /// Serializes an operation in the SPIR-V dialect that is a mirror of an 385 /// instruction in the SPIR-V spec. This is auto generated if hasOpcode == 1 386 /// and autogenSerialization == 1 in ODS. 387 template <typename OpTy> 388 LogicalResult processOp(OpTy op) { 389 return op.emitError("unsupported op serialization"); 390 } 391 392 //===--------------------------------------------------------------------===// 393 // Utilities 394 //===--------------------------------------------------------------------===// 395 396 /// Emits an OpDecorate instruction to decorate the given `target` with the 397 /// given `decoration`. 398 LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration, 399 ArrayRef<uint32_t> params = {}); 400 401 /// Emits an OpLine instruction with the given `loc` location information into 402 /// the given `binary` vector. 403 LogicalResult emitDebugLine(SmallVectorImpl<uint32_t> &binary, Location loc); 404 405 private: 406 /// The SPIR-V module to be serialized. 407 spirv::ModuleOp module; 408 409 /// An MLIR builder for getting MLIR constructs. 410 mlir::Builder mlirBuilder; 411 412 /// A flag which indicates if the debuginfo should be emitted. 413 bool emitDebugInfo = false; 414 415 /// A flag which indicates if the last processed instruction was a merge 416 /// instruction. 417 /// According to SPIR-V spec: "If a branch merge instruction is used, the last 418 /// OpLine in the block must be before its merge instruction". 419 bool lastProcessedWasMergeInst = false; 420 421 /// The <id> of the OpString instruction, which specifies a file name, for 422 /// use by other debug instructions. 423 uint32_t fileID = 0; 424 425 /// The next available result <id>. 426 uint32_t nextID = 1; 427 428 // The following are for different SPIR-V instruction sections. They follow 429 // the logical layout of a SPIR-V module. 430 431 SmallVector<uint32_t, 4> capabilities; 432 SmallVector<uint32_t, 0> extensions; 433 SmallVector<uint32_t, 0> extendedSets; 434 SmallVector<uint32_t, 3> memoryModel; 435 SmallVector<uint32_t, 0> entryPoints; 436 SmallVector<uint32_t, 4> executionModes; 437 SmallVector<uint32_t, 0> debug; 438 SmallVector<uint32_t, 0> names; 439 SmallVector<uint32_t, 0> decorations; 440 SmallVector<uint32_t, 0> typesGlobalValues; 441 SmallVector<uint32_t, 0> functions; 442 443 /// Recursive struct references are serialized as OpTypePointer instructions 444 /// to the recursive struct type. However, the OpTypePointer instruction 445 /// cannot be emitted before the recursive struct's OpTypeStruct. 446 /// RecursiveStructPointerInfo stores the data needed to emit such 447 /// OpTypePointer instructions after forward references to such types. 448 struct RecursiveStructPointerInfo { 449 uint32_t pointerTypeID; 450 spirv::StorageClass storageClass; 451 }; 452 453 // Maps spirv::StructType to its recursive reference member info. 454 DenseMap<Type, SmallVector<RecursiveStructPointerInfo, 0>> 455 recursiveStructInfos; 456 457 /// `functionHeader` contains all the instructions that must be in the first 458 /// block in the function, and `functionBody` contains the rest. After 459 /// processing FuncOp, the encoded instructions of a function are appended to 460 /// `functions`. An example of instructions in `functionHeader` in order: 461 /// OpFunction ... 462 /// OpFunctionParameter ... 463 /// OpFunctionParameter ... 464 /// OpLabel ... 465 /// OpVariable ... 466 /// OpVariable ... 467 SmallVector<uint32_t, 0> functionHeader; 468 SmallVector<uint32_t, 0> functionBody; 469 470 /// Map from type used in SPIR-V module to their <id>s. 471 DenseMap<Type, uint32_t> typeIDMap; 472 473 /// Map from constant values to their <id>s. 474 DenseMap<Attribute, uint32_t> constIDMap; 475 476 /// Map from specialization constant names to their <id>s. 477 llvm::StringMap<uint32_t> specConstIDMap; 478 479 /// Map from GlobalVariableOps name to <id>s. 480 llvm::StringMap<uint32_t> globalVarIDMap; 481 482 /// Map from FuncOps name to <id>s. 483 llvm::StringMap<uint32_t> funcIDMap; 484 485 /// Map from blocks to their <id>s. 486 DenseMap<Block *, uint32_t> blockIDMap; 487 488 /// Map from the Type to the <id> that represents undef value of that type. 489 DenseMap<Type, uint32_t> undefValIDMap; 490 491 /// Map from results of normal operations to their <id>s. 492 DenseMap<Value, uint32_t> valueIDMap; 493 494 /// Map from extended instruction set name to <id>s. 495 llvm::StringMap<uint32_t> extendedInstSetIDMap; 496 497 /// Map from values used in OpPhi instructions to their offset in the 498 /// `functions` section. 499 /// 500 /// When processing a block with arguments, we need to emit OpPhi 501 /// instructions to record the predecessor block <id>s and the values they 502 /// send to the block in question. But it's not guaranteed all values are 503 /// visited and thus assigned result <id>s. So we need this list to capture 504 /// the offsets into `functions` where a value is used so that we can fix it 505 /// up later after processing all the blocks in a function. 506 /// 507 /// More concretely, say if we are visiting the following blocks: 508 /// 509 /// ```mlir 510 /// ^phi(%arg0: i32): 511 /// ... 512 /// ^parent1: 513 /// ... 514 /// spv.Branch ^phi(%val0: i32) 515 /// ^parent2: 516 /// ... 517 /// spv.Branch ^phi(%val1: i32) 518 /// ``` 519 /// 520 /// When we are serializing the `^phi` block, we need to emit at the beginning 521 /// of the block OpPhi instructions which has the following parameters: 522 /// 523 /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1 524 /// id-for-%val1 id-for-^parent2 525 /// 526 /// But we don't know the <id> for %val0 and %val1 yet. One way is to visit 527 /// all the blocks twice and use the first visit to assign an <id> to each 528 /// value. But it's paying the overheads just for OpPhi emission. Instead, 529 /// we still visit the blocks once for emission. When we emit the OpPhi 530 /// instructions, we use 0 as a placeholder for the <id>s for %val0 and %val1. 531 /// At the same time, we record their offsets in the emitted binary (which is 532 /// placed inside `functions`) here. And then after emitting all blocks, we 533 /// replace the dummy <id> 0 with the real result <id> by overwriting 534 /// `functions[offset]`. 535 DenseMap<Value, SmallVector<size_t, 1>> deferredPhiValues; 536 }; 537 } // namespace 538 539 Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo) 540 : module(module), mlirBuilder(module.getContext()), 541 emitDebugInfo(emitDebugInfo) {} 542 543 LogicalResult Serializer::serialize() { 544 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); 545 546 if (failed(module.verify())) 547 return failure(); 548 549 // TODO: handle the other sections 550 processCapability(); 551 processExtension(); 552 processMemoryModel(); 553 processDebugInfo(); 554 555 // Iterate over the module body to serialize it. Assumptions are that there is 556 // only one basic block in the moduleOp 557 for (auto &op : module.getBlock()) { 558 if (failed(processOperation(&op))) { 559 return failure(); 560 } 561 } 562 563 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); 564 return success(); 565 } 566 567 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { 568 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + 569 extensions.size() + extendedSets.size() + 570 memoryModel.size() + entryPoints.size() + 571 executionModes.size() + decorations.size() + 572 typesGlobalValues.size() + functions.size(); 573 574 binary.clear(); 575 binary.reserve(moduleSize); 576 577 spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); 578 binary.append(capabilities.begin(), capabilities.end()); 579 binary.append(extensions.begin(), extensions.end()); 580 binary.append(extendedSets.begin(), extendedSets.end()); 581 binary.append(memoryModel.begin(), memoryModel.end()); 582 binary.append(entryPoints.begin(), entryPoints.end()); 583 binary.append(executionModes.begin(), executionModes.end()); 584 binary.append(debug.begin(), debug.end()); 585 binary.append(names.begin(), names.end()); 586 binary.append(decorations.begin(), decorations.end()); 587 binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); 588 binary.append(functions.begin(), functions.end()); 589 } 590 591 #ifndef NDEBUG 592 void Serializer::printValueIDMap(raw_ostream &os) { 593 os << "\n= Value <id> Map =\n\n"; 594 for (auto valueIDPair : valueIDMap) { 595 Value val = valueIDPair.first; 596 os << " " << val << " " 597 << "id = " << valueIDPair.second << ' '; 598 if (auto *op = val.getDefiningOp()) { 599 os << "from op '" << op->getName() << "'"; 600 } else if (auto arg = val.dyn_cast<BlockArgument>()) { 601 Block *block = arg.getOwner(); 602 os << "from argument of block " << block << ' '; 603 os << " in op '" << block->getParentOp()->getName() << "'"; 604 } 605 os << '\n'; 606 } 607 } 608 #endif 609 610 //===----------------------------------------------------------------------===// 611 // Module structure 612 //===----------------------------------------------------------------------===// 613 614 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { 615 auto funcID = funcIDMap.lookup(fnName); 616 if (!funcID) { 617 funcID = getNextID(); 618 funcIDMap[fnName] = funcID; 619 } 620 return funcID; 621 } 622 623 void Serializer::processCapability() { 624 for (auto cap : module.vce_triple()->getCapabilities()) 625 encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, 626 {static_cast<uint32_t>(cap)}); 627 } 628 629 void Serializer::processDebugInfo() { 630 if (!emitDebugInfo) 631 return; 632 auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>(); 633 auto fileName = fileLoc ? fileLoc.getFilename() : "<unknown>"; 634 fileID = getNextID(); 635 SmallVector<uint32_t, 16> operands; 636 operands.push_back(fileID); 637 spirv::encodeStringLiteralInto(operands, fileName); 638 encodeInstructionInto(debug, spirv::Opcode::OpString, operands); 639 // TODO: Encode more debug instructions. 640 } 641 642 void Serializer::processExtension() { 643 llvm::SmallVector<uint32_t, 16> extName; 644 for (spirv::Extension ext : module.vce_triple()->getExtensions()) { 645 extName.clear(); 646 spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext)); 647 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); 648 } 649 } 650 651 void Serializer::processMemoryModel() { 652 uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt(); 653 uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt(); 654 655 encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm}); 656 } 657 658 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { 659 if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) { 660 valueIDMap[op.getResult()] = resultID; 661 return success(); 662 } 663 return failure(); 664 } 665 666 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { 667 if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(), 668 /*isSpec=*/true)) { 669 // Emit the OpDecorate instruction for SpecId. 670 if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) { 671 auto val = static_cast<uint32_t>(specID.getInt()); 672 emitDecoration(resultID, spirv::Decoration::SpecId, {val}); 673 } 674 675 specConstIDMap[op.sym_name()] = resultID; 676 return processName(resultID, op.sym_name()); 677 } 678 return failure(); 679 } 680 681 LogicalResult 682 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { 683 uint32_t typeID = 0; 684 if (failed(processType(op.getLoc(), op.type(), typeID))) { 685 return failure(); 686 } 687 688 auto resultID = getNextID(); 689 690 SmallVector<uint32_t, 8> operands; 691 operands.push_back(typeID); 692 operands.push_back(resultID); 693 694 auto constituents = op.constituents(); 695 696 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { 697 auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>(); 698 699 auto constituentName = constituent.getValue(); 700 auto constituentID = getSpecConstID(constituentName); 701 702 if (!constituentID) { 703 return op.emitError("unknown result <id> for specialization constant ") 704 << constituentName; 705 } 706 707 operands.push_back(constituentID); 708 } 709 710 encodeInstructionInto(typesGlobalValues, 711 spirv::Opcode::OpSpecConstantComposite, operands); 712 specConstIDMap[op.sym_name()] = resultID; 713 714 return processName(resultID, op.sym_name()); 715 } 716 717 LogicalResult 718 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { 719 uint32_t typeID = 0; 720 if (failed(processType(op.getLoc(), op.getType(), typeID))) { 721 return failure(); 722 } 723 724 auto resultID = getNextID(); 725 726 SmallVector<uint32_t, 8> operands; 727 operands.push_back(typeID); 728 operands.push_back(resultID); 729 730 Block &block = op.getRegion().getBlocks().front(); 731 Operation &enclosedOp = block.getOperations().front(); 732 733 std::string enclosedOpName; 734 llvm::raw_string_ostream rss(enclosedOpName); 735 rss << "Op" << enclosedOp.getName().stripDialect(); 736 auto enclosedOpcode = spirv::symbolizeOpcode(rss.str()); 737 738 if (!enclosedOpcode) { 739 op.emitError("Couldn't find op code for op ") 740 << enclosedOp.getName().getStringRef(); 741 return failure(); 742 } 743 744 operands.push_back(static_cast<uint32_t>(enclosedOpcode.getValue())); 745 746 // Append operands to the enclosed op to the list of operands. 747 for (Value operand : enclosedOp.getOperands()) { 748 uint32_t id = getValueID(operand); 749 assert(id && "use before def!"); 750 operands.push_back(id); 751 } 752 753 encodeInstructionInto(typesGlobalValues, 754 spirv::Opcode::OpSpecConstantOperation, operands); 755 valueIDMap[op.getResult()] = resultID; 756 757 return success(); 758 } 759 760 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { 761 auto undefType = op.getType(); 762 auto &id = undefValIDMap[undefType]; 763 if (!id) { 764 id = getNextID(); 765 uint32_t typeID = 0; 766 if (failed(processType(op.getLoc(), undefType, typeID)) || 767 failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, 768 {typeID, id}))) { 769 return failure(); 770 } 771 } 772 valueIDMap[op.getResult()] = id; 773 return success(); 774 } 775 776 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, 777 NamedAttribute attr) { 778 auto attrName = attr.first.strref(); 779 auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); 780 auto decoration = spirv::symbolizeDecoration(decorationName); 781 if (!decoration) { 782 return emitError( 783 loc, "non-argument attributes expected to have snake-case-ified " 784 "decoration name, unhandled attribute with name : ") 785 << attrName; 786 } 787 SmallVector<uint32_t, 1> args; 788 switch (decoration.getValue()) { 789 case spirv::Decoration::Binding: 790 case spirv::Decoration::DescriptorSet: 791 case spirv::Decoration::Location: 792 if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) { 793 args.push_back(intAttr.getValue().getZExtValue()); 794 break; 795 } 796 return emitError(loc, "expected integer attribute for ") << attrName; 797 case spirv::Decoration::BuiltIn: 798 if (auto strAttr = attr.second.dyn_cast<StringAttr>()) { 799 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); 800 if (enumVal) { 801 args.push_back(static_cast<uint32_t>(enumVal.getValue())); 802 break; 803 } 804 return emitError(loc, "invalid ") 805 << attrName << " attribute " << strAttr.getValue(); 806 } 807 return emitError(loc, "expected string attribute for ") << attrName; 808 case spirv::Decoration::Aliased: 809 case spirv::Decoration::Flat: 810 case spirv::Decoration::NonReadable: 811 case spirv::Decoration::NonWritable: 812 case spirv::Decoration::NoPerspective: 813 case spirv::Decoration::Restrict: 814 // For unit attributes, the args list has no values so we do nothing 815 if (auto unitAttr = attr.second.dyn_cast<UnitAttr>()) 816 break; 817 return emitError(loc, "expected unit attribute for ") << attrName; 818 default: 819 return emitError(loc, "unhandled decoration ") << decorationName; 820 } 821 return emitDecoration(resultID, decoration.getValue(), args); 822 } 823 824 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { 825 assert(!name.empty() && "unexpected empty string for OpName"); 826 827 SmallVector<uint32_t, 4> nameOperands; 828 nameOperands.push_back(resultID); 829 if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) { 830 return failure(); 831 } 832 return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); 833 } 834 835 namespace { 836 template <> 837 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( 838 Location loc, spirv::ArrayType type, uint32_t resultID) { 839 if (unsigned stride = type.getArrayStride()) { 840 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 841 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 842 } 843 return success(); 844 } 845 846 template <> 847 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>( 848 Location Loc, spirv::RuntimeArrayType type, uint32_t resultID) { 849 if (unsigned stride = type.getArrayStride()) { 850 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 851 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 852 } 853 return success(); 854 } 855 856 LogicalResult Serializer::processMemberDecoration( 857 uint32_t structID, 858 const spirv::StructType::MemberDecorationInfo &memberDecoration) { 859 SmallVector<uint32_t, 4> args( 860 {structID, memberDecoration.memberIndex, 861 static_cast<uint32_t>(memberDecoration.decoration)}); 862 if (memberDecoration.hasValue) { 863 args.push_back(memberDecoration.decorationValue); 864 } 865 return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, 866 args); 867 } 868 } // namespace 869 870 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { 871 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); 872 assert(functionHeader.empty() && functionBody.empty()); 873 874 uint32_t fnTypeID = 0; 875 // Generate type of the function. 876 processType(op.getLoc(), op.getType(), fnTypeID); 877 878 // Add the function definition. 879 SmallVector<uint32_t, 4> operands; 880 uint32_t resTypeID = 0; 881 auto resultTypes = op.getType().getResults(); 882 if (resultTypes.size() > 1) { 883 return op.emitError("cannot serialize function with multiple return types"); 884 } 885 if (failed(processType(op.getLoc(), 886 (resultTypes.empty() ? getVoidType() : resultTypes[0]), 887 resTypeID))) { 888 return failure(); 889 } 890 operands.push_back(resTypeID); 891 auto funcID = getOrCreateFunctionID(op.getName()); 892 operands.push_back(funcID); 893 operands.push_back(static_cast<uint32_t>(op.function_control())); 894 operands.push_back(fnTypeID); 895 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); 896 897 // Add function name. 898 if (failed(processName(funcID, op.getName()))) { 899 return failure(); 900 } 901 902 // Declare the parameters. 903 for (auto arg : op.getArguments()) { 904 uint32_t argTypeID = 0; 905 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { 906 return failure(); 907 } 908 auto argValueID = getNextID(); 909 valueIDMap[arg] = argValueID; 910 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, 911 {argTypeID, argValueID}); 912 } 913 914 // Process the body. 915 if (op.isExternal()) { 916 return op.emitError("external function is unhandled"); 917 } 918 919 // Some instructions (e.g., OpVariable) in a function must be in the first 920 // block in the function. These instructions will be put in functionHeader. 921 // Thus, we put the label in functionHeader first, and omit it from the first 922 // block. 923 encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, 924 {getOrCreateBlockID(&op.front())}); 925 processBlock(&op.front(), /*omitLabel=*/true); 926 if (failed(visitInPrettyBlockOrder( 927 &op.front(), [&](Block *block) { return processBlock(block); }, 928 /*skipHeader=*/true))) { 929 return failure(); 930 } 931 932 // There might be OpPhi instructions who have value references needing to fix. 933 for (auto deferredValue : deferredPhiValues) { 934 Value value = deferredValue.first; 935 uint32_t id = getValueID(value); 936 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value 937 << " to id = " << id << '\n'); 938 assert(id && "OpPhi references undefined value!"); 939 for (size_t offset : deferredValue.second) 940 functionBody[offset] = id; 941 } 942 deferredPhiValues.clear(); 943 944 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() 945 << "' --\n"); 946 // Insert OpFunctionEnd. 947 if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, 948 {}))) { 949 return failure(); 950 } 951 952 functions.append(functionHeader.begin(), functionHeader.end()); 953 functions.append(functionBody.begin(), functionBody.end()); 954 functionHeader.clear(); 955 functionBody.clear(); 956 957 return success(); 958 } 959 960 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { 961 SmallVector<uint32_t, 4> operands; 962 SmallVector<StringRef, 2> elidedAttrs; 963 uint32_t resultID = 0; 964 uint32_t resultTypeID = 0; 965 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) { 966 return failure(); 967 } 968 operands.push_back(resultTypeID); 969 resultID = getNextID(); 970 valueIDMap[op.getResult()] = resultID; 971 operands.push_back(resultID); 972 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>()); 973 if (attr) { 974 operands.push_back(static_cast<uint32_t>( 975 attr.cast<IntegerAttr>().getValue().getZExtValue())); 976 } 977 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); 978 for (auto arg : op.getODSOperands(0)) { 979 auto argID = getValueID(arg); 980 if (!argID) { 981 return emitError(op.getLoc(), "operand 0 has a use before def"); 982 } 983 operands.push_back(argID); 984 } 985 emitDebugLine(functionHeader, op.getLoc()); 986 encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands); 987 for (auto attr : op->getAttrs()) { 988 if (llvm::any_of(elidedAttrs, 989 [&](StringRef elided) { return attr.first == elided; })) { 990 continue; 991 } 992 if (failed(processDecoration(op.getLoc(), resultID, attr))) { 993 return failure(); 994 } 995 } 996 return success(); 997 } 998 999 LogicalResult 1000 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { 1001 // Get TypeID. 1002 uint32_t resultTypeID = 0; 1003 SmallVector<StringRef, 4> elidedAttrs; 1004 if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { 1005 return failure(); 1006 } 1007 1008 if (isInterfaceStructPtrType(varOp.type())) { 1009 auto structType = varOp.type() 1010 .cast<spirv::PointerType>() 1011 .getPointeeType() 1012 .cast<spirv::StructType>(); 1013 if (failed( 1014 emitDecoration(getTypeID(structType), spirv::Decoration::Block))) { 1015 return varOp.emitError("cannot decorate ") 1016 << structType << " with Block decoration"; 1017 } 1018 } 1019 1020 elidedAttrs.push_back("type"); 1021 SmallVector<uint32_t, 4> operands; 1022 operands.push_back(resultTypeID); 1023 auto resultID = getNextID(); 1024 1025 // Encode the name. 1026 auto varName = varOp.sym_name(); 1027 elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); 1028 if (failed(processName(resultID, varName))) { 1029 return failure(); 1030 } 1031 globalVarIDMap[varName] = resultID; 1032 operands.push_back(resultID); 1033 1034 // Encode StorageClass. 1035 operands.push_back(static_cast<uint32_t>(varOp.storageClass())); 1036 1037 // Encode initialization. 1038 if (auto initializer = varOp.initializer()) { 1039 auto initializerID = getVariableID(initializer.getValue()); 1040 if (!initializerID) { 1041 return emitError(varOp.getLoc(), 1042 "invalid usage of undefined variable as initializer"); 1043 } 1044 operands.push_back(initializerID); 1045 elidedAttrs.push_back("initializer"); 1046 } 1047 1048 emitDebugLine(typesGlobalValues, varOp.getLoc()); 1049 if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, 1050 operands))) { 1051 elidedAttrs.push_back("initializer"); 1052 return failure(); 1053 } 1054 1055 // Encode decorations. 1056 for (auto attr : varOp->getAttrs()) { 1057 if (llvm::any_of(elidedAttrs, 1058 [&](StringRef elided) { return attr.first == elided; })) { 1059 continue; 1060 } 1061 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { 1062 return failure(); 1063 } 1064 } 1065 return success(); 1066 } 1067 1068 //===----------------------------------------------------------------------===// 1069 // Type 1070 //===----------------------------------------------------------------------===// 1071 1072 // According to the SPIR-V spec "Validation Rules for Shader Capabilities": 1073 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and 1074 // PushConstant Storage Classes must be explicitly laid out." 1075 bool Serializer::isInterfaceStructPtrType(Type type) const { 1076 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 1077 switch (ptrType.getStorageClass()) { 1078 case spirv::StorageClass::PhysicalStorageBuffer: 1079 case spirv::StorageClass::PushConstant: 1080 case spirv::StorageClass::StorageBuffer: 1081 case spirv::StorageClass::Uniform: 1082 return ptrType.getPointeeType().isa<spirv::StructType>(); 1083 default: 1084 break; 1085 } 1086 } 1087 return false; 1088 } 1089 1090 LogicalResult Serializer::processType(Location loc, Type type, 1091 uint32_t &typeID) { 1092 // Maintains a set of names for nested identified struct types. This is used 1093 // to properly serialize recursive references. 1094 llvm::SetVector<StringRef> serializationCtx; 1095 return processTypeImpl(loc, type, typeID, serializationCtx); 1096 } 1097 1098 LogicalResult 1099 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, 1100 llvm::SetVector<StringRef> &serializationCtx) { 1101 typeID = getTypeID(type); 1102 if (typeID) { 1103 return success(); 1104 } 1105 typeID = getNextID(); 1106 SmallVector<uint32_t, 4> operands; 1107 1108 operands.push_back(typeID); 1109 auto typeEnum = spirv::Opcode::OpTypeVoid; 1110 bool deferSerialization = false; 1111 1112 if ((type.isa<FunctionType>() && 1113 succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum, 1114 operands))) || 1115 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, 1116 deferSerialization, serializationCtx))) { 1117 if (deferSerialization) 1118 return success(); 1119 1120 typeIDMap[type] = typeID; 1121 1122 if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands))) 1123 return failure(); 1124 1125 if (recursiveStructInfos.count(type) != 0) { 1126 // This recursive struct type is emitted already, now the OpTypePointer 1127 // instructions referring to recursive references are emitted as well. 1128 for (auto &ptrInfo : recursiveStructInfos[type]) { 1129 // TODO: This might not work if more than 1 recursive reference is 1130 // present in the struct. 1131 SmallVector<uint32_t, 4> ptrOperands; 1132 ptrOperands.push_back(ptrInfo.pointerTypeID); 1133 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass)); 1134 ptrOperands.push_back(typeIDMap[type]); 1135 1136 if (failed(encodeInstructionInto( 1137 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands))) 1138 return failure(); 1139 } 1140 1141 recursiveStructInfos[type].clear(); 1142 } 1143 1144 return success(); 1145 } 1146 1147 return failure(); 1148 } 1149 1150 LogicalResult Serializer::prepareBasicType( 1151 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, 1152 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization, 1153 llvm::SetVector<StringRef> &serializationCtx) { 1154 deferSerialization = false; 1155 1156 if (isVoidType(type)) { 1157 typeEnum = spirv::Opcode::OpTypeVoid; 1158 return success(); 1159 } 1160 1161 if (auto intType = type.dyn_cast<IntegerType>()) { 1162 if (intType.getWidth() == 1) { 1163 typeEnum = spirv::Opcode::OpTypeBool; 1164 return success(); 1165 } 1166 1167 typeEnum = spirv::Opcode::OpTypeInt; 1168 operands.push_back(intType.getWidth()); 1169 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 1170 // to preserve or validate. 1171 // 0 indicates unsigned, or no signedness semantics 1172 // 1 indicates signed semantics." 1173 operands.push_back(intType.isSigned() ? 1 : 0); 1174 return success(); 1175 } 1176 1177 if (auto floatType = type.dyn_cast<FloatType>()) { 1178 typeEnum = spirv::Opcode::OpTypeFloat; 1179 operands.push_back(floatType.getWidth()); 1180 return success(); 1181 } 1182 1183 if (auto vectorType = type.dyn_cast<VectorType>()) { 1184 uint32_t elementTypeID = 0; 1185 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, 1186 serializationCtx))) { 1187 return failure(); 1188 } 1189 typeEnum = spirv::Opcode::OpTypeVector; 1190 operands.push_back(elementTypeID); 1191 operands.push_back(vectorType.getNumElements()); 1192 return success(); 1193 } 1194 1195 if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) { 1196 typeEnum = spirv::Opcode::OpTypeArray; 1197 uint32_t elementTypeID = 0; 1198 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, 1199 serializationCtx))) { 1200 return failure(); 1201 } 1202 operands.push_back(elementTypeID); 1203 if (auto elementCountID = prepareConstantInt( 1204 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { 1205 operands.push_back(elementCountID); 1206 } 1207 return processTypeDecoration(loc, arrayType, resultID); 1208 } 1209 1210 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 1211 uint32_t pointeeTypeID = 0; 1212 spirv::StructType pointeeStruct = 1213 ptrType.getPointeeType().dyn_cast<spirv::StructType>(); 1214 1215 if (pointeeStruct && pointeeStruct.isIdentified() && 1216 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { 1217 // A recursive reference to an enclosing struct is found. 1218 // 1219 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage 1220 // class as operands. 1221 SmallVector<uint32_t, 2> forwardPtrOperands; 1222 forwardPtrOperands.push_back(resultID); 1223 forwardPtrOperands.push_back( 1224 static_cast<uint32_t>(ptrType.getStorageClass())); 1225 1226 encodeInstructionInto(typesGlobalValues, 1227 spirv::Opcode::OpTypeForwardPointer, 1228 forwardPtrOperands); 1229 1230 // 2. Find the pointee (enclosing) struct. 1231 auto structType = spirv::StructType::getIdentified( 1232 module.getContext(), pointeeStruct.getIdentifier()); 1233 1234 if (!structType) 1235 return failure(); 1236 1237 // 3. Mark the OpTypePointer that is supposed to be emitted by this call 1238 // as deferred. 1239 deferSerialization = true; 1240 1241 // 4. Record the info needed to emit the deferred OpTypePointer 1242 // instruction when the enclosing struct is completely serialized. 1243 recursiveStructInfos[structType].push_back( 1244 {resultID, ptrType.getStorageClass()}); 1245 } else { 1246 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, 1247 serializationCtx))) 1248 return failure(); 1249 } 1250 1251 typeEnum = spirv::Opcode::OpTypePointer; 1252 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); 1253 operands.push_back(pointeeTypeID); 1254 return success(); 1255 } 1256 1257 if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) { 1258 uint32_t elementTypeID = 0; 1259 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), 1260 elementTypeID, serializationCtx))) { 1261 return failure(); 1262 } 1263 typeEnum = spirv::Opcode::OpTypeRuntimeArray; 1264 operands.push_back(elementTypeID); 1265 return processTypeDecoration(loc, runtimeArrayType, resultID); 1266 } 1267 1268 if (auto structType = type.dyn_cast<spirv::StructType>()) { 1269 if (structType.isIdentified()) { 1270 processName(resultID, structType.getIdentifier()); 1271 serializationCtx.insert(structType.getIdentifier()); 1272 } 1273 1274 bool hasOffset = structType.hasOffset(); 1275 for (auto elementIndex : 1276 llvm::seq<uint32_t>(0, structType.getNumElements())) { 1277 uint32_t elementTypeID = 0; 1278 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), 1279 elementTypeID, serializationCtx))) { 1280 return failure(); 1281 } 1282 operands.push_back(elementTypeID); 1283 if (hasOffset) { 1284 // Decorate each struct member with an offset 1285 spirv::StructType::MemberDecorationInfo offsetDecoration{ 1286 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, 1287 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))}; 1288 if (failed(processMemberDecoration(resultID, offsetDecoration))) { 1289 return emitError(loc, "cannot decorate ") 1290 << elementIndex << "-th member of " << structType 1291 << " with its offset"; 1292 } 1293 } 1294 } 1295 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 1296 structType.getMemberDecorations(memberDecorations); 1297 1298 for (auto &memberDecoration : memberDecorations) { 1299 if (failed(processMemberDecoration(resultID, memberDecoration))) { 1300 return emitError(loc, "cannot decorate ") 1301 << static_cast<uint32_t>(memberDecoration.memberIndex) 1302 << "-th member of " << structType << " with " 1303 << stringifyDecoration(memberDecoration.decoration); 1304 } 1305 } 1306 1307 typeEnum = spirv::Opcode::OpTypeStruct; 1308 1309 if (structType.isIdentified()) 1310 serializationCtx.remove(structType.getIdentifier()); 1311 1312 return success(); 1313 } 1314 1315 if (auto cooperativeMatrixType = 1316 type.dyn_cast<spirv::CooperativeMatrixNVType>()) { 1317 uint32_t elementTypeID = 0; 1318 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), 1319 elementTypeID, serializationCtx))) { 1320 return failure(); 1321 } 1322 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; 1323 auto getConstantOp = [&](uint32_t id) { 1324 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); 1325 return prepareConstantInt(loc, attr); 1326 }; 1327 operands.push_back(elementTypeID); 1328 operands.push_back( 1329 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope()))); 1330 operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); 1331 operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); 1332 return success(); 1333 } 1334 1335 if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) { 1336 uint32_t elementTypeID = 0; 1337 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, 1338 serializationCtx))) { 1339 return failure(); 1340 } 1341 typeEnum = spirv::Opcode::OpTypeMatrix; 1342 operands.push_back(elementTypeID); 1343 operands.push_back(matrixType.getNumColumns()); 1344 return success(); 1345 } 1346 1347 // TODO: Handle other types. 1348 return emitError(loc, "unhandled type in serialization: ") << type; 1349 } 1350 1351 LogicalResult 1352 Serializer::prepareFunctionType(Location loc, FunctionType type, 1353 spirv::Opcode &typeEnum, 1354 SmallVectorImpl<uint32_t> &operands) { 1355 typeEnum = spirv::Opcode::OpTypeFunction; 1356 assert(type.getNumResults() <= 1 && 1357 "serialization supports only a single return value"); 1358 uint32_t resultID = 0; 1359 if (failed(processType( 1360 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), 1361 resultID))) { 1362 return failure(); 1363 } 1364 operands.push_back(resultID); 1365 for (auto &res : type.getInputs()) { 1366 uint32_t argTypeID = 0; 1367 if (failed(processType(loc, res, argTypeID))) { 1368 return failure(); 1369 } 1370 operands.push_back(argTypeID); 1371 } 1372 return success(); 1373 } 1374 1375 //===----------------------------------------------------------------------===// 1376 // Constant 1377 //===----------------------------------------------------------------------===// 1378 1379 uint32_t Serializer::prepareConstant(Location loc, Type constType, 1380 Attribute valueAttr) { 1381 if (auto id = prepareConstantScalar(loc, valueAttr)) { 1382 return id; 1383 } 1384 1385 // This is a composite literal. We need to handle each component separately 1386 // and then emit an OpConstantComposite for the whole. 1387 1388 if (auto id = getConstantID(valueAttr)) { 1389 return id; 1390 } 1391 1392 uint32_t typeID = 0; 1393 if (failed(processType(loc, constType, typeID))) { 1394 return 0; 1395 } 1396 1397 uint32_t resultID = 0; 1398 if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) { 1399 int rank = attr.getType().dyn_cast<ShapedType>().getRank(); 1400 SmallVector<uint64_t, 4> index(rank); 1401 resultID = prepareDenseElementsConstant(loc, constType, attr, 1402 /*dim=*/0, index); 1403 } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) { 1404 resultID = prepareArrayConstant(loc, constType, arrayAttr); 1405 } 1406 1407 if (resultID == 0) { 1408 emitError(loc, "cannot serialize attribute: ") << valueAttr; 1409 return 0; 1410 } 1411 1412 constIDMap[valueAttr] = resultID; 1413 return resultID; 1414 } 1415 1416 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, 1417 ArrayAttr attr) { 1418 uint32_t typeID = 0; 1419 if (failed(processType(loc, constType, typeID))) { 1420 return 0; 1421 } 1422 1423 uint32_t resultID = getNextID(); 1424 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 1425 operands.reserve(attr.size() + 2); 1426 auto elementType = constType.cast<spirv::ArrayType>().getElementType(); 1427 for (Attribute elementAttr : attr) { 1428 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { 1429 operands.push_back(elementID); 1430 } else { 1431 return 0; 1432 } 1433 } 1434 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 1435 encodeInstructionInto(typesGlobalValues, opcode, operands); 1436 1437 return resultID; 1438 } 1439 1440 // TODO: Turn the below function into iterative function, instead of 1441 // recursive function. 1442 uint32_t 1443 Serializer::prepareDenseElementsConstant(Location loc, Type constType, 1444 DenseElementsAttr valueAttr, int dim, 1445 MutableArrayRef<uint64_t> index) { 1446 auto shapedType = valueAttr.getType().dyn_cast<ShapedType>(); 1447 assert(dim <= shapedType.getRank()); 1448 if (shapedType.getRank() == dim) { 1449 if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) { 1450 return attr.getType().getElementType().isInteger(1) 1451 ? prepareConstantBool(loc, attr.getValue<BoolAttr>(index)) 1452 : prepareConstantInt(loc, attr.getValue<IntegerAttr>(index)); 1453 } 1454 if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) { 1455 return prepareConstantFp(loc, attr.getValue<FloatAttr>(index)); 1456 } 1457 return 0; 1458 } 1459 1460 uint32_t typeID = 0; 1461 if (failed(processType(loc, constType, typeID))) { 1462 return 0; 1463 } 1464 1465 uint32_t resultID = getNextID(); 1466 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 1467 operands.reserve(shapedType.getDimSize(dim) + 2); 1468 auto elementType = constType.cast<spirv::CompositeType>().getElementType(0); 1469 for (int i = 0; i < shapedType.getDimSize(dim); ++i) { 1470 index[dim] = i; 1471 if (auto elementID = prepareDenseElementsConstant( 1472 loc, elementType, valueAttr, dim + 1, index)) { 1473 operands.push_back(elementID); 1474 } else { 1475 return 0; 1476 } 1477 } 1478 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 1479 encodeInstructionInto(typesGlobalValues, opcode, operands); 1480 1481 return resultID; 1482 } 1483 1484 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, 1485 bool isSpec) { 1486 if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) { 1487 return prepareConstantFp(loc, floatAttr, isSpec); 1488 } 1489 if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) { 1490 return prepareConstantBool(loc, boolAttr, isSpec); 1491 } 1492 if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) { 1493 return prepareConstantInt(loc, intAttr, isSpec); 1494 } 1495 1496 return 0; 1497 } 1498 1499 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, 1500 bool isSpec) { 1501 if (!isSpec) { 1502 // We can de-duplicate normal constants, but not specialization constants. 1503 if (auto id = getConstantID(boolAttr)) { 1504 return id; 1505 } 1506 } 1507 1508 // Process the type for this bool literal 1509 uint32_t typeID = 0; 1510 if (failed(processType(loc, boolAttr.getType(), typeID))) { 1511 return 0; 1512 } 1513 1514 auto resultID = getNextID(); 1515 auto opcode = boolAttr.getValue() 1516 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue 1517 : spirv::Opcode::OpConstantTrue) 1518 : (isSpec ? spirv::Opcode::OpSpecConstantFalse 1519 : spirv::Opcode::OpConstantFalse); 1520 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); 1521 1522 if (!isSpec) { 1523 constIDMap[boolAttr] = resultID; 1524 } 1525 return resultID; 1526 } 1527 1528 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, 1529 bool isSpec) { 1530 if (!isSpec) { 1531 // We can de-duplicate normal constants, but not specialization constants. 1532 if (auto id = getConstantID(intAttr)) { 1533 return id; 1534 } 1535 } 1536 1537 // Process the type for this integer literal 1538 uint32_t typeID = 0; 1539 if (failed(processType(loc, intAttr.getType(), typeID))) { 1540 return 0; 1541 } 1542 1543 auto resultID = getNextID(); 1544 APInt value = intAttr.getValue(); 1545 unsigned bitwidth = value.getBitWidth(); 1546 bool isSigned = value.isSignedIntN(bitwidth); 1547 1548 auto opcode = 1549 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 1550 1551 // According to SPIR-V spec, "When the type's bit width is less than 32-bits, 1552 // the literal's value appears in the low-order bits of the word, and the 1553 // high-order bits must be 0 for a floating-point type, or 0 for an integer 1554 // type with Signedness of 0, or sign extended when Signedness is 1." 1555 if (bitwidth == 32 || bitwidth == 16) { 1556 uint32_t word = 0; 1557 if (isSigned) { 1558 word = static_cast<int32_t>(value.getSExtValue()); 1559 } else { 1560 word = static_cast<uint32_t>(value.getZExtValue()); 1561 } 1562 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); 1563 } 1564 // According to SPIR-V spec: "When the type's bit width is larger than one 1565 // word, the literal’s low-order words appear first." 1566 else if (bitwidth == 64) { 1567 struct DoubleWord { 1568 uint32_t word1; 1569 uint32_t word2; 1570 } words; 1571 if (isSigned) { 1572 words = llvm::bit_cast<DoubleWord>(value.getSExtValue()); 1573 } else { 1574 words = llvm::bit_cast<DoubleWord>(value.getZExtValue()); 1575 } 1576 encodeInstructionInto(typesGlobalValues, opcode, 1577 {typeID, resultID, words.word1, words.word2}); 1578 } else { 1579 std::string valueStr; 1580 llvm::raw_string_ostream rss(valueStr); 1581 value.print(rss, /*isSigned=*/false); 1582 1583 emitError(loc, "cannot serialize ") 1584 << bitwidth << "-bit integer literal: " << rss.str(); 1585 return 0; 1586 } 1587 1588 if (!isSpec) { 1589 constIDMap[intAttr] = resultID; 1590 } 1591 return resultID; 1592 } 1593 1594 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, 1595 bool isSpec) { 1596 if (!isSpec) { 1597 // We can de-duplicate normal constants, but not specialization constants. 1598 if (auto id = getConstantID(floatAttr)) { 1599 return id; 1600 } 1601 } 1602 1603 // Process the type for this float literal 1604 uint32_t typeID = 0; 1605 if (failed(processType(loc, floatAttr.getType(), typeID))) { 1606 return 0; 1607 } 1608 1609 auto resultID = getNextID(); 1610 APFloat value = floatAttr.getValue(); 1611 APInt intValue = value.bitcastToAPInt(); 1612 1613 auto opcode = 1614 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 1615 1616 if (&value.getSemantics() == &APFloat::IEEEsingle()) { 1617 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat()); 1618 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); 1619 } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { 1620 struct DoubleWord { 1621 uint32_t word1; 1622 uint32_t word2; 1623 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble()); 1624 encodeInstructionInto(typesGlobalValues, opcode, 1625 {typeID, resultID, words.word1, words.word2}); 1626 } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { 1627 uint32_t word = 1628 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()); 1629 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); 1630 } else { 1631 std::string valueStr; 1632 llvm::raw_string_ostream rss(valueStr); 1633 value.print(rss); 1634 1635 emitError(loc, "cannot serialize ") 1636 << floatAttr.getType() << "-typed float literal: " << rss.str(); 1637 return 0; 1638 } 1639 1640 if (!isSpec) { 1641 constIDMap[floatAttr] = resultID; 1642 } 1643 return resultID; 1644 } 1645 1646 //===----------------------------------------------------------------------===// 1647 // Control flow 1648 //===----------------------------------------------------------------------===// 1649 1650 uint32_t Serializer::getOrCreateBlockID(Block *block) { 1651 if (uint32_t id = getBlockID(block)) 1652 return id; 1653 return blockIDMap[block] = getNextID(); 1654 } 1655 1656 LogicalResult 1657 Serializer::processBlock(Block *block, bool omitLabel, 1658 function_ref<void()> actionBeforeTerminator) { 1659 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); 1660 LLVM_DEBUG(block->print(llvm::dbgs())); 1661 LLVM_DEBUG(llvm::dbgs() << '\n'); 1662 if (!omitLabel) { 1663 uint32_t blockID = getOrCreateBlockID(block); 1664 LLVM_DEBUG(llvm::dbgs() 1665 << "[block] " << block << " (id = " << blockID << ")\n"); 1666 1667 // Emit OpLabel for this block. 1668 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); 1669 } 1670 1671 // Emit OpPhi instructions for block arguments, if any. 1672 if (failed(emitPhiForBlockArguments(block))) 1673 return failure(); 1674 1675 // Process each op in this block except the terminator. 1676 for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { 1677 if (failed(processOperation(&op))) 1678 return failure(); 1679 } 1680 1681 // Process the terminator. 1682 if (actionBeforeTerminator) 1683 actionBeforeTerminator(); 1684 if (failed(processOperation(&block->back()))) 1685 return failure(); 1686 1687 return success(); 1688 } 1689 1690 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { 1691 // Nothing to do if this block has no arguments or it's the entry block, which 1692 // always has the same arguments as the function signature. 1693 if (block->args_empty() || block->isEntryBlock()) 1694 return success(); 1695 1696 // If the block has arguments, we need to create SPIR-V OpPhi instructions. 1697 // A SPIR-V OpPhi instruction is of the syntax: 1698 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair 1699 // So we need to collect all predecessor blocks and the arguments they send 1700 // to this block. 1701 SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors; 1702 for (Block *predecessor : block->getPredecessors()) { 1703 auto *terminator = predecessor->getTerminator(); 1704 // The predecessor here is the immediate one according to MLIR's IR 1705 // structure. It does not directly map to the incoming parent block for the 1706 // OpPhi instructions at SPIR-V binary level. This is because structured 1707 // control flow ops are serialized to multiple SPIR-V blocks. If there is a 1708 // spv.selection/spv.loop op in the MLIR predecessor block, the branch op 1709 // jumping to the OpPhi's block then resides in the previous structured 1710 // control flow op's merge block. 1711 predecessor = getPhiIncomingBlock(predecessor); 1712 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) { 1713 predecessors.emplace_back(predecessor, branchOp.operand_begin()); 1714 } else { 1715 return terminator->emitError("unimplemented terminator for Phi creation"); 1716 } 1717 } 1718 1719 // Then create OpPhi instruction for each of the block argument. 1720 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) { 1721 BlockArgument arg = block->getArgument(argIndex); 1722 1723 // Get the type <id> and result <id> for this OpPhi instruction. 1724 uint32_t phiTypeID = 0; 1725 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) 1726 return failure(); 1727 uint32_t phiID = getNextID(); 1728 1729 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' 1730 << arg << " (id = " << phiID << ")\n"); 1731 1732 // Prepare the (value <id>, parent block <id>) pairs. 1733 SmallVector<uint32_t, 8> phiArgs; 1734 phiArgs.push_back(phiTypeID); 1735 phiArgs.push_back(phiID); 1736 1737 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) { 1738 Value value = *(predecessors[predIndex].second + argIndex); 1739 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); 1740 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId 1741 << ") value " << value << ' '); 1742 // Each pair is a value <id> ... 1743 uint32_t valueId = getValueID(value); 1744 if (valueId == 0) { 1745 // The op generating this value hasn't been visited yet so we don't have 1746 // an <id> assigned yet. Record this to fix up later. 1747 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); 1748 deferredPhiValues[value].push_back(functionBody.size() + 1 + 1749 phiArgs.size()); 1750 } else { 1751 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); 1752 } 1753 phiArgs.push_back(valueId); 1754 // ... and a parent block <id>. 1755 phiArgs.push_back(predBlockId); 1756 } 1757 1758 encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); 1759 valueIDMap[arg] = phiID; 1760 } 1761 1762 return success(); 1763 } 1764 1765 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { 1766 // Assign <id>s to all blocks so that branches inside the SelectionOp can 1767 // resolve properly. 1768 auto &body = selectionOp.body(); 1769 for (Block &block : body) 1770 getOrCreateBlockID(&block); 1771 1772 auto *headerBlock = selectionOp.getHeaderBlock(); 1773 auto *mergeBlock = selectionOp.getMergeBlock(); 1774 auto mergeID = getBlockID(mergeBlock); 1775 auto loc = selectionOp.getLoc(); 1776 1777 // Emit the selection header block, which dominates all other blocks, first. 1778 // We need to emit an OpSelectionMerge instruction before the selection header 1779 // block's terminator. 1780 auto emitSelectionMerge = [&]() { 1781 emitDebugLine(functionBody, loc); 1782 lastProcessedWasMergeInst = true; 1783 encodeInstructionInto( 1784 functionBody, spirv::Opcode::OpSelectionMerge, 1785 {mergeID, static_cast<uint32_t>(selectionOp.selection_control())}); 1786 }; 1787 // For structured selection, we cannot have blocks in the selection construct 1788 // branching to the selection header block. Entering the selection (and 1789 // reaching the selection header) must be from the block containing the 1790 // spv.selection op. If there are ops ahead of the spv.selection op in the 1791 // block, we can "merge" them into the selection header. So here we don't need 1792 // to emit a separate block; just continue with the existing block. 1793 if (failed(processBlock(headerBlock, /*omitLabel=*/true, emitSelectionMerge))) 1794 return failure(); 1795 1796 // Process all blocks with a depth-first visitor starting from the header 1797 // block. The selection header block and merge block are skipped by this 1798 // visitor. 1799 if (failed(visitInPrettyBlockOrder( 1800 headerBlock, [&](Block *block) { return processBlock(block); }, 1801 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) 1802 return failure(); 1803 1804 // There is nothing to do for the merge block in the selection, which just 1805 // contains a spv.mlir.merge op, itself. But we need to have an OpLabel 1806 // instruction to start a new SPIR-V block for ops following this SelectionOp. 1807 // The block should use the <id> for the merge block. 1808 return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); 1809 } 1810 1811 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { 1812 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve 1813 // properly. We don't need to assign for the entry block, which is just for 1814 // satisfying MLIR region's structural requirement. 1815 auto &body = loopOp.body(); 1816 for (Block &block : 1817 llvm::make_range(std::next(body.begin(), 1), body.end())) { 1818 getOrCreateBlockID(&block); 1819 } 1820 auto *headerBlock = loopOp.getHeaderBlock(); 1821 auto *continueBlock = loopOp.getContinueBlock(); 1822 auto *mergeBlock = loopOp.getMergeBlock(); 1823 auto headerID = getBlockID(headerBlock); 1824 auto continueID = getBlockID(continueBlock); 1825 auto mergeID = getBlockID(mergeBlock); 1826 auto loc = loopOp.getLoc(); 1827 1828 // This LoopOp is in some MLIR block with preceding and following ops. In the 1829 // binary format, it should reside in separate SPIR-V blocks from its 1830 // preceding and following ops. So we need to emit unconditional branches to 1831 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow 1832 // afterwards. 1833 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); 1834 1835 // LoopOp's entry block is just there for satisfying MLIR's structural 1836 // requirements so we omit it and start serialization from the loop header 1837 // block. 1838 1839 // Emit the loop header block, which dominates all other blocks, first. We 1840 // need to emit an OpLoopMerge instruction before the loop header block's 1841 // terminator. 1842 auto emitLoopMerge = [&]() { 1843 emitDebugLine(functionBody, loc); 1844 lastProcessedWasMergeInst = true; 1845 encodeInstructionInto( 1846 functionBody, spirv::Opcode::OpLoopMerge, 1847 {mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())}); 1848 }; 1849 if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) 1850 return failure(); 1851 1852 // Process all blocks with a depth-first visitor starting from the header 1853 // block. The loop header block, loop continue block, and loop merge block are 1854 // skipped by this visitor and handled later in this function. 1855 if (failed(visitInPrettyBlockOrder( 1856 headerBlock, [&](Block *block) { return processBlock(block); }, 1857 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) 1858 return failure(); 1859 1860 // We have handled all other blocks. Now get to the loop continue block. 1861 if (failed(processBlock(continueBlock))) 1862 return failure(); 1863 1864 // There is nothing to do for the merge block in the loop, which just contains 1865 // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to 1866 // start a new SPIR-V block for ops following this LoopOp. The block should 1867 // use the <id> for the merge block. 1868 return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); 1869 } 1870 1871 LogicalResult Serializer::processBranchConditionalOp( 1872 spirv::BranchConditionalOp condBranchOp) { 1873 auto conditionID = getValueID(condBranchOp.condition()); 1874 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); 1875 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); 1876 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID}; 1877 1878 if (auto weights = condBranchOp.branch_weights()) { 1879 for (auto val : weights->getValue()) 1880 arguments.push_back(val.cast<IntegerAttr>().getInt()); 1881 } 1882 1883 emitDebugLine(functionBody, condBranchOp.getLoc()); 1884 return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, 1885 arguments); 1886 } 1887 1888 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { 1889 emitDebugLine(functionBody, branchOp.getLoc()); 1890 return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, 1891 {getOrCreateBlockID(branchOp.getTarget())}); 1892 } 1893 1894 //===----------------------------------------------------------------------===// 1895 // Operation 1896 //===----------------------------------------------------------------------===// 1897 1898 LogicalResult Serializer::encodeExtensionInstruction( 1899 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, 1900 ArrayRef<uint32_t> operands) { 1901 // Check if the extension has been imported. 1902 auto &setID = extendedInstSetIDMap[extensionSetName]; 1903 if (!setID) { 1904 setID = getNextID(); 1905 SmallVector<uint32_t, 16> importOperands; 1906 importOperands.push_back(setID); 1907 if (failed( 1908 spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || 1909 failed(encodeInstructionInto( 1910 extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { 1911 return failure(); 1912 } 1913 } 1914 1915 // The first two operands are the result type <id> and result <id>. The set 1916 // <id> and the opcode need to be insert after this. 1917 if (operands.size() < 2) { 1918 return op->emitError("extended instructions must have a result encoding"); 1919 } 1920 SmallVector<uint32_t, 8> extInstOperands; 1921 extInstOperands.reserve(operands.size() + 2); 1922 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); 1923 extInstOperands.push_back(setID); 1924 extInstOperands.push_back(extensionOpcode); 1925 extInstOperands.append(std::next(operands.begin(), 2), operands.end()); 1926 return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, 1927 extInstOperands); 1928 } 1929 1930 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { 1931 auto varName = addressOfOp.variable(); 1932 auto variableID = getVariableID(varName); 1933 if (!variableID) { 1934 return addressOfOp.emitError("unknown result <id> for variable ") 1935 << varName; 1936 } 1937 valueIDMap[addressOfOp.pointer()] = variableID; 1938 return success(); 1939 } 1940 1941 LogicalResult 1942 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { 1943 auto constName = referenceOfOp.spec_const(); 1944 auto constID = getSpecConstID(constName); 1945 if (!constID) { 1946 return referenceOfOp.emitError( 1947 "unknown result <id> for specialization constant ") 1948 << constName; 1949 } 1950 valueIDMap[referenceOfOp.reference()] = constID; 1951 return success(); 1952 } 1953 1954 LogicalResult Serializer::processOperation(Operation *opInst) { 1955 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); 1956 1957 // First dispatch the ops that do not directly mirror an instruction from 1958 // the SPIR-V spec. 1959 return TypeSwitch<Operation *, LogicalResult>(opInst) 1960 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) 1961 .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) 1962 .Case([&](spirv::BranchConditionalOp op) { 1963 return processBranchConditionalOp(op); 1964 }) 1965 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) 1966 .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) 1967 .Case([&](spirv::GlobalVariableOp op) { 1968 return processGlobalVariableOp(op); 1969 }) 1970 .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) 1971 .Case([&](spirv::ModuleEndOp) { return success(); }) 1972 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) 1973 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) 1974 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) 1975 .Case([&](spirv::SpecConstantCompositeOp op) { 1976 return processSpecConstantCompositeOp(op); 1977 }) 1978 .Case([&](spirv::SpecConstantOperationOp op) { 1979 return processSpecConstantOperationOp(op); 1980 }) 1981 .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) 1982 .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) 1983 1984 // Then handle all the ops that directly mirror SPIR-V instructions with 1985 // auto-generated methods. 1986 .Default( 1987 [&](Operation *op) { return dispatchToAutogenSerialization(op); }); 1988 } 1989 1990 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, 1991 StringRef extInstSet, 1992 uint32_t opcode) { 1993 SmallVector<uint32_t, 4> operands; 1994 Location loc = op->getLoc(); 1995 1996 uint32_t resultID = 0; 1997 if (op->getNumResults() != 0) { 1998 uint32_t resultTypeID = 0; 1999 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID))) 2000 return failure(); 2001 operands.push_back(resultTypeID); 2002 2003 resultID = getNextID(); 2004 operands.push_back(resultID); 2005 valueIDMap[op->getResult(0)] = resultID; 2006 }; 2007 2008 for (Value operand : op->getOperands()) 2009 operands.push_back(getValueID(operand)); 2010 2011 emitDebugLine(functionBody, loc); 2012 2013 if (extInstSet.empty()) { 2014 encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode), 2015 operands); 2016 } else { 2017 encodeExtensionInstruction(op, extInstSet, opcode, operands); 2018 } 2019 2020 if (op->getNumResults() != 0) { 2021 for (auto attr : op->getAttrs()) { 2022 if (failed(processDecoration(loc, resultID, attr))) 2023 return failure(); 2024 } 2025 } 2026 2027 return success(); 2028 } 2029 2030 namespace { 2031 template <> 2032 LogicalResult 2033 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { 2034 SmallVector<uint32_t, 4> operands; 2035 // Add the ExecutionModel. 2036 operands.push_back(static_cast<uint32_t>(op.execution_model())); 2037 // Add the function <id>. 2038 auto funcID = getFunctionID(op.fn()); 2039 if (!funcID) { 2040 return op.emitError("missing <id> for function ") 2041 << op.fn() 2042 << "; function needs to be defined before spv.EntryPoint is " 2043 "serialized"; 2044 } 2045 operands.push_back(funcID); 2046 // Add the name of the function. 2047 spirv::encodeStringLiteralInto(operands, op.fn()); 2048 2049 // Add the interface values. 2050 if (auto interface = op.interface()) { 2051 for (auto var : interface.getValue()) { 2052 auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue()); 2053 if (!id) { 2054 return op.emitError("referencing undefined global variable." 2055 "spv.EntryPoint is at the end of spv.module. All " 2056 "referenced variables should already be defined"); 2057 } 2058 operands.push_back(id); 2059 } 2060 } 2061 return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, 2062 operands); 2063 } 2064 2065 template <> 2066 LogicalResult 2067 Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) { 2068 StringRef argNames[] = {"execution_scope", "memory_scope", 2069 "memory_semantics"}; 2070 SmallVector<uint32_t, 3> operands; 2071 2072 for (auto argName : argNames) { 2073 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName); 2074 auto operand = prepareConstantInt(op.getLoc(), argIntAttr); 2075 if (!operand) { 2076 return failure(); 2077 } 2078 operands.push_back(operand); 2079 } 2080 2081 return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier, 2082 operands); 2083 } 2084 2085 template <> 2086 LogicalResult 2087 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) { 2088 SmallVector<uint32_t, 4> operands; 2089 // Add the function <id>. 2090 auto funcID = getFunctionID(op.fn()); 2091 if (!funcID) { 2092 return op.emitError("missing <id> for function ") 2093 << op.fn() 2094 << "; function needs to be serialized before ExecutionModeOp is " 2095 "serialized"; 2096 } 2097 operands.push_back(funcID); 2098 // Add the ExecutionMode. 2099 operands.push_back(static_cast<uint32_t>(op.execution_mode())); 2100 2101 // Serialize values if any. 2102 auto values = op.values(); 2103 if (values) { 2104 for (auto &intVal : values.getValue()) { 2105 operands.push_back(static_cast<uint32_t>( 2106 intVal.cast<IntegerAttr>().getValue().getZExtValue())); 2107 } 2108 } 2109 return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, 2110 operands); 2111 } 2112 2113 template <> 2114 LogicalResult 2115 Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) { 2116 StringRef argNames[] = {"memory_scope", "memory_semantics"}; 2117 SmallVector<uint32_t, 2> operands; 2118 2119 for (auto argName : argNames) { 2120 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName); 2121 auto operand = prepareConstantInt(op.getLoc(), argIntAttr); 2122 if (!operand) { 2123 return failure(); 2124 } 2125 operands.push_back(operand); 2126 } 2127 2128 return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, 2129 operands); 2130 } 2131 2132 template <> 2133 LogicalResult 2134 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) { 2135 auto funcName = op.callee(); 2136 uint32_t resTypeID = 0; 2137 2138 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); 2139 if (failed(processType(op.getLoc(), resultTy, resTypeID))) 2140 return failure(); 2141 2142 auto funcID = getOrCreateFunctionID(funcName); 2143 auto funcCallID = getNextID(); 2144 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID}; 2145 2146 for (auto value : op.arguments()) { 2147 auto valueID = getValueID(value); 2148 assert(valueID && "cannot find a value for spv.FunctionCall"); 2149 operands.push_back(valueID); 2150 } 2151 2152 if (!resultTy.isa<NoneType>()) 2153 valueIDMap[op.getResult(0)] = funcCallID; 2154 2155 return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, 2156 operands); 2157 } 2158 2159 template <> 2160 LogicalResult 2161 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) { 2162 SmallVector<uint32_t, 4> operands; 2163 SmallVector<StringRef, 2> elidedAttrs; 2164 2165 for (Value operand : op->getOperands()) { 2166 auto id = getValueID(operand); 2167 assert(id && "use before def!"); 2168 operands.push_back(id); 2169 } 2170 2171 if (auto attr = op->getAttr("memory_access")) { 2172 operands.push_back(static_cast<uint32_t>( 2173 attr.cast<IntegerAttr>().getValue().getZExtValue())); 2174 } 2175 2176 elidedAttrs.push_back("memory_access"); 2177 2178 if (auto attr = op->getAttr("alignment")) { 2179 operands.push_back(static_cast<uint32_t>( 2180 attr.cast<IntegerAttr>().getValue().getZExtValue())); 2181 } 2182 2183 elidedAttrs.push_back("alignment"); 2184 2185 if (auto attr = op->getAttr("source_memory_access")) { 2186 operands.push_back(static_cast<uint32_t>( 2187 attr.cast<IntegerAttr>().getValue().getZExtValue())); 2188 } 2189 2190 elidedAttrs.push_back("source_memory_access"); 2191 2192 if (auto attr = op->getAttr("source_alignment")) { 2193 operands.push_back(static_cast<uint32_t>( 2194 attr.cast<IntegerAttr>().getValue().getZExtValue())); 2195 } 2196 2197 elidedAttrs.push_back("source_alignment"); 2198 emitDebugLine(functionBody, op.getLoc()); 2199 encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); 2200 2201 return success(); 2202 } 2203 2204 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and 2205 // various Serializer::processOp<...>() specializations. 2206 #define GET_SERIALIZATION_FNS 2207 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" 2208 } // namespace 2209 2210 LogicalResult Serializer::emitDecoration(uint32_t target, 2211 spirv::Decoration decoration, 2212 ArrayRef<uint32_t> params) { 2213 uint32_t wordCount = 3 + params.size(); 2214 decorations.push_back( 2215 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate)); 2216 decorations.push_back(target); 2217 decorations.push_back(static_cast<uint32_t>(decoration)); 2218 decorations.append(params.begin(), params.end()); 2219 return success(); 2220 } 2221 2222 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary, 2223 Location loc) { 2224 if (!emitDebugInfo) 2225 return success(); 2226 2227 if (lastProcessedWasMergeInst) { 2228 lastProcessedWasMergeInst = false; 2229 return success(); 2230 } 2231 2232 auto fileLoc = loc.dyn_cast<FileLineColLoc>(); 2233 if (fileLoc) 2234 encodeInstructionInto(binary, spirv::Opcode::OpLine, 2235 {fileID, fileLoc.getLine(), fileLoc.getColumn()}); 2236 return success(); 2237 } 2238 2239 namespace mlir { 2240 LogicalResult spirv::serialize(spirv::ModuleOp module, 2241 SmallVectorImpl<uint32_t> &binary, 2242 bool emitDebugInfo) { 2243 if (!module.vce_triple().hasValue()) 2244 return module.emitError( 2245 "module must have 'vce_triple' attribute to be serializeable"); 2246 2247 Serializer serializer(module, emitDebugInfo); 2248 2249 if (failed(serializer.serialize())) 2250 return failure(); 2251 2252 LLVM_DEBUG(serializer.printValueIDMap(llvm::dbgs())); 2253 2254 serializer.collect(binary); 2255 return success(); 2256 } 2257 } // namespace mlir 2258