1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Location.h"
18 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Debug.h"
22 
23 using namespace mlir;
24 
25 #define DEBUG_TYPE "spirv-deserialization"
26 
27 //===----------------------------------------------------------------------===//
28 // Utility Functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Extracts the opcode from the given first word of a SPIR-V instruction.
extractOpcode(uint32_t word)32 static inline spirv::Opcode extractOpcode(uint32_t word) {
33   return static_cast<spirv::Opcode>(word & 0xffff);
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // Instruction
38 //===----------------------------------------------------------------------===//
39 
getValue(uint32_t id)40 Value spirv::Deserializer::getValue(uint32_t id) {
41   if (auto constInfo = getConstant(id)) {
42     // Materialize a `spv.Constant` op at every use site.
43     return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
44                                                constInfo->first);
45   }
46   if (auto varOp = getGlobalVariable(id)) {
47     auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
48         unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation()));
49     return addressOfOp.pointer();
50   }
51   if (auto constOp = getSpecConstant(id)) {
52     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
53         unknownLoc, constOp.default_value().getType(),
54         SymbolRefAttr::get(constOp.getOperation()));
55     return referenceOfOp.reference();
56   }
57   if (auto constCompositeOp = getSpecConstantComposite(id)) {
58     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
59         unknownLoc, constCompositeOp.type(),
60         SymbolRefAttr::get(constCompositeOp.getOperation()));
61     return referenceOfOp.reference();
62   }
63   if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
64     return materializeSpecConstantOperation(
65         id, specConstOperationInfo->enclodesOpcode,
66         specConstOperationInfo->resultTypeID,
67         specConstOperationInfo->enclosedOpOperands);
68   }
69   if (auto undef = getUndefType(id)) {
70     return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
71   }
72   return valueMap.lookup(id);
73 }
74 
75 LogicalResult
sliceInstruction(spirv::Opcode & opcode,ArrayRef<uint32_t> & operands,Optional<spirv::Opcode> expectedOpcode)76 spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode,
77                                       ArrayRef<uint32_t> &operands,
78                                       Optional<spirv::Opcode> expectedOpcode) {
79   auto binarySize = binary.size();
80   if (curOffset >= binarySize) {
81     return emitError(unknownLoc, "expected ")
82            << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
83                               : "more")
84            << " instruction";
85   }
86 
87   // For each instruction, get its word count from the first word to slice it
88   // from the stream properly, and then dispatch to the instruction handler.
89 
90   uint32_t wordCount = binary[curOffset] >> 16;
91 
92   if (wordCount == 0)
93     return emitError(unknownLoc, "word count cannot be zero");
94 
95   uint32_t nextOffset = curOffset + wordCount;
96   if (nextOffset > binarySize)
97     return emitError(unknownLoc, "insufficient words for the last instruction");
98 
99   opcode = extractOpcode(binary[curOffset]);
100   operands = binary.slice(curOffset + 1, wordCount - 1);
101   curOffset = nextOffset;
102   return success();
103 }
104 
processInstruction(spirv::Opcode opcode,ArrayRef<uint32_t> operands,bool deferInstructions)105 LogicalResult spirv::Deserializer::processInstruction(
106     spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
107   LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
108                                 << spirv::stringifyOpcode(opcode) << "\n");
109 
110   // First dispatch all the instructions whose opcode does not correspond to
111   // those that have a direct mirror in the SPIR-V dialect
112   switch (opcode) {
113   case spirv::Opcode::OpCapability:
114     return processCapability(operands);
115   case spirv::Opcode::OpExtension:
116     return processExtension(operands);
117   case spirv::Opcode::OpExtInst:
118     return processExtInst(operands);
119   case spirv::Opcode::OpExtInstImport:
120     return processExtInstImport(operands);
121   case spirv::Opcode::OpMemberName:
122     return processMemberName(operands);
123   case spirv::Opcode::OpMemoryModel:
124     return processMemoryModel(operands);
125   case spirv::Opcode::OpEntryPoint:
126   case spirv::Opcode::OpExecutionMode:
127     if (deferInstructions) {
128       deferredInstructions.emplace_back(opcode, operands);
129       return success();
130     }
131     break;
132   case spirv::Opcode::OpVariable:
133     if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
134       return processGlobalVariable(operands);
135     }
136     break;
137   case spirv::Opcode::OpLine:
138     return processDebugLine(operands);
139   case spirv::Opcode::OpNoLine:
140     clearDebugLine();
141     return success();
142   case spirv::Opcode::OpName:
143     return processName(operands);
144   case spirv::Opcode::OpString:
145     return processDebugString(operands);
146   case spirv::Opcode::OpModuleProcessed:
147   case spirv::Opcode::OpSource:
148   case spirv::Opcode::OpSourceContinued:
149   case spirv::Opcode::OpSourceExtension:
150     // TODO: This is debug information embedded in the binary which should be
151     // translated into the spv.module.
152     return success();
153   case spirv::Opcode::OpTypeVoid:
154   case spirv::Opcode::OpTypeBool:
155   case spirv::Opcode::OpTypeInt:
156   case spirv::Opcode::OpTypeFloat:
157   case spirv::Opcode::OpTypeVector:
158   case spirv::Opcode::OpTypeMatrix:
159   case spirv::Opcode::OpTypeArray:
160   case spirv::Opcode::OpTypeFunction:
161   case spirv::Opcode::OpTypeImage:
162   case spirv::Opcode::OpTypeSampledImage:
163   case spirv::Opcode::OpTypeRuntimeArray:
164   case spirv::Opcode::OpTypeStruct:
165   case spirv::Opcode::OpTypePointer:
166   case spirv::Opcode::OpTypeCooperativeMatrixNV:
167     return processType(opcode, operands);
168   case spirv::Opcode::OpTypeForwardPointer:
169     return processTypeForwardPointer(operands);
170   case spirv::Opcode::OpConstant:
171     return processConstant(operands, /*isSpec=*/false);
172   case spirv::Opcode::OpSpecConstant:
173     return processConstant(operands, /*isSpec=*/true);
174   case spirv::Opcode::OpConstantComposite:
175     return processConstantComposite(operands);
176   case spirv::Opcode::OpSpecConstantComposite:
177     return processSpecConstantComposite(operands);
178   case spirv::Opcode::OpSpecConstantOp:
179     return processSpecConstantOperation(operands);
180   case spirv::Opcode::OpConstantTrue:
181     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
182   case spirv::Opcode::OpSpecConstantTrue:
183     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
184   case spirv::Opcode::OpConstantFalse:
185     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
186   case spirv::Opcode::OpSpecConstantFalse:
187     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
188   case spirv::Opcode::OpConstantNull:
189     return processConstantNull(operands);
190   case spirv::Opcode::OpDecorate:
191     return processDecoration(operands);
192   case spirv::Opcode::OpMemberDecorate:
193     return processMemberDecoration(operands);
194   case spirv::Opcode::OpFunction:
195     return processFunction(operands);
196   case spirv::Opcode::OpLabel:
197     return processLabel(operands);
198   case spirv::Opcode::OpBranch:
199     return processBranch(operands);
200   case spirv::Opcode::OpBranchConditional:
201     return processBranchConditional(operands);
202   case spirv::Opcode::OpSelectionMerge:
203     return processSelectionMerge(operands);
204   case spirv::Opcode::OpLoopMerge:
205     return processLoopMerge(operands);
206   case spirv::Opcode::OpPhi:
207     return processPhi(operands);
208   case spirv::Opcode::OpUndef:
209     return processUndef(operands);
210   default:
211     break;
212   }
213   return dispatchToAutogenDeserialization(opcode, operands);
214 }
215 
processOpWithoutGrammarAttr(ArrayRef<uint32_t> words,StringRef opName,bool hasResult,unsigned numOperands)216 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
217     ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
218     unsigned numOperands) {
219   SmallVector<Type, 1> resultTypes;
220   uint32_t valueID = 0;
221 
222   size_t wordIndex = 0;
223   if (hasResult) {
224     if (wordIndex >= words.size())
225       return emitError(unknownLoc,
226                        "expected result type <id> while deserializing for ")
227              << opName;
228 
229     // Decode the type <id>
230     auto type = getType(words[wordIndex]);
231     if (!type)
232       return emitError(unknownLoc, "unknown type result <id>: ")
233              << words[wordIndex];
234     resultTypes.push_back(type);
235     ++wordIndex;
236 
237     // Decode the result <id>
238     if (wordIndex >= words.size())
239       return emitError(unknownLoc,
240                        "expected result <id> while deserializing for ")
241              << opName;
242     valueID = words[wordIndex];
243     ++wordIndex;
244   }
245 
246   SmallVector<Value, 4> operands;
247   SmallVector<NamedAttribute, 4> attributes;
248 
249   // Decode operands
250   size_t operandIndex = 0;
251   for (; operandIndex < numOperands && wordIndex < words.size();
252        ++operandIndex, ++wordIndex) {
253     auto arg = getValue(words[wordIndex]);
254     if (!arg)
255       return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
256     operands.push_back(arg);
257   }
258   if (operandIndex != numOperands) {
259     return emitError(
260                unknownLoc,
261                "found less operands than expected when deserializing for ")
262            << opName << "; only " << operandIndex << " of " << numOperands
263            << " processed";
264   }
265   if (wordIndex != words.size()) {
266     return emitError(
267                unknownLoc,
268                "found more operands than expected when deserializing for ")
269            << opName << "; only " << wordIndex << " of " << words.size()
270            << " processed";
271   }
272 
273   // Attach attributes from decorations
274   if (decorations.count(valueID)) {
275     auto attrs = decorations[valueID].getAttrs();
276     attributes.append(attrs.begin(), attrs.end());
277   }
278 
279   // Create the op and update bookkeeping maps
280   Location loc = createFileLineColLoc(opBuilder);
281   OperationState opState(loc, opName);
282   opState.addOperands(operands);
283   if (hasResult)
284     opState.addTypes(resultTypes);
285   opState.addAttributes(attributes);
286   Operation *op = opBuilder.create(opState);
287   if (hasResult)
288     valueMap[valueID] = op->getResult(0);
289 
290   if (op->hasTrait<OpTrait::IsTerminator>())
291     clearDebugLine();
292 
293   return success();
294 }
295 
processUndef(ArrayRef<uint32_t> operands)296 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
297   if (operands.size() != 2) {
298     return emitError(unknownLoc, "OpUndef instruction must have two operands");
299   }
300   auto type = getType(operands[0]);
301   if (!type) {
302     return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
303   }
304   undefMap[operands[1]] = type;
305   return success();
306 }
307 
processExtInst(ArrayRef<uint32_t> operands)308 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
309   if (operands.size() < 4) {
310     return emitError(unknownLoc,
311                      "OpExtInst must have at least 4 operands, result type "
312                      "<id>, result <id>, set <id> and instruction opcode");
313   }
314   if (!extendedInstSets.count(operands[2])) {
315     return emitError(unknownLoc, "undefined set <id> in OpExtInst");
316   }
317   SmallVector<uint32_t, 4> slicedOperands;
318   slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
319   slicedOperands.append(std::next(operands.begin(), 4), operands.end());
320   return dispatchToExtensionSetAutogenDeserialization(
321       extendedInstSets[operands[2]], operands[3], slicedOperands);
322 }
323 
324 namespace mlir {
325 namespace spirv {
326 
327 template <>
328 LogicalResult
processOp(ArrayRef<uint32_t> words)329 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
330   unsigned wordIndex = 0;
331   if (wordIndex >= words.size()) {
332     return emitError(unknownLoc,
333                      "missing Execution Model specification in OpEntryPoint");
334   }
335   auto execModel = spirv::ExecutionModelAttr::get(
336       context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
337   if (wordIndex >= words.size()) {
338     return emitError(unknownLoc, "missing <id> in OpEntryPoint");
339   }
340   // Get the function <id>
341   auto fnID = words[wordIndex++];
342   // Get the function name
343   auto fnName = decodeStringLiteral(words, wordIndex);
344   // Verify that the function <id> matches the fnName
345   auto parsedFunc = getFunction(fnID);
346   if (!parsedFunc) {
347     return emitError(unknownLoc, "no function matching <id> ") << fnID;
348   }
349   if (parsedFunc.getName() != fnName) {
350     // The deserializer uses "spirv_fn_<id>" as the function name if the input
351     // SPIR-V blob does not contain a name for it. We should use a more clear
352     // indication for such case rather than relying on naming details.
353     if (!parsedFunc.getName().startswith("spirv_fn_"))
354       return emitError(unknownLoc,
355                        "function name mismatch between OpEntryPoint "
356                        "and OpFunction with <id> ")
357              << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
358     parsedFunc.setName(fnName);
359   }
360   SmallVector<Attribute, 4> interface;
361   while (wordIndex < words.size()) {
362     auto arg = getGlobalVariable(words[wordIndex]);
363     if (!arg) {
364       return emitError(unknownLoc, "undefined result <id> ")
365              << words[wordIndex] << " while decoding OpEntryPoint";
366     }
367     interface.push_back(SymbolRefAttr::get(arg.getOperation()));
368     wordIndex++;
369   }
370   opBuilder.create<spirv::EntryPointOp>(
371       unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
372       opBuilder.getArrayAttr(interface));
373   return success();
374 }
375 
376 template <>
377 LogicalResult
processOp(ArrayRef<uint32_t> words)378 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
379   unsigned wordIndex = 0;
380   if (wordIndex >= words.size()) {
381     return emitError(unknownLoc,
382                      "missing function result <id> in OpExecutionMode");
383   }
384   // Get the function <id> to get the name of the function
385   auto fnID = words[wordIndex++];
386   auto fn = getFunction(fnID);
387   if (!fn) {
388     return emitError(unknownLoc, "no function matching <id> ") << fnID;
389   }
390   // Get the Execution mode
391   if (wordIndex >= words.size()) {
392     return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
393   }
394   auto execMode = spirv::ExecutionModeAttr::get(
395       context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
396 
397   // Get the values
398   SmallVector<Attribute, 4> attrListElems;
399   while (wordIndex < words.size()) {
400     attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
401   }
402   auto values = opBuilder.getArrayAttr(attrListElems);
403   opBuilder.create<spirv::ExecutionModeOp>(
404       unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
405       execMode, values);
406   return success();
407 }
408 
409 template <>
410 LogicalResult
processOp(ArrayRef<uint32_t> operands)411 Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
412   if (operands.size() != 3) {
413     return emitError(
414         unknownLoc,
415         "OpControlBarrier must have execution scope <id>, memory scope <id> "
416         "and memory semantics <id>");
417   }
418 
419   SmallVector<IntegerAttr, 3> argAttrs;
420   for (auto operand : operands) {
421     auto argAttr = getConstantInt(operand);
422     if (!argAttr) {
423       return emitError(unknownLoc,
424                        "expected 32-bit integer constant from <id> ")
425              << operand << " for OpControlBarrier";
426     }
427     argAttrs.push_back(argAttr);
428   }
429 
430   opBuilder.create<spirv::ControlBarrierOp>(
431       unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
432       argAttrs[1].cast<spirv::ScopeAttr>(),
433       argAttrs[2].cast<spirv::MemorySemanticsAttr>());
434 
435   return success();
436 }
437 
438 template <>
439 LogicalResult
processOp(ArrayRef<uint32_t> operands)440 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
441   if (operands.size() < 3) {
442     return emitError(unknownLoc,
443                      "OpFunctionCall must have at least 3 operands");
444   }
445 
446   Type resultType = getType(operands[0]);
447   if (!resultType) {
448     return emitError(unknownLoc, "undefined result type from <id> ")
449            << operands[0];
450   }
451 
452   // Use null type to mean no result type.
453   if (isVoidType(resultType))
454     resultType = nullptr;
455 
456   auto resultID = operands[1];
457   auto functionID = operands[2];
458 
459   auto functionName = getFunctionSymbol(functionID);
460 
461   SmallVector<Value, 4> arguments;
462   for (auto operand : llvm::drop_begin(operands, 3)) {
463     auto value = getValue(operand);
464     if (!value) {
465       return emitError(unknownLoc, "unknown <id> ")
466              << operand << " used by OpFunctionCall";
467     }
468     arguments.push_back(value);
469   }
470 
471   auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
472       unknownLoc, resultType,
473       SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
474 
475   if (resultType)
476     valueMap[resultID] = opFunctionCall.getResult(0);
477   return success();
478 }
479 
480 template <>
481 LogicalResult
processOp(ArrayRef<uint32_t> operands)482 Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
483   if (operands.size() != 2) {
484     return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
485                                  "and memory semantics <id>");
486   }
487 
488   SmallVector<IntegerAttr, 2> argAttrs;
489   for (auto operand : operands) {
490     auto argAttr = getConstantInt(operand);
491     if (!argAttr) {
492       return emitError(unknownLoc,
493                        "expected 32-bit integer constant from <id> ")
494              << operand << " for OpMemoryBarrier";
495     }
496     argAttrs.push_back(argAttr);
497   }
498 
499   opBuilder.create<spirv::MemoryBarrierOp>(
500       unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
501       argAttrs[1].cast<spirv::MemorySemanticsAttr>());
502   return success();
503 }
504 
505 template <>
506 LogicalResult
processOp(ArrayRef<uint32_t> words)507 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
508   SmallVector<Type, 1> resultTypes;
509   size_t wordIndex = 0;
510   SmallVector<Value, 4> operands;
511   SmallVector<NamedAttribute, 4> attributes;
512 
513   if (wordIndex < words.size()) {
514     auto arg = getValue(words[wordIndex]);
515 
516     if (!arg) {
517       return emitError(unknownLoc, "unknown result <id> : ")
518              << words[wordIndex];
519     }
520 
521     operands.push_back(arg);
522     wordIndex++;
523   }
524 
525   if (wordIndex < words.size()) {
526     auto arg = getValue(words[wordIndex]);
527 
528     if (!arg) {
529       return emitError(unknownLoc, "unknown result <id> : ")
530              << words[wordIndex];
531     }
532 
533     operands.push_back(arg);
534     wordIndex++;
535   }
536 
537   bool isAlignedAttr = false;
538 
539   if (wordIndex < words.size()) {
540     auto attrValue = words[wordIndex++];
541     attributes.push_back(opBuilder.getNamedAttr(
542         "memory_access", opBuilder.getI32IntegerAttr(attrValue)));
543     isAlignedAttr = (attrValue == 2);
544   }
545 
546   if (isAlignedAttr && wordIndex < words.size()) {
547     attributes.push_back(opBuilder.getNamedAttr(
548         "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
549   }
550 
551   if (wordIndex < words.size()) {
552     attributes.push_back(opBuilder.getNamedAttr(
553         "source_memory_access",
554         opBuilder.getI32IntegerAttr(words[wordIndex++])));
555   }
556 
557   if (wordIndex < words.size()) {
558     attributes.push_back(opBuilder.getNamedAttr(
559         "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
560   }
561 
562   if (wordIndex != words.size()) {
563     return emitError(unknownLoc,
564                      "found more operands than expected when deserializing "
565                      "spirv::CopyMemoryOp, only ")
566            << wordIndex << " of " << words.size() << " processed";
567   }
568 
569   Location loc = createFileLineColLoc(opBuilder);
570   opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
571 
572   return success();
573 }
574 
575 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
576 // various Deserializer::processOp<...>() specializations.
577 #define GET_DESERIALIZATION_FNS
578 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
579 
580 } // namespace spirv
581 } // namespace mlir
582