1fd407e1fSAlex Zinenko //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2fd407e1fSAlex Zinenko //
3fd407e1fSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fd407e1fSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5fd407e1fSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fd407e1fSAlex Zinenko //
7fd407e1fSAlex Zinenko //===----------------------------------------------------------------------===//
8fd407e1fSAlex Zinenko //
9fd407e1fSAlex Zinenko // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10fd407e1fSAlex Zinenko // binding classes wrapping a generic operation API.
11fd407e1fSAlex Zinenko //
12fd407e1fSAlex Zinenko //===----------------------------------------------------------------------===//
13fd407e1fSAlex Zinenko
14*989d2b51SMatthias Springer #include "mlir/Support/LogicalResult.h"
15fd407e1fSAlex Zinenko #include "mlir/TableGen/GenInfo.h"
16fd407e1fSAlex Zinenko #include "mlir/TableGen/Operator.h"
17fd407e1fSAlex Zinenko #include "llvm/ADT/StringSet.h"
18fd407e1fSAlex Zinenko #include "llvm/Support/CommandLine.h"
19fd407e1fSAlex Zinenko #include "llvm/Support/FormatVariadic.h"
20fd407e1fSAlex Zinenko #include "llvm/TableGen/Error.h"
21fd407e1fSAlex Zinenko #include "llvm/TableGen/Record.h"
22fd407e1fSAlex Zinenko
23fd407e1fSAlex Zinenko using namespace mlir;
24fd407e1fSAlex Zinenko using namespace mlir::tblgen;
25fd407e1fSAlex Zinenko
26fd407e1fSAlex Zinenko /// File header and includes.
27894d88a7SStella Laurenzo /// {0} is the dialect namespace.
28fd407e1fSAlex Zinenko constexpr const char *fileHeader = R"Py(
29fd407e1fSAlex Zinenko # Autogenerated by mlir-tblgen; don't manually edit.
30fd407e1fSAlex Zinenko
31e31c77b1SStella Laurenzo from ._ods_common import _cext as _ods_cext
32b164f23cSAlex Zinenko from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
338a1f1a10SStella Laurenzo _ods_ir = _ods_cext.ir
34894d88a7SStella Laurenzo
35894d88a7SStella Laurenzo try:
36e31c77b1SStella Laurenzo from . import _{0}_ops_ext as _ods_ext_module
37894d88a7SStella Laurenzo except ImportError:
38894d88a7SStella Laurenzo _ods_ext_module = None
39894d88a7SStella Laurenzo
40b4c93eceSJohn Demme import builtins
41b4c93eceSJohn Demme
42fd407e1fSAlex Zinenko )Py";
43fd407e1fSAlex Zinenko
44fd407e1fSAlex Zinenko /// Template for dialect class:
45fd407e1fSAlex Zinenko /// {0} is the dialect namespace.
46fd407e1fSAlex Zinenko constexpr const char *dialectClassTemplate = R"Py(
478a1f1a10SStella Laurenzo @_ods_cext.register_dialect
488a1f1a10SStella Laurenzo class _Dialect(_ods_ir.Dialect):
49fd407e1fSAlex Zinenko DIALECT_NAMESPACE = "{0}"
50fd407e1fSAlex Zinenko pass
51fd407e1fSAlex Zinenko
52fd407e1fSAlex Zinenko )Py";
53fd407e1fSAlex Zinenko
543f71765aSAlex Zinenko constexpr const char *dialectExtensionTemplate = R"Py(
553f71765aSAlex Zinenko from ._{0}_ops_gen import _Dialect
563f71765aSAlex Zinenko )Py";
573f71765aSAlex Zinenko
58fd407e1fSAlex Zinenko /// Template for operation class:
59fd407e1fSAlex Zinenko /// {0} is the Python class name;
60fd407e1fSAlex Zinenko /// {1} is the operation name.
61fd407e1fSAlex Zinenko constexpr const char *opClassTemplate = R"Py(
628a1f1a10SStella Laurenzo @_ods_cext.register_operation(_Dialect)
63894d88a7SStella Laurenzo @_ods_extend_opview_class(_ods_ext_module)
648a1f1a10SStella Laurenzo class {0}(_ods_ir.OpView):
65fd407e1fSAlex Zinenko OPERATION_NAME = "{1}"
66fd407e1fSAlex Zinenko )Py";
67fd407e1fSAlex Zinenko
6871b6b010SStella Laurenzo /// Template for class level declarations of operand and result
6971b6b010SStella Laurenzo /// segment specs.
7071b6b010SStella Laurenzo /// {0} is either "OPERAND" or "RESULT"
7171b6b010SStella Laurenzo /// {1} is the segment spec
7271b6b010SStella Laurenzo /// Each segment spec is either None (default) or an array of integers
7371b6b010SStella Laurenzo /// where:
7471b6b010SStella Laurenzo /// 1 = single element (expect non sequence operand/result)
756981e5ecSAlex Zinenko /// 0 = optional element (expect a value or None)
7671b6b010SStella Laurenzo /// -1 = operand/result is a sequence corresponding to a variadic
7771b6b010SStella Laurenzo constexpr const char *opClassSizedSegmentsTemplate = R"Py(
7871b6b010SStella Laurenzo _ODS_{0}_SEGMENTS = {1}
7971b6b010SStella Laurenzo )Py";
8071b6b010SStella Laurenzo
8171b6b010SStella Laurenzo /// Template for class level declarations of the _ODS_REGIONS spec:
8271b6b010SStella Laurenzo /// {0} is the minimum number of regions
8371b6b010SStella Laurenzo /// {1} is the Python bool literal for hasNoVariadicRegions
8471b6b010SStella Laurenzo constexpr const char *opClassRegionSpecTemplate = R"Py(
8571b6b010SStella Laurenzo _ODS_REGIONS = ({0}, {1})
8671b6b010SStella Laurenzo )Py";
8771b6b010SStella Laurenzo
88fd407e1fSAlex Zinenko /// Template for single-element accessor:
89fd407e1fSAlex Zinenko /// {0} is the name of the accessor;
90fd407e1fSAlex Zinenko /// {1} is either 'operand' or 'result';
91fd407e1fSAlex Zinenko /// {2} is the position in the element list.
92fd407e1fSAlex Zinenko constexpr const char *opSingleTemplate = R"Py(
93b4c93eceSJohn Demme @builtins.property
94fd407e1fSAlex Zinenko def {0}(self):
95fd407e1fSAlex Zinenko return self.operation.{1}s[{2}]
96fd407e1fSAlex Zinenko )Py";
97fd407e1fSAlex Zinenko
98fd407e1fSAlex Zinenko /// Template for single-element accessor after a variable-length group:
99fd407e1fSAlex Zinenko /// {0} is the name of the accessor;
100fd407e1fSAlex Zinenko /// {1} is either 'operand' or 'result';
101fd407e1fSAlex Zinenko /// {2} is the total number of element groups;
102fd407e1fSAlex Zinenko /// {3} is the position of the current group in the group list.
103fd407e1fSAlex Zinenko /// This works for both a single variadic group (non-negative length) and an
104fd407e1fSAlex Zinenko /// single optional element (zero length if the element is absent).
105fd407e1fSAlex Zinenko constexpr const char *opSingleAfterVariableTemplate = R"Py(
106b4c93eceSJohn Demme @builtins.property
107fd407e1fSAlex Zinenko def {0}(self):
1088a1f1a10SStella Laurenzo _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
1098a1f1a10SStella Laurenzo return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
110fd407e1fSAlex Zinenko )Py";
111fd407e1fSAlex Zinenko
112fd407e1fSAlex Zinenko /// Template for an optional element accessor:
113fd407e1fSAlex Zinenko /// {0} is the name of the accessor;
114fd407e1fSAlex Zinenko /// {1} is either 'operand' or 'result';
115fd407e1fSAlex Zinenko /// {2} is the total number of element groups;
116fd407e1fSAlex Zinenko /// {3} is the position of the current group in the group list.
11754c99842SMichal Terepeta /// This works if we have only one variable-length group (and it's the optional
11854c99842SMichal Terepeta /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
11954c99842SMichal Terepeta /// smaller than the total number of groups.
120fd407e1fSAlex Zinenko constexpr const char *opOneOptionalTemplate = R"Py(
121b4c93eceSJohn Demme @builtins.property
122fd226c9bSStella Laurenzo def {0}(self):
12354c99842SMichal Terepeta return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
124fd407e1fSAlex Zinenko )Py";
125fd407e1fSAlex Zinenko
126fd407e1fSAlex Zinenko /// Template for the variadic group accessor in the single variadic group case:
127fd407e1fSAlex Zinenko /// {0} is the name of the accessor;
128fd407e1fSAlex Zinenko /// {1} is either 'operand' or 'result';
129fd407e1fSAlex Zinenko /// {2} is the total number of element groups;
130fd407e1fSAlex Zinenko /// {3} is the position of the current group in the group list.
131fd407e1fSAlex Zinenko constexpr const char *opOneVariadicTemplate = R"Py(
132b4c93eceSJohn Demme @builtins.property
133fd407e1fSAlex Zinenko def {0}(self):
1348a1f1a10SStella Laurenzo _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
1358a1f1a10SStella Laurenzo return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
136fd407e1fSAlex Zinenko )Py";
137fd407e1fSAlex Zinenko
138fd407e1fSAlex Zinenko /// First part of the template for equally-sized variadic group accessor:
139fd407e1fSAlex Zinenko /// {0} is the name of the accessor;
140fd407e1fSAlex Zinenko /// {1} is either 'operand' or 'result';
141fd407e1fSAlex Zinenko /// {2} is the total number of variadic groups;
142fd407e1fSAlex Zinenko /// {3} is the number of non-variadic groups preceding the current group;
143fd407e1fSAlex Zinenko /// {3} is the number of variadic groups preceding the current group.
144fd407e1fSAlex Zinenko constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
145b4c93eceSJohn Demme @builtins.property
146fd407e1fSAlex Zinenko def {0}(self):
1478a1f1a10SStella Laurenzo start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
148fd407e1fSAlex Zinenko
149fd407e1fSAlex Zinenko /// Second part of the template for equally-sized case, accessing a single
150fd407e1fSAlex Zinenko /// element:
151fd407e1fSAlex Zinenko /// {0} is either 'operand' or 'result'.
152fd407e1fSAlex Zinenko constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
153fd407e1fSAlex Zinenko return self.operation.{0}s[start]
154fd407e1fSAlex Zinenko )Py";
155fd407e1fSAlex Zinenko
156fd407e1fSAlex Zinenko /// Second part of the template for equally-sized case, accessing a variadic
157fd407e1fSAlex Zinenko /// group:
158fd407e1fSAlex Zinenko /// {0} is either 'operand' or 'result'.
159fd407e1fSAlex Zinenko constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
160fd407e1fSAlex Zinenko return self.operation.{0}s[start:start + pg]
161fd407e1fSAlex Zinenko )Py";
162fd407e1fSAlex Zinenko
163fd407e1fSAlex Zinenko /// Template for an attribute-sized group accessor:
164fd407e1fSAlex Zinenko /// {0} is the name of the accessor;
165fd407e1fSAlex Zinenko /// {1} is either 'operand' or 'result';
166fd407e1fSAlex Zinenko /// {2} is the position of the group in the group list;
167fd407e1fSAlex Zinenko /// {3} is a return suffix (expected [0] for single-element, empty for
168fd407e1fSAlex Zinenko /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
169fd407e1fSAlex Zinenko constexpr const char *opVariadicSegmentTemplate = R"Py(
170b4c93eceSJohn Demme @builtins.property
171fd407e1fSAlex Zinenko def {0}(self):
1728a1f1a10SStella Laurenzo {1}_range = _ods_segmented_accessor(
173fd407e1fSAlex Zinenko self.operation.{1}s,
174fd407e1fSAlex Zinenko self.operation.attributes["{1}_segment_sizes"], {2})
175fd407e1fSAlex Zinenko return {1}_range{3}
176fd407e1fSAlex Zinenko )Py";
177fd407e1fSAlex Zinenko
178fd407e1fSAlex Zinenko /// Template for a suffix when accessing an optional element in the
179fd407e1fSAlex Zinenko /// attribute-sized case:
180fd407e1fSAlex Zinenko /// {0} is either 'operand' or 'result';
181fd407e1fSAlex Zinenko constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
182fd407e1fSAlex Zinenko R"Py([0] if len({0}_range) > 0 else None)Py";
183fd407e1fSAlex Zinenko
184c5a6712fSAlex Zinenko /// Template for an operation attribute getter:
185c5a6712fSAlex Zinenko /// {0} is the name of the attribute sanitized for Python;
186c5a6712fSAlex Zinenko /// {1} is the Python type of the attribute;
187c5a6712fSAlex Zinenko /// {2} os the original name of the attribute.
188c5a6712fSAlex Zinenko constexpr const char *attributeGetterTemplate = R"Py(
189b4c93eceSJohn Demme @builtins.property
190c5a6712fSAlex Zinenko def {0}(self):
191c5a6712fSAlex Zinenko return {1}(self.operation.attributes["{2}"])
192c5a6712fSAlex Zinenko )Py";
193c5a6712fSAlex Zinenko
194c5a6712fSAlex Zinenko /// Template for an optional operation attribute getter:
195c5a6712fSAlex Zinenko /// {0} is the name of the attribute sanitized for Python;
196c5a6712fSAlex Zinenko /// {1} is the Python type of the attribute;
197c5a6712fSAlex Zinenko /// {2} is the original name of the attribute.
198c5a6712fSAlex Zinenko constexpr const char *optionalAttributeGetterTemplate = R"Py(
199b4c93eceSJohn Demme @builtins.property
200c5a6712fSAlex Zinenko def {0}(self):
201c5a6712fSAlex Zinenko if "{2}" not in self.operation.attributes:
202c5a6712fSAlex Zinenko return None
203c5a6712fSAlex Zinenko return {1}(self.operation.attributes["{2}"])
204c5a6712fSAlex Zinenko )Py";
205c5a6712fSAlex Zinenko
206029e199dSAlex Zinenko /// Template for a getter of a unit operation attribute, returns True of the
207c5a6712fSAlex Zinenko /// unit attribute is present, False otherwise (unit attributes have meaning
208c5a6712fSAlex Zinenko /// by mere presence):
209c5a6712fSAlex Zinenko /// {0} is the name of the attribute sanitized for Python,
210c5a6712fSAlex Zinenko /// {1} is the original name of the attribute.
211c5a6712fSAlex Zinenko constexpr const char *unitAttributeGetterTemplate = R"Py(
212b4c93eceSJohn Demme @builtins.property
213c5a6712fSAlex Zinenko def {0}(self):
214c5a6712fSAlex Zinenko return "{1}" in self.operation.attributes
215c5a6712fSAlex Zinenko )Py";
216c5a6712fSAlex Zinenko
217029e199dSAlex Zinenko /// Template for an operation attribute setter:
218029e199dSAlex Zinenko /// {0} is the name of the attribute sanitized for Python;
219029e199dSAlex Zinenko /// {1} is the original name of the attribute.
220029e199dSAlex Zinenko constexpr const char *attributeSetterTemplate = R"Py(
221029e199dSAlex Zinenko @{0}.setter
222029e199dSAlex Zinenko def {0}(self, value):
223029e199dSAlex Zinenko if value is None:
224029e199dSAlex Zinenko raise ValueError("'None' not allowed as value for mandatory attributes")
225029e199dSAlex Zinenko self.operation.attributes["{1}"] = value
226029e199dSAlex Zinenko )Py";
227029e199dSAlex Zinenko
228029e199dSAlex Zinenko /// Template for a setter of an optional operation attribute, setting to None
229029e199dSAlex Zinenko /// removes the attribute:
230029e199dSAlex Zinenko /// {0} is the name of the attribute sanitized for Python;
231029e199dSAlex Zinenko /// {1} is the original name of the attribute.
232029e199dSAlex Zinenko constexpr const char *optionalAttributeSetterTemplate = R"Py(
233029e199dSAlex Zinenko @{0}.setter
234029e199dSAlex Zinenko def {0}(self, value):
235029e199dSAlex Zinenko if value is not None:
236029e199dSAlex Zinenko self.operation.attributes["{1}"] = value
237029e199dSAlex Zinenko elif "{1}" in self.operation.attributes:
238029e199dSAlex Zinenko del self.operation.attributes["{1}"]
239029e199dSAlex Zinenko )Py";
240029e199dSAlex Zinenko
241029e199dSAlex Zinenko /// Template for a setter of a unit operation attribute, setting to None or
242029e199dSAlex Zinenko /// False removes the attribute:
243029e199dSAlex Zinenko /// {0} is the name of the attribute sanitized for Python;
244029e199dSAlex Zinenko /// {1} is the original name of the attribute.
245029e199dSAlex Zinenko constexpr const char *unitAttributeSetterTemplate = R"Py(
246029e199dSAlex Zinenko @{0}.setter
247029e199dSAlex Zinenko def {0}(self, value):
248029e199dSAlex Zinenko if bool(value):
2498a1f1a10SStella Laurenzo self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
250029e199dSAlex Zinenko elif "{1}" in self.operation.attributes:
251029e199dSAlex Zinenko del self.operation.attributes["{1}"]
252029e199dSAlex Zinenko )Py";
253029e199dSAlex Zinenko
254029e199dSAlex Zinenko /// Template for a deleter of an optional or a unit operation attribute, removes
255029e199dSAlex Zinenko /// the attribute from the operation:
256029e199dSAlex Zinenko /// {0} is the name of the attribute sanitized for Python;
257029e199dSAlex Zinenko /// {1} is the original name of the attribute.
258029e199dSAlex Zinenko constexpr const char *attributeDeleterTemplate = R"Py(
259029e199dSAlex Zinenko @{0}.deleter
260029e199dSAlex Zinenko def {0}(self):
261029e199dSAlex Zinenko del self.operation.attributes["{1}"]
262029e199dSAlex Zinenko )Py";
263029e199dSAlex Zinenko
26418fbd5feSAlex Zinenko constexpr const char *regionAccessorTemplate = R"PY(
26518fbd5feSAlex Zinenko @builtins.property
266310736e0SAlex Zinenko def {0}(self):
26718fbd5feSAlex Zinenko return self.regions[{1}]
26818fbd5feSAlex Zinenko )PY";
26918fbd5feSAlex Zinenko
270fd407e1fSAlex Zinenko static llvm::cl::OptionCategory
271fd407e1fSAlex Zinenko clOpPythonBindingCat("Options for -gen-python-op-bindings");
272fd407e1fSAlex Zinenko
273fd407e1fSAlex Zinenko static llvm::cl::opt<std::string>
274fd407e1fSAlex Zinenko clDialectName("bind-dialect",
275fd407e1fSAlex Zinenko llvm::cl::desc("The dialect to run the generator for"),
276fd407e1fSAlex Zinenko llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
277fd407e1fSAlex Zinenko
2783f71765aSAlex Zinenko static llvm::cl::opt<std::string> clDialectExtensionName(
2793f71765aSAlex Zinenko "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
2803f71765aSAlex Zinenko llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
2813f71765aSAlex Zinenko
282c5a6712fSAlex Zinenko using AttributeClasses = DenseMap<StringRef, StringRef>;
283c5a6712fSAlex Zinenko
284fd407e1fSAlex Zinenko /// Checks whether `str` is a Python keyword.
isPythonKeyword(StringRef str)285fd407e1fSAlex Zinenko static bool isPythonKeyword(StringRef str) {
286fd407e1fSAlex Zinenko static llvm::StringSet<> keywords(
287fd407e1fSAlex Zinenko {"and", "as", "assert", "break", "class", "continue",
288fd407e1fSAlex Zinenko "def", "del", "elif", "else", "except", "finally",
289fd407e1fSAlex Zinenko "for", "from", "global", "if", "import", "in",
290fd407e1fSAlex Zinenko "is", "lambda", "nonlocal", "not", "or", "pass",
291fd407e1fSAlex Zinenko "raise", "return", "try", "while", "with", "yield"});
292fd407e1fSAlex Zinenko return keywords.contains(str);
2937dadcd02SMehdi Amini }
294fd407e1fSAlex Zinenko
2958a1f1a10SStella Laurenzo /// Checks whether `str` would shadow a generated variable or attribute
2968a1f1a10SStella Laurenzo /// part of the OpView API.
isODSReserved(StringRef str)2978a1f1a10SStella Laurenzo static bool isODSReserved(StringRef str) {
2988a1f1a10SStella Laurenzo static llvm::StringSet<> reserved(
2998a1f1a10SStella Laurenzo {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
300fd226c9bSStella Laurenzo "loc", "verify", "regions", "results", "self", "operation",
3018a1f1a10SStella Laurenzo "DIALECT_NAMESPACE", "OPERATION_NAME"});
3028a1f1a10SStella Laurenzo return str.startswith("_ods_") || str.endswith("_ods") ||
3038a1f1a10SStella Laurenzo reserved.contains(str);
3048a1f1a10SStella Laurenzo }
3058a1f1a10SStella Laurenzo
306fd407e1fSAlex Zinenko /// Modifies the `name` in a way that it becomes suitable for Python bindings
307fd407e1fSAlex Zinenko /// (does not change the `name` if it already is suitable) and returns the
308fd407e1fSAlex Zinenko /// modified version.
sanitizeName(StringRef name)309fd407e1fSAlex Zinenko static std::string sanitizeName(StringRef name) {
3108a1f1a10SStella Laurenzo if (isPythonKeyword(name) || isODSReserved(name))
311fd407e1fSAlex Zinenko return (name + "_").str();
312fd407e1fSAlex Zinenko return name.str();
313fd407e1fSAlex Zinenko }
314fd407e1fSAlex Zinenko
attrSizedTraitForKind(const char * kind)315f9265de8SAlex Zinenko static std::string attrSizedTraitForKind(const char *kind) {
316f9265de8SAlex Zinenko return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
317f9265de8SAlex Zinenko llvm::StringRef(kind).take_front().upper(),
318f9265de8SAlex Zinenko llvm::StringRef(kind).drop_front());
319f9265de8SAlex Zinenko }
320f9265de8SAlex Zinenko
321fd407e1fSAlex Zinenko /// Emits accessors to "elements" of an Op definition. Currently, the supported
322fd407e1fSAlex Zinenko /// elements are operands and results, indicated by `kind`, which must be either
323fd407e1fSAlex Zinenko /// `operand` or `result` and is used verbatim in the emitted code.
emitElementAccessors(const Operator & op,raw_ostream & os,const char * kind,llvm::function_ref<unsigned (const Operator &)> getNumVariableLength,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement)324fd407e1fSAlex Zinenko static void emitElementAccessors(
325fd407e1fSAlex Zinenko const Operator &op, raw_ostream &os, const char *kind,
32654c99842SMichal Terepeta llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
327fd407e1fSAlex Zinenko llvm::function_ref<int(const Operator &)> getNumElements,
328fd407e1fSAlex Zinenko llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
329fd407e1fSAlex Zinenko getElement) {
330fd407e1fSAlex Zinenko assert(llvm::is_contained(
331fd407e1fSAlex Zinenko llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
332fd407e1fSAlex Zinenko "unsupported kind");
333fd407e1fSAlex Zinenko
334fd407e1fSAlex Zinenko // Traits indicating how to process variadic elements.
335fd407e1fSAlex Zinenko std::string sameSizeTrait =
336fd407e1fSAlex Zinenko llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
337fd407e1fSAlex Zinenko llvm::StringRef(kind).take_front().upper(),
338fd407e1fSAlex Zinenko llvm::StringRef(kind).drop_front());
339f9265de8SAlex Zinenko std::string attrSizedTrait = attrSizedTraitForKind(kind);
340fd407e1fSAlex Zinenko
34154c99842SMichal Terepeta unsigned numVariableLength = getNumVariableLength(op);
342fd407e1fSAlex Zinenko
34354c99842SMichal Terepeta // If there is only one variable-length element group, its size can be
34454c99842SMichal Terepeta // inferred from the total number of elements. If there are none, the
34554c99842SMichal Terepeta // generation is straightforward.
34654c99842SMichal Terepeta if (numVariableLength <= 1) {
347fd407e1fSAlex Zinenko bool seenVariableLength = false;
348fd407e1fSAlex Zinenko for (int i = 0, e = getNumElements(op); i < e; ++i) {
349fd407e1fSAlex Zinenko const NamedTypeConstraint &element = getElement(op, i);
350fd407e1fSAlex Zinenko if (element.isVariableLength())
351fd407e1fSAlex Zinenko seenVariableLength = true;
352fd407e1fSAlex Zinenko if (element.name.empty())
353fd407e1fSAlex Zinenko continue;
354fd407e1fSAlex Zinenko if (element.isVariableLength()) {
355fd407e1fSAlex Zinenko os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
356fd407e1fSAlex Zinenko : opOneVariadicTemplate,
357fd407e1fSAlex Zinenko sanitizeName(element.name), kind,
358fd407e1fSAlex Zinenko getNumElements(op), i);
359fd407e1fSAlex Zinenko } else if (seenVariableLength) {
360fd407e1fSAlex Zinenko os << llvm::formatv(opSingleAfterVariableTemplate,
361fd407e1fSAlex Zinenko sanitizeName(element.name), kind,
362fd407e1fSAlex Zinenko getNumElements(op), i);
363fd407e1fSAlex Zinenko } else {
364fd407e1fSAlex Zinenko os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
365fd407e1fSAlex Zinenko i);
366fd407e1fSAlex Zinenko }
367fd407e1fSAlex Zinenko }
368fd407e1fSAlex Zinenko return;
369fd407e1fSAlex Zinenko }
370fd407e1fSAlex Zinenko
371fd407e1fSAlex Zinenko // Handle the operations where variadic groups have the same size.
372fd407e1fSAlex Zinenko if (op.getTrait(sameSizeTrait)) {
373fd407e1fSAlex Zinenko int numPrecedingSimple = 0;
374fd407e1fSAlex Zinenko int numPrecedingVariadic = 0;
375fd407e1fSAlex Zinenko for (int i = 0, e = getNumElements(op); i < e; ++i) {
376fd407e1fSAlex Zinenko const NamedTypeConstraint &element = getElement(op, i);
377fd407e1fSAlex Zinenko if (!element.name.empty()) {
378fd407e1fSAlex Zinenko os << llvm::formatv(opVariadicEqualPrefixTemplate,
37954c99842SMichal Terepeta sanitizeName(element.name), kind, numVariableLength,
380fd407e1fSAlex Zinenko numPrecedingSimple, numPrecedingVariadic);
381fd407e1fSAlex Zinenko os << llvm::formatv(element.isVariableLength()
382fd407e1fSAlex Zinenko ? opVariadicEqualVariadicTemplate
383fd407e1fSAlex Zinenko : opVariadicEqualSimpleTemplate,
384fd407e1fSAlex Zinenko kind);
385fd407e1fSAlex Zinenko }
386fd407e1fSAlex Zinenko if (element.isVariableLength())
387fd407e1fSAlex Zinenko ++numPrecedingVariadic;
388fd407e1fSAlex Zinenko else
389fd407e1fSAlex Zinenko ++numPrecedingSimple;
390fd407e1fSAlex Zinenko }
391fd407e1fSAlex Zinenko return;
392fd407e1fSAlex Zinenko }
393fd407e1fSAlex Zinenko
394fd407e1fSAlex Zinenko // Handle the operations where the size of groups (variadic or not) is
395fd407e1fSAlex Zinenko // provided as an attribute. For non-variadic elements, make sure to return
396fd407e1fSAlex Zinenko // an element rather than a singleton container.
397fd407e1fSAlex Zinenko if (op.getTrait(attrSizedTrait)) {
398fd407e1fSAlex Zinenko for (int i = 0, e = getNumElements(op); i < e; ++i) {
399fd407e1fSAlex Zinenko const NamedTypeConstraint &element = getElement(op, i);
400fd407e1fSAlex Zinenko if (element.name.empty())
401fd407e1fSAlex Zinenko continue;
402fd407e1fSAlex Zinenko std::string trailing;
403fd407e1fSAlex Zinenko if (!element.isVariableLength())
404fd407e1fSAlex Zinenko trailing = "[0]";
405fd407e1fSAlex Zinenko else if (element.isOptional())
406fd407e1fSAlex Zinenko trailing = std::string(
407fd407e1fSAlex Zinenko llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
408fd407e1fSAlex Zinenko os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
409fd407e1fSAlex Zinenko kind, i, trailing);
410fd407e1fSAlex Zinenko }
411fd407e1fSAlex Zinenko return;
412fd407e1fSAlex Zinenko }
413fd407e1fSAlex Zinenko
414fd407e1fSAlex Zinenko llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
415fd407e1fSAlex Zinenko }
416fd407e1fSAlex Zinenko
417f9265de8SAlex Zinenko /// Free function helpers accessing Operator components.
getNumOperands(const Operator & op)418f9265de8SAlex Zinenko static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
getOperand(const Operator & op,int i)419f9265de8SAlex Zinenko static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
420f9265de8SAlex Zinenko return op.getOperand(i);
421f9265de8SAlex Zinenko }
getNumResults(const Operator & op)422f9265de8SAlex Zinenko static int getNumResults(const Operator &op) { return op.getNumResults(); }
getResult(const Operator & op,int i)423f9265de8SAlex Zinenko static const NamedTypeConstraint &getResult(const Operator &op, int i) {
424f9265de8SAlex Zinenko return op.getResult(i);
425f9265de8SAlex Zinenko }
426f9265de8SAlex Zinenko
427c5a6712fSAlex Zinenko /// Emits accessors to Op operands.
emitOperandAccessors(const Operator & op,raw_ostream & os)428fd407e1fSAlex Zinenko static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
42954c99842SMichal Terepeta auto getNumVariableLengthOperands = [](const Operator &oper) {
430fd407e1fSAlex Zinenko return oper.getNumVariableLengthOperands();
431fd407e1fSAlex Zinenko };
43254c99842SMichal Terepeta emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
43354c99842SMichal Terepeta getNumOperands, getOperand);
434fd407e1fSAlex Zinenko }
435fd407e1fSAlex Zinenko
436c5a6712fSAlex Zinenko /// Emits accessors Op results.
emitResultAccessors(const Operator & op,raw_ostream & os)437fd407e1fSAlex Zinenko static void emitResultAccessors(const Operator &op, raw_ostream &os) {
43854c99842SMichal Terepeta auto getNumVariableLengthResults = [](const Operator &oper) {
439fd407e1fSAlex Zinenko return oper.getNumVariableLengthResults();
440fd407e1fSAlex Zinenko };
44154c99842SMichal Terepeta emitElementAccessors(op, os, "result", getNumVariableLengthResults,
44254c99842SMichal Terepeta getNumResults, getResult);
443f9265de8SAlex Zinenko }
444f9265de8SAlex Zinenko
445c5a6712fSAlex Zinenko /// Emits accessors to Op attributes.
emitAttributeAccessors(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)446c5a6712fSAlex Zinenko static void emitAttributeAccessors(const Operator &op,
447c5a6712fSAlex Zinenko const AttributeClasses &attributeClasses,
448c5a6712fSAlex Zinenko raw_ostream &os) {
449c5a6712fSAlex Zinenko for (const auto &namedAttr : op.getAttributes()) {
450c5a6712fSAlex Zinenko // Skip "derived" attributes because they are just C++ functions that we
451c5a6712fSAlex Zinenko // don't currently expose.
452c5a6712fSAlex Zinenko if (namedAttr.attr.isDerivedAttr())
453c5a6712fSAlex Zinenko continue;
454c5a6712fSAlex Zinenko
455c5a6712fSAlex Zinenko if (namedAttr.name.empty())
456c5a6712fSAlex Zinenko continue;
457c5a6712fSAlex Zinenko
458029e199dSAlex Zinenko std::string sanitizedName = sanitizeName(namedAttr.name);
459029e199dSAlex Zinenko
460c5a6712fSAlex Zinenko // Unit attributes are handled specially.
461c5a6712fSAlex Zinenko if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
462029e199dSAlex Zinenko os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
463029e199dSAlex Zinenko namedAttr.name);
464029e199dSAlex Zinenko os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
465029e199dSAlex Zinenko namedAttr.name);
466029e199dSAlex Zinenko os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
467029e199dSAlex Zinenko namedAttr.name);
468c5a6712fSAlex Zinenko continue;
469c5a6712fSAlex Zinenko }
470c5a6712fSAlex Zinenko
471c5a6712fSAlex Zinenko // Other kinds of attributes need a mapping to a Python type.
472c5a6712fSAlex Zinenko if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
473c5a6712fSAlex Zinenko continue;
474c5a6712fSAlex Zinenko
475029e199dSAlex Zinenko StringRef pythonType =
476029e199dSAlex Zinenko attributeClasses.lookup(namedAttr.attr.getStorageType());
477029e199dSAlex Zinenko if (namedAttr.attr.isOptional()) {
478029e199dSAlex Zinenko os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
479029e199dSAlex Zinenko pythonType, namedAttr.name);
480029e199dSAlex Zinenko os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
481c5a6712fSAlex Zinenko namedAttr.name);
482029e199dSAlex Zinenko os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
483029e199dSAlex Zinenko namedAttr.name);
484029e199dSAlex Zinenko } else {
485029e199dSAlex Zinenko os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType,
486029e199dSAlex Zinenko namedAttr.name);
487029e199dSAlex Zinenko os << llvm::formatv(attributeSetterTemplate, sanitizedName,
488029e199dSAlex Zinenko namedAttr.name);
489029e199dSAlex Zinenko // Non-optional attributes cannot be deleted.
490029e199dSAlex Zinenko }
491c5a6712fSAlex Zinenko }
492c5a6712fSAlex Zinenko }
493c5a6712fSAlex Zinenko
494f9265de8SAlex Zinenko /// Template for the default auto-generated builder.
49571b6b010SStella Laurenzo /// {0} is a comma-separated list of builder arguments, including the trailing
496f9265de8SAlex Zinenko /// `loc` and `ip`;
4978e6c55c9SStella Laurenzo /// {1} is the code populating `operands`, `results` and `attributes`,
4988e6c55c9SStella Laurenzo /// `successors` fields.
499f9265de8SAlex Zinenko constexpr const char *initTemplate = R"Py(
50071b6b010SStella Laurenzo def __init__(self, {0}):
501f9265de8SAlex Zinenko operands = []
502f9265de8SAlex Zinenko results = []
503f9265de8SAlex Zinenko attributes = {{}
50418fbd5feSAlex Zinenko regions = None
50571b6b010SStella Laurenzo {1}
506fd226c9bSStella Laurenzo super().__init__(self.build_generic(
507fd226c9bSStella Laurenzo attributes=attributes, results=results, operands=operands,
50818fbd5feSAlex Zinenko successors=_ods_successors, regions=regions, loc=loc, ip=ip))
509f9265de8SAlex Zinenko )Py";
510f9265de8SAlex Zinenko
511f9265de8SAlex Zinenko /// Template for appending a single element to the operand/result list.
512b164f23cSAlex Zinenko /// {0} is the field name.
513b164f23cSAlex Zinenko constexpr const char *singleOperandAppendTemplate =
514b164f23cSAlex Zinenko "operands.append(_get_op_result_or_value({0}))";
515b164f23cSAlex Zinenko constexpr const char *singleResultAppendTemplate = "results.append({0})";
516f9265de8SAlex Zinenko
517f9265de8SAlex Zinenko /// Template for appending an optional element to the operand/result list.
518b164f23cSAlex Zinenko /// {0} is the field name.
519b164f23cSAlex Zinenko constexpr const char *optionalAppendOperandTemplate =
520b164f23cSAlex Zinenko "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
5216981e5ecSAlex Zinenko constexpr const char *optionalAppendAttrSizedOperandsTemplate =
5226981e5ecSAlex Zinenko "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
5236981e5ecSAlex Zinenko "None)";
524b164f23cSAlex Zinenko constexpr const char *optionalAppendResultTemplate =
525b164f23cSAlex Zinenko "if {0} is not None: results.append({0})";
526f9265de8SAlex Zinenko
527b164f23cSAlex Zinenko /// Template for appending a list of elements to the operand/result list.
528b164f23cSAlex Zinenko /// {0} is the field name.
529b164f23cSAlex Zinenko constexpr const char *multiOperandAppendTemplate =
530b164f23cSAlex Zinenko "operands.extend(_get_op_results_or_values({0}))";
531b164f23cSAlex Zinenko constexpr const char *multiOperandAppendPackTemplate =
532b164f23cSAlex Zinenko "operands.append(_get_op_results_or_values({0}))";
533b164f23cSAlex Zinenko constexpr const char *multiResultAppendTemplate = "results.extend({0})";
534f9265de8SAlex Zinenko
535c5a6712fSAlex Zinenko /// Template for setting an attribute in the operation builder.
536c5a6712fSAlex Zinenko /// {0} is the attribute name;
537c5a6712fSAlex Zinenko /// {1} is the builder argument name.
538c5a6712fSAlex Zinenko constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
539c5a6712fSAlex Zinenko
540c5a6712fSAlex Zinenko /// Template for setting an optional attribute in the operation builder.
541c5a6712fSAlex Zinenko /// {0} is the attribute name;
542c5a6712fSAlex Zinenko /// {1} is the builder argument name.
543c5a6712fSAlex Zinenko constexpr const char *initOptionalAttributeTemplate =
544c5a6712fSAlex Zinenko R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
545c5a6712fSAlex Zinenko
546*989d2b51SMatthias Springer /// Template for setting an attribute with a default value in the operation
547*989d2b51SMatthias Springer /// builder.
548*989d2b51SMatthias Springer /// {0} is the attribute name;
549*989d2b51SMatthias Springer /// {1} is the builder argument name;
550*989d2b51SMatthias Springer /// {2} is the default value.
551*989d2b51SMatthias Springer constexpr const char *initDefaultValuedAttributeTemplate =
552*989d2b51SMatthias Springer R"Py(attributes["{0}"] = {1} if {1} is not None else {2})Py";
553*989d2b51SMatthias Springer
554*989d2b51SMatthias Springer /// Template for asserting that an attribute value was provided when calling a
555*989d2b51SMatthias Springer /// builder.
556*989d2b51SMatthias Springer /// {0} is the attribute name;
557*989d2b51SMatthias Springer /// {1} is the builder argument name.
558*989d2b51SMatthias Springer constexpr const char *assertAttributeValueSpecified =
559*989d2b51SMatthias Springer R"Py(assert {1} is not None, "attribute {0} must be specified")Py";
560*989d2b51SMatthias Springer
561c5a6712fSAlex Zinenko constexpr const char *initUnitAttributeTemplate =
5628a1f1a10SStella Laurenzo R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
5638a1f1a10SStella Laurenzo _ods_get_default_loc_context(loc)))Py";
564c5a6712fSAlex Zinenko
5658e6c55c9SStella Laurenzo /// Template to initialize the successors list in the builder if there are any
5668e6c55c9SStella Laurenzo /// successors.
5678e6c55c9SStella Laurenzo /// {0} is the value to initialize the successors list to.
5688e6c55c9SStella Laurenzo constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
5698e6c55c9SStella Laurenzo
5708e6c55c9SStella Laurenzo /// Template to append or extend the list of successors in the builder.
5718e6c55c9SStella Laurenzo /// {0} is the list method ('append' or 'extend');
5728e6c55c9SStella Laurenzo /// {1} is the value to add.
5738e6c55c9SStella Laurenzo constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
5748e6c55c9SStella Laurenzo
5752995d29bSAlex Zinenko /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
5762995d29bSAlex Zinenko /// result types of the given operation.
hasSameArgumentAndResultTypes(const Operator & op)5772995d29bSAlex Zinenko static bool hasSameArgumentAndResultTypes(const Operator &op) {
5782995d29bSAlex Zinenko return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
5792995d29bSAlex Zinenko op.getNumVariableLengthResults() == 0;
5802995d29bSAlex Zinenko }
5812995d29bSAlex Zinenko
5822995d29bSAlex Zinenko /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
5832995d29bSAlex Zinenko /// result types of the given operation.
hasFirstAttrDerivedResultTypes(const Operator & op)5842995d29bSAlex Zinenko static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
5852995d29bSAlex Zinenko return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
5862995d29bSAlex Zinenko op.getNumVariableLengthResults() == 0;
5872995d29bSAlex Zinenko }
5882995d29bSAlex Zinenko
5892995d29bSAlex Zinenko /// Returns true if the InferTypeOpInterface can be used to infer result types
5902995d29bSAlex Zinenko /// of the given operation.
hasInferTypeInterface(const Operator & op)5912995d29bSAlex Zinenko static bool hasInferTypeInterface(const Operator &op) {
5922995d29bSAlex Zinenko return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
5932995d29bSAlex Zinenko op.getNumRegions() == 0;
5942995d29bSAlex Zinenko }
5952995d29bSAlex Zinenko
5962995d29bSAlex Zinenko /// Returns true if there is a trait or interface that can be used to infer
5972995d29bSAlex Zinenko /// result types of the given operation.
canInferType(const Operator & op)5982995d29bSAlex Zinenko static bool canInferType(const Operator &op) {
5992995d29bSAlex Zinenko return hasSameArgumentAndResultTypes(op) ||
6002995d29bSAlex Zinenko hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
6012995d29bSAlex Zinenko }
6022995d29bSAlex Zinenko
6032995d29bSAlex Zinenko /// Populates `builderArgs` with result names if the builder is expected to
6042995d29bSAlex Zinenko /// accept them as arguments.
605c5a6712fSAlex Zinenko static void
populateBuilderArgsResults(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs)6062995d29bSAlex Zinenko populateBuilderArgsResults(const Operator &op,
6072995d29bSAlex Zinenko llvm::SmallVectorImpl<std::string> &builderArgs) {
6082995d29bSAlex Zinenko if (canInferType(op))
6092995d29bSAlex Zinenko return;
6102995d29bSAlex Zinenko
611c5a6712fSAlex Zinenko for (int i = 0, e = op.getNumResults(); i < e; ++i) {
612c5a6712fSAlex Zinenko std::string name = op.getResultName(i).str();
613fd226c9bSStella Laurenzo if (name.empty()) {
614fd226c9bSStella Laurenzo if (op.getNumResults() == 1) {
615fd226c9bSStella Laurenzo // Special case for one result, make the default name be 'result'
616fd226c9bSStella Laurenzo // to properly match the built-in result accessor.
617fd226c9bSStella Laurenzo name = "result";
618fd226c9bSStella Laurenzo } else {
619c5a6712fSAlex Zinenko name = llvm::formatv("_gen_res_{0}", i);
620fd226c9bSStella Laurenzo }
621fd226c9bSStella Laurenzo }
622c5a6712fSAlex Zinenko name = sanitizeName(name);
623c5a6712fSAlex Zinenko builderArgs.push_back(name);
624c5a6712fSAlex Zinenko }
6252995d29bSAlex Zinenko }
6262995d29bSAlex Zinenko
6272995d29bSAlex Zinenko /// Populates `builderArgs` with the Python-compatible names of builder function
6282995d29bSAlex Zinenko /// arguments using intermixed attributes and operands in the same order as they
6292995d29bSAlex Zinenko /// appear in the `arguments` field of the op definition. Additionally,
6302995d29bSAlex Zinenko /// `operandNames` is populated with names of operands in their order of
6312995d29bSAlex Zinenko /// appearance.
6322995d29bSAlex Zinenko static void
populateBuilderArgs(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs,llvm::SmallVectorImpl<std::string> & operandNames,llvm::SmallVectorImpl<std::string> & successorArgNames)6332995d29bSAlex Zinenko populateBuilderArgs(const Operator &op,
6342995d29bSAlex Zinenko llvm::SmallVectorImpl<std::string> &builderArgs,
6352995d29bSAlex Zinenko llvm::SmallVectorImpl<std::string> &operandNames,
6362995d29bSAlex Zinenko llvm::SmallVectorImpl<std::string> &successorArgNames) {
6372995d29bSAlex Zinenko
638c5a6712fSAlex Zinenko for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
639c5a6712fSAlex Zinenko std::string name = op.getArgName(i).str();
640c5a6712fSAlex Zinenko if (name.empty())
641c5a6712fSAlex Zinenko name = llvm::formatv("_gen_arg_{0}", i);
642c5a6712fSAlex Zinenko name = sanitizeName(name);
643c5a6712fSAlex Zinenko builderArgs.push_back(name);
644c5a6712fSAlex Zinenko if (!op.getArg(i).is<NamedAttribute *>())
645c5a6712fSAlex Zinenko operandNames.push_back(name);
646c5a6712fSAlex Zinenko }
6479b79f50bSJeremy Furtek }
6489b79f50bSJeremy Furtek
6499b79f50bSJeremy Furtek /// Populates `builderArgs` with the Python-compatible names of builder function
6509b79f50bSJeremy Furtek /// successor arguments. Additionally, `successorArgNames` is also populated.
populateBuilderArgsSuccessors(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs,llvm::SmallVectorImpl<std::string> & successorArgNames)6519b79f50bSJeremy Furtek static void populateBuilderArgsSuccessors(
6529b79f50bSJeremy Furtek const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
6539b79f50bSJeremy Furtek llvm::SmallVectorImpl<std::string> &successorArgNames) {
6548e6c55c9SStella Laurenzo
6558e6c55c9SStella Laurenzo for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
6568e6c55c9SStella Laurenzo NamedSuccessor successor = op.getSuccessor(i);
6578e6c55c9SStella Laurenzo std::string name = std::string(successor.name);
6588e6c55c9SStella Laurenzo if (name.empty())
6598e6c55c9SStella Laurenzo name = llvm::formatv("_gen_successor_{0}", i);
6608e6c55c9SStella Laurenzo name = sanitizeName(name);
6618e6c55c9SStella Laurenzo builderArgs.push_back(name);
6628e6c55c9SStella Laurenzo successorArgNames.push_back(name);
6638e6c55c9SStella Laurenzo }
664c5a6712fSAlex Zinenko }
665c5a6712fSAlex Zinenko
666*989d2b51SMatthias Springer /// Generates Python code for the default value of the given attribute.
getAttributeDefaultValue(Attribute attr)667*989d2b51SMatthias Springer static FailureOr<std::string> getAttributeDefaultValue(Attribute attr) {
668*989d2b51SMatthias Springer assert(attr.hasDefaultValue() && "expected attribute with default value");
669*989d2b51SMatthias Springer StringRef storageType = attr.getStorageType().trim();
670*989d2b51SMatthias Springer StringRef defaultValCpp = attr.getDefaultValue().trim();
671*989d2b51SMatthias Springer
672*989d2b51SMatthias Springer // A list of commonly used attribute types and default values for which
673*989d2b51SMatthias Springer // we can generate Python code. Extend as needed.
674*989d2b51SMatthias Springer if (storageType.equals("::mlir::ArrayAttr") && defaultValCpp.equals("{}"))
675*989d2b51SMatthias Springer return std::string("_ods_ir.ArrayAttr.get([])");
676*989d2b51SMatthias Springer
677*989d2b51SMatthias Springer // No match: Cannot generate Python code.
678*989d2b51SMatthias Springer return failure();
679*989d2b51SMatthias Springer }
680*989d2b51SMatthias Springer
681c5a6712fSAlex Zinenko /// Populates `builderLines` with additional lines that are required in the
682c5a6712fSAlex Zinenko /// builder to set up operation attributes. `argNames` is expected to contain
683c5a6712fSAlex Zinenko /// the names of builder arguments that correspond to op arguments, i.e. to the
684c5a6712fSAlex Zinenko /// operands and attributes in the same order as they appear in the `arguments`
685c5a6712fSAlex Zinenko /// field.
686c5a6712fSAlex Zinenko static void
populateBuilderLinesAttr(const Operator & op,llvm::ArrayRef<std::string> argNames,llvm::SmallVectorImpl<std::string> & builderLines)687c5a6712fSAlex Zinenko populateBuilderLinesAttr(const Operator &op,
688c5a6712fSAlex Zinenko llvm::ArrayRef<std::string> argNames,
689c5a6712fSAlex Zinenko llvm::SmallVectorImpl<std::string> &builderLines) {
690c5a6712fSAlex Zinenko for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
691c5a6712fSAlex Zinenko Argument arg = op.getArg(i);
692c5a6712fSAlex Zinenko auto *attribute = arg.dyn_cast<NamedAttribute *>();
693c5a6712fSAlex Zinenko if (!attribute)
694c5a6712fSAlex Zinenko continue;
695c5a6712fSAlex Zinenko
696c5a6712fSAlex Zinenko // Unit attributes are handled specially.
697c5a6712fSAlex Zinenko if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
698c5a6712fSAlex Zinenko builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
699c5a6712fSAlex Zinenko attribute->name, argNames[i]));
700c5a6712fSAlex Zinenko continue;
701c5a6712fSAlex Zinenko }
702c5a6712fSAlex Zinenko
703*989d2b51SMatthias Springer // Attributes with default value are handled specially.
704*989d2b51SMatthias Springer if (attribute->attr.hasDefaultValue()) {
705*989d2b51SMatthias Springer // In case we cannot generate Python code for the default value, the
706*989d2b51SMatthias Springer // attribute must be specified by the user.
707*989d2b51SMatthias Springer FailureOr<std::string> defaultValPy =
708*989d2b51SMatthias Springer getAttributeDefaultValue(attribute->attr);
709*989d2b51SMatthias Springer if (succeeded(defaultValPy)) {
710*989d2b51SMatthias Springer builderLines.push_back(llvm::formatv(initDefaultValuedAttributeTemplate,
711*989d2b51SMatthias Springer attribute->name, argNames[i],
712*989d2b51SMatthias Springer *defaultValPy));
713*989d2b51SMatthias Springer } else {
714*989d2b51SMatthias Springer builderLines.push_back(llvm::formatv(assertAttributeValueSpecified,
715*989d2b51SMatthias Springer attribute->name, argNames[i]));
716*989d2b51SMatthias Springer builderLines.push_back(
717*989d2b51SMatthias Springer llvm::formatv(initAttributeTemplate, attribute->name, argNames[i]));
718*989d2b51SMatthias Springer }
719*989d2b51SMatthias Springer continue;
720*989d2b51SMatthias Springer }
721*989d2b51SMatthias Springer
722c5a6712fSAlex Zinenko builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
723c5a6712fSAlex Zinenko ? initOptionalAttributeTemplate
724c5a6712fSAlex Zinenko : initAttributeTemplate,
725c5a6712fSAlex Zinenko attribute->name, argNames[i]));
726c5a6712fSAlex Zinenko }
727c5a6712fSAlex Zinenko }
728c5a6712fSAlex Zinenko
729c5a6712fSAlex Zinenko /// Populates `builderLines` with additional lines that are required in the
7308e6c55c9SStella Laurenzo /// builder to set up successors. successorArgNames is expected to correspond
7318e6c55c9SStella Laurenzo /// to the Python argument name for each successor on the op.
populateBuilderLinesSuccessors(const Operator & op,llvm::ArrayRef<std::string> successorArgNames,llvm::SmallVectorImpl<std::string> & builderLines)7328e6c55c9SStella Laurenzo static void populateBuilderLinesSuccessors(
7338e6c55c9SStella Laurenzo const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
7348e6c55c9SStella Laurenzo llvm::SmallVectorImpl<std::string> &builderLines) {
7358e6c55c9SStella Laurenzo if (successorArgNames.empty()) {
7368e6c55c9SStella Laurenzo builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
7378e6c55c9SStella Laurenzo return;
7388e6c55c9SStella Laurenzo }
7398e6c55c9SStella Laurenzo
7408e6c55c9SStella Laurenzo builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
7418e6c55c9SStella Laurenzo for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
7428e6c55c9SStella Laurenzo auto &argName = successorArgNames[i];
7438e6c55c9SStella Laurenzo const NamedSuccessor &successor = op.getSuccessor(i);
7448e6c55c9SStella Laurenzo builderLines.push_back(
7458e6c55c9SStella Laurenzo llvm::formatv(addSuccessorTemplate,
7468e6c55c9SStella Laurenzo successor.isVariadic() ? "extend" : "append", argName));
7478e6c55c9SStella Laurenzo }
7488e6c55c9SStella Laurenzo }
7498e6c55c9SStella Laurenzo
7508e6c55c9SStella Laurenzo /// Populates `builderLines` with additional lines that are required in the
751b164f23cSAlex Zinenko /// builder to set up op operands.
752b164f23cSAlex Zinenko static void
populateBuilderLinesOperand(const Operator & op,llvm::ArrayRef<std::string> names,llvm::SmallVectorImpl<std::string> & builderLines)753b164f23cSAlex Zinenko populateBuilderLinesOperand(const Operator &op,
754b164f23cSAlex Zinenko llvm::ArrayRef<std::string> names,
755b164f23cSAlex Zinenko llvm::SmallVectorImpl<std::string> &builderLines) {
756b164f23cSAlex Zinenko bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
757f9265de8SAlex Zinenko
758f9265de8SAlex Zinenko // For each element, find or generate a name.
759b164f23cSAlex Zinenko for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
760b164f23cSAlex Zinenko const NamedTypeConstraint &element = op.getOperand(i);
761c5a6712fSAlex Zinenko std::string name = names[i];
762f9265de8SAlex Zinenko
763f9265de8SAlex Zinenko // Choose the formatting string based on the element kind.
76471b6b010SStella Laurenzo llvm::StringRef formatString;
765f9265de8SAlex Zinenko if (!element.isVariableLength()) {
766b164f23cSAlex Zinenko formatString = singleOperandAppendTemplate;
767f9265de8SAlex Zinenko } else if (element.isOptional()) {
7686981e5ecSAlex Zinenko if (sizedSegments) {
7696981e5ecSAlex Zinenko formatString = optionalAppendAttrSizedOperandsTemplate;
7706981e5ecSAlex Zinenko } else {
771b164f23cSAlex Zinenko formatString = optionalAppendOperandTemplate;
7726981e5ecSAlex Zinenko }
773f9265de8SAlex Zinenko } else {
774f9265de8SAlex Zinenko assert(element.isVariadic() && "unhandled element group type");
775b164f23cSAlex Zinenko // If emitting with sizedSegments, then we add the actual list-typed
776b164f23cSAlex Zinenko // element. Otherwise, we extend the actual operands.
77771b6b010SStella Laurenzo if (sizedSegments) {
778b164f23cSAlex Zinenko formatString = multiOperandAppendPackTemplate;
77971b6b010SStella Laurenzo } else {
780b164f23cSAlex Zinenko formatString = multiOperandAppendTemplate;
78171b6b010SStella Laurenzo }
782f9265de8SAlex Zinenko }
783f9265de8SAlex Zinenko
784b164f23cSAlex Zinenko builderLines.push_back(llvm::formatv(formatString.data(), name));
785b164f23cSAlex Zinenko }
786b164f23cSAlex Zinenko }
787b164f23cSAlex Zinenko
7882995d29bSAlex Zinenko /// Python code template for deriving the operation result types from its
7892995d29bSAlex Zinenko /// attribute:
7902995d29bSAlex Zinenko /// - {0} is the name of the attribute from which to derive the types.
7912995d29bSAlex Zinenko constexpr const char *deriveTypeFromAttrTemplate =
7922995d29bSAlex Zinenko R"PY(_ods_result_type_source_attr = attributes["{0}"]
7932995d29bSAlex Zinenko _ods_derived_result_type = (
7942995d29bSAlex Zinenko _ods_ir.TypeAttr(_ods_result_type_source_attr).value
7952995d29bSAlex Zinenko if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
7962995d29bSAlex Zinenko _ods_result_type_source_attr.type))PY";
7972995d29bSAlex Zinenko
7982995d29bSAlex Zinenko /// Python code template appending {0} type {1} times to the results list.
7992995d29bSAlex Zinenko constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
8002995d29bSAlex Zinenko
8012995d29bSAlex Zinenko /// Python code template for inferring the operation results using the
8022995d29bSAlex Zinenko /// corresponding interface:
8032995d29bSAlex Zinenko /// - {0} is the name of the class for which the types are inferred.
8042995d29bSAlex Zinenko constexpr const char *inferTypeInterfaceTemplate =
8052995d29bSAlex Zinenko R"PY(_ods_context = _ods_get_default_loc_context(loc)
8062995d29bSAlex Zinenko results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
8072995d29bSAlex Zinenko operands=operands,
8082995d29bSAlex Zinenko attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
8092995d29bSAlex Zinenko context=_ods_context,
8102995d29bSAlex Zinenko loc=loc)
8112995d29bSAlex Zinenko )PY";
8122995d29bSAlex Zinenko
8132995d29bSAlex Zinenko /// Appends the given multiline string as individual strings into
8142995d29bSAlex Zinenko /// `builderLines`.
appendLineByLine(StringRef string,llvm::SmallVectorImpl<std::string> & builderLines)8152995d29bSAlex Zinenko static void appendLineByLine(StringRef string,
8162995d29bSAlex Zinenko llvm::SmallVectorImpl<std::string> &builderLines) {
8172995d29bSAlex Zinenko
8182995d29bSAlex Zinenko std::pair<StringRef, StringRef> split = std::make_pair(string, string);
8192995d29bSAlex Zinenko do {
8202995d29bSAlex Zinenko split = split.second.split('\n');
8212995d29bSAlex Zinenko builderLines.push_back(split.first.str());
8222995d29bSAlex Zinenko } while (!split.second.empty());
8232995d29bSAlex Zinenko }
8242995d29bSAlex Zinenko
825b164f23cSAlex Zinenko /// Populates `builderLines` with additional lines that are required in the
826b164f23cSAlex Zinenko /// builder to set up op results.
827b164f23cSAlex Zinenko static void
populateBuilderLinesResult(const Operator & op,llvm::ArrayRef<std::string> names,llvm::SmallVectorImpl<std::string> & builderLines)828b164f23cSAlex Zinenko populateBuilderLinesResult(const Operator &op,
829b164f23cSAlex Zinenko llvm::ArrayRef<std::string> names,
830b164f23cSAlex Zinenko llvm::SmallVectorImpl<std::string> &builderLines) {
831b164f23cSAlex Zinenko bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
832b164f23cSAlex Zinenko
8332995d29bSAlex Zinenko if (hasSameArgumentAndResultTypes(op)) {
8342995d29bSAlex Zinenko builderLines.push_back(llvm::formatv(
8352995d29bSAlex Zinenko appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
8362995d29bSAlex Zinenko return;
8372995d29bSAlex Zinenko }
8382995d29bSAlex Zinenko
8392995d29bSAlex Zinenko if (hasFirstAttrDerivedResultTypes(op)) {
8402995d29bSAlex Zinenko const NamedAttribute &firstAttr = op.getAttribute(0);
8412995d29bSAlex Zinenko assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
8422995d29bSAlex Zinenko "from which the type is derived");
8432995d29bSAlex Zinenko appendLineByLine(
8442995d29bSAlex Zinenko llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
8452995d29bSAlex Zinenko builderLines);
8462995d29bSAlex Zinenko builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
8472995d29bSAlex Zinenko "_ods_derived_result_type",
8482995d29bSAlex Zinenko op.getNumResults()));
8492995d29bSAlex Zinenko return;
8502995d29bSAlex Zinenko }
8512995d29bSAlex Zinenko
8522995d29bSAlex Zinenko if (hasInferTypeInterface(op)) {
8532995d29bSAlex Zinenko appendLineByLine(
8542995d29bSAlex Zinenko llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
8552995d29bSAlex Zinenko builderLines);
8562995d29bSAlex Zinenko return;
8572995d29bSAlex Zinenko }
8582995d29bSAlex Zinenko
859b164f23cSAlex Zinenko // For each element, find or generate a name.
860b164f23cSAlex Zinenko for (int i = 0, e = op.getNumResults(); i < e; ++i) {
861b164f23cSAlex Zinenko const NamedTypeConstraint &element = op.getResult(i);
862b164f23cSAlex Zinenko std::string name = names[i];
863b164f23cSAlex Zinenko
864b164f23cSAlex Zinenko // Choose the formatting string based on the element kind.
865b164f23cSAlex Zinenko llvm::StringRef formatString;
866b164f23cSAlex Zinenko if (!element.isVariableLength()) {
867b164f23cSAlex Zinenko formatString = singleResultAppendTemplate;
868b164f23cSAlex Zinenko } else if (element.isOptional()) {
869b164f23cSAlex Zinenko formatString = optionalAppendResultTemplate;
870b164f23cSAlex Zinenko } else {
871b164f23cSAlex Zinenko assert(element.isVariadic() && "unhandled element group type");
872b164f23cSAlex Zinenko // If emitting with sizedSegments, then we add the actual list-typed
873b164f23cSAlex Zinenko // element. Otherwise, we extend the actual operands.
874b164f23cSAlex Zinenko if (sizedSegments) {
875b164f23cSAlex Zinenko formatString = singleResultAppendTemplate;
876b164f23cSAlex Zinenko } else {
877b164f23cSAlex Zinenko formatString = multiResultAppendTemplate;
878b164f23cSAlex Zinenko }
879b164f23cSAlex Zinenko }
880b164f23cSAlex Zinenko
881b164f23cSAlex Zinenko builderLines.push_back(llvm::formatv(formatString.data(), name));
882f9265de8SAlex Zinenko }
883f9265de8SAlex Zinenko }
884f9265de8SAlex Zinenko
88518fbd5feSAlex Zinenko /// If the operation has variadic regions, adds a builder argument to specify
88618fbd5feSAlex Zinenko /// the number of those regions and builder lines to forward it to the generic
88718fbd5feSAlex Zinenko /// constructor.
88818fbd5feSAlex Zinenko static void
populateBuilderRegions(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs,llvm::SmallVectorImpl<std::string> & builderLines)88918fbd5feSAlex Zinenko populateBuilderRegions(const Operator &op,
89018fbd5feSAlex Zinenko llvm::SmallVectorImpl<std::string> &builderArgs,
89118fbd5feSAlex Zinenko llvm::SmallVectorImpl<std::string> &builderLines) {
89218fbd5feSAlex Zinenko if (op.hasNoVariadicRegions())
89318fbd5feSAlex Zinenko return;
89418fbd5feSAlex Zinenko
89518fbd5feSAlex Zinenko // This is currently enforced when Operator is constructed.
89618fbd5feSAlex Zinenko assert(op.getNumVariadicRegions() == 1 &&
89718fbd5feSAlex Zinenko op.getRegion(op.getNumRegions() - 1).isVariadic() &&
89818fbd5feSAlex Zinenko "expected the last region to be varidic");
89918fbd5feSAlex Zinenko
90018fbd5feSAlex Zinenko const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1);
90118fbd5feSAlex Zinenko std::string name =
90218fbd5feSAlex Zinenko ("num_" + region.name.take_front().lower() + region.name.drop_front())
90318fbd5feSAlex Zinenko .str();
90418fbd5feSAlex Zinenko builderArgs.push_back(name);
90518fbd5feSAlex Zinenko builderLines.push_back(
90618fbd5feSAlex Zinenko llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
90718fbd5feSAlex Zinenko }
90818fbd5feSAlex Zinenko
909f9265de8SAlex Zinenko /// Emits a default builder constructing an operation from the list of its
910f9265de8SAlex Zinenko /// result types, followed by a list of its operands.
emitDefaultOpBuilder(const Operator & op,raw_ostream & os)911f9265de8SAlex Zinenko static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
912f9265de8SAlex Zinenko // If we are asked to skip default builders, comply.
913f9265de8SAlex Zinenko if (op.skipDefaultBuilders())
914f9265de8SAlex Zinenko return;
915f9265de8SAlex Zinenko
9168e6c55c9SStella Laurenzo llvm::SmallVector<std::string> builderArgs;
9178e6c55c9SStella Laurenzo llvm::SmallVector<std::string> builderLines;
9188e6c55c9SStella Laurenzo llvm::SmallVector<std::string> operandArgNames;
9198e6c55c9SStella Laurenzo llvm::SmallVector<std::string> successorArgNames;
920c5a6712fSAlex Zinenko builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
9218e6c55c9SStella Laurenzo op.getNumNativeAttributes() + op.getNumSuccessors());
9222995d29bSAlex Zinenko populateBuilderArgsResults(op, builderArgs);
9232995d29bSAlex Zinenko size_t numResultArgs = builderArgs.size();
9248e6c55c9SStella Laurenzo populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
9259b79f50bSJeremy Furtek size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
9269b79f50bSJeremy Furtek populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
9278e6c55c9SStella Laurenzo
928b164f23cSAlex Zinenko populateBuilderLinesOperand(op, operandArgNames, builderLines);
929c5a6712fSAlex Zinenko populateBuilderLinesAttr(
9302995d29bSAlex Zinenko op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs),
9312995d29bSAlex Zinenko builderLines);
9322995d29bSAlex Zinenko populateBuilderLinesResult(
9332995d29bSAlex Zinenko op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs),
934c5a6712fSAlex Zinenko builderLines);
9358e6c55c9SStella Laurenzo populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
93618fbd5feSAlex Zinenko populateBuilderRegions(op, builderArgs, builderLines);
937f9265de8SAlex Zinenko
9389b79f50bSJeremy Furtek // Layout of builderArgs vector elements:
9399b79f50bSJeremy Furtek // [ result_args operand_attr_args successor_args regions ]
9409b79f50bSJeremy Furtek
9419b79f50bSJeremy Furtek // Determine whether the argument corresponding to a given index into the
9429b79f50bSJeremy Furtek // builderArgs vector is a python keyword argument or not.
9439b79f50bSJeremy Furtek auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
9449b79f50bSJeremy Furtek // All result, successor, and region arguments are positional arguments.
9459b79f50bSJeremy Furtek if ((builderArgIndex < numResultArgs) ||
9469b79f50bSJeremy Furtek (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
9479b79f50bSJeremy Furtek return false;
9489b79f50bSJeremy Furtek // Keyword arguments:
9499b79f50bSJeremy Furtek // - optional named attributes (including unit attributes)
9509b79f50bSJeremy Furtek // - default-valued named attributes
9519b79f50bSJeremy Furtek // - optional operands
9529b79f50bSJeremy Furtek Argument a = op.getArg(builderArgIndex - numResultArgs);
9539b79f50bSJeremy Furtek if (auto *nattr = a.dyn_cast<NamedAttribute *>())
9549b79f50bSJeremy Furtek return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
955eacfd047SMehdi Amini if (auto *ntype = a.dyn_cast<NamedTypeConstraint *>())
9569b79f50bSJeremy Furtek return ntype->isOptional();
9579b79f50bSJeremy Furtek return false;
9589b79f50bSJeremy Furtek };
9599b79f50bSJeremy Furtek
9609b79f50bSJeremy Furtek // StringRefs in functionArgs refer to strings allocated by builderArgs.
9619b79f50bSJeremy Furtek llvm::SmallVector<llvm::StringRef> functionArgs;
9629b79f50bSJeremy Furtek
9639b79f50bSJeremy Furtek // Add positional arguments.
9649b79f50bSJeremy Furtek for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
9659b79f50bSJeremy Furtek if (!isKeywordArgFn(i))
9669b79f50bSJeremy Furtek functionArgs.push_back(builderArgs[i]);
9679b79f50bSJeremy Furtek }
9689b79f50bSJeremy Furtek
9699b79f50bSJeremy Furtek // Add a bare '*' to indicate that all following arguments must be keyword
9709b79f50bSJeremy Furtek // arguments.
9719b79f50bSJeremy Furtek functionArgs.push_back("*");
9729b79f50bSJeremy Furtek
9739b79f50bSJeremy Furtek // Add a default 'None' value to each keyword arg string, and then add to the
9749b79f50bSJeremy Furtek // function args list.
9759b79f50bSJeremy Furtek for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
9769b79f50bSJeremy Furtek if (isKeywordArgFn(i)) {
9779b79f50bSJeremy Furtek builderArgs[i].append("=None");
9789b79f50bSJeremy Furtek functionArgs.push_back(builderArgs[i]);
9799b79f50bSJeremy Furtek }
9809b79f50bSJeremy Furtek }
9819b79f50bSJeremy Furtek functionArgs.push_back("loc=None");
9829b79f50bSJeremy Furtek functionArgs.push_back("ip=None");
9839b79f50bSJeremy Furtek os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
984f9265de8SAlex Zinenko llvm::join(builderLines, "\n "));
985fd407e1fSAlex Zinenko }
986fd407e1fSAlex Zinenko
constructAttributeMapping(const llvm::RecordKeeper & records,AttributeClasses & attributeClasses)987c5a6712fSAlex Zinenko static void constructAttributeMapping(const llvm::RecordKeeper &records,
988c5a6712fSAlex Zinenko AttributeClasses &attributeClasses) {
989c5a6712fSAlex Zinenko for (const llvm::Record *rec :
990c5a6712fSAlex Zinenko records.getAllDerivedDefinitions("PythonAttr")) {
991c5a6712fSAlex Zinenko attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
992c5a6712fSAlex Zinenko rec->getValueAsString("pythonType").trim());
993c5a6712fSAlex Zinenko }
994c5a6712fSAlex Zinenko }
995c5a6712fSAlex Zinenko
emitSegmentSpec(const Operator & op,const char * kind,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement,raw_ostream & os)99671b6b010SStella Laurenzo static void emitSegmentSpec(
99771b6b010SStella Laurenzo const Operator &op, const char *kind,
99871b6b010SStella Laurenzo llvm::function_ref<int(const Operator &)> getNumElements,
99971b6b010SStella Laurenzo llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
100071b6b010SStella Laurenzo getElement,
100171b6b010SStella Laurenzo raw_ostream &os) {
100271b6b010SStella Laurenzo std::string segmentSpec("[");
100371b6b010SStella Laurenzo for (int i = 0, e = getNumElements(op); i < e; ++i) {
100471b6b010SStella Laurenzo const NamedTypeConstraint &element = getElement(op, i);
10056981e5ecSAlex Zinenko if (element.isOptional()) {
100671b6b010SStella Laurenzo segmentSpec.append("0,");
10076981e5ecSAlex Zinenko } else if (element.isVariadic()) {
10086981e5ecSAlex Zinenko segmentSpec.append("-1,");
100971b6b010SStella Laurenzo } else {
101071b6b010SStella Laurenzo segmentSpec.append("1,");
101171b6b010SStella Laurenzo }
101271b6b010SStella Laurenzo }
101371b6b010SStella Laurenzo segmentSpec.append("]");
101471b6b010SStella Laurenzo
101571b6b010SStella Laurenzo os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
101671b6b010SStella Laurenzo }
101771b6b010SStella Laurenzo
emitRegionAttributes(const Operator & op,raw_ostream & os)101871b6b010SStella Laurenzo static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
101971b6b010SStella Laurenzo // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
102071b6b010SStella Laurenzo // Note that the base OpView class defines this as (0, True).
102171b6b010SStella Laurenzo unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
102271b6b010SStella Laurenzo os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
102371b6b010SStella Laurenzo op.hasNoVariadicRegions() ? "True" : "False");
102471b6b010SStella Laurenzo }
102571b6b010SStella Laurenzo
102618fbd5feSAlex Zinenko /// Emits named accessors to regions.
emitRegionAccessors(const Operator & op,raw_ostream & os)102718fbd5feSAlex Zinenko static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
102889de9cc8SMehdi Amini for (const auto &en : llvm::enumerate(op.getRegions())) {
102918fbd5feSAlex Zinenko const NamedRegion ®ion = en.value();
103018fbd5feSAlex Zinenko if (region.name.empty())
103118fbd5feSAlex Zinenko continue;
103218fbd5feSAlex Zinenko
103318fbd5feSAlex Zinenko assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
103418fbd5feSAlex Zinenko "expected only the last region to be variadic");
103518fbd5feSAlex Zinenko os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
103618fbd5feSAlex Zinenko std::to_string(en.index()) +
103718fbd5feSAlex Zinenko (region.isVariadic() ? ":" : ""));
103818fbd5feSAlex Zinenko }
103918fbd5feSAlex Zinenko }
104018fbd5feSAlex Zinenko
1041fd407e1fSAlex Zinenko /// Emits bindings for a specific Op to the given output stream.
emitOpBindings(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)1042c5a6712fSAlex Zinenko static void emitOpBindings(const Operator &op,
1043c5a6712fSAlex Zinenko const AttributeClasses &attributeClasses,
1044c5a6712fSAlex Zinenko raw_ostream &os) {
1045fd407e1fSAlex Zinenko os << llvm::formatv(opClassTemplate, op.getCppClassName(),
1046fd407e1fSAlex Zinenko op.getOperationName());
104771b6b010SStella Laurenzo
104871b6b010SStella Laurenzo // Sized segments.
104971b6b010SStella Laurenzo if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
105071b6b010SStella Laurenzo emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
105171b6b010SStella Laurenzo }
105271b6b010SStella Laurenzo if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
105371b6b010SStella Laurenzo emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
105471b6b010SStella Laurenzo }
105571b6b010SStella Laurenzo
105671b6b010SStella Laurenzo emitRegionAttributes(op, os);
1057f9265de8SAlex Zinenko emitDefaultOpBuilder(op, os);
1058fd407e1fSAlex Zinenko emitOperandAccessors(op, os);
1059c5a6712fSAlex Zinenko emitAttributeAccessors(op, attributeClasses, os);
1060fd407e1fSAlex Zinenko emitResultAccessors(op, os);
106118fbd5feSAlex Zinenko emitRegionAccessors(op, os);
1062fd407e1fSAlex Zinenko }
1063fd407e1fSAlex Zinenko
1064fd407e1fSAlex Zinenko /// Emits bindings for the dialect specified in the command line, including file
1065fd407e1fSAlex Zinenko /// headers and utilities. Returns `false` on success to comply with Tablegen
1066fd407e1fSAlex Zinenko /// registration requirements.
emitAllOps(const llvm::RecordKeeper & records,raw_ostream & os)1067fd407e1fSAlex Zinenko static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
1068fd407e1fSAlex Zinenko if (clDialectName.empty())
1069fd407e1fSAlex Zinenko llvm::PrintFatalError("dialect name not provided");
1070fd407e1fSAlex Zinenko
1071c5a6712fSAlex Zinenko AttributeClasses attributeClasses;
1072c5a6712fSAlex Zinenko constructAttributeMapping(records, attributeClasses);
1073c5a6712fSAlex Zinenko
10743f71765aSAlex Zinenko bool isExtension = !clDialectExtensionName.empty();
10753f71765aSAlex Zinenko os << llvm::formatv(fileHeader, isExtension
10763f71765aSAlex Zinenko ? clDialectExtensionName.getValue()
10773f71765aSAlex Zinenko : clDialectName.getValue());
10783f71765aSAlex Zinenko if (isExtension)
10793f71765aSAlex Zinenko os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
10803f71765aSAlex Zinenko else
1081fd407e1fSAlex Zinenko os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
1082922b26cdSMehdi Amini
1083fd407e1fSAlex Zinenko for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
1084fd407e1fSAlex Zinenko Operator op(rec);
1085fd407e1fSAlex Zinenko if (op.getDialectName() == clDialectName.getValue())
1086c5a6712fSAlex Zinenko emitOpBindings(op, attributeClasses, os);
1087fd407e1fSAlex Zinenko }
1088fd407e1fSAlex Zinenko return false;
1089fd407e1fSAlex Zinenko }
1090fd407e1fSAlex Zinenko
1091fd407e1fSAlex Zinenko static GenRegistration
1092fd407e1fSAlex Zinenko genPythonBindings("gen-python-op-bindings",
1093fd407e1fSAlex Zinenko "Generate Python bindings for MLIR Ops", &emitAllOps);
1094