1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Support/LogicalResult.h"
15 #include "mlir/TableGen/GenInfo.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "llvm/ADT/StringSet.h"
18 #include "llvm/Support/CommandLine.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/TableGen/Error.h"
21 #include "llvm/TableGen/Record.h"
22
23 using namespace mlir;
24 using namespace mlir::tblgen;
25
26 /// File header and includes.
27 /// {0} is the dialect namespace.
28 constexpr const char *fileHeader = R"Py(
29 # Autogenerated by mlir-tblgen; don't manually edit.
30
31 from ._ods_common import _cext as _ods_cext
32 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
33 _ods_ir = _ods_cext.ir
34
35 try:
36 from . import _{0}_ops_ext as _ods_ext_module
37 except ImportError:
38 _ods_ext_module = None
39
40 import builtins
41
42 )Py";
43
44 /// Template for dialect class:
45 /// {0} is the dialect namespace.
46 constexpr const char *dialectClassTemplate = R"Py(
47 @_ods_cext.register_dialect
48 class _Dialect(_ods_ir.Dialect):
49 DIALECT_NAMESPACE = "{0}"
50 pass
51
52 )Py";
53
54 constexpr const char *dialectExtensionTemplate = R"Py(
55 from ._{0}_ops_gen import _Dialect
56 )Py";
57
58 /// Template for operation class:
59 /// {0} is the Python class name;
60 /// {1} is the operation name.
61 constexpr const char *opClassTemplate = R"Py(
62 @_ods_cext.register_operation(_Dialect)
63 @_ods_extend_opview_class(_ods_ext_module)
64 class {0}(_ods_ir.OpView):
65 OPERATION_NAME = "{1}"
66 )Py";
67
68 /// Template for class level declarations of operand and result
69 /// segment specs.
70 /// {0} is either "OPERAND" or "RESULT"
71 /// {1} is the segment spec
72 /// Each segment spec is either None (default) or an array of integers
73 /// where:
74 /// 1 = single element (expect non sequence operand/result)
75 /// 0 = optional element (expect a value or None)
76 /// -1 = operand/result is a sequence corresponding to a variadic
77 constexpr const char *opClassSizedSegmentsTemplate = R"Py(
78 _ODS_{0}_SEGMENTS = {1}
79 )Py";
80
81 /// Template for class level declarations of the _ODS_REGIONS spec:
82 /// {0} is the minimum number of regions
83 /// {1} is the Python bool literal for hasNoVariadicRegions
84 constexpr const char *opClassRegionSpecTemplate = R"Py(
85 _ODS_REGIONS = ({0}, {1})
86 )Py";
87
88 /// Template for single-element accessor:
89 /// {0} is the name of the accessor;
90 /// {1} is either 'operand' or 'result';
91 /// {2} is the position in the element list.
92 constexpr const char *opSingleTemplate = R"Py(
93 @builtins.property
94 def {0}(self):
95 return self.operation.{1}s[{2}]
96 )Py";
97
98 /// Template for single-element accessor after a variable-length group:
99 /// {0} is the name of the accessor;
100 /// {1} is either 'operand' or 'result';
101 /// {2} is the total number of element groups;
102 /// {3} is the position of the current group in the group list.
103 /// This works for both a single variadic group (non-negative length) and an
104 /// single optional element (zero length if the element is absent).
105 constexpr const char *opSingleAfterVariableTemplate = R"Py(
106 @builtins.property
107 def {0}(self):
108 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
109 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
110 )Py";
111
112 /// Template for an optional element accessor:
113 /// {0} is the name of the accessor;
114 /// {1} is either 'operand' or 'result';
115 /// {2} is the total number of element groups;
116 /// {3} is the position of the current group in the group list.
117 /// This works if we have only one variable-length group (and it's the optional
118 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
119 /// smaller than the total number of groups.
120 constexpr const char *opOneOptionalTemplate = R"Py(
121 @builtins.property
122 def {0}(self):
123 return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
124 )Py";
125
126 /// Template for the variadic group accessor in the single variadic group case:
127 /// {0} is the name of the accessor;
128 /// {1} is either 'operand' or 'result';
129 /// {2} is the total number of element groups;
130 /// {3} is the position of the current group in the group list.
131 constexpr const char *opOneVariadicTemplate = R"Py(
132 @builtins.property
133 def {0}(self):
134 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
135 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
136 )Py";
137
138 /// First part of the template for equally-sized variadic group accessor:
139 /// {0} is the name of the accessor;
140 /// {1} is either 'operand' or 'result';
141 /// {2} is the total number of variadic groups;
142 /// {3} is the number of non-variadic groups preceding the current group;
143 /// {3} is the number of variadic groups preceding the current group.
144 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
145 @builtins.property
146 def {0}(self):
147 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
148
149 /// Second part of the template for equally-sized case, accessing a single
150 /// element:
151 /// {0} is either 'operand' or 'result'.
152 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
153 return self.operation.{0}s[start]
154 )Py";
155
156 /// Second part of the template for equally-sized case, accessing a variadic
157 /// group:
158 /// {0} is either 'operand' or 'result'.
159 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
160 return self.operation.{0}s[start:start + pg]
161 )Py";
162
163 /// Template for an attribute-sized group accessor:
164 /// {0} is the name of the accessor;
165 /// {1} is either 'operand' or 'result';
166 /// {2} is the position of the group in the group list;
167 /// {3} is a return suffix (expected [0] for single-element, empty for
168 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
169 constexpr const char *opVariadicSegmentTemplate = R"Py(
170 @builtins.property
171 def {0}(self):
172 {1}_range = _ods_segmented_accessor(
173 self.operation.{1}s,
174 self.operation.attributes["{1}_segment_sizes"], {2})
175 return {1}_range{3}
176 )Py";
177
178 /// Template for a suffix when accessing an optional element in the
179 /// attribute-sized case:
180 /// {0} is either 'operand' or 'result';
181 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
182 R"Py([0] if len({0}_range) > 0 else None)Py";
183
184 /// Template for an operation attribute getter:
185 /// {0} is the name of the attribute sanitized for Python;
186 /// {1} is the Python type of the attribute;
187 /// {2} os the original name of the attribute.
188 constexpr const char *attributeGetterTemplate = R"Py(
189 @builtins.property
190 def {0}(self):
191 return {1}(self.operation.attributes["{2}"])
192 )Py";
193
194 /// Template for an optional operation attribute getter:
195 /// {0} is the name of the attribute sanitized for Python;
196 /// {1} is the Python type of the attribute;
197 /// {2} is the original name of the attribute.
198 constexpr const char *optionalAttributeGetterTemplate = R"Py(
199 @builtins.property
200 def {0}(self):
201 if "{2}" not in self.operation.attributes:
202 return None
203 return {1}(self.operation.attributes["{2}"])
204 )Py";
205
206 /// Template for a getter of a unit operation attribute, returns True of the
207 /// unit attribute is present, False otherwise (unit attributes have meaning
208 /// by mere presence):
209 /// {0} is the name of the attribute sanitized for Python,
210 /// {1} is the original name of the attribute.
211 constexpr const char *unitAttributeGetterTemplate = R"Py(
212 @builtins.property
213 def {0}(self):
214 return "{1}" in self.operation.attributes
215 )Py";
216
217 /// Template for an operation attribute setter:
218 /// {0} is the name of the attribute sanitized for Python;
219 /// {1} is the original name of the attribute.
220 constexpr const char *attributeSetterTemplate = R"Py(
221 @{0}.setter
222 def {0}(self, value):
223 if value is None:
224 raise ValueError("'None' not allowed as value for mandatory attributes")
225 self.operation.attributes["{1}"] = value
226 )Py";
227
228 /// Template for a setter of an optional operation attribute, setting to None
229 /// removes the attribute:
230 /// {0} is the name of the attribute sanitized for Python;
231 /// {1} is the original name of the attribute.
232 constexpr const char *optionalAttributeSetterTemplate = R"Py(
233 @{0}.setter
234 def {0}(self, value):
235 if value is not None:
236 self.operation.attributes["{1}"] = value
237 elif "{1}" in self.operation.attributes:
238 del self.operation.attributes["{1}"]
239 )Py";
240
241 /// Template for a setter of a unit operation attribute, setting to None or
242 /// False removes the attribute:
243 /// {0} is the name of the attribute sanitized for Python;
244 /// {1} is the original name of the attribute.
245 constexpr const char *unitAttributeSetterTemplate = R"Py(
246 @{0}.setter
247 def {0}(self, value):
248 if bool(value):
249 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
250 elif "{1}" in self.operation.attributes:
251 del self.operation.attributes["{1}"]
252 )Py";
253
254 /// Template for a deleter of an optional or a unit operation attribute, removes
255 /// the attribute from the operation:
256 /// {0} is the name of the attribute sanitized for Python;
257 /// {1} is the original name of the attribute.
258 constexpr const char *attributeDeleterTemplate = R"Py(
259 @{0}.deleter
260 def {0}(self):
261 del self.operation.attributes["{1}"]
262 )Py";
263
264 constexpr const char *regionAccessorTemplate = R"PY(
265 @builtins.property
266 def {0}(self):
267 return self.regions[{1}]
268 )PY";
269
270 static llvm::cl::OptionCategory
271 clOpPythonBindingCat("Options for -gen-python-op-bindings");
272
273 static llvm::cl::opt<std::string>
274 clDialectName("bind-dialect",
275 llvm::cl::desc("The dialect to run the generator for"),
276 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
277
278 static llvm::cl::opt<std::string> clDialectExtensionName(
279 "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
280 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
281
282 using AttributeClasses = DenseMap<StringRef, StringRef>;
283
284 /// Checks whether `str` is a Python keyword.
isPythonKeyword(StringRef str)285 static bool isPythonKeyword(StringRef str) {
286 static llvm::StringSet<> keywords(
287 {"and", "as", "assert", "break", "class", "continue",
288 "def", "del", "elif", "else", "except", "finally",
289 "for", "from", "global", "if", "import", "in",
290 "is", "lambda", "nonlocal", "not", "or", "pass",
291 "raise", "return", "try", "while", "with", "yield"});
292 return keywords.contains(str);
293 }
294
295 /// Checks whether `str` would shadow a generated variable or attribute
296 /// part of the OpView API.
isODSReserved(StringRef str)297 static bool isODSReserved(StringRef str) {
298 static llvm::StringSet<> reserved(
299 {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
300 "loc", "verify", "regions", "results", "self", "operation",
301 "DIALECT_NAMESPACE", "OPERATION_NAME"});
302 return str.startswith("_ods_") || str.endswith("_ods") ||
303 reserved.contains(str);
304 }
305
306 /// Modifies the `name` in a way that it becomes suitable for Python bindings
307 /// (does not change the `name` if it already is suitable) and returns the
308 /// modified version.
sanitizeName(StringRef name)309 static std::string sanitizeName(StringRef name) {
310 if (isPythonKeyword(name) || isODSReserved(name))
311 return (name + "_").str();
312 return name.str();
313 }
314
attrSizedTraitForKind(const char * kind)315 static std::string attrSizedTraitForKind(const char *kind) {
316 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
317 llvm::StringRef(kind).take_front().upper(),
318 llvm::StringRef(kind).drop_front());
319 }
320
321 /// Emits accessors to "elements" of an Op definition. Currently, the supported
322 /// elements are operands and results, indicated by `kind`, which must be either
323 /// `operand` or `result` and is used verbatim in the emitted code.
emitElementAccessors(const Operator & op,raw_ostream & os,const char * kind,llvm::function_ref<unsigned (const Operator &)> getNumVariableLength,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement)324 static void emitElementAccessors(
325 const Operator &op, raw_ostream &os, const char *kind,
326 llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
327 llvm::function_ref<int(const Operator &)> getNumElements,
328 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
329 getElement) {
330 assert(llvm::is_contained(
331 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
332 "unsupported kind");
333
334 // Traits indicating how to process variadic elements.
335 std::string sameSizeTrait =
336 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
337 llvm::StringRef(kind).take_front().upper(),
338 llvm::StringRef(kind).drop_front());
339 std::string attrSizedTrait = attrSizedTraitForKind(kind);
340
341 unsigned numVariableLength = getNumVariableLength(op);
342
343 // If there is only one variable-length element group, its size can be
344 // inferred from the total number of elements. If there are none, the
345 // generation is straightforward.
346 if (numVariableLength <= 1) {
347 bool seenVariableLength = false;
348 for (int i = 0, e = getNumElements(op); i < e; ++i) {
349 const NamedTypeConstraint &element = getElement(op, i);
350 if (element.isVariableLength())
351 seenVariableLength = true;
352 if (element.name.empty())
353 continue;
354 if (element.isVariableLength()) {
355 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
356 : opOneVariadicTemplate,
357 sanitizeName(element.name), kind,
358 getNumElements(op), i);
359 } else if (seenVariableLength) {
360 os << llvm::formatv(opSingleAfterVariableTemplate,
361 sanitizeName(element.name), kind,
362 getNumElements(op), i);
363 } else {
364 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
365 i);
366 }
367 }
368 return;
369 }
370
371 // Handle the operations where variadic groups have the same size.
372 if (op.getTrait(sameSizeTrait)) {
373 int numPrecedingSimple = 0;
374 int numPrecedingVariadic = 0;
375 for (int i = 0, e = getNumElements(op); i < e; ++i) {
376 const NamedTypeConstraint &element = getElement(op, i);
377 if (!element.name.empty()) {
378 os << llvm::formatv(opVariadicEqualPrefixTemplate,
379 sanitizeName(element.name), kind, numVariableLength,
380 numPrecedingSimple, numPrecedingVariadic);
381 os << llvm::formatv(element.isVariableLength()
382 ? opVariadicEqualVariadicTemplate
383 : opVariadicEqualSimpleTemplate,
384 kind);
385 }
386 if (element.isVariableLength())
387 ++numPrecedingVariadic;
388 else
389 ++numPrecedingSimple;
390 }
391 return;
392 }
393
394 // Handle the operations where the size of groups (variadic or not) is
395 // provided as an attribute. For non-variadic elements, make sure to return
396 // an element rather than a singleton container.
397 if (op.getTrait(attrSizedTrait)) {
398 for (int i = 0, e = getNumElements(op); i < e; ++i) {
399 const NamedTypeConstraint &element = getElement(op, i);
400 if (element.name.empty())
401 continue;
402 std::string trailing;
403 if (!element.isVariableLength())
404 trailing = "[0]";
405 else if (element.isOptional())
406 trailing = std::string(
407 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
408 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
409 kind, i, trailing);
410 }
411 return;
412 }
413
414 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
415 }
416
417 /// Free function helpers accessing Operator components.
getNumOperands(const Operator & op)418 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
getOperand(const Operator & op,int i)419 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
420 return op.getOperand(i);
421 }
getNumResults(const Operator & op)422 static int getNumResults(const Operator &op) { return op.getNumResults(); }
getResult(const Operator & op,int i)423 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
424 return op.getResult(i);
425 }
426
427 /// Emits accessors to Op operands.
emitOperandAccessors(const Operator & op,raw_ostream & os)428 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
429 auto getNumVariableLengthOperands = [](const Operator &oper) {
430 return oper.getNumVariableLengthOperands();
431 };
432 emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
433 getNumOperands, getOperand);
434 }
435
436 /// Emits accessors Op results.
emitResultAccessors(const Operator & op,raw_ostream & os)437 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
438 auto getNumVariableLengthResults = [](const Operator &oper) {
439 return oper.getNumVariableLengthResults();
440 };
441 emitElementAccessors(op, os, "result", getNumVariableLengthResults,
442 getNumResults, getResult);
443 }
444
445 /// Emits accessors to Op attributes.
emitAttributeAccessors(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)446 static void emitAttributeAccessors(const Operator &op,
447 const AttributeClasses &attributeClasses,
448 raw_ostream &os) {
449 for (const auto &namedAttr : op.getAttributes()) {
450 // Skip "derived" attributes because they are just C++ functions that we
451 // don't currently expose.
452 if (namedAttr.attr.isDerivedAttr())
453 continue;
454
455 if (namedAttr.name.empty())
456 continue;
457
458 std::string sanitizedName = sanitizeName(namedAttr.name);
459
460 // Unit attributes are handled specially.
461 if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
462 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
463 namedAttr.name);
464 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
465 namedAttr.name);
466 os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
467 namedAttr.name);
468 continue;
469 }
470
471 // Other kinds of attributes need a mapping to a Python type.
472 if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
473 continue;
474
475 StringRef pythonType =
476 attributeClasses.lookup(namedAttr.attr.getStorageType());
477 if (namedAttr.attr.isOptional()) {
478 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
479 pythonType, namedAttr.name);
480 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
481 namedAttr.name);
482 os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
483 namedAttr.name);
484 } else {
485 os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType,
486 namedAttr.name);
487 os << llvm::formatv(attributeSetterTemplate, sanitizedName,
488 namedAttr.name);
489 // Non-optional attributes cannot be deleted.
490 }
491 }
492 }
493
494 /// Template for the default auto-generated builder.
495 /// {0} is a comma-separated list of builder arguments, including the trailing
496 /// `loc` and `ip`;
497 /// {1} is the code populating `operands`, `results` and `attributes`,
498 /// `successors` fields.
499 constexpr const char *initTemplate = R"Py(
500 def __init__(self, {0}):
501 operands = []
502 results = []
503 attributes = {{}
504 regions = None
505 {1}
506 super().__init__(self.build_generic(
507 attributes=attributes, results=results, operands=operands,
508 successors=_ods_successors, regions=regions, loc=loc, ip=ip))
509 )Py";
510
511 /// Template for appending a single element to the operand/result list.
512 /// {0} is the field name.
513 constexpr const char *singleOperandAppendTemplate =
514 "operands.append(_get_op_result_or_value({0}))";
515 constexpr const char *singleResultAppendTemplate = "results.append({0})";
516
517 /// Template for appending an optional element to the operand/result list.
518 /// {0} is the field name.
519 constexpr const char *optionalAppendOperandTemplate =
520 "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
521 constexpr const char *optionalAppendAttrSizedOperandsTemplate =
522 "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
523 "None)";
524 constexpr const char *optionalAppendResultTemplate =
525 "if {0} is not None: results.append({0})";
526
527 /// Template for appending a list of elements to the operand/result list.
528 /// {0} is the field name.
529 constexpr const char *multiOperandAppendTemplate =
530 "operands.extend(_get_op_results_or_values({0}))";
531 constexpr const char *multiOperandAppendPackTemplate =
532 "operands.append(_get_op_results_or_values({0}))";
533 constexpr const char *multiResultAppendTemplate = "results.extend({0})";
534
535 /// Template for setting an attribute in the operation builder.
536 /// {0} is the attribute name;
537 /// {1} is the builder argument name.
538 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
539
540 /// Template for setting an optional attribute in the operation builder.
541 /// {0} is the attribute name;
542 /// {1} is the builder argument name.
543 constexpr const char *initOptionalAttributeTemplate =
544 R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
545
546 /// Template for setting an attribute with a default value in the operation
547 /// builder.
548 /// {0} is the attribute name;
549 /// {1} is the builder argument name;
550 /// {2} is the default value.
551 constexpr const char *initDefaultValuedAttributeTemplate =
552 R"Py(attributes["{0}"] = {1} if {1} is not None else {2})Py";
553
554 /// Template for asserting that an attribute value was provided when calling a
555 /// builder.
556 /// {0} is the attribute name;
557 /// {1} is the builder argument name.
558 constexpr const char *assertAttributeValueSpecified =
559 R"Py(assert {1} is not None, "attribute {0} must be specified")Py";
560
561 constexpr const char *initUnitAttributeTemplate =
562 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
563 _ods_get_default_loc_context(loc)))Py";
564
565 /// Template to initialize the successors list in the builder if there are any
566 /// successors.
567 /// {0} is the value to initialize the successors list to.
568 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
569
570 /// Template to append or extend the list of successors in the builder.
571 /// {0} is the list method ('append' or 'extend');
572 /// {1} is the value to add.
573 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
574
575 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
576 /// result types of the given operation.
hasSameArgumentAndResultTypes(const Operator & op)577 static bool hasSameArgumentAndResultTypes(const Operator &op) {
578 return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
579 op.getNumVariableLengthResults() == 0;
580 }
581
582 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
583 /// result types of the given operation.
hasFirstAttrDerivedResultTypes(const Operator & op)584 static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
585 return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
586 op.getNumVariableLengthResults() == 0;
587 }
588
589 /// Returns true if the InferTypeOpInterface can be used to infer result types
590 /// of the given operation.
hasInferTypeInterface(const Operator & op)591 static bool hasInferTypeInterface(const Operator &op) {
592 return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
593 op.getNumRegions() == 0;
594 }
595
596 /// Returns true if there is a trait or interface that can be used to infer
597 /// result types of the given operation.
canInferType(const Operator & op)598 static bool canInferType(const Operator &op) {
599 return hasSameArgumentAndResultTypes(op) ||
600 hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
601 }
602
603 /// Populates `builderArgs` with result names if the builder is expected to
604 /// accept them as arguments.
605 static void
populateBuilderArgsResults(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs)606 populateBuilderArgsResults(const Operator &op,
607 llvm::SmallVectorImpl<std::string> &builderArgs) {
608 if (canInferType(op))
609 return;
610
611 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
612 std::string name = op.getResultName(i).str();
613 if (name.empty()) {
614 if (op.getNumResults() == 1) {
615 // Special case for one result, make the default name be 'result'
616 // to properly match the built-in result accessor.
617 name = "result";
618 } else {
619 name = llvm::formatv("_gen_res_{0}", i);
620 }
621 }
622 name = sanitizeName(name);
623 builderArgs.push_back(name);
624 }
625 }
626
627 /// Populates `builderArgs` with the Python-compatible names of builder function
628 /// arguments using intermixed attributes and operands in the same order as they
629 /// appear in the `arguments` field of the op definition. Additionally,
630 /// `operandNames` is populated with names of operands in their order of
631 /// appearance.
632 static void
populateBuilderArgs(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs,llvm::SmallVectorImpl<std::string> & operandNames,llvm::SmallVectorImpl<std::string> & successorArgNames)633 populateBuilderArgs(const Operator &op,
634 llvm::SmallVectorImpl<std::string> &builderArgs,
635 llvm::SmallVectorImpl<std::string> &operandNames,
636 llvm::SmallVectorImpl<std::string> &successorArgNames) {
637
638 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
639 std::string name = op.getArgName(i).str();
640 if (name.empty())
641 name = llvm::formatv("_gen_arg_{0}", i);
642 name = sanitizeName(name);
643 builderArgs.push_back(name);
644 if (!op.getArg(i).is<NamedAttribute *>())
645 operandNames.push_back(name);
646 }
647 }
648
649 /// Populates `builderArgs` with the Python-compatible names of builder function
650 /// successor arguments. Additionally, `successorArgNames` is also populated.
populateBuilderArgsSuccessors(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs,llvm::SmallVectorImpl<std::string> & successorArgNames)651 static void populateBuilderArgsSuccessors(
652 const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
653 llvm::SmallVectorImpl<std::string> &successorArgNames) {
654
655 for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
656 NamedSuccessor successor = op.getSuccessor(i);
657 std::string name = std::string(successor.name);
658 if (name.empty())
659 name = llvm::formatv("_gen_successor_{0}", i);
660 name = sanitizeName(name);
661 builderArgs.push_back(name);
662 successorArgNames.push_back(name);
663 }
664 }
665
666 /// Generates Python code for the default value of the given attribute.
getAttributeDefaultValue(Attribute attr)667 static FailureOr<std::string> getAttributeDefaultValue(Attribute attr) {
668 assert(attr.hasDefaultValue() && "expected attribute with default value");
669 StringRef storageType = attr.getStorageType().trim();
670 StringRef defaultValCpp = attr.getDefaultValue().trim();
671
672 // A list of commonly used attribute types and default values for which
673 // we can generate Python code. Extend as needed.
674 if (storageType.equals("::mlir::ArrayAttr") && defaultValCpp.equals("{}"))
675 return std::string("_ods_ir.ArrayAttr.get([])");
676
677 // No match: Cannot generate Python code.
678 return failure();
679 }
680
681 /// Populates `builderLines` with additional lines that are required in the
682 /// builder to set up operation attributes. `argNames` is expected to contain
683 /// the names of builder arguments that correspond to op arguments, i.e. to the
684 /// operands and attributes in the same order as they appear in the `arguments`
685 /// field.
686 static void
populateBuilderLinesAttr(const Operator & op,llvm::ArrayRef<std::string> argNames,llvm::SmallVectorImpl<std::string> & builderLines)687 populateBuilderLinesAttr(const Operator &op,
688 llvm::ArrayRef<std::string> argNames,
689 llvm::SmallVectorImpl<std::string> &builderLines) {
690 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
691 Argument arg = op.getArg(i);
692 auto *attribute = arg.dyn_cast<NamedAttribute *>();
693 if (!attribute)
694 continue;
695
696 // Unit attributes are handled specially.
697 if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
698 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
699 attribute->name, argNames[i]));
700 continue;
701 }
702
703 // Attributes with default value are handled specially.
704 if (attribute->attr.hasDefaultValue()) {
705 // In case we cannot generate Python code for the default value, the
706 // attribute must be specified by the user.
707 FailureOr<std::string> defaultValPy =
708 getAttributeDefaultValue(attribute->attr);
709 if (succeeded(defaultValPy)) {
710 builderLines.push_back(llvm::formatv(initDefaultValuedAttributeTemplate,
711 attribute->name, argNames[i],
712 *defaultValPy));
713 } else {
714 builderLines.push_back(llvm::formatv(assertAttributeValueSpecified,
715 attribute->name, argNames[i]));
716 builderLines.push_back(
717 llvm::formatv(initAttributeTemplate, attribute->name, argNames[i]));
718 }
719 continue;
720 }
721
722 builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
723 ? initOptionalAttributeTemplate
724 : initAttributeTemplate,
725 attribute->name, argNames[i]));
726 }
727 }
728
729 /// Populates `builderLines` with additional lines that are required in the
730 /// builder to set up successors. successorArgNames is expected to correspond
731 /// to the Python argument name for each successor on the op.
populateBuilderLinesSuccessors(const Operator & op,llvm::ArrayRef<std::string> successorArgNames,llvm::SmallVectorImpl<std::string> & builderLines)732 static void populateBuilderLinesSuccessors(
733 const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
734 llvm::SmallVectorImpl<std::string> &builderLines) {
735 if (successorArgNames.empty()) {
736 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
737 return;
738 }
739
740 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
741 for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
742 auto &argName = successorArgNames[i];
743 const NamedSuccessor &successor = op.getSuccessor(i);
744 builderLines.push_back(
745 llvm::formatv(addSuccessorTemplate,
746 successor.isVariadic() ? "extend" : "append", argName));
747 }
748 }
749
750 /// Populates `builderLines` with additional lines that are required in the
751 /// builder to set up op operands.
752 static void
populateBuilderLinesOperand(const Operator & op,llvm::ArrayRef<std::string> names,llvm::SmallVectorImpl<std::string> & builderLines)753 populateBuilderLinesOperand(const Operator &op,
754 llvm::ArrayRef<std::string> names,
755 llvm::SmallVectorImpl<std::string> &builderLines) {
756 bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
757
758 // For each element, find or generate a name.
759 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
760 const NamedTypeConstraint &element = op.getOperand(i);
761 std::string name = names[i];
762
763 // Choose the formatting string based on the element kind.
764 llvm::StringRef formatString;
765 if (!element.isVariableLength()) {
766 formatString = singleOperandAppendTemplate;
767 } else if (element.isOptional()) {
768 if (sizedSegments) {
769 formatString = optionalAppendAttrSizedOperandsTemplate;
770 } else {
771 formatString = optionalAppendOperandTemplate;
772 }
773 } else {
774 assert(element.isVariadic() && "unhandled element group type");
775 // If emitting with sizedSegments, then we add the actual list-typed
776 // element. Otherwise, we extend the actual operands.
777 if (sizedSegments) {
778 formatString = multiOperandAppendPackTemplate;
779 } else {
780 formatString = multiOperandAppendTemplate;
781 }
782 }
783
784 builderLines.push_back(llvm::formatv(formatString.data(), name));
785 }
786 }
787
788 /// Python code template for deriving the operation result types from its
789 /// attribute:
790 /// - {0} is the name of the attribute from which to derive the types.
791 constexpr const char *deriveTypeFromAttrTemplate =
792 R"PY(_ods_result_type_source_attr = attributes["{0}"]
793 _ods_derived_result_type = (
794 _ods_ir.TypeAttr(_ods_result_type_source_attr).value
795 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
796 _ods_result_type_source_attr.type))PY";
797
798 /// Python code template appending {0} type {1} times to the results list.
799 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
800
801 /// Python code template for inferring the operation results using the
802 /// corresponding interface:
803 /// - {0} is the name of the class for which the types are inferred.
804 constexpr const char *inferTypeInterfaceTemplate =
805 R"PY(_ods_context = _ods_get_default_loc_context(loc)
806 results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
807 operands=operands,
808 attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
809 context=_ods_context,
810 loc=loc)
811 )PY";
812
813 /// Appends the given multiline string as individual strings into
814 /// `builderLines`.
appendLineByLine(StringRef string,llvm::SmallVectorImpl<std::string> & builderLines)815 static void appendLineByLine(StringRef string,
816 llvm::SmallVectorImpl<std::string> &builderLines) {
817
818 std::pair<StringRef, StringRef> split = std::make_pair(string, string);
819 do {
820 split = split.second.split('\n');
821 builderLines.push_back(split.first.str());
822 } while (!split.second.empty());
823 }
824
825 /// Populates `builderLines` with additional lines that are required in the
826 /// builder to set up op results.
827 static void
populateBuilderLinesResult(const Operator & op,llvm::ArrayRef<std::string> names,llvm::SmallVectorImpl<std::string> & builderLines)828 populateBuilderLinesResult(const Operator &op,
829 llvm::ArrayRef<std::string> names,
830 llvm::SmallVectorImpl<std::string> &builderLines) {
831 bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
832
833 if (hasSameArgumentAndResultTypes(op)) {
834 builderLines.push_back(llvm::formatv(
835 appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
836 return;
837 }
838
839 if (hasFirstAttrDerivedResultTypes(op)) {
840 const NamedAttribute &firstAttr = op.getAttribute(0);
841 assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
842 "from which the type is derived");
843 appendLineByLine(
844 llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
845 builderLines);
846 builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
847 "_ods_derived_result_type",
848 op.getNumResults()));
849 return;
850 }
851
852 if (hasInferTypeInterface(op)) {
853 appendLineByLine(
854 llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
855 builderLines);
856 return;
857 }
858
859 // For each element, find or generate a name.
860 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
861 const NamedTypeConstraint &element = op.getResult(i);
862 std::string name = names[i];
863
864 // Choose the formatting string based on the element kind.
865 llvm::StringRef formatString;
866 if (!element.isVariableLength()) {
867 formatString = singleResultAppendTemplate;
868 } else if (element.isOptional()) {
869 formatString = optionalAppendResultTemplate;
870 } else {
871 assert(element.isVariadic() && "unhandled element group type");
872 // If emitting with sizedSegments, then we add the actual list-typed
873 // element. Otherwise, we extend the actual operands.
874 if (sizedSegments) {
875 formatString = singleResultAppendTemplate;
876 } else {
877 formatString = multiResultAppendTemplate;
878 }
879 }
880
881 builderLines.push_back(llvm::formatv(formatString.data(), name));
882 }
883 }
884
885 /// If the operation has variadic regions, adds a builder argument to specify
886 /// the number of those regions and builder lines to forward it to the generic
887 /// constructor.
888 static void
populateBuilderRegions(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs,llvm::SmallVectorImpl<std::string> & builderLines)889 populateBuilderRegions(const Operator &op,
890 llvm::SmallVectorImpl<std::string> &builderArgs,
891 llvm::SmallVectorImpl<std::string> &builderLines) {
892 if (op.hasNoVariadicRegions())
893 return;
894
895 // This is currently enforced when Operator is constructed.
896 assert(op.getNumVariadicRegions() == 1 &&
897 op.getRegion(op.getNumRegions() - 1).isVariadic() &&
898 "expected the last region to be varidic");
899
900 const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1);
901 std::string name =
902 ("num_" + region.name.take_front().lower() + region.name.drop_front())
903 .str();
904 builderArgs.push_back(name);
905 builderLines.push_back(
906 llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
907 }
908
909 /// Emits a default builder constructing an operation from the list of its
910 /// result types, followed by a list of its operands.
emitDefaultOpBuilder(const Operator & op,raw_ostream & os)911 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
912 // If we are asked to skip default builders, comply.
913 if (op.skipDefaultBuilders())
914 return;
915
916 llvm::SmallVector<std::string> builderArgs;
917 llvm::SmallVector<std::string> builderLines;
918 llvm::SmallVector<std::string> operandArgNames;
919 llvm::SmallVector<std::string> successorArgNames;
920 builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
921 op.getNumNativeAttributes() + op.getNumSuccessors());
922 populateBuilderArgsResults(op, builderArgs);
923 size_t numResultArgs = builderArgs.size();
924 populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
925 size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
926 populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
927
928 populateBuilderLinesOperand(op, operandArgNames, builderLines);
929 populateBuilderLinesAttr(
930 op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs),
931 builderLines);
932 populateBuilderLinesResult(
933 op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs),
934 builderLines);
935 populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
936 populateBuilderRegions(op, builderArgs, builderLines);
937
938 // Layout of builderArgs vector elements:
939 // [ result_args operand_attr_args successor_args regions ]
940
941 // Determine whether the argument corresponding to a given index into the
942 // builderArgs vector is a python keyword argument or not.
943 auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
944 // All result, successor, and region arguments are positional arguments.
945 if ((builderArgIndex < numResultArgs) ||
946 (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
947 return false;
948 // Keyword arguments:
949 // - optional named attributes (including unit attributes)
950 // - default-valued named attributes
951 // - optional operands
952 Argument a = op.getArg(builderArgIndex - numResultArgs);
953 if (auto *nattr = a.dyn_cast<NamedAttribute *>())
954 return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
955 if (auto *ntype = a.dyn_cast<NamedTypeConstraint *>())
956 return ntype->isOptional();
957 return false;
958 };
959
960 // StringRefs in functionArgs refer to strings allocated by builderArgs.
961 llvm::SmallVector<llvm::StringRef> functionArgs;
962
963 // Add positional arguments.
964 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
965 if (!isKeywordArgFn(i))
966 functionArgs.push_back(builderArgs[i]);
967 }
968
969 // Add a bare '*' to indicate that all following arguments must be keyword
970 // arguments.
971 functionArgs.push_back("*");
972
973 // Add a default 'None' value to each keyword arg string, and then add to the
974 // function args list.
975 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
976 if (isKeywordArgFn(i)) {
977 builderArgs[i].append("=None");
978 functionArgs.push_back(builderArgs[i]);
979 }
980 }
981 functionArgs.push_back("loc=None");
982 functionArgs.push_back("ip=None");
983 os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
984 llvm::join(builderLines, "\n "));
985 }
986
constructAttributeMapping(const llvm::RecordKeeper & records,AttributeClasses & attributeClasses)987 static void constructAttributeMapping(const llvm::RecordKeeper &records,
988 AttributeClasses &attributeClasses) {
989 for (const llvm::Record *rec :
990 records.getAllDerivedDefinitions("PythonAttr")) {
991 attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
992 rec->getValueAsString("pythonType").trim());
993 }
994 }
995
emitSegmentSpec(const Operator & op,const char * kind,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement,raw_ostream & os)996 static void emitSegmentSpec(
997 const Operator &op, const char *kind,
998 llvm::function_ref<int(const Operator &)> getNumElements,
999 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
1000 getElement,
1001 raw_ostream &os) {
1002 std::string segmentSpec("[");
1003 for (int i = 0, e = getNumElements(op); i < e; ++i) {
1004 const NamedTypeConstraint &element = getElement(op, i);
1005 if (element.isOptional()) {
1006 segmentSpec.append("0,");
1007 } else if (element.isVariadic()) {
1008 segmentSpec.append("-1,");
1009 } else {
1010 segmentSpec.append("1,");
1011 }
1012 }
1013 segmentSpec.append("]");
1014
1015 os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
1016 }
1017
emitRegionAttributes(const Operator & op,raw_ostream & os)1018 static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
1019 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
1020 // Note that the base OpView class defines this as (0, True).
1021 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
1022 os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
1023 op.hasNoVariadicRegions() ? "True" : "False");
1024 }
1025
1026 /// Emits named accessors to regions.
emitRegionAccessors(const Operator & op,raw_ostream & os)1027 static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
1028 for (const auto &en : llvm::enumerate(op.getRegions())) {
1029 const NamedRegion ®ion = en.value();
1030 if (region.name.empty())
1031 continue;
1032
1033 assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
1034 "expected only the last region to be variadic");
1035 os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
1036 std::to_string(en.index()) +
1037 (region.isVariadic() ? ":" : ""));
1038 }
1039 }
1040
1041 /// Emits bindings for a specific Op to the given output stream.
emitOpBindings(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)1042 static void emitOpBindings(const Operator &op,
1043 const AttributeClasses &attributeClasses,
1044 raw_ostream &os) {
1045 os << llvm::formatv(opClassTemplate, op.getCppClassName(),
1046 op.getOperationName());
1047
1048 // Sized segments.
1049 if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
1050 emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
1051 }
1052 if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
1053 emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
1054 }
1055
1056 emitRegionAttributes(op, os);
1057 emitDefaultOpBuilder(op, os);
1058 emitOperandAccessors(op, os);
1059 emitAttributeAccessors(op, attributeClasses, os);
1060 emitResultAccessors(op, os);
1061 emitRegionAccessors(op, os);
1062 }
1063
1064 /// Emits bindings for the dialect specified in the command line, including file
1065 /// headers and utilities. Returns `false` on success to comply with Tablegen
1066 /// registration requirements.
emitAllOps(const llvm::RecordKeeper & records,raw_ostream & os)1067 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
1068 if (clDialectName.empty())
1069 llvm::PrintFatalError("dialect name not provided");
1070
1071 AttributeClasses attributeClasses;
1072 constructAttributeMapping(records, attributeClasses);
1073
1074 bool isExtension = !clDialectExtensionName.empty();
1075 os << llvm::formatv(fileHeader, isExtension
1076 ? clDialectExtensionName.getValue()
1077 : clDialectName.getValue());
1078 if (isExtension)
1079 os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
1080 else
1081 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
1082
1083 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
1084 Operator op(rec);
1085 if (op.getDialectName() == clDialectName.getValue())
1086 emitOpBindings(op, attributeClasses, os);
1087 }
1088 return false;
1089 }
1090
1091 static GenRegistration
1092 genPythonBindings("gen-python-op-bindings",
1093 "Generate Python bindings for MLIR Ops", &emitAllOps);
1094