1 //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the serialization methods for MLIR SPIR-V module ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/IR/RegionGraphTraits.h"
17 #include "mlir/Support/LogicalResult.h"
18 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
19 #include "llvm/ADT/DepthFirstIterator.h"
20 #include "llvm/Support/Debug.h"
21 
22 #define DEBUG_TYPE "spirv-serialization"
23 
24 using namespace mlir;
25 
26 /// A pre-order depth-first visitor function for processing basic blocks.
27 ///
28 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
29 /// depth-first manner and calls `blockHandler` on each block. Skips handling
30 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
31 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
32 /// successors.
33 ///
34 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
35 /// of blocks in a function must satisfy the rule that blocks appear before
36 /// all blocks they dominate." This can be achieved by a pre-order CFG
37 /// traversal algorithm. To make the serialization output more logical and
38 /// readable to human, we perform depth-first CFG traversal and delay the
39 /// serialization of the merge block and the continue block, if exists, until
40 /// after all other blocks have been processed.
41 static LogicalResult
visitInPrettyBlockOrder(Block * headerBlock,function_ref<LogicalResult (Block *)> blockHandler,bool skipHeader=false,BlockRange skipBlocks={})42 visitInPrettyBlockOrder(Block *headerBlock,
43                         function_ref<LogicalResult(Block *)> blockHandler,
44                         bool skipHeader = false, BlockRange skipBlocks = {}) {
45   llvm::df_iterator_default_set<Block *, 4> doneBlocks;
46   doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
47 
48   for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
49     if (skipHeader && block == headerBlock)
50       continue;
51     if (failed(blockHandler(block)))
52       return failure();
53   }
54   return success();
55 }
56 
57 namespace mlir {
58 namespace spirv {
processConstantOp(spirv::ConstantOp op)59 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
60   if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
61     valueIDMap[op.getResult()] = resultID;
62     return success();
63   }
64   return failure();
65 }
66 
processSpecConstantOp(spirv::SpecConstantOp op)67 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
68   if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(),
69                                             /*isSpec=*/true)) {
70     // Emit the OpDecorate instruction for SpecId.
71     if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
72       auto val = static_cast<uint32_t>(specID.getInt());
73       if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
74         return failure();
75     }
76 
77     specConstIDMap[op.sym_name()] = resultID;
78     return processName(resultID, op.sym_name());
79   }
80   return failure();
81 }
82 
83 LogicalResult
processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op)84 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
85   uint32_t typeID = 0;
86   if (failed(processType(op.getLoc(), op.type(), typeID))) {
87     return failure();
88   }
89 
90   auto resultID = getNextID();
91 
92   SmallVector<uint32_t, 8> operands;
93   operands.push_back(typeID);
94   operands.push_back(resultID);
95 
96   auto constituents = op.constituents();
97 
98   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
99     auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
100 
101     auto constituentName = constituent.getValue();
102     auto constituentID = getSpecConstID(constituentName);
103 
104     if (!constituentID) {
105       return op.emitError("unknown result <id> for specialization constant ")
106              << constituentName;
107     }
108 
109     operands.push_back(constituentID);
110   }
111 
112   encodeInstructionInto(typesGlobalValues,
113                         spirv::Opcode::OpSpecConstantComposite, operands);
114   specConstIDMap[op.sym_name()] = resultID;
115 
116   return processName(resultID, op.sym_name());
117 }
118 
119 LogicalResult
processSpecConstantOperationOp(spirv::SpecConstantOperationOp op)120 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
121   uint32_t typeID = 0;
122   if (failed(processType(op.getLoc(), op.getType(), typeID))) {
123     return failure();
124   }
125 
126   auto resultID = getNextID();
127 
128   SmallVector<uint32_t, 8> operands;
129   operands.push_back(typeID);
130   operands.push_back(resultID);
131 
132   Block &block = op.getRegion().getBlocks().front();
133   Operation &enclosedOp = block.getOperations().front();
134 
135   std::string enclosedOpName;
136   llvm::raw_string_ostream rss(enclosedOpName);
137   rss << "Op" << enclosedOp.getName().stripDialect();
138   auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
139 
140   if (!enclosedOpcode) {
141     op.emitError("Couldn't find op code for op ")
142         << enclosedOp.getName().getStringRef();
143     return failure();
144   }
145 
146   operands.push_back(static_cast<uint32_t>(*enclosedOpcode));
147 
148   // Append operands to the enclosed op to the list of operands.
149   for (Value operand : enclosedOp.getOperands()) {
150     uint32_t id = getValueID(operand);
151     assert(id && "use before def!");
152     operands.push_back(id);
153   }
154 
155   encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
156                         operands);
157   valueIDMap[op.getResult()] = resultID;
158 
159   return success();
160 }
161 
processUndefOp(spirv::UndefOp op)162 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
163   auto undefType = op.getType();
164   auto &id = undefValIDMap[undefType];
165   if (!id) {
166     id = getNextID();
167     uint32_t typeID = 0;
168     if (failed(processType(op.getLoc(), undefType, typeID)))
169       return failure();
170     encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
171                           {typeID, id});
172   }
173   valueIDMap[op.getResult()] = id;
174   return success();
175 }
176 
processFuncOp(spirv::FuncOp op)177 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
178   LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
179   assert(functionHeader.empty() && functionBody.empty());
180 
181   uint32_t fnTypeID = 0;
182   // Generate type of the function.
183   if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
184     return failure();
185 
186   // Add the function definition.
187   SmallVector<uint32_t, 4> operands;
188   uint32_t resTypeID = 0;
189   auto resultTypes = op.getFunctionType().getResults();
190   if (resultTypes.size() > 1) {
191     return op.emitError("cannot serialize function with multiple return types");
192   }
193   if (failed(processType(op.getLoc(),
194                          (resultTypes.empty() ? getVoidType() : resultTypes[0]),
195                          resTypeID))) {
196     return failure();
197   }
198   operands.push_back(resTypeID);
199   auto funcID = getOrCreateFunctionID(op.getName());
200   operands.push_back(funcID);
201   operands.push_back(static_cast<uint32_t>(op.function_control()));
202   operands.push_back(fnTypeID);
203   encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
204 
205   // Add function name.
206   if (failed(processName(funcID, op.getName()))) {
207     return failure();
208   }
209 
210   // Declare the parameters.
211   for (auto arg : op.getArguments()) {
212     uint32_t argTypeID = 0;
213     if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
214       return failure();
215     }
216     auto argValueID = getNextID();
217     valueIDMap[arg] = argValueID;
218     encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
219                           {argTypeID, argValueID});
220   }
221 
222   // Process the body.
223   if (op.isExternal()) {
224     return op.emitError("external function is unhandled");
225   }
226 
227   // Some instructions (e.g., OpVariable) in a function must be in the first
228   // block in the function. These instructions will be put in functionHeader.
229   // Thus, we put the label in functionHeader first, and omit it from the first
230   // block.
231   encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
232                         {getOrCreateBlockID(&op.front())});
233   if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
234     return failure();
235   if (failed(visitInPrettyBlockOrder(
236           &op.front(), [&](Block *block) { return processBlock(block); },
237           /*skipHeader=*/true))) {
238     return failure();
239   }
240 
241   // There might be OpPhi instructions who have value references needing to fix.
242   for (const auto &deferredValue : deferredPhiValues) {
243     Value value = deferredValue.first;
244     uint32_t id = getValueID(value);
245     LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
246                             << " to id = " << id << '\n');
247     assert(id && "OpPhi references undefined value!");
248     for (size_t offset : deferredValue.second)
249       functionBody[offset] = id;
250   }
251   deferredPhiValues.clear();
252 
253   LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
254                           << "' --\n");
255   // Insert OpFunctionEnd.
256   encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
257 
258   functions.append(functionHeader.begin(), functionHeader.end());
259   functions.append(functionBody.begin(), functionBody.end());
260   functionHeader.clear();
261   functionBody.clear();
262 
263   return success();
264 }
265 
processVariableOp(spirv::VariableOp op)266 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
267   SmallVector<uint32_t, 4> operands;
268   SmallVector<StringRef, 2> elidedAttrs;
269   uint32_t resultID = 0;
270   uint32_t resultTypeID = 0;
271   if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
272     return failure();
273   }
274   operands.push_back(resultTypeID);
275   resultID = getNextID();
276   valueIDMap[op.getResult()] = resultID;
277   operands.push_back(resultID);
278   auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
279   if (attr) {
280     operands.push_back(static_cast<uint32_t>(
281         attr.cast<IntegerAttr>().getValue().getZExtValue()));
282   }
283   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
284   for (auto arg : op.getODSOperands(0)) {
285     auto argID = getValueID(arg);
286     if (!argID) {
287       return emitError(op.getLoc(), "operand 0 has a use before def");
288     }
289     operands.push_back(argID);
290   }
291   if (failed(emitDebugLine(functionHeader, op.getLoc())))
292     return failure();
293   encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
294   for (auto attr : op->getAttrs()) {
295     if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
296           return attr.getName() == elided;
297         })) {
298       continue;
299     }
300     if (failed(processDecoration(op.getLoc(), resultID, attr))) {
301       return failure();
302     }
303   }
304   return success();
305 }
306 
307 LogicalResult
processGlobalVariableOp(spirv::GlobalVariableOp varOp)308 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
309   // Get TypeID.
310   uint32_t resultTypeID = 0;
311   SmallVector<StringRef, 4> elidedAttrs;
312   if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
313     return failure();
314   }
315 
316   elidedAttrs.push_back("type");
317   SmallVector<uint32_t, 4> operands;
318   operands.push_back(resultTypeID);
319   auto resultID = getNextID();
320 
321   // Encode the name.
322   auto varName = varOp.sym_name();
323   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
324   if (failed(processName(resultID, varName))) {
325     return failure();
326   }
327   globalVarIDMap[varName] = resultID;
328   operands.push_back(resultID);
329 
330   // Encode StorageClass.
331   operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
332 
333   // Encode initialization.
334   if (auto initializer = varOp.initializer()) {
335     auto initializerID = getVariableID(*initializer);
336     if (!initializerID) {
337       return emitError(varOp.getLoc(),
338                        "invalid usage of undefined variable as initializer");
339     }
340     operands.push_back(initializerID);
341     elidedAttrs.push_back("initializer");
342   }
343 
344   if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
345     return failure();
346   encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
347   elidedAttrs.push_back("initializer");
348 
349   // Encode decorations.
350   for (auto attr : varOp->getAttrs()) {
351     if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
352           return attr.getName() == elided;
353         })) {
354       continue;
355     }
356     if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
357       return failure();
358     }
359   }
360   return success();
361 }
362 
processSelectionOp(spirv::SelectionOp selectionOp)363 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
364   // Assign <id>s to all blocks so that branches inside the SelectionOp can
365   // resolve properly.
366   auto &body = selectionOp.body();
367   for (Block &block : body)
368     getOrCreateBlockID(&block);
369 
370   auto *headerBlock = selectionOp.getHeaderBlock();
371   auto *mergeBlock = selectionOp.getMergeBlock();
372   auto headerID = getBlockID(headerBlock);
373   auto mergeID = getBlockID(mergeBlock);
374   auto loc = selectionOp.getLoc();
375 
376   // This SelectionOp is in some MLIR block with preceding and following ops. In
377   // the binary format, it should reside in separate SPIR-V blocks from its
378   // preceding and following ops. So we need to emit unconditional branches to
379   // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
380   // flow afterwards.
381   encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
382 
383   // Emit the selection header block, which dominates all other blocks, first.
384   // We need to emit an OpSelectionMerge instruction before the selection header
385   // block's terminator.
386   auto emitSelectionMerge = [&]() {
387     if (failed(emitDebugLine(functionBody, loc)))
388       return failure();
389     lastProcessedWasMergeInst = true;
390     encodeInstructionInto(
391         functionBody, spirv::Opcode::OpSelectionMerge,
392         {mergeID, static_cast<uint32_t>(selectionOp.selection_control())});
393     return success();
394   };
395   if (failed(
396           processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge)))
397     return failure();
398 
399   // Process all blocks with a depth-first visitor starting from the header
400   // block. The selection header block and merge block are skipped by this
401   // visitor.
402   if (failed(visitInPrettyBlockOrder(
403           headerBlock, [&](Block *block) { return processBlock(block); },
404           /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
405     return failure();
406 
407   // There is nothing to do for the merge block in the selection, which just
408   // contains a spv.mlir.merge op, itself. But we need to have an OpLabel
409   // instruction to start a new SPIR-V block for ops following this SelectionOp.
410   // The block should use the <id> for the merge block.
411   encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
412   LLVM_DEBUG(llvm::dbgs() << "done merge ");
413   LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
414   LLVM_DEBUG(llvm::dbgs() << "\n");
415   return success();
416 }
417 
processLoopOp(spirv::LoopOp loopOp)418 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
419   // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
420   // properly. We don't need to assign for the entry block, which is just for
421   // satisfying MLIR region's structural requirement.
422   auto &body = loopOp.body();
423   for (Block &block : llvm::drop_begin(body))
424     getOrCreateBlockID(&block);
425 
426   auto *headerBlock = loopOp.getHeaderBlock();
427   auto *continueBlock = loopOp.getContinueBlock();
428   auto *mergeBlock = loopOp.getMergeBlock();
429   auto headerID = getBlockID(headerBlock);
430   auto continueID = getBlockID(continueBlock);
431   auto mergeID = getBlockID(mergeBlock);
432   auto loc = loopOp.getLoc();
433 
434   // This LoopOp is in some MLIR block with preceding and following ops. In the
435   // binary format, it should reside in separate SPIR-V blocks from its
436   // preceding and following ops. So we need to emit unconditional branches to
437   // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
438   // afterwards.
439   encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
440 
441   // LoopOp's entry block is just there for satisfying MLIR's structural
442   // requirements so we omit it and start serialization from the loop header
443   // block.
444 
445   // Emit the loop header block, which dominates all other blocks, first. We
446   // need to emit an OpLoopMerge instruction before the loop header block's
447   // terminator.
448   auto emitLoopMerge = [&]() {
449     if (failed(emitDebugLine(functionBody, loc)))
450       return failure();
451     lastProcessedWasMergeInst = true;
452     encodeInstructionInto(
453         functionBody, spirv::Opcode::OpLoopMerge,
454         {mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())});
455     return success();
456   };
457   if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
458     return failure();
459 
460   // Process all blocks with a depth-first visitor starting from the header
461   // block. The loop header block, loop continue block, and loop merge block are
462   // skipped by this visitor and handled later in this function.
463   if (failed(visitInPrettyBlockOrder(
464           headerBlock, [&](Block *block) { return processBlock(block); },
465           /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
466     return failure();
467 
468   // We have handled all other blocks. Now get to the loop continue block.
469   if (failed(processBlock(continueBlock)))
470     return failure();
471 
472   // There is nothing to do for the merge block in the loop, which just contains
473   // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to
474   // start a new SPIR-V block for ops following this LoopOp. The block should
475   // use the <id> for the merge block.
476   encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
477   LLVM_DEBUG(llvm::dbgs() << "done merge ");
478   LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
479   LLVM_DEBUG(llvm::dbgs() << "\n");
480   return success();
481 }
482 
processBranchConditionalOp(spirv::BranchConditionalOp condBranchOp)483 LogicalResult Serializer::processBranchConditionalOp(
484     spirv::BranchConditionalOp condBranchOp) {
485   auto conditionID = getValueID(condBranchOp.condition());
486   auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
487   auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
488   SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
489 
490   if (auto weights = condBranchOp.branch_weights()) {
491     for (auto val : weights->getValue())
492       arguments.push_back(val.cast<IntegerAttr>().getInt());
493   }
494 
495   if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
496     return failure();
497   encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
498                         arguments);
499   return success();
500 }
501 
processBranchOp(spirv::BranchOp branchOp)502 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
503   if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
504     return failure();
505   encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
506                         {getOrCreateBlockID(branchOp.getTarget())});
507   return success();
508 }
509 
processAddressOfOp(spirv::AddressOfOp addressOfOp)510 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
511   auto varName = addressOfOp.variable();
512   auto variableID = getVariableID(varName);
513   if (!variableID) {
514     return addressOfOp.emitError("unknown result <id> for variable ")
515            << varName;
516   }
517   valueIDMap[addressOfOp.pointer()] = variableID;
518   return success();
519 }
520 
521 LogicalResult
processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp)522 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
523   auto constName = referenceOfOp.spec_const();
524   auto constID = getSpecConstID(constName);
525   if (!constID) {
526     return referenceOfOp.emitError(
527                "unknown result <id> for specialization constant ")
528            << constName;
529   }
530   valueIDMap[referenceOfOp.reference()] = constID;
531   return success();
532 }
533 
534 template <>
535 LogicalResult
processOp(spirv::EntryPointOp op)536 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
537   SmallVector<uint32_t, 4> operands;
538   // Add the ExecutionModel.
539   operands.push_back(static_cast<uint32_t>(op.execution_model()));
540   // Add the function <id>.
541   auto funcID = getFunctionID(op.fn());
542   if (!funcID) {
543     return op.emitError("missing <id> for function ")
544            << op.fn()
545            << "; function needs to be defined before spv.EntryPoint is "
546               "serialized";
547   }
548   operands.push_back(funcID);
549   // Add the name of the function.
550   spirv::encodeStringLiteralInto(operands, op.fn());
551 
552   // Add the interface values.
553   if (auto interface = op.interface()) {
554     for (auto var : interface.getValue()) {
555       auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
556       if (!id) {
557         return op.emitError("referencing undefined global variable."
558                             "spv.EntryPoint is at the end of spv.module. All "
559                             "referenced variables should already be defined");
560       }
561       operands.push_back(id);
562     }
563   }
564   encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
565   return success();
566 }
567 
568 template <>
569 LogicalResult
processOp(spirv::ControlBarrierOp op)570 Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) {
571   StringRef argNames[] = {"execution_scope", "memory_scope",
572                           "memory_semantics"};
573   SmallVector<uint32_t, 3> operands;
574 
575   for (auto argName : argNames) {
576     auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
577     auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
578     if (!operand) {
579       return failure();
580     }
581     operands.push_back(operand);
582   }
583 
584   encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier,
585                         operands);
586   return success();
587 }
588 
589 template <>
590 LogicalResult
processOp(spirv::ExecutionModeOp op)591 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
592   SmallVector<uint32_t, 4> operands;
593   // Add the function <id>.
594   auto funcID = getFunctionID(op.fn());
595   if (!funcID) {
596     return op.emitError("missing <id> for function ")
597            << op.fn()
598            << "; function needs to be serialized before ExecutionModeOp is "
599               "serialized";
600   }
601   operands.push_back(funcID);
602   // Add the ExecutionMode.
603   operands.push_back(static_cast<uint32_t>(op.execution_mode()));
604 
605   // Serialize values if any.
606   auto values = op.values();
607   if (values) {
608     for (auto &intVal : values.getValue()) {
609       operands.push_back(static_cast<uint32_t>(
610           intVal.cast<IntegerAttr>().getValue().getZExtValue()));
611     }
612   }
613   encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
614                         operands);
615   return success();
616 }
617 
618 template <>
619 LogicalResult
processOp(spirv::MemoryBarrierOp op)620 Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) {
621   StringRef argNames[] = {"memory_scope", "memory_semantics"};
622   SmallVector<uint32_t, 2> operands;
623 
624   for (auto argName : argNames) {
625     auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
626     auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
627     if (!operand) {
628       return failure();
629     }
630     operands.push_back(operand);
631   }
632 
633   encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, operands);
634   return success();
635 }
636 
637 template <>
638 LogicalResult
processOp(spirv::FunctionCallOp op)639 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
640   auto funcName = op.callee();
641   uint32_t resTypeID = 0;
642 
643   Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
644   if (failed(processType(op.getLoc(), resultTy, resTypeID)))
645     return failure();
646 
647   auto funcID = getOrCreateFunctionID(funcName);
648   auto funcCallID = getNextID();
649   SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
650 
651   for (auto value : op.arguments()) {
652     auto valueID = getValueID(value);
653     assert(valueID && "cannot find a value for spv.FunctionCall");
654     operands.push_back(valueID);
655   }
656 
657   if (!resultTy.isa<NoneType>())
658     valueIDMap[op.getResult(0)] = funcCallID;
659 
660   encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
661   return success();
662 }
663 
664 template <>
665 LogicalResult
processOp(spirv::CopyMemoryOp op)666 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
667   SmallVector<uint32_t, 4> operands;
668   SmallVector<StringRef, 2> elidedAttrs;
669 
670   for (Value operand : op->getOperands()) {
671     auto id = getValueID(operand);
672     assert(id && "use before def!");
673     operands.push_back(id);
674   }
675 
676   if (auto attr = op->getAttr("memory_access")) {
677     operands.push_back(static_cast<uint32_t>(
678         attr.cast<IntegerAttr>().getValue().getZExtValue()));
679   }
680 
681   elidedAttrs.push_back("memory_access");
682 
683   if (auto attr = op->getAttr("alignment")) {
684     operands.push_back(static_cast<uint32_t>(
685         attr.cast<IntegerAttr>().getValue().getZExtValue()));
686   }
687 
688   elidedAttrs.push_back("alignment");
689 
690   if (auto attr = op->getAttr("source_memory_access")) {
691     operands.push_back(static_cast<uint32_t>(
692         attr.cast<IntegerAttr>().getValue().getZExtValue()));
693   }
694 
695   elidedAttrs.push_back("source_memory_access");
696 
697   if (auto attr = op->getAttr("source_alignment")) {
698     operands.push_back(static_cast<uint32_t>(
699         attr.cast<IntegerAttr>().getValue().getZExtValue()));
700   }
701 
702   elidedAttrs.push_back("source_alignment");
703   if (failed(emitDebugLine(functionBody, op.getLoc())))
704     return failure();
705   encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
706 
707   return success();
708 }
709 
710 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
711 // various Serializer::processOp<...>() specializations.
712 #define GET_SERIALIZATION_FNS
713 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
714 
715 } // namespace spirv
716 } // namespace mlir
717