1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/Support/LogicalResult.h"
23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/bit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "spirv-deserialization"
36 
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given `block` is a function entry block.
42 static inline bool isFnEntryBlock(Block *block) {
43   return block->isEntryBlock() &&
44          isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50 
51 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52                                   MLIRContext *context)
53     : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
54       module(createModuleOp()), opBuilder(module->getRegion()) {}
55 
56 LogicalResult spirv::Deserializer::deserialize() {
57   LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n");
58 
59   if (failed(processHeader()))
60     return failure();
61 
62   spirv::Opcode opcode = spirv::Opcode::OpNop;
63   ArrayRef<uint32_t> operands;
64   auto binarySize = binary.size();
65   while (curOffset < binarySize) {
66     // Slice the next instruction out and populate `opcode` and `operands`.
67     // Internally this also updates `curOffset`.
68     if (failed(sliceInstruction(opcode, operands)))
69       return failure();
70 
71     if (failed(processInstruction(opcode, operands)))
72       return failure();
73   }
74 
75   assert(curOffset == binarySize &&
76          "deserializer should never index beyond the binary end");
77 
78   for (auto &deferred : deferredInstructions) {
79     if (failed(processInstruction(deferred.first, deferred.second, false))) {
80       return failure();
81     }
82   }
83 
84   attachVCETriple();
85 
86   LLVM_DEBUG(llvm::dbgs() << "+++ completed deserialization +++\n");
87   return success();
88 }
89 
90 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
91   return std::move(module);
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // Module structure
96 //===----------------------------------------------------------------------===//
97 
98 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
99   OpBuilder builder(context);
100   OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
101   spirv::ModuleOp::build(builder, state);
102   return cast<spirv::ModuleOp>(Operation::create(state));
103 }
104 
105 LogicalResult spirv::Deserializer::processHeader() {
106   if (binary.size() < spirv::kHeaderWordCount)
107     return emitError(unknownLoc,
108                      "SPIR-V binary module must have a 5-word header");
109 
110   if (binary[0] != spirv::kMagicNumber)
111     return emitError(unknownLoc, "incorrect magic number");
112 
113   // Version number bytes: 0 | major number | minor number | 0
114   uint32_t majorVersion = (binary[1] << 8) >> 24;
115   uint32_t minorVersion = (binary[1] << 16) >> 24;
116   if (majorVersion == 1) {
117     switch (minorVersion) {
118 #define MIN_VERSION_CASE(v)                                                    \
119   case v:                                                                      \
120     version = spirv::Version::V_1_##v;                                         \
121     break
122 
123       MIN_VERSION_CASE(0);
124       MIN_VERSION_CASE(1);
125       MIN_VERSION_CASE(2);
126       MIN_VERSION_CASE(3);
127       MIN_VERSION_CASE(4);
128       MIN_VERSION_CASE(5);
129 #undef MIN_VERSION_CASE
130     default:
131       return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
132              << minorVersion;
133     }
134   } else {
135     return emitError(unknownLoc, "unsupported SPIR-V major version: ")
136            << majorVersion;
137   }
138 
139   // TODO: generator number, bound, schema
140   curOffset = spirv::kHeaderWordCount;
141   return success();
142 }
143 
144 LogicalResult
145 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
146   if (operands.size() != 1)
147     return emitError(unknownLoc, "OpMemoryModel must have one parameter");
148 
149   auto cap = spirv::symbolizeCapability(operands[0]);
150   if (!cap)
151     return emitError(unknownLoc, "unknown capability: ") << operands[0];
152 
153   capabilities.insert(*cap);
154   return success();
155 }
156 
157 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
158   if (words.empty()) {
159     return emitError(
160         unknownLoc,
161         "OpExtension must have a literal string for the extension name");
162   }
163 
164   unsigned wordIndex = 0;
165   StringRef extName = decodeStringLiteral(words, wordIndex);
166   if (wordIndex != words.size())
167     return emitError(unknownLoc,
168                      "unexpected trailing words in OpExtension instruction");
169   auto ext = spirv::symbolizeExtension(extName);
170   if (!ext)
171     return emitError(unknownLoc, "unknown extension: ") << extName;
172 
173   extensions.insert(*ext);
174   return success();
175 }
176 
177 LogicalResult
178 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
179   if (words.size() < 2) {
180     return emitError(unknownLoc,
181                      "OpExtInstImport must have a result <id> and a literal "
182                      "string for the extended instruction set name");
183   }
184 
185   unsigned wordIndex = 1;
186   extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
187   if (wordIndex != words.size()) {
188     return emitError(unknownLoc,
189                      "unexpected trailing words in OpExtInstImport");
190   }
191   return success();
192 }
193 
194 void spirv::Deserializer::attachVCETriple() {
195   (*module)->setAttr(
196       spirv::ModuleOp::getVCETripleAttrName(),
197       spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
198                                 extensions.getArrayRef(), context));
199 }
200 
201 LogicalResult
202 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
203   if (operands.size() != 2)
204     return emitError(unknownLoc, "OpMemoryModel must have two operands");
205 
206   (*module)->setAttr(
207       "addressing_model",
208       opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front())));
209   (*module)->setAttr(
210       "memory_model",
211       opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back())));
212 
213   return success();
214 }
215 
216 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
217   // TODO: This function should also be auto-generated. For now, since only a
218   // few decorations are processed/handled in a meaningful manner, going with a
219   // manual implementation.
220   if (words.size() < 2) {
221     return emitError(
222         unknownLoc, "OpDecorate must have at least result <id> and Decoration");
223   }
224   auto decorationName =
225       stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
226   if (decorationName.empty()) {
227     return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
228   }
229   auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
230   auto symbol = opBuilder.getStringAttr(attrName);
231   switch (static_cast<spirv::Decoration>(words[1])) {
232   case spirv::Decoration::DescriptorSet:
233   case spirv::Decoration::Binding:
234     if (words.size() != 3) {
235       return emitError(unknownLoc, "OpDecorate with ")
236              << decorationName << " needs a single integer literal";
237     }
238     decorations[words[0]].set(
239         symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
240     break;
241   case spirv::Decoration::BuiltIn:
242     if (words.size() != 3) {
243       return emitError(unknownLoc, "OpDecorate with ")
244              << decorationName << " needs a single integer literal";
245     }
246     decorations[words[0]].set(
247         symbol, opBuilder.getStringAttr(
248                     stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
249     break;
250   case spirv::Decoration::ArrayStride:
251     if (words.size() != 3) {
252       return emitError(unknownLoc, "OpDecorate with ")
253              << decorationName << " needs a single integer literal";
254     }
255     typeDecorations[words[0]] = words[2];
256     break;
257   case spirv::Decoration::Aliased:
258   case spirv::Decoration::Block:
259   case spirv::Decoration::BufferBlock:
260   case spirv::Decoration::Flat:
261   case spirv::Decoration::NonReadable:
262   case spirv::Decoration::NonWritable:
263   case spirv::Decoration::NoPerspective:
264   case spirv::Decoration::Restrict:
265   case spirv::Decoration::RelaxedPrecision:
266     if (words.size() != 2) {
267       return emitError(unknownLoc, "OpDecoration with ")
268              << decorationName << "needs a single target <id>";
269     }
270     // Block decoration does not affect spv.struct type, but is still stored for
271     // verification.
272     // TODO: Update StructType to contain this information since
273     // it is needed for many validation rules.
274     decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
275     break;
276   case spirv::Decoration::Location:
277   case spirv::Decoration::SpecId:
278     if (words.size() != 3) {
279       return emitError(unknownLoc, "OpDecoration with ")
280              << decorationName << "needs a single integer literal";
281     }
282     decorations[words[0]].set(
283         symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
284     break;
285   default:
286     return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
287   }
288   return success();
289 }
290 
291 LogicalResult
292 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
293   // The binary layout of OpMemberDecorate is different comparing to OpDecorate
294   if (words.size() < 3) {
295     return emitError(unknownLoc,
296                      "OpMemberDecorate must have at least 3 operands");
297   }
298 
299   auto decoration = static_cast<spirv::Decoration>(words[2]);
300   if (decoration == spirv::Decoration::Offset && words.size() != 4) {
301     return emitError(unknownLoc,
302                      " missing offset specification in OpMemberDecorate with "
303                      "Offset decoration");
304   }
305   ArrayRef<uint32_t> decorationOperands;
306   if (words.size() > 3) {
307     decorationOperands = words.slice(3);
308   }
309   memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
310   return success();
311 }
312 
313 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
314   if (words.size() < 3) {
315     return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
316   }
317   unsigned wordIndex = 2;
318   auto name = decodeStringLiteral(words, wordIndex);
319   if (wordIndex != words.size()) {
320     return emitError(unknownLoc,
321                      "unexpected trailing words in OpMemberName instruction");
322   }
323   memberNameMap[words[0]][words[1]] = name;
324   return success();
325 }
326 
327 LogicalResult
328 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
329   if (curFunction) {
330     return emitError(unknownLoc, "found function inside function");
331   }
332 
333   // Get the result type
334   if (operands.size() != 4) {
335     return emitError(unknownLoc, "OpFunction must have 4 parameters");
336   }
337   Type resultType = getType(operands[0]);
338   if (!resultType) {
339     return emitError(unknownLoc, "undefined result type from <id> ")
340            << operands[0];
341   }
342 
343   if (funcMap.count(operands[1])) {
344     return emitError(unknownLoc, "duplicate function definition/declaration");
345   }
346 
347   auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
348   if (!fnControl) {
349     return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
350   }
351 
352   Type fnType = getType(operands[3]);
353   if (!fnType || !fnType.isa<FunctionType>()) {
354     return emitError(unknownLoc, "unknown function type from <id> ")
355            << operands[3];
356   }
357   auto functionType = fnType.cast<FunctionType>();
358 
359   if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
360       (functionType.getNumResults() == 1 &&
361        functionType.getResult(0) != resultType)) {
362     return emitError(unknownLoc, "mismatch in function type ")
363            << functionType << " and return type " << resultType << " specified";
364   }
365 
366   std::string fnName = getFunctionSymbol(operands[1]);
367   auto funcOp = opBuilder.create<spirv::FuncOp>(
368       unknownLoc, fnName, functionType, fnControl.getValue());
369   curFunction = funcMap[operands[1]] = funcOp;
370   LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = "
371                           << fnType << ", id = " << operands[1] << ") --\n");
372   auto *entryBlock = funcOp.addEntryBlock();
373   LLVM_DEBUG(llvm::dbgs() << "[block] created entry block " << entryBlock
374                           << "\n");
375 
376   // Parse the op argument instructions
377   if (functionType.getNumInputs()) {
378     for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
379       auto argType = functionType.getInput(i);
380       spirv::Opcode opcode = spirv::Opcode::OpNop;
381       ArrayRef<uint32_t> operands;
382       if (failed(sliceInstruction(opcode, operands,
383                                   spirv::Opcode::OpFunctionParameter))) {
384         return failure();
385       }
386       if (opcode != spirv::Opcode::OpFunctionParameter) {
387         return emitError(
388                    unknownLoc,
389                    "missing OpFunctionParameter instruction for argument ")
390                << i;
391       }
392       if (operands.size() != 2) {
393         return emitError(
394             unknownLoc,
395             "expected result type and result <id> for OpFunctionParameter");
396       }
397       auto argDefinedType = getType(operands[0]);
398       if (!argDefinedType || argDefinedType != argType) {
399         return emitError(unknownLoc,
400                          "mismatch in argument type between function type "
401                          "definition ")
402                << functionType << " and argument type definition "
403                << argDefinedType << " at argument " << i;
404       }
405       if (getValue(operands[1])) {
406         return emitError(unknownLoc, "duplicate definition of result <id> '")
407                << operands[1];
408       }
409       auto argValue = funcOp.getArgument(i);
410       valueMap[operands[1]] = argValue;
411     }
412   }
413 
414   // RAII guard to reset the insertion point to the module's region after
415   // deserializing the body of this function.
416   OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
417 
418   spirv::Opcode opcode = spirv::Opcode::OpNop;
419   ArrayRef<uint32_t> instOperands;
420 
421   // Special handling for the entry block. We need to make sure it starts with
422   // an OpLabel instruction. The entry block takes the same parameters as the
423   // function. All other blocks do not take any parameter. We have already
424   // created the entry block, here we need to register it to the correct label
425   // <id>.
426   if (failed(sliceInstruction(opcode, instOperands,
427                               spirv::Opcode::OpFunctionEnd))) {
428     return failure();
429   }
430   if (opcode == spirv::Opcode::OpFunctionEnd) {
431     LLVM_DEBUG(llvm::dbgs()
432                << "-- completed function '" << fnName << "' (type = " << fnType
433                << ", id = " << operands[1] << ") --\n");
434     return processFunctionEnd(instOperands);
435   }
436   if (opcode != spirv::Opcode::OpLabel) {
437     return emitError(unknownLoc, "a basic block must start with OpLabel");
438   }
439   if (instOperands.size() != 1) {
440     return emitError(unknownLoc, "OpLabel should only have result <id>");
441   }
442   blockMap[instOperands[0]] = entryBlock;
443   if (failed(processLabel(instOperands))) {
444     return failure();
445   }
446 
447   // Then process all the other instructions in the function until we hit
448   // OpFunctionEnd.
449   while (succeeded(sliceInstruction(opcode, instOperands,
450                                     spirv::Opcode::OpFunctionEnd)) &&
451          opcode != spirv::Opcode::OpFunctionEnd) {
452     if (failed(processInstruction(opcode, instOperands))) {
453       return failure();
454     }
455   }
456   if (opcode != spirv::Opcode::OpFunctionEnd) {
457     return failure();
458   }
459 
460   LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = "
461                           << fnType << ", id = " << operands[1] << ") --\n");
462   return processFunctionEnd(instOperands);
463 }
464 
465 LogicalResult
466 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
467   // Process OpFunctionEnd.
468   if (!operands.empty()) {
469     return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
470   }
471 
472   // Wire up block arguments from OpPhi instructions.
473   // Put all structured control flow in spv.mlir.selection/spv.mlir.loop ops.
474   if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
475     return failure();
476   }
477 
478   curBlock = nullptr;
479   curFunction = llvm::None;
480 
481   return success();
482 }
483 
484 Optional<std::pair<Attribute, Type>>
485 spirv::Deserializer::getConstant(uint32_t id) {
486   auto constIt = constantMap.find(id);
487   if (constIt == constantMap.end())
488     return llvm::None;
489   return constIt->getSecond();
490 }
491 
492 Optional<spirv::SpecConstOperationMaterializationInfo>
493 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
494   auto constIt = specConstOperationMap.find(id);
495   if (constIt == specConstOperationMap.end())
496     return llvm::None;
497   return constIt->getSecond();
498 }
499 
500 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
501   auto funcName = nameMap.lookup(id).str();
502   if (funcName.empty()) {
503     funcName = "spirv_fn_" + std::to_string(id);
504   }
505   return funcName;
506 }
507 
508 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
509   auto constName = nameMap.lookup(id).str();
510   if (constName.empty()) {
511     constName = "spirv_spec_const_" + std::to_string(id);
512   }
513   return constName;
514 }
515 
516 spirv::SpecConstantOp
517 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
518                                         Attribute defaultValue) {
519   auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
520   auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
521                                                     defaultValue);
522   if (decorations.count(resultID)) {
523     for (auto attr : decorations[resultID].getAttrs())
524       op->setAttr(attr.getName(), attr.getValue());
525   }
526   specConstMap[resultID] = op;
527   return op;
528 }
529 
530 LogicalResult
531 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
532   unsigned wordIndex = 0;
533   if (operands.size() < 3) {
534     return emitError(
535         unknownLoc,
536         "OpVariable needs at least 3 operands, type, <id> and storage class");
537   }
538 
539   // Result Type.
540   auto type = getType(operands[wordIndex]);
541   if (!type) {
542     return emitError(unknownLoc, "unknown result type <id> : ")
543            << operands[wordIndex];
544   }
545   auto ptrType = type.dyn_cast<spirv::PointerType>();
546   if (!ptrType) {
547     return emitError(unknownLoc,
548                      "expected a result type <id> to be a spv.ptr, found : ")
549            << type;
550   }
551   wordIndex++;
552 
553   // Result <id>.
554   auto variableID = operands[wordIndex];
555   auto variableName = nameMap.lookup(variableID).str();
556   if (variableName.empty()) {
557     variableName = "spirv_var_" + std::to_string(variableID);
558   }
559   wordIndex++;
560 
561   // Storage class.
562   auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
563   if (ptrType.getStorageClass() != storageClass) {
564     return emitError(unknownLoc, "mismatch in storage class of pointer type ")
565            << type << " and that specified in OpVariable instruction  : "
566            << stringifyStorageClass(storageClass);
567   }
568   wordIndex++;
569 
570   // Initializer.
571   FlatSymbolRefAttr initializer = nullptr;
572   if (wordIndex < operands.size()) {
573     auto initializerOp = getGlobalVariable(operands[wordIndex]);
574     if (!initializerOp) {
575       return emitError(unknownLoc, "unknown <id> ")
576              << operands[wordIndex] << "used as initializer";
577     }
578     wordIndex++;
579     initializer = SymbolRefAttr::get(initializerOp.getOperation());
580   }
581   if (wordIndex != operands.size()) {
582     return emitError(unknownLoc,
583                      "found more operands than expected when deserializing "
584                      "OpVariable instruction, only ")
585            << wordIndex << " of " << operands.size() << " processed";
586   }
587   auto loc = createFileLineColLoc(opBuilder);
588   auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
589       loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
590       initializer);
591 
592   // Decorations.
593   if (decorations.count(variableID)) {
594     for (auto attr : decorations[variableID].getAttrs())
595       varOp->setAttr(attr.getName(), attr.getValue());
596   }
597   globalVariableMap[variableID] = varOp;
598   return success();
599 }
600 
601 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
602   auto constInfo = getConstant(id);
603   if (!constInfo) {
604     return nullptr;
605   }
606   return constInfo->first.dyn_cast<IntegerAttr>();
607 }
608 
609 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
610   if (operands.size() < 2) {
611     return emitError(unknownLoc, "OpName needs at least 2 operands");
612   }
613   if (!nameMap.lookup(operands[0]).empty()) {
614     return emitError(unknownLoc, "duplicate name found for result <id> ")
615            << operands[0];
616   }
617   unsigned wordIndex = 1;
618   StringRef name = decodeStringLiteral(operands, wordIndex);
619   if (wordIndex != operands.size()) {
620     return emitError(unknownLoc,
621                      "unexpected trailing words in OpName instruction");
622   }
623   nameMap[operands[0]] = name;
624   return success();
625 }
626 
627 //===----------------------------------------------------------------------===//
628 // Type
629 //===----------------------------------------------------------------------===//
630 
631 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
632                                                ArrayRef<uint32_t> operands) {
633   if (operands.empty()) {
634     return emitError(unknownLoc, "type instruction with opcode ")
635            << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
636   }
637 
638   /// TODO: Types might be forward declared in some instructions and need to be
639   /// handled appropriately.
640   if (typeMap.count(operands[0])) {
641     return emitError(unknownLoc, "duplicate definition for result <id> ")
642            << operands[0];
643   }
644 
645   switch (opcode) {
646   case spirv::Opcode::OpTypeVoid:
647     if (operands.size() != 1)
648       return emitError(unknownLoc, "OpTypeVoid must have no parameters");
649     typeMap[operands[0]] = opBuilder.getNoneType();
650     break;
651   case spirv::Opcode::OpTypeBool:
652     if (operands.size() != 1)
653       return emitError(unknownLoc, "OpTypeBool must have no parameters");
654     typeMap[operands[0]] = opBuilder.getI1Type();
655     break;
656   case spirv::Opcode::OpTypeInt: {
657     if (operands.size() != 3)
658       return emitError(
659           unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
660 
661     // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
662     // to preserve or validate.
663     // 0 indicates unsigned, or no signedness semantics
664     // 1 indicates signed semantics."
665     //
666     // So we cannot differentiate signless and unsigned integers; always use
667     // signless semantics for such cases.
668     auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
669                                  : IntegerType::SignednessSemantics::Signless;
670     typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
671   } break;
672   case spirv::Opcode::OpTypeFloat: {
673     if (operands.size() != 2)
674       return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
675 
676     Type floatTy;
677     switch (operands[1]) {
678     case 16:
679       floatTy = opBuilder.getF16Type();
680       break;
681     case 32:
682       floatTy = opBuilder.getF32Type();
683       break;
684     case 64:
685       floatTy = opBuilder.getF64Type();
686       break;
687     default:
688       return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
689              << operands[1];
690     }
691     typeMap[operands[0]] = floatTy;
692   } break;
693   case spirv::Opcode::OpTypeVector: {
694     if (operands.size() != 3) {
695       return emitError(
696           unknownLoc,
697           "OpTypeVector must have element type and count parameters");
698     }
699     Type elementTy = getType(operands[1]);
700     if (!elementTy) {
701       return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
702              << operands[1];
703     }
704     typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
705   } break;
706   case spirv::Opcode::OpTypePointer: {
707     return processOpTypePointer(operands);
708   } break;
709   case spirv::Opcode::OpTypeArray:
710     return processArrayType(operands);
711   case spirv::Opcode::OpTypeCooperativeMatrixNV:
712     return processCooperativeMatrixType(operands);
713   case spirv::Opcode::OpTypeFunction:
714     return processFunctionType(operands);
715   case spirv::Opcode::OpTypeImage:
716     return processImageType(operands);
717   case spirv::Opcode::OpTypeSampledImage:
718     return processSampledImageType(operands);
719   case spirv::Opcode::OpTypeRuntimeArray:
720     return processRuntimeArrayType(operands);
721   case spirv::Opcode::OpTypeStruct:
722     return processStructType(operands);
723   case spirv::Opcode::OpTypeMatrix:
724     return processMatrixType(operands);
725   default:
726     return emitError(unknownLoc, "unhandled type instruction");
727   }
728   return success();
729 }
730 
731 LogicalResult
732 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
733   if (operands.size() != 3)
734     return emitError(unknownLoc, "OpTypePointer must have two parameters");
735 
736   auto pointeeType = getType(operands[2]);
737   if (!pointeeType)
738     return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
739            << operands[2];
740 
741   uint32_t typePointerID = operands[0];
742   auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
743   typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
744 
745   for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
746        deferredStructIt != std::end(deferredStructTypesInfos);) {
747     for (auto *unresolvedMemberIt =
748              std::begin(deferredStructIt->unresolvedMemberTypes);
749          unresolvedMemberIt !=
750          std::end(deferredStructIt->unresolvedMemberTypes);) {
751       if (unresolvedMemberIt->first == typePointerID) {
752         // The newly constructed pointer type can resolve one of the
753         // deferred struct type members; update the memberTypes list and
754         // clean the unresolvedMemberTypes list accordingly.
755         deferredStructIt->memberTypes[unresolvedMemberIt->second] =
756             typeMap[typePointerID];
757         unresolvedMemberIt =
758             deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
759       } else {
760         ++unresolvedMemberIt;
761       }
762     }
763 
764     if (deferredStructIt->unresolvedMemberTypes.empty()) {
765       // All deferred struct type members are now resolved, set the struct body.
766       auto structType = deferredStructIt->deferredStructType;
767 
768       assert(structType && "expected a spirv::StructType");
769       assert(structType.isIdentified() && "expected an indentified struct");
770 
771       if (failed(structType.trySetBody(
772               deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
773               deferredStructIt->memberDecorationsInfo)))
774         return failure();
775 
776       deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
777     } else {
778       ++deferredStructIt;
779     }
780   }
781 
782   return success();
783 }
784 
785 LogicalResult
786 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
787   if (operands.size() != 3) {
788     return emitError(unknownLoc,
789                      "OpTypeArray must have element type and count parameters");
790   }
791 
792   Type elementTy = getType(operands[1]);
793   if (!elementTy) {
794     return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
795            << operands[1];
796   }
797 
798   unsigned count = 0;
799   // TODO: The count can also come frome a specialization constant.
800   auto countInfo = getConstant(operands[2]);
801   if (!countInfo) {
802     return emitError(unknownLoc, "OpTypeArray count <id> ")
803            << operands[2] << "can only come from normal constant right now";
804   }
805 
806   if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) {
807     count = intVal.getValue().getZExtValue();
808   } else {
809     return emitError(unknownLoc, "OpTypeArray count must come from a "
810                                  "scalar integer constant instruction");
811   }
812 
813   typeMap[operands[0]] = spirv::ArrayType::get(
814       elementTy, count, typeDecorations.lookup(operands[0]));
815   return success();
816 }
817 
818 LogicalResult
819 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
820   assert(!operands.empty() && "No operands for processing function type");
821   if (operands.size() == 1) {
822     return emitError(unknownLoc, "missing return type for OpTypeFunction");
823   }
824   auto returnType = getType(operands[1]);
825   if (!returnType) {
826     return emitError(unknownLoc, "unknown return type in OpTypeFunction");
827   }
828   SmallVector<Type, 1> argTypes;
829   for (size_t i = 2, e = operands.size(); i < e; ++i) {
830     auto ty = getType(operands[i]);
831     if (!ty) {
832       return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
833     }
834     argTypes.push_back(ty);
835   }
836   ArrayRef<Type> returnTypes;
837   if (!isVoidType(returnType)) {
838     returnTypes = llvm::makeArrayRef(returnType);
839   }
840   typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
841   return success();
842 }
843 
844 LogicalResult
845 spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
846   if (operands.size() != 5) {
847     return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
848                                  "type and row x column parameters");
849   }
850 
851   Type elementTy = getType(operands[1]);
852   if (!elementTy) {
853     return emitError(unknownLoc,
854                      "OpTypeCooperativeMatrix references undefined <id> ")
855            << operands[1];
856   }
857 
858   auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
859   if (!scope) {
860     return emitError(unknownLoc,
861                      "OpTypeCooperativeMatrix references undefined scope <id> ")
862            << operands[2];
863   }
864 
865   unsigned rows = getConstantInt(operands[3]).getInt();
866   unsigned columns = getConstantInt(operands[4]).getInt();
867 
868   typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
869       elementTy, scope.getValue(), rows, columns);
870   return success();
871 }
872 
873 LogicalResult
874 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
875   if (operands.size() != 2) {
876     return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
877   }
878   Type memberType = getType(operands[1]);
879   if (!memberType) {
880     return emitError(unknownLoc,
881                      "OpTypeRuntimeArray references undefined <id> ")
882            << operands[1];
883   }
884   typeMap[operands[0]] = spirv::RuntimeArrayType::get(
885       memberType, typeDecorations.lookup(operands[0]));
886   return success();
887 }
888 
889 LogicalResult
890 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
891   // TODO: Find a way to handle identified structs when debug info is stripped.
892 
893   if (operands.empty()) {
894     return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
895   }
896 
897   if (operands.size() == 1) {
898     // Handle empty struct.
899     typeMap[operands[0]] =
900         spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
901     return success();
902   }
903 
904   // First element is operand ID, second element is member index in the struct.
905   SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
906   SmallVector<Type, 4> memberTypes;
907 
908   for (auto op : llvm::drop_begin(operands, 1)) {
909     Type memberType = getType(op);
910     bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
911 
912     if (!memberType && !typeForwardPtr)
913       return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
914              << op;
915 
916     if (!memberType)
917       unresolvedMemberTypes.emplace_back(op, memberTypes.size());
918 
919     memberTypes.push_back(memberType);
920   }
921 
922   SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
923   SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
924   if (memberDecorationMap.count(operands[0])) {
925     auto &allMemberDecorations = memberDecorationMap[operands[0]];
926     for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
927       if (allMemberDecorations.count(memberIndex)) {
928         for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
929           // Check for offset.
930           if (memberDecoration.first == spirv::Decoration::Offset) {
931             // If offset info is empty, resize to the number of members;
932             if (offsetInfo.empty()) {
933               offsetInfo.resize(memberTypes.size());
934             }
935             offsetInfo[memberIndex] = memberDecoration.second[0];
936           } else {
937             if (!memberDecoration.second.empty()) {
938               memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
939                                                  memberDecoration.first,
940                                                  memberDecoration.second[0]);
941             } else {
942               memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
943                                                  memberDecoration.first, 0);
944             }
945           }
946         }
947       }
948     }
949   }
950 
951   uint32_t structID = operands[0];
952   std::string structIdentifier = nameMap.lookup(structID).str();
953 
954   if (structIdentifier.empty()) {
955     assert(unresolvedMemberTypes.empty() &&
956            "didn't expect unresolved member types");
957     typeMap[structID] =
958         spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
959   } else {
960     auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
961     typeMap[structID] = structTy;
962 
963     if (!unresolvedMemberTypes.empty())
964       deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
965                                           memberTypes, offsetInfo,
966                                           memberDecorationsInfo});
967     else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
968                                         memberDecorationsInfo)))
969       return failure();
970   }
971 
972   // TODO: Update StructType to have member name as attribute as
973   // well.
974   return success();
975 }
976 
977 LogicalResult
978 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
979   if (operands.size() != 3) {
980     // Three operands are needed: result_id, column_type, and column_count
981     return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
982                                  " (result_id, column_type, and column_count)");
983   }
984   // Matrix columns must be of vector type
985   Type elementTy = getType(operands[1]);
986   if (!elementTy) {
987     return emitError(unknownLoc,
988                      "OpTypeMatrix references undefined column type.")
989            << operands[1];
990   }
991 
992   uint32_t colsCount = operands[2];
993   typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
994   return success();
995 }
996 
997 LogicalResult
998 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
999   if (operands.size() != 2)
1000     return emitError(unknownLoc,
1001                      "OpTypeForwardPointer instruction must have two operands");
1002 
1003   typeForwardPointerIDs.insert(operands[0]);
1004   // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1005   // instruction that defines the actual type.
1006 
1007   return success();
1008 }
1009 
1010 LogicalResult
1011 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1012   // TODO: Add support for Access Qualifier.
1013   if (operands.size() != 8)
1014     return emitError(
1015         unknownLoc,
1016         "OpTypeImage with non-eight operands are not supported yet");
1017 
1018   Type elementTy = getType(operands[1]);
1019   if (!elementTy)
1020     return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1021            << operands[1];
1022 
1023   auto dim = spirv::symbolizeDim(operands[2]);
1024   if (!dim)
1025     return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1026            << operands[2];
1027 
1028   auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1029   if (!depthInfo)
1030     return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1031            << operands[3];
1032 
1033   auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1034   if (!arrayedInfo)
1035     return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1036            << operands[4];
1037 
1038   auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1039   if (!samplingInfo)
1040     return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1041 
1042   auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1043   if (!samplerUseInfo)
1044     return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1045            << operands[6];
1046 
1047   auto format = spirv::symbolizeImageFormat(operands[7]);
1048   if (!format)
1049     return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1050            << operands[7];
1051 
1052   typeMap[operands[0]] = spirv::ImageType::get(
1053       elementTy, dim.getValue(), depthInfo.getValue(), arrayedInfo.getValue(),
1054       samplingInfo.getValue(), samplerUseInfo.getValue(), format.getValue());
1055   return success();
1056 }
1057 
1058 LogicalResult
1059 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1060   if (operands.size() != 2)
1061     return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1062 
1063   Type elementTy = getType(operands[1]);
1064   if (!elementTy)
1065     return emitError(unknownLoc,
1066                      "OpTypeSampledImage references undefined <id>: ")
1067            << operands[1];
1068 
1069   typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1070   return success();
1071 }
1072 
1073 //===----------------------------------------------------------------------===//
1074 // Constant
1075 //===----------------------------------------------------------------------===//
1076 
1077 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1078                                                    bool isSpec) {
1079   StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1080 
1081   if (operands.size() < 2) {
1082     return emitError(unknownLoc)
1083            << opname << " must have type <id> and result <id>";
1084   }
1085   if (operands.size() < 3) {
1086     return emitError(unknownLoc)
1087            << opname << " must have at least 1 more parameter";
1088   }
1089 
1090   Type resultType = getType(operands[0]);
1091   if (!resultType) {
1092     return emitError(unknownLoc, "undefined result type from <id> ")
1093            << operands[0];
1094   }
1095 
1096   auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1097     if (bitwidth == 64) {
1098       if (operands.size() == 4) {
1099         return success();
1100       }
1101       return emitError(unknownLoc)
1102              << opname << " should have 2 parameters for 64-bit values";
1103     }
1104     if (bitwidth <= 32) {
1105       if (operands.size() == 3) {
1106         return success();
1107       }
1108 
1109       return emitError(unknownLoc)
1110              << opname
1111              << " should have 1 parameter for values with no more than 32 bits";
1112     }
1113     return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1114            << bitwidth;
1115   };
1116 
1117   auto resultID = operands[1];
1118 
1119   if (auto intType = resultType.dyn_cast<IntegerType>()) {
1120     auto bitwidth = intType.getWidth();
1121     if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1122       return failure();
1123     }
1124 
1125     APInt value;
1126     if (bitwidth == 64) {
1127       // 64-bit integers are represented with two SPIR-V words. According to
1128       // SPIR-V spec: "When the type’s bit width is larger than one word, the
1129       // literal’s low-order words appear first."
1130       struct DoubleWord {
1131         uint32_t word1;
1132         uint32_t word2;
1133       } words = {operands[2], operands[3]};
1134       value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1135     } else if (bitwidth <= 32) {
1136       value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1137     }
1138 
1139     auto attr = opBuilder.getIntegerAttr(intType, value);
1140 
1141     if (isSpec) {
1142       createSpecConstant(unknownLoc, resultID, attr);
1143     } else {
1144       // For normal constants, we just record the attribute (and its type) for
1145       // later materialization at use sites.
1146       constantMap.try_emplace(resultID, attr, intType);
1147     }
1148 
1149     return success();
1150   }
1151 
1152   if (auto floatType = resultType.dyn_cast<FloatType>()) {
1153     auto bitwidth = floatType.getWidth();
1154     if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1155       return failure();
1156     }
1157 
1158     APFloat value(0.f);
1159     if (floatType.isF64()) {
1160       // Double values are represented with two SPIR-V words. According to
1161       // SPIR-V spec: "When the type’s bit width is larger than one word, the
1162       // literal’s low-order words appear first."
1163       struct DoubleWord {
1164         uint32_t word1;
1165         uint32_t word2;
1166       } words = {operands[2], operands[3]};
1167       value = APFloat(llvm::bit_cast<double>(words));
1168     } else if (floatType.isF32()) {
1169       value = APFloat(llvm::bit_cast<float>(operands[2]));
1170     } else if (floatType.isF16()) {
1171       APInt data(16, operands[2]);
1172       value = APFloat(APFloat::IEEEhalf(), data);
1173     }
1174 
1175     auto attr = opBuilder.getFloatAttr(floatType, value);
1176     if (isSpec) {
1177       createSpecConstant(unknownLoc, resultID, attr);
1178     } else {
1179       // For normal constants, we just record the attribute (and its type) for
1180       // later materialization at use sites.
1181       constantMap.try_emplace(resultID, attr, floatType);
1182     }
1183 
1184     return success();
1185   }
1186 
1187   return emitError(unknownLoc, "OpConstant can only generate values of "
1188                                "scalar integer or floating-point type");
1189 }
1190 
1191 LogicalResult spirv::Deserializer::processConstantBool(
1192     bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1193   if (operands.size() != 2) {
1194     return emitError(unknownLoc, "Op")
1195            << (isSpec ? "Spec" : "") << "Constant"
1196            << (isTrue ? "True" : "False")
1197            << " must have type <id> and result <id>";
1198   }
1199 
1200   auto attr = opBuilder.getBoolAttr(isTrue);
1201   auto resultID = operands[1];
1202   if (isSpec) {
1203     createSpecConstant(unknownLoc, resultID, attr);
1204   } else {
1205     // For normal constants, we just record the attribute (and its type) for
1206     // later materialization at use sites.
1207     constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1208   }
1209 
1210   return success();
1211 }
1212 
1213 LogicalResult
1214 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1215   if (operands.size() < 2) {
1216     return emitError(unknownLoc,
1217                      "OpConstantComposite must have type <id> and result <id>");
1218   }
1219   if (operands.size() < 3) {
1220     return emitError(unknownLoc,
1221                      "OpConstantComposite must have at least 1 parameter");
1222   }
1223 
1224   Type resultType = getType(operands[0]);
1225   if (!resultType) {
1226     return emitError(unknownLoc, "undefined result type from <id> ")
1227            << operands[0];
1228   }
1229 
1230   SmallVector<Attribute, 4> elements;
1231   elements.reserve(operands.size() - 2);
1232   for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1233     auto elementInfo = getConstant(operands[i]);
1234     if (!elementInfo) {
1235       return emitError(unknownLoc, "OpConstantComposite component <id> ")
1236              << operands[i] << " must come from a normal constant";
1237     }
1238     elements.push_back(elementInfo->first);
1239   }
1240 
1241   auto resultID = operands[1];
1242   if (auto vectorType = resultType.dyn_cast<VectorType>()) {
1243     auto attr = DenseElementsAttr::get(vectorType, elements);
1244     // For normal constants, we just record the attribute (and its type) for
1245     // later materialization at use sites.
1246     constantMap.try_emplace(resultID, attr, resultType);
1247   } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
1248     auto attr = opBuilder.getArrayAttr(elements);
1249     constantMap.try_emplace(resultID, attr, resultType);
1250   } else {
1251     return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1252            << resultType;
1253   }
1254 
1255   return success();
1256 }
1257 
1258 LogicalResult
1259 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1260   if (operands.size() < 2) {
1261     return emitError(unknownLoc,
1262                      "OpConstantComposite must have type <id> and result <id>");
1263   }
1264   if (operands.size() < 3) {
1265     return emitError(unknownLoc,
1266                      "OpConstantComposite must have at least 1 parameter");
1267   }
1268 
1269   Type resultType = getType(operands[0]);
1270   if (!resultType) {
1271     return emitError(unknownLoc, "undefined result type from <id> ")
1272            << operands[0];
1273   }
1274 
1275   auto resultID = operands[1];
1276   auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1277 
1278   SmallVector<Attribute, 4> elements;
1279   elements.reserve(operands.size() - 2);
1280   for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1281     auto elementInfo = getSpecConstant(operands[i]);
1282     elements.push_back(SymbolRefAttr::get(elementInfo));
1283   }
1284 
1285   auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1286       unknownLoc, TypeAttr::get(resultType), symName,
1287       opBuilder.getArrayAttr(elements));
1288   specConstCompositeMap[resultID] = op;
1289 
1290   return success();
1291 }
1292 
1293 LogicalResult
1294 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1295   if (operands.size() < 3)
1296     return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1297                                  "result <id>, and operand opcode");
1298 
1299   uint32_t resultTypeID = operands[0];
1300 
1301   if (!getType(resultTypeID))
1302     return emitError(unknownLoc, "undefined result type from <id> ")
1303            << resultTypeID;
1304 
1305   uint32_t resultID = operands[1];
1306   spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1307   auto emplaceResult = specConstOperationMap.try_emplace(
1308       resultID,
1309       SpecConstOperationMaterializationInfo{
1310           enclosedOpcode, resultTypeID,
1311           SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1312 
1313   if (!emplaceResult.second)
1314     return emitError(unknownLoc, "value with <id>: ")
1315            << resultID << " is probably defined before.";
1316 
1317   return success();
1318 }
1319 
1320 Value spirv::Deserializer::materializeSpecConstantOperation(
1321     uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1322     ArrayRef<uint32_t> enclosedOpOperands) {
1323 
1324   Type resultType = getType(resultTypeID);
1325 
1326   // Instructions wrapped by OpSpecConstantOp need an ID for their
1327   // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1328   // dialect wrapped op. For that purpose, a new value map is created and "fake"
1329   // ID in that map is assigned to the result of the enclosed instruction. Note
1330   // that there is no need to update this fake ID since we only need to
1331   // reference the created Value for the enclosed op from the spv::YieldOp
1332   // created later in this method (both of which are the only values in their
1333   // region: the SpecConstantOperation's region). If we encounter another
1334   // SpecConstantOperation in the module, we simply re-use the fake ID since the
1335   // previous Value assigned to it isn't visible in the current scope anyway.
1336   DenseMap<uint32_t, Value> newValueMap;
1337   llvm::SaveAndRestore<DenseMap<uint32_t, Value>> valueMapGuard(valueMap,
1338                                                                 newValueMap);
1339   constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1340 
1341   SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1342   enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1343   enclosedOpResultTypeAndOperands.push_back(fakeID);
1344   enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1345                                          enclosedOpOperands.end());
1346 
1347   // Process enclosed instruction before creating the enclosing
1348   // specConstantOperation (and its region). This way, references to constants,
1349   // global variables, and spec constants will be materialized outside the new
1350   // op's region. For more info, see Deserializer::getValue's implementation.
1351   if (failed(
1352           processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1353     return Value();
1354 
1355   // Since the enclosed op is emitted in the current block, split it in a
1356   // separate new block.
1357   Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1358 
1359   auto loc = createFileLineColLoc(opBuilder);
1360   auto specConstOperationOp =
1361       opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1362 
1363   Region &body = specConstOperationOp.body();
1364   // Move the new block into SpecConstantOperation's body.
1365   body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1366                           Region::iterator(enclosedBlock));
1367   Block &block = body.back();
1368 
1369   // RAII guard to reset the insertion point to the module's region after
1370   // deserializing the body of the specConstantOperation.
1371   OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1372   opBuilder.setInsertionPointToEnd(&block);
1373 
1374   opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1375   return specConstOperationOp.getResult();
1376 }
1377 
1378 LogicalResult
1379 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1380   if (operands.size() != 2) {
1381     return emitError(unknownLoc,
1382                      "OpConstantNull must have type <id> and result <id>");
1383   }
1384 
1385   Type resultType = getType(operands[0]);
1386   if (!resultType) {
1387     return emitError(unknownLoc, "undefined result type from <id> ")
1388            << operands[0];
1389   }
1390 
1391   auto resultID = operands[1];
1392   if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) {
1393     auto attr = opBuilder.getZeroAttr(resultType);
1394     // For normal constants, we just record the attribute (and its type) for
1395     // later materialization at use sites.
1396     constantMap.try_emplace(resultID, attr, resultType);
1397     return success();
1398   }
1399 
1400   return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1401          << resultType;
1402 }
1403 
1404 //===----------------------------------------------------------------------===//
1405 // Control flow
1406 //===----------------------------------------------------------------------===//
1407 
1408 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1409   if (auto *block = getBlock(id)) {
1410     LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id
1411                             << " @ " << block << "\n");
1412     return block;
1413   }
1414 
1415   // We don't know where this block will be placed finally (in a
1416   // spv.mlir.selection or spv.mlir.loop or function). Create it into the
1417   // function for now and sort out the proper place later.
1418   auto *block = curFunction->addBlock();
1419   LLVM_DEBUG(llvm::dbgs() << "[block] created block for id = " << id << " @ "
1420                           << block << "\n");
1421   return blockMap[id] = block;
1422 }
1423 
1424 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1425   if (!curBlock) {
1426     return emitError(unknownLoc, "OpBranch must appear inside a block");
1427   }
1428 
1429   if (operands.size() != 1) {
1430     return emitError(unknownLoc, "OpBranch must take exactly one target label");
1431   }
1432 
1433   auto *target = getOrCreateBlock(operands[0]);
1434   auto loc = createFileLineColLoc(opBuilder);
1435   // The preceding instruction for the OpBranch instruction could be an
1436   // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1437   // the same OpLine information.
1438   opBuilder.create<spirv::BranchOp>(loc, target);
1439 
1440   (void)clearDebugLine();
1441   return success();
1442 }
1443 
1444 LogicalResult
1445 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1446   if (!curBlock) {
1447     return emitError(unknownLoc,
1448                      "OpBranchConditional must appear inside a block");
1449   }
1450 
1451   if (operands.size() != 3 && operands.size() != 5) {
1452     return emitError(unknownLoc,
1453                      "OpBranchConditional must have condition, true label, "
1454                      "false label, and optionally two branch weights");
1455   }
1456 
1457   auto condition = getValue(operands[0]);
1458   auto *trueBlock = getOrCreateBlock(operands[1]);
1459   auto *falseBlock = getOrCreateBlock(operands[2]);
1460 
1461   Optional<std::pair<uint32_t, uint32_t>> weights;
1462   if (operands.size() == 5) {
1463     weights = std::make_pair(operands[3], operands[4]);
1464   }
1465   // The preceding instruction for the OpBranchConditional instruction could be
1466   // an OpSelectionMerge instruction, in this case they will have the same
1467   // OpLine information.
1468   auto loc = createFileLineColLoc(opBuilder);
1469   opBuilder.create<spirv::BranchConditionalOp>(
1470       loc, condition, trueBlock,
1471       /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1472       /*falseArguments=*/ArrayRef<Value>(), weights);
1473 
1474   (void)clearDebugLine();
1475   return success();
1476 }
1477 
1478 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1479   if (!curFunction) {
1480     return emitError(unknownLoc, "OpLabel must appear inside a function");
1481   }
1482 
1483   if (operands.size() != 1) {
1484     return emitError(unknownLoc, "OpLabel should only have result <id>");
1485   }
1486 
1487   auto labelID = operands[0];
1488   // We may have forward declared this block.
1489   auto *block = getOrCreateBlock(labelID);
1490   LLVM_DEBUG(llvm::dbgs() << "[block] populating block " << block << "\n");
1491   // If we have seen this block, make sure it was just a forward declaration.
1492   assert(block->empty() && "re-deserialize the same block!");
1493 
1494   opBuilder.setInsertionPointToStart(block);
1495   blockMap[labelID] = curBlock = block;
1496 
1497   return success();
1498 }
1499 
1500 LogicalResult
1501 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1502   if (!curBlock) {
1503     return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1504   }
1505 
1506   if (operands.size() < 2) {
1507     return emitError(
1508         unknownLoc,
1509         "OpSelectionMerge must specify merge target and selection control");
1510   }
1511 
1512   auto *mergeBlock = getOrCreateBlock(operands[0]);
1513   auto loc = createFileLineColLoc(opBuilder);
1514   auto selectionControl = operands[1];
1515 
1516   if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1517            .second) {
1518     return emitError(
1519         unknownLoc,
1520         "a block cannot have more than one OpSelectionMerge instruction");
1521   }
1522 
1523   return success();
1524 }
1525 
1526 LogicalResult
1527 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1528   if (!curBlock) {
1529     return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1530   }
1531 
1532   if (operands.size() < 3) {
1533     return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1534                                  "continue target and loop control");
1535   }
1536 
1537   auto *mergeBlock = getOrCreateBlock(operands[0]);
1538   auto *continueBlock = getOrCreateBlock(operands[1]);
1539   auto loc = createFileLineColLoc(opBuilder);
1540   uint32_t loopControl = operands[2];
1541 
1542   if (!blockMergeInfo
1543            .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1544            .second) {
1545     return emitError(
1546         unknownLoc,
1547         "a block cannot have more than one OpLoopMerge instruction");
1548   }
1549 
1550   return success();
1551 }
1552 
1553 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1554   if (!curBlock) {
1555     return emitError(unknownLoc, "OpPhi must appear in a block");
1556   }
1557 
1558   if (operands.size() < 4) {
1559     return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1560                                  "and variable-parent pairs");
1561   }
1562 
1563   // Create a block argument for this OpPhi instruction.
1564   Type blockArgType = getType(operands[0]);
1565   BlockArgument blockArg = curBlock->addArgument(blockArgType);
1566   valueMap[operands[1]] = blockArg;
1567   LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg
1568                           << " id = " << operands[1] << " of type "
1569                           << blockArgType << '\n');
1570 
1571   // For each (value, predecessor) pair, insert the value to the predecessor's
1572   // blockPhiInfo entry so later we can fix the block argument there.
1573   for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1574     uint32_t value = operands[i];
1575     Block *predecessor = getOrCreateBlock(operands[i + 1]);
1576     std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1577     blockPhiInfo[predecessorTargetPair].push_back(value);
1578     LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor
1579                             << " with arg id = " << value << '\n');
1580   }
1581 
1582   return success();
1583 }
1584 
1585 namespace {
1586 /// A class for putting all blocks in a structured selection/loop in a
1587 /// spv.mlir.selection/spv.mlir.loop op.
1588 class ControlFlowStructurizer {
1589 public:
1590   /// Structurizes the loop at the given `headerBlock`.
1591   ///
1592   /// This method will create an spv.mlir.loop op in the `mergeBlock` and move
1593   /// all blocks in the structured loop into the spv.mlir.loop's region. All
1594   /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1595   /// method will also update `mergeInfo` by remapping all blocks inside to the
1596   /// newly cloned ones inside structured control flow op's regions.
1597   static LogicalResult structurize(Location loc, uint32_t control,
1598                                    spirv::BlockMergeInfoMap &mergeInfo,
1599                                    Block *headerBlock, Block *mergeBlock,
1600                                    Block *continueBlock) {
1601     return ControlFlowStructurizer(loc, control, mergeInfo, headerBlock,
1602                                    mergeBlock, continueBlock)
1603         .structurizeImpl();
1604   }
1605 
1606 private:
1607   ControlFlowStructurizer(Location loc, uint32_t control,
1608                           spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1609                           Block *merge, Block *cont)
1610       : location(loc), control(control), blockMergeInfo(mergeInfo),
1611         headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1612 
1613   /// Creates a new spv.mlir.selection op at the beginning of the `mergeBlock`.
1614   spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1615 
1616   /// Creates a new spv.mlir.loop op at the beginning of the `mergeBlock`.
1617   spirv::LoopOp createLoopOp(uint32_t loopControl);
1618 
1619   /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1620   void collectBlocksInConstruct();
1621 
1622   LogicalResult structurizeImpl();
1623 
1624   Location location;
1625   uint32_t control;
1626 
1627   spirv::BlockMergeInfoMap &blockMergeInfo;
1628 
1629   Block *headerBlock;
1630   Block *mergeBlock;
1631   Block *continueBlock; // nullptr for spv.mlir.selection
1632 
1633   SetVector<Block *> constructBlocks;
1634 };
1635 } // namespace
1636 
1637 spirv::SelectionOp
1638 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1639   // Create a builder and set the insertion point to the beginning of the
1640   // merge block so that the newly created SelectionOp will be inserted there.
1641   OpBuilder builder(&mergeBlock->front());
1642 
1643   auto control = static_cast<spirv::SelectionControl>(selectionControl);
1644   auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1645   selectionOp.addMergeBlock();
1646 
1647   return selectionOp;
1648 }
1649 
1650 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1651   // Create a builder and set the insertion point to the beginning of the
1652   // merge block so that the newly created LoopOp will be inserted there.
1653   OpBuilder builder(&mergeBlock->front());
1654 
1655   auto control = static_cast<spirv::LoopControl>(loopControl);
1656   auto loopOp = builder.create<spirv::LoopOp>(location, control);
1657   loopOp.addEntryAndMergeBlock();
1658 
1659   return loopOp;
1660 }
1661 
1662 void ControlFlowStructurizer::collectBlocksInConstruct() {
1663   assert(constructBlocks.empty() && "expected empty constructBlocks");
1664 
1665   // Put the header block in the work list first.
1666   constructBlocks.insert(headerBlock);
1667 
1668   // For each item in the work list, add its successors excluding the merge
1669   // block.
1670   for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1671     for (auto *successor : constructBlocks[i]->getSuccessors())
1672       if (successor != mergeBlock)
1673         constructBlocks.insert(successor);
1674   }
1675 }
1676 
1677 LogicalResult ControlFlowStructurizer::structurizeImpl() {
1678   Operation *op = nullptr;
1679   bool isLoop = continueBlock != nullptr;
1680   if (isLoop) {
1681     if (auto loopOp = createLoopOp(control))
1682       op = loopOp.getOperation();
1683   } else {
1684     if (auto selectionOp = createSelectionOp(control))
1685       op = selectionOp.getOperation();
1686   }
1687   if (!op)
1688     return failure();
1689   Region &body = op->getRegion(0);
1690 
1691   BlockAndValueMapping mapper;
1692   // All references to the old merge block should be directed to the
1693   // selection/loop merge block in the SelectionOp/LoopOp's region.
1694   mapper.map(mergeBlock, &body.back());
1695 
1696   collectBlocksInConstruct();
1697 
1698   // We've identified all blocks belonging to the selection/loop's region. Now
1699   // need to "move" them into the selection/loop. Instead of really moving the
1700   // blocks, in the following we copy them and remap all values and branches.
1701   // This is because:
1702   // * Inserting a block into a region requires the block not in any region
1703   //   before. But selections/loops can nest so we can create selection/loop ops
1704   //   in a nested manner, which means some blocks may already be in a
1705   //   selection/loop region when to be moved again.
1706   // * It's much trickier to fix up the branches into and out of the loop's
1707   //   region: we need to treat not-moved blocks and moved blocks differently:
1708   //   Not-moved blocks jumping to the loop header block need to jump to the
1709   //   merge point containing the new loop op but not the loop continue block's
1710   //   back edge. Moved blocks jumping out of the loop need to jump to the
1711   //   merge block inside the loop region but not other not-moved blocks.
1712   //   We cannot use replaceAllUsesWith clearly and it's harder to follow the
1713   //   logic.
1714 
1715   // Create a corresponding block in the SelectionOp/LoopOp's region for each
1716   // block in this loop construct.
1717   OpBuilder builder(body);
1718   for (auto *block : constructBlocks) {
1719     // Create a block and insert it before the selection/loop merge block in the
1720     // SelectionOp/LoopOp's region.
1721     auto *newBlock = builder.createBlock(&body.back());
1722     mapper.map(block, newBlock);
1723     LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock
1724                             << " from block " << block << "\n");
1725     if (!isFnEntryBlock(block)) {
1726       for (BlockArgument blockArg : block->getArguments()) {
1727         auto newArg = newBlock->addArgument(blockArg.getType());
1728         mapper.map(blockArg, newArg);
1729         LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg
1730                                 << " to " << newArg << '\n');
1731       }
1732     } else {
1733       LLVM_DEBUG(llvm::dbgs()
1734                  << "[cf] block " << block << " is a function entry block\n");
1735     }
1736     for (auto &op : *block)
1737       newBlock->push_back(op.clone(mapper));
1738   }
1739 
1740   // Go through all ops and remap the operands.
1741   auto remapOperands = [&](Operation *op) {
1742     for (auto &operand : op->getOpOperands())
1743       if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1744         operand.set(mappedOp);
1745     for (auto &succOp : op->getBlockOperands())
1746       if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1747         succOp.set(mappedOp);
1748   };
1749   for (auto &block : body) {
1750     block.walk(remapOperands);
1751   }
1752 
1753   // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1754   // the selection/loop construct into its region. Next we need to fix the
1755   // connections between this new SelectionOp/LoopOp with existing blocks.
1756 
1757   // All existing incoming branches should go to the merge block, where the
1758   // SelectionOp/LoopOp resides right now.
1759   headerBlock->replaceAllUsesWith(mergeBlock);
1760 
1761   if (isLoop) {
1762     // The loop selection/loop header block may have block arguments. Since now
1763     // we place the selection/loop op inside the old merge block, we need to
1764     // make sure the old merge block has the same block argument list.
1765     assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported");
1766     for (BlockArgument blockArg : headerBlock->getArguments()) {
1767       mergeBlock->addArgument(blockArg.getType());
1768     }
1769 
1770     // If the loop header block has block arguments, make sure the spv.branch op
1771     // matches.
1772     SmallVector<Value, 4> blockArgs;
1773     if (!headerBlock->args_empty())
1774       blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1775 
1776     // The loop entry block should have a unconditional branch jumping to the
1777     // loop header block.
1778     builder.setInsertionPointToEnd(&body.front());
1779     builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1780                                     ArrayRef<Value>(blockArgs));
1781   }
1782 
1783   // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1784   // cleaned up.
1785   LLVM_DEBUG(llvm::dbgs() << "[cf] cleaning up blocks after clone\n");
1786   // First we need to drop all operands' references inside all blocks. This is
1787   // needed because we can have blocks referencing SSA values from one another.
1788   for (auto *block : constructBlocks)
1789     block->dropAllReferences();
1790 
1791   // Then erase all old blocks.
1792   for (auto *block : constructBlocks) {
1793     // We've cloned all blocks belonging to this construct into the structured
1794     // control flow op's region. Among these blocks, some may compose another
1795     // selection/loop. If so, they will be recorded within blockMergeInfo.
1796     // We need to update the pointers there to the newly remapped ones so we can
1797     // continue structurizing them later.
1798     // TODO: The asserts in the following assumes input SPIR-V blob
1799     // forms correctly nested selection/loop constructs. We should relax this
1800     // and support error cases better.
1801     auto it = blockMergeInfo.find(block);
1802     if (it != blockMergeInfo.end()) {
1803       Block *newHeader = mapper.lookupOrNull(block);
1804       assert(newHeader && "nested loop header block should be remapped!");
1805 
1806       Block *newContinue = it->second.continueBlock;
1807       if (newContinue) {
1808         newContinue = mapper.lookupOrNull(newContinue);
1809         assert(newContinue && "nested loop continue block should be remapped!");
1810       }
1811 
1812       Block *newMerge = it->second.mergeBlock;
1813       if (Block *mappedTo = mapper.lookupOrNull(newMerge))
1814         newMerge = mappedTo;
1815 
1816       // Keep original location for nested selection/loop ops.
1817       Location loc = it->second.loc;
1818       // The iterator should be erased before adding a new entry into
1819       // blockMergeInfo to avoid iterator invalidation.
1820       blockMergeInfo.erase(it);
1821       blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
1822                                  newContinue);
1823     }
1824 
1825     // The structured selection/loop's entry block does not have arguments.
1826     // If the function's header block is also part of the structured control
1827     // flow, we cannot just simply erase it because it may contain arguments
1828     // matching the function signature and used by the cloned blocks.
1829     if (isFnEntryBlock(block)) {
1830       LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block
1831                               << " to only contain a spv.Branch op\n");
1832       // Still keep the function entry block for the potential block arguments,
1833       // but replace all ops inside with a branch to the merge block.
1834       block->clear();
1835       builder.setInsertionPointToEnd(block);
1836       builder.create<spirv::BranchOp>(location, mergeBlock);
1837     } else {
1838       LLVM_DEBUG(llvm::dbgs() << "[cf] erasing block " << block << "\n");
1839       block->erase();
1840     }
1841   }
1842 
1843   LLVM_DEBUG(
1844       llvm::dbgs() << "[cf] after structurizing construct with header block "
1845                    << headerBlock << ":\n"
1846                    << *op << '\n');
1847 
1848   return success();
1849 }
1850 
1851 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
1852   LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n");
1853 
1854   OpBuilder::InsertionGuard guard(opBuilder);
1855 
1856   for (const auto &info : blockPhiInfo) {
1857     Block *block = info.first.first;
1858     Block *target = info.first.second;
1859     const BlockPhiInfo &phiInfo = info.second;
1860     LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n");
1861     LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n");
1862     LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs()));
1863     LLVM_DEBUG(llvm::dbgs() << '\n');
1864 
1865     // Set insertion point to before this block's terminator early because we
1866     // may materialize ops via getValue() call.
1867     auto *op = block->getTerminator();
1868     opBuilder.setInsertionPoint(op);
1869 
1870     SmallVector<Value, 4> blockArgs;
1871     blockArgs.reserve(phiInfo.size());
1872     for (uint32_t valueId : phiInfo) {
1873       if (Value value = getValue(valueId)) {
1874         blockArgs.push_back(value);
1875         LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value
1876                                 << " id = " << valueId << '\n');
1877       } else {
1878         return emitError(unknownLoc, "OpPhi references undefined value!");
1879       }
1880     }
1881 
1882     if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
1883       // Replace the previous branch op with a new one with block arguments.
1884       opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
1885                                         blockArgs);
1886       branchOp.erase();
1887     } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
1888       assert((branchCondOp.getTrueBlock() == target ||
1889               branchCondOp.getFalseBlock() == target) &&
1890              "expected target to be either the true or false target");
1891       if (target == branchCondOp.trueTarget())
1892         opBuilder.create<spirv::BranchConditionalOp>(
1893             branchCondOp.getLoc(), branchCondOp.condition(), blockArgs,
1894             branchCondOp.getFalseBlockArguments(),
1895             branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(),
1896             branchCondOp.falseTarget());
1897       else
1898         opBuilder.create<spirv::BranchConditionalOp>(
1899             branchCondOp.getLoc(), branchCondOp.condition(),
1900             branchCondOp.getTrueBlockArguments(), blockArgs,
1901             branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(),
1902             branchCondOp.getFalseBlock());
1903 
1904       branchCondOp.erase();
1905     } else {
1906       return emitError(unknownLoc, "unimplemented terminator for Phi creation");
1907     }
1908 
1909     LLVM_DEBUG(llvm::dbgs() << "[phi] after creating block argument:\n");
1910     LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs()));
1911     LLVM_DEBUG(llvm::dbgs() << '\n');
1912   }
1913   blockPhiInfo.clear();
1914 
1915   LLVM_DEBUG(llvm::dbgs() << "[phi] completed wiring up block arguments\n");
1916   return success();
1917 }
1918 
1919 LogicalResult spirv::Deserializer::structurizeControlFlow() {
1920   LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n");
1921 
1922   while (!blockMergeInfo.empty()) {
1923     Block *headerBlock = blockMergeInfo.begin()->first;
1924     BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
1925 
1926     LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n");
1927     LLVM_DEBUG(headerBlock->print(llvm::dbgs()));
1928 
1929     auto *mergeBlock = mergeInfo.mergeBlock;
1930     assert(mergeBlock && "merge block cannot be nullptr");
1931     if (!mergeBlock->args_empty())
1932       return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
1933     LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n");
1934     LLVM_DEBUG(mergeBlock->print(llvm::dbgs()));
1935 
1936     auto *continueBlock = mergeInfo.continueBlock;
1937     if (continueBlock) {
1938       LLVM_DEBUG(llvm::dbgs()
1939                  << "[cf] continue block " << continueBlock << ":\n");
1940       LLVM_DEBUG(continueBlock->print(llvm::dbgs()));
1941     }
1942     // Erase this case before calling into structurizer, who will update
1943     // blockMergeInfo.
1944     blockMergeInfo.erase(blockMergeInfo.begin());
1945     if (failed(ControlFlowStructurizer::structurize(
1946             mergeInfo.loc, mergeInfo.control, blockMergeInfo, headerBlock,
1947             mergeBlock, continueBlock)))
1948       return failure();
1949   }
1950 
1951   LLVM_DEBUG(llvm::dbgs() << "[cf] completed structurizing control flow\n");
1952   return success();
1953 }
1954 
1955 //===----------------------------------------------------------------------===//
1956 // Debug
1957 //===----------------------------------------------------------------------===//
1958 
1959 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
1960   if (!debugLine)
1961     return unknownLoc;
1962 
1963   auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
1964   if (fileName.empty())
1965     fileName = "<unknown>";
1966   return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
1967                              debugLine->col);
1968 }
1969 
1970 LogicalResult
1971 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
1972   // According to SPIR-V spec:
1973   // "This location information applies to the instructions physically
1974   // following this instruction, up to the first occurrence of any of the
1975   // following: the next end of block, the next OpLine instruction, or the next
1976   // OpNoLine instruction."
1977   if (operands.size() != 3)
1978     return emitError(unknownLoc, "OpLine must have 3 operands");
1979   debugLine = DebugLine(operands[0], operands[1], operands[2]);
1980   return success();
1981 }
1982 
1983 LogicalResult spirv::Deserializer::clearDebugLine() {
1984   debugLine = llvm::None;
1985   return success();
1986 }
1987 
1988 LogicalResult
1989 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
1990   if (operands.size() < 2)
1991     return emitError(unknownLoc, "OpString needs at least 2 operands");
1992 
1993   if (!debugInfoMap.lookup(operands[0]).empty())
1994     return emitError(unknownLoc,
1995                      "duplicate debug string found for result <id> ")
1996            << operands[0];
1997 
1998   unsigned wordIndex = 1;
1999   StringRef debugString = decodeStringLiteral(operands, wordIndex);
2000   if (wordIndex != operands.size())
2001     return emitError(unknownLoc,
2002                      "unexpected trailing words in OpString instruction");
2003 
2004   debugInfoMap[operands[0]] = debugString;
2005   return success();
2006 }
2007