1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/Support/LogicalResult.h"
23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/bit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "spirv-deserialization"
36 
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given `block` is a function entry block.
isFnEntryBlock(Block * block)42 static inline bool isFnEntryBlock(Block *block) {
43   return block->isEntryBlock() &&
44          isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50 
Deserializer(ArrayRef<uint32_t> binary,MLIRContext * context)51 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52                                   MLIRContext *context)
53     : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
54       module(createModuleOp()), opBuilder(module->getRegion())
55 #ifndef NDEBUG
56       ,
57       logger(llvm::dbgs())
58 #endif
59 {
60 }
61 
deserialize()62 LogicalResult spirv::Deserializer::deserialize() {
63   LLVM_DEBUG({
64     logger.resetIndent();
65     logger.startLine()
66         << "//+++---------- start deserialization ----------+++//\n";
67   });
68 
69   if (failed(processHeader()))
70     return failure();
71 
72   spirv::Opcode opcode = spirv::Opcode::OpNop;
73   ArrayRef<uint32_t> operands;
74   auto binarySize = binary.size();
75   while (curOffset < binarySize) {
76     // Slice the next instruction out and populate `opcode` and `operands`.
77     // Internally this also updates `curOffset`.
78     if (failed(sliceInstruction(opcode, operands)))
79       return failure();
80 
81     if (failed(processInstruction(opcode, operands)))
82       return failure();
83   }
84 
85   assert(curOffset == binarySize &&
86          "deserializer should never index beyond the binary end");
87 
88   for (auto &deferred : deferredInstructions) {
89     if (failed(processInstruction(deferred.first, deferred.second, false))) {
90       return failure();
91     }
92   }
93 
94   attachVCETriple();
95 
96   LLVM_DEBUG(logger.startLine()
97              << "//+++-------- completed deserialization --------+++//\n");
98   return success();
99 }
100 
collect()101 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
102   return std::move(module);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Module structure
107 //===----------------------------------------------------------------------===//
108 
createModuleOp()109 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
110   OpBuilder builder(context);
111   OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
112   spirv::ModuleOp::build(builder, state);
113   return cast<spirv::ModuleOp>(Operation::create(state));
114 }
115 
processHeader()116 LogicalResult spirv::Deserializer::processHeader() {
117   if (binary.size() < spirv::kHeaderWordCount)
118     return emitError(unknownLoc,
119                      "SPIR-V binary module must have a 5-word header");
120 
121   if (binary[0] != spirv::kMagicNumber)
122     return emitError(unknownLoc, "incorrect magic number");
123 
124   // Version number bytes: 0 | major number | minor number | 0
125   uint32_t majorVersion = (binary[1] << 8) >> 24;
126   uint32_t minorVersion = (binary[1] << 16) >> 24;
127   if (majorVersion == 1) {
128     switch (minorVersion) {
129 #define MIN_VERSION_CASE(v)                                                    \
130   case v:                                                                      \
131     version = spirv::Version::V_1_##v;                                         \
132     break
133 
134       MIN_VERSION_CASE(0);
135       MIN_VERSION_CASE(1);
136       MIN_VERSION_CASE(2);
137       MIN_VERSION_CASE(3);
138       MIN_VERSION_CASE(4);
139       MIN_VERSION_CASE(5);
140 #undef MIN_VERSION_CASE
141     default:
142       return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
143              << minorVersion;
144     }
145   } else {
146     return emitError(unknownLoc, "unsupported SPIR-V major version: ")
147            << majorVersion;
148   }
149 
150   // TODO: generator number, bound, schema
151   curOffset = spirv::kHeaderWordCount;
152   return success();
153 }
154 
155 LogicalResult
processCapability(ArrayRef<uint32_t> operands)156 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
157   if (operands.size() != 1)
158     return emitError(unknownLoc, "OpMemoryModel must have one parameter");
159 
160   auto cap = spirv::symbolizeCapability(operands[0]);
161   if (!cap)
162     return emitError(unknownLoc, "unknown capability: ") << operands[0];
163 
164   capabilities.insert(*cap);
165   return success();
166 }
167 
processExtension(ArrayRef<uint32_t> words)168 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
169   if (words.empty()) {
170     return emitError(
171         unknownLoc,
172         "OpExtension must have a literal string for the extension name");
173   }
174 
175   unsigned wordIndex = 0;
176   StringRef extName = decodeStringLiteral(words, wordIndex);
177   if (wordIndex != words.size())
178     return emitError(unknownLoc,
179                      "unexpected trailing words in OpExtension instruction");
180   auto ext = spirv::symbolizeExtension(extName);
181   if (!ext)
182     return emitError(unknownLoc, "unknown extension: ") << extName;
183 
184   extensions.insert(*ext);
185   return success();
186 }
187 
188 LogicalResult
processExtInstImport(ArrayRef<uint32_t> words)189 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
190   if (words.size() < 2) {
191     return emitError(unknownLoc,
192                      "OpExtInstImport must have a result <id> and a literal "
193                      "string for the extended instruction set name");
194   }
195 
196   unsigned wordIndex = 1;
197   extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
198   if (wordIndex != words.size()) {
199     return emitError(unknownLoc,
200                      "unexpected trailing words in OpExtInstImport");
201   }
202   return success();
203 }
204 
attachVCETriple()205 void spirv::Deserializer::attachVCETriple() {
206   (*module)->setAttr(
207       spirv::ModuleOp::getVCETripleAttrName(),
208       spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
209                                 extensions.getArrayRef(), context));
210 }
211 
212 LogicalResult
processMemoryModel(ArrayRef<uint32_t> operands)213 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
214   if (operands.size() != 2)
215     return emitError(unknownLoc, "OpMemoryModel must have two operands");
216 
217   (*module)->setAttr(
218       "addressing_model",
219       opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front())));
220   (*module)->setAttr(
221       "memory_model",
222       opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back())));
223 
224   return success();
225 }
226 
processDecoration(ArrayRef<uint32_t> words)227 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
228   // TODO: This function should also be auto-generated. For now, since only a
229   // few decorations are processed/handled in a meaningful manner, going with a
230   // manual implementation.
231   if (words.size() < 2) {
232     return emitError(
233         unknownLoc, "OpDecorate must have at least result <id> and Decoration");
234   }
235   auto decorationName =
236       stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
237   if (decorationName.empty()) {
238     return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
239   }
240   auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
241   auto symbol = opBuilder.getStringAttr(attrName);
242   switch (static_cast<spirv::Decoration>(words[1])) {
243   case spirv::Decoration::DescriptorSet:
244   case spirv::Decoration::Binding:
245     if (words.size() != 3) {
246       return emitError(unknownLoc, "OpDecorate with ")
247              << decorationName << " needs a single integer literal";
248     }
249     decorations[words[0]].set(
250         symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
251     break;
252   case spirv::Decoration::BuiltIn:
253     if (words.size() != 3) {
254       return emitError(unknownLoc, "OpDecorate with ")
255              << decorationName << " needs a single integer literal";
256     }
257     decorations[words[0]].set(
258         symbol, opBuilder.getStringAttr(
259                     stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
260     break;
261   case spirv::Decoration::ArrayStride:
262     if (words.size() != 3) {
263       return emitError(unknownLoc, "OpDecorate with ")
264              << decorationName << " needs a single integer literal";
265     }
266     typeDecorations[words[0]] = words[2];
267     break;
268   case spirv::Decoration::Aliased:
269   case spirv::Decoration::Block:
270   case spirv::Decoration::BufferBlock:
271   case spirv::Decoration::Flat:
272   case spirv::Decoration::NonReadable:
273   case spirv::Decoration::NonWritable:
274   case spirv::Decoration::NoPerspective:
275   case spirv::Decoration::Restrict:
276   case spirv::Decoration::RelaxedPrecision:
277     if (words.size() != 2) {
278       return emitError(unknownLoc, "OpDecoration with ")
279              << decorationName << "needs a single target <id>";
280     }
281     // Block decoration does not affect spv.struct type, but is still stored for
282     // verification.
283     // TODO: Update StructType to contain this information since
284     // it is needed for many validation rules.
285     decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
286     break;
287   case spirv::Decoration::Location:
288   case spirv::Decoration::SpecId:
289     if (words.size() != 3) {
290       return emitError(unknownLoc, "OpDecoration with ")
291              << decorationName << "needs a single integer literal";
292     }
293     decorations[words[0]].set(
294         symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
295     break;
296   default:
297     return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
298   }
299   return success();
300 }
301 
302 LogicalResult
processMemberDecoration(ArrayRef<uint32_t> words)303 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
304   // The binary layout of OpMemberDecorate is different comparing to OpDecorate
305   if (words.size() < 3) {
306     return emitError(unknownLoc,
307                      "OpMemberDecorate must have at least 3 operands");
308   }
309 
310   auto decoration = static_cast<spirv::Decoration>(words[2]);
311   if (decoration == spirv::Decoration::Offset && words.size() != 4) {
312     return emitError(unknownLoc,
313                      " missing offset specification in OpMemberDecorate with "
314                      "Offset decoration");
315   }
316   ArrayRef<uint32_t> decorationOperands;
317   if (words.size() > 3) {
318     decorationOperands = words.slice(3);
319   }
320   memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
321   return success();
322 }
323 
processMemberName(ArrayRef<uint32_t> words)324 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
325   if (words.size() < 3) {
326     return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
327   }
328   unsigned wordIndex = 2;
329   auto name = decodeStringLiteral(words, wordIndex);
330   if (wordIndex != words.size()) {
331     return emitError(unknownLoc,
332                      "unexpected trailing words in OpMemberName instruction");
333   }
334   memberNameMap[words[0]][words[1]] = name;
335   return success();
336 }
337 
338 LogicalResult
processFunction(ArrayRef<uint32_t> operands)339 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
340   if (curFunction) {
341     return emitError(unknownLoc, "found function inside function");
342   }
343 
344   // Get the result type
345   if (operands.size() != 4) {
346     return emitError(unknownLoc, "OpFunction must have 4 parameters");
347   }
348   Type resultType = getType(operands[0]);
349   if (!resultType) {
350     return emitError(unknownLoc, "undefined result type from <id> ")
351            << operands[0];
352   }
353 
354   uint32_t fnID = operands[1];
355   if (funcMap.count(fnID)) {
356     return emitError(unknownLoc, "duplicate function definition/declaration");
357   }
358 
359   auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
360   if (!fnControl) {
361     return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
362   }
363 
364   Type fnType = getType(operands[3]);
365   if (!fnType || !fnType.isa<FunctionType>()) {
366     return emitError(unknownLoc, "unknown function type from <id> ")
367            << operands[3];
368   }
369   auto functionType = fnType.cast<FunctionType>();
370 
371   if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
372       (functionType.getNumResults() == 1 &&
373        functionType.getResult(0) != resultType)) {
374     return emitError(unknownLoc, "mismatch in function type ")
375            << functionType << " and return type " << resultType << " specified";
376   }
377 
378   std::string fnName = getFunctionSymbol(fnID);
379   auto funcOp = opBuilder.create<spirv::FuncOp>(
380       unknownLoc, fnName, functionType, fnControl.value());
381   curFunction = funcMap[fnID] = funcOp;
382   auto *entryBlock = funcOp.addEntryBlock();
383   LLVM_DEBUG({
384     logger.startLine()
385         << "//===-------------------------------------------===//\n";
386     logger.startLine() << "[fn] name: " << fnName << "\n";
387     logger.startLine() << "[fn] type: " << fnType << "\n";
388     logger.startLine() << "[fn] ID: " << fnID << "\n";
389     logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
390     logger.indent();
391   });
392 
393   // Parse the op argument instructions
394   if (functionType.getNumInputs()) {
395     for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
396       auto argType = functionType.getInput(i);
397       spirv::Opcode opcode = spirv::Opcode::OpNop;
398       ArrayRef<uint32_t> operands;
399       if (failed(sliceInstruction(opcode, operands,
400                                   spirv::Opcode::OpFunctionParameter))) {
401         return failure();
402       }
403       if (opcode != spirv::Opcode::OpFunctionParameter) {
404         return emitError(
405                    unknownLoc,
406                    "missing OpFunctionParameter instruction for argument ")
407                << i;
408       }
409       if (operands.size() != 2) {
410         return emitError(
411             unknownLoc,
412             "expected result type and result <id> for OpFunctionParameter");
413       }
414       auto argDefinedType = getType(operands[0]);
415       if (!argDefinedType || argDefinedType != argType) {
416         return emitError(unknownLoc,
417                          "mismatch in argument type between function type "
418                          "definition ")
419                << functionType << " and argument type definition "
420                << argDefinedType << " at argument " << i;
421       }
422       if (getValue(operands[1])) {
423         return emitError(unknownLoc, "duplicate definition of result <id> ")
424                << operands[1];
425       }
426       auto argValue = funcOp.getArgument(i);
427       valueMap[operands[1]] = argValue;
428     }
429   }
430 
431   // RAII guard to reset the insertion point to the module's region after
432   // deserializing the body of this function.
433   OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
434 
435   spirv::Opcode opcode = spirv::Opcode::OpNop;
436   ArrayRef<uint32_t> instOperands;
437 
438   // Special handling for the entry block. We need to make sure it starts with
439   // an OpLabel instruction. The entry block takes the same parameters as the
440   // function. All other blocks do not take any parameter. We have already
441   // created the entry block, here we need to register it to the correct label
442   // <id>.
443   if (failed(sliceInstruction(opcode, instOperands,
444                               spirv::Opcode::OpFunctionEnd))) {
445     return failure();
446   }
447   if (opcode == spirv::Opcode::OpFunctionEnd) {
448     return processFunctionEnd(instOperands);
449   }
450   if (opcode != spirv::Opcode::OpLabel) {
451     return emitError(unknownLoc, "a basic block must start with OpLabel");
452   }
453   if (instOperands.size() != 1) {
454     return emitError(unknownLoc, "OpLabel should only have result <id>");
455   }
456   blockMap[instOperands[0]] = entryBlock;
457   if (failed(processLabel(instOperands))) {
458     return failure();
459   }
460 
461   // Then process all the other instructions in the function until we hit
462   // OpFunctionEnd.
463   while (succeeded(sliceInstruction(opcode, instOperands,
464                                     spirv::Opcode::OpFunctionEnd)) &&
465          opcode != spirv::Opcode::OpFunctionEnd) {
466     if (failed(processInstruction(opcode, instOperands))) {
467       return failure();
468     }
469   }
470   if (opcode != spirv::Opcode::OpFunctionEnd) {
471     return failure();
472   }
473 
474   return processFunctionEnd(instOperands);
475 }
476 
477 LogicalResult
processFunctionEnd(ArrayRef<uint32_t> operands)478 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
479   // Process OpFunctionEnd.
480   if (!operands.empty()) {
481     return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
482   }
483 
484   // Wire up block arguments from OpPhi instructions.
485   // Put all structured control flow in spv.mlir.selection/spv.mlir.loop ops.
486   if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
487     return failure();
488   }
489 
490   curBlock = nullptr;
491   curFunction = llvm::None;
492 
493   LLVM_DEBUG({
494     logger.unindent();
495     logger.startLine()
496         << "//===-------------------------------------------===//\n";
497   });
498   return success();
499 }
500 
501 Optional<std::pair<Attribute, Type>>
getConstant(uint32_t id)502 spirv::Deserializer::getConstant(uint32_t id) {
503   auto constIt = constantMap.find(id);
504   if (constIt == constantMap.end())
505     return llvm::None;
506   return constIt->getSecond();
507 }
508 
509 Optional<spirv::SpecConstOperationMaterializationInfo>
getSpecConstantOperation(uint32_t id)510 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
511   auto constIt = specConstOperationMap.find(id);
512   if (constIt == specConstOperationMap.end())
513     return llvm::None;
514   return constIt->getSecond();
515 }
516 
getFunctionSymbol(uint32_t id)517 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
518   auto funcName = nameMap.lookup(id).str();
519   if (funcName.empty()) {
520     funcName = "spirv_fn_" + std::to_string(id);
521   }
522   return funcName;
523 }
524 
getSpecConstantSymbol(uint32_t id)525 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
526   auto constName = nameMap.lookup(id).str();
527   if (constName.empty()) {
528     constName = "spirv_spec_const_" + std::to_string(id);
529   }
530   return constName;
531 }
532 
533 spirv::SpecConstantOp
createSpecConstant(Location loc,uint32_t resultID,Attribute defaultValue)534 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
535                                         Attribute defaultValue) {
536   auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
537   auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
538                                                     defaultValue);
539   if (decorations.count(resultID)) {
540     for (auto attr : decorations[resultID].getAttrs())
541       op->setAttr(attr.getName(), attr.getValue());
542   }
543   specConstMap[resultID] = op;
544   return op;
545 }
546 
547 LogicalResult
processGlobalVariable(ArrayRef<uint32_t> operands)548 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
549   unsigned wordIndex = 0;
550   if (operands.size() < 3) {
551     return emitError(
552         unknownLoc,
553         "OpVariable needs at least 3 operands, type, <id> and storage class");
554   }
555 
556   // Result Type.
557   auto type = getType(operands[wordIndex]);
558   if (!type) {
559     return emitError(unknownLoc, "unknown result type <id> : ")
560            << operands[wordIndex];
561   }
562   auto ptrType = type.dyn_cast<spirv::PointerType>();
563   if (!ptrType) {
564     return emitError(unknownLoc,
565                      "expected a result type <id> to be a spv.ptr, found : ")
566            << type;
567   }
568   wordIndex++;
569 
570   // Result <id>.
571   auto variableID = operands[wordIndex];
572   auto variableName = nameMap.lookup(variableID).str();
573   if (variableName.empty()) {
574     variableName = "spirv_var_" + std::to_string(variableID);
575   }
576   wordIndex++;
577 
578   // Storage class.
579   auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
580   if (ptrType.getStorageClass() != storageClass) {
581     return emitError(unknownLoc, "mismatch in storage class of pointer type ")
582            << type << " and that specified in OpVariable instruction  : "
583            << stringifyStorageClass(storageClass);
584   }
585   wordIndex++;
586 
587   // Initializer.
588   FlatSymbolRefAttr initializer = nullptr;
589   if (wordIndex < operands.size()) {
590     auto initializerOp = getGlobalVariable(operands[wordIndex]);
591     if (!initializerOp) {
592       return emitError(unknownLoc, "unknown <id> ")
593              << operands[wordIndex] << "used as initializer";
594     }
595     wordIndex++;
596     initializer = SymbolRefAttr::get(initializerOp.getOperation());
597   }
598   if (wordIndex != operands.size()) {
599     return emitError(unknownLoc,
600                      "found more operands than expected when deserializing "
601                      "OpVariable instruction, only ")
602            << wordIndex << " of " << operands.size() << " processed";
603   }
604   auto loc = createFileLineColLoc(opBuilder);
605   auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
606       loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
607       initializer);
608 
609   // Decorations.
610   if (decorations.count(variableID)) {
611     for (auto attr : decorations[variableID].getAttrs())
612       varOp->setAttr(attr.getName(), attr.getValue());
613   }
614   globalVariableMap[variableID] = varOp;
615   return success();
616 }
617 
getConstantInt(uint32_t id)618 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
619   auto constInfo = getConstant(id);
620   if (!constInfo) {
621     return nullptr;
622   }
623   return constInfo->first.dyn_cast<IntegerAttr>();
624 }
625 
processName(ArrayRef<uint32_t> operands)626 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
627   if (operands.size() < 2) {
628     return emitError(unknownLoc, "OpName needs at least 2 operands");
629   }
630   if (!nameMap.lookup(operands[0]).empty()) {
631     return emitError(unknownLoc, "duplicate name found for result <id> ")
632            << operands[0];
633   }
634   unsigned wordIndex = 1;
635   StringRef name = decodeStringLiteral(operands, wordIndex);
636   if (wordIndex != operands.size()) {
637     return emitError(unknownLoc,
638                      "unexpected trailing words in OpName instruction");
639   }
640   nameMap[operands[0]] = name;
641   return success();
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // Type
646 //===----------------------------------------------------------------------===//
647 
processType(spirv::Opcode opcode,ArrayRef<uint32_t> operands)648 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
649                                                ArrayRef<uint32_t> operands) {
650   if (operands.empty()) {
651     return emitError(unknownLoc, "type instruction with opcode ")
652            << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
653   }
654 
655   /// TODO: Types might be forward declared in some instructions and need to be
656   /// handled appropriately.
657   if (typeMap.count(operands[0])) {
658     return emitError(unknownLoc, "duplicate definition for result <id> ")
659            << operands[0];
660   }
661 
662   switch (opcode) {
663   case spirv::Opcode::OpTypeVoid:
664     if (operands.size() != 1)
665       return emitError(unknownLoc, "OpTypeVoid must have no parameters");
666     typeMap[operands[0]] = opBuilder.getNoneType();
667     break;
668   case spirv::Opcode::OpTypeBool:
669     if (operands.size() != 1)
670       return emitError(unknownLoc, "OpTypeBool must have no parameters");
671     typeMap[operands[0]] = opBuilder.getI1Type();
672     break;
673   case spirv::Opcode::OpTypeInt: {
674     if (operands.size() != 3)
675       return emitError(
676           unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
677 
678     // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
679     // to preserve or validate.
680     // 0 indicates unsigned, or no signedness semantics
681     // 1 indicates signed semantics."
682     //
683     // So we cannot differentiate signless and unsigned integers; always use
684     // signless semantics for such cases.
685     auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
686                                  : IntegerType::SignednessSemantics::Signless;
687     typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
688   } break;
689   case spirv::Opcode::OpTypeFloat: {
690     if (operands.size() != 2)
691       return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
692 
693     Type floatTy;
694     switch (operands[1]) {
695     case 16:
696       floatTy = opBuilder.getF16Type();
697       break;
698     case 32:
699       floatTy = opBuilder.getF32Type();
700       break;
701     case 64:
702       floatTy = opBuilder.getF64Type();
703       break;
704     default:
705       return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
706              << operands[1];
707     }
708     typeMap[operands[0]] = floatTy;
709   } break;
710   case spirv::Opcode::OpTypeVector: {
711     if (operands.size() != 3) {
712       return emitError(
713           unknownLoc,
714           "OpTypeVector must have element type and count parameters");
715     }
716     Type elementTy = getType(operands[1]);
717     if (!elementTy) {
718       return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
719              << operands[1];
720     }
721     typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
722   } break;
723   case spirv::Opcode::OpTypePointer: {
724     return processOpTypePointer(operands);
725   } break;
726   case spirv::Opcode::OpTypeArray:
727     return processArrayType(operands);
728   case spirv::Opcode::OpTypeCooperativeMatrixNV:
729     return processCooperativeMatrixType(operands);
730   case spirv::Opcode::OpTypeFunction:
731     return processFunctionType(operands);
732   case spirv::Opcode::OpTypeImage:
733     return processImageType(operands);
734   case spirv::Opcode::OpTypeSampledImage:
735     return processSampledImageType(operands);
736   case spirv::Opcode::OpTypeRuntimeArray:
737     return processRuntimeArrayType(operands);
738   case spirv::Opcode::OpTypeStruct:
739     return processStructType(operands);
740   case spirv::Opcode::OpTypeMatrix:
741     return processMatrixType(operands);
742   default:
743     return emitError(unknownLoc, "unhandled type instruction");
744   }
745   return success();
746 }
747 
748 LogicalResult
processOpTypePointer(ArrayRef<uint32_t> operands)749 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
750   if (operands.size() != 3)
751     return emitError(unknownLoc, "OpTypePointer must have two parameters");
752 
753   auto pointeeType = getType(operands[2]);
754   if (!pointeeType)
755     return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
756            << operands[2];
757 
758   uint32_t typePointerID = operands[0];
759   auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
760   typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
761 
762   for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
763        deferredStructIt != std::end(deferredStructTypesInfos);) {
764     for (auto *unresolvedMemberIt =
765              std::begin(deferredStructIt->unresolvedMemberTypes);
766          unresolvedMemberIt !=
767          std::end(deferredStructIt->unresolvedMemberTypes);) {
768       if (unresolvedMemberIt->first == typePointerID) {
769         // The newly constructed pointer type can resolve one of the
770         // deferred struct type members; update the memberTypes list and
771         // clean the unresolvedMemberTypes list accordingly.
772         deferredStructIt->memberTypes[unresolvedMemberIt->second] =
773             typeMap[typePointerID];
774         unresolvedMemberIt =
775             deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
776       } else {
777         ++unresolvedMemberIt;
778       }
779     }
780 
781     if (deferredStructIt->unresolvedMemberTypes.empty()) {
782       // All deferred struct type members are now resolved, set the struct body.
783       auto structType = deferredStructIt->deferredStructType;
784 
785       assert(structType && "expected a spirv::StructType");
786       assert(structType.isIdentified() && "expected an indentified struct");
787 
788       if (failed(structType.trySetBody(
789               deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
790               deferredStructIt->memberDecorationsInfo)))
791         return failure();
792 
793       deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
794     } else {
795       ++deferredStructIt;
796     }
797   }
798 
799   return success();
800 }
801 
802 LogicalResult
processArrayType(ArrayRef<uint32_t> operands)803 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
804   if (operands.size() != 3) {
805     return emitError(unknownLoc,
806                      "OpTypeArray must have element type and count parameters");
807   }
808 
809   Type elementTy = getType(operands[1]);
810   if (!elementTy) {
811     return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
812            << operands[1];
813   }
814 
815   unsigned count = 0;
816   // TODO: The count can also come frome a specialization constant.
817   auto countInfo = getConstant(operands[2]);
818   if (!countInfo) {
819     return emitError(unknownLoc, "OpTypeArray count <id> ")
820            << operands[2] << "can only come from normal constant right now";
821   }
822 
823   if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) {
824     count = intVal.getValue().getZExtValue();
825   } else {
826     return emitError(unknownLoc, "OpTypeArray count must come from a "
827                                  "scalar integer constant instruction");
828   }
829 
830   typeMap[operands[0]] = spirv::ArrayType::get(
831       elementTy, count, typeDecorations.lookup(operands[0]));
832   return success();
833 }
834 
835 LogicalResult
processFunctionType(ArrayRef<uint32_t> operands)836 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
837   assert(!operands.empty() && "No operands for processing function type");
838   if (operands.size() == 1) {
839     return emitError(unknownLoc, "missing return type for OpTypeFunction");
840   }
841   auto returnType = getType(operands[1]);
842   if (!returnType) {
843     return emitError(unknownLoc, "unknown return type in OpTypeFunction");
844   }
845   SmallVector<Type, 1> argTypes;
846   for (size_t i = 2, e = operands.size(); i < e; ++i) {
847     auto ty = getType(operands[i]);
848     if (!ty) {
849       return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
850     }
851     argTypes.push_back(ty);
852   }
853   ArrayRef<Type> returnTypes;
854   if (!isVoidType(returnType)) {
855     returnTypes = llvm::makeArrayRef(returnType);
856   }
857   typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
858   return success();
859 }
860 
861 LogicalResult
processCooperativeMatrixType(ArrayRef<uint32_t> operands)862 spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
863   if (operands.size() != 5) {
864     return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
865                                  "type and row x column parameters");
866   }
867 
868   Type elementTy = getType(operands[1]);
869   if (!elementTy) {
870     return emitError(unknownLoc,
871                      "OpTypeCooperativeMatrix references undefined <id> ")
872            << operands[1];
873   }
874 
875   auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
876   if (!scope) {
877     return emitError(unknownLoc,
878                      "OpTypeCooperativeMatrix references undefined scope <id> ")
879            << operands[2];
880   }
881 
882   unsigned rows = getConstantInt(operands[3]).getInt();
883   unsigned columns = getConstantInt(operands[4]).getInt();
884 
885   typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
886       elementTy, scope.value(), rows, columns);
887   return success();
888 }
889 
890 LogicalResult
processRuntimeArrayType(ArrayRef<uint32_t> operands)891 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
892   if (operands.size() != 2) {
893     return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
894   }
895   Type memberType = getType(operands[1]);
896   if (!memberType) {
897     return emitError(unknownLoc,
898                      "OpTypeRuntimeArray references undefined <id> ")
899            << operands[1];
900   }
901   typeMap[operands[0]] = spirv::RuntimeArrayType::get(
902       memberType, typeDecorations.lookup(operands[0]));
903   return success();
904 }
905 
906 LogicalResult
processStructType(ArrayRef<uint32_t> operands)907 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
908   // TODO: Find a way to handle identified structs when debug info is stripped.
909 
910   if (operands.empty()) {
911     return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
912   }
913 
914   if (operands.size() == 1) {
915     // Handle empty struct.
916     typeMap[operands[0]] =
917         spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
918     return success();
919   }
920 
921   // First element is operand ID, second element is member index in the struct.
922   SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
923   SmallVector<Type, 4> memberTypes;
924 
925   for (auto op : llvm::drop_begin(operands, 1)) {
926     Type memberType = getType(op);
927     bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
928 
929     if (!memberType && !typeForwardPtr)
930       return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
931              << op;
932 
933     if (!memberType)
934       unresolvedMemberTypes.emplace_back(op, memberTypes.size());
935 
936     memberTypes.push_back(memberType);
937   }
938 
939   SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
940   SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
941   if (memberDecorationMap.count(operands[0])) {
942     auto &allMemberDecorations = memberDecorationMap[operands[0]];
943     for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
944       if (allMemberDecorations.count(memberIndex)) {
945         for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
946           // Check for offset.
947           if (memberDecoration.first == spirv::Decoration::Offset) {
948             // If offset info is empty, resize to the number of members;
949             if (offsetInfo.empty()) {
950               offsetInfo.resize(memberTypes.size());
951             }
952             offsetInfo[memberIndex] = memberDecoration.second[0];
953           } else {
954             if (!memberDecoration.second.empty()) {
955               memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
956                                                  memberDecoration.first,
957                                                  memberDecoration.second[0]);
958             } else {
959               memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
960                                                  memberDecoration.first, 0);
961             }
962           }
963         }
964       }
965     }
966   }
967 
968   uint32_t structID = operands[0];
969   std::string structIdentifier = nameMap.lookup(structID).str();
970 
971   if (structIdentifier.empty()) {
972     assert(unresolvedMemberTypes.empty() &&
973            "didn't expect unresolved member types");
974     typeMap[structID] =
975         spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
976   } else {
977     auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
978     typeMap[structID] = structTy;
979 
980     if (!unresolvedMemberTypes.empty())
981       deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
982                                           memberTypes, offsetInfo,
983                                           memberDecorationsInfo});
984     else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
985                                         memberDecorationsInfo)))
986       return failure();
987   }
988 
989   // TODO: Update StructType to have member name as attribute as
990   // well.
991   return success();
992 }
993 
994 LogicalResult
processMatrixType(ArrayRef<uint32_t> operands)995 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
996   if (operands.size() != 3) {
997     // Three operands are needed: result_id, column_type, and column_count
998     return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
999                                  " (result_id, column_type, and column_count)");
1000   }
1001   // Matrix columns must be of vector type
1002   Type elementTy = getType(operands[1]);
1003   if (!elementTy) {
1004     return emitError(unknownLoc,
1005                      "OpTypeMatrix references undefined column type.")
1006            << operands[1];
1007   }
1008 
1009   uint32_t colsCount = operands[2];
1010   typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1011   return success();
1012 }
1013 
1014 LogicalResult
processTypeForwardPointer(ArrayRef<uint32_t> operands)1015 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1016   if (operands.size() != 2)
1017     return emitError(unknownLoc,
1018                      "OpTypeForwardPointer instruction must have two operands");
1019 
1020   typeForwardPointerIDs.insert(operands[0]);
1021   // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1022   // instruction that defines the actual type.
1023 
1024   return success();
1025 }
1026 
1027 LogicalResult
processImageType(ArrayRef<uint32_t> operands)1028 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1029   // TODO: Add support for Access Qualifier.
1030   if (operands.size() != 8)
1031     return emitError(
1032         unknownLoc,
1033         "OpTypeImage with non-eight operands are not supported yet");
1034 
1035   Type elementTy = getType(operands[1]);
1036   if (!elementTy)
1037     return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1038            << operands[1];
1039 
1040   auto dim = spirv::symbolizeDim(operands[2]);
1041   if (!dim)
1042     return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1043            << operands[2];
1044 
1045   auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1046   if (!depthInfo)
1047     return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1048            << operands[3];
1049 
1050   auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1051   if (!arrayedInfo)
1052     return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1053            << operands[4];
1054 
1055   auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1056   if (!samplingInfo)
1057     return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1058 
1059   auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1060   if (!samplerUseInfo)
1061     return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1062            << operands[6];
1063 
1064   auto format = spirv::symbolizeImageFormat(operands[7]);
1065   if (!format)
1066     return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1067            << operands[7];
1068 
1069   typeMap[operands[0]] = spirv::ImageType::get(
1070       elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1071       samplingInfo.value(), samplerUseInfo.value(), format.value());
1072   return success();
1073 }
1074 
1075 LogicalResult
processSampledImageType(ArrayRef<uint32_t> operands)1076 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1077   if (operands.size() != 2)
1078     return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1079 
1080   Type elementTy = getType(operands[1]);
1081   if (!elementTy)
1082     return emitError(unknownLoc,
1083                      "OpTypeSampledImage references undefined <id>: ")
1084            << operands[1];
1085 
1086   typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1087   return success();
1088 }
1089 
1090 //===----------------------------------------------------------------------===//
1091 // Constant
1092 //===----------------------------------------------------------------------===//
1093 
processConstant(ArrayRef<uint32_t> operands,bool isSpec)1094 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1095                                                    bool isSpec) {
1096   StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1097 
1098   if (operands.size() < 2) {
1099     return emitError(unknownLoc)
1100            << opname << " must have type <id> and result <id>";
1101   }
1102   if (operands.size() < 3) {
1103     return emitError(unknownLoc)
1104            << opname << " must have at least 1 more parameter";
1105   }
1106 
1107   Type resultType = getType(operands[0]);
1108   if (!resultType) {
1109     return emitError(unknownLoc, "undefined result type from <id> ")
1110            << operands[0];
1111   }
1112 
1113   auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1114     if (bitwidth == 64) {
1115       if (operands.size() == 4) {
1116         return success();
1117       }
1118       return emitError(unknownLoc)
1119              << opname << " should have 2 parameters for 64-bit values";
1120     }
1121     if (bitwidth <= 32) {
1122       if (operands.size() == 3) {
1123         return success();
1124       }
1125 
1126       return emitError(unknownLoc)
1127              << opname
1128              << " should have 1 parameter for values with no more than 32 bits";
1129     }
1130     return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1131            << bitwidth;
1132   };
1133 
1134   auto resultID = operands[1];
1135 
1136   if (auto intType = resultType.dyn_cast<IntegerType>()) {
1137     auto bitwidth = intType.getWidth();
1138     if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1139       return failure();
1140     }
1141 
1142     APInt value;
1143     if (bitwidth == 64) {
1144       // 64-bit integers are represented with two SPIR-V words. According to
1145       // SPIR-V spec: "When the type’s bit width is larger than one word, the
1146       // literal’s low-order words appear first."
1147       struct DoubleWord {
1148         uint32_t word1;
1149         uint32_t word2;
1150       } words = {operands[2], operands[3]};
1151       value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1152     } else if (bitwidth <= 32) {
1153       value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1154     }
1155 
1156     auto attr = opBuilder.getIntegerAttr(intType, value);
1157 
1158     if (isSpec) {
1159       createSpecConstant(unknownLoc, resultID, attr);
1160     } else {
1161       // For normal constants, we just record the attribute (and its type) for
1162       // later materialization at use sites.
1163       constantMap.try_emplace(resultID, attr, intType);
1164     }
1165 
1166     return success();
1167   }
1168 
1169   if (auto floatType = resultType.dyn_cast<FloatType>()) {
1170     auto bitwidth = floatType.getWidth();
1171     if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1172       return failure();
1173     }
1174 
1175     APFloat value(0.f);
1176     if (floatType.isF64()) {
1177       // Double values are represented with two SPIR-V words. According to
1178       // SPIR-V spec: "When the type’s bit width is larger than one word, the
1179       // literal’s low-order words appear first."
1180       struct DoubleWord {
1181         uint32_t word1;
1182         uint32_t word2;
1183       } words = {operands[2], operands[3]};
1184       value = APFloat(llvm::bit_cast<double>(words));
1185     } else if (floatType.isF32()) {
1186       value = APFloat(llvm::bit_cast<float>(operands[2]));
1187     } else if (floatType.isF16()) {
1188       APInt data(16, operands[2]);
1189       value = APFloat(APFloat::IEEEhalf(), data);
1190     }
1191 
1192     auto attr = opBuilder.getFloatAttr(floatType, value);
1193     if (isSpec) {
1194       createSpecConstant(unknownLoc, resultID, attr);
1195     } else {
1196       // For normal constants, we just record the attribute (and its type) for
1197       // later materialization at use sites.
1198       constantMap.try_emplace(resultID, attr, floatType);
1199     }
1200 
1201     return success();
1202   }
1203 
1204   return emitError(unknownLoc, "OpConstant can only generate values of "
1205                                "scalar integer or floating-point type");
1206 }
1207 
processConstantBool(bool isTrue,ArrayRef<uint32_t> operands,bool isSpec)1208 LogicalResult spirv::Deserializer::processConstantBool(
1209     bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1210   if (operands.size() != 2) {
1211     return emitError(unknownLoc, "Op")
1212            << (isSpec ? "Spec" : "") << "Constant"
1213            << (isTrue ? "True" : "False")
1214            << " must have type <id> and result <id>";
1215   }
1216 
1217   auto attr = opBuilder.getBoolAttr(isTrue);
1218   auto resultID = operands[1];
1219   if (isSpec) {
1220     createSpecConstant(unknownLoc, resultID, attr);
1221   } else {
1222     // For normal constants, we just record the attribute (and its type) for
1223     // later materialization at use sites.
1224     constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1225   }
1226 
1227   return success();
1228 }
1229 
1230 LogicalResult
processConstantComposite(ArrayRef<uint32_t> operands)1231 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1232   if (operands.size() < 2) {
1233     return emitError(unknownLoc,
1234                      "OpConstantComposite must have type <id> and result <id>");
1235   }
1236   if (operands.size() < 3) {
1237     return emitError(unknownLoc,
1238                      "OpConstantComposite must have at least 1 parameter");
1239   }
1240 
1241   Type resultType = getType(operands[0]);
1242   if (!resultType) {
1243     return emitError(unknownLoc, "undefined result type from <id> ")
1244            << operands[0];
1245   }
1246 
1247   SmallVector<Attribute, 4> elements;
1248   elements.reserve(operands.size() - 2);
1249   for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1250     auto elementInfo = getConstant(operands[i]);
1251     if (!elementInfo) {
1252       return emitError(unknownLoc, "OpConstantComposite component <id> ")
1253              << operands[i] << " must come from a normal constant";
1254     }
1255     elements.push_back(elementInfo->first);
1256   }
1257 
1258   auto resultID = operands[1];
1259   if (auto vectorType = resultType.dyn_cast<VectorType>()) {
1260     auto attr = DenseElementsAttr::get(vectorType, elements);
1261     // For normal constants, we just record the attribute (and its type) for
1262     // later materialization at use sites.
1263     constantMap.try_emplace(resultID, attr, resultType);
1264   } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
1265     auto attr = opBuilder.getArrayAttr(elements);
1266     constantMap.try_emplace(resultID, attr, resultType);
1267   } else {
1268     return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1269            << resultType;
1270   }
1271 
1272   return success();
1273 }
1274 
1275 LogicalResult
processSpecConstantComposite(ArrayRef<uint32_t> operands)1276 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1277   if (operands.size() < 2) {
1278     return emitError(unknownLoc,
1279                      "OpConstantComposite must have type <id> and result <id>");
1280   }
1281   if (operands.size() < 3) {
1282     return emitError(unknownLoc,
1283                      "OpConstantComposite must have at least 1 parameter");
1284   }
1285 
1286   Type resultType = getType(operands[0]);
1287   if (!resultType) {
1288     return emitError(unknownLoc, "undefined result type from <id> ")
1289            << operands[0];
1290   }
1291 
1292   auto resultID = operands[1];
1293   auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1294 
1295   SmallVector<Attribute, 4> elements;
1296   elements.reserve(operands.size() - 2);
1297   for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1298     auto elementInfo = getSpecConstant(operands[i]);
1299     elements.push_back(SymbolRefAttr::get(elementInfo));
1300   }
1301 
1302   auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1303       unknownLoc, TypeAttr::get(resultType), symName,
1304       opBuilder.getArrayAttr(elements));
1305   specConstCompositeMap[resultID] = op;
1306 
1307   return success();
1308 }
1309 
1310 LogicalResult
processSpecConstantOperation(ArrayRef<uint32_t> operands)1311 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1312   if (operands.size() < 3)
1313     return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1314                                  "result <id>, and operand opcode");
1315 
1316   uint32_t resultTypeID = operands[0];
1317 
1318   if (!getType(resultTypeID))
1319     return emitError(unknownLoc, "undefined result type from <id> ")
1320            << resultTypeID;
1321 
1322   uint32_t resultID = operands[1];
1323   spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1324   auto emplaceResult = specConstOperationMap.try_emplace(
1325       resultID,
1326       SpecConstOperationMaterializationInfo{
1327           enclosedOpcode, resultTypeID,
1328           SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1329 
1330   if (!emplaceResult.second)
1331     return emitError(unknownLoc, "value with <id>: ")
1332            << resultID << " is probably defined before.";
1333 
1334   return success();
1335 }
1336 
materializeSpecConstantOperation(uint32_t resultID,spirv::Opcode enclosedOpcode,uint32_t resultTypeID,ArrayRef<uint32_t> enclosedOpOperands)1337 Value spirv::Deserializer::materializeSpecConstantOperation(
1338     uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1339     ArrayRef<uint32_t> enclosedOpOperands) {
1340 
1341   Type resultType = getType(resultTypeID);
1342 
1343   // Instructions wrapped by OpSpecConstantOp need an ID for their
1344   // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1345   // dialect wrapped op. For that purpose, a new value map is created and "fake"
1346   // ID in that map is assigned to the result of the enclosed instruction. Note
1347   // that there is no need to update this fake ID since we only need to
1348   // reference the created Value for the enclosed op from the spv::YieldOp
1349   // created later in this method (both of which are the only values in their
1350   // region: the SpecConstantOperation's region). If we encounter another
1351   // SpecConstantOperation in the module, we simply re-use the fake ID since the
1352   // previous Value assigned to it isn't visible in the current scope anyway.
1353   DenseMap<uint32_t, Value> newValueMap;
1354   llvm::SaveAndRestore<DenseMap<uint32_t, Value>> valueMapGuard(valueMap,
1355                                                                 newValueMap);
1356   constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1357 
1358   SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1359   enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1360   enclosedOpResultTypeAndOperands.push_back(fakeID);
1361   enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1362                                          enclosedOpOperands.end());
1363 
1364   // Process enclosed instruction before creating the enclosing
1365   // specConstantOperation (and its region). This way, references to constants,
1366   // global variables, and spec constants will be materialized outside the new
1367   // op's region. For more info, see Deserializer::getValue's implementation.
1368   if (failed(
1369           processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1370     return Value();
1371 
1372   // Since the enclosed op is emitted in the current block, split it in a
1373   // separate new block.
1374   Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1375 
1376   auto loc = createFileLineColLoc(opBuilder);
1377   auto specConstOperationOp =
1378       opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1379 
1380   Region &body = specConstOperationOp.body();
1381   // Move the new block into SpecConstantOperation's body.
1382   body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1383                           Region::iterator(enclosedBlock));
1384   Block &block = body.back();
1385 
1386   // RAII guard to reset the insertion point to the module's region after
1387   // deserializing the body of the specConstantOperation.
1388   OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1389   opBuilder.setInsertionPointToEnd(&block);
1390 
1391   opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1392   return specConstOperationOp.getResult();
1393 }
1394 
1395 LogicalResult
processConstantNull(ArrayRef<uint32_t> operands)1396 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1397   if (operands.size() != 2) {
1398     return emitError(unknownLoc,
1399                      "OpConstantNull must have type <id> and result <id>");
1400   }
1401 
1402   Type resultType = getType(operands[0]);
1403   if (!resultType) {
1404     return emitError(unknownLoc, "undefined result type from <id> ")
1405            << operands[0];
1406   }
1407 
1408   auto resultID = operands[1];
1409   if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) {
1410     auto attr = opBuilder.getZeroAttr(resultType);
1411     // For normal constants, we just record the attribute (and its type) for
1412     // later materialization at use sites.
1413     constantMap.try_emplace(resultID, attr, resultType);
1414     return success();
1415   }
1416 
1417   return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1418          << resultType;
1419 }
1420 
1421 //===----------------------------------------------------------------------===//
1422 // Control flow
1423 //===----------------------------------------------------------------------===//
1424 
getOrCreateBlock(uint32_t id)1425 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1426   if (auto *block = getBlock(id)) {
1427     LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1428                                   << " @ " << block << "\n");
1429     return block;
1430   }
1431 
1432   // We don't know where this block will be placed finally (in a
1433   // spv.mlir.selection or spv.mlir.loop or function). Create it into the
1434   // function for now and sort out the proper place later.
1435   auto *block = curFunction->addBlock();
1436   LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1437                                 << " @ " << block << "\n");
1438   return blockMap[id] = block;
1439 }
1440 
processBranch(ArrayRef<uint32_t> operands)1441 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1442   if (!curBlock) {
1443     return emitError(unknownLoc, "OpBranch must appear inside a block");
1444   }
1445 
1446   if (operands.size() != 1) {
1447     return emitError(unknownLoc, "OpBranch must take exactly one target label");
1448   }
1449 
1450   auto *target = getOrCreateBlock(operands[0]);
1451   auto loc = createFileLineColLoc(opBuilder);
1452   // The preceding instruction for the OpBranch instruction could be an
1453   // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1454   // the same OpLine information.
1455   opBuilder.create<spirv::BranchOp>(loc, target);
1456 
1457   clearDebugLine();
1458   return success();
1459 }
1460 
1461 LogicalResult
processBranchConditional(ArrayRef<uint32_t> operands)1462 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1463   if (!curBlock) {
1464     return emitError(unknownLoc,
1465                      "OpBranchConditional must appear inside a block");
1466   }
1467 
1468   if (operands.size() != 3 && operands.size() != 5) {
1469     return emitError(unknownLoc,
1470                      "OpBranchConditional must have condition, true label, "
1471                      "false label, and optionally two branch weights");
1472   }
1473 
1474   auto condition = getValue(operands[0]);
1475   auto *trueBlock = getOrCreateBlock(operands[1]);
1476   auto *falseBlock = getOrCreateBlock(operands[2]);
1477 
1478   Optional<std::pair<uint32_t, uint32_t>> weights;
1479   if (operands.size() == 5) {
1480     weights = std::make_pair(operands[3], operands[4]);
1481   }
1482   // The preceding instruction for the OpBranchConditional instruction could be
1483   // an OpSelectionMerge instruction, in this case they will have the same
1484   // OpLine information.
1485   auto loc = createFileLineColLoc(opBuilder);
1486   opBuilder.create<spirv::BranchConditionalOp>(
1487       loc, condition, trueBlock,
1488       /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1489       /*falseArguments=*/ArrayRef<Value>(), weights);
1490 
1491   clearDebugLine();
1492   return success();
1493 }
1494 
processLabel(ArrayRef<uint32_t> operands)1495 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1496   if (!curFunction) {
1497     return emitError(unknownLoc, "OpLabel must appear inside a function");
1498   }
1499 
1500   if (operands.size() != 1) {
1501     return emitError(unknownLoc, "OpLabel should only have result <id>");
1502   }
1503 
1504   auto labelID = operands[0];
1505   // We may have forward declared this block.
1506   auto *block = getOrCreateBlock(labelID);
1507   LLVM_DEBUG(logger.startLine()
1508              << "[block] populating block " << block << "\n");
1509   // If we have seen this block, make sure it was just a forward declaration.
1510   assert(block->empty() && "re-deserialize the same block!");
1511 
1512   opBuilder.setInsertionPointToStart(block);
1513   blockMap[labelID] = curBlock = block;
1514 
1515   return success();
1516 }
1517 
1518 LogicalResult
processSelectionMerge(ArrayRef<uint32_t> operands)1519 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1520   if (!curBlock) {
1521     return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1522   }
1523 
1524   if (operands.size() < 2) {
1525     return emitError(
1526         unknownLoc,
1527         "OpSelectionMerge must specify merge target and selection control");
1528   }
1529 
1530   auto *mergeBlock = getOrCreateBlock(operands[0]);
1531   auto loc = createFileLineColLoc(opBuilder);
1532   auto selectionControl = operands[1];
1533 
1534   if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1535            .second) {
1536     return emitError(
1537         unknownLoc,
1538         "a block cannot have more than one OpSelectionMerge instruction");
1539   }
1540 
1541   return success();
1542 }
1543 
1544 LogicalResult
processLoopMerge(ArrayRef<uint32_t> operands)1545 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1546   if (!curBlock) {
1547     return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1548   }
1549 
1550   if (operands.size() < 3) {
1551     return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1552                                  "continue target and loop control");
1553   }
1554 
1555   auto *mergeBlock = getOrCreateBlock(operands[0]);
1556   auto *continueBlock = getOrCreateBlock(operands[1]);
1557   auto loc = createFileLineColLoc(opBuilder);
1558   uint32_t loopControl = operands[2];
1559 
1560   if (!blockMergeInfo
1561            .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1562            .second) {
1563     return emitError(
1564         unknownLoc,
1565         "a block cannot have more than one OpLoopMerge instruction");
1566   }
1567 
1568   return success();
1569 }
1570 
processPhi(ArrayRef<uint32_t> operands)1571 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1572   if (!curBlock) {
1573     return emitError(unknownLoc, "OpPhi must appear in a block");
1574   }
1575 
1576   if (operands.size() < 4) {
1577     return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1578                                  "and variable-parent pairs");
1579   }
1580 
1581   // Create a block argument for this OpPhi instruction.
1582   Type blockArgType = getType(operands[0]);
1583   BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1584   valueMap[operands[1]] = blockArg;
1585   LLVM_DEBUG(logger.startLine()
1586              << "[phi] created block argument " << blockArg
1587              << " id = " << operands[1] << " of type " << blockArgType << "\n");
1588 
1589   // For each (value, predecessor) pair, insert the value to the predecessor's
1590   // blockPhiInfo entry so later we can fix the block argument there.
1591   for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1592     uint32_t value = operands[i];
1593     Block *predecessor = getOrCreateBlock(operands[i + 1]);
1594     std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1595     blockPhiInfo[predecessorTargetPair].push_back(value);
1596     LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1597                                   << " with arg id = " << value << "\n");
1598   }
1599 
1600   return success();
1601 }
1602 
1603 namespace {
1604 /// A class for putting all blocks in a structured selection/loop in a
1605 /// spv.mlir.selection/spv.mlir.loop op.
1606 class ControlFlowStructurizer {
1607 public:
1608 #ifndef NDEBUG
ControlFlowStructurizer(Location loc,uint32_t control,spirv::BlockMergeInfoMap & mergeInfo,Block * header,Block * merge,Block * cont,llvm::ScopedPrinter & logger)1609   ControlFlowStructurizer(Location loc, uint32_t control,
1610                           spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1611                           Block *merge, Block *cont,
1612                           llvm::ScopedPrinter &logger)
1613       : location(loc), control(control), blockMergeInfo(mergeInfo),
1614         headerBlock(header), mergeBlock(merge), continueBlock(cont),
1615         logger(logger) {}
1616 #else
1617   ControlFlowStructurizer(Location loc, uint32_t control,
1618                           spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1619                           Block *merge, Block *cont)
1620       : location(loc), control(control), blockMergeInfo(mergeInfo),
1621         headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1622 #endif
1623 
1624   /// Structurizes the loop at the given `headerBlock`.
1625   ///
1626   /// This method will create an spv.mlir.loop op in the `mergeBlock` and move
1627   /// all blocks in the structured loop into the spv.mlir.loop's region. All
1628   /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1629   /// method will also update `mergeInfo` by remapping all blocks inside to the
1630   /// newly cloned ones inside structured control flow op's regions.
1631   LogicalResult structurize();
1632 
1633 private:
1634   /// Creates a new spv.mlir.selection op at the beginning of the `mergeBlock`.
1635   spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1636 
1637   /// Creates a new spv.mlir.loop op at the beginning of the `mergeBlock`.
1638   spirv::LoopOp createLoopOp(uint32_t loopControl);
1639 
1640   /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1641   void collectBlocksInConstruct();
1642 
1643   Location location;
1644   uint32_t control;
1645 
1646   spirv::BlockMergeInfoMap &blockMergeInfo;
1647 
1648   Block *headerBlock;
1649   Block *mergeBlock;
1650   Block *continueBlock; // nullptr for spv.mlir.selection
1651 
1652   SetVector<Block *> constructBlocks;
1653 
1654 #ifndef NDEBUG
1655   /// A logger used to emit information during the deserialzation process.
1656   llvm::ScopedPrinter &logger;
1657 #endif
1658 };
1659 } // namespace
1660 
1661 spirv::SelectionOp
createSelectionOp(uint32_t selectionControl)1662 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1663   // Create a builder and set the insertion point to the beginning of the
1664   // merge block so that the newly created SelectionOp will be inserted there.
1665   OpBuilder builder(&mergeBlock->front());
1666 
1667   auto control = static_cast<spirv::SelectionControl>(selectionControl);
1668   auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1669   selectionOp.addMergeBlock();
1670 
1671   return selectionOp;
1672 }
1673 
createLoopOp(uint32_t loopControl)1674 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1675   // Create a builder and set the insertion point to the beginning of the
1676   // merge block so that the newly created LoopOp will be inserted there.
1677   OpBuilder builder(&mergeBlock->front());
1678 
1679   auto control = static_cast<spirv::LoopControl>(loopControl);
1680   auto loopOp = builder.create<spirv::LoopOp>(location, control);
1681   loopOp.addEntryAndMergeBlock();
1682 
1683   return loopOp;
1684 }
1685 
collectBlocksInConstruct()1686 void ControlFlowStructurizer::collectBlocksInConstruct() {
1687   assert(constructBlocks.empty() && "expected empty constructBlocks");
1688 
1689   // Put the header block in the work list first.
1690   constructBlocks.insert(headerBlock);
1691 
1692   // For each item in the work list, add its successors excluding the merge
1693   // block.
1694   for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1695     for (auto *successor : constructBlocks[i]->getSuccessors())
1696       if (successor != mergeBlock)
1697         constructBlocks.insert(successor);
1698   }
1699 }
1700 
structurize()1701 LogicalResult ControlFlowStructurizer::structurize() {
1702   Operation *op = nullptr;
1703   bool isLoop = continueBlock != nullptr;
1704   if (isLoop) {
1705     if (auto loopOp = createLoopOp(control))
1706       op = loopOp.getOperation();
1707   } else {
1708     if (auto selectionOp = createSelectionOp(control))
1709       op = selectionOp.getOperation();
1710   }
1711   if (!op)
1712     return failure();
1713   Region &body = op->getRegion(0);
1714 
1715   BlockAndValueMapping mapper;
1716   // All references to the old merge block should be directed to the
1717   // selection/loop merge block in the SelectionOp/LoopOp's region.
1718   mapper.map(mergeBlock, &body.back());
1719 
1720   collectBlocksInConstruct();
1721 
1722   // We've identified all blocks belonging to the selection/loop's region. Now
1723   // need to "move" them into the selection/loop. Instead of really moving the
1724   // blocks, in the following we copy them and remap all values and branches.
1725   // This is because:
1726   // * Inserting a block into a region requires the block not in any region
1727   //   before. But selections/loops can nest so we can create selection/loop ops
1728   //   in a nested manner, which means some blocks may already be in a
1729   //   selection/loop region when to be moved again.
1730   // * It's much trickier to fix up the branches into and out of the loop's
1731   //   region: we need to treat not-moved blocks and moved blocks differently:
1732   //   Not-moved blocks jumping to the loop header block need to jump to the
1733   //   merge point containing the new loop op but not the loop continue block's
1734   //   back edge. Moved blocks jumping out of the loop need to jump to the
1735   //   merge block inside the loop region but not other not-moved blocks.
1736   //   We cannot use replaceAllUsesWith clearly and it's harder to follow the
1737   //   logic.
1738 
1739   // Create a corresponding block in the SelectionOp/LoopOp's region for each
1740   // block in this loop construct.
1741   OpBuilder builder(body);
1742   for (auto *block : constructBlocks) {
1743     // Create a block and insert it before the selection/loop merge block in the
1744     // SelectionOp/LoopOp's region.
1745     auto *newBlock = builder.createBlock(&body.back());
1746     mapper.map(block, newBlock);
1747     LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1748                                   << " from block " << block << "\n");
1749     if (!isFnEntryBlock(block)) {
1750       for (BlockArgument blockArg : block->getArguments()) {
1751         auto newArg =
1752             newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1753         mapper.map(blockArg, newArg);
1754         LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1755                                       << blockArg << " to " << newArg << "\n");
1756       }
1757     } else {
1758       LLVM_DEBUG(logger.startLine()
1759                  << "[cf] block " << block << " is a function entry block\n");
1760     }
1761 
1762     for (auto &op : *block)
1763       newBlock->push_back(op.clone(mapper));
1764   }
1765 
1766   // Go through all ops and remap the operands.
1767   auto remapOperands = [&](Operation *op) {
1768     for (auto &operand : op->getOpOperands())
1769       if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1770         operand.set(mappedOp);
1771     for (auto &succOp : op->getBlockOperands())
1772       if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1773         succOp.set(mappedOp);
1774   };
1775   for (auto &block : body)
1776     block.walk(remapOperands);
1777 
1778   // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1779   // the selection/loop construct into its region. Next we need to fix the
1780   // connections between this new SelectionOp/LoopOp with existing blocks.
1781 
1782   // All existing incoming branches should go to the merge block, where the
1783   // SelectionOp/LoopOp resides right now.
1784   headerBlock->replaceAllUsesWith(mergeBlock);
1785 
1786   LLVM_DEBUG({
1787     logger.startLine() << "[cf] after cloning and fixing references:\n";
1788     headerBlock->getParentOp()->print(logger.getOStream());
1789     logger.startLine() << "\n";
1790   });
1791 
1792   if (isLoop) {
1793     if (!mergeBlock->args_empty()) {
1794       return mergeBlock->getParentOp()->emitError(
1795           "OpPhi in loop merge block unsupported");
1796     }
1797 
1798     // The loop header block may have block arguments. Since now we place the
1799     // loop op inside the old merge block, we need to make sure the old merge
1800     // block has the same block argument list.
1801     for (BlockArgument blockArg : headerBlock->getArguments())
1802       mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1803 
1804     // If the loop header block has block arguments, make sure the spv.Branch op
1805     // matches.
1806     SmallVector<Value, 4> blockArgs;
1807     if (!headerBlock->args_empty())
1808       blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1809 
1810     // The loop entry block should have a unconditional branch jumping to the
1811     // loop header block.
1812     builder.setInsertionPointToEnd(&body.front());
1813     builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1814                                     ArrayRef<Value>(blockArgs));
1815   }
1816 
1817   // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1818   // cleaned up.
1819   LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
1820   // First we need to drop all operands' references inside all blocks. This is
1821   // needed because we can have blocks referencing SSA values from one another.
1822   for (auto *block : constructBlocks)
1823     block->dropAllReferences();
1824 
1825   // Check that whether some op in the to-be-erased blocks still has uses. Those
1826   // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
1827   // region. We cannot handle such cases given that once a value is sinked into
1828   // the SelectionOp/LoopOp's region, there is no escape for it:
1829   // SelectionOp/LooOp does not support yield values right now.
1830   for (auto *block : constructBlocks) {
1831     for (Operation &op : *block)
1832       if (!op.use_empty())
1833         return op.emitOpError(
1834             "failed control flow structurization: it has uses outside of the "
1835             "enclosing selection/loop construct");
1836   }
1837 
1838   // Then erase all old blocks.
1839   for (auto *block : constructBlocks) {
1840     // We've cloned all blocks belonging to this construct into the structured
1841     // control flow op's region. Among these blocks, some may compose another
1842     // selection/loop. If so, they will be recorded within blockMergeInfo.
1843     // We need to update the pointers there to the newly remapped ones so we can
1844     // continue structurizing them later.
1845     // TODO: The asserts in the following assumes input SPIR-V blob forms
1846     // correctly nested selection/loop constructs. We should relax this and
1847     // support error cases better.
1848     auto it = blockMergeInfo.find(block);
1849     if (it != blockMergeInfo.end()) {
1850       // Use the original location for nested selection/loop ops.
1851       Location loc = it->second.loc;
1852 
1853       Block *newHeader = mapper.lookupOrNull(block);
1854       if (!newHeader)
1855         return emitError(loc, "failed control flow structurization: nested "
1856                               "loop header block should be remapped!");
1857 
1858       Block *newContinue = it->second.continueBlock;
1859       if (newContinue) {
1860         newContinue = mapper.lookupOrNull(newContinue);
1861         if (!newContinue)
1862           return emitError(loc, "failed control flow structurization: nested "
1863                                 "loop continue block should be remapped!");
1864       }
1865 
1866       Block *newMerge = it->second.mergeBlock;
1867       if (Block *mappedTo = mapper.lookupOrNull(newMerge))
1868         newMerge = mappedTo;
1869 
1870       // The iterator should be erased before adding a new entry into
1871       // blockMergeInfo to avoid iterator invalidation.
1872       blockMergeInfo.erase(it);
1873       blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
1874                                  newContinue);
1875     }
1876 
1877     // The structured selection/loop's entry block does not have arguments.
1878     // If the function's header block is also part of the structured control
1879     // flow, we cannot just simply erase it because it may contain arguments
1880     // matching the function signature and used by the cloned blocks.
1881     if (isFnEntryBlock(block)) {
1882       LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
1883                                     << " to only contain a spv.Branch op\n");
1884       // Still keep the function entry block for the potential block arguments,
1885       // but replace all ops inside with a branch to the merge block.
1886       block->clear();
1887       builder.setInsertionPointToEnd(block);
1888       builder.create<spirv::BranchOp>(location, mergeBlock);
1889     } else {
1890       LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
1891       block->erase();
1892     }
1893   }
1894 
1895   LLVM_DEBUG(logger.startLine()
1896              << "[cf] after structurizing construct with header block "
1897              << headerBlock << ":\n"
1898              << *op << "\n");
1899 
1900   return success();
1901 }
1902 
wireUpBlockArgument()1903 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
1904   LLVM_DEBUG({
1905     logger.startLine()
1906         << "//----- [phi] start wiring up block arguments -----//\n";
1907     logger.indent();
1908   });
1909 
1910   OpBuilder::InsertionGuard guard(opBuilder);
1911 
1912   for (const auto &info : blockPhiInfo) {
1913     Block *block = info.first.first;
1914     Block *target = info.first.second;
1915     const BlockPhiInfo &phiInfo = info.second;
1916     LLVM_DEBUG({
1917       logger.startLine() << "[phi] block " << block << "\n";
1918       logger.startLine() << "[phi] before creating block argument:\n";
1919       block->getParentOp()->print(logger.getOStream());
1920       logger.startLine() << "\n";
1921     });
1922 
1923     // Set insertion point to before this block's terminator early because we
1924     // may materialize ops via getValue() call.
1925     auto *op = block->getTerminator();
1926     opBuilder.setInsertionPoint(op);
1927 
1928     SmallVector<Value, 4> blockArgs;
1929     blockArgs.reserve(phiInfo.size());
1930     for (uint32_t valueId : phiInfo) {
1931       if (Value value = getValue(valueId)) {
1932         blockArgs.push_back(value);
1933         LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
1934                                       << " id = " << valueId << "\n");
1935       } else {
1936         return emitError(unknownLoc, "OpPhi references undefined value!");
1937       }
1938     }
1939 
1940     if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
1941       // Replace the previous branch op with a new one with block arguments.
1942       opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
1943                                         blockArgs);
1944       branchOp.erase();
1945     } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
1946       assert((branchCondOp.getTrueBlock() == target ||
1947               branchCondOp.getFalseBlock() == target) &&
1948              "expected target to be either the true or false target");
1949       if (target == branchCondOp.trueTarget())
1950         opBuilder.create<spirv::BranchConditionalOp>(
1951             branchCondOp.getLoc(), branchCondOp.condition(), blockArgs,
1952             branchCondOp.getFalseBlockArguments(),
1953             branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(),
1954             branchCondOp.falseTarget());
1955       else
1956         opBuilder.create<spirv::BranchConditionalOp>(
1957             branchCondOp.getLoc(), branchCondOp.condition(),
1958             branchCondOp.getTrueBlockArguments(), blockArgs,
1959             branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(),
1960             branchCondOp.getFalseBlock());
1961 
1962       branchCondOp.erase();
1963     } else {
1964       return emitError(unknownLoc, "unimplemented terminator for Phi creation");
1965     }
1966 
1967     LLVM_DEBUG({
1968       logger.startLine() << "[phi] after creating block argument:\n";
1969       block->getParentOp()->print(logger.getOStream());
1970       logger.startLine() << "\n";
1971     });
1972   }
1973   blockPhiInfo.clear();
1974 
1975   LLVM_DEBUG({
1976     logger.unindent();
1977     logger.startLine()
1978         << "//--- [phi] completed wiring up block arguments ---//\n";
1979   });
1980   return success();
1981 }
1982 
structurizeControlFlow()1983 LogicalResult spirv::Deserializer::structurizeControlFlow() {
1984   LLVM_DEBUG({
1985     logger.startLine()
1986         << "//----- [cf] start structurizing control flow -----//\n";
1987     logger.indent();
1988   });
1989 
1990   while (!blockMergeInfo.empty()) {
1991     Block *headerBlock = blockMergeInfo.begin()->first;
1992     BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
1993 
1994     LLVM_DEBUG({
1995       logger.startLine() << "[cf] header block " << headerBlock << ":\n";
1996       headerBlock->print(logger.getOStream());
1997       logger.startLine() << "\n";
1998     });
1999 
2000     auto *mergeBlock = mergeInfo.mergeBlock;
2001     assert(mergeBlock && "merge block cannot be nullptr");
2002     if (!mergeBlock->args_empty())
2003       return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2004     LLVM_DEBUG({
2005       logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2006       mergeBlock->print(logger.getOStream());
2007       logger.startLine() << "\n";
2008     });
2009 
2010     auto *continueBlock = mergeInfo.continueBlock;
2011     LLVM_DEBUG(if (continueBlock) {
2012       logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2013       continueBlock->print(logger.getOStream());
2014       logger.startLine() << "\n";
2015     });
2016     // Erase this case before calling into structurizer, who will update
2017     // blockMergeInfo.
2018     blockMergeInfo.erase(blockMergeInfo.begin());
2019     ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2020                                          blockMergeInfo, headerBlock,
2021                                          mergeBlock, continueBlock
2022 #ifndef NDEBUG
2023                                          ,
2024                                          logger
2025 #endif
2026     );
2027     if (failed(structurizer.structurize()))
2028       return failure();
2029   }
2030 
2031   LLVM_DEBUG({
2032     logger.unindent();
2033     logger.startLine()
2034         << "//--- [cf] completed structurizing control flow ---//\n";
2035   });
2036   return success();
2037 }
2038 
2039 //===----------------------------------------------------------------------===//
2040 // Debug
2041 //===----------------------------------------------------------------------===//
2042 
createFileLineColLoc(OpBuilder opBuilder)2043 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2044   if (!debugLine)
2045     return unknownLoc;
2046 
2047   auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2048   if (fileName.empty())
2049     fileName = "<unknown>";
2050   return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2051                              debugLine->column);
2052 }
2053 
2054 LogicalResult
processDebugLine(ArrayRef<uint32_t> operands)2055 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2056   // According to SPIR-V spec:
2057   // "This location information applies to the instructions physically
2058   // following this instruction, up to the first occurrence of any of the
2059   // following: the next end of block, the next OpLine instruction, or the next
2060   // OpNoLine instruction."
2061   if (operands.size() != 3)
2062     return emitError(unknownLoc, "OpLine must have 3 operands");
2063   debugLine = DebugLine{operands[0], operands[1], operands[2]};
2064   return success();
2065 }
2066 
clearDebugLine()2067 void spirv::Deserializer::clearDebugLine() { debugLine = llvm::None; }
2068 
2069 LogicalResult
processDebugString(ArrayRef<uint32_t> operands)2070 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2071   if (operands.size() < 2)
2072     return emitError(unknownLoc, "OpString needs at least 2 operands");
2073 
2074   if (!debugInfoMap.lookup(operands[0]).empty())
2075     return emitError(unknownLoc,
2076                      "duplicate debug string found for result <id> ")
2077            << operands[0];
2078 
2079   unsigned wordIndex = 1;
2080   StringRef debugString = decodeStringLiteral(operands, wordIndex);
2081   if (wordIndex != operands.size())
2082     return emitError(unknownLoc,
2083                      "unexpected trailing words in OpString instruction");
2084 
2085   debugInfoMap[operands[0]] = debugString;
2086   return success();
2087 }
2088