1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 
25 /// File header and includes.
26 ///   {0} is the dialect namespace.
27 constexpr const char *fileHeader = R"Py(
28 # Autogenerated by mlir-tblgen; don't manually edit.
29 
30 from ._ods_common import _cext as _ods_cext
31 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
32 _ods_ir = _ods_cext.ir
33 
34 try:
35   from . import _{0}_ops_ext as _ods_ext_module
36 except ImportError:
37   _ods_ext_module = None
38 
39 import builtins
40 
41 )Py";
42 
43 /// Template for dialect class:
44 ///   {0} is the dialect namespace.
45 constexpr const char *dialectClassTemplate = R"Py(
46 @_ods_cext.register_dialect
47 class _Dialect(_ods_ir.Dialect):
48   DIALECT_NAMESPACE = "{0}"
49   pass
50 
51 )Py";
52 
53 /// Template for operation class:
54 ///   {0} is the Python class name;
55 ///   {1} is the operation name.
56 constexpr const char *opClassTemplate = R"Py(
57 @_ods_cext.register_operation(_Dialect)
58 @_ods_extend_opview_class(_ods_ext_module)
59 class {0}(_ods_ir.OpView):
60   OPERATION_NAME = "{1}"
61 )Py";
62 
63 /// Template for class level declarations of operand and result
64 /// segment specs.
65 ///   {0} is either "OPERAND" or "RESULT"
66 ///   {1} is the segment spec
67 /// Each segment spec is either None (default) or an array of integers
68 /// where:
69 ///   1 = single element (expect non sequence operand/result)
70 ///   0 = optional element (expect a value or None)
71 ///   -1 = operand/result is a sequence corresponding to a variadic
72 constexpr const char *opClassSizedSegmentsTemplate = R"Py(
73   _ODS_{0}_SEGMENTS = {1}
74 )Py";
75 
76 /// Template for class level declarations of the _ODS_REGIONS spec:
77 ///   {0} is the minimum number of regions
78 ///   {1} is the Python bool literal for hasNoVariadicRegions
79 constexpr const char *opClassRegionSpecTemplate = R"Py(
80   _ODS_REGIONS = ({0}, {1})
81 )Py";
82 
83 /// Template for single-element accessor:
84 ///   {0} is the name of the accessor;
85 ///   {1} is either 'operand' or 'result';
86 ///   {2} is the position in the element list.
87 constexpr const char *opSingleTemplate = R"Py(
88   @builtins.property
89   def {0}(self):
90     return self.operation.{1}s[{2}]
91 )Py";
92 
93 /// Template for single-element accessor after a variable-length group:
94 ///   {0} is the name of the accessor;
95 ///   {1} is either 'operand' or 'result';
96 ///   {2} is the total number of element groups;
97 ///   {3} is the position of the current group in the group list.
98 /// This works for both a single variadic group (non-negative length) and an
99 /// single optional element (zero length if the element is absent).
100 constexpr const char *opSingleAfterVariableTemplate = R"Py(
101   @builtins.property
102   def {0}(self):
103     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
104     return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
105 )Py";
106 
107 /// Template for an optional element accessor:
108 ///   {0} is the name of the accessor;
109 ///   {1} is either 'operand' or 'result';
110 ///   {2} is the total number of element groups;
111 ///   {3} is the position of the current group in the group list.
112 /// This works if we have only one variable-length group (and it's the optional
113 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
114 /// smaller than the total number of groups.
115 constexpr const char *opOneOptionalTemplate = R"Py(
116   @builtins.property
117   def {0}(self):
118     return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
119 )Py";
120 
121 /// Template for the variadic group accessor in the single variadic group case:
122 ///   {0} is the name of the accessor;
123 ///   {1} is either 'operand' or 'result';
124 ///   {2} is the total number of element groups;
125 ///   {3} is the position of the current group in the group list.
126 constexpr const char *opOneVariadicTemplate = R"Py(
127   @builtins.property
128   def {0}(self):
129     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
130     return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
131 )Py";
132 
133 /// First part of the template for equally-sized variadic group accessor:
134 ///   {0} is the name of the accessor;
135 ///   {1} is either 'operand' or 'result';
136 ///   {2} is the total number of variadic groups;
137 ///   {3} is the number of non-variadic groups preceding the current group;
138 ///   {3} is the number of variadic groups preceding the current group.
139 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
140   @builtins.property
141   def {0}(self):
142     start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
143 
144 /// Second part of the template for equally-sized case, accessing a single
145 /// element:
146 ///   {0} is either 'operand' or 'result'.
147 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
148     return self.operation.{0}s[start]
149 )Py";
150 
151 /// Second part of the template for equally-sized case, accessing a variadic
152 /// group:
153 ///   {0} is either 'operand' or 'result'.
154 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
155     return self.operation.{0}s[start:start + pg]
156 )Py";
157 
158 /// Template for an attribute-sized group accessor:
159 ///   {0} is the name of the accessor;
160 ///   {1} is either 'operand' or 'result';
161 ///   {2} is the position of the group in the group list;
162 ///   {3} is a return suffix (expected [0] for single-element, empty for
163 ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
164 constexpr const char *opVariadicSegmentTemplate = R"Py(
165   @builtins.property
166   def {0}(self):
167     {1}_range = _ods_segmented_accessor(
168          self.operation.{1}s,
169          self.operation.attributes["{1}_segment_sizes"], {2})
170     return {1}_range{3}
171 )Py";
172 
173 /// Template for a suffix when accessing an optional element in the
174 /// attribute-sized case:
175 ///   {0} is either 'operand' or 'result';
176 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
177     R"Py([0] if len({0}_range) > 0 else None)Py";
178 
179 /// Template for an operation attribute getter:
180 ///   {0} is the name of the attribute sanitized for Python;
181 ///   {1} is the Python type of the attribute;
182 ///   {2} os the original name of the attribute.
183 constexpr const char *attributeGetterTemplate = R"Py(
184   @builtins.property
185   def {0}(self):
186     return {1}(self.operation.attributes["{2}"])
187 )Py";
188 
189 /// Template for an optional operation attribute getter:
190 ///   {0} is the name of the attribute sanitized for Python;
191 ///   {1} is the Python type of the attribute;
192 ///   {2} is the original name of the attribute.
193 constexpr const char *optionalAttributeGetterTemplate = R"Py(
194   @builtins.property
195   def {0}(self):
196     if "{2}" not in self.operation.attributes:
197       return None
198     return {1}(self.operation.attributes["{2}"])
199 )Py";
200 
201 /// Template for a getter of a unit operation attribute, returns True of the
202 /// unit attribute is present, False otherwise (unit attributes have meaning
203 /// by mere presence):
204 ///    {0} is the name of the attribute sanitized for Python,
205 ///    {1} is the original name of the attribute.
206 constexpr const char *unitAttributeGetterTemplate = R"Py(
207   @builtins.property
208   def {0}(self):
209     return "{1}" in self.operation.attributes
210 )Py";
211 
212 /// Template for an operation attribute setter:
213 ///    {0} is the name of the attribute sanitized for Python;
214 ///    {1} is the original name of the attribute.
215 constexpr const char *attributeSetterTemplate = R"Py(
216   @{0}.setter
217   def {0}(self, value):
218     if value is None:
219       raise ValueError("'None' not allowed as value for mandatory attributes")
220     self.operation.attributes["{1}"] = value
221 )Py";
222 
223 /// Template for a setter of an optional operation attribute, setting to None
224 /// removes the attribute:
225 ///    {0} is the name of the attribute sanitized for Python;
226 ///    {1} is the original name of the attribute.
227 constexpr const char *optionalAttributeSetterTemplate = R"Py(
228   @{0}.setter
229   def {0}(self, value):
230     if value is not None:
231       self.operation.attributes["{1}"] = value
232     elif "{1}" in self.operation.attributes:
233       del self.operation.attributes["{1}"]
234 )Py";
235 
236 /// Template for a setter of a unit operation attribute, setting to None or
237 /// False removes the attribute:
238 ///    {0} is the name of the attribute sanitized for Python;
239 ///    {1} is the original name of the attribute.
240 constexpr const char *unitAttributeSetterTemplate = R"Py(
241   @{0}.setter
242   def {0}(self, value):
243     if bool(value):
244       self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
245     elif "{1}" in self.operation.attributes:
246       del self.operation.attributes["{1}"]
247 )Py";
248 
249 /// Template for a deleter of an optional or a unit operation attribute, removes
250 /// the attribute from the operation:
251 ///    {0} is the name of the attribute sanitized for Python;
252 ///    {1} is the original name of the attribute.
253 constexpr const char *attributeDeleterTemplate = R"Py(
254   @{0}.deleter
255   def {0}(self):
256     del self.operation.attributes["{1}"]
257 )Py";
258 
259 constexpr const char *regionAccessorTemplate = R"PY(
260   @builtins.property
261   def {0}(self):
262     return self.regions[{1}]
263 )PY";
264 
265 static llvm::cl::OptionCategory
266     clOpPythonBindingCat("Options for -gen-python-op-bindings");
267 
268 static llvm::cl::opt<std::string>
269     clDialectName("bind-dialect",
270                   llvm::cl::desc("The dialect to run the generator for"),
271                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
272 
273 using AttributeClasses = DenseMap<StringRef, StringRef>;
274 
275 /// Checks whether `str` is a Python keyword.
276 static bool isPythonKeyword(StringRef str) {
277   static llvm::StringSet<> keywords(
278       {"and",   "as",     "assert",   "break", "class",  "continue",
279        "def",   "del",    "elif",     "else",  "except", "finally",
280        "for",   "from",   "global",   "if",    "import", "in",
281        "is",    "lambda", "nonlocal", "not",   "or",     "pass",
282        "raise", "return", "try",      "while", "with",   "yield"});
283   return keywords.contains(str);
284 }
285 
286 /// Checks whether `str` would shadow a generated variable or attribute
287 /// part of the OpView API.
288 static bool isODSReserved(StringRef str) {
289   static llvm::StringSet<> reserved(
290       {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
291        "loc", "verify", "regions", "results", "self", "operation",
292        "DIALECT_NAMESPACE", "OPERATION_NAME"});
293   return str.startswith("_ods_") || str.endswith("_ods") ||
294          reserved.contains(str);
295 }
296 
297 /// Modifies the `name` in a way that it becomes suitable for Python bindings
298 /// (does not change the `name` if it already is suitable) and returns the
299 /// modified version.
300 static std::string sanitizeName(StringRef name) {
301   if (isPythonKeyword(name) || isODSReserved(name))
302     return (name + "_").str();
303   return name.str();
304 }
305 
306 static std::string attrSizedTraitForKind(const char *kind) {
307   return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
308                        llvm::StringRef(kind).take_front().upper(),
309                        llvm::StringRef(kind).drop_front());
310 }
311 
312 /// Emits accessors to "elements" of an Op definition. Currently, the supported
313 /// elements are operands and results, indicated by `kind`, which must be either
314 /// `operand` or `result` and is used verbatim in the emitted code.
315 static void emitElementAccessors(
316     const Operator &op, raw_ostream &os, const char *kind,
317     llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
318     llvm::function_ref<int(const Operator &)> getNumElements,
319     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
320         getElement) {
321   assert(llvm::is_contained(
322              llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
323          "unsupported kind");
324 
325   // Traits indicating how to process variadic elements.
326   std::string sameSizeTrait =
327       llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
328                     llvm::StringRef(kind).take_front().upper(),
329                     llvm::StringRef(kind).drop_front());
330   std::string attrSizedTrait = attrSizedTraitForKind(kind);
331 
332   unsigned numVariableLength = getNumVariableLength(op);
333 
334   // If there is only one variable-length element group, its size can be
335   // inferred from the total number of elements. If there are none, the
336   // generation is straightforward.
337   if (numVariableLength <= 1) {
338     bool seenVariableLength = false;
339     for (int i = 0, e = getNumElements(op); i < e; ++i) {
340       const NamedTypeConstraint &element = getElement(op, i);
341       if (element.isVariableLength())
342         seenVariableLength = true;
343       if (element.name.empty())
344         continue;
345       if (element.isVariableLength()) {
346         os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
347                                                  : opOneVariadicTemplate,
348                             sanitizeName(element.name), kind,
349                             getNumElements(op), i);
350       } else if (seenVariableLength) {
351         os << llvm::formatv(opSingleAfterVariableTemplate,
352                             sanitizeName(element.name), kind,
353                             getNumElements(op), i);
354       } else {
355         os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
356                             i);
357       }
358     }
359     return;
360   }
361 
362   // Handle the operations where variadic groups have the same size.
363   if (op.getTrait(sameSizeTrait)) {
364     int numPrecedingSimple = 0;
365     int numPrecedingVariadic = 0;
366     for (int i = 0, e = getNumElements(op); i < e; ++i) {
367       const NamedTypeConstraint &element = getElement(op, i);
368       if (!element.name.empty()) {
369         os << llvm::formatv(opVariadicEqualPrefixTemplate,
370                             sanitizeName(element.name), kind, numVariableLength,
371                             numPrecedingSimple, numPrecedingVariadic);
372         os << llvm::formatv(element.isVariableLength()
373                                 ? opVariadicEqualVariadicTemplate
374                                 : opVariadicEqualSimpleTemplate,
375                             kind);
376       }
377       if (element.isVariableLength())
378         ++numPrecedingVariadic;
379       else
380         ++numPrecedingSimple;
381     }
382     return;
383   }
384 
385   // Handle the operations where the size of groups (variadic or not) is
386   // provided as an attribute. For non-variadic elements, make sure to return
387   // an element rather than a singleton container.
388   if (op.getTrait(attrSizedTrait)) {
389     for (int i = 0, e = getNumElements(op); i < e; ++i) {
390       const NamedTypeConstraint &element = getElement(op, i);
391       if (element.name.empty())
392         continue;
393       std::string trailing;
394       if (!element.isVariableLength())
395         trailing = "[0]";
396       else if (element.isOptional())
397         trailing = std::string(
398             llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
399       os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
400                           kind, i, trailing);
401     }
402     return;
403   }
404 
405   llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
406 }
407 
408 /// Free function helpers accessing Operator components.
409 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
410 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
411   return op.getOperand(i);
412 }
413 static int getNumResults(const Operator &op) { return op.getNumResults(); }
414 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
415   return op.getResult(i);
416 }
417 
418 /// Emits accessors to Op operands.
419 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
420   auto getNumVariableLengthOperands = [](const Operator &oper) {
421     return oper.getNumVariableLengthOperands();
422   };
423   emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
424                        getNumOperands, getOperand);
425 }
426 
427 /// Emits accessors Op results.
428 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
429   auto getNumVariableLengthResults = [](const Operator &oper) {
430     return oper.getNumVariableLengthResults();
431   };
432   emitElementAccessors(op, os, "result", getNumVariableLengthResults,
433                        getNumResults, getResult);
434 }
435 
436 /// Emits accessors to Op attributes.
437 static void emitAttributeAccessors(const Operator &op,
438                                    const AttributeClasses &attributeClasses,
439                                    raw_ostream &os) {
440   for (const auto &namedAttr : op.getAttributes()) {
441     // Skip "derived" attributes because they are just C++ functions that we
442     // don't currently expose.
443     if (namedAttr.attr.isDerivedAttr())
444       continue;
445 
446     if (namedAttr.name.empty())
447       continue;
448 
449     std::string sanitizedName = sanitizeName(namedAttr.name);
450 
451     // Unit attributes are handled specially.
452     if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
453       os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
454                           namedAttr.name);
455       os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
456                           namedAttr.name);
457       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
458                           namedAttr.name);
459       continue;
460     }
461 
462     // Other kinds of attributes need a mapping to a Python type.
463     if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
464       continue;
465 
466     StringRef pythonType =
467         attributeClasses.lookup(namedAttr.attr.getStorageType());
468     if (namedAttr.attr.isOptional()) {
469       os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
470                           pythonType, namedAttr.name);
471       os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
472                           namedAttr.name);
473       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
474                           namedAttr.name);
475     } else {
476       os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType,
477                           namedAttr.name);
478       os << llvm::formatv(attributeSetterTemplate, sanitizedName,
479                           namedAttr.name);
480       // Non-optional attributes cannot be deleted.
481     }
482   }
483 }
484 
485 /// Template for the default auto-generated builder.
486 ///   {0} is a comma-separated list of builder arguments, including the trailing
487 ///       `loc` and `ip`;
488 ///   {1} is the code populating `operands`, `results` and `attributes`,
489 ///       `successors` fields.
490 constexpr const char *initTemplate = R"Py(
491   def __init__(self, {0}):
492     operands = []
493     results = []
494     attributes = {{}
495     regions = None
496     {1}
497     super().__init__(self.build_generic(
498       attributes=attributes, results=results, operands=operands,
499       successors=_ods_successors, regions=regions, loc=loc, ip=ip))
500 )Py";
501 
502 /// Template for appending a single element to the operand/result list.
503 ///   {0} is the field name.
504 constexpr const char *singleOperandAppendTemplate =
505     "operands.append(_get_op_result_or_value({0}))";
506 constexpr const char *singleResultAppendTemplate = "results.append({0})";
507 
508 /// Template for appending an optional element to the operand/result list.
509 ///   {0} is the field name.
510 constexpr const char *optionalAppendOperandTemplate =
511     "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
512 constexpr const char *optionalAppendAttrSizedOperandsTemplate =
513     "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
514     "None)";
515 constexpr const char *optionalAppendResultTemplate =
516     "if {0} is not None: results.append({0})";
517 
518 /// Template for appending a list of elements to the operand/result list.
519 ///   {0} is the field name.
520 constexpr const char *multiOperandAppendTemplate =
521     "operands.extend(_get_op_results_or_values({0}))";
522 constexpr const char *multiOperandAppendPackTemplate =
523     "operands.append(_get_op_results_or_values({0}))";
524 constexpr const char *multiResultAppendTemplate = "results.extend({0})";
525 
526 /// Template for setting an attribute in the operation builder.
527 ///   {0} is the attribute name;
528 ///   {1} is the builder argument name.
529 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
530 
531 /// Template for setting an optional attribute in the operation builder.
532 ///   {0} is the attribute name;
533 ///   {1} is the builder argument name.
534 constexpr const char *initOptionalAttributeTemplate =
535     R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
536 
537 constexpr const char *initUnitAttributeTemplate =
538     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
539       _ods_get_default_loc_context(loc)))Py";
540 
541 /// Template to initialize the successors list in the builder if there are any
542 /// successors.
543 ///   {0} is the value to initialize the successors list to.
544 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
545 
546 /// Template to append or extend the list of successors in the builder.
547 ///   {0} is the list method ('append' or 'extend');
548 ///   {1} is the value to add.
549 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
550 
551 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
552 /// result types of the given operation.
553 static bool hasSameArgumentAndResultTypes(const Operator &op) {
554   return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
555          op.getNumVariableLengthResults() == 0;
556 }
557 
558 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
559 /// result types of the given operation.
560 static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
561   return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
562          op.getNumVariableLengthResults() == 0;
563 }
564 
565 /// Returns true if the InferTypeOpInterface can be used to infer result types
566 /// of the given operation.
567 static bool hasInferTypeInterface(const Operator &op) {
568   return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
569          op.getNumRegions() == 0;
570 }
571 
572 /// Returns true if there is a trait or interface that can be used to infer
573 /// result types of the given operation.
574 static bool canInferType(const Operator &op) {
575   return hasSameArgumentAndResultTypes(op) ||
576          hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
577 }
578 
579 /// Populates `builderArgs` with result names if the builder is expected to
580 /// accept them as arguments.
581 static void
582 populateBuilderArgsResults(const Operator &op,
583                            llvm::SmallVectorImpl<std::string> &builderArgs) {
584   if (canInferType(op))
585     return;
586 
587   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
588     std::string name = op.getResultName(i).str();
589     if (name.empty()) {
590       if (op.getNumResults() == 1) {
591         // Special case for one result, make the default name be 'result'
592         // to properly match the built-in result accessor.
593         name = "result";
594       } else {
595         name = llvm::formatv("_gen_res_{0}", i);
596       }
597     }
598     name = sanitizeName(name);
599     builderArgs.push_back(name);
600   }
601 }
602 
603 /// Populates `builderArgs` with the Python-compatible names of builder function
604 /// arguments using intermixed attributes and operands in the same order as they
605 /// appear in the `arguments` field of the op definition. Additionally,
606 /// `operandNames` is populated with names of operands in their order of
607 /// appearance.
608 static void
609 populateBuilderArgs(const Operator &op,
610                     llvm::SmallVectorImpl<std::string> &builderArgs,
611                     llvm::SmallVectorImpl<std::string> &operandNames,
612                     llvm::SmallVectorImpl<std::string> &successorArgNames) {
613 
614   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
615     std::string name = op.getArgName(i).str();
616     if (name.empty())
617       name = llvm::formatv("_gen_arg_{0}", i);
618     name = sanitizeName(name);
619     builderArgs.push_back(name);
620     if (!op.getArg(i).is<NamedAttribute *>())
621       operandNames.push_back(name);
622   }
623 }
624 
625 /// Populates `builderArgs` with the Python-compatible names of builder function
626 /// successor arguments. Additionally, `successorArgNames` is also populated.
627 static void populateBuilderArgsSuccessors(
628     const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
629     llvm::SmallVectorImpl<std::string> &successorArgNames) {
630 
631   for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
632     NamedSuccessor successor = op.getSuccessor(i);
633     std::string name = std::string(successor.name);
634     if (name.empty())
635       name = llvm::formatv("_gen_successor_{0}", i);
636     name = sanitizeName(name);
637     builderArgs.push_back(name);
638     successorArgNames.push_back(name);
639   }
640 }
641 
642 /// Populates `builderLines` with additional lines that are required in the
643 /// builder to set up operation attributes. `argNames` is expected to contain
644 /// the names of builder arguments that correspond to op arguments, i.e. to the
645 /// operands and attributes in the same order as they appear in the `arguments`
646 /// field.
647 static void
648 populateBuilderLinesAttr(const Operator &op,
649                          llvm::ArrayRef<std::string> argNames,
650                          llvm::SmallVectorImpl<std::string> &builderLines) {
651   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
652     Argument arg = op.getArg(i);
653     auto *attribute = arg.dyn_cast<NamedAttribute *>();
654     if (!attribute)
655       continue;
656 
657     // Unit attributes are handled specially.
658     if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
659       builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
660                                            attribute->name, argNames[i]));
661       continue;
662     }
663 
664     builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
665                                              ? initOptionalAttributeTemplate
666                                              : initAttributeTemplate,
667                                          attribute->name, argNames[i]));
668   }
669 }
670 
671 /// Populates `builderLines` with additional lines that are required in the
672 /// builder to set up successors. successorArgNames is expected to correspond
673 /// to the Python argument name for each successor on the op.
674 static void populateBuilderLinesSuccessors(
675     const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
676     llvm::SmallVectorImpl<std::string> &builderLines) {
677   if (successorArgNames.empty()) {
678     builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
679     return;
680   }
681 
682   builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
683   for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
684     auto &argName = successorArgNames[i];
685     const NamedSuccessor &successor = op.getSuccessor(i);
686     builderLines.push_back(
687         llvm::formatv(addSuccessorTemplate,
688                       successor.isVariadic() ? "extend" : "append", argName));
689   }
690 }
691 
692 /// Populates `builderLines` with additional lines that are required in the
693 /// builder to set up op operands.
694 static void
695 populateBuilderLinesOperand(const Operator &op,
696                             llvm::ArrayRef<std::string> names,
697                             llvm::SmallVectorImpl<std::string> &builderLines) {
698   bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
699 
700   // For each element, find or generate a name.
701   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
702     const NamedTypeConstraint &element = op.getOperand(i);
703     std::string name = names[i];
704 
705     // Choose the formatting string based on the element kind.
706     llvm::StringRef formatString;
707     if (!element.isVariableLength()) {
708       formatString = singleOperandAppendTemplate;
709     } else if (element.isOptional()) {
710       if (sizedSegments) {
711         formatString = optionalAppendAttrSizedOperandsTemplate;
712       } else {
713         formatString = optionalAppendOperandTemplate;
714       }
715     } else {
716       assert(element.isVariadic() && "unhandled element group type");
717       // If emitting with sizedSegments, then we add the actual list-typed
718       // element. Otherwise, we extend the actual operands.
719       if (sizedSegments) {
720         formatString = multiOperandAppendPackTemplate;
721       } else {
722         formatString = multiOperandAppendTemplate;
723       }
724     }
725 
726     builderLines.push_back(llvm::formatv(formatString.data(), name));
727   }
728 }
729 
730 /// Python code template for deriving the operation result types from its
731 /// attribute:
732 ///   - {0} is the name of the attribute from which to derive the types.
733 constexpr const char *deriveTypeFromAttrTemplate =
734     R"PY(_ods_result_type_source_attr = attributes["{0}"]
735 _ods_derived_result_type = (
736     _ods_ir.TypeAttr(_ods_result_type_source_attr).value
737     if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
738     _ods_result_type_source_attr.type))PY";
739 
740 /// Python code template appending {0} type {1} times to the results list.
741 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
742 
743 /// Python code template for inferring the operation results using the
744 /// corresponding interface:
745 ///   - {0} is the name of the class for which the types are inferred.
746 constexpr const char *inferTypeInterfaceTemplate =
747     R"PY(_ods_context = _ods_get_default_loc_context(loc)
748 results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
749     operands=operands,
750     attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
751     context=_ods_context,
752     loc=loc)
753 )PY";
754 
755 /// Appends the given multiline string as individual strings into
756 /// `builderLines`.
757 static void appendLineByLine(StringRef string,
758                              llvm::SmallVectorImpl<std::string> &builderLines) {
759 
760   std::pair<StringRef, StringRef> split = std::make_pair(string, string);
761   do {
762     split = split.second.split('\n');
763     builderLines.push_back(split.first.str());
764   } while (!split.second.empty());
765 }
766 
767 /// Populates `builderLines` with additional lines that are required in the
768 /// builder to set up op results.
769 static void
770 populateBuilderLinesResult(const Operator &op,
771                            llvm::ArrayRef<std::string> names,
772                            llvm::SmallVectorImpl<std::string> &builderLines) {
773   bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
774 
775   if (hasSameArgumentAndResultTypes(op)) {
776     builderLines.push_back(llvm::formatv(
777         appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
778     return;
779   }
780 
781   if (hasFirstAttrDerivedResultTypes(op)) {
782     const NamedAttribute &firstAttr = op.getAttribute(0);
783     assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
784                                       "from which the type is derived");
785     appendLineByLine(
786         llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
787         builderLines);
788     builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
789                                          "_ods_derived_result_type",
790                                          op.getNumResults()));
791     return;
792   }
793 
794   if (hasInferTypeInterface(op)) {
795     appendLineByLine(
796         llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
797         builderLines);
798     return;
799   }
800 
801   // For each element, find or generate a name.
802   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
803     const NamedTypeConstraint &element = op.getResult(i);
804     std::string name = names[i];
805 
806     // Choose the formatting string based on the element kind.
807     llvm::StringRef formatString;
808     if (!element.isVariableLength()) {
809       formatString = singleResultAppendTemplate;
810     } else if (element.isOptional()) {
811       formatString = optionalAppendResultTemplate;
812     } else {
813       assert(element.isVariadic() && "unhandled element group type");
814       // If emitting with sizedSegments, then we add the actual list-typed
815       // element. Otherwise, we extend the actual operands.
816       if (sizedSegments) {
817         formatString = singleResultAppendTemplate;
818       } else {
819         formatString = multiResultAppendTemplate;
820       }
821     }
822 
823     builderLines.push_back(llvm::formatv(formatString.data(), name));
824   }
825 }
826 
827 /// If the operation has variadic regions, adds a builder argument to specify
828 /// the number of those regions and builder lines to forward it to the generic
829 /// constructor.
830 static void
831 populateBuilderRegions(const Operator &op,
832                        llvm::SmallVectorImpl<std::string> &builderArgs,
833                        llvm::SmallVectorImpl<std::string> &builderLines) {
834   if (op.hasNoVariadicRegions())
835     return;
836 
837   // This is currently enforced when Operator is constructed.
838   assert(op.getNumVariadicRegions() == 1 &&
839          op.getRegion(op.getNumRegions() - 1).isVariadic() &&
840          "expected the last region to be varidic");
841 
842   const NamedRegion &region = op.getRegion(op.getNumRegions() - 1);
843   std::string name =
844       ("num_" + region.name.take_front().lower() + region.name.drop_front())
845           .str();
846   builderArgs.push_back(name);
847   builderLines.push_back(
848       llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
849 }
850 
851 /// Emits a default builder constructing an operation from the list of its
852 /// result types, followed by a list of its operands.
853 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
854   // If we are asked to skip default builders, comply.
855   if (op.skipDefaultBuilders())
856     return;
857 
858   llvm::SmallVector<std::string> builderArgs;
859   llvm::SmallVector<std::string> builderLines;
860   llvm::SmallVector<std::string> operandArgNames;
861   llvm::SmallVector<std::string> successorArgNames;
862   builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
863                       op.getNumNativeAttributes() + op.getNumSuccessors());
864   populateBuilderArgsResults(op, builderArgs);
865   size_t numResultArgs = builderArgs.size();
866   populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
867   size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
868   populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
869 
870   populateBuilderLinesOperand(op, operandArgNames, builderLines);
871   populateBuilderLinesAttr(
872       op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs),
873       builderLines);
874   populateBuilderLinesResult(
875       op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs),
876       builderLines);
877   populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
878   populateBuilderRegions(op, builderArgs, builderLines);
879 
880   // Layout of builderArgs vector elements:
881   // [ result_args  operand_attr_args successor_args regions ]
882 
883   // Determine whether the argument corresponding to a given index into the
884   // builderArgs vector is a python keyword argument or not.
885   auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
886     // All result, successor, and region arguments are positional arguments.
887     if ((builderArgIndex < numResultArgs) ||
888         (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
889       return false;
890     // Keyword arguments:
891     // - optional named attributes (including unit attributes)
892     // - default-valued named attributes
893     // - optional operands
894     Argument a = op.getArg(builderArgIndex - numResultArgs);
895     if (auto *nattr = a.dyn_cast<NamedAttribute *>())
896       return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
897     if (auto *ntype = a.dyn_cast<NamedTypeConstraint *>())
898       return ntype->isOptional();
899     else
900       return false;
901   };
902 
903   // StringRefs in functionArgs refer to strings allocated by builderArgs.
904   llvm::SmallVector<llvm::StringRef> functionArgs;
905 
906   // Add positional arguments.
907   for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
908     if (!isKeywordArgFn(i))
909       functionArgs.push_back(builderArgs[i]);
910   }
911 
912   // Add a bare '*' to indicate that all following arguments must be keyword
913   // arguments.
914   functionArgs.push_back("*");
915 
916   // Add a default 'None' value to each keyword arg string, and then add to the
917   // function args list.
918   for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
919     if (isKeywordArgFn(i)) {
920       builderArgs[i].append("=None");
921       functionArgs.push_back(builderArgs[i]);
922     }
923   }
924   functionArgs.push_back("loc=None");
925   functionArgs.push_back("ip=None");
926   os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
927                       llvm::join(builderLines, "\n    "));
928 }
929 
930 static void constructAttributeMapping(const llvm::RecordKeeper &records,
931                                       AttributeClasses &attributeClasses) {
932   for (const llvm::Record *rec :
933        records.getAllDerivedDefinitions("PythonAttr")) {
934     attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
935                                  rec->getValueAsString("pythonType").trim());
936   }
937 }
938 
939 static void emitSegmentSpec(
940     const Operator &op, const char *kind,
941     llvm::function_ref<int(const Operator &)> getNumElements,
942     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
943         getElement,
944     raw_ostream &os) {
945   std::string segmentSpec("[");
946   for (int i = 0, e = getNumElements(op); i < e; ++i) {
947     const NamedTypeConstraint &element = getElement(op, i);
948     if (element.isOptional()) {
949       segmentSpec.append("0,");
950     } else if (element.isVariadic()) {
951       segmentSpec.append("-1,");
952     } else {
953       segmentSpec.append("1,");
954     }
955   }
956   segmentSpec.append("]");
957 
958   os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
959 }
960 
961 static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
962   // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
963   // Note that the base OpView class defines this as (0, True).
964   unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
965   os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
966                       op.hasNoVariadicRegions() ? "True" : "False");
967 }
968 
969 /// Emits named accessors to regions.
970 static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
971   for (const auto &en : llvm::enumerate(op.getRegions())) {
972     const NamedRegion &region = en.value();
973     if (region.name.empty())
974       continue;
975 
976     assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
977            "expected only the last region to be variadic");
978     os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
979                         std::to_string(en.index()) +
980                             (region.isVariadic() ? ":" : ""));
981   }
982 }
983 
984 /// Emits bindings for a specific Op to the given output stream.
985 static void emitOpBindings(const Operator &op,
986                            const AttributeClasses &attributeClasses,
987                            raw_ostream &os) {
988   os << llvm::formatv(opClassTemplate, op.getCppClassName(),
989                       op.getOperationName());
990 
991   // Sized segments.
992   if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
993     emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
994   }
995   if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
996     emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
997   }
998 
999   emitRegionAttributes(op, os);
1000   emitDefaultOpBuilder(op, os);
1001   emitOperandAccessors(op, os);
1002   emitAttributeAccessors(op, attributeClasses, os);
1003   emitResultAccessors(op, os);
1004   emitRegionAccessors(op, os);
1005 }
1006 
1007 /// Emits bindings for the dialect specified in the command line, including file
1008 /// headers and utilities. Returns `false` on success to comply with Tablegen
1009 /// registration requirements.
1010 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
1011   if (clDialectName.empty())
1012     llvm::PrintFatalError("dialect name not provided");
1013 
1014   AttributeClasses attributeClasses;
1015   constructAttributeMapping(records, attributeClasses);
1016 
1017   os << llvm::formatv(fileHeader, clDialectName.getValue());
1018   os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
1019 
1020   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
1021     Operator op(rec);
1022     if (op.getDialectName() == clDialectName.getValue())
1023       emitOpBindings(op, attributeClasses, os);
1024   }
1025   return false;
1026 }
1027 
1028 static GenRegistration
1029     genPythonBindings("gen-python-op-bindings",
1030                       "Generate Python bindings for MLIR Ops", &emitAllOps);
1031