1 //===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serialization.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Target/SPIRV/Serialization.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/RegionGraphTraits.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
23 #include "llvm/ADT/DepthFirstIterator.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/ADT/bit.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 
34 #define DEBUG_TYPE "spirv-serialization"
35 
36 using namespace mlir;
37 
38 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
39 /// the given `binary` vector.
40 static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
41                                            spirv::Opcode op,
42                                            ArrayRef<uint32_t> operands) {
43   uint32_t wordCount = 1 + operands.size();
44   binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
45   binary.append(operands.begin(), operands.end());
46   return success();
47 }
48 
49 /// A pre-order depth-first visitor function for processing basic blocks.
50 ///
51 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
52 /// depth-first manner and calls `blockHandler` on each block. Skips handling
53 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
54 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
55 /// successors.
56 ///
57 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
58 /// of blocks in a function must satisfy the rule that blocks appear before
59 /// all blocks they dominate." This can be achieved by a pre-order CFG
60 /// traversal algorithm. To make the serialization output more logical and
61 /// readable to human, we perform depth-first CFG traversal and delay the
62 /// serialization of the merge block and the continue block, if exists, until
63 /// after all other blocks have been processed.
64 static LogicalResult
65 visitInPrettyBlockOrder(Block *headerBlock,
66                         function_ref<LogicalResult(Block *)> blockHandler,
67                         bool skipHeader = false, BlockRange skipBlocks = {}) {
68   llvm::df_iterator_default_set<Block *, 4> doneBlocks;
69   doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
70 
71   for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
72     if (skipHeader && block == headerBlock)
73       continue;
74     if (failed(blockHandler(block)))
75       return failure();
76   }
77   return success();
78 }
79 
80 /// Returns the merge block if the given `op` is a structured control flow op.
81 /// Otherwise returns nullptr.
82 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) {
83   if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
84     return selectionOp.getMergeBlock();
85   if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
86     return loopOp.getMergeBlock();
87   return nullptr;
88 }
89 
90 /// Given a predecessor `block` for a block with arguments, returns the block
91 /// that should be used as the parent block for SPIR-V OpPhi instructions
92 /// corresponding to the block arguments.
93 static Block *getPhiIncomingBlock(Block *block) {
94   // If the predecessor block in question is the entry block for a spv.loop,
95   // we jump to this spv.loop from its enclosing block.
96   if (block->isEntryBlock()) {
97     if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
98       // Then the incoming parent block for OpPhi should be the merge block of
99       // the structured control flow op before this loop.
100       Operation *op = loopOp.getOperation();
101       while ((op = op->getPrevNode()) != nullptr)
102         if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
103           return incomingBlock;
104       // Or the enclosing block itself if no structured control flow ops
105       // exists before this loop.
106       return loopOp->getBlock();
107     }
108   }
109 
110   // Otherwise, we jump from the given predecessor block. Try to see if there is
111   // a structured control flow op inside it.
112   for (Operation &op : llvm::reverse(block->getOperations())) {
113     if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
114       return incomingBlock;
115   }
116   return block;
117 }
118 
119 namespace {
120 
121 /// A SPIR-V module serializer.
122 ///
123 /// A SPIR-V binary module is a single linear stream of instructions; each
124 /// instruction is composed of 32-bit words with the layout:
125 ///
126 ///   | <word-count>|<opcode> |  <operand>   |  <operand>   | ... |
127 ///   | <------ word -------> | <-- word --> | <-- word --> | ... |
128 ///
129 /// For the first word, the 16 high-order bits are the word count of the
130 /// instruction, the 16 low-order bits are the opcode enumerant. The
131 /// instructions then belong to different sections, which must be laid out in
132 /// the particular order as specified in "2.4 Logical Layout of a Module" of
133 /// the SPIR-V spec.
134 class Serializer {
135 public:
136   /// Creates a serializer for the given SPIR-V `module`.
137   explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false);
138 
139   /// Serializes the remembered SPIR-V module.
140   LogicalResult serialize();
141 
142   /// Collects the final SPIR-V `binary`.
143   void collect(SmallVectorImpl<uint32_t> &binary);
144 
145 #ifndef NDEBUG
146   /// (For debugging) prints each value and its corresponding result <id>.
147   void printValueIDMap(raw_ostream &os);
148 #endif
149 
150 private:
151   // Note that there are two main categories of methods in this class:
152   // * process*() methods are meant to fully serialize a SPIR-V module entity
153   //   (header, type, op, etc.). They update internal vectors containing
154   //   different binary sections. They are not meant to be called except the
155   //   top-level serialization loop.
156   // * prepare*() methods are meant to be helpers that prepare for serializing
157   //   certain entity. They may or may not update internal vectors containing
158   //   different binary sections. They are meant to be called among themselves
159   //   or by other process*() methods for subtasks.
160 
161   //===--------------------------------------------------------------------===//
162   // <id>
163   //===--------------------------------------------------------------------===//
164 
165   // Note that it is illegal to use id <0> in SPIR-V binary module. Various
166   // methods in this class, if using SPIR-V word (uint32_t) as interface,
167   // check or return id <0> to indicate error in processing.
168 
169   /// Consumes the next unused <id>. This method will never return 0.
170   uint32_t getNextID() { return nextID++; }
171 
172   //===--------------------------------------------------------------------===//
173   // Module structure
174   //===--------------------------------------------------------------------===//
175 
176   uint32_t getSpecConstID(StringRef constName) const {
177     return specConstIDMap.lookup(constName);
178   }
179 
180   uint32_t getVariableID(StringRef varName) const {
181     return globalVarIDMap.lookup(varName);
182   }
183 
184   uint32_t getFunctionID(StringRef fnName) const {
185     return funcIDMap.lookup(fnName);
186   }
187 
188   /// Gets the <id> for the function with the given name. Assigns the next
189   /// available <id> if the function haven't been deserialized.
190   uint32_t getOrCreateFunctionID(StringRef fnName);
191 
192   void processCapability();
193 
194   void processDebugInfo();
195 
196   void processExtension();
197 
198   void processMemoryModel();
199 
200   LogicalResult processConstantOp(spirv::ConstantOp op);
201 
202   LogicalResult processSpecConstantOp(spirv::SpecConstantOp op);
203 
204   LogicalResult
205   processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op);
206 
207   LogicalResult
208   processSpecConstantOperationOp(spirv::SpecConstantOperationOp op);
209 
210   /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA
211   /// value to use with other operations. The SPIR-V spec recommends that
212   /// OpUndef be generated at module level. The serialization generates an
213   /// OpUndef for each type needed at module level.
214   LogicalResult processUndefOp(spirv::UndefOp op);
215 
216   /// Emit OpName for the given `resultID`.
217   LogicalResult processName(uint32_t resultID, StringRef name);
218 
219   /// Processes a SPIR-V function op.
220   LogicalResult processFuncOp(spirv::FuncOp op);
221 
222   LogicalResult processVariableOp(spirv::VariableOp op);
223 
224   /// Process a SPIR-V GlobalVariableOp
225   LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
226 
227   /// Process attributes that translate to decorations on the result <id>
228   LogicalResult processDecoration(Location loc, uint32_t resultID,
229                                   NamedAttribute attr);
230 
231   template <typename DType>
232   LogicalResult processTypeDecoration(Location loc, DType type,
233                                       uint32_t resultId) {
234     return emitError(loc, "unhandled decoration for type:") << type;
235   }
236 
237   /// Process member decoration
238   LogicalResult processMemberDecoration(
239       uint32_t structID,
240       const spirv::StructType::MemberDecorationInfo &memberDecorationInfo);
241 
242   //===--------------------------------------------------------------------===//
243   // Types
244   //===--------------------------------------------------------------------===//
245 
246   uint32_t getTypeID(Type type) const { return typeIDMap.lookup(type); }
247 
248   Type getVoidType() { return mlirBuilder.getNoneType(); }
249 
250   bool isVoidType(Type type) const { return type.isa<NoneType>(); }
251 
252   /// Returns true if the given type is a pointer type to a struct in some
253   /// interface storage class.
254   bool isInterfaceStructPtrType(Type type) const;
255 
256   /// Main dispatch method for serializing a type. The result <id> of the
257   /// serialized type will be returned as `typeID`.
258   LogicalResult processType(Location loc, Type type, uint32_t &typeID);
259   LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID,
260                                 llvm::SetVector<StringRef> &serializationCtx);
261 
262   /// Method for preparing basic SPIR-V type serialization. Returns the type's
263   /// opcode and operands for the instruction via `typeEnum` and `operands`.
264   LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID,
265                                  spirv::Opcode &typeEnum,
266                                  SmallVectorImpl<uint32_t> &operands,
267                                  bool &deferSerialization,
268                                  llvm::SetVector<StringRef> &serializationCtx);
269 
270   LogicalResult prepareFunctionType(Location loc, FunctionType type,
271                                     spirv::Opcode &typeEnum,
272                                     SmallVectorImpl<uint32_t> &operands);
273 
274   //===--------------------------------------------------------------------===//
275   // Constant
276   //===--------------------------------------------------------------------===//
277 
278   uint32_t getConstantID(Attribute value) const {
279     return constIDMap.lookup(value);
280   }
281 
282   /// Main dispatch method for processing a constant with the given `constType`
283   /// and `valueAttr`. `constType` is needed here because we can interpret the
284   /// `valueAttr` as a different type than the type of `valueAttr` itself; for
285   /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
286   /// constants.
287   uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
288 
289   /// Prepares array attribute serialization. This method emits corresponding
290   /// OpConstant* and returns the result <id> associated with it. Returns 0 if
291   /// failed.
292   uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr);
293 
294   /// Prepares bool/int/float DenseElementsAttr serialization. This method
295   /// iterates the DenseElementsAttr to construct the constant array, and
296   /// returns the result <id>  associated with it. Returns 0 if failed. Note
297   /// that the size of `index` must match the rank.
298   /// TODO: Consider to enhance splat elements cases. For splat cases,
299   /// we don't need to loop over all elements, especially when the splat value
300   /// is zero. We can use OpConstantNull when the value is zero.
301   uint32_t prepareDenseElementsConstant(Location loc, Type constType,
302                                         DenseElementsAttr valueAttr, int dim,
303                                         MutableArrayRef<uint64_t> index);
304 
305   /// Prepares scalar attribute serialization. This method emits corresponding
306   /// OpConstant* and returns the result <id> associated with it. Returns 0 if
307   /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is
308   /// true, then the constant will be serialized as a specialization constant.
309   uint32_t prepareConstantScalar(Location loc, Attribute valueAttr,
310                                  bool isSpec = false);
311 
312   uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr,
313                                bool isSpec = false);
314 
315   uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr,
316                               bool isSpec = false);
317 
318   uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr,
319                              bool isSpec = false);
320 
321   //===--------------------------------------------------------------------===//
322   // Control flow
323   //===--------------------------------------------------------------------===//
324 
325   /// Returns the result <id> for the given block.
326   uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); }
327 
328   /// Returns the result <id> for the given block. If no <id> has been assigned,
329   /// assigns the next available <id>
330   uint32_t getOrCreateBlockID(Block *block);
331 
332   /// Processes the given `block` and emits SPIR-V instructions for all ops
333   /// inside. Does not emit OpLabel for this block if `omitLabel` is true.
334   /// `actionBeforeTerminator` is a callback that will be invoked before
335   /// handling the terminator op. It can be used to inject the Op*Merge
336   /// instruction if this is a SPIR-V selection/loop header block.
337   LogicalResult
338   processBlock(Block *block, bool omitLabel = false,
339                function_ref<void()> actionBeforeTerminator = nullptr);
340 
341   /// Emits OpPhi instructions for the given block if it has block arguments.
342   LogicalResult emitPhiForBlockArguments(Block *block);
343 
344   LogicalResult processSelectionOp(spirv::SelectionOp selectionOp);
345 
346   LogicalResult processLoopOp(spirv::LoopOp loopOp);
347 
348   LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp);
349 
350   LogicalResult processBranchOp(spirv::BranchOp branchOp);
351 
352   //===--------------------------------------------------------------------===//
353   // Operations
354   //===--------------------------------------------------------------------===//
355 
356   LogicalResult encodeExtensionInstruction(Operation *op,
357                                            StringRef extensionSetName,
358                                            uint32_t opcode,
359                                            ArrayRef<uint32_t> operands);
360 
361   uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); }
362 
363   LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
364 
365   LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp);
366 
367   /// Main dispatch method for serializing an operation.
368   LogicalResult processOperation(Operation *op);
369 
370   /// Serializes an operation `op` as core instruction with `opcode` if
371   /// `extInstSet` is empty. Otherwise serializes it as an extended instruction
372   /// with `opcode` from `extInstSet`.
373   /// This method is a generic one for dispatching any SPIR-V ops that has no
374   /// variadic operands and attributes in TableGen definitions.
375   LogicalResult processOpWithoutGrammarAttr(Operation *op, StringRef extInstSet,
376                                             uint32_t opcode);
377 
378   /// Dispatches to the serialization function for an operation in SPIR-V
379   /// dialect that is a mirror of an instruction in the SPIR-V spec. This is
380   /// auto-generated from ODS. Dispatch is handled for all operations in SPIR-V
381   /// dialect that have hasOpcode == 1.
382   LogicalResult dispatchToAutogenSerialization(Operation *op);
383 
384   /// Serializes an operation in the SPIR-V dialect that is a mirror of an
385   /// instruction in the SPIR-V spec. This is auto generated if hasOpcode == 1
386   /// and autogenSerialization == 1 in ODS.
387   template <typename OpTy>
388   LogicalResult processOp(OpTy op) {
389     return op.emitError("unsupported op serialization");
390   }
391 
392   //===--------------------------------------------------------------------===//
393   // Utilities
394   //===--------------------------------------------------------------------===//
395 
396   /// Emits an OpDecorate instruction to decorate the given `target` with the
397   /// given `decoration`.
398   LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration,
399                                ArrayRef<uint32_t> params = {});
400 
401   /// Emits an OpLine instruction with the given `loc` location information into
402   /// the given `binary` vector.
403   LogicalResult emitDebugLine(SmallVectorImpl<uint32_t> &binary, Location loc);
404 
405 private:
406   /// The SPIR-V module to be serialized.
407   spirv::ModuleOp module;
408 
409   /// An MLIR builder for getting MLIR constructs.
410   mlir::Builder mlirBuilder;
411 
412   /// A flag which indicates if the debuginfo should be emitted.
413   bool emitDebugInfo = false;
414 
415   /// A flag which indicates if the last processed instruction was a merge
416   /// instruction.
417   /// According to SPIR-V spec: "If a branch merge instruction is used, the last
418   /// OpLine in the block must be before its merge instruction".
419   bool lastProcessedWasMergeInst = false;
420 
421   /// The <id> of the OpString instruction, which specifies a file name, for
422   /// use by other debug instructions.
423   uint32_t fileID = 0;
424 
425   /// The next available result <id>.
426   uint32_t nextID = 1;
427 
428   // The following are for different SPIR-V instruction sections. They follow
429   // the logical layout of a SPIR-V module.
430 
431   SmallVector<uint32_t, 4> capabilities;
432   SmallVector<uint32_t, 0> extensions;
433   SmallVector<uint32_t, 0> extendedSets;
434   SmallVector<uint32_t, 3> memoryModel;
435   SmallVector<uint32_t, 0> entryPoints;
436   SmallVector<uint32_t, 4> executionModes;
437   SmallVector<uint32_t, 0> debug;
438   SmallVector<uint32_t, 0> names;
439   SmallVector<uint32_t, 0> decorations;
440   SmallVector<uint32_t, 0> typesGlobalValues;
441   SmallVector<uint32_t, 0> functions;
442 
443   /// Recursive struct references are serialized as OpTypePointer instructions
444   /// to the recursive struct type. However, the OpTypePointer instruction
445   /// cannot be emitted before the recursive struct's OpTypeStruct.
446   /// RecursiveStructPointerInfo stores the data needed to emit such
447   /// OpTypePointer instructions after forward references to such types.
448   struct RecursiveStructPointerInfo {
449     uint32_t pointerTypeID;
450     spirv::StorageClass storageClass;
451   };
452 
453   // Maps spirv::StructType to its recursive reference member info.
454   DenseMap<Type, SmallVector<RecursiveStructPointerInfo, 0>>
455       recursiveStructInfos;
456 
457   /// `functionHeader` contains all the instructions that must be in the first
458   /// block in the function, and `functionBody` contains the rest. After
459   /// processing FuncOp, the encoded instructions of a function are appended to
460   /// `functions`. An example of instructions in `functionHeader` in order:
461   /// OpFunction ...
462   /// OpFunctionParameter ...
463   /// OpFunctionParameter ...
464   /// OpLabel ...
465   /// OpVariable ...
466   /// OpVariable ...
467   SmallVector<uint32_t, 0> functionHeader;
468   SmallVector<uint32_t, 0> functionBody;
469 
470   /// Map from type used in SPIR-V module to their <id>s.
471   DenseMap<Type, uint32_t> typeIDMap;
472 
473   /// Map from constant values to their <id>s.
474   DenseMap<Attribute, uint32_t> constIDMap;
475 
476   /// Map from specialization constant names to their <id>s.
477   llvm::StringMap<uint32_t> specConstIDMap;
478 
479   /// Map from GlobalVariableOps name to <id>s.
480   llvm::StringMap<uint32_t> globalVarIDMap;
481 
482   /// Map from FuncOps name to <id>s.
483   llvm::StringMap<uint32_t> funcIDMap;
484 
485   /// Map from blocks to their <id>s.
486   DenseMap<Block *, uint32_t> blockIDMap;
487 
488   /// Map from the Type to the <id> that represents undef value of that type.
489   DenseMap<Type, uint32_t> undefValIDMap;
490 
491   /// Map from results of normal operations to their <id>s.
492   DenseMap<Value, uint32_t> valueIDMap;
493 
494   /// Map from extended instruction set name to <id>s.
495   llvm::StringMap<uint32_t> extendedInstSetIDMap;
496 
497   /// Map from values used in OpPhi instructions to their offset in the
498   /// `functions` section.
499   ///
500   /// When processing a block with arguments, we need to emit OpPhi
501   /// instructions to record the predecessor block <id>s and the values they
502   /// send to the block in question. But it's not guaranteed all values are
503   /// visited and thus assigned result <id>s. So we need this list to capture
504   /// the offsets into `functions` where a value is used so that we can fix it
505   /// up later after processing all the blocks in a function.
506   ///
507   /// More concretely, say if we are visiting the following blocks:
508   ///
509   /// ```mlir
510   /// ^phi(%arg0: i32):
511   ///   ...
512   /// ^parent1:
513   ///   ...
514   ///   spv.Branch ^phi(%val0: i32)
515   /// ^parent2:
516   ///   ...
517   ///   spv.Branch ^phi(%val1: i32)
518   /// ```
519   ///
520   /// When we are serializing the `^phi` block, we need to emit at the beginning
521   /// of the block OpPhi instructions which has the following parameters:
522   ///
523   /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1
524   ///                               id-for-%val1 id-for-^parent2
525   ///
526   /// But we don't know the <id> for %val0 and %val1 yet. One way is to visit
527   /// all the blocks twice and use the first visit to assign an <id> to each
528   /// value. But it's paying the overheads just for OpPhi emission. Instead,
529   /// we still visit the blocks once for emission. When we emit the OpPhi
530   /// instructions, we use 0 as a placeholder for the <id>s for %val0 and %val1.
531   /// At the same time, we record their offsets in the emitted binary (which is
532   /// placed inside `functions`) here. And then after emitting all blocks, we
533   /// replace the dummy <id> 0 with the real result <id> by overwriting
534   /// `functions[offset]`.
535   DenseMap<Value, SmallVector<size_t, 1>> deferredPhiValues;
536 };
537 } // namespace
538 
539 Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo)
540     : module(module), mlirBuilder(module.getContext()),
541       emitDebugInfo(emitDebugInfo) {}
542 
543 LogicalResult Serializer::serialize() {
544   LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
545 
546   if (failed(module.verify()))
547     return failure();
548 
549   // TODO: handle the other sections
550   processCapability();
551   processExtension();
552   processMemoryModel();
553   processDebugInfo();
554 
555   // Iterate over the module body to serialize it. Assumptions are that there is
556   // only one basic block in the moduleOp
557   for (auto &op : module.getBlock()) {
558     if (failed(processOperation(&op))) {
559       return failure();
560     }
561   }
562 
563   LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
564   return success();
565 }
566 
567 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
568   auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
569                     extensions.size() + extendedSets.size() +
570                     memoryModel.size() + entryPoints.size() +
571                     executionModes.size() + decorations.size() +
572                     typesGlobalValues.size() + functions.size();
573 
574   binary.clear();
575   binary.reserve(moduleSize);
576 
577   spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID);
578   binary.append(capabilities.begin(), capabilities.end());
579   binary.append(extensions.begin(), extensions.end());
580   binary.append(extendedSets.begin(), extendedSets.end());
581   binary.append(memoryModel.begin(), memoryModel.end());
582   binary.append(entryPoints.begin(), entryPoints.end());
583   binary.append(executionModes.begin(), executionModes.end());
584   binary.append(debug.begin(), debug.end());
585   binary.append(names.begin(), names.end());
586   binary.append(decorations.begin(), decorations.end());
587   binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
588   binary.append(functions.begin(), functions.end());
589 }
590 
591 #ifndef NDEBUG
592 void Serializer::printValueIDMap(raw_ostream &os) {
593   os << "\n= Value <id> Map =\n\n";
594   for (auto valueIDPair : valueIDMap) {
595     Value val = valueIDPair.first;
596     os << "  " << val << " "
597        << "id = " << valueIDPair.second << ' ';
598     if (auto *op = val.getDefiningOp()) {
599       os << "from op '" << op->getName() << "'";
600     } else if (auto arg = val.dyn_cast<BlockArgument>()) {
601       Block *block = arg.getOwner();
602       os << "from argument of block " << block << ' ';
603       os << " in op '" << block->getParentOp()->getName() << "'";
604     }
605     os << '\n';
606   }
607 }
608 #endif
609 
610 //===----------------------------------------------------------------------===//
611 // Module structure
612 //===----------------------------------------------------------------------===//
613 
614 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
615   auto funcID = funcIDMap.lookup(fnName);
616   if (!funcID) {
617     funcID = getNextID();
618     funcIDMap[fnName] = funcID;
619   }
620   return funcID;
621 }
622 
623 void Serializer::processCapability() {
624   for (auto cap : module.vce_triple()->getCapabilities())
625     encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
626                           {static_cast<uint32_t>(cap)});
627 }
628 
629 void Serializer::processDebugInfo() {
630   if (!emitDebugInfo)
631     return;
632   auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>();
633   auto fileName = fileLoc ? fileLoc.getFilename() : "<unknown>";
634   fileID = getNextID();
635   SmallVector<uint32_t, 16> operands;
636   operands.push_back(fileID);
637   spirv::encodeStringLiteralInto(operands, fileName);
638   encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
639   // TODO: Encode more debug instructions.
640 }
641 
642 void Serializer::processExtension() {
643   llvm::SmallVector<uint32_t, 16> extName;
644   for (spirv::Extension ext : module.vce_triple()->getExtensions()) {
645     extName.clear();
646     spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
647     encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
648   }
649 }
650 
651 void Serializer::processMemoryModel() {
652   uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt();
653   uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt();
654 
655   encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
656 }
657 
658 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
659   if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
660     valueIDMap[op.getResult()] = resultID;
661     return success();
662   }
663   return failure();
664 }
665 
666 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
667   if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(),
668                                             /*isSpec=*/true)) {
669     // Emit the OpDecorate instruction for SpecId.
670     if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
671       auto val = static_cast<uint32_t>(specID.getInt());
672       emitDecoration(resultID, spirv::Decoration::SpecId, {val});
673     }
674 
675     specConstIDMap[op.sym_name()] = resultID;
676     return processName(resultID, op.sym_name());
677   }
678   return failure();
679 }
680 
681 LogicalResult
682 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
683   uint32_t typeID = 0;
684   if (failed(processType(op.getLoc(), op.type(), typeID))) {
685     return failure();
686   }
687 
688   auto resultID = getNextID();
689 
690   SmallVector<uint32_t, 8> operands;
691   operands.push_back(typeID);
692   operands.push_back(resultID);
693 
694   auto constituents = op.constituents();
695 
696   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
697     auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
698 
699     auto constituentName = constituent.getValue();
700     auto constituentID = getSpecConstID(constituentName);
701 
702     if (!constituentID) {
703       return op.emitError("unknown result <id> for specialization constant ")
704              << constituentName;
705     }
706 
707     operands.push_back(constituentID);
708   }
709 
710   encodeInstructionInto(typesGlobalValues,
711                         spirv::Opcode::OpSpecConstantComposite, operands);
712   specConstIDMap[op.sym_name()] = resultID;
713 
714   return processName(resultID, op.sym_name());
715 }
716 
717 LogicalResult
718 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
719   uint32_t typeID = 0;
720   if (failed(processType(op.getLoc(), op.getType(), typeID))) {
721     return failure();
722   }
723 
724   auto resultID = getNextID();
725 
726   SmallVector<uint32_t, 8> operands;
727   operands.push_back(typeID);
728   operands.push_back(resultID);
729 
730   Block &block = op.getRegion().getBlocks().front();
731   Operation &enclosedOp = block.getOperations().front();
732 
733   std::string enclosedOpName;
734   llvm::raw_string_ostream rss(enclosedOpName);
735   rss << "Op" << enclosedOp.getName().stripDialect();
736   auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
737 
738   if (!enclosedOpcode) {
739     op.emitError("Couldn't find op code for op ")
740         << enclosedOp.getName().getStringRef();
741     return failure();
742   }
743 
744   operands.push_back(static_cast<uint32_t>(enclosedOpcode.getValue()));
745 
746   // Append operands to the enclosed op to the list of operands.
747   for (Value operand : enclosedOp.getOperands()) {
748     uint32_t id = getValueID(operand);
749     assert(id && "use before def!");
750     operands.push_back(id);
751   }
752 
753   encodeInstructionInto(typesGlobalValues,
754                         spirv::Opcode::OpSpecConstantOperation, operands);
755   valueIDMap[op.getResult()] = resultID;
756 
757   return success();
758 }
759 
760 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
761   auto undefType = op.getType();
762   auto &id = undefValIDMap[undefType];
763   if (!id) {
764     id = getNextID();
765     uint32_t typeID = 0;
766     if (failed(processType(op.getLoc(), undefType, typeID)) ||
767         failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
768                                      {typeID, id}))) {
769       return failure();
770     }
771   }
772   valueIDMap[op.getResult()] = id;
773   return success();
774 }
775 
776 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
777                                             NamedAttribute attr) {
778   auto attrName = attr.first.strref();
779   auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true);
780   auto decoration = spirv::symbolizeDecoration(decorationName);
781   if (!decoration) {
782     return emitError(
783                loc, "non-argument attributes expected to have snake-case-ified "
784                     "decoration name, unhandled attribute with name : ")
785            << attrName;
786   }
787   SmallVector<uint32_t, 1> args;
788   switch (decoration.getValue()) {
789   case spirv::Decoration::Binding:
790   case spirv::Decoration::DescriptorSet:
791   case spirv::Decoration::Location:
792     if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) {
793       args.push_back(intAttr.getValue().getZExtValue());
794       break;
795     }
796     return emitError(loc, "expected integer attribute for ") << attrName;
797   case spirv::Decoration::BuiltIn:
798     if (auto strAttr = attr.second.dyn_cast<StringAttr>()) {
799       auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
800       if (enumVal) {
801         args.push_back(static_cast<uint32_t>(enumVal.getValue()));
802         break;
803       }
804       return emitError(loc, "invalid ")
805              << attrName << " attribute " << strAttr.getValue();
806     }
807     return emitError(loc, "expected string attribute for ") << attrName;
808   case spirv::Decoration::Aliased:
809   case spirv::Decoration::Flat:
810   case spirv::Decoration::NonReadable:
811   case spirv::Decoration::NonWritable:
812   case spirv::Decoration::NoPerspective:
813   case spirv::Decoration::Restrict:
814     // For unit attributes, the args list has no values so we do nothing
815     if (auto unitAttr = attr.second.dyn_cast<UnitAttr>())
816       break;
817     return emitError(loc, "expected unit attribute for ") << attrName;
818   default:
819     return emitError(loc, "unhandled decoration ") << decorationName;
820   }
821   return emitDecoration(resultID, decoration.getValue(), args);
822 }
823 
824 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
825   assert(!name.empty() && "unexpected empty string for OpName");
826 
827   SmallVector<uint32_t, 4> nameOperands;
828   nameOperands.push_back(resultID);
829   if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) {
830     return failure();
831   }
832   return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
833 }
834 
835 namespace {
836 template <>
837 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
838     Location loc, spirv::ArrayType type, uint32_t resultID) {
839   if (unsigned stride = type.getArrayStride()) {
840     // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
841     return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
842   }
843   return success();
844 }
845 
846 template <>
847 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
848     Location Loc, spirv::RuntimeArrayType type, uint32_t resultID) {
849   if (unsigned stride = type.getArrayStride()) {
850     // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
851     return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
852   }
853   return success();
854 }
855 
856 LogicalResult Serializer::processMemberDecoration(
857     uint32_t structID,
858     const spirv::StructType::MemberDecorationInfo &memberDecoration) {
859   SmallVector<uint32_t, 4> args(
860       {structID, memberDecoration.memberIndex,
861        static_cast<uint32_t>(memberDecoration.decoration)});
862   if (memberDecoration.hasValue) {
863     args.push_back(memberDecoration.decorationValue);
864   }
865   return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate,
866                                args);
867 }
868 } // namespace
869 
870 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
871   LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
872   assert(functionHeader.empty() && functionBody.empty());
873 
874   uint32_t fnTypeID = 0;
875   // Generate type of the function.
876   processType(op.getLoc(), op.getType(), fnTypeID);
877 
878   // Add the function definition.
879   SmallVector<uint32_t, 4> operands;
880   uint32_t resTypeID = 0;
881   auto resultTypes = op.getType().getResults();
882   if (resultTypes.size() > 1) {
883     return op.emitError("cannot serialize function with multiple return types");
884   }
885   if (failed(processType(op.getLoc(),
886                          (resultTypes.empty() ? getVoidType() : resultTypes[0]),
887                          resTypeID))) {
888     return failure();
889   }
890   operands.push_back(resTypeID);
891   auto funcID = getOrCreateFunctionID(op.getName());
892   operands.push_back(funcID);
893   operands.push_back(static_cast<uint32_t>(op.function_control()));
894   operands.push_back(fnTypeID);
895   encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
896 
897   // Add function name.
898   if (failed(processName(funcID, op.getName()))) {
899     return failure();
900   }
901 
902   // Declare the parameters.
903   for (auto arg : op.getArguments()) {
904     uint32_t argTypeID = 0;
905     if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
906       return failure();
907     }
908     auto argValueID = getNextID();
909     valueIDMap[arg] = argValueID;
910     encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
911                           {argTypeID, argValueID});
912   }
913 
914   // Process the body.
915   if (op.isExternal()) {
916     return op.emitError("external function is unhandled");
917   }
918 
919   // Some instructions (e.g., OpVariable) in a function must be in the first
920   // block in the function. These instructions will be put in functionHeader.
921   // Thus, we put the label in functionHeader first, and omit it from the first
922   // block.
923   encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
924                         {getOrCreateBlockID(&op.front())});
925   processBlock(&op.front(), /*omitLabel=*/true);
926   if (failed(visitInPrettyBlockOrder(
927           &op.front(), [&](Block *block) { return processBlock(block); },
928           /*skipHeader=*/true))) {
929     return failure();
930   }
931 
932   // There might be OpPhi instructions who have value references needing to fix.
933   for (auto deferredValue : deferredPhiValues) {
934     Value value = deferredValue.first;
935     uint32_t id = getValueID(value);
936     LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
937                             << " to id = " << id << '\n');
938     assert(id && "OpPhi references undefined value!");
939     for (size_t offset : deferredValue.second)
940       functionBody[offset] = id;
941   }
942   deferredPhiValues.clear();
943 
944   LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
945                           << "' --\n");
946   // Insert OpFunctionEnd.
947   if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd,
948                                    {}))) {
949     return failure();
950   }
951 
952   functions.append(functionHeader.begin(), functionHeader.end());
953   functions.append(functionBody.begin(), functionBody.end());
954   functionHeader.clear();
955   functionBody.clear();
956 
957   return success();
958 }
959 
960 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
961   SmallVector<uint32_t, 4> operands;
962   SmallVector<StringRef, 2> elidedAttrs;
963   uint32_t resultID = 0;
964   uint32_t resultTypeID = 0;
965   if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
966     return failure();
967   }
968   operands.push_back(resultTypeID);
969   resultID = getNextID();
970   valueIDMap[op.getResult()] = resultID;
971   operands.push_back(resultID);
972   auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
973   if (attr) {
974     operands.push_back(static_cast<uint32_t>(
975         attr.cast<IntegerAttr>().getValue().getZExtValue()));
976   }
977   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
978   for (auto arg : op.getODSOperands(0)) {
979     auto argID = getValueID(arg);
980     if (!argID) {
981       return emitError(op.getLoc(), "operand 0 has a use before def");
982     }
983     operands.push_back(argID);
984   }
985   emitDebugLine(functionHeader, op.getLoc());
986   encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
987   for (auto attr : op->getAttrs()) {
988     if (llvm::any_of(elidedAttrs,
989                      [&](StringRef elided) { return attr.first == elided; })) {
990       continue;
991     }
992     if (failed(processDecoration(op.getLoc(), resultID, attr))) {
993       return failure();
994     }
995   }
996   return success();
997 }
998 
999 LogicalResult
1000 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
1001   // Get TypeID.
1002   uint32_t resultTypeID = 0;
1003   SmallVector<StringRef, 4> elidedAttrs;
1004   if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
1005     return failure();
1006   }
1007 
1008   if (isInterfaceStructPtrType(varOp.type())) {
1009     auto structType = varOp.type()
1010                           .cast<spirv::PointerType>()
1011                           .getPointeeType()
1012                           .cast<spirv::StructType>();
1013     if (failed(
1014             emitDecoration(getTypeID(structType), spirv::Decoration::Block))) {
1015       return varOp.emitError("cannot decorate ")
1016              << structType << " with Block decoration";
1017     }
1018   }
1019 
1020   elidedAttrs.push_back("type");
1021   SmallVector<uint32_t, 4> operands;
1022   operands.push_back(resultTypeID);
1023   auto resultID = getNextID();
1024 
1025   // Encode the name.
1026   auto varName = varOp.sym_name();
1027   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1028   if (failed(processName(resultID, varName))) {
1029     return failure();
1030   }
1031   globalVarIDMap[varName] = resultID;
1032   operands.push_back(resultID);
1033 
1034   // Encode StorageClass.
1035   operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
1036 
1037   // Encode initialization.
1038   if (auto initializer = varOp.initializer()) {
1039     auto initializerID = getVariableID(initializer.getValue());
1040     if (!initializerID) {
1041       return emitError(varOp.getLoc(),
1042                        "invalid usage of undefined variable as initializer");
1043     }
1044     operands.push_back(initializerID);
1045     elidedAttrs.push_back("initializer");
1046   }
1047 
1048   emitDebugLine(typesGlobalValues, varOp.getLoc());
1049   if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable,
1050                                    operands))) {
1051     elidedAttrs.push_back("initializer");
1052     return failure();
1053   }
1054 
1055   // Encode decorations.
1056   for (auto attr : varOp->getAttrs()) {
1057     if (llvm::any_of(elidedAttrs,
1058                      [&](StringRef elided) { return attr.first == elided; })) {
1059       continue;
1060     }
1061     if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
1062       return failure();
1063     }
1064   }
1065   return success();
1066 }
1067 
1068 //===----------------------------------------------------------------------===//
1069 // Type
1070 //===----------------------------------------------------------------------===//
1071 
1072 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
1073 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
1074 // PushConstant Storage Classes must be explicitly laid out."
1075 bool Serializer::isInterfaceStructPtrType(Type type) const {
1076   if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
1077     switch (ptrType.getStorageClass()) {
1078     case spirv::StorageClass::PhysicalStorageBuffer:
1079     case spirv::StorageClass::PushConstant:
1080     case spirv::StorageClass::StorageBuffer:
1081     case spirv::StorageClass::Uniform:
1082       return ptrType.getPointeeType().isa<spirv::StructType>();
1083     default:
1084       break;
1085     }
1086   }
1087   return false;
1088 }
1089 
1090 LogicalResult Serializer::processType(Location loc, Type type,
1091                                       uint32_t &typeID) {
1092   // Maintains a set of names for nested identified struct types. This is used
1093   // to properly serialize recursive references.
1094   llvm::SetVector<StringRef> serializationCtx;
1095   return processTypeImpl(loc, type, typeID, serializationCtx);
1096 }
1097 
1098 LogicalResult
1099 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
1100                             llvm::SetVector<StringRef> &serializationCtx) {
1101   typeID = getTypeID(type);
1102   if (typeID) {
1103     return success();
1104   }
1105   typeID = getNextID();
1106   SmallVector<uint32_t, 4> operands;
1107 
1108   operands.push_back(typeID);
1109   auto typeEnum = spirv::Opcode::OpTypeVoid;
1110   bool deferSerialization = false;
1111 
1112   if ((type.isa<FunctionType>() &&
1113        succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
1114                                      operands))) ||
1115       succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
1116                                  deferSerialization, serializationCtx))) {
1117     if (deferSerialization)
1118       return success();
1119 
1120     typeIDMap[type] = typeID;
1121 
1122     if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands)))
1123       return failure();
1124 
1125     if (recursiveStructInfos.count(type) != 0) {
1126       // This recursive struct type is emitted already, now the OpTypePointer
1127       // instructions referring to recursive references are emitted as well.
1128       for (auto &ptrInfo : recursiveStructInfos[type]) {
1129         // TODO: This might not work if more than 1 recursive reference is
1130         // present in the struct.
1131         SmallVector<uint32_t, 4> ptrOperands;
1132         ptrOperands.push_back(ptrInfo.pointerTypeID);
1133         ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
1134         ptrOperands.push_back(typeIDMap[type]);
1135 
1136         if (failed(encodeInstructionInto(
1137                 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands)))
1138           return failure();
1139       }
1140 
1141       recursiveStructInfos[type].clear();
1142     }
1143 
1144     return success();
1145   }
1146 
1147   return failure();
1148 }
1149 
1150 LogicalResult Serializer::prepareBasicType(
1151     Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
1152     SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
1153     llvm::SetVector<StringRef> &serializationCtx) {
1154   deferSerialization = false;
1155 
1156   if (isVoidType(type)) {
1157     typeEnum = spirv::Opcode::OpTypeVoid;
1158     return success();
1159   }
1160 
1161   if (auto intType = type.dyn_cast<IntegerType>()) {
1162     if (intType.getWidth() == 1) {
1163       typeEnum = spirv::Opcode::OpTypeBool;
1164       return success();
1165     }
1166 
1167     typeEnum = spirv::Opcode::OpTypeInt;
1168     operands.push_back(intType.getWidth());
1169     // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
1170     // to preserve or validate.
1171     // 0 indicates unsigned, or no signedness semantics
1172     // 1 indicates signed semantics."
1173     operands.push_back(intType.isSigned() ? 1 : 0);
1174     return success();
1175   }
1176 
1177   if (auto floatType = type.dyn_cast<FloatType>()) {
1178     typeEnum = spirv::Opcode::OpTypeFloat;
1179     operands.push_back(floatType.getWidth());
1180     return success();
1181   }
1182 
1183   if (auto vectorType = type.dyn_cast<VectorType>()) {
1184     uint32_t elementTypeID = 0;
1185     if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
1186                                serializationCtx))) {
1187       return failure();
1188     }
1189     typeEnum = spirv::Opcode::OpTypeVector;
1190     operands.push_back(elementTypeID);
1191     operands.push_back(vectorType.getNumElements());
1192     return success();
1193   }
1194 
1195   if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
1196     typeEnum = spirv::Opcode::OpTypeArray;
1197     uint32_t elementTypeID = 0;
1198     if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
1199                                serializationCtx))) {
1200       return failure();
1201     }
1202     operands.push_back(elementTypeID);
1203     if (auto elementCountID = prepareConstantInt(
1204             loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
1205       operands.push_back(elementCountID);
1206     }
1207     return processTypeDecoration(loc, arrayType, resultID);
1208   }
1209 
1210   if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
1211     uint32_t pointeeTypeID = 0;
1212     spirv::StructType pointeeStruct =
1213         ptrType.getPointeeType().dyn_cast<spirv::StructType>();
1214 
1215     if (pointeeStruct && pointeeStruct.isIdentified() &&
1216         serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
1217       // A recursive reference to an enclosing struct is found.
1218       //
1219       // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
1220       // class as operands.
1221       SmallVector<uint32_t, 2> forwardPtrOperands;
1222       forwardPtrOperands.push_back(resultID);
1223       forwardPtrOperands.push_back(
1224           static_cast<uint32_t>(ptrType.getStorageClass()));
1225 
1226       encodeInstructionInto(typesGlobalValues,
1227                             spirv::Opcode::OpTypeForwardPointer,
1228                             forwardPtrOperands);
1229 
1230       // 2. Find the pointee (enclosing) struct.
1231       auto structType = spirv::StructType::getIdentified(
1232           module.getContext(), pointeeStruct.getIdentifier());
1233 
1234       if (!structType)
1235         return failure();
1236 
1237       // 3. Mark the OpTypePointer that is supposed to be emitted by this call
1238       // as deferred.
1239       deferSerialization = true;
1240 
1241       // 4. Record the info needed to emit the deferred OpTypePointer
1242       // instruction when the enclosing struct is completely serialized.
1243       recursiveStructInfos[structType].push_back(
1244           {resultID, ptrType.getStorageClass()});
1245     } else {
1246       if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
1247                                  serializationCtx)))
1248         return failure();
1249     }
1250 
1251     typeEnum = spirv::Opcode::OpTypePointer;
1252     operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
1253     operands.push_back(pointeeTypeID);
1254     return success();
1255   }
1256 
1257   if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
1258     uint32_t elementTypeID = 0;
1259     if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
1260                                elementTypeID, serializationCtx))) {
1261       return failure();
1262     }
1263     typeEnum = spirv::Opcode::OpTypeRuntimeArray;
1264     operands.push_back(elementTypeID);
1265     return processTypeDecoration(loc, runtimeArrayType, resultID);
1266   }
1267 
1268   if (auto structType = type.dyn_cast<spirv::StructType>()) {
1269     if (structType.isIdentified()) {
1270       processName(resultID, structType.getIdentifier());
1271       serializationCtx.insert(structType.getIdentifier());
1272     }
1273 
1274     bool hasOffset = structType.hasOffset();
1275     for (auto elementIndex :
1276          llvm::seq<uint32_t>(0, structType.getNumElements())) {
1277       uint32_t elementTypeID = 0;
1278       if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
1279                                  elementTypeID, serializationCtx))) {
1280         return failure();
1281       }
1282       operands.push_back(elementTypeID);
1283       if (hasOffset) {
1284         // Decorate each struct member with an offset
1285         spirv::StructType::MemberDecorationInfo offsetDecoration{
1286             elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
1287             static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
1288         if (failed(processMemberDecoration(resultID, offsetDecoration))) {
1289           return emitError(loc, "cannot decorate ")
1290                  << elementIndex << "-th member of " << structType
1291                  << " with its offset";
1292         }
1293       }
1294     }
1295     SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
1296     structType.getMemberDecorations(memberDecorations);
1297 
1298     for (auto &memberDecoration : memberDecorations) {
1299       if (failed(processMemberDecoration(resultID, memberDecoration))) {
1300         return emitError(loc, "cannot decorate ")
1301                << static_cast<uint32_t>(memberDecoration.memberIndex)
1302                << "-th member of " << structType << " with "
1303                << stringifyDecoration(memberDecoration.decoration);
1304       }
1305     }
1306 
1307     typeEnum = spirv::Opcode::OpTypeStruct;
1308 
1309     if (structType.isIdentified())
1310       serializationCtx.remove(structType.getIdentifier());
1311 
1312     return success();
1313   }
1314 
1315   if (auto cooperativeMatrixType =
1316           type.dyn_cast<spirv::CooperativeMatrixNVType>()) {
1317     uint32_t elementTypeID = 0;
1318     if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
1319                                elementTypeID, serializationCtx))) {
1320       return failure();
1321     }
1322     typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
1323     auto getConstantOp = [&](uint32_t id) {
1324       auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
1325       return prepareConstantInt(loc, attr);
1326     };
1327     operands.push_back(elementTypeID);
1328     operands.push_back(
1329         getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
1330     operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
1331     operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
1332     return success();
1333   }
1334 
1335   if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
1336     uint32_t elementTypeID = 0;
1337     if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
1338                                serializationCtx))) {
1339       return failure();
1340     }
1341     typeEnum = spirv::Opcode::OpTypeMatrix;
1342     operands.push_back(elementTypeID);
1343     operands.push_back(matrixType.getNumColumns());
1344     return success();
1345   }
1346 
1347   // TODO: Handle other types.
1348   return emitError(loc, "unhandled type in serialization: ") << type;
1349 }
1350 
1351 LogicalResult
1352 Serializer::prepareFunctionType(Location loc, FunctionType type,
1353                                 spirv::Opcode &typeEnum,
1354                                 SmallVectorImpl<uint32_t> &operands) {
1355   typeEnum = spirv::Opcode::OpTypeFunction;
1356   assert(type.getNumResults() <= 1 &&
1357          "serialization supports only a single return value");
1358   uint32_t resultID = 0;
1359   if (failed(processType(
1360           loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
1361           resultID))) {
1362     return failure();
1363   }
1364   operands.push_back(resultID);
1365   for (auto &res : type.getInputs()) {
1366     uint32_t argTypeID = 0;
1367     if (failed(processType(loc, res, argTypeID))) {
1368       return failure();
1369     }
1370     operands.push_back(argTypeID);
1371   }
1372   return success();
1373 }
1374 
1375 //===----------------------------------------------------------------------===//
1376 // Constant
1377 //===----------------------------------------------------------------------===//
1378 
1379 uint32_t Serializer::prepareConstant(Location loc, Type constType,
1380                                      Attribute valueAttr) {
1381   if (auto id = prepareConstantScalar(loc, valueAttr)) {
1382     return id;
1383   }
1384 
1385   // This is a composite literal. We need to handle each component separately
1386   // and then emit an OpConstantComposite for the whole.
1387 
1388   if (auto id = getConstantID(valueAttr)) {
1389     return id;
1390   }
1391 
1392   uint32_t typeID = 0;
1393   if (failed(processType(loc, constType, typeID))) {
1394     return 0;
1395   }
1396 
1397   uint32_t resultID = 0;
1398   if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) {
1399     int rank = attr.getType().dyn_cast<ShapedType>().getRank();
1400     SmallVector<uint64_t, 4> index(rank);
1401     resultID = prepareDenseElementsConstant(loc, constType, attr,
1402                                             /*dim=*/0, index);
1403   } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
1404     resultID = prepareArrayConstant(loc, constType, arrayAttr);
1405   }
1406 
1407   if (resultID == 0) {
1408     emitError(loc, "cannot serialize attribute: ") << valueAttr;
1409     return 0;
1410   }
1411 
1412   constIDMap[valueAttr] = resultID;
1413   return resultID;
1414 }
1415 
1416 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
1417                                           ArrayAttr attr) {
1418   uint32_t typeID = 0;
1419   if (failed(processType(loc, constType, typeID))) {
1420     return 0;
1421   }
1422 
1423   uint32_t resultID = getNextID();
1424   SmallVector<uint32_t, 4> operands = {typeID, resultID};
1425   operands.reserve(attr.size() + 2);
1426   auto elementType = constType.cast<spirv::ArrayType>().getElementType();
1427   for (Attribute elementAttr : attr) {
1428     if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
1429       operands.push_back(elementID);
1430     } else {
1431       return 0;
1432     }
1433   }
1434   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
1435   encodeInstructionInto(typesGlobalValues, opcode, operands);
1436 
1437   return resultID;
1438 }
1439 
1440 // TODO: Turn the below function into iterative function, instead of
1441 // recursive function.
1442 uint32_t
1443 Serializer::prepareDenseElementsConstant(Location loc, Type constType,
1444                                          DenseElementsAttr valueAttr, int dim,
1445                                          MutableArrayRef<uint64_t> index) {
1446   auto shapedType = valueAttr.getType().dyn_cast<ShapedType>();
1447   assert(dim <= shapedType.getRank());
1448   if (shapedType.getRank() == dim) {
1449     if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
1450       return attr.getType().getElementType().isInteger(1)
1451                  ? prepareConstantBool(loc, attr.getValue<BoolAttr>(index))
1452                  : prepareConstantInt(loc, attr.getValue<IntegerAttr>(index));
1453     }
1454     if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
1455       return prepareConstantFp(loc, attr.getValue<FloatAttr>(index));
1456     }
1457     return 0;
1458   }
1459 
1460   uint32_t typeID = 0;
1461   if (failed(processType(loc, constType, typeID))) {
1462     return 0;
1463   }
1464 
1465   uint32_t resultID = getNextID();
1466   SmallVector<uint32_t, 4> operands = {typeID, resultID};
1467   operands.reserve(shapedType.getDimSize(dim) + 2);
1468   auto elementType = constType.cast<spirv::CompositeType>().getElementType(0);
1469   for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
1470     index[dim] = i;
1471     if (auto elementID = prepareDenseElementsConstant(
1472             loc, elementType, valueAttr, dim + 1, index)) {
1473       operands.push_back(elementID);
1474     } else {
1475       return 0;
1476     }
1477   }
1478   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
1479   encodeInstructionInto(typesGlobalValues, opcode, operands);
1480 
1481   return resultID;
1482 }
1483 
1484 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
1485                                            bool isSpec) {
1486   if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
1487     return prepareConstantFp(loc, floatAttr, isSpec);
1488   }
1489   if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
1490     return prepareConstantBool(loc, boolAttr, isSpec);
1491   }
1492   if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
1493     return prepareConstantInt(loc, intAttr, isSpec);
1494   }
1495 
1496   return 0;
1497 }
1498 
1499 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
1500                                          bool isSpec) {
1501   if (!isSpec) {
1502     // We can de-duplicate normal constants, but not specialization constants.
1503     if (auto id = getConstantID(boolAttr)) {
1504       return id;
1505     }
1506   }
1507 
1508   // Process the type for this bool literal
1509   uint32_t typeID = 0;
1510   if (failed(processType(loc, boolAttr.getType(), typeID))) {
1511     return 0;
1512   }
1513 
1514   auto resultID = getNextID();
1515   auto opcode = boolAttr.getValue()
1516                     ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1517                               : spirv::Opcode::OpConstantTrue)
1518                     : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1519                               : spirv::Opcode::OpConstantFalse);
1520   encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
1521 
1522   if (!isSpec) {
1523     constIDMap[boolAttr] = resultID;
1524   }
1525   return resultID;
1526 }
1527 
1528 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
1529                                         bool isSpec) {
1530   if (!isSpec) {
1531     // We can de-duplicate normal constants, but not specialization constants.
1532     if (auto id = getConstantID(intAttr)) {
1533       return id;
1534     }
1535   }
1536 
1537   // Process the type for this integer literal
1538   uint32_t typeID = 0;
1539   if (failed(processType(loc, intAttr.getType(), typeID))) {
1540     return 0;
1541   }
1542 
1543   auto resultID = getNextID();
1544   APInt value = intAttr.getValue();
1545   unsigned bitwidth = value.getBitWidth();
1546   bool isSigned = value.isSignedIntN(bitwidth);
1547 
1548   auto opcode =
1549       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1550 
1551   // According to SPIR-V spec, "When the type's bit width is less than 32-bits,
1552   // the literal's value appears in the low-order bits of the word, and the
1553   // high-order bits must be 0 for a floating-point type, or 0 for an integer
1554   // type with Signedness of 0, or sign extended when Signedness is 1."
1555   if (bitwidth == 32 || bitwidth == 16) {
1556     uint32_t word = 0;
1557     if (isSigned) {
1558       word = static_cast<int32_t>(value.getSExtValue());
1559     } else {
1560       word = static_cast<uint32_t>(value.getZExtValue());
1561     }
1562     encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1563   }
1564   // According to SPIR-V spec: "When the type's bit width is larger than one
1565   // word, the literal’s low-order words appear first."
1566   else if (bitwidth == 64) {
1567     struct DoubleWord {
1568       uint32_t word1;
1569       uint32_t word2;
1570     } words;
1571     if (isSigned) {
1572       words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1573     } else {
1574       words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1575     }
1576     encodeInstructionInto(typesGlobalValues, opcode,
1577                           {typeID, resultID, words.word1, words.word2});
1578   } else {
1579     std::string valueStr;
1580     llvm::raw_string_ostream rss(valueStr);
1581     value.print(rss, /*isSigned=*/false);
1582 
1583     emitError(loc, "cannot serialize ")
1584         << bitwidth << "-bit integer literal: " << rss.str();
1585     return 0;
1586   }
1587 
1588   if (!isSpec) {
1589     constIDMap[intAttr] = resultID;
1590   }
1591   return resultID;
1592 }
1593 
1594 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1595                                        bool isSpec) {
1596   if (!isSpec) {
1597     // We can de-duplicate normal constants, but not specialization constants.
1598     if (auto id = getConstantID(floatAttr)) {
1599       return id;
1600     }
1601   }
1602 
1603   // Process the type for this float literal
1604   uint32_t typeID = 0;
1605   if (failed(processType(loc, floatAttr.getType(), typeID))) {
1606     return 0;
1607   }
1608 
1609   auto resultID = getNextID();
1610   APFloat value = floatAttr.getValue();
1611   APInt intValue = value.bitcastToAPInt();
1612 
1613   auto opcode =
1614       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1615 
1616   if (&value.getSemantics() == &APFloat::IEEEsingle()) {
1617     uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1618     encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1619   } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
1620     struct DoubleWord {
1621       uint32_t word1;
1622       uint32_t word2;
1623     } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1624     encodeInstructionInto(typesGlobalValues, opcode,
1625                           {typeID, resultID, words.word1, words.word2});
1626   } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
1627     uint32_t word =
1628         static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
1629     encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1630   } else {
1631     std::string valueStr;
1632     llvm::raw_string_ostream rss(valueStr);
1633     value.print(rss);
1634 
1635     emitError(loc, "cannot serialize ")
1636         << floatAttr.getType() << "-typed float literal: " << rss.str();
1637     return 0;
1638   }
1639 
1640   if (!isSpec) {
1641     constIDMap[floatAttr] = resultID;
1642   }
1643   return resultID;
1644 }
1645 
1646 //===----------------------------------------------------------------------===//
1647 // Control flow
1648 //===----------------------------------------------------------------------===//
1649 
1650 uint32_t Serializer::getOrCreateBlockID(Block *block) {
1651   if (uint32_t id = getBlockID(block))
1652     return id;
1653   return blockIDMap[block] = getNextID();
1654 }
1655 
1656 LogicalResult
1657 Serializer::processBlock(Block *block, bool omitLabel,
1658                          function_ref<void()> actionBeforeTerminator) {
1659   LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1660   LLVM_DEBUG(block->print(llvm::dbgs()));
1661   LLVM_DEBUG(llvm::dbgs() << '\n');
1662   if (!omitLabel) {
1663     uint32_t blockID = getOrCreateBlockID(block);
1664     LLVM_DEBUG(llvm::dbgs()
1665                << "[block] " << block << " (id = " << blockID << ")\n");
1666 
1667     // Emit OpLabel for this block.
1668     encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1669   }
1670 
1671   // Emit OpPhi instructions for block arguments, if any.
1672   if (failed(emitPhiForBlockArguments(block)))
1673     return failure();
1674 
1675   // Process each op in this block except the terminator.
1676   for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
1677     if (failed(processOperation(&op)))
1678       return failure();
1679   }
1680 
1681   // Process the terminator.
1682   if (actionBeforeTerminator)
1683     actionBeforeTerminator();
1684   if (failed(processOperation(&block->back())))
1685     return failure();
1686 
1687   return success();
1688 }
1689 
1690 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1691   // Nothing to do if this block has no arguments or it's the entry block, which
1692   // always has the same arguments as the function signature.
1693   if (block->args_empty() || block->isEntryBlock())
1694     return success();
1695 
1696   // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1697   // A SPIR-V OpPhi instruction is of the syntax:
1698   //   OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1699   // So we need to collect all predecessor blocks and the arguments they send
1700   // to this block.
1701   SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors;
1702   for (Block *predecessor : block->getPredecessors()) {
1703     auto *terminator = predecessor->getTerminator();
1704     // The predecessor here is the immediate one according to MLIR's IR
1705     // structure. It does not directly map to the incoming parent block for the
1706     // OpPhi instructions at SPIR-V binary level. This is because structured
1707     // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1708     // spv.selection/spv.loop op in the MLIR predecessor block, the branch op
1709     // jumping to the OpPhi's block then resides in the previous structured
1710     // control flow op's merge block.
1711     predecessor = getPhiIncomingBlock(predecessor);
1712     if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1713       predecessors.emplace_back(predecessor, branchOp.operand_begin());
1714     } else {
1715       return terminator->emitError("unimplemented terminator for Phi creation");
1716     }
1717   }
1718 
1719   // Then create OpPhi instruction for each of the block argument.
1720   for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1721     BlockArgument arg = block->getArgument(argIndex);
1722 
1723     // Get the type <id> and result <id> for this OpPhi instruction.
1724     uint32_t phiTypeID = 0;
1725     if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1726       return failure();
1727     uint32_t phiID = getNextID();
1728 
1729     LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1730                             << arg << " (id = " << phiID << ")\n");
1731 
1732     // Prepare the (value <id>, parent block <id>) pairs.
1733     SmallVector<uint32_t, 8> phiArgs;
1734     phiArgs.push_back(phiTypeID);
1735     phiArgs.push_back(phiID);
1736 
1737     for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1738       Value value = *(predecessors[predIndex].second + argIndex);
1739       uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1740       LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1741                               << ") value " << value << ' ');
1742       // Each pair is a value <id> ...
1743       uint32_t valueId = getValueID(value);
1744       if (valueId == 0) {
1745         // The op generating this value hasn't been visited yet so we don't have
1746         // an <id> assigned yet. Record this to fix up later.
1747         LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1748         deferredPhiValues[value].push_back(functionBody.size() + 1 +
1749                                            phiArgs.size());
1750       } else {
1751         LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1752       }
1753       phiArgs.push_back(valueId);
1754       // ... and a parent block <id>.
1755       phiArgs.push_back(predBlockId);
1756     }
1757 
1758     encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1759     valueIDMap[arg] = phiID;
1760   }
1761 
1762   return success();
1763 }
1764 
1765 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
1766   // Assign <id>s to all blocks so that branches inside the SelectionOp can
1767   // resolve properly.
1768   auto &body = selectionOp.body();
1769   for (Block &block : body)
1770     getOrCreateBlockID(&block);
1771 
1772   auto *headerBlock = selectionOp.getHeaderBlock();
1773   auto *mergeBlock = selectionOp.getMergeBlock();
1774   auto mergeID = getBlockID(mergeBlock);
1775   auto loc = selectionOp.getLoc();
1776 
1777   // Emit the selection header block, which dominates all other blocks, first.
1778   // We need to emit an OpSelectionMerge instruction before the selection header
1779   // block's terminator.
1780   auto emitSelectionMerge = [&]() {
1781     emitDebugLine(functionBody, loc);
1782     lastProcessedWasMergeInst = true;
1783     encodeInstructionInto(
1784         functionBody, spirv::Opcode::OpSelectionMerge,
1785         {mergeID, static_cast<uint32_t>(selectionOp.selection_control())});
1786   };
1787   // For structured selection, we cannot have blocks in the selection construct
1788   // branching to the selection header block. Entering the selection (and
1789   // reaching the selection header) must be from the block containing the
1790   // spv.selection op. If there are ops ahead of the spv.selection op in the
1791   // block, we can "merge" them into the selection header. So here we don't need
1792   // to emit a separate block; just continue with the existing block.
1793   if (failed(processBlock(headerBlock, /*omitLabel=*/true, emitSelectionMerge)))
1794     return failure();
1795 
1796   // Process all blocks with a depth-first visitor starting from the header
1797   // block. The selection header block and merge block are skipped by this
1798   // visitor.
1799   if (failed(visitInPrettyBlockOrder(
1800           headerBlock, [&](Block *block) { return processBlock(block); },
1801           /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
1802     return failure();
1803 
1804   // There is nothing to do for the merge block in the selection, which just
1805   // contains a spv.mlir.merge op, itself. But we need to have an OpLabel
1806   // instruction to start a new SPIR-V block for ops following this SelectionOp.
1807   // The block should use the <id> for the merge block.
1808   return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
1809 }
1810 
1811 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
1812   // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
1813   // properly. We don't need to assign for the entry block, which is just for
1814   // satisfying MLIR region's structural requirement.
1815   auto &body = loopOp.body();
1816   for (Block &block :
1817        llvm::make_range(std::next(body.begin(), 1), body.end())) {
1818     getOrCreateBlockID(&block);
1819   }
1820   auto *headerBlock = loopOp.getHeaderBlock();
1821   auto *continueBlock = loopOp.getContinueBlock();
1822   auto *mergeBlock = loopOp.getMergeBlock();
1823   auto headerID = getBlockID(headerBlock);
1824   auto continueID = getBlockID(continueBlock);
1825   auto mergeID = getBlockID(mergeBlock);
1826   auto loc = loopOp.getLoc();
1827 
1828   // This LoopOp is in some MLIR block with preceding and following ops. In the
1829   // binary format, it should reside in separate SPIR-V blocks from its
1830   // preceding and following ops. So we need to emit unconditional branches to
1831   // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
1832   // afterwards.
1833   encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
1834 
1835   // LoopOp's entry block is just there for satisfying MLIR's structural
1836   // requirements so we omit it and start serialization from the loop header
1837   // block.
1838 
1839   // Emit the loop header block, which dominates all other blocks, first. We
1840   // need to emit an OpLoopMerge instruction before the loop header block's
1841   // terminator.
1842   auto emitLoopMerge = [&]() {
1843     emitDebugLine(functionBody, loc);
1844     lastProcessedWasMergeInst = true;
1845     encodeInstructionInto(
1846         functionBody, spirv::Opcode::OpLoopMerge,
1847         {mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())});
1848   };
1849   if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
1850     return failure();
1851 
1852   // Process all blocks with a depth-first visitor starting from the header
1853   // block. The loop header block, loop continue block, and loop merge block are
1854   // skipped by this visitor and handled later in this function.
1855   if (failed(visitInPrettyBlockOrder(
1856           headerBlock, [&](Block *block) { return processBlock(block); },
1857           /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
1858     return failure();
1859 
1860   // We have handled all other blocks. Now get to the loop continue block.
1861   if (failed(processBlock(continueBlock)))
1862     return failure();
1863 
1864   // There is nothing to do for the merge block in the loop, which just contains
1865   // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to
1866   // start a new SPIR-V block for ops following this LoopOp. The block should
1867   // use the <id> for the merge block.
1868   return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
1869 }
1870 
1871 LogicalResult Serializer::processBranchConditionalOp(
1872     spirv::BranchConditionalOp condBranchOp) {
1873   auto conditionID = getValueID(condBranchOp.condition());
1874   auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
1875   auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
1876   SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
1877 
1878   if (auto weights = condBranchOp.branch_weights()) {
1879     for (auto val : weights->getValue())
1880       arguments.push_back(val.cast<IntegerAttr>().getInt());
1881   }
1882 
1883   emitDebugLine(functionBody, condBranchOp.getLoc());
1884   return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
1885                                arguments);
1886 }
1887 
1888 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
1889   emitDebugLine(functionBody, branchOp.getLoc());
1890   return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
1891                                {getOrCreateBlockID(branchOp.getTarget())});
1892 }
1893 
1894 //===----------------------------------------------------------------------===//
1895 // Operation
1896 //===----------------------------------------------------------------------===//
1897 
1898 LogicalResult Serializer::encodeExtensionInstruction(
1899     Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1900     ArrayRef<uint32_t> operands) {
1901   // Check if the extension has been imported.
1902   auto &setID = extendedInstSetIDMap[extensionSetName];
1903   if (!setID) {
1904     setID = getNextID();
1905     SmallVector<uint32_t, 16> importOperands;
1906     importOperands.push_back(setID);
1907     if (failed(
1908             spirv::encodeStringLiteralInto(importOperands, extensionSetName)) ||
1909         failed(encodeInstructionInto(
1910             extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) {
1911       return failure();
1912     }
1913   }
1914 
1915   // The first two operands are the result type <id> and result <id>. The set
1916   // <id> and the opcode need to be insert after this.
1917   if (operands.size() < 2) {
1918     return op->emitError("extended instructions must have a result encoding");
1919   }
1920   SmallVector<uint32_t, 8> extInstOperands;
1921   extInstOperands.reserve(operands.size() + 2);
1922   extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1923   extInstOperands.push_back(setID);
1924   extInstOperands.push_back(extensionOpcode);
1925   extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1926   return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1927                                extInstOperands);
1928 }
1929 
1930 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
1931   auto varName = addressOfOp.variable();
1932   auto variableID = getVariableID(varName);
1933   if (!variableID) {
1934     return addressOfOp.emitError("unknown result <id> for variable ")
1935            << varName;
1936   }
1937   valueIDMap[addressOfOp.pointer()] = variableID;
1938   return success();
1939 }
1940 
1941 LogicalResult
1942 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
1943   auto constName = referenceOfOp.spec_const();
1944   auto constID = getSpecConstID(constName);
1945   if (!constID) {
1946     return referenceOfOp.emitError(
1947                "unknown result <id> for specialization constant ")
1948            << constName;
1949   }
1950   valueIDMap[referenceOfOp.reference()] = constID;
1951   return success();
1952 }
1953 
1954 LogicalResult Serializer::processOperation(Operation *opInst) {
1955   LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1956 
1957   // First dispatch the ops that do not directly mirror an instruction from
1958   // the SPIR-V spec.
1959   return TypeSwitch<Operation *, LogicalResult>(opInst)
1960       .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1961       .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1962       .Case([&](spirv::BranchConditionalOp op) {
1963         return processBranchConditionalOp(op);
1964       })
1965       .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1966       .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1967       .Case([&](spirv::GlobalVariableOp op) {
1968         return processGlobalVariableOp(op);
1969       })
1970       .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1971       .Case([&](spirv::ModuleEndOp) { return success(); })
1972       .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1973       .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1974       .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1975       .Case([&](spirv::SpecConstantCompositeOp op) {
1976         return processSpecConstantCompositeOp(op);
1977       })
1978       .Case([&](spirv::SpecConstantOperationOp op) {
1979         return processSpecConstantOperationOp(op);
1980       })
1981       .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1982       .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1983 
1984       // Then handle all the ops that directly mirror SPIR-V instructions with
1985       // auto-generated methods.
1986       .Default(
1987           [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1988 }
1989 
1990 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1991                                                       StringRef extInstSet,
1992                                                       uint32_t opcode) {
1993   SmallVector<uint32_t, 4> operands;
1994   Location loc = op->getLoc();
1995 
1996   uint32_t resultID = 0;
1997   if (op->getNumResults() != 0) {
1998     uint32_t resultTypeID = 0;
1999     if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
2000       return failure();
2001     operands.push_back(resultTypeID);
2002 
2003     resultID = getNextID();
2004     operands.push_back(resultID);
2005     valueIDMap[op->getResult(0)] = resultID;
2006   };
2007 
2008   for (Value operand : op->getOperands())
2009     operands.push_back(getValueID(operand));
2010 
2011   emitDebugLine(functionBody, loc);
2012 
2013   if (extInstSet.empty()) {
2014     encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
2015                           operands);
2016   } else {
2017     encodeExtensionInstruction(op, extInstSet, opcode, operands);
2018   }
2019 
2020   if (op->getNumResults() != 0) {
2021     for (auto attr : op->getAttrs()) {
2022       if (failed(processDecoration(loc, resultID, attr)))
2023         return failure();
2024     }
2025   }
2026 
2027   return success();
2028 }
2029 
2030 namespace {
2031 template <>
2032 LogicalResult
2033 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
2034   SmallVector<uint32_t, 4> operands;
2035   // Add the ExecutionModel.
2036   operands.push_back(static_cast<uint32_t>(op.execution_model()));
2037   // Add the function <id>.
2038   auto funcID = getFunctionID(op.fn());
2039   if (!funcID) {
2040     return op.emitError("missing <id> for function ")
2041            << op.fn()
2042            << "; function needs to be defined before spv.EntryPoint is "
2043               "serialized";
2044   }
2045   operands.push_back(funcID);
2046   // Add the name of the function.
2047   spirv::encodeStringLiteralInto(operands, op.fn());
2048 
2049   // Add the interface values.
2050   if (auto interface = op.interface()) {
2051     for (auto var : interface.getValue()) {
2052       auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
2053       if (!id) {
2054         return op.emitError("referencing undefined global variable."
2055                             "spv.EntryPoint is at the end of spv.module. All "
2056                             "referenced variables should already be defined");
2057       }
2058       operands.push_back(id);
2059     }
2060   }
2061   return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint,
2062                                operands);
2063 }
2064 
2065 template <>
2066 LogicalResult
2067 Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) {
2068   StringRef argNames[] = {"execution_scope", "memory_scope",
2069                           "memory_semantics"};
2070   SmallVector<uint32_t, 3> operands;
2071 
2072   for (auto argName : argNames) {
2073     auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
2074     auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
2075     if (!operand) {
2076       return failure();
2077     }
2078     operands.push_back(operand);
2079   }
2080 
2081   return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier,
2082                                operands);
2083 }
2084 
2085 template <>
2086 LogicalResult
2087 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
2088   SmallVector<uint32_t, 4> operands;
2089   // Add the function <id>.
2090   auto funcID = getFunctionID(op.fn());
2091   if (!funcID) {
2092     return op.emitError("missing <id> for function ")
2093            << op.fn()
2094            << "; function needs to be serialized before ExecutionModeOp is "
2095               "serialized";
2096   }
2097   operands.push_back(funcID);
2098   // Add the ExecutionMode.
2099   operands.push_back(static_cast<uint32_t>(op.execution_mode()));
2100 
2101   // Serialize values if any.
2102   auto values = op.values();
2103   if (values) {
2104     for (auto &intVal : values.getValue()) {
2105       operands.push_back(static_cast<uint32_t>(
2106           intVal.cast<IntegerAttr>().getValue().getZExtValue()));
2107     }
2108   }
2109   return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
2110                                operands);
2111 }
2112 
2113 template <>
2114 LogicalResult
2115 Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) {
2116   StringRef argNames[] = {"memory_scope", "memory_semantics"};
2117   SmallVector<uint32_t, 2> operands;
2118 
2119   for (auto argName : argNames) {
2120     auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
2121     auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
2122     if (!operand) {
2123       return failure();
2124     }
2125     operands.push_back(operand);
2126   }
2127 
2128   return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier,
2129                                operands);
2130 }
2131 
2132 template <>
2133 LogicalResult
2134 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
2135   auto funcName = op.callee();
2136   uint32_t resTypeID = 0;
2137 
2138   Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
2139   if (failed(processType(op.getLoc(), resultTy, resTypeID)))
2140     return failure();
2141 
2142   auto funcID = getOrCreateFunctionID(funcName);
2143   auto funcCallID = getNextID();
2144   SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
2145 
2146   for (auto value : op.arguments()) {
2147     auto valueID = getValueID(value);
2148     assert(valueID && "cannot find a value for spv.FunctionCall");
2149     operands.push_back(valueID);
2150   }
2151 
2152   if (!resultTy.isa<NoneType>())
2153     valueIDMap[op.getResult(0)] = funcCallID;
2154 
2155   return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall,
2156                                operands);
2157 }
2158 
2159 template <>
2160 LogicalResult
2161 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
2162   SmallVector<uint32_t, 4> operands;
2163   SmallVector<StringRef, 2> elidedAttrs;
2164 
2165   for (Value operand : op->getOperands()) {
2166     auto id = getValueID(operand);
2167     assert(id && "use before def!");
2168     operands.push_back(id);
2169   }
2170 
2171   if (auto attr = op->getAttr("memory_access")) {
2172     operands.push_back(static_cast<uint32_t>(
2173         attr.cast<IntegerAttr>().getValue().getZExtValue()));
2174   }
2175 
2176   elidedAttrs.push_back("memory_access");
2177 
2178   if (auto attr = op->getAttr("alignment")) {
2179     operands.push_back(static_cast<uint32_t>(
2180         attr.cast<IntegerAttr>().getValue().getZExtValue()));
2181   }
2182 
2183   elidedAttrs.push_back("alignment");
2184 
2185   if (auto attr = op->getAttr("source_memory_access")) {
2186     operands.push_back(static_cast<uint32_t>(
2187         attr.cast<IntegerAttr>().getValue().getZExtValue()));
2188   }
2189 
2190   elidedAttrs.push_back("source_memory_access");
2191 
2192   if (auto attr = op->getAttr("source_alignment")) {
2193     operands.push_back(static_cast<uint32_t>(
2194         attr.cast<IntegerAttr>().getValue().getZExtValue()));
2195   }
2196 
2197   elidedAttrs.push_back("source_alignment");
2198   emitDebugLine(functionBody, op.getLoc());
2199   encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
2200 
2201   return success();
2202 }
2203 
2204 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
2205 // various Serializer::processOp<...>() specializations.
2206 #define GET_SERIALIZATION_FNS
2207 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
2208 } // namespace
2209 
2210 LogicalResult Serializer::emitDecoration(uint32_t target,
2211                                          spirv::Decoration decoration,
2212                                          ArrayRef<uint32_t> params) {
2213   uint32_t wordCount = 3 + params.size();
2214   decorations.push_back(
2215       spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate));
2216   decorations.push_back(target);
2217   decorations.push_back(static_cast<uint32_t>(decoration));
2218   decorations.append(params.begin(), params.end());
2219   return success();
2220 }
2221 
2222 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
2223                                         Location loc) {
2224   if (!emitDebugInfo)
2225     return success();
2226 
2227   if (lastProcessedWasMergeInst) {
2228     lastProcessedWasMergeInst = false;
2229     return success();
2230   }
2231 
2232   auto fileLoc = loc.dyn_cast<FileLineColLoc>();
2233   if (fileLoc)
2234     encodeInstructionInto(binary, spirv::Opcode::OpLine,
2235                           {fileID, fileLoc.getLine(), fileLoc.getColumn()});
2236   return success();
2237 }
2238 
2239 namespace mlir {
2240 LogicalResult spirv::serialize(spirv::ModuleOp module,
2241                                SmallVectorImpl<uint32_t> &binary,
2242                                bool emitDebugInfo) {
2243   if (!module.vce_triple().hasValue())
2244     return module.emitError(
2245         "module must have 'vce_triple' attribute to be serializeable");
2246 
2247   Serializer serializer(module, emitDebugInfo);
2248 
2249   if (failed(serializer.serialize()))
2250     return failure();
2251 
2252   LLVM_DEBUG(serializer.printValueIDMap(llvm::dbgs()));
2253 
2254   serializer.collect(binary);
2255   return success();
2256 }
2257 } // namespace mlir
2258