1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Location.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/Debug.h"
21 
22 using namespace mlir;
23 
24 #define DEBUG_TYPE "spirv-deserialization"
25 
26 //===----------------------------------------------------------------------===//
27 // Utility Functions
28 //===----------------------------------------------------------------------===//
29 
30 /// Extracts the opcode from the given first word of a SPIR-V instruction.
31 static inline spirv::Opcode extractOpcode(uint32_t word) {
32   return static_cast<spirv::Opcode>(word & 0xffff);
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // Instruction
37 //===----------------------------------------------------------------------===//
38 
39 Value spirv::Deserializer::getValue(uint32_t id) {
40   if (auto constInfo = getConstant(id)) {
41     // Materialize a `spv.constant` op at every use site.
42     return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
43                                                constInfo->first);
44   }
45   if (auto varOp = getGlobalVariable(id)) {
46     auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
47         unknownLoc, varOp.type(),
48         opBuilder.getSymbolRefAttr(varOp.getOperation()));
49     return addressOfOp.pointer();
50   }
51   if (auto constOp = getSpecConstant(id)) {
52     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
53         unknownLoc, constOp.default_value().getType(),
54         opBuilder.getSymbolRefAttr(constOp.getOperation()));
55     return referenceOfOp.reference();
56   }
57   if (auto constCompositeOp = getSpecConstantComposite(id)) {
58     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
59         unknownLoc, constCompositeOp.type(),
60         opBuilder.getSymbolRefAttr(constCompositeOp.getOperation()));
61     return referenceOfOp.reference();
62   }
63   if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
64     return materializeSpecConstantOperation(
65         id, specConstOperationInfo->enclodesOpcode,
66         specConstOperationInfo->resultTypeID,
67         specConstOperationInfo->enclosedOpOperands);
68   }
69   if (auto undef = getUndefType(id)) {
70     return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
71   }
72   return valueMap.lookup(id);
73 }
74 
75 LogicalResult
76 spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode,
77                                       ArrayRef<uint32_t> &operands,
78                                       Optional<spirv::Opcode> expectedOpcode) {
79   auto binarySize = binary.size();
80   if (curOffset >= binarySize) {
81     return emitError(unknownLoc, "expected ")
82            << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
83                               : "more")
84            << " instruction";
85   }
86 
87   // For each instruction, get its word count from the first word to slice it
88   // from the stream properly, and then dispatch to the instruction handler.
89 
90   uint32_t wordCount = binary[curOffset] >> 16;
91 
92   if (wordCount == 0)
93     return emitError(unknownLoc, "word count cannot be zero");
94 
95   uint32_t nextOffset = curOffset + wordCount;
96   if (nextOffset > binarySize)
97     return emitError(unknownLoc, "insufficient words for the last instruction");
98 
99   opcode = extractOpcode(binary[curOffset]);
100   operands = binary.slice(curOffset + 1, wordCount - 1);
101   curOffset = nextOffset;
102   return success();
103 }
104 
105 LogicalResult spirv::Deserializer::processInstruction(
106     spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
107   LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction "
108                           << spirv::stringifyOpcode(opcode) << "\n");
109 
110   // First dispatch all the instructions whose opcode does not correspond to
111   // those that have a direct mirror in the SPIR-V dialect
112   switch (opcode) {
113   case spirv::Opcode::OpCapability:
114     return processCapability(operands);
115   case spirv::Opcode::OpExtension:
116     return processExtension(operands);
117   case spirv::Opcode::OpExtInst:
118     return processExtInst(operands);
119   case spirv::Opcode::OpExtInstImport:
120     return processExtInstImport(operands);
121   case spirv::Opcode::OpMemberName:
122     return processMemberName(operands);
123   case spirv::Opcode::OpMemoryModel:
124     return processMemoryModel(operands);
125   case spirv::Opcode::OpEntryPoint:
126   case spirv::Opcode::OpExecutionMode:
127     if (deferInstructions) {
128       deferredInstructions.emplace_back(opcode, operands);
129       return success();
130     }
131     break;
132   case spirv::Opcode::OpVariable:
133     if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
134       return processGlobalVariable(operands);
135     }
136     break;
137   case spirv::Opcode::OpLine:
138     return processDebugLine(operands);
139   case spirv::Opcode::OpNoLine:
140     return clearDebugLine();
141   case spirv::Opcode::OpName:
142     return processName(operands);
143   case spirv::Opcode::OpString:
144     return processDebugString(operands);
145   case spirv::Opcode::OpModuleProcessed:
146   case spirv::Opcode::OpSource:
147   case spirv::Opcode::OpSourceContinued:
148   case spirv::Opcode::OpSourceExtension:
149     // TODO: This is debug information embedded in the binary which should be
150     // translated into the spv.module.
151     return success();
152   case spirv::Opcode::OpTypeVoid:
153   case spirv::Opcode::OpTypeBool:
154   case spirv::Opcode::OpTypeInt:
155   case spirv::Opcode::OpTypeFloat:
156   case spirv::Opcode::OpTypeVector:
157   case spirv::Opcode::OpTypeMatrix:
158   case spirv::Opcode::OpTypeArray:
159   case spirv::Opcode::OpTypeFunction:
160   case spirv::Opcode::OpTypeRuntimeArray:
161   case spirv::Opcode::OpTypeStruct:
162   case spirv::Opcode::OpTypePointer:
163   case spirv::Opcode::OpTypeCooperativeMatrixNV:
164     return processType(opcode, operands);
165   case spirv::Opcode::OpTypeForwardPointer:
166     return processTypeForwardPointer(operands);
167   case spirv::Opcode::OpConstant:
168     return processConstant(operands, /*isSpec=*/false);
169   case spirv::Opcode::OpSpecConstant:
170     return processConstant(operands, /*isSpec=*/true);
171   case spirv::Opcode::OpConstantComposite:
172     return processConstantComposite(operands);
173   case spirv::Opcode::OpSpecConstantComposite:
174     return processSpecConstantComposite(operands);
175   case spirv::Opcode::OpSpecConstantOperation:
176     return processSpecConstantOperation(operands);
177   case spirv::Opcode::OpConstantTrue:
178     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
179   case spirv::Opcode::OpSpecConstantTrue:
180     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
181   case spirv::Opcode::OpConstantFalse:
182     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
183   case spirv::Opcode::OpSpecConstantFalse:
184     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
185   case spirv::Opcode::OpConstantNull:
186     return processConstantNull(operands);
187   case spirv::Opcode::OpDecorate:
188     return processDecoration(operands);
189   case spirv::Opcode::OpMemberDecorate:
190     return processMemberDecoration(operands);
191   case spirv::Opcode::OpFunction:
192     return processFunction(operands);
193   case spirv::Opcode::OpLabel:
194     return processLabel(operands);
195   case spirv::Opcode::OpBranch:
196     return processBranch(operands);
197   case spirv::Opcode::OpBranchConditional:
198     return processBranchConditional(operands);
199   case spirv::Opcode::OpSelectionMerge:
200     return processSelectionMerge(operands);
201   case spirv::Opcode::OpLoopMerge:
202     return processLoopMerge(operands);
203   case spirv::Opcode::OpPhi:
204     return processPhi(operands);
205   case spirv::Opcode::OpUndef:
206     return processUndef(operands);
207   default:
208     break;
209   }
210   return dispatchToAutogenDeserialization(opcode, operands);
211 }
212 
213 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
214     ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
215     unsigned numOperands) {
216   SmallVector<Type, 1> resultTypes;
217   uint32_t valueID = 0;
218 
219   size_t wordIndex = 0;
220   if (hasResult) {
221     if (wordIndex >= words.size())
222       return emitError(unknownLoc,
223                        "expected result type <id> while deserializing for ")
224              << opName;
225 
226     // Decode the type <id>
227     auto type = getType(words[wordIndex]);
228     if (!type)
229       return emitError(unknownLoc, "unknown type result <id>: ")
230              << words[wordIndex];
231     resultTypes.push_back(type);
232     ++wordIndex;
233 
234     // Decode the result <id>
235     if (wordIndex >= words.size())
236       return emitError(unknownLoc,
237                        "expected result <id> while deserializing for ")
238              << opName;
239     valueID = words[wordIndex];
240     ++wordIndex;
241   }
242 
243   SmallVector<Value, 4> operands;
244   SmallVector<NamedAttribute, 4> attributes;
245 
246   // Decode operands
247   size_t operandIndex = 0;
248   for (; operandIndex < numOperands && wordIndex < words.size();
249        ++operandIndex, ++wordIndex) {
250     auto arg = getValue(words[wordIndex]);
251     if (!arg)
252       return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
253     operands.push_back(arg);
254   }
255   if (operandIndex != numOperands) {
256     return emitError(
257                unknownLoc,
258                "found less operands than expected when deserializing for ")
259            << opName << "; only " << operandIndex << " of " << numOperands
260            << " processed";
261   }
262   if (wordIndex != words.size()) {
263     return emitError(
264                unknownLoc,
265                "found more operands than expected when deserializing for ")
266            << opName << "; only " << wordIndex << " of " << words.size()
267            << " processed";
268   }
269 
270   // Attach attributes from decorations
271   if (decorations.count(valueID)) {
272     auto attrs = decorations[valueID].getAttrs();
273     attributes.append(attrs.begin(), attrs.end());
274   }
275 
276   // Create the op and update bookkeeping maps
277   Location loc = createFileLineColLoc(opBuilder);
278   OperationState opState(loc, opName);
279   opState.addOperands(operands);
280   if (hasResult)
281     opState.addTypes(resultTypes);
282   opState.addAttributes(attributes);
283   Operation *op = opBuilder.createOperation(opState);
284   if (hasResult)
285     valueMap[valueID] = op->getResult(0);
286 
287   if (op->hasTrait<OpTrait::IsTerminator>())
288     clearDebugLine();
289 
290   return success();
291 }
292 
293 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
294   if (operands.size() != 2) {
295     return emitError(unknownLoc, "OpUndef instruction must have two operands");
296   }
297   auto type = getType(operands[0]);
298   if (!type) {
299     return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
300   }
301   undefMap[operands[1]] = type;
302   return success();
303 }
304 
305 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
306   if (operands.size() < 4) {
307     return emitError(unknownLoc,
308                      "OpExtInst must have at least 4 operands, result type "
309                      "<id>, result <id>, set <id> and instruction opcode");
310   }
311   if (!extendedInstSets.count(operands[2])) {
312     return emitError(unknownLoc, "undefined set <id> in OpExtInst");
313   }
314   SmallVector<uint32_t, 4> slicedOperands;
315   slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
316   slicedOperands.append(std::next(operands.begin(), 4), operands.end());
317   return dispatchToExtensionSetAutogenDeserialization(
318       extendedInstSets[operands[2]], operands[3], slicedOperands);
319 }
320 
321 namespace mlir {
322 namespace spirv {
323 
324 template <>
325 LogicalResult
326 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
327   unsigned wordIndex = 0;
328   if (wordIndex >= words.size()) {
329     return emitError(unknownLoc,
330                      "missing Execution Model specification in OpEntryPoint");
331   }
332   auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]);
333   if (wordIndex >= words.size()) {
334     return emitError(unknownLoc, "missing <id> in OpEntryPoint");
335   }
336   // Get the function <id>
337   auto fnID = words[wordIndex++];
338   // Get the function name
339   auto fnName = decodeStringLiteral(words, wordIndex);
340   // Verify that the function <id> matches the fnName
341   auto parsedFunc = getFunction(fnID);
342   if (!parsedFunc) {
343     return emitError(unknownLoc, "no function matching <id> ") << fnID;
344   }
345   if (parsedFunc.getName() != fnName) {
346     return emitError(unknownLoc, "function name mismatch between OpEntryPoint "
347                                  "and OpFunction with <id> ")
348            << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
349   }
350   SmallVector<Attribute, 4> interface;
351   while (wordIndex < words.size()) {
352     auto arg = getGlobalVariable(words[wordIndex]);
353     if (!arg) {
354       return emitError(unknownLoc, "undefined result <id> ")
355              << words[wordIndex] << " while decoding OpEntryPoint";
356     }
357     interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
358     wordIndex++;
359   }
360   opBuilder.create<spirv::EntryPointOp>(unknownLoc, execModel,
361                                         opBuilder.getSymbolRefAttr(fnName),
362                                         opBuilder.getArrayAttr(interface));
363   return success();
364 }
365 
366 template <>
367 LogicalResult
368 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
369   unsigned wordIndex = 0;
370   if (wordIndex >= words.size()) {
371     return emitError(unknownLoc,
372                      "missing function result <id> in OpExecutionMode");
373   }
374   // Get the function <id> to get the name of the function
375   auto fnID = words[wordIndex++];
376   auto fn = getFunction(fnID);
377   if (!fn) {
378     return emitError(unknownLoc, "no function matching <id> ") << fnID;
379   }
380   // Get the Execution mode
381   if (wordIndex >= words.size()) {
382     return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
383   }
384   auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]);
385 
386   // Get the values
387   SmallVector<Attribute, 4> attrListElems;
388   while (wordIndex < words.size()) {
389     attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
390   }
391   auto values = opBuilder.getArrayAttr(attrListElems);
392   opBuilder.create<spirv::ExecutionModeOp>(
393       unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values);
394   return success();
395 }
396 
397 template <>
398 LogicalResult
399 Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
400   if (operands.size() != 3) {
401     return emitError(
402         unknownLoc,
403         "OpControlBarrier must have execution scope <id>, memory scope <id> "
404         "and memory semantics <id>");
405   }
406 
407   SmallVector<IntegerAttr, 3> argAttrs;
408   for (auto operand : operands) {
409     auto argAttr = getConstantInt(operand);
410     if (!argAttr) {
411       return emitError(unknownLoc,
412                        "expected 32-bit integer constant from <id> ")
413              << operand << " for OpControlBarrier";
414     }
415     argAttrs.push_back(argAttr);
416   }
417 
418   opBuilder.create<spirv::ControlBarrierOp>(unknownLoc, argAttrs[0],
419                                             argAttrs[1], argAttrs[2]);
420   return success();
421 }
422 
423 template <>
424 LogicalResult
425 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
426   if (operands.size() < 3) {
427     return emitError(unknownLoc,
428                      "OpFunctionCall must have at least 3 operands");
429   }
430 
431   Type resultType = getType(operands[0]);
432   if (!resultType) {
433     return emitError(unknownLoc, "undefined result type from <id> ")
434            << operands[0];
435   }
436 
437   // Use null type to mean no result type.
438   if (isVoidType(resultType))
439     resultType = nullptr;
440 
441   auto resultID = operands[1];
442   auto functionID = operands[2];
443 
444   auto functionName = getFunctionSymbol(functionID);
445 
446   SmallVector<Value, 4> arguments;
447   for (auto operand : llvm::drop_begin(operands, 3)) {
448     auto value = getValue(operand);
449     if (!value) {
450       return emitError(unknownLoc, "unknown <id> ")
451              << operand << " used by OpFunctionCall";
452     }
453     arguments.push_back(value);
454   }
455 
456   auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
457       unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName),
458       arguments);
459 
460   if (resultType)
461     valueMap[resultID] = opFunctionCall.getResult(0);
462   return success();
463 }
464 
465 template <>
466 LogicalResult
467 Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
468   if (operands.size() != 2) {
469     return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
470                                  "and memory semantics <id>");
471   }
472 
473   SmallVector<IntegerAttr, 2> argAttrs;
474   for (auto operand : operands) {
475     auto argAttr = getConstantInt(operand);
476     if (!argAttr) {
477       return emitError(unknownLoc,
478                        "expected 32-bit integer constant from <id> ")
479              << operand << " for OpMemoryBarrier";
480     }
481     argAttrs.push_back(argAttr);
482   }
483 
484   opBuilder.create<spirv::MemoryBarrierOp>(unknownLoc, argAttrs[0],
485                                            argAttrs[1]);
486   return success();
487 }
488 
489 template <>
490 LogicalResult
491 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
492   SmallVector<Type, 1> resultTypes;
493   size_t wordIndex = 0;
494   SmallVector<Value, 4> operands;
495   SmallVector<NamedAttribute, 4> attributes;
496 
497   if (wordIndex < words.size()) {
498     auto arg = getValue(words[wordIndex]);
499 
500     if (!arg) {
501       return emitError(unknownLoc, "unknown result <id> : ")
502              << words[wordIndex];
503     }
504 
505     operands.push_back(arg);
506     wordIndex++;
507   }
508 
509   if (wordIndex < words.size()) {
510     auto arg = getValue(words[wordIndex]);
511 
512     if (!arg) {
513       return emitError(unknownLoc, "unknown result <id> : ")
514              << words[wordIndex];
515     }
516 
517     operands.push_back(arg);
518     wordIndex++;
519   }
520 
521   bool isAlignedAttr = false;
522 
523   if (wordIndex < words.size()) {
524     auto attrValue = words[wordIndex++];
525     attributes.push_back(opBuilder.getNamedAttr(
526         "memory_access", opBuilder.getI32IntegerAttr(attrValue)));
527     isAlignedAttr = (attrValue == 2);
528   }
529 
530   if (isAlignedAttr && wordIndex < words.size()) {
531     attributes.push_back(opBuilder.getNamedAttr(
532         "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
533   }
534 
535   if (wordIndex < words.size()) {
536     attributes.push_back(opBuilder.getNamedAttr(
537         "source_memory_access",
538         opBuilder.getI32IntegerAttr(words[wordIndex++])));
539   }
540 
541   if (wordIndex < words.size()) {
542     attributes.push_back(opBuilder.getNamedAttr(
543         "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
544   }
545 
546   if (wordIndex != words.size()) {
547     return emitError(unknownLoc,
548                      "found more operands than expected when deserializing "
549                      "spirv::CopyMemoryOp, only ")
550            << wordIndex << " of " << words.size() << " processed";
551   }
552 
553   Location loc = createFileLineColLoc(opBuilder);
554   opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
555 
556   return success();
557 }
558 
559 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
560 // various Deserializer::processOp<...>() specializations.
561 #define GET_DESERIALIZATION_FNS
562 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
563 
564 } // namespace spirv
565 } // namespace mlir
566