1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Support/LogicalResult.h"
15 #include "mlir/TableGen/GenInfo.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "llvm/ADT/StringSet.h"
18 #include "llvm/Support/CommandLine.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/TableGen/Error.h"
21 #include "llvm/TableGen/Record.h"
22 
23 using namespace mlir;
24 using namespace mlir::tblgen;
25 
26 /// File header and includes.
27 ///   {0} is the dialect namespace.
28 constexpr const char *fileHeader = R"Py(
29 # Autogenerated by mlir-tblgen; don't manually edit.
30 
31 from ._ods_common import _cext as _ods_cext
32 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
33 _ods_ir = _ods_cext.ir
34 
35 try:
36   from . import _{0}_ops_ext as _ods_ext_module
37 except ImportError:
38   _ods_ext_module = None
39 
40 import builtins
41 
42 )Py";
43 
44 /// Template for dialect class:
45 ///   {0} is the dialect namespace.
46 constexpr const char *dialectClassTemplate = R"Py(
47 @_ods_cext.register_dialect
48 class _Dialect(_ods_ir.Dialect):
49   DIALECT_NAMESPACE = "{0}"
50   pass
51 
52 )Py";
53 
54 constexpr const char *dialectExtensionTemplate = R"Py(
55 from ._{0}_ops_gen import _Dialect
56 )Py";
57 
58 /// Template for operation class:
59 ///   {0} is the Python class name;
60 ///   {1} is the operation name.
61 constexpr const char *opClassTemplate = R"Py(
62 @_ods_cext.register_operation(_Dialect)
63 @_ods_extend_opview_class(_ods_ext_module)
64 class {0}(_ods_ir.OpView):
65   OPERATION_NAME = "{1}"
66 )Py";
67 
68 /// Template for class level declarations of operand and result
69 /// segment specs.
70 ///   {0} is either "OPERAND" or "RESULT"
71 ///   {1} is the segment spec
72 /// Each segment spec is either None (default) or an array of integers
73 /// where:
74 ///   1 = single element (expect non sequence operand/result)
75 ///   0 = optional element (expect a value or None)
76 ///   -1 = operand/result is a sequence corresponding to a variadic
77 constexpr const char *opClassSizedSegmentsTemplate = R"Py(
78   _ODS_{0}_SEGMENTS = {1}
79 )Py";
80 
81 /// Template for class level declarations of the _ODS_REGIONS spec:
82 ///   {0} is the minimum number of regions
83 ///   {1} is the Python bool literal for hasNoVariadicRegions
84 constexpr const char *opClassRegionSpecTemplate = R"Py(
85   _ODS_REGIONS = ({0}, {1})
86 )Py";
87 
88 /// Template for single-element accessor:
89 ///   {0} is the name of the accessor;
90 ///   {1} is either 'operand' or 'result';
91 ///   {2} is the position in the element list.
92 constexpr const char *opSingleTemplate = R"Py(
93   @builtins.property
94   def {0}(self):
95     return self.operation.{1}s[{2}]
96 )Py";
97 
98 /// Template for single-element accessor after a variable-length group:
99 ///   {0} is the name of the accessor;
100 ///   {1} is either 'operand' or 'result';
101 ///   {2} is the total number of element groups;
102 ///   {3} is the position of the current group in the group list.
103 /// This works for both a single variadic group (non-negative length) and an
104 /// single optional element (zero length if the element is absent).
105 constexpr const char *opSingleAfterVariableTemplate = R"Py(
106   @builtins.property
107   def {0}(self):
108     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
109     return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
110 )Py";
111 
112 /// Template for an optional element accessor:
113 ///   {0} is the name of the accessor;
114 ///   {1} is either 'operand' or 'result';
115 ///   {2} is the total number of element groups;
116 ///   {3} is the position of the current group in the group list.
117 /// This works if we have only one variable-length group (and it's the optional
118 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
119 /// smaller than the total number of groups.
120 constexpr const char *opOneOptionalTemplate = R"Py(
121   @builtins.property
122   def {0}(self):
123     return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
124 )Py";
125 
126 /// Template for the variadic group accessor in the single variadic group case:
127 ///   {0} is the name of the accessor;
128 ///   {1} is either 'operand' or 'result';
129 ///   {2} is the total number of element groups;
130 ///   {3} is the position of the current group in the group list.
131 constexpr const char *opOneVariadicTemplate = R"Py(
132   @builtins.property
133   def {0}(self):
134     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
135     return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
136 )Py";
137 
138 /// First part of the template for equally-sized variadic group accessor:
139 ///   {0} is the name of the accessor;
140 ///   {1} is either 'operand' or 'result';
141 ///   {2} is the total number of variadic groups;
142 ///   {3} is the number of non-variadic groups preceding the current group;
143 ///   {3} is the number of variadic groups preceding the current group.
144 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
145   @builtins.property
146   def {0}(self):
147     start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
148 
149 /// Second part of the template for equally-sized case, accessing a single
150 /// element:
151 ///   {0} is either 'operand' or 'result'.
152 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
153     return self.operation.{0}s[start]
154 )Py";
155 
156 /// Second part of the template for equally-sized case, accessing a variadic
157 /// group:
158 ///   {0} is either 'operand' or 'result'.
159 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
160     return self.operation.{0}s[start:start + pg]
161 )Py";
162 
163 /// Template for an attribute-sized group accessor:
164 ///   {0} is the name of the accessor;
165 ///   {1} is either 'operand' or 'result';
166 ///   {2} is the position of the group in the group list;
167 ///   {3} is a return suffix (expected [0] for single-element, empty for
168 ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
169 constexpr const char *opVariadicSegmentTemplate = R"Py(
170   @builtins.property
171   def {0}(self):
172     {1}_range = _ods_segmented_accessor(
173          self.operation.{1}s,
174          self.operation.attributes["{1}_segment_sizes"], {2})
175     return {1}_range{3}
176 )Py";
177 
178 /// Template for a suffix when accessing an optional element in the
179 /// attribute-sized case:
180 ///   {0} is either 'operand' or 'result';
181 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
182     R"Py([0] if len({0}_range) > 0 else None)Py";
183 
184 /// Template for an operation attribute getter:
185 ///   {0} is the name of the attribute sanitized for Python;
186 ///   {1} is the Python type of the attribute;
187 ///   {2} os the original name of the attribute.
188 constexpr const char *attributeGetterTemplate = R"Py(
189   @builtins.property
190   def {0}(self):
191     return {1}(self.operation.attributes["{2}"])
192 )Py";
193 
194 /// Template for an optional operation attribute getter:
195 ///   {0} is the name of the attribute sanitized for Python;
196 ///   {1} is the Python type of the attribute;
197 ///   {2} is the original name of the attribute.
198 constexpr const char *optionalAttributeGetterTemplate = R"Py(
199   @builtins.property
200   def {0}(self):
201     if "{2}" not in self.operation.attributes:
202       return None
203     return {1}(self.operation.attributes["{2}"])
204 )Py";
205 
206 /// Template for a getter of a unit operation attribute, returns True of the
207 /// unit attribute is present, False otherwise (unit attributes have meaning
208 /// by mere presence):
209 ///    {0} is the name of the attribute sanitized for Python,
210 ///    {1} is the original name of the attribute.
211 constexpr const char *unitAttributeGetterTemplate = R"Py(
212   @builtins.property
213   def {0}(self):
214     return "{1}" in self.operation.attributes
215 )Py";
216 
217 /// Template for an operation attribute setter:
218 ///    {0} is the name of the attribute sanitized for Python;
219 ///    {1} is the original name of the attribute.
220 constexpr const char *attributeSetterTemplate = R"Py(
221   @{0}.setter
222   def {0}(self, value):
223     if value is None:
224       raise ValueError("'None' not allowed as value for mandatory attributes")
225     self.operation.attributes["{1}"] = value
226 )Py";
227 
228 /// Template for a setter of an optional operation attribute, setting to None
229 /// removes the attribute:
230 ///    {0} is the name of the attribute sanitized for Python;
231 ///    {1} is the original name of the attribute.
232 constexpr const char *optionalAttributeSetterTemplate = R"Py(
233   @{0}.setter
234   def {0}(self, value):
235     if value is not None:
236       self.operation.attributes["{1}"] = value
237     elif "{1}" in self.operation.attributes:
238       del self.operation.attributes["{1}"]
239 )Py";
240 
241 /// Template for a setter of a unit operation attribute, setting to None or
242 /// False removes the attribute:
243 ///    {0} is the name of the attribute sanitized for Python;
244 ///    {1} is the original name of the attribute.
245 constexpr const char *unitAttributeSetterTemplate = R"Py(
246   @{0}.setter
247   def {0}(self, value):
248     if bool(value):
249       self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
250     elif "{1}" in self.operation.attributes:
251       del self.operation.attributes["{1}"]
252 )Py";
253 
254 /// Template for a deleter of an optional or a unit operation attribute, removes
255 /// the attribute from the operation:
256 ///    {0} is the name of the attribute sanitized for Python;
257 ///    {1} is the original name of the attribute.
258 constexpr const char *attributeDeleterTemplate = R"Py(
259   @{0}.deleter
260   def {0}(self):
261     del self.operation.attributes["{1}"]
262 )Py";
263 
264 constexpr const char *regionAccessorTemplate = R"PY(
265   @builtins.property
266   def {0}(self):
267     return self.regions[{1}]
268 )PY";
269 
270 static llvm::cl::OptionCategory
271     clOpPythonBindingCat("Options for -gen-python-op-bindings");
272 
273 static llvm::cl::opt<std::string>
274     clDialectName("bind-dialect",
275                   llvm::cl::desc("The dialect to run the generator for"),
276                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
277 
278 static llvm::cl::opt<std::string> clDialectExtensionName(
279     "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
280     llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
281 
282 using AttributeClasses = DenseMap<StringRef, StringRef>;
283 
284 /// Checks whether `str` is a Python keyword.
285 static bool isPythonKeyword(StringRef str) {
286   static llvm::StringSet<> keywords(
287       {"and",   "as",     "assert",   "break", "class",  "continue",
288        "def",   "del",    "elif",     "else",  "except", "finally",
289        "for",   "from",   "global",   "if",    "import", "in",
290        "is",    "lambda", "nonlocal", "not",   "or",     "pass",
291        "raise", "return", "try",      "while", "with",   "yield"});
292   return keywords.contains(str);
293 }
294 
295 /// Checks whether `str` would shadow a generated variable or attribute
296 /// part of the OpView API.
297 static bool isODSReserved(StringRef str) {
298   static llvm::StringSet<> reserved(
299       {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
300        "loc", "verify", "regions", "results", "self", "operation",
301        "DIALECT_NAMESPACE", "OPERATION_NAME"});
302   return str.startswith("_ods_") || str.endswith("_ods") ||
303          reserved.contains(str);
304 }
305 
306 /// Modifies the `name` in a way that it becomes suitable for Python bindings
307 /// (does not change the `name` if it already is suitable) and returns the
308 /// modified version.
309 static std::string sanitizeName(StringRef name) {
310   if (isPythonKeyword(name) || isODSReserved(name))
311     return (name + "_").str();
312   return name.str();
313 }
314 
315 static std::string attrSizedTraitForKind(const char *kind) {
316   return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
317                        llvm::StringRef(kind).take_front().upper(),
318                        llvm::StringRef(kind).drop_front());
319 }
320 
321 /// Emits accessors to "elements" of an Op definition. Currently, the supported
322 /// elements are operands and results, indicated by `kind`, which must be either
323 /// `operand` or `result` and is used verbatim in the emitted code.
324 static void emitElementAccessors(
325     const Operator &op, raw_ostream &os, const char *kind,
326     llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
327     llvm::function_ref<int(const Operator &)> getNumElements,
328     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
329         getElement) {
330   assert(llvm::is_contained(
331              llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
332          "unsupported kind");
333 
334   // Traits indicating how to process variadic elements.
335   std::string sameSizeTrait =
336       llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
337                     llvm::StringRef(kind).take_front().upper(),
338                     llvm::StringRef(kind).drop_front());
339   std::string attrSizedTrait = attrSizedTraitForKind(kind);
340 
341   unsigned numVariableLength = getNumVariableLength(op);
342 
343   // If there is only one variable-length element group, its size can be
344   // inferred from the total number of elements. If there are none, the
345   // generation is straightforward.
346   if (numVariableLength <= 1) {
347     bool seenVariableLength = false;
348     for (int i = 0, e = getNumElements(op); i < e; ++i) {
349       const NamedTypeConstraint &element = getElement(op, i);
350       if (element.isVariableLength())
351         seenVariableLength = true;
352       if (element.name.empty())
353         continue;
354       if (element.isVariableLength()) {
355         os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
356                                                  : opOneVariadicTemplate,
357                             sanitizeName(element.name), kind,
358                             getNumElements(op), i);
359       } else if (seenVariableLength) {
360         os << llvm::formatv(opSingleAfterVariableTemplate,
361                             sanitizeName(element.name), kind,
362                             getNumElements(op), i);
363       } else {
364         os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
365                             i);
366       }
367     }
368     return;
369   }
370 
371   // Handle the operations where variadic groups have the same size.
372   if (op.getTrait(sameSizeTrait)) {
373     int numPrecedingSimple = 0;
374     int numPrecedingVariadic = 0;
375     for (int i = 0, e = getNumElements(op); i < e; ++i) {
376       const NamedTypeConstraint &element = getElement(op, i);
377       if (!element.name.empty()) {
378         os << llvm::formatv(opVariadicEqualPrefixTemplate,
379                             sanitizeName(element.name), kind, numVariableLength,
380                             numPrecedingSimple, numPrecedingVariadic);
381         os << llvm::formatv(element.isVariableLength()
382                                 ? opVariadicEqualVariadicTemplate
383                                 : opVariadicEqualSimpleTemplate,
384                             kind);
385       }
386       if (element.isVariableLength())
387         ++numPrecedingVariadic;
388       else
389         ++numPrecedingSimple;
390     }
391     return;
392   }
393 
394   // Handle the operations where the size of groups (variadic or not) is
395   // provided as an attribute. For non-variadic elements, make sure to return
396   // an element rather than a singleton container.
397   if (op.getTrait(attrSizedTrait)) {
398     for (int i = 0, e = getNumElements(op); i < e; ++i) {
399       const NamedTypeConstraint &element = getElement(op, i);
400       if (element.name.empty())
401         continue;
402       std::string trailing;
403       if (!element.isVariableLength())
404         trailing = "[0]";
405       else if (element.isOptional())
406         trailing = std::string(
407             llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
408       os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
409                           kind, i, trailing);
410     }
411     return;
412   }
413 
414   llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
415 }
416 
417 /// Free function helpers accessing Operator components.
418 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
419 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
420   return op.getOperand(i);
421 }
422 static int getNumResults(const Operator &op) { return op.getNumResults(); }
423 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
424   return op.getResult(i);
425 }
426 
427 /// Emits accessors to Op operands.
428 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
429   auto getNumVariableLengthOperands = [](const Operator &oper) {
430     return oper.getNumVariableLengthOperands();
431   };
432   emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
433                        getNumOperands, getOperand);
434 }
435 
436 /// Emits accessors Op results.
437 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
438   auto getNumVariableLengthResults = [](const Operator &oper) {
439     return oper.getNumVariableLengthResults();
440   };
441   emitElementAccessors(op, os, "result", getNumVariableLengthResults,
442                        getNumResults, getResult);
443 }
444 
445 /// Emits accessors to Op attributes.
446 static void emitAttributeAccessors(const Operator &op,
447                                    const AttributeClasses &attributeClasses,
448                                    raw_ostream &os) {
449   for (const auto &namedAttr : op.getAttributes()) {
450     // Skip "derived" attributes because they are just C++ functions that we
451     // don't currently expose.
452     if (namedAttr.attr.isDerivedAttr())
453       continue;
454 
455     if (namedAttr.name.empty())
456       continue;
457 
458     std::string sanitizedName = sanitizeName(namedAttr.name);
459 
460     // Unit attributes are handled specially.
461     if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
462       os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
463                           namedAttr.name);
464       os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
465                           namedAttr.name);
466       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
467                           namedAttr.name);
468       continue;
469     }
470 
471     // Other kinds of attributes need a mapping to a Python type.
472     if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
473       continue;
474 
475     StringRef pythonType =
476         attributeClasses.lookup(namedAttr.attr.getStorageType());
477     if (namedAttr.attr.isOptional()) {
478       os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
479                           pythonType, namedAttr.name);
480       os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
481                           namedAttr.name);
482       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
483                           namedAttr.name);
484     } else {
485       os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType,
486                           namedAttr.name);
487       os << llvm::formatv(attributeSetterTemplate, sanitizedName,
488                           namedAttr.name);
489       // Non-optional attributes cannot be deleted.
490     }
491   }
492 }
493 
494 /// Template for the default auto-generated builder.
495 ///   {0} is a comma-separated list of builder arguments, including the trailing
496 ///       `loc` and `ip`;
497 ///   {1} is the code populating `operands`, `results` and `attributes`,
498 ///       `successors` fields.
499 constexpr const char *initTemplate = R"Py(
500   def __init__(self, {0}):
501     operands = []
502     results = []
503     attributes = {{}
504     regions = None
505     {1}
506     super().__init__(self.build_generic(
507       attributes=attributes, results=results, operands=operands,
508       successors=_ods_successors, regions=regions, loc=loc, ip=ip))
509 )Py";
510 
511 /// Template for appending a single element to the operand/result list.
512 ///   {0} is the field name.
513 constexpr const char *singleOperandAppendTemplate =
514     "operands.append(_get_op_result_or_value({0}))";
515 constexpr const char *singleResultAppendTemplate = "results.append({0})";
516 
517 /// Template for appending an optional element to the operand/result list.
518 ///   {0} is the field name.
519 constexpr const char *optionalAppendOperandTemplate =
520     "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
521 constexpr const char *optionalAppendAttrSizedOperandsTemplate =
522     "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
523     "None)";
524 constexpr const char *optionalAppendResultTemplate =
525     "if {0} is not None: results.append({0})";
526 
527 /// Template for appending a list of elements to the operand/result list.
528 ///   {0} is the field name.
529 constexpr const char *multiOperandAppendTemplate =
530     "operands.extend(_get_op_results_or_values({0}))";
531 constexpr const char *multiOperandAppendPackTemplate =
532     "operands.append(_get_op_results_or_values({0}))";
533 constexpr const char *multiResultAppendTemplate = "results.extend({0})";
534 
535 /// Template for setting an attribute in the operation builder.
536 ///   {0} is the attribute name;
537 ///   {1} is the builder argument name.
538 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
539 
540 /// Template for setting an optional attribute in the operation builder.
541 ///   {0} is the attribute name;
542 ///   {1} is the builder argument name.
543 constexpr const char *initOptionalAttributeTemplate =
544     R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
545 
546 /// Template for setting an attribute with a default value in the operation
547 /// builder.
548 ///   {0} is the attribute name;
549 ///   {1} is the builder argument name;
550 ///   {2} is the default value.
551 constexpr const char *initDefaultValuedAttributeTemplate =
552     R"Py(attributes["{0}"] = {1} if {1} is not None else {2})Py";
553 
554 /// Template for asserting that an attribute value was provided when calling a
555 /// builder.
556 ///   {0} is the attribute name;
557 ///   {1} is the builder argument name.
558 constexpr const char *assertAttributeValueSpecified =
559     R"Py(assert {1} is not None, "attribute {0} must be specified")Py";
560 
561 constexpr const char *initUnitAttributeTemplate =
562     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
563       _ods_get_default_loc_context(loc)))Py";
564 
565 /// Template to initialize the successors list in the builder if there are any
566 /// successors.
567 ///   {0} is the value to initialize the successors list to.
568 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
569 
570 /// Template to append or extend the list of successors in the builder.
571 ///   {0} is the list method ('append' or 'extend');
572 ///   {1} is the value to add.
573 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
574 
575 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
576 /// result types of the given operation.
577 static bool hasSameArgumentAndResultTypes(const Operator &op) {
578   return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
579          op.getNumVariableLengthResults() == 0;
580 }
581 
582 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
583 /// result types of the given operation.
584 static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
585   return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
586          op.getNumVariableLengthResults() == 0;
587 }
588 
589 /// Returns true if the InferTypeOpInterface can be used to infer result types
590 /// of the given operation.
591 static bool hasInferTypeInterface(const Operator &op) {
592   return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
593          op.getNumRegions() == 0;
594 }
595 
596 /// Returns true if there is a trait or interface that can be used to infer
597 /// result types of the given operation.
598 static bool canInferType(const Operator &op) {
599   return hasSameArgumentAndResultTypes(op) ||
600          hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
601 }
602 
603 /// Populates `builderArgs` with result names if the builder is expected to
604 /// accept them as arguments.
605 static void
606 populateBuilderArgsResults(const Operator &op,
607                            llvm::SmallVectorImpl<std::string> &builderArgs) {
608   if (canInferType(op))
609     return;
610 
611   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
612     std::string name = op.getResultName(i).str();
613     if (name.empty()) {
614       if (op.getNumResults() == 1) {
615         // Special case for one result, make the default name be 'result'
616         // to properly match the built-in result accessor.
617         name = "result";
618       } else {
619         name = llvm::formatv("_gen_res_{0}", i);
620       }
621     }
622     name = sanitizeName(name);
623     builderArgs.push_back(name);
624   }
625 }
626 
627 /// Populates `builderArgs` with the Python-compatible names of builder function
628 /// arguments using intermixed attributes and operands in the same order as they
629 /// appear in the `arguments` field of the op definition. Additionally,
630 /// `operandNames` is populated with names of operands in their order of
631 /// appearance.
632 static void
633 populateBuilderArgs(const Operator &op,
634                     llvm::SmallVectorImpl<std::string> &builderArgs,
635                     llvm::SmallVectorImpl<std::string> &operandNames,
636                     llvm::SmallVectorImpl<std::string> &successorArgNames) {
637 
638   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
639     std::string name = op.getArgName(i).str();
640     if (name.empty())
641       name = llvm::formatv("_gen_arg_{0}", i);
642     name = sanitizeName(name);
643     builderArgs.push_back(name);
644     if (!op.getArg(i).is<NamedAttribute *>())
645       operandNames.push_back(name);
646   }
647 }
648 
649 /// Populates `builderArgs` with the Python-compatible names of builder function
650 /// successor arguments. Additionally, `successorArgNames` is also populated.
651 static void populateBuilderArgsSuccessors(
652     const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
653     llvm::SmallVectorImpl<std::string> &successorArgNames) {
654 
655   for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
656     NamedSuccessor successor = op.getSuccessor(i);
657     std::string name = std::string(successor.name);
658     if (name.empty())
659       name = llvm::formatv("_gen_successor_{0}", i);
660     name = sanitizeName(name);
661     builderArgs.push_back(name);
662     successorArgNames.push_back(name);
663   }
664 }
665 
666 /// Generates Python code for the default value of the given attribute.
667 static FailureOr<std::string> getAttributeDefaultValue(Attribute attr) {
668   assert(attr.hasDefaultValue() && "expected attribute with default value");
669   StringRef storageType = attr.getStorageType().trim();
670   StringRef defaultValCpp = attr.getDefaultValue().trim();
671 
672   // A list of commonly used attribute types and default values for which
673   // we can generate Python code. Extend as needed.
674   if (storageType.equals("::mlir::ArrayAttr") && defaultValCpp.equals("{}"))
675     return std::string("_ods_ir.ArrayAttr.get([])");
676 
677   // No match: Cannot generate Python code.
678   return failure();
679 }
680 
681 /// Populates `builderLines` with additional lines that are required in the
682 /// builder to set up operation attributes. `argNames` is expected to contain
683 /// the names of builder arguments that correspond to op arguments, i.e. to the
684 /// operands and attributes in the same order as they appear in the `arguments`
685 /// field.
686 static void
687 populateBuilderLinesAttr(const Operator &op,
688                          llvm::ArrayRef<std::string> argNames,
689                          llvm::SmallVectorImpl<std::string> &builderLines) {
690   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
691     Argument arg = op.getArg(i);
692     auto *attribute = arg.dyn_cast<NamedAttribute *>();
693     if (!attribute)
694       continue;
695 
696     // Unit attributes are handled specially.
697     if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
698       builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
699                                            attribute->name, argNames[i]));
700       continue;
701     }
702 
703     // Attributes with default value are handled specially.
704     if (attribute->attr.hasDefaultValue()) {
705       // In case we cannot generate Python code for the default value, the
706       // attribute must be specified by the user.
707       FailureOr<std::string> defaultValPy =
708           getAttributeDefaultValue(attribute->attr);
709       if (succeeded(defaultValPy)) {
710         builderLines.push_back(llvm::formatv(initDefaultValuedAttributeTemplate,
711                                              attribute->name, argNames[i],
712                                              *defaultValPy));
713       } else {
714         builderLines.push_back(llvm::formatv(assertAttributeValueSpecified,
715                                              attribute->name, argNames[i]));
716         builderLines.push_back(
717             llvm::formatv(initAttributeTemplate, attribute->name, argNames[i]));
718       }
719       continue;
720     }
721 
722     builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
723                                              ? initOptionalAttributeTemplate
724                                              : initAttributeTemplate,
725                                          attribute->name, argNames[i]));
726   }
727 }
728 
729 /// Populates `builderLines` with additional lines that are required in the
730 /// builder to set up successors. successorArgNames is expected to correspond
731 /// to the Python argument name for each successor on the op.
732 static void populateBuilderLinesSuccessors(
733     const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
734     llvm::SmallVectorImpl<std::string> &builderLines) {
735   if (successorArgNames.empty()) {
736     builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
737     return;
738   }
739 
740   builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
741   for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
742     auto &argName = successorArgNames[i];
743     const NamedSuccessor &successor = op.getSuccessor(i);
744     builderLines.push_back(
745         llvm::formatv(addSuccessorTemplate,
746                       successor.isVariadic() ? "extend" : "append", argName));
747   }
748 }
749 
750 /// Populates `builderLines` with additional lines that are required in the
751 /// builder to set up op operands.
752 static void
753 populateBuilderLinesOperand(const Operator &op,
754                             llvm::ArrayRef<std::string> names,
755                             llvm::SmallVectorImpl<std::string> &builderLines) {
756   bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
757 
758   // For each element, find or generate a name.
759   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
760     const NamedTypeConstraint &element = op.getOperand(i);
761     std::string name = names[i];
762 
763     // Choose the formatting string based on the element kind.
764     llvm::StringRef formatString;
765     if (!element.isVariableLength()) {
766       formatString = singleOperandAppendTemplate;
767     } else if (element.isOptional()) {
768       if (sizedSegments) {
769         formatString = optionalAppendAttrSizedOperandsTemplate;
770       } else {
771         formatString = optionalAppendOperandTemplate;
772       }
773     } else {
774       assert(element.isVariadic() && "unhandled element group type");
775       // If emitting with sizedSegments, then we add the actual list-typed
776       // element. Otherwise, we extend the actual operands.
777       if (sizedSegments) {
778         formatString = multiOperandAppendPackTemplate;
779       } else {
780         formatString = multiOperandAppendTemplate;
781       }
782     }
783 
784     builderLines.push_back(llvm::formatv(formatString.data(), name));
785   }
786 }
787 
788 /// Python code template for deriving the operation result types from its
789 /// attribute:
790 ///   - {0} is the name of the attribute from which to derive the types.
791 constexpr const char *deriveTypeFromAttrTemplate =
792     R"PY(_ods_result_type_source_attr = attributes["{0}"]
793 _ods_derived_result_type = (
794     _ods_ir.TypeAttr(_ods_result_type_source_attr).value
795     if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
796     _ods_result_type_source_attr.type))PY";
797 
798 /// Python code template appending {0} type {1} times to the results list.
799 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
800 
801 /// Python code template for inferring the operation results using the
802 /// corresponding interface:
803 ///   - {0} is the name of the class for which the types are inferred.
804 constexpr const char *inferTypeInterfaceTemplate =
805     R"PY(_ods_context = _ods_get_default_loc_context(loc)
806 results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
807     operands=operands,
808     attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
809     context=_ods_context,
810     loc=loc)
811 )PY";
812 
813 /// Appends the given multiline string as individual strings into
814 /// `builderLines`.
815 static void appendLineByLine(StringRef string,
816                              llvm::SmallVectorImpl<std::string> &builderLines) {
817 
818   std::pair<StringRef, StringRef> split = std::make_pair(string, string);
819   do {
820     split = split.second.split('\n');
821     builderLines.push_back(split.first.str());
822   } while (!split.second.empty());
823 }
824 
825 /// Populates `builderLines` with additional lines that are required in the
826 /// builder to set up op results.
827 static void
828 populateBuilderLinesResult(const Operator &op,
829                            llvm::ArrayRef<std::string> names,
830                            llvm::SmallVectorImpl<std::string> &builderLines) {
831   bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
832 
833   if (hasSameArgumentAndResultTypes(op)) {
834     builderLines.push_back(llvm::formatv(
835         appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
836     return;
837   }
838 
839   if (hasFirstAttrDerivedResultTypes(op)) {
840     const NamedAttribute &firstAttr = op.getAttribute(0);
841     assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
842                                       "from which the type is derived");
843     appendLineByLine(
844         llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
845         builderLines);
846     builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
847                                          "_ods_derived_result_type",
848                                          op.getNumResults()));
849     return;
850   }
851 
852   if (hasInferTypeInterface(op)) {
853     appendLineByLine(
854         llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
855         builderLines);
856     return;
857   }
858 
859   // For each element, find or generate a name.
860   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
861     const NamedTypeConstraint &element = op.getResult(i);
862     std::string name = names[i];
863 
864     // Choose the formatting string based on the element kind.
865     llvm::StringRef formatString;
866     if (!element.isVariableLength()) {
867       formatString = singleResultAppendTemplate;
868     } else if (element.isOptional()) {
869       formatString = optionalAppendResultTemplate;
870     } else {
871       assert(element.isVariadic() && "unhandled element group type");
872       // If emitting with sizedSegments, then we add the actual list-typed
873       // element. Otherwise, we extend the actual operands.
874       if (sizedSegments) {
875         formatString = singleResultAppendTemplate;
876       } else {
877         formatString = multiResultAppendTemplate;
878       }
879     }
880 
881     builderLines.push_back(llvm::formatv(formatString.data(), name));
882   }
883 }
884 
885 /// If the operation has variadic regions, adds a builder argument to specify
886 /// the number of those regions and builder lines to forward it to the generic
887 /// constructor.
888 static void
889 populateBuilderRegions(const Operator &op,
890                        llvm::SmallVectorImpl<std::string> &builderArgs,
891                        llvm::SmallVectorImpl<std::string> &builderLines) {
892   if (op.hasNoVariadicRegions())
893     return;
894 
895   // This is currently enforced when Operator is constructed.
896   assert(op.getNumVariadicRegions() == 1 &&
897          op.getRegion(op.getNumRegions() - 1).isVariadic() &&
898          "expected the last region to be varidic");
899 
900   const NamedRegion &region = op.getRegion(op.getNumRegions() - 1);
901   std::string name =
902       ("num_" + region.name.take_front().lower() + region.name.drop_front())
903           .str();
904   builderArgs.push_back(name);
905   builderLines.push_back(
906       llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
907 }
908 
909 /// Emits a default builder constructing an operation from the list of its
910 /// result types, followed by a list of its operands.
911 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
912   // If we are asked to skip default builders, comply.
913   if (op.skipDefaultBuilders())
914     return;
915 
916   llvm::SmallVector<std::string> builderArgs;
917   llvm::SmallVector<std::string> builderLines;
918   llvm::SmallVector<std::string> operandArgNames;
919   llvm::SmallVector<std::string> successorArgNames;
920   builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
921                       op.getNumNativeAttributes() + op.getNumSuccessors());
922   populateBuilderArgsResults(op, builderArgs);
923   size_t numResultArgs = builderArgs.size();
924   populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
925   size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
926   populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
927 
928   populateBuilderLinesOperand(op, operandArgNames, builderLines);
929   populateBuilderLinesAttr(
930       op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs),
931       builderLines);
932   populateBuilderLinesResult(
933       op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs),
934       builderLines);
935   populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
936   populateBuilderRegions(op, builderArgs, builderLines);
937 
938   // Layout of builderArgs vector elements:
939   // [ result_args  operand_attr_args successor_args regions ]
940 
941   // Determine whether the argument corresponding to a given index into the
942   // builderArgs vector is a python keyword argument or not.
943   auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
944     // All result, successor, and region arguments are positional arguments.
945     if ((builderArgIndex < numResultArgs) ||
946         (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
947       return false;
948     // Keyword arguments:
949     // - optional named attributes (including unit attributes)
950     // - default-valued named attributes
951     // - optional operands
952     Argument a = op.getArg(builderArgIndex - numResultArgs);
953     if (auto *nattr = a.dyn_cast<NamedAttribute *>())
954       return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
955     if (auto *ntype = a.dyn_cast<NamedTypeConstraint *>())
956       return ntype->isOptional();
957     return false;
958   };
959 
960   // StringRefs in functionArgs refer to strings allocated by builderArgs.
961   llvm::SmallVector<llvm::StringRef> functionArgs;
962 
963   // Add positional arguments.
964   for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
965     if (!isKeywordArgFn(i))
966       functionArgs.push_back(builderArgs[i]);
967   }
968 
969   // Add a bare '*' to indicate that all following arguments must be keyword
970   // arguments.
971   functionArgs.push_back("*");
972 
973   // Add a default 'None' value to each keyword arg string, and then add to the
974   // function args list.
975   for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
976     if (isKeywordArgFn(i)) {
977       builderArgs[i].append("=None");
978       functionArgs.push_back(builderArgs[i]);
979     }
980   }
981   functionArgs.push_back("loc=None");
982   functionArgs.push_back("ip=None");
983   os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
984                       llvm::join(builderLines, "\n    "));
985 }
986 
987 static void constructAttributeMapping(const llvm::RecordKeeper &records,
988                                       AttributeClasses &attributeClasses) {
989   for (const llvm::Record *rec :
990        records.getAllDerivedDefinitions("PythonAttr")) {
991     attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
992                                  rec->getValueAsString("pythonType").trim());
993   }
994 }
995 
996 static void emitSegmentSpec(
997     const Operator &op, const char *kind,
998     llvm::function_ref<int(const Operator &)> getNumElements,
999     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
1000         getElement,
1001     raw_ostream &os) {
1002   std::string segmentSpec("[");
1003   for (int i = 0, e = getNumElements(op); i < e; ++i) {
1004     const NamedTypeConstraint &element = getElement(op, i);
1005     if (element.isOptional()) {
1006       segmentSpec.append("0,");
1007     } else if (element.isVariadic()) {
1008       segmentSpec.append("-1,");
1009     } else {
1010       segmentSpec.append("1,");
1011     }
1012   }
1013   segmentSpec.append("]");
1014 
1015   os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
1016 }
1017 
1018 static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
1019   // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
1020   // Note that the base OpView class defines this as (0, True).
1021   unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
1022   os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
1023                       op.hasNoVariadicRegions() ? "True" : "False");
1024 }
1025 
1026 /// Emits named accessors to regions.
1027 static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
1028   for (const auto &en : llvm::enumerate(op.getRegions())) {
1029     const NamedRegion &region = en.value();
1030     if (region.name.empty())
1031       continue;
1032 
1033     assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
1034            "expected only the last region to be variadic");
1035     os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
1036                         std::to_string(en.index()) +
1037                             (region.isVariadic() ? ":" : ""));
1038   }
1039 }
1040 
1041 /// Emits bindings for a specific Op to the given output stream.
1042 static void emitOpBindings(const Operator &op,
1043                            const AttributeClasses &attributeClasses,
1044                            raw_ostream &os) {
1045   os << llvm::formatv(opClassTemplate, op.getCppClassName(),
1046                       op.getOperationName());
1047 
1048   // Sized segments.
1049   if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
1050     emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
1051   }
1052   if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
1053     emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
1054   }
1055 
1056   emitRegionAttributes(op, os);
1057   emitDefaultOpBuilder(op, os);
1058   emitOperandAccessors(op, os);
1059   emitAttributeAccessors(op, attributeClasses, os);
1060   emitResultAccessors(op, os);
1061   emitRegionAccessors(op, os);
1062 }
1063 
1064 /// Emits bindings for the dialect specified in the command line, including file
1065 /// headers and utilities. Returns `false` on success to comply with Tablegen
1066 /// registration requirements.
1067 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
1068   if (clDialectName.empty())
1069     llvm::PrintFatalError("dialect name not provided");
1070 
1071   AttributeClasses attributeClasses;
1072   constructAttributeMapping(records, attributeClasses);
1073 
1074   bool isExtension = !clDialectExtensionName.empty();
1075   os << llvm::formatv(fileHeader, isExtension
1076                                       ? clDialectExtensionName.getValue()
1077                                       : clDialectName.getValue());
1078   if (isExtension)
1079     os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
1080   else
1081     os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
1082 
1083   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
1084     Operator op(rec);
1085     if (op.getDialectName() == clDialectName.getValue())
1086       emitOpBindings(op, attributeClasses, os);
1087   }
1088   return false;
1089 }
1090 
1091 static GenRegistration
1092     genPythonBindings("gen-python-op-bindings",
1093                       "Generate Python bindings for MLIR Ops", &emitAllOps);
1094