1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Location.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/Debug.h"
21 
22 using namespace mlir;
23 
24 #define DEBUG_TYPE "spirv-deserialization"
25 
26 //===----------------------------------------------------------------------===//
27 // Utility Functions
28 //===----------------------------------------------------------------------===//
29 
30 /// Extracts the opcode from the given first word of a SPIR-V instruction.
31 static inline spirv::Opcode extractOpcode(uint32_t word) {
32   return static_cast<spirv::Opcode>(word & 0xffff);
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // Instruction
37 //===----------------------------------------------------------------------===//
38 
39 Value spirv::Deserializer::getValue(uint32_t id) {
40   if (auto constInfo = getConstant(id)) {
41     // Materialize a `spv.Constant` op at every use site.
42     return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
43                                                constInfo->first);
44   }
45   if (auto varOp = getGlobalVariable(id)) {
46     auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
47         unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation()));
48     return addressOfOp.pointer();
49   }
50   if (auto constOp = getSpecConstant(id)) {
51     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
52         unknownLoc, constOp.default_value().getType(),
53         SymbolRefAttr::get(constOp.getOperation()));
54     return referenceOfOp.reference();
55   }
56   if (auto constCompositeOp = getSpecConstantComposite(id)) {
57     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
58         unknownLoc, constCompositeOp.type(),
59         SymbolRefAttr::get(constCompositeOp.getOperation()));
60     return referenceOfOp.reference();
61   }
62   if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
63     return materializeSpecConstantOperation(
64         id, specConstOperationInfo->enclodesOpcode,
65         specConstOperationInfo->resultTypeID,
66         specConstOperationInfo->enclosedOpOperands);
67   }
68   if (auto undef = getUndefType(id)) {
69     return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
70   }
71   return valueMap.lookup(id);
72 }
73 
74 LogicalResult
75 spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode,
76                                       ArrayRef<uint32_t> &operands,
77                                       Optional<spirv::Opcode> expectedOpcode) {
78   auto binarySize = binary.size();
79   if (curOffset >= binarySize) {
80     return emitError(unknownLoc, "expected ")
81            << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
82                               : "more")
83            << " instruction";
84   }
85 
86   // For each instruction, get its word count from the first word to slice it
87   // from the stream properly, and then dispatch to the instruction handler.
88 
89   uint32_t wordCount = binary[curOffset] >> 16;
90 
91   if (wordCount == 0)
92     return emitError(unknownLoc, "word count cannot be zero");
93 
94   uint32_t nextOffset = curOffset + wordCount;
95   if (nextOffset > binarySize)
96     return emitError(unknownLoc, "insufficient words for the last instruction");
97 
98   opcode = extractOpcode(binary[curOffset]);
99   operands = binary.slice(curOffset + 1, wordCount - 1);
100   curOffset = nextOffset;
101   return success();
102 }
103 
104 LogicalResult spirv::Deserializer::processInstruction(
105     spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
106   LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction "
107                           << spirv::stringifyOpcode(opcode) << "\n");
108 
109   // First dispatch all the instructions whose opcode does not correspond to
110   // those that have a direct mirror in the SPIR-V dialect
111   switch (opcode) {
112   case spirv::Opcode::OpCapability:
113     return processCapability(operands);
114   case spirv::Opcode::OpExtension:
115     return processExtension(operands);
116   case spirv::Opcode::OpExtInst:
117     return processExtInst(operands);
118   case spirv::Opcode::OpExtInstImport:
119     return processExtInstImport(operands);
120   case spirv::Opcode::OpMemberName:
121     return processMemberName(operands);
122   case spirv::Opcode::OpMemoryModel:
123     return processMemoryModel(operands);
124   case spirv::Opcode::OpEntryPoint:
125   case spirv::Opcode::OpExecutionMode:
126     if (deferInstructions) {
127       deferredInstructions.emplace_back(opcode, operands);
128       return success();
129     }
130     break;
131   case spirv::Opcode::OpVariable:
132     if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
133       return processGlobalVariable(operands);
134     }
135     break;
136   case spirv::Opcode::OpLine:
137     return processDebugLine(operands);
138   case spirv::Opcode::OpNoLine:
139     return clearDebugLine();
140   case spirv::Opcode::OpName:
141     return processName(operands);
142   case spirv::Opcode::OpString:
143     return processDebugString(operands);
144   case spirv::Opcode::OpModuleProcessed:
145   case spirv::Opcode::OpSource:
146   case spirv::Opcode::OpSourceContinued:
147   case spirv::Opcode::OpSourceExtension:
148     // TODO: This is debug information embedded in the binary which should be
149     // translated into the spv.module.
150     return success();
151   case spirv::Opcode::OpTypeVoid:
152   case spirv::Opcode::OpTypeBool:
153   case spirv::Opcode::OpTypeInt:
154   case spirv::Opcode::OpTypeFloat:
155   case spirv::Opcode::OpTypeVector:
156   case spirv::Opcode::OpTypeMatrix:
157   case spirv::Opcode::OpTypeArray:
158   case spirv::Opcode::OpTypeFunction:
159   case spirv::Opcode::OpTypeImage:
160   case spirv::Opcode::OpTypeSampledImage:
161   case spirv::Opcode::OpTypeRuntimeArray:
162   case spirv::Opcode::OpTypeStruct:
163   case spirv::Opcode::OpTypePointer:
164   case spirv::Opcode::OpTypeCooperativeMatrixNV:
165     return processType(opcode, operands);
166   case spirv::Opcode::OpTypeForwardPointer:
167     return processTypeForwardPointer(operands);
168   case spirv::Opcode::OpConstant:
169     return processConstant(operands, /*isSpec=*/false);
170   case spirv::Opcode::OpSpecConstant:
171     return processConstant(operands, /*isSpec=*/true);
172   case spirv::Opcode::OpConstantComposite:
173     return processConstantComposite(operands);
174   case spirv::Opcode::OpSpecConstantComposite:
175     return processSpecConstantComposite(operands);
176   case spirv::Opcode::OpSpecConstantOp:
177     return processSpecConstantOperation(operands);
178   case spirv::Opcode::OpConstantTrue:
179     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
180   case spirv::Opcode::OpSpecConstantTrue:
181     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
182   case spirv::Opcode::OpConstantFalse:
183     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
184   case spirv::Opcode::OpSpecConstantFalse:
185     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
186   case spirv::Opcode::OpConstantNull:
187     return processConstantNull(operands);
188   case spirv::Opcode::OpDecorate:
189     return processDecoration(operands);
190   case spirv::Opcode::OpMemberDecorate:
191     return processMemberDecoration(operands);
192   case spirv::Opcode::OpFunction:
193     return processFunction(operands);
194   case spirv::Opcode::OpLabel:
195     return processLabel(operands);
196   case spirv::Opcode::OpBranch:
197     return processBranch(operands);
198   case spirv::Opcode::OpBranchConditional:
199     return processBranchConditional(operands);
200   case spirv::Opcode::OpSelectionMerge:
201     return processSelectionMerge(operands);
202   case spirv::Opcode::OpLoopMerge:
203     return processLoopMerge(operands);
204   case spirv::Opcode::OpPhi:
205     return processPhi(operands);
206   case spirv::Opcode::OpUndef:
207     return processUndef(operands);
208   default:
209     break;
210   }
211   return dispatchToAutogenDeserialization(opcode, operands);
212 }
213 
214 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
215     ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
216     unsigned numOperands) {
217   SmallVector<Type, 1> resultTypes;
218   uint32_t valueID = 0;
219 
220   size_t wordIndex = 0;
221   if (hasResult) {
222     if (wordIndex >= words.size())
223       return emitError(unknownLoc,
224                        "expected result type <id> while deserializing for ")
225              << opName;
226 
227     // Decode the type <id>
228     auto type = getType(words[wordIndex]);
229     if (!type)
230       return emitError(unknownLoc, "unknown type result <id>: ")
231              << words[wordIndex];
232     resultTypes.push_back(type);
233     ++wordIndex;
234 
235     // Decode the result <id>
236     if (wordIndex >= words.size())
237       return emitError(unknownLoc,
238                        "expected result <id> while deserializing for ")
239              << opName;
240     valueID = words[wordIndex];
241     ++wordIndex;
242   }
243 
244   SmallVector<Value, 4> operands;
245   SmallVector<NamedAttribute, 4> attributes;
246 
247   // Decode operands
248   size_t operandIndex = 0;
249   for (; operandIndex < numOperands && wordIndex < words.size();
250        ++operandIndex, ++wordIndex) {
251     auto arg = getValue(words[wordIndex]);
252     if (!arg)
253       return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
254     operands.push_back(arg);
255   }
256   if (operandIndex != numOperands) {
257     return emitError(
258                unknownLoc,
259                "found less operands than expected when deserializing for ")
260            << opName << "; only " << operandIndex << " of " << numOperands
261            << " processed";
262   }
263   if (wordIndex != words.size()) {
264     return emitError(
265                unknownLoc,
266                "found more operands than expected when deserializing for ")
267            << opName << "; only " << wordIndex << " of " << words.size()
268            << " processed";
269   }
270 
271   // Attach attributes from decorations
272   if (decorations.count(valueID)) {
273     auto attrs = decorations[valueID].getAttrs();
274     attributes.append(attrs.begin(), attrs.end());
275   }
276 
277   // Create the op and update bookkeeping maps
278   Location loc = createFileLineColLoc(opBuilder);
279   OperationState opState(loc, opName);
280   opState.addOperands(operands);
281   if (hasResult)
282     opState.addTypes(resultTypes);
283   opState.addAttributes(attributes);
284   Operation *op = opBuilder.createOperation(opState);
285   if (hasResult)
286     valueMap[valueID] = op->getResult(0);
287 
288   if (op->hasTrait<OpTrait::IsTerminator>())
289     (void)clearDebugLine();
290 
291   return success();
292 }
293 
294 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
295   if (operands.size() != 2) {
296     return emitError(unknownLoc, "OpUndef instruction must have two operands");
297   }
298   auto type = getType(operands[0]);
299   if (!type) {
300     return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
301   }
302   undefMap[operands[1]] = type;
303   return success();
304 }
305 
306 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
307   if (operands.size() < 4) {
308     return emitError(unknownLoc,
309                      "OpExtInst must have at least 4 operands, result type "
310                      "<id>, result <id>, set <id> and instruction opcode");
311   }
312   if (!extendedInstSets.count(operands[2])) {
313     return emitError(unknownLoc, "undefined set <id> in OpExtInst");
314   }
315   SmallVector<uint32_t, 4> slicedOperands;
316   slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
317   slicedOperands.append(std::next(operands.begin(), 4), operands.end());
318   return dispatchToExtensionSetAutogenDeserialization(
319       extendedInstSets[operands[2]], operands[3], slicedOperands);
320 }
321 
322 namespace mlir {
323 namespace spirv {
324 
325 template <>
326 LogicalResult
327 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
328   unsigned wordIndex = 0;
329   if (wordIndex >= words.size()) {
330     return emitError(unknownLoc,
331                      "missing Execution Model specification in OpEntryPoint");
332   }
333   auto execModel = spirv::ExecutionModelAttr::get(
334       context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
335   if (wordIndex >= words.size()) {
336     return emitError(unknownLoc, "missing <id> in OpEntryPoint");
337   }
338   // Get the function <id>
339   auto fnID = words[wordIndex++];
340   // Get the function name
341   auto fnName = decodeStringLiteral(words, wordIndex);
342   // Verify that the function <id> matches the fnName
343   auto parsedFunc = getFunction(fnID);
344   if (!parsedFunc) {
345     return emitError(unknownLoc, "no function matching <id> ") << fnID;
346   }
347   if (parsedFunc.getName() != fnName) {
348     return emitError(unknownLoc, "function name mismatch between OpEntryPoint "
349                                  "and OpFunction with <id> ")
350            << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
351   }
352   SmallVector<Attribute, 4> interface;
353   while (wordIndex < words.size()) {
354     auto arg = getGlobalVariable(words[wordIndex]);
355     if (!arg) {
356       return emitError(unknownLoc, "undefined result <id> ")
357              << words[wordIndex] << " while decoding OpEntryPoint";
358     }
359     interface.push_back(SymbolRefAttr::get(arg.getOperation()));
360     wordIndex++;
361   }
362   opBuilder.create<spirv::EntryPointOp>(
363       unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
364       opBuilder.getArrayAttr(interface));
365   return success();
366 }
367 
368 template <>
369 LogicalResult
370 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
371   unsigned wordIndex = 0;
372   if (wordIndex >= words.size()) {
373     return emitError(unknownLoc,
374                      "missing function result <id> in OpExecutionMode");
375   }
376   // Get the function <id> to get the name of the function
377   auto fnID = words[wordIndex++];
378   auto fn = getFunction(fnID);
379   if (!fn) {
380     return emitError(unknownLoc, "no function matching <id> ") << fnID;
381   }
382   // Get the Execution mode
383   if (wordIndex >= words.size()) {
384     return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
385   }
386   auto execMode = spirv::ExecutionModeAttr::get(
387       context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
388 
389   // Get the values
390   SmallVector<Attribute, 4> attrListElems;
391   while (wordIndex < words.size()) {
392     attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
393   }
394   auto values = opBuilder.getArrayAttr(attrListElems);
395   opBuilder.create<spirv::ExecutionModeOp>(
396       unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
397       execMode, values);
398   return success();
399 }
400 
401 template <>
402 LogicalResult
403 Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
404   if (operands.size() != 3) {
405     return emitError(
406         unknownLoc,
407         "OpControlBarrier must have execution scope <id>, memory scope <id> "
408         "and memory semantics <id>");
409   }
410 
411   SmallVector<IntegerAttr, 3> argAttrs;
412   for (auto operand : operands) {
413     auto argAttr = getConstantInt(operand);
414     if (!argAttr) {
415       return emitError(unknownLoc,
416                        "expected 32-bit integer constant from <id> ")
417              << operand << " for OpControlBarrier";
418     }
419     argAttrs.push_back(argAttr);
420   }
421 
422   opBuilder.create<spirv::ControlBarrierOp>(
423       unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
424       argAttrs[1].cast<spirv::ScopeAttr>(),
425       argAttrs[2].cast<spirv::MemorySemanticsAttr>());
426 
427   return success();
428 }
429 
430 template <>
431 LogicalResult
432 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
433   if (operands.size() < 3) {
434     return emitError(unknownLoc,
435                      "OpFunctionCall must have at least 3 operands");
436   }
437 
438   Type resultType = getType(operands[0]);
439   if (!resultType) {
440     return emitError(unknownLoc, "undefined result type from <id> ")
441            << operands[0];
442   }
443 
444   // Use null type to mean no result type.
445   if (isVoidType(resultType))
446     resultType = nullptr;
447 
448   auto resultID = operands[1];
449   auto functionID = operands[2];
450 
451   auto functionName = getFunctionSymbol(functionID);
452 
453   SmallVector<Value, 4> arguments;
454   for (auto operand : llvm::drop_begin(operands, 3)) {
455     auto value = getValue(operand);
456     if (!value) {
457       return emitError(unknownLoc, "unknown <id> ")
458              << operand << " used by OpFunctionCall";
459     }
460     arguments.push_back(value);
461   }
462 
463   auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
464       unknownLoc, resultType,
465       SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
466 
467   if (resultType)
468     valueMap[resultID] = opFunctionCall.getResult(0);
469   return success();
470 }
471 
472 template <>
473 LogicalResult
474 Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
475   if (operands.size() != 2) {
476     return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
477                                  "and memory semantics <id>");
478   }
479 
480   SmallVector<IntegerAttr, 2> argAttrs;
481   for (auto operand : operands) {
482     auto argAttr = getConstantInt(operand);
483     if (!argAttr) {
484       return emitError(unknownLoc,
485                        "expected 32-bit integer constant from <id> ")
486              << operand << " for OpMemoryBarrier";
487     }
488     argAttrs.push_back(argAttr);
489   }
490 
491   opBuilder.create<spirv::MemoryBarrierOp>(
492       unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
493       argAttrs[1].cast<spirv::MemorySemanticsAttr>());
494   return success();
495 }
496 
497 template <>
498 LogicalResult
499 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
500   SmallVector<Type, 1> resultTypes;
501   size_t wordIndex = 0;
502   SmallVector<Value, 4> operands;
503   SmallVector<NamedAttribute, 4> attributes;
504 
505   if (wordIndex < words.size()) {
506     auto arg = getValue(words[wordIndex]);
507 
508     if (!arg) {
509       return emitError(unknownLoc, "unknown result <id> : ")
510              << words[wordIndex];
511     }
512 
513     operands.push_back(arg);
514     wordIndex++;
515   }
516 
517   if (wordIndex < words.size()) {
518     auto arg = getValue(words[wordIndex]);
519 
520     if (!arg) {
521       return emitError(unknownLoc, "unknown result <id> : ")
522              << words[wordIndex];
523     }
524 
525     operands.push_back(arg);
526     wordIndex++;
527   }
528 
529   bool isAlignedAttr = false;
530 
531   if (wordIndex < words.size()) {
532     auto attrValue = words[wordIndex++];
533     attributes.push_back(opBuilder.getNamedAttr(
534         "memory_access", opBuilder.getI32IntegerAttr(attrValue)));
535     isAlignedAttr = (attrValue == 2);
536   }
537 
538   if (isAlignedAttr && wordIndex < words.size()) {
539     attributes.push_back(opBuilder.getNamedAttr(
540         "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
541   }
542 
543   if (wordIndex < words.size()) {
544     attributes.push_back(opBuilder.getNamedAttr(
545         "source_memory_access",
546         opBuilder.getI32IntegerAttr(words[wordIndex++])));
547   }
548 
549   if (wordIndex < words.size()) {
550     attributes.push_back(opBuilder.getNamedAttr(
551         "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
552   }
553 
554   if (wordIndex != words.size()) {
555     return emitError(unknownLoc,
556                      "found more operands than expected when deserializing "
557                      "spirv::CopyMemoryOp, only ")
558            << wordIndex << " of " << words.size() << " processed";
559   }
560 
561   Location loc = createFileLineColLoc(opBuilder);
562   opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
563 
564   return success();
565 }
566 
567 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
568 // various Deserializer::processOp<...>() specializations.
569 #define GET_DESERIALIZATION_FNS
570 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
571 
572 } // namespace spirv
573 } // namespace mlir
574