1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Support/LogicalResult.h"
19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/ADT/bit.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "spirv-serialization"
28 
29 using namespace mlir;
30 
31 /// Returns the merge block if the given `op` is a structured control flow op.
32 /// Otherwise returns nullptr.
33 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) {
34   if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
35     return selectionOp.getMergeBlock();
36   if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
37     return loopOp.getMergeBlock();
38   return nullptr;
39 }
40 
41 /// Given a predecessor `block` for a block with arguments, returns the block
42 /// that should be used as the parent block for SPIR-V OpPhi instructions
43 /// corresponding to the block arguments.
44 static Block *getPhiIncomingBlock(Block *block) {
45   // If the predecessor block in question is the entry block for a
46   // spv.mlir.loop, we jump to this spv.mlir.loop from its enclosing block.
47   if (block->isEntryBlock()) {
48     if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
49       // Then the incoming parent block for OpPhi should be the merge block of
50       // the structured control flow op before this loop.
51       Operation *op = loopOp.getOperation();
52       while ((op = op->getPrevNode()) != nullptr)
53         if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
54           return incomingBlock;
55       // Or the enclosing block itself if no structured control flow ops
56       // exists before this loop.
57       return loopOp->getBlock();
58     }
59   }
60 
61   // Otherwise, we jump from the given predecessor block. Try to see if there is
62   // a structured control flow op inside it.
63   for (Operation &op : llvm::reverse(block->getOperations())) {
64     if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
65       return incomingBlock;
66   }
67   return block;
68 }
69 
70 namespace mlir {
71 namespace spirv {
72 
73 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
74 /// the given `binary` vector.
75 LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
76                                     spirv::Opcode op,
77                                     ArrayRef<uint32_t> operands) {
78   uint32_t wordCount = 1 + operands.size();
79   binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
80   binary.append(operands.begin(), operands.end());
81   return success();
82 }
83 
84 Serializer::Serializer(spirv::ModuleOp module,
85                        const SerializationOptions &options)
86     : module(module), mlirBuilder(module.getContext()), options(options) {}
87 
88 LogicalResult Serializer::serialize() {
89   LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
90 
91   if (failed(module.verify()))
92     return failure();
93 
94   // TODO: handle the other sections
95   processCapability();
96   processExtension();
97   processMemoryModel();
98   processDebugInfo();
99 
100   // Iterate over the module body to serialize it. Assumptions are that there is
101   // only one basic block in the moduleOp
102   for (auto &op : *module.getBody()) {
103     if (failed(processOperation(&op))) {
104       return failure();
105     }
106   }
107 
108   LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
109   return success();
110 }
111 
112 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
113   auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
114                     extensions.size() + extendedSets.size() +
115                     memoryModel.size() + entryPoints.size() +
116                     executionModes.size() + decorations.size() +
117                     typesGlobalValues.size() + functions.size();
118 
119   binary.clear();
120   binary.reserve(moduleSize);
121 
122   spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID);
123   binary.append(capabilities.begin(), capabilities.end());
124   binary.append(extensions.begin(), extensions.end());
125   binary.append(extendedSets.begin(), extendedSets.end());
126   binary.append(memoryModel.begin(), memoryModel.end());
127   binary.append(entryPoints.begin(), entryPoints.end());
128   binary.append(executionModes.begin(), executionModes.end());
129   binary.append(debug.begin(), debug.end());
130   binary.append(names.begin(), names.end());
131   binary.append(decorations.begin(), decorations.end());
132   binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
133   binary.append(functions.begin(), functions.end());
134 }
135 
136 #ifndef NDEBUG
137 void Serializer::printValueIDMap(raw_ostream &os) {
138   os << "\n= Value <id> Map =\n\n";
139   for (auto valueIDPair : valueIDMap) {
140     Value val = valueIDPair.first;
141     os << "  " << val << " "
142        << "id = " << valueIDPair.second << ' ';
143     if (auto *op = val.getDefiningOp()) {
144       os << "from op '" << op->getName() << "'";
145     } else if (auto arg = val.dyn_cast<BlockArgument>()) {
146       Block *block = arg.getOwner();
147       os << "from argument of block " << block << ' ';
148       os << " in op '" << block->getParentOp()->getName() << "'";
149     }
150     os << '\n';
151   }
152 }
153 #endif
154 
155 //===----------------------------------------------------------------------===//
156 // Module structure
157 //===----------------------------------------------------------------------===//
158 
159 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
160   auto funcID = funcIDMap.lookup(fnName);
161   if (!funcID) {
162     funcID = getNextID();
163     funcIDMap[fnName] = funcID;
164   }
165   return funcID;
166 }
167 
168 void Serializer::processCapability() {
169   for (auto cap : module.vce_triple()->getCapabilities())
170     (void)encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
171                                 {static_cast<uint32_t>(cap)});
172 }
173 
174 void Serializer::processDebugInfo() {
175   if (!options.emitDebugInfo)
176     return;
177   auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>();
178   auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
179   fileID = getNextID();
180   SmallVector<uint32_t, 16> operands;
181   operands.push_back(fileID);
182   (void)spirv::encodeStringLiteralInto(operands, fileName);
183   (void)encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
184   // TODO: Encode more debug instructions.
185 }
186 
187 void Serializer::processExtension() {
188   llvm::SmallVector<uint32_t, 16> extName;
189   for (spirv::Extension ext : module.vce_triple()->getExtensions()) {
190     extName.clear();
191     (void)spirv::encodeStringLiteralInto(extName,
192                                          spirv::stringifyExtension(ext));
193     (void)encodeInstructionInto(extensions, spirv::Opcode::OpExtension,
194                                 extName);
195   }
196 }
197 
198 void Serializer::processMemoryModel() {
199   uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt();
200   uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt();
201 
202   (void)encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel,
203                               {am, mm});
204 }
205 
206 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
207                                             NamedAttribute attr) {
208   auto attrName = attr.getName().strref();
209   auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true);
210   auto decoration = spirv::symbolizeDecoration(decorationName);
211   if (!decoration) {
212     return emitError(
213                loc, "non-argument attributes expected to have snake-case-ified "
214                     "decoration name, unhandled attribute with name : ")
215            << attrName;
216   }
217   SmallVector<uint32_t, 1> args;
218   switch (decoration.getValue()) {
219   case spirv::Decoration::Binding:
220   case spirv::Decoration::DescriptorSet:
221   case spirv::Decoration::Location:
222     if (auto intAttr = attr.getValue().dyn_cast<IntegerAttr>()) {
223       args.push_back(intAttr.getValue().getZExtValue());
224       break;
225     }
226     return emitError(loc, "expected integer attribute for ") << attrName;
227   case spirv::Decoration::BuiltIn:
228     if (auto strAttr = attr.getValue().dyn_cast<StringAttr>()) {
229       auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
230       if (enumVal) {
231         args.push_back(static_cast<uint32_t>(enumVal.getValue()));
232         break;
233       }
234       return emitError(loc, "invalid ")
235              << attrName << " attribute " << strAttr.getValue();
236     }
237     return emitError(loc, "expected string attribute for ") << attrName;
238   case spirv::Decoration::Aliased:
239   case spirv::Decoration::Flat:
240   case spirv::Decoration::NonReadable:
241   case spirv::Decoration::NonWritable:
242   case spirv::Decoration::NoPerspective:
243   case spirv::Decoration::Restrict:
244   case spirv::Decoration::RelaxedPrecision:
245     // For unit attributes, the args list has no values so we do nothing
246     if (auto unitAttr = attr.getValue().dyn_cast<UnitAttr>())
247       break;
248     return emitError(loc, "expected unit attribute for ") << attrName;
249   default:
250     return emitError(loc, "unhandled decoration ") << decorationName;
251   }
252   return emitDecoration(resultID, decoration.getValue(), args);
253 }
254 
255 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
256   assert(!name.empty() && "unexpected empty string for OpName");
257   if (!options.emitSymbolName)
258     return success();
259 
260   SmallVector<uint32_t, 4> nameOperands;
261   nameOperands.push_back(resultID);
262   if (failed(spirv::encodeStringLiteralInto(nameOperands, name)))
263     return failure();
264   return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
265 }
266 
267 template <>
268 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
269     Location loc, spirv::ArrayType type, uint32_t resultID) {
270   if (unsigned stride = type.getArrayStride()) {
271     // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
272     return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
273   }
274   return success();
275 }
276 
277 template <>
278 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
279     Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
280   if (unsigned stride = type.getArrayStride()) {
281     // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
282     return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
283   }
284   return success();
285 }
286 
287 LogicalResult Serializer::processMemberDecoration(
288     uint32_t structID,
289     const spirv::StructType::MemberDecorationInfo &memberDecoration) {
290   SmallVector<uint32_t, 4> args(
291       {structID, memberDecoration.memberIndex,
292        static_cast<uint32_t>(memberDecoration.decoration)});
293   if (memberDecoration.hasValue) {
294     args.push_back(memberDecoration.decorationValue);
295   }
296   return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate,
297                                args);
298 }
299 
300 //===----------------------------------------------------------------------===//
301 // Type
302 //===----------------------------------------------------------------------===//
303 
304 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
305 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
306 // PushConstant Storage Classes must be explicitly laid out."
307 bool Serializer::isInterfaceStructPtrType(Type type) const {
308   if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
309     switch (ptrType.getStorageClass()) {
310     case spirv::StorageClass::PhysicalStorageBuffer:
311     case spirv::StorageClass::PushConstant:
312     case spirv::StorageClass::StorageBuffer:
313     case spirv::StorageClass::Uniform:
314       return ptrType.getPointeeType().isa<spirv::StructType>();
315     default:
316       break;
317     }
318   }
319   return false;
320 }
321 
322 LogicalResult Serializer::processType(Location loc, Type type,
323                                       uint32_t &typeID) {
324   // Maintains a set of names for nested identified struct types. This is used
325   // to properly serialize recursive references.
326   SetVector<StringRef> serializationCtx;
327   return processTypeImpl(loc, type, typeID, serializationCtx);
328 }
329 
330 LogicalResult
331 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
332                             SetVector<StringRef> &serializationCtx) {
333   typeID = getTypeID(type);
334   if (typeID)
335     return success();
336 
337   typeID = getNextID();
338   SmallVector<uint32_t, 4> operands;
339 
340   operands.push_back(typeID);
341   auto typeEnum = spirv::Opcode::OpTypeVoid;
342   bool deferSerialization = false;
343 
344   if ((type.isa<FunctionType>() &&
345        succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
346                                      operands))) ||
347       succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
348                                  deferSerialization, serializationCtx))) {
349     if (deferSerialization)
350       return success();
351 
352     typeIDMap[type] = typeID;
353 
354     if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands)))
355       return failure();
356 
357     if (recursiveStructInfos.count(type) != 0) {
358       // This recursive struct type is emitted already, now the OpTypePointer
359       // instructions referring to recursive references are emitted as well.
360       for (auto &ptrInfo : recursiveStructInfos[type]) {
361         // TODO: This might not work if more than 1 recursive reference is
362         // present in the struct.
363         SmallVector<uint32_t, 4> ptrOperands;
364         ptrOperands.push_back(ptrInfo.pointerTypeID);
365         ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
366         ptrOperands.push_back(typeIDMap[type]);
367 
368         if (failed(encodeInstructionInto(
369                 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands)))
370           return failure();
371       }
372 
373       recursiveStructInfos[type].clear();
374     }
375 
376     return success();
377   }
378 
379   return failure();
380 }
381 
382 LogicalResult Serializer::prepareBasicType(
383     Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
384     SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
385     SetVector<StringRef> &serializationCtx) {
386   deferSerialization = false;
387 
388   if (isVoidType(type)) {
389     typeEnum = spirv::Opcode::OpTypeVoid;
390     return success();
391   }
392 
393   if (auto intType = type.dyn_cast<IntegerType>()) {
394     if (intType.getWidth() == 1) {
395       typeEnum = spirv::Opcode::OpTypeBool;
396       return success();
397     }
398 
399     typeEnum = spirv::Opcode::OpTypeInt;
400     operands.push_back(intType.getWidth());
401     // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
402     // to preserve or validate.
403     // 0 indicates unsigned, or no signedness semantics
404     // 1 indicates signed semantics."
405     operands.push_back(intType.isSigned() ? 1 : 0);
406     return success();
407   }
408 
409   if (auto floatType = type.dyn_cast<FloatType>()) {
410     typeEnum = spirv::Opcode::OpTypeFloat;
411     operands.push_back(floatType.getWidth());
412     return success();
413   }
414 
415   if (auto vectorType = type.dyn_cast<VectorType>()) {
416     uint32_t elementTypeID = 0;
417     if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
418                                serializationCtx))) {
419       return failure();
420     }
421     typeEnum = spirv::Opcode::OpTypeVector;
422     operands.push_back(elementTypeID);
423     operands.push_back(vectorType.getNumElements());
424     return success();
425   }
426 
427   if (auto imageType = type.dyn_cast<spirv::ImageType>()) {
428     typeEnum = spirv::Opcode::OpTypeImage;
429     uint32_t sampledTypeID = 0;
430     if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
431       return failure();
432 
433     operands.push_back(sampledTypeID);
434     operands.push_back(static_cast<uint32_t>(imageType.getDim()));
435     operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo()));
436     operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo()));
437     operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo()));
438     operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo()));
439     operands.push_back(static_cast<uint32_t>(imageType.getImageFormat()));
440     return success();
441   }
442 
443   if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
444     typeEnum = spirv::Opcode::OpTypeArray;
445     uint32_t elementTypeID = 0;
446     if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
447                                serializationCtx))) {
448       return failure();
449     }
450     operands.push_back(elementTypeID);
451     if (auto elementCountID = prepareConstantInt(
452             loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
453       operands.push_back(elementCountID);
454     }
455     return processTypeDecoration(loc, arrayType, resultID);
456   }
457 
458   if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
459     uint32_t pointeeTypeID = 0;
460     spirv::StructType pointeeStruct =
461         ptrType.getPointeeType().dyn_cast<spirv::StructType>();
462 
463     if (pointeeStruct && pointeeStruct.isIdentified() &&
464         serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
465       // A recursive reference to an enclosing struct is found.
466       //
467       // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
468       // class as operands.
469       SmallVector<uint32_t, 2> forwardPtrOperands;
470       forwardPtrOperands.push_back(resultID);
471       forwardPtrOperands.push_back(
472           static_cast<uint32_t>(ptrType.getStorageClass()));
473 
474       (void)encodeInstructionInto(typesGlobalValues,
475                                   spirv::Opcode::OpTypeForwardPointer,
476                                   forwardPtrOperands);
477 
478       // 2. Find the pointee (enclosing) struct.
479       auto structType = spirv::StructType::getIdentified(
480           module.getContext(), pointeeStruct.getIdentifier());
481 
482       if (!structType)
483         return failure();
484 
485       // 3. Mark the OpTypePointer that is supposed to be emitted by this call
486       // as deferred.
487       deferSerialization = true;
488 
489       // 4. Record the info needed to emit the deferred OpTypePointer
490       // instruction when the enclosing struct is completely serialized.
491       recursiveStructInfos[structType].push_back(
492           {resultID, ptrType.getStorageClass()});
493     } else {
494       if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
495                                  serializationCtx)))
496         return failure();
497     }
498 
499     typeEnum = spirv::Opcode::OpTypePointer;
500     operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
501     operands.push_back(pointeeTypeID);
502 
503     if (isInterfaceStructPtrType(ptrType)) {
504       if (failed(emitDecoration(getTypeID(pointeeStruct),
505                                 spirv::Decoration::Block)))
506         return emitError(loc, "cannot decorate ")
507                << pointeeStruct << " with Block decoration";
508     }
509 
510     return success();
511   }
512 
513   if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
514     uint32_t elementTypeID = 0;
515     if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
516                                elementTypeID, serializationCtx))) {
517       return failure();
518     }
519     typeEnum = spirv::Opcode::OpTypeRuntimeArray;
520     operands.push_back(elementTypeID);
521     return processTypeDecoration(loc, runtimeArrayType, resultID);
522   }
523 
524   if (auto sampledImageType = type.dyn_cast<spirv::SampledImageType>()) {
525     typeEnum = spirv::Opcode::OpTypeSampledImage;
526     uint32_t imageTypeID = 0;
527     if (failed(
528             processType(loc, sampledImageType.getImageType(), imageTypeID))) {
529       return failure();
530     }
531     operands.push_back(imageTypeID);
532     return success();
533   }
534 
535   if (auto structType = type.dyn_cast<spirv::StructType>()) {
536     if (structType.isIdentified()) {
537       (void)processName(resultID, structType.getIdentifier());
538       serializationCtx.insert(structType.getIdentifier());
539     }
540 
541     bool hasOffset = structType.hasOffset();
542     for (auto elementIndex :
543          llvm::seq<uint32_t>(0, structType.getNumElements())) {
544       uint32_t elementTypeID = 0;
545       if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
546                                  elementTypeID, serializationCtx))) {
547         return failure();
548       }
549       operands.push_back(elementTypeID);
550       if (hasOffset) {
551         // Decorate each struct member with an offset
552         spirv::StructType::MemberDecorationInfo offsetDecoration{
553             elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
554             static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
555         if (failed(processMemberDecoration(resultID, offsetDecoration))) {
556           return emitError(loc, "cannot decorate ")
557                  << elementIndex << "-th member of " << structType
558                  << " with its offset";
559         }
560       }
561     }
562     SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
563     structType.getMemberDecorations(memberDecorations);
564 
565     for (auto &memberDecoration : memberDecorations) {
566       if (failed(processMemberDecoration(resultID, memberDecoration))) {
567         return emitError(loc, "cannot decorate ")
568                << static_cast<uint32_t>(memberDecoration.memberIndex)
569                << "-th member of " << structType << " with "
570                << stringifyDecoration(memberDecoration.decoration);
571       }
572     }
573 
574     typeEnum = spirv::Opcode::OpTypeStruct;
575 
576     if (structType.isIdentified())
577       serializationCtx.remove(structType.getIdentifier());
578 
579     return success();
580   }
581 
582   if (auto cooperativeMatrixType =
583           type.dyn_cast<spirv::CooperativeMatrixNVType>()) {
584     uint32_t elementTypeID = 0;
585     if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
586                                elementTypeID, serializationCtx))) {
587       return failure();
588     }
589     typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
590     auto getConstantOp = [&](uint32_t id) {
591       auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
592       return prepareConstantInt(loc, attr);
593     };
594     operands.push_back(elementTypeID);
595     operands.push_back(
596         getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
597     operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
598     operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
599     return success();
600   }
601 
602   if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
603     uint32_t elementTypeID = 0;
604     if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
605                                serializationCtx))) {
606       return failure();
607     }
608     typeEnum = spirv::Opcode::OpTypeMatrix;
609     operands.push_back(elementTypeID);
610     operands.push_back(matrixType.getNumColumns());
611     return success();
612   }
613 
614   // TODO: Handle other types.
615   return emitError(loc, "unhandled type in serialization: ") << type;
616 }
617 
618 LogicalResult
619 Serializer::prepareFunctionType(Location loc, FunctionType type,
620                                 spirv::Opcode &typeEnum,
621                                 SmallVectorImpl<uint32_t> &operands) {
622   typeEnum = spirv::Opcode::OpTypeFunction;
623   assert(type.getNumResults() <= 1 &&
624          "serialization supports only a single return value");
625   uint32_t resultID = 0;
626   if (failed(processType(
627           loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
628           resultID))) {
629     return failure();
630   }
631   operands.push_back(resultID);
632   for (auto &res : type.getInputs()) {
633     uint32_t argTypeID = 0;
634     if (failed(processType(loc, res, argTypeID))) {
635       return failure();
636     }
637     operands.push_back(argTypeID);
638   }
639   return success();
640 }
641 
642 //===----------------------------------------------------------------------===//
643 // Constant
644 //===----------------------------------------------------------------------===//
645 
646 uint32_t Serializer::prepareConstant(Location loc, Type constType,
647                                      Attribute valueAttr) {
648   if (auto id = prepareConstantScalar(loc, valueAttr)) {
649     return id;
650   }
651 
652   // This is a composite literal. We need to handle each component separately
653   // and then emit an OpConstantComposite for the whole.
654 
655   if (auto id = getConstantID(valueAttr)) {
656     return id;
657   }
658 
659   uint32_t typeID = 0;
660   if (failed(processType(loc, constType, typeID))) {
661     return 0;
662   }
663 
664   uint32_t resultID = 0;
665   if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) {
666     int rank = attr.getType().dyn_cast<ShapedType>().getRank();
667     SmallVector<uint64_t, 4> index(rank);
668     resultID = prepareDenseElementsConstant(loc, constType, attr,
669                                             /*dim=*/0, index);
670   } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
671     resultID = prepareArrayConstant(loc, constType, arrayAttr);
672   }
673 
674   if (resultID == 0) {
675     emitError(loc, "cannot serialize attribute: ") << valueAttr;
676     return 0;
677   }
678 
679   constIDMap[valueAttr] = resultID;
680   return resultID;
681 }
682 
683 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
684                                           ArrayAttr attr) {
685   uint32_t typeID = 0;
686   if (failed(processType(loc, constType, typeID))) {
687     return 0;
688   }
689 
690   uint32_t resultID = getNextID();
691   SmallVector<uint32_t, 4> operands = {typeID, resultID};
692   operands.reserve(attr.size() + 2);
693   auto elementType = constType.cast<spirv::ArrayType>().getElementType();
694   for (Attribute elementAttr : attr) {
695     if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
696       operands.push_back(elementID);
697     } else {
698       return 0;
699     }
700   }
701   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
702   (void)encodeInstructionInto(typesGlobalValues, opcode, operands);
703 
704   return resultID;
705 }
706 
707 // TODO: Turn the below function into iterative function, instead of
708 // recursive function.
709 uint32_t
710 Serializer::prepareDenseElementsConstant(Location loc, Type constType,
711                                          DenseElementsAttr valueAttr, int dim,
712                                          MutableArrayRef<uint64_t> index) {
713   auto shapedType = valueAttr.getType().dyn_cast<ShapedType>();
714   assert(dim <= shapedType.getRank());
715   if (shapedType.getRank() == dim) {
716     if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
717       return attr.getType().getElementType().isInteger(1)
718                  ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
719                  : prepareConstantInt(loc,
720                                       attr.getValues<IntegerAttr>()[index]);
721     }
722     if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
723       return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
724     }
725     return 0;
726   }
727 
728   uint32_t typeID = 0;
729   if (failed(processType(loc, constType, typeID))) {
730     return 0;
731   }
732 
733   uint32_t resultID = getNextID();
734   SmallVector<uint32_t, 4> operands = {typeID, resultID};
735   operands.reserve(shapedType.getDimSize(dim) + 2);
736   auto elementType = constType.cast<spirv::CompositeType>().getElementType(0);
737   for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
738     index[dim] = i;
739     if (auto elementID = prepareDenseElementsConstant(
740             loc, elementType, valueAttr, dim + 1, index)) {
741       operands.push_back(elementID);
742     } else {
743       return 0;
744     }
745   }
746   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
747   (void)encodeInstructionInto(typesGlobalValues, opcode, operands);
748 
749   return resultID;
750 }
751 
752 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
753                                            bool isSpec) {
754   if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
755     return prepareConstantFp(loc, floatAttr, isSpec);
756   }
757   if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
758     return prepareConstantBool(loc, boolAttr, isSpec);
759   }
760   if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
761     return prepareConstantInt(loc, intAttr, isSpec);
762   }
763 
764   return 0;
765 }
766 
767 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
768                                          bool isSpec) {
769   if (!isSpec) {
770     // We can de-duplicate normal constants, but not specialization constants.
771     if (auto id = getConstantID(boolAttr)) {
772       return id;
773     }
774   }
775 
776   // Process the type for this bool literal
777   uint32_t typeID = 0;
778   if (failed(processType(loc, boolAttr.getType(), typeID))) {
779     return 0;
780   }
781 
782   auto resultID = getNextID();
783   auto opcode = boolAttr.getValue()
784                     ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
785                               : spirv::Opcode::OpConstantTrue)
786                     : (isSpec ? spirv::Opcode::OpSpecConstantFalse
787                               : spirv::Opcode::OpConstantFalse);
788   (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
789 
790   if (!isSpec) {
791     constIDMap[boolAttr] = resultID;
792   }
793   return resultID;
794 }
795 
796 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
797                                         bool isSpec) {
798   if (!isSpec) {
799     // We can de-duplicate normal constants, but not specialization constants.
800     if (auto id = getConstantID(intAttr)) {
801       return id;
802     }
803   }
804 
805   // Process the type for this integer literal
806   uint32_t typeID = 0;
807   if (failed(processType(loc, intAttr.getType(), typeID))) {
808     return 0;
809   }
810 
811   auto resultID = getNextID();
812   APInt value = intAttr.getValue();
813   unsigned bitwidth = value.getBitWidth();
814   bool isSigned = value.isSignedIntN(bitwidth);
815 
816   auto opcode =
817       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
818 
819   switch (bitwidth) {
820     // According to SPIR-V spec, "When the type's bit width is less than
821     // 32-bits, the literal's value appears in the low-order bits of the word,
822     // and the high-order bits must be 0 for a floating-point type, or 0 for an
823     // integer type with Signedness of 0, or sign extended when Signedness
824     // is 1."
825   case 32:
826   case 16:
827   case 8: {
828     uint32_t word = 0;
829     if (isSigned) {
830       word = static_cast<int32_t>(value.getSExtValue());
831     } else {
832       word = static_cast<uint32_t>(value.getZExtValue());
833     }
834     (void)encodeInstructionInto(typesGlobalValues, opcode,
835                                 {typeID, resultID, word});
836   } break;
837     // According to SPIR-V spec: "When the type's bit width is larger than one
838     // word, the literal’s low-order words appear first."
839   case 64: {
840     struct DoubleWord {
841       uint32_t word1;
842       uint32_t word2;
843     } words;
844     if (isSigned) {
845       words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
846     } else {
847       words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
848     }
849     (void)encodeInstructionInto(typesGlobalValues, opcode,
850                                 {typeID, resultID, words.word1, words.word2});
851   } break;
852   default: {
853     std::string valueStr;
854     llvm::raw_string_ostream rss(valueStr);
855     value.print(rss, /*isSigned=*/false);
856 
857     emitError(loc, "cannot serialize ")
858         << bitwidth << "-bit integer literal: " << rss.str();
859     return 0;
860   }
861   }
862 
863   if (!isSpec) {
864     constIDMap[intAttr] = resultID;
865   }
866   return resultID;
867 }
868 
869 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
870                                        bool isSpec) {
871   if (!isSpec) {
872     // We can de-duplicate normal constants, but not specialization constants.
873     if (auto id = getConstantID(floatAttr)) {
874       return id;
875     }
876   }
877 
878   // Process the type for this float literal
879   uint32_t typeID = 0;
880   if (failed(processType(loc, floatAttr.getType(), typeID))) {
881     return 0;
882   }
883 
884   auto resultID = getNextID();
885   APFloat value = floatAttr.getValue();
886   APInt intValue = value.bitcastToAPInt();
887 
888   auto opcode =
889       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
890 
891   if (&value.getSemantics() == &APFloat::IEEEsingle()) {
892     uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
893     (void)encodeInstructionInto(typesGlobalValues, opcode,
894                                 {typeID, resultID, word});
895   } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
896     struct DoubleWord {
897       uint32_t word1;
898       uint32_t word2;
899     } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
900     (void)encodeInstructionInto(typesGlobalValues, opcode,
901                                 {typeID, resultID, words.word1, words.word2});
902   } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
903     uint32_t word =
904         static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
905     (void)encodeInstructionInto(typesGlobalValues, opcode,
906                                 {typeID, resultID, word});
907   } else {
908     std::string valueStr;
909     llvm::raw_string_ostream rss(valueStr);
910     value.print(rss);
911 
912     emitError(loc, "cannot serialize ")
913         << floatAttr.getType() << "-typed float literal: " << rss.str();
914     return 0;
915   }
916 
917   if (!isSpec) {
918     constIDMap[floatAttr] = resultID;
919   }
920   return resultID;
921 }
922 
923 //===----------------------------------------------------------------------===//
924 // Control flow
925 //===----------------------------------------------------------------------===//
926 
927 uint32_t Serializer::getOrCreateBlockID(Block *block) {
928   if (uint32_t id = getBlockID(block))
929     return id;
930   return blockIDMap[block] = getNextID();
931 }
932 
933 LogicalResult
934 Serializer::processBlock(Block *block, bool omitLabel,
935                          function_ref<void()> actionBeforeTerminator) {
936   LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
937   LLVM_DEBUG(block->print(llvm::dbgs()));
938   LLVM_DEBUG(llvm::dbgs() << '\n');
939   if (!omitLabel) {
940     uint32_t blockID = getOrCreateBlockID(block);
941     LLVM_DEBUG(llvm::dbgs()
942                << "[block] " << block << " (id = " << blockID << ")\n");
943 
944     // Emit OpLabel for this block.
945     (void)encodeInstructionInto(functionBody, spirv::Opcode::OpLabel,
946                                 {blockID});
947   }
948 
949   // Emit OpPhi instructions for block arguments, if any.
950   if (failed(emitPhiForBlockArguments(block)))
951     return failure();
952 
953   // Process each op in this block except the terminator.
954   for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
955     if (failed(processOperation(&op)))
956       return failure();
957   }
958 
959   // Process the terminator.
960   if (actionBeforeTerminator)
961     actionBeforeTerminator();
962   if (failed(processOperation(&block->back())))
963     return failure();
964 
965   return success();
966 }
967 
968 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
969   // Nothing to do if this block has no arguments or it's the entry block, which
970   // always has the same arguments as the function signature.
971   if (block->args_empty() || block->isEntryBlock())
972     return success();
973 
974   // If the block has arguments, we need to create SPIR-V OpPhi instructions.
975   // A SPIR-V OpPhi instruction is of the syntax:
976   //   OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
977   // So we need to collect all predecessor blocks and the arguments they send
978   // to this block.
979   SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
980   for (Block *predecessor : block->getPredecessors()) {
981     auto *terminator = predecessor->getTerminator();
982     // The predecessor here is the immediate one according to MLIR's IR
983     // structure. It does not directly map to the incoming parent block for the
984     // OpPhi instructions at SPIR-V binary level. This is because structured
985     // control flow ops are serialized to multiple SPIR-V blocks. If there is a
986     // spv.mlir.selection/spv.mlir.loop op in the MLIR predecessor block, the
987     // branch op jumping to the OpPhi's block then resides in the previous
988     // structured control flow op's merge block.
989     predecessor = getPhiIncomingBlock(predecessor);
990     if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
991       predecessors.emplace_back(predecessor, branchOp.getOperands());
992     } else if (auto branchCondOp =
993                    dyn_cast<spirv::BranchConditionalOp>(terminator)) {
994       Optional<OperandRange> blockOperands;
995 
996       for (auto successorIdx :
997            llvm::seq<unsigned>(0, predecessor->getNumSuccessors()))
998         if (predecessor->getSuccessors()[successorIdx] == block) {
999           blockOperands = branchCondOp.getSuccessorOperands(successorIdx);
1000           break;
1001         }
1002 
1003       assert(blockOperands && !blockOperands->empty() &&
1004              "expected non-empty block operand range");
1005       predecessors.emplace_back(predecessor, *blockOperands);
1006     } else {
1007       return terminator->emitError("unimplemented terminator for Phi creation");
1008     }
1009   }
1010 
1011   // Then create OpPhi instruction for each of the block argument.
1012   for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1013     BlockArgument arg = block->getArgument(argIndex);
1014 
1015     // Get the type <id> and result <id> for this OpPhi instruction.
1016     uint32_t phiTypeID = 0;
1017     if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1018       return failure();
1019     uint32_t phiID = getNextID();
1020 
1021     LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1022                             << arg << " (id = " << phiID << ")\n");
1023 
1024     // Prepare the (value <id>, parent block <id>) pairs.
1025     SmallVector<uint32_t, 8> phiArgs;
1026     phiArgs.push_back(phiTypeID);
1027     phiArgs.push_back(phiID);
1028 
1029     for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1030       Value value = predecessors[predIndex].second[argIndex];
1031       uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1032       LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1033                               << ") value " << value << ' ');
1034       // Each pair is a value <id> ...
1035       uint32_t valueId = getValueID(value);
1036       if (valueId == 0) {
1037         // The op generating this value hasn't been visited yet so we don't have
1038         // an <id> assigned yet. Record this to fix up later.
1039         LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1040         deferredPhiValues[value].push_back(functionBody.size() + 1 +
1041                                            phiArgs.size());
1042       } else {
1043         LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1044       }
1045       phiArgs.push_back(valueId);
1046       // ... and a parent block <id>.
1047       phiArgs.push_back(predBlockId);
1048     }
1049 
1050     (void)encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1051     valueIDMap[arg] = phiID;
1052   }
1053 
1054   return success();
1055 }
1056 
1057 //===----------------------------------------------------------------------===//
1058 // Operation
1059 //===----------------------------------------------------------------------===//
1060 
1061 LogicalResult Serializer::encodeExtensionInstruction(
1062     Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1063     ArrayRef<uint32_t> operands) {
1064   // Check if the extension has been imported.
1065   auto &setID = extendedInstSetIDMap[extensionSetName];
1066   if (!setID) {
1067     setID = getNextID();
1068     SmallVector<uint32_t, 16> importOperands;
1069     importOperands.push_back(setID);
1070     if (failed(
1071             spirv::encodeStringLiteralInto(importOperands, extensionSetName)) ||
1072         failed(encodeInstructionInto(
1073             extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) {
1074       return failure();
1075     }
1076   }
1077 
1078   // The first two operands are the result type <id> and result <id>. The set
1079   // <id> and the opcode need to be insert after this.
1080   if (operands.size() < 2) {
1081     return op->emitError("extended instructions must have a result encoding");
1082   }
1083   SmallVector<uint32_t, 8> extInstOperands;
1084   extInstOperands.reserve(operands.size() + 2);
1085   extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1086   extInstOperands.push_back(setID);
1087   extInstOperands.push_back(extensionOpcode);
1088   extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1089   return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1090                                extInstOperands);
1091 }
1092 
1093 LogicalResult Serializer::processOperation(Operation *opInst) {
1094   LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1095 
1096   // First dispatch the ops that do not directly mirror an instruction from
1097   // the SPIR-V spec.
1098   return TypeSwitch<Operation *, LogicalResult>(opInst)
1099       .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1100       .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1101       .Case([&](spirv::BranchConditionalOp op) {
1102         return processBranchConditionalOp(op);
1103       })
1104       .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1105       .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1106       .Case([&](spirv::GlobalVariableOp op) {
1107         return processGlobalVariableOp(op);
1108       })
1109       .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1110       .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1111       .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1112       .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1113       .Case([&](spirv::SpecConstantCompositeOp op) {
1114         return processSpecConstantCompositeOp(op);
1115       })
1116       .Case([&](spirv::SpecConstantOperationOp op) {
1117         return processSpecConstantOperationOp(op);
1118       })
1119       .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1120       .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1121 
1122       // Then handle all the ops that directly mirror SPIR-V instructions with
1123       // auto-generated methods.
1124       .Default(
1125           [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1126 }
1127 
1128 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1129                                                       StringRef extInstSet,
1130                                                       uint32_t opcode) {
1131   SmallVector<uint32_t, 4> operands;
1132   Location loc = op->getLoc();
1133 
1134   uint32_t resultID = 0;
1135   if (op->getNumResults() != 0) {
1136     uint32_t resultTypeID = 0;
1137     if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
1138       return failure();
1139     operands.push_back(resultTypeID);
1140 
1141     resultID = getNextID();
1142     operands.push_back(resultID);
1143     valueIDMap[op->getResult(0)] = resultID;
1144   };
1145 
1146   for (Value operand : op->getOperands())
1147     operands.push_back(getValueID(operand));
1148 
1149   (void)emitDebugLine(functionBody, loc);
1150 
1151   if (extInstSet.empty()) {
1152     (void)encodeInstructionInto(functionBody,
1153                                 static_cast<spirv::Opcode>(opcode), operands);
1154   } else {
1155     (void)encodeExtensionInstruction(op, extInstSet, opcode, operands);
1156   }
1157 
1158   if (op->getNumResults() != 0) {
1159     for (auto attr : op->getAttrs()) {
1160       if (failed(processDecoration(loc, resultID, attr)))
1161         return failure();
1162     }
1163   }
1164 
1165   return success();
1166 }
1167 
1168 LogicalResult Serializer::emitDecoration(uint32_t target,
1169                                          spirv::Decoration decoration,
1170                                          ArrayRef<uint32_t> params) {
1171   uint32_t wordCount = 3 + params.size();
1172   decorations.push_back(
1173       spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate));
1174   decorations.push_back(target);
1175   decorations.push_back(static_cast<uint32_t>(decoration));
1176   decorations.append(params.begin(), params.end());
1177   return success();
1178 }
1179 
1180 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1181                                         Location loc) {
1182   if (!options.emitDebugInfo)
1183     return success();
1184 
1185   if (lastProcessedWasMergeInst) {
1186     lastProcessedWasMergeInst = false;
1187     return success();
1188   }
1189 
1190   auto fileLoc = loc.dyn_cast<FileLineColLoc>();
1191   if (fileLoc)
1192     (void)encodeInstructionInto(
1193         binary, spirv::Opcode::OpLine,
1194         {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1195   return success();
1196 }
1197 } // namespace spirv
1198 } // namespace mlir
1199