1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "Deserializer.h"
14
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/Support/LogicalResult.h"
23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/bit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/raw_ostream.h"
32
33 using namespace mlir;
34
35 #define DEBUG_TYPE "spirv-deserialization"
36
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40
41 /// Returns true if the given `block` is a function entry block.
isFnEntryBlock(Block * block)42 static inline bool isFnEntryBlock(Block *block) {
43 return block->isEntryBlock() &&
44 isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50
Deserializer(ArrayRef<uint32_t> binary,MLIRContext * context)51 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52 MLIRContext *context)
53 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
54 module(createModuleOp()), opBuilder(module->getRegion())
55 #ifndef NDEBUG
56 ,
57 logger(llvm::dbgs())
58 #endif
59 {
60 }
61
deserialize()62 LogicalResult spirv::Deserializer::deserialize() {
63 LLVM_DEBUG({
64 logger.resetIndent();
65 logger.startLine()
66 << "//+++---------- start deserialization ----------+++//\n";
67 });
68
69 if (failed(processHeader()))
70 return failure();
71
72 spirv::Opcode opcode = spirv::Opcode::OpNop;
73 ArrayRef<uint32_t> operands;
74 auto binarySize = binary.size();
75 while (curOffset < binarySize) {
76 // Slice the next instruction out and populate `opcode` and `operands`.
77 // Internally this also updates `curOffset`.
78 if (failed(sliceInstruction(opcode, operands)))
79 return failure();
80
81 if (failed(processInstruction(opcode, operands)))
82 return failure();
83 }
84
85 assert(curOffset == binarySize &&
86 "deserializer should never index beyond the binary end");
87
88 for (auto &deferred : deferredInstructions) {
89 if (failed(processInstruction(deferred.first, deferred.second, false))) {
90 return failure();
91 }
92 }
93
94 attachVCETriple();
95
96 LLVM_DEBUG(logger.startLine()
97 << "//+++-------- completed deserialization --------+++//\n");
98 return success();
99 }
100
collect()101 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
102 return std::move(module);
103 }
104
105 //===----------------------------------------------------------------------===//
106 // Module structure
107 //===----------------------------------------------------------------------===//
108
createModuleOp()109 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
110 OpBuilder builder(context);
111 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
112 spirv::ModuleOp::build(builder, state);
113 return cast<spirv::ModuleOp>(Operation::create(state));
114 }
115
processHeader()116 LogicalResult spirv::Deserializer::processHeader() {
117 if (binary.size() < spirv::kHeaderWordCount)
118 return emitError(unknownLoc,
119 "SPIR-V binary module must have a 5-word header");
120
121 if (binary[0] != spirv::kMagicNumber)
122 return emitError(unknownLoc, "incorrect magic number");
123
124 // Version number bytes: 0 | major number | minor number | 0
125 uint32_t majorVersion = (binary[1] << 8) >> 24;
126 uint32_t minorVersion = (binary[1] << 16) >> 24;
127 if (majorVersion == 1) {
128 switch (minorVersion) {
129 #define MIN_VERSION_CASE(v) \
130 case v: \
131 version = spirv::Version::V_1_##v; \
132 break
133
134 MIN_VERSION_CASE(0);
135 MIN_VERSION_CASE(1);
136 MIN_VERSION_CASE(2);
137 MIN_VERSION_CASE(3);
138 MIN_VERSION_CASE(4);
139 MIN_VERSION_CASE(5);
140 #undef MIN_VERSION_CASE
141 default:
142 return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
143 << minorVersion;
144 }
145 } else {
146 return emitError(unknownLoc, "unsupported SPIR-V major version: ")
147 << majorVersion;
148 }
149
150 // TODO: generator number, bound, schema
151 curOffset = spirv::kHeaderWordCount;
152 return success();
153 }
154
155 LogicalResult
processCapability(ArrayRef<uint32_t> operands)156 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
157 if (operands.size() != 1)
158 return emitError(unknownLoc, "OpMemoryModel must have one parameter");
159
160 auto cap = spirv::symbolizeCapability(operands[0]);
161 if (!cap)
162 return emitError(unknownLoc, "unknown capability: ") << operands[0];
163
164 capabilities.insert(*cap);
165 return success();
166 }
167
processExtension(ArrayRef<uint32_t> words)168 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
169 if (words.empty()) {
170 return emitError(
171 unknownLoc,
172 "OpExtension must have a literal string for the extension name");
173 }
174
175 unsigned wordIndex = 0;
176 StringRef extName = decodeStringLiteral(words, wordIndex);
177 if (wordIndex != words.size())
178 return emitError(unknownLoc,
179 "unexpected trailing words in OpExtension instruction");
180 auto ext = spirv::symbolizeExtension(extName);
181 if (!ext)
182 return emitError(unknownLoc, "unknown extension: ") << extName;
183
184 extensions.insert(*ext);
185 return success();
186 }
187
188 LogicalResult
processExtInstImport(ArrayRef<uint32_t> words)189 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
190 if (words.size() < 2) {
191 return emitError(unknownLoc,
192 "OpExtInstImport must have a result <id> and a literal "
193 "string for the extended instruction set name");
194 }
195
196 unsigned wordIndex = 1;
197 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
198 if (wordIndex != words.size()) {
199 return emitError(unknownLoc,
200 "unexpected trailing words in OpExtInstImport");
201 }
202 return success();
203 }
204
attachVCETriple()205 void spirv::Deserializer::attachVCETriple() {
206 (*module)->setAttr(
207 spirv::ModuleOp::getVCETripleAttrName(),
208 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
209 extensions.getArrayRef(), context));
210 }
211
212 LogicalResult
processMemoryModel(ArrayRef<uint32_t> operands)213 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
214 if (operands.size() != 2)
215 return emitError(unknownLoc, "OpMemoryModel must have two operands");
216
217 (*module)->setAttr(
218 "addressing_model",
219 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front())));
220 (*module)->setAttr(
221 "memory_model",
222 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back())));
223
224 return success();
225 }
226
processDecoration(ArrayRef<uint32_t> words)227 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
228 // TODO: This function should also be auto-generated. For now, since only a
229 // few decorations are processed/handled in a meaningful manner, going with a
230 // manual implementation.
231 if (words.size() < 2) {
232 return emitError(
233 unknownLoc, "OpDecorate must have at least result <id> and Decoration");
234 }
235 auto decorationName =
236 stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
237 if (decorationName.empty()) {
238 return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
239 }
240 auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
241 auto symbol = opBuilder.getStringAttr(attrName);
242 switch (static_cast<spirv::Decoration>(words[1])) {
243 case spirv::Decoration::DescriptorSet:
244 case spirv::Decoration::Binding:
245 if (words.size() != 3) {
246 return emitError(unknownLoc, "OpDecorate with ")
247 << decorationName << " needs a single integer literal";
248 }
249 decorations[words[0]].set(
250 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
251 break;
252 case spirv::Decoration::BuiltIn:
253 if (words.size() != 3) {
254 return emitError(unknownLoc, "OpDecorate with ")
255 << decorationName << " needs a single integer literal";
256 }
257 decorations[words[0]].set(
258 symbol, opBuilder.getStringAttr(
259 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
260 break;
261 case spirv::Decoration::ArrayStride:
262 if (words.size() != 3) {
263 return emitError(unknownLoc, "OpDecorate with ")
264 << decorationName << " needs a single integer literal";
265 }
266 typeDecorations[words[0]] = words[2];
267 break;
268 case spirv::Decoration::Aliased:
269 case spirv::Decoration::Block:
270 case spirv::Decoration::BufferBlock:
271 case spirv::Decoration::Flat:
272 case spirv::Decoration::NonReadable:
273 case spirv::Decoration::NonWritable:
274 case spirv::Decoration::NoPerspective:
275 case spirv::Decoration::Restrict:
276 case spirv::Decoration::RelaxedPrecision:
277 if (words.size() != 2) {
278 return emitError(unknownLoc, "OpDecoration with ")
279 << decorationName << "needs a single target <id>";
280 }
281 // Block decoration does not affect spv.struct type, but is still stored for
282 // verification.
283 // TODO: Update StructType to contain this information since
284 // it is needed for many validation rules.
285 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
286 break;
287 case spirv::Decoration::Location:
288 case spirv::Decoration::SpecId:
289 if (words.size() != 3) {
290 return emitError(unknownLoc, "OpDecoration with ")
291 << decorationName << "needs a single integer literal";
292 }
293 decorations[words[0]].set(
294 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
295 break;
296 default:
297 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
298 }
299 return success();
300 }
301
302 LogicalResult
processMemberDecoration(ArrayRef<uint32_t> words)303 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
304 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
305 if (words.size() < 3) {
306 return emitError(unknownLoc,
307 "OpMemberDecorate must have at least 3 operands");
308 }
309
310 auto decoration = static_cast<spirv::Decoration>(words[2]);
311 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
312 return emitError(unknownLoc,
313 " missing offset specification in OpMemberDecorate with "
314 "Offset decoration");
315 }
316 ArrayRef<uint32_t> decorationOperands;
317 if (words.size() > 3) {
318 decorationOperands = words.slice(3);
319 }
320 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
321 return success();
322 }
323
processMemberName(ArrayRef<uint32_t> words)324 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
325 if (words.size() < 3) {
326 return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
327 }
328 unsigned wordIndex = 2;
329 auto name = decodeStringLiteral(words, wordIndex);
330 if (wordIndex != words.size()) {
331 return emitError(unknownLoc,
332 "unexpected trailing words in OpMemberName instruction");
333 }
334 memberNameMap[words[0]][words[1]] = name;
335 return success();
336 }
337
338 LogicalResult
processFunction(ArrayRef<uint32_t> operands)339 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
340 if (curFunction) {
341 return emitError(unknownLoc, "found function inside function");
342 }
343
344 // Get the result type
345 if (operands.size() != 4) {
346 return emitError(unknownLoc, "OpFunction must have 4 parameters");
347 }
348 Type resultType = getType(operands[0]);
349 if (!resultType) {
350 return emitError(unknownLoc, "undefined result type from <id> ")
351 << operands[0];
352 }
353
354 uint32_t fnID = operands[1];
355 if (funcMap.count(fnID)) {
356 return emitError(unknownLoc, "duplicate function definition/declaration");
357 }
358
359 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
360 if (!fnControl) {
361 return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
362 }
363
364 Type fnType = getType(operands[3]);
365 if (!fnType || !fnType.isa<FunctionType>()) {
366 return emitError(unknownLoc, "unknown function type from <id> ")
367 << operands[3];
368 }
369 auto functionType = fnType.cast<FunctionType>();
370
371 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
372 (functionType.getNumResults() == 1 &&
373 functionType.getResult(0) != resultType)) {
374 return emitError(unknownLoc, "mismatch in function type ")
375 << functionType << " and return type " << resultType << " specified";
376 }
377
378 std::string fnName = getFunctionSymbol(fnID);
379 auto funcOp = opBuilder.create<spirv::FuncOp>(
380 unknownLoc, fnName, functionType, fnControl.value());
381 curFunction = funcMap[fnID] = funcOp;
382 auto *entryBlock = funcOp.addEntryBlock();
383 LLVM_DEBUG({
384 logger.startLine()
385 << "//===-------------------------------------------===//\n";
386 logger.startLine() << "[fn] name: " << fnName << "\n";
387 logger.startLine() << "[fn] type: " << fnType << "\n";
388 logger.startLine() << "[fn] ID: " << fnID << "\n";
389 logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
390 logger.indent();
391 });
392
393 // Parse the op argument instructions
394 if (functionType.getNumInputs()) {
395 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
396 auto argType = functionType.getInput(i);
397 spirv::Opcode opcode = spirv::Opcode::OpNop;
398 ArrayRef<uint32_t> operands;
399 if (failed(sliceInstruction(opcode, operands,
400 spirv::Opcode::OpFunctionParameter))) {
401 return failure();
402 }
403 if (opcode != spirv::Opcode::OpFunctionParameter) {
404 return emitError(
405 unknownLoc,
406 "missing OpFunctionParameter instruction for argument ")
407 << i;
408 }
409 if (operands.size() != 2) {
410 return emitError(
411 unknownLoc,
412 "expected result type and result <id> for OpFunctionParameter");
413 }
414 auto argDefinedType = getType(operands[0]);
415 if (!argDefinedType || argDefinedType != argType) {
416 return emitError(unknownLoc,
417 "mismatch in argument type between function type "
418 "definition ")
419 << functionType << " and argument type definition "
420 << argDefinedType << " at argument " << i;
421 }
422 if (getValue(operands[1])) {
423 return emitError(unknownLoc, "duplicate definition of result <id> ")
424 << operands[1];
425 }
426 auto argValue = funcOp.getArgument(i);
427 valueMap[operands[1]] = argValue;
428 }
429 }
430
431 // RAII guard to reset the insertion point to the module's region after
432 // deserializing the body of this function.
433 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
434
435 spirv::Opcode opcode = spirv::Opcode::OpNop;
436 ArrayRef<uint32_t> instOperands;
437
438 // Special handling for the entry block. We need to make sure it starts with
439 // an OpLabel instruction. The entry block takes the same parameters as the
440 // function. All other blocks do not take any parameter. We have already
441 // created the entry block, here we need to register it to the correct label
442 // <id>.
443 if (failed(sliceInstruction(opcode, instOperands,
444 spirv::Opcode::OpFunctionEnd))) {
445 return failure();
446 }
447 if (opcode == spirv::Opcode::OpFunctionEnd) {
448 return processFunctionEnd(instOperands);
449 }
450 if (opcode != spirv::Opcode::OpLabel) {
451 return emitError(unknownLoc, "a basic block must start with OpLabel");
452 }
453 if (instOperands.size() != 1) {
454 return emitError(unknownLoc, "OpLabel should only have result <id>");
455 }
456 blockMap[instOperands[0]] = entryBlock;
457 if (failed(processLabel(instOperands))) {
458 return failure();
459 }
460
461 // Then process all the other instructions in the function until we hit
462 // OpFunctionEnd.
463 while (succeeded(sliceInstruction(opcode, instOperands,
464 spirv::Opcode::OpFunctionEnd)) &&
465 opcode != spirv::Opcode::OpFunctionEnd) {
466 if (failed(processInstruction(opcode, instOperands))) {
467 return failure();
468 }
469 }
470 if (opcode != spirv::Opcode::OpFunctionEnd) {
471 return failure();
472 }
473
474 return processFunctionEnd(instOperands);
475 }
476
477 LogicalResult
processFunctionEnd(ArrayRef<uint32_t> operands)478 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
479 // Process OpFunctionEnd.
480 if (!operands.empty()) {
481 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
482 }
483
484 // Wire up block arguments from OpPhi instructions.
485 // Put all structured control flow in spv.mlir.selection/spv.mlir.loop ops.
486 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
487 return failure();
488 }
489
490 curBlock = nullptr;
491 curFunction = llvm::None;
492
493 LLVM_DEBUG({
494 logger.unindent();
495 logger.startLine()
496 << "//===-------------------------------------------===//\n";
497 });
498 return success();
499 }
500
501 Optional<std::pair<Attribute, Type>>
getConstant(uint32_t id)502 spirv::Deserializer::getConstant(uint32_t id) {
503 auto constIt = constantMap.find(id);
504 if (constIt == constantMap.end())
505 return llvm::None;
506 return constIt->getSecond();
507 }
508
509 Optional<spirv::SpecConstOperationMaterializationInfo>
getSpecConstantOperation(uint32_t id)510 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
511 auto constIt = specConstOperationMap.find(id);
512 if (constIt == specConstOperationMap.end())
513 return llvm::None;
514 return constIt->getSecond();
515 }
516
getFunctionSymbol(uint32_t id)517 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
518 auto funcName = nameMap.lookup(id).str();
519 if (funcName.empty()) {
520 funcName = "spirv_fn_" + std::to_string(id);
521 }
522 return funcName;
523 }
524
getSpecConstantSymbol(uint32_t id)525 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
526 auto constName = nameMap.lookup(id).str();
527 if (constName.empty()) {
528 constName = "spirv_spec_const_" + std::to_string(id);
529 }
530 return constName;
531 }
532
533 spirv::SpecConstantOp
createSpecConstant(Location loc,uint32_t resultID,Attribute defaultValue)534 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
535 Attribute defaultValue) {
536 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
537 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
538 defaultValue);
539 if (decorations.count(resultID)) {
540 for (auto attr : decorations[resultID].getAttrs())
541 op->setAttr(attr.getName(), attr.getValue());
542 }
543 specConstMap[resultID] = op;
544 return op;
545 }
546
547 LogicalResult
processGlobalVariable(ArrayRef<uint32_t> operands)548 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
549 unsigned wordIndex = 0;
550 if (operands.size() < 3) {
551 return emitError(
552 unknownLoc,
553 "OpVariable needs at least 3 operands, type, <id> and storage class");
554 }
555
556 // Result Type.
557 auto type = getType(operands[wordIndex]);
558 if (!type) {
559 return emitError(unknownLoc, "unknown result type <id> : ")
560 << operands[wordIndex];
561 }
562 auto ptrType = type.dyn_cast<spirv::PointerType>();
563 if (!ptrType) {
564 return emitError(unknownLoc,
565 "expected a result type <id> to be a spv.ptr, found : ")
566 << type;
567 }
568 wordIndex++;
569
570 // Result <id>.
571 auto variableID = operands[wordIndex];
572 auto variableName = nameMap.lookup(variableID).str();
573 if (variableName.empty()) {
574 variableName = "spirv_var_" + std::to_string(variableID);
575 }
576 wordIndex++;
577
578 // Storage class.
579 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
580 if (ptrType.getStorageClass() != storageClass) {
581 return emitError(unknownLoc, "mismatch in storage class of pointer type ")
582 << type << " and that specified in OpVariable instruction : "
583 << stringifyStorageClass(storageClass);
584 }
585 wordIndex++;
586
587 // Initializer.
588 FlatSymbolRefAttr initializer = nullptr;
589 if (wordIndex < operands.size()) {
590 auto initializerOp = getGlobalVariable(operands[wordIndex]);
591 if (!initializerOp) {
592 return emitError(unknownLoc, "unknown <id> ")
593 << operands[wordIndex] << "used as initializer";
594 }
595 wordIndex++;
596 initializer = SymbolRefAttr::get(initializerOp.getOperation());
597 }
598 if (wordIndex != operands.size()) {
599 return emitError(unknownLoc,
600 "found more operands than expected when deserializing "
601 "OpVariable instruction, only ")
602 << wordIndex << " of " << operands.size() << " processed";
603 }
604 auto loc = createFileLineColLoc(opBuilder);
605 auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
606 loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
607 initializer);
608
609 // Decorations.
610 if (decorations.count(variableID)) {
611 for (auto attr : decorations[variableID].getAttrs())
612 varOp->setAttr(attr.getName(), attr.getValue());
613 }
614 globalVariableMap[variableID] = varOp;
615 return success();
616 }
617
getConstantInt(uint32_t id)618 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
619 auto constInfo = getConstant(id);
620 if (!constInfo) {
621 return nullptr;
622 }
623 return constInfo->first.dyn_cast<IntegerAttr>();
624 }
625
processName(ArrayRef<uint32_t> operands)626 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
627 if (operands.size() < 2) {
628 return emitError(unknownLoc, "OpName needs at least 2 operands");
629 }
630 if (!nameMap.lookup(operands[0]).empty()) {
631 return emitError(unknownLoc, "duplicate name found for result <id> ")
632 << operands[0];
633 }
634 unsigned wordIndex = 1;
635 StringRef name = decodeStringLiteral(operands, wordIndex);
636 if (wordIndex != operands.size()) {
637 return emitError(unknownLoc,
638 "unexpected trailing words in OpName instruction");
639 }
640 nameMap[operands[0]] = name;
641 return success();
642 }
643
644 //===----------------------------------------------------------------------===//
645 // Type
646 //===----------------------------------------------------------------------===//
647
processType(spirv::Opcode opcode,ArrayRef<uint32_t> operands)648 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
649 ArrayRef<uint32_t> operands) {
650 if (operands.empty()) {
651 return emitError(unknownLoc, "type instruction with opcode ")
652 << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
653 }
654
655 /// TODO: Types might be forward declared in some instructions and need to be
656 /// handled appropriately.
657 if (typeMap.count(operands[0])) {
658 return emitError(unknownLoc, "duplicate definition for result <id> ")
659 << operands[0];
660 }
661
662 switch (opcode) {
663 case spirv::Opcode::OpTypeVoid:
664 if (operands.size() != 1)
665 return emitError(unknownLoc, "OpTypeVoid must have no parameters");
666 typeMap[operands[0]] = opBuilder.getNoneType();
667 break;
668 case spirv::Opcode::OpTypeBool:
669 if (operands.size() != 1)
670 return emitError(unknownLoc, "OpTypeBool must have no parameters");
671 typeMap[operands[0]] = opBuilder.getI1Type();
672 break;
673 case spirv::Opcode::OpTypeInt: {
674 if (operands.size() != 3)
675 return emitError(
676 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
677
678 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
679 // to preserve or validate.
680 // 0 indicates unsigned, or no signedness semantics
681 // 1 indicates signed semantics."
682 //
683 // So we cannot differentiate signless and unsigned integers; always use
684 // signless semantics for such cases.
685 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
686 : IntegerType::SignednessSemantics::Signless;
687 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
688 } break;
689 case spirv::Opcode::OpTypeFloat: {
690 if (operands.size() != 2)
691 return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
692
693 Type floatTy;
694 switch (operands[1]) {
695 case 16:
696 floatTy = opBuilder.getF16Type();
697 break;
698 case 32:
699 floatTy = opBuilder.getF32Type();
700 break;
701 case 64:
702 floatTy = opBuilder.getF64Type();
703 break;
704 default:
705 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
706 << operands[1];
707 }
708 typeMap[operands[0]] = floatTy;
709 } break;
710 case spirv::Opcode::OpTypeVector: {
711 if (operands.size() != 3) {
712 return emitError(
713 unknownLoc,
714 "OpTypeVector must have element type and count parameters");
715 }
716 Type elementTy = getType(operands[1]);
717 if (!elementTy) {
718 return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
719 << operands[1];
720 }
721 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
722 } break;
723 case spirv::Opcode::OpTypePointer: {
724 return processOpTypePointer(operands);
725 } break;
726 case spirv::Opcode::OpTypeArray:
727 return processArrayType(operands);
728 case spirv::Opcode::OpTypeCooperativeMatrixNV:
729 return processCooperativeMatrixType(operands);
730 case spirv::Opcode::OpTypeFunction:
731 return processFunctionType(operands);
732 case spirv::Opcode::OpTypeImage:
733 return processImageType(operands);
734 case spirv::Opcode::OpTypeSampledImage:
735 return processSampledImageType(operands);
736 case spirv::Opcode::OpTypeRuntimeArray:
737 return processRuntimeArrayType(operands);
738 case spirv::Opcode::OpTypeStruct:
739 return processStructType(operands);
740 case spirv::Opcode::OpTypeMatrix:
741 return processMatrixType(operands);
742 default:
743 return emitError(unknownLoc, "unhandled type instruction");
744 }
745 return success();
746 }
747
748 LogicalResult
processOpTypePointer(ArrayRef<uint32_t> operands)749 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
750 if (operands.size() != 3)
751 return emitError(unknownLoc, "OpTypePointer must have two parameters");
752
753 auto pointeeType = getType(operands[2]);
754 if (!pointeeType)
755 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
756 << operands[2];
757
758 uint32_t typePointerID = operands[0];
759 auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
760 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
761
762 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
763 deferredStructIt != std::end(deferredStructTypesInfos);) {
764 for (auto *unresolvedMemberIt =
765 std::begin(deferredStructIt->unresolvedMemberTypes);
766 unresolvedMemberIt !=
767 std::end(deferredStructIt->unresolvedMemberTypes);) {
768 if (unresolvedMemberIt->first == typePointerID) {
769 // The newly constructed pointer type can resolve one of the
770 // deferred struct type members; update the memberTypes list and
771 // clean the unresolvedMemberTypes list accordingly.
772 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
773 typeMap[typePointerID];
774 unresolvedMemberIt =
775 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
776 } else {
777 ++unresolvedMemberIt;
778 }
779 }
780
781 if (deferredStructIt->unresolvedMemberTypes.empty()) {
782 // All deferred struct type members are now resolved, set the struct body.
783 auto structType = deferredStructIt->deferredStructType;
784
785 assert(structType && "expected a spirv::StructType");
786 assert(structType.isIdentified() && "expected an indentified struct");
787
788 if (failed(structType.trySetBody(
789 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
790 deferredStructIt->memberDecorationsInfo)))
791 return failure();
792
793 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
794 } else {
795 ++deferredStructIt;
796 }
797 }
798
799 return success();
800 }
801
802 LogicalResult
processArrayType(ArrayRef<uint32_t> operands)803 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
804 if (operands.size() != 3) {
805 return emitError(unknownLoc,
806 "OpTypeArray must have element type and count parameters");
807 }
808
809 Type elementTy = getType(operands[1]);
810 if (!elementTy) {
811 return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
812 << operands[1];
813 }
814
815 unsigned count = 0;
816 // TODO: The count can also come frome a specialization constant.
817 auto countInfo = getConstant(operands[2]);
818 if (!countInfo) {
819 return emitError(unknownLoc, "OpTypeArray count <id> ")
820 << operands[2] << "can only come from normal constant right now";
821 }
822
823 if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) {
824 count = intVal.getValue().getZExtValue();
825 } else {
826 return emitError(unknownLoc, "OpTypeArray count must come from a "
827 "scalar integer constant instruction");
828 }
829
830 typeMap[operands[0]] = spirv::ArrayType::get(
831 elementTy, count, typeDecorations.lookup(operands[0]));
832 return success();
833 }
834
835 LogicalResult
processFunctionType(ArrayRef<uint32_t> operands)836 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
837 assert(!operands.empty() && "No operands for processing function type");
838 if (operands.size() == 1) {
839 return emitError(unknownLoc, "missing return type for OpTypeFunction");
840 }
841 auto returnType = getType(operands[1]);
842 if (!returnType) {
843 return emitError(unknownLoc, "unknown return type in OpTypeFunction");
844 }
845 SmallVector<Type, 1> argTypes;
846 for (size_t i = 2, e = operands.size(); i < e; ++i) {
847 auto ty = getType(operands[i]);
848 if (!ty) {
849 return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
850 }
851 argTypes.push_back(ty);
852 }
853 ArrayRef<Type> returnTypes;
854 if (!isVoidType(returnType)) {
855 returnTypes = llvm::makeArrayRef(returnType);
856 }
857 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
858 return success();
859 }
860
861 LogicalResult
processCooperativeMatrixType(ArrayRef<uint32_t> operands)862 spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
863 if (operands.size() != 5) {
864 return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
865 "type and row x column parameters");
866 }
867
868 Type elementTy = getType(operands[1]);
869 if (!elementTy) {
870 return emitError(unknownLoc,
871 "OpTypeCooperativeMatrix references undefined <id> ")
872 << operands[1];
873 }
874
875 auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
876 if (!scope) {
877 return emitError(unknownLoc,
878 "OpTypeCooperativeMatrix references undefined scope <id> ")
879 << operands[2];
880 }
881
882 unsigned rows = getConstantInt(operands[3]).getInt();
883 unsigned columns = getConstantInt(operands[4]).getInt();
884
885 typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
886 elementTy, scope.value(), rows, columns);
887 return success();
888 }
889
890 LogicalResult
processRuntimeArrayType(ArrayRef<uint32_t> operands)891 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
892 if (operands.size() != 2) {
893 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
894 }
895 Type memberType = getType(operands[1]);
896 if (!memberType) {
897 return emitError(unknownLoc,
898 "OpTypeRuntimeArray references undefined <id> ")
899 << operands[1];
900 }
901 typeMap[operands[0]] = spirv::RuntimeArrayType::get(
902 memberType, typeDecorations.lookup(operands[0]));
903 return success();
904 }
905
906 LogicalResult
processStructType(ArrayRef<uint32_t> operands)907 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
908 // TODO: Find a way to handle identified structs when debug info is stripped.
909
910 if (operands.empty()) {
911 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
912 }
913
914 if (operands.size() == 1) {
915 // Handle empty struct.
916 typeMap[operands[0]] =
917 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
918 return success();
919 }
920
921 // First element is operand ID, second element is member index in the struct.
922 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
923 SmallVector<Type, 4> memberTypes;
924
925 for (auto op : llvm::drop_begin(operands, 1)) {
926 Type memberType = getType(op);
927 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
928
929 if (!memberType && !typeForwardPtr)
930 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
931 << op;
932
933 if (!memberType)
934 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
935
936 memberTypes.push_back(memberType);
937 }
938
939 SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
940 SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
941 if (memberDecorationMap.count(operands[0])) {
942 auto &allMemberDecorations = memberDecorationMap[operands[0]];
943 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
944 if (allMemberDecorations.count(memberIndex)) {
945 for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
946 // Check for offset.
947 if (memberDecoration.first == spirv::Decoration::Offset) {
948 // If offset info is empty, resize to the number of members;
949 if (offsetInfo.empty()) {
950 offsetInfo.resize(memberTypes.size());
951 }
952 offsetInfo[memberIndex] = memberDecoration.second[0];
953 } else {
954 if (!memberDecoration.second.empty()) {
955 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
956 memberDecoration.first,
957 memberDecoration.second[0]);
958 } else {
959 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
960 memberDecoration.first, 0);
961 }
962 }
963 }
964 }
965 }
966 }
967
968 uint32_t structID = operands[0];
969 std::string structIdentifier = nameMap.lookup(structID).str();
970
971 if (structIdentifier.empty()) {
972 assert(unresolvedMemberTypes.empty() &&
973 "didn't expect unresolved member types");
974 typeMap[structID] =
975 spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
976 } else {
977 auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
978 typeMap[structID] = structTy;
979
980 if (!unresolvedMemberTypes.empty())
981 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
982 memberTypes, offsetInfo,
983 memberDecorationsInfo});
984 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
985 memberDecorationsInfo)))
986 return failure();
987 }
988
989 // TODO: Update StructType to have member name as attribute as
990 // well.
991 return success();
992 }
993
994 LogicalResult
processMatrixType(ArrayRef<uint32_t> operands)995 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
996 if (operands.size() != 3) {
997 // Three operands are needed: result_id, column_type, and column_count
998 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
999 " (result_id, column_type, and column_count)");
1000 }
1001 // Matrix columns must be of vector type
1002 Type elementTy = getType(operands[1]);
1003 if (!elementTy) {
1004 return emitError(unknownLoc,
1005 "OpTypeMatrix references undefined column type.")
1006 << operands[1];
1007 }
1008
1009 uint32_t colsCount = operands[2];
1010 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1011 return success();
1012 }
1013
1014 LogicalResult
processTypeForwardPointer(ArrayRef<uint32_t> operands)1015 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1016 if (operands.size() != 2)
1017 return emitError(unknownLoc,
1018 "OpTypeForwardPointer instruction must have two operands");
1019
1020 typeForwardPointerIDs.insert(operands[0]);
1021 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1022 // instruction that defines the actual type.
1023
1024 return success();
1025 }
1026
1027 LogicalResult
processImageType(ArrayRef<uint32_t> operands)1028 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1029 // TODO: Add support for Access Qualifier.
1030 if (operands.size() != 8)
1031 return emitError(
1032 unknownLoc,
1033 "OpTypeImage with non-eight operands are not supported yet");
1034
1035 Type elementTy = getType(operands[1]);
1036 if (!elementTy)
1037 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1038 << operands[1];
1039
1040 auto dim = spirv::symbolizeDim(operands[2]);
1041 if (!dim)
1042 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1043 << operands[2];
1044
1045 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1046 if (!depthInfo)
1047 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1048 << operands[3];
1049
1050 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1051 if (!arrayedInfo)
1052 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1053 << operands[4];
1054
1055 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1056 if (!samplingInfo)
1057 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1058
1059 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1060 if (!samplerUseInfo)
1061 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1062 << operands[6];
1063
1064 auto format = spirv::symbolizeImageFormat(operands[7]);
1065 if (!format)
1066 return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1067 << operands[7];
1068
1069 typeMap[operands[0]] = spirv::ImageType::get(
1070 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1071 samplingInfo.value(), samplerUseInfo.value(), format.value());
1072 return success();
1073 }
1074
1075 LogicalResult
processSampledImageType(ArrayRef<uint32_t> operands)1076 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1077 if (operands.size() != 2)
1078 return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1079
1080 Type elementTy = getType(operands[1]);
1081 if (!elementTy)
1082 return emitError(unknownLoc,
1083 "OpTypeSampledImage references undefined <id>: ")
1084 << operands[1];
1085
1086 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1087 return success();
1088 }
1089
1090 //===----------------------------------------------------------------------===//
1091 // Constant
1092 //===----------------------------------------------------------------------===//
1093
processConstant(ArrayRef<uint32_t> operands,bool isSpec)1094 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1095 bool isSpec) {
1096 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1097
1098 if (operands.size() < 2) {
1099 return emitError(unknownLoc)
1100 << opname << " must have type <id> and result <id>";
1101 }
1102 if (operands.size() < 3) {
1103 return emitError(unknownLoc)
1104 << opname << " must have at least 1 more parameter";
1105 }
1106
1107 Type resultType = getType(operands[0]);
1108 if (!resultType) {
1109 return emitError(unknownLoc, "undefined result type from <id> ")
1110 << operands[0];
1111 }
1112
1113 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1114 if (bitwidth == 64) {
1115 if (operands.size() == 4) {
1116 return success();
1117 }
1118 return emitError(unknownLoc)
1119 << opname << " should have 2 parameters for 64-bit values";
1120 }
1121 if (bitwidth <= 32) {
1122 if (operands.size() == 3) {
1123 return success();
1124 }
1125
1126 return emitError(unknownLoc)
1127 << opname
1128 << " should have 1 parameter for values with no more than 32 bits";
1129 }
1130 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1131 << bitwidth;
1132 };
1133
1134 auto resultID = operands[1];
1135
1136 if (auto intType = resultType.dyn_cast<IntegerType>()) {
1137 auto bitwidth = intType.getWidth();
1138 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1139 return failure();
1140 }
1141
1142 APInt value;
1143 if (bitwidth == 64) {
1144 // 64-bit integers are represented with two SPIR-V words. According to
1145 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1146 // literal’s low-order words appear first."
1147 struct DoubleWord {
1148 uint32_t word1;
1149 uint32_t word2;
1150 } words = {operands[2], operands[3]};
1151 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1152 } else if (bitwidth <= 32) {
1153 value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1154 }
1155
1156 auto attr = opBuilder.getIntegerAttr(intType, value);
1157
1158 if (isSpec) {
1159 createSpecConstant(unknownLoc, resultID, attr);
1160 } else {
1161 // For normal constants, we just record the attribute (and its type) for
1162 // later materialization at use sites.
1163 constantMap.try_emplace(resultID, attr, intType);
1164 }
1165
1166 return success();
1167 }
1168
1169 if (auto floatType = resultType.dyn_cast<FloatType>()) {
1170 auto bitwidth = floatType.getWidth();
1171 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1172 return failure();
1173 }
1174
1175 APFloat value(0.f);
1176 if (floatType.isF64()) {
1177 // Double values are represented with two SPIR-V words. According to
1178 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1179 // literal’s low-order words appear first."
1180 struct DoubleWord {
1181 uint32_t word1;
1182 uint32_t word2;
1183 } words = {operands[2], operands[3]};
1184 value = APFloat(llvm::bit_cast<double>(words));
1185 } else if (floatType.isF32()) {
1186 value = APFloat(llvm::bit_cast<float>(operands[2]));
1187 } else if (floatType.isF16()) {
1188 APInt data(16, operands[2]);
1189 value = APFloat(APFloat::IEEEhalf(), data);
1190 }
1191
1192 auto attr = opBuilder.getFloatAttr(floatType, value);
1193 if (isSpec) {
1194 createSpecConstant(unknownLoc, resultID, attr);
1195 } else {
1196 // For normal constants, we just record the attribute (and its type) for
1197 // later materialization at use sites.
1198 constantMap.try_emplace(resultID, attr, floatType);
1199 }
1200
1201 return success();
1202 }
1203
1204 return emitError(unknownLoc, "OpConstant can only generate values of "
1205 "scalar integer or floating-point type");
1206 }
1207
processConstantBool(bool isTrue,ArrayRef<uint32_t> operands,bool isSpec)1208 LogicalResult spirv::Deserializer::processConstantBool(
1209 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1210 if (operands.size() != 2) {
1211 return emitError(unknownLoc, "Op")
1212 << (isSpec ? "Spec" : "") << "Constant"
1213 << (isTrue ? "True" : "False")
1214 << " must have type <id> and result <id>";
1215 }
1216
1217 auto attr = opBuilder.getBoolAttr(isTrue);
1218 auto resultID = operands[1];
1219 if (isSpec) {
1220 createSpecConstant(unknownLoc, resultID, attr);
1221 } else {
1222 // For normal constants, we just record the attribute (and its type) for
1223 // later materialization at use sites.
1224 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1225 }
1226
1227 return success();
1228 }
1229
1230 LogicalResult
processConstantComposite(ArrayRef<uint32_t> operands)1231 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1232 if (operands.size() < 2) {
1233 return emitError(unknownLoc,
1234 "OpConstantComposite must have type <id> and result <id>");
1235 }
1236 if (operands.size() < 3) {
1237 return emitError(unknownLoc,
1238 "OpConstantComposite must have at least 1 parameter");
1239 }
1240
1241 Type resultType = getType(operands[0]);
1242 if (!resultType) {
1243 return emitError(unknownLoc, "undefined result type from <id> ")
1244 << operands[0];
1245 }
1246
1247 SmallVector<Attribute, 4> elements;
1248 elements.reserve(operands.size() - 2);
1249 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1250 auto elementInfo = getConstant(operands[i]);
1251 if (!elementInfo) {
1252 return emitError(unknownLoc, "OpConstantComposite component <id> ")
1253 << operands[i] << " must come from a normal constant";
1254 }
1255 elements.push_back(elementInfo->first);
1256 }
1257
1258 auto resultID = operands[1];
1259 if (auto vectorType = resultType.dyn_cast<VectorType>()) {
1260 auto attr = DenseElementsAttr::get(vectorType, elements);
1261 // For normal constants, we just record the attribute (and its type) for
1262 // later materialization at use sites.
1263 constantMap.try_emplace(resultID, attr, resultType);
1264 } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
1265 auto attr = opBuilder.getArrayAttr(elements);
1266 constantMap.try_emplace(resultID, attr, resultType);
1267 } else {
1268 return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1269 << resultType;
1270 }
1271
1272 return success();
1273 }
1274
1275 LogicalResult
processSpecConstantComposite(ArrayRef<uint32_t> operands)1276 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1277 if (operands.size() < 2) {
1278 return emitError(unknownLoc,
1279 "OpConstantComposite must have type <id> and result <id>");
1280 }
1281 if (operands.size() < 3) {
1282 return emitError(unknownLoc,
1283 "OpConstantComposite must have at least 1 parameter");
1284 }
1285
1286 Type resultType = getType(operands[0]);
1287 if (!resultType) {
1288 return emitError(unknownLoc, "undefined result type from <id> ")
1289 << operands[0];
1290 }
1291
1292 auto resultID = operands[1];
1293 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1294
1295 SmallVector<Attribute, 4> elements;
1296 elements.reserve(operands.size() - 2);
1297 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1298 auto elementInfo = getSpecConstant(operands[i]);
1299 elements.push_back(SymbolRefAttr::get(elementInfo));
1300 }
1301
1302 auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1303 unknownLoc, TypeAttr::get(resultType), symName,
1304 opBuilder.getArrayAttr(elements));
1305 specConstCompositeMap[resultID] = op;
1306
1307 return success();
1308 }
1309
1310 LogicalResult
processSpecConstantOperation(ArrayRef<uint32_t> operands)1311 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1312 if (operands.size() < 3)
1313 return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1314 "result <id>, and operand opcode");
1315
1316 uint32_t resultTypeID = operands[0];
1317
1318 if (!getType(resultTypeID))
1319 return emitError(unknownLoc, "undefined result type from <id> ")
1320 << resultTypeID;
1321
1322 uint32_t resultID = operands[1];
1323 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1324 auto emplaceResult = specConstOperationMap.try_emplace(
1325 resultID,
1326 SpecConstOperationMaterializationInfo{
1327 enclosedOpcode, resultTypeID,
1328 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1329
1330 if (!emplaceResult.second)
1331 return emitError(unknownLoc, "value with <id>: ")
1332 << resultID << " is probably defined before.";
1333
1334 return success();
1335 }
1336
materializeSpecConstantOperation(uint32_t resultID,spirv::Opcode enclosedOpcode,uint32_t resultTypeID,ArrayRef<uint32_t> enclosedOpOperands)1337 Value spirv::Deserializer::materializeSpecConstantOperation(
1338 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1339 ArrayRef<uint32_t> enclosedOpOperands) {
1340
1341 Type resultType = getType(resultTypeID);
1342
1343 // Instructions wrapped by OpSpecConstantOp need an ID for their
1344 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1345 // dialect wrapped op. For that purpose, a new value map is created and "fake"
1346 // ID in that map is assigned to the result of the enclosed instruction. Note
1347 // that there is no need to update this fake ID since we only need to
1348 // reference the created Value for the enclosed op from the spv::YieldOp
1349 // created later in this method (both of which are the only values in their
1350 // region: the SpecConstantOperation's region). If we encounter another
1351 // SpecConstantOperation in the module, we simply re-use the fake ID since the
1352 // previous Value assigned to it isn't visible in the current scope anyway.
1353 DenseMap<uint32_t, Value> newValueMap;
1354 llvm::SaveAndRestore<DenseMap<uint32_t, Value>> valueMapGuard(valueMap,
1355 newValueMap);
1356 constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1357
1358 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1359 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1360 enclosedOpResultTypeAndOperands.push_back(fakeID);
1361 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1362 enclosedOpOperands.end());
1363
1364 // Process enclosed instruction before creating the enclosing
1365 // specConstantOperation (and its region). This way, references to constants,
1366 // global variables, and spec constants will be materialized outside the new
1367 // op's region. For more info, see Deserializer::getValue's implementation.
1368 if (failed(
1369 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1370 return Value();
1371
1372 // Since the enclosed op is emitted in the current block, split it in a
1373 // separate new block.
1374 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1375
1376 auto loc = createFileLineColLoc(opBuilder);
1377 auto specConstOperationOp =
1378 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1379
1380 Region &body = specConstOperationOp.body();
1381 // Move the new block into SpecConstantOperation's body.
1382 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1383 Region::iterator(enclosedBlock));
1384 Block &block = body.back();
1385
1386 // RAII guard to reset the insertion point to the module's region after
1387 // deserializing the body of the specConstantOperation.
1388 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1389 opBuilder.setInsertionPointToEnd(&block);
1390
1391 opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1392 return specConstOperationOp.getResult();
1393 }
1394
1395 LogicalResult
processConstantNull(ArrayRef<uint32_t> operands)1396 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1397 if (operands.size() != 2) {
1398 return emitError(unknownLoc,
1399 "OpConstantNull must have type <id> and result <id>");
1400 }
1401
1402 Type resultType = getType(operands[0]);
1403 if (!resultType) {
1404 return emitError(unknownLoc, "undefined result type from <id> ")
1405 << operands[0];
1406 }
1407
1408 auto resultID = operands[1];
1409 if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) {
1410 auto attr = opBuilder.getZeroAttr(resultType);
1411 // For normal constants, we just record the attribute (and its type) for
1412 // later materialization at use sites.
1413 constantMap.try_emplace(resultID, attr, resultType);
1414 return success();
1415 }
1416
1417 return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1418 << resultType;
1419 }
1420
1421 //===----------------------------------------------------------------------===//
1422 // Control flow
1423 //===----------------------------------------------------------------------===//
1424
getOrCreateBlock(uint32_t id)1425 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1426 if (auto *block = getBlock(id)) {
1427 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
1428 << " @ " << block << "\n");
1429 return block;
1430 }
1431
1432 // We don't know where this block will be placed finally (in a
1433 // spv.mlir.selection or spv.mlir.loop or function). Create it into the
1434 // function for now and sort out the proper place later.
1435 auto *block = curFunction->addBlock();
1436 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
1437 << " @ " << block << "\n");
1438 return blockMap[id] = block;
1439 }
1440
processBranch(ArrayRef<uint32_t> operands)1441 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1442 if (!curBlock) {
1443 return emitError(unknownLoc, "OpBranch must appear inside a block");
1444 }
1445
1446 if (operands.size() != 1) {
1447 return emitError(unknownLoc, "OpBranch must take exactly one target label");
1448 }
1449
1450 auto *target = getOrCreateBlock(operands[0]);
1451 auto loc = createFileLineColLoc(opBuilder);
1452 // The preceding instruction for the OpBranch instruction could be an
1453 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1454 // the same OpLine information.
1455 opBuilder.create<spirv::BranchOp>(loc, target);
1456
1457 clearDebugLine();
1458 return success();
1459 }
1460
1461 LogicalResult
processBranchConditional(ArrayRef<uint32_t> operands)1462 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1463 if (!curBlock) {
1464 return emitError(unknownLoc,
1465 "OpBranchConditional must appear inside a block");
1466 }
1467
1468 if (operands.size() != 3 && operands.size() != 5) {
1469 return emitError(unknownLoc,
1470 "OpBranchConditional must have condition, true label, "
1471 "false label, and optionally two branch weights");
1472 }
1473
1474 auto condition = getValue(operands[0]);
1475 auto *trueBlock = getOrCreateBlock(operands[1]);
1476 auto *falseBlock = getOrCreateBlock(operands[2]);
1477
1478 Optional<std::pair<uint32_t, uint32_t>> weights;
1479 if (operands.size() == 5) {
1480 weights = std::make_pair(operands[3], operands[4]);
1481 }
1482 // The preceding instruction for the OpBranchConditional instruction could be
1483 // an OpSelectionMerge instruction, in this case they will have the same
1484 // OpLine information.
1485 auto loc = createFileLineColLoc(opBuilder);
1486 opBuilder.create<spirv::BranchConditionalOp>(
1487 loc, condition, trueBlock,
1488 /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1489 /*falseArguments=*/ArrayRef<Value>(), weights);
1490
1491 clearDebugLine();
1492 return success();
1493 }
1494
processLabel(ArrayRef<uint32_t> operands)1495 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1496 if (!curFunction) {
1497 return emitError(unknownLoc, "OpLabel must appear inside a function");
1498 }
1499
1500 if (operands.size() != 1) {
1501 return emitError(unknownLoc, "OpLabel should only have result <id>");
1502 }
1503
1504 auto labelID = operands[0];
1505 // We may have forward declared this block.
1506 auto *block = getOrCreateBlock(labelID);
1507 LLVM_DEBUG(logger.startLine()
1508 << "[block] populating block " << block << "\n");
1509 // If we have seen this block, make sure it was just a forward declaration.
1510 assert(block->empty() && "re-deserialize the same block!");
1511
1512 opBuilder.setInsertionPointToStart(block);
1513 blockMap[labelID] = curBlock = block;
1514
1515 return success();
1516 }
1517
1518 LogicalResult
processSelectionMerge(ArrayRef<uint32_t> operands)1519 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1520 if (!curBlock) {
1521 return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1522 }
1523
1524 if (operands.size() < 2) {
1525 return emitError(
1526 unknownLoc,
1527 "OpSelectionMerge must specify merge target and selection control");
1528 }
1529
1530 auto *mergeBlock = getOrCreateBlock(operands[0]);
1531 auto loc = createFileLineColLoc(opBuilder);
1532 auto selectionControl = operands[1];
1533
1534 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1535 .second) {
1536 return emitError(
1537 unknownLoc,
1538 "a block cannot have more than one OpSelectionMerge instruction");
1539 }
1540
1541 return success();
1542 }
1543
1544 LogicalResult
processLoopMerge(ArrayRef<uint32_t> operands)1545 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1546 if (!curBlock) {
1547 return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1548 }
1549
1550 if (operands.size() < 3) {
1551 return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1552 "continue target and loop control");
1553 }
1554
1555 auto *mergeBlock = getOrCreateBlock(operands[0]);
1556 auto *continueBlock = getOrCreateBlock(operands[1]);
1557 auto loc = createFileLineColLoc(opBuilder);
1558 uint32_t loopControl = operands[2];
1559
1560 if (!blockMergeInfo
1561 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1562 .second) {
1563 return emitError(
1564 unknownLoc,
1565 "a block cannot have more than one OpLoopMerge instruction");
1566 }
1567
1568 return success();
1569 }
1570
processPhi(ArrayRef<uint32_t> operands)1571 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1572 if (!curBlock) {
1573 return emitError(unknownLoc, "OpPhi must appear in a block");
1574 }
1575
1576 if (operands.size() < 4) {
1577 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1578 "and variable-parent pairs");
1579 }
1580
1581 // Create a block argument for this OpPhi instruction.
1582 Type blockArgType = getType(operands[0]);
1583 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1584 valueMap[operands[1]] = blockArg;
1585 LLVM_DEBUG(logger.startLine()
1586 << "[phi] created block argument " << blockArg
1587 << " id = " << operands[1] << " of type " << blockArgType << "\n");
1588
1589 // For each (value, predecessor) pair, insert the value to the predecessor's
1590 // blockPhiInfo entry so later we can fix the block argument there.
1591 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1592 uint32_t value = operands[i];
1593 Block *predecessor = getOrCreateBlock(operands[i + 1]);
1594 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1595 blockPhiInfo[predecessorTargetPair].push_back(value);
1596 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
1597 << " with arg id = " << value << "\n");
1598 }
1599
1600 return success();
1601 }
1602
1603 namespace {
1604 /// A class for putting all blocks in a structured selection/loop in a
1605 /// spv.mlir.selection/spv.mlir.loop op.
1606 class ControlFlowStructurizer {
1607 public:
1608 #ifndef NDEBUG
ControlFlowStructurizer(Location loc,uint32_t control,spirv::BlockMergeInfoMap & mergeInfo,Block * header,Block * merge,Block * cont,llvm::ScopedPrinter & logger)1609 ControlFlowStructurizer(Location loc, uint32_t control,
1610 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1611 Block *merge, Block *cont,
1612 llvm::ScopedPrinter &logger)
1613 : location(loc), control(control), blockMergeInfo(mergeInfo),
1614 headerBlock(header), mergeBlock(merge), continueBlock(cont),
1615 logger(logger) {}
1616 #else
1617 ControlFlowStructurizer(Location loc, uint32_t control,
1618 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1619 Block *merge, Block *cont)
1620 : location(loc), control(control), blockMergeInfo(mergeInfo),
1621 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1622 #endif
1623
1624 /// Structurizes the loop at the given `headerBlock`.
1625 ///
1626 /// This method will create an spv.mlir.loop op in the `mergeBlock` and move
1627 /// all blocks in the structured loop into the spv.mlir.loop's region. All
1628 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1629 /// method will also update `mergeInfo` by remapping all blocks inside to the
1630 /// newly cloned ones inside structured control flow op's regions.
1631 LogicalResult structurize();
1632
1633 private:
1634 /// Creates a new spv.mlir.selection op at the beginning of the `mergeBlock`.
1635 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1636
1637 /// Creates a new spv.mlir.loop op at the beginning of the `mergeBlock`.
1638 spirv::LoopOp createLoopOp(uint32_t loopControl);
1639
1640 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1641 void collectBlocksInConstruct();
1642
1643 Location location;
1644 uint32_t control;
1645
1646 spirv::BlockMergeInfoMap &blockMergeInfo;
1647
1648 Block *headerBlock;
1649 Block *mergeBlock;
1650 Block *continueBlock; // nullptr for spv.mlir.selection
1651
1652 SetVector<Block *> constructBlocks;
1653
1654 #ifndef NDEBUG
1655 /// A logger used to emit information during the deserialzation process.
1656 llvm::ScopedPrinter &logger;
1657 #endif
1658 };
1659 } // namespace
1660
1661 spirv::SelectionOp
createSelectionOp(uint32_t selectionControl)1662 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1663 // Create a builder and set the insertion point to the beginning of the
1664 // merge block so that the newly created SelectionOp will be inserted there.
1665 OpBuilder builder(&mergeBlock->front());
1666
1667 auto control = static_cast<spirv::SelectionControl>(selectionControl);
1668 auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1669 selectionOp.addMergeBlock();
1670
1671 return selectionOp;
1672 }
1673
createLoopOp(uint32_t loopControl)1674 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1675 // Create a builder and set the insertion point to the beginning of the
1676 // merge block so that the newly created LoopOp will be inserted there.
1677 OpBuilder builder(&mergeBlock->front());
1678
1679 auto control = static_cast<spirv::LoopControl>(loopControl);
1680 auto loopOp = builder.create<spirv::LoopOp>(location, control);
1681 loopOp.addEntryAndMergeBlock();
1682
1683 return loopOp;
1684 }
1685
collectBlocksInConstruct()1686 void ControlFlowStructurizer::collectBlocksInConstruct() {
1687 assert(constructBlocks.empty() && "expected empty constructBlocks");
1688
1689 // Put the header block in the work list first.
1690 constructBlocks.insert(headerBlock);
1691
1692 // For each item in the work list, add its successors excluding the merge
1693 // block.
1694 for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1695 for (auto *successor : constructBlocks[i]->getSuccessors())
1696 if (successor != mergeBlock)
1697 constructBlocks.insert(successor);
1698 }
1699 }
1700
structurize()1701 LogicalResult ControlFlowStructurizer::structurize() {
1702 Operation *op = nullptr;
1703 bool isLoop = continueBlock != nullptr;
1704 if (isLoop) {
1705 if (auto loopOp = createLoopOp(control))
1706 op = loopOp.getOperation();
1707 } else {
1708 if (auto selectionOp = createSelectionOp(control))
1709 op = selectionOp.getOperation();
1710 }
1711 if (!op)
1712 return failure();
1713 Region &body = op->getRegion(0);
1714
1715 BlockAndValueMapping mapper;
1716 // All references to the old merge block should be directed to the
1717 // selection/loop merge block in the SelectionOp/LoopOp's region.
1718 mapper.map(mergeBlock, &body.back());
1719
1720 collectBlocksInConstruct();
1721
1722 // We've identified all blocks belonging to the selection/loop's region. Now
1723 // need to "move" them into the selection/loop. Instead of really moving the
1724 // blocks, in the following we copy them and remap all values and branches.
1725 // This is because:
1726 // * Inserting a block into a region requires the block not in any region
1727 // before. But selections/loops can nest so we can create selection/loop ops
1728 // in a nested manner, which means some blocks may already be in a
1729 // selection/loop region when to be moved again.
1730 // * It's much trickier to fix up the branches into and out of the loop's
1731 // region: we need to treat not-moved blocks and moved blocks differently:
1732 // Not-moved blocks jumping to the loop header block need to jump to the
1733 // merge point containing the new loop op but not the loop continue block's
1734 // back edge. Moved blocks jumping out of the loop need to jump to the
1735 // merge block inside the loop region but not other not-moved blocks.
1736 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1737 // logic.
1738
1739 // Create a corresponding block in the SelectionOp/LoopOp's region for each
1740 // block in this loop construct.
1741 OpBuilder builder(body);
1742 for (auto *block : constructBlocks) {
1743 // Create a block and insert it before the selection/loop merge block in the
1744 // SelectionOp/LoopOp's region.
1745 auto *newBlock = builder.createBlock(&body.back());
1746 mapper.map(block, newBlock);
1747 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
1748 << " from block " << block << "\n");
1749 if (!isFnEntryBlock(block)) {
1750 for (BlockArgument blockArg : block->getArguments()) {
1751 auto newArg =
1752 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1753 mapper.map(blockArg, newArg);
1754 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
1755 << blockArg << " to " << newArg << "\n");
1756 }
1757 } else {
1758 LLVM_DEBUG(logger.startLine()
1759 << "[cf] block " << block << " is a function entry block\n");
1760 }
1761
1762 for (auto &op : *block)
1763 newBlock->push_back(op.clone(mapper));
1764 }
1765
1766 // Go through all ops and remap the operands.
1767 auto remapOperands = [&](Operation *op) {
1768 for (auto &operand : op->getOpOperands())
1769 if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1770 operand.set(mappedOp);
1771 for (auto &succOp : op->getBlockOperands())
1772 if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1773 succOp.set(mappedOp);
1774 };
1775 for (auto &block : body)
1776 block.walk(remapOperands);
1777
1778 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1779 // the selection/loop construct into its region. Next we need to fix the
1780 // connections between this new SelectionOp/LoopOp with existing blocks.
1781
1782 // All existing incoming branches should go to the merge block, where the
1783 // SelectionOp/LoopOp resides right now.
1784 headerBlock->replaceAllUsesWith(mergeBlock);
1785
1786 LLVM_DEBUG({
1787 logger.startLine() << "[cf] after cloning and fixing references:\n";
1788 headerBlock->getParentOp()->print(logger.getOStream());
1789 logger.startLine() << "\n";
1790 });
1791
1792 if (isLoop) {
1793 if (!mergeBlock->args_empty()) {
1794 return mergeBlock->getParentOp()->emitError(
1795 "OpPhi in loop merge block unsupported");
1796 }
1797
1798 // The loop header block may have block arguments. Since now we place the
1799 // loop op inside the old merge block, we need to make sure the old merge
1800 // block has the same block argument list.
1801 for (BlockArgument blockArg : headerBlock->getArguments())
1802 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
1803
1804 // If the loop header block has block arguments, make sure the spv.Branch op
1805 // matches.
1806 SmallVector<Value, 4> blockArgs;
1807 if (!headerBlock->args_empty())
1808 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1809
1810 // The loop entry block should have a unconditional branch jumping to the
1811 // loop header block.
1812 builder.setInsertionPointToEnd(&body.front());
1813 builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1814 ArrayRef<Value>(blockArgs));
1815 }
1816
1817 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1818 // cleaned up.
1819 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
1820 // First we need to drop all operands' references inside all blocks. This is
1821 // needed because we can have blocks referencing SSA values from one another.
1822 for (auto *block : constructBlocks)
1823 block->dropAllReferences();
1824
1825 // Check that whether some op in the to-be-erased blocks still has uses. Those
1826 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
1827 // region. We cannot handle such cases given that once a value is sinked into
1828 // the SelectionOp/LoopOp's region, there is no escape for it:
1829 // SelectionOp/LooOp does not support yield values right now.
1830 for (auto *block : constructBlocks) {
1831 for (Operation &op : *block)
1832 if (!op.use_empty())
1833 return op.emitOpError(
1834 "failed control flow structurization: it has uses outside of the "
1835 "enclosing selection/loop construct");
1836 }
1837
1838 // Then erase all old blocks.
1839 for (auto *block : constructBlocks) {
1840 // We've cloned all blocks belonging to this construct into the structured
1841 // control flow op's region. Among these blocks, some may compose another
1842 // selection/loop. If so, they will be recorded within blockMergeInfo.
1843 // We need to update the pointers there to the newly remapped ones so we can
1844 // continue structurizing them later.
1845 // TODO: The asserts in the following assumes input SPIR-V blob forms
1846 // correctly nested selection/loop constructs. We should relax this and
1847 // support error cases better.
1848 auto it = blockMergeInfo.find(block);
1849 if (it != blockMergeInfo.end()) {
1850 // Use the original location for nested selection/loop ops.
1851 Location loc = it->second.loc;
1852
1853 Block *newHeader = mapper.lookupOrNull(block);
1854 if (!newHeader)
1855 return emitError(loc, "failed control flow structurization: nested "
1856 "loop header block should be remapped!");
1857
1858 Block *newContinue = it->second.continueBlock;
1859 if (newContinue) {
1860 newContinue = mapper.lookupOrNull(newContinue);
1861 if (!newContinue)
1862 return emitError(loc, "failed control flow structurization: nested "
1863 "loop continue block should be remapped!");
1864 }
1865
1866 Block *newMerge = it->second.mergeBlock;
1867 if (Block *mappedTo = mapper.lookupOrNull(newMerge))
1868 newMerge = mappedTo;
1869
1870 // The iterator should be erased before adding a new entry into
1871 // blockMergeInfo to avoid iterator invalidation.
1872 blockMergeInfo.erase(it);
1873 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
1874 newContinue);
1875 }
1876
1877 // The structured selection/loop's entry block does not have arguments.
1878 // If the function's header block is also part of the structured control
1879 // flow, we cannot just simply erase it because it may contain arguments
1880 // matching the function signature and used by the cloned blocks.
1881 if (isFnEntryBlock(block)) {
1882 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
1883 << " to only contain a spv.Branch op\n");
1884 // Still keep the function entry block for the potential block arguments,
1885 // but replace all ops inside with a branch to the merge block.
1886 block->clear();
1887 builder.setInsertionPointToEnd(block);
1888 builder.create<spirv::BranchOp>(location, mergeBlock);
1889 } else {
1890 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
1891 block->erase();
1892 }
1893 }
1894
1895 LLVM_DEBUG(logger.startLine()
1896 << "[cf] after structurizing construct with header block "
1897 << headerBlock << ":\n"
1898 << *op << "\n");
1899
1900 return success();
1901 }
1902
wireUpBlockArgument()1903 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
1904 LLVM_DEBUG({
1905 logger.startLine()
1906 << "//----- [phi] start wiring up block arguments -----//\n";
1907 logger.indent();
1908 });
1909
1910 OpBuilder::InsertionGuard guard(opBuilder);
1911
1912 for (const auto &info : blockPhiInfo) {
1913 Block *block = info.first.first;
1914 Block *target = info.first.second;
1915 const BlockPhiInfo &phiInfo = info.second;
1916 LLVM_DEBUG({
1917 logger.startLine() << "[phi] block " << block << "\n";
1918 logger.startLine() << "[phi] before creating block argument:\n";
1919 block->getParentOp()->print(logger.getOStream());
1920 logger.startLine() << "\n";
1921 });
1922
1923 // Set insertion point to before this block's terminator early because we
1924 // may materialize ops via getValue() call.
1925 auto *op = block->getTerminator();
1926 opBuilder.setInsertionPoint(op);
1927
1928 SmallVector<Value, 4> blockArgs;
1929 blockArgs.reserve(phiInfo.size());
1930 for (uint32_t valueId : phiInfo) {
1931 if (Value value = getValue(valueId)) {
1932 blockArgs.push_back(value);
1933 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
1934 << " id = " << valueId << "\n");
1935 } else {
1936 return emitError(unknownLoc, "OpPhi references undefined value!");
1937 }
1938 }
1939
1940 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
1941 // Replace the previous branch op with a new one with block arguments.
1942 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
1943 blockArgs);
1944 branchOp.erase();
1945 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
1946 assert((branchCondOp.getTrueBlock() == target ||
1947 branchCondOp.getFalseBlock() == target) &&
1948 "expected target to be either the true or false target");
1949 if (target == branchCondOp.trueTarget())
1950 opBuilder.create<spirv::BranchConditionalOp>(
1951 branchCondOp.getLoc(), branchCondOp.condition(), blockArgs,
1952 branchCondOp.getFalseBlockArguments(),
1953 branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(),
1954 branchCondOp.falseTarget());
1955 else
1956 opBuilder.create<spirv::BranchConditionalOp>(
1957 branchCondOp.getLoc(), branchCondOp.condition(),
1958 branchCondOp.getTrueBlockArguments(), blockArgs,
1959 branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(),
1960 branchCondOp.getFalseBlock());
1961
1962 branchCondOp.erase();
1963 } else {
1964 return emitError(unknownLoc, "unimplemented terminator for Phi creation");
1965 }
1966
1967 LLVM_DEBUG({
1968 logger.startLine() << "[phi] after creating block argument:\n";
1969 block->getParentOp()->print(logger.getOStream());
1970 logger.startLine() << "\n";
1971 });
1972 }
1973 blockPhiInfo.clear();
1974
1975 LLVM_DEBUG({
1976 logger.unindent();
1977 logger.startLine()
1978 << "//--- [phi] completed wiring up block arguments ---//\n";
1979 });
1980 return success();
1981 }
1982
structurizeControlFlow()1983 LogicalResult spirv::Deserializer::structurizeControlFlow() {
1984 LLVM_DEBUG({
1985 logger.startLine()
1986 << "//----- [cf] start structurizing control flow -----//\n";
1987 logger.indent();
1988 });
1989
1990 while (!blockMergeInfo.empty()) {
1991 Block *headerBlock = blockMergeInfo.begin()->first;
1992 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
1993
1994 LLVM_DEBUG({
1995 logger.startLine() << "[cf] header block " << headerBlock << ":\n";
1996 headerBlock->print(logger.getOStream());
1997 logger.startLine() << "\n";
1998 });
1999
2000 auto *mergeBlock = mergeInfo.mergeBlock;
2001 assert(mergeBlock && "merge block cannot be nullptr");
2002 if (!mergeBlock->args_empty())
2003 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2004 LLVM_DEBUG({
2005 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2006 mergeBlock->print(logger.getOStream());
2007 logger.startLine() << "\n";
2008 });
2009
2010 auto *continueBlock = mergeInfo.continueBlock;
2011 LLVM_DEBUG(if (continueBlock) {
2012 logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2013 continueBlock->print(logger.getOStream());
2014 logger.startLine() << "\n";
2015 });
2016 // Erase this case before calling into structurizer, who will update
2017 // blockMergeInfo.
2018 blockMergeInfo.erase(blockMergeInfo.begin());
2019 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2020 blockMergeInfo, headerBlock,
2021 mergeBlock, continueBlock
2022 #ifndef NDEBUG
2023 ,
2024 logger
2025 #endif
2026 );
2027 if (failed(structurizer.structurize()))
2028 return failure();
2029 }
2030
2031 LLVM_DEBUG({
2032 logger.unindent();
2033 logger.startLine()
2034 << "//--- [cf] completed structurizing control flow ---//\n";
2035 });
2036 return success();
2037 }
2038
2039 //===----------------------------------------------------------------------===//
2040 // Debug
2041 //===----------------------------------------------------------------------===//
2042
createFileLineColLoc(OpBuilder opBuilder)2043 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2044 if (!debugLine)
2045 return unknownLoc;
2046
2047 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2048 if (fileName.empty())
2049 fileName = "<unknown>";
2050 return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2051 debugLine->column);
2052 }
2053
2054 LogicalResult
processDebugLine(ArrayRef<uint32_t> operands)2055 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2056 // According to SPIR-V spec:
2057 // "This location information applies to the instructions physically
2058 // following this instruction, up to the first occurrence of any of the
2059 // following: the next end of block, the next OpLine instruction, or the next
2060 // OpNoLine instruction."
2061 if (operands.size() != 3)
2062 return emitError(unknownLoc, "OpLine must have 3 operands");
2063 debugLine = DebugLine{operands[0], operands[1], operands[2]};
2064 return success();
2065 }
2066
clearDebugLine()2067 void spirv::Deserializer::clearDebugLine() { debugLine = llvm::None; }
2068
2069 LogicalResult
processDebugString(ArrayRef<uint32_t> operands)2070 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2071 if (operands.size() < 2)
2072 return emitError(unknownLoc, "OpString needs at least 2 operands");
2073
2074 if (!debugInfoMap.lookup(operands[0]).empty())
2075 return emitError(unknownLoc,
2076 "duplicate debug string found for result <id> ")
2077 << operands[0];
2078
2079 unsigned wordIndex = 1;
2080 StringRef debugString = decodeStringLiteral(operands, wordIndex);
2081 if (wordIndex != operands.size())
2082 return emitError(unknownLoc,
2083 "unexpected trailing words in OpString instruction");
2084
2085 debugInfoMap[operands[0]] = debugString;
2086 return success();
2087 }
2088