1 //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the serialization methods for MLIR SPIR-V module ops.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "Serializer.h"
14
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/IR/RegionGraphTraits.h"
17 #include "mlir/Support/LogicalResult.h"
18 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
19 #include "llvm/ADT/DepthFirstIterator.h"
20 #include "llvm/Support/Debug.h"
21
22 #define DEBUG_TYPE "spirv-serialization"
23
24 using namespace mlir;
25
26 /// A pre-order depth-first visitor function for processing basic blocks.
27 ///
28 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
29 /// depth-first manner and calls `blockHandler` on each block. Skips handling
30 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
31 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
32 /// successors.
33 ///
34 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
35 /// of blocks in a function must satisfy the rule that blocks appear before
36 /// all blocks they dominate." This can be achieved by a pre-order CFG
37 /// traversal algorithm. To make the serialization output more logical and
38 /// readable to human, we perform depth-first CFG traversal and delay the
39 /// serialization of the merge block and the continue block, if exists, until
40 /// after all other blocks have been processed.
41 static LogicalResult
visitInPrettyBlockOrder(Block * headerBlock,function_ref<LogicalResult (Block *)> blockHandler,bool skipHeader=false,BlockRange skipBlocks={})42 visitInPrettyBlockOrder(Block *headerBlock,
43 function_ref<LogicalResult(Block *)> blockHandler,
44 bool skipHeader = false, BlockRange skipBlocks = {}) {
45 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
46 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
47
48 for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
49 if (skipHeader && block == headerBlock)
50 continue;
51 if (failed(blockHandler(block)))
52 return failure();
53 }
54 return success();
55 }
56
57 namespace mlir {
58 namespace spirv {
processConstantOp(spirv::ConstantOp op)59 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
60 if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
61 valueIDMap[op.getResult()] = resultID;
62 return success();
63 }
64 return failure();
65 }
66
processSpecConstantOp(spirv::SpecConstantOp op)67 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
68 if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(),
69 /*isSpec=*/true)) {
70 // Emit the OpDecorate instruction for SpecId.
71 if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
72 auto val = static_cast<uint32_t>(specID.getInt());
73 if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
74 return failure();
75 }
76
77 specConstIDMap[op.sym_name()] = resultID;
78 return processName(resultID, op.sym_name());
79 }
80 return failure();
81 }
82
83 LogicalResult
processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op)84 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
85 uint32_t typeID = 0;
86 if (failed(processType(op.getLoc(), op.type(), typeID))) {
87 return failure();
88 }
89
90 auto resultID = getNextID();
91
92 SmallVector<uint32_t, 8> operands;
93 operands.push_back(typeID);
94 operands.push_back(resultID);
95
96 auto constituents = op.constituents();
97
98 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
99 auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
100
101 auto constituentName = constituent.getValue();
102 auto constituentID = getSpecConstID(constituentName);
103
104 if (!constituentID) {
105 return op.emitError("unknown result <id> for specialization constant ")
106 << constituentName;
107 }
108
109 operands.push_back(constituentID);
110 }
111
112 encodeInstructionInto(typesGlobalValues,
113 spirv::Opcode::OpSpecConstantComposite, operands);
114 specConstIDMap[op.sym_name()] = resultID;
115
116 return processName(resultID, op.sym_name());
117 }
118
119 LogicalResult
processSpecConstantOperationOp(spirv::SpecConstantOperationOp op)120 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
121 uint32_t typeID = 0;
122 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
123 return failure();
124 }
125
126 auto resultID = getNextID();
127
128 SmallVector<uint32_t, 8> operands;
129 operands.push_back(typeID);
130 operands.push_back(resultID);
131
132 Block &block = op.getRegion().getBlocks().front();
133 Operation &enclosedOp = block.getOperations().front();
134
135 std::string enclosedOpName;
136 llvm::raw_string_ostream rss(enclosedOpName);
137 rss << "Op" << enclosedOp.getName().stripDialect();
138 auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
139
140 if (!enclosedOpcode) {
141 op.emitError("Couldn't find op code for op ")
142 << enclosedOp.getName().getStringRef();
143 return failure();
144 }
145
146 operands.push_back(static_cast<uint32_t>(*enclosedOpcode));
147
148 // Append operands to the enclosed op to the list of operands.
149 for (Value operand : enclosedOp.getOperands()) {
150 uint32_t id = getValueID(operand);
151 assert(id && "use before def!");
152 operands.push_back(id);
153 }
154
155 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
156 operands);
157 valueIDMap[op.getResult()] = resultID;
158
159 return success();
160 }
161
processUndefOp(spirv::UndefOp op)162 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
163 auto undefType = op.getType();
164 auto &id = undefValIDMap[undefType];
165 if (!id) {
166 id = getNextID();
167 uint32_t typeID = 0;
168 if (failed(processType(op.getLoc(), undefType, typeID)))
169 return failure();
170 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
171 {typeID, id});
172 }
173 valueIDMap[op.getResult()] = id;
174 return success();
175 }
176
processFuncOp(spirv::FuncOp op)177 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
178 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
179 assert(functionHeader.empty() && functionBody.empty());
180
181 uint32_t fnTypeID = 0;
182 // Generate type of the function.
183 if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
184 return failure();
185
186 // Add the function definition.
187 SmallVector<uint32_t, 4> operands;
188 uint32_t resTypeID = 0;
189 auto resultTypes = op.getFunctionType().getResults();
190 if (resultTypes.size() > 1) {
191 return op.emitError("cannot serialize function with multiple return types");
192 }
193 if (failed(processType(op.getLoc(),
194 (resultTypes.empty() ? getVoidType() : resultTypes[0]),
195 resTypeID))) {
196 return failure();
197 }
198 operands.push_back(resTypeID);
199 auto funcID = getOrCreateFunctionID(op.getName());
200 operands.push_back(funcID);
201 operands.push_back(static_cast<uint32_t>(op.function_control()));
202 operands.push_back(fnTypeID);
203 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
204
205 // Add function name.
206 if (failed(processName(funcID, op.getName()))) {
207 return failure();
208 }
209
210 // Declare the parameters.
211 for (auto arg : op.getArguments()) {
212 uint32_t argTypeID = 0;
213 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
214 return failure();
215 }
216 auto argValueID = getNextID();
217 valueIDMap[arg] = argValueID;
218 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
219 {argTypeID, argValueID});
220 }
221
222 // Process the body.
223 if (op.isExternal()) {
224 return op.emitError("external function is unhandled");
225 }
226
227 // Some instructions (e.g., OpVariable) in a function must be in the first
228 // block in the function. These instructions will be put in functionHeader.
229 // Thus, we put the label in functionHeader first, and omit it from the first
230 // block.
231 encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
232 {getOrCreateBlockID(&op.front())});
233 if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
234 return failure();
235 if (failed(visitInPrettyBlockOrder(
236 &op.front(), [&](Block *block) { return processBlock(block); },
237 /*skipHeader=*/true))) {
238 return failure();
239 }
240
241 // There might be OpPhi instructions who have value references needing to fix.
242 for (const auto &deferredValue : deferredPhiValues) {
243 Value value = deferredValue.first;
244 uint32_t id = getValueID(value);
245 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
246 << " to id = " << id << '\n');
247 assert(id && "OpPhi references undefined value!");
248 for (size_t offset : deferredValue.second)
249 functionBody[offset] = id;
250 }
251 deferredPhiValues.clear();
252
253 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
254 << "' --\n");
255 // Insert OpFunctionEnd.
256 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
257
258 functions.append(functionHeader.begin(), functionHeader.end());
259 functions.append(functionBody.begin(), functionBody.end());
260 functionHeader.clear();
261 functionBody.clear();
262
263 return success();
264 }
265
processVariableOp(spirv::VariableOp op)266 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
267 SmallVector<uint32_t, 4> operands;
268 SmallVector<StringRef, 2> elidedAttrs;
269 uint32_t resultID = 0;
270 uint32_t resultTypeID = 0;
271 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
272 return failure();
273 }
274 operands.push_back(resultTypeID);
275 resultID = getNextID();
276 valueIDMap[op.getResult()] = resultID;
277 operands.push_back(resultID);
278 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
279 if (attr) {
280 operands.push_back(static_cast<uint32_t>(
281 attr.cast<IntegerAttr>().getValue().getZExtValue()));
282 }
283 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
284 for (auto arg : op.getODSOperands(0)) {
285 auto argID = getValueID(arg);
286 if (!argID) {
287 return emitError(op.getLoc(), "operand 0 has a use before def");
288 }
289 operands.push_back(argID);
290 }
291 if (failed(emitDebugLine(functionHeader, op.getLoc())))
292 return failure();
293 encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
294 for (auto attr : op->getAttrs()) {
295 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
296 return attr.getName() == elided;
297 })) {
298 continue;
299 }
300 if (failed(processDecoration(op.getLoc(), resultID, attr))) {
301 return failure();
302 }
303 }
304 return success();
305 }
306
307 LogicalResult
processGlobalVariableOp(spirv::GlobalVariableOp varOp)308 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
309 // Get TypeID.
310 uint32_t resultTypeID = 0;
311 SmallVector<StringRef, 4> elidedAttrs;
312 if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
313 return failure();
314 }
315
316 elidedAttrs.push_back("type");
317 SmallVector<uint32_t, 4> operands;
318 operands.push_back(resultTypeID);
319 auto resultID = getNextID();
320
321 // Encode the name.
322 auto varName = varOp.sym_name();
323 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
324 if (failed(processName(resultID, varName))) {
325 return failure();
326 }
327 globalVarIDMap[varName] = resultID;
328 operands.push_back(resultID);
329
330 // Encode StorageClass.
331 operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
332
333 // Encode initialization.
334 if (auto initializer = varOp.initializer()) {
335 auto initializerID = getVariableID(*initializer);
336 if (!initializerID) {
337 return emitError(varOp.getLoc(),
338 "invalid usage of undefined variable as initializer");
339 }
340 operands.push_back(initializerID);
341 elidedAttrs.push_back("initializer");
342 }
343
344 if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
345 return failure();
346 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
347 elidedAttrs.push_back("initializer");
348
349 // Encode decorations.
350 for (auto attr : varOp->getAttrs()) {
351 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
352 return attr.getName() == elided;
353 })) {
354 continue;
355 }
356 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
357 return failure();
358 }
359 }
360 return success();
361 }
362
processSelectionOp(spirv::SelectionOp selectionOp)363 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
364 // Assign <id>s to all blocks so that branches inside the SelectionOp can
365 // resolve properly.
366 auto &body = selectionOp.body();
367 for (Block &block : body)
368 getOrCreateBlockID(&block);
369
370 auto *headerBlock = selectionOp.getHeaderBlock();
371 auto *mergeBlock = selectionOp.getMergeBlock();
372 auto headerID = getBlockID(headerBlock);
373 auto mergeID = getBlockID(mergeBlock);
374 auto loc = selectionOp.getLoc();
375
376 // This SelectionOp is in some MLIR block with preceding and following ops. In
377 // the binary format, it should reside in separate SPIR-V blocks from its
378 // preceding and following ops. So we need to emit unconditional branches to
379 // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
380 // flow afterwards.
381 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
382
383 // Emit the selection header block, which dominates all other blocks, first.
384 // We need to emit an OpSelectionMerge instruction before the selection header
385 // block's terminator.
386 auto emitSelectionMerge = [&]() {
387 if (failed(emitDebugLine(functionBody, loc)))
388 return failure();
389 lastProcessedWasMergeInst = true;
390 encodeInstructionInto(
391 functionBody, spirv::Opcode::OpSelectionMerge,
392 {mergeID, static_cast<uint32_t>(selectionOp.selection_control())});
393 return success();
394 };
395 if (failed(
396 processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge)))
397 return failure();
398
399 // Process all blocks with a depth-first visitor starting from the header
400 // block. The selection header block and merge block are skipped by this
401 // visitor.
402 if (failed(visitInPrettyBlockOrder(
403 headerBlock, [&](Block *block) { return processBlock(block); },
404 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
405 return failure();
406
407 // There is nothing to do for the merge block in the selection, which just
408 // contains a spv.mlir.merge op, itself. But we need to have an OpLabel
409 // instruction to start a new SPIR-V block for ops following this SelectionOp.
410 // The block should use the <id> for the merge block.
411 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
412 LLVM_DEBUG(llvm::dbgs() << "done merge ");
413 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
414 LLVM_DEBUG(llvm::dbgs() << "\n");
415 return success();
416 }
417
processLoopOp(spirv::LoopOp loopOp)418 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
419 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
420 // properly. We don't need to assign for the entry block, which is just for
421 // satisfying MLIR region's structural requirement.
422 auto &body = loopOp.body();
423 for (Block &block : llvm::drop_begin(body))
424 getOrCreateBlockID(&block);
425
426 auto *headerBlock = loopOp.getHeaderBlock();
427 auto *continueBlock = loopOp.getContinueBlock();
428 auto *mergeBlock = loopOp.getMergeBlock();
429 auto headerID = getBlockID(headerBlock);
430 auto continueID = getBlockID(continueBlock);
431 auto mergeID = getBlockID(mergeBlock);
432 auto loc = loopOp.getLoc();
433
434 // This LoopOp is in some MLIR block with preceding and following ops. In the
435 // binary format, it should reside in separate SPIR-V blocks from its
436 // preceding and following ops. So we need to emit unconditional branches to
437 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
438 // afterwards.
439 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
440
441 // LoopOp's entry block is just there for satisfying MLIR's structural
442 // requirements so we omit it and start serialization from the loop header
443 // block.
444
445 // Emit the loop header block, which dominates all other blocks, first. We
446 // need to emit an OpLoopMerge instruction before the loop header block's
447 // terminator.
448 auto emitLoopMerge = [&]() {
449 if (failed(emitDebugLine(functionBody, loc)))
450 return failure();
451 lastProcessedWasMergeInst = true;
452 encodeInstructionInto(
453 functionBody, spirv::Opcode::OpLoopMerge,
454 {mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())});
455 return success();
456 };
457 if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
458 return failure();
459
460 // Process all blocks with a depth-first visitor starting from the header
461 // block. The loop header block, loop continue block, and loop merge block are
462 // skipped by this visitor and handled later in this function.
463 if (failed(visitInPrettyBlockOrder(
464 headerBlock, [&](Block *block) { return processBlock(block); },
465 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
466 return failure();
467
468 // We have handled all other blocks. Now get to the loop continue block.
469 if (failed(processBlock(continueBlock)))
470 return failure();
471
472 // There is nothing to do for the merge block in the loop, which just contains
473 // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to
474 // start a new SPIR-V block for ops following this LoopOp. The block should
475 // use the <id> for the merge block.
476 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
477 LLVM_DEBUG(llvm::dbgs() << "done merge ");
478 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
479 LLVM_DEBUG(llvm::dbgs() << "\n");
480 return success();
481 }
482
processBranchConditionalOp(spirv::BranchConditionalOp condBranchOp)483 LogicalResult Serializer::processBranchConditionalOp(
484 spirv::BranchConditionalOp condBranchOp) {
485 auto conditionID = getValueID(condBranchOp.condition());
486 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
487 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
488 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
489
490 if (auto weights = condBranchOp.branch_weights()) {
491 for (auto val : weights->getValue())
492 arguments.push_back(val.cast<IntegerAttr>().getInt());
493 }
494
495 if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
496 return failure();
497 encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
498 arguments);
499 return success();
500 }
501
processBranchOp(spirv::BranchOp branchOp)502 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
503 if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
504 return failure();
505 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
506 {getOrCreateBlockID(branchOp.getTarget())});
507 return success();
508 }
509
processAddressOfOp(spirv::AddressOfOp addressOfOp)510 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
511 auto varName = addressOfOp.variable();
512 auto variableID = getVariableID(varName);
513 if (!variableID) {
514 return addressOfOp.emitError("unknown result <id> for variable ")
515 << varName;
516 }
517 valueIDMap[addressOfOp.pointer()] = variableID;
518 return success();
519 }
520
521 LogicalResult
processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp)522 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
523 auto constName = referenceOfOp.spec_const();
524 auto constID = getSpecConstID(constName);
525 if (!constID) {
526 return referenceOfOp.emitError(
527 "unknown result <id> for specialization constant ")
528 << constName;
529 }
530 valueIDMap[referenceOfOp.reference()] = constID;
531 return success();
532 }
533
534 template <>
535 LogicalResult
processOp(spirv::EntryPointOp op)536 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
537 SmallVector<uint32_t, 4> operands;
538 // Add the ExecutionModel.
539 operands.push_back(static_cast<uint32_t>(op.execution_model()));
540 // Add the function <id>.
541 auto funcID = getFunctionID(op.fn());
542 if (!funcID) {
543 return op.emitError("missing <id> for function ")
544 << op.fn()
545 << "; function needs to be defined before spv.EntryPoint is "
546 "serialized";
547 }
548 operands.push_back(funcID);
549 // Add the name of the function.
550 spirv::encodeStringLiteralInto(operands, op.fn());
551
552 // Add the interface values.
553 if (auto interface = op.interface()) {
554 for (auto var : interface.getValue()) {
555 auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
556 if (!id) {
557 return op.emitError("referencing undefined global variable."
558 "spv.EntryPoint is at the end of spv.module. All "
559 "referenced variables should already be defined");
560 }
561 operands.push_back(id);
562 }
563 }
564 encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
565 return success();
566 }
567
568 template <>
569 LogicalResult
processOp(spirv::ControlBarrierOp op)570 Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) {
571 StringRef argNames[] = {"execution_scope", "memory_scope",
572 "memory_semantics"};
573 SmallVector<uint32_t, 3> operands;
574
575 for (auto argName : argNames) {
576 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
577 auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
578 if (!operand) {
579 return failure();
580 }
581 operands.push_back(operand);
582 }
583
584 encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier,
585 operands);
586 return success();
587 }
588
589 template <>
590 LogicalResult
processOp(spirv::ExecutionModeOp op)591 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
592 SmallVector<uint32_t, 4> operands;
593 // Add the function <id>.
594 auto funcID = getFunctionID(op.fn());
595 if (!funcID) {
596 return op.emitError("missing <id> for function ")
597 << op.fn()
598 << "; function needs to be serialized before ExecutionModeOp is "
599 "serialized";
600 }
601 operands.push_back(funcID);
602 // Add the ExecutionMode.
603 operands.push_back(static_cast<uint32_t>(op.execution_mode()));
604
605 // Serialize values if any.
606 auto values = op.values();
607 if (values) {
608 for (auto &intVal : values.getValue()) {
609 operands.push_back(static_cast<uint32_t>(
610 intVal.cast<IntegerAttr>().getValue().getZExtValue()));
611 }
612 }
613 encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
614 operands);
615 return success();
616 }
617
618 template <>
619 LogicalResult
processOp(spirv::MemoryBarrierOp op)620 Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) {
621 StringRef argNames[] = {"memory_scope", "memory_semantics"};
622 SmallVector<uint32_t, 2> operands;
623
624 for (auto argName : argNames) {
625 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
626 auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
627 if (!operand) {
628 return failure();
629 }
630 operands.push_back(operand);
631 }
632
633 encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, operands);
634 return success();
635 }
636
637 template <>
638 LogicalResult
processOp(spirv::FunctionCallOp op)639 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
640 auto funcName = op.callee();
641 uint32_t resTypeID = 0;
642
643 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
644 if (failed(processType(op.getLoc(), resultTy, resTypeID)))
645 return failure();
646
647 auto funcID = getOrCreateFunctionID(funcName);
648 auto funcCallID = getNextID();
649 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
650
651 for (auto value : op.arguments()) {
652 auto valueID = getValueID(value);
653 assert(valueID && "cannot find a value for spv.FunctionCall");
654 operands.push_back(valueID);
655 }
656
657 if (!resultTy.isa<NoneType>())
658 valueIDMap[op.getResult(0)] = funcCallID;
659
660 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
661 return success();
662 }
663
664 template <>
665 LogicalResult
processOp(spirv::CopyMemoryOp op)666 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
667 SmallVector<uint32_t, 4> operands;
668 SmallVector<StringRef, 2> elidedAttrs;
669
670 for (Value operand : op->getOperands()) {
671 auto id = getValueID(operand);
672 assert(id && "use before def!");
673 operands.push_back(id);
674 }
675
676 if (auto attr = op->getAttr("memory_access")) {
677 operands.push_back(static_cast<uint32_t>(
678 attr.cast<IntegerAttr>().getValue().getZExtValue()));
679 }
680
681 elidedAttrs.push_back("memory_access");
682
683 if (auto attr = op->getAttr("alignment")) {
684 operands.push_back(static_cast<uint32_t>(
685 attr.cast<IntegerAttr>().getValue().getZExtValue()));
686 }
687
688 elidedAttrs.push_back("alignment");
689
690 if (auto attr = op->getAttr("source_memory_access")) {
691 operands.push_back(static_cast<uint32_t>(
692 attr.cast<IntegerAttr>().getValue().getZExtValue()));
693 }
694
695 elidedAttrs.push_back("source_memory_access");
696
697 if (auto attr = op->getAttr("source_alignment")) {
698 operands.push_back(static_cast<uint32_t>(
699 attr.cast<IntegerAttr>().getValue().getZExtValue()));
700 }
701
702 elidedAttrs.push_back("source_alignment");
703 if (failed(emitDebugLine(functionBody, op.getLoc())))
704 return failure();
705 encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
706
707 return success();
708 }
709
710 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
711 // various Serializer::processOp<...>() specializations.
712 #define GET_SERIALIZATION_FNS
713 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
714
715 } // namespace spirv
716 } // namespace mlir
717