1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "OpClass.h"
15 #include "OpFormatGen.h"
16 #include "OpGenHelpers.h"
17 #include "mlir/TableGen/Class.h"
18 #include "mlir/TableGen/CodeGenHelpers.h"
19 #include "mlir/TableGen/Format.h"
20 #include "mlir/TableGen/GenInfo.h"
21 #include "mlir/TableGen/Interfaces.h"
22 #include "mlir/TableGen/Operator.h"
23 #include "mlir/TableGen/SideEffects.h"
24 #include "mlir/TableGen/Trait.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringSet.h"
29 #include "llvm/ADT/StringSwitch.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/Signals.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Record.h"
34 #include "llvm/TableGen/TableGenBackend.h"
35 
36 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
37 
38 using namespace llvm;
39 using namespace mlir;
40 using namespace mlir::tblgen;
41 
42 static const char *const tblgenNamePrefix = "tblgen_";
43 static const char *const generatedArgName = "odsArg";
44 static const char *const odsBuilder = "odsBuilder";
45 static const char *const builderOpState = "odsState";
46 
47 /// The names of the implicit attributes that contain variadic operand and
48 /// result segment sizes.
49 static const char *const operandSegmentAttrName = "operand_segment_sizes";
50 static const char *const resultSegmentAttrName = "result_segment_sizes";
51 
52 /// Code for an Op to lookup an attribute. Uses cached identifiers and subrange
53 /// lookup.
54 ///
55 /// {0}: Code snippet to get the attribute's name or identifier.
56 /// {1}: The lower bound on the sorted subrange.
57 /// {2}: The upper bound on the sorted subrange.
58 /// {3}: Code snippet to get the array of named attributes.
59 /// {4}: "Named" to get the named attribute.
60 static const char *const subrangeGetAttr =
61     "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - "
62     "{2}, {0})";
63 
64 /// The logic to calculate the actual value range for a declared operand/result
65 /// of an op with variadic operands/results. Note that this logic is not for
66 /// general use; it assumes all variadic operands/results must have the same
67 /// number of values.
68 ///
69 /// {0}: The list of whether each declared operand/result is variadic.
70 /// {1}: The total number of non-variadic operands/results.
71 /// {2}: The total number of variadic operands/results.
72 /// {3}: The total number of actual values.
73 /// {4}: "operand" or "result".
74 static const char *const sameVariadicSizeValueRangeCalcCode = R"(
75   bool isVariadic[] = {{{0}};
76   int prevVariadicCount = 0;
77   for (unsigned i = 0; i < index; ++i)
78     if (isVariadic[i]) ++prevVariadicCount;
79 
80   // Calculate how many dynamic values a static variadic {4} corresponds to.
81   // This assumes all static variadic {4}s have the same dynamic value count.
82   int variadicSize = ({3} - {1}) / {2};
83   // `index` passed in as the parameter is the static index which counts each
84   // {4} (variadic or not) as size 1. So here for each previous static variadic
85   // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
86   // value pack for this static {4} starts.
87   int start = index + (variadicSize - 1) * prevVariadicCount;
88   int size = isVariadic[index] ? variadicSize : 1;
89   return {{start, size};
90 )";
91 
92 /// The logic to calculate the actual value range for a declared operand/result
93 /// of an op with variadic operands/results. Note that this logic is assumes
94 /// the op has an attribute specifying the size of each operand/result segment
95 /// (variadic or not).
96 static const char *const attrSizedSegmentValueRangeCalcCode = R"(
97   const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin<uint32_t>();
98   if (sizeAttr.isSplat())
99     return {*sizeAttrValueIt * index, *sizeAttrValueIt};
100 
101   unsigned start = 0;
102   for (unsigned i = 0; i < index; ++i)
103     start += sizeAttrValueIt[i];
104   return {start, sizeAttrValueIt[index]};
105 )";
106 /// The code snippet to initialize the sizes for the value range calculation.
107 ///
108 /// {0}: The code to get the attribute.
109 static const char *const adapterSegmentSizeAttrInitCode = R"(
110   assert(odsAttrs && "missing segment size attribute for op");
111   auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>();
112 )";
113 /// The code snippet to initialize the sizes for the value range calculation.
114 ///
115 /// {0}: The code to get the attribute.
116 static const char *const opSegmentSizeAttrInitCode = R"(
117   auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>();
118 )";
119 
120 /// The logic to calculate the actual value range for a declared operand
121 /// of an op with variadic of variadic operands within the OpAdaptor.
122 ///
123 /// {0}: The name of the segment attribute.
124 /// {1}: The index of the main operand.
125 static const char *const variadicOfVariadicAdaptorCalcCode = R"(
126   auto tblgenTmpOperands = getODSOperands({1});
127   auto sizeAttrValues = {0}().getValues<uint32_t>();
128   auto sizeAttrIt = sizeAttrValues.begin();
129 
130   ::llvm::SmallVector<::mlir::ValueRange> tblgenTmpOperandGroups;
131   for (int i = 0, e = ::llvm::size(sizeAttrValues); i < e; ++i, ++sizeAttrIt) {{
132     tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(*sizeAttrIt));
133     tblgenTmpOperands = tblgenTmpOperands.drop_front(*sizeAttrIt);
134   }
135   return tblgenTmpOperandGroups;
136 )";
137 
138 /// The logic to build a range of either operand or result values.
139 ///
140 /// {0}: The begin iterator of the actual values.
141 /// {1}: The call to generate the start and length of the value range.
142 static const char *const valueRangeReturnCode = R"(
143   auto valueRange = {1};
144   return {{std::next({0}, valueRange.first),
145            std::next({0}, valueRange.first + valueRange.second)};
146 )";
147 
148 /// A header for indicating code sections.
149 ///
150 /// {0}: Some text, or a class name.
151 /// {1}: Some text.
152 static const char *const opCommentHeader = R"(
153 //===----------------------------------------------------------------------===//
154 // {0} {1}
155 //===----------------------------------------------------------------------===//
156 
157 )";
158 
159 //===----------------------------------------------------------------------===//
160 // Utility structs and functions
161 //===----------------------------------------------------------------------===//
162 
163 // Replaces all occurrences of `match` in `str` with `substitute`.
replaceAllSubstrs(std::string str,const std::string & match,const std::string & substitute)164 static std::string replaceAllSubstrs(std::string str, const std::string &match,
165                                      const std::string &substitute) {
166   std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
167   while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
168     str = str.replace(matchLoc, match.size(), substitute);
169     scanLoc = matchLoc + substitute.size();
170   }
171   return str;
172 }
173 
174 // Returns whether the record has a value of the given name that can be returned
175 // via getValueAsString.
hasStringAttribute(const Record & record,StringRef fieldName)176 static inline bool hasStringAttribute(const Record &record,
177                                       StringRef fieldName) {
178   auto *valueInit = record.getValueInit(fieldName);
179   return isa<StringInit>(valueInit);
180 }
181 
getArgumentName(const Operator & op,int index)182 static std::string getArgumentName(const Operator &op, int index) {
183   const auto &operand = op.getOperand(index);
184   if (!operand.name.empty())
185     return std::string(operand.name);
186   return std::string(formatv("{0}_{1}", generatedArgName, index));
187 }
188 
189 // Returns true if we can use unwrapped value for the given `attr` in builders.
canUseUnwrappedRawValue(const tblgen::Attribute & attr)190 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
191   return attr.getReturnType() != attr.getStorageType() &&
192          // We need to wrap the raw value into an attribute in the builder impl
193          // so we need to make sure that the attribute specifies how to do that.
194          !attr.getConstBuilderTemplate().empty();
195 }
196 
197 namespace {
198 /// Metadata on a registered attribute. Given that attributes are stored in
199 /// sorted order on operations, we can use information from ODS to deduce the
200 /// number of required attributes less and and greater than each attribute,
201 /// allowing us to search only a subrange of the attributes in ODS-generated
202 /// getters.
203 struct AttributeMetadata {
204   /// The attribute name.
205   StringRef attrName;
206   /// Whether the attribute is required.
207   bool isRequired;
208   /// The ODS attribute constraint. Not present for implicit attributes.
209   Optional<Attribute> constraint;
210   /// The number of required attributes less than this attribute.
211   unsigned lowerBound = 0;
212   /// The number of required attributes greater than this attribute.
213   unsigned upperBound = 0;
214 };
215 
216 /// Helper class to select between OpAdaptor and Op code templates.
217 class OpOrAdaptorHelper {
218 public:
OpOrAdaptorHelper(const Operator & op,bool emitForOp)219   OpOrAdaptorHelper(const Operator &op, bool emitForOp)
220       : op(op), emitForOp(emitForOp) {
221     computeAttrMetadata();
222   }
223 
224   /// Object that wraps a functor in a stream operator for interop with
225   /// llvm::formatv.
226   class Formatter {
227   public:
228     template <typename Functor>
Formatter(Functor && func)229     Formatter(Functor &&func) : func(std::forward<Functor>(func)) {}
230 
str() const231     std::string str() const {
232       std::string result;
233       llvm::raw_string_ostream os(result);
234       os << *this;
235       return os.str();
236     }
237 
238   private:
239     std::function<raw_ostream &(raw_ostream &)> func;
240 
operator <<(raw_ostream & os,const Formatter & fmt)241     friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) {
242       return fmt.func(os);
243     }
244   };
245 
246   // Generate code for getting an attribute.
getAttr(StringRef attrName,bool isNamed=false) const247   Formatter getAttr(StringRef attrName, bool isNamed = false) const {
248     assert(attrMetadata.count(attrName) && "expected attribute metadata");
249     return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & {
250       const AttributeMetadata &attr = attrMetadata.find(attrName)->second;
251       return os << formatv(subrangeGetAttr, getAttrName(attrName),
252                            attr.lowerBound, attr.upperBound, getAttrRange(),
253                            isNamed ? "Named" : "");
254     };
255   }
256 
257   // Generate code for getting the name of an attribute.
getAttrName(StringRef attrName) const258   Formatter getAttrName(StringRef attrName) const {
259     return [this, attrName](raw_ostream &os) -> raw_ostream & {
260       if (emitForOp)
261         return os << op.getGetterName(attrName) << "AttrName()";
262       return os << formatv("{0}::{1}AttrName(*odsOpName)", op.getCppClassName(),
263                            op.getGetterName(attrName));
264     };
265   }
266 
267   // Get the code snippet for getting the named attribute range.
getAttrRange() const268   StringRef getAttrRange() const {
269     return emitForOp ? "(*this)->getAttrs()" : "odsAttrs";
270   }
271 
272   // Get the prefix code for emitting an error.
emitErrorPrefix() const273   Formatter emitErrorPrefix() const {
274     return [this](raw_ostream &os) -> raw_ostream & {
275       if (emitForOp)
276         return os << "emitOpError(";
277       return os << formatv("emitError(loc, \"'{0}' op \"",
278                            op.getOperationName());
279     };
280   }
281 
282   // Get the call to get an operand or segment of operands.
getOperand(unsigned index) const283   Formatter getOperand(unsigned index) const {
284     return [this, index](raw_ostream &os) -> raw_ostream & {
285       return os << formatv(op.getOperand(index).isVariadic()
286                                ? "this->getODSOperands({0})"
287                                : "(*this->getODSOperands({0}).begin())",
288                            index);
289     };
290   }
291 
292   // Get the call to get a result of segment of results.
getResult(unsigned index) const293   Formatter getResult(unsigned index) const {
294     return [this, index](raw_ostream &os) -> raw_ostream & {
295       if (!emitForOp)
296         return os << "<no results should be generated>";
297       return os << formatv(op.getResult(index).isVariadic()
298                                ? "this->getODSResults({0})"
299                                : "(*this->getODSResults({0}).begin())",
300                            index);
301     };
302   }
303 
304   // Return whether an op instance is available.
isEmittingForOp() const305   bool isEmittingForOp() const { return emitForOp; }
306 
307   // Return the ODS operation wrapper.
getOp() const308   const Operator &getOp() const { return op; }
309 
310   // Get the attribute metadata sorted by name.
getAttrMetadata() const311   const llvm::MapVector<StringRef, AttributeMetadata> &getAttrMetadata() const {
312     return attrMetadata;
313   }
314 
315 private:
316   // Compute the attribute metadata.
317   void computeAttrMetadata();
318 
319   // The operation ODS wrapper.
320   const Operator &op;
321   // True if code is being generate for an op. False for an adaptor.
322   const bool emitForOp;
323 
324   // The attribute metadata, mapped by name.
325   llvm::MapVector<StringRef, AttributeMetadata> attrMetadata;
326   // The number of required attributes.
327   unsigned numRequired;
328 };
329 
330 } // namespace
331 
computeAttrMetadata()332 void OpOrAdaptorHelper::computeAttrMetadata() {
333   // Enumerate the attribute names of this op, ensuring the attribute names are
334   // unique in case implicit attributes are explicitly registered.
335   for (const NamedAttribute &namedAttr : op.getAttributes()) {
336     Attribute attr = namedAttr.attr;
337     bool isOptional =
338         attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr();
339     attrMetadata.insert(
340         {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}});
341   }
342   // Include key attributes from several traits as implicitly registered.
343   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
344     attrMetadata.insert(
345         {operandSegmentAttrName,
346          AttributeMetadata{operandSegmentAttrName, /*isRequired=*/true,
347                            /*attr=*/llvm::None}});
348   }
349   if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
350     attrMetadata.insert(
351         {resultSegmentAttrName,
352          AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true,
353                            /*attr=*/llvm::None}});
354   }
355 
356   // Store the metadata in sorted order.
357   SmallVector<AttributeMetadata> sortedAttrMetadata =
358       llvm::to_vector(llvm::make_second_range(attrMetadata.takeVector()));
359   llvm::sort(sortedAttrMetadata,
360              [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) {
361                return lhs.attrName < rhs.attrName;
362              });
363 
364   // Compute the subrange bounds for each attribute.
365   numRequired = 0;
366   for (AttributeMetadata &attr : sortedAttrMetadata) {
367     attr.lowerBound = numRequired;
368     numRequired += attr.isRequired;
369   };
370   for (AttributeMetadata &attr : sortedAttrMetadata)
371     attr.upperBound = numRequired - attr.lowerBound - attr.isRequired;
372 
373   // Store the results back into the map.
374   for (const AttributeMetadata &attr : sortedAttrMetadata)
375     attrMetadata.insert({attr.attrName, attr});
376 }
377 
378 //===----------------------------------------------------------------------===//
379 // Op emitter
380 //===----------------------------------------------------------------------===//
381 
382 namespace {
383 // Helper class to emit a record into the given output stream.
384 class OpEmitter {
385 public:
386   static void
387   emitDecl(const Operator &op, raw_ostream &os,
388            const StaticVerifierFunctionEmitter &staticVerifierEmitter);
389   static void
390   emitDef(const Operator &op, raw_ostream &os,
391           const StaticVerifierFunctionEmitter &staticVerifierEmitter);
392 
393 private:
394   OpEmitter(const Operator &op,
395             const StaticVerifierFunctionEmitter &staticVerifierEmitter);
396 
397   void emitDecl(raw_ostream &os);
398   void emitDef(raw_ostream &os);
399 
400   // Generate methods for accessing the attribute names of this operation.
401   void genAttrNameGetters();
402 
403   // Generates the OpAsmOpInterface for this operation if possible.
404   void genOpAsmInterface();
405 
406   // Generates the `getOperationName` method for this op.
407   void genOpNameGetter();
408 
409   // Generates getters for the attributes.
410   void genAttrGetters();
411 
412   // Generates setter for the attributes.
413   void genAttrSetters();
414 
415   // Generates removers for optional attributes.
416   void genOptionalAttrRemovers();
417 
418   // Generates getters for named operands.
419   void genNamedOperandGetters();
420 
421   // Generates setters for named operands.
422   void genNamedOperandSetters();
423 
424   // Generates getters for named results.
425   void genNamedResultGetters();
426 
427   // Generates getters for named regions.
428   void genNamedRegionGetters();
429 
430   // Generates getters for named successors.
431   void genNamedSuccessorGetters();
432 
433   // Generates the method to populate default attributes.
434   void genPopulateDefaultAttributes();
435 
436   // Generates builder methods for the operation.
437   void genBuilder();
438 
439   // Generates the build() method that takes each operand/attribute
440   // as a stand-alone parameter.
441   void genSeparateArgParamBuilder();
442 
443   // Generates the build() method that takes each operand/attribute as a
444   // stand-alone parameter. The generated build() method uses first operand's
445   // type as all results' types.
446   void genUseOperandAsResultTypeSeparateParamBuilder();
447 
448   // Generates the build() method that takes all operands/attributes
449   // collectively as one parameter. The generated build() method uses first
450   // operand's type as all results' types.
451   void genUseOperandAsResultTypeCollectiveParamBuilder();
452 
453   // Generates the build() method that takes aggregate operands/attributes
454   // parameters. This build() method uses inferred types as result types.
455   // Requires: The type needs to be inferable via InferTypeOpInterface.
456   void genInferredTypeCollectiveParamBuilder();
457 
458   // Generates the build() method that takes each operand/attribute as a
459   // stand-alone parameter. The generated build() method uses first attribute's
460   // type as all result's types.
461   void genUseAttrAsResultTypeBuilder();
462 
463   // Generates the build() method that takes all result types collectively as
464   // one parameter. Similarly for operands and attributes.
465   void genCollectiveParamBuilder();
466 
467   // The kind of parameter to generate for result types in builders.
468   enum class TypeParamKind {
469     None,       // No result type in parameter list.
470     Separate,   // A separate parameter for each result type.
471     Collective, // An ArrayRef<Type> for all result types.
472   };
473 
474   // The kind of parameter to generate for attributes in builders.
475   enum class AttrParamKind {
476     WrappedAttr,    // A wrapped MLIR Attribute instance.
477     UnwrappedValue, // A raw value without MLIR Attribute wrapper.
478   };
479 
480   // Builds the parameter list for build() method of this op. This method writes
481   // to `paramList` the comma-separated parameter list and updates
482   // `resultTypeNames` with the names for parameters for specifying result
483   // types. `inferredAttributes` is populated with any attributes that are
484   // elided from the build list. The given `typeParamKind` and `attrParamKind`
485   // controls how result types and attributes are placed in the parameter list.
486   void buildParamList(SmallVectorImpl<MethodParameter> &paramList,
487                       llvm::StringSet<> &inferredAttributes,
488                       SmallVectorImpl<std::string> &resultTypeNames,
489                       TypeParamKind typeParamKind,
490                       AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
491 
492   // Adds op arguments and regions into operation state for build() methods.
493   void
494   genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
495                                          llvm::StringSet<> &inferredAttributes,
496                                          bool isRawValueAttr = false);
497 
498   // Generates canonicalizer declaration for the operation.
499   void genCanonicalizerDecls();
500 
501   // Generates the folder declaration for the operation.
502   void genFolderDecls();
503 
504   // Generates the parser for the operation.
505   void genParser();
506 
507   // Generates the printer for the operation.
508   void genPrinter();
509 
510   // Generates verify method for the operation.
511   void genVerifier();
512 
513   // Generates custom verify methods for the operation.
514   void genCustomVerifier();
515 
516   // Generates verify statements for operands and results in the operation.
517   // The generated code will be attached to `body`.
518   void genOperandResultVerifier(MethodBody &body,
519                                 Operator::const_value_range values,
520                                 StringRef valueKind);
521 
522   // Generates verify statements for regions in the operation.
523   // The generated code will be attached to `body`.
524   void genRegionVerifier(MethodBody &body);
525 
526   // Generates verify statements for successors in the operation.
527   // The generated code will be attached to `body`.
528   void genSuccessorVerifier(MethodBody &body);
529 
530   // Generates the traits used by the object.
531   void genTraits();
532 
533   // Generate the OpInterface methods for all interfaces.
534   void genOpInterfaceMethods();
535 
536   // Generate op interface methods for the given interface.
537   void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
538 
539   // Generate op interface method for the given interface method. If
540   // 'declaration' is true, generates a declaration, else a definition.
541   Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
542                                bool declaration = true);
543 
544   // Generate the side effect interface methods.
545   void genSideEffectInterfaceMethods();
546 
547   // Generate the type inference interface methods.
548   void genTypeInterfaceMethods();
549 
550 private:
551   // The TableGen record for this op.
552   // TODO: OpEmitter should not have a Record directly,
553   // it should rather go through the Operator for better abstraction.
554   const Record &def;
555 
556   // The wrapper operator class for querying information from this op.
557   const Operator &op;
558 
559   // The C++ code builder for this op
560   OpClass opClass;
561 
562   // The format context for verification code generation.
563   FmtContext verifyCtx;
564 
565   // The emitter containing all of the locally emitted verification functions.
566   const StaticVerifierFunctionEmitter &staticVerifierEmitter;
567 
568   // Helper for emitting op code.
569   OpOrAdaptorHelper emitHelper;
570 };
571 
572 } // namespace
573 
574 // Populate the format context `ctx` with substitutions of attributes, operands
575 // and results.
populateSubstitutions(const OpOrAdaptorHelper & emitHelper,FmtContext & ctx)576 static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
577                                   FmtContext &ctx) {
578   // Populate substitutions for attributes.
579   auto &op = emitHelper.getOp();
580   for (const auto &namedAttr : op.getAttributes())
581     ctx.addSubst(namedAttr.name, emitHelper.getAttr(namedAttr.name).str());
582 
583   // Populate substitutions for named operands.
584   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
585     auto &value = op.getOperand(i);
586     if (!value.name.empty())
587       ctx.addSubst(value.name, emitHelper.getOperand(i).str());
588   }
589 
590   // Populate substitutions for results.
591   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
592     auto &value = op.getResult(i);
593     if (!value.name.empty())
594       ctx.addSubst(value.name, emitHelper.getResult(i).str());
595   }
596 }
597 
598 /// Generate verification on native traits requiring attributes.
genNativeTraitAttrVerifier(MethodBody & body,const OpOrAdaptorHelper & emitHelper)599 static void genNativeTraitAttrVerifier(MethodBody &body,
600                                        const OpOrAdaptorHelper &emitHelper) {
601   // Check that the variadic segment sizes attribute exists and contains the
602   // expected number of elements.
603   //
604   // {0}: Attribute name.
605   // {1}: Expected number of elements.
606   // {2}: "operand" or "result".
607   // {3}: Emit error prefix.
608   const char *const checkAttrSizedValueSegmentsCode = R"(
609   {
610     auto sizeAttr = tblgen_{0}.cast<::mlir::DenseIntElementsAttr>();
611     auto numElements =
612         sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
613     if (numElements != {1})
614       return {3}"'{0}' attribute for specifying {2} segments must have {1} "
615                 "elements, but got ") << numElements;
616   }
617   )";
618 
619   // Verify a few traits first so that we can use getODSOperands() and
620   // getODSResults() in the rest of the verifier.
621   auto &op = emitHelper.getOp();
622   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
623     body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName,
624                     op.getNumOperands(), "operand",
625                     emitHelper.emitErrorPrefix());
626   }
627   if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
628     body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName,
629                     op.getNumResults(), "result", emitHelper.emitErrorPrefix());
630   }
631 }
632 
633 // Generate attribute verification. If an op instance is not available, then
634 // attribute checks that require one will not be emitted.
635 //
636 // Attribute verification is performed as follows:
637 //
638 // 1. Verify that all required attributes are present in sorted order. This
639 // ensures that we can use subrange lookup even with potentially missing
640 // attributes.
641 // 2. Verify native trait attributes so that other attributes may call methods
642 // that depend on the validity of these attributes, e.g. segment size attributes
643 // and operand or result getters.
644 // 3. Verify the constraints on all present attributes.
genAttributeVerifier(const OpOrAdaptorHelper & emitHelper,FmtContext & ctx,MethodBody & body,const StaticVerifierFunctionEmitter & staticVerifierEmitter)645 static void genAttributeVerifier(
646     const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body,
647     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
648   if (emitHelper.getAttrMetadata().empty())
649     return;
650 
651   // Verify the attribute if it is present. This assumes that default values
652   // are valid. This code snippet pastes the condition inline.
653   //
654   // TODO: verify the default value is valid (perhaps in debug mode only).
655   //
656   // {0}: Attribute variable name.
657   // {1}: Attribute condition code.
658   // {2}: Emit error prefix.
659   // {3}: Attribute name.
660   // {4}: Attribute/constraint description.
661   const char *const verifyAttrInline = R"(
662   if ({0} && !({1}))
663     return {2}"attribute '{3}' failed to satisfy constraint: {4}");
664 )";
665   // Verify the attribute using a uniqued constraint. Can only be used within
666   // the context of an op.
667   //
668   // {0}: Unique constraint name.
669   // {1}: Attribute variable name.
670   // {2}: Attribute name.
671   const char *const verifyAttrUnique = R"(
672   if (::mlir::failed({0}(*this, {1}, "{2}")))
673     return ::mlir::failure();
674 )";
675 
676   // Traverse the array until the required attribute is found. Return an error
677   // if the traversal reached the end.
678   //
679   // {0}: Code to get the name of the attribute.
680   // {1}: The emit error prefix.
681   // {2}: The name of the attribute.
682   const char *const findRequiredAttr = R"(while (true) {{
683   if (namedAttrIt == namedAttrRange.end())
684     return {1}"requires attribute '{2}'");
685   if (namedAttrIt->getName() == {0}) {{
686     tblgen_{2} = namedAttrIt->getValue();
687     break;
688   })";
689 
690   // Emit a check to see if the iteration has encountered an optional attribute.
691   //
692   // {0}: Code to get the name of the attribute.
693   // {1}: The name of the attribute.
694   const char *const checkOptionalAttr = R"(
695   else if (namedAttrIt->getName() == {0}) {{
696     tblgen_{1} = namedAttrIt->getValue();
697   })";
698 
699   // Emit the start of the loop for checking trailing attributes.
700   const char *const checkTrailingAttrs = R"(while (true) {
701   if (namedAttrIt == namedAttrRange.end()) {
702     break;
703   })";
704 
705   // Return true if a verifier can be emitted for the attribute: it is not a
706   // derived attribute, it has a predicate, its condition is not empty, and, for
707   // adaptors, the condition does not reference the op.
708   const auto canEmitVerifier = [&](Attribute attr) {
709     if (attr.isDerivedAttr())
710       return false;
711     Pred pred = attr.getPredicate();
712     if (pred.isNull())
713       return false;
714     std::string condition = pred.getCondition();
715     return !condition.empty() && (!StringRef(condition).contains("$_op") ||
716                                   emitHelper.isEmittingForOp());
717   };
718 
719   // Emit the verifier for the attribute.
720   const auto emitVerifier = [&](Attribute attr, StringRef attrName,
721                                 StringRef varName) {
722     std::string condition = attr.getPredicate().getCondition();
723 
724     Optional<StringRef> constraintFn;
725     if (emitHelper.isEmittingForOp() &&
726         (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) {
727       body << formatv(verifyAttrUnique, *constraintFn, varName, attrName);
728     } else {
729       body << formatv(verifyAttrInline, varName,
730                       tgfmt(condition, &ctx.withSelf(varName)),
731                       emitHelper.emitErrorPrefix(), attrName,
732                       escapeString(attr.getSummary()));
733     }
734   };
735 
736   // Prefix variables with `tblgen_` to avoid hiding the attribute accessor.
737   const auto getVarName = [&](StringRef attrName) {
738     return (tblgenNamePrefix + attrName).str();
739   };
740 
741   body.indent() << formatv("auto namedAttrRange = {0};\n",
742                            emitHelper.getAttrRange());
743   body << "auto namedAttrIt = namedAttrRange.begin();\n";
744 
745   // Iterate over the attributes in sorted order. Keep track of the optional
746   // attributes that may be encountered along the way.
747   SmallVector<const AttributeMetadata *> optionalAttrs;
748   for (const std::pair<StringRef, AttributeMetadata> &it :
749        emitHelper.getAttrMetadata()) {
750     const AttributeMetadata &metadata = it.second;
751     if (!metadata.isRequired) {
752       optionalAttrs.push_back(&metadata);
753       continue;
754     }
755 
756     body << formatv("::mlir::Attribute {0};\n", getVarName(it.first));
757     for (const AttributeMetadata *optional : optionalAttrs) {
758       body << formatv("::mlir::Attribute {0};\n",
759                       getVarName(optional->attrName));
760     }
761     body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first),
762                     emitHelper.emitErrorPrefix(), it.first);
763     for (const AttributeMetadata *optional : optionalAttrs) {
764       body << formatv(checkOptionalAttr,
765                       emitHelper.getAttrName(optional->attrName),
766                       optional->attrName);
767     }
768     body << "\n  ++namedAttrIt;\n}\n";
769     optionalAttrs.clear();
770   }
771   // Get trailing optional attributes.
772   if (!optionalAttrs.empty()) {
773     for (const AttributeMetadata *optional : optionalAttrs) {
774       body << formatv("::mlir::Attribute {0};\n",
775                       getVarName(optional->attrName));
776     }
777     body << checkTrailingAttrs;
778     for (const AttributeMetadata *optional : optionalAttrs) {
779       body << formatv(checkOptionalAttr,
780                       emitHelper.getAttrName(optional->attrName),
781                       optional->attrName);
782     }
783     body << "\n  ++namedAttrIt;\n}\n";
784   }
785   body.unindent();
786 
787   // Emit the checks for segment attributes first so that the other constraints
788   // can call operand and result getters.
789   genNativeTraitAttrVerifier(body, emitHelper);
790 
791   for (const auto &namedAttr : emitHelper.getOp().getAttributes())
792     if (canEmitVerifier(namedAttr.attr))
793       emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name));
794 }
795 
796 /// Op extra class definitions have a `$cppClass` substitution that is to be
797 /// replaced by the C++ class name.
formatExtraDefinitions(const Operator & op)798 static std::string formatExtraDefinitions(const Operator &op) {
799   FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName());
800   return tgfmt(op.getExtraClassDefinition(), &ctx).str();
801 }
802 
OpEmitter(const Operator & op,const StaticVerifierFunctionEmitter & staticVerifierEmitter)803 OpEmitter::OpEmitter(const Operator &op,
804                      const StaticVerifierFunctionEmitter &staticVerifierEmitter)
805     : def(op.getDef()), op(op),
806       opClass(op.getCppClassName(), op.getExtraClassDeclaration(),
807               formatExtraDefinitions(op)),
808       staticVerifierEmitter(staticVerifierEmitter),
809       emitHelper(op, /*emitForOp=*/true) {
810   verifyCtx.withOp("(*this->getOperation())");
811   verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
812 
813   genTraits();
814 
815   // Generate C++ code for various op methods. The order here determines the
816   // methods in the generated file.
817   genAttrNameGetters();
818   genOpAsmInterface();
819   genOpNameGetter();
820   genNamedOperandGetters();
821   genNamedOperandSetters();
822   genNamedResultGetters();
823   genNamedRegionGetters();
824   genNamedSuccessorGetters();
825   genAttrGetters();
826   genAttrSetters();
827   genOptionalAttrRemovers();
828   genBuilder();
829   genPopulateDefaultAttributes();
830   genParser();
831   genPrinter();
832   genVerifier();
833   genCustomVerifier();
834   genCanonicalizerDecls();
835   genFolderDecls();
836   genTypeInterfaceMethods();
837   genOpInterfaceMethods();
838   generateOpFormat(op, opClass);
839   genSideEffectInterfaceMethods();
840 }
emitDecl(const Operator & op,raw_ostream & os,const StaticVerifierFunctionEmitter & staticVerifierEmitter)841 void OpEmitter::emitDecl(
842     const Operator &op, raw_ostream &os,
843     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
844   OpEmitter(op, staticVerifierEmitter).emitDecl(os);
845 }
846 
emitDef(const Operator & op,raw_ostream & os,const StaticVerifierFunctionEmitter & staticVerifierEmitter)847 void OpEmitter::emitDef(
848     const Operator &op, raw_ostream &os,
849     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
850   OpEmitter(op, staticVerifierEmitter).emitDef(os);
851 }
852 
emitDecl(raw_ostream & os)853 void OpEmitter::emitDecl(raw_ostream &os) {
854   opClass.finalize();
855   opClass.writeDeclTo(os);
856 }
857 
emitDef(raw_ostream & os)858 void OpEmitter::emitDef(raw_ostream &os) {
859   opClass.finalize();
860   opClass.writeDefTo(os);
861 }
862 
errorIfPruned(size_t line,Method * m,const Twine & methodName,const Operator & op)863 static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
864                           const Operator &op) {
865   if (m)
866     return;
867   PrintFatalError(op.getLoc(), "Unexpected overlap when generating `" +
868                                    methodName + "` for " +
869                                    op.getOperationName() + " (from line " +
870                                    Twine(line) + ")");
871 }
872 
873 #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
874 
genAttrNameGetters()875 void OpEmitter::genAttrNameGetters() {
876   const llvm::MapVector<StringRef, AttributeMetadata> &attributes =
877       emitHelper.getAttrMetadata();
878 
879   // Emit the getAttributeNames method.
880   {
881     auto *method = opClass.addStaticInlineMethod(
882         "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames");
883     ERROR_IF_PRUNED(method, "getAttributeNames", op);
884     auto &body = method->body();
885     if (attributes.empty()) {
886       body << "  return {};";
887       // Nothing else to do if there are no registered attributes. Exit early.
888       return;
889     }
890     body << "  static ::llvm::StringRef attrNames[] = {";
891     llvm::interleaveComma(llvm::make_first_range(attributes), body,
892                           [&](StringRef attrName) {
893                             body << "::llvm::StringRef(\"" << attrName << "\")";
894                           });
895     body << "};\n  return ::llvm::makeArrayRef(attrNames);";
896   }
897 
898   // Emit the getAttributeNameForIndex methods.
899   {
900     auto *method = opClass.addInlineMethod<Method::Private>(
901         "::mlir::StringAttr", "getAttributeNameForIndex",
902         MethodParameter("unsigned", "index"));
903     ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
904     method->body()
905         << "  return getAttributeNameForIndex((*this)->getName(), index);";
906   }
907   {
908     auto *method = opClass.addStaticInlineMethod<Method::Private>(
909         "::mlir::StringAttr", "getAttributeNameForIndex",
910         MethodParameter("::mlir::OperationName", "name"),
911         MethodParameter("unsigned", "index"));
912     ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
913 
914     const char *const getAttrName = R"(
915   assert(index < {0} && "invalid attribute index");
916   return name.getRegisteredInfo()->getAttributeNames()[index];
917 )";
918     method->body() << formatv(getAttrName, attributes.size());
919   }
920 
921   // Generate the <attr>AttrName methods, that expose the attribute names to
922   // users.
923   const char *attrNameMethodBody = "  return getAttributeNameForIndex({0});";
924   for (auto &attrIt : llvm::enumerate(llvm::make_first_range(attributes))) {
925     for (StringRef name : op.getGetterNames(attrIt.value())) {
926       std::string methodName = (name + "AttrName").str();
927 
928       // Generate the non-static variant.
929       {
930         auto *method =
931             opClass.addInlineMethod("::mlir::StringAttr", methodName);
932         ERROR_IF_PRUNED(method, methodName, op);
933         method->body() << llvm::formatv(attrNameMethodBody, attrIt.index());
934       }
935 
936       // Generate the static variant.
937       {
938         auto *method = opClass.addStaticInlineMethod(
939             "::mlir::StringAttr", methodName,
940             MethodParameter("::mlir::OperationName", "name"));
941         ERROR_IF_PRUNED(method, methodName, op);
942         method->body() << llvm::formatv(attrNameMethodBody,
943                                         "name, " + Twine(attrIt.index()));
944       }
945     }
946   }
947 }
948 
949 // Emit the getter for an attribute with the return type specified.
950 // It is templated to be shared between the Op and the adaptor class.
951 template <typename OpClassOrAdaptor>
emitAttrGetterWithReturnType(FmtContext & fctx,OpClassOrAdaptor & opClass,const Operator & op,StringRef name,Attribute attr)952 static void emitAttrGetterWithReturnType(FmtContext &fctx,
953                                          OpClassOrAdaptor &opClass,
954                                          const Operator &op, StringRef name,
955                                          Attribute attr) {
956   auto *method = opClass.addMethod(attr.getReturnType(), name);
957   ERROR_IF_PRUNED(method, name, op);
958   auto &body = method->body();
959   body << "  auto attr = " << name << "Attr();\n";
960   if (attr.hasDefaultValue()) {
961     // Returns the default value if not set.
962     // TODO: this is inefficient, we are recreating the attribute for every
963     // call. This should be set instead.
964     if (!attr.isConstBuildable()) {
965       PrintFatalError("DefaultValuedAttr of type " + attr.getAttrDefName() +
966                       " must have a constBuilder");
967     }
968     std::string defaultValue = std::string(
969         tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
970     body << "    if (!attr)\n      return "
971          << tgfmt(attr.getConvertFromStorageCall(),
972                   &fctx.withSelf(defaultValue))
973          << ";\n";
974   }
975   body << "  return "
976        << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
977        << ";\n";
978 }
979 
genAttrGetters()980 void OpEmitter::genAttrGetters() {
981   FmtContext fctx;
982   fctx.withBuilder("::mlir::Builder((*this)->getContext())");
983 
984   // Emit the derived attribute body.
985   auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
986     if (auto *method = opClass.addMethod(attr.getReturnType(), name))
987       method->body() << "  " << attr.getDerivedCodeBody() << "\n";
988   };
989 
990   // Generate named accessor with Attribute return type. This is a wrapper class
991   // that allows referring to the attributes via accessors instead of having to
992   // use the string interface for better compile time verification.
993   auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName,
994                                      Attribute attr) {
995     auto *method = opClass.addMethod(attr.getStorageType(), name + "Attr");
996     if (!method)
997       return;
998     method->body() << formatv(
999         "  return {0}.{1}<{2}>();", emitHelper.getAttr(attrName),
1000         attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null"
1001                                                     : "cast",
1002         attr.getStorageType());
1003   };
1004 
1005   for (const NamedAttribute &namedAttr : op.getAttributes()) {
1006     for (StringRef name : op.getGetterNames(namedAttr.name)) {
1007       if (namedAttr.attr.isDerivedAttr()) {
1008         emitDerivedAttr(name, namedAttr.attr);
1009       } else {
1010         emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr);
1011         emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr);
1012       }
1013     }
1014   }
1015 
1016   auto derivedAttrs = make_filter_range(op.getAttributes(),
1017                                         [](const NamedAttribute &namedAttr) {
1018                                           return namedAttr.attr.isDerivedAttr();
1019                                         });
1020   if (derivedAttrs.empty())
1021     return;
1022 
1023   opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
1024   // Generate helper method to query whether a named attribute is a derived
1025   // attribute. This enables, for example, avoiding adding an attribute that
1026   // overlaps with a derived attribute.
1027   {
1028     auto *method =
1029         opClass.addStaticMethod("bool", "isDerivedAttribute",
1030                                 MethodParameter("::llvm::StringRef", "name"));
1031     ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
1032     auto &body = method->body();
1033     for (auto namedAttr : derivedAttrs)
1034       body << "  if (name == \"" << namedAttr.name << "\") return true;\n";
1035     body << " return false;";
1036   }
1037   // Generate method to materialize derived attributes as a DictionaryAttr.
1038   {
1039     auto *method = opClass.addMethod("::mlir::DictionaryAttr",
1040                                      "materializeDerivedAttributes");
1041     ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
1042     auto &body = method->body();
1043 
1044     auto nonMaterializable =
1045         make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
1046           return namedAttr.attr.getConvertFromStorageCall().empty();
1047         });
1048     if (!nonMaterializable.empty()) {
1049       std::string attrs;
1050       llvm::raw_string_ostream os(attrs);
1051       interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
1052         os << op.getGetterName(attr.name);
1053       });
1054       PrintWarning(
1055           op.getLoc(),
1056           formatv(
1057               "op has non-materializable derived attributes '{0}', skipping",
1058               os.str()));
1059       body << formatv("  emitOpError(\"op has non-materializable derived "
1060                       "attributes '{0}'\");\n",
1061                       attrs);
1062       body << "  return nullptr;";
1063       return;
1064     }
1065 
1066     body << "  ::mlir::MLIRContext* ctx = getContext();\n";
1067     body << "  ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
1068     body << "  return ::mlir::DictionaryAttr::get(";
1069     body << "  ctx, {\n";
1070     interleave(
1071         derivedAttrs, body,
1072         [&](const NamedAttribute &namedAttr) {
1073           auto tmpl = namedAttr.attr.getConvertFromStorageCall();
1074           std::string name = op.getGetterName(namedAttr.name);
1075           body << "    {" << name << "AttrName(),\n"
1076                << tgfmt(tmpl, &fctx.withSelf(name + "()")
1077                                    .withBuilder("odsBuilder")
1078                                    .addSubst("_ctxt", "ctx"))
1079                << "}";
1080         },
1081         ",\n");
1082     body << "});";
1083   }
1084 }
1085 
genAttrSetters()1086 void OpEmitter::genAttrSetters() {
1087   // Generate raw named setter type. This is a wrapper class that allows setting
1088   // to the attributes via setters instead of having to use the string interface
1089   // for better compile time verification.
1090   auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
1091                                      Attribute attr) {
1092     auto *method =
1093         opClass.addMethod("void", setterName + "Attr",
1094                           MethodParameter(attr.getStorageType(), "attr"));
1095     if (method)
1096       method->body() << formatv("  (*this)->setAttr({0}AttrName(), attr);",
1097                                 getterName);
1098   };
1099 
1100   for (const NamedAttribute &namedAttr : op.getAttributes()) {
1101     if (namedAttr.attr.isDerivedAttr())
1102       continue;
1103     for (auto names : llvm::zip(op.getSetterNames(namedAttr.name),
1104                                 op.getGetterNames(namedAttr.name)))
1105       emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
1106                               namedAttr.attr);
1107   }
1108 }
1109 
genOptionalAttrRemovers()1110 void OpEmitter::genOptionalAttrRemovers() {
1111   // Generate methods for removing optional attributes, instead of having to
1112   // use the string interface. Enables better compile time verification.
1113   auto emitRemoveAttr = [&](StringRef name) {
1114     auto upperInitial = name.take_front().upper();
1115     auto suffix = name.drop_front();
1116     auto *method = opClass.addMethod("::mlir::Attribute",
1117                                      "remove" + upperInitial + suffix + "Attr");
1118     if (!method)
1119       return;
1120     method->body() << formatv("  return (*this)->removeAttr({0}AttrName());",
1121                               op.getGetterName(name));
1122   };
1123 
1124   for (const NamedAttribute &namedAttr : op.getAttributes())
1125     if (namedAttr.attr.isOptional())
1126       emitRemoveAttr(namedAttr.name);
1127 }
1128 
1129 // Generates the code to compute the start and end index of an operand or result
1130 // range.
1131 template <typename RangeT>
1132 static void
generateValueRangeStartAndEnd(Class & opClass,StringRef methodName,int numVariadic,int numNonVariadic,StringRef rangeSizeCall,bool hasAttrSegmentSize,StringRef sizeAttrInit,RangeT && odsValues)1133 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
1134                               int numVariadic, int numNonVariadic,
1135                               StringRef rangeSizeCall, bool hasAttrSegmentSize,
1136                               StringRef sizeAttrInit, RangeT &&odsValues) {
1137   auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName,
1138                                    MethodParameter("unsigned", "index"));
1139   if (!method)
1140     return;
1141   auto &body = method->body();
1142   if (numVariadic == 0) {
1143     body << "  return {index, 1};\n";
1144   } else if (hasAttrSegmentSize) {
1145     body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
1146   } else {
1147     // Because the op can have arbitrarily interleaved variadic and non-variadic
1148     // operands, we need to embed a list in the "sink" getter method for
1149     // calculation at run-time.
1150     SmallVector<StringRef, 4> isVariadic;
1151     isVariadic.reserve(llvm::size(odsValues));
1152     for (auto &it : odsValues)
1153       isVariadic.push_back(it.isVariableLength() ? "true" : "false");
1154     std::string isVariadicList = llvm::join(isVariadic, ", ");
1155     body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
1156                     numNonVariadic, numVariadic, rangeSizeCall, "operand");
1157   }
1158 }
1159 
1160 // Generates the named operand getter methods for the given Operator `op` and
1161 // puts them in `opClass`.  Uses `rangeType` as the return type of getters that
1162 // return a range of operands (individual operands are `Value ` and each
1163 // element in the range must also be `Value `); use `rangeBeginCall` to get
1164 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
1165 // obtain the number of operands. `getOperandCallPattern` contains the code
1166 // necessary to obtain a single operand whose position will be substituted
1167 // instead of
1168 // "{0}" marker in the pattern.  Note that the pattern should work for any kind
1169 // of ops, in particular for one-operand ops that may not have the
1170 // `getOperand(unsigned)` method.
generateNamedOperandGetters(const Operator & op,Class & opClass,bool isAdaptor,StringRef sizeAttrInit,StringRef rangeType,StringRef rangeBeginCall,StringRef rangeSizeCall,StringRef getOperandCallPattern)1171 static void generateNamedOperandGetters(const Operator &op, Class &opClass,
1172                                         bool isAdaptor, StringRef sizeAttrInit,
1173                                         StringRef rangeType,
1174                                         StringRef rangeBeginCall,
1175                                         StringRef rangeSizeCall,
1176                                         StringRef getOperandCallPattern) {
1177   const int numOperands = op.getNumOperands();
1178   const int numVariadicOperands = op.getNumVariableLengthOperands();
1179   const int numNormalOperands = numOperands - numVariadicOperands;
1180 
1181   const auto *sameVariadicSize =
1182       op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
1183   const auto *attrSizedOperands =
1184       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1185 
1186   if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
1187     PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
1188                                  "specification over their sizes");
1189   }
1190 
1191   if (numVariadicOperands < 2 && attrSizedOperands) {
1192     PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
1193                                  "to use 'AttrSizedOperandSegments' trait");
1194   }
1195 
1196   if (attrSizedOperands && sameVariadicSize) {
1197     PrintFatalError(op.getLoc(),
1198                     "op cannot have both 'AttrSizedOperandSegments' and "
1199                     "'SameVariadicOperandSize' traits");
1200   }
1201 
1202   // First emit a few "sink" getter methods upon which we layer all nicer named
1203   // getter methods.
1204   generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
1205                                 numVariadicOperands, numNormalOperands,
1206                                 rangeSizeCall, attrSizedOperands, sizeAttrInit,
1207                                 const_cast<Operator &>(op).getOperands());
1208 
1209   auto *m = opClass.addMethod(rangeType, "getODSOperands",
1210                               MethodParameter("unsigned", "index"));
1211   ERROR_IF_PRUNED(m, "getODSOperands", op);
1212   auto &body = m->body();
1213   body << formatv(valueRangeReturnCode, rangeBeginCall,
1214                   "getODSOperandIndexAndLength(index)");
1215 
1216   // Then we emit nicer named getter methods by redirecting to the "sink" getter
1217   // method.
1218   for (int i = 0; i != numOperands; ++i) {
1219     const auto &operand = op.getOperand(i);
1220     if (operand.name.empty())
1221       continue;
1222     for (StringRef name : op.getGetterNames(operand.name)) {
1223       if (operand.isOptional()) {
1224         m = opClass.addMethod("::mlir::Value", name);
1225         ERROR_IF_PRUNED(m, name, op);
1226         m->body() << "  auto operands = getODSOperands(" << i << ");\n"
1227                   << "  return operands.empty() ? ::mlir::Value() : "
1228                      "*operands.begin();";
1229       } else if (operand.isVariadicOfVariadic()) {
1230         std::string segmentAttr = op.getGetterName(
1231             operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
1232         if (isAdaptor) {
1233           m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>",
1234                                 name);
1235           ERROR_IF_PRUNED(m, name, op);
1236           m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
1237                                      segmentAttr, i);
1238           continue;
1239         }
1240 
1241         m = opClass.addMethod("::mlir::OperandRangeRange", name);
1242         ERROR_IF_PRUNED(m, name, op);
1243         m->body() << "  return getODSOperands(" << i << ").split("
1244                   << segmentAttr << "Attr());";
1245       } else if (operand.isVariadic()) {
1246         m = opClass.addMethod(rangeType, name);
1247         ERROR_IF_PRUNED(m, name, op);
1248         m->body() << "  return getODSOperands(" << i << ");";
1249       } else {
1250         m = opClass.addMethod("::mlir::Value", name);
1251         ERROR_IF_PRUNED(m, name, op);
1252         m->body() << "  return *getODSOperands(" << i << ").begin();";
1253       }
1254     }
1255   }
1256 }
1257 
genNamedOperandGetters()1258 void OpEmitter::genNamedOperandGetters() {
1259   // Build the code snippet used for initializing the operand_segment_size)s
1260   // array.
1261   std::string attrSizeInitCode;
1262   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1263     attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
1264                                emitHelper.getAttr(operandSegmentAttrName));
1265   }
1266 
1267   generateNamedOperandGetters(
1268       op, opClass,
1269       /*isAdaptor=*/false,
1270       /*sizeAttrInit=*/attrSizeInitCode,
1271       /*rangeType=*/"::mlir::Operation::operand_range",
1272       /*rangeBeginCall=*/"getOperation()->operand_begin()",
1273       /*rangeSizeCall=*/"getOperation()->getNumOperands()",
1274       /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
1275 }
1276 
genNamedOperandSetters()1277 void OpEmitter::genNamedOperandSetters() {
1278   auto *attrSizedOperands =
1279       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1280   for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
1281     const auto &operand = op.getOperand(i);
1282     if (operand.name.empty())
1283       continue;
1284     for (StringRef name : op.getGetterNames(operand.name)) {
1285       auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
1286                                       ? "::mlir::MutableOperandRangeRange"
1287                                       : "::mlir::MutableOperandRange",
1288                                   (name + "Mutable").str());
1289       ERROR_IF_PRUNED(m, name, op);
1290       auto &body = m->body();
1291       body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n"
1292            << "  auto mutableRange = "
1293               "::mlir::MutableOperandRange(getOperation(), "
1294               "range.first, range.second";
1295       if (attrSizedOperands) {
1296         body << formatv(
1297             ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
1298             emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
1299       }
1300       body << ");\n";
1301 
1302       // If this operand is a nested variadic, we split the range into a
1303       // MutableOperandRangeRange that provides a range over all of the
1304       // sub-ranges.
1305       if (operand.isVariadicOfVariadic()) {
1306         body << "  return "
1307                 "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
1308              << op.getGetterName(
1309                     operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
1310              << "AttrName()));\n";
1311       } else {
1312         // Otherwise, we use the full range directly.
1313         body << "  return mutableRange;\n";
1314       }
1315     }
1316   }
1317 }
1318 
genNamedResultGetters()1319 void OpEmitter::genNamedResultGetters() {
1320   const int numResults = op.getNumResults();
1321   const int numVariadicResults = op.getNumVariableLengthResults();
1322   const int numNormalResults = numResults - numVariadicResults;
1323 
1324   // If we have more than one variadic results, we need more complicated logic
1325   // to calculate the value range for each result.
1326 
1327   const auto *sameVariadicSize =
1328       op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
1329   const auto *attrSizedResults =
1330       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
1331 
1332   if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
1333     PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
1334                                  "specification over their sizes");
1335   }
1336 
1337   if (numVariadicResults < 2 && attrSizedResults) {
1338     PrintFatalError(op.getLoc(), "op must have at least two variadic results "
1339                                  "to use 'AttrSizedResultSegments' trait");
1340   }
1341 
1342   if (attrSizedResults && sameVariadicSize) {
1343     PrintFatalError(op.getLoc(),
1344                     "op cannot have both 'AttrSizedResultSegments' and "
1345                     "'SameVariadicResultSize' traits");
1346   }
1347 
1348   // Build the initializer string for the result segment size attribute.
1349   std::string attrSizeInitCode;
1350   if (attrSizedResults) {
1351     attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
1352                                emitHelper.getAttr(resultSegmentAttrName));
1353   }
1354 
1355   generateValueRangeStartAndEnd(
1356       opClass, "getODSResultIndexAndLength", numVariadicResults,
1357       numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
1358       attrSizeInitCode, op.getResults());
1359 
1360   auto *m =
1361       opClass.addMethod("::mlir::Operation::result_range", "getODSResults",
1362                         MethodParameter("unsigned", "index"));
1363   ERROR_IF_PRUNED(m, "getODSResults", op);
1364   m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
1365                        "getODSResultIndexAndLength(index)");
1366 
1367   for (int i = 0; i != numResults; ++i) {
1368     const auto &result = op.getResult(i);
1369     if (result.name.empty())
1370       continue;
1371     for (StringRef name : op.getGetterNames(result.name)) {
1372       if (result.isOptional()) {
1373         m = opClass.addMethod("::mlir::Value", name);
1374         ERROR_IF_PRUNED(m, name, op);
1375         m->body()
1376             << "  auto results = getODSResults(" << i << ");\n"
1377             << "  return results.empty() ? ::mlir::Value() : *results.begin();";
1378       } else if (result.isVariadic()) {
1379         m = opClass.addMethod("::mlir::Operation::result_range", name);
1380         ERROR_IF_PRUNED(m, name, op);
1381         m->body() << "  return getODSResults(" << i << ");";
1382       } else {
1383         m = opClass.addMethod("::mlir::Value", name);
1384         ERROR_IF_PRUNED(m, name, op);
1385         m->body() << "  return *getODSResults(" << i << ").begin();";
1386       }
1387     }
1388   }
1389 }
1390 
genNamedRegionGetters()1391 void OpEmitter::genNamedRegionGetters() {
1392   unsigned numRegions = op.getNumRegions();
1393   for (unsigned i = 0; i < numRegions; ++i) {
1394     const auto &region = op.getRegion(i);
1395     if (region.name.empty())
1396       continue;
1397 
1398     for (StringRef name : op.getGetterNames(region.name)) {
1399       // Generate the accessors for a variadic region.
1400       if (region.isVariadic()) {
1401         auto *m =
1402             opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name);
1403         ERROR_IF_PRUNED(m, name, op);
1404         m->body() << formatv("  return (*this)->getRegions().drop_front({0});",
1405                              i);
1406         continue;
1407       }
1408 
1409       auto *m = opClass.addMethod("::mlir::Region &", name);
1410       ERROR_IF_PRUNED(m, name, op);
1411       m->body() << formatv("  return (*this)->getRegion({0});", i);
1412     }
1413   }
1414 }
1415 
genNamedSuccessorGetters()1416 void OpEmitter::genNamedSuccessorGetters() {
1417   unsigned numSuccessors = op.getNumSuccessors();
1418   for (unsigned i = 0; i < numSuccessors; ++i) {
1419     const NamedSuccessor &successor = op.getSuccessor(i);
1420     if (successor.name.empty())
1421       continue;
1422 
1423     for (StringRef name : op.getGetterNames(successor.name)) {
1424       // Generate the accessors for a variadic successor list.
1425       if (successor.isVariadic()) {
1426         auto *m = opClass.addMethod("::mlir::SuccessorRange", name);
1427         ERROR_IF_PRUNED(m, name, op);
1428         m->body() << formatv(
1429             "  return {std::next((*this)->successor_begin(), {0}), "
1430             "(*this)->successor_end()};",
1431             i);
1432         continue;
1433       }
1434 
1435       auto *m = opClass.addMethod("::mlir::Block *", name);
1436       ERROR_IF_PRUNED(m, name, op);
1437       m->body() << formatv("  return (*this)->getSuccessor({0});", i);
1438     }
1439   }
1440 }
1441 
canGenerateUnwrappedBuilder(const Operator & op)1442 static bool canGenerateUnwrappedBuilder(const Operator &op) {
1443   // If this op does not have native attributes at all, return directly to avoid
1444   // redefining builders.
1445   if (op.getNumNativeAttributes() == 0)
1446     return false;
1447 
1448   bool canGenerate = false;
1449   // We are generating builders that take raw values for attributes. We need to
1450   // make sure the native attributes have a meaningful "unwrapped" value type
1451   // different from the wrapped mlir::Attribute type to avoid redefining
1452   // builders. This checks for the op has at least one such native attribute.
1453   for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
1454     const NamedAttribute &namedAttr = op.getAttribute(i);
1455     if (canUseUnwrappedRawValue(namedAttr.attr)) {
1456       canGenerate = true;
1457       break;
1458     }
1459   }
1460   return canGenerate;
1461 }
1462 
canInferType(const Operator & op)1463 static bool canInferType(const Operator &op) {
1464   return op.getTrait("::mlir::InferTypeOpInterface::Trait");
1465 }
1466 
genSeparateArgParamBuilder()1467 void OpEmitter::genSeparateArgParamBuilder() {
1468   SmallVector<AttrParamKind, 2> attrBuilderType;
1469   attrBuilderType.push_back(AttrParamKind::WrappedAttr);
1470   if (canGenerateUnwrappedBuilder(op))
1471     attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
1472 
1473   // Emit with separate builders with or without unwrapped attributes and/or
1474   // inferring result type.
1475   auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
1476                   bool inferType) {
1477     SmallVector<MethodParameter> paramList;
1478     SmallVector<std::string, 4> resultNames;
1479     llvm::StringSet<> inferredAttributes;
1480     buildParamList(paramList, inferredAttributes, resultNames, paramKind,
1481                    attrType);
1482 
1483     auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1484     // If the builder is redundant, skip generating the method.
1485     if (!m)
1486       return;
1487     auto &body = m->body();
1488     genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
1489                                            /*isRawValueAttr=*/attrType ==
1490                                                AttrParamKind::UnwrappedValue);
1491 
1492     // Push all result types to the operation state
1493 
1494     if (inferType) {
1495       // Generate builder that infers type too.
1496       // TODO: Subsume this with general checking if type can be
1497       // inferred automatically.
1498       // TODO: Expand to handle regions.
1499       body << formatv(R"(
1500         ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1501         if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1502                       {1}.location, {1}.operands,
1503                       {1}.attributes.getDictionary({1}.getContext()),
1504                       /*regions=*/{{}, inferredReturnTypes)))
1505           {1}.addTypes(inferredReturnTypes);
1506         else
1507           ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1508                       opClass.getClassName(), builderOpState);
1509       return;
1510     }
1511 
1512     switch (paramKind) {
1513     case TypeParamKind::None:
1514       return;
1515     case TypeParamKind::Separate:
1516       for (int i = 0, e = op.getNumResults(); i < e; ++i) {
1517         if (op.getResult(i).isOptional())
1518           body << "  if (" << resultNames[i] << ")\n  ";
1519         body << "  " << builderOpState << ".addTypes(" << resultNames[i]
1520              << ");\n";
1521       }
1522       return;
1523     case TypeParamKind::Collective: {
1524       int numResults = op.getNumResults();
1525       int numVariadicResults = op.getNumVariableLengthResults();
1526       int numNonVariadicResults = numResults - numVariadicResults;
1527       bool hasVariadicResult = numVariadicResults != 0;
1528 
1529       // Avoid emitting "resultTypes.size() >= 0u" which is always true.
1530       if (!(hasVariadicResult && numNonVariadicResults == 0))
1531         body << "  "
1532              << "assert(resultTypes.size() "
1533              << (hasVariadicResult ? ">=" : "==") << " "
1534              << numNonVariadicResults
1535              << "u && \"mismatched number of results\");\n";
1536       body << "  " << builderOpState << ".addTypes(resultTypes);\n";
1537     }
1538       return;
1539     }
1540     llvm_unreachable("unhandled TypeParamKind");
1541   };
1542 
1543   // Some of the build methods generated here may be ambiguous, but TableGen's
1544   // ambiguous function detection will elide those ones.
1545   for (auto attrType : attrBuilderType) {
1546     emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
1547     if (canInferType(op) && op.getNumRegions() == 0)
1548       emit(attrType, TypeParamKind::None, /*inferType=*/true);
1549     emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
1550   }
1551 }
1552 
genUseOperandAsResultTypeCollectiveParamBuilder()1553 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
1554   int numResults = op.getNumResults();
1555 
1556   // Signature
1557   SmallVector<MethodParameter> paramList;
1558   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1559   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1560   paramList.emplace_back("::mlir::ValueRange", "operands");
1561   // Provide default value for `attributes` when its the last parameter
1562   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1563   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1564                          "attributes", attributesDefaultValue);
1565   if (op.getNumVariadicRegions())
1566     paramList.emplace_back("unsigned", "numRegions");
1567 
1568   auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1569   // If the builder is redundant, skip generating the method
1570   if (!m)
1571     return;
1572   auto &body = m->body();
1573 
1574   // Operands
1575   body << "  " << builderOpState << ".addOperands(operands);\n";
1576 
1577   // Attributes
1578   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1579 
1580   // Create the correct number of regions
1581   if (int numRegions = op.getNumRegions()) {
1582     body << llvm::formatv(
1583         "  for (unsigned i = 0; i != {0}; ++i)\n",
1584         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1585     body << "    (void)" << builderOpState << ".addRegion();\n";
1586   }
1587 
1588   // Result types
1589   SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
1590   body << "  " << builderOpState << ".addTypes({"
1591        << llvm::join(resultTypes, ", ") << "});\n\n";
1592 }
1593 
genPopulateDefaultAttributes()1594 void OpEmitter::genPopulateDefaultAttributes() {
1595   // All done if no attributes have default values.
1596   if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) {
1597         return !named.attr.hasDefaultValue();
1598       }))
1599     return;
1600 
1601   SmallVector<MethodParameter> paramList;
1602   paramList.emplace_back("const ::mlir::RegisteredOperationName &", "opName");
1603   paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
1604   auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList);
1605   ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
1606   auto &body = m->body();
1607   body.indent();
1608 
1609   // Set default attributes that are unset.
1610   body << "auto attrNames = opName.getAttributeNames();\n";
1611   body << "::mlir::Builder " << odsBuilder
1612        << "(attrNames.front().getContext());\n";
1613   StringMap<int> attrIndex;
1614   for (const auto &it : llvm::enumerate(emitHelper.getAttrMetadata())) {
1615     attrIndex[it.value().first] = it.index();
1616   }
1617   for (const NamedAttribute &namedAttr : op.getAttributes()) {
1618     auto &attr = namedAttr.attr;
1619     if (!attr.hasDefaultValue())
1620       continue;
1621     auto index = attrIndex[namedAttr.name];
1622     body << "if (!attributes.get(attrNames[" << index << "])) {\n";
1623     FmtContext fctx;
1624     fctx.withBuilder(odsBuilder);
1625     std::string defaultValue = std::string(
1626         tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
1627     body.indent() << formatv(" attributes.append(attrNames[{0}], {1});\n",
1628                              index, defaultValue);
1629     body.unindent() << "}\n";
1630   }
1631 }
1632 
genInferredTypeCollectiveParamBuilder()1633 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
1634   SmallVector<MethodParameter> paramList;
1635   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1636   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1637   paramList.emplace_back("::mlir::ValueRange", "operands");
1638   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1639   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1640                          "attributes", attributesDefaultValue);
1641   if (op.getNumVariadicRegions())
1642     paramList.emplace_back("unsigned", "numRegions");
1643 
1644   auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1645   // If the builder is redundant, skip generating the method
1646   if (!m)
1647     return;
1648   auto &body = m->body();
1649 
1650   int numResults = op.getNumResults();
1651   int numVariadicResults = op.getNumVariableLengthResults();
1652   int numNonVariadicResults = numResults - numVariadicResults;
1653 
1654   int numOperands = op.getNumOperands();
1655   int numVariadicOperands = op.getNumVariableLengthOperands();
1656   int numNonVariadicOperands = numOperands - numVariadicOperands;
1657 
1658   // Operands
1659   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1660     body << "  assert(operands.size()"
1661          << (numVariadicOperands != 0 ? " >= " : " == ")
1662          << numNonVariadicOperands
1663          << "u && \"mismatched number of parameters\");\n";
1664   body << "  " << builderOpState << ".addOperands(operands);\n";
1665   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1666 
1667   // Create the correct number of regions
1668   if (int numRegions = op.getNumRegions()) {
1669     body << llvm::formatv(
1670         "  for (unsigned i = 0; i != {0}; ++i)\n",
1671         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1672     body << "    (void)" << builderOpState << ".addRegion();\n";
1673   }
1674 
1675   // Result types
1676   body << formatv(R"(
1677   ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1678   if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1679           {1}.location, operands,
1680           {1}.attributes.getDictionary({1}.getContext()),
1681           {1}.regions, inferredReturnTypes))) {{)",
1682                   opClass.getClassName(), builderOpState);
1683   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1684     body << "\n    assert(inferredReturnTypes.size()"
1685          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1686          << "u && \"mismatched number of return types\");";
1687   body << "\n    " << builderOpState << ".addTypes(inferredReturnTypes);";
1688 
1689   body << formatv(R"(
1690   } else {{
1691     ::llvm::report_fatal_error("Failed to infer result type(s).");
1692   })",
1693                   opClass.getClassName(), builderOpState);
1694 }
1695 
genUseOperandAsResultTypeSeparateParamBuilder()1696 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
1697   auto emit = [&](AttrParamKind attrType) {
1698     SmallVector<MethodParameter> paramList;
1699     SmallVector<std::string, 4> resultNames;
1700     llvm::StringSet<> inferredAttributes;
1701     buildParamList(paramList, inferredAttributes, resultNames,
1702                    TypeParamKind::None, attrType);
1703 
1704     auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1705     // If the builder is redundant, skip generating the method
1706     if (!m)
1707       return;
1708     auto &body = m->body();
1709     genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
1710                                            /*isRawValueAttr=*/attrType ==
1711                                                AttrParamKind::UnwrappedValue);
1712 
1713     auto numResults = op.getNumResults();
1714     if (numResults == 0)
1715       return;
1716 
1717     // Push all result types to the operation state
1718     const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
1719     std::string resultType =
1720         formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
1721     body << "  " << builderOpState << ".addTypes({" << resultType;
1722     for (int i = 1; i != numResults; ++i)
1723       body << ", " << resultType;
1724     body << "});\n\n";
1725   };
1726 
1727   emit(AttrParamKind::WrappedAttr);
1728   // Generate additional builder(s) if any attributes can be "unwrapped"
1729   if (canGenerateUnwrappedBuilder(op))
1730     emit(AttrParamKind::UnwrappedValue);
1731 }
1732 
genUseAttrAsResultTypeBuilder()1733 void OpEmitter::genUseAttrAsResultTypeBuilder() {
1734   SmallVector<MethodParameter> paramList;
1735   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1736   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1737   paramList.emplace_back("::mlir::ValueRange", "operands");
1738   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1739                          "attributes", "{}");
1740   auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1741   // If the builder is redundant, skip generating the method
1742   if (!m)
1743     return;
1744 
1745   auto &body = m->body();
1746 
1747   // Push all result types to the operation state
1748   std::string resultType;
1749   const auto &namedAttr = op.getAttribute(0);
1750 
1751   body << "  auto attrName = " << op.getGetterName(namedAttr.name)
1752        << "AttrName(" << builderOpState
1753        << ".name);\n"
1754           "  for (auto attr : attributes) {\n"
1755           "    if (attr.getName() != attrName) continue;\n";
1756   if (namedAttr.attr.isTypeAttr()) {
1757     resultType = "attr.getValue().cast<::mlir::TypeAttr>().getValue()";
1758   } else {
1759     resultType = "attr.getValue().getType()";
1760   }
1761 
1762   // Operands
1763   body << "  " << builderOpState << ".addOperands(operands);\n";
1764 
1765   // Attributes
1766   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1767 
1768   // Result types
1769   SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
1770   body << "    " << builderOpState << ".addTypes({"
1771        << llvm::join(resultTypes, ", ") << "});\n";
1772   body << "  }\n";
1773 }
1774 
1775 /// Returns a signature of the builder. Updates the context `fctx` to enable
1776 /// replacement of $_builder and $_state in the body.
1777 static SmallVector<MethodParameter>
getBuilderSignature(const Builder & builder)1778 getBuilderSignature(const Builder &builder) {
1779   ArrayRef<Builder::Parameter> params(builder.getParameters());
1780 
1781   // Inject builder and state arguments.
1782   SmallVector<MethodParameter> arguments;
1783   arguments.reserve(params.size() + 2);
1784   arguments.emplace_back("::mlir::OpBuilder &", odsBuilder);
1785   arguments.emplace_back("::mlir::OperationState &", builderOpState);
1786 
1787   for (unsigned i = 0, e = params.size(); i < e; ++i) {
1788     // If no name is provided, generate one.
1789     Optional<StringRef> paramName = params[i].getName();
1790     std::string name =
1791         paramName ? paramName->str() : "odsArg" + std::to_string(i);
1792 
1793     StringRef defaultValue;
1794     if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
1795       defaultValue = *defaultParamValue;
1796 
1797     arguments.emplace_back(params[i].getCppType(), std::move(name),
1798                            defaultValue);
1799   }
1800 
1801   return arguments;
1802 }
1803 
genBuilder()1804 void OpEmitter::genBuilder() {
1805   // Handle custom builders if provided.
1806   for (const Builder &builder : op.getBuilders()) {
1807     SmallVector<MethodParameter> arguments = getBuilderSignature(builder);
1808 
1809     Optional<StringRef> body = builder.getBody();
1810     auto properties = body ? Method::Static : Method::StaticDeclaration;
1811     auto *method =
1812         opClass.addMethod("void", "build", properties, std::move(arguments));
1813     if (body)
1814       ERROR_IF_PRUNED(method, "build", op);
1815 
1816     FmtContext fctx;
1817     fctx.withBuilder(odsBuilder);
1818     fctx.addSubst("_state", builderOpState);
1819     if (body)
1820       method->body() << tgfmt(*body, &fctx);
1821   }
1822 
1823   // Generate default builders that requires all result type, operands, and
1824   // attributes as parameters.
1825   if (op.skipDefaultBuilders())
1826     return;
1827 
1828   // We generate three classes of builders here:
1829   // 1. one having a stand-alone parameter for each operand / attribute, and
1830   genSeparateArgParamBuilder();
1831   // 2. one having an aggregated parameter for all result types / operands /
1832   //    attributes, and
1833   genCollectiveParamBuilder();
1834   // 3. one having a stand-alone parameter for each operand and attribute,
1835   //    use the first operand or attribute's type as all result types
1836   //    to facilitate different call patterns.
1837   if (op.getNumVariableLengthResults() == 0) {
1838     if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
1839       genUseOperandAsResultTypeSeparateParamBuilder();
1840       genUseOperandAsResultTypeCollectiveParamBuilder();
1841     }
1842     if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
1843       genUseAttrAsResultTypeBuilder();
1844   }
1845 }
1846 
genCollectiveParamBuilder()1847 void OpEmitter::genCollectiveParamBuilder() {
1848   int numResults = op.getNumResults();
1849   int numVariadicResults = op.getNumVariableLengthResults();
1850   int numNonVariadicResults = numResults - numVariadicResults;
1851 
1852   int numOperands = op.getNumOperands();
1853   int numVariadicOperands = op.getNumVariableLengthOperands();
1854   int numNonVariadicOperands = numOperands - numVariadicOperands;
1855 
1856   SmallVector<MethodParameter> paramList;
1857   paramList.emplace_back("::mlir::OpBuilder &", "");
1858   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1859   paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1860   paramList.emplace_back("::mlir::ValueRange", "operands");
1861   // Provide default value for `attributes` when its the last parameter
1862   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1863   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1864                          "attributes", attributesDefaultValue);
1865   if (op.getNumVariadicRegions())
1866     paramList.emplace_back("unsigned", "numRegions");
1867 
1868   auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1869   // If the builder is redundant, skip generating the method
1870   if (!m)
1871     return;
1872   auto &body = m->body();
1873 
1874   // Operands
1875   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1876     body << "  assert(operands.size()"
1877          << (numVariadicOperands != 0 ? " >= " : " == ")
1878          << numNonVariadicOperands
1879          << "u && \"mismatched number of parameters\");\n";
1880   body << "  " << builderOpState << ".addOperands(operands);\n";
1881 
1882   // Attributes
1883   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1884 
1885   // Create the correct number of regions
1886   if (int numRegions = op.getNumRegions()) {
1887     body << llvm::formatv(
1888         "  for (unsigned i = 0; i != {0}; ++i)\n",
1889         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1890     body << "    (void)" << builderOpState << ".addRegion();\n";
1891   }
1892 
1893   // Result types
1894   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1895     body << "  assert(resultTypes.size()"
1896          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1897          << "u && \"mismatched number of return types\");\n";
1898   body << "  " << builderOpState << ".addTypes(resultTypes);\n";
1899 
1900   // Generate builder that infers type too.
1901   // TODO: Expand to handle successors.
1902   if (canInferType(op) && op.getNumSuccessors() == 0)
1903     genInferredTypeCollectiveParamBuilder();
1904 }
1905 
buildParamList(SmallVectorImpl<MethodParameter> & paramList,llvm::StringSet<> & inferredAttributes,SmallVectorImpl<std::string> & resultTypeNames,TypeParamKind typeParamKind,AttrParamKind attrParamKind)1906 void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
1907                                llvm::StringSet<> &inferredAttributes,
1908                                SmallVectorImpl<std::string> &resultTypeNames,
1909                                TypeParamKind typeParamKind,
1910                                AttrParamKind attrParamKind) {
1911   resultTypeNames.clear();
1912   auto numResults = op.getNumResults();
1913   resultTypeNames.reserve(numResults);
1914 
1915   paramList.emplace_back("::mlir::OpBuilder &", odsBuilder);
1916   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1917 
1918   switch (typeParamKind) {
1919   case TypeParamKind::None:
1920     break;
1921   case TypeParamKind::Separate: {
1922     // Add parameters for all return types
1923     for (int i = 0; i < numResults; ++i) {
1924       const auto &result = op.getResult(i);
1925       std::string resultName = std::string(result.name);
1926       if (resultName.empty())
1927         resultName = std::string(formatv("resultType{0}", i));
1928 
1929       StringRef type =
1930           result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
1931 
1932       paramList.emplace_back(type, resultName, result.isOptional());
1933       resultTypeNames.emplace_back(std::move(resultName));
1934     }
1935   } break;
1936   case TypeParamKind::Collective: {
1937     paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1938     resultTypeNames.push_back("resultTypes");
1939   } break;
1940   }
1941 
1942   // Add parameters for all arguments (operands and attributes).
1943   int defaultValuedAttrStartIndex = op.getNumArgs();
1944   // Successors and variadic regions go at the end of the parameter list, so no
1945   // default arguments are possible.
1946   bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions();
1947   if (attrParamKind == AttrParamKind::UnwrappedValue && !hasTrailingParams) {
1948     // Calculate the start index from which we can attach default values in the
1949     // builder declaration.
1950     for (int i = op.getNumArgs() - 1; i >= 0; --i) {
1951       auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
1952       if (!namedAttr || !namedAttr->attr.hasDefaultValue())
1953         break;
1954 
1955       if (!canUseUnwrappedRawValue(namedAttr->attr))
1956         break;
1957 
1958       // Creating an APInt requires us to provide bitwidth, value, and
1959       // signedness, which is complicated compared to others. Similarly
1960       // for APFloat.
1961       // TODO: Adjust the 'returnType' field of such attributes
1962       // to support them.
1963       StringRef retType = namedAttr->attr.getReturnType();
1964       if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
1965         break;
1966 
1967       defaultValuedAttrStartIndex = i;
1968     }
1969   }
1970 
1971   /// Collect any inferred attributes.
1972   for (const NamedTypeConstraint &operand : op.getOperands()) {
1973     if (operand.isVariadicOfVariadic()) {
1974       inferredAttributes.insert(
1975           operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
1976     }
1977   }
1978 
1979   for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
1980     Argument arg = op.getArg(i);
1981     if (const auto *operand = arg.dyn_cast<NamedTypeConstraint *>()) {
1982       StringRef type;
1983       if (operand->isVariadicOfVariadic())
1984         type = "::llvm::ArrayRef<::mlir::ValueRange>";
1985       else if (operand->isVariadic())
1986         type = "::mlir::ValueRange";
1987       else
1988         type = "::mlir::Value";
1989 
1990       paramList.emplace_back(type, getArgumentName(op, numOperands++),
1991                              operand->isOptional());
1992       continue;
1993     }
1994     const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
1995     const Attribute &attr = namedAttr.attr;
1996 
1997     // Inferred attributes don't need to be added to the param list.
1998     if (inferredAttributes.contains(namedAttr.name))
1999       continue;
2000 
2001     StringRef type;
2002     switch (attrParamKind) {
2003     case AttrParamKind::WrappedAttr:
2004       type = attr.getStorageType();
2005       break;
2006     case AttrParamKind::UnwrappedValue:
2007       if (canUseUnwrappedRawValue(attr))
2008         type = attr.getReturnType();
2009       else
2010         type = attr.getStorageType();
2011       break;
2012     }
2013 
2014     // Attach default value if requested and possible.
2015     std::string defaultValue;
2016     if (attrParamKind == AttrParamKind::UnwrappedValue &&
2017         i >= defaultValuedAttrStartIndex) {
2018       defaultValue += attr.getDefaultValue();
2019     }
2020     paramList.emplace_back(type, namedAttr.name, StringRef(defaultValue),
2021                            attr.isOptional());
2022   }
2023 
2024   /// Insert parameters for each successor.
2025   for (const NamedSuccessor &succ : op.getSuccessors()) {
2026     StringRef type =
2027         succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
2028     paramList.emplace_back(type, succ.name);
2029   }
2030 
2031   /// Insert parameters for variadic regions.
2032   for (const NamedRegion &region : op.getRegions())
2033     if (region.isVariadic())
2034       paramList.emplace_back("unsigned",
2035                              llvm::formatv("{0}Count", region.name).str());
2036 }
2037 
genCodeForAddingArgAndRegionForBuilder(MethodBody & body,llvm::StringSet<> & inferredAttributes,bool isRawValueAttr)2038 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
2039     MethodBody &body, llvm::StringSet<> &inferredAttributes,
2040     bool isRawValueAttr) {
2041   // Push all operands to the result.
2042   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
2043     std::string argName = getArgumentName(op, i);
2044     const NamedTypeConstraint &operand = op.getOperand(i);
2045     if (operand.constraint.isVariadicOfVariadic()) {
2046       body << "  for (::mlir::ValueRange range : " << argName << ")\n   "
2047            << builderOpState << ".addOperands(range);\n";
2048 
2049       // Add the segment attribute.
2050       body << "  {\n"
2051            << "    ::llvm::SmallVector<int32_t> rangeSegments;\n"
2052            << "    for (::mlir::ValueRange range : " << argName << ")\n"
2053            << "      rangeSegments.push_back(range.size());\n"
2054            << "    " << builderOpState << ".addAttribute("
2055            << op.getGetterName(
2056                   operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
2057            << "AttrName(" << builderOpState << ".name), " << odsBuilder
2058            << ".getI32TensorAttr(rangeSegments));"
2059            << "  }\n";
2060       continue;
2061     }
2062 
2063     if (operand.isOptional())
2064       body << "  if (" << argName << ")\n  ";
2065     body << "  " << builderOpState << ".addOperands(" << argName << ");\n";
2066   }
2067 
2068   // If the operation has the operand segment size attribute, add it here.
2069   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
2070     std::string sizes = op.getGetterName(operandSegmentAttrName);
2071     body << "  " << builderOpState << ".addAttribute(" << sizes << "AttrName("
2072          << builderOpState << ".name), "
2073          << "odsBuilder.getI32VectorAttr({";
2074     interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
2075       const NamedTypeConstraint &operand = op.getOperand(i);
2076       if (!operand.isVariableLength()) {
2077         body << "1";
2078         return;
2079       }
2080 
2081       std::string operandName = getArgumentName(op, i);
2082       if (operand.isOptional()) {
2083         body << "(" << operandName << " ? 1 : 0)";
2084       } else if (operand.isVariadicOfVariadic()) {
2085         body << llvm::formatv(
2086             "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, "
2087             "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + "
2088             "range.size(); }))",
2089             operandName);
2090       } else {
2091         body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
2092       }
2093     });
2094     body << "}));\n";
2095   }
2096 
2097   // Push all attributes to the result.
2098   for (const auto &namedAttr : op.getAttributes()) {
2099     auto &attr = namedAttr.attr;
2100     if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name))
2101       continue;
2102 
2103     bool emitNotNullCheck =
2104         attr.isOptional() || (attr.hasDefaultValue() && !isRawValueAttr);
2105     if (emitNotNullCheck)
2106       body << formatv("  if ({0}) ", namedAttr.name) << "{\n";
2107 
2108     if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
2109       // If this is a raw value, then we need to wrap it in an Attribute
2110       // instance.
2111       FmtContext fctx;
2112       fctx.withBuilder("odsBuilder");
2113 
2114       std::string builderTemplate = std::string(attr.getConstBuilderTemplate());
2115 
2116       // For StringAttr, its constant builder call will wrap the input in
2117       // quotes, which is correct for normal string literals, but incorrect
2118       // here given we use function arguments. So we need to strip the
2119       // wrapping quotes.
2120       if (StringRef(builderTemplate).contains("\"$0\""))
2121         builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
2122 
2123       std::string value =
2124           std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
2125       body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
2126                       builderOpState, op.getGetterName(namedAttr.name), value);
2127     } else {
2128       body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
2129                       builderOpState, op.getGetterName(namedAttr.name),
2130                       namedAttr.name);
2131     }
2132     if (emitNotNullCheck)
2133       body << "  }\n";
2134   }
2135 
2136   // Create the correct number of regions.
2137   for (const NamedRegion &region : op.getRegions()) {
2138     if (region.isVariadic())
2139       body << formatv("  for (unsigned i = 0; i < {0}Count; ++i)\n  ",
2140                       region.name);
2141 
2142     body << "  (void)" << builderOpState << ".addRegion();\n";
2143   }
2144 
2145   // Push all successors to the result.
2146   for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
2147     body << formatv("  {0}.addSuccessors({1});\n", builderOpState,
2148                     namedSuccessor.name);
2149   }
2150 }
2151 
genCanonicalizerDecls()2152 void OpEmitter::genCanonicalizerDecls() {
2153   bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod");
2154   if (hasCanonicalizeMethod) {
2155     // static LogicResult FooOp::
2156     // canonicalize(FooOp op, PatternRewriter &rewriter);
2157     SmallVector<MethodParameter> paramList;
2158     paramList.emplace_back(op.getCppClassName(), "op");
2159     paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
2160     auto *m = opClass.declareStaticMethod("::mlir::LogicalResult",
2161                                           "canonicalize", std::move(paramList));
2162     ERROR_IF_PRUNED(m, "canonicalize", op);
2163   }
2164 
2165   // We get a prototype for 'getCanonicalizationPatterns' if requested directly
2166   // or if using a 'canonicalize' method.
2167   bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer");
2168   if (!hasCanonicalizeMethod && !hasCanonicalizer)
2169     return;
2170 
2171   // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
2172   // method, but not implementing 'getCanonicalizationPatterns' manually.
2173   bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
2174 
2175   // Add a signature for getCanonicalizationPatterns if implemented by the
2176   // dialect or if synthesized to call 'canonicalize'.
2177   SmallVector<MethodParameter> paramList;
2178   paramList.emplace_back("::mlir::RewritePatternSet &", "results");
2179   paramList.emplace_back("::mlir::MLIRContext *", "context");
2180   auto kind = hasBody ? Method::Static : Method::StaticDeclaration;
2181   auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind,
2182                                    std::move(paramList));
2183 
2184   // If synthesizing the method, fill it it.
2185   if (hasBody) {
2186     ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op);
2187     method->body() << "  results.add(canonicalize);\n";
2188   }
2189 }
2190 
genFolderDecls()2191 void OpEmitter::genFolderDecls() {
2192   bool hasSingleResult =
2193       op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
2194 
2195   if (def.getValueAsBit("hasFolder")) {
2196     if (hasSingleResult) {
2197       auto *m = opClass.declareMethod(
2198           "::mlir::OpFoldResult", "fold",
2199           MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands"));
2200       ERROR_IF_PRUNED(m, "operands", op);
2201     } else {
2202       SmallVector<MethodParameter> paramList;
2203       paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
2204       paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
2205                              "results");
2206       auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold",
2207                                       std::move(paramList));
2208       ERROR_IF_PRUNED(m, "fold", op);
2209     }
2210   }
2211 }
2212 
genOpInterfaceMethods(const tblgen::InterfaceTrait * opTrait)2213 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
2214   Interface interface = opTrait->getInterface();
2215 
2216   // Get the set of methods that should always be declared.
2217   auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
2218   llvm::StringSet<> alwaysDeclaredMethods;
2219   alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
2220                                alwaysDeclaredMethodsVec.end());
2221 
2222   for (const InterfaceMethod &method : interface.getMethods()) {
2223     // Don't declare if the method has a body.
2224     if (method.getBody())
2225       continue;
2226     // Don't declare if the method has a default implementation and the op
2227     // didn't request that it always be declared.
2228     if (method.getDefaultImplementation() &&
2229         !alwaysDeclaredMethods.count(method.getName()))
2230       continue;
2231     // Interface methods are allowed to overlap with existing methods, so don't
2232     // check if pruned.
2233     (void)genOpInterfaceMethod(method);
2234   }
2235 }
2236 
genOpInterfaceMethod(const InterfaceMethod & method,bool declaration)2237 Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
2238                                         bool declaration) {
2239   SmallVector<MethodParameter> paramList;
2240   for (const InterfaceMethod::Argument &arg : method.getArguments())
2241     paramList.emplace_back(arg.type, arg.name);
2242 
2243   auto props = (method.isStatic() ? Method::Static : Method::None) |
2244                (declaration ? Method::Declaration : Method::None);
2245   return opClass.addMethod(method.getReturnType(), method.getName(), props,
2246                            std::move(paramList));
2247 }
2248 
genOpInterfaceMethods()2249 void OpEmitter::genOpInterfaceMethods() {
2250   for (const auto &trait : op.getTraits()) {
2251     if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
2252       if (opTrait->shouldDeclareMethods())
2253         genOpInterfaceMethods(opTrait);
2254   }
2255 }
2256 
genSideEffectInterfaceMethods()2257 void OpEmitter::genSideEffectInterfaceMethods() {
2258   enum EffectKind { Operand, Result, Symbol, Static };
2259   struct EffectLocation {
2260     /// The effect applied.
2261     SideEffect effect;
2262 
2263     /// The index if the kind is not static.
2264     unsigned index;
2265 
2266     /// The kind of the location.
2267     unsigned kind;
2268   };
2269 
2270   StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
2271   auto resolveDecorators = [&](Operator::var_decorator_range decorators,
2272                                unsigned index, unsigned kind) {
2273     for (auto decorator : decorators)
2274       if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
2275         opClass.addTrait(effect->getInterfaceTrait());
2276         interfaceEffects[effect->getBaseEffectName()].push_back(
2277             EffectLocation{*effect, index, kind});
2278       }
2279   };
2280 
2281   // Collect effects that were specified via:
2282   /// Traits.
2283   for (const auto &trait : op.getTraits()) {
2284     const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
2285     if (!opTrait)
2286       continue;
2287     auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
2288     for (auto decorator : opTrait->getEffects())
2289       effects.push_back(EffectLocation{cast<SideEffect>(decorator),
2290                                        /*index=*/0, EffectKind::Static});
2291   }
2292   /// Attributes and Operands.
2293   for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
2294     Argument arg = op.getArg(i);
2295     if (arg.is<NamedTypeConstraint *>()) {
2296       resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
2297       ++operandIt;
2298       continue;
2299     }
2300     const NamedAttribute *attr = arg.get<NamedAttribute *>();
2301     if (attr->attr.getBaseAttr().isSymbolRefAttr())
2302       resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
2303   }
2304   /// Results.
2305   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
2306     resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
2307 
2308   // The code used to add an effect instance.
2309   // {0}: The effect class.
2310   // {1}: Optional value or symbol reference.
2311   // {1}: The resource class.
2312   const char *addEffectCode =
2313       "  effects.emplace_back({0}::get(), {1}{2}::get());\n";
2314 
2315   for (auto &it : interfaceEffects) {
2316     // Generate the 'getEffects' method.
2317     std::string type = llvm::formatv("::llvm::SmallVectorImpl<::mlir::"
2318                                      "SideEffects::EffectInstance<{0}>> &",
2319                                      it.first())
2320                            .str();
2321     auto *getEffects = opClass.addMethod("void", "getEffects",
2322                                          MethodParameter(type, "effects"));
2323     ERROR_IF_PRUNED(getEffects, "getEffects", op);
2324     auto &body = getEffects->body();
2325 
2326     // Add effect instances for each of the locations marked on the operation.
2327     for (auto &location : it.second) {
2328       StringRef effect = location.effect.getName();
2329       StringRef resource = location.effect.getResource();
2330       if (location.kind == EffectKind::Static) {
2331         // A static instance has no attached value.
2332         body << llvm::formatv(addEffectCode, effect, "", resource).str();
2333       } else if (location.kind == EffectKind::Symbol) {
2334         // A symbol reference requires adding the proper attribute.
2335         const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
2336         std::string argName = op.getGetterName(attr->name);
2337         if (attr->attr.isOptional()) {
2338           body << "  if (auto symbolRef = " << argName << "Attr())\n  "
2339                << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
2340                       .str();
2341         } else {
2342           body << llvm::formatv(addEffectCode, effect, argName + "Attr(), ",
2343                                 resource)
2344                       .str();
2345         }
2346       } else {
2347         // Otherwise this is an operand/result, so we need to attach the Value.
2348         body << "  for (::mlir::Value value : getODS"
2349              << (location.kind == EffectKind::Operand ? "Operands" : "Results")
2350              << "(" << location.index << "))\n  "
2351              << llvm::formatv(addEffectCode, effect, "value, ", resource).str();
2352       }
2353     }
2354   }
2355 }
2356 
genTypeInterfaceMethods()2357 void OpEmitter::genTypeInterfaceMethods() {
2358   if (!op.allResultTypesKnown())
2359     return;
2360   // Generate 'inferReturnTypes' method declaration using the interface method
2361   // declared in 'InferTypeOpInterface' op interface.
2362   const auto *trait =
2363       cast<InterfaceTrait>(op.getTrait("::mlir::InferTypeOpInterface::Trait"));
2364   Interface interface = trait->getInterface();
2365   Method *method = [&]() -> Method * {
2366     for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
2367       if (interfaceMethod.getName() == "inferReturnTypes") {
2368         return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
2369       }
2370     }
2371     assert(0 && "unable to find inferReturnTypes interface method");
2372     return nullptr;
2373   }();
2374   ERROR_IF_PRUNED(method, "inferReturnTypes", op);
2375   auto &body = method->body();
2376   body << "  inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
2377 
2378   FmtContext fctx;
2379   fctx.withBuilder("odsBuilder");
2380   body << "  ::mlir::Builder odsBuilder(context);\n";
2381 
2382   // Preprocess the result types and build all of the types used during
2383   // inferrence. This limits the amount of duplicated work when a type is used
2384   // to infer multiple others.
2385   llvm::DenseMap<Constraint, int> constraintsTypes;
2386   llvm::DenseMap<int, int> argumentsTypes;
2387   int inferredTypeIdx = 0;
2388   for (int i = 0, e = op.getNumResults(); i != e; ++i) {
2389     auto type = op.getSameTypeAsResult(i).front();
2390 
2391     // If the type isn't an argument, it refers to a buildable type.
2392     if (!type.isArg()) {
2393       auto it = constraintsTypes.try_emplace(type.getType(), inferredTypeIdx);
2394       if (!it.second)
2395         continue;
2396 
2397       // If we haven't seen this constraint, generate a variable for it.
2398       body << "  ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
2399            << tgfmt(*type.getType().getBuilderCall(), &fctx) << ";\n";
2400       continue;
2401     }
2402 
2403     // Otherwise, this is an argument.
2404     int argIndex = type.getArg();
2405     auto it = argumentsTypes.try_emplace(argIndex, inferredTypeIdx);
2406     if (!it.second)
2407       continue;
2408     body << "  ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = ";
2409 
2410     // If this is an operand, just index into operand list to access the type.
2411     auto arg = op.getArgToOperandOrAttribute(argIndex);
2412     if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
2413       body << "operands[" << arg.operandOrAttributeIndex() << "].getType()";
2414 
2415       // If this is an attribute, index into the attribute dictionary.
2416     } else {
2417       auto *attr =
2418           op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
2419       body << "attributes.get(\"" << attr->name << "\").getType()";
2420     }
2421     body << ";\n";
2422   }
2423 
2424   // Perform a second pass that handles assigning the inferred types to the
2425   // results.
2426   for (int i = 0, e = op.getNumResults(); i != e; ++i) {
2427     auto types = op.getSameTypeAsResult(i);
2428 
2429     // Append the inferred type.
2430     auto type = types.front();
2431     body << "  inferredReturnTypes[" << i << "] = odsInferredType"
2432          << (type.isArg() ? argumentsTypes[type.getArg()]
2433                           : constraintsTypes[type.getType()])
2434          << ";\n";
2435 
2436     if (types.size() == 1)
2437       continue;
2438     // TODO: We could verify equality here, but skipping that for verification.
2439   }
2440   body << "  return ::mlir::success();";
2441 }
2442 
genParser()2443 void OpEmitter::genParser() {
2444   if (hasStringAttribute(def, "assemblyFormat"))
2445     return;
2446 
2447   if (!def.getValueAsBit("hasCustomAssemblyFormat"))
2448     return;
2449 
2450   SmallVector<MethodParameter> paramList;
2451   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
2452   paramList.emplace_back("::mlir::OperationState &", "result");
2453 
2454   auto *method = opClass.declareStaticMethod("::mlir::ParseResult", "parse",
2455                                              std::move(paramList));
2456   ERROR_IF_PRUNED(method, "parse", op);
2457 }
2458 
genPrinter()2459 void OpEmitter::genPrinter() {
2460   if (hasStringAttribute(def, "assemblyFormat"))
2461     return;
2462 
2463   // Check to see if this op uses a c++ format.
2464   if (!def.getValueAsBit("hasCustomAssemblyFormat"))
2465     return;
2466   auto *method = opClass.declareMethod(
2467       "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p"));
2468   ERROR_IF_PRUNED(method, "print", op);
2469 }
2470 
genVerifier()2471 void OpEmitter::genVerifier() {
2472   auto *implMethod =
2473       opClass.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl");
2474   ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
2475   auto &implBody = implMethod->body();
2476 
2477   populateSubstitutions(emitHelper, verifyCtx);
2478   genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter);
2479   genOperandResultVerifier(implBody, op.getOperands(), "operand");
2480   genOperandResultVerifier(implBody, op.getResults(), "result");
2481 
2482   for (auto &trait : op.getTraits()) {
2483     if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
2484       implBody << tgfmt("  if (!($0))\n    "
2485                         "return emitOpError(\"failed to verify that $1\");\n",
2486                         &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
2487                         t->getSummary());
2488     }
2489   }
2490 
2491   genRegionVerifier(implBody);
2492   genSuccessorVerifier(implBody);
2493 
2494   implBody << "  return ::mlir::success();\n";
2495 
2496   // TODO: Some places use the `verifyInvariants` to do operation verification.
2497   // This may not act as their expectation because this doesn't call any
2498   // verifiers of native/interface traits. Needs to review those use cases and
2499   // see if we should use the mlir::verify() instead.
2500   auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants");
2501   ERROR_IF_PRUNED(method, "verifyInvariants", op);
2502   auto &body = method->body();
2503   if (def.getValueAsBit("hasVerifier")) {
2504     body << "  if(::mlir::succeeded(verifyInvariantsImpl()) && "
2505             "::mlir::succeeded(verify()))\n";
2506     body << "    return ::mlir::success();\n";
2507     body << "  return ::mlir::failure();";
2508   } else {
2509     body << "  return verifyInvariantsImpl();";
2510   }
2511 }
2512 
genCustomVerifier()2513 void OpEmitter::genCustomVerifier() {
2514   if (def.getValueAsBit("hasVerifier")) {
2515     auto *method = opClass.declareMethod("::mlir::LogicalResult", "verify");
2516     ERROR_IF_PRUNED(method, "verify", op);
2517   }
2518 
2519   if (def.getValueAsBit("hasRegionVerifier")) {
2520     auto *method =
2521         opClass.declareMethod("::mlir::LogicalResult", "verifyRegions");
2522     ERROR_IF_PRUNED(method, "verifyRegions", op);
2523   }
2524 }
2525 
genOperandResultVerifier(MethodBody & body,Operator::const_value_range values,StringRef valueKind)2526 void OpEmitter::genOperandResultVerifier(MethodBody &body,
2527                                          Operator::const_value_range values,
2528                                          StringRef valueKind) {
2529   // Check that an optional value is at most 1 element.
2530   //
2531   // {0}: Value index.
2532   // {1}: "operand" or "result"
2533   const char *const verifyOptional = R"(
2534     if (valueGroup{0}.size() > 1) {
2535       return emitOpError("{1} group starting at #") << index
2536           << " requires 0 or 1 element, but found " << valueGroup{0}.size();
2537     }
2538 )";
2539   // Check the types of a range of values.
2540   //
2541   // {0}: Value index.
2542   // {1}: Type constraint function.
2543   // {2}: "operand" or "result"
2544   const char *const verifyValues = R"(
2545     for (auto v : valueGroup{0}) {
2546       if (::mlir::failed({1}(*this, v.getType(), "{2}", index++)))
2547         return ::mlir::failure();
2548     }
2549 )";
2550 
2551   const auto canSkip = [](const NamedTypeConstraint &value) {
2552     return !value.hasPredicate() && !value.isOptional() &&
2553            !value.isVariadicOfVariadic();
2554   };
2555   if (values.empty() || llvm::all_of(values, canSkip))
2556     return;
2557 
2558   FmtContext fctx;
2559 
2560   body << "  {\n    unsigned index = 0; (void)index;\n";
2561 
2562   for (const auto &staticValue : llvm::enumerate(values)) {
2563     const NamedTypeConstraint &value = staticValue.value();
2564 
2565     bool hasPredicate = value.hasPredicate();
2566     bool isOptional = value.isOptional();
2567     bool isVariadicOfVariadic = value.isVariadicOfVariadic();
2568     if (!hasPredicate && !isOptional && !isVariadicOfVariadic)
2569       continue;
2570     body << formatv("    auto valueGroup{2} = getODS{0}{1}s({2});\n",
2571                     // Capitalize the first letter to match the function name
2572                     valueKind.substr(0, 1).upper(), valueKind.substr(1),
2573                     staticValue.index());
2574 
2575     // If the constraint is optional check that the value group has at most 1
2576     // value.
2577     if (isOptional) {
2578       body << formatv(verifyOptional, staticValue.index(), valueKind);
2579     } else if (isVariadicOfVariadic) {
2580       body << formatv(
2581           "    if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
2582           "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n"
2583           "      return ::mlir::failure();\n",
2584           value.constraint.getVariadicOfVariadicSegmentSizeAttr(), value.name,
2585           staticValue.index());
2586     }
2587 
2588     // Otherwise, if there is no predicate there is nothing left to do.
2589     if (!hasPredicate)
2590       continue;
2591     // Emit a loop to check all the dynamic values in the pack.
2592     StringRef constraintFn =
2593         staticVerifierEmitter.getTypeConstraintFn(value.constraint);
2594     body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind);
2595   }
2596 
2597   body << "  }\n";
2598 }
2599 
genRegionVerifier(MethodBody & body)2600 void OpEmitter::genRegionVerifier(MethodBody &body) {
2601   /// Code to verify a region.
2602   ///
2603   /// {0}: Getter for the regions.
2604   /// {1}: The region constraint.
2605   /// {2}: The region's name.
2606   /// {3}: The region description.
2607   const char *const verifyRegion = R"(
2608     for (auto &region : {0})
2609       if (::mlir::failed({1}(*this, region, "{2}", index++)))
2610         return ::mlir::failure();
2611 )";
2612   /// Get a single region.
2613   ///
2614   /// {0}: The region's index.
2615   const char *const getSingleRegion =
2616       "::llvm::makeMutableArrayRef((*this)->getRegion({0}))";
2617 
2618   // If we have no regions, there is nothing more to do.
2619   const auto canSkip = [](const NamedRegion &region) {
2620     return region.constraint.getPredicate().isNull();
2621   };
2622   auto regions = op.getRegions();
2623   if (regions.empty() && llvm::all_of(regions, canSkip))
2624     return;
2625 
2626   body << "  {\n    unsigned index = 0; (void)index;\n";
2627   for (const auto &it : llvm::enumerate(regions)) {
2628     const auto &region = it.value();
2629     if (canSkip(region))
2630       continue;
2631 
2632     auto getRegion = region.isVariadic()
2633                          ? formatv("{0}()", op.getGetterName(region.name)).str()
2634                          : formatv(getSingleRegion, it.index()).str();
2635     auto constraintFn =
2636         staticVerifierEmitter.getRegionConstraintFn(region.constraint);
2637     body << formatv(verifyRegion, getRegion, constraintFn, region.name);
2638   }
2639   body << "  }\n";
2640 }
2641 
genSuccessorVerifier(MethodBody & body)2642 void OpEmitter::genSuccessorVerifier(MethodBody &body) {
2643   const char *const verifySuccessor = R"(
2644     for (auto *successor : {0})
2645       if (::mlir::failed({1}(*this, successor, "{2}", index++)))
2646         return ::mlir::failure();
2647 )";
2648   /// Get a single successor.
2649   ///
2650   /// {0}: The successor's name.
2651   const char *const getSingleSuccessor = "::llvm::makeMutableArrayRef({0}())";
2652 
2653   // If we have no successors, there is nothing more to do.
2654   const auto canSkip = [](const NamedSuccessor &successor) {
2655     return successor.constraint.getPredicate().isNull();
2656   };
2657   auto successors = op.getSuccessors();
2658   if (successors.empty() && llvm::all_of(successors, canSkip))
2659     return;
2660 
2661   body << "  {\n    unsigned index = 0; (void)index;\n";
2662 
2663   for (auto &it : llvm::enumerate(successors)) {
2664     const auto &successor = it.value();
2665     if (canSkip(successor))
2666       continue;
2667 
2668     auto getSuccessor =
2669         formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor,
2670                 successor.name, it.index())
2671             .str();
2672     auto constraintFn =
2673         staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint);
2674     body << formatv(verifySuccessor, getSuccessor, constraintFn,
2675                     successor.name);
2676   }
2677   body << "  }\n";
2678 }
2679 
2680 /// Add a size count trait to the given operation class.
addSizeCountTrait(OpClass & opClass,StringRef traitKind,int numTotal,int numVariadic)2681 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
2682                               int numTotal, int numVariadic) {
2683   if (numVariadic != 0) {
2684     if (numTotal == numVariadic)
2685       opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
2686     else
2687       opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
2688                        Twine(numTotal - numVariadic) + ">::Impl");
2689     return;
2690   }
2691   switch (numTotal) {
2692   case 0:
2693     opClass.addTrait("::mlir::OpTrait::Zero" + traitKind + "s");
2694     break;
2695   case 1:
2696     opClass.addTrait("::mlir::OpTrait::One" + traitKind);
2697     break;
2698   default:
2699     opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
2700                      ">::Impl");
2701     break;
2702   }
2703 }
2704 
genTraits()2705 void OpEmitter::genTraits() {
2706   // Add region size trait.
2707   unsigned numRegions = op.getNumRegions();
2708   unsigned numVariadicRegions = op.getNumVariadicRegions();
2709   addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
2710 
2711   // Add result size traits.
2712   int numResults = op.getNumResults();
2713   int numVariadicResults = op.getNumVariableLengthResults();
2714   addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
2715 
2716   // For single result ops with a known specific type, generate a OneTypedResult
2717   // trait.
2718   if (numResults == 1 && numVariadicResults == 0) {
2719     auto cppName = op.getResults().begin()->constraint.getCPPClassName();
2720     opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
2721   }
2722 
2723   // Add successor size trait.
2724   unsigned numSuccessors = op.getNumSuccessors();
2725   unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
2726   addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
2727 
2728   // Add variadic size trait and normal op traits.
2729   int numOperands = op.getNumOperands();
2730   int numVariadicOperands = op.getNumVariableLengthOperands();
2731 
2732   // Add operand size trait.
2733   addSizeCountTrait(opClass, "Operand", numOperands, numVariadicOperands);
2734 
2735   // The op traits defined internal are ensured that they can be verified
2736   // earlier.
2737   for (const auto &trait : op.getTraits()) {
2738     if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
2739       if (opTrait->isStructuralOpTrait())
2740         opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2741     }
2742   }
2743 
2744   // OpInvariants wrapps the verifyInvariants which needs to be run before
2745   // native/interface traits and after all the traits with `StructuralOpTrait`.
2746   opClass.addTrait("::mlir::OpTrait::OpInvariants");
2747 
2748   // Add the native and interface traits.
2749   for (const auto &trait : op.getTraits()) {
2750     if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
2751       if (!opTrait->isStructuralOpTrait())
2752         opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2753     } else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) {
2754       opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2755     }
2756   }
2757 }
2758 
genOpNameGetter()2759 void OpEmitter::genOpNameGetter() {
2760   auto *method = opClass.addStaticMethod<Method::Constexpr>(
2761       "::llvm::StringLiteral", "getOperationName");
2762   ERROR_IF_PRUNED(method, "getOperationName", op);
2763   method->body() << "  return ::llvm::StringLiteral(\"" << op.getOperationName()
2764                  << "\");";
2765 }
2766 
genOpAsmInterface()2767 void OpEmitter::genOpAsmInterface() {
2768   // If the user only has one results or specifically added the Asm trait,
2769   // then don't generate it for them. We specifically only handle multi result
2770   // operations, because the name of a single result in the common case is not
2771   // interesting(generally 'result'/'output'/etc.).
2772   // TODO: We could also add a flag to allow operations to opt in to this
2773   // generation, even if they only have a single operation.
2774   int numResults = op.getNumResults();
2775   if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
2776     return;
2777 
2778   SmallVector<StringRef, 4> resultNames(numResults);
2779   for (int i = 0; i != numResults; ++i)
2780     resultNames[i] = op.getResultName(i);
2781 
2782   // Don't add the trait if none of the results have a valid name.
2783   if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
2784     return;
2785   opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
2786 
2787   // Generate the right accessor for the number of results.
2788   auto *method = opClass.addMethod(
2789       "void", "getAsmResultNames",
2790       MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
2791   ERROR_IF_PRUNED(method, "getAsmResultNames", op);
2792   auto &body = method->body();
2793   for (int i = 0; i != numResults; ++i) {
2794     body << "  auto resultGroup" << i << " = getODSResults(" << i << ");\n"
2795          << "  if (!llvm::empty(resultGroup" << i << "))\n"
2796          << "    setNameFn(*resultGroup" << i << ".begin(), \""
2797          << resultNames[i] << "\");\n";
2798   }
2799 }
2800 
2801 //===----------------------------------------------------------------------===//
2802 // OpOperandAdaptor emitter
2803 //===----------------------------------------------------------------------===//
2804 
2805 namespace {
2806 // Helper class to emit Op operand adaptors to an output stream.  Operand
2807 // adaptors are wrappers around ArrayRef<Value> that provide named operand
2808 // getters identical to those defined in the Op.
2809 class OpOperandAdaptorEmitter {
2810 public:
2811   static void
2812   emitDecl(const Operator &op,
2813            const StaticVerifierFunctionEmitter &staticVerifierEmitter,
2814            raw_ostream &os);
2815   static void
2816   emitDef(const Operator &op,
2817           const StaticVerifierFunctionEmitter &staticVerifierEmitter,
2818           raw_ostream &os);
2819 
2820 private:
2821   explicit OpOperandAdaptorEmitter(
2822       const Operator &op,
2823       const StaticVerifierFunctionEmitter &staticVerifierEmitter);
2824 
2825   // Add verification function. This generates a verify method for the adaptor
2826   // which verifies all the op-independent attribute constraints.
2827   void addVerification();
2828 
2829   // The operation for which to emit an adaptor.
2830   const Operator &op;
2831 
2832   // The generated adaptor class.
2833   Class adaptor;
2834 
2835   // The emitter containing all of the locally emitted verification functions.
2836   const StaticVerifierFunctionEmitter &staticVerifierEmitter;
2837 
2838   // Helper for emitting adaptor code.
2839   OpOrAdaptorHelper emitHelper;
2840 };
2841 } // namespace
2842 
OpOperandAdaptorEmitter(const Operator & op,const StaticVerifierFunctionEmitter & staticVerifierEmitter)2843 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
2844     const Operator &op,
2845     const StaticVerifierFunctionEmitter &staticVerifierEmitter)
2846     : op(op), adaptor(op.getAdaptorName()),
2847       staticVerifierEmitter(staticVerifierEmitter),
2848       emitHelper(op, /*emitForOp=*/false) {
2849   adaptor.addField("::mlir::ValueRange", "odsOperands");
2850   adaptor.addField("::mlir::DictionaryAttr", "odsAttrs");
2851   adaptor.addField("::mlir::RegionRange", "odsRegions");
2852   adaptor.addField("::llvm::Optional<::mlir::OperationName>", "odsOpName");
2853 
2854   const auto *attrSizedOperands =
2855       op.getTrait("::m::OpTrait::AttrSizedOperandSegments");
2856   {
2857     SmallVector<MethodParameter> paramList;
2858     paramList.emplace_back("::mlir::ValueRange", "values");
2859     paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
2860                            attrSizedOperands ? "" : "nullptr");
2861     paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
2862     auto *constructor = adaptor.addConstructor(std::move(paramList));
2863 
2864     constructor->addMemberInitializer("odsOperands", "values");
2865     constructor->addMemberInitializer("odsAttrs", "attrs");
2866     constructor->addMemberInitializer("odsRegions", "regions");
2867 
2868     MethodBody &body = constructor->body();
2869     body.indent() << "if (odsAttrs)\n";
2870     body.indent() << formatv(
2871         "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n",
2872         op.getOperationName());
2873   }
2874 
2875   {
2876     auto *constructor =
2877         adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
2878     constructor->addMemberInitializer("odsOperands", "op->getOperands()");
2879     constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
2880     constructor->addMemberInitializer("odsRegions", "op->getRegions()");
2881     constructor->addMemberInitializer("odsOpName", "op->getName()");
2882   }
2883 
2884   {
2885     auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands");
2886     ERROR_IF_PRUNED(m, "getOperands", op);
2887     m->body() << "  return odsOperands;";
2888   }
2889   std::string sizeAttrInit;
2890   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
2891     sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode,
2892                            emitHelper.getAttr(operandSegmentAttrName));
2893   }
2894   generateNamedOperandGetters(op, adaptor,
2895                               /*isAdaptor=*/true, sizeAttrInit,
2896                               /*rangeType=*/"::mlir::ValueRange",
2897                               /*rangeBeginCall=*/"odsOperands.begin()",
2898                               /*rangeSizeCall=*/"odsOperands.size()",
2899                               /*getOperandCallPattern=*/"odsOperands[{0}]");
2900 
2901   FmtContext fctx;
2902   fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
2903 
2904   // Generate named accessor with Attribute return type.
2905   auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName,
2906                                      Attribute attr) {
2907     auto *method = adaptor.addMethod(attr.getStorageType(), emitName + "Attr");
2908     ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
2909     auto &body = method->body().indent();
2910     body << "assert(odsAttrs && \"no attributes when constructing adapter\");\n"
2911          << formatv("auto attr = {0}.{1}<{2}>();\n", emitHelper.getAttr(name),
2912                     attr.hasDefaultValue() || attr.isOptional()
2913                         ? "dyn_cast_or_null"
2914                         : "cast",
2915                     attr.getStorageType());
2916 
2917     if (attr.hasDefaultValue()) {
2918       // Use the default value if attribute is not set.
2919       // TODO: this is inefficient, we are recreating the attribute for every
2920       // call. This should be set instead.
2921       std::string defaultValue = std::string(
2922           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2923       body << "  if (!attr)\n    attr = " << defaultValue << ";\n";
2924     }
2925     body << "return attr;\n";
2926   };
2927 
2928   {
2929     auto *m = adaptor.addMethod("::mlir::DictionaryAttr", "getAttributes");
2930     ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
2931     m->body() << "  return odsAttrs;";
2932   }
2933   for (auto &namedAttr : op.getAttributes()) {
2934     const auto &name = namedAttr.name;
2935     const auto &attr = namedAttr.attr;
2936     if (attr.isDerivedAttr())
2937       continue;
2938     for (const auto &emitName : op.getGetterNames(name)) {
2939       emitAttrWithStorageType(name, emitName, attr);
2940       emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr);
2941     }
2942   }
2943 
2944   unsigned numRegions = op.getNumRegions();
2945   if (numRegions > 0) {
2946     auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions");
2947     ERROR_IF_PRUNED(m, "Adaptor::getRegions", op);
2948     m->body() << "  return odsRegions;";
2949   }
2950   for (unsigned i = 0; i < numRegions; ++i) {
2951     const auto &region = op.getRegion(i);
2952     if (region.name.empty())
2953       continue;
2954 
2955     // Generate the accessors for a variadic region.
2956     for (StringRef name : op.getGetterNames(region.name)) {
2957       if (region.isVariadic()) {
2958         auto *m = adaptor.addMethod("::mlir::RegionRange", name);
2959         ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
2960         m->body() << formatv("  return odsRegions.drop_front({0});", i);
2961         continue;
2962       }
2963 
2964       auto *m = adaptor.addMethod("::mlir::Region &", name);
2965       ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
2966       m->body() << formatv("  return *odsRegions[{0}];", i);
2967     }
2968   }
2969 
2970   // Add verification function.
2971   addVerification();
2972   adaptor.finalize();
2973 }
2974 
addVerification()2975 void OpOperandAdaptorEmitter::addVerification() {
2976   auto *method = adaptor.addMethod("::mlir::LogicalResult", "verify",
2977                                    MethodParameter("::mlir::Location", "loc"));
2978   ERROR_IF_PRUNED(method, "verify", op);
2979   auto &body = method->body();
2980 
2981   FmtContext verifyCtx;
2982   populateSubstitutions(emitHelper, verifyCtx);
2983   genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
2984 
2985   body << "  return ::mlir::success();";
2986 }
2987 
emitDecl(const Operator & op,const StaticVerifierFunctionEmitter & staticVerifierEmitter,raw_ostream & os)2988 void OpOperandAdaptorEmitter::emitDecl(
2989     const Operator &op,
2990     const StaticVerifierFunctionEmitter &staticVerifierEmitter,
2991     raw_ostream &os) {
2992   OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os);
2993 }
2994 
emitDef(const Operator & op,const StaticVerifierFunctionEmitter & staticVerifierEmitter,raw_ostream & os)2995 void OpOperandAdaptorEmitter::emitDef(
2996     const Operator &op,
2997     const StaticVerifierFunctionEmitter &staticVerifierEmitter,
2998     raw_ostream &os) {
2999   OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os);
3000 }
3001 
3002 // Emits the opcode enum and op classes.
emitOpClasses(const RecordKeeper & recordKeeper,const std::vector<Record * > & defs,raw_ostream & os,bool emitDecl)3003 static void emitOpClasses(const RecordKeeper &recordKeeper,
3004                           const std::vector<Record *> &defs, raw_ostream &os,
3005                           bool emitDecl) {
3006   // First emit forward declaration for each class, this allows them to refer
3007   // to each others in traits for example.
3008   if (emitDecl) {
3009     os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
3010     os << "#undef GET_OP_FWD_DEFINES\n";
3011     for (auto *def : defs) {
3012       Operator op(*def);
3013       NamespaceEmitter emitter(os, op.getCppNamespace());
3014       os << "class " << op.getCppClassName() << ";\n";
3015     }
3016     os << "#endif\n\n";
3017   }
3018 
3019   IfDefScope scope("GET_OP_CLASSES", os);
3020   if (defs.empty())
3021     return;
3022 
3023   // Generate all of the locally instantiated methods first.
3024   StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
3025   os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
3026   staticVerifierEmitter.emitOpConstraints(defs, emitDecl);
3027 
3028   for (auto *def : defs) {
3029     Operator op(*def);
3030     if (emitDecl) {
3031       {
3032         NamespaceEmitter emitter(os, op.getCppNamespace());
3033         os << formatv(opCommentHeader, op.getQualCppClassName(),
3034                       "declarations");
3035         OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os);
3036         OpEmitter::emitDecl(op, os, staticVerifierEmitter);
3037       }
3038       // Emit the TypeID explicit specialization to have a single definition.
3039       if (!op.getCppNamespace().empty())
3040         os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
3041            << "::" << op.getCppClassName() << ")\n\n";
3042     } else {
3043       {
3044         NamespaceEmitter emitter(os, op.getCppNamespace());
3045         os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
3046         OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os);
3047         OpEmitter::emitDef(op, os, staticVerifierEmitter);
3048       }
3049       // Emit the TypeID explicit specialization to have a single definition.
3050       if (!op.getCppNamespace().empty())
3051         os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
3052            << "::" << op.getCppClassName() << ")\n\n";
3053     }
3054   }
3055 }
3056 
3057 // Emits a comma-separated list of the ops.
emitOpList(const std::vector<Record * > & defs,raw_ostream & os)3058 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
3059   IfDefScope scope("GET_OP_LIST", os);
3060 
3061   interleave(
3062       // TODO: We are constructing the Operator wrapper instance just for
3063       // getting it's qualified class name here. Reduce the overhead by having a
3064       // lightweight version of Operator class just for that purpose.
3065       defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
3066       [&os]() { os << ",\n"; });
3067 }
3068 
emitOpDecls(const RecordKeeper & recordKeeper,raw_ostream & os)3069 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
3070   emitSourceFileHeader("Op Declarations", os);
3071 
3072   std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
3073   emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
3074 
3075   return false;
3076 }
3077 
emitOpDefs(const RecordKeeper & recordKeeper,raw_ostream & os)3078 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
3079   emitSourceFileHeader("Op Definitions", os);
3080 
3081   std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
3082   emitOpList(defs, os);
3083   emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
3084 
3085   return false;
3086 }
3087 
3088 static mlir::GenRegistration
3089     genOpDecls("gen-op-decls", "Generate op declarations",
__anon5e0e87942302(const RecordKeeper &records, raw_ostream &os) 3090                [](const RecordKeeper &records, raw_ostream &os) {
3091                  return emitOpDecls(records, os);
3092                });
3093 
3094 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
3095                                        [](const RecordKeeper &records,
__anon5e0e87942402(const RecordKeeper &records, raw_ostream &os) 3096                                           raw_ostream &os) {
3097                                          return emitOpDefs(records, os);
3098                                        });
3099