1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "OpClass.h"
15 #include "OpFormatGen.h"
16 #include "OpGenHelpers.h"
17 #include "mlir/TableGen/Class.h"
18 #include "mlir/TableGen/CodeGenHelpers.h"
19 #include "mlir/TableGen/Format.h"
20 #include "mlir/TableGen/GenInfo.h"
21 #include "mlir/TableGen/Interfaces.h"
22 #include "mlir/TableGen/Operator.h"
23 #include "mlir/TableGen/SideEffects.h"
24 #include "mlir/TableGen/Trait.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringSet.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/Signals.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Record.h"
33 #include "llvm/TableGen/TableGenBackend.h"
34 
35 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
36 
37 using namespace llvm;
38 using namespace mlir;
39 using namespace mlir::tblgen;
40 
41 static const char *const tblgenNamePrefix = "tblgen_";
42 static const char *const generatedArgName = "odsArg";
43 static const char *const odsBuilder = "odsBuilder";
44 static const char *const builderOpState = "odsState";
45 
46 /// Code for an Op to lookup an attribute. Uses cached identifiers.
47 ///
48 /// {0}: The attribute's getter name.
49 static const char *const opGetAttr = "(*this)->getAttr({0}AttrName())";
50 
51 /// The logic to calculate the actual value range for a declared operand/result
52 /// of an op with variadic operands/results. Note that this logic is not for
53 /// general use; it assumes all variadic operands/results must have the same
54 /// number of values.
55 ///
56 /// {0}: The list of whether each declared operand/result is variadic.
57 /// {1}: The total number of non-variadic operands/results.
58 /// {2}: The total number of variadic operands/results.
59 /// {3}: The total number of actual values.
60 /// {4}: "operand" or "result".
61 static const char *const sameVariadicSizeValueRangeCalcCode = R"(
62   bool isVariadic[] = {{{0}};
63   int prevVariadicCount = 0;
64   for (unsigned i = 0; i < index; ++i)
65     if (isVariadic[i]) ++prevVariadicCount;
66 
67   // Calculate how many dynamic values a static variadic {4} corresponds to.
68   // This assumes all static variadic {4}s have the same dynamic value count.
69   int variadicSize = ({3} - {1}) / {2};
70   // `index` passed in as the parameter is the static index which counts each
71   // {4} (variadic or not) as size 1. So here for each previous static variadic
72   // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
73   // value pack for this static {4} starts.
74   int start = index + (variadicSize - 1) * prevVariadicCount;
75   int size = isVariadic[index] ? variadicSize : 1;
76   return {{start, size};
77 )";
78 
79 /// The logic to calculate the actual value range for a declared operand/result
80 /// of an op with variadic operands/results. Note that this logic is assumes
81 /// the op has an attribute specifying the size of each operand/result segment
82 /// (variadic or not).
83 ///
84 /// {0}: The name of the attribute specifying the segment sizes.
85 static const char *const adapterSegmentSizeAttrInitCode = R"(
86   assert(odsAttrs && "missing segment size attribute for op");
87   auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
88 )";
89 static const char *const opSegmentSizeAttrInitCode = R"(
90   auto sizeAttr =
91       (*this)->getAttr({0}AttrName()).cast<::mlir::DenseIntElementsAttr>();
92 )";
93 static const char *const attrSizedSegmentValueRangeCalcCode = R"(
94   const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin<uint32_t>();
95   if (sizeAttr.isSplat())
96     return {*sizeAttrValueIt * index, *sizeAttrValueIt};
97 
98   unsigned start = 0;
99   for (unsigned i = 0; i < index; ++i)
100     start += sizeAttrValueIt[i];
101   return {start, sizeAttrValueIt[index]};
102 )";
103 
104 /// The logic to calculate the actual value range for a declared operand
105 /// of an op with variadic of variadic operands within the OpAdaptor.
106 ///
107 /// {0}: The name of the segment attribute.
108 /// {1}: The index of the main operand.
109 static const char *const variadicOfVariadicAdaptorCalcCode = R"(
110   auto tblgenTmpOperands = getODSOperands({1});
111   auto sizeAttrValues = {0}().getValues<uint32_t>();
112   auto sizeAttrIt = sizeAttrValues.begin();
113 
114   ::llvm::SmallVector<::mlir::ValueRange> tblgenTmpOperandGroups;
115   for (int i = 0, e = ::llvm::size(sizeAttrValues); i < e; ++i, ++sizeAttrIt) {{
116     tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(*sizeAttrIt));
117     tblgenTmpOperands = tblgenTmpOperands.drop_front(*sizeAttrIt);
118   }
119   return tblgenTmpOperandGroups;
120 )";
121 
122 /// The logic to build a range of either operand or result values.
123 ///
124 /// {0}: The begin iterator of the actual values.
125 /// {1}: The call to generate the start and length of the value range.
126 static const char *const valueRangeReturnCode = R"(
127   auto valueRange = {1};
128   return {{std::next({0}, valueRange.first),
129            std::next({0}, valueRange.first + valueRange.second)};
130 )";
131 
132 /// A header for indicating code sections.
133 ///
134 /// {0}: Some text, or a class name.
135 /// {1}: Some text.
136 static const char *const opCommentHeader = R"(
137 //===----------------------------------------------------------------------===//
138 // {0} {1}
139 //===----------------------------------------------------------------------===//
140 
141 )";
142 
143 //===----------------------------------------------------------------------===//
144 // Utility structs and functions
145 //===----------------------------------------------------------------------===//
146 
147 // Replaces all occurrences of `match` in `str` with `substitute`.
148 static std::string replaceAllSubstrs(std::string str, const std::string &match,
149                                      const std::string &substitute) {
150   std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
151   while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
152     str = str.replace(matchLoc, match.size(), substitute);
153     scanLoc = matchLoc + substitute.size();
154   }
155   return str;
156 }
157 
158 // Returns whether the record has a value of the given name that can be returned
159 // via getValueAsString.
160 static inline bool hasStringAttribute(const Record &record,
161                                       StringRef fieldName) {
162   auto *valueInit = record.getValueInit(fieldName);
163   return isa<StringInit>(valueInit);
164 }
165 
166 static std::string getArgumentName(const Operator &op, int index) {
167   const auto &operand = op.getOperand(index);
168   if (!operand.name.empty())
169     return std::string(operand.name);
170   return std::string(formatv("{0}_{1}", generatedArgName, index));
171 }
172 
173 // Returns true if we can use unwrapped value for the given `attr` in builders.
174 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
175   return attr.getReturnType() != attr.getStorageType() &&
176          // We need to wrap the raw value into an attribute in the builder impl
177          // so we need to make sure that the attribute specifies how to do that.
178          !attr.getConstBuilderTemplate().empty();
179 }
180 
181 namespace {
182 /// Helper class to select between OpAdaptor and Op code templates.
183 class OpOrAdaptorHelper {
184 public:
185   OpOrAdaptorHelper(const Operator &op, bool emitForOp)
186       : op(op), emitForOp(emitForOp) {}
187 
188   /// Object that wraps a functor in a stream operator for interop with
189   /// llvm::formatv.
190   class Formatter {
191   public:
192     template <typename Functor>
193     Formatter(Functor &&func) : func(std::forward<Functor>(func)) {}
194 
195     std::string str() const {
196       std::string result;
197       llvm::raw_string_ostream os(result);
198       os << *this;
199       return os.str();
200     }
201 
202   private:
203     std::function<raw_ostream &(raw_ostream &)> func;
204 
205     friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) {
206       return fmt.func(os);
207     }
208   };
209 
210   // Generate code for getting an attribute.
211   Formatter getAttr(StringRef attrName) const {
212     return [this, attrName](raw_ostream &os) -> raw_ostream & {
213       if (!emitForOp)
214         return os << formatv("odsAttrs.get(\"{0}\")", attrName);
215       return os << formatv(opGetAttr, op.getGetterName(attrName));
216     };
217   }
218 
219   // Get the prefix code for emitting an error.
220   Formatter emitErrorPrefix() const {
221     return [this](raw_ostream &os) -> raw_ostream & {
222       if (emitForOp)
223         return os << "emitOpError(";
224       return os << formatv("emitError(loc, \"'{0}' op \"",
225                            op.getOperationName());
226     };
227   }
228 
229   // Get the call to get an operand or segment of operands.
230   Formatter getOperand(unsigned index) const {
231     return [this, index](raw_ostream &os) -> raw_ostream & {
232       return os << formatv(op.getOperand(index).isVariadic()
233                                ? "this->getODSOperands({0})"
234                                : "(*this->getODSOperands({0}).begin())",
235                            index);
236     };
237   }
238 
239   // Get the call to get a result of segment of results.
240   Formatter getResult(unsigned index) const {
241     return [this, index](raw_ostream &os) -> raw_ostream & {
242       if (!emitForOp)
243         return os << "<no results should be generated>";
244       return os << formatv(op.getResult(index).isVariadic()
245                                ? "this->getODSResults({0})"
246                                : "(*this->getODSResults({0}).begin())",
247                            index);
248     };
249   }
250 
251   // Return whether an op instance is available.
252   bool isEmittingForOp() const { return emitForOp; }
253 
254   // Return the ODS operation wrapper.
255   const Operator &getOp() const { return op; }
256 
257 private:
258   // The operation ODS wrapper.
259   const Operator &op;
260   // True if code is being generate for an op. False for an adaptor.
261   const bool emitForOp;
262 };
263 } // namespace
264 
265 //===----------------------------------------------------------------------===//
266 // Op emitter
267 //===----------------------------------------------------------------------===//
268 
269 namespace {
270 // Helper class to emit a record into the given output stream.
271 class OpEmitter {
272 public:
273   static void
274   emitDecl(const Operator &op, raw_ostream &os,
275            const StaticVerifierFunctionEmitter &staticVerifierEmitter);
276   static void
277   emitDef(const Operator &op, raw_ostream &os,
278           const StaticVerifierFunctionEmitter &staticVerifierEmitter);
279 
280 private:
281   OpEmitter(const Operator &op,
282             const StaticVerifierFunctionEmitter &staticVerifierEmitter);
283 
284   void emitDecl(raw_ostream &os);
285   void emitDef(raw_ostream &os);
286 
287   // Generate methods for accessing the attribute names of this operation.
288   void genAttrNameGetters();
289 
290   // Generates the OpAsmOpInterface for this operation if possible.
291   void genOpAsmInterface();
292 
293   // Generates the `getOperationName` method for this op.
294   void genOpNameGetter();
295 
296   // Generates getters for the attributes.
297   void genAttrGetters();
298 
299   // Generates setter for the attributes.
300   void genAttrSetters();
301 
302   // Generates removers for optional attributes.
303   void genOptionalAttrRemovers();
304 
305   // Generates getters for named operands.
306   void genNamedOperandGetters();
307 
308   // Generates setters for named operands.
309   void genNamedOperandSetters();
310 
311   // Generates getters for named results.
312   void genNamedResultGetters();
313 
314   // Generates getters for named regions.
315   void genNamedRegionGetters();
316 
317   // Generates getters for named successors.
318   void genNamedSuccessorGetters();
319 
320   // Generates builder methods for the operation.
321   void genBuilder();
322 
323   // Generates the build() method that takes each operand/attribute
324   // as a stand-alone parameter.
325   void genSeparateArgParamBuilder();
326 
327   // Generates the build() method that takes each operand/attribute as a
328   // stand-alone parameter. The generated build() method uses first operand's
329   // type as all results' types.
330   void genUseOperandAsResultTypeSeparateParamBuilder();
331 
332   // Generates the build() method that takes all operands/attributes
333   // collectively as one parameter. The generated build() method uses first
334   // operand's type as all results' types.
335   void genUseOperandAsResultTypeCollectiveParamBuilder();
336 
337   // Generates the build() method that takes aggregate operands/attributes
338   // parameters. This build() method uses inferred types as result types.
339   // Requires: The type needs to be inferable via InferTypeOpInterface.
340   void genInferredTypeCollectiveParamBuilder();
341 
342   // Generates the build() method that takes each operand/attribute as a
343   // stand-alone parameter. The generated build() method uses first attribute's
344   // type as all result's types.
345   void genUseAttrAsResultTypeBuilder();
346 
347   // Generates the build() method that takes all result types collectively as
348   // one parameter. Similarly for operands and attributes.
349   void genCollectiveParamBuilder();
350 
351   // The kind of parameter to generate for result types in builders.
352   enum class TypeParamKind {
353     None,       // No result type in parameter list.
354     Separate,   // A separate parameter for each result type.
355     Collective, // An ArrayRef<Type> for all result types.
356   };
357 
358   // The kind of parameter to generate for attributes in builders.
359   enum class AttrParamKind {
360     WrappedAttr,    // A wrapped MLIR Attribute instance.
361     UnwrappedValue, // A raw value without MLIR Attribute wrapper.
362   };
363 
364   // Builds the parameter list for build() method of this op. This method writes
365   // to `paramList` the comma-separated parameter list and updates
366   // `resultTypeNames` with the names for parameters for specifying result
367   // types. `inferredAttributes` is populated with any attributes that are
368   // elided from the build list. The given `typeParamKind` and `attrParamKind`
369   // controls how result types and attributes are placed in the parameter list.
370   void buildParamList(SmallVectorImpl<MethodParameter> &paramList,
371                       llvm::StringSet<> &inferredAttributes,
372                       SmallVectorImpl<std::string> &resultTypeNames,
373                       TypeParamKind typeParamKind,
374                       AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
375 
376   // Adds op arguments and regions into operation state for build() methods.
377   void
378   genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
379                                          llvm::StringSet<> &inferredAttributes,
380                                          bool isRawValueAttr = false);
381 
382   // Generates canonicalizer declaration for the operation.
383   void genCanonicalizerDecls();
384 
385   // Generates the folder declaration for the operation.
386   void genFolderDecls();
387 
388   // Generates the parser for the operation.
389   void genParser();
390 
391   // Generates the printer for the operation.
392   void genPrinter();
393 
394   // Generates verify method for the operation.
395   void genVerifier();
396 
397   // Generates verify statements for operands and results in the operation.
398   // The generated code will be attached to `body`.
399   void genOperandResultVerifier(MethodBody &body,
400                                 Operator::const_value_range values,
401                                 StringRef valueKind);
402 
403   // Generates verify statements for regions in the operation.
404   // The generated code will be attached to `body`.
405   void genRegionVerifier(MethodBody &body);
406 
407   // Generates verify statements for successors in the operation.
408   // The generated code will be attached to `body`.
409   void genSuccessorVerifier(MethodBody &body);
410 
411   // Generates the traits used by the object.
412   void genTraits();
413 
414   // Generate the OpInterface methods for all interfaces.
415   void genOpInterfaceMethods();
416 
417   // Generate op interface methods for the given interface.
418   void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
419 
420   // Generate op interface method for the given interface method. If
421   // 'declaration' is true, generates a declaration, else a definition.
422   Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
423                                bool declaration = true);
424 
425   // Generate the side effect interface methods.
426   void genSideEffectInterfaceMethods();
427 
428   // Generate the type inference interface methods.
429   void genTypeInterfaceMethods();
430 
431 private:
432   // The TableGen record for this op.
433   // TODO: OpEmitter should not have a Record directly,
434   // it should rather go through the Operator for better abstraction.
435   const Record &def;
436 
437   // The wrapper operator class for querying information from this op.
438   Operator op;
439 
440   // The C++ code builder for this op
441   OpClass opClass;
442 
443   // The format context for verification code generation.
444   FmtContext verifyCtx;
445 
446   // The emitter containing all of the locally emitted verification functions.
447   const StaticVerifierFunctionEmitter &staticVerifierEmitter;
448 };
449 
450 } // namespace
451 
452 // Populate the format context `ctx` with substitutions of attributes, operands
453 // and results.
454 static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
455                                   FmtContext &ctx) {
456   // Populate substitutions for attributes.
457   auto &op = emitHelper.getOp();
458   for (const auto &namedAttr : op.getAttributes())
459     ctx.addSubst(namedAttr.name, emitHelper.getAttr(namedAttr.name).str());
460 
461   // Populate substitutions for named operands.
462   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
463     auto &value = op.getOperand(i);
464     if (!value.name.empty())
465       ctx.addSubst(value.name, emitHelper.getOperand(i).str());
466   }
467 
468   // Populate substitutions for results.
469   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
470     auto &value = op.getResult(i);
471     if (!value.name.empty())
472       ctx.addSubst(value.name, emitHelper.getResult(i).str());
473   }
474 }
475 
476 // Generate attribute verification. If an op instance is not available, then
477 // attribute checks that require one will not be emitted.
478 static void genAttributeVerifier(
479     const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body,
480     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
481   // Check that a required attribute exists.
482   //
483   // {0}: Attribute variable name.
484   // {1}: Emit error prefix.
485   // {2}: Attribute name.
486   const char *const verifyRequiredAttr = R"(
487     if (!{0})
488       return {1}"requires attribute '{2}'");
489 )";
490   // Verify the attribute if it is present. This assumes that default values
491   // are valid. This code snippet pastes the condition inline.
492   //
493   // TODO: verify the default value is valid (perhaps in debug mode only).
494   //
495   // {0}: Attribute variable name.
496   // {1}: Attribute condition code.
497   // {2}: Emit error prefix.
498   // {3}: Attribute name.
499   // {4}: Attribute/constraint description.
500   const char *const verifyAttrInline = R"(
501     if ({0} && !({1}))
502       return {2}"attribute '{3}' failed to satisfy constraint: {4}");
503 )";
504   // Verify the attribute using a uniqued constraint. Can only be used within
505   // the context of an op.
506   //
507   // {0}: Unique constraint name.
508   // {1}: Attribute variable name.
509   // {2}: Attribute name.
510   const char *const verifyAttrUnique = R"(
511     if (::mlir::failed({0}(*this, {1}, "{2}")))
512       return ::mlir::failure();
513 )";
514 
515   for (const auto &namedAttr : emitHelper.getOp().getAttributes()) {
516     const auto &attr = namedAttr.attr;
517     StringRef attrName = namedAttr.name;
518     if (attr.isDerivedAttr())
519       continue;
520 
521     bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
522     auto attrPred = attr.getPredicate();
523     std::string condition = attrPred.isNull() ? "" : attrPred.getCondition();
524     // If the attribute's condition needs an op but none is available, then the
525     // condition cannot be emitted.
526     bool canEmitCondition =
527         !condition.empty() && (!StringRef(condition).contains("$_op") ||
528                                emitHelper.isEmittingForOp());
529 
530     // Prefix with `tblgen_` to avoid hiding the attribute accessor.
531     std::string varName = (tblgenNamePrefix + attrName).str();
532 
533     // If the attribute is not required and we cannot emit the condition, then
534     // there is nothing to be done.
535     if (allowMissingAttr && !canEmitCondition)
536       continue;
537 
538     body << formatv("  {\n    auto {0} = {1};", varName,
539                     emitHelper.getAttr(attrName));
540 
541     if (!allowMissingAttr) {
542       body << formatv(verifyRequiredAttr, varName, emitHelper.emitErrorPrefix(),
543                       attrName);
544     }
545     if (canEmitCondition) {
546       Optional<StringRef> constraintFn;
547       if (emitHelper.isEmittingForOp() &&
548           (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) {
549         body << formatv(verifyAttrUnique, *constraintFn, varName, attrName);
550       } else {
551         body << formatv(verifyAttrInline, varName,
552                         tgfmt(condition, &ctx.withSelf(varName)),
553                         emitHelper.emitErrorPrefix(), attrName,
554                         escapeString(attr.getSummary()));
555       }
556     }
557     body << "  }\n";
558   }
559 }
560 
561 /// Op extra class definitions have a `$cppClass` substitution that is to be
562 /// replaced by the C++ class name.
563 static std::string formatExtraDefinitions(const Operator &op) {
564   FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName());
565   return tgfmt(op.getExtraClassDefinition(), &ctx).str();
566 }
567 
568 OpEmitter::OpEmitter(const Operator &op,
569                      const StaticVerifierFunctionEmitter &staticVerifierEmitter)
570     : def(op.getDef()), op(op),
571       opClass(op.getCppClassName(), op.getExtraClassDeclaration(),
572               formatExtraDefinitions(op)),
573       staticVerifierEmitter(staticVerifierEmitter) {
574   verifyCtx.withOp("(*this->getOperation())");
575   verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
576 
577   genTraits();
578 
579   // Generate C++ code for various op methods. The order here determines the
580   // methods in the generated file.
581   genAttrNameGetters();
582   genOpAsmInterface();
583   genOpNameGetter();
584   genNamedOperandGetters();
585   genNamedOperandSetters();
586   genNamedResultGetters();
587   genNamedRegionGetters();
588   genNamedSuccessorGetters();
589   genAttrGetters();
590   genAttrSetters();
591   genOptionalAttrRemovers();
592   genBuilder();
593   genParser();
594   genPrinter();
595   genVerifier();
596   genCanonicalizerDecls();
597   genFolderDecls();
598   genTypeInterfaceMethods();
599   genOpInterfaceMethods();
600   generateOpFormat(op, opClass);
601   genSideEffectInterfaceMethods();
602 }
603 void OpEmitter::emitDecl(
604     const Operator &op, raw_ostream &os,
605     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
606   OpEmitter(op, staticVerifierEmitter).emitDecl(os);
607 }
608 
609 void OpEmitter::emitDef(
610     const Operator &op, raw_ostream &os,
611     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
612   OpEmitter(op, staticVerifierEmitter).emitDef(os);
613 }
614 
615 void OpEmitter::emitDecl(raw_ostream &os) {
616   opClass.finalize();
617   opClass.writeDeclTo(os);
618 }
619 
620 void OpEmitter::emitDef(raw_ostream &os) {
621   opClass.finalize();
622   opClass.writeDefTo(os);
623 }
624 
625 static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
626                           const Operator &op) {
627   if (m)
628     return;
629   PrintFatalError(op.getLoc(), "Unexpected overlap when generating `" +
630                                    methodName + "` for " +
631                                    op.getOperationName() + " (from line " +
632                                    Twine(line) + ")");
633 }
634 #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
635 
636 void OpEmitter::genAttrNameGetters() {
637   // A map of attribute names (including implicit attributes) registered to the
638   // current operation, to the relative order in which they were registered.
639   llvm::MapVector<StringRef, unsigned> attributeNames;
640 
641   // Enumerate the attribute names of this op, assigning each a relative
642   // ordering.
643   auto addAttrName = [&](StringRef name) {
644     unsigned index = attributeNames.size();
645     attributeNames.insert({name, index});
646   };
647   for (const NamedAttribute &namedAttr : op.getAttributes())
648     addAttrName(namedAttr.name);
649   // Include key attributes from several traits as implicitly registered.
650   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
651     addAttrName("operand_segment_sizes");
652   if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
653     addAttrName("result_segment_sizes");
654 
655   // Emit the getAttributeNames method.
656   {
657     auto *method = opClass.addStaticInlineMethod(
658         "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames");
659     ERROR_IF_PRUNED(method, "getAttributeNames", op);
660     auto &body = method->body();
661     if (attributeNames.empty()) {
662       body << "  return {};";
663     } else {
664       body << "  static ::llvm::StringRef attrNames[] = {";
665       llvm::interleaveComma(llvm::make_first_range(attributeNames), body,
666                             [&](StringRef attrName) {
667                               body << "::llvm::StringRef(\"" << attrName
668                                    << "\")";
669                             });
670       body << "};\n  return ::llvm::makeArrayRef(attrNames);";
671     }
672   }
673   if (attributeNames.empty())
674     return;
675 
676   // Emit the getAttributeNameForIndex methods.
677   {
678     auto *method = opClass.addInlineMethod<Method::Private>(
679         "::mlir::StringAttr", "getAttributeNameForIndex",
680         MethodParameter("unsigned", "index"));
681     ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
682     method->body()
683         << "  return getAttributeNameForIndex((*this)->getName(), index);";
684   }
685   {
686     auto *method = opClass.addStaticInlineMethod<Method::Private>(
687         "::mlir::StringAttr", "getAttributeNameForIndex",
688         MethodParameter("::mlir::OperationName", "name"),
689         MethodParameter("unsigned", "index"));
690     ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
691 
692     const char *const getAttrName = R"(
693   assert(index < {0} && "invalid attribute index");
694   return name.getRegisteredInfo()->getAttributeNames()[index];
695 )";
696     method->body() << formatv(getAttrName, attributeNames.size());
697   }
698 
699   // Generate the <attr>AttrName methods, that expose the attribute names to
700   // users.
701   const char *attrNameMethodBody = "  return getAttributeNameForIndex({0});";
702   for (const std::pair<StringRef, unsigned> &attrIt : attributeNames) {
703     for (StringRef name : op.getGetterNames(attrIt.first)) {
704       std::string methodName = (name + "AttrName").str();
705 
706       // Generate the non-static variant.
707       {
708         auto *method =
709             opClass.addInlineMethod("::mlir::StringAttr", methodName);
710         ERROR_IF_PRUNED(method, methodName, op);
711         method->body() << llvm::formatv(attrNameMethodBody, attrIt.second);
712       }
713 
714       // Generate the static variant.
715       {
716         auto *method = opClass.addStaticInlineMethod(
717             "::mlir::StringAttr", methodName,
718             MethodParameter("::mlir::OperationName", "name"));
719         ERROR_IF_PRUNED(method, methodName, op);
720         method->body() << llvm::formatv(attrNameMethodBody,
721                                         "name, " + Twine(attrIt.second));
722       }
723     }
724   }
725 }
726 
727 // Emit the getter for an attribute with the return type specified.
728 // It is templated to be shared between the Op and the adaptor class.
729 template <typename OpClassOrAdaptor>
730 static void emitAttrGetterWithReturnType(FmtContext &fctx,
731                                          OpClassOrAdaptor &opClass,
732                                          const Operator &op, StringRef name,
733                                          Attribute attr) {
734   auto *method = opClass.addMethod(attr.getReturnType(), name);
735   ERROR_IF_PRUNED(method, name, op);
736   auto &body = method->body();
737   body << "  auto attr = " << name << "Attr();\n";
738   if (attr.hasDefaultValue()) {
739     // Returns the default value if not set.
740     // TODO: this is inefficient, we are recreating the attribute for every
741     // call. This should be set instead.
742     std::string defaultValue = std::string(
743         tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
744     body << "    if (!attr)\n      return "
745          << tgfmt(attr.getConvertFromStorageCall(),
746                   &fctx.withSelf(defaultValue))
747          << ";\n";
748   }
749   body << "  return "
750        << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
751        << ";\n";
752 }
753 
754 void OpEmitter::genAttrGetters() {
755   FmtContext fctx;
756   fctx.withBuilder("::mlir::Builder((*this)->getContext())");
757 
758   // Emit the derived attribute body.
759   auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
760     if (auto *method = opClass.addMethod(attr.getReturnType(), name))
761       method->body() << "  " << attr.getDerivedCodeBody() << "\n";
762   };
763 
764   // Generate named accessor with Attribute return type. This is a wrapper class
765   // that allows referring to the attributes via accessors instead of having to
766   // use the string interface for better compile time verification.
767   auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
768     auto *method = opClass.addMethod(attr.getStorageType(), name + "Attr");
769     if (!method)
770       return;
771     method->body() << formatv(
772         "  return {0}.{1}<{2}>();", formatv(opGetAttr, name),
773         attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null"
774                                                     : "cast",
775         attr.getStorageType());
776   };
777 
778   for (const NamedAttribute &namedAttr : op.getAttributes()) {
779     for (StringRef name : op.getGetterNames(namedAttr.name)) {
780       if (namedAttr.attr.isDerivedAttr()) {
781         emitDerivedAttr(name, namedAttr.attr);
782       } else {
783         emitAttrWithStorageType(name, namedAttr.attr);
784         emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr);
785       }
786     }
787   }
788 
789   auto derivedAttrs = make_filter_range(op.getAttributes(),
790                                         [](const NamedAttribute &namedAttr) {
791                                           return namedAttr.attr.isDerivedAttr();
792                                         });
793   if (derivedAttrs.empty())
794     return;
795 
796   opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
797   // Generate helper method to query whether a named attribute is a derived
798   // attribute. This enables, for example, avoiding adding an attribute that
799   // overlaps with a derived attribute.
800   {
801     auto *method =
802         opClass.addStaticMethod("bool", "isDerivedAttribute",
803                                 MethodParameter("::llvm::StringRef", "name"));
804     ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
805     auto &body = method->body();
806     for (auto namedAttr : derivedAttrs)
807       body << "  if (name == \"" << namedAttr.name << "\") return true;\n";
808     body << " return false;";
809   }
810   // Generate method to materialize derived attributes as a DictionaryAttr.
811   {
812     auto *method = opClass.addMethod("::mlir::DictionaryAttr",
813                                      "materializeDerivedAttributes");
814     ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
815     auto &body = method->body();
816 
817     auto nonMaterializable =
818         make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
819           return namedAttr.attr.getConvertFromStorageCall().empty();
820         });
821     if (!nonMaterializable.empty()) {
822       std::string attrs;
823       llvm::raw_string_ostream os(attrs);
824       interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
825         os << op.getGetterName(attr.name);
826       });
827       PrintWarning(
828           op.getLoc(),
829           formatv(
830               "op has non-materializable derived attributes '{0}', skipping",
831               os.str()));
832       body << formatv("  emitOpError(\"op has non-materializable derived "
833                       "attributes '{0}'\");\n",
834                       attrs);
835       body << "  return nullptr;";
836       return;
837     }
838 
839     body << "  ::mlir::MLIRContext* ctx = getContext();\n";
840     body << "  ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
841     body << "  return ::mlir::DictionaryAttr::get(";
842     body << "  ctx, {\n";
843     interleave(
844         derivedAttrs, body,
845         [&](const NamedAttribute &namedAttr) {
846           auto tmpl = namedAttr.attr.getConvertFromStorageCall();
847           std::string name = op.getGetterName(namedAttr.name);
848           body << "    {" << name << "AttrName(),\n"
849                << tgfmt(tmpl, &fctx.withSelf(name + "()")
850                                    .withBuilder("odsBuilder")
851                                    .addSubst("_ctx", "ctx"))
852                << "}";
853         },
854         ",\n");
855     body << "});";
856   }
857 }
858 
859 void OpEmitter::genAttrSetters() {
860   // Generate raw named setter type. This is a wrapper class that allows setting
861   // to the attributes via setters instead of having to use the string interface
862   // for better compile time verification.
863   auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
864                                      Attribute attr) {
865     auto *method =
866         opClass.addMethod("void", setterName + "Attr",
867                           MethodParameter(attr.getStorageType(), "attr"));
868     if (method)
869       method->body() << formatv("  (*this)->setAttr({0}AttrName(), attr);",
870                                 getterName);
871   };
872 
873   for (const NamedAttribute &namedAttr : op.getAttributes()) {
874     if (namedAttr.attr.isDerivedAttr())
875       continue;
876     for (auto names : llvm::zip(op.getSetterNames(namedAttr.name),
877                                 op.getGetterNames(namedAttr.name)))
878       emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
879                               namedAttr.attr);
880   }
881 }
882 
883 void OpEmitter::genOptionalAttrRemovers() {
884   // Generate methods for removing optional attributes, instead of having to
885   // use the string interface. Enables better compile time verification.
886   auto emitRemoveAttr = [&](StringRef name) {
887     auto upperInitial = name.take_front().upper();
888     auto suffix = name.drop_front();
889     auto *method = opClass.addMethod("::mlir::Attribute",
890                                      "remove" + upperInitial + suffix + "Attr");
891     if (!method)
892       return;
893     method->body() << formatv("  return (*this)->removeAttr({0}AttrName());",
894                               op.getGetterName(name));
895   };
896 
897   for (const NamedAttribute &namedAttr : op.getAttributes())
898     if (namedAttr.attr.isOptional())
899       emitRemoveAttr(namedAttr.name);
900 }
901 
902 // Generates the code to compute the start and end index of an operand or result
903 // range.
904 template <typename RangeT>
905 static void
906 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
907                               int numVariadic, int numNonVariadic,
908                               StringRef rangeSizeCall, bool hasAttrSegmentSize,
909                               StringRef sizeAttrInit, RangeT &&odsValues) {
910   auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName,
911                                    MethodParameter("unsigned", "index"));
912   if (!method)
913     return;
914   auto &body = method->body();
915   if (numVariadic == 0) {
916     body << "  return {index, 1};\n";
917   } else if (hasAttrSegmentSize) {
918     body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
919   } else {
920     // Because the op can have arbitrarily interleaved variadic and non-variadic
921     // operands, we need to embed a list in the "sink" getter method for
922     // calculation at run-time.
923     SmallVector<StringRef, 4> isVariadic;
924     isVariadic.reserve(llvm::size(odsValues));
925     for (auto &it : odsValues)
926       isVariadic.push_back(it.isVariableLength() ? "true" : "false");
927     std::string isVariadicList = llvm::join(isVariadic, ", ");
928     body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
929                     numNonVariadic, numVariadic, rangeSizeCall, "operand");
930   }
931 }
932 
933 // Generates the named operand getter methods for the given Operator `op` and
934 // puts them in `opClass`.  Uses `rangeType` as the return type of getters that
935 // return a range of operands (individual operands are `Value ` and each
936 // element in the range must also be `Value `); use `rangeBeginCall` to get
937 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
938 // obtain the number of operands. `getOperandCallPattern` contains the code
939 // necessary to obtain a single operand whose position will be substituted
940 // instead of
941 // "{0}" marker in the pattern.  Note that the pattern should work for any kind
942 // of ops, in particular for one-operand ops that may not have the
943 // `getOperand(unsigned)` method.
944 static void generateNamedOperandGetters(const Operator &op, Class &opClass,
945                                         bool isAdaptor, StringRef sizeAttrInit,
946                                         StringRef rangeType,
947                                         StringRef rangeBeginCall,
948                                         StringRef rangeSizeCall,
949                                         StringRef getOperandCallPattern) {
950   const int numOperands = op.getNumOperands();
951   const int numVariadicOperands = op.getNumVariableLengthOperands();
952   const int numNormalOperands = numOperands - numVariadicOperands;
953 
954   const auto *sameVariadicSize =
955       op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
956   const auto *attrSizedOperands =
957       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
958 
959   if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
960     PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
961                                  "specification over their sizes");
962   }
963 
964   if (numVariadicOperands < 2 && attrSizedOperands) {
965     PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
966                                  "to use 'AttrSizedOperandSegments' trait");
967   }
968 
969   if (attrSizedOperands && sameVariadicSize) {
970     PrintFatalError(op.getLoc(),
971                     "op cannot have both 'AttrSizedOperandSegments' and "
972                     "'SameVariadicOperandSize' traits");
973   }
974 
975   // First emit a few "sink" getter methods upon which we layer all nicer named
976   // getter methods.
977   generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
978                                 numVariadicOperands, numNormalOperands,
979                                 rangeSizeCall, attrSizedOperands, sizeAttrInit,
980                                 const_cast<Operator &>(op).getOperands());
981 
982   auto *m = opClass.addMethod(rangeType, "getODSOperands",
983                               MethodParameter("unsigned", "index"));
984   ERROR_IF_PRUNED(m, "getODSOperands", op);
985   auto &body = m->body();
986   body << formatv(valueRangeReturnCode, rangeBeginCall,
987                   "getODSOperandIndexAndLength(index)");
988 
989   // Then we emit nicer named getter methods by redirecting to the "sink" getter
990   // method.
991   for (int i = 0; i != numOperands; ++i) {
992     const auto &operand = op.getOperand(i);
993     if (operand.name.empty())
994       continue;
995     for (StringRef name : op.getGetterNames(operand.name)) {
996       if (operand.isOptional()) {
997         m = opClass.addMethod("::mlir::Value", name);
998         ERROR_IF_PRUNED(m, name, op);
999         m->body() << "  auto operands = getODSOperands(" << i << ");\n"
1000                   << "  return operands.empty() ? ::mlir::Value() : "
1001                      "*operands.begin();";
1002       } else if (operand.isVariadicOfVariadic()) {
1003         std::string segmentAttr = op.getGetterName(
1004             operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
1005         if (isAdaptor) {
1006           m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>",
1007                                 name);
1008           ERROR_IF_PRUNED(m, name, op);
1009           m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
1010                                      segmentAttr, i);
1011           continue;
1012         }
1013 
1014         m = opClass.addMethod("::mlir::OperandRangeRange", name);
1015         ERROR_IF_PRUNED(m, name, op);
1016         m->body() << "  return getODSOperands(" << i << ").split("
1017                   << segmentAttr << "Attr());";
1018       } else if (operand.isVariadic()) {
1019         m = opClass.addMethod(rangeType, name);
1020         ERROR_IF_PRUNED(m, name, op);
1021         m->body() << "  return getODSOperands(" << i << ");";
1022       } else {
1023         m = opClass.addMethod("::mlir::Value", name);
1024         ERROR_IF_PRUNED(m, name, op);
1025         m->body() << "  return *getODSOperands(" << i << ").begin();";
1026       }
1027     }
1028   }
1029 }
1030 
1031 void OpEmitter::genNamedOperandGetters() {
1032   // Build the code snippet used for initializing the operand_segment_size)s
1033   // array.
1034   std::string attrSizeInitCode;
1035   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1036     std::string attr = op.getGetterName("operand_segment_sizes");
1037     attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str();
1038   }
1039 
1040   generateNamedOperandGetters(
1041       op, opClass,
1042       /*isAdaptor=*/false,
1043       /*sizeAttrInit=*/attrSizeInitCode,
1044       /*rangeType=*/"::mlir::Operation::operand_range",
1045       /*rangeBeginCall=*/"getOperation()->operand_begin()",
1046       /*rangeSizeCall=*/"getOperation()->getNumOperands()",
1047       /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
1048 }
1049 
1050 void OpEmitter::genNamedOperandSetters() {
1051   auto *attrSizedOperands =
1052       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1053   for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
1054     const auto &operand = op.getOperand(i);
1055     if (operand.name.empty())
1056       continue;
1057     for (StringRef name : op.getGetterNames(operand.name)) {
1058       auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
1059                                       ? "::mlir::MutableOperandRangeRange"
1060                                       : "::mlir::MutableOperandRange",
1061                                   (name + "Mutable").str());
1062       ERROR_IF_PRUNED(m, name, op);
1063       auto &body = m->body();
1064       body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n"
1065            << "  auto mutableRange = "
1066               "::mlir::MutableOperandRange(getOperation(), "
1067               "range.first, range.second";
1068       if (attrSizedOperands)
1069         body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
1070              << "u, *getOperation()->getAttrDictionary().getNamed("
1071              << op.getGetterName("operand_segment_sizes") << "AttrName()))";
1072       body << ");\n";
1073 
1074       // If this operand is a nested variadic, we split the range into a
1075       // MutableOperandRangeRange that provides a range over all of the
1076       // sub-ranges.
1077       if (operand.isVariadicOfVariadic()) {
1078         //
1079         body << "  return "
1080                 "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
1081              << op.getGetterName(
1082                     operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
1083              << "AttrName()));\n";
1084       } else {
1085         // Otherwise, we use the full range directly.
1086         body << "  return mutableRange;\n";
1087       }
1088     }
1089   }
1090 }
1091 
1092 void OpEmitter::genNamedResultGetters() {
1093   const int numResults = op.getNumResults();
1094   const int numVariadicResults = op.getNumVariableLengthResults();
1095   const int numNormalResults = numResults - numVariadicResults;
1096 
1097   // If we have more than one variadic results, we need more complicated logic
1098   // to calculate the value range for each result.
1099 
1100   const auto *sameVariadicSize =
1101       op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
1102   const auto *attrSizedResults =
1103       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
1104 
1105   if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
1106     PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
1107                                  "specification over their sizes");
1108   }
1109 
1110   if (numVariadicResults < 2 && attrSizedResults) {
1111     PrintFatalError(op.getLoc(), "op must have at least two variadic results "
1112                                  "to use 'AttrSizedResultSegments' trait");
1113   }
1114 
1115   if (attrSizedResults && sameVariadicSize) {
1116     PrintFatalError(op.getLoc(),
1117                     "op cannot have both 'AttrSizedResultSegments' and "
1118                     "'SameVariadicResultSize' traits");
1119   }
1120 
1121   // Build the initializer string for the result segment size attribute.
1122   std::string attrSizeInitCode;
1123   if (attrSizedResults) {
1124     std::string attr = op.getGetterName("result_segment_sizes");
1125     attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str();
1126   }
1127 
1128   generateValueRangeStartAndEnd(
1129       opClass, "getODSResultIndexAndLength", numVariadicResults,
1130       numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
1131       attrSizeInitCode, op.getResults());
1132 
1133   auto *m =
1134       opClass.addMethod("::mlir::Operation::result_range", "getODSResults",
1135                         MethodParameter("unsigned", "index"));
1136   ERROR_IF_PRUNED(m, "getODSResults", op);
1137   m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
1138                        "getODSResultIndexAndLength(index)");
1139 
1140   for (int i = 0; i != numResults; ++i) {
1141     const auto &result = op.getResult(i);
1142     if (result.name.empty())
1143       continue;
1144     for (StringRef name : op.getGetterNames(result.name)) {
1145       if (result.isOptional()) {
1146         m = opClass.addMethod("::mlir::Value", name);
1147         ERROR_IF_PRUNED(m, name, op);
1148         m->body()
1149             << "  auto results = getODSResults(" << i << ");\n"
1150             << "  return results.empty() ? ::mlir::Value() : *results.begin();";
1151       } else if (result.isVariadic()) {
1152         m = opClass.addMethod("::mlir::Operation::result_range", name);
1153         ERROR_IF_PRUNED(m, name, op);
1154         m->body() << "  return getODSResults(" << i << ");";
1155       } else {
1156         m = opClass.addMethod("::mlir::Value", name);
1157         ERROR_IF_PRUNED(m, name, op);
1158         m->body() << "  return *getODSResults(" << i << ").begin();";
1159       }
1160     }
1161   }
1162 }
1163 
1164 void OpEmitter::genNamedRegionGetters() {
1165   unsigned numRegions = op.getNumRegions();
1166   for (unsigned i = 0; i < numRegions; ++i) {
1167     const auto &region = op.getRegion(i);
1168     if (region.name.empty())
1169       continue;
1170 
1171     for (StringRef name : op.getGetterNames(region.name)) {
1172       // Generate the accessors for a variadic region.
1173       if (region.isVariadic()) {
1174         auto *m =
1175             opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name);
1176         ERROR_IF_PRUNED(m, name, op);
1177         m->body() << formatv("  return (*this)->getRegions().drop_front({0});",
1178                              i);
1179         continue;
1180       }
1181 
1182       auto *m = opClass.addMethod("::mlir::Region &", name);
1183       ERROR_IF_PRUNED(m, name, op);
1184       m->body() << formatv("  return (*this)->getRegion({0});", i);
1185     }
1186   }
1187 }
1188 
1189 void OpEmitter::genNamedSuccessorGetters() {
1190   unsigned numSuccessors = op.getNumSuccessors();
1191   for (unsigned i = 0; i < numSuccessors; ++i) {
1192     const NamedSuccessor &successor = op.getSuccessor(i);
1193     if (successor.name.empty())
1194       continue;
1195 
1196     for (StringRef name : op.getGetterNames(successor.name)) {
1197       // Generate the accessors for a variadic successor list.
1198       if (successor.isVariadic()) {
1199         auto *m = opClass.addMethod("::mlir::SuccessorRange", name);
1200         ERROR_IF_PRUNED(m, name, op);
1201         m->body() << formatv(
1202             "  return {std::next((*this)->successor_begin(), {0}), "
1203             "(*this)->successor_end()};",
1204             i);
1205         continue;
1206       }
1207 
1208       auto *m = opClass.addMethod("::mlir::Block *", name);
1209       ERROR_IF_PRUNED(m, name, op);
1210       m->body() << formatv("  return (*this)->getSuccessor({0});", i);
1211     }
1212   }
1213 }
1214 
1215 static bool canGenerateUnwrappedBuilder(Operator &op) {
1216   // If this op does not have native attributes at all, return directly to avoid
1217   // redefining builders.
1218   if (op.getNumNativeAttributes() == 0)
1219     return false;
1220 
1221   bool canGenerate = false;
1222   // We are generating builders that take raw values for attributes. We need to
1223   // make sure the native attributes have a meaningful "unwrapped" value type
1224   // different from the wrapped mlir::Attribute type to avoid redefining
1225   // builders. This checks for the op has at least one such native attribute.
1226   for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
1227     NamedAttribute &namedAttr = op.getAttribute(i);
1228     if (canUseUnwrappedRawValue(namedAttr.attr)) {
1229       canGenerate = true;
1230       break;
1231     }
1232   }
1233   return canGenerate;
1234 }
1235 
1236 static bool canInferType(Operator &op) {
1237   return op.getTrait("::mlir::InferTypeOpInterface::Trait");
1238 }
1239 
1240 void OpEmitter::genSeparateArgParamBuilder() {
1241   SmallVector<AttrParamKind, 2> attrBuilderType;
1242   attrBuilderType.push_back(AttrParamKind::WrappedAttr);
1243   if (canGenerateUnwrappedBuilder(op))
1244     attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
1245 
1246   // Emit with separate builders with or without unwrapped attributes and/or
1247   // inferring result type.
1248   auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
1249                   bool inferType) {
1250     SmallVector<MethodParameter> paramList;
1251     SmallVector<std::string, 4> resultNames;
1252     llvm::StringSet<> inferredAttributes;
1253     buildParamList(paramList, inferredAttributes, resultNames, paramKind,
1254                    attrType);
1255 
1256     auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1257     // If the builder is redundant, skip generating the method.
1258     if (!m)
1259       return;
1260     auto &body = m->body();
1261     genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
1262                                            /*isRawValueAttr=*/attrType ==
1263                                                AttrParamKind::UnwrappedValue);
1264 
1265     // Push all result types to the operation state
1266 
1267     if (inferType) {
1268       // Generate builder that infers type too.
1269       // TODO: Subsume this with general checking if type can be
1270       // inferred automatically.
1271       // TODO: Expand to handle regions.
1272       body << formatv(R"(
1273         ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1274         if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1275                       {1}.location, {1}.operands,
1276                       {1}.attributes.getDictionary({1}.getContext()),
1277                       /*regions=*/{{}, inferredReturnTypes)))
1278           {1}.addTypes(inferredReturnTypes);
1279         else
1280           ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1281                       opClass.getClassName(), builderOpState);
1282       return;
1283     }
1284 
1285     switch (paramKind) {
1286     case TypeParamKind::None:
1287       return;
1288     case TypeParamKind::Separate:
1289       for (int i = 0, e = op.getNumResults(); i < e; ++i) {
1290         if (op.getResult(i).isOptional())
1291           body << "  if (" << resultNames[i] << ")\n  ";
1292         body << "  " << builderOpState << ".addTypes(" << resultNames[i]
1293              << ");\n";
1294       }
1295       return;
1296     case TypeParamKind::Collective: {
1297       int numResults = op.getNumResults();
1298       int numVariadicResults = op.getNumVariableLengthResults();
1299       int numNonVariadicResults = numResults - numVariadicResults;
1300       bool hasVariadicResult = numVariadicResults != 0;
1301 
1302       // Avoid emitting "resultTypes.size() >= 0u" which is always true.
1303       if (!(hasVariadicResult && numNonVariadicResults == 0))
1304         body << "  "
1305              << "assert(resultTypes.size() "
1306              << (hasVariadicResult ? ">=" : "==") << " "
1307              << numNonVariadicResults
1308              << "u && \"mismatched number of results\");\n";
1309       body << "  " << builderOpState << ".addTypes(resultTypes);\n";
1310     }
1311       return;
1312     }
1313     llvm_unreachable("unhandled TypeParamKind");
1314   };
1315 
1316   // Some of the build methods generated here may be ambiguous, but TableGen's
1317   // ambiguous function detection will elide those ones.
1318   for (auto attrType : attrBuilderType) {
1319     emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
1320     if (canInferType(op) && op.getNumRegions() == 0)
1321       emit(attrType, TypeParamKind::None, /*inferType=*/true);
1322     emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
1323   }
1324 }
1325 
1326 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
1327   int numResults = op.getNumResults();
1328 
1329   // Signature
1330   SmallVector<MethodParameter> paramList;
1331   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1332   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1333   paramList.emplace_back("::mlir::ValueRange", "operands");
1334   // Provide default value for `attributes` when its the last parameter
1335   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1336   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1337                          "attributes", attributesDefaultValue);
1338   if (op.getNumVariadicRegions())
1339     paramList.emplace_back("unsigned", "numRegions");
1340 
1341   auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1342   // If the builder is redundant, skip generating the method
1343   if (!m)
1344     return;
1345   auto &body = m->body();
1346 
1347   // Operands
1348   body << "  " << builderOpState << ".addOperands(operands);\n";
1349 
1350   // Attributes
1351   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1352 
1353   // Create the correct number of regions
1354   if (int numRegions = op.getNumRegions()) {
1355     body << llvm::formatv(
1356         "  for (unsigned i = 0; i != {0}; ++i)\n",
1357         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1358     body << "    (void)" << builderOpState << ".addRegion();\n";
1359   }
1360 
1361   // Result types
1362   SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
1363   body << "  " << builderOpState << ".addTypes({"
1364        << llvm::join(resultTypes, ", ") << "});\n\n";
1365 }
1366 
1367 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
1368   // TODO: Expand to support regions.
1369   SmallVector<MethodParameter> paramList;
1370   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1371   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1372   paramList.emplace_back("::mlir::ValueRange", "operands");
1373   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1374                          "attributes", "{}");
1375   auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1376   // If the builder is redundant, skip generating the method
1377   if (!m)
1378     return;
1379   auto &body = m->body();
1380 
1381   int numResults = op.getNumResults();
1382   int numVariadicResults = op.getNumVariableLengthResults();
1383   int numNonVariadicResults = numResults - numVariadicResults;
1384 
1385   int numOperands = op.getNumOperands();
1386   int numVariadicOperands = op.getNumVariableLengthOperands();
1387   int numNonVariadicOperands = numOperands - numVariadicOperands;
1388 
1389   // Operands
1390   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1391     body << "  assert(operands.size()"
1392          << (numVariadicOperands != 0 ? " >= " : " == ")
1393          << numNonVariadicOperands
1394          << "u && \"mismatched number of parameters\");\n";
1395   body << "  " << builderOpState << ".addOperands(operands);\n";
1396   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1397 
1398   // Create the correct number of regions
1399   if (int numRegions = op.getNumRegions()) {
1400     body << llvm::formatv(
1401         "  for (unsigned i = 0; i != {0}; ++i)\n",
1402         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1403     body << "    (void)" << builderOpState << ".addRegion();\n";
1404   }
1405 
1406   // Result types
1407   body << formatv(R"(
1408   ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1409   if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1410           {1}.location, operands,
1411           {1}.attributes.getDictionary({1}.getContext()),
1412           {1}.regions, inferredReturnTypes))) {{)",
1413                   opClass.getClassName(), builderOpState);
1414   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1415     body << "\n    assert(inferredReturnTypes.size()"
1416          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1417          << "u && \"mismatched number of return types\");";
1418   body << "\n    " << builderOpState << ".addTypes(inferredReturnTypes);";
1419 
1420   body << formatv(R"(
1421   } else {{
1422     ::llvm::report_fatal_error("Failed to infer result type(s).");
1423   })",
1424                   opClass.getClassName(), builderOpState);
1425 }
1426 
1427 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
1428   SmallVector<MethodParameter> paramList;
1429   SmallVector<std::string, 4> resultNames;
1430   llvm::StringSet<> inferredAttributes;
1431   buildParamList(paramList, inferredAttributes, resultNames,
1432                  TypeParamKind::None);
1433 
1434   auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1435   // If the builder is redundant, skip generating the method
1436   if (!m)
1437     return;
1438   auto &body = m->body();
1439   genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes);
1440 
1441   auto numResults = op.getNumResults();
1442   if (numResults == 0)
1443     return;
1444 
1445   // Push all result types to the operation state
1446   const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
1447   std::string resultType =
1448       formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
1449   body << "  " << builderOpState << ".addTypes({" << resultType;
1450   for (int i = 1; i != numResults; ++i)
1451     body << ", " << resultType;
1452   body << "});\n\n";
1453 }
1454 
1455 void OpEmitter::genUseAttrAsResultTypeBuilder() {
1456   SmallVector<MethodParameter> paramList;
1457   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1458   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1459   paramList.emplace_back("::mlir::ValueRange", "operands");
1460   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1461                          "attributes", "{}");
1462   auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1463   // If the builder is redundant, skip generating the method
1464   if (!m)
1465     return;
1466 
1467   auto &body = m->body();
1468 
1469   // Push all result types to the operation state
1470   std::string resultType;
1471   const auto &namedAttr = op.getAttribute(0);
1472 
1473   body << "  auto attrName = " << op.getGetterName(namedAttr.name)
1474        << "AttrName(" << builderOpState
1475        << ".name);\n"
1476           "  for (auto attr : attributes) {\n"
1477           "    if (attr.getName() != attrName) continue;\n";
1478   if (namedAttr.attr.isTypeAttr()) {
1479     resultType = "attr.getValue().cast<::mlir::TypeAttr>().getValue()";
1480   } else {
1481     resultType = "attr.getValue().getType()";
1482   }
1483 
1484   // Operands
1485   body << "  " << builderOpState << ".addOperands(operands);\n";
1486 
1487   // Attributes
1488   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1489 
1490   // Result types
1491   SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
1492   body << "    " << builderOpState << ".addTypes({"
1493        << llvm::join(resultTypes, ", ") << "});\n";
1494   body << "  }\n";
1495 }
1496 
1497 /// Returns a signature of the builder. Updates the context `fctx` to enable
1498 /// replacement of $_builder and $_state in the body.
1499 static SmallVector<MethodParameter>
1500 getBuilderSignature(const Builder &builder) {
1501   ArrayRef<Builder::Parameter> params(builder.getParameters());
1502 
1503   // Inject builder and state arguments.
1504   SmallVector<MethodParameter> arguments;
1505   arguments.reserve(params.size() + 2);
1506   arguments.emplace_back("::mlir::OpBuilder &", odsBuilder);
1507   arguments.emplace_back("::mlir::OperationState &", builderOpState);
1508 
1509   for (unsigned i = 0, e = params.size(); i < e; ++i) {
1510     // If no name is provided, generate one.
1511     Optional<StringRef> paramName = params[i].getName();
1512     std::string name =
1513         paramName ? paramName->str() : "odsArg" + std::to_string(i);
1514 
1515     StringRef defaultValue;
1516     if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
1517       defaultValue = *defaultParamValue;
1518 
1519     arguments.emplace_back(params[i].getCppType(), std::move(name),
1520                            defaultValue);
1521   }
1522 
1523   return arguments;
1524 }
1525 
1526 void OpEmitter::genBuilder() {
1527   // Handle custom builders if provided.
1528   for (const Builder &builder : op.getBuilders()) {
1529     SmallVector<MethodParameter> arguments = getBuilderSignature(builder);
1530 
1531     Optional<StringRef> body = builder.getBody();
1532     auto properties = body ? Method::Static : Method::StaticDeclaration;
1533     auto *method =
1534         opClass.addMethod("void", "build", properties, std::move(arguments));
1535     if (body)
1536       ERROR_IF_PRUNED(method, "build", op);
1537 
1538     FmtContext fctx;
1539     fctx.withBuilder(odsBuilder);
1540     fctx.addSubst("_state", builderOpState);
1541     if (body)
1542       method->body() << tgfmt(*body, &fctx);
1543   }
1544 
1545   // Generate default builders that requires all result type, operands, and
1546   // attributes as parameters.
1547   if (op.skipDefaultBuilders())
1548     return;
1549 
1550   // We generate three classes of builders here:
1551   // 1. one having a stand-alone parameter for each operand / attribute, and
1552   genSeparateArgParamBuilder();
1553   // 2. one having an aggregated parameter for all result types / operands /
1554   //    attributes, and
1555   genCollectiveParamBuilder();
1556   // 3. one having a stand-alone parameter for each operand and attribute,
1557   //    use the first operand or attribute's type as all result types
1558   //    to facilitate different call patterns.
1559   if (op.getNumVariableLengthResults() == 0) {
1560     if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
1561       genUseOperandAsResultTypeSeparateParamBuilder();
1562       genUseOperandAsResultTypeCollectiveParamBuilder();
1563     }
1564     if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
1565       genUseAttrAsResultTypeBuilder();
1566   }
1567 }
1568 
1569 void OpEmitter::genCollectiveParamBuilder() {
1570   int numResults = op.getNumResults();
1571   int numVariadicResults = op.getNumVariableLengthResults();
1572   int numNonVariadicResults = numResults - numVariadicResults;
1573 
1574   int numOperands = op.getNumOperands();
1575   int numVariadicOperands = op.getNumVariableLengthOperands();
1576   int numNonVariadicOperands = numOperands - numVariadicOperands;
1577 
1578   SmallVector<MethodParameter> paramList;
1579   paramList.emplace_back("::mlir::OpBuilder &", "");
1580   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1581   paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1582   paramList.emplace_back("::mlir::ValueRange", "operands");
1583   // Provide default value for `attributes` when its the last parameter
1584   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1585   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1586                          "attributes", attributesDefaultValue);
1587   if (op.getNumVariadicRegions())
1588     paramList.emplace_back("unsigned", "numRegions");
1589 
1590   auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1591   // If the builder is redundant, skip generating the method
1592   if (!m)
1593     return;
1594   auto &body = m->body();
1595 
1596   // Operands
1597   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1598     body << "  assert(operands.size()"
1599          << (numVariadicOperands != 0 ? " >= " : " == ")
1600          << numNonVariadicOperands
1601          << "u && \"mismatched number of parameters\");\n";
1602   body << "  " << builderOpState << ".addOperands(operands);\n";
1603 
1604   // Attributes
1605   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1606 
1607   // Create the correct number of regions
1608   if (int numRegions = op.getNumRegions()) {
1609     body << llvm::formatv(
1610         "  for (unsigned i = 0; i != {0}; ++i)\n",
1611         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1612     body << "    (void)" << builderOpState << ".addRegion();\n";
1613   }
1614 
1615   // Result types
1616   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1617     body << "  assert(resultTypes.size()"
1618          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1619          << "u && \"mismatched number of return types\");\n";
1620   body << "  " << builderOpState << ".addTypes(resultTypes);\n";
1621 
1622   // Generate builder that infers type too.
1623   // TODO: Expand to handle successors.
1624   if (canInferType(op) && op.getNumSuccessors() == 0)
1625     genInferredTypeCollectiveParamBuilder();
1626 }
1627 
1628 void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
1629                                llvm::StringSet<> &inferredAttributes,
1630                                SmallVectorImpl<std::string> &resultTypeNames,
1631                                TypeParamKind typeParamKind,
1632                                AttrParamKind attrParamKind) {
1633   resultTypeNames.clear();
1634   auto numResults = op.getNumResults();
1635   resultTypeNames.reserve(numResults);
1636 
1637   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1638   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1639 
1640   switch (typeParamKind) {
1641   case TypeParamKind::None:
1642     break;
1643   case TypeParamKind::Separate: {
1644     // Add parameters for all return types
1645     for (int i = 0; i < numResults; ++i) {
1646       const auto &result = op.getResult(i);
1647       std::string resultName = std::string(result.name);
1648       if (resultName.empty())
1649         resultName = std::string(formatv("resultType{0}", i));
1650 
1651       StringRef type =
1652           result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
1653 
1654       paramList.emplace_back(type, resultName, result.isOptional());
1655       resultTypeNames.emplace_back(std::move(resultName));
1656     }
1657   } break;
1658   case TypeParamKind::Collective: {
1659     paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1660     resultTypeNames.push_back("resultTypes");
1661   } break;
1662   }
1663 
1664   // Add parameters for all arguments (operands and attributes).
1665   int defaultValuedAttrStartIndex = op.getNumArgs();
1666   // Successors and variadic regions go at the end of the parameter list, so no
1667   // default arguments are possible.
1668   bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions();
1669   if (attrParamKind == AttrParamKind::UnwrappedValue && !hasTrailingParams) {
1670     // Calculate the start index from which we can attach default values in the
1671     // builder declaration.
1672     for (int i = op.getNumArgs() - 1; i >= 0; --i) {
1673       auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
1674       if (!namedAttr || !namedAttr->attr.hasDefaultValue())
1675         break;
1676 
1677       if (!canUseUnwrappedRawValue(namedAttr->attr))
1678         break;
1679 
1680       // Creating an APInt requires us to provide bitwidth, value, and
1681       // signedness, which is complicated compared to others. Similarly
1682       // for APFloat.
1683       // TODO: Adjust the 'returnType' field of such attributes
1684       // to support them.
1685       StringRef retType = namedAttr->attr.getReturnType();
1686       if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
1687         break;
1688 
1689       defaultValuedAttrStartIndex = i;
1690     }
1691   }
1692 
1693   /// Collect any inferred attributes.
1694   for (const NamedTypeConstraint &operand : op.getOperands()) {
1695     if (operand.isVariadicOfVariadic()) {
1696       inferredAttributes.insert(
1697           operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
1698     }
1699   }
1700 
1701   for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
1702     Argument arg = op.getArg(i);
1703     if (const auto *operand = arg.dyn_cast<NamedTypeConstraint *>()) {
1704       StringRef type;
1705       if (operand->isVariadicOfVariadic())
1706         type = "::llvm::ArrayRef<::mlir::ValueRange>";
1707       else if (operand->isVariadic())
1708         type = "::mlir::ValueRange";
1709       else
1710         type = "::mlir::Value";
1711 
1712       paramList.emplace_back(type, getArgumentName(op, numOperands++),
1713                              operand->isOptional());
1714       continue;
1715     }
1716     const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
1717     const Attribute &attr = namedAttr.attr;
1718 
1719     // inferred attributes don't need to be added to the param list.
1720     if (inferredAttributes.contains(namedAttr.name))
1721       continue;
1722 
1723     StringRef type;
1724     switch (attrParamKind) {
1725     case AttrParamKind::WrappedAttr:
1726       type = attr.getStorageType();
1727       break;
1728     case AttrParamKind::UnwrappedValue:
1729       if (canUseUnwrappedRawValue(attr))
1730         type = attr.getReturnType();
1731       else
1732         type = attr.getStorageType();
1733       break;
1734     }
1735 
1736     // Attach default value if requested and possible.
1737     std::string defaultValue;
1738     if (attrParamKind == AttrParamKind::UnwrappedValue &&
1739         i >= defaultValuedAttrStartIndex) {
1740       defaultValue += attr.getDefaultValue();
1741     }
1742     paramList.emplace_back(type, namedAttr.name, StringRef(defaultValue),
1743                            attr.isOptional());
1744   }
1745 
1746   /// Insert parameters for each successor.
1747   for (const NamedSuccessor &succ : op.getSuccessors()) {
1748     StringRef type =
1749         succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
1750     paramList.emplace_back(type, succ.name);
1751   }
1752 
1753   /// Insert parameters for variadic regions.
1754   for (const NamedRegion &region : op.getRegions())
1755     if (region.isVariadic())
1756       paramList.emplace_back("unsigned",
1757                              llvm::formatv("{0}Count", region.name).str());
1758 }
1759 
1760 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
1761     MethodBody &body, llvm::StringSet<> &inferredAttributes,
1762     bool isRawValueAttr) {
1763   // Push all operands to the result.
1764   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
1765     std::string argName = getArgumentName(op, i);
1766     NamedTypeConstraint &operand = op.getOperand(i);
1767     if (operand.constraint.isVariadicOfVariadic()) {
1768       body << "  for (::mlir::ValueRange range : " << argName << ")\n   "
1769            << builderOpState << ".addOperands(range);\n";
1770 
1771       // Add the segment attribute.
1772       body << "  {\n"
1773            << "    ::llvm::SmallVector<int32_t> rangeSegments;\n"
1774            << "    for (::mlir::ValueRange range : " << argName << ")\n"
1775            << "      rangeSegments.push_back(range.size());\n"
1776            << "    " << builderOpState << ".addAttribute("
1777            << op.getGetterName(
1778                   operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
1779            << "AttrName(" << builderOpState << ".name), " << odsBuilder
1780            << ".getI32TensorAttr(rangeSegments));"
1781            << "  }\n";
1782       continue;
1783     }
1784 
1785     if (operand.isOptional())
1786       body << "  if (" << argName << ")\n  ";
1787     body << "  " << builderOpState << ".addOperands(" << argName << ");\n";
1788   }
1789 
1790   // If the operation has the operand segment size attribute, add it here.
1791   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1792     std::string sizes = op.getGetterName("operand_segment_sizes");
1793     body << "  " << builderOpState << ".addAttribute(" << sizes << "AttrName("
1794          << builderOpState << ".name), "
1795          << "odsBuilder.getI32VectorAttr({";
1796     interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1797       const NamedTypeConstraint &operand = op.getOperand(i);
1798       if (!operand.isVariableLength()) {
1799         body << "1";
1800         return;
1801       }
1802 
1803       std::string operandName = getArgumentName(op, i);
1804       if (operand.isOptional()) {
1805         body << "(" << operandName << " ? 1 : 0)";
1806       } else if (operand.isVariadicOfVariadic()) {
1807         body << llvm::formatv(
1808             "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, "
1809             "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + "
1810             "range.size(); }))",
1811             operandName);
1812       } else {
1813         body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
1814       }
1815     });
1816     body << "}));\n";
1817   }
1818 
1819   // Push all attributes to the result.
1820   for (const auto &namedAttr : op.getAttributes()) {
1821     auto &attr = namedAttr.attr;
1822     if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name))
1823       continue;
1824 
1825     bool emitNotNullCheck = attr.isOptional();
1826     if (emitNotNullCheck)
1827       body << formatv("  if ({0}) ", namedAttr.name) << "{\n";
1828 
1829     if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
1830       // If this is a raw value, then we need to wrap it in an Attribute
1831       // instance.
1832       FmtContext fctx;
1833       fctx.withBuilder("odsBuilder");
1834 
1835       std::string builderTemplate = std::string(attr.getConstBuilderTemplate());
1836 
1837       // For StringAttr, its constant builder call will wrap the input in
1838       // quotes, which is correct for normal string literals, but incorrect
1839       // here given we use function arguments. So we need to strip the
1840       // wrapping quotes.
1841       if (StringRef(builderTemplate).contains("\"$0\""))
1842         builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
1843 
1844       std::string value =
1845           std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
1846       body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
1847                       builderOpState, op.getGetterName(namedAttr.name), value);
1848     } else {
1849       body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
1850                       builderOpState, op.getGetterName(namedAttr.name),
1851                       namedAttr.name);
1852     }
1853     if (emitNotNullCheck)
1854       body << "  }\n";
1855   }
1856 
1857   // Create the correct number of regions.
1858   for (const NamedRegion &region : op.getRegions()) {
1859     if (region.isVariadic())
1860       body << formatv("  for (unsigned i = 0; i < {0}Count; ++i)\n  ",
1861                       region.name);
1862 
1863     body << "  (void)" << builderOpState << ".addRegion();\n";
1864   }
1865 
1866   // Push all successors to the result.
1867   for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
1868     body << formatv("  {0}.addSuccessors({1});\n", builderOpState,
1869                     namedSuccessor.name);
1870   }
1871 }
1872 
1873 void OpEmitter::genCanonicalizerDecls() {
1874   bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod");
1875   if (hasCanonicalizeMethod) {
1876     // static LogicResult FooOp::
1877     // canonicalize(FooOp op, PatternRewriter &rewriter);
1878     SmallVector<MethodParameter> paramList;
1879     paramList.emplace_back(op.getCppClassName(), "op");
1880     paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
1881     auto *m = opClass.declareStaticMethod("::mlir::LogicalResult",
1882                                           "canonicalize", std::move(paramList));
1883     ERROR_IF_PRUNED(m, "canonicalize", op);
1884   }
1885 
1886   // We get a prototype for 'getCanonicalizationPatterns' if requested directly
1887   // or if using a 'canonicalize' method.
1888   bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer");
1889   if (!hasCanonicalizeMethod && !hasCanonicalizer)
1890     return;
1891 
1892   // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
1893   // method, but not implementing 'getCanonicalizationPatterns' manually.
1894   bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
1895 
1896   // Add a signature for getCanonicalizationPatterns if implemented by the
1897   // dialect or if synthesized to call 'canonicalize'.
1898   SmallVector<MethodParameter> paramList;
1899   paramList.emplace_back("::mlir::RewritePatternSet &", "results");
1900   paramList.emplace_back("::mlir::MLIRContext *", "context");
1901   auto kind = hasBody ? Method::Static : Method::StaticDeclaration;
1902   auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind,
1903                                    std::move(paramList));
1904 
1905   // If synthesizing the method, fill it it.
1906   if (hasBody) {
1907     ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op);
1908     method->body() << "  results.add(canonicalize);\n";
1909   }
1910 }
1911 
1912 void OpEmitter::genFolderDecls() {
1913   bool hasSingleResult =
1914       op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
1915 
1916   if (def.getValueAsBit("hasFolder")) {
1917     if (hasSingleResult) {
1918       auto *m = opClass.declareMethod(
1919           "::mlir::OpFoldResult", "fold",
1920           MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands"));
1921       ERROR_IF_PRUNED(m, "operands", op);
1922     } else {
1923       SmallVector<MethodParameter> paramList;
1924       paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
1925       paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
1926                              "results");
1927       auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold",
1928                                       std::move(paramList));
1929       ERROR_IF_PRUNED(m, "fold", op);
1930     }
1931   }
1932 }
1933 
1934 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
1935   Interface interface = opTrait->getInterface();
1936 
1937   // Get the set of methods that should always be declared.
1938   auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
1939   llvm::StringSet<> alwaysDeclaredMethods;
1940   alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
1941                                alwaysDeclaredMethodsVec.end());
1942 
1943   for (const InterfaceMethod &method : interface.getMethods()) {
1944     // Don't declare if the method has a body.
1945     if (method.getBody())
1946       continue;
1947     // Don't declare if the method has a default implementation and the op
1948     // didn't request that it always be declared.
1949     if (method.getDefaultImplementation() &&
1950         !alwaysDeclaredMethods.count(method.getName()))
1951       continue;
1952     // Interface methods are allowed to overlap with existing methods, so don't
1953     // check if pruned.
1954     (void)genOpInterfaceMethod(method);
1955   }
1956 }
1957 
1958 Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
1959                                         bool declaration) {
1960   SmallVector<MethodParameter> paramList;
1961   for (const InterfaceMethod::Argument &arg : method.getArguments())
1962     paramList.emplace_back(arg.type, arg.name);
1963 
1964   auto props = (method.isStatic() ? Method::Static : Method::None) |
1965                (declaration ? Method::Declaration : Method::None);
1966   return opClass.addMethod(method.getReturnType(), method.getName(), props,
1967                            std::move(paramList));
1968 }
1969 
1970 void OpEmitter::genOpInterfaceMethods() {
1971   for (const auto &trait : op.getTraits()) {
1972     if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
1973       if (opTrait->shouldDeclareMethods())
1974         genOpInterfaceMethods(opTrait);
1975   }
1976 }
1977 
1978 void OpEmitter::genSideEffectInterfaceMethods() {
1979   enum EffectKind { Operand, Result, Symbol, Static };
1980   struct EffectLocation {
1981     /// The effect applied.
1982     SideEffect effect;
1983 
1984     /// The index if the kind is not static.
1985     unsigned index;
1986 
1987     /// The kind of the location.
1988     unsigned kind;
1989   };
1990 
1991   StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
1992   auto resolveDecorators = [&](Operator::var_decorator_range decorators,
1993                                unsigned index, unsigned kind) {
1994     for (auto decorator : decorators)
1995       if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
1996         opClass.addTrait(effect->getInterfaceTrait());
1997         interfaceEffects[effect->getBaseEffectName()].push_back(
1998             EffectLocation{*effect, index, kind});
1999       }
2000   };
2001 
2002   // Collect effects that were specified via:
2003   /// Traits.
2004   for (const auto &trait : op.getTraits()) {
2005     const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
2006     if (!opTrait)
2007       continue;
2008     auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
2009     for (auto decorator : opTrait->getEffects())
2010       effects.push_back(EffectLocation{cast<SideEffect>(decorator),
2011                                        /*index=*/0, EffectKind::Static});
2012   }
2013   /// Attributes and Operands.
2014   for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
2015     Argument arg = op.getArg(i);
2016     if (arg.is<NamedTypeConstraint *>()) {
2017       resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
2018       ++operandIt;
2019       continue;
2020     }
2021     const NamedAttribute *attr = arg.get<NamedAttribute *>();
2022     if (attr->attr.getBaseAttr().isSymbolRefAttr())
2023       resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
2024   }
2025   /// Results.
2026   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
2027     resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
2028 
2029   // The code used to add an effect instance.
2030   // {0}: The effect class.
2031   // {1}: Optional value or symbol reference.
2032   // {1}: The resource class.
2033   const char *addEffectCode =
2034       "  effects.emplace_back({0}::get(), {1}{2}::get());\n";
2035 
2036   for (auto &it : interfaceEffects) {
2037     // Generate the 'getEffects' method.
2038     std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::"
2039                                      "SideEffects::EffectInstance<{0}>> &",
2040                                      it.first())
2041                            .str();
2042     auto *getEffects = opClass.addMethod("void", "getEffects",
2043                                          MethodParameter(type, "effects"));
2044     ERROR_IF_PRUNED(getEffects, "getEffects", op);
2045     auto &body = getEffects->body();
2046 
2047     // Add effect instances for each of the locations marked on the operation.
2048     for (auto &location : it.second) {
2049       StringRef effect = location.effect.getName();
2050       StringRef resource = location.effect.getResource();
2051       if (location.kind == EffectKind::Static) {
2052         // A static instance has no attached value.
2053         body << llvm::formatv(addEffectCode, effect, "", resource).str();
2054       } else if (location.kind == EffectKind::Symbol) {
2055         // A symbol reference requires adding the proper attribute.
2056         const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
2057         std::string argName = op.getGetterName(attr->name);
2058         if (attr->attr.isOptional()) {
2059           body << "  if (auto symbolRef = " << argName << "Attr())\n  "
2060                << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
2061                       .str();
2062         } else {
2063           body << llvm::formatv(addEffectCode, effect, argName + "Attr(), ",
2064                                 resource)
2065                       .str();
2066         }
2067       } else {
2068         // Otherwise this is an operand/result, so we need to attach the Value.
2069         body << "  for (::mlir::Value value : getODS"
2070              << (location.kind == EffectKind::Operand ? "Operands" : "Results")
2071              << "(" << location.index << "))\n  "
2072              << llvm::formatv(addEffectCode, effect, "value, ", resource).str();
2073       }
2074     }
2075   }
2076 }
2077 
2078 void OpEmitter::genTypeInterfaceMethods() {
2079   if (!op.allResultTypesKnown())
2080     return;
2081   // Generate 'inferReturnTypes' method declaration using the interface method
2082   // declared in 'InferTypeOpInterface' op interface.
2083   const auto *trait =
2084       cast<InterfaceTrait>(op.getTrait("::mlir::InferTypeOpInterface::Trait"));
2085   Interface interface = trait->getInterface();
2086   Method *method = [&]() -> Method * {
2087     for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
2088       if (interfaceMethod.getName() == "inferReturnTypes") {
2089         return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
2090       }
2091     }
2092     assert(0 && "unable to find inferReturnTypes interface method");
2093     return nullptr;
2094   }();
2095   ERROR_IF_PRUNED(method, "inferReturnTypes", op);
2096   auto &body = method->body();
2097   body << "  inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
2098 
2099   FmtContext fctx;
2100   fctx.withBuilder("odsBuilder");
2101   body << "  ::mlir::Builder odsBuilder(context);\n";
2102 
2103   auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & {
2104     if (!type.isArg())
2105       return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
2106     auto argIndex = type.getArg();
2107     assert(!op.getArg(argIndex).is<NamedAttribute *>());
2108     auto arg = op.getArgToOperandOrAttribute(argIndex);
2109     if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
2110       return body << "operands[" << arg.operandOrAttributeIndex()
2111                   << "].getType()";
2112     return body << "attributes[" << arg.operandOrAttributeIndex()
2113                 << "].getType()";
2114   };
2115 
2116   for (int i = 0, e = op.getNumResults(); i != e; ++i) {
2117     body << "  inferredReturnTypes[" << i << "] = ";
2118     auto types = op.getSameTypeAsResult(i);
2119     emitType(types[0]) << ";\n";
2120     if (types.size() == 1)
2121       continue;
2122     // TODO: We could verify equality here, but skipping that for verification.
2123   }
2124   body << "  return ::mlir::success();";
2125 }
2126 
2127 void OpEmitter::genParser() {
2128   if (!hasStringAttribute(def, "parser") ||
2129       hasStringAttribute(def, "assemblyFormat"))
2130     return;
2131 
2132   SmallVector<MethodParameter> paramList;
2133   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
2134   paramList.emplace_back("::mlir::OperationState &", "result");
2135   auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
2136                                          std::move(paramList));
2137   ERROR_IF_PRUNED(method, "parse", op);
2138 
2139   FmtContext fctx;
2140   fctx.addSubst("cppClass", opClass.getClassName());
2141   auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
2142   method->body() << "  " << tgfmt(parser, &fctx);
2143 }
2144 
2145 void OpEmitter::genPrinter() {
2146   if (hasStringAttribute(def, "assemblyFormat"))
2147     return;
2148 
2149   auto *valueInit = def.getValueInit("printer");
2150   StringInit *stringInit = dyn_cast<StringInit>(valueInit);
2151   if (!stringInit)
2152     return;
2153 
2154   auto *method = opClass.addMethod(
2155       "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p"));
2156   ERROR_IF_PRUNED(method, "print", op);
2157   FmtContext fctx;
2158   fctx.addSubst("cppClass", opClass.getClassName());
2159   auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
2160   method->body() << "  " << tgfmt(printer, &fctx);
2161 }
2162 
2163 /// Generate verification on native traits requiring attributes.
2164 static void genNativeTraitAttrVerifier(MethodBody &body,
2165                                        const OpOrAdaptorHelper &emitHelper) {
2166   // Check that the variadic segment sizes attribute exists and contains the
2167   // expected number of elements.
2168   //
2169   // {0}: Attribute name.
2170   // {1}: Expected number of elements.
2171   // {2}: "operand" or "result".
2172   // {3}: Attribute getter call.
2173   // {4}: Emit error prefix.
2174   const char *const checkAttrSizedValueSegmentsCode = R"(
2175   {
2176     auto sizeAttr = {3}.dyn_cast<::mlir::DenseIntElementsAttr>();
2177     if (!sizeAttr)
2178       return {4}"missing segment sizes attribute '{0}'");
2179     auto numElements =
2180         sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
2181     if (numElements != {1})
2182       return {4}"'{0}' attribute for specifying {2} segments must have {1} "
2183                 "elements, but got ") << numElements;
2184   }
2185   )";
2186 
2187   // Verify a few traits first so that we can use getODSOperands() and
2188   // getODSResults() in the rest of the verifier.
2189   auto &op = emitHelper.getOp();
2190   for (auto &trait : op.getTraits()) {
2191     auto *t = dyn_cast<tblgen::NativeTrait>(&trait);
2192     if (!t)
2193       continue;
2194     std::string traitName = t->getFullyQualifiedTraitName();
2195     if (traitName == "::mlir::OpTrait::AttrSizedOperandSegments") {
2196       StringRef attrName = "operand_segment_sizes";
2197       body << formatv(checkAttrSizedValueSegmentsCode, attrName,
2198                       op.getNumOperands(), "operand",
2199                       emitHelper.getAttr(attrName),
2200                       emitHelper.emitErrorPrefix());
2201     } else if (traitName == "::mlir::OpTrait::AttrSizedResultSegments") {
2202       StringRef attrName = "result_segment_sizes";
2203       body << formatv(
2204           checkAttrSizedValueSegmentsCode, attrName, op.getNumResults(),
2205           "result", emitHelper.getAttr(attrName), emitHelper.emitErrorPrefix());
2206     }
2207   }
2208 }
2209 
2210 void OpEmitter::genVerifier() {
2211   auto *method = opClass.addMethod("::mlir::LogicalResult", "verify");
2212   ERROR_IF_PRUNED(method, "verify", op);
2213   auto &body = method->body();
2214 
2215   OpOrAdaptorHelper emitHelper(op, /*isOp=*/true);
2216   genNativeTraitAttrVerifier(body, emitHelper);
2217 
2218   auto *valueInit = def.getValueInit("verifier");
2219   StringInit *stringInit = dyn_cast<StringInit>(valueInit);
2220   bool hasCustomVerify = stringInit && !stringInit->getValue().empty();
2221   populateSubstitutions(emitHelper, verifyCtx);
2222 
2223   genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
2224   genOperandResultVerifier(body, op.getOperands(), "operand");
2225   genOperandResultVerifier(body, op.getResults(), "result");
2226 
2227   for (auto &trait : op.getTraits()) {
2228     if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
2229       body << tgfmt("  if (!($0))\n    "
2230                     "return emitOpError(\"failed to verify that $1\");\n",
2231                     &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
2232                     t->getSummary());
2233     }
2234   }
2235 
2236   genRegionVerifier(body);
2237   genSuccessorVerifier(body);
2238 
2239   if (hasCustomVerify) {
2240     FmtContext fctx;
2241     fctx.addSubst("cppClass", opClass.getClassName());
2242     auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
2243     body << "  " << tgfmt(printer, &fctx);
2244   } else {
2245     body << "  return ::mlir::success();\n";
2246   }
2247 }
2248 
2249 void OpEmitter::genOperandResultVerifier(MethodBody &body,
2250                                          Operator::const_value_range values,
2251                                          StringRef valueKind) {
2252   // Check that an optional value is at most 1 element.
2253   //
2254   // {0}: Value index.
2255   // {1}: "operand" or "result"
2256   const char *const verifyOptional = R"(
2257     if (valueGroup{0}.size() > 1) {
2258       return emitOpError("{1} group starting at #") << index
2259           << " requires 0 or 1 element, but found " << valueGroup{0}.size();
2260     }
2261 )";
2262   // Check the types of a range of values.
2263   //
2264   // {0}: Value index.
2265   // {1}: Type constraint function.
2266   // {2}: "operand" or "result"
2267   const char *const verifyValues = R"(
2268     for (auto v : valueGroup{0}) {
2269       if (::mlir::failed({1}(*this, v.getType(), "{2}", index++)))
2270         return ::mlir::failure();
2271     }
2272 )";
2273 
2274   const auto canSkip = [](const NamedTypeConstraint &value) {
2275     return !value.hasPredicate() && !value.isOptional() &&
2276            !value.isVariadicOfVariadic();
2277   };
2278   if (values.empty() || llvm::all_of(values, canSkip))
2279     return;
2280 
2281   FmtContext fctx;
2282 
2283   body << "  {\n    unsigned index = 0; (void)index;\n";
2284 
2285   for (const auto &staticValue : llvm::enumerate(values)) {
2286     const NamedTypeConstraint &value = staticValue.value();
2287 
2288     bool hasPredicate = value.hasPredicate();
2289     bool isOptional = value.isOptional();
2290     bool isVariadicOfVariadic = value.isVariadicOfVariadic();
2291     if (!hasPredicate && !isOptional && !isVariadicOfVariadic)
2292       continue;
2293     body << formatv("    auto valueGroup{2} = getODS{0}{1}s({2});\n",
2294                     // Capitalize the first letter to match the function name
2295                     valueKind.substr(0, 1).upper(), valueKind.substr(1),
2296                     staticValue.index());
2297 
2298     // If the constraint is optional check that the value group has at most 1
2299     // value.
2300     if (isOptional) {
2301       body << formatv(verifyOptional, staticValue.index(), valueKind);
2302     } else if (isVariadicOfVariadic) {
2303       body << formatv(
2304           "    if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
2305           "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n"
2306           "      return ::mlir::failure();\n",
2307           value.constraint.getVariadicOfVariadicSegmentSizeAttr(), value.name,
2308           staticValue.index());
2309     }
2310 
2311     // Otherwise, if there is no predicate there is nothing left to do.
2312     if (!hasPredicate)
2313       continue;
2314     // Emit a loop to check all the dynamic values in the pack.
2315     StringRef constraintFn =
2316         staticVerifierEmitter.getTypeConstraintFn(value.constraint);
2317     body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind);
2318   }
2319 
2320   body << "  }\n";
2321 }
2322 
2323 void OpEmitter::genRegionVerifier(MethodBody &body) {
2324   /// Code to verify a region.
2325   ///
2326   /// {0}: Getter for the regions.
2327   /// {1}: The region constraint.
2328   /// {2}: The region's name.
2329   /// {3}: The region description.
2330   const char *const verifyRegion = R"(
2331     for (auto &region : {0})
2332       if (::mlir::failed({1}(*this, region, "{2}", index++)))
2333         return ::mlir::failure();
2334 )";
2335   /// Get a single region.
2336   ///
2337   /// {0}: The region's index.
2338   const char *const getSingleRegion =
2339       "::llvm::makeMutableArrayRef((*this)->getRegion({0}))";
2340 
2341   // If we have no regions, there is nothing more to do.
2342   const auto canSkip = [](const NamedRegion &region) {
2343     return region.constraint.getPredicate().isNull();
2344   };
2345   auto regions = op.getRegions();
2346   if (regions.empty() && llvm::all_of(regions, canSkip))
2347     return;
2348 
2349   body << "  {\n    unsigned index = 0; (void)index;\n";
2350   for (const auto &it : llvm::enumerate(regions)) {
2351     const auto &region = it.value();
2352     if (canSkip(region))
2353       continue;
2354 
2355     auto getRegion = region.isVariadic()
2356                          ? formatv("{0}()", op.getGetterName(region.name)).str()
2357                          : formatv(getSingleRegion, it.index()).str();
2358     auto constraintFn =
2359         staticVerifierEmitter.getRegionConstraintFn(region.constraint);
2360     body << formatv(verifyRegion, getRegion, constraintFn, region.name);
2361   }
2362   body << "  }\n";
2363 }
2364 
2365 void OpEmitter::genSuccessorVerifier(MethodBody &body) {
2366   const char *const verifySuccessor = R"(
2367     for (auto *successor : {0})
2368       if (::mlir::failed({1}(*this, successor, "{2}", index++)))
2369         return ::mlir::failure();
2370 )";
2371   /// Get a single successor.
2372   ///
2373   /// {0}: The successor's name.
2374   const char *const getSingleSuccessor = "::llvm::makeMutableArrayRef({0}())";
2375 
2376   // If we have no successors, there is nothing more to do.
2377   const auto canSkip = [](const NamedSuccessor &successor) {
2378     return successor.constraint.getPredicate().isNull();
2379   };
2380   auto successors = op.getSuccessors();
2381   if (successors.empty() && llvm::all_of(successors, canSkip))
2382     return;
2383 
2384   body << "  {\n    unsigned index = 0; (void)index;\n";
2385 
2386   for (auto &it : llvm::enumerate(successors)) {
2387     const auto &successor = it.value();
2388     if (canSkip(successor))
2389       continue;
2390 
2391     auto getSuccessor =
2392         formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor,
2393                 successor.name, it.index())
2394             .str();
2395     auto constraintFn =
2396         staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint);
2397     body << formatv(verifySuccessor, getSuccessor, constraintFn,
2398                     successor.name);
2399   }
2400   body << "  }\n";
2401 }
2402 
2403 /// Add a size count trait to the given operation class.
2404 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
2405                               int numTotal, int numVariadic) {
2406   if (numVariadic != 0) {
2407     if (numTotal == numVariadic)
2408       opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
2409     else
2410       opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
2411                        Twine(numTotal - numVariadic) + ">::Impl");
2412     return;
2413   }
2414   switch (numTotal) {
2415   case 0:
2416     opClass.addTrait("::mlir::OpTrait::Zero" + traitKind);
2417     break;
2418   case 1:
2419     opClass.addTrait("::mlir::OpTrait::One" + traitKind);
2420     break;
2421   default:
2422     opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
2423                      ">::Impl");
2424     break;
2425   }
2426 }
2427 
2428 void OpEmitter::genTraits() {
2429   // Add region size trait.
2430   unsigned numRegions = op.getNumRegions();
2431   unsigned numVariadicRegions = op.getNumVariadicRegions();
2432   addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
2433 
2434   // Add result size traits.
2435   int numResults = op.getNumResults();
2436   int numVariadicResults = op.getNumVariableLengthResults();
2437   addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
2438 
2439   // For single result ops with a known specific type, generate a OneTypedResult
2440   // trait.
2441   if (numResults == 1 && numVariadicResults == 0) {
2442     auto cppName = op.getResults().begin()->constraint.getCPPClassName();
2443     opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
2444   }
2445 
2446   // Add successor size trait.
2447   unsigned numSuccessors = op.getNumSuccessors();
2448   unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
2449   addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
2450 
2451   // Add variadic size trait and normal op traits.
2452   int numOperands = op.getNumOperands();
2453   int numVariadicOperands = op.getNumVariableLengthOperands();
2454 
2455   // Add operand size trait.
2456   if (numVariadicOperands != 0) {
2457     if (numOperands == numVariadicOperands)
2458       opClass.addTrait("::mlir::OpTrait::VariadicOperands");
2459     else
2460       opClass.addTrait("::mlir::OpTrait::AtLeastNOperands<" +
2461                        Twine(numOperands - numVariadicOperands) + ">::Impl");
2462   } else {
2463     switch (numOperands) {
2464     case 0:
2465       opClass.addTrait("::mlir::OpTrait::ZeroOperands");
2466       break;
2467     case 1:
2468       opClass.addTrait("::mlir::OpTrait::OneOperand");
2469       break;
2470     default:
2471       opClass.addTrait("::mlir::OpTrait::NOperands<" + Twine(numOperands) +
2472                        ">::Impl");
2473       break;
2474     }
2475   }
2476 
2477   // Add the native and interface traits.
2478   for (const auto &trait : op.getTraits()) {
2479     if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait))
2480       opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2481     else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
2482       opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2483   }
2484 }
2485 
2486 void OpEmitter::genOpNameGetter() {
2487   auto *method = opClass.addStaticMethod<Method::Constexpr>(
2488       "::llvm::StringLiteral", "getOperationName");
2489   ERROR_IF_PRUNED(method, "getOperationName", op);
2490   method->body() << "  return ::llvm::StringLiteral(\"" << op.getOperationName()
2491                  << "\");";
2492 }
2493 
2494 void OpEmitter::genOpAsmInterface() {
2495   // If the user only has one results or specifically added the Asm trait,
2496   // then don't generate it for them. We specifically only handle multi result
2497   // operations, because the name of a single result in the common case is not
2498   // interesting(generally 'result'/'output'/etc.).
2499   // TODO: We could also add a flag to allow operations to opt in to this
2500   // generation, even if they only have a single operation.
2501   int numResults = op.getNumResults();
2502   if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
2503     return;
2504 
2505   SmallVector<StringRef, 4> resultNames(numResults);
2506   for (int i = 0; i != numResults; ++i)
2507     resultNames[i] = op.getResultName(i);
2508 
2509   // Don't add the trait if none of the results have a valid name.
2510   if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
2511     return;
2512   opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
2513 
2514   // Generate the right accessor for the number of results.
2515   auto *method = opClass.addMethod(
2516       "void", "getAsmResultNames",
2517       MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
2518   ERROR_IF_PRUNED(method, "getAsmResultNames", op);
2519   auto &body = method->body();
2520   for (int i = 0; i != numResults; ++i) {
2521     body << "  auto resultGroup" << i << " = getODSResults(" << i << ");\n"
2522          << "  if (!llvm::empty(resultGroup" << i << "))\n"
2523          << "    setNameFn(*resultGroup" << i << ".begin(), \""
2524          << resultNames[i] << "\");\n";
2525   }
2526 }
2527 
2528 //===----------------------------------------------------------------------===//
2529 // OpOperandAdaptor emitter
2530 //===----------------------------------------------------------------------===//
2531 
2532 namespace {
2533 // Helper class to emit Op operand adaptors to an output stream.  Operand
2534 // adaptors are wrappers around ArrayRef<Value> that provide named operand
2535 // getters identical to those defined in the Op.
2536 class OpOperandAdaptorEmitter {
2537 public:
2538   static void emitDecl(const Operator &op,
2539                        StaticVerifierFunctionEmitter &staticVerifierEmitter,
2540                        raw_ostream &os);
2541   static void emitDef(const Operator &op,
2542                       StaticVerifierFunctionEmitter &staticVerifierEmitter,
2543                       raw_ostream &os);
2544 
2545 private:
2546   explicit OpOperandAdaptorEmitter(
2547       const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter);
2548 
2549   // Add verification function. This generates a verify method for the adaptor
2550   // which verifies all the op-independent attribute constraints.
2551   void addVerification();
2552 
2553   const Operator &op;
2554   Class adaptor;
2555   StaticVerifierFunctionEmitter &staticVerifierEmitter;
2556 };
2557 } // namespace
2558 
2559 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
2560     const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter)
2561     : op(op), adaptor(op.getAdaptorName()),
2562       staticVerifierEmitter(staticVerifierEmitter) {
2563   adaptor.addField("::mlir::ValueRange", "odsOperands");
2564   adaptor.addField("::mlir::DictionaryAttr", "odsAttrs");
2565   adaptor.addField("::mlir::RegionRange", "odsRegions");
2566   const auto *attrSizedOperands =
2567       op.getTrait("::m::OpTrait::AttrSizedOperandSegments");
2568   {
2569     SmallVector<MethodParameter> paramList;
2570     paramList.emplace_back("::mlir::ValueRange", "values");
2571     paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
2572                            attrSizedOperands ? "" : "nullptr");
2573     paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
2574     auto *constructor = adaptor.addConstructor(std::move(paramList));
2575 
2576     constructor->addMemberInitializer("odsOperands", "values");
2577     constructor->addMemberInitializer("odsAttrs", "attrs");
2578     constructor->addMemberInitializer("odsRegions", "regions");
2579   }
2580 
2581   {
2582     auto *constructor = adaptor.addConstructor(
2583         MethodParameter(op.getCppClassName() + " &", "op"));
2584     constructor->addMemberInitializer("odsOperands", "op->getOperands()");
2585     constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
2586     constructor->addMemberInitializer("odsRegions", "op->getRegions()");
2587   }
2588 
2589   {
2590     auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands");
2591     ERROR_IF_PRUNED(m, "getOperands", op);
2592     m->body() << "  return odsOperands;";
2593   }
2594   std::string attr = "operand_segment_sizes";
2595   std::string sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, attr);
2596   generateNamedOperandGetters(op, adaptor,
2597                               /*isAdaptor=*/true, sizeAttrInit,
2598                               /*rangeType=*/"::mlir::ValueRange",
2599                               /*rangeBeginCall=*/"odsOperands.begin()",
2600                               /*rangeSizeCall=*/"odsOperands.size()",
2601                               /*getOperandCallPattern=*/"odsOperands[{0}]");
2602 
2603   FmtContext fctx;
2604   fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
2605 
2606   // Generate named accessor with Attribute return type.
2607   auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName,
2608                                      Attribute attr) {
2609     auto *method = adaptor.addMethod(attr.getStorageType(), emitName + "Attr");
2610     ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
2611     auto &body = method->body();
2612     body << "  assert(odsAttrs && \"no attributes when constructing adapter\");"
2613          << "\n  " << attr.getStorageType() << " attr = "
2614          << "odsAttrs.get(\"" << name << "\").";
2615     if (attr.hasDefaultValue() || attr.isOptional())
2616       body << "dyn_cast_or_null<";
2617     else
2618       body << "cast<";
2619     body << attr.getStorageType() << ">();\n";
2620 
2621     if (attr.hasDefaultValue()) {
2622       // Use the default value if attribute is not set.
2623       // TODO: this is inefficient, we are recreating the attribute for every
2624       // call. This should be set instead.
2625       std::string defaultValue = std::string(
2626           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2627       body << "  if (!attr)\n    attr = " << defaultValue << ";\n";
2628     }
2629     body << "  return attr;\n";
2630   };
2631 
2632   {
2633     auto *m = adaptor.addMethod("::mlir::DictionaryAttr", "getAttributes");
2634     ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
2635     m->body() << "  return odsAttrs;";
2636   }
2637   for (auto &namedAttr : op.getAttributes()) {
2638     const auto &name = namedAttr.name;
2639     const auto &attr = namedAttr.attr;
2640     if (attr.isDerivedAttr())
2641       continue;
2642     for (const auto &emitName : op.getGetterNames(name)) {
2643       emitAttrWithStorageType(name, emitName, attr);
2644       emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr);
2645     }
2646   }
2647 
2648   unsigned numRegions = op.getNumRegions();
2649   if (numRegions > 0) {
2650     auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions");
2651     ERROR_IF_PRUNED(m, "Adaptor::getRegions", op);
2652     m->body() << "  return odsRegions;";
2653   }
2654   for (unsigned i = 0; i < numRegions; ++i) {
2655     const auto &region = op.getRegion(i);
2656     if (region.name.empty())
2657       continue;
2658 
2659     // Generate the accessors for a variadic region.
2660     for (StringRef name : op.getGetterNames(region.name)) {
2661       if (region.isVariadic()) {
2662         auto *m = adaptor.addMethod("::mlir::RegionRange", name);
2663         ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
2664         m->body() << formatv("  return odsRegions.drop_front({0});", i);
2665         continue;
2666       }
2667 
2668       auto *m = adaptor.addMethod("::mlir::Region &", name);
2669       ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
2670       m->body() << formatv("  return *odsRegions[{0}];", i);
2671     }
2672   }
2673 
2674   // Add verification function.
2675   addVerification();
2676   adaptor.finalize();
2677 }
2678 
2679 void OpOperandAdaptorEmitter::addVerification() {
2680   auto *method = adaptor.addMethod("::mlir::LogicalResult", "verify",
2681                                    MethodParameter("::mlir::Location", "loc"));
2682   ERROR_IF_PRUNED(method, "verify", op);
2683   auto &body = method->body();
2684 
2685   OpOrAdaptorHelper emitHelper(op, /*isOp=*/false);
2686   genNativeTraitAttrVerifier(body, emitHelper);
2687   FmtContext verifyCtx;
2688   populateSubstitutions(emitHelper, verifyCtx);
2689 
2690   genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
2691 
2692   body << "  return ::mlir::success();";
2693 }
2694 
2695 void OpOperandAdaptorEmitter::emitDecl(
2696     const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter,
2697     raw_ostream &os) {
2698   OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os);
2699 }
2700 
2701 void OpOperandAdaptorEmitter::emitDef(
2702     const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter,
2703     raw_ostream &os) {
2704   OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os);
2705 }
2706 
2707 // Emits the opcode enum and op classes.
2708 static void emitOpClasses(const RecordKeeper &recordKeeper,
2709                           const std::vector<Record *> &defs, raw_ostream &os,
2710                           bool emitDecl) {
2711   // First emit forward declaration for each class, this allows them to refer
2712   // to each others in traits for example.
2713   if (emitDecl) {
2714     os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
2715     os << "#undef GET_OP_FWD_DEFINES\n";
2716     for (auto *def : defs) {
2717       Operator op(*def);
2718       NamespaceEmitter emitter(os, op.getCppNamespace());
2719       os << "class " << op.getCppClassName() << ";\n";
2720     }
2721     os << "#endif\n\n";
2722   }
2723 
2724   IfDefScope scope("GET_OP_CLASSES", os);
2725   if (defs.empty())
2726     return;
2727 
2728   // Generate all of the locally instantiated methods first.
2729   StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
2730   os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
2731   staticVerifierEmitter.emitOpConstraints(defs, emitDecl);
2732 
2733   for (auto *def : defs) {
2734     Operator op(*def);
2735     if (emitDecl) {
2736       {
2737         NamespaceEmitter emitter(os, op.getCppNamespace());
2738         os << formatv(opCommentHeader, op.getQualCppClassName(),
2739                       "declarations");
2740         OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os);
2741         OpEmitter::emitDecl(op, os, staticVerifierEmitter);
2742       }
2743       // Emit the TypeID explicit specialization to have a single definition.
2744       if (!op.getCppNamespace().empty())
2745         os << "DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
2746            << "::" << op.getCppClassName() << ")\n\n";
2747     } else {
2748       {
2749         NamespaceEmitter emitter(os, op.getCppNamespace());
2750         os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
2751         OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os);
2752         OpEmitter::emitDef(op, os, staticVerifierEmitter);
2753       }
2754       // Emit the TypeID explicit specialization to have a single definition.
2755       if (!op.getCppNamespace().empty())
2756         os << "DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
2757            << "::" << op.getCppClassName() << ")\n\n";
2758     }
2759   }
2760 }
2761 
2762 // Emits a comma-separated list of the ops.
2763 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
2764   IfDefScope scope("GET_OP_LIST", os);
2765 
2766   interleave(
2767       // TODO: We are constructing the Operator wrapper instance just for
2768       // getting it's qualified class name here. Reduce the overhead by having a
2769       // lightweight version of Operator class just for that purpose.
2770       defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
2771       [&os]() { os << ",\n"; });
2772 }
2773 
2774 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
2775   emitSourceFileHeader("Op Declarations", os);
2776 
2777   std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
2778   emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
2779 
2780   return false;
2781 }
2782 
2783 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
2784   emitSourceFileHeader("Op Definitions", os);
2785 
2786   std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
2787   emitOpList(defs, os);
2788   emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
2789 
2790   return false;
2791 }
2792 
2793 static mlir::GenRegistration
2794     genOpDecls("gen-op-decls", "Generate op declarations",
2795                [](const RecordKeeper &records, raw_ostream &os) {
2796                  return emitOpDecls(records, os);
2797                });
2798 
2799 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
2800                                        [](const RecordKeeper &records,
2801                                           raw_ostream &os) {
2802                                          return emitOpDefs(records, os);
2803                                        });
2804