1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Location.h"
18 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Debug.h"
22 
23 using namespace mlir;
24 
25 #define DEBUG_TYPE "spirv-deserialization"
26 
27 //===----------------------------------------------------------------------===//
28 // Utility Functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Extracts the opcode from the given first word of a SPIR-V instruction.
32 static inline spirv::Opcode extractOpcode(uint32_t word) {
33   return static_cast<spirv::Opcode>(word & 0xffff);
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // Instruction
38 //===----------------------------------------------------------------------===//
39 
40 Value spirv::Deserializer::getValue(uint32_t id) {
41   if (auto constInfo = getConstant(id)) {
42     // Materialize a `spv.Constant` op at every use site.
43     return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
44                                                constInfo->first);
45   }
46   if (auto varOp = getGlobalVariable(id)) {
47     auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
48         unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation()));
49     return addressOfOp.pointer();
50   }
51   if (auto constOp = getSpecConstant(id)) {
52     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
53         unknownLoc, constOp.default_value().getType(),
54         SymbolRefAttr::get(constOp.getOperation()));
55     return referenceOfOp.reference();
56   }
57   if (auto constCompositeOp = getSpecConstantComposite(id)) {
58     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
59         unknownLoc, constCompositeOp.type(),
60         SymbolRefAttr::get(constCompositeOp.getOperation()));
61     return referenceOfOp.reference();
62   }
63   if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
64     return materializeSpecConstantOperation(
65         id, specConstOperationInfo->enclodesOpcode,
66         specConstOperationInfo->resultTypeID,
67         specConstOperationInfo->enclosedOpOperands);
68   }
69   if (auto undef = getUndefType(id)) {
70     return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
71   }
72   return valueMap.lookup(id);
73 }
74 
75 LogicalResult
76 spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode,
77                                       ArrayRef<uint32_t> &operands,
78                                       Optional<spirv::Opcode> expectedOpcode) {
79   auto binarySize = binary.size();
80   if (curOffset >= binarySize) {
81     return emitError(unknownLoc, "expected ")
82            << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
83                               : "more")
84            << " instruction";
85   }
86 
87   // For each instruction, get its word count from the first word to slice it
88   // from the stream properly, and then dispatch to the instruction handler.
89 
90   uint32_t wordCount = binary[curOffset] >> 16;
91 
92   if (wordCount == 0)
93     return emitError(unknownLoc, "word count cannot be zero");
94 
95   uint32_t nextOffset = curOffset + wordCount;
96   if (nextOffset > binarySize)
97     return emitError(unknownLoc, "insufficient words for the last instruction");
98 
99   opcode = extractOpcode(binary[curOffset]);
100   operands = binary.slice(curOffset + 1, wordCount - 1);
101   curOffset = nextOffset;
102   return success();
103 }
104 
105 LogicalResult spirv::Deserializer::processInstruction(
106     spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
107   LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
108                                 << spirv::stringifyOpcode(opcode) << "\n");
109 
110   // First dispatch all the instructions whose opcode does not correspond to
111   // those that have a direct mirror in the SPIR-V dialect
112   switch (opcode) {
113   case spirv::Opcode::OpCapability:
114     return processCapability(operands);
115   case spirv::Opcode::OpExtension:
116     return processExtension(operands);
117   case spirv::Opcode::OpExtInst:
118     return processExtInst(operands);
119   case spirv::Opcode::OpExtInstImport:
120     return processExtInstImport(operands);
121   case spirv::Opcode::OpMemberName:
122     return processMemberName(operands);
123   case spirv::Opcode::OpMemoryModel:
124     return processMemoryModel(operands);
125   case spirv::Opcode::OpEntryPoint:
126   case spirv::Opcode::OpExecutionMode:
127     if (deferInstructions) {
128       deferredInstructions.emplace_back(opcode, operands);
129       return success();
130     }
131     break;
132   case spirv::Opcode::OpVariable:
133     if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
134       return processGlobalVariable(operands);
135     }
136     break;
137   case spirv::Opcode::OpLine:
138     return processDebugLine(operands);
139   case spirv::Opcode::OpNoLine:
140     clearDebugLine();
141     return success();
142   case spirv::Opcode::OpName:
143     return processName(operands);
144   case spirv::Opcode::OpString:
145     return processDebugString(operands);
146   case spirv::Opcode::OpModuleProcessed:
147   case spirv::Opcode::OpSource:
148   case spirv::Opcode::OpSourceContinued:
149   case spirv::Opcode::OpSourceExtension:
150     // TODO: This is debug information embedded in the binary which should be
151     // translated into the spv.module.
152     return success();
153   case spirv::Opcode::OpTypeVoid:
154   case spirv::Opcode::OpTypeBool:
155   case spirv::Opcode::OpTypeInt:
156   case spirv::Opcode::OpTypeFloat:
157   case spirv::Opcode::OpTypeVector:
158   case spirv::Opcode::OpTypeMatrix:
159   case spirv::Opcode::OpTypeArray:
160   case spirv::Opcode::OpTypeFunction:
161   case spirv::Opcode::OpTypeImage:
162   case spirv::Opcode::OpTypeSampledImage:
163   case spirv::Opcode::OpTypeRuntimeArray:
164   case spirv::Opcode::OpTypeStruct:
165   case spirv::Opcode::OpTypePointer:
166   case spirv::Opcode::OpTypeCooperativeMatrixNV:
167     return processType(opcode, operands);
168   case spirv::Opcode::OpTypeForwardPointer:
169     return processTypeForwardPointer(operands);
170   case spirv::Opcode::OpConstant:
171     return processConstant(operands, /*isSpec=*/false);
172   case spirv::Opcode::OpSpecConstant:
173     return processConstant(operands, /*isSpec=*/true);
174   case spirv::Opcode::OpConstantComposite:
175     return processConstantComposite(operands);
176   case spirv::Opcode::OpSpecConstantComposite:
177     return processSpecConstantComposite(operands);
178   case spirv::Opcode::OpSpecConstantOp:
179     return processSpecConstantOperation(operands);
180   case spirv::Opcode::OpConstantTrue:
181     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
182   case spirv::Opcode::OpSpecConstantTrue:
183     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
184   case spirv::Opcode::OpConstantFalse:
185     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
186   case spirv::Opcode::OpSpecConstantFalse:
187     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
188   case spirv::Opcode::OpConstantNull:
189     return processConstantNull(operands);
190   case spirv::Opcode::OpDecorate:
191     return processDecoration(operands);
192   case spirv::Opcode::OpMemberDecorate:
193     return processMemberDecoration(operands);
194   case spirv::Opcode::OpFunction:
195     return processFunction(operands);
196   case spirv::Opcode::OpLabel:
197     return processLabel(operands);
198   case spirv::Opcode::OpBranch:
199     return processBranch(operands);
200   case spirv::Opcode::OpBranchConditional:
201     return processBranchConditional(operands);
202   case spirv::Opcode::OpSelectionMerge:
203     return processSelectionMerge(operands);
204   case spirv::Opcode::OpLoopMerge:
205     return processLoopMerge(operands);
206   case spirv::Opcode::OpPhi:
207     return processPhi(operands);
208   case spirv::Opcode::OpUndef:
209     return processUndef(operands);
210   default:
211     break;
212   }
213   return dispatchToAutogenDeserialization(opcode, operands);
214 }
215 
216 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
217     ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
218     unsigned numOperands) {
219   SmallVector<Type, 1> resultTypes;
220   uint32_t valueID = 0;
221 
222   size_t wordIndex = 0;
223   if (hasResult) {
224     if (wordIndex >= words.size())
225       return emitError(unknownLoc,
226                        "expected result type <id> while deserializing for ")
227              << opName;
228 
229     // Decode the type <id>
230     auto type = getType(words[wordIndex]);
231     if (!type)
232       return emitError(unknownLoc, "unknown type result <id>: ")
233              << words[wordIndex];
234     resultTypes.push_back(type);
235     ++wordIndex;
236 
237     // Decode the result <id>
238     if (wordIndex >= words.size())
239       return emitError(unknownLoc,
240                        "expected result <id> while deserializing for ")
241              << opName;
242     valueID = words[wordIndex];
243     ++wordIndex;
244   }
245 
246   SmallVector<Value, 4> operands;
247   SmallVector<NamedAttribute, 4> attributes;
248 
249   // Decode operands
250   size_t operandIndex = 0;
251   for (; operandIndex < numOperands && wordIndex < words.size();
252        ++operandIndex, ++wordIndex) {
253     auto arg = getValue(words[wordIndex]);
254     if (!arg)
255       return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
256     operands.push_back(arg);
257   }
258   if (operandIndex != numOperands) {
259     return emitError(
260                unknownLoc,
261                "found less operands than expected when deserializing for ")
262            << opName << "; only " << operandIndex << " of " << numOperands
263            << " processed";
264   }
265   if (wordIndex != words.size()) {
266     return emitError(
267                unknownLoc,
268                "found more operands than expected when deserializing for ")
269            << opName << "; only " << wordIndex << " of " << words.size()
270            << " processed";
271   }
272 
273   // Attach attributes from decorations
274   if (decorations.count(valueID)) {
275     auto attrs = decorations[valueID].getAttrs();
276     attributes.append(attrs.begin(), attrs.end());
277   }
278 
279   // Create the op and update bookkeeping maps
280   Location loc = createFileLineColLoc(opBuilder);
281   OperationState opState(loc, opName);
282   opState.addOperands(operands);
283   if (hasResult)
284     opState.addTypes(resultTypes);
285   opState.addAttributes(attributes);
286   Operation *op = opBuilder.create(opState);
287   if (hasResult)
288     valueMap[valueID] = op->getResult(0);
289 
290   if (op->hasTrait<OpTrait::IsTerminator>())
291     clearDebugLine();
292 
293   return success();
294 }
295 
296 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
297   if (operands.size() != 2) {
298     return emitError(unknownLoc, "OpUndef instruction must have two operands");
299   }
300   auto type = getType(operands[0]);
301   if (!type) {
302     return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
303   }
304   undefMap[operands[1]] = type;
305   return success();
306 }
307 
308 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
309   if (operands.size() < 4) {
310     return emitError(unknownLoc,
311                      "OpExtInst must have at least 4 operands, result type "
312                      "<id>, result <id>, set <id> and instruction opcode");
313   }
314   if (!extendedInstSets.count(operands[2])) {
315     return emitError(unknownLoc, "undefined set <id> in OpExtInst");
316   }
317   SmallVector<uint32_t, 4> slicedOperands;
318   slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
319   slicedOperands.append(std::next(operands.begin(), 4), operands.end());
320   return dispatchToExtensionSetAutogenDeserialization(
321       extendedInstSets[operands[2]], operands[3], slicedOperands);
322 }
323 
324 namespace mlir {
325 namespace spirv {
326 
327 template <>
328 LogicalResult
329 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
330   unsigned wordIndex = 0;
331   if (wordIndex >= words.size()) {
332     return emitError(unknownLoc,
333                      "missing Execution Model specification in OpEntryPoint");
334   }
335   auto execModel = spirv::ExecutionModelAttr::get(
336       context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
337   if (wordIndex >= words.size()) {
338     return emitError(unknownLoc, "missing <id> in OpEntryPoint");
339   }
340   // Get the function <id>
341   auto fnID = words[wordIndex++];
342   // Get the function name
343   auto fnName = decodeStringLiteral(words, wordIndex);
344   // Verify that the function <id> matches the fnName
345   auto parsedFunc = getFunction(fnID);
346   if (!parsedFunc) {
347     return emitError(unknownLoc, "no function matching <id> ") << fnID;
348   }
349   if (parsedFunc.getName() != fnName) {
350     return emitError(unknownLoc, "function name mismatch between OpEntryPoint "
351                                  "and OpFunction with <id> ")
352            << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
353   }
354   SmallVector<Attribute, 4> interface;
355   while (wordIndex < words.size()) {
356     auto arg = getGlobalVariable(words[wordIndex]);
357     if (!arg) {
358       return emitError(unknownLoc, "undefined result <id> ")
359              << words[wordIndex] << " while decoding OpEntryPoint";
360     }
361     interface.push_back(SymbolRefAttr::get(arg.getOperation()));
362     wordIndex++;
363   }
364   opBuilder.create<spirv::EntryPointOp>(
365       unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
366       opBuilder.getArrayAttr(interface));
367   return success();
368 }
369 
370 template <>
371 LogicalResult
372 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
373   unsigned wordIndex = 0;
374   if (wordIndex >= words.size()) {
375     return emitError(unknownLoc,
376                      "missing function result <id> in OpExecutionMode");
377   }
378   // Get the function <id> to get the name of the function
379   auto fnID = words[wordIndex++];
380   auto fn = getFunction(fnID);
381   if (!fn) {
382     return emitError(unknownLoc, "no function matching <id> ") << fnID;
383   }
384   // Get the Execution mode
385   if (wordIndex >= words.size()) {
386     return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
387   }
388   auto execMode = spirv::ExecutionModeAttr::get(
389       context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
390 
391   // Get the values
392   SmallVector<Attribute, 4> attrListElems;
393   while (wordIndex < words.size()) {
394     attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
395   }
396   auto values = opBuilder.getArrayAttr(attrListElems);
397   opBuilder.create<spirv::ExecutionModeOp>(
398       unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
399       execMode, values);
400   return success();
401 }
402 
403 template <>
404 LogicalResult
405 Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
406   if (operands.size() != 3) {
407     return emitError(
408         unknownLoc,
409         "OpControlBarrier must have execution scope <id>, memory scope <id> "
410         "and memory semantics <id>");
411   }
412 
413   SmallVector<IntegerAttr, 3> argAttrs;
414   for (auto operand : operands) {
415     auto argAttr = getConstantInt(operand);
416     if (!argAttr) {
417       return emitError(unknownLoc,
418                        "expected 32-bit integer constant from <id> ")
419              << operand << " for OpControlBarrier";
420     }
421     argAttrs.push_back(argAttr);
422   }
423 
424   opBuilder.create<spirv::ControlBarrierOp>(
425       unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
426       argAttrs[1].cast<spirv::ScopeAttr>(),
427       argAttrs[2].cast<spirv::MemorySemanticsAttr>());
428 
429   return success();
430 }
431 
432 template <>
433 LogicalResult
434 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
435   if (operands.size() < 3) {
436     return emitError(unknownLoc,
437                      "OpFunctionCall must have at least 3 operands");
438   }
439 
440   Type resultType = getType(operands[0]);
441   if (!resultType) {
442     return emitError(unknownLoc, "undefined result type from <id> ")
443            << operands[0];
444   }
445 
446   // Use null type to mean no result type.
447   if (isVoidType(resultType))
448     resultType = nullptr;
449 
450   auto resultID = operands[1];
451   auto functionID = operands[2];
452 
453   auto functionName = getFunctionSymbol(functionID);
454 
455   SmallVector<Value, 4> arguments;
456   for (auto operand : llvm::drop_begin(operands, 3)) {
457     auto value = getValue(operand);
458     if (!value) {
459       return emitError(unknownLoc, "unknown <id> ")
460              << operand << " used by OpFunctionCall";
461     }
462     arguments.push_back(value);
463   }
464 
465   auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
466       unknownLoc, resultType,
467       SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
468 
469   if (resultType)
470     valueMap[resultID] = opFunctionCall.getResult(0);
471   return success();
472 }
473 
474 template <>
475 LogicalResult
476 Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
477   if (operands.size() != 2) {
478     return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
479                                  "and memory semantics <id>");
480   }
481 
482   SmallVector<IntegerAttr, 2> argAttrs;
483   for (auto operand : operands) {
484     auto argAttr = getConstantInt(operand);
485     if (!argAttr) {
486       return emitError(unknownLoc,
487                        "expected 32-bit integer constant from <id> ")
488              << operand << " for OpMemoryBarrier";
489     }
490     argAttrs.push_back(argAttr);
491   }
492 
493   opBuilder.create<spirv::MemoryBarrierOp>(
494       unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
495       argAttrs[1].cast<spirv::MemorySemanticsAttr>());
496   return success();
497 }
498 
499 template <>
500 LogicalResult
501 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
502   SmallVector<Type, 1> resultTypes;
503   size_t wordIndex = 0;
504   SmallVector<Value, 4> operands;
505   SmallVector<NamedAttribute, 4> attributes;
506 
507   if (wordIndex < words.size()) {
508     auto arg = getValue(words[wordIndex]);
509 
510     if (!arg) {
511       return emitError(unknownLoc, "unknown result <id> : ")
512              << words[wordIndex];
513     }
514 
515     operands.push_back(arg);
516     wordIndex++;
517   }
518 
519   if (wordIndex < words.size()) {
520     auto arg = getValue(words[wordIndex]);
521 
522     if (!arg) {
523       return emitError(unknownLoc, "unknown result <id> : ")
524              << words[wordIndex];
525     }
526 
527     operands.push_back(arg);
528     wordIndex++;
529   }
530 
531   bool isAlignedAttr = false;
532 
533   if (wordIndex < words.size()) {
534     auto attrValue = words[wordIndex++];
535     attributes.push_back(opBuilder.getNamedAttr(
536         "memory_access", opBuilder.getI32IntegerAttr(attrValue)));
537     isAlignedAttr = (attrValue == 2);
538   }
539 
540   if (isAlignedAttr && wordIndex < words.size()) {
541     attributes.push_back(opBuilder.getNamedAttr(
542         "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
543   }
544 
545   if (wordIndex < words.size()) {
546     attributes.push_back(opBuilder.getNamedAttr(
547         "source_memory_access",
548         opBuilder.getI32IntegerAttr(words[wordIndex++])));
549   }
550 
551   if (wordIndex < words.size()) {
552     attributes.push_back(opBuilder.getNamedAttr(
553         "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
554   }
555 
556   if (wordIndex != words.size()) {
557     return emitError(unknownLoc,
558                      "found more operands than expected when deserializing "
559                      "spirv::CopyMemoryOp, only ")
560            << wordIndex << " of " << words.size() << " processed";
561   }
562 
563   Location loc = createFileLineColLoc(opBuilder);
564   opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
565 
566   return success();
567 }
568 
569 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
570 // various Deserializer::processOp<...>() specializations.
571 #define GET_DESERIALIZATION_FNS
572 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
573 
574 } // namespace spirv
575 } // namespace mlir
576