1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 
25 /// File header and includes.
26 ///   {0} is the dialect namespace.
27 constexpr const char *fileHeader = R"Py(
28 # Autogenerated by mlir-tblgen; don't manually edit.
29 
30 from ._ods_common import _cext as _ods_cext
31 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
32 _ods_ir = _ods_cext.ir
33 
34 try:
35   from . import _{0}_ops_ext as _ods_ext_module
36 except ImportError:
37   _ods_ext_module = None
38 
39 import builtins
40 
41 )Py";
42 
43 /// Template for dialect class:
44 ///   {0} is the dialect namespace.
45 constexpr const char *dialectClassTemplate = R"Py(
46 @_ods_cext.register_dialect
47 class _Dialect(_ods_ir.Dialect):
48   DIALECT_NAMESPACE = "{0}"
49   pass
50 
51 )Py";
52 
53 /// Template for operation class:
54 ///   {0} is the Python class name;
55 ///   {1} is the operation name.
56 constexpr const char *opClassTemplate = R"Py(
57 @_ods_cext.register_operation(_Dialect)
58 @_ods_extend_opview_class(_ods_ext_module)
59 class {0}(_ods_ir.OpView):
60   OPERATION_NAME = "{1}"
61 )Py";
62 
63 /// Template for class level declarations of operand and result
64 /// segment specs.
65 ///   {0} is either "OPERAND" or "RESULT"
66 ///   {1} is the segment spec
67 /// Each segment spec is either None (default) or an array of integers
68 /// where:
69 ///   1 = single element (expect non sequence operand/result)
70 ///   -1 = operand/result is a sequence corresponding to a variadic
71 constexpr const char *opClassSizedSegmentsTemplate = R"Py(
72   _ODS_{0}_SEGMENTS = {1}
73 )Py";
74 
75 /// Template for class level declarations of the _ODS_REGIONS spec:
76 ///   {0} is the minimum number of regions
77 ///   {1} is the Python bool literal for hasNoVariadicRegions
78 constexpr const char *opClassRegionSpecTemplate = R"Py(
79   _ODS_REGIONS = ({0}, {1})
80 )Py";
81 
82 /// Template for single-element accessor:
83 ///   {0} is the name of the accessor;
84 ///   {1} is either 'operand' or 'result';
85 ///   {2} is the position in the element list.
86 constexpr const char *opSingleTemplate = R"Py(
87   @builtins.property
88   def {0}(self):
89     return self.operation.{1}s[{2}]
90 )Py";
91 
92 /// Template for single-element accessor after a variable-length group:
93 ///   {0} is the name of the accessor;
94 ///   {1} is either 'operand' or 'result';
95 ///   {2} is the total number of element groups;
96 ///   {3} is the position of the current group in the group list.
97 /// This works for both a single variadic group (non-negative length) and an
98 /// single optional element (zero length if the element is absent).
99 constexpr const char *opSingleAfterVariableTemplate = R"Py(
100   @builtins.property
101   def {0}(self):
102     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
103     return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
104 )Py";
105 
106 /// Template for an optional element accessor:
107 ///   {0} is the name of the accessor;
108 ///   {1} is either 'operand' or 'result';
109 ///   {2} is the total number of element groups;
110 ///   {3} is the position of the current group in the group list.
111 constexpr const char *opOneOptionalTemplate = R"Py(
112   @builtins.property
113   def {0}(self):
114     return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None
115 )Py";
116 
117 /// Template for the variadic group accessor in the single variadic group case:
118 ///   {0} is the name of the accessor;
119 ///   {1} is either 'operand' or 'result';
120 ///   {2} is the total number of element groups;
121 ///   {3} is the position of the current group in the group list.
122 constexpr const char *opOneVariadicTemplate = R"Py(
123   @builtins.property
124   def {0}(self):
125     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
126     return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
127 )Py";
128 
129 /// First part of the template for equally-sized variadic group accessor:
130 ///   {0} is the name of the accessor;
131 ///   {1} is either 'operand' or 'result';
132 ///   {2} is the total number of variadic groups;
133 ///   {3} is the number of non-variadic groups preceding the current group;
134 ///   {3} is the number of variadic groups preceding the current group.
135 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
136   @builtins.property
137   def {0}(self):
138     start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
139 
140 /// Second part of the template for equally-sized case, accessing a single
141 /// element:
142 ///   {0} is either 'operand' or 'result'.
143 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
144     return self.operation.{0}s[start]
145 )Py";
146 
147 /// Second part of the template for equally-sized case, accessing a variadic
148 /// group:
149 ///   {0} is either 'operand' or 'result'.
150 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
151     return self.operation.{0}s[start:start + pg]
152 )Py";
153 
154 /// Template for an attribute-sized group accessor:
155 ///   {0} is the name of the accessor;
156 ///   {1} is either 'operand' or 'result';
157 ///   {2} is the position of the group in the group list;
158 ///   {3} is a return suffix (expected [0] for single-element, empty for
159 ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
160 constexpr const char *opVariadicSegmentTemplate = R"Py(
161   @builtins.property
162   def {0}(self):
163     {1}_range = _ods_segmented_accessor(
164          self.operation.{1}s,
165          self.operation.attributes["{1}_segment_sizes"], {2})
166     return {1}_range{3}
167 )Py";
168 
169 /// Template for a suffix when accessing an optional element in the
170 /// attribute-sized case:
171 ///   {0} is either 'operand' or 'result';
172 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
173     R"Py([0] if len({0}_range) > 0 else None)Py";
174 
175 /// Template for an operation attribute getter:
176 ///   {0} is the name of the attribute sanitized for Python;
177 ///   {1} is the Python type of the attribute;
178 ///   {2} os the original name of the attribute.
179 constexpr const char *attributeGetterTemplate = R"Py(
180   @builtins.property
181   def {0}(self):
182     return {1}(self.operation.attributes["{2}"])
183 )Py";
184 
185 /// Template for an optional operation attribute getter:
186 ///   {0} is the name of the attribute sanitized for Python;
187 ///   {1} is the Python type of the attribute;
188 ///   {2} is the original name of the attribute.
189 constexpr const char *optionalAttributeGetterTemplate = R"Py(
190   @builtins.property
191   def {0}(self):
192     if "{2}" not in self.operation.attributes:
193       return None
194     return {1}(self.operation.attributes["{2}"])
195 )Py";
196 
197 /// Template for a getter of a unit operation attribute, returns True of the
198 /// unit attribute is present, False otherwise (unit attributes have meaning
199 /// by mere presence):
200 ///    {0} is the name of the attribute sanitized for Python,
201 ///    {1} is the original name of the attribute.
202 constexpr const char *unitAttributeGetterTemplate = R"Py(
203   @builtins.property
204   def {0}(self):
205     return "{1}" in self.operation.attributes
206 )Py";
207 
208 /// Template for an operation attribute setter:
209 ///    {0} is the name of the attribute sanitized for Python;
210 ///    {1} is the original name of the attribute.
211 constexpr const char *attributeSetterTemplate = R"Py(
212   @{0}.setter
213   def {0}(self, value):
214     if value is None:
215       raise ValueError("'None' not allowed as value for mandatory attributes")
216     self.operation.attributes["{1}"] = value
217 )Py";
218 
219 /// Template for a setter of an optional operation attribute, setting to None
220 /// removes the attribute:
221 ///    {0} is the name of the attribute sanitized for Python;
222 ///    {1} is the original name of the attribute.
223 constexpr const char *optionalAttributeSetterTemplate = R"Py(
224   @{0}.setter
225   def {0}(self, value):
226     if value is not None:
227       self.operation.attributes["{1}"] = value
228     elif "{1}" in self.operation.attributes:
229       del self.operation.attributes["{1}"]
230 )Py";
231 
232 /// Template for a setter of a unit operation attribute, setting to None or
233 /// False removes the attribute:
234 ///    {0} is the name of the attribute sanitized for Python;
235 ///    {1} is the original name of the attribute.
236 constexpr const char *unitAttributeSetterTemplate = R"Py(
237   @{0}.setter
238   def {0}(self, value):
239     if bool(value):
240       self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
241     elif "{1}" in self.operation.attributes:
242       del self.operation.attributes["{1}"]
243 )Py";
244 
245 /// Template for a deleter of an optional or a unit operation attribute, removes
246 /// the attribute from the operation:
247 ///    {0} is the name of the attribute sanitized for Python;
248 ///    {1} is the original name of the attribute.
249 constexpr const char *attributeDeleterTemplate = R"Py(
250   @{0}.deleter
251   def {0}(self):
252     del self.operation.attributes["{1}"]
253 )Py";
254 
255 constexpr const char *regionAccessorTemplate = R"PY(
256   @builtins.property
257   def {0}(self):
258     return self.regions[{1}]
259 )PY";
260 
261 static llvm::cl::OptionCategory
262     clOpPythonBindingCat("Options for -gen-python-op-bindings");
263 
264 static llvm::cl::opt<std::string>
265     clDialectName("bind-dialect",
266                   llvm::cl::desc("The dialect to run the generator for"),
267                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
268 
269 using AttributeClasses = DenseMap<StringRef, StringRef>;
270 
271 /// Checks whether `str` is a Python keyword.
272 static bool isPythonKeyword(StringRef str) {
273   static llvm::StringSet<> keywords(
274       {"and",   "as",     "assert",   "break", "class",  "continue",
275        "def",   "del",    "elif",     "else",  "except", "finally",
276        "for",   "from",   "global",   "if",    "import", "in",
277        "is",    "lambda", "nonlocal", "not",   "or",     "pass",
278        "raise", "return", "try",      "while", "with",   "yield"});
279   return keywords.contains(str);
280 }
281 
282 /// Checks whether `str` would shadow a generated variable or attribute
283 /// part of the OpView API.
284 static bool isODSReserved(StringRef str) {
285   static llvm::StringSet<> reserved(
286       {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
287        "loc", "verify", "regions", "results", "self", "operation",
288        "DIALECT_NAMESPACE", "OPERATION_NAME"});
289   return str.startswith("_ods_") || str.endswith("_ods") ||
290          reserved.contains(str);
291 }
292 
293 /// Modifies the `name` in a way that it becomes suitable for Python bindings
294 /// (does not change the `name` if it already is suitable) and returns the
295 /// modified version.
296 static std::string sanitizeName(StringRef name) {
297   if (isPythonKeyword(name) || isODSReserved(name))
298     return (name + "_").str();
299   return name.str();
300 }
301 
302 static std::string attrSizedTraitForKind(const char *kind) {
303   return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
304                        llvm::StringRef(kind).take_front().upper(),
305                        llvm::StringRef(kind).drop_front());
306 }
307 
308 /// Emits accessors to "elements" of an Op definition. Currently, the supported
309 /// elements are operands and results, indicated by `kind`, which must be either
310 /// `operand` or `result` and is used verbatim in the emitted code.
311 static void emitElementAccessors(
312     const Operator &op, raw_ostream &os, const char *kind,
313     llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
314     llvm::function_ref<int(const Operator &)> getNumElements,
315     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
316         getElement) {
317   assert(llvm::is_contained(
318              llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
319          "unsupported kind");
320 
321   // Traits indicating how to process variadic elements.
322   std::string sameSizeTrait =
323       llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
324                     llvm::StringRef(kind).take_front().upper(),
325                     llvm::StringRef(kind).drop_front());
326   std::string attrSizedTrait = attrSizedTraitForKind(kind);
327 
328   unsigned numVariadic = getNumVariadic(op);
329 
330   // If there is only one variadic element group, its size can be inferred from
331   // the total number of elements. If there are none, the generation is
332   // straightforward.
333   if (numVariadic <= 1) {
334     bool seenVariableLength = false;
335     for (int i = 0, e = getNumElements(op); i < e; ++i) {
336       const NamedTypeConstraint &element = getElement(op, i);
337       if (element.isVariableLength())
338         seenVariableLength = true;
339       if (element.name.empty())
340         continue;
341       if (element.isVariableLength()) {
342         os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
343                                                  : opOneVariadicTemplate,
344                             sanitizeName(element.name), kind,
345                             getNumElements(op), i);
346       } else if (seenVariableLength) {
347         os << llvm::formatv(opSingleAfterVariableTemplate,
348                             sanitizeName(element.name), kind,
349                             getNumElements(op), i);
350       } else {
351         os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
352                             i);
353       }
354     }
355     return;
356   }
357 
358   // Handle the operations where variadic groups have the same size.
359   if (op.getTrait(sameSizeTrait)) {
360     int numPrecedingSimple = 0;
361     int numPrecedingVariadic = 0;
362     for (int i = 0, e = getNumElements(op); i < e; ++i) {
363       const NamedTypeConstraint &element = getElement(op, i);
364       if (!element.name.empty()) {
365         os << llvm::formatv(opVariadicEqualPrefixTemplate,
366                             sanitizeName(element.name), kind, numVariadic,
367                             numPrecedingSimple, numPrecedingVariadic);
368         os << llvm::formatv(element.isVariableLength()
369                                 ? opVariadicEqualVariadicTemplate
370                                 : opVariadicEqualSimpleTemplate,
371                             kind);
372       }
373       if (element.isVariableLength())
374         ++numPrecedingVariadic;
375       else
376         ++numPrecedingSimple;
377     }
378     return;
379   }
380 
381   // Handle the operations where the size of groups (variadic or not) is
382   // provided as an attribute. For non-variadic elements, make sure to return
383   // an element rather than a singleton container.
384   if (op.getTrait(attrSizedTrait)) {
385     for (int i = 0, e = getNumElements(op); i < e; ++i) {
386       const NamedTypeConstraint &element = getElement(op, i);
387       if (element.name.empty())
388         continue;
389       std::string trailing;
390       if (!element.isVariableLength())
391         trailing = "[0]";
392       else if (element.isOptional())
393         trailing = std::string(
394             llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
395       os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
396                           kind, i, trailing);
397     }
398     return;
399   }
400 
401   llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
402 }
403 
404 /// Free function helpers accessing Operator components.
405 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
406 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
407   return op.getOperand(i);
408 }
409 static int getNumResults(const Operator &op) { return op.getNumResults(); }
410 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
411   return op.getResult(i);
412 }
413 
414 /// Emits accessors to Op operands.
415 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
416   auto getNumVariadic = [](const Operator &oper) {
417     return oper.getNumVariableLengthOperands();
418   };
419   emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands,
420                        getOperand);
421 }
422 
423 /// Emits accessors Op results.
424 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
425   auto getNumVariadic = [](const Operator &oper) {
426     return oper.getNumVariableLengthResults();
427   };
428   emitElementAccessors(op, os, "result", getNumVariadic, getNumResults,
429                        getResult);
430 }
431 
432 /// Emits accessors to Op attributes.
433 static void emitAttributeAccessors(const Operator &op,
434                                    const AttributeClasses &attributeClasses,
435                                    raw_ostream &os) {
436   for (const auto &namedAttr : op.getAttributes()) {
437     // Skip "derived" attributes because they are just C++ functions that we
438     // don't currently expose.
439     if (namedAttr.attr.isDerivedAttr())
440       continue;
441 
442     if (namedAttr.name.empty())
443       continue;
444 
445     std::string sanitizedName = sanitizeName(namedAttr.name);
446 
447     // Unit attributes are handled specially.
448     if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
449       os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
450                           namedAttr.name);
451       os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
452                           namedAttr.name);
453       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
454                           namedAttr.name);
455       continue;
456     }
457 
458     // Other kinds of attributes need a mapping to a Python type.
459     if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
460       continue;
461 
462     StringRef pythonType =
463         attributeClasses.lookup(namedAttr.attr.getStorageType());
464     if (namedAttr.attr.isOptional()) {
465       os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
466                           pythonType, namedAttr.name);
467       os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
468                           namedAttr.name);
469       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
470                           namedAttr.name);
471     } else {
472       os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType,
473                           namedAttr.name);
474       os << llvm::formatv(attributeSetterTemplate, sanitizedName,
475                           namedAttr.name);
476       // Non-optional attributes cannot be deleted.
477     }
478   }
479 }
480 
481 /// Template for the default auto-generated builder.
482 ///   {0} is a comma-separated list of builder arguments, including the trailing
483 ///       `loc` and `ip`;
484 ///   {1} is the code populating `operands`, `results` and `attributes`,
485 ///       `successors` fields.
486 constexpr const char *initTemplate = R"Py(
487   def __init__(self, {0}):
488     operands = []
489     results = []
490     attributes = {{}
491     regions = None
492     {1}
493     super().__init__(self.build_generic(
494       attributes=attributes, results=results, operands=operands,
495       successors=_ods_successors, regions=regions, loc=loc, ip=ip))
496 )Py";
497 
498 /// Template for appending a single element to the operand/result list.
499 ///   {0} is the field name.
500 constexpr const char *singleOperandAppendTemplate =
501     "operands.append(_get_op_result_or_value({0}))";
502 constexpr const char *singleResultAppendTemplate = "results.append({0})";
503 
504 /// Template for appending an optional element to the operand/result list.
505 ///   {0} is the field name.
506 constexpr const char *optionalAppendOperandTemplate =
507     "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
508 constexpr const char *optionalAppendResultTemplate =
509     "if {0} is not None: results.append({0})";
510 
511 /// Template for appending a list of elements to the operand/result list.
512 ///   {0} is the field name.
513 constexpr const char *multiOperandAppendTemplate =
514     "operands.extend(_get_op_results_or_values({0}))";
515 constexpr const char *multiOperandAppendPackTemplate =
516     "operands.append(_get_op_results_or_values({0}))";
517 constexpr const char *multiResultAppendTemplate = "results.extend({0})";
518 
519 /// Template for setting an attribute in the operation builder.
520 ///   {0} is the attribute name;
521 ///   {1} is the builder argument name.
522 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
523 
524 /// Template for setting an optional attribute in the operation builder.
525 ///   {0} is the attribute name;
526 ///   {1} is the builder argument name.
527 constexpr const char *initOptionalAttributeTemplate =
528     R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
529 
530 constexpr const char *initUnitAttributeTemplate =
531     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
532       _ods_get_default_loc_context(loc)))Py";
533 
534 /// Template to initialize the successors list in the builder if there are any
535 /// successors.
536 ///   {0} is the value to initialize the successors list to.
537 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
538 
539 /// Template to append or extend the list of successors in the builder.
540 ///   {0} is the list method ('append' or 'extend');
541 ///   {1} is the value to add.
542 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
543 
544 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
545 /// result types of the given operation.
546 static bool hasSameArgumentAndResultTypes(const Operator &op) {
547   return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
548          op.getNumVariableLengthResults() == 0;
549 }
550 
551 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
552 /// result types of the given operation.
553 static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
554   return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
555          op.getNumVariableLengthResults() == 0;
556 }
557 
558 /// Returns true if the InferTypeOpInterface can be used to infer result types
559 /// of the given operation.
560 static bool hasInferTypeInterface(const Operator &op) {
561   return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
562          op.getNumRegions() == 0;
563 }
564 
565 /// Returns true if there is a trait or interface that can be used to infer
566 /// result types of the given operation.
567 static bool canInferType(const Operator &op) {
568   return hasSameArgumentAndResultTypes(op) ||
569          hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
570 }
571 
572 /// Populates `builderArgs` with result names if the builder is expected to
573 /// accept them as arguments.
574 static void
575 populateBuilderArgsResults(const Operator &op,
576                            llvm::SmallVectorImpl<std::string> &builderArgs) {
577   if (canInferType(op))
578     return;
579 
580   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
581     std::string name = op.getResultName(i).str();
582     if (name.empty()) {
583       if (op.getNumResults() == 1) {
584         // Special case for one result, make the default name be 'result'
585         // to properly match the built-in result accessor.
586         name = "result";
587       } else {
588         name = llvm::formatv("_gen_res_{0}", i);
589       }
590     }
591     name = sanitizeName(name);
592     builderArgs.push_back(name);
593   }
594 }
595 
596 /// Populates `builderArgs` with the Python-compatible names of builder function
597 /// arguments using intermixed attributes and operands in the same order as they
598 /// appear in the `arguments` field of the op definition. Additionally,
599 /// `operandNames` is populated with names of operands in their order of
600 /// appearance.
601 static void
602 populateBuilderArgs(const Operator &op,
603                     llvm::SmallVectorImpl<std::string> &builderArgs,
604                     llvm::SmallVectorImpl<std::string> &operandNames,
605                     llvm::SmallVectorImpl<std::string> &successorArgNames) {
606 
607   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
608     std::string name = op.getArgName(i).str();
609     if (name.empty())
610       name = llvm::formatv("_gen_arg_{0}", i);
611     name = sanitizeName(name);
612     builderArgs.push_back(name);
613     if (!op.getArg(i).is<NamedAttribute *>())
614       operandNames.push_back(name);
615   }
616 
617   for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
618     NamedSuccessor successor = op.getSuccessor(i);
619     std::string name = std::string(successor.name);
620     if (name.empty())
621       name = llvm::formatv("_gen_successor_{0}", i);
622     name = sanitizeName(name);
623     builderArgs.push_back(name);
624     successorArgNames.push_back(name);
625   }
626 }
627 
628 /// Populates `builderLines` with additional lines that are required in the
629 /// builder to set up operation attributes. `argNames` is expected to contain
630 /// the names of builder arguments that correspond to op arguments, i.e. to the
631 /// operands and attributes in the same order as they appear in the `arguments`
632 /// field.
633 static void
634 populateBuilderLinesAttr(const Operator &op,
635                          llvm::ArrayRef<std::string> argNames,
636                          llvm::SmallVectorImpl<std::string> &builderLines) {
637   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
638     Argument arg = op.getArg(i);
639     auto *attribute = arg.dyn_cast<NamedAttribute *>();
640     if (!attribute)
641       continue;
642 
643     // Unit attributes are handled specially.
644     if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
645       builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
646                                            attribute->name, argNames[i]));
647       continue;
648     }
649 
650     builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
651                                              ? initOptionalAttributeTemplate
652                                              : initAttributeTemplate,
653                                          attribute->name, argNames[i]));
654   }
655 }
656 
657 /// Populates `builderLines` with additional lines that are required in the
658 /// builder to set up successors. successorArgNames is expected to correspond
659 /// to the Python argument name for each successor on the op.
660 static void populateBuilderLinesSuccessors(
661     const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
662     llvm::SmallVectorImpl<std::string> &builderLines) {
663   if (successorArgNames.empty()) {
664     builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
665     return;
666   }
667 
668   builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
669   for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
670     auto &argName = successorArgNames[i];
671     const NamedSuccessor &successor = op.getSuccessor(i);
672     builderLines.push_back(
673         llvm::formatv(addSuccessorTemplate,
674                       successor.isVariadic() ? "extend" : "append", argName));
675   }
676 }
677 
678 /// Populates `builderLines` with additional lines that are required in the
679 /// builder to set up op operands.
680 static void
681 populateBuilderLinesOperand(const Operator &op,
682                             llvm::ArrayRef<std::string> names,
683                             llvm::SmallVectorImpl<std::string> &builderLines) {
684   bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
685 
686   // For each element, find or generate a name.
687   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
688     const NamedTypeConstraint &element = op.getOperand(i);
689     std::string name = names[i];
690 
691     // Choose the formatting string based on the element kind.
692     llvm::StringRef formatString;
693     if (!element.isVariableLength()) {
694       formatString = singleOperandAppendTemplate;
695     } else if (element.isOptional()) {
696       formatString = optionalAppendOperandTemplate;
697     } else {
698       assert(element.isVariadic() && "unhandled element group type");
699       // If emitting with sizedSegments, then we add the actual list-typed
700       // element. Otherwise, we extend the actual operands.
701       if (sizedSegments) {
702         formatString = multiOperandAppendPackTemplate;
703       } else {
704         formatString = multiOperandAppendTemplate;
705       }
706     }
707 
708     builderLines.push_back(llvm::formatv(formatString.data(), name));
709   }
710 }
711 
712 /// Python code template for deriving the operation result types from its
713 /// attribute:
714 ///   - {0} is the name of the attribute from which to derive the types.
715 constexpr const char *deriveTypeFromAttrTemplate =
716     R"PY(_ods_result_type_source_attr = attributes["{0}"]
717 _ods_derived_result_type = (
718     _ods_ir.TypeAttr(_ods_result_type_source_attr).value
719     if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
720     _ods_result_type_source_attr.type))PY";
721 
722 /// Python code template appending {0} type {1} times to the results list.
723 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
724 
725 /// Python code template for inferring the operation results using the
726 /// corresponding interface:
727 ///   - {0} is the name of the class for which the types are inferred.
728 constexpr const char *inferTypeInterfaceTemplate =
729     R"PY(_ods_context = _ods_get_default_loc_context(loc)
730 results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
731     operands=operands,
732     attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
733     context=_ods_context,
734     loc=loc)
735 )PY";
736 
737 /// Appends the given multiline string as individual strings into
738 /// `builderLines`.
739 static void appendLineByLine(StringRef string,
740                              llvm::SmallVectorImpl<std::string> &builderLines) {
741 
742   std::pair<StringRef, StringRef> split = std::make_pair(string, string);
743   do {
744     split = split.second.split('\n');
745     builderLines.push_back(split.first.str());
746   } while (!split.second.empty());
747 }
748 
749 /// Populates `builderLines` with additional lines that are required in the
750 /// builder to set up op results.
751 static void
752 populateBuilderLinesResult(const Operator &op,
753                            llvm::ArrayRef<std::string> names,
754                            llvm::SmallVectorImpl<std::string> &builderLines) {
755   bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
756 
757   if (hasSameArgumentAndResultTypes(op)) {
758     builderLines.push_back(llvm::formatv(
759         appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
760     return;
761   }
762 
763   if (hasFirstAttrDerivedResultTypes(op)) {
764     const NamedAttribute &firstAttr = op.getAttribute(0);
765     assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
766                                       "from which the type is derived");
767     appendLineByLine(
768         llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
769         builderLines);
770     builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
771                                          "_ods_derived_result_type",
772                                          op.getNumResults()));
773     return;
774   }
775 
776   if (hasInferTypeInterface(op)) {
777     appendLineByLine(
778         llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
779         builderLines);
780     return;
781   }
782 
783   // For each element, find or generate a name.
784   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
785     const NamedTypeConstraint &element = op.getResult(i);
786     std::string name = names[i];
787 
788     // Choose the formatting string based on the element kind.
789     llvm::StringRef formatString;
790     if (!element.isVariableLength()) {
791       formatString = singleResultAppendTemplate;
792     } else if (element.isOptional()) {
793       formatString = optionalAppendResultTemplate;
794     } else {
795       assert(element.isVariadic() && "unhandled element group type");
796       // If emitting with sizedSegments, then we add the actual list-typed
797       // element. Otherwise, we extend the actual operands.
798       if (sizedSegments) {
799         formatString = singleResultAppendTemplate;
800       } else {
801         formatString = multiResultAppendTemplate;
802       }
803     }
804 
805     builderLines.push_back(llvm::formatv(formatString.data(), name));
806   }
807 }
808 
809 /// If the operation has variadic regions, adds a builder argument to specify
810 /// the number of those regions and builder lines to forward it to the generic
811 /// constructor.
812 static void
813 populateBuilderRegions(const Operator &op,
814                        llvm::SmallVectorImpl<std::string> &builderArgs,
815                        llvm::SmallVectorImpl<std::string> &builderLines) {
816   if (op.hasNoVariadicRegions())
817     return;
818 
819   // This is currently enforced when Operator is constructed.
820   assert(op.getNumVariadicRegions() == 1 &&
821          op.getRegion(op.getNumRegions() - 1).isVariadic() &&
822          "expected the last region to be varidic");
823 
824   const NamedRegion &region = op.getRegion(op.getNumRegions() - 1);
825   std::string name =
826       ("num_" + region.name.take_front().lower() + region.name.drop_front())
827           .str();
828   builderArgs.push_back(name);
829   builderLines.push_back(
830       llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
831 }
832 
833 /// Emits a default builder constructing an operation from the list of its
834 /// result types, followed by a list of its operands.
835 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
836   // If we are asked to skip default builders, comply.
837   if (op.skipDefaultBuilders())
838     return;
839 
840   llvm::SmallVector<std::string> builderArgs;
841   llvm::SmallVector<std::string> builderLines;
842   llvm::SmallVector<std::string> operandArgNames;
843   llvm::SmallVector<std::string> successorArgNames;
844   builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
845                       op.getNumNativeAttributes() + op.getNumSuccessors());
846   populateBuilderArgsResults(op, builderArgs);
847   size_t numResultArgs = builderArgs.size();
848   populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
849 
850   populateBuilderLinesOperand(op, operandArgNames, builderLines);
851   populateBuilderLinesAttr(
852       op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs),
853       builderLines);
854   populateBuilderLinesResult(
855       op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs),
856       builderLines);
857   populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
858   populateBuilderRegions(op, builderArgs, builderLines);
859 
860   builderArgs.push_back("*");
861   builderArgs.push_back("loc=None");
862   builderArgs.push_back("ip=None");
863   os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "),
864                       llvm::join(builderLines, "\n    "));
865 }
866 
867 static void constructAttributeMapping(const llvm::RecordKeeper &records,
868                                       AttributeClasses &attributeClasses) {
869   for (const llvm::Record *rec :
870        records.getAllDerivedDefinitions("PythonAttr")) {
871     attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
872                                  rec->getValueAsString("pythonType").trim());
873   }
874 }
875 
876 static void emitSegmentSpec(
877     const Operator &op, const char *kind,
878     llvm::function_ref<int(const Operator &)> getNumElements,
879     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
880         getElement,
881     raw_ostream &os) {
882   std::string segmentSpec("[");
883   for (int i = 0, e = getNumElements(op); i < e; ++i) {
884     const NamedTypeConstraint &element = getElement(op, i);
885     if (element.isVariableLength()) {
886       segmentSpec.append("-1,");
887     } else if (element.isOptional()) {
888       segmentSpec.append("0,");
889     } else {
890       segmentSpec.append("1,");
891     }
892   }
893   segmentSpec.append("]");
894 
895   os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
896 }
897 
898 static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
899   // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
900   // Note that the base OpView class defines this as (0, True).
901   unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
902   os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
903                       op.hasNoVariadicRegions() ? "True" : "False");
904 }
905 
906 /// Emits named accessors to regions.
907 static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
908   for (auto en : llvm::enumerate(op.getRegions())) {
909     const NamedRegion &region = en.value();
910     if (region.name.empty())
911       continue;
912 
913     assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
914            "expected only the last region to be variadic");
915     os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
916                         std::to_string(en.index()) +
917                             (region.isVariadic() ? ":" : ""));
918   }
919 }
920 
921 /// Emits bindings for a specific Op to the given output stream.
922 static void emitOpBindings(const Operator &op,
923                            const AttributeClasses &attributeClasses,
924                            raw_ostream &os) {
925   os << llvm::formatv(opClassTemplate, op.getCppClassName(),
926                       op.getOperationName());
927 
928   // Sized segments.
929   if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
930     emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
931   }
932   if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
933     emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
934   }
935 
936   emitRegionAttributes(op, os);
937   emitDefaultOpBuilder(op, os);
938   emitOperandAccessors(op, os);
939   emitAttributeAccessors(op, attributeClasses, os);
940   emitResultAccessors(op, os);
941   emitRegionAccessors(op, os);
942 }
943 
944 /// Emits bindings for the dialect specified in the command line, including file
945 /// headers and utilities. Returns `false` on success to comply with Tablegen
946 /// registration requirements.
947 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
948   if (clDialectName.empty())
949     llvm::PrintFatalError("dialect name not provided");
950 
951   AttributeClasses attributeClasses;
952   constructAttributeMapping(records, attributeClasses);
953 
954   os << llvm::formatv(fileHeader, clDialectName.getValue());
955   os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
956 
957   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
958     Operator op(rec);
959     if (op.getDialectName() == clDialectName.getValue())
960       emitOpBindings(op, attributeClasses, os);
961   }
962   return false;
963 }
964 
965 static GenRegistration
966     genPythonBindings("gen-python-op-bindings",
967                       "Generate Python bindings for MLIR Ops", &emitAllOps);
968