1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Deserializer methods for SPIR-V binary instructions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Location.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/Debug.h"
21 
22 using namespace mlir;
23 
24 #define DEBUG_TYPE "spirv-deserialization"
25 
26 //===----------------------------------------------------------------------===//
27 // Utility Functions
28 //===----------------------------------------------------------------------===//
29 
30 /// Extracts the opcode from the given first word of a SPIR-V instruction.
31 static inline spirv::Opcode extractOpcode(uint32_t word) {
32   return static_cast<spirv::Opcode>(word & 0xffff);
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // Instruction
37 //===----------------------------------------------------------------------===//
38 
39 Value spirv::Deserializer::getValue(uint32_t id) {
40   if (auto constInfo = getConstant(id)) {
41     // Materialize a `spv.Constant` op at every use site.
42     return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
43                                                constInfo->first);
44   }
45   if (auto varOp = getGlobalVariable(id)) {
46     auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
47         unknownLoc, varOp.type(),
48         opBuilder.getSymbolRefAttr(varOp.getOperation()));
49     return addressOfOp.pointer();
50   }
51   if (auto constOp = getSpecConstant(id)) {
52     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
53         unknownLoc, constOp.default_value().getType(),
54         opBuilder.getSymbolRefAttr(constOp.getOperation()));
55     return referenceOfOp.reference();
56   }
57   if (auto constCompositeOp = getSpecConstantComposite(id)) {
58     auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
59         unknownLoc, constCompositeOp.type(),
60         opBuilder.getSymbolRefAttr(constCompositeOp.getOperation()));
61     return referenceOfOp.reference();
62   }
63   if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
64     return materializeSpecConstantOperation(
65         id, specConstOperationInfo->enclodesOpcode,
66         specConstOperationInfo->resultTypeID,
67         specConstOperationInfo->enclosedOpOperands);
68   }
69   if (auto undef = getUndefType(id)) {
70     return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
71   }
72   return valueMap.lookup(id);
73 }
74 
75 LogicalResult
76 spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode,
77                                       ArrayRef<uint32_t> &operands,
78                                       Optional<spirv::Opcode> expectedOpcode) {
79   auto binarySize = binary.size();
80   if (curOffset >= binarySize) {
81     return emitError(unknownLoc, "expected ")
82            << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
83                               : "more")
84            << " instruction";
85   }
86 
87   // For each instruction, get its word count from the first word to slice it
88   // from the stream properly, and then dispatch to the instruction handler.
89 
90   uint32_t wordCount = binary[curOffset] >> 16;
91 
92   if (wordCount == 0)
93     return emitError(unknownLoc, "word count cannot be zero");
94 
95   uint32_t nextOffset = curOffset + wordCount;
96   if (nextOffset > binarySize)
97     return emitError(unknownLoc, "insufficient words for the last instruction");
98 
99   opcode = extractOpcode(binary[curOffset]);
100   operands = binary.slice(curOffset + 1, wordCount - 1);
101   curOffset = nextOffset;
102   return success();
103 }
104 
105 LogicalResult spirv::Deserializer::processInstruction(
106     spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
107   LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction "
108                           << spirv::stringifyOpcode(opcode) << "\n");
109 
110   // First dispatch all the instructions whose opcode does not correspond to
111   // those that have a direct mirror in the SPIR-V dialect
112   switch (opcode) {
113   case spirv::Opcode::OpCapability:
114     return processCapability(operands);
115   case spirv::Opcode::OpExtension:
116     return processExtension(operands);
117   case spirv::Opcode::OpExtInst:
118     return processExtInst(operands);
119   case spirv::Opcode::OpExtInstImport:
120     return processExtInstImport(operands);
121   case spirv::Opcode::OpMemberName:
122     return processMemberName(operands);
123   case spirv::Opcode::OpMemoryModel:
124     return processMemoryModel(operands);
125   case spirv::Opcode::OpEntryPoint:
126   case spirv::Opcode::OpExecutionMode:
127     if (deferInstructions) {
128       deferredInstructions.emplace_back(opcode, operands);
129       return success();
130     }
131     break;
132   case spirv::Opcode::OpVariable:
133     if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
134       return processGlobalVariable(operands);
135     }
136     break;
137   case spirv::Opcode::OpLine:
138     return processDebugLine(operands);
139   case spirv::Opcode::OpNoLine:
140     return clearDebugLine();
141   case spirv::Opcode::OpName:
142     return processName(operands);
143   case spirv::Opcode::OpString:
144     return processDebugString(operands);
145   case spirv::Opcode::OpModuleProcessed:
146   case spirv::Opcode::OpSource:
147   case spirv::Opcode::OpSourceContinued:
148   case spirv::Opcode::OpSourceExtension:
149     // TODO: This is debug information embedded in the binary which should be
150     // translated into the spv.module.
151     return success();
152   case spirv::Opcode::OpTypeVoid:
153   case spirv::Opcode::OpTypeBool:
154   case spirv::Opcode::OpTypeInt:
155   case spirv::Opcode::OpTypeFloat:
156   case spirv::Opcode::OpTypeVector:
157   case spirv::Opcode::OpTypeMatrix:
158   case spirv::Opcode::OpTypeArray:
159   case spirv::Opcode::OpTypeFunction:
160   case spirv::Opcode::OpTypeImage:
161   case spirv::Opcode::OpTypeSampledImage:
162   case spirv::Opcode::OpTypeRuntimeArray:
163   case spirv::Opcode::OpTypeStruct:
164   case spirv::Opcode::OpTypePointer:
165   case spirv::Opcode::OpTypeCooperativeMatrixNV:
166     return processType(opcode, operands);
167   case spirv::Opcode::OpTypeForwardPointer:
168     return processTypeForwardPointer(operands);
169   case spirv::Opcode::OpConstant:
170     return processConstant(operands, /*isSpec=*/false);
171   case spirv::Opcode::OpSpecConstant:
172     return processConstant(operands, /*isSpec=*/true);
173   case spirv::Opcode::OpConstantComposite:
174     return processConstantComposite(operands);
175   case spirv::Opcode::OpSpecConstantComposite:
176     return processSpecConstantComposite(operands);
177   case spirv::Opcode::OpSpecConstantOp:
178     return processSpecConstantOperation(operands);
179   case spirv::Opcode::OpConstantTrue:
180     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
181   case spirv::Opcode::OpSpecConstantTrue:
182     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
183   case spirv::Opcode::OpConstantFalse:
184     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
185   case spirv::Opcode::OpSpecConstantFalse:
186     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
187   case spirv::Opcode::OpConstantNull:
188     return processConstantNull(operands);
189   case spirv::Opcode::OpDecorate:
190     return processDecoration(operands);
191   case spirv::Opcode::OpMemberDecorate:
192     return processMemberDecoration(operands);
193   case spirv::Opcode::OpFunction:
194     return processFunction(operands);
195   case spirv::Opcode::OpLabel:
196     return processLabel(operands);
197   case spirv::Opcode::OpBranch:
198     return processBranch(operands);
199   case spirv::Opcode::OpBranchConditional:
200     return processBranchConditional(operands);
201   case spirv::Opcode::OpSelectionMerge:
202     return processSelectionMerge(operands);
203   case spirv::Opcode::OpLoopMerge:
204     return processLoopMerge(operands);
205   case spirv::Opcode::OpPhi:
206     return processPhi(operands);
207   case spirv::Opcode::OpUndef:
208     return processUndef(operands);
209   default:
210     break;
211   }
212   return dispatchToAutogenDeserialization(opcode, operands);
213 }
214 
215 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
216     ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
217     unsigned numOperands) {
218   SmallVector<Type, 1> resultTypes;
219   uint32_t valueID = 0;
220 
221   size_t wordIndex = 0;
222   if (hasResult) {
223     if (wordIndex >= words.size())
224       return emitError(unknownLoc,
225                        "expected result type <id> while deserializing for ")
226              << opName;
227 
228     // Decode the type <id>
229     auto type = getType(words[wordIndex]);
230     if (!type)
231       return emitError(unknownLoc, "unknown type result <id>: ")
232              << words[wordIndex];
233     resultTypes.push_back(type);
234     ++wordIndex;
235 
236     // Decode the result <id>
237     if (wordIndex >= words.size())
238       return emitError(unknownLoc,
239                        "expected result <id> while deserializing for ")
240              << opName;
241     valueID = words[wordIndex];
242     ++wordIndex;
243   }
244 
245   SmallVector<Value, 4> operands;
246   SmallVector<NamedAttribute, 4> attributes;
247 
248   // Decode operands
249   size_t operandIndex = 0;
250   for (; operandIndex < numOperands && wordIndex < words.size();
251        ++operandIndex, ++wordIndex) {
252     auto arg = getValue(words[wordIndex]);
253     if (!arg)
254       return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
255     operands.push_back(arg);
256   }
257   if (operandIndex != numOperands) {
258     return emitError(
259                unknownLoc,
260                "found less operands than expected when deserializing for ")
261            << opName << "; only " << operandIndex << " of " << numOperands
262            << " processed";
263   }
264   if (wordIndex != words.size()) {
265     return emitError(
266                unknownLoc,
267                "found more operands than expected when deserializing for ")
268            << opName << "; only " << wordIndex << " of " << words.size()
269            << " processed";
270   }
271 
272   // Attach attributes from decorations
273   if (decorations.count(valueID)) {
274     auto attrs = decorations[valueID].getAttrs();
275     attributes.append(attrs.begin(), attrs.end());
276   }
277 
278   // Create the op and update bookkeeping maps
279   Location loc = createFileLineColLoc(opBuilder);
280   OperationState opState(loc, opName);
281   opState.addOperands(operands);
282   if (hasResult)
283     opState.addTypes(resultTypes);
284   opState.addAttributes(attributes);
285   Operation *op = opBuilder.createOperation(opState);
286   if (hasResult)
287     valueMap[valueID] = op->getResult(0);
288 
289   if (op->hasTrait<OpTrait::IsTerminator>())
290     (void)clearDebugLine();
291 
292   return success();
293 }
294 
295 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
296   if (operands.size() != 2) {
297     return emitError(unknownLoc, "OpUndef instruction must have two operands");
298   }
299   auto type = getType(operands[0]);
300   if (!type) {
301     return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
302   }
303   undefMap[operands[1]] = type;
304   return success();
305 }
306 
307 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
308   if (operands.size() < 4) {
309     return emitError(unknownLoc,
310                      "OpExtInst must have at least 4 operands, result type "
311                      "<id>, result <id>, set <id> and instruction opcode");
312   }
313   if (!extendedInstSets.count(operands[2])) {
314     return emitError(unknownLoc, "undefined set <id> in OpExtInst");
315   }
316   SmallVector<uint32_t, 4> slicedOperands;
317   slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
318   slicedOperands.append(std::next(operands.begin(), 4), operands.end());
319   return dispatchToExtensionSetAutogenDeserialization(
320       extendedInstSets[operands[2]], operands[3], slicedOperands);
321 }
322 
323 namespace mlir {
324 namespace spirv {
325 
326 template <>
327 LogicalResult
328 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
329   unsigned wordIndex = 0;
330   if (wordIndex >= words.size()) {
331     return emitError(unknownLoc,
332                      "missing Execution Model specification in OpEntryPoint");
333   }
334   auto execModel = spirv::ExecutionModelAttr::get(
335       context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
336   if (wordIndex >= words.size()) {
337     return emitError(unknownLoc, "missing <id> in OpEntryPoint");
338   }
339   // Get the function <id>
340   auto fnID = words[wordIndex++];
341   // Get the function name
342   auto fnName = decodeStringLiteral(words, wordIndex);
343   // Verify that the function <id> matches the fnName
344   auto parsedFunc = getFunction(fnID);
345   if (!parsedFunc) {
346     return emitError(unknownLoc, "no function matching <id> ") << fnID;
347   }
348   if (parsedFunc.getName() != fnName) {
349     return emitError(unknownLoc, "function name mismatch between OpEntryPoint "
350                                  "and OpFunction with <id> ")
351            << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
352   }
353   SmallVector<Attribute, 4> interface;
354   while (wordIndex < words.size()) {
355     auto arg = getGlobalVariable(words[wordIndex]);
356     if (!arg) {
357       return emitError(unknownLoc, "undefined result <id> ")
358              << words[wordIndex] << " while decoding OpEntryPoint";
359     }
360     interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
361     wordIndex++;
362   }
363   opBuilder.create<spirv::EntryPointOp>(unknownLoc, execModel,
364                                         opBuilder.getSymbolRefAttr(fnName),
365                                         opBuilder.getArrayAttr(interface));
366   return success();
367 }
368 
369 template <>
370 LogicalResult
371 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
372   unsigned wordIndex = 0;
373   if (wordIndex >= words.size()) {
374     return emitError(unknownLoc,
375                      "missing function result <id> in OpExecutionMode");
376   }
377   // Get the function <id> to get the name of the function
378   auto fnID = words[wordIndex++];
379   auto fn = getFunction(fnID);
380   if (!fn) {
381     return emitError(unknownLoc, "no function matching <id> ") << fnID;
382   }
383   // Get the Execution mode
384   if (wordIndex >= words.size()) {
385     return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
386   }
387   auto execMode = spirv::ExecutionModeAttr::get(
388       context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
389 
390   // Get the values
391   SmallVector<Attribute, 4> attrListElems;
392   while (wordIndex < words.size()) {
393     attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
394   }
395   auto values = opBuilder.getArrayAttr(attrListElems);
396   opBuilder.create<spirv::ExecutionModeOp>(
397       unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values);
398   return success();
399 }
400 
401 template <>
402 LogicalResult
403 Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
404   if (operands.size() != 3) {
405     return emitError(
406         unknownLoc,
407         "OpControlBarrier must have execution scope <id>, memory scope <id> "
408         "and memory semantics <id>");
409   }
410 
411   SmallVector<IntegerAttr, 3> argAttrs;
412   for (auto operand : operands) {
413     auto argAttr = getConstantInt(operand);
414     if (!argAttr) {
415       return emitError(unknownLoc,
416                        "expected 32-bit integer constant from <id> ")
417              << operand << " for OpControlBarrier";
418     }
419     argAttrs.push_back(argAttr);
420   }
421 
422   opBuilder.create<spirv::ControlBarrierOp>(
423       unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
424       argAttrs[1].cast<spirv::ScopeAttr>(),
425       argAttrs[2].cast<spirv::MemorySemanticsAttr>());
426 
427   return success();
428 }
429 
430 template <>
431 LogicalResult
432 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
433   if (operands.size() < 3) {
434     return emitError(unknownLoc,
435                      "OpFunctionCall must have at least 3 operands");
436   }
437 
438   Type resultType = getType(operands[0]);
439   if (!resultType) {
440     return emitError(unknownLoc, "undefined result type from <id> ")
441            << operands[0];
442   }
443 
444   // Use null type to mean no result type.
445   if (isVoidType(resultType))
446     resultType = nullptr;
447 
448   auto resultID = operands[1];
449   auto functionID = operands[2];
450 
451   auto functionName = getFunctionSymbol(functionID);
452 
453   SmallVector<Value, 4> arguments;
454   for (auto operand : llvm::drop_begin(operands, 3)) {
455     auto value = getValue(operand);
456     if (!value) {
457       return emitError(unknownLoc, "unknown <id> ")
458              << operand << " used by OpFunctionCall";
459     }
460     arguments.push_back(value);
461   }
462 
463   auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
464       unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName),
465       arguments);
466 
467   if (resultType)
468     valueMap[resultID] = opFunctionCall.getResult(0);
469   return success();
470 }
471 
472 template <>
473 LogicalResult
474 Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
475   if (operands.size() != 2) {
476     return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
477                                  "and memory semantics <id>");
478   }
479 
480   SmallVector<IntegerAttr, 2> argAttrs;
481   for (auto operand : operands) {
482     auto argAttr = getConstantInt(operand);
483     if (!argAttr) {
484       return emitError(unknownLoc,
485                        "expected 32-bit integer constant from <id> ")
486              << operand << " for OpMemoryBarrier";
487     }
488     argAttrs.push_back(argAttr);
489   }
490 
491   opBuilder.create<spirv::MemoryBarrierOp>(
492       unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
493       argAttrs[1].cast<spirv::MemorySemanticsAttr>());
494   return success();
495 }
496 
497 template <>
498 LogicalResult
499 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
500   SmallVector<Type, 1> resultTypes;
501   size_t wordIndex = 0;
502   SmallVector<Value, 4> operands;
503   SmallVector<NamedAttribute, 4> attributes;
504 
505   if (wordIndex < words.size()) {
506     auto arg = getValue(words[wordIndex]);
507 
508     if (!arg) {
509       return emitError(unknownLoc, "unknown result <id> : ")
510              << words[wordIndex];
511     }
512 
513     operands.push_back(arg);
514     wordIndex++;
515   }
516 
517   if (wordIndex < words.size()) {
518     auto arg = getValue(words[wordIndex]);
519 
520     if (!arg) {
521       return emitError(unknownLoc, "unknown result <id> : ")
522              << words[wordIndex];
523     }
524 
525     operands.push_back(arg);
526     wordIndex++;
527   }
528 
529   bool isAlignedAttr = false;
530 
531   if (wordIndex < words.size()) {
532     auto attrValue = words[wordIndex++];
533     attributes.push_back(opBuilder.getNamedAttr(
534         "memory_access", opBuilder.getI32IntegerAttr(attrValue)));
535     isAlignedAttr = (attrValue == 2);
536   }
537 
538   if (isAlignedAttr && wordIndex < words.size()) {
539     attributes.push_back(opBuilder.getNamedAttr(
540         "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
541   }
542 
543   if (wordIndex < words.size()) {
544     attributes.push_back(opBuilder.getNamedAttr(
545         "source_memory_access",
546         opBuilder.getI32IntegerAttr(words[wordIndex++])));
547   }
548 
549   if (wordIndex < words.size()) {
550     attributes.push_back(opBuilder.getNamedAttr(
551         "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
552   }
553 
554   if (wordIndex != words.size()) {
555     return emitError(unknownLoc,
556                      "found more operands than expected when deserializing "
557                      "spirv::CopyMemoryOp, only ")
558            << wordIndex << " of " << words.size() << " processed";
559   }
560 
561   Location loc = createFileLineColLoc(opBuilder);
562   opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
563 
564   return success();
565 }
566 
567 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
568 // various Deserializer::processOp<...>() specializations.
569 #define GET_DESERIALIZATION_FNS
570 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
571 
572 } // namespace spirv
573 } // namespace mlir
574