1 //===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the MLIR SPIR-V module to SPIR-V binary serialization. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Target/SPIRV/Serialization.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/RegionGraphTraits.h" 21 #include "mlir/Support/LogicalResult.h" 22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 23 #include "llvm/ADT/DepthFirstIterator.h" 24 #include "llvm/ADT/Sequence.h" 25 #include "llvm/ADT/SetVector.h" 26 #include "llvm/ADT/SmallPtrSet.h" 27 #include "llvm/ADT/SmallVector.h" 28 #include "llvm/ADT/StringExtras.h" 29 #include "llvm/ADT/TypeSwitch.h" 30 #include "llvm/ADT/bit.h" 31 #include "llvm/Support/Debug.h" 32 #include "llvm/Support/raw_ostream.h" 33 34 #define DEBUG_TYPE "spirv-serialization" 35 36 using namespace mlir; 37 38 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into 39 /// the given `binary` vector. 40 static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, 41 spirv::Opcode op, 42 ArrayRef<uint32_t> operands) { 43 uint32_t wordCount = 1 + operands.size(); 44 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 45 binary.append(operands.begin(), operands.end()); 46 return success(); 47 } 48 49 /// A pre-order depth-first visitor function for processing basic blocks. 50 /// 51 /// Visits the basic blocks starting from the given `headerBlock` in pre-order 52 /// depth-first manner and calls `blockHandler` on each block. Skips handling 53 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` 54 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s 55 /// successors. 56 /// 57 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order 58 /// of blocks in a function must satisfy the rule that blocks appear before 59 /// all blocks they dominate." This can be achieved by a pre-order CFG 60 /// traversal algorithm. To make the serialization output more logical and 61 /// readable to human, we perform depth-first CFG traversal and delay the 62 /// serialization of the merge block and the continue block, if exists, until 63 /// after all other blocks have been processed. 64 static LogicalResult 65 visitInPrettyBlockOrder(Block *headerBlock, 66 function_ref<LogicalResult(Block *)> blockHandler, 67 bool skipHeader = false, BlockRange skipBlocks = {}) { 68 llvm::df_iterator_default_set<Block *, 4> doneBlocks; 69 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end()); 70 71 for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) { 72 if (skipHeader && block == headerBlock) 73 continue; 74 if (failed(blockHandler(block))) 75 return failure(); 76 } 77 return success(); 78 } 79 80 /// Returns the merge block if the given `op` is a structured control flow op. 81 /// Otherwise returns nullptr. 82 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { 83 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) 84 return selectionOp.getMergeBlock(); 85 if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) 86 return loopOp.getMergeBlock(); 87 return nullptr; 88 } 89 90 /// Given a predecessor `block` for a block with arguments, returns the block 91 /// that should be used as the parent block for SPIR-V OpPhi instructions 92 /// corresponding to the block arguments. 93 static Block *getPhiIncomingBlock(Block *block) { 94 // If the predecessor block in question is the entry block for a spv.loop, 95 // we jump to this spv.loop from its enclosing block. 96 if (block->isEntryBlock()) { 97 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) { 98 // Then the incoming parent block for OpPhi should be the merge block of 99 // the structured control flow op before this loop. 100 Operation *op = loopOp.getOperation(); 101 while ((op = op->getPrevNode()) != nullptr) 102 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) 103 return incomingBlock; 104 // Or the enclosing block itself if no structured control flow ops 105 // exists before this loop. 106 return loopOp->getBlock(); 107 } 108 } 109 110 // Otherwise, we jump from the given predecessor block. Try to see if there is 111 // a structured control flow op inside it. 112 for (Operation &op : llvm::reverse(block->getOperations())) { 113 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) 114 return incomingBlock; 115 } 116 return block; 117 } 118 119 namespace { 120 121 /// A SPIR-V module serializer. 122 /// 123 /// A SPIR-V binary module is a single linear stream of instructions; each 124 /// instruction is composed of 32-bit words with the layout: 125 /// 126 /// | <word-count>|<opcode> | <operand> | <operand> | ... | 127 /// | <------ word -------> | <-- word --> | <-- word --> | ... | 128 /// 129 /// For the first word, the 16 high-order bits are the word count of the 130 /// instruction, the 16 low-order bits are the opcode enumerant. The 131 /// instructions then belong to different sections, which must be laid out in 132 /// the particular order as specified in "2.4 Logical Layout of a Module" of 133 /// the SPIR-V spec. 134 class Serializer { 135 public: 136 /// Creates a serializer for the given SPIR-V `module`. 137 explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false); 138 139 /// Serializes the remembered SPIR-V module. 140 LogicalResult serialize(); 141 142 /// Collects the final SPIR-V `binary`. 143 void collect(SmallVectorImpl<uint32_t> &binary); 144 145 #ifndef NDEBUG 146 /// (For debugging) prints each value and its corresponding result <id>. 147 void printValueIDMap(raw_ostream &os); 148 #endif 149 150 private: 151 // Note that there are two main categories of methods in this class: 152 // * process*() methods are meant to fully serialize a SPIR-V module entity 153 // (header, type, op, etc.). They update internal vectors containing 154 // different binary sections. They are not meant to be called except the 155 // top-level serialization loop. 156 // * prepare*() methods are meant to be helpers that prepare for serializing 157 // certain entity. They may or may not update internal vectors containing 158 // different binary sections. They are meant to be called among themselves 159 // or by other process*() methods for subtasks. 160 161 //===--------------------------------------------------------------------===// 162 // <id> 163 //===--------------------------------------------------------------------===// 164 165 // Note that it is illegal to use id <0> in SPIR-V binary module. Various 166 // methods in this class, if using SPIR-V word (uint32_t) as interface, 167 // check or return id <0> to indicate error in processing. 168 169 /// Consumes the next unused <id>. This method will never return 0. 170 uint32_t getNextID() { return nextID++; } 171 172 //===--------------------------------------------------------------------===// 173 // Module structure 174 //===--------------------------------------------------------------------===// 175 176 uint32_t getSpecConstID(StringRef constName) const { 177 return specConstIDMap.lookup(constName); 178 } 179 180 uint32_t getVariableID(StringRef varName) const { 181 return globalVarIDMap.lookup(varName); 182 } 183 184 uint32_t getFunctionID(StringRef fnName) const { 185 return funcIDMap.lookup(fnName); 186 } 187 188 /// Gets the <id> for the function with the given name. Assigns the next 189 /// available <id> if the function haven't been deserialized. 190 uint32_t getOrCreateFunctionID(StringRef fnName); 191 192 void processCapability(); 193 194 void processDebugInfo(); 195 196 void processExtension(); 197 198 void processMemoryModel(); 199 200 LogicalResult processConstantOp(spirv::ConstantOp op); 201 202 LogicalResult processSpecConstantOp(spirv::SpecConstantOp op); 203 204 LogicalResult 205 processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op); 206 207 LogicalResult 208 processSpecConstantOperationOp(spirv::SpecConstantOperationOp op); 209 210 /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA 211 /// value to use with other operations. The SPIR-V spec recommends that 212 /// OpUndef be generated at module level. The serialization generates an 213 /// OpUndef for each type needed at module level. 214 LogicalResult processUndefOp(spirv::UndefOp op); 215 216 /// Emit OpName for the given `resultID`. 217 LogicalResult processName(uint32_t resultID, StringRef name); 218 219 /// Processes a SPIR-V function op. 220 LogicalResult processFuncOp(spirv::FuncOp op); 221 222 LogicalResult processVariableOp(spirv::VariableOp op); 223 224 /// Process a SPIR-V GlobalVariableOp 225 LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp); 226 227 /// Process attributes that translate to decorations on the result <id> 228 LogicalResult processDecoration(Location loc, uint32_t resultID, 229 NamedAttribute attr); 230 231 template <typename DType> 232 LogicalResult processTypeDecoration(Location loc, DType type, 233 uint32_t resultId) { 234 return emitError(loc, "unhandled decoration for type:") << type; 235 } 236 237 /// Process member decoration 238 LogicalResult processMemberDecoration( 239 uint32_t structID, 240 const spirv::StructType::MemberDecorationInfo &memberDecorationInfo); 241 242 //===--------------------------------------------------------------------===// 243 // Types 244 //===--------------------------------------------------------------------===// 245 246 uint32_t getTypeID(Type type) const { return typeIDMap.lookup(type); } 247 248 Type getVoidType() { return mlirBuilder.getNoneType(); } 249 250 bool isVoidType(Type type) const { return type.isa<NoneType>(); } 251 252 /// Returns true if the given type is a pointer type to a struct in some 253 /// interface storage class. 254 bool isInterfaceStructPtrType(Type type) const; 255 256 /// Main dispatch method for serializing a type. The result <id> of the 257 /// serialized type will be returned as `typeID`. 258 LogicalResult processType(Location loc, Type type, uint32_t &typeID); 259 LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID, 260 llvm::SetVector<StringRef> &serializationCtx); 261 262 /// Method for preparing basic SPIR-V type serialization. Returns the type's 263 /// opcode and operands for the instruction via `typeEnum` and `operands`. 264 LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID, 265 spirv::Opcode &typeEnum, 266 SmallVectorImpl<uint32_t> &operands, 267 bool &deferSerialization, 268 llvm::SetVector<StringRef> &serializationCtx); 269 270 LogicalResult prepareFunctionType(Location loc, FunctionType type, 271 spirv::Opcode &typeEnum, 272 SmallVectorImpl<uint32_t> &operands); 273 274 //===--------------------------------------------------------------------===// 275 // Constant 276 //===--------------------------------------------------------------------===// 277 278 uint32_t getConstantID(Attribute value) const { 279 return constIDMap.lookup(value); 280 } 281 282 /// Main dispatch method for processing a constant with the given `constType` 283 /// and `valueAttr`. `constType` is needed here because we can interpret the 284 /// `valueAttr` as a different type than the type of `valueAttr` itself; for 285 /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType 286 /// constants. 287 uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr); 288 289 /// Prepares array attribute serialization. This method emits corresponding 290 /// OpConstant* and returns the result <id> associated with it. Returns 0 if 291 /// failed. 292 uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr); 293 294 /// Prepares bool/int/float DenseElementsAttr serialization. This method 295 /// iterates the DenseElementsAttr to construct the constant array, and 296 /// returns the result <id> associated with it. Returns 0 if failed. Note 297 /// that the size of `index` must match the rank. 298 /// TODO: Consider to enhance splat elements cases. For splat cases, 299 /// we don't need to loop over all elements, especially when the splat value 300 /// is zero. We can use OpConstantNull when the value is zero. 301 uint32_t prepareDenseElementsConstant(Location loc, Type constType, 302 DenseElementsAttr valueAttr, int dim, 303 MutableArrayRef<uint64_t> index); 304 305 /// Prepares scalar attribute serialization. This method emits corresponding 306 /// OpConstant* and returns the result <id> associated with it. Returns 0 if 307 /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is 308 /// true, then the constant will be serialized as a specialization constant. 309 uint32_t prepareConstantScalar(Location loc, Attribute valueAttr, 310 bool isSpec = false); 311 312 uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, 313 bool isSpec = false); 314 315 uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, 316 bool isSpec = false); 317 318 uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, 319 bool isSpec = false); 320 321 //===--------------------------------------------------------------------===// 322 // Control flow 323 //===--------------------------------------------------------------------===// 324 325 /// Returns the result <id> for the given block. 326 uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); } 327 328 /// Returns the result <id> for the given block. If no <id> has been assigned, 329 /// assigns the next available <id> 330 uint32_t getOrCreateBlockID(Block *block); 331 332 /// Processes the given `block` and emits SPIR-V instructions for all ops 333 /// inside. Does not emit OpLabel for this block if `omitLabel` is true. 334 /// `actionBeforeTerminator` is a callback that will be invoked before 335 /// handling the terminator op. It can be used to inject the Op*Merge 336 /// instruction if this is a SPIR-V selection/loop header block. 337 LogicalResult 338 processBlock(Block *block, bool omitLabel = false, 339 function_ref<void()> actionBeforeTerminator = nullptr); 340 341 /// Emits OpPhi instructions for the given block if it has block arguments. 342 LogicalResult emitPhiForBlockArguments(Block *block); 343 344 LogicalResult processSelectionOp(spirv::SelectionOp selectionOp); 345 346 LogicalResult processLoopOp(spirv::LoopOp loopOp); 347 348 LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp); 349 350 LogicalResult processBranchOp(spirv::BranchOp branchOp); 351 352 //===--------------------------------------------------------------------===// 353 // Operations 354 //===--------------------------------------------------------------------===// 355 356 LogicalResult encodeExtensionInstruction(Operation *op, 357 StringRef extensionSetName, 358 uint32_t opcode, 359 ArrayRef<uint32_t> operands); 360 361 uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); } 362 363 LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp); 364 365 LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp); 366 367 /// Main dispatch method for serializing an operation. 368 LogicalResult processOperation(Operation *op); 369 370 /// Serializes an operation `op` as core instruction with `opcode` if 371 /// `extInstSet` is empty. Otherwise serializes it as an extended instruction 372 /// with `opcode` from `extInstSet`. 373 /// This method is a generic one for dispatching any SPIR-V ops that has no 374 /// variadic operands and attributes in TableGen definitions. 375 LogicalResult processOpWithoutGrammarAttr(Operation *op, StringRef extInstSet, 376 uint32_t opcode); 377 378 /// Dispatches to the serialization function for an operation in SPIR-V 379 /// dialect that is a mirror of an instruction in the SPIR-V spec. This is 380 /// auto-generated from ODS. Dispatch is handled for all operations in SPIR-V 381 /// dialect that have hasOpcode == 1. 382 LogicalResult dispatchToAutogenSerialization(Operation *op); 383 384 /// Serializes an operation in the SPIR-V dialect that is a mirror of an 385 /// instruction in the SPIR-V spec. This is auto generated if hasOpcode == 1 386 /// and autogenSerialization == 1 in ODS. 387 template <typename OpTy> 388 LogicalResult processOp(OpTy op) { 389 return op.emitError("unsupported op serialization"); 390 } 391 392 //===--------------------------------------------------------------------===// 393 // Utilities 394 //===--------------------------------------------------------------------===// 395 396 /// Emits an OpDecorate instruction to decorate the given `target` with the 397 /// given `decoration`. 398 LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration, 399 ArrayRef<uint32_t> params = {}); 400 401 /// Emits an OpLine instruction with the given `loc` location information into 402 /// the given `binary` vector. 403 LogicalResult emitDebugLine(SmallVectorImpl<uint32_t> &binary, Location loc); 404 405 private: 406 /// The SPIR-V module to be serialized. 407 spirv::ModuleOp module; 408 409 /// An MLIR builder for getting MLIR constructs. 410 mlir::Builder mlirBuilder; 411 412 /// A flag which indicates if the debuginfo should be emitted. 413 bool emitDebugInfo = false; 414 415 /// A flag which indicates if the last processed instruction was a merge 416 /// instruction. 417 /// According to SPIR-V spec: "If a branch merge instruction is used, the last 418 /// OpLine in the block must be before its merge instruction". 419 bool lastProcessedWasMergeInst = false; 420 421 /// The <id> of the OpString instruction, which specifies a file name, for 422 /// use by other debug instructions. 423 uint32_t fileID = 0; 424 425 /// The next available result <id>. 426 uint32_t nextID = 1; 427 428 // The following are for different SPIR-V instruction sections. They follow 429 // the logical layout of a SPIR-V module. 430 431 SmallVector<uint32_t, 4> capabilities; 432 SmallVector<uint32_t, 0> extensions; 433 SmallVector<uint32_t, 0> extendedSets; 434 SmallVector<uint32_t, 3> memoryModel; 435 SmallVector<uint32_t, 0> entryPoints; 436 SmallVector<uint32_t, 4> executionModes; 437 SmallVector<uint32_t, 0> debug; 438 SmallVector<uint32_t, 0> names; 439 SmallVector<uint32_t, 0> decorations; 440 SmallVector<uint32_t, 0> typesGlobalValues; 441 SmallVector<uint32_t, 0> functions; 442 443 /// Recursive struct references are serialized as OpTypePointer instructions 444 /// to the recursive struct type. However, the OpTypePointer instruction 445 /// cannot be emitted before the recursive struct's OpTypeStruct. 446 /// RecursiveStructPointerInfo stores the data needed to emit such 447 /// OpTypePointer instructions after forward references to such types. 448 struct RecursiveStructPointerInfo { 449 uint32_t pointerTypeID; 450 spirv::StorageClass storageClass; 451 }; 452 453 // Maps spirv::StructType to its recursive reference member info. 454 DenseMap<Type, SmallVector<RecursiveStructPointerInfo, 0>> 455 recursiveStructInfos; 456 457 /// `functionHeader` contains all the instructions that must be in the first 458 /// block in the function, and `functionBody` contains the rest. After 459 /// processing FuncOp, the encoded instructions of a function are appended to 460 /// `functions`. An example of instructions in `functionHeader` in order: 461 /// OpFunction ... 462 /// OpFunctionParameter ... 463 /// OpFunctionParameter ... 464 /// OpLabel ... 465 /// OpVariable ... 466 /// OpVariable ... 467 SmallVector<uint32_t, 0> functionHeader; 468 SmallVector<uint32_t, 0> functionBody; 469 470 /// Map from type used in SPIR-V module to their <id>s. 471 DenseMap<Type, uint32_t> typeIDMap; 472 473 /// Map from constant values to their <id>s. 474 DenseMap<Attribute, uint32_t> constIDMap; 475 476 /// Map from specialization constant names to their <id>s. 477 llvm::StringMap<uint32_t> specConstIDMap; 478 479 /// Map from GlobalVariableOps name to <id>s. 480 llvm::StringMap<uint32_t> globalVarIDMap; 481 482 /// Map from FuncOps name to <id>s. 483 llvm::StringMap<uint32_t> funcIDMap; 484 485 /// Map from blocks to their <id>s. 486 DenseMap<Block *, uint32_t> blockIDMap; 487 488 /// Map from the Type to the <id> that represents undef value of that type. 489 DenseMap<Type, uint32_t> undefValIDMap; 490 491 /// Map from results of normal operations to their <id>s. 492 DenseMap<Value, uint32_t> valueIDMap; 493 494 /// Map from extended instruction set name to <id>s. 495 llvm::StringMap<uint32_t> extendedInstSetIDMap; 496 497 /// Map from values used in OpPhi instructions to their offset in the 498 /// `functions` section. 499 /// 500 /// When processing a block with arguments, we need to emit OpPhi 501 /// instructions to record the predecessor block <id>s and the values they 502 /// send to the block in question. But it's not guaranteed all values are 503 /// visited and thus assigned result <id>s. So we need this list to capture 504 /// the offsets into `functions` where a value is used so that we can fix it 505 /// up later after processing all the blocks in a function. 506 /// 507 /// More concretely, say if we are visiting the following blocks: 508 /// 509 /// ```mlir 510 /// ^phi(%arg0: i32): 511 /// ... 512 /// ^parent1: 513 /// ... 514 /// spv.Branch ^phi(%val0: i32) 515 /// ^parent2: 516 /// ... 517 /// spv.Branch ^phi(%val1: i32) 518 /// ``` 519 /// 520 /// When we are serializing the `^phi` block, we need to emit at the beginning 521 /// of the block OpPhi instructions which has the following parameters: 522 /// 523 /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1 524 /// id-for-%val1 id-for-^parent2 525 /// 526 /// But we don't know the <id> for %val0 and %val1 yet. One way is to visit 527 /// all the blocks twice and use the first visit to assign an <id> to each 528 /// value. But it's paying the overheads just for OpPhi emission. Instead, 529 /// we still visit the blocks once for emission. When we emit the OpPhi 530 /// instructions, we use 0 as a placeholder for the <id>s for %val0 and %val1. 531 /// At the same time, we record their offsets in the emitted binary (which is 532 /// placed inside `functions`) here. And then after emitting all blocks, we 533 /// replace the dummy <id> 0 with the real result <id> by overwriting 534 /// `functions[offset]`. 535 DenseMap<Value, SmallVector<size_t, 1>> deferredPhiValues; 536 }; 537 } // namespace 538 539 Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo) 540 : module(module), mlirBuilder(module.getContext()), 541 emitDebugInfo(emitDebugInfo) {} 542 543 LogicalResult Serializer::serialize() { 544 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); 545 546 if (failed(module.verify())) 547 return failure(); 548 549 // TODO: handle the other sections 550 processCapability(); 551 processExtension(); 552 processMemoryModel(); 553 processDebugInfo(); 554 555 // Iterate over the module body to serialize it. Assumptions are that there is 556 // only one basic block in the moduleOp 557 for (auto &op : module.getBlock()) { 558 if (failed(processOperation(&op))) { 559 return failure(); 560 } 561 } 562 563 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); 564 return success(); 565 } 566 567 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { 568 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + 569 extensions.size() + extendedSets.size() + 570 memoryModel.size() + entryPoints.size() + 571 executionModes.size() + decorations.size() + 572 typesGlobalValues.size() + functions.size(); 573 574 binary.clear(); 575 binary.reserve(moduleSize); 576 577 spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); 578 binary.append(capabilities.begin(), capabilities.end()); 579 binary.append(extensions.begin(), extensions.end()); 580 binary.append(extendedSets.begin(), extendedSets.end()); 581 binary.append(memoryModel.begin(), memoryModel.end()); 582 binary.append(entryPoints.begin(), entryPoints.end()); 583 binary.append(executionModes.begin(), executionModes.end()); 584 binary.append(debug.begin(), debug.end()); 585 binary.append(names.begin(), names.end()); 586 binary.append(decorations.begin(), decorations.end()); 587 binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); 588 binary.append(functions.begin(), functions.end()); 589 } 590 591 #ifndef NDEBUG 592 void Serializer::printValueIDMap(raw_ostream &os) { 593 os << "\n= Value <id> Map =\n\n"; 594 for (auto valueIDPair : valueIDMap) { 595 Value val = valueIDPair.first; 596 os << " " << val << " " 597 << "id = " << valueIDPair.second << ' '; 598 if (auto *op = val.getDefiningOp()) { 599 os << "from op '" << op->getName() << "'"; 600 } else if (auto arg = val.dyn_cast<BlockArgument>()) { 601 Block *block = arg.getOwner(); 602 os << "from argument of block " << block << ' '; 603 os << " in op '" << block->getParentOp()->getName() << "'"; 604 } 605 os << '\n'; 606 } 607 } 608 #endif 609 610 //===----------------------------------------------------------------------===// 611 // Module structure 612 //===----------------------------------------------------------------------===// 613 614 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { 615 auto funcID = funcIDMap.lookup(fnName); 616 if (!funcID) { 617 funcID = getNextID(); 618 funcIDMap[fnName] = funcID; 619 } 620 return funcID; 621 } 622 623 void Serializer::processCapability() { 624 for (auto cap : module.vce_triple()->getCapabilities()) 625 encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, 626 {static_cast<uint32_t>(cap)}); 627 } 628 629 void Serializer::processDebugInfo() { 630 if (!emitDebugInfo) 631 return; 632 auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>(); 633 auto fileName = fileLoc ? fileLoc.getFilename() : "<unknown>"; 634 fileID = getNextID(); 635 SmallVector<uint32_t, 16> operands; 636 operands.push_back(fileID); 637 spirv::encodeStringLiteralInto(operands, fileName); 638 encodeInstructionInto(debug, spirv::Opcode::OpString, operands); 639 // TODO: Encode more debug instructions. 640 } 641 642 void Serializer::processExtension() { 643 llvm::SmallVector<uint32_t, 16> extName; 644 for (spirv::Extension ext : module.vce_triple()->getExtensions()) { 645 extName.clear(); 646 spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext)); 647 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); 648 } 649 } 650 651 void Serializer::processMemoryModel() { 652 uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt(); 653 uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt(); 654 655 encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm}); 656 } 657 658 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { 659 if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) { 660 valueIDMap[op.getResult()] = resultID; 661 return success(); 662 } 663 return failure(); 664 } 665 666 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { 667 if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(), 668 /*isSpec=*/true)) { 669 // Emit the OpDecorate instruction for SpecId. 670 if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) { 671 auto val = static_cast<uint32_t>(specID.getInt()); 672 emitDecoration(resultID, spirv::Decoration::SpecId, {val}); 673 } 674 675 specConstIDMap[op.sym_name()] = resultID; 676 return processName(resultID, op.sym_name()); 677 } 678 return failure(); 679 } 680 681 LogicalResult 682 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { 683 uint32_t typeID = 0; 684 if (failed(processType(op.getLoc(), op.type(), typeID))) { 685 return failure(); 686 } 687 688 auto resultID = getNextID(); 689 690 SmallVector<uint32_t, 8> operands; 691 operands.push_back(typeID); 692 operands.push_back(resultID); 693 694 auto constituents = op.constituents(); 695 696 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { 697 auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>(); 698 699 auto constituentName = constituent.getValue(); 700 auto constituentID = getSpecConstID(constituentName); 701 702 if (!constituentID) { 703 return op.emitError("unknown result <id> for specialization constant ") 704 << constituentName; 705 } 706 707 operands.push_back(constituentID); 708 } 709 710 encodeInstructionInto(typesGlobalValues, 711 spirv::Opcode::OpSpecConstantComposite, operands); 712 specConstIDMap[op.sym_name()] = resultID; 713 714 return processName(resultID, op.sym_name()); 715 } 716 717 LogicalResult 718 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { 719 uint32_t typeID = 0; 720 if (failed(processType(op.getLoc(), op.getType(), typeID))) { 721 return failure(); 722 } 723 724 auto resultID = getNextID(); 725 726 SmallVector<uint32_t, 8> operands; 727 operands.push_back(typeID); 728 operands.push_back(resultID); 729 730 Block &block = op.getRegion().getBlocks().front(); 731 Operation &enclosedOp = block.getOperations().front(); 732 733 std::string enclosedOpName; 734 llvm::raw_string_ostream rss(enclosedOpName); 735 rss << "Op" << enclosedOp.getName().stripDialect(); 736 auto enclosedOpcode = spirv::symbolizeOpcode(rss.str()); 737 738 if (!enclosedOpcode) { 739 op.emitError("Couldn't find op code for op ") 740 << enclosedOp.getName().getStringRef(); 741 return failure(); 742 } 743 744 operands.push_back(static_cast<uint32_t>(enclosedOpcode.getValue())); 745 746 // Append operands to the enclosed op to the list of operands. 747 for (Value operand : enclosedOp.getOperands()) { 748 uint32_t id = getValueID(operand); 749 assert(id && "use before def!"); 750 operands.push_back(id); 751 } 752 753 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp, 754 operands); 755 valueIDMap[op.getResult()] = resultID; 756 757 return success(); 758 } 759 760 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { 761 auto undefType = op.getType(); 762 auto &id = undefValIDMap[undefType]; 763 if (!id) { 764 id = getNextID(); 765 uint32_t typeID = 0; 766 if (failed(processType(op.getLoc(), undefType, typeID)) || 767 failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, 768 {typeID, id}))) { 769 return failure(); 770 } 771 } 772 valueIDMap[op.getResult()] = id; 773 return success(); 774 } 775 776 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, 777 NamedAttribute attr) { 778 auto attrName = attr.first.strref(); 779 auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); 780 auto decoration = spirv::symbolizeDecoration(decorationName); 781 if (!decoration) { 782 return emitError( 783 loc, "non-argument attributes expected to have snake-case-ified " 784 "decoration name, unhandled attribute with name : ") 785 << attrName; 786 } 787 SmallVector<uint32_t, 1> args; 788 switch (decoration.getValue()) { 789 case spirv::Decoration::Binding: 790 case spirv::Decoration::DescriptorSet: 791 case spirv::Decoration::Location: 792 if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) { 793 args.push_back(intAttr.getValue().getZExtValue()); 794 break; 795 } 796 return emitError(loc, "expected integer attribute for ") << attrName; 797 case spirv::Decoration::BuiltIn: 798 if (auto strAttr = attr.second.dyn_cast<StringAttr>()) { 799 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); 800 if (enumVal) { 801 args.push_back(static_cast<uint32_t>(enumVal.getValue())); 802 break; 803 } 804 return emitError(loc, "invalid ") 805 << attrName << " attribute " << strAttr.getValue(); 806 } 807 return emitError(loc, "expected string attribute for ") << attrName; 808 case spirv::Decoration::Aliased: 809 case spirv::Decoration::Flat: 810 case spirv::Decoration::NonReadable: 811 case spirv::Decoration::NonWritable: 812 case spirv::Decoration::NoPerspective: 813 case spirv::Decoration::Restrict: 814 // For unit attributes, the args list has no values so we do nothing 815 if (auto unitAttr = attr.second.dyn_cast<UnitAttr>()) 816 break; 817 return emitError(loc, "expected unit attribute for ") << attrName; 818 default: 819 return emitError(loc, "unhandled decoration ") << decorationName; 820 } 821 return emitDecoration(resultID, decoration.getValue(), args); 822 } 823 824 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { 825 assert(!name.empty() && "unexpected empty string for OpName"); 826 827 SmallVector<uint32_t, 4> nameOperands; 828 nameOperands.push_back(resultID); 829 if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) { 830 return failure(); 831 } 832 return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); 833 } 834 835 namespace { 836 template <> 837 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( 838 Location loc, spirv::ArrayType type, uint32_t resultID) { 839 if (unsigned stride = type.getArrayStride()) { 840 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 841 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 842 } 843 return success(); 844 } 845 846 template <> 847 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>( 848 Location Loc, spirv::RuntimeArrayType type, uint32_t resultID) { 849 if (unsigned stride = type.getArrayStride()) { 850 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 851 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 852 } 853 return success(); 854 } 855 856 LogicalResult Serializer::processMemberDecoration( 857 uint32_t structID, 858 const spirv::StructType::MemberDecorationInfo &memberDecoration) { 859 SmallVector<uint32_t, 4> args( 860 {structID, memberDecoration.memberIndex, 861 static_cast<uint32_t>(memberDecoration.decoration)}); 862 if (memberDecoration.hasValue) { 863 args.push_back(memberDecoration.decorationValue); 864 } 865 return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, 866 args); 867 } 868 } // namespace 869 870 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { 871 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); 872 assert(functionHeader.empty() && functionBody.empty()); 873 874 uint32_t fnTypeID = 0; 875 // Generate type of the function. 876 processType(op.getLoc(), op.getType(), fnTypeID); 877 878 // Add the function definition. 879 SmallVector<uint32_t, 4> operands; 880 uint32_t resTypeID = 0; 881 auto resultTypes = op.getType().getResults(); 882 if (resultTypes.size() > 1) { 883 return op.emitError("cannot serialize function with multiple return types"); 884 } 885 if (failed(processType(op.getLoc(), 886 (resultTypes.empty() ? getVoidType() : resultTypes[0]), 887 resTypeID))) { 888 return failure(); 889 } 890 operands.push_back(resTypeID); 891 auto funcID = getOrCreateFunctionID(op.getName()); 892 operands.push_back(funcID); 893 operands.push_back(static_cast<uint32_t>(op.function_control())); 894 operands.push_back(fnTypeID); 895 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); 896 897 // Add function name. 898 if (failed(processName(funcID, op.getName()))) { 899 return failure(); 900 } 901 902 // Declare the parameters. 903 for (auto arg : op.getArguments()) { 904 uint32_t argTypeID = 0; 905 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { 906 return failure(); 907 } 908 auto argValueID = getNextID(); 909 valueIDMap[arg] = argValueID; 910 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, 911 {argTypeID, argValueID}); 912 } 913 914 // Process the body. 915 if (op.isExternal()) { 916 return op.emitError("external function is unhandled"); 917 } 918 919 // Some instructions (e.g., OpVariable) in a function must be in the first 920 // block in the function. These instructions will be put in functionHeader. 921 // Thus, we put the label in functionHeader first, and omit it from the first 922 // block. 923 encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, 924 {getOrCreateBlockID(&op.front())}); 925 processBlock(&op.front(), /*omitLabel=*/true); 926 if (failed(visitInPrettyBlockOrder( 927 &op.front(), [&](Block *block) { return processBlock(block); }, 928 /*skipHeader=*/true))) { 929 return failure(); 930 } 931 932 // There might be OpPhi instructions who have value references needing to fix. 933 for (auto deferredValue : deferredPhiValues) { 934 Value value = deferredValue.first; 935 uint32_t id = getValueID(value); 936 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value 937 << " to id = " << id << '\n'); 938 assert(id && "OpPhi references undefined value!"); 939 for (size_t offset : deferredValue.second) 940 functionBody[offset] = id; 941 } 942 deferredPhiValues.clear(); 943 944 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() 945 << "' --\n"); 946 // Insert OpFunctionEnd. 947 if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, 948 {}))) { 949 return failure(); 950 } 951 952 functions.append(functionHeader.begin(), functionHeader.end()); 953 functions.append(functionBody.begin(), functionBody.end()); 954 functionHeader.clear(); 955 functionBody.clear(); 956 957 return success(); 958 } 959 960 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { 961 SmallVector<uint32_t, 4> operands; 962 SmallVector<StringRef, 2> elidedAttrs; 963 uint32_t resultID = 0; 964 uint32_t resultTypeID = 0; 965 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) { 966 return failure(); 967 } 968 operands.push_back(resultTypeID); 969 resultID = getNextID(); 970 valueIDMap[op.getResult()] = resultID; 971 operands.push_back(resultID); 972 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>()); 973 if (attr) { 974 operands.push_back(static_cast<uint32_t>( 975 attr.cast<IntegerAttr>().getValue().getZExtValue())); 976 } 977 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); 978 for (auto arg : op.getODSOperands(0)) { 979 auto argID = getValueID(arg); 980 if (!argID) { 981 return emitError(op.getLoc(), "operand 0 has a use before def"); 982 } 983 operands.push_back(argID); 984 } 985 emitDebugLine(functionHeader, op.getLoc()); 986 encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands); 987 for (auto attr : op->getAttrs()) { 988 if (llvm::any_of(elidedAttrs, 989 [&](StringRef elided) { return attr.first == elided; })) { 990 continue; 991 } 992 if (failed(processDecoration(op.getLoc(), resultID, attr))) { 993 return failure(); 994 } 995 } 996 return success(); 997 } 998 999 LogicalResult 1000 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { 1001 // Get TypeID. 1002 uint32_t resultTypeID = 0; 1003 SmallVector<StringRef, 4> elidedAttrs; 1004 if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { 1005 return failure(); 1006 } 1007 1008 if (isInterfaceStructPtrType(varOp.type())) { 1009 auto structType = varOp.type() 1010 .cast<spirv::PointerType>() 1011 .getPointeeType() 1012 .cast<spirv::StructType>(); 1013 if (failed( 1014 emitDecoration(getTypeID(structType), spirv::Decoration::Block))) { 1015 return varOp.emitError("cannot decorate ") 1016 << structType << " with Block decoration"; 1017 } 1018 } 1019 1020 elidedAttrs.push_back("type"); 1021 SmallVector<uint32_t, 4> operands; 1022 operands.push_back(resultTypeID); 1023 auto resultID = getNextID(); 1024 1025 // Encode the name. 1026 auto varName = varOp.sym_name(); 1027 elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); 1028 if (failed(processName(resultID, varName))) { 1029 return failure(); 1030 } 1031 globalVarIDMap[varName] = resultID; 1032 operands.push_back(resultID); 1033 1034 // Encode StorageClass. 1035 operands.push_back(static_cast<uint32_t>(varOp.storageClass())); 1036 1037 // Encode initialization. 1038 if (auto initializer = varOp.initializer()) { 1039 auto initializerID = getVariableID(initializer.getValue()); 1040 if (!initializerID) { 1041 return emitError(varOp.getLoc(), 1042 "invalid usage of undefined variable as initializer"); 1043 } 1044 operands.push_back(initializerID); 1045 elidedAttrs.push_back("initializer"); 1046 } 1047 1048 emitDebugLine(typesGlobalValues, varOp.getLoc()); 1049 if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, 1050 operands))) { 1051 elidedAttrs.push_back("initializer"); 1052 return failure(); 1053 } 1054 1055 // Encode decorations. 1056 for (auto attr : varOp->getAttrs()) { 1057 if (llvm::any_of(elidedAttrs, 1058 [&](StringRef elided) { return attr.first == elided; })) { 1059 continue; 1060 } 1061 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { 1062 return failure(); 1063 } 1064 } 1065 return success(); 1066 } 1067 1068 //===----------------------------------------------------------------------===// 1069 // Type 1070 //===----------------------------------------------------------------------===// 1071 1072 // According to the SPIR-V spec "Validation Rules for Shader Capabilities": 1073 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and 1074 // PushConstant Storage Classes must be explicitly laid out." 1075 bool Serializer::isInterfaceStructPtrType(Type type) const { 1076 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 1077 switch (ptrType.getStorageClass()) { 1078 case spirv::StorageClass::PhysicalStorageBuffer: 1079 case spirv::StorageClass::PushConstant: 1080 case spirv::StorageClass::StorageBuffer: 1081 case spirv::StorageClass::Uniform: 1082 return ptrType.getPointeeType().isa<spirv::StructType>(); 1083 default: 1084 break; 1085 } 1086 } 1087 return false; 1088 } 1089 1090 LogicalResult Serializer::processType(Location loc, Type type, 1091 uint32_t &typeID) { 1092 // Maintains a set of names for nested identified struct types. This is used 1093 // to properly serialize recursive references. 1094 llvm::SetVector<StringRef> serializationCtx; 1095 return processTypeImpl(loc, type, typeID, serializationCtx); 1096 } 1097 1098 LogicalResult 1099 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, 1100 llvm::SetVector<StringRef> &serializationCtx) { 1101 typeID = getTypeID(type); 1102 if (typeID) { 1103 return success(); 1104 } 1105 typeID = getNextID(); 1106 SmallVector<uint32_t, 4> operands; 1107 1108 operands.push_back(typeID); 1109 auto typeEnum = spirv::Opcode::OpTypeVoid; 1110 bool deferSerialization = false; 1111 1112 if ((type.isa<FunctionType>() && 1113 succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum, 1114 operands))) || 1115 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, 1116 deferSerialization, serializationCtx))) { 1117 if (deferSerialization) 1118 return success(); 1119 1120 typeIDMap[type] = typeID; 1121 1122 if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands))) 1123 return failure(); 1124 1125 if (recursiveStructInfos.count(type) != 0) { 1126 // This recursive struct type is emitted already, now the OpTypePointer 1127 // instructions referring to recursive references are emitted as well. 1128 for (auto &ptrInfo : recursiveStructInfos[type]) { 1129 // TODO: This might not work if more than 1 recursive reference is 1130 // present in the struct. 1131 SmallVector<uint32_t, 4> ptrOperands; 1132 ptrOperands.push_back(ptrInfo.pointerTypeID); 1133 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass)); 1134 ptrOperands.push_back(typeIDMap[type]); 1135 1136 if (failed(encodeInstructionInto( 1137 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands))) 1138 return failure(); 1139 } 1140 1141 recursiveStructInfos[type].clear(); 1142 } 1143 1144 return success(); 1145 } 1146 1147 return failure(); 1148 } 1149 1150 LogicalResult Serializer::prepareBasicType( 1151 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, 1152 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization, 1153 llvm::SetVector<StringRef> &serializationCtx) { 1154 deferSerialization = false; 1155 1156 if (isVoidType(type)) { 1157 typeEnum = spirv::Opcode::OpTypeVoid; 1158 return success(); 1159 } 1160 1161 if (auto intType = type.dyn_cast<IntegerType>()) { 1162 if (intType.getWidth() == 1) { 1163 typeEnum = spirv::Opcode::OpTypeBool; 1164 return success(); 1165 } 1166 1167 typeEnum = spirv::Opcode::OpTypeInt; 1168 operands.push_back(intType.getWidth()); 1169 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 1170 // to preserve or validate. 1171 // 0 indicates unsigned, or no signedness semantics 1172 // 1 indicates signed semantics." 1173 operands.push_back(intType.isSigned() ? 1 : 0); 1174 return success(); 1175 } 1176 1177 if (auto floatType = type.dyn_cast<FloatType>()) { 1178 typeEnum = spirv::Opcode::OpTypeFloat; 1179 operands.push_back(floatType.getWidth()); 1180 return success(); 1181 } 1182 1183 if (auto vectorType = type.dyn_cast<VectorType>()) { 1184 uint32_t elementTypeID = 0; 1185 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, 1186 serializationCtx))) { 1187 return failure(); 1188 } 1189 typeEnum = spirv::Opcode::OpTypeVector; 1190 operands.push_back(elementTypeID); 1191 operands.push_back(vectorType.getNumElements()); 1192 return success(); 1193 } 1194 1195 if (auto imageType = type.dyn_cast<spirv::ImageType>()) { 1196 typeEnum = spirv::Opcode::OpTypeImage; 1197 uint32_t sampledTypeID = 0; 1198 if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) 1199 return failure(); 1200 1201 operands.push_back(sampledTypeID); 1202 operands.push_back(static_cast<uint32_t>(imageType.getDim())); 1203 operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo())); 1204 operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo())); 1205 operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo())); 1206 operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo())); 1207 operands.push_back(static_cast<uint32_t>(imageType.getImageFormat())); 1208 return success(); 1209 } 1210 1211 if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) { 1212 typeEnum = spirv::Opcode::OpTypeArray; 1213 uint32_t elementTypeID = 0; 1214 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, 1215 serializationCtx))) { 1216 return failure(); 1217 } 1218 operands.push_back(elementTypeID); 1219 if (auto elementCountID = prepareConstantInt( 1220 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { 1221 operands.push_back(elementCountID); 1222 } 1223 return processTypeDecoration(loc, arrayType, resultID); 1224 } 1225 1226 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 1227 uint32_t pointeeTypeID = 0; 1228 spirv::StructType pointeeStruct = 1229 ptrType.getPointeeType().dyn_cast<spirv::StructType>(); 1230 1231 if (pointeeStruct && pointeeStruct.isIdentified() && 1232 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { 1233 // A recursive reference to an enclosing struct is found. 1234 // 1235 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage 1236 // class as operands. 1237 SmallVector<uint32_t, 2> forwardPtrOperands; 1238 forwardPtrOperands.push_back(resultID); 1239 forwardPtrOperands.push_back( 1240 static_cast<uint32_t>(ptrType.getStorageClass())); 1241 1242 encodeInstructionInto(typesGlobalValues, 1243 spirv::Opcode::OpTypeForwardPointer, 1244 forwardPtrOperands); 1245 1246 // 2. Find the pointee (enclosing) struct. 1247 auto structType = spirv::StructType::getIdentified( 1248 module.getContext(), pointeeStruct.getIdentifier()); 1249 1250 if (!structType) 1251 return failure(); 1252 1253 // 3. Mark the OpTypePointer that is supposed to be emitted by this call 1254 // as deferred. 1255 deferSerialization = true; 1256 1257 // 4. Record the info needed to emit the deferred OpTypePointer 1258 // instruction when the enclosing struct is completely serialized. 1259 recursiveStructInfos[structType].push_back( 1260 {resultID, ptrType.getStorageClass()}); 1261 } else { 1262 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, 1263 serializationCtx))) 1264 return failure(); 1265 } 1266 1267 typeEnum = spirv::Opcode::OpTypePointer; 1268 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); 1269 operands.push_back(pointeeTypeID); 1270 return success(); 1271 } 1272 1273 if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) { 1274 uint32_t elementTypeID = 0; 1275 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), 1276 elementTypeID, serializationCtx))) { 1277 return failure(); 1278 } 1279 typeEnum = spirv::Opcode::OpTypeRuntimeArray; 1280 operands.push_back(elementTypeID); 1281 return processTypeDecoration(loc, runtimeArrayType, resultID); 1282 } 1283 1284 if (auto structType = type.dyn_cast<spirv::StructType>()) { 1285 if (structType.isIdentified()) { 1286 processName(resultID, structType.getIdentifier()); 1287 serializationCtx.insert(structType.getIdentifier()); 1288 } 1289 1290 bool hasOffset = structType.hasOffset(); 1291 for (auto elementIndex : 1292 llvm::seq<uint32_t>(0, structType.getNumElements())) { 1293 uint32_t elementTypeID = 0; 1294 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), 1295 elementTypeID, serializationCtx))) { 1296 return failure(); 1297 } 1298 operands.push_back(elementTypeID); 1299 if (hasOffset) { 1300 // Decorate each struct member with an offset 1301 spirv::StructType::MemberDecorationInfo offsetDecoration{ 1302 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, 1303 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))}; 1304 if (failed(processMemberDecoration(resultID, offsetDecoration))) { 1305 return emitError(loc, "cannot decorate ") 1306 << elementIndex << "-th member of " << structType 1307 << " with its offset"; 1308 } 1309 } 1310 } 1311 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 1312 structType.getMemberDecorations(memberDecorations); 1313 1314 for (auto &memberDecoration : memberDecorations) { 1315 if (failed(processMemberDecoration(resultID, memberDecoration))) { 1316 return emitError(loc, "cannot decorate ") 1317 << static_cast<uint32_t>(memberDecoration.memberIndex) 1318 << "-th member of " << structType << " with " 1319 << stringifyDecoration(memberDecoration.decoration); 1320 } 1321 } 1322 1323 typeEnum = spirv::Opcode::OpTypeStruct; 1324 1325 if (structType.isIdentified()) 1326 serializationCtx.remove(structType.getIdentifier()); 1327 1328 return success(); 1329 } 1330 1331 if (auto cooperativeMatrixType = 1332 type.dyn_cast<spirv::CooperativeMatrixNVType>()) { 1333 uint32_t elementTypeID = 0; 1334 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), 1335 elementTypeID, serializationCtx))) { 1336 return failure(); 1337 } 1338 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; 1339 auto getConstantOp = [&](uint32_t id) { 1340 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); 1341 return prepareConstantInt(loc, attr); 1342 }; 1343 operands.push_back(elementTypeID); 1344 operands.push_back( 1345 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope()))); 1346 operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); 1347 operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); 1348 return success(); 1349 } 1350 1351 if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) { 1352 uint32_t elementTypeID = 0; 1353 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, 1354 serializationCtx))) { 1355 return failure(); 1356 } 1357 typeEnum = spirv::Opcode::OpTypeMatrix; 1358 operands.push_back(elementTypeID); 1359 operands.push_back(matrixType.getNumColumns()); 1360 return success(); 1361 } 1362 1363 // TODO: Handle other types. 1364 return emitError(loc, "unhandled type in serialization: ") << type; 1365 } 1366 1367 LogicalResult 1368 Serializer::prepareFunctionType(Location loc, FunctionType type, 1369 spirv::Opcode &typeEnum, 1370 SmallVectorImpl<uint32_t> &operands) { 1371 typeEnum = spirv::Opcode::OpTypeFunction; 1372 assert(type.getNumResults() <= 1 && 1373 "serialization supports only a single return value"); 1374 uint32_t resultID = 0; 1375 if (failed(processType( 1376 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), 1377 resultID))) { 1378 return failure(); 1379 } 1380 operands.push_back(resultID); 1381 for (auto &res : type.getInputs()) { 1382 uint32_t argTypeID = 0; 1383 if (failed(processType(loc, res, argTypeID))) { 1384 return failure(); 1385 } 1386 operands.push_back(argTypeID); 1387 } 1388 return success(); 1389 } 1390 1391 //===----------------------------------------------------------------------===// 1392 // Constant 1393 //===----------------------------------------------------------------------===// 1394 1395 uint32_t Serializer::prepareConstant(Location loc, Type constType, 1396 Attribute valueAttr) { 1397 if (auto id = prepareConstantScalar(loc, valueAttr)) { 1398 return id; 1399 } 1400 1401 // This is a composite literal. We need to handle each component separately 1402 // and then emit an OpConstantComposite for the whole. 1403 1404 if (auto id = getConstantID(valueAttr)) { 1405 return id; 1406 } 1407 1408 uint32_t typeID = 0; 1409 if (failed(processType(loc, constType, typeID))) { 1410 return 0; 1411 } 1412 1413 uint32_t resultID = 0; 1414 if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) { 1415 int rank = attr.getType().dyn_cast<ShapedType>().getRank(); 1416 SmallVector<uint64_t, 4> index(rank); 1417 resultID = prepareDenseElementsConstant(loc, constType, attr, 1418 /*dim=*/0, index); 1419 } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) { 1420 resultID = prepareArrayConstant(loc, constType, arrayAttr); 1421 } 1422 1423 if (resultID == 0) { 1424 emitError(loc, "cannot serialize attribute: ") << valueAttr; 1425 return 0; 1426 } 1427 1428 constIDMap[valueAttr] = resultID; 1429 return resultID; 1430 } 1431 1432 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, 1433 ArrayAttr attr) { 1434 uint32_t typeID = 0; 1435 if (failed(processType(loc, constType, typeID))) { 1436 return 0; 1437 } 1438 1439 uint32_t resultID = getNextID(); 1440 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 1441 operands.reserve(attr.size() + 2); 1442 auto elementType = constType.cast<spirv::ArrayType>().getElementType(); 1443 for (Attribute elementAttr : attr) { 1444 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { 1445 operands.push_back(elementID); 1446 } else { 1447 return 0; 1448 } 1449 } 1450 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 1451 encodeInstructionInto(typesGlobalValues, opcode, operands); 1452 1453 return resultID; 1454 } 1455 1456 // TODO: Turn the below function into iterative function, instead of 1457 // recursive function. 1458 uint32_t 1459 Serializer::prepareDenseElementsConstant(Location loc, Type constType, 1460 DenseElementsAttr valueAttr, int dim, 1461 MutableArrayRef<uint64_t> index) { 1462 auto shapedType = valueAttr.getType().dyn_cast<ShapedType>(); 1463 assert(dim <= shapedType.getRank()); 1464 if (shapedType.getRank() == dim) { 1465 if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) { 1466 return attr.getType().getElementType().isInteger(1) 1467 ? prepareConstantBool(loc, attr.getValue<BoolAttr>(index)) 1468 : prepareConstantInt(loc, attr.getValue<IntegerAttr>(index)); 1469 } 1470 if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) { 1471 return prepareConstantFp(loc, attr.getValue<FloatAttr>(index)); 1472 } 1473 return 0; 1474 } 1475 1476 uint32_t typeID = 0; 1477 if (failed(processType(loc, constType, typeID))) { 1478 return 0; 1479 } 1480 1481 uint32_t resultID = getNextID(); 1482 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 1483 operands.reserve(shapedType.getDimSize(dim) + 2); 1484 auto elementType = constType.cast<spirv::CompositeType>().getElementType(0); 1485 for (int i = 0; i < shapedType.getDimSize(dim); ++i) { 1486 index[dim] = i; 1487 if (auto elementID = prepareDenseElementsConstant( 1488 loc, elementType, valueAttr, dim + 1, index)) { 1489 operands.push_back(elementID); 1490 } else { 1491 return 0; 1492 } 1493 } 1494 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 1495 encodeInstructionInto(typesGlobalValues, opcode, operands); 1496 1497 return resultID; 1498 } 1499 1500 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, 1501 bool isSpec) { 1502 if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) { 1503 return prepareConstantFp(loc, floatAttr, isSpec); 1504 } 1505 if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) { 1506 return prepareConstantBool(loc, boolAttr, isSpec); 1507 } 1508 if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) { 1509 return prepareConstantInt(loc, intAttr, isSpec); 1510 } 1511 1512 return 0; 1513 } 1514 1515 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, 1516 bool isSpec) { 1517 if (!isSpec) { 1518 // We can de-duplicate normal constants, but not specialization constants. 1519 if (auto id = getConstantID(boolAttr)) { 1520 return id; 1521 } 1522 } 1523 1524 // Process the type for this bool literal 1525 uint32_t typeID = 0; 1526 if (failed(processType(loc, boolAttr.getType(), typeID))) { 1527 return 0; 1528 } 1529 1530 auto resultID = getNextID(); 1531 auto opcode = boolAttr.getValue() 1532 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue 1533 : spirv::Opcode::OpConstantTrue) 1534 : (isSpec ? spirv::Opcode::OpSpecConstantFalse 1535 : spirv::Opcode::OpConstantFalse); 1536 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); 1537 1538 if (!isSpec) { 1539 constIDMap[boolAttr] = resultID; 1540 } 1541 return resultID; 1542 } 1543 1544 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, 1545 bool isSpec) { 1546 if (!isSpec) { 1547 // We can de-duplicate normal constants, but not specialization constants. 1548 if (auto id = getConstantID(intAttr)) { 1549 return id; 1550 } 1551 } 1552 1553 // Process the type for this integer literal 1554 uint32_t typeID = 0; 1555 if (failed(processType(loc, intAttr.getType(), typeID))) { 1556 return 0; 1557 } 1558 1559 auto resultID = getNextID(); 1560 APInt value = intAttr.getValue(); 1561 unsigned bitwidth = value.getBitWidth(); 1562 bool isSigned = value.isSignedIntN(bitwidth); 1563 1564 auto opcode = 1565 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 1566 1567 // According to SPIR-V spec, "When the type's bit width is less than 32-bits, 1568 // the literal's value appears in the low-order bits of the word, and the 1569 // high-order bits must be 0 for a floating-point type, or 0 for an integer 1570 // type with Signedness of 0, or sign extended when Signedness is 1." 1571 if (bitwidth == 32 || bitwidth == 16) { 1572 uint32_t word = 0; 1573 if (isSigned) { 1574 word = static_cast<int32_t>(value.getSExtValue()); 1575 } else { 1576 word = static_cast<uint32_t>(value.getZExtValue()); 1577 } 1578 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); 1579 } 1580 // According to SPIR-V spec: "When the type's bit width is larger than one 1581 // word, the literal’s low-order words appear first." 1582 else if (bitwidth == 64) { 1583 struct DoubleWord { 1584 uint32_t word1; 1585 uint32_t word2; 1586 } words; 1587 if (isSigned) { 1588 words = llvm::bit_cast<DoubleWord>(value.getSExtValue()); 1589 } else { 1590 words = llvm::bit_cast<DoubleWord>(value.getZExtValue()); 1591 } 1592 encodeInstructionInto(typesGlobalValues, opcode, 1593 {typeID, resultID, words.word1, words.word2}); 1594 } else { 1595 std::string valueStr; 1596 llvm::raw_string_ostream rss(valueStr); 1597 value.print(rss, /*isSigned=*/false); 1598 1599 emitError(loc, "cannot serialize ") 1600 << bitwidth << "-bit integer literal: " << rss.str(); 1601 return 0; 1602 } 1603 1604 if (!isSpec) { 1605 constIDMap[intAttr] = resultID; 1606 } 1607 return resultID; 1608 } 1609 1610 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, 1611 bool isSpec) { 1612 if (!isSpec) { 1613 // We can de-duplicate normal constants, but not specialization constants. 1614 if (auto id = getConstantID(floatAttr)) { 1615 return id; 1616 } 1617 } 1618 1619 // Process the type for this float literal 1620 uint32_t typeID = 0; 1621 if (failed(processType(loc, floatAttr.getType(), typeID))) { 1622 return 0; 1623 } 1624 1625 auto resultID = getNextID(); 1626 APFloat value = floatAttr.getValue(); 1627 APInt intValue = value.bitcastToAPInt(); 1628 1629 auto opcode = 1630 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 1631 1632 if (&value.getSemantics() == &APFloat::IEEEsingle()) { 1633 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat()); 1634 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); 1635 } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { 1636 struct DoubleWord { 1637 uint32_t word1; 1638 uint32_t word2; 1639 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble()); 1640 encodeInstructionInto(typesGlobalValues, opcode, 1641 {typeID, resultID, words.word1, words.word2}); 1642 } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { 1643 uint32_t word = 1644 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()); 1645 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); 1646 } else { 1647 std::string valueStr; 1648 llvm::raw_string_ostream rss(valueStr); 1649 value.print(rss); 1650 1651 emitError(loc, "cannot serialize ") 1652 << floatAttr.getType() << "-typed float literal: " << rss.str(); 1653 return 0; 1654 } 1655 1656 if (!isSpec) { 1657 constIDMap[floatAttr] = resultID; 1658 } 1659 return resultID; 1660 } 1661 1662 //===----------------------------------------------------------------------===// 1663 // Control flow 1664 //===----------------------------------------------------------------------===// 1665 1666 uint32_t Serializer::getOrCreateBlockID(Block *block) { 1667 if (uint32_t id = getBlockID(block)) 1668 return id; 1669 return blockIDMap[block] = getNextID(); 1670 } 1671 1672 LogicalResult 1673 Serializer::processBlock(Block *block, bool omitLabel, 1674 function_ref<void()> actionBeforeTerminator) { 1675 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); 1676 LLVM_DEBUG(block->print(llvm::dbgs())); 1677 LLVM_DEBUG(llvm::dbgs() << '\n'); 1678 if (!omitLabel) { 1679 uint32_t blockID = getOrCreateBlockID(block); 1680 LLVM_DEBUG(llvm::dbgs() 1681 << "[block] " << block << " (id = " << blockID << ")\n"); 1682 1683 // Emit OpLabel for this block. 1684 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); 1685 } 1686 1687 // Emit OpPhi instructions for block arguments, if any. 1688 if (failed(emitPhiForBlockArguments(block))) 1689 return failure(); 1690 1691 // Process each op in this block except the terminator. 1692 for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { 1693 if (failed(processOperation(&op))) 1694 return failure(); 1695 } 1696 1697 // Process the terminator. 1698 if (actionBeforeTerminator) 1699 actionBeforeTerminator(); 1700 if (failed(processOperation(&block->back()))) 1701 return failure(); 1702 1703 return success(); 1704 } 1705 1706 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { 1707 // Nothing to do if this block has no arguments or it's the entry block, which 1708 // always has the same arguments as the function signature. 1709 if (block->args_empty() || block->isEntryBlock()) 1710 return success(); 1711 1712 // If the block has arguments, we need to create SPIR-V OpPhi instructions. 1713 // A SPIR-V OpPhi instruction is of the syntax: 1714 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair 1715 // So we need to collect all predecessor blocks and the arguments they send 1716 // to this block. 1717 SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors; 1718 for (Block *predecessor : block->getPredecessors()) { 1719 auto *terminator = predecessor->getTerminator(); 1720 // The predecessor here is the immediate one according to MLIR's IR 1721 // structure. It does not directly map to the incoming parent block for the 1722 // OpPhi instructions at SPIR-V binary level. This is because structured 1723 // control flow ops are serialized to multiple SPIR-V blocks. If there is a 1724 // spv.selection/spv.loop op in the MLIR predecessor block, the branch op 1725 // jumping to the OpPhi's block then resides in the previous structured 1726 // control flow op's merge block. 1727 predecessor = getPhiIncomingBlock(predecessor); 1728 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) { 1729 predecessors.emplace_back(predecessor, branchOp.operand_begin()); 1730 } else { 1731 return terminator->emitError("unimplemented terminator for Phi creation"); 1732 } 1733 } 1734 1735 // Then create OpPhi instruction for each of the block argument. 1736 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) { 1737 BlockArgument arg = block->getArgument(argIndex); 1738 1739 // Get the type <id> and result <id> for this OpPhi instruction. 1740 uint32_t phiTypeID = 0; 1741 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) 1742 return failure(); 1743 uint32_t phiID = getNextID(); 1744 1745 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' 1746 << arg << " (id = " << phiID << ")\n"); 1747 1748 // Prepare the (value <id>, parent block <id>) pairs. 1749 SmallVector<uint32_t, 8> phiArgs; 1750 phiArgs.push_back(phiTypeID); 1751 phiArgs.push_back(phiID); 1752 1753 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) { 1754 Value value = *(predecessors[predIndex].second + argIndex); 1755 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); 1756 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId 1757 << ") value " << value << ' '); 1758 // Each pair is a value <id> ... 1759 uint32_t valueId = getValueID(value); 1760 if (valueId == 0) { 1761 // The op generating this value hasn't been visited yet so we don't have 1762 // an <id> assigned yet. Record this to fix up later. 1763 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); 1764 deferredPhiValues[value].push_back(functionBody.size() + 1 + 1765 phiArgs.size()); 1766 } else { 1767 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); 1768 } 1769 phiArgs.push_back(valueId); 1770 // ... and a parent block <id>. 1771 phiArgs.push_back(predBlockId); 1772 } 1773 1774 encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); 1775 valueIDMap[arg] = phiID; 1776 } 1777 1778 return success(); 1779 } 1780 1781 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { 1782 // Assign <id>s to all blocks so that branches inside the SelectionOp can 1783 // resolve properly. 1784 auto &body = selectionOp.body(); 1785 for (Block &block : body) 1786 getOrCreateBlockID(&block); 1787 1788 auto *headerBlock = selectionOp.getHeaderBlock(); 1789 auto *mergeBlock = selectionOp.getMergeBlock(); 1790 auto mergeID = getBlockID(mergeBlock); 1791 auto loc = selectionOp.getLoc(); 1792 1793 // Emit the selection header block, which dominates all other blocks, first. 1794 // We need to emit an OpSelectionMerge instruction before the selection header 1795 // block's terminator. 1796 auto emitSelectionMerge = [&]() { 1797 emitDebugLine(functionBody, loc); 1798 lastProcessedWasMergeInst = true; 1799 encodeInstructionInto( 1800 functionBody, spirv::Opcode::OpSelectionMerge, 1801 {mergeID, static_cast<uint32_t>(selectionOp.selection_control())}); 1802 }; 1803 // For structured selection, we cannot have blocks in the selection construct 1804 // branching to the selection header block. Entering the selection (and 1805 // reaching the selection header) must be from the block containing the 1806 // spv.selection op. If there are ops ahead of the spv.selection op in the 1807 // block, we can "merge" them into the selection header. So here we don't need 1808 // to emit a separate block; just continue with the existing block. 1809 if (failed(processBlock(headerBlock, /*omitLabel=*/true, emitSelectionMerge))) 1810 return failure(); 1811 1812 // Process all blocks with a depth-first visitor starting from the header 1813 // block. The selection header block and merge block are skipped by this 1814 // visitor. 1815 if (failed(visitInPrettyBlockOrder( 1816 headerBlock, [&](Block *block) { return processBlock(block); }, 1817 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) 1818 return failure(); 1819 1820 // There is nothing to do for the merge block in the selection, which just 1821 // contains a spv.mlir.merge op, itself. But we need to have an OpLabel 1822 // instruction to start a new SPIR-V block for ops following this SelectionOp. 1823 // The block should use the <id> for the merge block. 1824 return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); 1825 } 1826 1827 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { 1828 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve 1829 // properly. We don't need to assign for the entry block, which is just for 1830 // satisfying MLIR region's structural requirement. 1831 auto &body = loopOp.body(); 1832 for (Block &block : 1833 llvm::make_range(std::next(body.begin(), 1), body.end())) { 1834 getOrCreateBlockID(&block); 1835 } 1836 auto *headerBlock = loopOp.getHeaderBlock(); 1837 auto *continueBlock = loopOp.getContinueBlock(); 1838 auto *mergeBlock = loopOp.getMergeBlock(); 1839 auto headerID = getBlockID(headerBlock); 1840 auto continueID = getBlockID(continueBlock); 1841 auto mergeID = getBlockID(mergeBlock); 1842 auto loc = loopOp.getLoc(); 1843 1844 // This LoopOp is in some MLIR block with preceding and following ops. In the 1845 // binary format, it should reside in separate SPIR-V blocks from its 1846 // preceding and following ops. So we need to emit unconditional branches to 1847 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow 1848 // afterwards. 1849 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); 1850 1851 // LoopOp's entry block is just there for satisfying MLIR's structural 1852 // requirements so we omit it and start serialization from the loop header 1853 // block. 1854 1855 // Emit the loop header block, which dominates all other blocks, first. We 1856 // need to emit an OpLoopMerge instruction before the loop header block's 1857 // terminator. 1858 auto emitLoopMerge = [&]() { 1859 emitDebugLine(functionBody, loc); 1860 lastProcessedWasMergeInst = true; 1861 encodeInstructionInto( 1862 functionBody, spirv::Opcode::OpLoopMerge, 1863 {mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())}); 1864 }; 1865 if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) 1866 return failure(); 1867 1868 // Process all blocks with a depth-first visitor starting from the header 1869 // block. The loop header block, loop continue block, and loop merge block are 1870 // skipped by this visitor and handled later in this function. 1871 if (failed(visitInPrettyBlockOrder( 1872 headerBlock, [&](Block *block) { return processBlock(block); }, 1873 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) 1874 return failure(); 1875 1876 // We have handled all other blocks. Now get to the loop continue block. 1877 if (failed(processBlock(continueBlock))) 1878 return failure(); 1879 1880 // There is nothing to do for the merge block in the loop, which just contains 1881 // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to 1882 // start a new SPIR-V block for ops following this LoopOp. The block should 1883 // use the <id> for the merge block. 1884 return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); 1885 } 1886 1887 LogicalResult Serializer::processBranchConditionalOp( 1888 spirv::BranchConditionalOp condBranchOp) { 1889 auto conditionID = getValueID(condBranchOp.condition()); 1890 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); 1891 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); 1892 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID}; 1893 1894 if (auto weights = condBranchOp.branch_weights()) { 1895 for (auto val : weights->getValue()) 1896 arguments.push_back(val.cast<IntegerAttr>().getInt()); 1897 } 1898 1899 emitDebugLine(functionBody, condBranchOp.getLoc()); 1900 return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, 1901 arguments); 1902 } 1903 1904 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { 1905 emitDebugLine(functionBody, branchOp.getLoc()); 1906 return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, 1907 {getOrCreateBlockID(branchOp.getTarget())}); 1908 } 1909 1910 //===----------------------------------------------------------------------===// 1911 // Operation 1912 //===----------------------------------------------------------------------===// 1913 1914 LogicalResult Serializer::encodeExtensionInstruction( 1915 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, 1916 ArrayRef<uint32_t> operands) { 1917 // Check if the extension has been imported. 1918 auto &setID = extendedInstSetIDMap[extensionSetName]; 1919 if (!setID) { 1920 setID = getNextID(); 1921 SmallVector<uint32_t, 16> importOperands; 1922 importOperands.push_back(setID); 1923 if (failed( 1924 spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || 1925 failed(encodeInstructionInto( 1926 extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { 1927 return failure(); 1928 } 1929 } 1930 1931 // The first two operands are the result type <id> and result <id>. The set 1932 // <id> and the opcode need to be insert after this. 1933 if (operands.size() < 2) { 1934 return op->emitError("extended instructions must have a result encoding"); 1935 } 1936 SmallVector<uint32_t, 8> extInstOperands; 1937 extInstOperands.reserve(operands.size() + 2); 1938 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); 1939 extInstOperands.push_back(setID); 1940 extInstOperands.push_back(extensionOpcode); 1941 extInstOperands.append(std::next(operands.begin(), 2), operands.end()); 1942 return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, 1943 extInstOperands); 1944 } 1945 1946 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { 1947 auto varName = addressOfOp.variable(); 1948 auto variableID = getVariableID(varName); 1949 if (!variableID) { 1950 return addressOfOp.emitError("unknown result <id> for variable ") 1951 << varName; 1952 } 1953 valueIDMap[addressOfOp.pointer()] = variableID; 1954 return success(); 1955 } 1956 1957 LogicalResult 1958 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { 1959 auto constName = referenceOfOp.spec_const(); 1960 auto constID = getSpecConstID(constName); 1961 if (!constID) { 1962 return referenceOfOp.emitError( 1963 "unknown result <id> for specialization constant ") 1964 << constName; 1965 } 1966 valueIDMap[referenceOfOp.reference()] = constID; 1967 return success(); 1968 } 1969 1970 LogicalResult Serializer::processOperation(Operation *opInst) { 1971 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); 1972 1973 // First dispatch the ops that do not directly mirror an instruction from 1974 // the SPIR-V spec. 1975 return TypeSwitch<Operation *, LogicalResult>(opInst) 1976 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) 1977 .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) 1978 .Case([&](spirv::BranchConditionalOp op) { 1979 return processBranchConditionalOp(op); 1980 }) 1981 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) 1982 .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) 1983 .Case([&](spirv::GlobalVariableOp op) { 1984 return processGlobalVariableOp(op); 1985 }) 1986 .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) 1987 .Case([&](spirv::ModuleEndOp) { return success(); }) 1988 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) 1989 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) 1990 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) 1991 .Case([&](spirv::SpecConstantCompositeOp op) { 1992 return processSpecConstantCompositeOp(op); 1993 }) 1994 .Case([&](spirv::SpecConstantOperationOp op) { 1995 return processSpecConstantOperationOp(op); 1996 }) 1997 .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) 1998 .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) 1999 2000 // Then handle all the ops that directly mirror SPIR-V instructions with 2001 // auto-generated methods. 2002 .Default( 2003 [&](Operation *op) { return dispatchToAutogenSerialization(op); }); 2004 } 2005 2006 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, 2007 StringRef extInstSet, 2008 uint32_t opcode) { 2009 SmallVector<uint32_t, 4> operands; 2010 Location loc = op->getLoc(); 2011 2012 uint32_t resultID = 0; 2013 if (op->getNumResults() != 0) { 2014 uint32_t resultTypeID = 0; 2015 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID))) 2016 return failure(); 2017 operands.push_back(resultTypeID); 2018 2019 resultID = getNextID(); 2020 operands.push_back(resultID); 2021 valueIDMap[op->getResult(0)] = resultID; 2022 }; 2023 2024 for (Value operand : op->getOperands()) 2025 operands.push_back(getValueID(operand)); 2026 2027 emitDebugLine(functionBody, loc); 2028 2029 if (extInstSet.empty()) { 2030 encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode), 2031 operands); 2032 } else { 2033 encodeExtensionInstruction(op, extInstSet, opcode, operands); 2034 } 2035 2036 if (op->getNumResults() != 0) { 2037 for (auto attr : op->getAttrs()) { 2038 if (failed(processDecoration(loc, resultID, attr))) 2039 return failure(); 2040 } 2041 } 2042 2043 return success(); 2044 } 2045 2046 namespace { 2047 template <> 2048 LogicalResult 2049 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { 2050 SmallVector<uint32_t, 4> operands; 2051 // Add the ExecutionModel. 2052 operands.push_back(static_cast<uint32_t>(op.execution_model())); 2053 // Add the function <id>. 2054 auto funcID = getFunctionID(op.fn()); 2055 if (!funcID) { 2056 return op.emitError("missing <id> for function ") 2057 << op.fn() 2058 << "; function needs to be defined before spv.EntryPoint is " 2059 "serialized"; 2060 } 2061 operands.push_back(funcID); 2062 // Add the name of the function. 2063 spirv::encodeStringLiteralInto(operands, op.fn()); 2064 2065 // Add the interface values. 2066 if (auto interface = op.interface()) { 2067 for (auto var : interface.getValue()) { 2068 auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue()); 2069 if (!id) { 2070 return op.emitError("referencing undefined global variable." 2071 "spv.EntryPoint is at the end of spv.module. All " 2072 "referenced variables should already be defined"); 2073 } 2074 operands.push_back(id); 2075 } 2076 } 2077 return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, 2078 operands); 2079 } 2080 2081 template <> 2082 LogicalResult 2083 Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) { 2084 StringRef argNames[] = {"execution_scope", "memory_scope", 2085 "memory_semantics"}; 2086 SmallVector<uint32_t, 3> operands; 2087 2088 for (auto argName : argNames) { 2089 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName); 2090 auto operand = prepareConstantInt(op.getLoc(), argIntAttr); 2091 if (!operand) { 2092 return failure(); 2093 } 2094 operands.push_back(operand); 2095 } 2096 2097 return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier, 2098 operands); 2099 } 2100 2101 template <> 2102 LogicalResult 2103 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) { 2104 SmallVector<uint32_t, 4> operands; 2105 // Add the function <id>. 2106 auto funcID = getFunctionID(op.fn()); 2107 if (!funcID) { 2108 return op.emitError("missing <id> for function ") 2109 << op.fn() 2110 << "; function needs to be serialized before ExecutionModeOp is " 2111 "serialized"; 2112 } 2113 operands.push_back(funcID); 2114 // Add the ExecutionMode. 2115 operands.push_back(static_cast<uint32_t>(op.execution_mode())); 2116 2117 // Serialize values if any. 2118 auto values = op.values(); 2119 if (values) { 2120 for (auto &intVal : values.getValue()) { 2121 operands.push_back(static_cast<uint32_t>( 2122 intVal.cast<IntegerAttr>().getValue().getZExtValue())); 2123 } 2124 } 2125 return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, 2126 operands); 2127 } 2128 2129 template <> 2130 LogicalResult 2131 Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) { 2132 StringRef argNames[] = {"memory_scope", "memory_semantics"}; 2133 SmallVector<uint32_t, 2> operands; 2134 2135 for (auto argName : argNames) { 2136 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName); 2137 auto operand = prepareConstantInt(op.getLoc(), argIntAttr); 2138 if (!operand) { 2139 return failure(); 2140 } 2141 operands.push_back(operand); 2142 } 2143 2144 return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, 2145 operands); 2146 } 2147 2148 template <> 2149 LogicalResult 2150 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) { 2151 auto funcName = op.callee(); 2152 uint32_t resTypeID = 0; 2153 2154 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); 2155 if (failed(processType(op.getLoc(), resultTy, resTypeID))) 2156 return failure(); 2157 2158 auto funcID = getOrCreateFunctionID(funcName); 2159 auto funcCallID = getNextID(); 2160 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID}; 2161 2162 for (auto value : op.arguments()) { 2163 auto valueID = getValueID(value); 2164 assert(valueID && "cannot find a value for spv.FunctionCall"); 2165 operands.push_back(valueID); 2166 } 2167 2168 if (!resultTy.isa<NoneType>()) 2169 valueIDMap[op.getResult(0)] = funcCallID; 2170 2171 return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, 2172 operands); 2173 } 2174 2175 template <> 2176 LogicalResult 2177 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) { 2178 SmallVector<uint32_t, 4> operands; 2179 SmallVector<StringRef, 2> elidedAttrs; 2180 2181 for (Value operand : op->getOperands()) { 2182 auto id = getValueID(operand); 2183 assert(id && "use before def!"); 2184 operands.push_back(id); 2185 } 2186 2187 if (auto attr = op->getAttr("memory_access")) { 2188 operands.push_back(static_cast<uint32_t>( 2189 attr.cast<IntegerAttr>().getValue().getZExtValue())); 2190 } 2191 2192 elidedAttrs.push_back("memory_access"); 2193 2194 if (auto attr = op->getAttr("alignment")) { 2195 operands.push_back(static_cast<uint32_t>( 2196 attr.cast<IntegerAttr>().getValue().getZExtValue())); 2197 } 2198 2199 elidedAttrs.push_back("alignment"); 2200 2201 if (auto attr = op->getAttr("source_memory_access")) { 2202 operands.push_back(static_cast<uint32_t>( 2203 attr.cast<IntegerAttr>().getValue().getZExtValue())); 2204 } 2205 2206 elidedAttrs.push_back("source_memory_access"); 2207 2208 if (auto attr = op->getAttr("source_alignment")) { 2209 operands.push_back(static_cast<uint32_t>( 2210 attr.cast<IntegerAttr>().getValue().getZExtValue())); 2211 } 2212 2213 elidedAttrs.push_back("source_alignment"); 2214 emitDebugLine(functionBody, op.getLoc()); 2215 encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); 2216 2217 return success(); 2218 } 2219 2220 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and 2221 // various Serializer::processOp<...>() specializations. 2222 #define GET_SERIALIZATION_FNS 2223 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" 2224 } // namespace 2225 2226 LogicalResult Serializer::emitDecoration(uint32_t target, 2227 spirv::Decoration decoration, 2228 ArrayRef<uint32_t> params) { 2229 uint32_t wordCount = 3 + params.size(); 2230 decorations.push_back( 2231 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate)); 2232 decorations.push_back(target); 2233 decorations.push_back(static_cast<uint32_t>(decoration)); 2234 decorations.append(params.begin(), params.end()); 2235 return success(); 2236 } 2237 2238 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary, 2239 Location loc) { 2240 if (!emitDebugInfo) 2241 return success(); 2242 2243 if (lastProcessedWasMergeInst) { 2244 lastProcessedWasMergeInst = false; 2245 return success(); 2246 } 2247 2248 auto fileLoc = loc.dyn_cast<FileLineColLoc>(); 2249 if (fileLoc) 2250 encodeInstructionInto(binary, spirv::Opcode::OpLine, 2251 {fileID, fileLoc.getLine(), fileLoc.getColumn()}); 2252 return success(); 2253 } 2254 2255 namespace mlir { 2256 LogicalResult spirv::serialize(spirv::ModuleOp module, 2257 SmallVectorImpl<uint32_t> &binary, 2258 bool emitDebugInfo) { 2259 if (!module.vce_triple().hasValue()) 2260 return module.emitError( 2261 "module must have 'vce_triple' attribute to be serializeable"); 2262 2263 Serializer serializer(module, emitDebugInfo); 2264 2265 if (failed(serializer.serialize())) 2266 return failure(); 2267 2268 LLVM_DEBUG(serializer.printValueIDMap(llvm::dbgs())); 2269 2270 serializer.collect(binary); 2271 return success(); 2272 } 2273 } // namespace mlir 2274