1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 
25 /// File header and includes.
26 ///   {0} is the dialect namespace.
27 constexpr const char *fileHeader = R"Py(
28 # Autogenerated by mlir-tblgen; don't manually edit.
29 
30 from ._ods_common import _cext as _ods_cext
31 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
32 _ods_ir = _ods_cext.ir
33 
34 try:
35   from . import _{0}_ops_ext as _ods_ext_module
36 except ImportError:
37   _ods_ext_module = None
38 
39 import builtins
40 
41 )Py";
42 
43 /// Template for dialect class:
44 ///   {0} is the dialect namespace.
45 constexpr const char *dialectClassTemplate = R"Py(
46 @_ods_cext.register_dialect
47 class _Dialect(_ods_ir.Dialect):
48   DIALECT_NAMESPACE = "{0}"
49   pass
50 
51 )Py";
52 
53 constexpr const char *dialectExtensionTemplate = R"Py(
54 from ._{0}_ops_gen import _Dialect
55 )Py";
56 
57 /// Template for operation class:
58 ///   {0} is the Python class name;
59 ///   {1} is the operation name.
60 constexpr const char *opClassTemplate = R"Py(
61 @_ods_cext.register_operation(_Dialect)
62 @_ods_extend_opview_class(_ods_ext_module)
63 class {0}(_ods_ir.OpView):
64   OPERATION_NAME = "{1}"
65 )Py";
66 
67 /// Template for class level declarations of operand and result
68 /// segment specs.
69 ///   {0} is either "OPERAND" or "RESULT"
70 ///   {1} is the segment spec
71 /// Each segment spec is either None (default) or an array of integers
72 /// where:
73 ///   1 = single element (expect non sequence operand/result)
74 ///   0 = optional element (expect a value or None)
75 ///   -1 = operand/result is a sequence corresponding to a variadic
76 constexpr const char *opClassSizedSegmentsTemplate = R"Py(
77   _ODS_{0}_SEGMENTS = {1}
78 )Py";
79 
80 /// Template for class level declarations of the _ODS_REGIONS spec:
81 ///   {0} is the minimum number of regions
82 ///   {1} is the Python bool literal for hasNoVariadicRegions
83 constexpr const char *opClassRegionSpecTemplate = R"Py(
84   _ODS_REGIONS = ({0}, {1})
85 )Py";
86 
87 /// Template for single-element accessor:
88 ///   {0} is the name of the accessor;
89 ///   {1} is either 'operand' or 'result';
90 ///   {2} is the position in the element list.
91 constexpr const char *opSingleTemplate = R"Py(
92   @builtins.property
93   def {0}(self):
94     return self.operation.{1}s[{2}]
95 )Py";
96 
97 /// Template for single-element accessor after a variable-length group:
98 ///   {0} is the name of the accessor;
99 ///   {1} is either 'operand' or 'result';
100 ///   {2} is the total number of element groups;
101 ///   {3} is the position of the current group in the group list.
102 /// This works for both a single variadic group (non-negative length) and an
103 /// single optional element (zero length if the element is absent).
104 constexpr const char *opSingleAfterVariableTemplate = R"Py(
105   @builtins.property
106   def {0}(self):
107     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
108     return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
109 )Py";
110 
111 /// Template for an optional element accessor:
112 ///   {0} is the name of the accessor;
113 ///   {1} is either 'operand' or 'result';
114 ///   {2} is the total number of element groups;
115 ///   {3} is the position of the current group in the group list.
116 /// This works if we have only one variable-length group (and it's the optional
117 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
118 /// smaller than the total number of groups.
119 constexpr const char *opOneOptionalTemplate = R"Py(
120   @builtins.property
121   def {0}(self):
122     return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
123 )Py";
124 
125 /// Template for the variadic group accessor in the single variadic group case:
126 ///   {0} is the name of the accessor;
127 ///   {1} is either 'operand' or 'result';
128 ///   {2} is the total number of element groups;
129 ///   {3} is the position of the current group in the group list.
130 constexpr const char *opOneVariadicTemplate = R"Py(
131   @builtins.property
132   def {0}(self):
133     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
134     return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
135 )Py";
136 
137 /// First part of the template for equally-sized variadic group accessor:
138 ///   {0} is the name of the accessor;
139 ///   {1} is either 'operand' or 'result';
140 ///   {2} is the total number of variadic groups;
141 ///   {3} is the number of non-variadic groups preceding the current group;
142 ///   {3} is the number of variadic groups preceding the current group.
143 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
144   @builtins.property
145   def {0}(self):
146     start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
147 
148 /// Second part of the template for equally-sized case, accessing a single
149 /// element:
150 ///   {0} is either 'operand' or 'result'.
151 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
152     return self.operation.{0}s[start]
153 )Py";
154 
155 /// Second part of the template for equally-sized case, accessing a variadic
156 /// group:
157 ///   {0} is either 'operand' or 'result'.
158 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
159     return self.operation.{0}s[start:start + pg]
160 )Py";
161 
162 /// Template for an attribute-sized group accessor:
163 ///   {0} is the name of the accessor;
164 ///   {1} is either 'operand' or 'result';
165 ///   {2} is the position of the group in the group list;
166 ///   {3} is a return suffix (expected [0] for single-element, empty for
167 ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
168 constexpr const char *opVariadicSegmentTemplate = R"Py(
169   @builtins.property
170   def {0}(self):
171     {1}_range = _ods_segmented_accessor(
172          self.operation.{1}s,
173          self.operation.attributes["{1}_segment_sizes"], {2})
174     return {1}_range{3}
175 )Py";
176 
177 /// Template for a suffix when accessing an optional element in the
178 /// attribute-sized case:
179 ///   {0} is either 'operand' or 'result';
180 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
181     R"Py([0] if len({0}_range) > 0 else None)Py";
182 
183 /// Template for an operation attribute getter:
184 ///   {0} is the name of the attribute sanitized for Python;
185 ///   {1} is the Python type of the attribute;
186 ///   {2} os the original name of the attribute.
187 constexpr const char *attributeGetterTemplate = R"Py(
188   @builtins.property
189   def {0}(self):
190     return {1}(self.operation.attributes["{2}"])
191 )Py";
192 
193 /// Template for an optional operation attribute getter:
194 ///   {0} is the name of the attribute sanitized for Python;
195 ///   {1} is the Python type of the attribute;
196 ///   {2} is the original name of the attribute.
197 constexpr const char *optionalAttributeGetterTemplate = R"Py(
198   @builtins.property
199   def {0}(self):
200     if "{2}" not in self.operation.attributes:
201       return None
202     return {1}(self.operation.attributes["{2}"])
203 )Py";
204 
205 /// Template for a getter of a unit operation attribute, returns True of the
206 /// unit attribute is present, False otherwise (unit attributes have meaning
207 /// by mere presence):
208 ///    {0} is the name of the attribute sanitized for Python,
209 ///    {1} is the original name of the attribute.
210 constexpr const char *unitAttributeGetterTemplate = R"Py(
211   @builtins.property
212   def {0}(self):
213     return "{1}" in self.operation.attributes
214 )Py";
215 
216 /// Template for an operation attribute setter:
217 ///    {0} is the name of the attribute sanitized for Python;
218 ///    {1} is the original name of the attribute.
219 constexpr const char *attributeSetterTemplate = R"Py(
220   @{0}.setter
221   def {0}(self, value):
222     if value is None:
223       raise ValueError("'None' not allowed as value for mandatory attributes")
224     self.operation.attributes["{1}"] = value
225 )Py";
226 
227 /// Template for a setter of an optional operation attribute, setting to None
228 /// removes the attribute:
229 ///    {0} is the name of the attribute sanitized for Python;
230 ///    {1} is the original name of the attribute.
231 constexpr const char *optionalAttributeSetterTemplate = R"Py(
232   @{0}.setter
233   def {0}(self, value):
234     if value is not None:
235       self.operation.attributes["{1}"] = value
236     elif "{1}" in self.operation.attributes:
237       del self.operation.attributes["{1}"]
238 )Py";
239 
240 /// Template for a setter of a unit operation attribute, setting to None or
241 /// False removes the attribute:
242 ///    {0} is the name of the attribute sanitized for Python;
243 ///    {1} is the original name of the attribute.
244 constexpr const char *unitAttributeSetterTemplate = R"Py(
245   @{0}.setter
246   def {0}(self, value):
247     if bool(value):
248       self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
249     elif "{1}" in self.operation.attributes:
250       del self.operation.attributes["{1}"]
251 )Py";
252 
253 /// Template for a deleter of an optional or a unit operation attribute, removes
254 /// the attribute from the operation:
255 ///    {0} is the name of the attribute sanitized for Python;
256 ///    {1} is the original name of the attribute.
257 constexpr const char *attributeDeleterTemplate = R"Py(
258   @{0}.deleter
259   def {0}(self):
260     del self.operation.attributes["{1}"]
261 )Py";
262 
263 constexpr const char *regionAccessorTemplate = R"PY(
264   @builtins.property
265   def {0}(self):
266     return self.regions[{1}]
267 )PY";
268 
269 static llvm::cl::OptionCategory
270     clOpPythonBindingCat("Options for -gen-python-op-bindings");
271 
272 static llvm::cl::opt<std::string>
273     clDialectName("bind-dialect",
274                   llvm::cl::desc("The dialect to run the generator for"),
275                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
276 
277 static llvm::cl::opt<std::string> clDialectExtensionName(
278     "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
279     llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
280 
281 using AttributeClasses = DenseMap<StringRef, StringRef>;
282 
283 /// Checks whether `str` is a Python keyword.
284 static bool isPythonKeyword(StringRef str) {
285   static llvm::StringSet<> keywords(
286       {"and",   "as",     "assert",   "break", "class",  "continue",
287        "def",   "del",    "elif",     "else",  "except", "finally",
288        "for",   "from",   "global",   "if",    "import", "in",
289        "is",    "lambda", "nonlocal", "not",   "or",     "pass",
290        "raise", "return", "try",      "while", "with",   "yield"});
291   return keywords.contains(str);
292 }
293 
294 /// Checks whether `str` would shadow a generated variable or attribute
295 /// part of the OpView API.
296 static bool isODSReserved(StringRef str) {
297   static llvm::StringSet<> reserved(
298       {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
299        "loc", "verify", "regions", "results", "self", "operation",
300        "DIALECT_NAMESPACE", "OPERATION_NAME"});
301   return str.startswith("_ods_") || str.endswith("_ods") ||
302          reserved.contains(str);
303 }
304 
305 /// Modifies the `name` in a way that it becomes suitable for Python bindings
306 /// (does not change the `name` if it already is suitable) and returns the
307 /// modified version.
308 static std::string sanitizeName(StringRef name) {
309   if (isPythonKeyword(name) || isODSReserved(name))
310     return (name + "_").str();
311   return name.str();
312 }
313 
314 static std::string attrSizedTraitForKind(const char *kind) {
315   return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
316                        llvm::StringRef(kind).take_front().upper(),
317                        llvm::StringRef(kind).drop_front());
318 }
319 
320 /// Emits accessors to "elements" of an Op definition. Currently, the supported
321 /// elements are operands and results, indicated by `kind`, which must be either
322 /// `operand` or `result` and is used verbatim in the emitted code.
323 static void emitElementAccessors(
324     const Operator &op, raw_ostream &os, const char *kind,
325     llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
326     llvm::function_ref<int(const Operator &)> getNumElements,
327     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
328         getElement) {
329   assert(llvm::is_contained(
330              llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
331          "unsupported kind");
332 
333   // Traits indicating how to process variadic elements.
334   std::string sameSizeTrait =
335       llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
336                     llvm::StringRef(kind).take_front().upper(),
337                     llvm::StringRef(kind).drop_front());
338   std::string attrSizedTrait = attrSizedTraitForKind(kind);
339 
340   unsigned numVariableLength = getNumVariableLength(op);
341 
342   // If there is only one variable-length element group, its size can be
343   // inferred from the total number of elements. If there are none, the
344   // generation is straightforward.
345   if (numVariableLength <= 1) {
346     bool seenVariableLength = false;
347     for (int i = 0, e = getNumElements(op); i < e; ++i) {
348       const NamedTypeConstraint &element = getElement(op, i);
349       if (element.isVariableLength())
350         seenVariableLength = true;
351       if (element.name.empty())
352         continue;
353       if (element.isVariableLength()) {
354         os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
355                                                  : opOneVariadicTemplate,
356                             sanitizeName(element.name), kind,
357                             getNumElements(op), i);
358       } else if (seenVariableLength) {
359         os << llvm::formatv(opSingleAfterVariableTemplate,
360                             sanitizeName(element.name), kind,
361                             getNumElements(op), i);
362       } else {
363         os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
364                             i);
365       }
366     }
367     return;
368   }
369 
370   // Handle the operations where variadic groups have the same size.
371   if (op.getTrait(sameSizeTrait)) {
372     int numPrecedingSimple = 0;
373     int numPrecedingVariadic = 0;
374     for (int i = 0, e = getNumElements(op); i < e; ++i) {
375       const NamedTypeConstraint &element = getElement(op, i);
376       if (!element.name.empty()) {
377         os << llvm::formatv(opVariadicEqualPrefixTemplate,
378                             sanitizeName(element.name), kind, numVariableLength,
379                             numPrecedingSimple, numPrecedingVariadic);
380         os << llvm::formatv(element.isVariableLength()
381                                 ? opVariadicEqualVariadicTemplate
382                                 : opVariadicEqualSimpleTemplate,
383                             kind);
384       }
385       if (element.isVariableLength())
386         ++numPrecedingVariadic;
387       else
388         ++numPrecedingSimple;
389     }
390     return;
391   }
392 
393   // Handle the operations where the size of groups (variadic or not) is
394   // provided as an attribute. For non-variadic elements, make sure to return
395   // an element rather than a singleton container.
396   if (op.getTrait(attrSizedTrait)) {
397     for (int i = 0, e = getNumElements(op); i < e; ++i) {
398       const NamedTypeConstraint &element = getElement(op, i);
399       if (element.name.empty())
400         continue;
401       std::string trailing;
402       if (!element.isVariableLength())
403         trailing = "[0]";
404       else if (element.isOptional())
405         trailing = std::string(
406             llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
407       os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
408                           kind, i, trailing);
409     }
410     return;
411   }
412 
413   llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
414 }
415 
416 /// Free function helpers accessing Operator components.
417 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
418 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
419   return op.getOperand(i);
420 }
421 static int getNumResults(const Operator &op) { return op.getNumResults(); }
422 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
423   return op.getResult(i);
424 }
425 
426 /// Emits accessors to Op operands.
427 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
428   auto getNumVariableLengthOperands = [](const Operator &oper) {
429     return oper.getNumVariableLengthOperands();
430   };
431   emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
432                        getNumOperands, getOperand);
433 }
434 
435 /// Emits accessors Op results.
436 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
437   auto getNumVariableLengthResults = [](const Operator &oper) {
438     return oper.getNumVariableLengthResults();
439   };
440   emitElementAccessors(op, os, "result", getNumVariableLengthResults,
441                        getNumResults, getResult);
442 }
443 
444 /// Emits accessors to Op attributes.
445 static void emitAttributeAccessors(const Operator &op,
446                                    const AttributeClasses &attributeClasses,
447                                    raw_ostream &os) {
448   for (const auto &namedAttr : op.getAttributes()) {
449     // Skip "derived" attributes because they are just C++ functions that we
450     // don't currently expose.
451     if (namedAttr.attr.isDerivedAttr())
452       continue;
453 
454     if (namedAttr.name.empty())
455       continue;
456 
457     std::string sanitizedName = sanitizeName(namedAttr.name);
458 
459     // Unit attributes are handled specially.
460     if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
461       os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
462                           namedAttr.name);
463       os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
464                           namedAttr.name);
465       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
466                           namedAttr.name);
467       continue;
468     }
469 
470     // Other kinds of attributes need a mapping to a Python type.
471     if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
472       continue;
473 
474     StringRef pythonType =
475         attributeClasses.lookup(namedAttr.attr.getStorageType());
476     if (namedAttr.attr.isOptional()) {
477       os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
478                           pythonType, namedAttr.name);
479       os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
480                           namedAttr.name);
481       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
482                           namedAttr.name);
483     } else {
484       os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType,
485                           namedAttr.name);
486       os << llvm::formatv(attributeSetterTemplate, sanitizedName,
487                           namedAttr.name);
488       // Non-optional attributes cannot be deleted.
489     }
490   }
491 }
492 
493 /// Template for the default auto-generated builder.
494 ///   {0} is a comma-separated list of builder arguments, including the trailing
495 ///       `loc` and `ip`;
496 ///   {1} is the code populating `operands`, `results` and `attributes`,
497 ///       `successors` fields.
498 constexpr const char *initTemplate = R"Py(
499   def __init__(self, {0}):
500     operands = []
501     results = []
502     attributes = {{}
503     regions = None
504     {1}
505     super().__init__(self.build_generic(
506       attributes=attributes, results=results, operands=operands,
507       successors=_ods_successors, regions=regions, loc=loc, ip=ip))
508 )Py";
509 
510 /// Template for appending a single element to the operand/result list.
511 ///   {0} is the field name.
512 constexpr const char *singleOperandAppendTemplate =
513     "operands.append(_get_op_result_or_value({0}))";
514 constexpr const char *singleResultAppendTemplate = "results.append({0})";
515 
516 /// Template for appending an optional element to the operand/result list.
517 ///   {0} is the field name.
518 constexpr const char *optionalAppendOperandTemplate =
519     "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
520 constexpr const char *optionalAppendAttrSizedOperandsTemplate =
521     "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
522     "None)";
523 constexpr const char *optionalAppendResultTemplate =
524     "if {0} is not None: results.append({0})";
525 
526 /// Template for appending a list of elements to the operand/result list.
527 ///   {0} is the field name.
528 constexpr const char *multiOperandAppendTemplate =
529     "operands.extend(_get_op_results_or_values({0}))";
530 constexpr const char *multiOperandAppendPackTemplate =
531     "operands.append(_get_op_results_or_values({0}))";
532 constexpr const char *multiResultAppendTemplate = "results.extend({0})";
533 
534 /// Template for setting an attribute in the operation builder.
535 ///   {0} is the attribute name;
536 ///   {1} is the builder argument name.
537 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
538 
539 /// Template for setting an optional attribute in the operation builder.
540 ///   {0} is the attribute name;
541 ///   {1} is the builder argument name.
542 constexpr const char *initOptionalAttributeTemplate =
543     R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
544 
545 constexpr const char *initUnitAttributeTemplate =
546     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
547       _ods_get_default_loc_context(loc)))Py";
548 
549 /// Template to initialize the successors list in the builder if there are any
550 /// successors.
551 ///   {0} is the value to initialize the successors list to.
552 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
553 
554 /// Template to append or extend the list of successors in the builder.
555 ///   {0} is the list method ('append' or 'extend');
556 ///   {1} is the value to add.
557 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
558 
559 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
560 /// result types of the given operation.
561 static bool hasSameArgumentAndResultTypes(const Operator &op) {
562   return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
563          op.getNumVariableLengthResults() == 0;
564 }
565 
566 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
567 /// result types of the given operation.
568 static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
569   return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
570          op.getNumVariableLengthResults() == 0;
571 }
572 
573 /// Returns true if the InferTypeOpInterface can be used to infer result types
574 /// of the given operation.
575 static bool hasInferTypeInterface(const Operator &op) {
576   return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
577          op.getNumRegions() == 0;
578 }
579 
580 /// Returns true if there is a trait or interface that can be used to infer
581 /// result types of the given operation.
582 static bool canInferType(const Operator &op) {
583   return hasSameArgumentAndResultTypes(op) ||
584          hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
585 }
586 
587 /// Populates `builderArgs` with result names if the builder is expected to
588 /// accept them as arguments.
589 static void
590 populateBuilderArgsResults(const Operator &op,
591                            llvm::SmallVectorImpl<std::string> &builderArgs) {
592   if (canInferType(op))
593     return;
594 
595   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
596     std::string name = op.getResultName(i).str();
597     if (name.empty()) {
598       if (op.getNumResults() == 1) {
599         // Special case for one result, make the default name be 'result'
600         // to properly match the built-in result accessor.
601         name = "result";
602       } else {
603         name = llvm::formatv("_gen_res_{0}", i);
604       }
605     }
606     name = sanitizeName(name);
607     builderArgs.push_back(name);
608   }
609 }
610 
611 /// Populates `builderArgs` with the Python-compatible names of builder function
612 /// arguments using intermixed attributes and operands in the same order as they
613 /// appear in the `arguments` field of the op definition. Additionally,
614 /// `operandNames` is populated with names of operands in their order of
615 /// appearance.
616 static void
617 populateBuilderArgs(const Operator &op,
618                     llvm::SmallVectorImpl<std::string> &builderArgs,
619                     llvm::SmallVectorImpl<std::string> &operandNames,
620                     llvm::SmallVectorImpl<std::string> &successorArgNames) {
621 
622   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
623     std::string name = op.getArgName(i).str();
624     if (name.empty())
625       name = llvm::formatv("_gen_arg_{0}", i);
626     name = sanitizeName(name);
627     builderArgs.push_back(name);
628     if (!op.getArg(i).is<NamedAttribute *>())
629       operandNames.push_back(name);
630   }
631 }
632 
633 /// Populates `builderArgs` with the Python-compatible names of builder function
634 /// successor arguments. Additionally, `successorArgNames` is also populated.
635 static void populateBuilderArgsSuccessors(
636     const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
637     llvm::SmallVectorImpl<std::string> &successorArgNames) {
638 
639   for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
640     NamedSuccessor successor = op.getSuccessor(i);
641     std::string name = std::string(successor.name);
642     if (name.empty())
643       name = llvm::formatv("_gen_successor_{0}", i);
644     name = sanitizeName(name);
645     builderArgs.push_back(name);
646     successorArgNames.push_back(name);
647   }
648 }
649 
650 /// Populates `builderLines` with additional lines that are required in the
651 /// builder to set up operation attributes. `argNames` is expected to contain
652 /// the names of builder arguments that correspond to op arguments, i.e. to the
653 /// operands and attributes in the same order as they appear in the `arguments`
654 /// field.
655 static void
656 populateBuilderLinesAttr(const Operator &op,
657                          llvm::ArrayRef<std::string> argNames,
658                          llvm::SmallVectorImpl<std::string> &builderLines) {
659   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
660     Argument arg = op.getArg(i);
661     auto *attribute = arg.dyn_cast<NamedAttribute *>();
662     if (!attribute)
663       continue;
664 
665     // Unit attributes are handled specially.
666     if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
667       builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
668                                            attribute->name, argNames[i]));
669       continue;
670     }
671 
672     builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
673                                              ? initOptionalAttributeTemplate
674                                              : initAttributeTemplate,
675                                          attribute->name, argNames[i]));
676   }
677 }
678 
679 /// Populates `builderLines` with additional lines that are required in the
680 /// builder to set up successors. successorArgNames is expected to correspond
681 /// to the Python argument name for each successor on the op.
682 static void populateBuilderLinesSuccessors(
683     const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
684     llvm::SmallVectorImpl<std::string> &builderLines) {
685   if (successorArgNames.empty()) {
686     builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
687     return;
688   }
689 
690   builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
691   for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
692     auto &argName = successorArgNames[i];
693     const NamedSuccessor &successor = op.getSuccessor(i);
694     builderLines.push_back(
695         llvm::formatv(addSuccessorTemplate,
696                       successor.isVariadic() ? "extend" : "append", argName));
697   }
698 }
699 
700 /// Populates `builderLines` with additional lines that are required in the
701 /// builder to set up op operands.
702 static void
703 populateBuilderLinesOperand(const Operator &op,
704                             llvm::ArrayRef<std::string> names,
705                             llvm::SmallVectorImpl<std::string> &builderLines) {
706   bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
707 
708   // For each element, find or generate a name.
709   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
710     const NamedTypeConstraint &element = op.getOperand(i);
711     std::string name = names[i];
712 
713     // Choose the formatting string based on the element kind.
714     llvm::StringRef formatString;
715     if (!element.isVariableLength()) {
716       formatString = singleOperandAppendTemplate;
717     } else if (element.isOptional()) {
718       if (sizedSegments) {
719         formatString = optionalAppendAttrSizedOperandsTemplate;
720       } else {
721         formatString = optionalAppendOperandTemplate;
722       }
723     } else {
724       assert(element.isVariadic() && "unhandled element group type");
725       // If emitting with sizedSegments, then we add the actual list-typed
726       // element. Otherwise, we extend the actual operands.
727       if (sizedSegments) {
728         formatString = multiOperandAppendPackTemplate;
729       } else {
730         formatString = multiOperandAppendTemplate;
731       }
732     }
733 
734     builderLines.push_back(llvm::formatv(formatString.data(), name));
735   }
736 }
737 
738 /// Python code template for deriving the operation result types from its
739 /// attribute:
740 ///   - {0} is the name of the attribute from which to derive the types.
741 constexpr const char *deriveTypeFromAttrTemplate =
742     R"PY(_ods_result_type_source_attr = attributes["{0}"]
743 _ods_derived_result_type = (
744     _ods_ir.TypeAttr(_ods_result_type_source_attr).value
745     if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
746     _ods_result_type_source_attr.type))PY";
747 
748 /// Python code template appending {0} type {1} times to the results list.
749 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
750 
751 /// Python code template for inferring the operation results using the
752 /// corresponding interface:
753 ///   - {0} is the name of the class for which the types are inferred.
754 constexpr const char *inferTypeInterfaceTemplate =
755     R"PY(_ods_context = _ods_get_default_loc_context(loc)
756 results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
757     operands=operands,
758     attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
759     context=_ods_context,
760     loc=loc)
761 )PY";
762 
763 /// Appends the given multiline string as individual strings into
764 /// `builderLines`.
765 static void appendLineByLine(StringRef string,
766                              llvm::SmallVectorImpl<std::string> &builderLines) {
767 
768   std::pair<StringRef, StringRef> split = std::make_pair(string, string);
769   do {
770     split = split.second.split('\n');
771     builderLines.push_back(split.first.str());
772   } while (!split.second.empty());
773 }
774 
775 /// Populates `builderLines` with additional lines that are required in the
776 /// builder to set up op results.
777 static void
778 populateBuilderLinesResult(const Operator &op,
779                            llvm::ArrayRef<std::string> names,
780                            llvm::SmallVectorImpl<std::string> &builderLines) {
781   bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
782 
783   if (hasSameArgumentAndResultTypes(op)) {
784     builderLines.push_back(llvm::formatv(
785         appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
786     return;
787   }
788 
789   if (hasFirstAttrDerivedResultTypes(op)) {
790     const NamedAttribute &firstAttr = op.getAttribute(0);
791     assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
792                                       "from which the type is derived");
793     appendLineByLine(
794         llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
795         builderLines);
796     builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
797                                          "_ods_derived_result_type",
798                                          op.getNumResults()));
799     return;
800   }
801 
802   if (hasInferTypeInterface(op)) {
803     appendLineByLine(
804         llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
805         builderLines);
806     return;
807   }
808 
809   // For each element, find or generate a name.
810   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
811     const NamedTypeConstraint &element = op.getResult(i);
812     std::string name = names[i];
813 
814     // Choose the formatting string based on the element kind.
815     llvm::StringRef formatString;
816     if (!element.isVariableLength()) {
817       formatString = singleResultAppendTemplate;
818     } else if (element.isOptional()) {
819       formatString = optionalAppendResultTemplate;
820     } else {
821       assert(element.isVariadic() && "unhandled element group type");
822       // If emitting with sizedSegments, then we add the actual list-typed
823       // element. Otherwise, we extend the actual operands.
824       if (sizedSegments) {
825         formatString = singleResultAppendTemplate;
826       } else {
827         formatString = multiResultAppendTemplate;
828       }
829     }
830 
831     builderLines.push_back(llvm::formatv(formatString.data(), name));
832   }
833 }
834 
835 /// If the operation has variadic regions, adds a builder argument to specify
836 /// the number of those regions and builder lines to forward it to the generic
837 /// constructor.
838 static void
839 populateBuilderRegions(const Operator &op,
840                        llvm::SmallVectorImpl<std::string> &builderArgs,
841                        llvm::SmallVectorImpl<std::string> &builderLines) {
842   if (op.hasNoVariadicRegions())
843     return;
844 
845   // This is currently enforced when Operator is constructed.
846   assert(op.getNumVariadicRegions() == 1 &&
847          op.getRegion(op.getNumRegions() - 1).isVariadic() &&
848          "expected the last region to be varidic");
849 
850   const NamedRegion &region = op.getRegion(op.getNumRegions() - 1);
851   std::string name =
852       ("num_" + region.name.take_front().lower() + region.name.drop_front())
853           .str();
854   builderArgs.push_back(name);
855   builderLines.push_back(
856       llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
857 }
858 
859 /// Emits a default builder constructing an operation from the list of its
860 /// result types, followed by a list of its operands.
861 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
862   // If we are asked to skip default builders, comply.
863   if (op.skipDefaultBuilders())
864     return;
865 
866   llvm::SmallVector<std::string> builderArgs;
867   llvm::SmallVector<std::string> builderLines;
868   llvm::SmallVector<std::string> operandArgNames;
869   llvm::SmallVector<std::string> successorArgNames;
870   builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
871                       op.getNumNativeAttributes() + op.getNumSuccessors());
872   populateBuilderArgsResults(op, builderArgs);
873   size_t numResultArgs = builderArgs.size();
874   populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
875   size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
876   populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
877 
878   populateBuilderLinesOperand(op, operandArgNames, builderLines);
879   populateBuilderLinesAttr(
880       op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs),
881       builderLines);
882   populateBuilderLinesResult(
883       op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs),
884       builderLines);
885   populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
886   populateBuilderRegions(op, builderArgs, builderLines);
887 
888   // Layout of builderArgs vector elements:
889   // [ result_args  operand_attr_args successor_args regions ]
890 
891   // Determine whether the argument corresponding to a given index into the
892   // builderArgs vector is a python keyword argument or not.
893   auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
894     // All result, successor, and region arguments are positional arguments.
895     if ((builderArgIndex < numResultArgs) ||
896         (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
897       return false;
898     // Keyword arguments:
899     // - optional named attributes (including unit attributes)
900     // - default-valued named attributes
901     // - optional operands
902     Argument a = op.getArg(builderArgIndex - numResultArgs);
903     if (auto *nattr = a.dyn_cast<NamedAttribute *>())
904       return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
905     if (auto *ntype = a.dyn_cast<NamedTypeConstraint *>())
906       return ntype->isOptional();
907     return false;
908   };
909 
910   // StringRefs in functionArgs refer to strings allocated by builderArgs.
911   llvm::SmallVector<llvm::StringRef> functionArgs;
912 
913   // Add positional arguments.
914   for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
915     if (!isKeywordArgFn(i))
916       functionArgs.push_back(builderArgs[i]);
917   }
918 
919   // Add a bare '*' to indicate that all following arguments must be keyword
920   // arguments.
921   functionArgs.push_back("*");
922 
923   // Add a default 'None' value to each keyword arg string, and then add to the
924   // function args list.
925   for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
926     if (isKeywordArgFn(i)) {
927       builderArgs[i].append("=None");
928       functionArgs.push_back(builderArgs[i]);
929     }
930   }
931   functionArgs.push_back("loc=None");
932   functionArgs.push_back("ip=None");
933   os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
934                       llvm::join(builderLines, "\n    "));
935 }
936 
937 static void constructAttributeMapping(const llvm::RecordKeeper &records,
938                                       AttributeClasses &attributeClasses) {
939   for (const llvm::Record *rec :
940        records.getAllDerivedDefinitions("PythonAttr")) {
941     attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
942                                  rec->getValueAsString("pythonType").trim());
943   }
944 }
945 
946 static void emitSegmentSpec(
947     const Operator &op, const char *kind,
948     llvm::function_ref<int(const Operator &)> getNumElements,
949     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
950         getElement,
951     raw_ostream &os) {
952   std::string segmentSpec("[");
953   for (int i = 0, e = getNumElements(op); i < e; ++i) {
954     const NamedTypeConstraint &element = getElement(op, i);
955     if (element.isOptional()) {
956       segmentSpec.append("0,");
957     } else if (element.isVariadic()) {
958       segmentSpec.append("-1,");
959     } else {
960       segmentSpec.append("1,");
961     }
962   }
963   segmentSpec.append("]");
964 
965   os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
966 }
967 
968 static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
969   // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
970   // Note that the base OpView class defines this as (0, True).
971   unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
972   os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
973                       op.hasNoVariadicRegions() ? "True" : "False");
974 }
975 
976 /// Emits named accessors to regions.
977 static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
978   for (const auto &en : llvm::enumerate(op.getRegions())) {
979     const NamedRegion &region = en.value();
980     if (region.name.empty())
981       continue;
982 
983     assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
984            "expected only the last region to be variadic");
985     os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
986                         std::to_string(en.index()) +
987                             (region.isVariadic() ? ":" : ""));
988   }
989 }
990 
991 /// Emits bindings for a specific Op to the given output stream.
992 static void emitOpBindings(const Operator &op,
993                            const AttributeClasses &attributeClasses,
994                            raw_ostream &os) {
995   os << llvm::formatv(opClassTemplate, op.getCppClassName(),
996                       op.getOperationName());
997 
998   // Sized segments.
999   if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
1000     emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
1001   }
1002   if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
1003     emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
1004   }
1005 
1006   emitRegionAttributes(op, os);
1007   emitDefaultOpBuilder(op, os);
1008   emitOperandAccessors(op, os);
1009   emitAttributeAccessors(op, attributeClasses, os);
1010   emitResultAccessors(op, os);
1011   emitRegionAccessors(op, os);
1012 }
1013 
1014 /// Emits bindings for the dialect specified in the command line, including file
1015 /// headers and utilities. Returns `false` on success to comply with Tablegen
1016 /// registration requirements.
1017 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
1018   if (clDialectName.empty())
1019     llvm::PrintFatalError("dialect name not provided");
1020 
1021   AttributeClasses attributeClasses;
1022   constructAttributeMapping(records, attributeClasses);
1023 
1024   bool isExtension = !clDialectExtensionName.empty();
1025   os << llvm::formatv(fileHeader, isExtension
1026                                       ? clDialectExtensionName.getValue()
1027                                       : clDialectName.getValue());
1028   if (isExtension)
1029     os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
1030   else
1031     os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
1032 
1033   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
1034     Operator op(rec);
1035     if (op.getDialectName() == clDialectName.getValue())
1036       emitOpBindings(op, attributeClasses, os);
1037   }
1038   return false;
1039 }
1040 
1041 static GenRegistration
1042     genPythonBindings("gen-python-op-bindings",
1043                       "Generate Python bindings for MLIR Ops", &emitAllOps);
1044