1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "OpClass.h"
15 #include "OpFormatGen.h"
16 #include "OpGenHelpers.h"
17 #include "mlir/TableGen/Class.h"
18 #include "mlir/TableGen/CodeGenHelpers.h"
19 #include "mlir/TableGen/Format.h"
20 #include "mlir/TableGen/GenInfo.h"
21 #include "mlir/TableGen/Interfaces.h"
22 #include "mlir/TableGen/Operator.h"
23 #include "mlir/TableGen/SideEffects.h"
24 #include "mlir/TableGen/Trait.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringSet.h"
29 #include "llvm/ADT/StringSwitch.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/Signals.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Record.h"
34 #include "llvm/TableGen/TableGenBackend.h"
35 
36 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
37 
38 using namespace llvm;
39 using namespace mlir;
40 using namespace mlir::tblgen;
41 
42 static const char *const tblgenNamePrefix = "tblgen_";
43 static const char *const generatedArgName = "odsArg";
44 static const char *const odsBuilder = "odsBuilder";
45 static const char *const builderOpState = "odsState";
46 
47 /// The names of the implicit attributes that contain variadic operand and
48 /// result segment sizes.
49 static const char *const operandSegmentAttrName = "operand_segment_sizes";
50 static const char *const resultSegmentAttrName = "result_segment_sizes";
51 
52 /// Code for an Op to lookup an attribute. Uses cached identifiers and subrange
53 /// lookup.
54 ///
55 /// {0}: Code snippet to get the attribute's name or identifier.
56 /// {1}: The lower bound on the sorted subrange.
57 /// {2}: The upper bound on the sorted subrange.
58 /// {3}: Code snippet to get the array of named attributes.
59 /// {4}: "Named" to get the named attribute.
60 static const char *const subrangeGetAttr =
61     "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - "
62     "{2}, {0})";
63 
64 /// The logic to calculate the actual value range for a declared operand/result
65 /// of an op with variadic operands/results. Note that this logic is not for
66 /// general use; it assumes all variadic operands/results must have the same
67 /// number of values.
68 ///
69 /// {0}: The list of whether each declared operand/result is variadic.
70 /// {1}: The total number of non-variadic operands/results.
71 /// {2}: The total number of variadic operands/results.
72 /// {3}: The total number of actual values.
73 /// {4}: "operand" or "result".
74 static const char *const sameVariadicSizeValueRangeCalcCode = R"(
75   bool isVariadic[] = {{{0}};
76   int prevVariadicCount = 0;
77   for (unsigned i = 0; i < index; ++i)
78     if (isVariadic[i]) ++prevVariadicCount;
79 
80   // Calculate how many dynamic values a static variadic {4} corresponds to.
81   // This assumes all static variadic {4}s have the same dynamic value count.
82   int variadicSize = ({3} - {1}) / {2};
83   // `index` passed in as the parameter is the static index which counts each
84   // {4} (variadic or not) as size 1. So here for each previous static variadic
85   // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
86   // value pack for this static {4} starts.
87   int start = index + (variadicSize - 1) * prevVariadicCount;
88   int size = isVariadic[index] ? variadicSize : 1;
89   return {{start, size};
90 )";
91 
92 /// The logic to calculate the actual value range for a declared operand/result
93 /// of an op with variadic operands/results. Note that this logic is assumes
94 /// the op has an attribute specifying the size of each operand/result segment
95 /// (variadic or not).
96 static const char *const attrSizedSegmentValueRangeCalcCode = R"(
97   const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin<uint32_t>();
98   if (sizeAttr.isSplat())
99     return {*sizeAttrValueIt * index, *sizeAttrValueIt};
100 
101   unsigned start = 0;
102   for (unsigned i = 0; i < index; ++i)
103     start += sizeAttrValueIt[i];
104   return {start, sizeAttrValueIt[index]};
105 )";
106 /// The code snippet to initialize the sizes for the value range calculation.
107 ///
108 /// {0}: The code to get the attribute.
109 static const char *const adapterSegmentSizeAttrInitCode = R"(
110   assert(odsAttrs && "missing segment size attribute for op");
111   auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>();
112 )";
113 /// The code snippet to initialize the sizes for the value range calculation.
114 ///
115 /// {0}: The code to get the attribute.
116 static const char *const opSegmentSizeAttrInitCode = R"(
117   auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>();
118 )";
119 
120 /// The logic to calculate the actual value range for a declared operand
121 /// of an op with variadic of variadic operands within the OpAdaptor.
122 ///
123 /// {0}: The name of the segment attribute.
124 /// {1}: The index of the main operand.
125 static const char *const variadicOfVariadicAdaptorCalcCode = R"(
126   auto tblgenTmpOperands = getODSOperands({1});
127   auto sizeAttrValues = {0}().getValues<uint32_t>();
128   auto sizeAttrIt = sizeAttrValues.begin();
129 
130   ::llvm::SmallVector<::mlir::ValueRange> tblgenTmpOperandGroups;
131   for (int i = 0, e = ::llvm::size(sizeAttrValues); i < e; ++i, ++sizeAttrIt) {{
132     tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(*sizeAttrIt));
133     tblgenTmpOperands = tblgenTmpOperands.drop_front(*sizeAttrIt);
134   }
135   return tblgenTmpOperandGroups;
136 )";
137 
138 /// The logic to build a range of either operand or result values.
139 ///
140 /// {0}: The begin iterator of the actual values.
141 /// {1}: The call to generate the start and length of the value range.
142 static const char *const valueRangeReturnCode = R"(
143   auto valueRange = {1};
144   return {{std::next({0}, valueRange.first),
145            std::next({0}, valueRange.first + valueRange.second)};
146 )";
147 
148 /// A header for indicating code sections.
149 ///
150 /// {0}: Some text, or a class name.
151 /// {1}: Some text.
152 static const char *const opCommentHeader = R"(
153 //===----------------------------------------------------------------------===//
154 // {0} {1}
155 //===----------------------------------------------------------------------===//
156 
157 )";
158 
159 //===----------------------------------------------------------------------===//
160 // Utility structs and functions
161 //===----------------------------------------------------------------------===//
162 
163 // Replaces all occurrences of `match` in `str` with `substitute`.
164 static std::string replaceAllSubstrs(std::string str, const std::string &match,
165                                      const std::string &substitute) {
166   std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
167   while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
168     str = str.replace(matchLoc, match.size(), substitute);
169     scanLoc = matchLoc + substitute.size();
170   }
171   return str;
172 }
173 
174 // Returns whether the record has a value of the given name that can be returned
175 // via getValueAsString.
176 static inline bool hasStringAttribute(const Record &record,
177                                       StringRef fieldName) {
178   auto *valueInit = record.getValueInit(fieldName);
179   return isa<StringInit>(valueInit);
180 }
181 
182 static std::string getArgumentName(const Operator &op, int index) {
183   const auto &operand = op.getOperand(index);
184   if (!operand.name.empty())
185     return std::string(operand.name);
186   return std::string(formatv("{0}_{1}", generatedArgName, index));
187 }
188 
189 // Returns true if we can use unwrapped value for the given `attr` in builders.
190 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
191   return attr.getReturnType() != attr.getStorageType() &&
192          // We need to wrap the raw value into an attribute in the builder impl
193          // so we need to make sure that the attribute specifies how to do that.
194          !attr.getConstBuilderTemplate().empty();
195 }
196 
197 namespace {
198 /// Metadata on a registered attribute. Given that attributes are stored in
199 /// sorted order on operations, we can use information from ODS to deduce the
200 /// number of required attributes less and and greater than each attribute,
201 /// allowing us to search only a subrange of the attributes in ODS-generated
202 /// getters.
203 struct AttributeMetadata {
204   /// The attribute name.
205   StringRef attrName;
206   /// Whether the attribute is required.
207   bool isRequired;
208   /// The ODS attribute constraint. Not present for implicit attributes.
209   Optional<Attribute> constraint;
210   /// The number of required attributes less than this attribute.
211   unsigned lowerBound = 0;
212   /// The number of required attributes greater than this attribute.
213   unsigned upperBound = 0;
214 };
215 
216 /// Helper class to select between OpAdaptor and Op code templates.
217 class OpOrAdaptorHelper {
218 public:
219   OpOrAdaptorHelper(const Operator &op, bool emitForOp)
220       : op(op), emitForOp(emitForOp) {
221     computeAttrMetadata();
222   }
223 
224   /// Object that wraps a functor in a stream operator for interop with
225   /// llvm::formatv.
226   class Formatter {
227   public:
228     template <typename Functor>
229     Formatter(Functor &&func) : func(std::forward<Functor>(func)) {}
230 
231     std::string str() const {
232       std::string result;
233       llvm::raw_string_ostream os(result);
234       os << *this;
235       return os.str();
236     }
237 
238   private:
239     std::function<raw_ostream &(raw_ostream &)> func;
240 
241     friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) {
242       return fmt.func(os);
243     }
244   };
245 
246   // Generate code for getting an attribute.
247   Formatter getAttr(StringRef attrName, bool isNamed = false) const {
248     assert(attrMetadata.count(attrName) && "expected attribute metadata");
249     return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & {
250       const AttributeMetadata &attr = attrMetadata.find(attrName)->second;
251       return os << formatv(subrangeGetAttr, getAttrName(attrName),
252                            attr.lowerBound, attr.upperBound, getAttrRange(),
253                            isNamed ? "Named" : "");
254     };
255   }
256 
257   // Generate code for getting the name of an attribute.
258   Formatter getAttrName(StringRef attrName) const {
259     return [this, attrName](raw_ostream &os) -> raw_ostream & {
260       if (emitForOp)
261         return os << op.getGetterName(attrName) << "AttrName()";
262       return os << formatv("{0}::{1}AttrName(*odsOpName)", op.getCppClassName(),
263                            op.getGetterName(attrName));
264     };
265   }
266 
267   // Get the code snippet for getting the named attribute range.
268   StringRef getAttrRange() const {
269     return emitForOp ? "(*this)->getAttrs()" : "odsAttrs";
270   }
271 
272   // Get the prefix code for emitting an error.
273   Formatter emitErrorPrefix() const {
274     return [this](raw_ostream &os) -> raw_ostream & {
275       if (emitForOp)
276         return os << "emitOpError(";
277       return os << formatv("emitError(loc, \"'{0}' op \"",
278                            op.getOperationName());
279     };
280   }
281 
282   // Get the call to get an operand or segment of operands.
283   Formatter getOperand(unsigned index) const {
284     return [this, index](raw_ostream &os) -> raw_ostream & {
285       return os << formatv(op.getOperand(index).isVariadic()
286                                ? "this->getODSOperands({0})"
287                                : "(*this->getODSOperands({0}).begin())",
288                            index);
289     };
290   }
291 
292   // Get the call to get a result of segment of results.
293   Formatter getResult(unsigned index) const {
294     return [this, index](raw_ostream &os) -> raw_ostream & {
295       if (!emitForOp)
296         return os << "<no results should be generated>";
297       return os << formatv(op.getResult(index).isVariadic()
298                                ? "this->getODSResults({0})"
299                                : "(*this->getODSResults({0}).begin())",
300                            index);
301     };
302   }
303 
304   // Return whether an op instance is available.
305   bool isEmittingForOp() const { return emitForOp; }
306 
307   // Return the ODS operation wrapper.
308   const Operator &getOp() const { return op; }
309 
310   // Get the attribute metadata sorted by name.
311   const llvm::MapVector<StringRef, AttributeMetadata> &getAttrMetadata() const {
312     return attrMetadata;
313   }
314 
315 private:
316   // Compute the attribute metadata.
317   void computeAttrMetadata();
318 
319   // The operation ODS wrapper.
320   const Operator &op;
321   // True if code is being generate for an op. False for an adaptor.
322   const bool emitForOp;
323 
324   // The attribute metadata, mapped by name.
325   llvm::MapVector<StringRef, AttributeMetadata> attrMetadata;
326   // The number of required attributes.
327   unsigned numRequired;
328 };
329 
330 } // namespace
331 
332 void OpOrAdaptorHelper::computeAttrMetadata() {
333   // Enumerate the attribute names of this op, ensuring the attribute names are
334   // unique in case implicit attributes are explicitly registered.
335   for (const NamedAttribute &namedAttr : op.getAttributes()) {
336     Attribute attr = namedAttr.attr;
337     bool isOptional =
338         attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr();
339     attrMetadata.insert(
340         {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}});
341   }
342   // Include key attributes from several traits as implicitly registered.
343   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
344     attrMetadata.insert(
345         {operandSegmentAttrName,
346          AttributeMetadata{operandSegmentAttrName, /*isRequired=*/true,
347                            /*attr=*/llvm::None}});
348   }
349   if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
350     attrMetadata.insert(
351         {resultSegmentAttrName,
352          AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true,
353                            /*attr=*/llvm::None}});
354   }
355 
356   // Store the metadata in sorted order.
357   SmallVector<AttributeMetadata> sortedAttrMetadata =
358       llvm::to_vector(llvm::make_second_range(attrMetadata.takeVector()));
359   llvm::sort(sortedAttrMetadata,
360              [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) {
361                return lhs.attrName < rhs.attrName;
362              });
363 
364   // Compute the subrange bounds for each attribute.
365   numRequired = 0;
366   for (AttributeMetadata &attr : sortedAttrMetadata) {
367     attr.lowerBound = numRequired;
368     numRequired += attr.isRequired;
369   };
370   for (AttributeMetadata &attr : sortedAttrMetadata)
371     attr.upperBound = numRequired - attr.lowerBound - attr.isRequired;
372 
373   // Store the results back into the map.
374   for (const AttributeMetadata &attr : sortedAttrMetadata)
375     attrMetadata.insert({attr.attrName, attr});
376 }
377 
378 //===----------------------------------------------------------------------===//
379 // Op emitter
380 //===----------------------------------------------------------------------===//
381 
382 namespace {
383 // Helper class to emit a record into the given output stream.
384 class OpEmitter {
385 public:
386   static void
387   emitDecl(const Operator &op, raw_ostream &os,
388            const StaticVerifierFunctionEmitter &staticVerifierEmitter);
389   static void
390   emitDef(const Operator &op, raw_ostream &os,
391           const StaticVerifierFunctionEmitter &staticVerifierEmitter);
392 
393 private:
394   OpEmitter(const Operator &op,
395             const StaticVerifierFunctionEmitter &staticVerifierEmitter);
396 
397   void emitDecl(raw_ostream &os);
398   void emitDef(raw_ostream &os);
399 
400   // Generate methods for accessing the attribute names of this operation.
401   void genAttrNameGetters();
402 
403   // Generates the OpAsmOpInterface for this operation if possible.
404   void genOpAsmInterface();
405 
406   // Generates the `getOperationName` method for this op.
407   void genOpNameGetter();
408 
409   // Generates getters for the attributes.
410   void genAttrGetters();
411 
412   // Generates setter for the attributes.
413   void genAttrSetters();
414 
415   // Generates removers for optional attributes.
416   void genOptionalAttrRemovers();
417 
418   // Generates getters for named operands.
419   void genNamedOperandGetters();
420 
421   // Generates setters for named operands.
422   void genNamedOperandSetters();
423 
424   // Generates getters for named results.
425   void genNamedResultGetters();
426 
427   // Generates getters for named regions.
428   void genNamedRegionGetters();
429 
430   // Generates getters for named successors.
431   void genNamedSuccessorGetters();
432 
433   // Generates builder methods for the operation.
434   void genBuilder();
435 
436   // Generates the build() method that takes each operand/attribute
437   // as a stand-alone parameter.
438   void genSeparateArgParamBuilder();
439 
440   // Generates the build() method that takes each operand/attribute as a
441   // stand-alone parameter. The generated build() method uses first operand's
442   // type as all results' types.
443   void genUseOperandAsResultTypeSeparateParamBuilder();
444 
445   // Generates the build() method that takes all operands/attributes
446   // collectively as one parameter. The generated build() method uses first
447   // operand's type as all results' types.
448   void genUseOperandAsResultTypeCollectiveParamBuilder();
449 
450   // Generates the build() method that takes aggregate operands/attributes
451   // parameters. This build() method uses inferred types as result types.
452   // Requires: The type needs to be inferable via InferTypeOpInterface.
453   void genInferredTypeCollectiveParamBuilder();
454 
455   // Generates the build() method that takes each operand/attribute as a
456   // stand-alone parameter. The generated build() method uses first attribute's
457   // type as all result's types.
458   void genUseAttrAsResultTypeBuilder();
459 
460   // Generates the build() method that takes all result types collectively as
461   // one parameter. Similarly for operands and attributes.
462   void genCollectiveParamBuilder();
463 
464   // The kind of parameter to generate for result types in builders.
465   enum class TypeParamKind {
466     None,       // No result type in parameter list.
467     Separate,   // A separate parameter for each result type.
468     Collective, // An ArrayRef<Type> for all result types.
469   };
470 
471   // The kind of parameter to generate for attributes in builders.
472   enum class AttrParamKind {
473     WrappedAttr,    // A wrapped MLIR Attribute instance.
474     UnwrappedValue, // A raw value without MLIR Attribute wrapper.
475   };
476 
477   // Builds the parameter list for build() method of this op. This method writes
478   // to `paramList` the comma-separated parameter list and updates
479   // `resultTypeNames` with the names for parameters for specifying result
480   // types. `inferredAttributes` is populated with any attributes that are
481   // elided from the build list. The given `typeParamKind` and `attrParamKind`
482   // controls how result types and attributes are placed in the parameter list.
483   void buildParamList(SmallVectorImpl<MethodParameter> &paramList,
484                       llvm::StringSet<> &inferredAttributes,
485                       SmallVectorImpl<std::string> &resultTypeNames,
486                       TypeParamKind typeParamKind,
487                       AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
488 
489   // Adds op arguments and regions into operation state for build() methods.
490   void
491   genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
492                                          llvm::StringSet<> &inferredAttributes,
493                                          bool isRawValueAttr = false);
494 
495   // Generates canonicalizer declaration for the operation.
496   void genCanonicalizerDecls();
497 
498   // Generates the folder declaration for the operation.
499   void genFolderDecls();
500 
501   // Generates the parser for the operation.
502   void genParser();
503 
504   // Generates the printer for the operation.
505   void genPrinter();
506 
507   // Generates verify method for the operation.
508   void genVerifier();
509 
510   // Generates custom verify methods for the operation.
511   void genCustomVerifier();
512 
513   // Generates verify statements for operands and results in the operation.
514   // The generated code will be attached to `body`.
515   void genOperandResultVerifier(MethodBody &body,
516                                 Operator::const_value_range values,
517                                 StringRef valueKind);
518 
519   // Generates verify statements for regions in the operation.
520   // The generated code will be attached to `body`.
521   void genRegionVerifier(MethodBody &body);
522 
523   // Generates verify statements for successors in the operation.
524   // The generated code will be attached to `body`.
525   void genSuccessorVerifier(MethodBody &body);
526 
527   // Generates the traits used by the object.
528   void genTraits();
529 
530   // Generate the OpInterface methods for all interfaces.
531   void genOpInterfaceMethods();
532 
533   // Generate op interface methods for the given interface.
534   void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
535 
536   // Generate op interface method for the given interface method. If
537   // 'declaration' is true, generates a declaration, else a definition.
538   Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
539                                bool declaration = true);
540 
541   // Generate the side effect interface methods.
542   void genSideEffectInterfaceMethods();
543 
544   // Generate the type inference interface methods.
545   void genTypeInterfaceMethods();
546 
547 private:
548   // The TableGen record for this op.
549   // TODO: OpEmitter should not have a Record directly,
550   // it should rather go through the Operator for better abstraction.
551   const Record &def;
552 
553   // The wrapper operator class for querying information from this op.
554   const Operator &op;
555 
556   // The C++ code builder for this op
557   OpClass opClass;
558 
559   // The format context for verification code generation.
560   FmtContext verifyCtx;
561 
562   // The emitter containing all of the locally emitted verification functions.
563   const StaticVerifierFunctionEmitter &staticVerifierEmitter;
564 
565   // Helper for emitting op code.
566   OpOrAdaptorHelper emitHelper;
567 };
568 
569 } // namespace
570 
571 // Populate the format context `ctx` with substitutions of attributes, operands
572 // and results.
573 static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
574                                   FmtContext &ctx) {
575   // Populate substitutions for attributes.
576   auto &op = emitHelper.getOp();
577   for (const auto &namedAttr : op.getAttributes())
578     ctx.addSubst(namedAttr.name, emitHelper.getAttr(namedAttr.name).str());
579 
580   // Populate substitutions for named operands.
581   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
582     auto &value = op.getOperand(i);
583     if (!value.name.empty())
584       ctx.addSubst(value.name, emitHelper.getOperand(i).str());
585   }
586 
587   // Populate substitutions for results.
588   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
589     auto &value = op.getResult(i);
590     if (!value.name.empty())
591       ctx.addSubst(value.name, emitHelper.getResult(i).str());
592   }
593 }
594 
595 /// Generate verification on native traits requiring attributes.
596 static void genNativeTraitAttrVerifier(MethodBody &body,
597                                        const OpOrAdaptorHelper &emitHelper) {
598   // Check that the variadic segment sizes attribute exists and contains the
599   // expected number of elements.
600   //
601   // {0}: Attribute name.
602   // {1}: Expected number of elements.
603   // {2}: "operand" or "result".
604   // {3}: Emit error prefix.
605   const char *const checkAttrSizedValueSegmentsCode = R"(
606   {
607     auto sizeAttr = tblgen_{0}.cast<::mlir::DenseIntElementsAttr>();
608     auto numElements =
609         sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
610     if (numElements != {1})
611       return {3}"'{0}' attribute for specifying {2} segments must have {1} "
612                 "elements, but got ") << numElements;
613   }
614   )";
615 
616   // Verify a few traits first so that we can use getODSOperands() and
617   // getODSResults() in the rest of the verifier.
618   auto &op = emitHelper.getOp();
619   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
620     body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName,
621                     op.getNumOperands(), "operand",
622                     emitHelper.emitErrorPrefix());
623   }
624   if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
625     body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName,
626                     op.getNumResults(), "result", emitHelper.emitErrorPrefix());
627   }
628 }
629 
630 // Generate attribute verification. If an op instance is not available, then
631 // attribute checks that require one will not be emitted.
632 //
633 // Attribute verification is performed as follows:
634 //
635 // 1. Verify that all required attributes are present in sorted order. This
636 // ensures that we can use subrange lookup even with potentially missing
637 // attributes.
638 // 2. Verify native trait attributes so that other attributes may call methods
639 // that depend on the validity of these attributes, e.g. segment size attributes
640 // and operand or result getters.
641 // 3. Verify the constraints on all present attributes.
642 static void genAttributeVerifier(
643     const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body,
644     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
645   if (emitHelper.getAttrMetadata().empty())
646     return;
647 
648   // Verify the attribute if it is present. This assumes that default values
649   // are valid. This code snippet pastes the condition inline.
650   //
651   // TODO: verify the default value is valid (perhaps in debug mode only).
652   //
653   // {0}: Attribute variable name.
654   // {1}: Attribute condition code.
655   // {2}: Emit error prefix.
656   // {3}: Attribute name.
657   // {4}: Attribute/constraint description.
658   const char *const verifyAttrInline = R"(
659   if ({0} && !({1}))
660     return {2}"attribute '{3}' failed to satisfy constraint: {4}");
661 )";
662   // Verify the attribute using a uniqued constraint. Can only be used within
663   // the context of an op.
664   //
665   // {0}: Unique constraint name.
666   // {1}: Attribute variable name.
667   // {2}: Attribute name.
668   const char *const verifyAttrUnique = R"(
669   if (::mlir::failed({0}(*this, {1}, "{2}")))
670     return ::mlir::failure();
671 )";
672 
673   // Traverse the array until the required attribute is found. Return an error
674   // if the traversal reached the end.
675   //
676   // {0}: Code to get the name of the attribute.
677   // {1}: The emit error prefix.
678   // {2}: The name of the attribute.
679   const char *const findRequiredAttr = R"(while (true) {{
680   if (namedAttrIt == namedAttrRange.end())
681     return {1}"requires attribute '{2}'");
682   if (namedAttrIt->getName() == {0}) {{
683     tblgen_{2} = namedAttrIt->getValue();
684     break;
685   })";
686 
687   // Emit a check to see if the iteration has encountered an optional attribute.
688   //
689   // {0}: Code to get the name of the attribute.
690   // {1}: The name of the attribute.
691   const char *const checkOptionalAttr = R"(
692   else if (namedAttrIt->getName() == {0}) {{
693     tblgen_{1} = namedAttrIt->getValue();
694   })";
695 
696   // Emit the start of the loop for checking trailing attributes.
697   const char *const checkTrailingAttrs = R"(while (true) {
698   if (namedAttrIt == namedAttrRange.end()) {
699     break;
700   })";
701 
702   // Return true if a verifier can be emitted for the attribute: it is not a
703   // derived attribute, it has a predicate, its condition is not empty, and, for
704   // adaptors, the condition does not reference the op.
705   const auto canEmitVerifier = [&](Attribute attr) {
706     if (attr.isDerivedAttr())
707       return false;
708     Pred pred = attr.getPredicate();
709     if (pred.isNull())
710       return false;
711     std::string condition = pred.getCondition();
712     return !condition.empty() && (!StringRef(condition).contains("$_op") ||
713                                   emitHelper.isEmittingForOp());
714   };
715 
716   // Emit the verifier for the attribute.
717   const auto emitVerifier = [&](Attribute attr, StringRef attrName,
718                                 StringRef varName) {
719     std::string condition = attr.getPredicate().getCondition();
720 
721     Optional<StringRef> constraintFn;
722     if (emitHelper.isEmittingForOp() &&
723         (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) {
724       body << formatv(verifyAttrUnique, *constraintFn, varName, attrName);
725     } else {
726       body << formatv(verifyAttrInline, varName,
727                       tgfmt(condition, &ctx.withSelf(varName)),
728                       emitHelper.emitErrorPrefix(), attrName,
729                       escapeString(attr.getSummary()));
730     }
731   };
732 
733   // Prefix variables with `tblgen_` to avoid hiding the attribute accessor.
734   const auto getVarName = [&](StringRef attrName) {
735     return (tblgenNamePrefix + attrName).str();
736   };
737 
738   body.indent() << formatv("auto namedAttrRange = {0};\n",
739                            emitHelper.getAttrRange());
740   body << "auto namedAttrIt = namedAttrRange.begin();\n";
741 
742   // Iterate over the attributes in sorted order. Keep track of the optional
743   // attributes that may be encountered along the way.
744   SmallVector<const AttributeMetadata *> optionalAttrs;
745   for (const std::pair<StringRef, AttributeMetadata> &it :
746        emitHelper.getAttrMetadata()) {
747     const AttributeMetadata &metadata = it.second;
748     if (!metadata.isRequired) {
749       optionalAttrs.push_back(&metadata);
750       continue;
751     }
752 
753     body << formatv("::mlir::Attribute {0};\n", getVarName(it.first));
754     for (const AttributeMetadata *optional : optionalAttrs) {
755       body << formatv("::mlir::Attribute {0};\n",
756                       getVarName(optional->attrName));
757     }
758     body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first),
759                     emitHelper.emitErrorPrefix(), it.first);
760     for (const AttributeMetadata *optional : optionalAttrs) {
761       body << formatv(checkOptionalAttr,
762                       emitHelper.getAttrName(optional->attrName),
763                       optional->attrName);
764     }
765     body << "\n  ++namedAttrIt;\n}\n";
766     optionalAttrs.clear();
767   }
768   // Get trailing optional attributes.
769   if (!optionalAttrs.empty()) {
770     for (const AttributeMetadata *optional : optionalAttrs) {
771       body << formatv("::mlir::Attribute {0};\n",
772                       getVarName(optional->attrName));
773     }
774     body << checkTrailingAttrs;
775     for (const AttributeMetadata *optional : optionalAttrs) {
776       body << formatv(checkOptionalAttr,
777                       emitHelper.getAttrName(optional->attrName),
778                       optional->attrName);
779     }
780     body << "\n  ++namedAttrIt;\n}\n";
781   }
782   body.unindent();
783 
784   // Emit the checks for segment attributes first so that the other constraints
785   // can call operand and result getters.
786   genNativeTraitAttrVerifier(body, emitHelper);
787 
788   for (const auto &namedAttr : emitHelper.getOp().getAttributes())
789     if (canEmitVerifier(namedAttr.attr))
790       emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name));
791 }
792 
793 /// Op extra class definitions have a `$cppClass` substitution that is to be
794 /// replaced by the C++ class name.
795 static std::string formatExtraDefinitions(const Operator &op) {
796   FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName());
797   return tgfmt(op.getExtraClassDefinition(), &ctx).str();
798 }
799 
800 OpEmitter::OpEmitter(const Operator &op,
801                      const StaticVerifierFunctionEmitter &staticVerifierEmitter)
802     : def(op.getDef()), op(op),
803       opClass(op.getCppClassName(), op.getExtraClassDeclaration(),
804               formatExtraDefinitions(op)),
805       staticVerifierEmitter(staticVerifierEmitter),
806       emitHelper(op, /*emitForOp=*/true) {
807   verifyCtx.withOp("(*this->getOperation())");
808   verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
809 
810   genTraits();
811 
812   // Generate C++ code for various op methods. The order here determines the
813   // methods in the generated file.
814   genAttrNameGetters();
815   genOpAsmInterface();
816   genOpNameGetter();
817   genNamedOperandGetters();
818   genNamedOperandSetters();
819   genNamedResultGetters();
820   genNamedRegionGetters();
821   genNamedSuccessorGetters();
822   genAttrGetters();
823   genAttrSetters();
824   genOptionalAttrRemovers();
825   genBuilder();
826   genParser();
827   genPrinter();
828   genVerifier();
829   genCustomVerifier();
830   genCanonicalizerDecls();
831   genFolderDecls();
832   genTypeInterfaceMethods();
833   genOpInterfaceMethods();
834   generateOpFormat(op, opClass);
835   genSideEffectInterfaceMethods();
836 }
837 void OpEmitter::emitDecl(
838     const Operator &op, raw_ostream &os,
839     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
840   OpEmitter(op, staticVerifierEmitter).emitDecl(os);
841 }
842 
843 void OpEmitter::emitDef(
844     const Operator &op, raw_ostream &os,
845     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
846   OpEmitter(op, staticVerifierEmitter).emitDef(os);
847 }
848 
849 void OpEmitter::emitDecl(raw_ostream &os) {
850   opClass.finalize();
851   opClass.writeDeclTo(os);
852 }
853 
854 void OpEmitter::emitDef(raw_ostream &os) {
855   opClass.finalize();
856   opClass.writeDefTo(os);
857 }
858 
859 static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
860                           const Operator &op) {
861   if (m)
862     return;
863   PrintFatalError(op.getLoc(), "Unexpected overlap when generating `" +
864                                    methodName + "` for " +
865                                    op.getOperationName() + " (from line " +
866                                    Twine(line) + ")");
867 }
868 
869 #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
870 
871 void OpEmitter::genAttrNameGetters() {
872   const llvm::MapVector<StringRef, AttributeMetadata> &attributes =
873       emitHelper.getAttrMetadata();
874 
875   // Emit the getAttributeNames method.
876   {
877     auto *method = opClass.addStaticInlineMethod(
878         "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames");
879     ERROR_IF_PRUNED(method, "getAttributeNames", op);
880     auto &body = method->body();
881     if (attributes.empty()) {
882       body << "  return {};";
883       // Nothing else to do if there are no registered attributes. Exit early.
884       return;
885     }
886     body << "  static ::llvm::StringRef attrNames[] = {";
887     llvm::interleaveComma(llvm::make_first_range(attributes), body,
888                           [&](StringRef attrName) {
889                             body << "::llvm::StringRef(\"" << attrName << "\")";
890                           });
891     body << "};\n  return ::llvm::makeArrayRef(attrNames);";
892   }
893 
894   // Emit the getAttributeNameForIndex methods.
895   {
896     auto *method = opClass.addInlineMethod<Method::Private>(
897         "::mlir::StringAttr", "getAttributeNameForIndex",
898         MethodParameter("unsigned", "index"));
899     ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
900     method->body()
901         << "  return getAttributeNameForIndex((*this)->getName(), index);";
902   }
903   {
904     auto *method = opClass.addStaticInlineMethod<Method::Private>(
905         "::mlir::StringAttr", "getAttributeNameForIndex",
906         MethodParameter("::mlir::OperationName", "name"),
907         MethodParameter("unsigned", "index"));
908     ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
909 
910     const char *const getAttrName = R"(
911   assert(index < {0} && "invalid attribute index");
912   return name.getRegisteredInfo()->getAttributeNames()[index];
913 )";
914     method->body() << formatv(getAttrName, attributes.size());
915   }
916 
917   // Generate the <attr>AttrName methods, that expose the attribute names to
918   // users.
919   const char *attrNameMethodBody = "  return getAttributeNameForIndex({0});";
920   for (auto &attrIt : llvm::enumerate(llvm::make_first_range(attributes))) {
921     for (StringRef name : op.getGetterNames(attrIt.value())) {
922       std::string methodName = (name + "AttrName").str();
923 
924       // Generate the non-static variant.
925       {
926         auto *method =
927             opClass.addInlineMethod("::mlir::StringAttr", methodName);
928         ERROR_IF_PRUNED(method, methodName, op);
929         method->body() << llvm::formatv(attrNameMethodBody, attrIt.index());
930       }
931 
932       // Generate the static variant.
933       {
934         auto *method = opClass.addStaticInlineMethod(
935             "::mlir::StringAttr", methodName,
936             MethodParameter("::mlir::OperationName", "name"));
937         ERROR_IF_PRUNED(method, methodName, op);
938         method->body() << llvm::formatv(attrNameMethodBody,
939                                         "name, " + Twine(attrIt.index()));
940       }
941     }
942   }
943 }
944 
945 // Emit the getter for an attribute with the return type specified.
946 // It is templated to be shared between the Op and the adaptor class.
947 template <typename OpClassOrAdaptor>
948 static void emitAttrGetterWithReturnType(FmtContext &fctx,
949                                          OpClassOrAdaptor &opClass,
950                                          const Operator &op, StringRef name,
951                                          Attribute attr) {
952   auto *method = opClass.addMethod(attr.getReturnType(), name);
953   ERROR_IF_PRUNED(method, name, op);
954   auto &body = method->body();
955   body << "  auto attr = " << name << "Attr();\n";
956   if (attr.hasDefaultValue()) {
957     // Returns the default value if not set.
958     // TODO: this is inefficient, we are recreating the attribute for every
959     // call. This should be set instead.
960     if (!attr.isConstBuildable()) {
961       PrintFatalError("DefaultValuedAttr of type " + attr.getAttrDefName() +
962                       " must have a constBuilder");
963     }
964     std::string defaultValue = std::string(
965         tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
966     body << "    if (!attr)\n      return "
967          << tgfmt(attr.getConvertFromStorageCall(),
968                   &fctx.withSelf(defaultValue))
969          << ";\n";
970   }
971   body << "  return "
972        << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
973        << ";\n";
974 }
975 
976 void OpEmitter::genAttrGetters() {
977   FmtContext fctx;
978   fctx.withBuilder("::mlir::Builder((*this)->getContext())");
979 
980   // Emit the derived attribute body.
981   auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
982     if (auto *method = opClass.addMethod(attr.getReturnType(), name))
983       method->body() << "  " << attr.getDerivedCodeBody() << "\n";
984   };
985 
986   // Generate named accessor with Attribute return type. This is a wrapper class
987   // that allows referring to the attributes via accessors instead of having to
988   // use the string interface for better compile time verification.
989   auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName,
990                                      Attribute attr) {
991     auto *method = opClass.addMethod(attr.getStorageType(), name + "Attr");
992     if (!method)
993       return;
994     method->body() << formatv(
995         "  return {0}.{1}<{2}>();", emitHelper.getAttr(attrName),
996         attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null"
997                                                     : "cast",
998         attr.getStorageType());
999   };
1000 
1001   for (const NamedAttribute &namedAttr : op.getAttributes()) {
1002     for (StringRef name : op.getGetterNames(namedAttr.name)) {
1003       if (namedAttr.attr.isDerivedAttr()) {
1004         emitDerivedAttr(name, namedAttr.attr);
1005       } else {
1006         emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr);
1007         emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr);
1008       }
1009     }
1010   }
1011 
1012   auto derivedAttrs = make_filter_range(op.getAttributes(),
1013                                         [](const NamedAttribute &namedAttr) {
1014                                           return namedAttr.attr.isDerivedAttr();
1015                                         });
1016   if (derivedAttrs.empty())
1017     return;
1018 
1019   opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
1020   // Generate helper method to query whether a named attribute is a derived
1021   // attribute. This enables, for example, avoiding adding an attribute that
1022   // overlaps with a derived attribute.
1023   {
1024     auto *method =
1025         opClass.addStaticMethod("bool", "isDerivedAttribute",
1026                                 MethodParameter("::llvm::StringRef", "name"));
1027     ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
1028     auto &body = method->body();
1029     for (auto namedAttr : derivedAttrs)
1030       body << "  if (name == \"" << namedAttr.name << "\") return true;\n";
1031     body << " return false;";
1032   }
1033   // Generate method to materialize derived attributes as a DictionaryAttr.
1034   {
1035     auto *method = opClass.addMethod("::mlir::DictionaryAttr",
1036                                      "materializeDerivedAttributes");
1037     ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
1038     auto &body = method->body();
1039 
1040     auto nonMaterializable =
1041         make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
1042           return namedAttr.attr.getConvertFromStorageCall().empty();
1043         });
1044     if (!nonMaterializable.empty()) {
1045       std::string attrs;
1046       llvm::raw_string_ostream os(attrs);
1047       interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
1048         os << op.getGetterName(attr.name);
1049       });
1050       PrintWarning(
1051           op.getLoc(),
1052           formatv(
1053               "op has non-materializable derived attributes '{0}', skipping",
1054               os.str()));
1055       body << formatv("  emitOpError(\"op has non-materializable derived "
1056                       "attributes '{0}'\");\n",
1057                       attrs);
1058       body << "  return nullptr;";
1059       return;
1060     }
1061 
1062     body << "  ::mlir::MLIRContext* ctx = getContext();\n";
1063     body << "  ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
1064     body << "  return ::mlir::DictionaryAttr::get(";
1065     body << "  ctx, {\n";
1066     interleave(
1067         derivedAttrs, body,
1068         [&](const NamedAttribute &namedAttr) {
1069           auto tmpl = namedAttr.attr.getConvertFromStorageCall();
1070           std::string name = op.getGetterName(namedAttr.name);
1071           body << "    {" << name << "AttrName(),\n"
1072                << tgfmt(tmpl, &fctx.withSelf(name + "()")
1073                                    .withBuilder("odsBuilder")
1074                                    .addSubst("_ctx", "ctx"))
1075                << "}";
1076         },
1077         ",\n");
1078     body << "});";
1079   }
1080 }
1081 
1082 void OpEmitter::genAttrSetters() {
1083   // Generate raw named setter type. This is a wrapper class that allows setting
1084   // to the attributes via setters instead of having to use the string interface
1085   // for better compile time verification.
1086   auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
1087                                      Attribute attr) {
1088     auto *method =
1089         opClass.addMethod("void", setterName + "Attr",
1090                           MethodParameter(attr.getStorageType(), "attr"));
1091     if (method)
1092       method->body() << formatv("  (*this)->setAttr({0}AttrName(), attr);",
1093                                 getterName);
1094   };
1095 
1096   for (const NamedAttribute &namedAttr : op.getAttributes()) {
1097     if (namedAttr.attr.isDerivedAttr())
1098       continue;
1099     for (auto names : llvm::zip(op.getSetterNames(namedAttr.name),
1100                                 op.getGetterNames(namedAttr.name)))
1101       emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
1102                               namedAttr.attr);
1103   }
1104 }
1105 
1106 void OpEmitter::genOptionalAttrRemovers() {
1107   // Generate methods for removing optional attributes, instead of having to
1108   // use the string interface. Enables better compile time verification.
1109   auto emitRemoveAttr = [&](StringRef name) {
1110     auto upperInitial = name.take_front().upper();
1111     auto suffix = name.drop_front();
1112     auto *method = opClass.addMethod("::mlir::Attribute",
1113                                      "remove" + upperInitial + suffix + "Attr");
1114     if (!method)
1115       return;
1116     method->body() << formatv("  return (*this)->removeAttr({0}AttrName());",
1117                               op.getGetterName(name));
1118   };
1119 
1120   for (const NamedAttribute &namedAttr : op.getAttributes())
1121     if (namedAttr.attr.isOptional())
1122       emitRemoveAttr(namedAttr.name);
1123 }
1124 
1125 // Generates the code to compute the start and end index of an operand or result
1126 // range.
1127 template <typename RangeT>
1128 static void
1129 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
1130                               int numVariadic, int numNonVariadic,
1131                               StringRef rangeSizeCall, bool hasAttrSegmentSize,
1132                               StringRef sizeAttrInit, RangeT &&odsValues) {
1133   auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName,
1134                                    MethodParameter("unsigned", "index"));
1135   if (!method)
1136     return;
1137   auto &body = method->body();
1138   if (numVariadic == 0) {
1139     body << "  return {index, 1};\n";
1140   } else if (hasAttrSegmentSize) {
1141     body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
1142   } else {
1143     // Because the op can have arbitrarily interleaved variadic and non-variadic
1144     // operands, we need to embed a list in the "sink" getter method for
1145     // calculation at run-time.
1146     SmallVector<StringRef, 4> isVariadic;
1147     isVariadic.reserve(llvm::size(odsValues));
1148     for (auto &it : odsValues)
1149       isVariadic.push_back(it.isVariableLength() ? "true" : "false");
1150     std::string isVariadicList = llvm::join(isVariadic, ", ");
1151     body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
1152                     numNonVariadic, numVariadic, rangeSizeCall, "operand");
1153   }
1154 }
1155 
1156 // Generates the named operand getter methods for the given Operator `op` and
1157 // puts them in `opClass`.  Uses `rangeType` as the return type of getters that
1158 // return a range of operands (individual operands are `Value ` and each
1159 // element in the range must also be `Value `); use `rangeBeginCall` to get
1160 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
1161 // obtain the number of operands. `getOperandCallPattern` contains the code
1162 // necessary to obtain a single operand whose position will be substituted
1163 // instead of
1164 // "{0}" marker in the pattern.  Note that the pattern should work for any kind
1165 // of ops, in particular for one-operand ops that may not have the
1166 // `getOperand(unsigned)` method.
1167 static void generateNamedOperandGetters(const Operator &op, Class &opClass,
1168                                         bool isAdaptor, StringRef sizeAttrInit,
1169                                         StringRef rangeType,
1170                                         StringRef rangeBeginCall,
1171                                         StringRef rangeSizeCall,
1172                                         StringRef getOperandCallPattern) {
1173   const int numOperands = op.getNumOperands();
1174   const int numVariadicOperands = op.getNumVariableLengthOperands();
1175   const int numNormalOperands = numOperands - numVariadicOperands;
1176 
1177   const auto *sameVariadicSize =
1178       op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
1179   const auto *attrSizedOperands =
1180       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1181 
1182   if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
1183     PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
1184                                  "specification over their sizes");
1185   }
1186 
1187   if (numVariadicOperands < 2 && attrSizedOperands) {
1188     PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
1189                                  "to use 'AttrSizedOperandSegments' trait");
1190   }
1191 
1192   if (attrSizedOperands && sameVariadicSize) {
1193     PrintFatalError(op.getLoc(),
1194                     "op cannot have both 'AttrSizedOperandSegments' and "
1195                     "'SameVariadicOperandSize' traits");
1196   }
1197 
1198   // First emit a few "sink" getter methods upon which we layer all nicer named
1199   // getter methods.
1200   generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
1201                                 numVariadicOperands, numNormalOperands,
1202                                 rangeSizeCall, attrSizedOperands, sizeAttrInit,
1203                                 const_cast<Operator &>(op).getOperands());
1204 
1205   auto *m = opClass.addMethod(rangeType, "getODSOperands",
1206                               MethodParameter("unsigned", "index"));
1207   ERROR_IF_PRUNED(m, "getODSOperands", op);
1208   auto &body = m->body();
1209   body << formatv(valueRangeReturnCode, rangeBeginCall,
1210                   "getODSOperandIndexAndLength(index)");
1211 
1212   // Then we emit nicer named getter methods by redirecting to the "sink" getter
1213   // method.
1214   for (int i = 0; i != numOperands; ++i) {
1215     const auto &operand = op.getOperand(i);
1216     if (operand.name.empty())
1217       continue;
1218     for (StringRef name : op.getGetterNames(operand.name)) {
1219       if (operand.isOptional()) {
1220         m = opClass.addMethod("::mlir::Value", name);
1221         ERROR_IF_PRUNED(m, name, op);
1222         m->body() << "  auto operands = getODSOperands(" << i << ");\n"
1223                   << "  return operands.empty() ? ::mlir::Value() : "
1224                      "*operands.begin();";
1225       } else if (operand.isVariadicOfVariadic()) {
1226         std::string segmentAttr = op.getGetterName(
1227             operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
1228         if (isAdaptor) {
1229           m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>",
1230                                 name);
1231           ERROR_IF_PRUNED(m, name, op);
1232           m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
1233                                      segmentAttr, i);
1234           continue;
1235         }
1236 
1237         m = opClass.addMethod("::mlir::OperandRangeRange", name);
1238         ERROR_IF_PRUNED(m, name, op);
1239         m->body() << "  return getODSOperands(" << i << ").split("
1240                   << segmentAttr << "Attr());";
1241       } else if (operand.isVariadic()) {
1242         m = opClass.addMethod(rangeType, name);
1243         ERROR_IF_PRUNED(m, name, op);
1244         m->body() << "  return getODSOperands(" << i << ");";
1245       } else {
1246         m = opClass.addMethod("::mlir::Value", name);
1247         ERROR_IF_PRUNED(m, name, op);
1248         m->body() << "  return *getODSOperands(" << i << ").begin();";
1249       }
1250     }
1251   }
1252 }
1253 
1254 void OpEmitter::genNamedOperandGetters() {
1255   // Build the code snippet used for initializing the operand_segment_size)s
1256   // array.
1257   std::string attrSizeInitCode;
1258   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1259     attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
1260                                emitHelper.getAttr(operandSegmentAttrName));
1261   }
1262 
1263   generateNamedOperandGetters(
1264       op, opClass,
1265       /*isAdaptor=*/false,
1266       /*sizeAttrInit=*/attrSizeInitCode,
1267       /*rangeType=*/"::mlir::Operation::operand_range",
1268       /*rangeBeginCall=*/"getOperation()->operand_begin()",
1269       /*rangeSizeCall=*/"getOperation()->getNumOperands()",
1270       /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
1271 }
1272 
1273 void OpEmitter::genNamedOperandSetters() {
1274   auto *attrSizedOperands =
1275       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1276   for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
1277     const auto &operand = op.getOperand(i);
1278     if (operand.name.empty())
1279       continue;
1280     for (StringRef name : op.getGetterNames(operand.name)) {
1281       auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
1282                                       ? "::mlir::MutableOperandRangeRange"
1283                                       : "::mlir::MutableOperandRange",
1284                                   (name + "Mutable").str());
1285       ERROR_IF_PRUNED(m, name, op);
1286       auto &body = m->body();
1287       body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n"
1288            << "  auto mutableRange = "
1289               "::mlir::MutableOperandRange(getOperation(), "
1290               "range.first, range.second";
1291       if (attrSizedOperands) {
1292         body << formatv(
1293             ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
1294             emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
1295       }
1296       body << ");\n";
1297 
1298       // If this operand is a nested variadic, we split the range into a
1299       // MutableOperandRangeRange that provides a range over all of the
1300       // sub-ranges.
1301       if (operand.isVariadicOfVariadic()) {
1302         body << "  return "
1303                 "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
1304              << op.getGetterName(
1305                     operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
1306              << "AttrName()));\n";
1307       } else {
1308         // Otherwise, we use the full range directly.
1309         body << "  return mutableRange;\n";
1310       }
1311     }
1312   }
1313 }
1314 
1315 void OpEmitter::genNamedResultGetters() {
1316   const int numResults = op.getNumResults();
1317   const int numVariadicResults = op.getNumVariableLengthResults();
1318   const int numNormalResults = numResults - numVariadicResults;
1319 
1320   // If we have more than one variadic results, we need more complicated logic
1321   // to calculate the value range for each result.
1322 
1323   const auto *sameVariadicSize =
1324       op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
1325   const auto *attrSizedResults =
1326       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
1327 
1328   if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
1329     PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
1330                                  "specification over their sizes");
1331   }
1332 
1333   if (numVariadicResults < 2 && attrSizedResults) {
1334     PrintFatalError(op.getLoc(), "op must have at least two variadic results "
1335                                  "to use 'AttrSizedResultSegments' trait");
1336   }
1337 
1338   if (attrSizedResults && sameVariadicSize) {
1339     PrintFatalError(op.getLoc(),
1340                     "op cannot have both 'AttrSizedResultSegments' and "
1341                     "'SameVariadicResultSize' traits");
1342   }
1343 
1344   // Build the initializer string for the result segment size attribute.
1345   std::string attrSizeInitCode;
1346   if (attrSizedResults) {
1347     attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
1348                                emitHelper.getAttr(resultSegmentAttrName));
1349   }
1350 
1351   generateValueRangeStartAndEnd(
1352       opClass, "getODSResultIndexAndLength", numVariadicResults,
1353       numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
1354       attrSizeInitCode, op.getResults());
1355 
1356   auto *m =
1357       opClass.addMethod("::mlir::Operation::result_range", "getODSResults",
1358                         MethodParameter("unsigned", "index"));
1359   ERROR_IF_PRUNED(m, "getODSResults", op);
1360   m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
1361                        "getODSResultIndexAndLength(index)");
1362 
1363   for (int i = 0; i != numResults; ++i) {
1364     const auto &result = op.getResult(i);
1365     if (result.name.empty())
1366       continue;
1367     for (StringRef name : op.getGetterNames(result.name)) {
1368       if (result.isOptional()) {
1369         m = opClass.addMethod("::mlir::Value", name);
1370         ERROR_IF_PRUNED(m, name, op);
1371         m->body()
1372             << "  auto results = getODSResults(" << i << ");\n"
1373             << "  return results.empty() ? ::mlir::Value() : *results.begin();";
1374       } else if (result.isVariadic()) {
1375         m = opClass.addMethod("::mlir::Operation::result_range", name);
1376         ERROR_IF_PRUNED(m, name, op);
1377         m->body() << "  return getODSResults(" << i << ");";
1378       } else {
1379         m = opClass.addMethod("::mlir::Value", name);
1380         ERROR_IF_PRUNED(m, name, op);
1381         m->body() << "  return *getODSResults(" << i << ").begin();";
1382       }
1383     }
1384   }
1385 }
1386 
1387 void OpEmitter::genNamedRegionGetters() {
1388   unsigned numRegions = op.getNumRegions();
1389   for (unsigned i = 0; i < numRegions; ++i) {
1390     const auto &region = op.getRegion(i);
1391     if (region.name.empty())
1392       continue;
1393 
1394     for (StringRef name : op.getGetterNames(region.name)) {
1395       // Generate the accessors for a variadic region.
1396       if (region.isVariadic()) {
1397         auto *m =
1398             opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name);
1399         ERROR_IF_PRUNED(m, name, op);
1400         m->body() << formatv("  return (*this)->getRegions().drop_front({0});",
1401                              i);
1402         continue;
1403       }
1404 
1405       auto *m = opClass.addMethod("::mlir::Region &", name);
1406       ERROR_IF_PRUNED(m, name, op);
1407       m->body() << formatv("  return (*this)->getRegion({0});", i);
1408     }
1409   }
1410 }
1411 
1412 void OpEmitter::genNamedSuccessorGetters() {
1413   unsigned numSuccessors = op.getNumSuccessors();
1414   for (unsigned i = 0; i < numSuccessors; ++i) {
1415     const NamedSuccessor &successor = op.getSuccessor(i);
1416     if (successor.name.empty())
1417       continue;
1418 
1419     for (StringRef name : op.getGetterNames(successor.name)) {
1420       // Generate the accessors for a variadic successor list.
1421       if (successor.isVariadic()) {
1422         auto *m = opClass.addMethod("::mlir::SuccessorRange", name);
1423         ERROR_IF_PRUNED(m, name, op);
1424         m->body() << formatv(
1425             "  return {std::next((*this)->successor_begin(), {0}), "
1426             "(*this)->successor_end()};",
1427             i);
1428         continue;
1429       }
1430 
1431       auto *m = opClass.addMethod("::mlir::Block *", name);
1432       ERROR_IF_PRUNED(m, name, op);
1433       m->body() << formatv("  return (*this)->getSuccessor({0});", i);
1434     }
1435   }
1436 }
1437 
1438 static bool canGenerateUnwrappedBuilder(const Operator &op) {
1439   // If this op does not have native attributes at all, return directly to avoid
1440   // redefining builders.
1441   if (op.getNumNativeAttributes() == 0)
1442     return false;
1443 
1444   bool canGenerate = false;
1445   // We are generating builders that take raw values for attributes. We need to
1446   // make sure the native attributes have a meaningful "unwrapped" value type
1447   // different from the wrapped mlir::Attribute type to avoid redefining
1448   // builders. This checks for the op has at least one such native attribute.
1449   for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
1450     const NamedAttribute &namedAttr = op.getAttribute(i);
1451     if (canUseUnwrappedRawValue(namedAttr.attr)) {
1452       canGenerate = true;
1453       break;
1454     }
1455   }
1456   return canGenerate;
1457 }
1458 
1459 static bool canInferType(const Operator &op) {
1460   return op.getTrait("::mlir::InferTypeOpInterface::Trait");
1461 }
1462 
1463 void OpEmitter::genSeparateArgParamBuilder() {
1464   SmallVector<AttrParamKind, 2> attrBuilderType;
1465   attrBuilderType.push_back(AttrParamKind::WrappedAttr);
1466   if (canGenerateUnwrappedBuilder(op))
1467     attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
1468 
1469   // Emit with separate builders with or without unwrapped attributes and/or
1470   // inferring result type.
1471   auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
1472                   bool inferType) {
1473     SmallVector<MethodParameter> paramList;
1474     SmallVector<std::string, 4> resultNames;
1475     llvm::StringSet<> inferredAttributes;
1476     buildParamList(paramList, inferredAttributes, resultNames, paramKind,
1477                    attrType);
1478 
1479     auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1480     // If the builder is redundant, skip generating the method.
1481     if (!m)
1482       return;
1483     auto &body = m->body();
1484     genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
1485                                            /*isRawValueAttr=*/attrType ==
1486                                                AttrParamKind::UnwrappedValue);
1487 
1488     // Push all result types to the operation state
1489 
1490     if (inferType) {
1491       // Generate builder that infers type too.
1492       // TODO: Subsume this with general checking if type can be
1493       // inferred automatically.
1494       // TODO: Expand to handle regions.
1495       body << formatv(R"(
1496         ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1497         if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1498                       {1}.location, {1}.operands,
1499                       {1}.attributes.getDictionary({1}.getContext()),
1500                       /*regions=*/{{}, inferredReturnTypes)))
1501           {1}.addTypes(inferredReturnTypes);
1502         else
1503           ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1504                       opClass.getClassName(), builderOpState);
1505       return;
1506     }
1507 
1508     switch (paramKind) {
1509     case TypeParamKind::None:
1510       return;
1511     case TypeParamKind::Separate:
1512       for (int i = 0, e = op.getNumResults(); i < e; ++i) {
1513         if (op.getResult(i).isOptional())
1514           body << "  if (" << resultNames[i] << ")\n  ";
1515         body << "  " << builderOpState << ".addTypes(" << resultNames[i]
1516              << ");\n";
1517       }
1518       return;
1519     case TypeParamKind::Collective: {
1520       int numResults = op.getNumResults();
1521       int numVariadicResults = op.getNumVariableLengthResults();
1522       int numNonVariadicResults = numResults - numVariadicResults;
1523       bool hasVariadicResult = numVariadicResults != 0;
1524 
1525       // Avoid emitting "resultTypes.size() >= 0u" which is always true.
1526       if (!(hasVariadicResult && numNonVariadicResults == 0))
1527         body << "  "
1528              << "assert(resultTypes.size() "
1529              << (hasVariadicResult ? ">=" : "==") << " "
1530              << numNonVariadicResults
1531              << "u && \"mismatched number of results\");\n";
1532       body << "  " << builderOpState << ".addTypes(resultTypes);\n";
1533     }
1534       return;
1535     }
1536     llvm_unreachable("unhandled TypeParamKind");
1537   };
1538 
1539   // Some of the build methods generated here may be ambiguous, but TableGen's
1540   // ambiguous function detection will elide those ones.
1541   for (auto attrType : attrBuilderType) {
1542     emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
1543     if (canInferType(op) && op.getNumRegions() == 0)
1544       emit(attrType, TypeParamKind::None, /*inferType=*/true);
1545     emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
1546   }
1547 }
1548 
1549 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
1550   int numResults = op.getNumResults();
1551 
1552   // Signature
1553   SmallVector<MethodParameter> paramList;
1554   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1555   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1556   paramList.emplace_back("::mlir::ValueRange", "operands");
1557   // Provide default value for `attributes` when its the last parameter
1558   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1559   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1560                          "attributes", attributesDefaultValue);
1561   if (op.getNumVariadicRegions())
1562     paramList.emplace_back("unsigned", "numRegions");
1563 
1564   auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1565   // If the builder is redundant, skip generating the method
1566   if (!m)
1567     return;
1568   auto &body = m->body();
1569 
1570   // Operands
1571   body << "  " << builderOpState << ".addOperands(operands);\n";
1572 
1573   // Attributes
1574   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1575 
1576   // Create the correct number of regions
1577   if (int numRegions = op.getNumRegions()) {
1578     body << llvm::formatv(
1579         "  for (unsigned i = 0; i != {0}; ++i)\n",
1580         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1581     body << "    (void)" << builderOpState << ".addRegion();\n";
1582   }
1583 
1584   // Result types
1585   SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
1586   body << "  " << builderOpState << ".addTypes({"
1587        << llvm::join(resultTypes, ", ") << "});\n\n";
1588 }
1589 
1590 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
1591   SmallVector<MethodParameter> paramList;
1592   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1593   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1594   paramList.emplace_back("::mlir::ValueRange", "operands");
1595   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1596   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1597                          "attributes", attributesDefaultValue);
1598   if (op.getNumVariadicRegions())
1599     paramList.emplace_back("unsigned", "numRegions");
1600 
1601   auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1602   // If the builder is redundant, skip generating the method
1603   if (!m)
1604     return;
1605   auto &body = m->body();
1606 
1607   int numResults = op.getNumResults();
1608   int numVariadicResults = op.getNumVariableLengthResults();
1609   int numNonVariadicResults = numResults - numVariadicResults;
1610 
1611   int numOperands = op.getNumOperands();
1612   int numVariadicOperands = op.getNumVariableLengthOperands();
1613   int numNonVariadicOperands = numOperands - numVariadicOperands;
1614 
1615   // Operands
1616   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1617     body << "  assert(operands.size()"
1618          << (numVariadicOperands != 0 ? " >= " : " == ")
1619          << numNonVariadicOperands
1620          << "u && \"mismatched number of parameters\");\n";
1621   body << "  " << builderOpState << ".addOperands(operands);\n";
1622   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1623 
1624   // Create the correct number of regions
1625   if (int numRegions = op.getNumRegions()) {
1626     body << llvm::formatv(
1627         "  for (unsigned i = 0; i != {0}; ++i)\n",
1628         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1629     body << "    (void)" << builderOpState << ".addRegion();\n";
1630   }
1631 
1632   // Result types
1633   body << formatv(R"(
1634   ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1635   if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1636           {1}.location, operands,
1637           {1}.attributes.getDictionary({1}.getContext()),
1638           {1}.regions, inferredReturnTypes))) {{)",
1639                   opClass.getClassName(), builderOpState);
1640   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1641     body << "\n    assert(inferredReturnTypes.size()"
1642          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1643          << "u && \"mismatched number of return types\");";
1644   body << "\n    " << builderOpState << ".addTypes(inferredReturnTypes);";
1645 
1646   body << formatv(R"(
1647   } else {{
1648     ::llvm::report_fatal_error("Failed to infer result type(s).");
1649   })",
1650                   opClass.getClassName(), builderOpState);
1651 }
1652 
1653 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
1654   auto emit = [&](AttrParamKind attrType) {
1655     SmallVector<MethodParameter> paramList;
1656     SmallVector<std::string, 4> resultNames;
1657     llvm::StringSet<> inferredAttributes;
1658     buildParamList(paramList, inferredAttributes, resultNames,
1659                    TypeParamKind::None, attrType);
1660 
1661     auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1662     // If the builder is redundant, skip generating the method
1663     if (!m)
1664       return;
1665     auto &body = m->body();
1666     genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
1667                                            /*isRawValueAttr=*/attrType ==
1668                                                AttrParamKind::UnwrappedValue);
1669 
1670     auto numResults = op.getNumResults();
1671     if (numResults == 0)
1672       return;
1673 
1674     // Push all result types to the operation state
1675     const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
1676     std::string resultType =
1677         formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
1678     body << "  " << builderOpState << ".addTypes({" << resultType;
1679     for (int i = 1; i != numResults; ++i)
1680       body << ", " << resultType;
1681     body << "});\n\n";
1682   };
1683 
1684   emit(AttrParamKind::WrappedAttr);
1685   // Generate additional builder(s) if any attributes can be "unwrapped"
1686   if (canGenerateUnwrappedBuilder(op))
1687     emit(AttrParamKind::UnwrappedValue);
1688 }
1689 
1690 void OpEmitter::genUseAttrAsResultTypeBuilder() {
1691   SmallVector<MethodParameter> paramList;
1692   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1693   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1694   paramList.emplace_back("::mlir::ValueRange", "operands");
1695   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1696                          "attributes", "{}");
1697   auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1698   // If the builder is redundant, skip generating the method
1699   if (!m)
1700     return;
1701 
1702   auto &body = m->body();
1703 
1704   // Push all result types to the operation state
1705   std::string resultType;
1706   const auto &namedAttr = op.getAttribute(0);
1707 
1708   body << "  auto attrName = " << op.getGetterName(namedAttr.name)
1709        << "AttrName(" << builderOpState
1710        << ".name);\n"
1711           "  for (auto attr : attributes) {\n"
1712           "    if (attr.getName() != attrName) continue;\n";
1713   if (namedAttr.attr.isTypeAttr()) {
1714     resultType = "attr.getValue().cast<::mlir::TypeAttr>().getValue()";
1715   } else {
1716     resultType = "attr.getValue().getType()";
1717   }
1718 
1719   // Operands
1720   body << "  " << builderOpState << ".addOperands(operands);\n";
1721 
1722   // Attributes
1723   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1724 
1725   // Result types
1726   SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
1727   body << "    " << builderOpState << ".addTypes({"
1728        << llvm::join(resultTypes, ", ") << "});\n";
1729   body << "  }\n";
1730 }
1731 
1732 /// Returns a signature of the builder. Updates the context `fctx` to enable
1733 /// replacement of $_builder and $_state in the body.
1734 static SmallVector<MethodParameter>
1735 getBuilderSignature(const Builder &builder) {
1736   ArrayRef<Builder::Parameter> params(builder.getParameters());
1737 
1738   // Inject builder and state arguments.
1739   SmallVector<MethodParameter> arguments;
1740   arguments.reserve(params.size() + 2);
1741   arguments.emplace_back("::mlir::OpBuilder &", odsBuilder);
1742   arguments.emplace_back("::mlir::OperationState &", builderOpState);
1743 
1744   for (unsigned i = 0, e = params.size(); i < e; ++i) {
1745     // If no name is provided, generate one.
1746     Optional<StringRef> paramName = params[i].getName();
1747     std::string name =
1748         paramName ? paramName->str() : "odsArg" + std::to_string(i);
1749 
1750     StringRef defaultValue;
1751     if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
1752       defaultValue = *defaultParamValue;
1753 
1754     arguments.emplace_back(params[i].getCppType(), std::move(name),
1755                            defaultValue);
1756   }
1757 
1758   return arguments;
1759 }
1760 
1761 void OpEmitter::genBuilder() {
1762   // Handle custom builders if provided.
1763   for (const Builder &builder : op.getBuilders()) {
1764     SmallVector<MethodParameter> arguments = getBuilderSignature(builder);
1765 
1766     Optional<StringRef> body = builder.getBody();
1767     auto properties = body ? Method::Static : Method::StaticDeclaration;
1768     auto *method =
1769         opClass.addMethod("void", "build", properties, std::move(arguments));
1770     if (body)
1771       ERROR_IF_PRUNED(method, "build", op);
1772 
1773     FmtContext fctx;
1774     fctx.withBuilder(odsBuilder);
1775     fctx.addSubst("_state", builderOpState);
1776     if (body)
1777       method->body() << tgfmt(*body, &fctx);
1778   }
1779 
1780   // Generate default builders that requires all result type, operands, and
1781   // attributes as parameters.
1782   if (op.skipDefaultBuilders())
1783     return;
1784 
1785   // We generate three classes of builders here:
1786   // 1. one having a stand-alone parameter for each operand / attribute, and
1787   genSeparateArgParamBuilder();
1788   // 2. one having an aggregated parameter for all result types / operands /
1789   //    attributes, and
1790   genCollectiveParamBuilder();
1791   // 3. one having a stand-alone parameter for each operand and attribute,
1792   //    use the first operand or attribute's type as all result types
1793   //    to facilitate different call patterns.
1794   if (op.getNumVariableLengthResults() == 0) {
1795     if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
1796       genUseOperandAsResultTypeSeparateParamBuilder();
1797       genUseOperandAsResultTypeCollectiveParamBuilder();
1798     }
1799     if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
1800       genUseAttrAsResultTypeBuilder();
1801   }
1802 }
1803 
1804 void OpEmitter::genCollectiveParamBuilder() {
1805   int numResults = op.getNumResults();
1806   int numVariadicResults = op.getNumVariableLengthResults();
1807   int numNonVariadicResults = numResults - numVariadicResults;
1808 
1809   int numOperands = op.getNumOperands();
1810   int numVariadicOperands = op.getNumVariableLengthOperands();
1811   int numNonVariadicOperands = numOperands - numVariadicOperands;
1812 
1813   SmallVector<MethodParameter> paramList;
1814   paramList.emplace_back("::mlir::OpBuilder &", "");
1815   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1816   paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1817   paramList.emplace_back("::mlir::ValueRange", "operands");
1818   // Provide default value for `attributes` when its the last parameter
1819   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1820   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1821                          "attributes", attributesDefaultValue);
1822   if (op.getNumVariadicRegions())
1823     paramList.emplace_back("unsigned", "numRegions");
1824 
1825   auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1826   // If the builder is redundant, skip generating the method
1827   if (!m)
1828     return;
1829   auto &body = m->body();
1830 
1831   // Operands
1832   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1833     body << "  assert(operands.size()"
1834          << (numVariadicOperands != 0 ? " >= " : " == ")
1835          << numNonVariadicOperands
1836          << "u && \"mismatched number of parameters\");\n";
1837   body << "  " << builderOpState << ".addOperands(operands);\n";
1838 
1839   // Attributes
1840   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1841 
1842   // Create the correct number of regions
1843   if (int numRegions = op.getNumRegions()) {
1844     body << llvm::formatv(
1845         "  for (unsigned i = 0; i != {0}; ++i)\n",
1846         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1847     body << "    (void)" << builderOpState << ".addRegion();\n";
1848   }
1849 
1850   // Result types
1851   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1852     body << "  assert(resultTypes.size()"
1853          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1854          << "u && \"mismatched number of return types\");\n";
1855   body << "  " << builderOpState << ".addTypes(resultTypes);\n";
1856 
1857   // Generate builder that infers type too.
1858   // TODO: Expand to handle successors.
1859   if (canInferType(op) && op.getNumSuccessors() == 0)
1860     genInferredTypeCollectiveParamBuilder();
1861 }
1862 
1863 void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
1864                                llvm::StringSet<> &inferredAttributes,
1865                                SmallVectorImpl<std::string> &resultTypeNames,
1866                                TypeParamKind typeParamKind,
1867                                AttrParamKind attrParamKind) {
1868   resultTypeNames.clear();
1869   auto numResults = op.getNumResults();
1870   resultTypeNames.reserve(numResults);
1871 
1872   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1873   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1874 
1875   switch (typeParamKind) {
1876   case TypeParamKind::None:
1877     break;
1878   case TypeParamKind::Separate: {
1879     // Add parameters for all return types
1880     for (int i = 0; i < numResults; ++i) {
1881       const auto &result = op.getResult(i);
1882       std::string resultName = std::string(result.name);
1883       if (resultName.empty())
1884         resultName = std::string(formatv("resultType{0}", i));
1885 
1886       StringRef type =
1887           result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
1888 
1889       paramList.emplace_back(type, resultName, result.isOptional());
1890       resultTypeNames.emplace_back(std::move(resultName));
1891     }
1892   } break;
1893   case TypeParamKind::Collective: {
1894     paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1895     resultTypeNames.push_back("resultTypes");
1896   } break;
1897   }
1898 
1899   // Add parameters for all arguments (operands and attributes).
1900   int defaultValuedAttrStartIndex = op.getNumArgs();
1901   // Successors and variadic regions go at the end of the parameter list, so no
1902   // default arguments are possible.
1903   bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions();
1904   if (attrParamKind == AttrParamKind::UnwrappedValue && !hasTrailingParams) {
1905     // Calculate the start index from which we can attach default values in the
1906     // builder declaration.
1907     for (int i = op.getNumArgs() - 1; i >= 0; --i) {
1908       auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
1909       if (!namedAttr || !namedAttr->attr.hasDefaultValue())
1910         break;
1911 
1912       if (!canUseUnwrappedRawValue(namedAttr->attr))
1913         break;
1914 
1915       // Creating an APInt requires us to provide bitwidth, value, and
1916       // signedness, which is complicated compared to others. Similarly
1917       // for APFloat.
1918       // TODO: Adjust the 'returnType' field of such attributes
1919       // to support them.
1920       StringRef retType = namedAttr->attr.getReturnType();
1921       if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
1922         break;
1923 
1924       defaultValuedAttrStartIndex = i;
1925     }
1926   }
1927 
1928   /// Collect any inferred attributes.
1929   for (const NamedTypeConstraint &operand : op.getOperands()) {
1930     if (operand.isVariadicOfVariadic()) {
1931       inferredAttributes.insert(
1932           operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
1933     }
1934   }
1935 
1936   for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
1937     Argument arg = op.getArg(i);
1938     if (const auto *operand = arg.dyn_cast<NamedTypeConstraint *>()) {
1939       StringRef type;
1940       if (operand->isVariadicOfVariadic())
1941         type = "::llvm::ArrayRef<::mlir::ValueRange>";
1942       else if (operand->isVariadic())
1943         type = "::mlir::ValueRange";
1944       else
1945         type = "::mlir::Value";
1946 
1947       paramList.emplace_back(type, getArgumentName(op, numOperands++),
1948                              operand->isOptional());
1949       continue;
1950     }
1951     const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
1952     const Attribute &attr = namedAttr.attr;
1953 
1954     // Inferred attributes don't need to be added to the param list.
1955     if (inferredAttributes.contains(namedAttr.name))
1956       continue;
1957 
1958     StringRef type;
1959     switch (attrParamKind) {
1960     case AttrParamKind::WrappedAttr:
1961       type = attr.getStorageType();
1962       break;
1963     case AttrParamKind::UnwrappedValue:
1964       if (canUseUnwrappedRawValue(attr))
1965         type = attr.getReturnType();
1966       else
1967         type = attr.getStorageType();
1968       break;
1969     }
1970 
1971     // Attach default value if requested and possible.
1972     std::string defaultValue;
1973     if (attrParamKind == AttrParamKind::UnwrappedValue &&
1974         i >= defaultValuedAttrStartIndex) {
1975       defaultValue += attr.getDefaultValue();
1976     }
1977     paramList.emplace_back(type, namedAttr.name, StringRef(defaultValue),
1978                            attr.isOptional());
1979   }
1980 
1981   /// Insert parameters for each successor.
1982   for (const NamedSuccessor &succ : op.getSuccessors()) {
1983     StringRef type =
1984         succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
1985     paramList.emplace_back(type, succ.name);
1986   }
1987 
1988   /// Insert parameters for variadic regions.
1989   for (const NamedRegion &region : op.getRegions())
1990     if (region.isVariadic())
1991       paramList.emplace_back("unsigned",
1992                              llvm::formatv("{0}Count", region.name).str());
1993 }
1994 
1995 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
1996     MethodBody &body, llvm::StringSet<> &inferredAttributes,
1997     bool isRawValueAttr) {
1998   // Push all operands to the result.
1999   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
2000     std::string argName = getArgumentName(op, i);
2001     const NamedTypeConstraint &operand = op.getOperand(i);
2002     if (operand.constraint.isVariadicOfVariadic()) {
2003       body << "  for (::mlir::ValueRange range : " << argName << ")\n   "
2004            << builderOpState << ".addOperands(range);\n";
2005 
2006       // Add the segment attribute.
2007       body << "  {\n"
2008            << "    ::llvm::SmallVector<int32_t> rangeSegments;\n"
2009            << "    for (::mlir::ValueRange range : " << argName << ")\n"
2010            << "      rangeSegments.push_back(range.size());\n"
2011            << "    " << builderOpState << ".addAttribute("
2012            << op.getGetterName(
2013                   operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
2014            << "AttrName(" << builderOpState << ".name), " << odsBuilder
2015            << ".getI32TensorAttr(rangeSegments));"
2016            << "  }\n";
2017       continue;
2018     }
2019 
2020     if (operand.isOptional())
2021       body << "  if (" << argName << ")\n  ";
2022     body << "  " << builderOpState << ".addOperands(" << argName << ");\n";
2023   }
2024 
2025   // If the operation has the operand segment size attribute, add it here.
2026   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
2027     std::string sizes = op.getGetterName(operandSegmentAttrName);
2028     body << "  " << builderOpState << ".addAttribute(" << sizes << "AttrName("
2029          << builderOpState << ".name), "
2030          << "odsBuilder.getI32VectorAttr({";
2031     interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
2032       const NamedTypeConstraint &operand = op.getOperand(i);
2033       if (!operand.isVariableLength()) {
2034         body << "1";
2035         return;
2036       }
2037 
2038       std::string operandName = getArgumentName(op, i);
2039       if (operand.isOptional()) {
2040         body << "(" << operandName << " ? 1 : 0)";
2041       } else if (operand.isVariadicOfVariadic()) {
2042         body << llvm::formatv(
2043             "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, "
2044             "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + "
2045             "range.size(); }))",
2046             operandName);
2047       } else {
2048         body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
2049       }
2050     });
2051     body << "}));\n";
2052   }
2053 
2054   // Push all attributes to the result.
2055   for (const auto &namedAttr : op.getAttributes()) {
2056     auto &attr = namedAttr.attr;
2057     if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name))
2058       continue;
2059 
2060     bool emitNotNullCheck =
2061         attr.isOptional() || (attr.hasDefaultValue() && !isRawValueAttr);
2062     if (emitNotNullCheck)
2063       body << formatv("  if ({0}) ", namedAttr.name) << "{\n";
2064 
2065     if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
2066       // If this is a raw value, then we need to wrap it in an Attribute
2067       // instance.
2068       FmtContext fctx;
2069       fctx.withBuilder("odsBuilder");
2070 
2071       std::string builderTemplate = std::string(attr.getConstBuilderTemplate());
2072 
2073       // For StringAttr, its constant builder call will wrap the input in
2074       // quotes, which is correct for normal string literals, but incorrect
2075       // here given we use function arguments. So we need to strip the
2076       // wrapping quotes.
2077       if (StringRef(builderTemplate).contains("\"$0\""))
2078         builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
2079 
2080       std::string value =
2081           std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
2082       body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
2083                       builderOpState, op.getGetterName(namedAttr.name), value);
2084     } else {
2085       body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
2086                       builderOpState, op.getGetterName(namedAttr.name),
2087                       namedAttr.name);
2088     }
2089     if (emitNotNullCheck)
2090       body << "  }\n";
2091   }
2092 
2093   // Create the correct number of regions.
2094   for (const NamedRegion &region : op.getRegions()) {
2095     if (region.isVariadic())
2096       body << formatv("  for (unsigned i = 0; i < {0}Count; ++i)\n  ",
2097                       region.name);
2098 
2099     body << "  (void)" << builderOpState << ".addRegion();\n";
2100   }
2101 
2102   // Push all successors to the result.
2103   for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
2104     body << formatv("  {0}.addSuccessors({1});\n", builderOpState,
2105                     namedSuccessor.name);
2106   }
2107 }
2108 
2109 void OpEmitter::genCanonicalizerDecls() {
2110   bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod");
2111   if (hasCanonicalizeMethod) {
2112     // static LogicResult FooOp::
2113     // canonicalize(FooOp op, PatternRewriter &rewriter);
2114     SmallVector<MethodParameter> paramList;
2115     paramList.emplace_back(op.getCppClassName(), "op");
2116     paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
2117     auto *m = opClass.declareStaticMethod("::mlir::LogicalResult",
2118                                           "canonicalize", std::move(paramList));
2119     ERROR_IF_PRUNED(m, "canonicalize", op);
2120   }
2121 
2122   // We get a prototype for 'getCanonicalizationPatterns' if requested directly
2123   // or if using a 'canonicalize' method.
2124   bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer");
2125   if (!hasCanonicalizeMethod && !hasCanonicalizer)
2126     return;
2127 
2128   // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
2129   // method, but not implementing 'getCanonicalizationPatterns' manually.
2130   bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
2131 
2132   // Add a signature for getCanonicalizationPatterns if implemented by the
2133   // dialect or if synthesized to call 'canonicalize'.
2134   SmallVector<MethodParameter> paramList;
2135   paramList.emplace_back("::mlir::RewritePatternSet &", "results");
2136   paramList.emplace_back("::mlir::MLIRContext *", "context");
2137   auto kind = hasBody ? Method::Static : Method::StaticDeclaration;
2138   auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind,
2139                                    std::move(paramList));
2140 
2141   // If synthesizing the method, fill it it.
2142   if (hasBody) {
2143     ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op);
2144     method->body() << "  results.add(canonicalize);\n";
2145   }
2146 }
2147 
2148 void OpEmitter::genFolderDecls() {
2149   bool hasSingleResult =
2150       op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
2151 
2152   if (def.getValueAsBit("hasFolder")) {
2153     if (hasSingleResult) {
2154       auto *m = opClass.declareMethod(
2155           "::mlir::OpFoldResult", "fold",
2156           MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands"));
2157       ERROR_IF_PRUNED(m, "operands", op);
2158     } else {
2159       SmallVector<MethodParameter> paramList;
2160       paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
2161       paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
2162                              "results");
2163       auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold",
2164                                       std::move(paramList));
2165       ERROR_IF_PRUNED(m, "fold", op);
2166     }
2167   }
2168 }
2169 
2170 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
2171   Interface interface = opTrait->getInterface();
2172 
2173   // Get the set of methods that should always be declared.
2174   auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
2175   llvm::StringSet<> alwaysDeclaredMethods;
2176   alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
2177                                alwaysDeclaredMethodsVec.end());
2178 
2179   for (const InterfaceMethod &method : interface.getMethods()) {
2180     // Don't declare if the method has a body.
2181     if (method.getBody())
2182       continue;
2183     // Don't declare if the method has a default implementation and the op
2184     // didn't request that it always be declared.
2185     if (method.getDefaultImplementation() &&
2186         !alwaysDeclaredMethods.count(method.getName()))
2187       continue;
2188     // Interface methods are allowed to overlap with existing methods, so don't
2189     // check if pruned.
2190     (void)genOpInterfaceMethod(method);
2191   }
2192 }
2193 
2194 Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
2195                                         bool declaration) {
2196   SmallVector<MethodParameter> paramList;
2197   for (const InterfaceMethod::Argument &arg : method.getArguments())
2198     paramList.emplace_back(arg.type, arg.name);
2199 
2200   auto props = (method.isStatic() ? Method::Static : Method::None) |
2201                (declaration ? Method::Declaration : Method::None);
2202   return opClass.addMethod(method.getReturnType(), method.getName(), props,
2203                            std::move(paramList));
2204 }
2205 
2206 void OpEmitter::genOpInterfaceMethods() {
2207   for (const auto &trait : op.getTraits()) {
2208     if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
2209       if (opTrait->shouldDeclareMethods())
2210         genOpInterfaceMethods(opTrait);
2211   }
2212 }
2213 
2214 void OpEmitter::genSideEffectInterfaceMethods() {
2215   enum EffectKind { Operand, Result, Symbol, Static };
2216   struct EffectLocation {
2217     /// The effect applied.
2218     SideEffect effect;
2219 
2220     /// The index if the kind is not static.
2221     unsigned index;
2222 
2223     /// The kind of the location.
2224     unsigned kind;
2225   };
2226 
2227   StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
2228   auto resolveDecorators = [&](Operator::var_decorator_range decorators,
2229                                unsigned index, unsigned kind) {
2230     for (auto decorator : decorators)
2231       if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
2232         opClass.addTrait(effect->getInterfaceTrait());
2233         interfaceEffects[effect->getBaseEffectName()].push_back(
2234             EffectLocation{*effect, index, kind});
2235       }
2236   };
2237 
2238   // Collect effects that were specified via:
2239   /// Traits.
2240   for (const auto &trait : op.getTraits()) {
2241     const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
2242     if (!opTrait)
2243       continue;
2244     auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
2245     for (auto decorator : opTrait->getEffects())
2246       effects.push_back(EffectLocation{cast<SideEffect>(decorator),
2247                                        /*index=*/0, EffectKind::Static});
2248   }
2249   /// Attributes and Operands.
2250   for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
2251     Argument arg = op.getArg(i);
2252     if (arg.is<NamedTypeConstraint *>()) {
2253       resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
2254       ++operandIt;
2255       continue;
2256     }
2257     const NamedAttribute *attr = arg.get<NamedAttribute *>();
2258     if (attr->attr.getBaseAttr().isSymbolRefAttr())
2259       resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
2260   }
2261   /// Results.
2262   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
2263     resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
2264 
2265   // The code used to add an effect instance.
2266   // {0}: The effect class.
2267   // {1}: Optional value or symbol reference.
2268   // {1}: The resource class.
2269   const char *addEffectCode =
2270       "  effects.emplace_back({0}::get(), {1}{2}::get());\n";
2271 
2272   for (auto &it : interfaceEffects) {
2273     // Generate the 'getEffects' method.
2274     std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::"
2275                                      "SideEffects::EffectInstance<{0}>> &",
2276                                      it.first())
2277                            .str();
2278     auto *getEffects = opClass.addMethod("void", "getEffects",
2279                                          MethodParameter(type, "effects"));
2280     ERROR_IF_PRUNED(getEffects, "getEffects", op);
2281     auto &body = getEffects->body();
2282 
2283     // Add effect instances for each of the locations marked on the operation.
2284     for (auto &location : it.second) {
2285       StringRef effect = location.effect.getName();
2286       StringRef resource = location.effect.getResource();
2287       if (location.kind == EffectKind::Static) {
2288         // A static instance has no attached value.
2289         body << llvm::formatv(addEffectCode, effect, "", resource).str();
2290       } else if (location.kind == EffectKind::Symbol) {
2291         // A symbol reference requires adding the proper attribute.
2292         const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
2293         std::string argName = op.getGetterName(attr->name);
2294         if (attr->attr.isOptional()) {
2295           body << "  if (auto symbolRef = " << argName << "Attr())\n  "
2296                << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
2297                       .str();
2298         } else {
2299           body << llvm::formatv(addEffectCode, effect, argName + "Attr(), ",
2300                                 resource)
2301                       .str();
2302         }
2303       } else {
2304         // Otherwise this is an operand/result, so we need to attach the Value.
2305         body << "  for (::mlir::Value value : getODS"
2306              << (location.kind == EffectKind::Operand ? "Operands" : "Results")
2307              << "(" << location.index << "))\n  "
2308              << llvm::formatv(addEffectCode, effect, "value, ", resource).str();
2309       }
2310     }
2311   }
2312 }
2313 
2314 void OpEmitter::genTypeInterfaceMethods() {
2315   if (!op.allResultTypesKnown())
2316     return;
2317   // Generate 'inferReturnTypes' method declaration using the interface method
2318   // declared in 'InferTypeOpInterface' op interface.
2319   const auto *trait =
2320       cast<InterfaceTrait>(op.getTrait("::mlir::InferTypeOpInterface::Trait"));
2321   Interface interface = trait->getInterface();
2322   Method *method = [&]() -> Method * {
2323     for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
2324       if (interfaceMethod.getName() == "inferReturnTypes") {
2325         return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
2326       }
2327     }
2328     assert(0 && "unable to find inferReturnTypes interface method");
2329     return nullptr;
2330   }();
2331   ERROR_IF_PRUNED(method, "inferReturnTypes", op);
2332   auto &body = method->body();
2333   body << "  inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
2334 
2335   FmtContext fctx;
2336   fctx.withBuilder("odsBuilder");
2337   body << "  ::mlir::Builder odsBuilder(context);\n";
2338 
2339   auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & {
2340     if (!type.isArg())
2341       return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
2342     auto argIndex = type.getArg();
2343     assert(!op.getArg(argIndex).is<NamedAttribute *>());
2344     auto arg = op.getArgToOperandOrAttribute(argIndex);
2345     if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
2346       return body << "operands[" << arg.operandOrAttributeIndex()
2347                   << "].getType()";
2348     return body << "attributes[" << arg.operandOrAttributeIndex()
2349                 << "].getType()";
2350   };
2351 
2352   for (int i = 0, e = op.getNumResults(); i != e; ++i) {
2353     body << "  inferredReturnTypes[" << i << "] = ";
2354     auto types = op.getSameTypeAsResult(i);
2355     emitType(types[0]) << ";\n";
2356     if (types.size() == 1)
2357       continue;
2358     // TODO: We could verify equality here, but skipping that for verification.
2359   }
2360   body << "  return ::mlir::success();";
2361 }
2362 
2363 void OpEmitter::genParser() {
2364   if (hasStringAttribute(def, "assemblyFormat"))
2365     return;
2366 
2367   if (!def.getValueAsBit("hasCustomAssemblyFormat"))
2368     return;
2369 
2370   SmallVector<MethodParameter> paramList;
2371   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
2372   paramList.emplace_back("::mlir::OperationState &", "result");
2373 
2374   auto *method = opClass.declareStaticMethod("::mlir::ParseResult", "parse",
2375                                              std::move(paramList));
2376   ERROR_IF_PRUNED(method, "parse", op);
2377 }
2378 
2379 void OpEmitter::genPrinter() {
2380   if (hasStringAttribute(def, "assemblyFormat"))
2381     return;
2382 
2383   // Check to see if this op uses a c++ format.
2384   if (!def.getValueAsBit("hasCustomAssemblyFormat"))
2385     return;
2386   auto *method = opClass.declareMethod(
2387       "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p"));
2388   ERROR_IF_PRUNED(method, "print", op);
2389 }
2390 
2391 void OpEmitter::genVerifier() {
2392   auto *implMethod =
2393       opClass.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl");
2394   ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
2395   auto &implBody = implMethod->body();
2396 
2397   populateSubstitutions(emitHelper, verifyCtx);
2398   genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter);
2399   genOperandResultVerifier(implBody, op.getOperands(), "operand");
2400   genOperandResultVerifier(implBody, op.getResults(), "result");
2401 
2402   for (auto &trait : op.getTraits()) {
2403     if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
2404       implBody << tgfmt("  if (!($0))\n    "
2405                         "return emitOpError(\"failed to verify that $1\");\n",
2406                         &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
2407                         t->getSummary());
2408     }
2409   }
2410 
2411   genRegionVerifier(implBody);
2412   genSuccessorVerifier(implBody);
2413 
2414   implBody << "  return ::mlir::success();\n";
2415 
2416   // TODO: Some places use the `verifyInvariants` to do operation verification.
2417   // This may not act as their expectation because this doesn't call any
2418   // verifiers of native/interface traits. Needs to review those use cases and
2419   // see if we should use the mlir::verify() instead.
2420   auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants");
2421   ERROR_IF_PRUNED(method, "verifyInvariants", op);
2422   auto &body = method->body();
2423   if (def.getValueAsBit("hasVerifier")) {
2424     body << "  if(::mlir::succeeded(verifyInvariantsImpl()) && "
2425             "::mlir::succeeded(verify()))\n";
2426     body << "    return ::mlir::success();\n";
2427     body << "  return ::mlir::failure();";
2428   } else {
2429     body << "  return verifyInvariantsImpl();";
2430   }
2431 }
2432 
2433 void OpEmitter::genCustomVerifier() {
2434   if (def.getValueAsBit("hasVerifier")) {
2435     auto *method = opClass.declareMethod("::mlir::LogicalResult", "verify");
2436     ERROR_IF_PRUNED(method, "verify", op);
2437   }
2438 
2439   if (def.getValueAsBit("hasRegionVerifier")) {
2440     auto *method =
2441         opClass.declareMethod("::mlir::LogicalResult", "verifyRegions");
2442     ERROR_IF_PRUNED(method, "verifyRegions", op);
2443   }
2444 }
2445 
2446 void OpEmitter::genOperandResultVerifier(MethodBody &body,
2447                                          Operator::const_value_range values,
2448                                          StringRef valueKind) {
2449   // Check that an optional value is at most 1 element.
2450   //
2451   // {0}: Value index.
2452   // {1}: "operand" or "result"
2453   const char *const verifyOptional = R"(
2454     if (valueGroup{0}.size() > 1) {
2455       return emitOpError("{1} group starting at #") << index
2456           << " requires 0 or 1 element, but found " << valueGroup{0}.size();
2457     }
2458 )";
2459   // Check the types of a range of values.
2460   //
2461   // {0}: Value index.
2462   // {1}: Type constraint function.
2463   // {2}: "operand" or "result"
2464   const char *const verifyValues = R"(
2465     for (auto v : valueGroup{0}) {
2466       if (::mlir::failed({1}(*this, v.getType(), "{2}", index++)))
2467         return ::mlir::failure();
2468     }
2469 )";
2470 
2471   const auto canSkip = [](const NamedTypeConstraint &value) {
2472     return !value.hasPredicate() && !value.isOptional() &&
2473            !value.isVariadicOfVariadic();
2474   };
2475   if (values.empty() || llvm::all_of(values, canSkip))
2476     return;
2477 
2478   FmtContext fctx;
2479 
2480   body << "  {\n    unsigned index = 0; (void)index;\n";
2481 
2482   for (const auto &staticValue : llvm::enumerate(values)) {
2483     const NamedTypeConstraint &value = staticValue.value();
2484 
2485     bool hasPredicate = value.hasPredicate();
2486     bool isOptional = value.isOptional();
2487     bool isVariadicOfVariadic = value.isVariadicOfVariadic();
2488     if (!hasPredicate && !isOptional && !isVariadicOfVariadic)
2489       continue;
2490     body << formatv("    auto valueGroup{2} = getODS{0}{1}s({2});\n",
2491                     // Capitalize the first letter to match the function name
2492                     valueKind.substr(0, 1).upper(), valueKind.substr(1),
2493                     staticValue.index());
2494 
2495     // If the constraint is optional check that the value group has at most 1
2496     // value.
2497     if (isOptional) {
2498       body << formatv(verifyOptional, staticValue.index(), valueKind);
2499     } else if (isVariadicOfVariadic) {
2500       body << formatv(
2501           "    if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
2502           "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n"
2503           "      return ::mlir::failure();\n",
2504           value.constraint.getVariadicOfVariadicSegmentSizeAttr(), value.name,
2505           staticValue.index());
2506     }
2507 
2508     // Otherwise, if there is no predicate there is nothing left to do.
2509     if (!hasPredicate)
2510       continue;
2511     // Emit a loop to check all the dynamic values in the pack.
2512     StringRef constraintFn =
2513         staticVerifierEmitter.getTypeConstraintFn(value.constraint);
2514     body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind);
2515   }
2516 
2517   body << "  }\n";
2518 }
2519 
2520 void OpEmitter::genRegionVerifier(MethodBody &body) {
2521   /// Code to verify a region.
2522   ///
2523   /// {0}: Getter for the regions.
2524   /// {1}: The region constraint.
2525   /// {2}: The region's name.
2526   /// {3}: The region description.
2527   const char *const verifyRegion = R"(
2528     for (auto &region : {0})
2529       if (::mlir::failed({1}(*this, region, "{2}", index++)))
2530         return ::mlir::failure();
2531 )";
2532   /// Get a single region.
2533   ///
2534   /// {0}: The region's index.
2535   const char *const getSingleRegion =
2536       "::llvm::makeMutableArrayRef((*this)->getRegion({0}))";
2537 
2538   // If we have no regions, there is nothing more to do.
2539   const auto canSkip = [](const NamedRegion &region) {
2540     return region.constraint.getPredicate().isNull();
2541   };
2542   auto regions = op.getRegions();
2543   if (regions.empty() && llvm::all_of(regions, canSkip))
2544     return;
2545 
2546   body << "  {\n    unsigned index = 0; (void)index;\n";
2547   for (const auto &it : llvm::enumerate(regions)) {
2548     const auto &region = it.value();
2549     if (canSkip(region))
2550       continue;
2551 
2552     auto getRegion = region.isVariadic()
2553                          ? formatv("{0}()", op.getGetterName(region.name)).str()
2554                          : formatv(getSingleRegion, it.index()).str();
2555     auto constraintFn =
2556         staticVerifierEmitter.getRegionConstraintFn(region.constraint);
2557     body << formatv(verifyRegion, getRegion, constraintFn, region.name);
2558   }
2559   body << "  }\n";
2560 }
2561 
2562 void OpEmitter::genSuccessorVerifier(MethodBody &body) {
2563   const char *const verifySuccessor = R"(
2564     for (auto *successor : {0})
2565       if (::mlir::failed({1}(*this, successor, "{2}", index++)))
2566         return ::mlir::failure();
2567 )";
2568   /// Get a single successor.
2569   ///
2570   /// {0}: The successor's name.
2571   const char *const getSingleSuccessor = "::llvm::makeMutableArrayRef({0}())";
2572 
2573   // If we have no successors, there is nothing more to do.
2574   const auto canSkip = [](const NamedSuccessor &successor) {
2575     return successor.constraint.getPredicate().isNull();
2576   };
2577   auto successors = op.getSuccessors();
2578   if (successors.empty() && llvm::all_of(successors, canSkip))
2579     return;
2580 
2581   body << "  {\n    unsigned index = 0; (void)index;\n";
2582 
2583   for (auto &it : llvm::enumerate(successors)) {
2584     const auto &successor = it.value();
2585     if (canSkip(successor))
2586       continue;
2587 
2588     auto getSuccessor =
2589         formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor,
2590                 successor.name, it.index())
2591             .str();
2592     auto constraintFn =
2593         staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint);
2594     body << formatv(verifySuccessor, getSuccessor, constraintFn,
2595                     successor.name);
2596   }
2597   body << "  }\n";
2598 }
2599 
2600 /// Add a size count trait to the given operation class.
2601 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
2602                               int numTotal, int numVariadic) {
2603   if (numVariadic != 0) {
2604     if (numTotal == numVariadic)
2605       opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
2606     else
2607       opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
2608                        Twine(numTotal - numVariadic) + ">::Impl");
2609     return;
2610   }
2611   switch (numTotal) {
2612   case 0:
2613     opClass.addTrait("::mlir::OpTrait::Zero" + traitKind);
2614     break;
2615   case 1:
2616     opClass.addTrait("::mlir::OpTrait::One" + traitKind);
2617     break;
2618   default:
2619     opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
2620                      ">::Impl");
2621     break;
2622   }
2623 }
2624 
2625 void OpEmitter::genTraits() {
2626   // Add region size trait.
2627   unsigned numRegions = op.getNumRegions();
2628   unsigned numVariadicRegions = op.getNumVariadicRegions();
2629   addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
2630 
2631   // Add result size traits.
2632   int numResults = op.getNumResults();
2633   int numVariadicResults = op.getNumVariableLengthResults();
2634   addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
2635 
2636   // For single result ops with a known specific type, generate a OneTypedResult
2637   // trait.
2638   if (numResults == 1 && numVariadicResults == 0) {
2639     auto cppName = op.getResults().begin()->constraint.getCPPClassName();
2640     opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
2641   }
2642 
2643   // Add successor size trait.
2644   unsigned numSuccessors = op.getNumSuccessors();
2645   unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
2646   addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
2647 
2648   // Add variadic size trait and normal op traits.
2649   int numOperands = op.getNumOperands();
2650   int numVariadicOperands = op.getNumVariableLengthOperands();
2651 
2652   // Add operand size trait.
2653   if (numVariadicOperands != 0) {
2654     if (numOperands == numVariadicOperands)
2655       opClass.addTrait("::mlir::OpTrait::VariadicOperands");
2656     else
2657       opClass.addTrait("::mlir::OpTrait::AtLeastNOperands<" +
2658                        Twine(numOperands - numVariadicOperands) + ">::Impl");
2659   } else {
2660     switch (numOperands) {
2661     case 0:
2662       opClass.addTrait("::mlir::OpTrait::ZeroOperands");
2663       break;
2664     case 1:
2665       opClass.addTrait("::mlir::OpTrait::OneOperand");
2666       break;
2667     default:
2668       opClass.addTrait("::mlir::OpTrait::NOperands<" + Twine(numOperands) +
2669                        ">::Impl");
2670       break;
2671     }
2672   }
2673 
2674   // The op traits defined internal are ensured that they can be verified
2675   // earlier.
2676   for (const auto &trait : op.getTraits()) {
2677     if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
2678       if (opTrait->isStructuralOpTrait())
2679         opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2680     }
2681   }
2682 
2683   // OpInvariants wrapps the verifyInvariants which needs to be run before
2684   // native/interface traits and after all the traits with `StructuralOpTrait`.
2685   opClass.addTrait("::mlir::OpTrait::OpInvariants");
2686 
2687   // Add the native and interface traits.
2688   for (const auto &trait : op.getTraits()) {
2689     if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
2690       if (!opTrait->isStructuralOpTrait())
2691         opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2692     } else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) {
2693       opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2694     }
2695   }
2696 }
2697 
2698 void OpEmitter::genOpNameGetter() {
2699   auto *method = opClass.addStaticMethod<Method::Constexpr>(
2700       "::llvm::StringLiteral", "getOperationName");
2701   ERROR_IF_PRUNED(method, "getOperationName", op);
2702   method->body() << "  return ::llvm::StringLiteral(\"" << op.getOperationName()
2703                  << "\");";
2704 }
2705 
2706 void OpEmitter::genOpAsmInterface() {
2707   // If the user only has one results or specifically added the Asm trait,
2708   // then don't generate it for them. We specifically only handle multi result
2709   // operations, because the name of a single result in the common case is not
2710   // interesting(generally 'result'/'output'/etc.).
2711   // TODO: We could also add a flag to allow operations to opt in to this
2712   // generation, even if they only have a single operation.
2713   int numResults = op.getNumResults();
2714   if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
2715     return;
2716 
2717   SmallVector<StringRef, 4> resultNames(numResults);
2718   for (int i = 0; i != numResults; ++i)
2719     resultNames[i] = op.getResultName(i);
2720 
2721   // Don't add the trait if none of the results have a valid name.
2722   if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
2723     return;
2724   opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
2725 
2726   // Generate the right accessor for the number of results.
2727   auto *method = opClass.addMethod(
2728       "void", "getAsmResultNames",
2729       MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
2730   ERROR_IF_PRUNED(method, "getAsmResultNames", op);
2731   auto &body = method->body();
2732   for (int i = 0; i != numResults; ++i) {
2733     body << "  auto resultGroup" << i << " = getODSResults(" << i << ");\n"
2734          << "  if (!llvm::empty(resultGroup" << i << "))\n"
2735          << "    setNameFn(*resultGroup" << i << ".begin(), \""
2736          << resultNames[i] << "\");\n";
2737   }
2738 }
2739 
2740 //===----------------------------------------------------------------------===//
2741 // OpOperandAdaptor emitter
2742 //===----------------------------------------------------------------------===//
2743 
2744 namespace {
2745 // Helper class to emit Op operand adaptors to an output stream.  Operand
2746 // adaptors are wrappers around ArrayRef<Value> that provide named operand
2747 // getters identical to those defined in the Op.
2748 class OpOperandAdaptorEmitter {
2749 public:
2750   static void
2751   emitDecl(const Operator &op,
2752            const StaticVerifierFunctionEmitter &staticVerifierEmitter,
2753            raw_ostream &os);
2754   static void
2755   emitDef(const Operator &op,
2756           const StaticVerifierFunctionEmitter &staticVerifierEmitter,
2757           raw_ostream &os);
2758 
2759 private:
2760   explicit OpOperandAdaptorEmitter(
2761       const Operator &op,
2762       const StaticVerifierFunctionEmitter &staticVerifierEmitter);
2763 
2764   // Add verification function. This generates a verify method for the adaptor
2765   // which verifies all the op-independent attribute constraints.
2766   void addVerification();
2767 
2768   // The operation for which to emit an adaptor.
2769   const Operator &op;
2770 
2771   // The generated adaptor class.
2772   Class adaptor;
2773 
2774   // The emitter containing all of the locally emitted verification functions.
2775   const StaticVerifierFunctionEmitter &staticVerifierEmitter;
2776 
2777   // Helper for emitting adaptor code.
2778   OpOrAdaptorHelper emitHelper;
2779 };
2780 } // namespace
2781 
2782 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
2783     const Operator &op,
2784     const StaticVerifierFunctionEmitter &staticVerifierEmitter)
2785     : op(op), adaptor(op.getAdaptorName()),
2786       staticVerifierEmitter(staticVerifierEmitter),
2787       emitHelper(op, /*emitForOp=*/false) {
2788   adaptor.addField("::mlir::ValueRange", "odsOperands");
2789   adaptor.addField("::mlir::DictionaryAttr", "odsAttrs");
2790   adaptor.addField("::mlir::RegionRange", "odsRegions");
2791   adaptor.addField("::llvm::Optional<::mlir::OperationName>", "odsOpName");
2792 
2793   const auto *attrSizedOperands =
2794       op.getTrait("::m::OpTrait::AttrSizedOperandSegments");
2795   {
2796     SmallVector<MethodParameter> paramList;
2797     paramList.emplace_back("::mlir::ValueRange", "values");
2798     paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
2799                            attrSizedOperands ? "" : "nullptr");
2800     paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
2801     auto *constructor = adaptor.addConstructor(std::move(paramList));
2802 
2803     constructor->addMemberInitializer("odsOperands", "values");
2804     constructor->addMemberInitializer("odsAttrs", "attrs");
2805     constructor->addMemberInitializer("odsRegions", "regions");
2806 
2807     MethodBody &body = constructor->body();
2808     body.indent() << "if (odsAttrs)\n";
2809     body.indent() << formatv(
2810         "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n",
2811         op.getOperationName());
2812   }
2813 
2814   {
2815     auto *constructor =
2816         adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
2817     constructor->addMemberInitializer("odsOperands", "op->getOperands()");
2818     constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
2819     constructor->addMemberInitializer("odsRegions", "op->getRegions()");
2820     constructor->addMemberInitializer("odsOpName", "op->getName()");
2821   }
2822 
2823   {
2824     auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands");
2825     ERROR_IF_PRUNED(m, "getOperands", op);
2826     m->body() << "  return odsOperands;";
2827   }
2828   std::string sizeAttrInit;
2829   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
2830     sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode,
2831                            emitHelper.getAttr(operandSegmentAttrName));
2832   }
2833   generateNamedOperandGetters(op, adaptor,
2834                               /*isAdaptor=*/true, sizeAttrInit,
2835                               /*rangeType=*/"::mlir::ValueRange",
2836                               /*rangeBeginCall=*/"odsOperands.begin()",
2837                               /*rangeSizeCall=*/"odsOperands.size()",
2838                               /*getOperandCallPattern=*/"odsOperands[{0}]");
2839 
2840   FmtContext fctx;
2841   fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
2842 
2843   // Generate named accessor with Attribute return type.
2844   auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName,
2845                                      Attribute attr) {
2846     auto *method = adaptor.addMethod(attr.getStorageType(), emitName + "Attr");
2847     ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
2848     auto &body = method->body().indent();
2849     body << "assert(odsAttrs && \"no attributes when constructing adapter\");\n"
2850          << formatv("auto attr = {0}.{1}<{2}>();\n", emitHelper.getAttr(name),
2851                     attr.hasDefaultValue() || attr.isOptional()
2852                         ? "dyn_cast_or_null"
2853                         : "cast",
2854                     attr.getStorageType());
2855 
2856     if (attr.hasDefaultValue()) {
2857       // Use the default value if attribute is not set.
2858       // TODO: this is inefficient, we are recreating the attribute for every
2859       // call. This should be set instead.
2860       std::string defaultValue = std::string(
2861           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2862       body << "  if (!attr)\n    attr = " << defaultValue << ";\n";
2863     }
2864     body << "  return attr;\n";
2865   };
2866 
2867   {
2868     auto *m = adaptor.addMethod("::mlir::DictionaryAttr", "getAttributes");
2869     ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
2870     m->body() << "  return odsAttrs;";
2871   }
2872   for (auto &namedAttr : op.getAttributes()) {
2873     const auto &name = namedAttr.name;
2874     const auto &attr = namedAttr.attr;
2875     if (attr.isDerivedAttr())
2876       continue;
2877     for (const auto &emitName : op.getGetterNames(name)) {
2878       emitAttrWithStorageType(name, emitName, attr);
2879       emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr);
2880     }
2881   }
2882 
2883   unsigned numRegions = op.getNumRegions();
2884   if (numRegions > 0) {
2885     auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions");
2886     ERROR_IF_PRUNED(m, "Adaptor::getRegions", op);
2887     m->body() << "  return odsRegions;";
2888   }
2889   for (unsigned i = 0; i < numRegions; ++i) {
2890     const auto &region = op.getRegion(i);
2891     if (region.name.empty())
2892       continue;
2893 
2894     // Generate the accessors for a variadic region.
2895     for (StringRef name : op.getGetterNames(region.name)) {
2896       if (region.isVariadic()) {
2897         auto *m = adaptor.addMethod("::mlir::RegionRange", name);
2898         ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
2899         m->body() << formatv("  return odsRegions.drop_front({0});", i);
2900         continue;
2901       }
2902 
2903       auto *m = adaptor.addMethod("::mlir::Region &", name);
2904       ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
2905       m->body() << formatv("  return *odsRegions[{0}];", i);
2906     }
2907   }
2908 
2909   // Add verification function.
2910   addVerification();
2911   adaptor.finalize();
2912 }
2913 
2914 void OpOperandAdaptorEmitter::addVerification() {
2915   auto *method = adaptor.addMethod("::mlir::LogicalResult", "verify",
2916                                    MethodParameter("::mlir::Location", "loc"));
2917   ERROR_IF_PRUNED(method, "verify", op);
2918   auto &body = method->body();
2919 
2920   FmtContext verifyCtx;
2921   populateSubstitutions(emitHelper, verifyCtx);
2922   genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
2923 
2924   body << "  return ::mlir::success();";
2925 }
2926 
2927 void OpOperandAdaptorEmitter::emitDecl(
2928     const Operator &op,
2929     const StaticVerifierFunctionEmitter &staticVerifierEmitter,
2930     raw_ostream &os) {
2931   OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os);
2932 }
2933 
2934 void OpOperandAdaptorEmitter::emitDef(
2935     const Operator &op,
2936     const StaticVerifierFunctionEmitter &staticVerifierEmitter,
2937     raw_ostream &os) {
2938   OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os);
2939 }
2940 
2941 // Emits the opcode enum and op classes.
2942 static void emitOpClasses(const RecordKeeper &recordKeeper,
2943                           const std::vector<Record *> &defs, raw_ostream &os,
2944                           bool emitDecl) {
2945   // First emit forward declaration for each class, this allows them to refer
2946   // to each others in traits for example.
2947   if (emitDecl) {
2948     os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
2949     os << "#undef GET_OP_FWD_DEFINES\n";
2950     for (auto *def : defs) {
2951       Operator op(*def);
2952       NamespaceEmitter emitter(os, op.getCppNamespace());
2953       os << "class " << op.getCppClassName() << ";\n";
2954     }
2955     os << "#endif\n\n";
2956   }
2957 
2958   IfDefScope scope("GET_OP_CLASSES", os);
2959   if (defs.empty())
2960     return;
2961 
2962   // Generate all of the locally instantiated methods first.
2963   StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
2964   os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
2965   staticVerifierEmitter.emitOpConstraints(defs, emitDecl);
2966 
2967   for (auto *def : defs) {
2968     Operator op(*def);
2969     if (emitDecl) {
2970       {
2971         NamespaceEmitter emitter(os, op.getCppNamespace());
2972         os << formatv(opCommentHeader, op.getQualCppClassName(),
2973                       "declarations");
2974         OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os);
2975         OpEmitter::emitDecl(op, os, staticVerifierEmitter);
2976       }
2977       // Emit the TypeID explicit specialization to have a single definition.
2978       if (!op.getCppNamespace().empty())
2979         os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
2980            << "::" << op.getCppClassName() << ")\n\n";
2981     } else {
2982       {
2983         NamespaceEmitter emitter(os, op.getCppNamespace());
2984         os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
2985         OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os);
2986         OpEmitter::emitDef(op, os, staticVerifierEmitter);
2987       }
2988       // Emit the TypeID explicit specialization to have a single definition.
2989       if (!op.getCppNamespace().empty())
2990         os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
2991            << "::" << op.getCppClassName() << ")\n\n";
2992     }
2993   }
2994 }
2995 
2996 // Emits a comma-separated list of the ops.
2997 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
2998   IfDefScope scope("GET_OP_LIST", os);
2999 
3000   interleave(
3001       // TODO: We are constructing the Operator wrapper instance just for
3002       // getting it's qualified class name here. Reduce the overhead by having a
3003       // lightweight version of Operator class just for that purpose.
3004       defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
3005       [&os]() { os << ",\n"; });
3006 }
3007 
3008 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
3009   emitSourceFileHeader("Op Declarations", os);
3010 
3011   std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
3012   emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
3013 
3014   return false;
3015 }
3016 
3017 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
3018   emitSourceFileHeader("Op Definitions", os);
3019 
3020   std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
3021   emitOpList(defs, os);
3022   emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
3023 
3024   return false;
3025 }
3026 
3027 static mlir::GenRegistration
3028     genOpDecls("gen-op-decls", "Generate op declarations",
3029                [](const RecordKeeper &records, raw_ostream &os) {
3030                  return emitOpDecls(records, os);
3031                });
3032 
3033 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
3034                                        [](const RecordKeeper &records,
3035                                           raw_ostream &os) {
3036                                          return emitOpDefs(records, os);
3037                                        });
3038