1 //===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the MLIR SPIR-V module to SPIR-V binary serialization. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Target/SPIRV/Serialization.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/RegionGraphTraits.h" 21 #include "mlir/Support/LogicalResult.h" 22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 23 #include "llvm/ADT/DepthFirstIterator.h" 24 #include "llvm/ADT/Sequence.h" 25 #include "llvm/ADT/SetVector.h" 26 #include "llvm/ADT/SmallPtrSet.h" 27 #include "llvm/ADT/SmallVector.h" 28 #include "llvm/ADT/StringExtras.h" 29 #include "llvm/ADT/TypeSwitch.h" 30 #include "llvm/ADT/bit.h" 31 #include "llvm/Support/Debug.h" 32 #include "llvm/Support/raw_ostream.h" 33 34 #define DEBUG_TYPE "spirv-serialization" 35 36 using namespace mlir; 37 38 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into 39 /// the given `binary` vector. 40 static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, 41 spirv::Opcode op, 42 ArrayRef<uint32_t> operands) { 43 uint32_t wordCount = 1 + operands.size(); 44 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 45 binary.append(operands.begin(), operands.end()); 46 return success(); 47 } 48 49 /// A pre-order depth-first visitor function for processing basic blocks. 50 /// 51 /// Visits the basic blocks starting from the given `headerBlock` in pre-order 52 /// depth-first manner and calls `blockHandler` on each block. Skips handling 53 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` 54 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s 55 /// successors. 56 /// 57 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order 58 /// of blocks in a function must satisfy the rule that blocks appear before 59 /// all blocks they dominate." This can be achieved by a pre-order CFG 60 /// traversal algorithm. To make the serialization output more logical and 61 /// readable to human, we perform depth-first CFG traversal and delay the 62 /// serialization of the merge block and the continue block, if exists, until 63 /// after all other blocks have been processed. 64 static LogicalResult 65 visitInPrettyBlockOrder(Block *headerBlock, 66 function_ref<LogicalResult(Block *)> blockHandler, 67 bool skipHeader = false, BlockRange skipBlocks = {}) { 68 llvm::df_iterator_default_set<Block *, 4> doneBlocks; 69 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end()); 70 71 for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) { 72 if (skipHeader && block == headerBlock) 73 continue; 74 if (failed(blockHandler(block))) 75 return failure(); 76 } 77 return success(); 78 } 79 80 /// Returns the merge block if the given `op` is a structured control flow op. 81 /// Otherwise returns nullptr. 82 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { 83 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) 84 return selectionOp.getMergeBlock(); 85 if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) 86 return loopOp.getMergeBlock(); 87 return nullptr; 88 } 89 90 /// Given a predecessor `block` for a block with arguments, returns the block 91 /// that should be used as the parent block for SPIR-V OpPhi instructions 92 /// corresponding to the block arguments. 93 static Block *getPhiIncomingBlock(Block *block) { 94 // If the predecessor block in question is the entry block for a spv.loop, 95 // we jump to this spv.loop from its enclosing block. 96 if (block->isEntryBlock()) { 97 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) { 98 // Then the incoming parent block for OpPhi should be the merge block of 99 // the structured control flow op before this loop. 100 Operation *op = loopOp.getOperation(); 101 while ((op = op->getPrevNode()) != nullptr) 102 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) 103 return incomingBlock; 104 // Or the enclosing block itself if no structured control flow ops 105 // exists before this loop. 106 return loopOp->getBlock(); 107 } 108 } 109 110 // Otherwise, we jump from the given predecessor block. Try to see if there is 111 // a structured control flow op inside it. 112 for (Operation &op : llvm::reverse(block->getOperations())) { 113 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) 114 return incomingBlock; 115 } 116 return block; 117 } 118 119 namespace { 120 121 /// A SPIR-V module serializer. 122 /// 123 /// A SPIR-V binary module is a single linear stream of instructions; each 124 /// instruction is composed of 32-bit words with the layout: 125 /// 126 /// | <word-count>|<opcode> | <operand> | <operand> | ... | 127 /// | <------ word -------> | <-- word --> | <-- word --> | ... | 128 /// 129 /// For the first word, the 16 high-order bits are the word count of the 130 /// instruction, the 16 low-order bits are the opcode enumerant. The 131 /// instructions then belong to different sections, which must be laid out in 132 /// the particular order as specified in "2.4 Logical Layout of a Module" of 133 /// the SPIR-V spec. 134 class Serializer { 135 public: 136 /// Creates a serializer for the given SPIR-V `module`. 137 explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false); 138 139 /// Serializes the remembered SPIR-V module. 140 LogicalResult serialize(); 141 142 /// Collects the final SPIR-V `binary`. 143 void collect(SmallVectorImpl<uint32_t> &binary); 144 145 #ifndef NDEBUG 146 /// (For debugging) prints each value and its corresponding result <id>. 147 void printValueIDMap(raw_ostream &os); 148 #endif 149 150 private: 151 // Note that there are two main categories of methods in this class: 152 // * process*() methods are meant to fully serialize a SPIR-V module entity 153 // (header, type, op, etc.). They update internal vectors containing 154 // different binary sections. They are not meant to be called except the 155 // top-level serialization loop. 156 // * prepare*() methods are meant to be helpers that prepare for serializing 157 // certain entity. They may or may not update internal vectors containing 158 // different binary sections. They are meant to be called among themselves 159 // or by other process*() methods for subtasks. 160 161 //===--------------------------------------------------------------------===// 162 // <id> 163 //===--------------------------------------------------------------------===// 164 165 // Note that it is illegal to use id <0> in SPIR-V binary module. Various 166 // methods in this class, if using SPIR-V word (uint32_t) as interface, 167 // check or return id <0> to indicate error in processing. 168 169 /// Consumes the next unused <id>. This method will never return 0. 170 uint32_t getNextID() { return nextID++; } 171 172 //===--------------------------------------------------------------------===// 173 // Module structure 174 //===--------------------------------------------------------------------===// 175 176 uint32_t getSpecConstID(StringRef constName) const { 177 return specConstIDMap.lookup(constName); 178 } 179 180 uint32_t getVariableID(StringRef varName) const { 181 return globalVarIDMap.lookup(varName); 182 } 183 184 uint32_t getFunctionID(StringRef fnName) const { 185 return funcIDMap.lookup(fnName); 186 } 187 188 /// Gets the <id> for the function with the given name. Assigns the next 189 /// available <id> if the function haven't been deserialized. 190 uint32_t getOrCreateFunctionID(StringRef fnName); 191 192 void processCapability(); 193 194 void processDebugInfo(); 195 196 void processExtension(); 197 198 void processMemoryModel(); 199 200 LogicalResult processConstantOp(spirv::ConstantOp op); 201 202 LogicalResult processSpecConstantOp(spirv::SpecConstantOp op); 203 204 LogicalResult 205 processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op); 206 207 LogicalResult 208 processSpecConstantOperationOp(spirv::SpecConstantOperationOp op); 209 210 /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA 211 /// value to use with other operations. The SPIR-V spec recommends that 212 /// OpUndef be generated at module level. The serialization generates an 213 /// OpUndef for each type needed at module level. 214 LogicalResult processUndefOp(spirv::UndefOp op); 215 216 /// Emit OpName for the given `resultID`. 217 LogicalResult processName(uint32_t resultID, StringRef name); 218 219 /// Processes a SPIR-V function op. 220 LogicalResult processFuncOp(spirv::FuncOp op); 221 222 LogicalResult processVariableOp(spirv::VariableOp op); 223 224 /// Process a SPIR-V GlobalVariableOp 225 LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp); 226 227 /// Process attributes that translate to decorations on the result <id> 228 LogicalResult processDecoration(Location loc, uint32_t resultID, 229 NamedAttribute attr); 230 231 template <typename DType> 232 LogicalResult processTypeDecoration(Location loc, DType type, 233 uint32_t resultId) { 234 return emitError(loc, "unhandled decoration for type:") << type; 235 } 236 237 /// Process member decoration 238 LogicalResult processMemberDecoration( 239 uint32_t structID, 240 const spirv::StructType::MemberDecorationInfo &memberDecorationInfo); 241 242 //===--------------------------------------------------------------------===// 243 // Types 244 //===--------------------------------------------------------------------===// 245 246 uint32_t getTypeID(Type type) const { return typeIDMap.lookup(type); } 247 248 Type getVoidType() { return mlirBuilder.getNoneType(); } 249 250 bool isVoidType(Type type) const { return type.isa<NoneType>(); } 251 252 /// Returns true if the given type is a pointer type to a struct in some 253 /// interface storage class. 254 bool isInterfaceStructPtrType(Type type) const; 255 256 /// Main dispatch method for serializing a type. The result <id> of the 257 /// serialized type will be returned as `typeID`. 258 LogicalResult processType(Location loc, Type type, uint32_t &typeID); 259 LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID, 260 llvm::SetVector<StringRef> &serializationCtx); 261 262 /// Method for preparing basic SPIR-V type serialization. Returns the type's 263 /// opcode and operands for the instruction via `typeEnum` and `operands`. 264 LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID, 265 spirv::Opcode &typeEnum, 266 SmallVectorImpl<uint32_t> &operands, 267 bool &deferSerialization, 268 llvm::SetVector<StringRef> &serializationCtx); 269 270 LogicalResult prepareFunctionType(Location loc, FunctionType type, 271 spirv::Opcode &typeEnum, 272 SmallVectorImpl<uint32_t> &operands); 273 274 //===--------------------------------------------------------------------===// 275 // Constant 276 //===--------------------------------------------------------------------===// 277 278 uint32_t getConstantID(Attribute value) const { 279 return constIDMap.lookup(value); 280 } 281 282 /// Main dispatch method for processing a constant with the given `constType` 283 /// and `valueAttr`. `constType` is needed here because we can interpret the 284 /// `valueAttr` as a different type than the type of `valueAttr` itself; for 285 /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType 286 /// constants. 287 uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr); 288 289 /// Prepares array attribute serialization. This method emits corresponding 290 /// OpConstant* and returns the result <id> associated with it. Returns 0 if 291 /// failed. 292 uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr); 293 294 /// Prepares bool/int/float DenseElementsAttr serialization. This method 295 /// iterates the DenseElementsAttr to construct the constant array, and 296 /// returns the result <id> associated with it. Returns 0 if failed. Note 297 /// that the size of `index` must match the rank. 298 /// TODO: Consider to enhance splat elements cases. For splat cases, 299 /// we don't need to loop over all elements, especially when the splat value 300 /// is zero. We can use OpConstantNull when the value is zero. 301 uint32_t prepareDenseElementsConstant(Location loc, Type constType, 302 DenseElementsAttr valueAttr, int dim, 303 MutableArrayRef<uint64_t> index); 304 305 /// Prepares scalar attribute serialization. This method emits corresponding 306 /// OpConstant* and returns the result <id> associated with it. Returns 0 if 307 /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is 308 /// true, then the constant will be serialized as a specialization constant. 309 uint32_t prepareConstantScalar(Location loc, Attribute valueAttr, 310 bool isSpec = false); 311 312 uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, 313 bool isSpec = false); 314 315 uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, 316 bool isSpec = false); 317 318 uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, 319 bool isSpec = false); 320 321 //===--------------------------------------------------------------------===// 322 // Control flow 323 //===--------------------------------------------------------------------===// 324 325 /// Returns the result <id> for the given block. 326 uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); } 327 328 /// Returns the result <id> for the given block. If no <id> has been assigned, 329 /// assigns the next available <id> 330 uint32_t getOrCreateBlockID(Block *block); 331 332 /// Processes the given `block` and emits SPIR-V instructions for all ops 333 /// inside. Does not emit OpLabel for this block if `omitLabel` is true. 334 /// `actionBeforeTerminator` is a callback that will be invoked before 335 /// handling the terminator op. It can be used to inject the Op*Merge 336 /// instruction if this is a SPIR-V selection/loop header block. 337 LogicalResult 338 processBlock(Block *block, bool omitLabel = false, 339 function_ref<void()> actionBeforeTerminator = nullptr); 340 341 /// Emits OpPhi instructions for the given block if it has block arguments. 342 LogicalResult emitPhiForBlockArguments(Block *block); 343 344 LogicalResult processSelectionOp(spirv::SelectionOp selectionOp); 345 346 LogicalResult processLoopOp(spirv::LoopOp loopOp); 347 348 LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp); 349 350 LogicalResult processBranchOp(spirv::BranchOp branchOp); 351 352 //===--------------------------------------------------------------------===// 353 // Operations 354 //===--------------------------------------------------------------------===// 355 356 LogicalResult encodeExtensionInstruction(Operation *op, 357 StringRef extensionSetName, 358 uint32_t opcode, 359 ArrayRef<uint32_t> operands); 360 361 uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); } 362 363 LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp); 364 365 LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp); 366 367 /// Main dispatch method for serializing an operation. 368 LogicalResult processOperation(Operation *op); 369 370 /// Serializes an operation `op` as core instruction with `opcode` if 371 /// `extInstSet` is empty. Otherwise serializes it as an extended instruction 372 /// with `opcode` from `extInstSet`. 373 /// This method is a generic one for dispatching any SPIR-V ops that has no 374 /// variadic operands and attributes in TableGen definitions. 375 LogicalResult processOpWithoutGrammarAttr(Operation *op, StringRef extInstSet, 376 uint32_t opcode); 377 378 /// Dispatches to the serialization function for an operation in SPIR-V 379 /// dialect that is a mirror of an instruction in the SPIR-V spec. This is 380 /// auto-generated from ODS. Dispatch is handled for all operations in SPIR-V 381 /// dialect that have hasOpcode == 1. 382 LogicalResult dispatchToAutogenSerialization(Operation *op); 383 384 /// Serializes an operation in the SPIR-V dialect that is a mirror of an 385 /// instruction in the SPIR-V spec. This is auto generated if hasOpcode == 1 386 /// and autogenSerialization == 1 in ODS. 387 template <typename OpTy> 388 LogicalResult processOp(OpTy op) { 389 return op.emitError("unsupported op serialization"); 390 } 391 392 //===--------------------------------------------------------------------===// 393 // Utilities 394 //===--------------------------------------------------------------------===// 395 396 /// Emits an OpDecorate instruction to decorate the given `target` with the 397 /// given `decoration`. 398 LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration, 399 ArrayRef<uint32_t> params = {}); 400 401 /// Emits an OpLine instruction with the given `loc` location information into 402 /// the given `binary` vector. 403 LogicalResult emitDebugLine(SmallVectorImpl<uint32_t> &binary, Location loc); 404 405 private: 406 /// The SPIR-V module to be serialized. 407 spirv::ModuleOp module; 408 409 /// An MLIR builder for getting MLIR constructs. 410 mlir::Builder mlirBuilder; 411 412 /// A flag which indicates if the debuginfo should be emitted. 413 bool emitDebugInfo = false; 414 415 /// A flag which indicates if the last processed instruction was a merge 416 /// instruction. 417 /// According to SPIR-V spec: "If a branch merge instruction is used, the last 418 /// OpLine in the block must be before its merge instruction". 419 bool lastProcessedWasMergeInst = false; 420 421 /// The <id> of the OpString instruction, which specifies a file name, for 422 /// use by other debug instructions. 423 uint32_t fileID = 0; 424 425 /// The next available result <id>. 426 uint32_t nextID = 1; 427 428 // The following are for different SPIR-V instruction sections. They follow 429 // the logical layout of a SPIR-V module. 430 431 SmallVector<uint32_t, 4> capabilities; 432 SmallVector<uint32_t, 0> extensions; 433 SmallVector<uint32_t, 0> extendedSets; 434 SmallVector<uint32_t, 3> memoryModel; 435 SmallVector<uint32_t, 0> entryPoints; 436 SmallVector<uint32_t, 4> executionModes; 437 SmallVector<uint32_t, 0> debug; 438 SmallVector<uint32_t, 0> names; 439 SmallVector<uint32_t, 0> decorations; 440 SmallVector<uint32_t, 0> typesGlobalValues; 441 SmallVector<uint32_t, 0> functions; 442 443 /// Recursive struct references are serialized as OpTypePointer instructions 444 /// to the recursive struct type. However, the OpTypePointer instruction 445 /// cannot be emitted before the recursive struct's OpTypeStruct. 446 /// RecursiveStructPointerInfo stores the data needed to emit such 447 /// OpTypePointer instructions after forward references to such types. 448 struct RecursiveStructPointerInfo { 449 uint32_t pointerTypeID; 450 spirv::StorageClass storageClass; 451 }; 452 453 // Maps spirv::StructType to its recursive reference member info. 454 DenseMap<Type, SmallVector<RecursiveStructPointerInfo, 0>> 455 recursiveStructInfos; 456 457 /// `functionHeader` contains all the instructions that must be in the first 458 /// block in the function, and `functionBody` contains the rest. After 459 /// processing FuncOp, the encoded instructions of a function are appended to 460 /// `functions`. An example of instructions in `functionHeader` in order: 461 /// OpFunction ... 462 /// OpFunctionParameter ... 463 /// OpFunctionParameter ... 464 /// OpLabel ... 465 /// OpVariable ... 466 /// OpVariable ... 467 SmallVector<uint32_t, 0> functionHeader; 468 SmallVector<uint32_t, 0> functionBody; 469 470 /// Map from type used in SPIR-V module to their <id>s. 471 DenseMap<Type, uint32_t> typeIDMap; 472 473 /// Map from constant values to their <id>s. 474 DenseMap<Attribute, uint32_t> constIDMap; 475 476 /// Map from specialization constant names to their <id>s. 477 llvm::StringMap<uint32_t> specConstIDMap; 478 479 /// Map from GlobalVariableOps name to <id>s. 480 llvm::StringMap<uint32_t> globalVarIDMap; 481 482 /// Map from FuncOps name to <id>s. 483 llvm::StringMap<uint32_t> funcIDMap; 484 485 /// Map from blocks to their <id>s. 486 DenseMap<Block *, uint32_t> blockIDMap; 487 488 /// Map from the Type to the <id> that represents undef value of that type. 489 DenseMap<Type, uint32_t> undefValIDMap; 490 491 /// Map from results of normal operations to their <id>s. 492 DenseMap<Value, uint32_t> valueIDMap; 493 494 /// Map from extended instruction set name to <id>s. 495 llvm::StringMap<uint32_t> extendedInstSetIDMap; 496 497 /// Map from values used in OpPhi instructions to their offset in the 498 /// `functions` section. 499 /// 500 /// When processing a block with arguments, we need to emit OpPhi 501 /// instructions to record the predecessor block <id>s and the values they 502 /// send to the block in question. But it's not guaranteed all values are 503 /// visited and thus assigned result <id>s. So we need this list to capture 504 /// the offsets into `functions` where a value is used so that we can fix it 505 /// up later after processing all the blocks in a function. 506 /// 507 /// More concretely, say if we are visiting the following blocks: 508 /// 509 /// ```mlir 510 /// ^phi(%arg0: i32): 511 /// ... 512 /// ^parent1: 513 /// ... 514 /// spv.Branch ^phi(%val0: i32) 515 /// ^parent2: 516 /// ... 517 /// spv.Branch ^phi(%val1: i32) 518 /// ``` 519 /// 520 /// When we are serializing the `^phi` block, we need to emit at the beginning 521 /// of the block OpPhi instructions which has the following parameters: 522 /// 523 /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1 524 /// id-for-%val1 id-for-^parent2 525 /// 526 /// But we don't know the <id> for %val0 and %val1 yet. One way is to visit 527 /// all the blocks twice and use the first visit to assign an <id> to each 528 /// value. But it's paying the overheads just for OpPhi emission. Instead, 529 /// we still visit the blocks once for emission. When we emit the OpPhi 530 /// instructions, we use 0 as a placeholder for the <id>s for %val0 and %val1. 531 /// At the same time, we record their offsets in the emitted binary (which is 532 /// placed inside `functions`) here. And then after emitting all blocks, we 533 /// replace the dummy <id> 0 with the real result <id> by overwriting 534 /// `functions[offset]`. 535 DenseMap<Value, SmallVector<size_t, 1>> deferredPhiValues; 536 }; 537 } // namespace 538 539 Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo) 540 : module(module), mlirBuilder(module.getContext()), 541 emitDebugInfo(emitDebugInfo) {} 542 543 LogicalResult Serializer::serialize() { 544 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); 545 546 if (failed(module.verify())) 547 return failure(); 548 549 // TODO: handle the other sections 550 processCapability(); 551 processExtension(); 552 processMemoryModel(); 553 processDebugInfo(); 554 555 // Iterate over the module body to serialize it. Assumptions are that there is 556 // only one basic block in the moduleOp 557 for (auto &op : module.getBlock()) { 558 if (failed(processOperation(&op))) { 559 return failure(); 560 } 561 } 562 563 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); 564 return success(); 565 } 566 567 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { 568 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + 569 extensions.size() + extendedSets.size() + 570 memoryModel.size() + entryPoints.size() + 571 executionModes.size() + decorations.size() + 572 typesGlobalValues.size() + functions.size(); 573 574 binary.clear(); 575 binary.reserve(moduleSize); 576 577 spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); 578 binary.append(capabilities.begin(), capabilities.end()); 579 binary.append(extensions.begin(), extensions.end()); 580 binary.append(extendedSets.begin(), extendedSets.end()); 581 binary.append(memoryModel.begin(), memoryModel.end()); 582 binary.append(entryPoints.begin(), entryPoints.end()); 583 binary.append(executionModes.begin(), executionModes.end()); 584 binary.append(debug.begin(), debug.end()); 585 binary.append(names.begin(), names.end()); 586 binary.append(decorations.begin(), decorations.end()); 587 binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); 588 binary.append(functions.begin(), functions.end()); 589 } 590 591 #ifndef NDEBUG 592 void Serializer::printValueIDMap(raw_ostream &os) { 593 os << "\n= Value <id> Map =\n\n"; 594 for (auto valueIDPair : valueIDMap) { 595 Value val = valueIDPair.first; 596 os << " " << val << " " 597 << "id = " << valueIDPair.second << ' '; 598 if (auto *op = val.getDefiningOp()) { 599 os << "from op '" << op->getName() << "'"; 600 } else if (auto arg = val.dyn_cast<BlockArgument>()) { 601 Block *block = arg.getOwner(); 602 os << "from argument of block " << block << ' '; 603 os << " in op '" << block->getParentOp()->getName() << "'"; 604 } 605 os << '\n'; 606 } 607 } 608 #endif 609 610 //===----------------------------------------------------------------------===// 611 // Module structure 612 //===----------------------------------------------------------------------===// 613 614 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { 615 auto funcID = funcIDMap.lookup(fnName); 616 if (!funcID) { 617 funcID = getNextID(); 618 funcIDMap[fnName] = funcID; 619 } 620 return funcID; 621 } 622 623 void Serializer::processCapability() { 624 for (auto cap : module.vce_triple()->getCapabilities()) 625 (void)encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, 626 {static_cast<uint32_t>(cap)}); 627 } 628 629 void Serializer::processDebugInfo() { 630 if (!emitDebugInfo) 631 return; 632 auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>(); 633 auto fileName = fileLoc ? fileLoc.getFilename() : "<unknown>"; 634 fileID = getNextID(); 635 SmallVector<uint32_t, 16> operands; 636 operands.push_back(fileID); 637 (void)spirv::encodeStringLiteralInto(operands, fileName); 638 (void)encodeInstructionInto(debug, spirv::Opcode::OpString, operands); 639 // TODO: Encode more debug instructions. 640 } 641 642 void Serializer::processExtension() { 643 llvm::SmallVector<uint32_t, 16> extName; 644 for (spirv::Extension ext : module.vce_triple()->getExtensions()) { 645 extName.clear(); 646 (void)spirv::encodeStringLiteralInto(extName, 647 spirv::stringifyExtension(ext)); 648 (void)encodeInstructionInto(extensions, spirv::Opcode::OpExtension, 649 extName); 650 } 651 } 652 653 void Serializer::processMemoryModel() { 654 uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt(); 655 uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt(); 656 657 (void)encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, 658 {am, mm}); 659 } 660 661 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { 662 if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) { 663 valueIDMap[op.getResult()] = resultID; 664 return success(); 665 } 666 return failure(); 667 } 668 669 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { 670 if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(), 671 /*isSpec=*/true)) { 672 // Emit the OpDecorate instruction for SpecId. 673 if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) { 674 auto val = static_cast<uint32_t>(specID.getInt()); 675 (void)emitDecoration(resultID, spirv::Decoration::SpecId, {val}); 676 } 677 678 specConstIDMap[op.sym_name()] = resultID; 679 return processName(resultID, op.sym_name()); 680 } 681 return failure(); 682 } 683 684 LogicalResult 685 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { 686 uint32_t typeID = 0; 687 if (failed(processType(op.getLoc(), op.type(), typeID))) { 688 return failure(); 689 } 690 691 auto resultID = getNextID(); 692 693 SmallVector<uint32_t, 8> operands; 694 operands.push_back(typeID); 695 operands.push_back(resultID); 696 697 auto constituents = op.constituents(); 698 699 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { 700 auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>(); 701 702 auto constituentName = constituent.getValue(); 703 auto constituentID = getSpecConstID(constituentName); 704 705 if (!constituentID) { 706 return op.emitError("unknown result <id> for specialization constant ") 707 << constituentName; 708 } 709 710 operands.push_back(constituentID); 711 } 712 713 (void)encodeInstructionInto(typesGlobalValues, 714 spirv::Opcode::OpSpecConstantComposite, operands); 715 specConstIDMap[op.sym_name()] = resultID; 716 717 return processName(resultID, op.sym_name()); 718 } 719 720 LogicalResult 721 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { 722 uint32_t typeID = 0; 723 if (failed(processType(op.getLoc(), op.getType(), typeID))) { 724 return failure(); 725 } 726 727 auto resultID = getNextID(); 728 729 SmallVector<uint32_t, 8> operands; 730 operands.push_back(typeID); 731 operands.push_back(resultID); 732 733 Block &block = op.getRegion().getBlocks().front(); 734 Operation &enclosedOp = block.getOperations().front(); 735 736 std::string enclosedOpName; 737 llvm::raw_string_ostream rss(enclosedOpName); 738 rss << "Op" << enclosedOp.getName().stripDialect(); 739 auto enclosedOpcode = spirv::symbolizeOpcode(rss.str()); 740 741 if (!enclosedOpcode) { 742 op.emitError("Couldn't find op code for op ") 743 << enclosedOp.getName().getStringRef(); 744 return failure(); 745 } 746 747 operands.push_back(static_cast<uint32_t>(enclosedOpcode.getValue())); 748 749 // Append operands to the enclosed op to the list of operands. 750 for (Value operand : enclosedOp.getOperands()) { 751 uint32_t id = getValueID(operand); 752 assert(id && "use before def!"); 753 operands.push_back(id); 754 } 755 756 (void)encodeInstructionInto(typesGlobalValues, 757 spirv::Opcode::OpSpecConstantOp, operands); 758 valueIDMap[op.getResult()] = resultID; 759 760 return success(); 761 } 762 763 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { 764 auto undefType = op.getType(); 765 auto &id = undefValIDMap[undefType]; 766 if (!id) { 767 id = getNextID(); 768 uint32_t typeID = 0; 769 if (failed(processType(op.getLoc(), undefType, typeID)) || 770 failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, 771 {typeID, id}))) { 772 return failure(); 773 } 774 } 775 valueIDMap[op.getResult()] = id; 776 return success(); 777 } 778 779 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, 780 NamedAttribute attr) { 781 auto attrName = attr.first.strref(); 782 auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); 783 auto decoration = spirv::symbolizeDecoration(decorationName); 784 if (!decoration) { 785 return emitError( 786 loc, "non-argument attributes expected to have snake-case-ified " 787 "decoration name, unhandled attribute with name : ") 788 << attrName; 789 } 790 SmallVector<uint32_t, 1> args; 791 switch (decoration.getValue()) { 792 case spirv::Decoration::Binding: 793 case spirv::Decoration::DescriptorSet: 794 case spirv::Decoration::Location: 795 if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) { 796 args.push_back(intAttr.getValue().getZExtValue()); 797 break; 798 } 799 return emitError(loc, "expected integer attribute for ") << attrName; 800 case spirv::Decoration::BuiltIn: 801 if (auto strAttr = attr.second.dyn_cast<StringAttr>()) { 802 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); 803 if (enumVal) { 804 args.push_back(static_cast<uint32_t>(enumVal.getValue())); 805 break; 806 } 807 return emitError(loc, "invalid ") 808 << attrName << " attribute " << strAttr.getValue(); 809 } 810 return emitError(loc, "expected string attribute for ") << attrName; 811 case spirv::Decoration::Aliased: 812 case spirv::Decoration::Flat: 813 case spirv::Decoration::NonReadable: 814 case spirv::Decoration::NonWritable: 815 case spirv::Decoration::NoPerspective: 816 case spirv::Decoration::Restrict: 817 // For unit attributes, the args list has no values so we do nothing 818 if (auto unitAttr = attr.second.dyn_cast<UnitAttr>()) 819 break; 820 return emitError(loc, "expected unit attribute for ") << attrName; 821 default: 822 return emitError(loc, "unhandled decoration ") << decorationName; 823 } 824 return emitDecoration(resultID, decoration.getValue(), args); 825 } 826 827 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { 828 assert(!name.empty() && "unexpected empty string for OpName"); 829 830 SmallVector<uint32_t, 4> nameOperands; 831 nameOperands.push_back(resultID); 832 if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) { 833 return failure(); 834 } 835 return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); 836 } 837 838 namespace { 839 template <> 840 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( 841 Location loc, spirv::ArrayType type, uint32_t resultID) { 842 if (unsigned stride = type.getArrayStride()) { 843 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 844 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 845 } 846 return success(); 847 } 848 849 template <> 850 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>( 851 Location Loc, spirv::RuntimeArrayType type, uint32_t resultID) { 852 if (unsigned stride = type.getArrayStride()) { 853 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 854 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 855 } 856 return success(); 857 } 858 859 LogicalResult Serializer::processMemberDecoration( 860 uint32_t structID, 861 const spirv::StructType::MemberDecorationInfo &memberDecoration) { 862 SmallVector<uint32_t, 4> args( 863 {structID, memberDecoration.memberIndex, 864 static_cast<uint32_t>(memberDecoration.decoration)}); 865 if (memberDecoration.hasValue) { 866 args.push_back(memberDecoration.decorationValue); 867 } 868 return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, 869 args); 870 } 871 } // namespace 872 873 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { 874 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); 875 assert(functionHeader.empty() && functionBody.empty()); 876 877 uint32_t fnTypeID = 0; 878 // Generate type of the function. 879 (void)processType(op.getLoc(), op.getType(), fnTypeID); 880 881 // Add the function definition. 882 SmallVector<uint32_t, 4> operands; 883 uint32_t resTypeID = 0; 884 auto resultTypes = op.getType().getResults(); 885 if (resultTypes.size() > 1) { 886 return op.emitError("cannot serialize function with multiple return types"); 887 } 888 if (failed(processType(op.getLoc(), 889 (resultTypes.empty() ? getVoidType() : resultTypes[0]), 890 resTypeID))) { 891 return failure(); 892 } 893 operands.push_back(resTypeID); 894 auto funcID = getOrCreateFunctionID(op.getName()); 895 operands.push_back(funcID); 896 operands.push_back(static_cast<uint32_t>(op.function_control())); 897 operands.push_back(fnTypeID); 898 (void)encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, 899 operands); 900 901 // Add function name. 902 if (failed(processName(funcID, op.getName()))) { 903 return failure(); 904 } 905 906 // Declare the parameters. 907 for (auto arg : op.getArguments()) { 908 uint32_t argTypeID = 0; 909 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { 910 return failure(); 911 } 912 auto argValueID = getNextID(); 913 valueIDMap[arg] = argValueID; 914 (void)encodeInstructionInto(functionHeader, 915 spirv::Opcode::OpFunctionParameter, 916 {argTypeID, argValueID}); 917 } 918 919 // Process the body. 920 if (op.isExternal()) { 921 return op.emitError("external function is unhandled"); 922 } 923 924 // Some instructions (e.g., OpVariable) in a function must be in the first 925 // block in the function. These instructions will be put in functionHeader. 926 // Thus, we put the label in functionHeader first, and omit it from the first 927 // block. 928 (void)encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, 929 {getOrCreateBlockID(&op.front())}); 930 (void)processBlock(&op.front(), /*omitLabel=*/true); 931 if (failed(visitInPrettyBlockOrder( 932 &op.front(), [&](Block *block) { return processBlock(block); }, 933 /*skipHeader=*/true))) { 934 return failure(); 935 } 936 937 // There might be OpPhi instructions who have value references needing to fix. 938 for (auto deferredValue : deferredPhiValues) { 939 Value value = deferredValue.first; 940 uint32_t id = getValueID(value); 941 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value 942 << " to id = " << id << '\n'); 943 assert(id && "OpPhi references undefined value!"); 944 for (size_t offset : deferredValue.second) 945 functionBody[offset] = id; 946 } 947 deferredPhiValues.clear(); 948 949 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() 950 << "' --\n"); 951 // Insert OpFunctionEnd. 952 if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, 953 {}))) { 954 return failure(); 955 } 956 957 functions.append(functionHeader.begin(), functionHeader.end()); 958 functions.append(functionBody.begin(), functionBody.end()); 959 functionHeader.clear(); 960 functionBody.clear(); 961 962 return success(); 963 } 964 965 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { 966 SmallVector<uint32_t, 4> operands; 967 SmallVector<StringRef, 2> elidedAttrs; 968 uint32_t resultID = 0; 969 uint32_t resultTypeID = 0; 970 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) { 971 return failure(); 972 } 973 operands.push_back(resultTypeID); 974 resultID = getNextID(); 975 valueIDMap[op.getResult()] = resultID; 976 operands.push_back(resultID); 977 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>()); 978 if (attr) { 979 operands.push_back(static_cast<uint32_t>( 980 attr.cast<IntegerAttr>().getValue().getZExtValue())); 981 } 982 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); 983 for (auto arg : op.getODSOperands(0)) { 984 auto argID = getValueID(arg); 985 if (!argID) { 986 return emitError(op.getLoc(), "operand 0 has a use before def"); 987 } 988 operands.push_back(argID); 989 } 990 (void)emitDebugLine(functionHeader, op.getLoc()); 991 (void)encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, 992 operands); 993 for (auto attr : op->getAttrs()) { 994 if (llvm::any_of(elidedAttrs, 995 [&](StringRef elided) { return attr.first == elided; })) { 996 continue; 997 } 998 if (failed(processDecoration(op.getLoc(), resultID, attr))) { 999 return failure(); 1000 } 1001 } 1002 return success(); 1003 } 1004 1005 LogicalResult 1006 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { 1007 // Get TypeID. 1008 uint32_t resultTypeID = 0; 1009 SmallVector<StringRef, 4> elidedAttrs; 1010 if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { 1011 return failure(); 1012 } 1013 1014 if (isInterfaceStructPtrType(varOp.type())) { 1015 auto structType = varOp.type() 1016 .cast<spirv::PointerType>() 1017 .getPointeeType() 1018 .cast<spirv::StructType>(); 1019 if (failed( 1020 emitDecoration(getTypeID(structType), spirv::Decoration::Block))) { 1021 return varOp.emitError("cannot decorate ") 1022 << structType << " with Block decoration"; 1023 } 1024 } 1025 1026 elidedAttrs.push_back("type"); 1027 SmallVector<uint32_t, 4> operands; 1028 operands.push_back(resultTypeID); 1029 auto resultID = getNextID(); 1030 1031 // Encode the name. 1032 auto varName = varOp.sym_name(); 1033 elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); 1034 if (failed(processName(resultID, varName))) { 1035 return failure(); 1036 } 1037 globalVarIDMap[varName] = resultID; 1038 operands.push_back(resultID); 1039 1040 // Encode StorageClass. 1041 operands.push_back(static_cast<uint32_t>(varOp.storageClass())); 1042 1043 // Encode initialization. 1044 if (auto initializer = varOp.initializer()) { 1045 auto initializerID = getVariableID(initializer.getValue()); 1046 if (!initializerID) { 1047 return emitError(varOp.getLoc(), 1048 "invalid usage of undefined variable as initializer"); 1049 } 1050 operands.push_back(initializerID); 1051 elidedAttrs.push_back("initializer"); 1052 } 1053 1054 (void)emitDebugLine(typesGlobalValues, varOp.getLoc()); 1055 if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, 1056 operands))) { 1057 elidedAttrs.push_back("initializer"); 1058 return failure(); 1059 } 1060 1061 // Encode decorations. 1062 for (auto attr : varOp->getAttrs()) { 1063 if (llvm::any_of(elidedAttrs, 1064 [&](StringRef elided) { return attr.first == elided; })) { 1065 continue; 1066 } 1067 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { 1068 return failure(); 1069 } 1070 } 1071 return success(); 1072 } 1073 1074 //===----------------------------------------------------------------------===// 1075 // Type 1076 //===----------------------------------------------------------------------===// 1077 1078 // According to the SPIR-V spec "Validation Rules for Shader Capabilities": 1079 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and 1080 // PushConstant Storage Classes must be explicitly laid out." 1081 bool Serializer::isInterfaceStructPtrType(Type type) const { 1082 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 1083 switch (ptrType.getStorageClass()) { 1084 case spirv::StorageClass::PhysicalStorageBuffer: 1085 case spirv::StorageClass::PushConstant: 1086 case spirv::StorageClass::StorageBuffer: 1087 case spirv::StorageClass::Uniform: 1088 return ptrType.getPointeeType().isa<spirv::StructType>(); 1089 default: 1090 break; 1091 } 1092 } 1093 return false; 1094 } 1095 1096 LogicalResult Serializer::processType(Location loc, Type type, 1097 uint32_t &typeID) { 1098 // Maintains a set of names for nested identified struct types. This is used 1099 // to properly serialize recursive references. 1100 llvm::SetVector<StringRef> serializationCtx; 1101 return processTypeImpl(loc, type, typeID, serializationCtx); 1102 } 1103 1104 LogicalResult 1105 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, 1106 llvm::SetVector<StringRef> &serializationCtx) { 1107 typeID = getTypeID(type); 1108 if (typeID) { 1109 return success(); 1110 } 1111 typeID = getNextID(); 1112 SmallVector<uint32_t, 4> operands; 1113 1114 operands.push_back(typeID); 1115 auto typeEnum = spirv::Opcode::OpTypeVoid; 1116 bool deferSerialization = false; 1117 1118 if ((type.isa<FunctionType>() && 1119 succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum, 1120 operands))) || 1121 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, 1122 deferSerialization, serializationCtx))) { 1123 if (deferSerialization) 1124 return success(); 1125 1126 typeIDMap[type] = typeID; 1127 1128 if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands))) 1129 return failure(); 1130 1131 if (recursiveStructInfos.count(type) != 0) { 1132 // This recursive struct type is emitted already, now the OpTypePointer 1133 // instructions referring to recursive references are emitted as well. 1134 for (auto &ptrInfo : recursiveStructInfos[type]) { 1135 // TODO: This might not work if more than 1 recursive reference is 1136 // present in the struct. 1137 SmallVector<uint32_t, 4> ptrOperands; 1138 ptrOperands.push_back(ptrInfo.pointerTypeID); 1139 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass)); 1140 ptrOperands.push_back(typeIDMap[type]); 1141 1142 if (failed(encodeInstructionInto( 1143 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands))) 1144 return failure(); 1145 } 1146 1147 recursiveStructInfos[type].clear(); 1148 } 1149 1150 return success(); 1151 } 1152 1153 return failure(); 1154 } 1155 1156 LogicalResult Serializer::prepareBasicType( 1157 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, 1158 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization, 1159 llvm::SetVector<StringRef> &serializationCtx) { 1160 deferSerialization = false; 1161 1162 if (isVoidType(type)) { 1163 typeEnum = spirv::Opcode::OpTypeVoid; 1164 return success(); 1165 } 1166 1167 if (auto intType = type.dyn_cast<IntegerType>()) { 1168 if (intType.getWidth() == 1) { 1169 typeEnum = spirv::Opcode::OpTypeBool; 1170 return success(); 1171 } 1172 1173 typeEnum = spirv::Opcode::OpTypeInt; 1174 operands.push_back(intType.getWidth()); 1175 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 1176 // to preserve or validate. 1177 // 0 indicates unsigned, or no signedness semantics 1178 // 1 indicates signed semantics." 1179 operands.push_back(intType.isSigned() ? 1 : 0); 1180 return success(); 1181 } 1182 1183 if (auto floatType = type.dyn_cast<FloatType>()) { 1184 typeEnum = spirv::Opcode::OpTypeFloat; 1185 operands.push_back(floatType.getWidth()); 1186 return success(); 1187 } 1188 1189 if (auto vectorType = type.dyn_cast<VectorType>()) { 1190 uint32_t elementTypeID = 0; 1191 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, 1192 serializationCtx))) { 1193 return failure(); 1194 } 1195 typeEnum = spirv::Opcode::OpTypeVector; 1196 operands.push_back(elementTypeID); 1197 operands.push_back(vectorType.getNumElements()); 1198 return success(); 1199 } 1200 1201 if (auto imageType = type.dyn_cast<spirv::ImageType>()) { 1202 typeEnum = spirv::Opcode::OpTypeImage; 1203 uint32_t sampledTypeID = 0; 1204 if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) 1205 return failure(); 1206 1207 operands.push_back(sampledTypeID); 1208 operands.push_back(static_cast<uint32_t>(imageType.getDim())); 1209 operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo())); 1210 operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo())); 1211 operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo())); 1212 operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo())); 1213 operands.push_back(static_cast<uint32_t>(imageType.getImageFormat())); 1214 return success(); 1215 } 1216 1217 if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) { 1218 typeEnum = spirv::Opcode::OpTypeArray; 1219 uint32_t elementTypeID = 0; 1220 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, 1221 serializationCtx))) { 1222 return failure(); 1223 } 1224 operands.push_back(elementTypeID); 1225 if (auto elementCountID = prepareConstantInt( 1226 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { 1227 operands.push_back(elementCountID); 1228 } 1229 return processTypeDecoration(loc, arrayType, resultID); 1230 } 1231 1232 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 1233 uint32_t pointeeTypeID = 0; 1234 spirv::StructType pointeeStruct = 1235 ptrType.getPointeeType().dyn_cast<spirv::StructType>(); 1236 1237 if (pointeeStruct && pointeeStruct.isIdentified() && 1238 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { 1239 // A recursive reference to an enclosing struct is found. 1240 // 1241 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage 1242 // class as operands. 1243 SmallVector<uint32_t, 2> forwardPtrOperands; 1244 forwardPtrOperands.push_back(resultID); 1245 forwardPtrOperands.push_back( 1246 static_cast<uint32_t>(ptrType.getStorageClass())); 1247 1248 (void)encodeInstructionInto(typesGlobalValues, 1249 spirv::Opcode::OpTypeForwardPointer, 1250 forwardPtrOperands); 1251 1252 // 2. Find the pointee (enclosing) struct. 1253 auto structType = spirv::StructType::getIdentified( 1254 module.getContext(), pointeeStruct.getIdentifier()); 1255 1256 if (!structType) 1257 return failure(); 1258 1259 // 3. Mark the OpTypePointer that is supposed to be emitted by this call 1260 // as deferred. 1261 deferSerialization = true; 1262 1263 // 4. Record the info needed to emit the deferred OpTypePointer 1264 // instruction when the enclosing struct is completely serialized. 1265 recursiveStructInfos[structType].push_back( 1266 {resultID, ptrType.getStorageClass()}); 1267 } else { 1268 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, 1269 serializationCtx))) 1270 return failure(); 1271 } 1272 1273 typeEnum = spirv::Opcode::OpTypePointer; 1274 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); 1275 operands.push_back(pointeeTypeID); 1276 return success(); 1277 } 1278 1279 if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) { 1280 uint32_t elementTypeID = 0; 1281 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), 1282 elementTypeID, serializationCtx))) { 1283 return failure(); 1284 } 1285 typeEnum = spirv::Opcode::OpTypeRuntimeArray; 1286 operands.push_back(elementTypeID); 1287 return processTypeDecoration(loc, runtimeArrayType, resultID); 1288 } 1289 1290 if (auto structType = type.dyn_cast<spirv::StructType>()) { 1291 if (structType.isIdentified()) { 1292 (void)processName(resultID, structType.getIdentifier()); 1293 serializationCtx.insert(structType.getIdentifier()); 1294 } 1295 1296 bool hasOffset = structType.hasOffset(); 1297 for (auto elementIndex : 1298 llvm::seq<uint32_t>(0, structType.getNumElements())) { 1299 uint32_t elementTypeID = 0; 1300 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), 1301 elementTypeID, serializationCtx))) { 1302 return failure(); 1303 } 1304 operands.push_back(elementTypeID); 1305 if (hasOffset) { 1306 // Decorate each struct member with an offset 1307 spirv::StructType::MemberDecorationInfo offsetDecoration{ 1308 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, 1309 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))}; 1310 if (failed(processMemberDecoration(resultID, offsetDecoration))) { 1311 return emitError(loc, "cannot decorate ") 1312 << elementIndex << "-th member of " << structType 1313 << " with its offset"; 1314 } 1315 } 1316 } 1317 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 1318 structType.getMemberDecorations(memberDecorations); 1319 1320 for (auto &memberDecoration : memberDecorations) { 1321 if (failed(processMemberDecoration(resultID, memberDecoration))) { 1322 return emitError(loc, "cannot decorate ") 1323 << static_cast<uint32_t>(memberDecoration.memberIndex) 1324 << "-th member of " << structType << " with " 1325 << stringifyDecoration(memberDecoration.decoration); 1326 } 1327 } 1328 1329 typeEnum = spirv::Opcode::OpTypeStruct; 1330 1331 if (structType.isIdentified()) 1332 serializationCtx.remove(structType.getIdentifier()); 1333 1334 return success(); 1335 } 1336 1337 if (auto cooperativeMatrixType = 1338 type.dyn_cast<spirv::CooperativeMatrixNVType>()) { 1339 uint32_t elementTypeID = 0; 1340 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), 1341 elementTypeID, serializationCtx))) { 1342 return failure(); 1343 } 1344 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; 1345 auto getConstantOp = [&](uint32_t id) { 1346 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); 1347 return prepareConstantInt(loc, attr); 1348 }; 1349 operands.push_back(elementTypeID); 1350 operands.push_back( 1351 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope()))); 1352 operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); 1353 operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); 1354 return success(); 1355 } 1356 1357 if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) { 1358 uint32_t elementTypeID = 0; 1359 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, 1360 serializationCtx))) { 1361 return failure(); 1362 } 1363 typeEnum = spirv::Opcode::OpTypeMatrix; 1364 operands.push_back(elementTypeID); 1365 operands.push_back(matrixType.getNumColumns()); 1366 return success(); 1367 } 1368 1369 // TODO: Handle other types. 1370 return emitError(loc, "unhandled type in serialization: ") << type; 1371 } 1372 1373 LogicalResult 1374 Serializer::prepareFunctionType(Location loc, FunctionType type, 1375 spirv::Opcode &typeEnum, 1376 SmallVectorImpl<uint32_t> &operands) { 1377 typeEnum = spirv::Opcode::OpTypeFunction; 1378 assert(type.getNumResults() <= 1 && 1379 "serialization supports only a single return value"); 1380 uint32_t resultID = 0; 1381 if (failed(processType( 1382 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), 1383 resultID))) { 1384 return failure(); 1385 } 1386 operands.push_back(resultID); 1387 for (auto &res : type.getInputs()) { 1388 uint32_t argTypeID = 0; 1389 if (failed(processType(loc, res, argTypeID))) { 1390 return failure(); 1391 } 1392 operands.push_back(argTypeID); 1393 } 1394 return success(); 1395 } 1396 1397 //===----------------------------------------------------------------------===// 1398 // Constant 1399 //===----------------------------------------------------------------------===// 1400 1401 uint32_t Serializer::prepareConstant(Location loc, Type constType, 1402 Attribute valueAttr) { 1403 if (auto id = prepareConstantScalar(loc, valueAttr)) { 1404 return id; 1405 } 1406 1407 // This is a composite literal. We need to handle each component separately 1408 // and then emit an OpConstantComposite for the whole. 1409 1410 if (auto id = getConstantID(valueAttr)) { 1411 return id; 1412 } 1413 1414 uint32_t typeID = 0; 1415 if (failed(processType(loc, constType, typeID))) { 1416 return 0; 1417 } 1418 1419 uint32_t resultID = 0; 1420 if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) { 1421 int rank = attr.getType().dyn_cast<ShapedType>().getRank(); 1422 SmallVector<uint64_t, 4> index(rank); 1423 resultID = prepareDenseElementsConstant(loc, constType, attr, 1424 /*dim=*/0, index); 1425 } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) { 1426 resultID = prepareArrayConstant(loc, constType, arrayAttr); 1427 } 1428 1429 if (resultID == 0) { 1430 emitError(loc, "cannot serialize attribute: ") << valueAttr; 1431 return 0; 1432 } 1433 1434 constIDMap[valueAttr] = resultID; 1435 return resultID; 1436 } 1437 1438 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, 1439 ArrayAttr attr) { 1440 uint32_t typeID = 0; 1441 if (failed(processType(loc, constType, typeID))) { 1442 return 0; 1443 } 1444 1445 uint32_t resultID = getNextID(); 1446 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 1447 operands.reserve(attr.size() + 2); 1448 auto elementType = constType.cast<spirv::ArrayType>().getElementType(); 1449 for (Attribute elementAttr : attr) { 1450 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { 1451 operands.push_back(elementID); 1452 } else { 1453 return 0; 1454 } 1455 } 1456 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 1457 (void)encodeInstructionInto(typesGlobalValues, opcode, operands); 1458 1459 return resultID; 1460 } 1461 1462 // TODO: Turn the below function into iterative function, instead of 1463 // recursive function. 1464 uint32_t 1465 Serializer::prepareDenseElementsConstant(Location loc, Type constType, 1466 DenseElementsAttr valueAttr, int dim, 1467 MutableArrayRef<uint64_t> index) { 1468 auto shapedType = valueAttr.getType().dyn_cast<ShapedType>(); 1469 assert(dim <= shapedType.getRank()); 1470 if (shapedType.getRank() == dim) { 1471 if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) { 1472 return attr.getType().getElementType().isInteger(1) 1473 ? prepareConstantBool(loc, attr.getValue<BoolAttr>(index)) 1474 : prepareConstantInt(loc, attr.getValue<IntegerAttr>(index)); 1475 } 1476 if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) { 1477 return prepareConstantFp(loc, attr.getValue<FloatAttr>(index)); 1478 } 1479 return 0; 1480 } 1481 1482 uint32_t typeID = 0; 1483 if (failed(processType(loc, constType, typeID))) { 1484 return 0; 1485 } 1486 1487 uint32_t resultID = getNextID(); 1488 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 1489 operands.reserve(shapedType.getDimSize(dim) + 2); 1490 auto elementType = constType.cast<spirv::CompositeType>().getElementType(0); 1491 for (int i = 0; i < shapedType.getDimSize(dim); ++i) { 1492 index[dim] = i; 1493 if (auto elementID = prepareDenseElementsConstant( 1494 loc, elementType, valueAttr, dim + 1, index)) { 1495 operands.push_back(elementID); 1496 } else { 1497 return 0; 1498 } 1499 } 1500 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 1501 (void)encodeInstructionInto(typesGlobalValues, opcode, operands); 1502 1503 return resultID; 1504 } 1505 1506 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, 1507 bool isSpec) { 1508 if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) { 1509 return prepareConstantFp(loc, floatAttr, isSpec); 1510 } 1511 if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) { 1512 return prepareConstantBool(loc, boolAttr, isSpec); 1513 } 1514 if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) { 1515 return prepareConstantInt(loc, intAttr, isSpec); 1516 } 1517 1518 return 0; 1519 } 1520 1521 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, 1522 bool isSpec) { 1523 if (!isSpec) { 1524 // We can de-duplicate normal constants, but not specialization constants. 1525 if (auto id = getConstantID(boolAttr)) { 1526 return id; 1527 } 1528 } 1529 1530 // Process the type for this bool literal 1531 uint32_t typeID = 0; 1532 if (failed(processType(loc, boolAttr.getType(), typeID))) { 1533 return 0; 1534 } 1535 1536 auto resultID = getNextID(); 1537 auto opcode = boolAttr.getValue() 1538 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue 1539 : spirv::Opcode::OpConstantTrue) 1540 : (isSpec ? spirv::Opcode::OpSpecConstantFalse 1541 : spirv::Opcode::OpConstantFalse); 1542 (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); 1543 1544 if (!isSpec) { 1545 constIDMap[boolAttr] = resultID; 1546 } 1547 return resultID; 1548 } 1549 1550 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, 1551 bool isSpec) { 1552 if (!isSpec) { 1553 // We can de-duplicate normal constants, but not specialization constants. 1554 if (auto id = getConstantID(intAttr)) { 1555 return id; 1556 } 1557 } 1558 1559 // Process the type for this integer literal 1560 uint32_t typeID = 0; 1561 if (failed(processType(loc, intAttr.getType(), typeID))) { 1562 return 0; 1563 } 1564 1565 auto resultID = getNextID(); 1566 APInt value = intAttr.getValue(); 1567 unsigned bitwidth = value.getBitWidth(); 1568 bool isSigned = value.isSignedIntN(bitwidth); 1569 1570 auto opcode = 1571 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 1572 1573 // According to SPIR-V spec, "When the type's bit width is less than 32-bits, 1574 // the literal's value appears in the low-order bits of the word, and the 1575 // high-order bits must be 0 for a floating-point type, or 0 for an integer 1576 // type with Signedness of 0, or sign extended when Signedness is 1." 1577 if (bitwidth == 32 || bitwidth == 16) { 1578 uint32_t word = 0; 1579 if (isSigned) { 1580 word = static_cast<int32_t>(value.getSExtValue()); 1581 } else { 1582 word = static_cast<uint32_t>(value.getZExtValue()); 1583 } 1584 (void)encodeInstructionInto(typesGlobalValues, opcode, 1585 {typeID, resultID, word}); 1586 } 1587 // According to SPIR-V spec: "When the type's bit width is larger than one 1588 // word, the literal’s low-order words appear first." 1589 else if (bitwidth == 64) { 1590 struct DoubleWord { 1591 uint32_t word1; 1592 uint32_t word2; 1593 } words; 1594 if (isSigned) { 1595 words = llvm::bit_cast<DoubleWord>(value.getSExtValue()); 1596 } else { 1597 words = llvm::bit_cast<DoubleWord>(value.getZExtValue()); 1598 } 1599 (void)encodeInstructionInto(typesGlobalValues, opcode, 1600 {typeID, resultID, words.word1, words.word2}); 1601 } else { 1602 std::string valueStr; 1603 llvm::raw_string_ostream rss(valueStr); 1604 value.print(rss, /*isSigned=*/false); 1605 1606 emitError(loc, "cannot serialize ") 1607 << bitwidth << "-bit integer literal: " << rss.str(); 1608 return 0; 1609 } 1610 1611 if (!isSpec) { 1612 constIDMap[intAttr] = resultID; 1613 } 1614 return resultID; 1615 } 1616 1617 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, 1618 bool isSpec) { 1619 if (!isSpec) { 1620 // We can de-duplicate normal constants, but not specialization constants. 1621 if (auto id = getConstantID(floatAttr)) { 1622 return id; 1623 } 1624 } 1625 1626 // Process the type for this float literal 1627 uint32_t typeID = 0; 1628 if (failed(processType(loc, floatAttr.getType(), typeID))) { 1629 return 0; 1630 } 1631 1632 auto resultID = getNextID(); 1633 APFloat value = floatAttr.getValue(); 1634 APInt intValue = value.bitcastToAPInt(); 1635 1636 auto opcode = 1637 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 1638 1639 if (&value.getSemantics() == &APFloat::IEEEsingle()) { 1640 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat()); 1641 (void)encodeInstructionInto(typesGlobalValues, opcode, 1642 {typeID, resultID, word}); 1643 } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { 1644 struct DoubleWord { 1645 uint32_t word1; 1646 uint32_t word2; 1647 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble()); 1648 (void)encodeInstructionInto(typesGlobalValues, opcode, 1649 {typeID, resultID, words.word1, words.word2}); 1650 } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { 1651 uint32_t word = 1652 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()); 1653 (void)encodeInstructionInto(typesGlobalValues, opcode, 1654 {typeID, resultID, word}); 1655 } else { 1656 std::string valueStr; 1657 llvm::raw_string_ostream rss(valueStr); 1658 value.print(rss); 1659 1660 emitError(loc, "cannot serialize ") 1661 << floatAttr.getType() << "-typed float literal: " << rss.str(); 1662 return 0; 1663 } 1664 1665 if (!isSpec) { 1666 constIDMap[floatAttr] = resultID; 1667 } 1668 return resultID; 1669 } 1670 1671 //===----------------------------------------------------------------------===// 1672 // Control flow 1673 //===----------------------------------------------------------------------===// 1674 1675 uint32_t Serializer::getOrCreateBlockID(Block *block) { 1676 if (uint32_t id = getBlockID(block)) 1677 return id; 1678 return blockIDMap[block] = getNextID(); 1679 } 1680 1681 LogicalResult 1682 Serializer::processBlock(Block *block, bool omitLabel, 1683 function_ref<void()> actionBeforeTerminator) { 1684 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); 1685 LLVM_DEBUG(block->print(llvm::dbgs())); 1686 LLVM_DEBUG(llvm::dbgs() << '\n'); 1687 if (!omitLabel) { 1688 uint32_t blockID = getOrCreateBlockID(block); 1689 LLVM_DEBUG(llvm::dbgs() 1690 << "[block] " << block << " (id = " << blockID << ")\n"); 1691 1692 // Emit OpLabel for this block. 1693 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, 1694 {blockID}); 1695 } 1696 1697 // Emit OpPhi instructions for block arguments, if any. 1698 if (failed(emitPhiForBlockArguments(block))) 1699 return failure(); 1700 1701 // Process each op in this block except the terminator. 1702 for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { 1703 if (failed(processOperation(&op))) 1704 return failure(); 1705 } 1706 1707 // Process the terminator. 1708 if (actionBeforeTerminator) 1709 actionBeforeTerminator(); 1710 if (failed(processOperation(&block->back()))) 1711 return failure(); 1712 1713 return success(); 1714 } 1715 1716 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { 1717 // Nothing to do if this block has no arguments or it's the entry block, which 1718 // always has the same arguments as the function signature. 1719 if (block->args_empty() || block->isEntryBlock()) 1720 return success(); 1721 1722 // If the block has arguments, we need to create SPIR-V OpPhi instructions. 1723 // A SPIR-V OpPhi instruction is of the syntax: 1724 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair 1725 // So we need to collect all predecessor blocks and the arguments they send 1726 // to this block. 1727 SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors; 1728 for (Block *predecessor : block->getPredecessors()) { 1729 auto *terminator = predecessor->getTerminator(); 1730 // The predecessor here is the immediate one according to MLIR's IR 1731 // structure. It does not directly map to the incoming parent block for the 1732 // OpPhi instructions at SPIR-V binary level. This is because structured 1733 // control flow ops are serialized to multiple SPIR-V blocks. If there is a 1734 // spv.selection/spv.loop op in the MLIR predecessor block, the branch op 1735 // jumping to the OpPhi's block then resides in the previous structured 1736 // control flow op's merge block. 1737 predecessor = getPhiIncomingBlock(predecessor); 1738 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) { 1739 predecessors.emplace_back(predecessor, branchOp.operand_begin()); 1740 } else { 1741 return terminator->emitError("unimplemented terminator for Phi creation"); 1742 } 1743 } 1744 1745 // Then create OpPhi instruction for each of the block argument. 1746 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) { 1747 BlockArgument arg = block->getArgument(argIndex); 1748 1749 // Get the type <id> and result <id> for this OpPhi instruction. 1750 uint32_t phiTypeID = 0; 1751 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) 1752 return failure(); 1753 uint32_t phiID = getNextID(); 1754 1755 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' 1756 << arg << " (id = " << phiID << ")\n"); 1757 1758 // Prepare the (value <id>, parent block <id>) pairs. 1759 SmallVector<uint32_t, 8> phiArgs; 1760 phiArgs.push_back(phiTypeID); 1761 phiArgs.push_back(phiID); 1762 1763 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) { 1764 Value value = *(predecessors[predIndex].second + argIndex); 1765 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); 1766 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId 1767 << ") value " << value << ' '); 1768 // Each pair is a value <id> ... 1769 uint32_t valueId = getValueID(value); 1770 if (valueId == 0) { 1771 // The op generating this value hasn't been visited yet so we don't have 1772 // an <id> assigned yet. Record this to fix up later. 1773 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); 1774 deferredPhiValues[value].push_back(functionBody.size() + 1 + 1775 phiArgs.size()); 1776 } else { 1777 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); 1778 } 1779 phiArgs.push_back(valueId); 1780 // ... and a parent block <id>. 1781 phiArgs.push_back(predBlockId); 1782 } 1783 1784 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); 1785 valueIDMap[arg] = phiID; 1786 } 1787 1788 return success(); 1789 } 1790 1791 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { 1792 // Assign <id>s to all blocks so that branches inside the SelectionOp can 1793 // resolve properly. 1794 auto &body = selectionOp.body(); 1795 for (Block &block : body) 1796 getOrCreateBlockID(&block); 1797 1798 auto *headerBlock = selectionOp.getHeaderBlock(); 1799 auto *mergeBlock = selectionOp.getMergeBlock(); 1800 auto mergeID = getBlockID(mergeBlock); 1801 auto loc = selectionOp.getLoc(); 1802 1803 // Emit the selection header block, which dominates all other blocks, first. 1804 // We need to emit an OpSelectionMerge instruction before the selection header 1805 // block's terminator. 1806 auto emitSelectionMerge = [&]() { 1807 (void)emitDebugLine(functionBody, loc); 1808 lastProcessedWasMergeInst = true; 1809 (void)encodeInstructionInto( 1810 functionBody, spirv::Opcode::OpSelectionMerge, 1811 {mergeID, static_cast<uint32_t>(selectionOp.selection_control())}); 1812 }; 1813 // For structured selection, we cannot have blocks in the selection construct 1814 // branching to the selection header block. Entering the selection (and 1815 // reaching the selection header) must be from the block containing the 1816 // spv.selection op. If there are ops ahead of the spv.selection op in the 1817 // block, we can "merge" them into the selection header. So here we don't need 1818 // to emit a separate block; just continue with the existing block. 1819 if (failed(processBlock(headerBlock, /*omitLabel=*/true, emitSelectionMerge))) 1820 return failure(); 1821 1822 // Process all blocks with a depth-first visitor starting from the header 1823 // block. The selection header block and merge block are skipped by this 1824 // visitor. 1825 if (failed(visitInPrettyBlockOrder( 1826 headerBlock, [&](Block *block) { return processBlock(block); }, 1827 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) 1828 return failure(); 1829 1830 // There is nothing to do for the merge block in the selection, which just 1831 // contains a spv.mlir.merge op, itself. But we need to have an OpLabel 1832 // instruction to start a new SPIR-V block for ops following this SelectionOp. 1833 // The block should use the <id> for the merge block. 1834 return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); 1835 } 1836 1837 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { 1838 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve 1839 // properly. We don't need to assign for the entry block, which is just for 1840 // satisfying MLIR region's structural requirement. 1841 auto &body = loopOp.body(); 1842 for (Block &block : 1843 llvm::make_range(std::next(body.begin(), 1), body.end())) { 1844 getOrCreateBlockID(&block); 1845 } 1846 auto *headerBlock = loopOp.getHeaderBlock(); 1847 auto *continueBlock = loopOp.getContinueBlock(); 1848 auto *mergeBlock = loopOp.getMergeBlock(); 1849 auto headerID = getBlockID(headerBlock); 1850 auto continueID = getBlockID(continueBlock); 1851 auto mergeID = getBlockID(mergeBlock); 1852 auto loc = loopOp.getLoc(); 1853 1854 // This LoopOp is in some MLIR block with preceding and following ops. In the 1855 // binary format, it should reside in separate SPIR-V blocks from its 1856 // preceding and following ops. So we need to emit unconditional branches to 1857 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow 1858 // afterwards. 1859 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, 1860 {headerID}); 1861 1862 // LoopOp's entry block is just there for satisfying MLIR's structural 1863 // requirements so we omit it and start serialization from the loop header 1864 // block. 1865 1866 // Emit the loop header block, which dominates all other blocks, first. We 1867 // need to emit an OpLoopMerge instruction before the loop header block's 1868 // terminator. 1869 auto emitLoopMerge = [&]() { 1870 (void)emitDebugLine(functionBody, loc); 1871 lastProcessedWasMergeInst = true; 1872 (void)encodeInstructionInto( 1873 functionBody, spirv::Opcode::OpLoopMerge, 1874 {mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())}); 1875 }; 1876 if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) 1877 return failure(); 1878 1879 // Process all blocks with a depth-first visitor starting from the header 1880 // block. The loop header block, loop continue block, and loop merge block are 1881 // skipped by this visitor and handled later in this function. 1882 if (failed(visitInPrettyBlockOrder( 1883 headerBlock, [&](Block *block) { return processBlock(block); }, 1884 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) 1885 return failure(); 1886 1887 // We have handled all other blocks. Now get to the loop continue block. 1888 if (failed(processBlock(continueBlock))) 1889 return failure(); 1890 1891 // There is nothing to do for the merge block in the loop, which just contains 1892 // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to 1893 // start a new SPIR-V block for ops following this LoopOp. The block should 1894 // use the <id> for the merge block. 1895 return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); 1896 } 1897 1898 LogicalResult Serializer::processBranchConditionalOp( 1899 spirv::BranchConditionalOp condBranchOp) { 1900 auto conditionID = getValueID(condBranchOp.condition()); 1901 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); 1902 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); 1903 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID}; 1904 1905 if (auto weights = condBranchOp.branch_weights()) { 1906 for (auto val : weights->getValue()) 1907 arguments.push_back(val.cast<IntegerAttr>().getInt()); 1908 } 1909 1910 (void)emitDebugLine(functionBody, condBranchOp.getLoc()); 1911 return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, 1912 arguments); 1913 } 1914 1915 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { 1916 (void)emitDebugLine(functionBody, branchOp.getLoc()); 1917 return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, 1918 {getOrCreateBlockID(branchOp.getTarget())}); 1919 } 1920 1921 //===----------------------------------------------------------------------===// 1922 // Operation 1923 //===----------------------------------------------------------------------===// 1924 1925 LogicalResult Serializer::encodeExtensionInstruction( 1926 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, 1927 ArrayRef<uint32_t> operands) { 1928 // Check if the extension has been imported. 1929 auto &setID = extendedInstSetIDMap[extensionSetName]; 1930 if (!setID) { 1931 setID = getNextID(); 1932 SmallVector<uint32_t, 16> importOperands; 1933 importOperands.push_back(setID); 1934 if (failed( 1935 spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || 1936 failed(encodeInstructionInto( 1937 extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { 1938 return failure(); 1939 } 1940 } 1941 1942 // The first two operands are the result type <id> and result <id>. The set 1943 // <id> and the opcode need to be insert after this. 1944 if (operands.size() < 2) { 1945 return op->emitError("extended instructions must have a result encoding"); 1946 } 1947 SmallVector<uint32_t, 8> extInstOperands; 1948 extInstOperands.reserve(operands.size() + 2); 1949 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); 1950 extInstOperands.push_back(setID); 1951 extInstOperands.push_back(extensionOpcode); 1952 extInstOperands.append(std::next(operands.begin(), 2), operands.end()); 1953 return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, 1954 extInstOperands); 1955 } 1956 1957 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { 1958 auto varName = addressOfOp.variable(); 1959 auto variableID = getVariableID(varName); 1960 if (!variableID) { 1961 return addressOfOp.emitError("unknown result <id> for variable ") 1962 << varName; 1963 } 1964 valueIDMap[addressOfOp.pointer()] = variableID; 1965 return success(); 1966 } 1967 1968 LogicalResult 1969 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { 1970 auto constName = referenceOfOp.spec_const(); 1971 auto constID = getSpecConstID(constName); 1972 if (!constID) { 1973 return referenceOfOp.emitError( 1974 "unknown result <id> for specialization constant ") 1975 << constName; 1976 } 1977 valueIDMap[referenceOfOp.reference()] = constID; 1978 return success(); 1979 } 1980 1981 LogicalResult Serializer::processOperation(Operation *opInst) { 1982 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); 1983 1984 // First dispatch the ops that do not directly mirror an instruction from 1985 // the SPIR-V spec. 1986 return TypeSwitch<Operation *, LogicalResult>(opInst) 1987 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) 1988 .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) 1989 .Case([&](spirv::BranchConditionalOp op) { 1990 return processBranchConditionalOp(op); 1991 }) 1992 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) 1993 .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) 1994 .Case([&](spirv::GlobalVariableOp op) { 1995 return processGlobalVariableOp(op); 1996 }) 1997 .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) 1998 .Case([&](spirv::ModuleEndOp) { return success(); }) 1999 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) 2000 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) 2001 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) 2002 .Case([&](spirv::SpecConstantCompositeOp op) { 2003 return processSpecConstantCompositeOp(op); 2004 }) 2005 .Case([&](spirv::SpecConstantOperationOp op) { 2006 return processSpecConstantOperationOp(op); 2007 }) 2008 .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) 2009 .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) 2010 2011 // Then handle all the ops that directly mirror SPIR-V instructions with 2012 // auto-generated methods. 2013 .Default( 2014 [&](Operation *op) { return dispatchToAutogenSerialization(op); }); 2015 } 2016 2017 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, 2018 StringRef extInstSet, 2019 uint32_t opcode) { 2020 SmallVector<uint32_t, 4> operands; 2021 Location loc = op->getLoc(); 2022 2023 uint32_t resultID = 0; 2024 if (op->getNumResults() != 0) { 2025 uint32_t resultTypeID = 0; 2026 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID))) 2027 return failure(); 2028 operands.push_back(resultTypeID); 2029 2030 resultID = getNextID(); 2031 operands.push_back(resultID); 2032 valueIDMap[op->getResult(0)] = resultID; 2033 }; 2034 2035 for (Value operand : op->getOperands()) 2036 operands.push_back(getValueID(operand)); 2037 2038 (void)emitDebugLine(functionBody, loc); 2039 2040 if (extInstSet.empty()) { 2041 (void)encodeInstructionInto(functionBody, 2042 static_cast<spirv::Opcode>(opcode), operands); 2043 } else { 2044 (void)encodeExtensionInstruction(op, extInstSet, opcode, operands); 2045 } 2046 2047 if (op->getNumResults() != 0) { 2048 for (auto attr : op->getAttrs()) { 2049 if (failed(processDecoration(loc, resultID, attr))) 2050 return failure(); 2051 } 2052 } 2053 2054 return success(); 2055 } 2056 2057 namespace { 2058 template <> 2059 LogicalResult 2060 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { 2061 SmallVector<uint32_t, 4> operands; 2062 // Add the ExecutionModel. 2063 operands.push_back(static_cast<uint32_t>(op.execution_model())); 2064 // Add the function <id>. 2065 auto funcID = getFunctionID(op.fn()); 2066 if (!funcID) { 2067 return op.emitError("missing <id> for function ") 2068 << op.fn() 2069 << "; function needs to be defined before spv.EntryPoint is " 2070 "serialized"; 2071 } 2072 operands.push_back(funcID); 2073 // Add the name of the function. 2074 (void)spirv::encodeStringLiteralInto(operands, op.fn()); 2075 2076 // Add the interface values. 2077 if (auto interface = op.interface()) { 2078 for (auto var : interface.getValue()) { 2079 auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue()); 2080 if (!id) { 2081 return op.emitError("referencing undefined global variable." 2082 "spv.EntryPoint is at the end of spv.module. All " 2083 "referenced variables should already be defined"); 2084 } 2085 operands.push_back(id); 2086 } 2087 } 2088 return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, 2089 operands); 2090 } 2091 2092 template <> 2093 LogicalResult 2094 Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) { 2095 StringRef argNames[] = {"execution_scope", "memory_scope", 2096 "memory_semantics"}; 2097 SmallVector<uint32_t, 3> operands; 2098 2099 for (auto argName : argNames) { 2100 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName); 2101 auto operand = prepareConstantInt(op.getLoc(), argIntAttr); 2102 if (!operand) { 2103 return failure(); 2104 } 2105 operands.push_back(operand); 2106 } 2107 2108 return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier, 2109 operands); 2110 } 2111 2112 template <> 2113 LogicalResult 2114 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) { 2115 SmallVector<uint32_t, 4> operands; 2116 // Add the function <id>. 2117 auto funcID = getFunctionID(op.fn()); 2118 if (!funcID) { 2119 return op.emitError("missing <id> for function ") 2120 << op.fn() 2121 << "; function needs to be serialized before ExecutionModeOp is " 2122 "serialized"; 2123 } 2124 operands.push_back(funcID); 2125 // Add the ExecutionMode. 2126 operands.push_back(static_cast<uint32_t>(op.execution_mode())); 2127 2128 // Serialize values if any. 2129 auto values = op.values(); 2130 if (values) { 2131 for (auto &intVal : values.getValue()) { 2132 operands.push_back(static_cast<uint32_t>( 2133 intVal.cast<IntegerAttr>().getValue().getZExtValue())); 2134 } 2135 } 2136 return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, 2137 operands); 2138 } 2139 2140 template <> 2141 LogicalResult 2142 Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) { 2143 StringRef argNames[] = {"memory_scope", "memory_semantics"}; 2144 SmallVector<uint32_t, 2> operands; 2145 2146 for (auto argName : argNames) { 2147 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName); 2148 auto operand = prepareConstantInt(op.getLoc(), argIntAttr); 2149 if (!operand) { 2150 return failure(); 2151 } 2152 operands.push_back(operand); 2153 } 2154 2155 return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, 2156 operands); 2157 } 2158 2159 template <> 2160 LogicalResult 2161 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) { 2162 auto funcName = op.callee(); 2163 uint32_t resTypeID = 0; 2164 2165 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); 2166 if (failed(processType(op.getLoc(), resultTy, resTypeID))) 2167 return failure(); 2168 2169 auto funcID = getOrCreateFunctionID(funcName); 2170 auto funcCallID = getNextID(); 2171 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID}; 2172 2173 for (auto value : op.arguments()) { 2174 auto valueID = getValueID(value); 2175 assert(valueID && "cannot find a value for spv.FunctionCall"); 2176 operands.push_back(valueID); 2177 } 2178 2179 if (!resultTy.isa<NoneType>()) 2180 valueIDMap[op.getResult(0)] = funcCallID; 2181 2182 return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, 2183 operands); 2184 } 2185 2186 template <> 2187 LogicalResult 2188 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) { 2189 SmallVector<uint32_t, 4> operands; 2190 SmallVector<StringRef, 2> elidedAttrs; 2191 2192 for (Value operand : op->getOperands()) { 2193 auto id = getValueID(operand); 2194 assert(id && "use before def!"); 2195 operands.push_back(id); 2196 } 2197 2198 if (auto attr = op->getAttr("memory_access")) { 2199 operands.push_back(static_cast<uint32_t>( 2200 attr.cast<IntegerAttr>().getValue().getZExtValue())); 2201 } 2202 2203 elidedAttrs.push_back("memory_access"); 2204 2205 if (auto attr = op->getAttr("alignment")) { 2206 operands.push_back(static_cast<uint32_t>( 2207 attr.cast<IntegerAttr>().getValue().getZExtValue())); 2208 } 2209 2210 elidedAttrs.push_back("alignment"); 2211 2212 if (auto attr = op->getAttr("source_memory_access")) { 2213 operands.push_back(static_cast<uint32_t>( 2214 attr.cast<IntegerAttr>().getValue().getZExtValue())); 2215 } 2216 2217 elidedAttrs.push_back("source_memory_access"); 2218 2219 if (auto attr = op->getAttr("source_alignment")) { 2220 operands.push_back(static_cast<uint32_t>( 2221 attr.cast<IntegerAttr>().getValue().getZExtValue())); 2222 } 2223 2224 elidedAttrs.push_back("source_alignment"); 2225 (void)emitDebugLine(functionBody, op.getLoc()); 2226 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, 2227 operands); 2228 2229 return success(); 2230 } 2231 2232 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and 2233 // various Serializer::processOp<...>() specializations. 2234 #define GET_SERIALIZATION_FNS 2235 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" 2236 } // namespace 2237 2238 LogicalResult Serializer::emitDecoration(uint32_t target, 2239 spirv::Decoration decoration, 2240 ArrayRef<uint32_t> params) { 2241 uint32_t wordCount = 3 + params.size(); 2242 decorations.push_back( 2243 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate)); 2244 decorations.push_back(target); 2245 decorations.push_back(static_cast<uint32_t>(decoration)); 2246 decorations.append(params.begin(), params.end()); 2247 return success(); 2248 } 2249 2250 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary, 2251 Location loc) { 2252 if (!emitDebugInfo) 2253 return success(); 2254 2255 if (lastProcessedWasMergeInst) { 2256 lastProcessedWasMergeInst = false; 2257 return success(); 2258 } 2259 2260 auto fileLoc = loc.dyn_cast<FileLineColLoc>(); 2261 if (fileLoc) 2262 (void)encodeInstructionInto( 2263 binary, spirv::Opcode::OpLine, 2264 {fileID, fileLoc.getLine(), fileLoc.getColumn()}); 2265 return success(); 2266 } 2267 2268 namespace mlir { 2269 LogicalResult spirv::serialize(spirv::ModuleOp module, 2270 SmallVectorImpl<uint32_t> &binary, 2271 bool emitDebugInfo) { 2272 if (!module.vce_triple().hasValue()) 2273 return module.emitError( 2274 "module must have 'vce_triple' attribute to be serializeable"); 2275 2276 Serializer serializer(module, emitDebugInfo); 2277 2278 if (failed(serializer.serialize())) 2279 return failure(); 2280 2281 LLVM_DEBUG(serializer.printValueIDMap(llvm::dbgs())); 2282 2283 serializer.collect(binary); 2284 return success(); 2285 } 2286 } // namespace mlir 2287