1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 
25 /// File header and includes.
26 ///   {0} is the dialect namespace.
27 constexpr const char *fileHeader = R"Py(
28 # Autogenerated by mlir-tblgen; don't manually edit.
29 
30 from ._ods_common import _cext as _ods_cext
31 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
32 _ods_ir = _ods_cext.ir
33 
34 try:
35   from . import _{0}_ops_ext as _ods_ext_module
36 except ImportError:
37   _ods_ext_module = None
38 
39 import builtins
40 
41 )Py";
42 
43 /// Template for dialect class:
44 ///   {0} is the dialect namespace.
45 constexpr const char *dialectClassTemplate = R"Py(
46 @_ods_cext.register_dialect
47 class _Dialect(_ods_ir.Dialect):
48   DIALECT_NAMESPACE = "{0}"
49   pass
50 
51 )Py";
52 
53 /// Template for operation class:
54 ///   {0} is the Python class name;
55 ///   {1} is the operation name.
56 constexpr const char *opClassTemplate = R"Py(
57 @_ods_cext.register_operation(_Dialect)
58 @_ods_extend_opview_class(_ods_ext_module)
59 class {0}(_ods_ir.OpView):
60   OPERATION_NAME = "{1}"
61 )Py";
62 
63 /// Template for class level declarations of operand and result
64 /// segment specs.
65 ///   {0} is either "OPERAND" or "RESULT"
66 ///   {1} is the segment spec
67 /// Each segment spec is either None (default) or an array of integers
68 /// where:
69 ///   1 = single element (expect non sequence operand/result)
70 ///   0 = optional element (expect a value or None)
71 ///   -1 = operand/result is a sequence corresponding to a variadic
72 constexpr const char *opClassSizedSegmentsTemplate = R"Py(
73   _ODS_{0}_SEGMENTS = {1}
74 )Py";
75 
76 /// Template for class level declarations of the _ODS_REGIONS spec:
77 ///   {0} is the minimum number of regions
78 ///   {1} is the Python bool literal for hasNoVariadicRegions
79 constexpr const char *opClassRegionSpecTemplate = R"Py(
80   _ODS_REGIONS = ({0}, {1})
81 )Py";
82 
83 /// Template for single-element accessor:
84 ///   {0} is the name of the accessor;
85 ///   {1} is either 'operand' or 'result';
86 ///   {2} is the position in the element list.
87 constexpr const char *opSingleTemplate = R"Py(
88   @builtins.property
89   def {0}(self):
90     return self.operation.{1}s[{2}]
91 )Py";
92 
93 /// Template for single-element accessor after a variable-length group:
94 ///   {0} is the name of the accessor;
95 ///   {1} is either 'operand' or 'result';
96 ///   {2} is the total number of element groups;
97 ///   {3} is the position of the current group in the group list.
98 /// This works for both a single variadic group (non-negative length) and an
99 /// single optional element (zero length if the element is absent).
100 constexpr const char *opSingleAfterVariableTemplate = R"Py(
101   @builtins.property
102   def {0}(self):
103     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
104     return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
105 )Py";
106 
107 /// Template for an optional element accessor:
108 ///   {0} is the name of the accessor;
109 ///   {1} is either 'operand' or 'result';
110 ///   {2} is the total number of element groups;
111 ///   {3} is the position of the current group in the group list.
112 constexpr const char *opOneOptionalTemplate = R"Py(
113   @builtins.property
114   def {0}(self):
115     return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None
116 )Py";
117 
118 /// Template for the variadic group accessor in the single variadic group case:
119 ///   {0} is the name of the accessor;
120 ///   {1} is either 'operand' or 'result';
121 ///   {2} is the total number of element groups;
122 ///   {3} is the position of the current group in the group list.
123 constexpr const char *opOneVariadicTemplate = R"Py(
124   @builtins.property
125   def {0}(self):
126     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
127     return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
128 )Py";
129 
130 /// First part of the template for equally-sized variadic group accessor:
131 ///   {0} is the name of the accessor;
132 ///   {1} is either 'operand' or 'result';
133 ///   {2} is the total number of variadic groups;
134 ///   {3} is the number of non-variadic groups preceding the current group;
135 ///   {3} is the number of variadic groups preceding the current group.
136 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
137   @builtins.property
138   def {0}(self):
139     start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
140 
141 /// Second part of the template for equally-sized case, accessing a single
142 /// element:
143 ///   {0} is either 'operand' or 'result'.
144 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
145     return self.operation.{0}s[start]
146 )Py";
147 
148 /// Second part of the template for equally-sized case, accessing a variadic
149 /// group:
150 ///   {0} is either 'operand' or 'result'.
151 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
152     return self.operation.{0}s[start:start + pg]
153 )Py";
154 
155 /// Template for an attribute-sized group accessor:
156 ///   {0} is the name of the accessor;
157 ///   {1} is either 'operand' or 'result';
158 ///   {2} is the position of the group in the group list;
159 ///   {3} is a return suffix (expected [0] for single-element, empty for
160 ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
161 constexpr const char *opVariadicSegmentTemplate = R"Py(
162   @builtins.property
163   def {0}(self):
164     {1}_range = _ods_segmented_accessor(
165          self.operation.{1}s,
166          self.operation.attributes["{1}_segment_sizes"], {2})
167     return {1}_range{3}
168 )Py";
169 
170 /// Template for a suffix when accessing an optional element in the
171 /// attribute-sized case:
172 ///   {0} is either 'operand' or 'result';
173 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
174     R"Py([0] if len({0}_range) > 0 else None)Py";
175 
176 /// Template for an operation attribute getter:
177 ///   {0} is the name of the attribute sanitized for Python;
178 ///   {1} is the Python type of the attribute;
179 ///   {2} os the original name of the attribute.
180 constexpr const char *attributeGetterTemplate = R"Py(
181   @builtins.property
182   def {0}(self):
183     return {1}(self.operation.attributes["{2}"])
184 )Py";
185 
186 /// Template for an optional operation attribute getter:
187 ///   {0} is the name of the attribute sanitized for Python;
188 ///   {1} is the Python type of the attribute;
189 ///   {2} is the original name of the attribute.
190 constexpr const char *optionalAttributeGetterTemplate = R"Py(
191   @builtins.property
192   def {0}(self):
193     if "{2}" not in self.operation.attributes:
194       return None
195     return {1}(self.operation.attributes["{2}"])
196 )Py";
197 
198 /// Template for a getter of a unit operation attribute, returns True of the
199 /// unit attribute is present, False otherwise (unit attributes have meaning
200 /// by mere presence):
201 ///    {0} is the name of the attribute sanitized for Python,
202 ///    {1} is the original name of the attribute.
203 constexpr const char *unitAttributeGetterTemplate = R"Py(
204   @builtins.property
205   def {0}(self):
206     return "{1}" in self.operation.attributes
207 )Py";
208 
209 /// Template for an operation attribute setter:
210 ///    {0} is the name of the attribute sanitized for Python;
211 ///    {1} is the original name of the attribute.
212 constexpr const char *attributeSetterTemplate = R"Py(
213   @{0}.setter
214   def {0}(self, value):
215     if value is None:
216       raise ValueError("'None' not allowed as value for mandatory attributes")
217     self.operation.attributes["{1}"] = value
218 )Py";
219 
220 /// Template for a setter of an optional operation attribute, setting to None
221 /// removes the attribute:
222 ///    {0} is the name of the attribute sanitized for Python;
223 ///    {1} is the original name of the attribute.
224 constexpr const char *optionalAttributeSetterTemplate = R"Py(
225   @{0}.setter
226   def {0}(self, value):
227     if value is not None:
228       self.operation.attributes["{1}"] = value
229     elif "{1}" in self.operation.attributes:
230       del self.operation.attributes["{1}"]
231 )Py";
232 
233 /// Template for a setter of a unit operation attribute, setting to None or
234 /// False removes the attribute:
235 ///    {0} is the name of the attribute sanitized for Python;
236 ///    {1} is the original name of the attribute.
237 constexpr const char *unitAttributeSetterTemplate = R"Py(
238   @{0}.setter
239   def {0}(self, value):
240     if bool(value):
241       self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
242     elif "{1}" in self.operation.attributes:
243       del self.operation.attributes["{1}"]
244 )Py";
245 
246 /// Template for a deleter of an optional or a unit operation attribute, removes
247 /// the attribute from the operation:
248 ///    {0} is the name of the attribute sanitized for Python;
249 ///    {1} is the original name of the attribute.
250 constexpr const char *attributeDeleterTemplate = R"Py(
251   @{0}.deleter
252   def {0}(self):
253     del self.operation.attributes["{1}"]
254 )Py";
255 
256 constexpr const char *regionAccessorTemplate = R"PY(
257   @builtins.property
258   def {0}(self):
259     return self.regions[{1}]
260 )PY";
261 
262 static llvm::cl::OptionCategory
263     clOpPythonBindingCat("Options for -gen-python-op-bindings");
264 
265 static llvm::cl::opt<std::string>
266     clDialectName("bind-dialect",
267                   llvm::cl::desc("The dialect to run the generator for"),
268                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
269 
270 using AttributeClasses = DenseMap<StringRef, StringRef>;
271 
272 /// Checks whether `str` is a Python keyword.
273 static bool isPythonKeyword(StringRef str) {
274   static llvm::StringSet<> keywords(
275       {"and",   "as",     "assert",   "break", "class",  "continue",
276        "def",   "del",    "elif",     "else",  "except", "finally",
277        "for",   "from",   "global",   "if",    "import", "in",
278        "is",    "lambda", "nonlocal", "not",   "or",     "pass",
279        "raise", "return", "try",      "while", "with",   "yield"});
280   return keywords.contains(str);
281 }
282 
283 /// Checks whether `str` would shadow a generated variable or attribute
284 /// part of the OpView API.
285 static bool isODSReserved(StringRef str) {
286   static llvm::StringSet<> reserved(
287       {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
288        "loc", "verify", "regions", "results", "self", "operation",
289        "DIALECT_NAMESPACE", "OPERATION_NAME"});
290   return str.startswith("_ods_") || str.endswith("_ods") ||
291          reserved.contains(str);
292 }
293 
294 /// Modifies the `name` in a way that it becomes suitable for Python bindings
295 /// (does not change the `name` if it already is suitable) and returns the
296 /// modified version.
297 static std::string sanitizeName(StringRef name) {
298   if (isPythonKeyword(name) || isODSReserved(name))
299     return (name + "_").str();
300   return name.str();
301 }
302 
303 static std::string attrSizedTraitForKind(const char *kind) {
304   return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
305                        llvm::StringRef(kind).take_front().upper(),
306                        llvm::StringRef(kind).drop_front());
307 }
308 
309 /// Emits accessors to "elements" of an Op definition. Currently, the supported
310 /// elements are operands and results, indicated by `kind`, which must be either
311 /// `operand` or `result` and is used verbatim in the emitted code.
312 static void emitElementAccessors(
313     const Operator &op, raw_ostream &os, const char *kind,
314     llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
315     llvm::function_ref<int(const Operator &)> getNumElements,
316     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
317         getElement) {
318   assert(llvm::is_contained(
319              llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
320          "unsupported kind");
321 
322   // Traits indicating how to process variadic elements.
323   std::string sameSizeTrait =
324       llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
325                     llvm::StringRef(kind).take_front().upper(),
326                     llvm::StringRef(kind).drop_front());
327   std::string attrSizedTrait = attrSizedTraitForKind(kind);
328 
329   unsigned numVariadic = getNumVariadic(op);
330 
331   // If there is only one variadic element group, its size can be inferred from
332   // the total number of elements. If there are none, the generation is
333   // straightforward.
334   if (numVariadic <= 1) {
335     bool seenVariableLength = false;
336     for (int i = 0, e = getNumElements(op); i < e; ++i) {
337       const NamedTypeConstraint &element = getElement(op, i);
338       if (element.isVariableLength())
339         seenVariableLength = true;
340       if (element.name.empty())
341         continue;
342       if (element.isVariableLength()) {
343         os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
344                                                  : opOneVariadicTemplate,
345                             sanitizeName(element.name), kind,
346                             getNumElements(op), i);
347       } else if (seenVariableLength) {
348         os << llvm::formatv(opSingleAfterVariableTemplate,
349                             sanitizeName(element.name), kind,
350                             getNumElements(op), i);
351       } else {
352         os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
353                             i);
354       }
355     }
356     return;
357   }
358 
359   // Handle the operations where variadic groups have the same size.
360   if (op.getTrait(sameSizeTrait)) {
361     int numPrecedingSimple = 0;
362     int numPrecedingVariadic = 0;
363     for (int i = 0, e = getNumElements(op); i < e; ++i) {
364       const NamedTypeConstraint &element = getElement(op, i);
365       if (!element.name.empty()) {
366         os << llvm::formatv(opVariadicEqualPrefixTemplate,
367                             sanitizeName(element.name), kind, numVariadic,
368                             numPrecedingSimple, numPrecedingVariadic);
369         os << llvm::formatv(element.isVariableLength()
370                                 ? opVariadicEqualVariadicTemplate
371                                 : opVariadicEqualSimpleTemplate,
372                             kind);
373       }
374       if (element.isVariableLength())
375         ++numPrecedingVariadic;
376       else
377         ++numPrecedingSimple;
378     }
379     return;
380   }
381 
382   // Handle the operations where the size of groups (variadic or not) is
383   // provided as an attribute. For non-variadic elements, make sure to return
384   // an element rather than a singleton container.
385   if (op.getTrait(attrSizedTrait)) {
386     for (int i = 0, e = getNumElements(op); i < e; ++i) {
387       const NamedTypeConstraint &element = getElement(op, i);
388       if (element.name.empty())
389         continue;
390       std::string trailing;
391       if (!element.isVariableLength())
392         trailing = "[0]";
393       else if (element.isOptional())
394         trailing = std::string(
395             llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
396       os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
397                           kind, i, trailing);
398     }
399     return;
400   }
401 
402   llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
403 }
404 
405 /// Free function helpers accessing Operator components.
406 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
407 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
408   return op.getOperand(i);
409 }
410 static int getNumResults(const Operator &op) { return op.getNumResults(); }
411 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
412   return op.getResult(i);
413 }
414 
415 /// Emits accessors to Op operands.
416 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
417   auto getNumVariadic = [](const Operator &oper) {
418     return oper.getNumVariableLengthOperands();
419   };
420   emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands,
421                        getOperand);
422 }
423 
424 /// Emits accessors Op results.
425 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
426   auto getNumVariadic = [](const Operator &oper) {
427     return oper.getNumVariableLengthResults();
428   };
429   emitElementAccessors(op, os, "result", getNumVariadic, getNumResults,
430                        getResult);
431 }
432 
433 /// Emits accessors to Op attributes.
434 static void emitAttributeAccessors(const Operator &op,
435                                    const AttributeClasses &attributeClasses,
436                                    raw_ostream &os) {
437   for (const auto &namedAttr : op.getAttributes()) {
438     // Skip "derived" attributes because they are just C++ functions that we
439     // don't currently expose.
440     if (namedAttr.attr.isDerivedAttr())
441       continue;
442 
443     if (namedAttr.name.empty())
444       continue;
445 
446     std::string sanitizedName = sanitizeName(namedAttr.name);
447 
448     // Unit attributes are handled specially.
449     if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
450       os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
451                           namedAttr.name);
452       os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
453                           namedAttr.name);
454       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
455                           namedAttr.name);
456       continue;
457     }
458 
459     // Other kinds of attributes need a mapping to a Python type.
460     if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
461       continue;
462 
463     StringRef pythonType =
464         attributeClasses.lookup(namedAttr.attr.getStorageType());
465     if (namedAttr.attr.isOptional()) {
466       os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
467                           pythonType, namedAttr.name);
468       os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
469                           namedAttr.name);
470       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
471                           namedAttr.name);
472     } else {
473       os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType,
474                           namedAttr.name);
475       os << llvm::formatv(attributeSetterTemplate, sanitizedName,
476                           namedAttr.name);
477       // Non-optional attributes cannot be deleted.
478     }
479   }
480 }
481 
482 /// Template for the default auto-generated builder.
483 ///   {0} is a comma-separated list of builder arguments, including the trailing
484 ///       `loc` and `ip`;
485 ///   {1} is the code populating `operands`, `results` and `attributes`,
486 ///       `successors` fields.
487 constexpr const char *initTemplate = R"Py(
488   def __init__(self, {0}):
489     operands = []
490     results = []
491     attributes = {{}
492     regions = None
493     {1}
494     super().__init__(self.build_generic(
495       attributes=attributes, results=results, operands=operands,
496       successors=_ods_successors, regions=regions, loc=loc, ip=ip))
497 )Py";
498 
499 /// Template for appending a single element to the operand/result list.
500 ///   {0} is the field name.
501 constexpr const char *singleOperandAppendTemplate =
502     "operands.append(_get_op_result_or_value({0}))";
503 constexpr const char *singleResultAppendTemplate = "results.append({0})";
504 
505 /// Template for appending an optional element to the operand/result list.
506 ///   {0} is the field name.
507 constexpr const char *optionalAppendOperandTemplate =
508     "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
509 constexpr const char *optionalAppendAttrSizedOperandsTemplate =
510     "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
511     "None)";
512 constexpr const char *optionalAppendResultTemplate =
513     "if {0} is not None: results.append({0})";
514 
515 /// Template for appending a list of elements to the operand/result list.
516 ///   {0} is the field name.
517 constexpr const char *multiOperandAppendTemplate =
518     "operands.extend(_get_op_results_or_values({0}))";
519 constexpr const char *multiOperandAppendPackTemplate =
520     "operands.append(_get_op_results_or_values({0}))";
521 constexpr const char *multiResultAppendTemplate = "results.extend({0})";
522 
523 /// Template for setting an attribute in the operation builder.
524 ///   {0} is the attribute name;
525 ///   {1} is the builder argument name.
526 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
527 
528 /// Template for setting an optional attribute in the operation builder.
529 ///   {0} is the attribute name;
530 ///   {1} is the builder argument name.
531 constexpr const char *initOptionalAttributeTemplate =
532     R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
533 
534 constexpr const char *initUnitAttributeTemplate =
535     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
536       _ods_get_default_loc_context(loc)))Py";
537 
538 /// Template to initialize the successors list in the builder if there are any
539 /// successors.
540 ///   {0} is the value to initialize the successors list to.
541 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
542 
543 /// Template to append or extend the list of successors in the builder.
544 ///   {0} is the list method ('append' or 'extend');
545 ///   {1} is the value to add.
546 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
547 
548 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
549 /// result types of the given operation.
550 static bool hasSameArgumentAndResultTypes(const Operator &op) {
551   return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
552          op.getNumVariableLengthResults() == 0;
553 }
554 
555 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
556 /// result types of the given operation.
557 static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
558   return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
559          op.getNumVariableLengthResults() == 0;
560 }
561 
562 /// Returns true if the InferTypeOpInterface can be used to infer result types
563 /// of the given operation.
564 static bool hasInferTypeInterface(const Operator &op) {
565   return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
566          op.getNumRegions() == 0;
567 }
568 
569 /// Returns true if there is a trait or interface that can be used to infer
570 /// result types of the given operation.
571 static bool canInferType(const Operator &op) {
572   return hasSameArgumentAndResultTypes(op) ||
573          hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
574 }
575 
576 /// Populates `builderArgs` with result names if the builder is expected to
577 /// accept them as arguments.
578 static void
579 populateBuilderArgsResults(const Operator &op,
580                            llvm::SmallVectorImpl<std::string> &builderArgs) {
581   if (canInferType(op))
582     return;
583 
584   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
585     std::string name = op.getResultName(i).str();
586     if (name.empty()) {
587       if (op.getNumResults() == 1) {
588         // Special case for one result, make the default name be 'result'
589         // to properly match the built-in result accessor.
590         name = "result";
591       } else {
592         name = llvm::formatv("_gen_res_{0}", i);
593       }
594     }
595     name = sanitizeName(name);
596     builderArgs.push_back(name);
597   }
598 }
599 
600 /// Populates `builderArgs` with the Python-compatible names of builder function
601 /// arguments using intermixed attributes and operands in the same order as they
602 /// appear in the `arguments` field of the op definition. Additionally,
603 /// `operandNames` is populated with names of operands in their order of
604 /// appearance.
605 static void
606 populateBuilderArgs(const Operator &op,
607                     llvm::SmallVectorImpl<std::string> &builderArgs,
608                     llvm::SmallVectorImpl<std::string> &operandNames,
609                     llvm::SmallVectorImpl<std::string> &successorArgNames) {
610 
611   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
612     std::string name = op.getArgName(i).str();
613     if (name.empty())
614       name = llvm::formatv("_gen_arg_{0}", i);
615     name = sanitizeName(name);
616     builderArgs.push_back(name);
617     if (!op.getArg(i).is<NamedAttribute *>())
618       operandNames.push_back(name);
619   }
620 
621   for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
622     NamedSuccessor successor = op.getSuccessor(i);
623     std::string name = std::string(successor.name);
624     if (name.empty())
625       name = llvm::formatv("_gen_successor_{0}", i);
626     name = sanitizeName(name);
627     builderArgs.push_back(name);
628     successorArgNames.push_back(name);
629   }
630 }
631 
632 /// Populates `builderLines` with additional lines that are required in the
633 /// builder to set up operation attributes. `argNames` is expected to contain
634 /// the names of builder arguments that correspond to op arguments, i.e. to the
635 /// operands and attributes in the same order as they appear in the `arguments`
636 /// field.
637 static void
638 populateBuilderLinesAttr(const Operator &op,
639                          llvm::ArrayRef<std::string> argNames,
640                          llvm::SmallVectorImpl<std::string> &builderLines) {
641   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
642     Argument arg = op.getArg(i);
643     auto *attribute = arg.dyn_cast<NamedAttribute *>();
644     if (!attribute)
645       continue;
646 
647     // Unit attributes are handled specially.
648     if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
649       builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
650                                            attribute->name, argNames[i]));
651       continue;
652     }
653 
654     builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
655                                              ? initOptionalAttributeTemplate
656                                              : initAttributeTemplate,
657                                          attribute->name, argNames[i]));
658   }
659 }
660 
661 /// Populates `builderLines` with additional lines that are required in the
662 /// builder to set up successors. successorArgNames is expected to correspond
663 /// to the Python argument name for each successor on the op.
664 static void populateBuilderLinesSuccessors(
665     const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
666     llvm::SmallVectorImpl<std::string> &builderLines) {
667   if (successorArgNames.empty()) {
668     builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
669     return;
670   }
671 
672   builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
673   for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
674     auto &argName = successorArgNames[i];
675     const NamedSuccessor &successor = op.getSuccessor(i);
676     builderLines.push_back(
677         llvm::formatv(addSuccessorTemplate,
678                       successor.isVariadic() ? "extend" : "append", argName));
679   }
680 }
681 
682 /// Populates `builderLines` with additional lines that are required in the
683 /// builder to set up op operands.
684 static void
685 populateBuilderLinesOperand(const Operator &op,
686                             llvm::ArrayRef<std::string> names,
687                             llvm::SmallVectorImpl<std::string> &builderLines) {
688   bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
689 
690   // For each element, find or generate a name.
691   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
692     const NamedTypeConstraint &element = op.getOperand(i);
693     std::string name = names[i];
694 
695     // Choose the formatting string based on the element kind.
696     llvm::StringRef formatString;
697     if (!element.isVariableLength()) {
698       formatString = singleOperandAppendTemplate;
699     } else if (element.isOptional()) {
700       if (sizedSegments) {
701         formatString = optionalAppendAttrSizedOperandsTemplate;
702       } else {
703         formatString = optionalAppendOperandTemplate;
704       }
705     } else {
706       assert(element.isVariadic() && "unhandled element group type");
707       // If emitting with sizedSegments, then we add the actual list-typed
708       // element. Otherwise, we extend the actual operands.
709       if (sizedSegments) {
710         formatString = multiOperandAppendPackTemplate;
711       } else {
712         formatString = multiOperandAppendTemplate;
713       }
714     }
715 
716     builderLines.push_back(llvm::formatv(formatString.data(), name));
717   }
718 }
719 
720 /// Python code template for deriving the operation result types from its
721 /// attribute:
722 ///   - {0} is the name of the attribute from which to derive the types.
723 constexpr const char *deriveTypeFromAttrTemplate =
724     R"PY(_ods_result_type_source_attr = attributes["{0}"]
725 _ods_derived_result_type = (
726     _ods_ir.TypeAttr(_ods_result_type_source_attr).value
727     if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
728     _ods_result_type_source_attr.type))PY";
729 
730 /// Python code template appending {0} type {1} times to the results list.
731 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
732 
733 /// Python code template for inferring the operation results using the
734 /// corresponding interface:
735 ///   - {0} is the name of the class for which the types are inferred.
736 constexpr const char *inferTypeInterfaceTemplate =
737     R"PY(_ods_context = _ods_get_default_loc_context(loc)
738 results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
739     operands=operands,
740     attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
741     context=_ods_context,
742     loc=loc)
743 )PY";
744 
745 /// Appends the given multiline string as individual strings into
746 /// `builderLines`.
747 static void appendLineByLine(StringRef string,
748                              llvm::SmallVectorImpl<std::string> &builderLines) {
749 
750   std::pair<StringRef, StringRef> split = std::make_pair(string, string);
751   do {
752     split = split.second.split('\n');
753     builderLines.push_back(split.first.str());
754   } while (!split.second.empty());
755 }
756 
757 /// Populates `builderLines` with additional lines that are required in the
758 /// builder to set up op results.
759 static void
760 populateBuilderLinesResult(const Operator &op,
761                            llvm::ArrayRef<std::string> names,
762                            llvm::SmallVectorImpl<std::string> &builderLines) {
763   bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
764 
765   if (hasSameArgumentAndResultTypes(op)) {
766     builderLines.push_back(llvm::formatv(
767         appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
768     return;
769   }
770 
771   if (hasFirstAttrDerivedResultTypes(op)) {
772     const NamedAttribute &firstAttr = op.getAttribute(0);
773     assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
774                                       "from which the type is derived");
775     appendLineByLine(
776         llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
777         builderLines);
778     builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
779                                          "_ods_derived_result_type",
780                                          op.getNumResults()));
781     return;
782   }
783 
784   if (hasInferTypeInterface(op)) {
785     appendLineByLine(
786         llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
787         builderLines);
788     return;
789   }
790 
791   // For each element, find or generate a name.
792   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
793     const NamedTypeConstraint &element = op.getResult(i);
794     std::string name = names[i];
795 
796     // Choose the formatting string based on the element kind.
797     llvm::StringRef formatString;
798     if (!element.isVariableLength()) {
799       formatString = singleResultAppendTemplate;
800     } else if (element.isOptional()) {
801       formatString = optionalAppendResultTemplate;
802     } else {
803       assert(element.isVariadic() && "unhandled element group type");
804       // If emitting with sizedSegments, then we add the actual list-typed
805       // element. Otherwise, we extend the actual operands.
806       if (sizedSegments) {
807         formatString = singleResultAppendTemplate;
808       } else {
809         formatString = multiResultAppendTemplate;
810       }
811     }
812 
813     builderLines.push_back(llvm::formatv(formatString.data(), name));
814   }
815 }
816 
817 /// If the operation has variadic regions, adds a builder argument to specify
818 /// the number of those regions and builder lines to forward it to the generic
819 /// constructor.
820 static void
821 populateBuilderRegions(const Operator &op,
822                        llvm::SmallVectorImpl<std::string> &builderArgs,
823                        llvm::SmallVectorImpl<std::string> &builderLines) {
824   if (op.hasNoVariadicRegions())
825     return;
826 
827   // This is currently enforced when Operator is constructed.
828   assert(op.getNumVariadicRegions() == 1 &&
829          op.getRegion(op.getNumRegions() - 1).isVariadic() &&
830          "expected the last region to be varidic");
831 
832   const NamedRegion &region = op.getRegion(op.getNumRegions() - 1);
833   std::string name =
834       ("num_" + region.name.take_front().lower() + region.name.drop_front())
835           .str();
836   builderArgs.push_back(name);
837   builderLines.push_back(
838       llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
839 }
840 
841 /// Emits a default builder constructing an operation from the list of its
842 /// result types, followed by a list of its operands.
843 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
844   // If we are asked to skip default builders, comply.
845   if (op.skipDefaultBuilders())
846     return;
847 
848   llvm::SmallVector<std::string> builderArgs;
849   llvm::SmallVector<std::string> builderLines;
850   llvm::SmallVector<std::string> operandArgNames;
851   llvm::SmallVector<std::string> successorArgNames;
852   builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
853                       op.getNumNativeAttributes() + op.getNumSuccessors());
854   populateBuilderArgsResults(op, builderArgs);
855   size_t numResultArgs = builderArgs.size();
856   populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
857 
858   populateBuilderLinesOperand(op, operandArgNames, builderLines);
859   populateBuilderLinesAttr(
860       op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs),
861       builderLines);
862   populateBuilderLinesResult(
863       op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs),
864       builderLines);
865   populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
866   populateBuilderRegions(op, builderArgs, builderLines);
867 
868   builderArgs.push_back("*");
869   builderArgs.push_back("loc=None");
870   builderArgs.push_back("ip=None");
871   os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "),
872                       llvm::join(builderLines, "\n    "));
873 }
874 
875 static void constructAttributeMapping(const llvm::RecordKeeper &records,
876                                       AttributeClasses &attributeClasses) {
877   for (const llvm::Record *rec :
878        records.getAllDerivedDefinitions("PythonAttr")) {
879     attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
880                                  rec->getValueAsString("pythonType").trim());
881   }
882 }
883 
884 static void emitSegmentSpec(
885     const Operator &op, const char *kind,
886     llvm::function_ref<int(const Operator &)> getNumElements,
887     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
888         getElement,
889     raw_ostream &os) {
890   std::string segmentSpec("[");
891   for (int i = 0, e = getNumElements(op); i < e; ++i) {
892     const NamedTypeConstraint &element = getElement(op, i);
893     if (element.isOptional()) {
894       segmentSpec.append("0,");
895     } else if (element.isVariadic()) {
896       segmentSpec.append("-1,");
897     } else {
898       segmentSpec.append("1,");
899     }
900   }
901   segmentSpec.append("]");
902 
903   os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
904 }
905 
906 static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
907   // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
908   // Note that the base OpView class defines this as (0, True).
909   unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
910   os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
911                       op.hasNoVariadicRegions() ? "True" : "False");
912 }
913 
914 /// Emits named accessors to regions.
915 static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
916   for (auto en : llvm::enumerate(op.getRegions())) {
917     const NamedRegion &region = en.value();
918     if (region.name.empty())
919       continue;
920 
921     assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
922            "expected only the last region to be variadic");
923     os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
924                         std::to_string(en.index()) +
925                             (region.isVariadic() ? ":" : ""));
926   }
927 }
928 
929 /// Emits bindings for a specific Op to the given output stream.
930 static void emitOpBindings(const Operator &op,
931                            const AttributeClasses &attributeClasses,
932                            raw_ostream &os) {
933   os << llvm::formatv(opClassTemplate, op.getCppClassName(),
934                       op.getOperationName());
935 
936   // Sized segments.
937   if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
938     emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
939   }
940   if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
941     emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
942   }
943 
944   emitRegionAttributes(op, os);
945   emitDefaultOpBuilder(op, os);
946   emitOperandAccessors(op, os);
947   emitAttributeAccessors(op, attributeClasses, os);
948   emitResultAccessors(op, os);
949   emitRegionAccessors(op, os);
950 }
951 
952 /// Emits bindings for the dialect specified in the command line, including file
953 /// headers and utilities. Returns `false` on success to comply with Tablegen
954 /// registration requirements.
955 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
956   if (clDialectName.empty())
957     llvm::PrintFatalError("dialect name not provided");
958 
959   AttributeClasses attributeClasses;
960   constructAttributeMapping(records, attributeClasses);
961 
962   os << llvm::formatv(fileHeader, clDialectName.getValue());
963   os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
964 
965   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
966     Operator op(rec);
967     if (op.getDialectName() == clDialectName.getValue())
968       emitOpBindings(op, attributeClasses, os);
969   }
970   return false;
971 }
972 
973 static GenRegistration
974     genPythonBindings("gen-python-op-bindings",
975                       "Generate Python bindings for MLIR Ops", &emitAllOps);
976