1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "Serializer.h"
14
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Support/LogicalResult.h"
19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/ADT/bit.h"
25 #include "llvm/Support/Debug.h"
26
27 #define DEBUG_TYPE "spirv-serialization"
28
29 using namespace mlir;
30
31 /// Returns the merge block if the given `op` is a structured control flow op.
32 /// Otherwise returns nullptr.
getStructuredControlFlowOpMergeBlock(Operation * op)33 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) {
34 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
35 return selectionOp.getMergeBlock();
36 if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
37 return loopOp.getMergeBlock();
38 return nullptr;
39 }
40
41 /// Given a predecessor `block` for a block with arguments, returns the block
42 /// that should be used as the parent block for SPIR-V OpPhi instructions
43 /// corresponding to the block arguments.
getPhiIncomingBlock(Block * block)44 static Block *getPhiIncomingBlock(Block *block) {
45 // If the predecessor block in question is the entry block for a
46 // spv.mlir.loop, we jump to this spv.mlir.loop from its enclosing block.
47 if (block->isEntryBlock()) {
48 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
49 // Then the incoming parent block for OpPhi should be the merge block of
50 // the structured control flow op before this loop.
51 Operation *op = loopOp.getOperation();
52 while ((op = op->getPrevNode()) != nullptr)
53 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
54 return incomingBlock;
55 // Or the enclosing block itself if no structured control flow ops
56 // exists before this loop.
57 return loopOp->getBlock();
58 }
59 }
60
61 // Otherwise, we jump from the given predecessor block. Try to see if there is
62 // a structured control flow op inside it.
63 for (Operation &op : llvm::reverse(block->getOperations())) {
64 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
65 return incomingBlock;
66 }
67 return block;
68 }
69
70 namespace mlir {
71 namespace spirv {
72
73 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
74 /// the given `binary` vector.
encodeInstructionInto(SmallVectorImpl<uint32_t> & binary,spirv::Opcode op,ArrayRef<uint32_t> operands)75 void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
76 ArrayRef<uint32_t> operands) {
77 uint32_t wordCount = 1 + operands.size();
78 binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
79 binary.append(operands.begin(), operands.end());
80 }
81
Serializer(spirv::ModuleOp module,const SerializationOptions & options)82 Serializer::Serializer(spirv::ModuleOp module,
83 const SerializationOptions &options)
84 : module(module), mlirBuilder(module.getContext()), options(options) {}
85
serialize()86 LogicalResult Serializer::serialize() {
87 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
88
89 if (failed(module.verifyInvariants()))
90 return failure();
91
92 // TODO: handle the other sections
93 processCapability();
94 processExtension();
95 processMemoryModel();
96 processDebugInfo();
97
98 // Iterate over the module body to serialize it. Assumptions are that there is
99 // only one basic block in the moduleOp
100 for (auto &op : *module.getBody()) {
101 if (failed(processOperation(&op))) {
102 return failure();
103 }
104 }
105
106 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
107 return success();
108 }
109
collect(SmallVectorImpl<uint32_t> & binary)110 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
111 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
112 extensions.size() + extendedSets.size() +
113 memoryModel.size() + entryPoints.size() +
114 executionModes.size() + decorations.size() +
115 typesGlobalValues.size() + functions.size();
116
117 binary.clear();
118 binary.reserve(moduleSize);
119
120 spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID);
121 binary.append(capabilities.begin(), capabilities.end());
122 binary.append(extensions.begin(), extensions.end());
123 binary.append(extendedSets.begin(), extendedSets.end());
124 binary.append(memoryModel.begin(), memoryModel.end());
125 binary.append(entryPoints.begin(), entryPoints.end());
126 binary.append(executionModes.begin(), executionModes.end());
127 binary.append(debug.begin(), debug.end());
128 binary.append(names.begin(), names.end());
129 binary.append(decorations.begin(), decorations.end());
130 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
131 binary.append(functions.begin(), functions.end());
132 }
133
134 #ifndef NDEBUG
printValueIDMap(raw_ostream & os)135 void Serializer::printValueIDMap(raw_ostream &os) {
136 os << "\n= Value <id> Map =\n\n";
137 for (auto valueIDPair : valueIDMap) {
138 Value val = valueIDPair.first;
139 os << " " << val << " "
140 << "id = " << valueIDPair.second << ' ';
141 if (auto *op = val.getDefiningOp()) {
142 os << "from op '" << op->getName() << "'";
143 } else if (auto arg = val.dyn_cast<BlockArgument>()) {
144 Block *block = arg.getOwner();
145 os << "from argument of block " << block << ' ';
146 os << " in op '" << block->getParentOp()->getName() << "'";
147 }
148 os << '\n';
149 }
150 }
151 #endif
152
153 //===----------------------------------------------------------------------===//
154 // Module structure
155 //===----------------------------------------------------------------------===//
156
getOrCreateFunctionID(StringRef fnName)157 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
158 auto funcID = funcIDMap.lookup(fnName);
159 if (!funcID) {
160 funcID = getNextID();
161 funcIDMap[fnName] = funcID;
162 }
163 return funcID;
164 }
165
processCapability()166 void Serializer::processCapability() {
167 for (auto cap : module.vce_triple()->getCapabilities())
168 encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
169 {static_cast<uint32_t>(cap)});
170 }
171
processDebugInfo()172 void Serializer::processDebugInfo() {
173 if (!options.emitDebugInfo)
174 return;
175 auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>();
176 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
177 fileID = getNextID();
178 SmallVector<uint32_t, 16> operands;
179 operands.push_back(fileID);
180 spirv::encodeStringLiteralInto(operands, fileName);
181 encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
182 // TODO: Encode more debug instructions.
183 }
184
processExtension()185 void Serializer::processExtension() {
186 llvm::SmallVector<uint32_t, 16> extName;
187 for (spirv::Extension ext : module.vce_triple()->getExtensions()) {
188 extName.clear();
189 spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
190 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
191 }
192 }
193
processMemoryModel()194 void Serializer::processMemoryModel() {
195 uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt();
196 uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt();
197
198 encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
199 }
200
processDecoration(Location loc,uint32_t resultID,NamedAttribute attr)201 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
202 NamedAttribute attr) {
203 auto attrName = attr.getName().strref();
204 auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true);
205 auto decoration = spirv::symbolizeDecoration(decorationName);
206 if (!decoration) {
207 return emitError(
208 loc, "non-argument attributes expected to have snake-case-ified "
209 "decoration name, unhandled attribute with name : ")
210 << attrName;
211 }
212 SmallVector<uint32_t, 1> args;
213 switch (*decoration) {
214 case spirv::Decoration::Binding:
215 case spirv::Decoration::DescriptorSet:
216 case spirv::Decoration::Location:
217 if (auto intAttr = attr.getValue().dyn_cast<IntegerAttr>()) {
218 args.push_back(intAttr.getValue().getZExtValue());
219 break;
220 }
221 return emitError(loc, "expected integer attribute for ") << attrName;
222 case spirv::Decoration::BuiltIn:
223 if (auto strAttr = attr.getValue().dyn_cast<StringAttr>()) {
224 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
225 if (enumVal) {
226 args.push_back(static_cast<uint32_t>(*enumVal));
227 break;
228 }
229 return emitError(loc, "invalid ")
230 << attrName << " attribute " << strAttr.getValue();
231 }
232 return emitError(loc, "expected string attribute for ") << attrName;
233 case spirv::Decoration::Aliased:
234 case spirv::Decoration::Flat:
235 case spirv::Decoration::NonReadable:
236 case spirv::Decoration::NonWritable:
237 case spirv::Decoration::NoPerspective:
238 case spirv::Decoration::Restrict:
239 case spirv::Decoration::RelaxedPrecision:
240 // For unit attributes, the args list has no values so we do nothing
241 if (auto unitAttr = attr.getValue().dyn_cast<UnitAttr>())
242 break;
243 return emitError(loc, "expected unit attribute for ") << attrName;
244 default:
245 return emitError(loc, "unhandled decoration ") << decorationName;
246 }
247 return emitDecoration(resultID, *decoration, args);
248 }
249
processName(uint32_t resultID,StringRef name)250 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
251 assert(!name.empty() && "unexpected empty string for OpName");
252 if (!options.emitSymbolName)
253 return success();
254
255 SmallVector<uint32_t, 4> nameOperands;
256 nameOperands.push_back(resultID);
257 spirv::encodeStringLiteralInto(nameOperands, name);
258 encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
259 return success();
260 }
261
262 template <>
processTypeDecoration(Location loc,spirv::ArrayType type,uint32_t resultID)263 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
264 Location loc, spirv::ArrayType type, uint32_t resultID) {
265 if (unsigned stride = type.getArrayStride()) {
266 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
267 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
268 }
269 return success();
270 }
271
272 template <>
processTypeDecoration(Location loc,spirv::RuntimeArrayType type,uint32_t resultID)273 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
274 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
275 if (unsigned stride = type.getArrayStride()) {
276 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
277 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
278 }
279 return success();
280 }
281
processMemberDecoration(uint32_t structID,const spirv::StructType::MemberDecorationInfo & memberDecoration)282 LogicalResult Serializer::processMemberDecoration(
283 uint32_t structID,
284 const spirv::StructType::MemberDecorationInfo &memberDecoration) {
285 SmallVector<uint32_t, 4> args(
286 {structID, memberDecoration.memberIndex,
287 static_cast<uint32_t>(memberDecoration.decoration)});
288 if (memberDecoration.hasValue) {
289 args.push_back(memberDecoration.decorationValue);
290 }
291 encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
292 return success();
293 }
294
295 //===----------------------------------------------------------------------===//
296 // Type
297 //===----------------------------------------------------------------------===//
298
299 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
300 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
301 // PushConstant Storage Classes must be explicitly laid out."
isInterfaceStructPtrType(Type type) const302 bool Serializer::isInterfaceStructPtrType(Type type) const {
303 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
304 switch (ptrType.getStorageClass()) {
305 case spirv::StorageClass::PhysicalStorageBuffer:
306 case spirv::StorageClass::PushConstant:
307 case spirv::StorageClass::StorageBuffer:
308 case spirv::StorageClass::Uniform:
309 return ptrType.getPointeeType().isa<spirv::StructType>();
310 default:
311 break;
312 }
313 }
314 return false;
315 }
316
processType(Location loc,Type type,uint32_t & typeID)317 LogicalResult Serializer::processType(Location loc, Type type,
318 uint32_t &typeID) {
319 // Maintains a set of names for nested identified struct types. This is used
320 // to properly serialize recursive references.
321 SetVector<StringRef> serializationCtx;
322 return processTypeImpl(loc, type, typeID, serializationCtx);
323 }
324
325 LogicalResult
processTypeImpl(Location loc,Type type,uint32_t & typeID,SetVector<StringRef> & serializationCtx)326 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
327 SetVector<StringRef> &serializationCtx) {
328 typeID = getTypeID(type);
329 if (typeID)
330 return success();
331
332 typeID = getNextID();
333 SmallVector<uint32_t, 4> operands;
334
335 operands.push_back(typeID);
336 auto typeEnum = spirv::Opcode::OpTypeVoid;
337 bool deferSerialization = false;
338
339 if ((type.isa<FunctionType>() &&
340 succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
341 operands))) ||
342 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
343 deferSerialization, serializationCtx))) {
344 if (deferSerialization)
345 return success();
346
347 typeIDMap[type] = typeID;
348
349 encodeInstructionInto(typesGlobalValues, typeEnum, operands);
350
351 if (recursiveStructInfos.count(type) != 0) {
352 // This recursive struct type is emitted already, now the OpTypePointer
353 // instructions referring to recursive references are emitted as well.
354 for (auto &ptrInfo : recursiveStructInfos[type]) {
355 // TODO: This might not work if more than 1 recursive reference is
356 // present in the struct.
357 SmallVector<uint32_t, 4> ptrOperands;
358 ptrOperands.push_back(ptrInfo.pointerTypeID);
359 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
360 ptrOperands.push_back(typeIDMap[type]);
361
362 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
363 ptrOperands);
364 }
365
366 recursiveStructInfos[type].clear();
367 }
368
369 return success();
370 }
371
372 return failure();
373 }
374
prepareBasicType(Location loc,Type type,uint32_t resultID,spirv::Opcode & typeEnum,SmallVectorImpl<uint32_t> & operands,bool & deferSerialization,SetVector<StringRef> & serializationCtx)375 LogicalResult Serializer::prepareBasicType(
376 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
377 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
378 SetVector<StringRef> &serializationCtx) {
379 deferSerialization = false;
380
381 if (isVoidType(type)) {
382 typeEnum = spirv::Opcode::OpTypeVoid;
383 return success();
384 }
385
386 if (auto intType = type.dyn_cast<IntegerType>()) {
387 if (intType.getWidth() == 1) {
388 typeEnum = spirv::Opcode::OpTypeBool;
389 return success();
390 }
391
392 typeEnum = spirv::Opcode::OpTypeInt;
393 operands.push_back(intType.getWidth());
394 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
395 // to preserve or validate.
396 // 0 indicates unsigned, or no signedness semantics
397 // 1 indicates signed semantics."
398 operands.push_back(intType.isSigned() ? 1 : 0);
399 return success();
400 }
401
402 if (auto floatType = type.dyn_cast<FloatType>()) {
403 typeEnum = spirv::Opcode::OpTypeFloat;
404 operands.push_back(floatType.getWidth());
405 return success();
406 }
407
408 if (auto vectorType = type.dyn_cast<VectorType>()) {
409 uint32_t elementTypeID = 0;
410 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
411 serializationCtx))) {
412 return failure();
413 }
414 typeEnum = spirv::Opcode::OpTypeVector;
415 operands.push_back(elementTypeID);
416 operands.push_back(vectorType.getNumElements());
417 return success();
418 }
419
420 if (auto imageType = type.dyn_cast<spirv::ImageType>()) {
421 typeEnum = spirv::Opcode::OpTypeImage;
422 uint32_t sampledTypeID = 0;
423 if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
424 return failure();
425
426 operands.push_back(sampledTypeID);
427 operands.push_back(static_cast<uint32_t>(imageType.getDim()));
428 operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo()));
429 operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo()));
430 operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo()));
431 operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo()));
432 operands.push_back(static_cast<uint32_t>(imageType.getImageFormat()));
433 return success();
434 }
435
436 if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
437 typeEnum = spirv::Opcode::OpTypeArray;
438 uint32_t elementTypeID = 0;
439 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
440 serializationCtx))) {
441 return failure();
442 }
443 operands.push_back(elementTypeID);
444 if (auto elementCountID = prepareConstantInt(
445 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
446 operands.push_back(elementCountID);
447 }
448 return processTypeDecoration(loc, arrayType, resultID);
449 }
450
451 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
452 uint32_t pointeeTypeID = 0;
453 spirv::StructType pointeeStruct =
454 ptrType.getPointeeType().dyn_cast<spirv::StructType>();
455
456 if (pointeeStruct && pointeeStruct.isIdentified() &&
457 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
458 // A recursive reference to an enclosing struct is found.
459 //
460 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
461 // class as operands.
462 SmallVector<uint32_t, 2> forwardPtrOperands;
463 forwardPtrOperands.push_back(resultID);
464 forwardPtrOperands.push_back(
465 static_cast<uint32_t>(ptrType.getStorageClass()));
466
467 encodeInstructionInto(typesGlobalValues,
468 spirv::Opcode::OpTypeForwardPointer,
469 forwardPtrOperands);
470
471 // 2. Find the pointee (enclosing) struct.
472 auto structType = spirv::StructType::getIdentified(
473 module.getContext(), pointeeStruct.getIdentifier());
474
475 if (!structType)
476 return failure();
477
478 // 3. Mark the OpTypePointer that is supposed to be emitted by this call
479 // as deferred.
480 deferSerialization = true;
481
482 // 4. Record the info needed to emit the deferred OpTypePointer
483 // instruction when the enclosing struct is completely serialized.
484 recursiveStructInfos[structType].push_back(
485 {resultID, ptrType.getStorageClass()});
486 } else {
487 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
488 serializationCtx)))
489 return failure();
490 }
491
492 typeEnum = spirv::Opcode::OpTypePointer;
493 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
494 operands.push_back(pointeeTypeID);
495
496 if (isInterfaceStructPtrType(ptrType)) {
497 if (failed(emitDecoration(getTypeID(pointeeStruct),
498 spirv::Decoration::Block)))
499 return emitError(loc, "cannot decorate ")
500 << pointeeStruct << " with Block decoration";
501 }
502
503 return success();
504 }
505
506 if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
507 uint32_t elementTypeID = 0;
508 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
509 elementTypeID, serializationCtx))) {
510 return failure();
511 }
512 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
513 operands.push_back(elementTypeID);
514 return processTypeDecoration(loc, runtimeArrayType, resultID);
515 }
516
517 if (auto sampledImageType = type.dyn_cast<spirv::SampledImageType>()) {
518 typeEnum = spirv::Opcode::OpTypeSampledImage;
519 uint32_t imageTypeID = 0;
520 if (failed(
521 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
522 return failure();
523 }
524 operands.push_back(imageTypeID);
525 return success();
526 }
527
528 if (auto structType = type.dyn_cast<spirv::StructType>()) {
529 if (structType.isIdentified()) {
530 if (failed(processName(resultID, structType.getIdentifier())))
531 return failure();
532 serializationCtx.insert(structType.getIdentifier());
533 }
534
535 bool hasOffset = structType.hasOffset();
536 for (auto elementIndex :
537 llvm::seq<uint32_t>(0, structType.getNumElements())) {
538 uint32_t elementTypeID = 0;
539 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
540 elementTypeID, serializationCtx))) {
541 return failure();
542 }
543 operands.push_back(elementTypeID);
544 if (hasOffset) {
545 // Decorate each struct member with an offset
546 spirv::StructType::MemberDecorationInfo offsetDecoration{
547 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
548 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
549 if (failed(processMemberDecoration(resultID, offsetDecoration))) {
550 return emitError(loc, "cannot decorate ")
551 << elementIndex << "-th member of " << structType
552 << " with its offset";
553 }
554 }
555 }
556 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
557 structType.getMemberDecorations(memberDecorations);
558
559 for (auto &memberDecoration : memberDecorations) {
560 if (failed(processMemberDecoration(resultID, memberDecoration))) {
561 return emitError(loc, "cannot decorate ")
562 << static_cast<uint32_t>(memberDecoration.memberIndex)
563 << "-th member of " << structType << " with "
564 << stringifyDecoration(memberDecoration.decoration);
565 }
566 }
567
568 typeEnum = spirv::Opcode::OpTypeStruct;
569
570 if (structType.isIdentified())
571 serializationCtx.remove(structType.getIdentifier());
572
573 return success();
574 }
575
576 if (auto cooperativeMatrixType =
577 type.dyn_cast<spirv::CooperativeMatrixNVType>()) {
578 uint32_t elementTypeID = 0;
579 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
580 elementTypeID, serializationCtx))) {
581 return failure();
582 }
583 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
584 auto getConstantOp = [&](uint32_t id) {
585 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
586 return prepareConstantInt(loc, attr);
587 };
588 operands.push_back(elementTypeID);
589 operands.push_back(
590 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
591 operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
592 operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
593 return success();
594 }
595
596 if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
597 uint32_t elementTypeID = 0;
598 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
599 serializationCtx))) {
600 return failure();
601 }
602 typeEnum = spirv::Opcode::OpTypeMatrix;
603 operands.push_back(elementTypeID);
604 operands.push_back(matrixType.getNumColumns());
605 return success();
606 }
607
608 // TODO: Handle other types.
609 return emitError(loc, "unhandled type in serialization: ") << type;
610 }
611
612 LogicalResult
prepareFunctionType(Location loc,FunctionType type,spirv::Opcode & typeEnum,SmallVectorImpl<uint32_t> & operands)613 Serializer::prepareFunctionType(Location loc, FunctionType type,
614 spirv::Opcode &typeEnum,
615 SmallVectorImpl<uint32_t> &operands) {
616 typeEnum = spirv::Opcode::OpTypeFunction;
617 assert(type.getNumResults() <= 1 &&
618 "serialization supports only a single return value");
619 uint32_t resultID = 0;
620 if (failed(processType(
621 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
622 resultID))) {
623 return failure();
624 }
625 operands.push_back(resultID);
626 for (auto &res : type.getInputs()) {
627 uint32_t argTypeID = 0;
628 if (failed(processType(loc, res, argTypeID))) {
629 return failure();
630 }
631 operands.push_back(argTypeID);
632 }
633 return success();
634 }
635
636 //===----------------------------------------------------------------------===//
637 // Constant
638 //===----------------------------------------------------------------------===//
639
prepareConstant(Location loc,Type constType,Attribute valueAttr)640 uint32_t Serializer::prepareConstant(Location loc, Type constType,
641 Attribute valueAttr) {
642 if (auto id = prepareConstantScalar(loc, valueAttr)) {
643 return id;
644 }
645
646 // This is a composite literal. We need to handle each component separately
647 // and then emit an OpConstantComposite for the whole.
648
649 if (auto id = getConstantID(valueAttr)) {
650 return id;
651 }
652
653 uint32_t typeID = 0;
654 if (failed(processType(loc, constType, typeID))) {
655 return 0;
656 }
657
658 uint32_t resultID = 0;
659 if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) {
660 int rank = attr.getType().dyn_cast<ShapedType>().getRank();
661 SmallVector<uint64_t, 4> index(rank);
662 resultID = prepareDenseElementsConstant(loc, constType, attr,
663 /*dim=*/0, index);
664 } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
665 resultID = prepareArrayConstant(loc, constType, arrayAttr);
666 }
667
668 if (resultID == 0) {
669 emitError(loc, "cannot serialize attribute: ") << valueAttr;
670 return 0;
671 }
672
673 constIDMap[valueAttr] = resultID;
674 return resultID;
675 }
676
prepareArrayConstant(Location loc,Type constType,ArrayAttr attr)677 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
678 ArrayAttr attr) {
679 uint32_t typeID = 0;
680 if (failed(processType(loc, constType, typeID))) {
681 return 0;
682 }
683
684 uint32_t resultID = getNextID();
685 SmallVector<uint32_t, 4> operands = {typeID, resultID};
686 operands.reserve(attr.size() + 2);
687 auto elementType = constType.cast<spirv::ArrayType>().getElementType();
688 for (Attribute elementAttr : attr) {
689 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
690 operands.push_back(elementID);
691 } else {
692 return 0;
693 }
694 }
695 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
696 encodeInstructionInto(typesGlobalValues, opcode, operands);
697
698 return resultID;
699 }
700
701 // TODO: Turn the below function into iterative function, instead of
702 // recursive function.
703 uint32_t
prepareDenseElementsConstant(Location loc,Type constType,DenseElementsAttr valueAttr,int dim,MutableArrayRef<uint64_t> index)704 Serializer::prepareDenseElementsConstant(Location loc, Type constType,
705 DenseElementsAttr valueAttr, int dim,
706 MutableArrayRef<uint64_t> index) {
707 auto shapedType = valueAttr.getType().dyn_cast<ShapedType>();
708 assert(dim <= shapedType.getRank());
709 if (shapedType.getRank() == dim) {
710 if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
711 return attr.getType().getElementType().isInteger(1)
712 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
713 : prepareConstantInt(loc,
714 attr.getValues<IntegerAttr>()[index]);
715 }
716 if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
717 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
718 }
719 return 0;
720 }
721
722 uint32_t typeID = 0;
723 if (failed(processType(loc, constType, typeID))) {
724 return 0;
725 }
726
727 uint32_t resultID = getNextID();
728 SmallVector<uint32_t, 4> operands = {typeID, resultID};
729 operands.reserve(shapedType.getDimSize(dim) + 2);
730 auto elementType = constType.cast<spirv::CompositeType>().getElementType(0);
731 for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
732 index[dim] = i;
733 if (auto elementID = prepareDenseElementsConstant(
734 loc, elementType, valueAttr, dim + 1, index)) {
735 operands.push_back(elementID);
736 } else {
737 return 0;
738 }
739 }
740 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
741 encodeInstructionInto(typesGlobalValues, opcode, operands);
742
743 return resultID;
744 }
745
prepareConstantScalar(Location loc,Attribute valueAttr,bool isSpec)746 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
747 bool isSpec) {
748 if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
749 return prepareConstantFp(loc, floatAttr, isSpec);
750 }
751 if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
752 return prepareConstantBool(loc, boolAttr, isSpec);
753 }
754 if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
755 return prepareConstantInt(loc, intAttr, isSpec);
756 }
757
758 return 0;
759 }
760
prepareConstantBool(Location loc,BoolAttr boolAttr,bool isSpec)761 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
762 bool isSpec) {
763 if (!isSpec) {
764 // We can de-duplicate normal constants, but not specialization constants.
765 if (auto id = getConstantID(boolAttr)) {
766 return id;
767 }
768 }
769
770 // Process the type for this bool literal
771 uint32_t typeID = 0;
772 if (failed(processType(loc, boolAttr.getType(), typeID))) {
773 return 0;
774 }
775
776 auto resultID = getNextID();
777 auto opcode = boolAttr.getValue()
778 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
779 : spirv::Opcode::OpConstantTrue)
780 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
781 : spirv::Opcode::OpConstantFalse);
782 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
783
784 if (!isSpec) {
785 constIDMap[boolAttr] = resultID;
786 }
787 return resultID;
788 }
789
prepareConstantInt(Location loc,IntegerAttr intAttr,bool isSpec)790 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
791 bool isSpec) {
792 if (!isSpec) {
793 // We can de-duplicate normal constants, but not specialization constants.
794 if (auto id = getConstantID(intAttr)) {
795 return id;
796 }
797 }
798
799 // Process the type for this integer literal
800 uint32_t typeID = 0;
801 if (failed(processType(loc, intAttr.getType(), typeID))) {
802 return 0;
803 }
804
805 auto resultID = getNextID();
806 APInt value = intAttr.getValue();
807 unsigned bitwidth = value.getBitWidth();
808 bool isSigned = value.isSignedIntN(bitwidth);
809
810 auto opcode =
811 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
812
813 switch (bitwidth) {
814 // According to SPIR-V spec, "When the type's bit width is less than
815 // 32-bits, the literal's value appears in the low-order bits of the word,
816 // and the high-order bits must be 0 for a floating-point type, or 0 for an
817 // integer type with Signedness of 0, or sign extended when Signedness
818 // is 1."
819 case 32:
820 case 16:
821 case 8: {
822 uint32_t word = 0;
823 if (isSigned) {
824 word = static_cast<int32_t>(value.getSExtValue());
825 } else {
826 word = static_cast<uint32_t>(value.getZExtValue());
827 }
828 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
829 } break;
830 // According to SPIR-V spec: "When the type's bit width is larger than one
831 // word, the literal’s low-order words appear first."
832 case 64: {
833 struct DoubleWord {
834 uint32_t word1;
835 uint32_t word2;
836 } words;
837 if (isSigned) {
838 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
839 } else {
840 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
841 }
842 encodeInstructionInto(typesGlobalValues, opcode,
843 {typeID, resultID, words.word1, words.word2});
844 } break;
845 default: {
846 std::string valueStr;
847 llvm::raw_string_ostream rss(valueStr);
848 value.print(rss, /*isSigned=*/false);
849
850 emitError(loc, "cannot serialize ")
851 << bitwidth << "-bit integer literal: " << rss.str();
852 return 0;
853 }
854 }
855
856 if (!isSpec) {
857 constIDMap[intAttr] = resultID;
858 }
859 return resultID;
860 }
861
prepareConstantFp(Location loc,FloatAttr floatAttr,bool isSpec)862 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
863 bool isSpec) {
864 if (!isSpec) {
865 // We can de-duplicate normal constants, but not specialization constants.
866 if (auto id = getConstantID(floatAttr)) {
867 return id;
868 }
869 }
870
871 // Process the type for this float literal
872 uint32_t typeID = 0;
873 if (failed(processType(loc, floatAttr.getType(), typeID))) {
874 return 0;
875 }
876
877 auto resultID = getNextID();
878 APFloat value = floatAttr.getValue();
879 APInt intValue = value.bitcastToAPInt();
880
881 auto opcode =
882 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
883
884 if (&value.getSemantics() == &APFloat::IEEEsingle()) {
885 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
886 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
887 } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
888 struct DoubleWord {
889 uint32_t word1;
890 uint32_t word2;
891 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
892 encodeInstructionInto(typesGlobalValues, opcode,
893 {typeID, resultID, words.word1, words.word2});
894 } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
895 uint32_t word =
896 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
897 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
898 } else {
899 std::string valueStr;
900 llvm::raw_string_ostream rss(valueStr);
901 value.print(rss);
902
903 emitError(loc, "cannot serialize ")
904 << floatAttr.getType() << "-typed float literal: " << rss.str();
905 return 0;
906 }
907
908 if (!isSpec) {
909 constIDMap[floatAttr] = resultID;
910 }
911 return resultID;
912 }
913
914 //===----------------------------------------------------------------------===//
915 // Control flow
916 //===----------------------------------------------------------------------===//
917
getOrCreateBlockID(Block * block)918 uint32_t Serializer::getOrCreateBlockID(Block *block) {
919 if (uint32_t id = getBlockID(block))
920 return id;
921 return blockIDMap[block] = getNextID();
922 }
923
924 #ifndef NDEBUG
printBlock(Block * block,raw_ostream & os)925 void Serializer::printBlock(Block *block, raw_ostream &os) {
926 os << "block " << block << " (id = ";
927 if (uint32_t id = getBlockID(block))
928 os << id;
929 else
930 os << "unknown";
931 os << ")\n";
932 }
933 #endif
934
935 LogicalResult
processBlock(Block * block,bool omitLabel,function_ref<LogicalResult ()> emitMerge)936 Serializer::processBlock(Block *block, bool omitLabel,
937 function_ref<LogicalResult()> emitMerge) {
938 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
939 LLVM_DEBUG(block->print(llvm::dbgs()));
940 LLVM_DEBUG(llvm::dbgs() << '\n');
941 if (!omitLabel) {
942 uint32_t blockID = getOrCreateBlockID(block);
943 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
944
945 // Emit OpLabel for this block.
946 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
947 }
948
949 // Emit OpPhi instructions for block arguments, if any.
950 if (failed(emitPhiForBlockArguments(block)))
951 return failure();
952
953 // If we need to emit merge instructions, it must happen in this block. Check
954 // whether we have other structured control flow ops, which will be expanded
955 // into multiple basic blocks. If that's the case, we need to emit the merge
956 // right now and then create new blocks for further serialization of the ops
957 // in this block.
958 if (emitMerge && llvm::any_of(block->getOperations(), [](Operation &op) {
959 return isa<spirv::LoopOp, spirv::SelectionOp>(op);
960 })) {
961 if (failed(emitMerge()))
962 return failure();
963 emitMerge = nullptr;
964
965 // Start a new block for further serialization.
966 uint32_t blockID = getNextID();
967 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
968 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
969 }
970
971 // Process each op in this block except the terminator.
972 for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
973 if (failed(processOperation(&op)))
974 return failure();
975 }
976
977 // Process the terminator.
978 if (emitMerge)
979 if (failed(emitMerge()))
980 return failure();
981 if (failed(processOperation(&block->back())))
982 return failure();
983
984 return success();
985 }
986
emitPhiForBlockArguments(Block * block)987 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
988 // Nothing to do if this block has no arguments or it's the entry block, which
989 // always has the same arguments as the function signature.
990 if (block->args_empty() || block->isEntryBlock())
991 return success();
992
993 LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
994
995 // If the block has arguments, we need to create SPIR-V OpPhi instructions.
996 // A SPIR-V OpPhi instruction is of the syntax:
997 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
998 // So we need to collect all predecessor blocks and the arguments they send
999 // to this block.
1000 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
1001 for (Block *mlirPredecessor : block->getPredecessors()) {
1002 auto *terminator = mlirPredecessor->getTerminator();
1003 LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
1004 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1005 LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
1006 // The predecessor here is the immediate one according to MLIR's IR
1007 // structure. It does not directly map to the incoming parent block for the
1008 // OpPhi instructions at SPIR-V binary level. This is because structured
1009 // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1010 // spv.mlir.selection/spv.mlir.loop op in the MLIR predecessor block, the
1011 // branch op jumping to the OpPhi's block then resides in the previous
1012 // structured control flow op's merge block.
1013 Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
1014 LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
1015 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1016 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1017 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1018 } else if (auto branchCondOp =
1019 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1020 Optional<OperandRange> blockOperands;
1021 if (branchCondOp.trueTarget() == block) {
1022 blockOperands = branchCondOp.trueTargetOperands();
1023 } else {
1024 assert(branchCondOp.falseTarget() == block);
1025 blockOperands = branchCondOp.falseTargetOperands();
1026 }
1027
1028 assert(!blockOperands->empty() &&
1029 "expected non-empty block operand range");
1030 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1031 } else {
1032 return terminator->emitError("unimplemented terminator for Phi creation");
1033 }
1034 LLVM_DEBUG({
1035 llvm::dbgs() << " block arguments:\n";
1036 for (Value v : predecessors.back().second)
1037 llvm::dbgs() << " " << v << "\n";
1038 });
1039 }
1040
1041 // Then create OpPhi instruction for each of the block argument.
1042 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1043 BlockArgument arg = block->getArgument(argIndex);
1044
1045 // Get the type <id> and result <id> for this OpPhi instruction.
1046 uint32_t phiTypeID = 0;
1047 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1048 return failure();
1049 uint32_t phiID = getNextID();
1050
1051 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1052 << arg << " (id = " << phiID << ")\n");
1053
1054 // Prepare the (value <id>, parent block <id>) pairs.
1055 SmallVector<uint32_t, 8> phiArgs;
1056 phiArgs.push_back(phiTypeID);
1057 phiArgs.push_back(phiID);
1058
1059 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1060 Value value = predecessors[predIndex].second[argIndex];
1061 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1062 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1063 << ") value " << value << ' ');
1064 // Each pair is a value <id> ...
1065 uint32_t valueId = getValueID(value);
1066 if (valueId == 0) {
1067 // The op generating this value hasn't been visited yet so we don't have
1068 // an <id> assigned yet. Record this to fix up later.
1069 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1070 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1071 phiArgs.size());
1072 } else {
1073 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1074 }
1075 phiArgs.push_back(valueId);
1076 // ... and a parent block <id>.
1077 phiArgs.push_back(predBlockId);
1078 }
1079
1080 encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1081 valueIDMap[arg] = phiID;
1082 }
1083
1084 return success();
1085 }
1086
1087 //===----------------------------------------------------------------------===//
1088 // Operation
1089 //===----------------------------------------------------------------------===//
1090
encodeExtensionInstruction(Operation * op,StringRef extensionSetName,uint32_t extensionOpcode,ArrayRef<uint32_t> operands)1091 LogicalResult Serializer::encodeExtensionInstruction(
1092 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1093 ArrayRef<uint32_t> operands) {
1094 // Check if the extension has been imported.
1095 auto &setID = extendedInstSetIDMap[extensionSetName];
1096 if (!setID) {
1097 setID = getNextID();
1098 SmallVector<uint32_t, 16> importOperands;
1099 importOperands.push_back(setID);
1100 spirv::encodeStringLiteralInto(importOperands, extensionSetName);
1101 encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
1102 importOperands);
1103 }
1104
1105 // The first two operands are the result type <id> and result <id>. The set
1106 // <id> and the opcode need to be insert after this.
1107 if (operands.size() < 2) {
1108 return op->emitError("extended instructions must have a result encoding");
1109 }
1110 SmallVector<uint32_t, 8> extInstOperands;
1111 extInstOperands.reserve(operands.size() + 2);
1112 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1113 extInstOperands.push_back(setID);
1114 extInstOperands.push_back(extensionOpcode);
1115 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1116 encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1117 extInstOperands);
1118 return success();
1119 }
1120
processOperation(Operation * opInst)1121 LogicalResult Serializer::processOperation(Operation *opInst) {
1122 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1123
1124 // First dispatch the ops that do not directly mirror an instruction from
1125 // the SPIR-V spec.
1126 return TypeSwitch<Operation *, LogicalResult>(opInst)
1127 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1128 .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1129 .Case([&](spirv::BranchConditionalOp op) {
1130 return processBranchConditionalOp(op);
1131 })
1132 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1133 .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1134 .Case([&](spirv::GlobalVariableOp op) {
1135 return processGlobalVariableOp(op);
1136 })
1137 .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1138 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1139 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1140 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1141 .Case([&](spirv::SpecConstantCompositeOp op) {
1142 return processSpecConstantCompositeOp(op);
1143 })
1144 .Case([&](spirv::SpecConstantOperationOp op) {
1145 return processSpecConstantOperationOp(op);
1146 })
1147 .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1148 .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1149
1150 // Then handle all the ops that directly mirror SPIR-V instructions with
1151 // auto-generated methods.
1152 .Default(
1153 [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1154 }
1155
processOpWithoutGrammarAttr(Operation * op,StringRef extInstSet,uint32_t opcode)1156 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1157 StringRef extInstSet,
1158 uint32_t opcode) {
1159 SmallVector<uint32_t, 4> operands;
1160 Location loc = op->getLoc();
1161
1162 uint32_t resultID = 0;
1163 if (op->getNumResults() != 0) {
1164 uint32_t resultTypeID = 0;
1165 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
1166 return failure();
1167 operands.push_back(resultTypeID);
1168
1169 resultID = getNextID();
1170 operands.push_back(resultID);
1171 valueIDMap[op->getResult(0)] = resultID;
1172 };
1173
1174 for (Value operand : op->getOperands())
1175 operands.push_back(getValueID(operand));
1176
1177 if (failed(emitDebugLine(functionBody, loc)))
1178 return failure();
1179
1180 if (extInstSet.empty()) {
1181 encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
1182 operands);
1183 } else {
1184 if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1185 return failure();
1186 }
1187
1188 if (op->getNumResults() != 0) {
1189 for (auto attr : op->getAttrs()) {
1190 if (failed(processDecoration(loc, resultID, attr)))
1191 return failure();
1192 }
1193 }
1194
1195 return success();
1196 }
1197
emitDecoration(uint32_t target,spirv::Decoration decoration,ArrayRef<uint32_t> params)1198 LogicalResult Serializer::emitDecoration(uint32_t target,
1199 spirv::Decoration decoration,
1200 ArrayRef<uint32_t> params) {
1201 uint32_t wordCount = 3 + params.size();
1202 decorations.push_back(
1203 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate));
1204 decorations.push_back(target);
1205 decorations.push_back(static_cast<uint32_t>(decoration));
1206 decorations.append(params.begin(), params.end());
1207 return success();
1208 }
1209
emitDebugLine(SmallVectorImpl<uint32_t> & binary,Location loc)1210 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1211 Location loc) {
1212 if (!options.emitDebugInfo)
1213 return success();
1214
1215 if (lastProcessedWasMergeInst) {
1216 lastProcessedWasMergeInst = false;
1217 return success();
1218 }
1219
1220 auto fileLoc = loc.dyn_cast<FileLineColLoc>();
1221 if (fileLoc)
1222 encodeInstructionInto(binary, spirv::Opcode::OpLine,
1223 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1224 return success();
1225 }
1226 } // namespace spirv
1227 } // namespace mlir
1228