1 //===- Operator.cpp - Operator class --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/TableGen/Operator.h"
14 #include "mlir/TableGen/Predicate.h"
15 #include "mlir/TableGen/Trait.h"
16 #include "mlir/TableGen/Type.h"
17 #include "llvm/ADT/EquivalenceClasses.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27 
28 #define DEBUG_TYPE "mlir-tblgen-operator"
29 
30 using namespace mlir;
31 using namespace mlir::tblgen;
32 
33 using llvm::DagInit;
34 using llvm::DefInit;
35 using llvm::Record;
36 
37 Operator::Operator(const llvm::Record &def)
38     : dialect(def.getValueAsDef("opDialect")), def(def) {
39   // The first `_` in the op's TableGen def name is treated as separating the
40   // dialect prefix and the op class name. The dialect prefix will be ignored if
41   // not empty. Otherwise, if def name starts with a `_`, the `_` is considered
42   // as part of the class name.
43   StringRef prefix;
44   std::tie(prefix, cppClassName) = def.getName().split('_');
45   if (prefix.empty()) {
46     // Class name with a leading underscore and without dialect prefix
47     cppClassName = def.getName();
48   } else if (cppClassName.empty()) {
49     // Class name without dialect prefix
50     cppClassName = prefix;
51   }
52 
53   cppNamespace = def.getValueAsString("cppNamespace");
54 
55   populateOpStructure();
56 }
57 
58 std::string Operator::getOperationName() const {
59   auto prefix = dialect.getName();
60   auto opName = def.getValueAsString("opName");
61   if (prefix.empty())
62     return std::string(opName);
63   return std::string(llvm::formatv("{0}.{1}", prefix, opName));
64 }
65 
66 std::string Operator::getAdaptorName() const {
67   return std::string(llvm::formatv("{0}Adaptor", getCppClassName()));
68 }
69 
70 StringRef Operator::getDialectName() const { return dialect.getName(); }
71 
72 StringRef Operator::getCppClassName() const { return cppClassName; }
73 
74 std::string Operator::getQualCppClassName() const {
75   if (cppNamespace.empty())
76     return std::string(cppClassName);
77   return std::string(llvm::formatv("{0}::{1}", cppNamespace, cppClassName));
78 }
79 
80 StringRef Operator::getCppNamespace() const { return cppNamespace; }
81 
82 int Operator::getNumResults() const {
83   DagInit *results = def.getValueAsDag("results");
84   return results->getNumArgs();
85 }
86 
87 StringRef Operator::getExtraClassDeclaration() const {
88   constexpr auto attr = "extraClassDeclaration";
89   if (def.isValueUnset(attr))
90     return {};
91   return def.getValueAsString(attr);
92 }
93 
94 const llvm::Record &Operator::getDef() const { return def; }
95 
96 bool Operator::skipDefaultBuilders() const {
97   return def.getValueAsBit("skipDefaultBuilders");
98 }
99 
100 auto Operator::result_begin() -> value_iterator { return results.begin(); }
101 
102 auto Operator::result_end() -> value_iterator { return results.end(); }
103 
104 auto Operator::getResults() -> value_range {
105   return {result_begin(), result_end()};
106 }
107 
108 TypeConstraint Operator::getResultTypeConstraint(int index) const {
109   DagInit *results = def.getValueAsDag("results");
110   return TypeConstraint(cast<DefInit>(results->getArg(index)));
111 }
112 
113 StringRef Operator::getResultName(int index) const {
114   DagInit *results = def.getValueAsDag("results");
115   return results->getArgNameStr(index);
116 }
117 
118 auto Operator::getResultDecorators(int index) const -> var_decorator_range {
119   Record *result =
120       cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef();
121   if (!result->isSubClassOf("OpVariable"))
122     return var_decorator_range(nullptr, nullptr);
123   return *result->getValueAsListInit("decorators");
124 }
125 
126 unsigned Operator::getNumVariableLengthResults() const {
127   return llvm::count_if(results, [](const NamedTypeConstraint &c) {
128     return c.constraint.isVariableLength();
129   });
130 }
131 
132 unsigned Operator::getNumVariableLengthOperands() const {
133   return llvm::count_if(operands, [](const NamedTypeConstraint &c) {
134     return c.constraint.isVariableLength();
135   });
136 }
137 
138 bool Operator::hasSingleVariadicArg() const {
139   return getNumArgs() == 1 && getArg(0).is<NamedTypeConstraint *>() &&
140          getOperand(0).isVariadic();
141 }
142 
143 Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); }
144 
145 Operator::arg_iterator Operator::arg_end() const { return arguments.end(); }
146 
147 Operator::arg_range Operator::getArgs() const {
148   return {arg_begin(), arg_end()};
149 }
150 
151 StringRef Operator::getArgName(int index) const {
152   DagInit *argumentValues = def.getValueAsDag("arguments");
153   return argumentValues->getArgNameStr(index);
154 }
155 
156 auto Operator::getArgDecorators(int index) const -> var_decorator_range {
157   Record *arg =
158       cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef();
159   if (!arg->isSubClassOf("OpVariable"))
160     return var_decorator_range(nullptr, nullptr);
161   return *arg->getValueAsListInit("decorators");
162 }
163 
164 const Trait *Operator::getTrait(StringRef trait) const {
165   for (const auto &t : traits) {
166     if (const auto *traitDef = dyn_cast<NativeTrait>(&t)) {
167       if (traitDef->getFullyQualifiedTraitName() == trait)
168         return traitDef;
169     } else if (const auto *traitDef = dyn_cast<InternalTrait>(&t)) {
170       if (traitDef->getFullyQualifiedTraitName() == trait)
171         return traitDef;
172     } else if (const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) {
173       if (traitDef->getFullyQualifiedTraitName() == trait)
174         return traitDef;
175     }
176   }
177   return nullptr;
178 }
179 
180 auto Operator::region_begin() const -> const_region_iterator {
181   return regions.begin();
182 }
183 auto Operator::region_end() const -> const_region_iterator {
184   return regions.end();
185 }
186 auto Operator::getRegions() const
187     -> llvm::iterator_range<const_region_iterator> {
188   return {region_begin(), region_end()};
189 }
190 
191 unsigned Operator::getNumRegions() const { return regions.size(); }
192 
193 const NamedRegion &Operator::getRegion(unsigned index) const {
194   return regions[index];
195 }
196 
197 unsigned Operator::getNumVariadicRegions() const {
198   return llvm::count_if(regions,
199                         [](const NamedRegion &c) { return c.isVariadic(); });
200 }
201 
202 auto Operator::successor_begin() const -> const_successor_iterator {
203   return successors.begin();
204 }
205 auto Operator::successor_end() const -> const_successor_iterator {
206   return successors.end();
207 }
208 auto Operator::getSuccessors() const
209     -> llvm::iterator_range<const_successor_iterator> {
210   return {successor_begin(), successor_end()};
211 }
212 
213 unsigned Operator::getNumSuccessors() const { return successors.size(); }
214 
215 const NamedSuccessor &Operator::getSuccessor(unsigned index) const {
216   return successors[index];
217 }
218 
219 unsigned Operator::getNumVariadicSuccessors() const {
220   return llvm::count_if(successors,
221                         [](const NamedSuccessor &c) { return c.isVariadic(); });
222 }
223 
224 auto Operator::trait_begin() const -> const_trait_iterator {
225   return traits.begin();
226 }
227 auto Operator::trait_end() const -> const_trait_iterator {
228   return traits.end();
229 }
230 auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> {
231   return {trait_begin(), trait_end()};
232 }
233 
234 auto Operator::attribute_begin() const -> attribute_iterator {
235   return attributes.begin();
236 }
237 auto Operator::attribute_end() const -> attribute_iterator {
238   return attributes.end();
239 }
240 auto Operator::getAttributes() const
241     -> llvm::iterator_range<attribute_iterator> {
242   return {attribute_begin(), attribute_end()};
243 }
244 
245 auto Operator::operand_begin() -> value_iterator { return operands.begin(); }
246 auto Operator::operand_end() -> value_iterator { return operands.end(); }
247 auto Operator::getOperands() -> value_range {
248   return {operand_begin(), operand_end()};
249 }
250 
251 auto Operator::getArg(int index) const -> Argument { return arguments[index]; }
252 
253 // Mapping from result index to combined argument and result index. Arguments
254 // are indexed to match getArg index, while the result indexes are mapped to
255 // avoid overlap.
256 static int resultIndex(int i) { return -1 - i; }
257 
258 bool Operator::isVariadic() const {
259   return any_of(llvm::concat<const NamedTypeConstraint>(operands, results),
260                 [](const NamedTypeConstraint &op) { return op.isVariadic(); });
261 }
262 
263 void Operator::populateTypeInferenceInfo(
264     const llvm::StringMap<int> &argumentsAndResultsIndex) {
265   // If the type inference op interface is not registered, then do not attempt
266   // to determine if the result types an be inferred.
267   auto &recordKeeper = def.getRecords();
268   auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface);
269   allResultsHaveKnownTypes = false;
270   if (!inferTrait)
271     return;
272 
273   // If there are no results, the skip this else the build method generated
274   // overlaps with another autogenerated builder.
275   if (getNumResults() == 0)
276     return;
277 
278   // Skip for ops with variadic operands/results.
279   // TODO: This can be relaxed.
280   if (isVariadic())
281     return;
282 
283   // Skip cases currently being custom generated.
284   // TODO: Remove special cases.
285   if (getTrait("::mlir::OpTrait::SameOperandsAndResultType"))
286     return;
287 
288   // We create equivalence classes of argument/result types where arguments
289   // and results are mapped into the same index space and indices corresponding
290   // to the same type are in the same equivalence class.
291   llvm::EquivalenceClasses<int> ecs;
292   resultTypeMapping.resize(getNumResults());
293   // Captures the argument whose type matches a given result type. Preference
294   // towards capturing operands first before attributes.
295   auto captureMapping = [&](int i) {
296     bool found = false;
297     ecs.insert(resultIndex(i));
298     auto mi = ecs.findLeader(resultIndex(i));
299     for (auto me = ecs.member_end(); mi != me; ++mi) {
300       if (*mi < 0) {
301         auto tc = getResultTypeConstraint(i);
302         if (tc.getBuilderCall().hasValue()) {
303           resultTypeMapping[i].emplace_back(tc);
304           found = true;
305         }
306         continue;
307       }
308 
309       if (getArg(*mi).is<NamedAttribute *>()) {
310         // TODO: Handle attributes.
311         continue;
312       } else {
313         resultTypeMapping[i].emplace_back(*mi);
314         found = true;
315       }
316     }
317     return found;
318   };
319 
320   for (const Trait &trait : traits) {
321     const llvm::Record &def = trait.getDef();
322     // If the infer type op interface was manually added, then treat it as
323     // intention that the op needs special handling.
324     // TODO: Reconsider whether to always generate, this is more conservative
325     // and keeps existing behavior so starting that way for now.
326     if (def.isSubClassOf(
327             llvm::formatv("{0}::Trait", inferTypeOpInterface).str()))
328       return;
329     if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait))
330       if (&traitDef->getDef() == inferTrait)
331         return;
332 
333     if (!def.isSubClassOf("AllTypesMatch"))
334       continue;
335 
336     auto values = def.getValueAsListOfStrings("values");
337     auto root = argumentsAndResultsIndex.lookup(values.front());
338     for (StringRef str : values)
339       ecs.unionSets(argumentsAndResultsIndex.lookup(str), root);
340   }
341 
342   // Verifies that all output types have a corresponding known input type
343   // and chooses matching operand or attribute (in that order) that
344   // matches it.
345   allResultsHaveKnownTypes =
346       all_of(llvm::seq<int>(0, getNumResults()), captureMapping);
347 
348   // If the types could be computed, then add type inference trait.
349   if (allResultsHaveKnownTypes)
350     traits.push_back(Trait::create(inferTrait->getDefInit()));
351 }
352 
353 void Operator::populateOpStructure() {
354   auto &recordKeeper = def.getRecords();
355   auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint");
356   auto *attrClass = recordKeeper.getClass("Attr");
357   auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr");
358   auto *opVarClass = recordKeeper.getClass("OpVariable");
359   numNativeAttributes = 0;
360 
361   DagInit *argumentValues = def.getValueAsDag("arguments");
362   unsigned numArgs = argumentValues->getNumArgs();
363 
364   // Mapping from name of to argument or result index. Arguments are indexed
365   // to match getArg index, while the results are negatively indexed.
366   llvm::StringMap<int> argumentsAndResultsIndex;
367 
368   // Handle operands and native attributes.
369   for (unsigned i = 0; i != numArgs; ++i) {
370     auto *arg = argumentValues->getArg(i);
371     auto givenName = argumentValues->getArgNameStr(i);
372     auto *argDefInit = dyn_cast<DefInit>(arg);
373     if (!argDefInit)
374       PrintFatalError(def.getLoc(),
375                       Twine("undefined type for argument #") + Twine(i));
376     Record *argDef = argDefInit->getDef();
377     if (argDef->isSubClassOf(opVarClass))
378       argDef = argDef->getValueAsDef("constraint");
379 
380     if (argDef->isSubClassOf(typeConstraintClass)) {
381       operands.push_back(
382           NamedTypeConstraint{givenName, TypeConstraint(argDef)});
383     } else if (argDef->isSubClassOf(attrClass)) {
384       if (givenName.empty())
385         PrintFatalError(argDef->getLoc(), "attributes must be named");
386       if (argDef->isSubClassOf(derivedAttrClass))
387         PrintFatalError(argDef->getLoc(),
388                         "derived attributes not allowed in argument list");
389       attributes.push_back({givenName, Attribute(argDef)});
390       ++numNativeAttributes;
391     } else {
392       PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
393                                     "from TypeConstraint or Attr are allowed");
394     }
395     if (!givenName.empty())
396       argumentsAndResultsIndex[givenName] = i;
397   }
398 
399   // Handle derived attributes.
400   for (const auto &val : def.getValues()) {
401     if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
402       if (!record->isSubClassOf(attrClass))
403         continue;
404       if (!record->isSubClassOf(derivedAttrClass))
405         PrintFatalError(def.getLoc(),
406                         "unexpected Attr where only DerivedAttr is allowed");
407 
408       if (record->getClasses().size() != 1) {
409         PrintFatalError(
410             def.getLoc(),
411             "unsupported attribute modelling, only single class expected");
412       }
413       attributes.push_back(
414           {cast<llvm::StringInit>(val.getNameInit())->getValue(),
415            Attribute(cast<DefInit>(val.getValue()))});
416     }
417   }
418 
419   // Populate `arguments`. This must happen after we've finalized `operands` and
420   // `attributes` because we will put their elements' pointers in `arguments`.
421   // SmallVector may perform re-allocation under the hood when adding new
422   // elements.
423   int operandIndex = 0, attrIndex = 0;
424   for (unsigned i = 0; i != numArgs; ++i) {
425     Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
426     if (argDef->isSubClassOf(opVarClass))
427       argDef = argDef->getValueAsDef("constraint");
428 
429     if (argDef->isSubClassOf(typeConstraintClass)) {
430       attrOrOperandMapping.push_back(
431           {OperandOrAttribute::Kind::Operand, operandIndex});
432       arguments.emplace_back(&operands[operandIndex++]);
433     } else {
434       assert(argDef->isSubClassOf(attrClass));
435       attrOrOperandMapping.push_back(
436           {OperandOrAttribute::Kind::Attribute, attrIndex});
437       arguments.emplace_back(&attributes[attrIndex++]);
438     }
439   }
440 
441   auto *resultsDag = def.getValueAsDag("results");
442   auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
443   if (!outsOp || outsOp->getDef()->getName() != "outs") {
444     PrintFatalError(def.getLoc(), "'results' must have 'outs' directive");
445   }
446 
447   // Handle results.
448   for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
449     auto name = resultsDag->getArgNameStr(i);
450     auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
451     if (!resultInit) {
452       PrintFatalError(def.getLoc(),
453                       Twine("undefined type for result #") + Twine(i));
454     }
455     auto *resultDef = resultInit->getDef();
456     if (resultDef->isSubClassOf(opVarClass))
457       resultDef = resultDef->getValueAsDef("constraint");
458     results.push_back({name, TypeConstraint(resultDef)});
459     if (!name.empty())
460       argumentsAndResultsIndex[name] = resultIndex(i);
461 
462     // We currently only support VariadicOfVariadic operands.
463     if (results.back().constraint.isVariadicOfVariadic()) {
464       PrintFatalError(
465           def.getLoc(),
466           "'VariadicOfVariadic' results are currently not supported");
467     }
468   }
469 
470   // Handle successors
471   auto *successorsDag = def.getValueAsDag("successors");
472   auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator());
473   if (!successorsOp || successorsOp->getDef()->getName() != "successor") {
474     PrintFatalError(def.getLoc(),
475                     "'successors' must have 'successor' directive");
476   }
477 
478   for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
479     auto name = successorsDag->getArgNameStr(i);
480     auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i));
481     if (!successorInit) {
482       PrintFatalError(def.getLoc(),
483                       Twine("undefined kind for successor #") + Twine(i));
484     }
485     Successor successor(successorInit->getDef());
486 
487     // Only support variadic successors if it is the last one for now.
488     if (i != e - 1 && successor.isVariadic())
489       PrintFatalError(def.getLoc(), "only the last successor can be variadic");
490     successors.push_back({name, successor});
491   }
492 
493   // Create list of traits, skipping over duplicates: appending to lists in
494   // tablegen is easy, making them unique less so, so dedupe here.
495   if (auto *traitList = def.getValueAsListInit("traits")) {
496     // This is uniquing based on pointers of the trait.
497     SmallPtrSet<const llvm::Init *, 32> traitSet;
498     traits.reserve(traitSet.size());
499 
500     std::function<void(llvm::ListInit *)> insert;
501     insert = [&](llvm::ListInit *traitList) {
502       for (auto *traitInit : *traitList) {
503         auto *def = cast<DefInit>(traitInit)->getDef();
504         if (def->isSubClassOf("OpTraitList")) {
505           insert(def->getValueAsListInit("traits"));
506           continue;
507         }
508         // Keep traits in the same order while skipping over duplicates.
509         if (traitSet.insert(traitInit).second)
510           traits.push_back(Trait::create(traitInit));
511       }
512     };
513     insert(traitList);
514   }
515 
516   populateTypeInferenceInfo(argumentsAndResultsIndex);
517 
518   // Handle regions
519   auto *regionsDag = def.getValueAsDag("regions");
520   auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
521   if (!regionsOp || regionsOp->getDef()->getName() != "region") {
522     PrintFatalError(def.getLoc(), "'regions' must have 'region' directive");
523   }
524 
525   for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
526     auto name = regionsDag->getArgNameStr(i);
527     auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
528     if (!regionInit) {
529       PrintFatalError(def.getLoc(),
530                       Twine("undefined kind for region #") + Twine(i));
531     }
532     Region region(regionInit->getDef());
533     if (region.isVariadic()) {
534       // Only support variadic regions if it is the last one for now.
535       if (i != e - 1)
536         PrintFatalError(def.getLoc(), "only the last region can be variadic");
537       if (name.empty())
538         PrintFatalError(def.getLoc(), "variadic regions must be named");
539     }
540 
541     regions.push_back({name, region});
542   }
543 
544   // Populate the builders.
545   auto *builderList =
546       dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders"));
547   if (builderList && !builderList->empty()) {
548     for (llvm::Init *init : builderList->getValues())
549       builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc());
550   } else if (skipDefaultBuilders()) {
551     PrintFatalError(
552         def.getLoc(),
553         "default builders are skipped and no custom builders provided");
554   }
555 
556   LLVM_DEBUG(print(llvm::dbgs()));
557 }
558 
559 auto Operator::getSameTypeAsResult(int index) const -> ArrayRef<ArgOrType> {
560   assert(allResultTypesKnown());
561   return resultTypeMapping[index];
562 }
563 
564 ArrayRef<llvm::SMLoc> Operator::getLoc() const { return def.getLoc(); }
565 
566 bool Operator::hasDescription() const {
567   return def.getValue("description") != nullptr;
568 }
569 
570 StringRef Operator::getDescription() const {
571   return def.getValueAsString("description");
572 }
573 
574 bool Operator::hasSummary() const { return def.getValue("summary") != nullptr; }
575 
576 StringRef Operator::getSummary() const {
577   return def.getValueAsString("summary");
578 }
579 
580 bool Operator::hasAssemblyFormat() const {
581   auto *valueInit = def.getValueInit("assemblyFormat");
582   return isa<llvm::StringInit>(valueInit);
583 }
584 
585 StringRef Operator::getAssemblyFormat() const {
586   return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
587       .Case<llvm::StringInit>([&](auto *init) { return init->getValue(); });
588 }
589 
590 void Operator::print(llvm::raw_ostream &os) const {
591   os << "op '" << getOperationName() << "'\n";
592   for (Argument arg : arguments) {
593     if (auto *attr = arg.dyn_cast<NamedAttribute *>())
594       os << "[attribute] " << attr->name << '\n';
595     else
596       os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
597   }
598 }
599 
600 auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
601     -> VariableDecorator {
602   return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
603 }
604 
605 auto Operator::getArgToOperandOrAttribute(int index) const
606     -> OperandOrAttribute {
607   return attrOrOperandMapping[index];
608 }
609