1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "Deserializer.h"
14
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Location.h"
18 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Debug.h"
22
23 using namespace mlir;
24
25 #define DEBUG_TYPE "spirv-deserialization"
26
27 //===----------------------------------------------------------------------===//
28 // Utility Functions
29 //===----------------------------------------------------------------------===//
30
31 /// Extracts the opcode from the given first word of a SPIR-V instruction.
extractOpcode(uint32_t word)32 static inline spirv::Opcode extractOpcode(uint32_t word) {
33 return static_cast<spirv::Opcode>(word & 0xffff);
34 }
35
36 //===----------------------------------------------------------------------===//
37 // Instruction
38 //===----------------------------------------------------------------------===//
39
getValue(uint32_t id)40 Value spirv::Deserializer::getValue(uint32_t id) {
41 if (auto constInfo = getConstant(id)) {
42 // Materialize a `spv.Constant` op at every use site.
43 return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
44 constInfo->first);
45 }
46 if (auto varOp = getGlobalVariable(id)) {
47 auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
48 unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation()));
49 return addressOfOp.pointer();
50 }
51 if (auto constOp = getSpecConstant(id)) {
52 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
53 unknownLoc, constOp.default_value().getType(),
54 SymbolRefAttr::get(constOp.getOperation()));
55 return referenceOfOp.reference();
56 }
57 if (auto constCompositeOp = getSpecConstantComposite(id)) {
58 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
59 unknownLoc, constCompositeOp.type(),
60 SymbolRefAttr::get(constCompositeOp.getOperation()));
61 return referenceOfOp.reference();
62 }
63 if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
64 return materializeSpecConstantOperation(
65 id, specConstOperationInfo->enclodesOpcode,
66 specConstOperationInfo->resultTypeID,
67 specConstOperationInfo->enclosedOpOperands);
68 }
69 if (auto undef = getUndefType(id)) {
70 return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
71 }
72 return valueMap.lookup(id);
73 }
74
75 LogicalResult
sliceInstruction(spirv::Opcode & opcode,ArrayRef<uint32_t> & operands,Optional<spirv::Opcode> expectedOpcode)76 spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode,
77 ArrayRef<uint32_t> &operands,
78 Optional<spirv::Opcode> expectedOpcode) {
79 auto binarySize = binary.size();
80 if (curOffset >= binarySize) {
81 return emitError(unknownLoc, "expected ")
82 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
83 : "more")
84 << " instruction";
85 }
86
87 // For each instruction, get its word count from the first word to slice it
88 // from the stream properly, and then dispatch to the instruction handler.
89
90 uint32_t wordCount = binary[curOffset] >> 16;
91
92 if (wordCount == 0)
93 return emitError(unknownLoc, "word count cannot be zero");
94
95 uint32_t nextOffset = curOffset + wordCount;
96 if (nextOffset > binarySize)
97 return emitError(unknownLoc, "insufficient words for the last instruction");
98
99 opcode = extractOpcode(binary[curOffset]);
100 operands = binary.slice(curOffset + 1, wordCount - 1);
101 curOffset = nextOffset;
102 return success();
103 }
104
processInstruction(spirv::Opcode opcode,ArrayRef<uint32_t> operands,bool deferInstructions)105 LogicalResult spirv::Deserializer::processInstruction(
106 spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
107 LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
108 << spirv::stringifyOpcode(opcode) << "\n");
109
110 // First dispatch all the instructions whose opcode does not correspond to
111 // those that have a direct mirror in the SPIR-V dialect
112 switch (opcode) {
113 case spirv::Opcode::OpCapability:
114 return processCapability(operands);
115 case spirv::Opcode::OpExtension:
116 return processExtension(operands);
117 case spirv::Opcode::OpExtInst:
118 return processExtInst(operands);
119 case spirv::Opcode::OpExtInstImport:
120 return processExtInstImport(operands);
121 case spirv::Opcode::OpMemberName:
122 return processMemberName(operands);
123 case spirv::Opcode::OpMemoryModel:
124 return processMemoryModel(operands);
125 case spirv::Opcode::OpEntryPoint:
126 case spirv::Opcode::OpExecutionMode:
127 if (deferInstructions) {
128 deferredInstructions.emplace_back(opcode, operands);
129 return success();
130 }
131 break;
132 case spirv::Opcode::OpVariable:
133 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
134 return processGlobalVariable(operands);
135 }
136 break;
137 case spirv::Opcode::OpLine:
138 return processDebugLine(operands);
139 case spirv::Opcode::OpNoLine:
140 clearDebugLine();
141 return success();
142 case spirv::Opcode::OpName:
143 return processName(operands);
144 case spirv::Opcode::OpString:
145 return processDebugString(operands);
146 case spirv::Opcode::OpModuleProcessed:
147 case spirv::Opcode::OpSource:
148 case spirv::Opcode::OpSourceContinued:
149 case spirv::Opcode::OpSourceExtension:
150 // TODO: This is debug information embedded in the binary which should be
151 // translated into the spv.module.
152 return success();
153 case spirv::Opcode::OpTypeVoid:
154 case spirv::Opcode::OpTypeBool:
155 case spirv::Opcode::OpTypeInt:
156 case spirv::Opcode::OpTypeFloat:
157 case spirv::Opcode::OpTypeVector:
158 case spirv::Opcode::OpTypeMatrix:
159 case spirv::Opcode::OpTypeArray:
160 case spirv::Opcode::OpTypeFunction:
161 case spirv::Opcode::OpTypeImage:
162 case spirv::Opcode::OpTypeSampledImage:
163 case spirv::Opcode::OpTypeRuntimeArray:
164 case spirv::Opcode::OpTypeStruct:
165 case spirv::Opcode::OpTypePointer:
166 case spirv::Opcode::OpTypeCooperativeMatrixNV:
167 return processType(opcode, operands);
168 case spirv::Opcode::OpTypeForwardPointer:
169 return processTypeForwardPointer(operands);
170 case spirv::Opcode::OpConstant:
171 return processConstant(operands, /*isSpec=*/false);
172 case spirv::Opcode::OpSpecConstant:
173 return processConstant(operands, /*isSpec=*/true);
174 case spirv::Opcode::OpConstantComposite:
175 return processConstantComposite(operands);
176 case spirv::Opcode::OpSpecConstantComposite:
177 return processSpecConstantComposite(operands);
178 case spirv::Opcode::OpSpecConstantOp:
179 return processSpecConstantOperation(operands);
180 case spirv::Opcode::OpConstantTrue:
181 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
182 case spirv::Opcode::OpSpecConstantTrue:
183 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
184 case spirv::Opcode::OpConstantFalse:
185 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
186 case spirv::Opcode::OpSpecConstantFalse:
187 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
188 case spirv::Opcode::OpConstantNull:
189 return processConstantNull(operands);
190 case spirv::Opcode::OpDecorate:
191 return processDecoration(operands);
192 case spirv::Opcode::OpMemberDecorate:
193 return processMemberDecoration(operands);
194 case spirv::Opcode::OpFunction:
195 return processFunction(operands);
196 case spirv::Opcode::OpLabel:
197 return processLabel(operands);
198 case spirv::Opcode::OpBranch:
199 return processBranch(operands);
200 case spirv::Opcode::OpBranchConditional:
201 return processBranchConditional(operands);
202 case spirv::Opcode::OpSelectionMerge:
203 return processSelectionMerge(operands);
204 case spirv::Opcode::OpLoopMerge:
205 return processLoopMerge(operands);
206 case spirv::Opcode::OpPhi:
207 return processPhi(operands);
208 case spirv::Opcode::OpUndef:
209 return processUndef(operands);
210 default:
211 break;
212 }
213 return dispatchToAutogenDeserialization(opcode, operands);
214 }
215
processOpWithoutGrammarAttr(ArrayRef<uint32_t> words,StringRef opName,bool hasResult,unsigned numOperands)216 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
217 ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
218 unsigned numOperands) {
219 SmallVector<Type, 1> resultTypes;
220 uint32_t valueID = 0;
221
222 size_t wordIndex = 0;
223 if (hasResult) {
224 if (wordIndex >= words.size())
225 return emitError(unknownLoc,
226 "expected result type <id> while deserializing for ")
227 << opName;
228
229 // Decode the type <id>
230 auto type = getType(words[wordIndex]);
231 if (!type)
232 return emitError(unknownLoc, "unknown type result <id>: ")
233 << words[wordIndex];
234 resultTypes.push_back(type);
235 ++wordIndex;
236
237 // Decode the result <id>
238 if (wordIndex >= words.size())
239 return emitError(unknownLoc,
240 "expected result <id> while deserializing for ")
241 << opName;
242 valueID = words[wordIndex];
243 ++wordIndex;
244 }
245
246 SmallVector<Value, 4> operands;
247 SmallVector<NamedAttribute, 4> attributes;
248
249 // Decode operands
250 size_t operandIndex = 0;
251 for (; operandIndex < numOperands && wordIndex < words.size();
252 ++operandIndex, ++wordIndex) {
253 auto arg = getValue(words[wordIndex]);
254 if (!arg)
255 return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
256 operands.push_back(arg);
257 }
258 if (operandIndex != numOperands) {
259 return emitError(
260 unknownLoc,
261 "found less operands than expected when deserializing for ")
262 << opName << "; only " << operandIndex << " of " << numOperands
263 << " processed";
264 }
265 if (wordIndex != words.size()) {
266 return emitError(
267 unknownLoc,
268 "found more operands than expected when deserializing for ")
269 << opName << "; only " << wordIndex << " of " << words.size()
270 << " processed";
271 }
272
273 // Attach attributes from decorations
274 if (decorations.count(valueID)) {
275 auto attrs = decorations[valueID].getAttrs();
276 attributes.append(attrs.begin(), attrs.end());
277 }
278
279 // Create the op and update bookkeeping maps
280 Location loc = createFileLineColLoc(opBuilder);
281 OperationState opState(loc, opName);
282 opState.addOperands(operands);
283 if (hasResult)
284 opState.addTypes(resultTypes);
285 opState.addAttributes(attributes);
286 Operation *op = opBuilder.create(opState);
287 if (hasResult)
288 valueMap[valueID] = op->getResult(0);
289
290 if (op->hasTrait<OpTrait::IsTerminator>())
291 clearDebugLine();
292
293 return success();
294 }
295
processUndef(ArrayRef<uint32_t> operands)296 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
297 if (operands.size() != 2) {
298 return emitError(unknownLoc, "OpUndef instruction must have two operands");
299 }
300 auto type = getType(operands[0]);
301 if (!type) {
302 return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
303 }
304 undefMap[operands[1]] = type;
305 return success();
306 }
307
processExtInst(ArrayRef<uint32_t> operands)308 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
309 if (operands.size() < 4) {
310 return emitError(unknownLoc,
311 "OpExtInst must have at least 4 operands, result type "
312 "<id>, result <id>, set <id> and instruction opcode");
313 }
314 if (!extendedInstSets.count(operands[2])) {
315 return emitError(unknownLoc, "undefined set <id> in OpExtInst");
316 }
317 SmallVector<uint32_t, 4> slicedOperands;
318 slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
319 slicedOperands.append(std::next(operands.begin(), 4), operands.end());
320 return dispatchToExtensionSetAutogenDeserialization(
321 extendedInstSets[operands[2]], operands[3], slicedOperands);
322 }
323
324 namespace mlir {
325 namespace spirv {
326
327 template <>
328 LogicalResult
processOp(ArrayRef<uint32_t> words)329 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
330 unsigned wordIndex = 0;
331 if (wordIndex >= words.size()) {
332 return emitError(unknownLoc,
333 "missing Execution Model specification in OpEntryPoint");
334 }
335 auto execModel = spirv::ExecutionModelAttr::get(
336 context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
337 if (wordIndex >= words.size()) {
338 return emitError(unknownLoc, "missing <id> in OpEntryPoint");
339 }
340 // Get the function <id>
341 auto fnID = words[wordIndex++];
342 // Get the function name
343 auto fnName = decodeStringLiteral(words, wordIndex);
344 // Verify that the function <id> matches the fnName
345 auto parsedFunc = getFunction(fnID);
346 if (!parsedFunc) {
347 return emitError(unknownLoc, "no function matching <id> ") << fnID;
348 }
349 if (parsedFunc.getName() != fnName) {
350 // The deserializer uses "spirv_fn_<id>" as the function name if the input
351 // SPIR-V blob does not contain a name for it. We should use a more clear
352 // indication for such case rather than relying on naming details.
353 if (!parsedFunc.getName().startswith("spirv_fn_"))
354 return emitError(unknownLoc,
355 "function name mismatch between OpEntryPoint "
356 "and OpFunction with <id> ")
357 << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
358 parsedFunc.setName(fnName);
359 }
360 SmallVector<Attribute, 4> interface;
361 while (wordIndex < words.size()) {
362 auto arg = getGlobalVariable(words[wordIndex]);
363 if (!arg) {
364 return emitError(unknownLoc, "undefined result <id> ")
365 << words[wordIndex] << " while decoding OpEntryPoint";
366 }
367 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
368 wordIndex++;
369 }
370 opBuilder.create<spirv::EntryPointOp>(
371 unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
372 opBuilder.getArrayAttr(interface));
373 return success();
374 }
375
376 template <>
377 LogicalResult
processOp(ArrayRef<uint32_t> words)378 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
379 unsigned wordIndex = 0;
380 if (wordIndex >= words.size()) {
381 return emitError(unknownLoc,
382 "missing function result <id> in OpExecutionMode");
383 }
384 // Get the function <id> to get the name of the function
385 auto fnID = words[wordIndex++];
386 auto fn = getFunction(fnID);
387 if (!fn) {
388 return emitError(unknownLoc, "no function matching <id> ") << fnID;
389 }
390 // Get the Execution mode
391 if (wordIndex >= words.size()) {
392 return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
393 }
394 auto execMode = spirv::ExecutionModeAttr::get(
395 context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
396
397 // Get the values
398 SmallVector<Attribute, 4> attrListElems;
399 while (wordIndex < words.size()) {
400 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
401 }
402 auto values = opBuilder.getArrayAttr(attrListElems);
403 opBuilder.create<spirv::ExecutionModeOp>(
404 unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
405 execMode, values);
406 return success();
407 }
408
409 template <>
410 LogicalResult
processOp(ArrayRef<uint32_t> operands)411 Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
412 if (operands.size() != 3) {
413 return emitError(
414 unknownLoc,
415 "OpControlBarrier must have execution scope <id>, memory scope <id> "
416 "and memory semantics <id>");
417 }
418
419 SmallVector<IntegerAttr, 3> argAttrs;
420 for (auto operand : operands) {
421 auto argAttr = getConstantInt(operand);
422 if (!argAttr) {
423 return emitError(unknownLoc,
424 "expected 32-bit integer constant from <id> ")
425 << operand << " for OpControlBarrier";
426 }
427 argAttrs.push_back(argAttr);
428 }
429
430 opBuilder.create<spirv::ControlBarrierOp>(
431 unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
432 argAttrs[1].cast<spirv::ScopeAttr>(),
433 argAttrs[2].cast<spirv::MemorySemanticsAttr>());
434
435 return success();
436 }
437
438 template <>
439 LogicalResult
processOp(ArrayRef<uint32_t> operands)440 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
441 if (operands.size() < 3) {
442 return emitError(unknownLoc,
443 "OpFunctionCall must have at least 3 operands");
444 }
445
446 Type resultType = getType(operands[0]);
447 if (!resultType) {
448 return emitError(unknownLoc, "undefined result type from <id> ")
449 << operands[0];
450 }
451
452 // Use null type to mean no result type.
453 if (isVoidType(resultType))
454 resultType = nullptr;
455
456 auto resultID = operands[1];
457 auto functionID = operands[2];
458
459 auto functionName = getFunctionSymbol(functionID);
460
461 SmallVector<Value, 4> arguments;
462 for (auto operand : llvm::drop_begin(operands, 3)) {
463 auto value = getValue(operand);
464 if (!value) {
465 return emitError(unknownLoc, "unknown <id> ")
466 << operand << " used by OpFunctionCall";
467 }
468 arguments.push_back(value);
469 }
470
471 auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
472 unknownLoc, resultType,
473 SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
474
475 if (resultType)
476 valueMap[resultID] = opFunctionCall.getResult(0);
477 return success();
478 }
479
480 template <>
481 LogicalResult
processOp(ArrayRef<uint32_t> operands)482 Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
483 if (operands.size() != 2) {
484 return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
485 "and memory semantics <id>");
486 }
487
488 SmallVector<IntegerAttr, 2> argAttrs;
489 for (auto operand : operands) {
490 auto argAttr = getConstantInt(operand);
491 if (!argAttr) {
492 return emitError(unknownLoc,
493 "expected 32-bit integer constant from <id> ")
494 << operand << " for OpMemoryBarrier";
495 }
496 argAttrs.push_back(argAttr);
497 }
498
499 opBuilder.create<spirv::MemoryBarrierOp>(
500 unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
501 argAttrs[1].cast<spirv::MemorySemanticsAttr>());
502 return success();
503 }
504
505 template <>
506 LogicalResult
processOp(ArrayRef<uint32_t> words)507 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
508 SmallVector<Type, 1> resultTypes;
509 size_t wordIndex = 0;
510 SmallVector<Value, 4> operands;
511 SmallVector<NamedAttribute, 4> attributes;
512
513 if (wordIndex < words.size()) {
514 auto arg = getValue(words[wordIndex]);
515
516 if (!arg) {
517 return emitError(unknownLoc, "unknown result <id> : ")
518 << words[wordIndex];
519 }
520
521 operands.push_back(arg);
522 wordIndex++;
523 }
524
525 if (wordIndex < words.size()) {
526 auto arg = getValue(words[wordIndex]);
527
528 if (!arg) {
529 return emitError(unknownLoc, "unknown result <id> : ")
530 << words[wordIndex];
531 }
532
533 operands.push_back(arg);
534 wordIndex++;
535 }
536
537 bool isAlignedAttr = false;
538
539 if (wordIndex < words.size()) {
540 auto attrValue = words[wordIndex++];
541 attributes.push_back(opBuilder.getNamedAttr(
542 "memory_access", opBuilder.getI32IntegerAttr(attrValue)));
543 isAlignedAttr = (attrValue == 2);
544 }
545
546 if (isAlignedAttr && wordIndex < words.size()) {
547 attributes.push_back(opBuilder.getNamedAttr(
548 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
549 }
550
551 if (wordIndex < words.size()) {
552 attributes.push_back(opBuilder.getNamedAttr(
553 "source_memory_access",
554 opBuilder.getI32IntegerAttr(words[wordIndex++])));
555 }
556
557 if (wordIndex < words.size()) {
558 attributes.push_back(opBuilder.getNamedAttr(
559 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
560 }
561
562 if (wordIndex != words.size()) {
563 return emitError(unknownLoc,
564 "found more operands than expected when deserializing "
565 "spirv::CopyMemoryOp, only ")
566 << wordIndex << " of " << words.size() << " processed";
567 }
568
569 Location loc = createFileLineColLoc(opBuilder);
570 opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
571
572 return success();
573 }
574
575 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
576 // various Deserializer::processOp<...>() specializations.
577 #define GET_DESERIALIZATION_FNS
578 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
579
580 } // namespace spirv
581 } // namespace mlir
582