1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Support/LogicalResult.h"
19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/ADT/bit.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "spirv-serialization"
28 
29 using namespace mlir;
30 
31 /// Returns the merge block if the given `op` is a structured control flow op.
32 /// Otherwise returns nullptr.
33 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) {
34   if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
35     return selectionOp.getMergeBlock();
36   if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
37     return loopOp.getMergeBlock();
38   return nullptr;
39 }
40 
41 /// Given a predecessor `block` for a block with arguments, returns the block
42 /// that should be used as the parent block for SPIR-V OpPhi instructions
43 /// corresponding to the block arguments.
44 static Block *getPhiIncomingBlock(Block *block) {
45   // If the predecessor block in question is the entry block for a spv.loop,
46   // we jump to this spv.loop from its enclosing block.
47   if (block->isEntryBlock()) {
48     if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
49       // Then the incoming parent block for OpPhi should be the merge block of
50       // the structured control flow op before this loop.
51       Operation *op = loopOp.getOperation();
52       while ((op = op->getPrevNode()) != nullptr)
53         if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
54           return incomingBlock;
55       // Or the enclosing block itself if no structured control flow ops
56       // exists before this loop.
57       return loopOp->getBlock();
58     }
59   }
60 
61   // Otherwise, we jump from the given predecessor block. Try to see if there is
62   // a structured control flow op inside it.
63   for (Operation &op : llvm::reverse(block->getOperations())) {
64     if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
65       return incomingBlock;
66   }
67   return block;
68 }
69 
70 namespace mlir {
71 namespace spirv {
72 
73 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
74 /// the given `binary` vector.
75 LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
76                                     spirv::Opcode op,
77                                     ArrayRef<uint32_t> operands) {
78   uint32_t wordCount = 1 + operands.size();
79   binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
80   binary.append(operands.begin(), operands.end());
81   return success();
82 }
83 
84 Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo)
85     : module(module), mlirBuilder(module.getContext()),
86       emitDebugInfo(emitDebugInfo) {}
87 
88 LogicalResult Serializer::serialize() {
89   LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
90 
91   if (failed(module.verify()))
92     return failure();
93 
94   // TODO: handle the other sections
95   processCapability();
96   processExtension();
97   processMemoryModel();
98   processDebugInfo();
99 
100   // Iterate over the module body to serialize it. Assumptions are that there is
101   // only one basic block in the moduleOp
102   for (auto &op : module.getBlock()) {
103     if (failed(processOperation(&op))) {
104       return failure();
105     }
106   }
107 
108   LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
109   return success();
110 }
111 
112 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
113   auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
114                     extensions.size() + extendedSets.size() +
115                     memoryModel.size() + entryPoints.size() +
116                     executionModes.size() + decorations.size() +
117                     typesGlobalValues.size() + functions.size();
118 
119   binary.clear();
120   binary.reserve(moduleSize);
121 
122   spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID);
123   binary.append(capabilities.begin(), capabilities.end());
124   binary.append(extensions.begin(), extensions.end());
125   binary.append(extendedSets.begin(), extendedSets.end());
126   binary.append(memoryModel.begin(), memoryModel.end());
127   binary.append(entryPoints.begin(), entryPoints.end());
128   binary.append(executionModes.begin(), executionModes.end());
129   binary.append(debug.begin(), debug.end());
130   binary.append(names.begin(), names.end());
131   binary.append(decorations.begin(), decorations.end());
132   binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
133   binary.append(functions.begin(), functions.end());
134 }
135 
136 #ifndef NDEBUG
137 void Serializer::printValueIDMap(raw_ostream &os) {
138   os << "\n= Value <id> Map =\n\n";
139   for (auto valueIDPair : valueIDMap) {
140     Value val = valueIDPair.first;
141     os << "  " << val << " "
142        << "id = " << valueIDPair.second << ' ';
143     if (auto *op = val.getDefiningOp()) {
144       os << "from op '" << op->getName() << "'";
145     } else if (auto arg = val.dyn_cast<BlockArgument>()) {
146       Block *block = arg.getOwner();
147       os << "from argument of block " << block << ' ';
148       os << " in op '" << block->getParentOp()->getName() << "'";
149     }
150     os << '\n';
151   }
152 }
153 #endif
154 
155 //===----------------------------------------------------------------------===//
156 // Module structure
157 //===----------------------------------------------------------------------===//
158 
159 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
160   auto funcID = funcIDMap.lookup(fnName);
161   if (!funcID) {
162     funcID = getNextID();
163     funcIDMap[fnName] = funcID;
164   }
165   return funcID;
166 }
167 
168 void Serializer::processCapability() {
169   for (auto cap : module.vce_triple()->getCapabilities())
170     (void)encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
171                                 {static_cast<uint32_t>(cap)});
172 }
173 
174 void Serializer::processDebugInfo() {
175   if (!emitDebugInfo)
176     return;
177   auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>();
178   auto fileName = fileLoc ? fileLoc.getFilename() : "<unknown>";
179   fileID = getNextID();
180   SmallVector<uint32_t, 16> operands;
181   operands.push_back(fileID);
182   (void)spirv::encodeStringLiteralInto(operands, fileName);
183   (void)encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
184   // TODO: Encode more debug instructions.
185 }
186 
187 void Serializer::processExtension() {
188   llvm::SmallVector<uint32_t, 16> extName;
189   for (spirv::Extension ext : module.vce_triple()->getExtensions()) {
190     extName.clear();
191     (void)spirv::encodeStringLiteralInto(extName,
192                                          spirv::stringifyExtension(ext));
193     (void)encodeInstructionInto(extensions, spirv::Opcode::OpExtension,
194                                 extName);
195   }
196 }
197 
198 void Serializer::processMemoryModel() {
199   uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt();
200   uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt();
201 
202   (void)encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel,
203                               {am, mm});
204 }
205 
206 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
207                                             NamedAttribute attr) {
208   auto attrName = attr.first.strref();
209   auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true);
210   auto decoration = spirv::symbolizeDecoration(decorationName);
211   if (!decoration) {
212     return emitError(
213                loc, "non-argument attributes expected to have snake-case-ified "
214                     "decoration name, unhandled attribute with name : ")
215            << attrName;
216   }
217   SmallVector<uint32_t, 1> args;
218   switch (decoration.getValue()) {
219   case spirv::Decoration::Binding:
220   case spirv::Decoration::DescriptorSet:
221   case spirv::Decoration::Location:
222     if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) {
223       args.push_back(intAttr.getValue().getZExtValue());
224       break;
225     }
226     return emitError(loc, "expected integer attribute for ") << attrName;
227   case spirv::Decoration::BuiltIn:
228     if (auto strAttr = attr.second.dyn_cast<StringAttr>()) {
229       auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
230       if (enumVal) {
231         args.push_back(static_cast<uint32_t>(enumVal.getValue()));
232         break;
233       }
234       return emitError(loc, "invalid ")
235              << attrName << " attribute " << strAttr.getValue();
236     }
237     return emitError(loc, "expected string attribute for ") << attrName;
238   case spirv::Decoration::Aliased:
239   case spirv::Decoration::Flat:
240   case spirv::Decoration::NonReadable:
241   case spirv::Decoration::NonWritable:
242   case spirv::Decoration::NoPerspective:
243   case spirv::Decoration::Restrict:
244     // For unit attributes, the args list has no values so we do nothing
245     if (auto unitAttr = attr.second.dyn_cast<UnitAttr>())
246       break;
247     return emitError(loc, "expected unit attribute for ") << attrName;
248   default:
249     return emitError(loc, "unhandled decoration ") << decorationName;
250   }
251   return emitDecoration(resultID, decoration.getValue(), args);
252 }
253 
254 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
255   assert(!name.empty() && "unexpected empty string for OpName");
256 
257   SmallVector<uint32_t, 4> nameOperands;
258   nameOperands.push_back(resultID);
259   if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) {
260     return failure();
261   }
262   return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
263 }
264 
265 template <>
266 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
267     Location loc, spirv::ArrayType type, uint32_t resultID) {
268   if (unsigned stride = type.getArrayStride()) {
269     // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
270     return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
271   }
272   return success();
273 }
274 
275 template <>
276 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
277     Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
278   if (unsigned stride = type.getArrayStride()) {
279     // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
280     return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
281   }
282   return success();
283 }
284 
285 LogicalResult Serializer::processMemberDecoration(
286     uint32_t structID,
287     const spirv::StructType::MemberDecorationInfo &memberDecoration) {
288   SmallVector<uint32_t, 4> args(
289       {structID, memberDecoration.memberIndex,
290        static_cast<uint32_t>(memberDecoration.decoration)});
291   if (memberDecoration.hasValue) {
292     args.push_back(memberDecoration.decorationValue);
293   }
294   return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate,
295                                args);
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // Type
300 //===----------------------------------------------------------------------===//
301 
302 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
303 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
304 // PushConstant Storage Classes must be explicitly laid out."
305 bool Serializer::isInterfaceStructPtrType(Type type) const {
306   if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
307     switch (ptrType.getStorageClass()) {
308     case spirv::StorageClass::PhysicalStorageBuffer:
309     case spirv::StorageClass::PushConstant:
310     case spirv::StorageClass::StorageBuffer:
311     case spirv::StorageClass::Uniform:
312       return ptrType.getPointeeType().isa<spirv::StructType>();
313     default:
314       break;
315     }
316   }
317   return false;
318 }
319 
320 LogicalResult Serializer::processType(Location loc, Type type,
321                                       uint32_t &typeID) {
322   // Maintains a set of names for nested identified struct types. This is used
323   // to properly serialize recursive references.
324   llvm::SetVector<StringRef> serializationCtx;
325   return processTypeImpl(loc, type, typeID, serializationCtx);
326 }
327 
328 LogicalResult
329 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
330                             llvm::SetVector<StringRef> &serializationCtx) {
331   typeID = getTypeID(type);
332   if (typeID) {
333     return success();
334   }
335   typeID = getNextID();
336   SmallVector<uint32_t, 4> operands;
337 
338   operands.push_back(typeID);
339   auto typeEnum = spirv::Opcode::OpTypeVoid;
340   bool deferSerialization = false;
341 
342   if ((type.isa<FunctionType>() &&
343        succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
344                                      operands))) ||
345       succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
346                                  deferSerialization, serializationCtx))) {
347     if (deferSerialization)
348       return success();
349 
350     typeIDMap[type] = typeID;
351 
352     if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands)))
353       return failure();
354 
355     if (recursiveStructInfos.count(type) != 0) {
356       // This recursive struct type is emitted already, now the OpTypePointer
357       // instructions referring to recursive references are emitted as well.
358       for (auto &ptrInfo : recursiveStructInfos[type]) {
359         // TODO: This might not work if more than 1 recursive reference is
360         // present in the struct.
361         SmallVector<uint32_t, 4> ptrOperands;
362         ptrOperands.push_back(ptrInfo.pointerTypeID);
363         ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
364         ptrOperands.push_back(typeIDMap[type]);
365 
366         if (failed(encodeInstructionInto(
367                 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands)))
368           return failure();
369       }
370 
371       recursiveStructInfos[type].clear();
372     }
373 
374     return success();
375   }
376 
377   return failure();
378 }
379 
380 LogicalResult Serializer::prepareBasicType(
381     Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
382     SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
383     llvm::SetVector<StringRef> &serializationCtx) {
384   deferSerialization = false;
385 
386   if (isVoidType(type)) {
387     typeEnum = spirv::Opcode::OpTypeVoid;
388     return success();
389   }
390 
391   if (auto intType = type.dyn_cast<IntegerType>()) {
392     if (intType.getWidth() == 1) {
393       typeEnum = spirv::Opcode::OpTypeBool;
394       return success();
395     }
396 
397     typeEnum = spirv::Opcode::OpTypeInt;
398     operands.push_back(intType.getWidth());
399     // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
400     // to preserve or validate.
401     // 0 indicates unsigned, or no signedness semantics
402     // 1 indicates signed semantics."
403     operands.push_back(intType.isSigned() ? 1 : 0);
404     return success();
405   }
406 
407   if (auto floatType = type.dyn_cast<FloatType>()) {
408     typeEnum = spirv::Opcode::OpTypeFloat;
409     operands.push_back(floatType.getWidth());
410     return success();
411   }
412 
413   if (auto vectorType = type.dyn_cast<VectorType>()) {
414     uint32_t elementTypeID = 0;
415     if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
416                                serializationCtx))) {
417       return failure();
418     }
419     typeEnum = spirv::Opcode::OpTypeVector;
420     operands.push_back(elementTypeID);
421     operands.push_back(vectorType.getNumElements());
422     return success();
423   }
424 
425   if (auto imageType = type.dyn_cast<spirv::ImageType>()) {
426     typeEnum = spirv::Opcode::OpTypeImage;
427     uint32_t sampledTypeID = 0;
428     if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
429       return failure();
430 
431     operands.push_back(sampledTypeID);
432     operands.push_back(static_cast<uint32_t>(imageType.getDim()));
433     operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo()));
434     operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo()));
435     operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo()));
436     operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo()));
437     operands.push_back(static_cast<uint32_t>(imageType.getImageFormat()));
438     return success();
439   }
440 
441   if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
442     typeEnum = spirv::Opcode::OpTypeArray;
443     uint32_t elementTypeID = 0;
444     if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
445                                serializationCtx))) {
446       return failure();
447     }
448     operands.push_back(elementTypeID);
449     if (auto elementCountID = prepareConstantInt(
450             loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
451       operands.push_back(elementCountID);
452     }
453     return processTypeDecoration(loc, arrayType, resultID);
454   }
455 
456   if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
457     uint32_t pointeeTypeID = 0;
458     spirv::StructType pointeeStruct =
459         ptrType.getPointeeType().dyn_cast<spirv::StructType>();
460 
461     if (pointeeStruct && pointeeStruct.isIdentified() &&
462         serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
463       // A recursive reference to an enclosing struct is found.
464       //
465       // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
466       // class as operands.
467       SmallVector<uint32_t, 2> forwardPtrOperands;
468       forwardPtrOperands.push_back(resultID);
469       forwardPtrOperands.push_back(
470           static_cast<uint32_t>(ptrType.getStorageClass()));
471 
472       (void)encodeInstructionInto(typesGlobalValues,
473                                   spirv::Opcode::OpTypeForwardPointer,
474                                   forwardPtrOperands);
475 
476       // 2. Find the pointee (enclosing) struct.
477       auto structType = spirv::StructType::getIdentified(
478           module.getContext(), pointeeStruct.getIdentifier());
479 
480       if (!structType)
481         return failure();
482 
483       // 3. Mark the OpTypePointer that is supposed to be emitted by this call
484       // as deferred.
485       deferSerialization = true;
486 
487       // 4. Record the info needed to emit the deferred OpTypePointer
488       // instruction when the enclosing struct is completely serialized.
489       recursiveStructInfos[structType].push_back(
490           {resultID, ptrType.getStorageClass()});
491     } else {
492       if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
493                                  serializationCtx)))
494         return failure();
495     }
496 
497     typeEnum = spirv::Opcode::OpTypePointer;
498     operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
499     operands.push_back(pointeeTypeID);
500     return success();
501   }
502 
503   if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
504     uint32_t elementTypeID = 0;
505     if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
506                                elementTypeID, serializationCtx))) {
507       return failure();
508     }
509     typeEnum = spirv::Opcode::OpTypeRuntimeArray;
510     operands.push_back(elementTypeID);
511     return processTypeDecoration(loc, runtimeArrayType, resultID);
512   }
513 
514   if (auto structType = type.dyn_cast<spirv::StructType>()) {
515     if (structType.isIdentified()) {
516       (void)processName(resultID, structType.getIdentifier());
517       serializationCtx.insert(structType.getIdentifier());
518     }
519 
520     bool hasOffset = structType.hasOffset();
521     for (auto elementIndex :
522          llvm::seq<uint32_t>(0, structType.getNumElements())) {
523       uint32_t elementTypeID = 0;
524       if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
525                                  elementTypeID, serializationCtx))) {
526         return failure();
527       }
528       operands.push_back(elementTypeID);
529       if (hasOffset) {
530         // Decorate each struct member with an offset
531         spirv::StructType::MemberDecorationInfo offsetDecoration{
532             elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
533             static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
534         if (failed(processMemberDecoration(resultID, offsetDecoration))) {
535           return emitError(loc, "cannot decorate ")
536                  << elementIndex << "-th member of " << structType
537                  << " with its offset";
538         }
539       }
540     }
541     SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
542     structType.getMemberDecorations(memberDecorations);
543 
544     for (auto &memberDecoration : memberDecorations) {
545       if (failed(processMemberDecoration(resultID, memberDecoration))) {
546         return emitError(loc, "cannot decorate ")
547                << static_cast<uint32_t>(memberDecoration.memberIndex)
548                << "-th member of " << structType << " with "
549                << stringifyDecoration(memberDecoration.decoration);
550       }
551     }
552 
553     typeEnum = spirv::Opcode::OpTypeStruct;
554 
555     if (structType.isIdentified())
556       serializationCtx.remove(structType.getIdentifier());
557 
558     return success();
559   }
560 
561   if (auto cooperativeMatrixType =
562           type.dyn_cast<spirv::CooperativeMatrixNVType>()) {
563     uint32_t elementTypeID = 0;
564     if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
565                                elementTypeID, serializationCtx))) {
566       return failure();
567     }
568     typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
569     auto getConstantOp = [&](uint32_t id) {
570       auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
571       return prepareConstantInt(loc, attr);
572     };
573     operands.push_back(elementTypeID);
574     operands.push_back(
575         getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
576     operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
577     operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
578     return success();
579   }
580 
581   if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
582     uint32_t elementTypeID = 0;
583     if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
584                                serializationCtx))) {
585       return failure();
586     }
587     typeEnum = spirv::Opcode::OpTypeMatrix;
588     operands.push_back(elementTypeID);
589     operands.push_back(matrixType.getNumColumns());
590     return success();
591   }
592 
593   // TODO: Handle other types.
594   return emitError(loc, "unhandled type in serialization: ") << type;
595 }
596 
597 LogicalResult
598 Serializer::prepareFunctionType(Location loc, FunctionType type,
599                                 spirv::Opcode &typeEnum,
600                                 SmallVectorImpl<uint32_t> &operands) {
601   typeEnum = spirv::Opcode::OpTypeFunction;
602   assert(type.getNumResults() <= 1 &&
603          "serialization supports only a single return value");
604   uint32_t resultID = 0;
605   if (failed(processType(
606           loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
607           resultID))) {
608     return failure();
609   }
610   operands.push_back(resultID);
611   for (auto &res : type.getInputs()) {
612     uint32_t argTypeID = 0;
613     if (failed(processType(loc, res, argTypeID))) {
614       return failure();
615     }
616     operands.push_back(argTypeID);
617   }
618   return success();
619 }
620 
621 //===----------------------------------------------------------------------===//
622 // Constant
623 //===----------------------------------------------------------------------===//
624 
625 uint32_t Serializer::prepareConstant(Location loc, Type constType,
626                                      Attribute valueAttr) {
627   if (auto id = prepareConstantScalar(loc, valueAttr)) {
628     return id;
629   }
630 
631   // This is a composite literal. We need to handle each component separately
632   // and then emit an OpConstantComposite for the whole.
633 
634   if (auto id = getConstantID(valueAttr)) {
635     return id;
636   }
637 
638   uint32_t typeID = 0;
639   if (failed(processType(loc, constType, typeID))) {
640     return 0;
641   }
642 
643   uint32_t resultID = 0;
644   if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) {
645     int rank = attr.getType().dyn_cast<ShapedType>().getRank();
646     SmallVector<uint64_t, 4> index(rank);
647     resultID = prepareDenseElementsConstant(loc, constType, attr,
648                                             /*dim=*/0, index);
649   } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
650     resultID = prepareArrayConstant(loc, constType, arrayAttr);
651   }
652 
653   if (resultID == 0) {
654     emitError(loc, "cannot serialize attribute: ") << valueAttr;
655     return 0;
656   }
657 
658   constIDMap[valueAttr] = resultID;
659   return resultID;
660 }
661 
662 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
663                                           ArrayAttr attr) {
664   uint32_t typeID = 0;
665   if (failed(processType(loc, constType, typeID))) {
666     return 0;
667   }
668 
669   uint32_t resultID = getNextID();
670   SmallVector<uint32_t, 4> operands = {typeID, resultID};
671   operands.reserve(attr.size() + 2);
672   auto elementType = constType.cast<spirv::ArrayType>().getElementType();
673   for (Attribute elementAttr : attr) {
674     if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
675       operands.push_back(elementID);
676     } else {
677       return 0;
678     }
679   }
680   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
681   (void)encodeInstructionInto(typesGlobalValues, opcode, operands);
682 
683   return resultID;
684 }
685 
686 // TODO: Turn the below function into iterative function, instead of
687 // recursive function.
688 uint32_t
689 Serializer::prepareDenseElementsConstant(Location loc, Type constType,
690                                          DenseElementsAttr valueAttr, int dim,
691                                          MutableArrayRef<uint64_t> index) {
692   auto shapedType = valueAttr.getType().dyn_cast<ShapedType>();
693   assert(dim <= shapedType.getRank());
694   if (shapedType.getRank() == dim) {
695     if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
696       return attr.getType().getElementType().isInteger(1)
697                  ? prepareConstantBool(loc, attr.getValue<BoolAttr>(index))
698                  : prepareConstantInt(loc, attr.getValue<IntegerAttr>(index));
699     }
700     if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
701       return prepareConstantFp(loc, attr.getValue<FloatAttr>(index));
702     }
703     return 0;
704   }
705 
706   uint32_t typeID = 0;
707   if (failed(processType(loc, constType, typeID))) {
708     return 0;
709   }
710 
711   uint32_t resultID = getNextID();
712   SmallVector<uint32_t, 4> operands = {typeID, resultID};
713   operands.reserve(shapedType.getDimSize(dim) + 2);
714   auto elementType = constType.cast<spirv::CompositeType>().getElementType(0);
715   for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
716     index[dim] = i;
717     if (auto elementID = prepareDenseElementsConstant(
718             loc, elementType, valueAttr, dim + 1, index)) {
719       operands.push_back(elementID);
720     } else {
721       return 0;
722     }
723   }
724   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
725   (void)encodeInstructionInto(typesGlobalValues, opcode, operands);
726 
727   return resultID;
728 }
729 
730 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
731                                            bool isSpec) {
732   if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
733     return prepareConstantFp(loc, floatAttr, isSpec);
734   }
735   if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
736     return prepareConstantBool(loc, boolAttr, isSpec);
737   }
738   if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
739     return prepareConstantInt(loc, intAttr, isSpec);
740   }
741 
742   return 0;
743 }
744 
745 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
746                                          bool isSpec) {
747   if (!isSpec) {
748     // We can de-duplicate normal constants, but not specialization constants.
749     if (auto id = getConstantID(boolAttr)) {
750       return id;
751     }
752   }
753 
754   // Process the type for this bool literal
755   uint32_t typeID = 0;
756   if (failed(processType(loc, boolAttr.getType(), typeID))) {
757     return 0;
758   }
759 
760   auto resultID = getNextID();
761   auto opcode = boolAttr.getValue()
762                     ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
763                               : spirv::Opcode::OpConstantTrue)
764                     : (isSpec ? spirv::Opcode::OpSpecConstantFalse
765                               : spirv::Opcode::OpConstantFalse);
766   (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
767 
768   if (!isSpec) {
769     constIDMap[boolAttr] = resultID;
770   }
771   return resultID;
772 }
773 
774 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
775                                         bool isSpec) {
776   if (!isSpec) {
777     // We can de-duplicate normal constants, but not specialization constants.
778     if (auto id = getConstantID(intAttr)) {
779       return id;
780     }
781   }
782 
783   // Process the type for this integer literal
784   uint32_t typeID = 0;
785   if (failed(processType(loc, intAttr.getType(), typeID))) {
786     return 0;
787   }
788 
789   auto resultID = getNextID();
790   APInt value = intAttr.getValue();
791   unsigned bitwidth = value.getBitWidth();
792   bool isSigned = value.isSignedIntN(bitwidth);
793 
794   auto opcode =
795       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
796 
797   // According to SPIR-V spec, "When the type's bit width is less than 32-bits,
798   // the literal's value appears in the low-order bits of the word, and the
799   // high-order bits must be 0 for a floating-point type, or 0 for an integer
800   // type with Signedness of 0, or sign extended when Signedness is 1."
801   if (bitwidth == 32 || bitwidth == 16) {
802     uint32_t word = 0;
803     if (isSigned) {
804       word = static_cast<int32_t>(value.getSExtValue());
805     } else {
806       word = static_cast<uint32_t>(value.getZExtValue());
807     }
808     (void)encodeInstructionInto(typesGlobalValues, opcode,
809                                 {typeID, resultID, word});
810   }
811   // According to SPIR-V spec: "When the type's bit width is larger than one
812   // word, the literal’s low-order words appear first."
813   else if (bitwidth == 64) {
814     struct DoubleWord {
815       uint32_t word1;
816       uint32_t word2;
817     } words;
818     if (isSigned) {
819       words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
820     } else {
821       words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
822     }
823     (void)encodeInstructionInto(typesGlobalValues, opcode,
824                                 {typeID, resultID, words.word1, words.word2});
825   } else {
826     std::string valueStr;
827     llvm::raw_string_ostream rss(valueStr);
828     value.print(rss, /*isSigned=*/false);
829 
830     emitError(loc, "cannot serialize ")
831         << bitwidth << "-bit integer literal: " << rss.str();
832     return 0;
833   }
834 
835   if (!isSpec) {
836     constIDMap[intAttr] = resultID;
837   }
838   return resultID;
839 }
840 
841 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
842                                        bool isSpec) {
843   if (!isSpec) {
844     // We can de-duplicate normal constants, but not specialization constants.
845     if (auto id = getConstantID(floatAttr)) {
846       return id;
847     }
848   }
849 
850   // Process the type for this float literal
851   uint32_t typeID = 0;
852   if (failed(processType(loc, floatAttr.getType(), typeID))) {
853     return 0;
854   }
855 
856   auto resultID = getNextID();
857   APFloat value = floatAttr.getValue();
858   APInt intValue = value.bitcastToAPInt();
859 
860   auto opcode =
861       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
862 
863   if (&value.getSemantics() == &APFloat::IEEEsingle()) {
864     uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
865     (void)encodeInstructionInto(typesGlobalValues, opcode,
866                                 {typeID, resultID, word});
867   } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
868     struct DoubleWord {
869       uint32_t word1;
870       uint32_t word2;
871     } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
872     (void)encodeInstructionInto(typesGlobalValues, opcode,
873                                 {typeID, resultID, words.word1, words.word2});
874   } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
875     uint32_t word =
876         static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
877     (void)encodeInstructionInto(typesGlobalValues, opcode,
878                                 {typeID, resultID, word});
879   } else {
880     std::string valueStr;
881     llvm::raw_string_ostream rss(valueStr);
882     value.print(rss);
883 
884     emitError(loc, "cannot serialize ")
885         << floatAttr.getType() << "-typed float literal: " << rss.str();
886     return 0;
887   }
888 
889   if (!isSpec) {
890     constIDMap[floatAttr] = resultID;
891   }
892   return resultID;
893 }
894 
895 //===----------------------------------------------------------------------===//
896 // Control flow
897 //===----------------------------------------------------------------------===//
898 
899 uint32_t Serializer::getOrCreateBlockID(Block *block) {
900   if (uint32_t id = getBlockID(block))
901     return id;
902   return blockIDMap[block] = getNextID();
903 }
904 
905 LogicalResult
906 Serializer::processBlock(Block *block, bool omitLabel,
907                          function_ref<void()> actionBeforeTerminator) {
908   LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
909   LLVM_DEBUG(block->print(llvm::dbgs()));
910   LLVM_DEBUG(llvm::dbgs() << '\n');
911   if (!omitLabel) {
912     uint32_t blockID = getOrCreateBlockID(block);
913     LLVM_DEBUG(llvm::dbgs()
914                << "[block] " << block << " (id = " << blockID << ")\n");
915 
916     // Emit OpLabel for this block.
917     (void)encodeInstructionInto(functionBody, spirv::Opcode::OpLabel,
918                                 {blockID});
919   }
920 
921   // Emit OpPhi instructions for block arguments, if any.
922   if (failed(emitPhiForBlockArguments(block)))
923     return failure();
924 
925   // Process each op in this block except the terminator.
926   for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
927     if (failed(processOperation(&op)))
928       return failure();
929   }
930 
931   // Process the terminator.
932   if (actionBeforeTerminator)
933     actionBeforeTerminator();
934   if (failed(processOperation(&block->back())))
935     return failure();
936 
937   return success();
938 }
939 
940 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
941   // Nothing to do if this block has no arguments or it's the entry block, which
942   // always has the same arguments as the function signature.
943   if (block->args_empty() || block->isEntryBlock())
944     return success();
945 
946   // If the block has arguments, we need to create SPIR-V OpPhi instructions.
947   // A SPIR-V OpPhi instruction is of the syntax:
948   //   OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
949   // So we need to collect all predecessor blocks and the arguments they send
950   // to this block.
951   SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors;
952   for (Block *predecessor : block->getPredecessors()) {
953     auto *terminator = predecessor->getTerminator();
954     // The predecessor here is the immediate one according to MLIR's IR
955     // structure. It does not directly map to the incoming parent block for the
956     // OpPhi instructions at SPIR-V binary level. This is because structured
957     // control flow ops are serialized to multiple SPIR-V blocks. If there is a
958     // spv.selection/spv.loop op in the MLIR predecessor block, the branch op
959     // jumping to the OpPhi's block then resides in the previous structured
960     // control flow op's merge block.
961     predecessor = getPhiIncomingBlock(predecessor);
962     if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
963       predecessors.emplace_back(predecessor, branchOp.operand_begin());
964     } else {
965       return terminator->emitError("unimplemented terminator for Phi creation");
966     }
967   }
968 
969   // Then create OpPhi instruction for each of the block argument.
970   for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
971     BlockArgument arg = block->getArgument(argIndex);
972 
973     // Get the type <id> and result <id> for this OpPhi instruction.
974     uint32_t phiTypeID = 0;
975     if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
976       return failure();
977     uint32_t phiID = getNextID();
978 
979     LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
980                             << arg << " (id = " << phiID << ")\n");
981 
982     // Prepare the (value <id>, parent block <id>) pairs.
983     SmallVector<uint32_t, 8> phiArgs;
984     phiArgs.push_back(phiTypeID);
985     phiArgs.push_back(phiID);
986 
987     for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
988       Value value = *(predecessors[predIndex].second + argIndex);
989       uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
990       LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
991                               << ") value " << value << ' ');
992       // Each pair is a value <id> ...
993       uint32_t valueId = getValueID(value);
994       if (valueId == 0) {
995         // The op generating this value hasn't been visited yet so we don't have
996         // an <id> assigned yet. Record this to fix up later.
997         LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
998         deferredPhiValues[value].push_back(functionBody.size() + 1 +
999                                            phiArgs.size());
1000       } else {
1001         LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1002       }
1003       phiArgs.push_back(valueId);
1004       // ... and a parent block <id>.
1005       phiArgs.push_back(predBlockId);
1006     }
1007 
1008     (void)encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1009     valueIDMap[arg] = phiID;
1010   }
1011 
1012   return success();
1013 }
1014 
1015 //===----------------------------------------------------------------------===//
1016 // Operation
1017 //===----------------------------------------------------------------------===//
1018 
1019 LogicalResult Serializer::encodeExtensionInstruction(
1020     Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1021     ArrayRef<uint32_t> operands) {
1022   // Check if the extension has been imported.
1023   auto &setID = extendedInstSetIDMap[extensionSetName];
1024   if (!setID) {
1025     setID = getNextID();
1026     SmallVector<uint32_t, 16> importOperands;
1027     importOperands.push_back(setID);
1028     if (failed(
1029             spirv::encodeStringLiteralInto(importOperands, extensionSetName)) ||
1030         failed(encodeInstructionInto(
1031             extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) {
1032       return failure();
1033     }
1034   }
1035 
1036   // The first two operands are the result type <id> and result <id>. The set
1037   // <id> and the opcode need to be insert after this.
1038   if (operands.size() < 2) {
1039     return op->emitError("extended instructions must have a result encoding");
1040   }
1041   SmallVector<uint32_t, 8> extInstOperands;
1042   extInstOperands.reserve(operands.size() + 2);
1043   extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1044   extInstOperands.push_back(setID);
1045   extInstOperands.push_back(extensionOpcode);
1046   extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1047   return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1048                                extInstOperands);
1049 }
1050 
1051 LogicalResult Serializer::processOperation(Operation *opInst) {
1052   LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1053 
1054   // First dispatch the ops that do not directly mirror an instruction from
1055   // the SPIR-V spec.
1056   return TypeSwitch<Operation *, LogicalResult>(opInst)
1057       .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1058       .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1059       .Case([&](spirv::BranchConditionalOp op) {
1060         return processBranchConditionalOp(op);
1061       })
1062       .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1063       .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1064       .Case([&](spirv::GlobalVariableOp op) {
1065         return processGlobalVariableOp(op);
1066       })
1067       .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1068       .Case([&](spirv::ModuleEndOp) { return success(); })
1069       .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1070       .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1071       .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1072       .Case([&](spirv::SpecConstantCompositeOp op) {
1073         return processSpecConstantCompositeOp(op);
1074       })
1075       .Case([&](spirv::SpecConstantOperationOp op) {
1076         return processSpecConstantOperationOp(op);
1077       })
1078       .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1079       .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1080 
1081       // Then handle all the ops that directly mirror SPIR-V instructions with
1082       // auto-generated methods.
1083       .Default(
1084           [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1085 }
1086 
1087 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1088                                                       StringRef extInstSet,
1089                                                       uint32_t opcode) {
1090   SmallVector<uint32_t, 4> operands;
1091   Location loc = op->getLoc();
1092 
1093   uint32_t resultID = 0;
1094   if (op->getNumResults() != 0) {
1095     uint32_t resultTypeID = 0;
1096     if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
1097       return failure();
1098     operands.push_back(resultTypeID);
1099 
1100     resultID = getNextID();
1101     operands.push_back(resultID);
1102     valueIDMap[op->getResult(0)] = resultID;
1103   };
1104 
1105   for (Value operand : op->getOperands())
1106     operands.push_back(getValueID(operand));
1107 
1108   (void)emitDebugLine(functionBody, loc);
1109 
1110   if (extInstSet.empty()) {
1111     (void)encodeInstructionInto(functionBody,
1112                                 static_cast<spirv::Opcode>(opcode), operands);
1113   } else {
1114     (void)encodeExtensionInstruction(op, extInstSet, opcode, operands);
1115   }
1116 
1117   if (op->getNumResults() != 0) {
1118     for (auto attr : op->getAttrs()) {
1119       if (failed(processDecoration(loc, resultID, attr)))
1120         return failure();
1121     }
1122   }
1123 
1124   return success();
1125 }
1126 
1127 LogicalResult Serializer::emitDecoration(uint32_t target,
1128                                          spirv::Decoration decoration,
1129                                          ArrayRef<uint32_t> params) {
1130   uint32_t wordCount = 3 + params.size();
1131   decorations.push_back(
1132       spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate));
1133   decorations.push_back(target);
1134   decorations.push_back(static_cast<uint32_t>(decoration));
1135   decorations.append(params.begin(), params.end());
1136   return success();
1137 }
1138 
1139 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1140                                         Location loc) {
1141   if (!emitDebugInfo)
1142     return success();
1143 
1144   if (lastProcessedWasMergeInst) {
1145     lastProcessedWasMergeInst = false;
1146     return success();
1147   }
1148 
1149   auto fileLoc = loc.dyn_cast<FileLineColLoc>();
1150   if (fileLoc)
1151     (void)encodeInstructionInto(
1152         binary, spirv::Opcode::OpLine,
1153         {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1154   return success();
1155 }
1156 } // namespace spirv
1157 } // namespace mlir
1158