1 //===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serialization.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Target/SPIRV/Serialization.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/RegionGraphTraits.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
23 #include "llvm/ADT/DepthFirstIterator.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/ADT/bit.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 
34 #define DEBUG_TYPE "spirv-serialization"
35 
36 using namespace mlir;
37 
38 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
39 /// the given `binary` vector.
40 static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
41                                            spirv::Opcode op,
42                                            ArrayRef<uint32_t> operands) {
43   uint32_t wordCount = 1 + operands.size();
44   binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
45   binary.append(operands.begin(), operands.end());
46   return success();
47 }
48 
49 /// A pre-order depth-first visitor function for processing basic blocks.
50 ///
51 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
52 /// depth-first manner and calls `blockHandler` on each block. Skips handling
53 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
54 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
55 /// successors.
56 ///
57 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
58 /// of blocks in a function must satisfy the rule that blocks appear before
59 /// all blocks they dominate." This can be achieved by a pre-order CFG
60 /// traversal algorithm. To make the serialization output more logical and
61 /// readable to human, we perform depth-first CFG traversal and delay the
62 /// serialization of the merge block and the continue block, if exists, until
63 /// after all other blocks have been processed.
64 static LogicalResult
65 visitInPrettyBlockOrder(Block *headerBlock,
66                         function_ref<LogicalResult(Block *)> blockHandler,
67                         bool skipHeader = false, BlockRange skipBlocks = {}) {
68   llvm::df_iterator_default_set<Block *, 4> doneBlocks;
69   doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
70 
71   for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
72     if (skipHeader && block == headerBlock)
73       continue;
74     if (failed(blockHandler(block)))
75       return failure();
76   }
77   return success();
78 }
79 
80 /// Returns the merge block if the given `op` is a structured control flow op.
81 /// Otherwise returns nullptr.
82 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) {
83   if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
84     return selectionOp.getMergeBlock();
85   if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
86     return loopOp.getMergeBlock();
87   return nullptr;
88 }
89 
90 /// Given a predecessor `block` for a block with arguments, returns the block
91 /// that should be used as the parent block for SPIR-V OpPhi instructions
92 /// corresponding to the block arguments.
93 static Block *getPhiIncomingBlock(Block *block) {
94   // If the predecessor block in question is the entry block for a spv.loop,
95   // we jump to this spv.loop from its enclosing block.
96   if (block->isEntryBlock()) {
97     if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
98       // Then the incoming parent block for OpPhi should be the merge block of
99       // the structured control flow op before this loop.
100       Operation *op = loopOp.getOperation();
101       while ((op = op->getPrevNode()) != nullptr)
102         if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
103           return incomingBlock;
104       // Or the enclosing block itself if no structured control flow ops
105       // exists before this loop.
106       return loopOp->getBlock();
107     }
108   }
109 
110   // Otherwise, we jump from the given predecessor block. Try to see if there is
111   // a structured control flow op inside it.
112   for (Operation &op : llvm::reverse(block->getOperations())) {
113     if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
114       return incomingBlock;
115   }
116   return block;
117 }
118 
119 namespace {
120 
121 /// A SPIR-V module serializer.
122 ///
123 /// A SPIR-V binary module is a single linear stream of instructions; each
124 /// instruction is composed of 32-bit words with the layout:
125 ///
126 ///   | <word-count>|<opcode> |  <operand>   |  <operand>   | ... |
127 ///   | <------ word -------> | <-- word --> | <-- word --> | ... |
128 ///
129 /// For the first word, the 16 high-order bits are the word count of the
130 /// instruction, the 16 low-order bits are the opcode enumerant. The
131 /// instructions then belong to different sections, which must be laid out in
132 /// the particular order as specified in "2.4 Logical Layout of a Module" of
133 /// the SPIR-V spec.
134 class Serializer {
135 public:
136   /// Creates a serializer for the given SPIR-V `module`.
137   explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false);
138 
139   /// Serializes the remembered SPIR-V module.
140   LogicalResult serialize();
141 
142   /// Collects the final SPIR-V `binary`.
143   void collect(SmallVectorImpl<uint32_t> &binary);
144 
145 #ifndef NDEBUG
146   /// (For debugging) prints each value and its corresponding result <id>.
147   void printValueIDMap(raw_ostream &os);
148 #endif
149 
150 private:
151   // Note that there are two main categories of methods in this class:
152   // * process*() methods are meant to fully serialize a SPIR-V module entity
153   //   (header, type, op, etc.). They update internal vectors containing
154   //   different binary sections. They are not meant to be called except the
155   //   top-level serialization loop.
156   // * prepare*() methods are meant to be helpers that prepare for serializing
157   //   certain entity. They may or may not update internal vectors containing
158   //   different binary sections. They are meant to be called among themselves
159   //   or by other process*() methods for subtasks.
160 
161   //===--------------------------------------------------------------------===//
162   // <id>
163   //===--------------------------------------------------------------------===//
164 
165   // Note that it is illegal to use id <0> in SPIR-V binary module. Various
166   // methods in this class, if using SPIR-V word (uint32_t) as interface,
167   // check or return id <0> to indicate error in processing.
168 
169   /// Consumes the next unused <id>. This method will never return 0.
170   uint32_t getNextID() { return nextID++; }
171 
172   //===--------------------------------------------------------------------===//
173   // Module structure
174   //===--------------------------------------------------------------------===//
175 
176   uint32_t getSpecConstID(StringRef constName) const {
177     return specConstIDMap.lookup(constName);
178   }
179 
180   uint32_t getVariableID(StringRef varName) const {
181     return globalVarIDMap.lookup(varName);
182   }
183 
184   uint32_t getFunctionID(StringRef fnName) const {
185     return funcIDMap.lookup(fnName);
186   }
187 
188   /// Gets the <id> for the function with the given name. Assigns the next
189   /// available <id> if the function haven't been deserialized.
190   uint32_t getOrCreateFunctionID(StringRef fnName);
191 
192   void processCapability();
193 
194   void processDebugInfo();
195 
196   void processExtension();
197 
198   void processMemoryModel();
199 
200   LogicalResult processConstantOp(spirv::ConstantOp op);
201 
202   LogicalResult processSpecConstantOp(spirv::SpecConstantOp op);
203 
204   LogicalResult
205   processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op);
206 
207   LogicalResult
208   processSpecConstantOperationOp(spirv::SpecConstantOperationOp op);
209 
210   /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA
211   /// value to use with other operations. The SPIR-V spec recommends that
212   /// OpUndef be generated at module level. The serialization generates an
213   /// OpUndef for each type needed at module level.
214   LogicalResult processUndefOp(spirv::UndefOp op);
215 
216   /// Emit OpName for the given `resultID`.
217   LogicalResult processName(uint32_t resultID, StringRef name);
218 
219   /// Processes a SPIR-V function op.
220   LogicalResult processFuncOp(spirv::FuncOp op);
221 
222   LogicalResult processVariableOp(spirv::VariableOp op);
223 
224   /// Process a SPIR-V GlobalVariableOp
225   LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
226 
227   /// Process attributes that translate to decorations on the result <id>
228   LogicalResult processDecoration(Location loc, uint32_t resultID,
229                                   NamedAttribute attr);
230 
231   template <typename DType>
232   LogicalResult processTypeDecoration(Location loc, DType type,
233                                       uint32_t resultId) {
234     return emitError(loc, "unhandled decoration for type:") << type;
235   }
236 
237   /// Process member decoration
238   LogicalResult processMemberDecoration(
239       uint32_t structID,
240       const spirv::StructType::MemberDecorationInfo &memberDecorationInfo);
241 
242   //===--------------------------------------------------------------------===//
243   // Types
244   //===--------------------------------------------------------------------===//
245 
246   uint32_t getTypeID(Type type) const { return typeIDMap.lookup(type); }
247 
248   Type getVoidType() { return mlirBuilder.getNoneType(); }
249 
250   bool isVoidType(Type type) const { return type.isa<NoneType>(); }
251 
252   /// Returns true if the given type is a pointer type to a struct in some
253   /// interface storage class.
254   bool isInterfaceStructPtrType(Type type) const;
255 
256   /// Main dispatch method for serializing a type. The result <id> of the
257   /// serialized type will be returned as `typeID`.
258   LogicalResult processType(Location loc, Type type, uint32_t &typeID);
259   LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID,
260                                 llvm::SetVector<StringRef> &serializationCtx);
261 
262   /// Method for preparing basic SPIR-V type serialization. Returns the type's
263   /// opcode and operands for the instruction via `typeEnum` and `operands`.
264   LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID,
265                                  spirv::Opcode &typeEnum,
266                                  SmallVectorImpl<uint32_t> &operands,
267                                  bool &deferSerialization,
268                                  llvm::SetVector<StringRef> &serializationCtx);
269 
270   LogicalResult prepareFunctionType(Location loc, FunctionType type,
271                                     spirv::Opcode &typeEnum,
272                                     SmallVectorImpl<uint32_t> &operands);
273 
274   //===--------------------------------------------------------------------===//
275   // Constant
276   //===--------------------------------------------------------------------===//
277 
278   uint32_t getConstantID(Attribute value) const {
279     return constIDMap.lookup(value);
280   }
281 
282   /// Main dispatch method for processing a constant with the given `constType`
283   /// and `valueAttr`. `constType` is needed here because we can interpret the
284   /// `valueAttr` as a different type than the type of `valueAttr` itself; for
285   /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
286   /// constants.
287   uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
288 
289   /// Prepares array attribute serialization. This method emits corresponding
290   /// OpConstant* and returns the result <id> associated with it. Returns 0 if
291   /// failed.
292   uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr);
293 
294   /// Prepares bool/int/float DenseElementsAttr serialization. This method
295   /// iterates the DenseElementsAttr to construct the constant array, and
296   /// returns the result <id>  associated with it. Returns 0 if failed. Note
297   /// that the size of `index` must match the rank.
298   /// TODO: Consider to enhance splat elements cases. For splat cases,
299   /// we don't need to loop over all elements, especially when the splat value
300   /// is zero. We can use OpConstantNull when the value is zero.
301   uint32_t prepareDenseElementsConstant(Location loc, Type constType,
302                                         DenseElementsAttr valueAttr, int dim,
303                                         MutableArrayRef<uint64_t> index);
304 
305   /// Prepares scalar attribute serialization. This method emits corresponding
306   /// OpConstant* and returns the result <id> associated with it. Returns 0 if
307   /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is
308   /// true, then the constant will be serialized as a specialization constant.
309   uint32_t prepareConstantScalar(Location loc, Attribute valueAttr,
310                                  bool isSpec = false);
311 
312   uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr,
313                                bool isSpec = false);
314 
315   uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr,
316                               bool isSpec = false);
317 
318   uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr,
319                              bool isSpec = false);
320 
321   //===--------------------------------------------------------------------===//
322   // Control flow
323   //===--------------------------------------------------------------------===//
324 
325   /// Returns the result <id> for the given block.
326   uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); }
327 
328   /// Returns the result <id> for the given block. If no <id> has been assigned,
329   /// assigns the next available <id>
330   uint32_t getOrCreateBlockID(Block *block);
331 
332   /// Processes the given `block` and emits SPIR-V instructions for all ops
333   /// inside. Does not emit OpLabel for this block if `omitLabel` is true.
334   /// `actionBeforeTerminator` is a callback that will be invoked before
335   /// handling the terminator op. It can be used to inject the Op*Merge
336   /// instruction if this is a SPIR-V selection/loop header block.
337   LogicalResult
338   processBlock(Block *block, bool omitLabel = false,
339                function_ref<void()> actionBeforeTerminator = nullptr);
340 
341   /// Emits OpPhi instructions for the given block if it has block arguments.
342   LogicalResult emitPhiForBlockArguments(Block *block);
343 
344   LogicalResult processSelectionOp(spirv::SelectionOp selectionOp);
345 
346   LogicalResult processLoopOp(spirv::LoopOp loopOp);
347 
348   LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp);
349 
350   LogicalResult processBranchOp(spirv::BranchOp branchOp);
351 
352   //===--------------------------------------------------------------------===//
353   // Operations
354   //===--------------------------------------------------------------------===//
355 
356   LogicalResult encodeExtensionInstruction(Operation *op,
357                                            StringRef extensionSetName,
358                                            uint32_t opcode,
359                                            ArrayRef<uint32_t> operands);
360 
361   uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); }
362 
363   LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
364 
365   LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp);
366 
367   /// Main dispatch method for serializing an operation.
368   LogicalResult processOperation(Operation *op);
369 
370   /// Serializes an operation `op` as core instruction with `opcode` if
371   /// `extInstSet` is empty. Otherwise serializes it as an extended instruction
372   /// with `opcode` from `extInstSet`.
373   /// This method is a generic one for dispatching any SPIR-V ops that has no
374   /// variadic operands and attributes in TableGen definitions.
375   LogicalResult processOpWithoutGrammarAttr(Operation *op, StringRef extInstSet,
376                                             uint32_t opcode);
377 
378   /// Dispatches to the serialization function for an operation in SPIR-V
379   /// dialect that is a mirror of an instruction in the SPIR-V spec. This is
380   /// auto-generated from ODS. Dispatch is handled for all operations in SPIR-V
381   /// dialect that have hasOpcode == 1.
382   LogicalResult dispatchToAutogenSerialization(Operation *op);
383 
384   /// Serializes an operation in the SPIR-V dialect that is a mirror of an
385   /// instruction in the SPIR-V spec. This is auto generated if hasOpcode == 1
386   /// and autogenSerialization == 1 in ODS.
387   template <typename OpTy>
388   LogicalResult processOp(OpTy op) {
389     return op.emitError("unsupported op serialization");
390   }
391 
392   //===--------------------------------------------------------------------===//
393   // Utilities
394   //===--------------------------------------------------------------------===//
395 
396   /// Emits an OpDecorate instruction to decorate the given `target` with the
397   /// given `decoration`.
398   LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration,
399                                ArrayRef<uint32_t> params = {});
400 
401   /// Emits an OpLine instruction with the given `loc` location information into
402   /// the given `binary` vector.
403   LogicalResult emitDebugLine(SmallVectorImpl<uint32_t> &binary, Location loc);
404 
405 private:
406   /// The SPIR-V module to be serialized.
407   spirv::ModuleOp module;
408 
409   /// An MLIR builder for getting MLIR constructs.
410   mlir::Builder mlirBuilder;
411 
412   /// A flag which indicates if the debuginfo should be emitted.
413   bool emitDebugInfo = false;
414 
415   /// A flag which indicates if the last processed instruction was a merge
416   /// instruction.
417   /// According to SPIR-V spec: "If a branch merge instruction is used, the last
418   /// OpLine in the block must be before its merge instruction".
419   bool lastProcessedWasMergeInst = false;
420 
421   /// The <id> of the OpString instruction, which specifies a file name, for
422   /// use by other debug instructions.
423   uint32_t fileID = 0;
424 
425   /// The next available result <id>.
426   uint32_t nextID = 1;
427 
428   // The following are for different SPIR-V instruction sections. They follow
429   // the logical layout of a SPIR-V module.
430 
431   SmallVector<uint32_t, 4> capabilities;
432   SmallVector<uint32_t, 0> extensions;
433   SmallVector<uint32_t, 0> extendedSets;
434   SmallVector<uint32_t, 3> memoryModel;
435   SmallVector<uint32_t, 0> entryPoints;
436   SmallVector<uint32_t, 4> executionModes;
437   SmallVector<uint32_t, 0> debug;
438   SmallVector<uint32_t, 0> names;
439   SmallVector<uint32_t, 0> decorations;
440   SmallVector<uint32_t, 0> typesGlobalValues;
441   SmallVector<uint32_t, 0> functions;
442 
443   /// Recursive struct references are serialized as OpTypePointer instructions
444   /// to the recursive struct type. However, the OpTypePointer instruction
445   /// cannot be emitted before the recursive struct's OpTypeStruct.
446   /// RecursiveStructPointerInfo stores the data needed to emit such
447   /// OpTypePointer instructions after forward references to such types.
448   struct RecursiveStructPointerInfo {
449     uint32_t pointerTypeID;
450     spirv::StorageClass storageClass;
451   };
452 
453   // Maps spirv::StructType to its recursive reference member info.
454   DenseMap<Type, SmallVector<RecursiveStructPointerInfo, 0>>
455       recursiveStructInfos;
456 
457   /// `functionHeader` contains all the instructions that must be in the first
458   /// block in the function, and `functionBody` contains the rest. After
459   /// processing FuncOp, the encoded instructions of a function are appended to
460   /// `functions`. An example of instructions in `functionHeader` in order:
461   /// OpFunction ...
462   /// OpFunctionParameter ...
463   /// OpFunctionParameter ...
464   /// OpLabel ...
465   /// OpVariable ...
466   /// OpVariable ...
467   SmallVector<uint32_t, 0> functionHeader;
468   SmallVector<uint32_t, 0> functionBody;
469 
470   /// Map from type used in SPIR-V module to their <id>s.
471   DenseMap<Type, uint32_t> typeIDMap;
472 
473   /// Map from constant values to their <id>s.
474   DenseMap<Attribute, uint32_t> constIDMap;
475 
476   /// Map from specialization constant names to their <id>s.
477   llvm::StringMap<uint32_t> specConstIDMap;
478 
479   /// Map from GlobalVariableOps name to <id>s.
480   llvm::StringMap<uint32_t> globalVarIDMap;
481 
482   /// Map from FuncOps name to <id>s.
483   llvm::StringMap<uint32_t> funcIDMap;
484 
485   /// Map from blocks to their <id>s.
486   DenseMap<Block *, uint32_t> blockIDMap;
487 
488   /// Map from the Type to the <id> that represents undef value of that type.
489   DenseMap<Type, uint32_t> undefValIDMap;
490 
491   /// Map from results of normal operations to their <id>s.
492   DenseMap<Value, uint32_t> valueIDMap;
493 
494   /// Map from extended instruction set name to <id>s.
495   llvm::StringMap<uint32_t> extendedInstSetIDMap;
496 
497   /// Map from values used in OpPhi instructions to their offset in the
498   /// `functions` section.
499   ///
500   /// When processing a block with arguments, we need to emit OpPhi
501   /// instructions to record the predecessor block <id>s and the values they
502   /// send to the block in question. But it's not guaranteed all values are
503   /// visited and thus assigned result <id>s. So we need this list to capture
504   /// the offsets into `functions` where a value is used so that we can fix it
505   /// up later after processing all the blocks in a function.
506   ///
507   /// More concretely, say if we are visiting the following blocks:
508   ///
509   /// ```mlir
510   /// ^phi(%arg0: i32):
511   ///   ...
512   /// ^parent1:
513   ///   ...
514   ///   spv.Branch ^phi(%val0: i32)
515   /// ^parent2:
516   ///   ...
517   ///   spv.Branch ^phi(%val1: i32)
518   /// ```
519   ///
520   /// When we are serializing the `^phi` block, we need to emit at the beginning
521   /// of the block OpPhi instructions which has the following parameters:
522   ///
523   /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1
524   ///                               id-for-%val1 id-for-^parent2
525   ///
526   /// But we don't know the <id> for %val0 and %val1 yet. One way is to visit
527   /// all the blocks twice and use the first visit to assign an <id> to each
528   /// value. But it's paying the overheads just for OpPhi emission. Instead,
529   /// we still visit the blocks once for emission. When we emit the OpPhi
530   /// instructions, we use 0 as a placeholder for the <id>s for %val0 and %val1.
531   /// At the same time, we record their offsets in the emitted binary (which is
532   /// placed inside `functions`) here. And then after emitting all blocks, we
533   /// replace the dummy <id> 0 with the real result <id> by overwriting
534   /// `functions[offset]`.
535   DenseMap<Value, SmallVector<size_t, 1>> deferredPhiValues;
536 };
537 } // namespace
538 
539 Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo)
540     : module(module), mlirBuilder(module.getContext()),
541       emitDebugInfo(emitDebugInfo) {}
542 
543 LogicalResult Serializer::serialize() {
544   LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
545 
546   if (failed(module.verify()))
547     return failure();
548 
549   // TODO: handle the other sections
550   processCapability();
551   processExtension();
552   processMemoryModel();
553   processDebugInfo();
554 
555   // Iterate over the module body to serialize it. Assumptions are that there is
556   // only one basic block in the moduleOp
557   for (auto &op : module.getBlock()) {
558     if (failed(processOperation(&op))) {
559       return failure();
560     }
561   }
562 
563   LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
564   return success();
565 }
566 
567 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
568   auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
569                     extensions.size() + extendedSets.size() +
570                     memoryModel.size() + entryPoints.size() +
571                     executionModes.size() + decorations.size() +
572                     typesGlobalValues.size() + functions.size();
573 
574   binary.clear();
575   binary.reserve(moduleSize);
576 
577   spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID);
578   binary.append(capabilities.begin(), capabilities.end());
579   binary.append(extensions.begin(), extensions.end());
580   binary.append(extendedSets.begin(), extendedSets.end());
581   binary.append(memoryModel.begin(), memoryModel.end());
582   binary.append(entryPoints.begin(), entryPoints.end());
583   binary.append(executionModes.begin(), executionModes.end());
584   binary.append(debug.begin(), debug.end());
585   binary.append(names.begin(), names.end());
586   binary.append(decorations.begin(), decorations.end());
587   binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
588   binary.append(functions.begin(), functions.end());
589 }
590 
591 #ifndef NDEBUG
592 void Serializer::printValueIDMap(raw_ostream &os) {
593   os << "\n= Value <id> Map =\n\n";
594   for (auto valueIDPair : valueIDMap) {
595     Value val = valueIDPair.first;
596     os << "  " << val << " "
597        << "id = " << valueIDPair.second << ' ';
598     if (auto *op = val.getDefiningOp()) {
599       os << "from op '" << op->getName() << "'";
600     } else if (auto arg = val.dyn_cast<BlockArgument>()) {
601       Block *block = arg.getOwner();
602       os << "from argument of block " << block << ' ';
603       os << " in op '" << block->getParentOp()->getName() << "'";
604     }
605     os << '\n';
606   }
607 }
608 #endif
609 
610 //===----------------------------------------------------------------------===//
611 // Module structure
612 //===----------------------------------------------------------------------===//
613 
614 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
615   auto funcID = funcIDMap.lookup(fnName);
616   if (!funcID) {
617     funcID = getNextID();
618     funcIDMap[fnName] = funcID;
619   }
620   return funcID;
621 }
622 
623 void Serializer::processCapability() {
624   for (auto cap : module.vce_triple()->getCapabilities())
625     (void)encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
626                                 {static_cast<uint32_t>(cap)});
627 }
628 
629 void Serializer::processDebugInfo() {
630   if (!emitDebugInfo)
631     return;
632   auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>();
633   auto fileName = fileLoc ? fileLoc.getFilename() : "<unknown>";
634   fileID = getNextID();
635   SmallVector<uint32_t, 16> operands;
636   operands.push_back(fileID);
637   (void)spirv::encodeStringLiteralInto(operands, fileName);
638   (void)encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
639   // TODO: Encode more debug instructions.
640 }
641 
642 void Serializer::processExtension() {
643   llvm::SmallVector<uint32_t, 16> extName;
644   for (spirv::Extension ext : module.vce_triple()->getExtensions()) {
645     extName.clear();
646     (void)spirv::encodeStringLiteralInto(extName,
647                                          spirv::stringifyExtension(ext));
648     (void)encodeInstructionInto(extensions, spirv::Opcode::OpExtension,
649                                 extName);
650   }
651 }
652 
653 void Serializer::processMemoryModel() {
654   uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt();
655   uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt();
656 
657   (void)encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel,
658                               {am, mm});
659 }
660 
661 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
662   if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
663     valueIDMap[op.getResult()] = resultID;
664     return success();
665   }
666   return failure();
667 }
668 
669 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
670   if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(),
671                                             /*isSpec=*/true)) {
672     // Emit the OpDecorate instruction for SpecId.
673     if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
674       auto val = static_cast<uint32_t>(specID.getInt());
675       (void)emitDecoration(resultID, spirv::Decoration::SpecId, {val});
676     }
677 
678     specConstIDMap[op.sym_name()] = resultID;
679     return processName(resultID, op.sym_name());
680   }
681   return failure();
682 }
683 
684 LogicalResult
685 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
686   uint32_t typeID = 0;
687   if (failed(processType(op.getLoc(), op.type(), typeID))) {
688     return failure();
689   }
690 
691   auto resultID = getNextID();
692 
693   SmallVector<uint32_t, 8> operands;
694   operands.push_back(typeID);
695   operands.push_back(resultID);
696 
697   auto constituents = op.constituents();
698 
699   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
700     auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
701 
702     auto constituentName = constituent.getValue();
703     auto constituentID = getSpecConstID(constituentName);
704 
705     if (!constituentID) {
706       return op.emitError("unknown result <id> for specialization constant ")
707              << constituentName;
708     }
709 
710     operands.push_back(constituentID);
711   }
712 
713   (void)encodeInstructionInto(typesGlobalValues,
714                               spirv::Opcode::OpSpecConstantComposite, operands);
715   specConstIDMap[op.sym_name()] = resultID;
716 
717   return processName(resultID, op.sym_name());
718 }
719 
720 LogicalResult
721 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
722   uint32_t typeID = 0;
723   if (failed(processType(op.getLoc(), op.getType(), typeID))) {
724     return failure();
725   }
726 
727   auto resultID = getNextID();
728 
729   SmallVector<uint32_t, 8> operands;
730   operands.push_back(typeID);
731   operands.push_back(resultID);
732 
733   Block &block = op.getRegion().getBlocks().front();
734   Operation &enclosedOp = block.getOperations().front();
735 
736   std::string enclosedOpName;
737   llvm::raw_string_ostream rss(enclosedOpName);
738   rss << "Op" << enclosedOp.getName().stripDialect();
739   auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
740 
741   if (!enclosedOpcode) {
742     op.emitError("Couldn't find op code for op ")
743         << enclosedOp.getName().getStringRef();
744     return failure();
745   }
746 
747   operands.push_back(static_cast<uint32_t>(enclosedOpcode.getValue()));
748 
749   // Append operands to the enclosed op to the list of operands.
750   for (Value operand : enclosedOp.getOperands()) {
751     uint32_t id = getValueID(operand);
752     assert(id && "use before def!");
753     operands.push_back(id);
754   }
755 
756   (void)encodeInstructionInto(typesGlobalValues,
757                               spirv::Opcode::OpSpecConstantOp, operands);
758   valueIDMap[op.getResult()] = resultID;
759 
760   return success();
761 }
762 
763 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
764   auto undefType = op.getType();
765   auto &id = undefValIDMap[undefType];
766   if (!id) {
767     id = getNextID();
768     uint32_t typeID = 0;
769     if (failed(processType(op.getLoc(), undefType, typeID)) ||
770         failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
771                                      {typeID, id}))) {
772       return failure();
773     }
774   }
775   valueIDMap[op.getResult()] = id;
776   return success();
777 }
778 
779 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
780                                             NamedAttribute attr) {
781   auto attrName = attr.first.strref();
782   auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true);
783   auto decoration = spirv::symbolizeDecoration(decorationName);
784   if (!decoration) {
785     return emitError(
786                loc, "non-argument attributes expected to have snake-case-ified "
787                     "decoration name, unhandled attribute with name : ")
788            << attrName;
789   }
790   SmallVector<uint32_t, 1> args;
791   switch (decoration.getValue()) {
792   case spirv::Decoration::Binding:
793   case spirv::Decoration::DescriptorSet:
794   case spirv::Decoration::Location:
795     if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) {
796       args.push_back(intAttr.getValue().getZExtValue());
797       break;
798     }
799     return emitError(loc, "expected integer attribute for ") << attrName;
800   case spirv::Decoration::BuiltIn:
801     if (auto strAttr = attr.second.dyn_cast<StringAttr>()) {
802       auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
803       if (enumVal) {
804         args.push_back(static_cast<uint32_t>(enumVal.getValue()));
805         break;
806       }
807       return emitError(loc, "invalid ")
808              << attrName << " attribute " << strAttr.getValue();
809     }
810     return emitError(loc, "expected string attribute for ") << attrName;
811   case spirv::Decoration::Aliased:
812   case spirv::Decoration::Flat:
813   case spirv::Decoration::NonReadable:
814   case spirv::Decoration::NonWritable:
815   case spirv::Decoration::NoPerspective:
816   case spirv::Decoration::Restrict:
817     // For unit attributes, the args list has no values so we do nothing
818     if (auto unitAttr = attr.second.dyn_cast<UnitAttr>())
819       break;
820     return emitError(loc, "expected unit attribute for ") << attrName;
821   default:
822     return emitError(loc, "unhandled decoration ") << decorationName;
823   }
824   return emitDecoration(resultID, decoration.getValue(), args);
825 }
826 
827 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
828   assert(!name.empty() && "unexpected empty string for OpName");
829 
830   SmallVector<uint32_t, 4> nameOperands;
831   nameOperands.push_back(resultID);
832   if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) {
833     return failure();
834   }
835   return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
836 }
837 
838 namespace {
839 template <>
840 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
841     Location loc, spirv::ArrayType type, uint32_t resultID) {
842   if (unsigned stride = type.getArrayStride()) {
843     // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
844     return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
845   }
846   return success();
847 }
848 
849 template <>
850 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
851     Location Loc, spirv::RuntimeArrayType type, uint32_t resultID) {
852   if (unsigned stride = type.getArrayStride()) {
853     // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
854     return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
855   }
856   return success();
857 }
858 
859 LogicalResult Serializer::processMemberDecoration(
860     uint32_t structID,
861     const spirv::StructType::MemberDecorationInfo &memberDecoration) {
862   SmallVector<uint32_t, 4> args(
863       {structID, memberDecoration.memberIndex,
864        static_cast<uint32_t>(memberDecoration.decoration)});
865   if (memberDecoration.hasValue) {
866     args.push_back(memberDecoration.decorationValue);
867   }
868   return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate,
869                                args);
870 }
871 } // namespace
872 
873 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
874   LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
875   assert(functionHeader.empty() && functionBody.empty());
876 
877   uint32_t fnTypeID = 0;
878   // Generate type of the function.
879   (void)processType(op.getLoc(), op.getType(), fnTypeID);
880 
881   // Add the function definition.
882   SmallVector<uint32_t, 4> operands;
883   uint32_t resTypeID = 0;
884   auto resultTypes = op.getType().getResults();
885   if (resultTypes.size() > 1) {
886     return op.emitError("cannot serialize function with multiple return types");
887   }
888   if (failed(processType(op.getLoc(),
889                          (resultTypes.empty() ? getVoidType() : resultTypes[0]),
890                          resTypeID))) {
891     return failure();
892   }
893   operands.push_back(resTypeID);
894   auto funcID = getOrCreateFunctionID(op.getName());
895   operands.push_back(funcID);
896   operands.push_back(static_cast<uint32_t>(op.function_control()));
897   operands.push_back(fnTypeID);
898   (void)encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction,
899                               operands);
900 
901   // Add function name.
902   if (failed(processName(funcID, op.getName()))) {
903     return failure();
904   }
905 
906   // Declare the parameters.
907   for (auto arg : op.getArguments()) {
908     uint32_t argTypeID = 0;
909     if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
910       return failure();
911     }
912     auto argValueID = getNextID();
913     valueIDMap[arg] = argValueID;
914     (void)encodeInstructionInto(functionHeader,
915                                 spirv::Opcode::OpFunctionParameter,
916                                 {argTypeID, argValueID});
917   }
918 
919   // Process the body.
920   if (op.isExternal()) {
921     return op.emitError("external function is unhandled");
922   }
923 
924   // Some instructions (e.g., OpVariable) in a function must be in the first
925   // block in the function. These instructions will be put in functionHeader.
926   // Thus, we put the label in functionHeader first, and omit it from the first
927   // block.
928   (void)encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
929                               {getOrCreateBlockID(&op.front())});
930   (void)processBlock(&op.front(), /*omitLabel=*/true);
931   if (failed(visitInPrettyBlockOrder(
932           &op.front(), [&](Block *block) { return processBlock(block); },
933           /*skipHeader=*/true))) {
934     return failure();
935   }
936 
937   // There might be OpPhi instructions who have value references needing to fix.
938   for (auto deferredValue : deferredPhiValues) {
939     Value value = deferredValue.first;
940     uint32_t id = getValueID(value);
941     LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
942                             << " to id = " << id << '\n');
943     assert(id && "OpPhi references undefined value!");
944     for (size_t offset : deferredValue.second)
945       functionBody[offset] = id;
946   }
947   deferredPhiValues.clear();
948 
949   LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
950                           << "' --\n");
951   // Insert OpFunctionEnd.
952   if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd,
953                                    {}))) {
954     return failure();
955   }
956 
957   functions.append(functionHeader.begin(), functionHeader.end());
958   functions.append(functionBody.begin(), functionBody.end());
959   functionHeader.clear();
960   functionBody.clear();
961 
962   return success();
963 }
964 
965 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
966   SmallVector<uint32_t, 4> operands;
967   SmallVector<StringRef, 2> elidedAttrs;
968   uint32_t resultID = 0;
969   uint32_t resultTypeID = 0;
970   if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
971     return failure();
972   }
973   operands.push_back(resultTypeID);
974   resultID = getNextID();
975   valueIDMap[op.getResult()] = resultID;
976   operands.push_back(resultID);
977   auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
978   if (attr) {
979     operands.push_back(static_cast<uint32_t>(
980         attr.cast<IntegerAttr>().getValue().getZExtValue()));
981   }
982   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
983   for (auto arg : op.getODSOperands(0)) {
984     auto argID = getValueID(arg);
985     if (!argID) {
986       return emitError(op.getLoc(), "operand 0 has a use before def");
987     }
988     operands.push_back(argID);
989   }
990   (void)emitDebugLine(functionHeader, op.getLoc());
991   (void)encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable,
992                               operands);
993   for (auto attr : op->getAttrs()) {
994     if (llvm::any_of(elidedAttrs,
995                      [&](StringRef elided) { return attr.first == elided; })) {
996       continue;
997     }
998     if (failed(processDecoration(op.getLoc(), resultID, attr))) {
999       return failure();
1000     }
1001   }
1002   return success();
1003 }
1004 
1005 LogicalResult
1006 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
1007   // Get TypeID.
1008   uint32_t resultTypeID = 0;
1009   SmallVector<StringRef, 4> elidedAttrs;
1010   if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
1011     return failure();
1012   }
1013 
1014   if (isInterfaceStructPtrType(varOp.type())) {
1015     auto structType = varOp.type()
1016                           .cast<spirv::PointerType>()
1017                           .getPointeeType()
1018                           .cast<spirv::StructType>();
1019     if (failed(
1020             emitDecoration(getTypeID(structType), spirv::Decoration::Block))) {
1021       return varOp.emitError("cannot decorate ")
1022              << structType << " with Block decoration";
1023     }
1024   }
1025 
1026   elidedAttrs.push_back("type");
1027   SmallVector<uint32_t, 4> operands;
1028   operands.push_back(resultTypeID);
1029   auto resultID = getNextID();
1030 
1031   // Encode the name.
1032   auto varName = varOp.sym_name();
1033   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1034   if (failed(processName(resultID, varName))) {
1035     return failure();
1036   }
1037   globalVarIDMap[varName] = resultID;
1038   operands.push_back(resultID);
1039 
1040   // Encode StorageClass.
1041   operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
1042 
1043   // Encode initialization.
1044   if (auto initializer = varOp.initializer()) {
1045     auto initializerID = getVariableID(initializer.getValue());
1046     if (!initializerID) {
1047       return emitError(varOp.getLoc(),
1048                        "invalid usage of undefined variable as initializer");
1049     }
1050     operands.push_back(initializerID);
1051     elidedAttrs.push_back("initializer");
1052   }
1053 
1054   (void)emitDebugLine(typesGlobalValues, varOp.getLoc());
1055   if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable,
1056                                    operands))) {
1057     elidedAttrs.push_back("initializer");
1058     return failure();
1059   }
1060 
1061   // Encode decorations.
1062   for (auto attr : varOp->getAttrs()) {
1063     if (llvm::any_of(elidedAttrs,
1064                      [&](StringRef elided) { return attr.first == elided; })) {
1065       continue;
1066     }
1067     if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
1068       return failure();
1069     }
1070   }
1071   return success();
1072 }
1073 
1074 //===----------------------------------------------------------------------===//
1075 // Type
1076 //===----------------------------------------------------------------------===//
1077 
1078 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
1079 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
1080 // PushConstant Storage Classes must be explicitly laid out."
1081 bool Serializer::isInterfaceStructPtrType(Type type) const {
1082   if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
1083     switch (ptrType.getStorageClass()) {
1084     case spirv::StorageClass::PhysicalStorageBuffer:
1085     case spirv::StorageClass::PushConstant:
1086     case spirv::StorageClass::StorageBuffer:
1087     case spirv::StorageClass::Uniform:
1088       return ptrType.getPointeeType().isa<spirv::StructType>();
1089     default:
1090       break;
1091     }
1092   }
1093   return false;
1094 }
1095 
1096 LogicalResult Serializer::processType(Location loc, Type type,
1097                                       uint32_t &typeID) {
1098   // Maintains a set of names for nested identified struct types. This is used
1099   // to properly serialize recursive references.
1100   llvm::SetVector<StringRef> serializationCtx;
1101   return processTypeImpl(loc, type, typeID, serializationCtx);
1102 }
1103 
1104 LogicalResult
1105 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
1106                             llvm::SetVector<StringRef> &serializationCtx) {
1107   typeID = getTypeID(type);
1108   if (typeID) {
1109     return success();
1110   }
1111   typeID = getNextID();
1112   SmallVector<uint32_t, 4> operands;
1113 
1114   operands.push_back(typeID);
1115   auto typeEnum = spirv::Opcode::OpTypeVoid;
1116   bool deferSerialization = false;
1117 
1118   if ((type.isa<FunctionType>() &&
1119        succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
1120                                      operands))) ||
1121       succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
1122                                  deferSerialization, serializationCtx))) {
1123     if (deferSerialization)
1124       return success();
1125 
1126     typeIDMap[type] = typeID;
1127 
1128     if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands)))
1129       return failure();
1130 
1131     if (recursiveStructInfos.count(type) != 0) {
1132       // This recursive struct type is emitted already, now the OpTypePointer
1133       // instructions referring to recursive references are emitted as well.
1134       for (auto &ptrInfo : recursiveStructInfos[type]) {
1135         // TODO: This might not work if more than 1 recursive reference is
1136         // present in the struct.
1137         SmallVector<uint32_t, 4> ptrOperands;
1138         ptrOperands.push_back(ptrInfo.pointerTypeID);
1139         ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
1140         ptrOperands.push_back(typeIDMap[type]);
1141 
1142         if (failed(encodeInstructionInto(
1143                 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands)))
1144           return failure();
1145       }
1146 
1147       recursiveStructInfos[type].clear();
1148     }
1149 
1150     return success();
1151   }
1152 
1153   return failure();
1154 }
1155 
1156 LogicalResult Serializer::prepareBasicType(
1157     Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
1158     SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
1159     llvm::SetVector<StringRef> &serializationCtx) {
1160   deferSerialization = false;
1161 
1162   if (isVoidType(type)) {
1163     typeEnum = spirv::Opcode::OpTypeVoid;
1164     return success();
1165   }
1166 
1167   if (auto intType = type.dyn_cast<IntegerType>()) {
1168     if (intType.getWidth() == 1) {
1169       typeEnum = spirv::Opcode::OpTypeBool;
1170       return success();
1171     }
1172 
1173     typeEnum = spirv::Opcode::OpTypeInt;
1174     operands.push_back(intType.getWidth());
1175     // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
1176     // to preserve or validate.
1177     // 0 indicates unsigned, or no signedness semantics
1178     // 1 indicates signed semantics."
1179     operands.push_back(intType.isSigned() ? 1 : 0);
1180     return success();
1181   }
1182 
1183   if (auto floatType = type.dyn_cast<FloatType>()) {
1184     typeEnum = spirv::Opcode::OpTypeFloat;
1185     operands.push_back(floatType.getWidth());
1186     return success();
1187   }
1188 
1189   if (auto vectorType = type.dyn_cast<VectorType>()) {
1190     uint32_t elementTypeID = 0;
1191     if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
1192                                serializationCtx))) {
1193       return failure();
1194     }
1195     typeEnum = spirv::Opcode::OpTypeVector;
1196     operands.push_back(elementTypeID);
1197     operands.push_back(vectorType.getNumElements());
1198     return success();
1199   }
1200 
1201   if (auto imageType = type.dyn_cast<spirv::ImageType>()) {
1202     typeEnum = spirv::Opcode::OpTypeImage;
1203     uint32_t sampledTypeID = 0;
1204     if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
1205       return failure();
1206 
1207     operands.push_back(sampledTypeID);
1208     operands.push_back(static_cast<uint32_t>(imageType.getDim()));
1209     operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo()));
1210     operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo()));
1211     operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo()));
1212     operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo()));
1213     operands.push_back(static_cast<uint32_t>(imageType.getImageFormat()));
1214     return success();
1215   }
1216 
1217   if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
1218     typeEnum = spirv::Opcode::OpTypeArray;
1219     uint32_t elementTypeID = 0;
1220     if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
1221                                serializationCtx))) {
1222       return failure();
1223     }
1224     operands.push_back(elementTypeID);
1225     if (auto elementCountID = prepareConstantInt(
1226             loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
1227       operands.push_back(elementCountID);
1228     }
1229     return processTypeDecoration(loc, arrayType, resultID);
1230   }
1231 
1232   if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
1233     uint32_t pointeeTypeID = 0;
1234     spirv::StructType pointeeStruct =
1235         ptrType.getPointeeType().dyn_cast<spirv::StructType>();
1236 
1237     if (pointeeStruct && pointeeStruct.isIdentified() &&
1238         serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
1239       // A recursive reference to an enclosing struct is found.
1240       //
1241       // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
1242       // class as operands.
1243       SmallVector<uint32_t, 2> forwardPtrOperands;
1244       forwardPtrOperands.push_back(resultID);
1245       forwardPtrOperands.push_back(
1246           static_cast<uint32_t>(ptrType.getStorageClass()));
1247 
1248       (void)encodeInstructionInto(typesGlobalValues,
1249                                   spirv::Opcode::OpTypeForwardPointer,
1250                                   forwardPtrOperands);
1251 
1252       // 2. Find the pointee (enclosing) struct.
1253       auto structType = spirv::StructType::getIdentified(
1254           module.getContext(), pointeeStruct.getIdentifier());
1255 
1256       if (!structType)
1257         return failure();
1258 
1259       // 3. Mark the OpTypePointer that is supposed to be emitted by this call
1260       // as deferred.
1261       deferSerialization = true;
1262 
1263       // 4. Record the info needed to emit the deferred OpTypePointer
1264       // instruction when the enclosing struct is completely serialized.
1265       recursiveStructInfos[structType].push_back(
1266           {resultID, ptrType.getStorageClass()});
1267     } else {
1268       if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
1269                                  serializationCtx)))
1270         return failure();
1271     }
1272 
1273     typeEnum = spirv::Opcode::OpTypePointer;
1274     operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
1275     operands.push_back(pointeeTypeID);
1276     return success();
1277   }
1278 
1279   if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
1280     uint32_t elementTypeID = 0;
1281     if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
1282                                elementTypeID, serializationCtx))) {
1283       return failure();
1284     }
1285     typeEnum = spirv::Opcode::OpTypeRuntimeArray;
1286     operands.push_back(elementTypeID);
1287     return processTypeDecoration(loc, runtimeArrayType, resultID);
1288   }
1289 
1290   if (auto structType = type.dyn_cast<spirv::StructType>()) {
1291     if (structType.isIdentified()) {
1292       (void)processName(resultID, structType.getIdentifier());
1293       serializationCtx.insert(structType.getIdentifier());
1294     }
1295 
1296     bool hasOffset = structType.hasOffset();
1297     for (auto elementIndex :
1298          llvm::seq<uint32_t>(0, structType.getNumElements())) {
1299       uint32_t elementTypeID = 0;
1300       if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
1301                                  elementTypeID, serializationCtx))) {
1302         return failure();
1303       }
1304       operands.push_back(elementTypeID);
1305       if (hasOffset) {
1306         // Decorate each struct member with an offset
1307         spirv::StructType::MemberDecorationInfo offsetDecoration{
1308             elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
1309             static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
1310         if (failed(processMemberDecoration(resultID, offsetDecoration))) {
1311           return emitError(loc, "cannot decorate ")
1312                  << elementIndex << "-th member of " << structType
1313                  << " with its offset";
1314         }
1315       }
1316     }
1317     SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
1318     structType.getMemberDecorations(memberDecorations);
1319 
1320     for (auto &memberDecoration : memberDecorations) {
1321       if (failed(processMemberDecoration(resultID, memberDecoration))) {
1322         return emitError(loc, "cannot decorate ")
1323                << static_cast<uint32_t>(memberDecoration.memberIndex)
1324                << "-th member of " << structType << " with "
1325                << stringifyDecoration(memberDecoration.decoration);
1326       }
1327     }
1328 
1329     typeEnum = spirv::Opcode::OpTypeStruct;
1330 
1331     if (structType.isIdentified())
1332       serializationCtx.remove(structType.getIdentifier());
1333 
1334     return success();
1335   }
1336 
1337   if (auto cooperativeMatrixType =
1338           type.dyn_cast<spirv::CooperativeMatrixNVType>()) {
1339     uint32_t elementTypeID = 0;
1340     if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
1341                                elementTypeID, serializationCtx))) {
1342       return failure();
1343     }
1344     typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
1345     auto getConstantOp = [&](uint32_t id) {
1346       auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
1347       return prepareConstantInt(loc, attr);
1348     };
1349     operands.push_back(elementTypeID);
1350     operands.push_back(
1351         getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
1352     operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
1353     operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
1354     return success();
1355   }
1356 
1357   if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
1358     uint32_t elementTypeID = 0;
1359     if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
1360                                serializationCtx))) {
1361       return failure();
1362     }
1363     typeEnum = spirv::Opcode::OpTypeMatrix;
1364     operands.push_back(elementTypeID);
1365     operands.push_back(matrixType.getNumColumns());
1366     return success();
1367   }
1368 
1369   // TODO: Handle other types.
1370   return emitError(loc, "unhandled type in serialization: ") << type;
1371 }
1372 
1373 LogicalResult
1374 Serializer::prepareFunctionType(Location loc, FunctionType type,
1375                                 spirv::Opcode &typeEnum,
1376                                 SmallVectorImpl<uint32_t> &operands) {
1377   typeEnum = spirv::Opcode::OpTypeFunction;
1378   assert(type.getNumResults() <= 1 &&
1379          "serialization supports only a single return value");
1380   uint32_t resultID = 0;
1381   if (failed(processType(
1382           loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
1383           resultID))) {
1384     return failure();
1385   }
1386   operands.push_back(resultID);
1387   for (auto &res : type.getInputs()) {
1388     uint32_t argTypeID = 0;
1389     if (failed(processType(loc, res, argTypeID))) {
1390       return failure();
1391     }
1392     operands.push_back(argTypeID);
1393   }
1394   return success();
1395 }
1396 
1397 //===----------------------------------------------------------------------===//
1398 // Constant
1399 //===----------------------------------------------------------------------===//
1400 
1401 uint32_t Serializer::prepareConstant(Location loc, Type constType,
1402                                      Attribute valueAttr) {
1403   if (auto id = prepareConstantScalar(loc, valueAttr)) {
1404     return id;
1405   }
1406 
1407   // This is a composite literal. We need to handle each component separately
1408   // and then emit an OpConstantComposite for the whole.
1409 
1410   if (auto id = getConstantID(valueAttr)) {
1411     return id;
1412   }
1413 
1414   uint32_t typeID = 0;
1415   if (failed(processType(loc, constType, typeID))) {
1416     return 0;
1417   }
1418 
1419   uint32_t resultID = 0;
1420   if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) {
1421     int rank = attr.getType().dyn_cast<ShapedType>().getRank();
1422     SmallVector<uint64_t, 4> index(rank);
1423     resultID = prepareDenseElementsConstant(loc, constType, attr,
1424                                             /*dim=*/0, index);
1425   } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
1426     resultID = prepareArrayConstant(loc, constType, arrayAttr);
1427   }
1428 
1429   if (resultID == 0) {
1430     emitError(loc, "cannot serialize attribute: ") << valueAttr;
1431     return 0;
1432   }
1433 
1434   constIDMap[valueAttr] = resultID;
1435   return resultID;
1436 }
1437 
1438 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
1439                                           ArrayAttr attr) {
1440   uint32_t typeID = 0;
1441   if (failed(processType(loc, constType, typeID))) {
1442     return 0;
1443   }
1444 
1445   uint32_t resultID = getNextID();
1446   SmallVector<uint32_t, 4> operands = {typeID, resultID};
1447   operands.reserve(attr.size() + 2);
1448   auto elementType = constType.cast<spirv::ArrayType>().getElementType();
1449   for (Attribute elementAttr : attr) {
1450     if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
1451       operands.push_back(elementID);
1452     } else {
1453       return 0;
1454     }
1455   }
1456   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
1457   (void)encodeInstructionInto(typesGlobalValues, opcode, operands);
1458 
1459   return resultID;
1460 }
1461 
1462 // TODO: Turn the below function into iterative function, instead of
1463 // recursive function.
1464 uint32_t
1465 Serializer::prepareDenseElementsConstant(Location loc, Type constType,
1466                                          DenseElementsAttr valueAttr, int dim,
1467                                          MutableArrayRef<uint64_t> index) {
1468   auto shapedType = valueAttr.getType().dyn_cast<ShapedType>();
1469   assert(dim <= shapedType.getRank());
1470   if (shapedType.getRank() == dim) {
1471     if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
1472       return attr.getType().getElementType().isInteger(1)
1473                  ? prepareConstantBool(loc, attr.getValue<BoolAttr>(index))
1474                  : prepareConstantInt(loc, attr.getValue<IntegerAttr>(index));
1475     }
1476     if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
1477       return prepareConstantFp(loc, attr.getValue<FloatAttr>(index));
1478     }
1479     return 0;
1480   }
1481 
1482   uint32_t typeID = 0;
1483   if (failed(processType(loc, constType, typeID))) {
1484     return 0;
1485   }
1486 
1487   uint32_t resultID = getNextID();
1488   SmallVector<uint32_t, 4> operands = {typeID, resultID};
1489   operands.reserve(shapedType.getDimSize(dim) + 2);
1490   auto elementType = constType.cast<spirv::CompositeType>().getElementType(0);
1491   for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
1492     index[dim] = i;
1493     if (auto elementID = prepareDenseElementsConstant(
1494             loc, elementType, valueAttr, dim + 1, index)) {
1495       operands.push_back(elementID);
1496     } else {
1497       return 0;
1498     }
1499   }
1500   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
1501   (void)encodeInstructionInto(typesGlobalValues, opcode, operands);
1502 
1503   return resultID;
1504 }
1505 
1506 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
1507                                            bool isSpec) {
1508   if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
1509     return prepareConstantFp(loc, floatAttr, isSpec);
1510   }
1511   if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
1512     return prepareConstantBool(loc, boolAttr, isSpec);
1513   }
1514   if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
1515     return prepareConstantInt(loc, intAttr, isSpec);
1516   }
1517 
1518   return 0;
1519 }
1520 
1521 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
1522                                          bool isSpec) {
1523   if (!isSpec) {
1524     // We can de-duplicate normal constants, but not specialization constants.
1525     if (auto id = getConstantID(boolAttr)) {
1526       return id;
1527     }
1528   }
1529 
1530   // Process the type for this bool literal
1531   uint32_t typeID = 0;
1532   if (failed(processType(loc, boolAttr.getType(), typeID))) {
1533     return 0;
1534   }
1535 
1536   auto resultID = getNextID();
1537   auto opcode = boolAttr.getValue()
1538                     ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1539                               : spirv::Opcode::OpConstantTrue)
1540                     : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1541                               : spirv::Opcode::OpConstantFalse);
1542   (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
1543 
1544   if (!isSpec) {
1545     constIDMap[boolAttr] = resultID;
1546   }
1547   return resultID;
1548 }
1549 
1550 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
1551                                         bool isSpec) {
1552   if (!isSpec) {
1553     // We can de-duplicate normal constants, but not specialization constants.
1554     if (auto id = getConstantID(intAttr)) {
1555       return id;
1556     }
1557   }
1558 
1559   // Process the type for this integer literal
1560   uint32_t typeID = 0;
1561   if (failed(processType(loc, intAttr.getType(), typeID))) {
1562     return 0;
1563   }
1564 
1565   auto resultID = getNextID();
1566   APInt value = intAttr.getValue();
1567   unsigned bitwidth = value.getBitWidth();
1568   bool isSigned = value.isSignedIntN(bitwidth);
1569 
1570   auto opcode =
1571       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1572 
1573   // According to SPIR-V spec, "When the type's bit width is less than 32-bits,
1574   // the literal's value appears in the low-order bits of the word, and the
1575   // high-order bits must be 0 for a floating-point type, or 0 for an integer
1576   // type with Signedness of 0, or sign extended when Signedness is 1."
1577   if (bitwidth == 32 || bitwidth == 16) {
1578     uint32_t word = 0;
1579     if (isSigned) {
1580       word = static_cast<int32_t>(value.getSExtValue());
1581     } else {
1582       word = static_cast<uint32_t>(value.getZExtValue());
1583     }
1584     (void)encodeInstructionInto(typesGlobalValues, opcode,
1585                                 {typeID, resultID, word});
1586   }
1587   // According to SPIR-V spec: "When the type's bit width is larger than one
1588   // word, the literal’s low-order words appear first."
1589   else if (bitwidth == 64) {
1590     struct DoubleWord {
1591       uint32_t word1;
1592       uint32_t word2;
1593     } words;
1594     if (isSigned) {
1595       words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1596     } else {
1597       words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1598     }
1599     (void)encodeInstructionInto(typesGlobalValues, opcode,
1600                                 {typeID, resultID, words.word1, words.word2});
1601   } else {
1602     std::string valueStr;
1603     llvm::raw_string_ostream rss(valueStr);
1604     value.print(rss, /*isSigned=*/false);
1605 
1606     emitError(loc, "cannot serialize ")
1607         << bitwidth << "-bit integer literal: " << rss.str();
1608     return 0;
1609   }
1610 
1611   if (!isSpec) {
1612     constIDMap[intAttr] = resultID;
1613   }
1614   return resultID;
1615 }
1616 
1617 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1618                                        bool isSpec) {
1619   if (!isSpec) {
1620     // We can de-duplicate normal constants, but not specialization constants.
1621     if (auto id = getConstantID(floatAttr)) {
1622       return id;
1623     }
1624   }
1625 
1626   // Process the type for this float literal
1627   uint32_t typeID = 0;
1628   if (failed(processType(loc, floatAttr.getType(), typeID))) {
1629     return 0;
1630   }
1631 
1632   auto resultID = getNextID();
1633   APFloat value = floatAttr.getValue();
1634   APInt intValue = value.bitcastToAPInt();
1635 
1636   auto opcode =
1637       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1638 
1639   if (&value.getSemantics() == &APFloat::IEEEsingle()) {
1640     uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1641     (void)encodeInstructionInto(typesGlobalValues, opcode,
1642                                 {typeID, resultID, word});
1643   } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
1644     struct DoubleWord {
1645       uint32_t word1;
1646       uint32_t word2;
1647     } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1648     (void)encodeInstructionInto(typesGlobalValues, opcode,
1649                                 {typeID, resultID, words.word1, words.word2});
1650   } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
1651     uint32_t word =
1652         static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
1653     (void)encodeInstructionInto(typesGlobalValues, opcode,
1654                                 {typeID, resultID, word});
1655   } else {
1656     std::string valueStr;
1657     llvm::raw_string_ostream rss(valueStr);
1658     value.print(rss);
1659 
1660     emitError(loc, "cannot serialize ")
1661         << floatAttr.getType() << "-typed float literal: " << rss.str();
1662     return 0;
1663   }
1664 
1665   if (!isSpec) {
1666     constIDMap[floatAttr] = resultID;
1667   }
1668   return resultID;
1669 }
1670 
1671 //===----------------------------------------------------------------------===//
1672 // Control flow
1673 //===----------------------------------------------------------------------===//
1674 
1675 uint32_t Serializer::getOrCreateBlockID(Block *block) {
1676   if (uint32_t id = getBlockID(block))
1677     return id;
1678   return blockIDMap[block] = getNextID();
1679 }
1680 
1681 LogicalResult
1682 Serializer::processBlock(Block *block, bool omitLabel,
1683                          function_ref<void()> actionBeforeTerminator) {
1684   LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1685   LLVM_DEBUG(block->print(llvm::dbgs()));
1686   LLVM_DEBUG(llvm::dbgs() << '\n');
1687   if (!omitLabel) {
1688     uint32_t blockID = getOrCreateBlockID(block);
1689     LLVM_DEBUG(llvm::dbgs()
1690                << "[block] " << block << " (id = " << blockID << ")\n");
1691 
1692     // Emit OpLabel for this block.
1693     (void)encodeInstructionInto(functionBody, spirv::Opcode::OpLabel,
1694                                 {blockID});
1695   }
1696 
1697   // Emit OpPhi instructions for block arguments, if any.
1698   if (failed(emitPhiForBlockArguments(block)))
1699     return failure();
1700 
1701   // Process each op in this block except the terminator.
1702   for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
1703     if (failed(processOperation(&op)))
1704       return failure();
1705   }
1706 
1707   // Process the terminator.
1708   if (actionBeforeTerminator)
1709     actionBeforeTerminator();
1710   if (failed(processOperation(&block->back())))
1711     return failure();
1712 
1713   return success();
1714 }
1715 
1716 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1717   // Nothing to do if this block has no arguments or it's the entry block, which
1718   // always has the same arguments as the function signature.
1719   if (block->args_empty() || block->isEntryBlock())
1720     return success();
1721 
1722   // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1723   // A SPIR-V OpPhi instruction is of the syntax:
1724   //   OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1725   // So we need to collect all predecessor blocks and the arguments they send
1726   // to this block.
1727   SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors;
1728   for (Block *predecessor : block->getPredecessors()) {
1729     auto *terminator = predecessor->getTerminator();
1730     // The predecessor here is the immediate one according to MLIR's IR
1731     // structure. It does not directly map to the incoming parent block for the
1732     // OpPhi instructions at SPIR-V binary level. This is because structured
1733     // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1734     // spv.selection/spv.loop op in the MLIR predecessor block, the branch op
1735     // jumping to the OpPhi's block then resides in the previous structured
1736     // control flow op's merge block.
1737     predecessor = getPhiIncomingBlock(predecessor);
1738     if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1739       predecessors.emplace_back(predecessor, branchOp.operand_begin());
1740     } else {
1741       return terminator->emitError("unimplemented terminator for Phi creation");
1742     }
1743   }
1744 
1745   // Then create OpPhi instruction for each of the block argument.
1746   for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1747     BlockArgument arg = block->getArgument(argIndex);
1748 
1749     // Get the type <id> and result <id> for this OpPhi instruction.
1750     uint32_t phiTypeID = 0;
1751     if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1752       return failure();
1753     uint32_t phiID = getNextID();
1754 
1755     LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1756                             << arg << " (id = " << phiID << ")\n");
1757 
1758     // Prepare the (value <id>, parent block <id>) pairs.
1759     SmallVector<uint32_t, 8> phiArgs;
1760     phiArgs.push_back(phiTypeID);
1761     phiArgs.push_back(phiID);
1762 
1763     for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1764       Value value = *(predecessors[predIndex].second + argIndex);
1765       uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1766       LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1767                               << ") value " << value << ' ');
1768       // Each pair is a value <id> ...
1769       uint32_t valueId = getValueID(value);
1770       if (valueId == 0) {
1771         // The op generating this value hasn't been visited yet so we don't have
1772         // an <id> assigned yet. Record this to fix up later.
1773         LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1774         deferredPhiValues[value].push_back(functionBody.size() + 1 +
1775                                            phiArgs.size());
1776       } else {
1777         LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1778       }
1779       phiArgs.push_back(valueId);
1780       // ... and a parent block <id>.
1781       phiArgs.push_back(predBlockId);
1782     }
1783 
1784     (void)encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1785     valueIDMap[arg] = phiID;
1786   }
1787 
1788   return success();
1789 }
1790 
1791 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
1792   // Assign <id>s to all blocks so that branches inside the SelectionOp can
1793   // resolve properly.
1794   auto &body = selectionOp.body();
1795   for (Block &block : body)
1796     getOrCreateBlockID(&block);
1797 
1798   auto *headerBlock = selectionOp.getHeaderBlock();
1799   auto *mergeBlock = selectionOp.getMergeBlock();
1800   auto mergeID = getBlockID(mergeBlock);
1801   auto loc = selectionOp.getLoc();
1802 
1803   // Emit the selection header block, which dominates all other blocks, first.
1804   // We need to emit an OpSelectionMerge instruction before the selection header
1805   // block's terminator.
1806   auto emitSelectionMerge = [&]() {
1807     (void)emitDebugLine(functionBody, loc);
1808     lastProcessedWasMergeInst = true;
1809     (void)encodeInstructionInto(
1810         functionBody, spirv::Opcode::OpSelectionMerge,
1811         {mergeID, static_cast<uint32_t>(selectionOp.selection_control())});
1812   };
1813   // For structured selection, we cannot have blocks in the selection construct
1814   // branching to the selection header block. Entering the selection (and
1815   // reaching the selection header) must be from the block containing the
1816   // spv.selection op. If there are ops ahead of the spv.selection op in the
1817   // block, we can "merge" them into the selection header. So here we don't need
1818   // to emit a separate block; just continue with the existing block.
1819   if (failed(processBlock(headerBlock, /*omitLabel=*/true, emitSelectionMerge)))
1820     return failure();
1821 
1822   // Process all blocks with a depth-first visitor starting from the header
1823   // block. The selection header block and merge block are skipped by this
1824   // visitor.
1825   if (failed(visitInPrettyBlockOrder(
1826           headerBlock, [&](Block *block) { return processBlock(block); },
1827           /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
1828     return failure();
1829 
1830   // There is nothing to do for the merge block in the selection, which just
1831   // contains a spv.mlir.merge op, itself. But we need to have an OpLabel
1832   // instruction to start a new SPIR-V block for ops following this SelectionOp.
1833   // The block should use the <id> for the merge block.
1834   return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
1835 }
1836 
1837 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
1838   // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
1839   // properly. We don't need to assign for the entry block, which is just for
1840   // satisfying MLIR region's structural requirement.
1841   auto &body = loopOp.body();
1842   for (Block &block :
1843        llvm::make_range(std::next(body.begin(), 1), body.end())) {
1844     getOrCreateBlockID(&block);
1845   }
1846   auto *headerBlock = loopOp.getHeaderBlock();
1847   auto *continueBlock = loopOp.getContinueBlock();
1848   auto *mergeBlock = loopOp.getMergeBlock();
1849   auto headerID = getBlockID(headerBlock);
1850   auto continueID = getBlockID(continueBlock);
1851   auto mergeID = getBlockID(mergeBlock);
1852   auto loc = loopOp.getLoc();
1853 
1854   // This LoopOp is in some MLIR block with preceding and following ops. In the
1855   // binary format, it should reside in separate SPIR-V blocks from its
1856   // preceding and following ops. So we need to emit unconditional branches to
1857   // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
1858   // afterwards.
1859   (void)encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
1860                               {headerID});
1861 
1862   // LoopOp's entry block is just there for satisfying MLIR's structural
1863   // requirements so we omit it and start serialization from the loop header
1864   // block.
1865 
1866   // Emit the loop header block, which dominates all other blocks, first. We
1867   // need to emit an OpLoopMerge instruction before the loop header block's
1868   // terminator.
1869   auto emitLoopMerge = [&]() {
1870     (void)emitDebugLine(functionBody, loc);
1871     lastProcessedWasMergeInst = true;
1872     (void)encodeInstructionInto(
1873         functionBody, spirv::Opcode::OpLoopMerge,
1874         {mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())});
1875   };
1876   if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
1877     return failure();
1878 
1879   // Process all blocks with a depth-first visitor starting from the header
1880   // block. The loop header block, loop continue block, and loop merge block are
1881   // skipped by this visitor and handled later in this function.
1882   if (failed(visitInPrettyBlockOrder(
1883           headerBlock, [&](Block *block) { return processBlock(block); },
1884           /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
1885     return failure();
1886 
1887   // We have handled all other blocks. Now get to the loop continue block.
1888   if (failed(processBlock(continueBlock)))
1889     return failure();
1890 
1891   // There is nothing to do for the merge block in the loop, which just contains
1892   // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to
1893   // start a new SPIR-V block for ops following this LoopOp. The block should
1894   // use the <id> for the merge block.
1895   return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
1896 }
1897 
1898 LogicalResult Serializer::processBranchConditionalOp(
1899     spirv::BranchConditionalOp condBranchOp) {
1900   auto conditionID = getValueID(condBranchOp.condition());
1901   auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
1902   auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
1903   SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
1904 
1905   if (auto weights = condBranchOp.branch_weights()) {
1906     for (auto val : weights->getValue())
1907       arguments.push_back(val.cast<IntegerAttr>().getInt());
1908   }
1909 
1910   (void)emitDebugLine(functionBody, condBranchOp.getLoc());
1911   return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
1912                                arguments);
1913 }
1914 
1915 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
1916   (void)emitDebugLine(functionBody, branchOp.getLoc());
1917   return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
1918                                {getOrCreateBlockID(branchOp.getTarget())});
1919 }
1920 
1921 //===----------------------------------------------------------------------===//
1922 // Operation
1923 //===----------------------------------------------------------------------===//
1924 
1925 LogicalResult Serializer::encodeExtensionInstruction(
1926     Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1927     ArrayRef<uint32_t> operands) {
1928   // Check if the extension has been imported.
1929   auto &setID = extendedInstSetIDMap[extensionSetName];
1930   if (!setID) {
1931     setID = getNextID();
1932     SmallVector<uint32_t, 16> importOperands;
1933     importOperands.push_back(setID);
1934     if (failed(
1935             spirv::encodeStringLiteralInto(importOperands, extensionSetName)) ||
1936         failed(encodeInstructionInto(
1937             extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) {
1938       return failure();
1939     }
1940   }
1941 
1942   // The first two operands are the result type <id> and result <id>. The set
1943   // <id> and the opcode need to be insert after this.
1944   if (operands.size() < 2) {
1945     return op->emitError("extended instructions must have a result encoding");
1946   }
1947   SmallVector<uint32_t, 8> extInstOperands;
1948   extInstOperands.reserve(operands.size() + 2);
1949   extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1950   extInstOperands.push_back(setID);
1951   extInstOperands.push_back(extensionOpcode);
1952   extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1953   return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1954                                extInstOperands);
1955 }
1956 
1957 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
1958   auto varName = addressOfOp.variable();
1959   auto variableID = getVariableID(varName);
1960   if (!variableID) {
1961     return addressOfOp.emitError("unknown result <id> for variable ")
1962            << varName;
1963   }
1964   valueIDMap[addressOfOp.pointer()] = variableID;
1965   return success();
1966 }
1967 
1968 LogicalResult
1969 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
1970   auto constName = referenceOfOp.spec_const();
1971   auto constID = getSpecConstID(constName);
1972   if (!constID) {
1973     return referenceOfOp.emitError(
1974                "unknown result <id> for specialization constant ")
1975            << constName;
1976   }
1977   valueIDMap[referenceOfOp.reference()] = constID;
1978   return success();
1979 }
1980 
1981 LogicalResult Serializer::processOperation(Operation *opInst) {
1982   LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1983 
1984   // First dispatch the ops that do not directly mirror an instruction from
1985   // the SPIR-V spec.
1986   return TypeSwitch<Operation *, LogicalResult>(opInst)
1987       .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1988       .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1989       .Case([&](spirv::BranchConditionalOp op) {
1990         return processBranchConditionalOp(op);
1991       })
1992       .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1993       .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1994       .Case([&](spirv::GlobalVariableOp op) {
1995         return processGlobalVariableOp(op);
1996       })
1997       .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1998       .Case([&](spirv::ModuleEndOp) { return success(); })
1999       .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
2000       .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
2001       .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
2002       .Case([&](spirv::SpecConstantCompositeOp op) {
2003         return processSpecConstantCompositeOp(op);
2004       })
2005       .Case([&](spirv::SpecConstantOperationOp op) {
2006         return processSpecConstantOperationOp(op);
2007       })
2008       .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
2009       .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
2010 
2011       // Then handle all the ops that directly mirror SPIR-V instructions with
2012       // auto-generated methods.
2013       .Default(
2014           [&](Operation *op) { return dispatchToAutogenSerialization(op); });
2015 }
2016 
2017 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
2018                                                       StringRef extInstSet,
2019                                                       uint32_t opcode) {
2020   SmallVector<uint32_t, 4> operands;
2021   Location loc = op->getLoc();
2022 
2023   uint32_t resultID = 0;
2024   if (op->getNumResults() != 0) {
2025     uint32_t resultTypeID = 0;
2026     if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
2027       return failure();
2028     operands.push_back(resultTypeID);
2029 
2030     resultID = getNextID();
2031     operands.push_back(resultID);
2032     valueIDMap[op->getResult(0)] = resultID;
2033   };
2034 
2035   for (Value operand : op->getOperands())
2036     operands.push_back(getValueID(operand));
2037 
2038   (void)emitDebugLine(functionBody, loc);
2039 
2040   if (extInstSet.empty()) {
2041     (void)encodeInstructionInto(functionBody,
2042                                 static_cast<spirv::Opcode>(opcode), operands);
2043   } else {
2044     (void)encodeExtensionInstruction(op, extInstSet, opcode, operands);
2045   }
2046 
2047   if (op->getNumResults() != 0) {
2048     for (auto attr : op->getAttrs()) {
2049       if (failed(processDecoration(loc, resultID, attr)))
2050         return failure();
2051     }
2052   }
2053 
2054   return success();
2055 }
2056 
2057 namespace {
2058 template <>
2059 LogicalResult
2060 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
2061   SmallVector<uint32_t, 4> operands;
2062   // Add the ExecutionModel.
2063   operands.push_back(static_cast<uint32_t>(op.execution_model()));
2064   // Add the function <id>.
2065   auto funcID = getFunctionID(op.fn());
2066   if (!funcID) {
2067     return op.emitError("missing <id> for function ")
2068            << op.fn()
2069            << "; function needs to be defined before spv.EntryPoint is "
2070               "serialized";
2071   }
2072   operands.push_back(funcID);
2073   // Add the name of the function.
2074   (void)spirv::encodeStringLiteralInto(operands, op.fn());
2075 
2076   // Add the interface values.
2077   if (auto interface = op.interface()) {
2078     for (auto var : interface.getValue()) {
2079       auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
2080       if (!id) {
2081         return op.emitError("referencing undefined global variable."
2082                             "spv.EntryPoint is at the end of spv.module. All "
2083                             "referenced variables should already be defined");
2084       }
2085       operands.push_back(id);
2086     }
2087   }
2088   return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint,
2089                                operands);
2090 }
2091 
2092 template <>
2093 LogicalResult
2094 Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) {
2095   StringRef argNames[] = {"execution_scope", "memory_scope",
2096                           "memory_semantics"};
2097   SmallVector<uint32_t, 3> operands;
2098 
2099   for (auto argName : argNames) {
2100     auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
2101     auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
2102     if (!operand) {
2103       return failure();
2104     }
2105     operands.push_back(operand);
2106   }
2107 
2108   return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier,
2109                                operands);
2110 }
2111 
2112 template <>
2113 LogicalResult
2114 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
2115   SmallVector<uint32_t, 4> operands;
2116   // Add the function <id>.
2117   auto funcID = getFunctionID(op.fn());
2118   if (!funcID) {
2119     return op.emitError("missing <id> for function ")
2120            << op.fn()
2121            << "; function needs to be serialized before ExecutionModeOp is "
2122               "serialized";
2123   }
2124   operands.push_back(funcID);
2125   // Add the ExecutionMode.
2126   operands.push_back(static_cast<uint32_t>(op.execution_mode()));
2127 
2128   // Serialize values if any.
2129   auto values = op.values();
2130   if (values) {
2131     for (auto &intVal : values.getValue()) {
2132       operands.push_back(static_cast<uint32_t>(
2133           intVal.cast<IntegerAttr>().getValue().getZExtValue()));
2134     }
2135   }
2136   return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
2137                                operands);
2138 }
2139 
2140 template <>
2141 LogicalResult
2142 Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) {
2143   StringRef argNames[] = {"memory_scope", "memory_semantics"};
2144   SmallVector<uint32_t, 2> operands;
2145 
2146   for (auto argName : argNames) {
2147     auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
2148     auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
2149     if (!operand) {
2150       return failure();
2151     }
2152     operands.push_back(operand);
2153   }
2154 
2155   return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier,
2156                                operands);
2157 }
2158 
2159 template <>
2160 LogicalResult
2161 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
2162   auto funcName = op.callee();
2163   uint32_t resTypeID = 0;
2164 
2165   Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
2166   if (failed(processType(op.getLoc(), resultTy, resTypeID)))
2167     return failure();
2168 
2169   auto funcID = getOrCreateFunctionID(funcName);
2170   auto funcCallID = getNextID();
2171   SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
2172 
2173   for (auto value : op.arguments()) {
2174     auto valueID = getValueID(value);
2175     assert(valueID && "cannot find a value for spv.FunctionCall");
2176     operands.push_back(valueID);
2177   }
2178 
2179   if (!resultTy.isa<NoneType>())
2180     valueIDMap[op.getResult(0)] = funcCallID;
2181 
2182   return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall,
2183                                operands);
2184 }
2185 
2186 template <>
2187 LogicalResult
2188 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
2189   SmallVector<uint32_t, 4> operands;
2190   SmallVector<StringRef, 2> elidedAttrs;
2191 
2192   for (Value operand : op->getOperands()) {
2193     auto id = getValueID(operand);
2194     assert(id && "use before def!");
2195     operands.push_back(id);
2196   }
2197 
2198   if (auto attr = op->getAttr("memory_access")) {
2199     operands.push_back(static_cast<uint32_t>(
2200         attr.cast<IntegerAttr>().getValue().getZExtValue()));
2201   }
2202 
2203   elidedAttrs.push_back("memory_access");
2204 
2205   if (auto attr = op->getAttr("alignment")) {
2206     operands.push_back(static_cast<uint32_t>(
2207         attr.cast<IntegerAttr>().getValue().getZExtValue()));
2208   }
2209 
2210   elidedAttrs.push_back("alignment");
2211 
2212   if (auto attr = op->getAttr("source_memory_access")) {
2213     operands.push_back(static_cast<uint32_t>(
2214         attr.cast<IntegerAttr>().getValue().getZExtValue()));
2215   }
2216 
2217   elidedAttrs.push_back("source_memory_access");
2218 
2219   if (auto attr = op->getAttr("source_alignment")) {
2220     operands.push_back(static_cast<uint32_t>(
2221         attr.cast<IntegerAttr>().getValue().getZExtValue()));
2222   }
2223 
2224   elidedAttrs.push_back("source_alignment");
2225   (void)emitDebugLine(functionBody, op.getLoc());
2226   (void)encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory,
2227                               operands);
2228 
2229   return success();
2230 }
2231 
2232 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
2233 // various Serializer::processOp<...>() specializations.
2234 #define GET_SERIALIZATION_FNS
2235 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
2236 } // namespace
2237 
2238 LogicalResult Serializer::emitDecoration(uint32_t target,
2239                                          spirv::Decoration decoration,
2240                                          ArrayRef<uint32_t> params) {
2241   uint32_t wordCount = 3 + params.size();
2242   decorations.push_back(
2243       spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate));
2244   decorations.push_back(target);
2245   decorations.push_back(static_cast<uint32_t>(decoration));
2246   decorations.append(params.begin(), params.end());
2247   return success();
2248 }
2249 
2250 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
2251                                         Location loc) {
2252   if (!emitDebugInfo)
2253     return success();
2254 
2255   if (lastProcessedWasMergeInst) {
2256     lastProcessedWasMergeInst = false;
2257     return success();
2258   }
2259 
2260   auto fileLoc = loc.dyn_cast<FileLineColLoc>();
2261   if (fileLoc)
2262     (void)encodeInstructionInto(
2263         binary, spirv::Opcode::OpLine,
2264         {fileID, fileLoc.getLine(), fileLoc.getColumn()});
2265   return success();
2266 }
2267 
2268 namespace mlir {
2269 LogicalResult spirv::serialize(spirv::ModuleOp module,
2270                                SmallVectorImpl<uint32_t> &binary,
2271                                bool emitDebugInfo) {
2272   if (!module.vce_triple().hasValue())
2273     return module.emitError(
2274         "module must have 'vce_triple' attribute to be serializeable");
2275 
2276   Serializer serializer(module, emitDebugInfo);
2277 
2278   if (failed(serializer.serialize()))
2279     return failure();
2280 
2281   LLVM_DEBUG(serializer.printValueIDMap(llvm::dbgs()));
2282 
2283   serializer.collect(binary);
2284   return success();
2285 }
2286 } // namespace mlir
2287