1 //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the serialization methods for MLIR SPIR-V module ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Serializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/IR/RegionGraphTraits.h" 17 #include "mlir/Support/LogicalResult.h" 18 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 19 #include "llvm/ADT/DepthFirstIterator.h" 20 #include "llvm/Support/Debug.h" 21 22 #define DEBUG_TYPE "spirv-serialization" 23 24 using namespace mlir; 25 26 /// A pre-order depth-first visitor function for processing basic blocks. 27 /// 28 /// Visits the basic blocks starting from the given `headerBlock` in pre-order 29 /// depth-first manner and calls `blockHandler` on each block. Skips handling 30 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler` 31 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s 32 /// successors. 33 /// 34 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order 35 /// of blocks in a function must satisfy the rule that blocks appear before 36 /// all blocks they dominate." This can be achieved by a pre-order CFG 37 /// traversal algorithm. To make the serialization output more logical and 38 /// readable to human, we perform depth-first CFG traversal and delay the 39 /// serialization of the merge block and the continue block, if exists, until 40 /// after all other blocks have been processed. 41 static LogicalResult 42 visitInPrettyBlockOrder(Block *headerBlock, 43 function_ref<LogicalResult(Block *)> blockHandler, 44 bool skipHeader = false, BlockRange skipBlocks = {}) { 45 llvm::df_iterator_default_set<Block *, 4> doneBlocks; 46 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end()); 47 48 for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) { 49 if (skipHeader && block == headerBlock) 50 continue; 51 if (failed(blockHandler(block))) 52 return failure(); 53 } 54 return success(); 55 } 56 57 namespace mlir { 58 namespace spirv { 59 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { 60 if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) { 61 valueIDMap[op.getResult()] = resultID; 62 return success(); 63 } 64 return failure(); 65 } 66 67 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { 68 if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(), 69 /*isSpec=*/true)) { 70 // Emit the OpDecorate instruction for SpecId. 71 if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) { 72 auto val = static_cast<uint32_t>(specID.getInt()); 73 if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val}))) 74 return failure(); 75 } 76 77 specConstIDMap[op.sym_name()] = resultID; 78 return processName(resultID, op.sym_name()); 79 } 80 return failure(); 81 } 82 83 LogicalResult 84 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { 85 uint32_t typeID = 0; 86 if (failed(processType(op.getLoc(), op.type(), typeID))) { 87 return failure(); 88 } 89 90 auto resultID = getNextID(); 91 92 SmallVector<uint32_t, 8> operands; 93 operands.push_back(typeID); 94 operands.push_back(resultID); 95 96 auto constituents = op.constituents(); 97 98 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { 99 auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>(); 100 101 auto constituentName = constituent.getValue(); 102 auto constituentID = getSpecConstID(constituentName); 103 104 if (!constituentID) { 105 return op.emitError("unknown result <id> for specialization constant ") 106 << constituentName; 107 } 108 109 operands.push_back(constituentID); 110 } 111 112 encodeInstructionInto(typesGlobalValues, 113 spirv::Opcode::OpSpecConstantComposite, operands); 114 specConstIDMap[op.sym_name()] = resultID; 115 116 return processName(resultID, op.sym_name()); 117 } 118 119 LogicalResult 120 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { 121 uint32_t typeID = 0; 122 if (failed(processType(op.getLoc(), op.getType(), typeID))) { 123 return failure(); 124 } 125 126 auto resultID = getNextID(); 127 128 SmallVector<uint32_t, 8> operands; 129 operands.push_back(typeID); 130 operands.push_back(resultID); 131 132 Block &block = op.getRegion().getBlocks().front(); 133 Operation &enclosedOp = block.getOperations().front(); 134 135 std::string enclosedOpName; 136 llvm::raw_string_ostream rss(enclosedOpName); 137 rss << "Op" << enclosedOp.getName().stripDialect(); 138 auto enclosedOpcode = spirv::symbolizeOpcode(rss.str()); 139 140 if (!enclosedOpcode) { 141 op.emitError("Couldn't find op code for op ") 142 << enclosedOp.getName().getStringRef(); 143 return failure(); 144 } 145 146 operands.push_back(static_cast<uint32_t>(*enclosedOpcode)); 147 148 // Append operands to the enclosed op to the list of operands. 149 for (Value operand : enclosedOp.getOperands()) { 150 uint32_t id = getValueID(operand); 151 assert(id && "use before def!"); 152 operands.push_back(id); 153 } 154 155 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp, 156 operands); 157 valueIDMap[op.getResult()] = resultID; 158 159 return success(); 160 } 161 162 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { 163 auto undefType = op.getType(); 164 auto &id = undefValIDMap[undefType]; 165 if (!id) { 166 id = getNextID(); 167 uint32_t typeID = 0; 168 if (failed(processType(op.getLoc(), undefType, typeID))) 169 return failure(); 170 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef, 171 {typeID, id}); 172 } 173 valueIDMap[op.getResult()] = id; 174 return success(); 175 } 176 177 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { 178 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); 179 assert(functionHeader.empty() && functionBody.empty()); 180 181 uint32_t fnTypeID = 0; 182 // Generate type of the function. 183 if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID))) 184 return failure(); 185 186 // Add the function definition. 187 SmallVector<uint32_t, 4> operands; 188 uint32_t resTypeID = 0; 189 auto resultTypes = op.getFunctionType().getResults(); 190 if (resultTypes.size() > 1) { 191 return op.emitError("cannot serialize function with multiple return types"); 192 } 193 if (failed(processType(op.getLoc(), 194 (resultTypes.empty() ? getVoidType() : resultTypes[0]), 195 resTypeID))) { 196 return failure(); 197 } 198 operands.push_back(resTypeID); 199 auto funcID = getOrCreateFunctionID(op.getName()); 200 operands.push_back(funcID); 201 operands.push_back(static_cast<uint32_t>(op.function_control())); 202 operands.push_back(fnTypeID); 203 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); 204 205 // Add function name. 206 if (failed(processName(funcID, op.getName()))) { 207 return failure(); 208 } 209 210 // Declare the parameters. 211 for (auto arg : op.getArguments()) { 212 uint32_t argTypeID = 0; 213 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { 214 return failure(); 215 } 216 auto argValueID = getNextID(); 217 valueIDMap[arg] = argValueID; 218 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, 219 {argTypeID, argValueID}); 220 } 221 222 // Process the body. 223 if (op.isExternal()) { 224 return op.emitError("external function is unhandled"); 225 } 226 227 // Some instructions (e.g., OpVariable) in a function must be in the first 228 // block in the function. These instructions will be put in functionHeader. 229 // Thus, we put the label in functionHeader first, and omit it from the first 230 // block. 231 encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, 232 {getOrCreateBlockID(&op.front())}); 233 if (failed(processBlock(&op.front(), /*omitLabel=*/true))) 234 return failure(); 235 if (failed(visitInPrettyBlockOrder( 236 &op.front(), [&](Block *block) { return processBlock(block); }, 237 /*skipHeader=*/true))) { 238 return failure(); 239 } 240 241 // There might be OpPhi instructions who have value references needing to fix. 242 for (const auto &deferredValue : deferredPhiValues) { 243 Value value = deferredValue.first; 244 uint32_t id = getValueID(value); 245 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value 246 << " to id = " << id << '\n'); 247 assert(id && "OpPhi references undefined value!"); 248 for (size_t offset : deferredValue.second) 249 functionBody[offset] = id; 250 } 251 deferredPhiValues.clear(); 252 253 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() 254 << "' --\n"); 255 // Insert OpFunctionEnd. 256 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {}); 257 258 functions.append(functionHeader.begin(), functionHeader.end()); 259 functions.append(functionBody.begin(), functionBody.end()); 260 functionHeader.clear(); 261 functionBody.clear(); 262 263 return success(); 264 } 265 266 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { 267 SmallVector<uint32_t, 4> operands; 268 SmallVector<StringRef, 2> elidedAttrs; 269 uint32_t resultID = 0; 270 uint32_t resultTypeID = 0; 271 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) { 272 return failure(); 273 } 274 operands.push_back(resultTypeID); 275 resultID = getNextID(); 276 valueIDMap[op.getResult()] = resultID; 277 operands.push_back(resultID); 278 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>()); 279 if (attr) { 280 operands.push_back(static_cast<uint32_t>( 281 attr.cast<IntegerAttr>().getValue().getZExtValue())); 282 } 283 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); 284 for (auto arg : op.getODSOperands(0)) { 285 auto argID = getValueID(arg); 286 if (!argID) { 287 return emitError(op.getLoc(), "operand 0 has a use before def"); 288 } 289 operands.push_back(argID); 290 } 291 if (failed(emitDebugLine(functionHeader, op.getLoc()))) 292 return failure(); 293 encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands); 294 for (auto attr : op->getAttrs()) { 295 if (llvm::any_of(elidedAttrs, [&](StringRef elided) { 296 return attr.getName() == elided; 297 })) { 298 continue; 299 } 300 if (failed(processDecoration(op.getLoc(), resultID, attr))) { 301 return failure(); 302 } 303 } 304 return success(); 305 } 306 307 LogicalResult 308 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { 309 // Get TypeID. 310 uint32_t resultTypeID = 0; 311 SmallVector<StringRef, 4> elidedAttrs; 312 if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { 313 return failure(); 314 } 315 316 elidedAttrs.push_back("type"); 317 SmallVector<uint32_t, 4> operands; 318 operands.push_back(resultTypeID); 319 auto resultID = getNextID(); 320 321 // Encode the name. 322 auto varName = varOp.sym_name(); 323 elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); 324 if (failed(processName(resultID, varName))) { 325 return failure(); 326 } 327 globalVarIDMap[varName] = resultID; 328 operands.push_back(resultID); 329 330 // Encode StorageClass. 331 operands.push_back(static_cast<uint32_t>(varOp.storageClass())); 332 333 // Encode initialization. 334 if (auto initializer = varOp.initializer()) { 335 auto initializerID = getVariableID(*initializer); 336 if (!initializerID) { 337 return emitError(varOp.getLoc(), 338 "invalid usage of undefined variable as initializer"); 339 } 340 operands.push_back(initializerID); 341 elidedAttrs.push_back("initializer"); 342 } 343 344 if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc()))) 345 return failure(); 346 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands); 347 elidedAttrs.push_back("initializer"); 348 349 // Encode decorations. 350 for (auto attr : varOp->getAttrs()) { 351 if (llvm::any_of(elidedAttrs, [&](StringRef elided) { 352 return attr.getName() == elided; 353 })) { 354 continue; 355 } 356 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { 357 return failure(); 358 } 359 } 360 return success(); 361 } 362 363 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { 364 // Assign <id>s to all blocks so that branches inside the SelectionOp can 365 // resolve properly. 366 auto &body = selectionOp.body(); 367 for (Block &block : body) 368 getOrCreateBlockID(&block); 369 370 auto *headerBlock = selectionOp.getHeaderBlock(); 371 auto *mergeBlock = selectionOp.getMergeBlock(); 372 auto headerID = getBlockID(headerBlock); 373 auto mergeID = getBlockID(mergeBlock); 374 auto loc = selectionOp.getLoc(); 375 376 // This SelectionOp is in some MLIR block with preceding and following ops. In 377 // the binary format, it should reside in separate SPIR-V blocks from its 378 // preceding and following ops. So we need to emit unconditional branches to 379 // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal 380 // flow afterwards. 381 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); 382 383 // Emit the selection header block, which dominates all other blocks, first. 384 // We need to emit an OpSelectionMerge instruction before the selection header 385 // block's terminator. 386 auto emitSelectionMerge = [&]() { 387 if (failed(emitDebugLine(functionBody, loc))) 388 return failure(); 389 lastProcessedWasMergeInst = true; 390 encodeInstructionInto( 391 functionBody, spirv::Opcode::OpSelectionMerge, 392 {mergeID, static_cast<uint32_t>(selectionOp.selection_control())}); 393 return success(); 394 }; 395 if (failed( 396 processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge))) 397 return failure(); 398 399 // Process all blocks with a depth-first visitor starting from the header 400 // block. The selection header block and merge block are skipped by this 401 // visitor. 402 if (failed(visitInPrettyBlockOrder( 403 headerBlock, [&](Block *block) { return processBlock(block); }, 404 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock}))) 405 return failure(); 406 407 // There is nothing to do for the merge block in the selection, which just 408 // contains a spv.mlir.merge op, itself. But we need to have an OpLabel 409 // instruction to start a new SPIR-V block for ops following this SelectionOp. 410 // The block should use the <id> for the merge block. 411 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); 412 LLVM_DEBUG(llvm::dbgs() << "done merge "); 413 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); 414 LLVM_DEBUG(llvm::dbgs() << "\n"); 415 return success(); 416 } 417 418 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { 419 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve 420 // properly. We don't need to assign for the entry block, which is just for 421 // satisfying MLIR region's structural requirement. 422 auto &body = loopOp.body(); 423 for (Block &block : llvm::drop_begin(body)) 424 getOrCreateBlockID(&block); 425 426 auto *headerBlock = loopOp.getHeaderBlock(); 427 auto *continueBlock = loopOp.getContinueBlock(); 428 auto *mergeBlock = loopOp.getMergeBlock(); 429 auto headerID = getBlockID(headerBlock); 430 auto continueID = getBlockID(continueBlock); 431 auto mergeID = getBlockID(mergeBlock); 432 auto loc = loopOp.getLoc(); 433 434 // This LoopOp is in some MLIR block with preceding and following ops. In the 435 // binary format, it should reside in separate SPIR-V blocks from its 436 // preceding and following ops. So we need to emit unconditional branches to 437 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow 438 // afterwards. 439 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); 440 441 // LoopOp's entry block is just there for satisfying MLIR's structural 442 // requirements so we omit it and start serialization from the loop header 443 // block. 444 445 // Emit the loop header block, which dominates all other blocks, first. We 446 // need to emit an OpLoopMerge instruction before the loop header block's 447 // terminator. 448 auto emitLoopMerge = [&]() { 449 if (failed(emitDebugLine(functionBody, loc))) 450 return failure(); 451 lastProcessedWasMergeInst = true; 452 encodeInstructionInto( 453 functionBody, spirv::Opcode::OpLoopMerge, 454 {mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())}); 455 return success(); 456 }; 457 if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) 458 return failure(); 459 460 // Process all blocks with a depth-first visitor starting from the header 461 // block. The loop header block, loop continue block, and loop merge block are 462 // skipped by this visitor and handled later in this function. 463 if (failed(visitInPrettyBlockOrder( 464 headerBlock, [&](Block *block) { return processBlock(block); }, 465 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock}))) 466 return failure(); 467 468 // We have handled all other blocks. Now get to the loop continue block. 469 if (failed(processBlock(continueBlock))) 470 return failure(); 471 472 // There is nothing to do for the merge block in the loop, which just contains 473 // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to 474 // start a new SPIR-V block for ops following this LoopOp. The block should 475 // use the <id> for the merge block. 476 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); 477 LLVM_DEBUG(llvm::dbgs() << "done merge "); 478 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); 479 LLVM_DEBUG(llvm::dbgs() << "\n"); 480 return success(); 481 } 482 483 LogicalResult Serializer::processBranchConditionalOp( 484 spirv::BranchConditionalOp condBranchOp) { 485 auto conditionID = getValueID(condBranchOp.condition()); 486 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); 487 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); 488 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID}; 489 490 if (auto weights = condBranchOp.branch_weights()) { 491 for (auto val : weights->getValue()) 492 arguments.push_back(val.cast<IntegerAttr>().getInt()); 493 } 494 495 if (failed(emitDebugLine(functionBody, condBranchOp.getLoc()))) 496 return failure(); 497 encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional, 498 arguments); 499 return success(); 500 } 501 502 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { 503 if (failed(emitDebugLine(functionBody, branchOp.getLoc()))) 504 return failure(); 505 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, 506 {getOrCreateBlockID(branchOp.getTarget())}); 507 return success(); 508 } 509 510 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { 511 auto varName = addressOfOp.variable(); 512 auto variableID = getVariableID(varName); 513 if (!variableID) { 514 return addressOfOp.emitError("unknown result <id> for variable ") 515 << varName; 516 } 517 valueIDMap[addressOfOp.pointer()] = variableID; 518 return success(); 519 } 520 521 LogicalResult 522 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { 523 auto constName = referenceOfOp.spec_const(); 524 auto constID = getSpecConstID(constName); 525 if (!constID) { 526 return referenceOfOp.emitError( 527 "unknown result <id> for specialization constant ") 528 << constName; 529 } 530 valueIDMap[referenceOfOp.reference()] = constID; 531 return success(); 532 } 533 534 template <> 535 LogicalResult 536 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { 537 SmallVector<uint32_t, 4> operands; 538 // Add the ExecutionModel. 539 operands.push_back(static_cast<uint32_t>(op.execution_model())); 540 // Add the function <id>. 541 auto funcID = getFunctionID(op.fn()); 542 if (!funcID) { 543 return op.emitError("missing <id> for function ") 544 << op.fn() 545 << "; function needs to be defined before spv.EntryPoint is " 546 "serialized"; 547 } 548 operands.push_back(funcID); 549 // Add the name of the function. 550 spirv::encodeStringLiteralInto(operands, op.fn()); 551 552 // Add the interface values. 553 if (auto interface = op.interface()) { 554 for (auto var : interface.getValue()) { 555 auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue()); 556 if (!id) { 557 return op.emitError("referencing undefined global variable." 558 "spv.EntryPoint is at the end of spv.module. All " 559 "referenced variables should already be defined"); 560 } 561 operands.push_back(id); 562 } 563 } 564 encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands); 565 return success(); 566 } 567 568 template <> 569 LogicalResult 570 Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) { 571 StringRef argNames[] = {"execution_scope", "memory_scope", 572 "memory_semantics"}; 573 SmallVector<uint32_t, 3> operands; 574 575 for (auto argName : argNames) { 576 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName); 577 auto operand = prepareConstantInt(op.getLoc(), argIntAttr); 578 if (!operand) { 579 return failure(); 580 } 581 operands.push_back(operand); 582 } 583 584 encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier, 585 operands); 586 return success(); 587 } 588 589 template <> 590 LogicalResult 591 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) { 592 SmallVector<uint32_t, 4> operands; 593 // Add the function <id>. 594 auto funcID = getFunctionID(op.fn()); 595 if (!funcID) { 596 return op.emitError("missing <id> for function ") 597 << op.fn() 598 << "; function needs to be serialized before ExecutionModeOp is " 599 "serialized"; 600 } 601 operands.push_back(funcID); 602 // Add the ExecutionMode. 603 operands.push_back(static_cast<uint32_t>(op.execution_mode())); 604 605 // Serialize values if any. 606 auto values = op.values(); 607 if (values) { 608 for (auto &intVal : values.getValue()) { 609 operands.push_back(static_cast<uint32_t>( 610 intVal.cast<IntegerAttr>().getValue().getZExtValue())); 611 } 612 } 613 encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, 614 operands); 615 return success(); 616 } 617 618 template <> 619 LogicalResult 620 Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) { 621 StringRef argNames[] = {"memory_scope", "memory_semantics"}; 622 SmallVector<uint32_t, 2> operands; 623 624 for (auto argName : argNames) { 625 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName); 626 auto operand = prepareConstantInt(op.getLoc(), argIntAttr); 627 if (!operand) { 628 return failure(); 629 } 630 operands.push_back(operand); 631 } 632 633 encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, operands); 634 return success(); 635 } 636 637 template <> 638 LogicalResult 639 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) { 640 auto funcName = op.callee(); 641 uint32_t resTypeID = 0; 642 643 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); 644 if (failed(processType(op.getLoc(), resultTy, resTypeID))) 645 return failure(); 646 647 auto funcID = getOrCreateFunctionID(funcName); 648 auto funcCallID = getNextID(); 649 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID}; 650 651 for (auto value : op.arguments()) { 652 auto valueID = getValueID(value); 653 assert(valueID && "cannot find a value for spv.FunctionCall"); 654 operands.push_back(valueID); 655 } 656 657 if (!resultTy.isa<NoneType>()) 658 valueIDMap[op.getResult(0)] = funcCallID; 659 660 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands); 661 return success(); 662 } 663 664 template <> 665 LogicalResult 666 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) { 667 SmallVector<uint32_t, 4> operands; 668 SmallVector<StringRef, 2> elidedAttrs; 669 670 for (Value operand : op->getOperands()) { 671 auto id = getValueID(operand); 672 assert(id && "use before def!"); 673 operands.push_back(id); 674 } 675 676 if (auto attr = op->getAttr("memory_access")) { 677 operands.push_back(static_cast<uint32_t>( 678 attr.cast<IntegerAttr>().getValue().getZExtValue())); 679 } 680 681 elidedAttrs.push_back("memory_access"); 682 683 if (auto attr = op->getAttr("alignment")) { 684 operands.push_back(static_cast<uint32_t>( 685 attr.cast<IntegerAttr>().getValue().getZExtValue())); 686 } 687 688 elidedAttrs.push_back("alignment"); 689 690 if (auto attr = op->getAttr("source_memory_access")) { 691 operands.push_back(static_cast<uint32_t>( 692 attr.cast<IntegerAttr>().getValue().getZExtValue())); 693 } 694 695 elidedAttrs.push_back("source_memory_access"); 696 697 if (auto attr = op->getAttr("source_alignment")) { 698 operands.push_back(static_cast<uint32_t>( 699 attr.cast<IntegerAttr>().getValue().getZExtValue())); 700 } 701 702 elidedAttrs.push_back("source_alignment"); 703 if (failed(emitDebugLine(functionBody, op.getLoc()))) 704 return failure(); 705 encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); 706 707 return success(); 708 } 709 710 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and 711 // various Serializer::processOp<...>() specializations. 712 #define GET_SERIALIZATION_FNS 713 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" 714 715 } // namespace spirv 716 } // namespace mlir 717