1fd407e1fSAlex Zinenko //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2fd407e1fSAlex Zinenko //
3fd407e1fSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fd407e1fSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5fd407e1fSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fd407e1fSAlex Zinenko //
7fd407e1fSAlex Zinenko //===----------------------------------------------------------------------===//
8fd407e1fSAlex Zinenko //
9fd407e1fSAlex Zinenko // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10fd407e1fSAlex Zinenko // binding classes wrapping a generic operation API.
11fd407e1fSAlex Zinenko //
12fd407e1fSAlex Zinenko //===----------------------------------------------------------------------===//
13fd407e1fSAlex Zinenko 
14*989d2b51SMatthias Springer #include "mlir/Support/LogicalResult.h"
15fd407e1fSAlex Zinenko #include "mlir/TableGen/GenInfo.h"
16fd407e1fSAlex Zinenko #include "mlir/TableGen/Operator.h"
17fd407e1fSAlex Zinenko #include "llvm/ADT/StringSet.h"
18fd407e1fSAlex Zinenko #include "llvm/Support/CommandLine.h"
19fd407e1fSAlex Zinenko #include "llvm/Support/FormatVariadic.h"
20fd407e1fSAlex Zinenko #include "llvm/TableGen/Error.h"
21fd407e1fSAlex Zinenko #include "llvm/TableGen/Record.h"
22fd407e1fSAlex Zinenko 
23fd407e1fSAlex Zinenko using namespace mlir;
24fd407e1fSAlex Zinenko using namespace mlir::tblgen;
25fd407e1fSAlex Zinenko 
26fd407e1fSAlex Zinenko /// File header and includes.
27894d88a7SStella Laurenzo ///   {0} is the dialect namespace.
28fd407e1fSAlex Zinenko constexpr const char *fileHeader = R"Py(
29fd407e1fSAlex Zinenko # Autogenerated by mlir-tblgen; don't manually edit.
30fd407e1fSAlex Zinenko 
31e31c77b1SStella Laurenzo from ._ods_common import _cext as _ods_cext
32b164f23cSAlex Zinenko from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
338a1f1a10SStella Laurenzo _ods_ir = _ods_cext.ir
34894d88a7SStella Laurenzo 
35894d88a7SStella Laurenzo try:
36e31c77b1SStella Laurenzo   from . import _{0}_ops_ext as _ods_ext_module
37894d88a7SStella Laurenzo except ImportError:
38894d88a7SStella Laurenzo   _ods_ext_module = None
39894d88a7SStella Laurenzo 
40b4c93eceSJohn Demme import builtins
41b4c93eceSJohn Demme 
42fd407e1fSAlex Zinenko )Py";
43fd407e1fSAlex Zinenko 
44fd407e1fSAlex Zinenko /// Template for dialect class:
45fd407e1fSAlex Zinenko ///   {0} is the dialect namespace.
46fd407e1fSAlex Zinenko constexpr const char *dialectClassTemplate = R"Py(
478a1f1a10SStella Laurenzo @_ods_cext.register_dialect
488a1f1a10SStella Laurenzo class _Dialect(_ods_ir.Dialect):
49fd407e1fSAlex Zinenko   DIALECT_NAMESPACE = "{0}"
50fd407e1fSAlex Zinenko   pass
51fd407e1fSAlex Zinenko 
52fd407e1fSAlex Zinenko )Py";
53fd407e1fSAlex Zinenko 
543f71765aSAlex Zinenko constexpr const char *dialectExtensionTemplate = R"Py(
553f71765aSAlex Zinenko from ._{0}_ops_gen import _Dialect
563f71765aSAlex Zinenko )Py";
573f71765aSAlex Zinenko 
58fd407e1fSAlex Zinenko /// Template for operation class:
59fd407e1fSAlex Zinenko ///   {0} is the Python class name;
60fd407e1fSAlex Zinenko ///   {1} is the operation name.
61fd407e1fSAlex Zinenko constexpr const char *opClassTemplate = R"Py(
628a1f1a10SStella Laurenzo @_ods_cext.register_operation(_Dialect)
63894d88a7SStella Laurenzo @_ods_extend_opview_class(_ods_ext_module)
648a1f1a10SStella Laurenzo class {0}(_ods_ir.OpView):
65fd407e1fSAlex Zinenko   OPERATION_NAME = "{1}"
66fd407e1fSAlex Zinenko )Py";
67fd407e1fSAlex Zinenko 
6871b6b010SStella Laurenzo /// Template for class level declarations of operand and result
6971b6b010SStella Laurenzo /// segment specs.
7071b6b010SStella Laurenzo ///   {0} is either "OPERAND" or "RESULT"
7171b6b010SStella Laurenzo ///   {1} is the segment spec
7271b6b010SStella Laurenzo /// Each segment spec is either None (default) or an array of integers
7371b6b010SStella Laurenzo /// where:
7471b6b010SStella Laurenzo ///   1 = single element (expect non sequence operand/result)
756981e5ecSAlex Zinenko ///   0 = optional element (expect a value or None)
7671b6b010SStella Laurenzo ///   -1 = operand/result is a sequence corresponding to a variadic
7771b6b010SStella Laurenzo constexpr const char *opClassSizedSegmentsTemplate = R"Py(
7871b6b010SStella Laurenzo   _ODS_{0}_SEGMENTS = {1}
7971b6b010SStella Laurenzo )Py";
8071b6b010SStella Laurenzo 
8171b6b010SStella Laurenzo /// Template for class level declarations of the _ODS_REGIONS spec:
8271b6b010SStella Laurenzo ///   {0} is the minimum number of regions
8371b6b010SStella Laurenzo ///   {1} is the Python bool literal for hasNoVariadicRegions
8471b6b010SStella Laurenzo constexpr const char *opClassRegionSpecTemplate = R"Py(
8571b6b010SStella Laurenzo   _ODS_REGIONS = ({0}, {1})
8671b6b010SStella Laurenzo )Py";
8771b6b010SStella Laurenzo 
88fd407e1fSAlex Zinenko /// Template for single-element accessor:
89fd407e1fSAlex Zinenko ///   {0} is the name of the accessor;
90fd407e1fSAlex Zinenko ///   {1} is either 'operand' or 'result';
91fd407e1fSAlex Zinenko ///   {2} is the position in the element list.
92fd407e1fSAlex Zinenko constexpr const char *opSingleTemplate = R"Py(
93b4c93eceSJohn Demme   @builtins.property
94fd407e1fSAlex Zinenko   def {0}(self):
95fd407e1fSAlex Zinenko     return self.operation.{1}s[{2}]
96fd407e1fSAlex Zinenko )Py";
97fd407e1fSAlex Zinenko 
98fd407e1fSAlex Zinenko /// Template for single-element accessor after a variable-length group:
99fd407e1fSAlex Zinenko ///   {0} is the name of the accessor;
100fd407e1fSAlex Zinenko ///   {1} is either 'operand' or 'result';
101fd407e1fSAlex Zinenko ///   {2} is the total number of element groups;
102fd407e1fSAlex Zinenko ///   {3} is the position of the current group in the group list.
103fd407e1fSAlex Zinenko /// This works for both a single variadic group (non-negative length) and an
104fd407e1fSAlex Zinenko /// single optional element (zero length if the element is absent).
105fd407e1fSAlex Zinenko constexpr const char *opSingleAfterVariableTemplate = R"Py(
106b4c93eceSJohn Demme   @builtins.property
107fd407e1fSAlex Zinenko   def {0}(self):
1088a1f1a10SStella Laurenzo     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
1098a1f1a10SStella Laurenzo     return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
110fd407e1fSAlex Zinenko )Py";
111fd407e1fSAlex Zinenko 
112fd407e1fSAlex Zinenko /// Template for an optional element accessor:
113fd407e1fSAlex Zinenko ///   {0} is the name of the accessor;
114fd407e1fSAlex Zinenko ///   {1} is either 'operand' or 'result';
115fd407e1fSAlex Zinenko ///   {2} is the total number of element groups;
116fd407e1fSAlex Zinenko ///   {3} is the position of the current group in the group list.
11754c99842SMichal Terepeta /// This works if we have only one variable-length group (and it's the optional
11854c99842SMichal Terepeta /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
11954c99842SMichal Terepeta /// smaller than the total number of groups.
120fd407e1fSAlex Zinenko constexpr const char *opOneOptionalTemplate = R"Py(
121b4c93eceSJohn Demme   @builtins.property
122fd226c9bSStella Laurenzo   def {0}(self):
12354c99842SMichal Terepeta     return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
124fd407e1fSAlex Zinenko )Py";
125fd407e1fSAlex Zinenko 
126fd407e1fSAlex Zinenko /// Template for the variadic group accessor in the single variadic group case:
127fd407e1fSAlex Zinenko ///   {0} is the name of the accessor;
128fd407e1fSAlex Zinenko ///   {1} is either 'operand' or 'result';
129fd407e1fSAlex Zinenko ///   {2} is the total number of element groups;
130fd407e1fSAlex Zinenko ///   {3} is the position of the current group in the group list.
131fd407e1fSAlex Zinenko constexpr const char *opOneVariadicTemplate = R"Py(
132b4c93eceSJohn Demme   @builtins.property
133fd407e1fSAlex Zinenko   def {0}(self):
1348a1f1a10SStella Laurenzo     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
1358a1f1a10SStella Laurenzo     return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
136fd407e1fSAlex Zinenko )Py";
137fd407e1fSAlex Zinenko 
138fd407e1fSAlex Zinenko /// First part of the template for equally-sized variadic group accessor:
139fd407e1fSAlex Zinenko ///   {0} is the name of the accessor;
140fd407e1fSAlex Zinenko ///   {1} is either 'operand' or 'result';
141fd407e1fSAlex Zinenko ///   {2} is the total number of variadic groups;
142fd407e1fSAlex Zinenko ///   {3} is the number of non-variadic groups preceding the current group;
143fd407e1fSAlex Zinenko ///   {3} is the number of variadic groups preceding the current group.
144fd407e1fSAlex Zinenko constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
145b4c93eceSJohn Demme   @builtins.property
146fd407e1fSAlex Zinenko   def {0}(self):
1478a1f1a10SStella Laurenzo     start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
148fd407e1fSAlex Zinenko 
149fd407e1fSAlex Zinenko /// Second part of the template for equally-sized case, accessing a single
150fd407e1fSAlex Zinenko /// element:
151fd407e1fSAlex Zinenko ///   {0} is either 'operand' or 'result'.
152fd407e1fSAlex Zinenko constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
153fd407e1fSAlex Zinenko     return self.operation.{0}s[start]
154fd407e1fSAlex Zinenko )Py";
155fd407e1fSAlex Zinenko 
156fd407e1fSAlex Zinenko /// Second part of the template for equally-sized case, accessing a variadic
157fd407e1fSAlex Zinenko /// group:
158fd407e1fSAlex Zinenko ///   {0} is either 'operand' or 'result'.
159fd407e1fSAlex Zinenko constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
160fd407e1fSAlex Zinenko     return self.operation.{0}s[start:start + pg]
161fd407e1fSAlex Zinenko )Py";
162fd407e1fSAlex Zinenko 
163fd407e1fSAlex Zinenko /// Template for an attribute-sized group accessor:
164fd407e1fSAlex Zinenko ///   {0} is the name of the accessor;
165fd407e1fSAlex Zinenko ///   {1} is either 'operand' or 'result';
166fd407e1fSAlex Zinenko ///   {2} is the position of the group in the group list;
167fd407e1fSAlex Zinenko ///   {3} is a return suffix (expected [0] for single-element, empty for
168fd407e1fSAlex Zinenko ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
169fd407e1fSAlex Zinenko constexpr const char *opVariadicSegmentTemplate = R"Py(
170b4c93eceSJohn Demme   @builtins.property
171fd407e1fSAlex Zinenko   def {0}(self):
1728a1f1a10SStella Laurenzo     {1}_range = _ods_segmented_accessor(
173fd407e1fSAlex Zinenko          self.operation.{1}s,
174fd407e1fSAlex Zinenko          self.operation.attributes["{1}_segment_sizes"], {2})
175fd407e1fSAlex Zinenko     return {1}_range{3}
176fd407e1fSAlex Zinenko )Py";
177fd407e1fSAlex Zinenko 
178fd407e1fSAlex Zinenko /// Template for a suffix when accessing an optional element in the
179fd407e1fSAlex Zinenko /// attribute-sized case:
180fd407e1fSAlex Zinenko ///   {0} is either 'operand' or 'result';
181fd407e1fSAlex Zinenko constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
182fd407e1fSAlex Zinenko     R"Py([0] if len({0}_range) > 0 else None)Py";
183fd407e1fSAlex Zinenko 
184c5a6712fSAlex Zinenko /// Template for an operation attribute getter:
185c5a6712fSAlex Zinenko ///   {0} is the name of the attribute sanitized for Python;
186c5a6712fSAlex Zinenko ///   {1} is the Python type of the attribute;
187c5a6712fSAlex Zinenko ///   {2} os the original name of the attribute.
188c5a6712fSAlex Zinenko constexpr const char *attributeGetterTemplate = R"Py(
189b4c93eceSJohn Demme   @builtins.property
190c5a6712fSAlex Zinenko   def {0}(self):
191c5a6712fSAlex Zinenko     return {1}(self.operation.attributes["{2}"])
192c5a6712fSAlex Zinenko )Py";
193c5a6712fSAlex Zinenko 
194c5a6712fSAlex Zinenko /// Template for an optional operation attribute getter:
195c5a6712fSAlex Zinenko ///   {0} is the name of the attribute sanitized for Python;
196c5a6712fSAlex Zinenko ///   {1} is the Python type of the attribute;
197c5a6712fSAlex Zinenko ///   {2} is the original name of the attribute.
198c5a6712fSAlex Zinenko constexpr const char *optionalAttributeGetterTemplate = R"Py(
199b4c93eceSJohn Demme   @builtins.property
200c5a6712fSAlex Zinenko   def {0}(self):
201c5a6712fSAlex Zinenko     if "{2}" not in self.operation.attributes:
202c5a6712fSAlex Zinenko       return None
203c5a6712fSAlex Zinenko     return {1}(self.operation.attributes["{2}"])
204c5a6712fSAlex Zinenko )Py";
205c5a6712fSAlex Zinenko 
206029e199dSAlex Zinenko /// Template for a getter of a unit operation attribute, returns True of the
207c5a6712fSAlex Zinenko /// unit attribute is present, False otherwise (unit attributes have meaning
208c5a6712fSAlex Zinenko /// by mere presence):
209c5a6712fSAlex Zinenko ///    {0} is the name of the attribute sanitized for Python,
210c5a6712fSAlex Zinenko ///    {1} is the original name of the attribute.
211c5a6712fSAlex Zinenko constexpr const char *unitAttributeGetterTemplate = R"Py(
212b4c93eceSJohn Demme   @builtins.property
213c5a6712fSAlex Zinenko   def {0}(self):
214c5a6712fSAlex Zinenko     return "{1}" in self.operation.attributes
215c5a6712fSAlex Zinenko )Py";
216c5a6712fSAlex Zinenko 
217029e199dSAlex Zinenko /// Template for an operation attribute setter:
218029e199dSAlex Zinenko ///    {0} is the name of the attribute sanitized for Python;
219029e199dSAlex Zinenko ///    {1} is the original name of the attribute.
220029e199dSAlex Zinenko constexpr const char *attributeSetterTemplate = R"Py(
221029e199dSAlex Zinenko   @{0}.setter
222029e199dSAlex Zinenko   def {0}(self, value):
223029e199dSAlex Zinenko     if value is None:
224029e199dSAlex Zinenko       raise ValueError("'None' not allowed as value for mandatory attributes")
225029e199dSAlex Zinenko     self.operation.attributes["{1}"] = value
226029e199dSAlex Zinenko )Py";
227029e199dSAlex Zinenko 
228029e199dSAlex Zinenko /// Template for a setter of an optional operation attribute, setting to None
229029e199dSAlex Zinenko /// removes the attribute:
230029e199dSAlex Zinenko ///    {0} is the name of the attribute sanitized for Python;
231029e199dSAlex Zinenko ///    {1} is the original name of the attribute.
232029e199dSAlex Zinenko constexpr const char *optionalAttributeSetterTemplate = R"Py(
233029e199dSAlex Zinenko   @{0}.setter
234029e199dSAlex Zinenko   def {0}(self, value):
235029e199dSAlex Zinenko     if value is not None:
236029e199dSAlex Zinenko       self.operation.attributes["{1}"] = value
237029e199dSAlex Zinenko     elif "{1}" in self.operation.attributes:
238029e199dSAlex Zinenko       del self.operation.attributes["{1}"]
239029e199dSAlex Zinenko )Py";
240029e199dSAlex Zinenko 
241029e199dSAlex Zinenko /// Template for a setter of a unit operation attribute, setting to None or
242029e199dSAlex Zinenko /// False removes the attribute:
243029e199dSAlex Zinenko ///    {0} is the name of the attribute sanitized for Python;
244029e199dSAlex Zinenko ///    {1} is the original name of the attribute.
245029e199dSAlex Zinenko constexpr const char *unitAttributeSetterTemplate = R"Py(
246029e199dSAlex Zinenko   @{0}.setter
247029e199dSAlex Zinenko   def {0}(self, value):
248029e199dSAlex Zinenko     if bool(value):
2498a1f1a10SStella Laurenzo       self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
250029e199dSAlex Zinenko     elif "{1}" in self.operation.attributes:
251029e199dSAlex Zinenko       del self.operation.attributes["{1}"]
252029e199dSAlex Zinenko )Py";
253029e199dSAlex Zinenko 
254029e199dSAlex Zinenko /// Template for a deleter of an optional or a unit operation attribute, removes
255029e199dSAlex Zinenko /// the attribute from the operation:
256029e199dSAlex Zinenko ///    {0} is the name of the attribute sanitized for Python;
257029e199dSAlex Zinenko ///    {1} is the original name of the attribute.
258029e199dSAlex Zinenko constexpr const char *attributeDeleterTemplate = R"Py(
259029e199dSAlex Zinenko   @{0}.deleter
260029e199dSAlex Zinenko   def {0}(self):
261029e199dSAlex Zinenko     del self.operation.attributes["{1}"]
262029e199dSAlex Zinenko )Py";
263029e199dSAlex Zinenko 
26418fbd5feSAlex Zinenko constexpr const char *regionAccessorTemplate = R"PY(
26518fbd5feSAlex Zinenko   @builtins.property
266310736e0SAlex Zinenko   def {0}(self):
26718fbd5feSAlex Zinenko     return self.regions[{1}]
26818fbd5feSAlex Zinenko )PY";
26918fbd5feSAlex Zinenko 
270fd407e1fSAlex Zinenko static llvm::cl::OptionCategory
271fd407e1fSAlex Zinenko     clOpPythonBindingCat("Options for -gen-python-op-bindings");
272fd407e1fSAlex Zinenko 
273fd407e1fSAlex Zinenko static llvm::cl::opt<std::string>
274fd407e1fSAlex Zinenko     clDialectName("bind-dialect",
275fd407e1fSAlex Zinenko                   llvm::cl::desc("The dialect to run the generator for"),
276fd407e1fSAlex Zinenko                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
277fd407e1fSAlex Zinenko 
2783f71765aSAlex Zinenko static llvm::cl::opt<std::string> clDialectExtensionName(
2793f71765aSAlex Zinenko     "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
2803f71765aSAlex Zinenko     llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
2813f71765aSAlex Zinenko 
282c5a6712fSAlex Zinenko using AttributeClasses = DenseMap<StringRef, StringRef>;
283c5a6712fSAlex Zinenko 
284fd407e1fSAlex Zinenko /// Checks whether `str` is a Python keyword.
isPythonKeyword(StringRef str)285fd407e1fSAlex Zinenko static bool isPythonKeyword(StringRef str) {
286fd407e1fSAlex Zinenko   static llvm::StringSet<> keywords(
287fd407e1fSAlex Zinenko       {"and",   "as",     "assert",   "break", "class",  "continue",
288fd407e1fSAlex Zinenko        "def",   "del",    "elif",     "else",  "except", "finally",
289fd407e1fSAlex Zinenko        "for",   "from",   "global",   "if",    "import", "in",
290fd407e1fSAlex Zinenko        "is",    "lambda", "nonlocal", "not",   "or",     "pass",
291fd407e1fSAlex Zinenko        "raise", "return", "try",      "while", "with",   "yield"});
292fd407e1fSAlex Zinenko   return keywords.contains(str);
2937dadcd02SMehdi Amini }
294fd407e1fSAlex Zinenko 
2958a1f1a10SStella Laurenzo /// Checks whether `str` would shadow a generated variable or attribute
2968a1f1a10SStella Laurenzo /// part of the OpView API.
isODSReserved(StringRef str)2978a1f1a10SStella Laurenzo static bool isODSReserved(StringRef str) {
2988a1f1a10SStella Laurenzo   static llvm::StringSet<> reserved(
2998a1f1a10SStella Laurenzo       {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
300fd226c9bSStella Laurenzo        "loc", "verify", "regions", "results", "self", "operation",
3018a1f1a10SStella Laurenzo        "DIALECT_NAMESPACE", "OPERATION_NAME"});
3028a1f1a10SStella Laurenzo   return str.startswith("_ods_") || str.endswith("_ods") ||
3038a1f1a10SStella Laurenzo          reserved.contains(str);
3048a1f1a10SStella Laurenzo }
3058a1f1a10SStella Laurenzo 
306fd407e1fSAlex Zinenko /// Modifies the `name` in a way that it becomes suitable for Python bindings
307fd407e1fSAlex Zinenko /// (does not change the `name` if it already is suitable) and returns the
308fd407e1fSAlex Zinenko /// modified version.
sanitizeName(StringRef name)309fd407e1fSAlex Zinenko static std::string sanitizeName(StringRef name) {
3108a1f1a10SStella Laurenzo   if (isPythonKeyword(name) || isODSReserved(name))
311fd407e1fSAlex Zinenko     return (name + "_").str();
312fd407e1fSAlex Zinenko   return name.str();
313fd407e1fSAlex Zinenko }
314fd407e1fSAlex Zinenko 
attrSizedTraitForKind(const char * kind)315f9265de8SAlex Zinenko static std::string attrSizedTraitForKind(const char *kind) {
316f9265de8SAlex Zinenko   return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
317f9265de8SAlex Zinenko                        llvm::StringRef(kind).take_front().upper(),
318f9265de8SAlex Zinenko                        llvm::StringRef(kind).drop_front());
319f9265de8SAlex Zinenko }
320f9265de8SAlex Zinenko 
321fd407e1fSAlex Zinenko /// Emits accessors to "elements" of an Op definition. Currently, the supported
322fd407e1fSAlex Zinenko /// elements are operands and results, indicated by `kind`, which must be either
323fd407e1fSAlex Zinenko /// `operand` or `result` and is used verbatim in the emitted code.
emitElementAccessors(const Operator & op,raw_ostream & os,const char * kind,llvm::function_ref<unsigned (const Operator &)> getNumVariableLength,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement)324fd407e1fSAlex Zinenko static void emitElementAccessors(
325fd407e1fSAlex Zinenko     const Operator &op, raw_ostream &os, const char *kind,
32654c99842SMichal Terepeta     llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
327fd407e1fSAlex Zinenko     llvm::function_ref<int(const Operator &)> getNumElements,
328fd407e1fSAlex Zinenko     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
329fd407e1fSAlex Zinenko         getElement) {
330fd407e1fSAlex Zinenko   assert(llvm::is_contained(
331fd407e1fSAlex Zinenko              llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
332fd407e1fSAlex Zinenko          "unsupported kind");
333fd407e1fSAlex Zinenko 
334fd407e1fSAlex Zinenko   // Traits indicating how to process variadic elements.
335fd407e1fSAlex Zinenko   std::string sameSizeTrait =
336fd407e1fSAlex Zinenko       llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
337fd407e1fSAlex Zinenko                     llvm::StringRef(kind).take_front().upper(),
338fd407e1fSAlex Zinenko                     llvm::StringRef(kind).drop_front());
339f9265de8SAlex Zinenko   std::string attrSizedTrait = attrSizedTraitForKind(kind);
340fd407e1fSAlex Zinenko 
34154c99842SMichal Terepeta   unsigned numVariableLength = getNumVariableLength(op);
342fd407e1fSAlex Zinenko 
34354c99842SMichal Terepeta   // If there is only one variable-length element group, its size can be
34454c99842SMichal Terepeta   // inferred from the total number of elements. If there are none, the
34554c99842SMichal Terepeta   // generation is straightforward.
34654c99842SMichal Terepeta   if (numVariableLength <= 1) {
347fd407e1fSAlex Zinenko     bool seenVariableLength = false;
348fd407e1fSAlex Zinenko     for (int i = 0, e = getNumElements(op); i < e; ++i) {
349fd407e1fSAlex Zinenko       const NamedTypeConstraint &element = getElement(op, i);
350fd407e1fSAlex Zinenko       if (element.isVariableLength())
351fd407e1fSAlex Zinenko         seenVariableLength = true;
352fd407e1fSAlex Zinenko       if (element.name.empty())
353fd407e1fSAlex Zinenko         continue;
354fd407e1fSAlex Zinenko       if (element.isVariableLength()) {
355fd407e1fSAlex Zinenko         os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
356fd407e1fSAlex Zinenko                                                  : opOneVariadicTemplate,
357fd407e1fSAlex Zinenko                             sanitizeName(element.name), kind,
358fd407e1fSAlex Zinenko                             getNumElements(op), i);
359fd407e1fSAlex Zinenko       } else if (seenVariableLength) {
360fd407e1fSAlex Zinenko         os << llvm::formatv(opSingleAfterVariableTemplate,
361fd407e1fSAlex Zinenko                             sanitizeName(element.name), kind,
362fd407e1fSAlex Zinenko                             getNumElements(op), i);
363fd407e1fSAlex Zinenko       } else {
364fd407e1fSAlex Zinenko         os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
365fd407e1fSAlex Zinenko                             i);
366fd407e1fSAlex Zinenko       }
367fd407e1fSAlex Zinenko     }
368fd407e1fSAlex Zinenko     return;
369fd407e1fSAlex Zinenko   }
370fd407e1fSAlex Zinenko 
371fd407e1fSAlex Zinenko   // Handle the operations where variadic groups have the same size.
372fd407e1fSAlex Zinenko   if (op.getTrait(sameSizeTrait)) {
373fd407e1fSAlex Zinenko     int numPrecedingSimple = 0;
374fd407e1fSAlex Zinenko     int numPrecedingVariadic = 0;
375fd407e1fSAlex Zinenko     for (int i = 0, e = getNumElements(op); i < e; ++i) {
376fd407e1fSAlex Zinenko       const NamedTypeConstraint &element = getElement(op, i);
377fd407e1fSAlex Zinenko       if (!element.name.empty()) {
378fd407e1fSAlex Zinenko         os << llvm::formatv(opVariadicEqualPrefixTemplate,
37954c99842SMichal Terepeta                             sanitizeName(element.name), kind, numVariableLength,
380fd407e1fSAlex Zinenko                             numPrecedingSimple, numPrecedingVariadic);
381fd407e1fSAlex Zinenko         os << llvm::formatv(element.isVariableLength()
382fd407e1fSAlex Zinenko                                 ? opVariadicEqualVariadicTemplate
383fd407e1fSAlex Zinenko                                 : opVariadicEqualSimpleTemplate,
384fd407e1fSAlex Zinenko                             kind);
385fd407e1fSAlex Zinenko       }
386fd407e1fSAlex Zinenko       if (element.isVariableLength())
387fd407e1fSAlex Zinenko         ++numPrecedingVariadic;
388fd407e1fSAlex Zinenko       else
389fd407e1fSAlex Zinenko         ++numPrecedingSimple;
390fd407e1fSAlex Zinenko     }
391fd407e1fSAlex Zinenko     return;
392fd407e1fSAlex Zinenko   }
393fd407e1fSAlex Zinenko 
394fd407e1fSAlex Zinenko   // Handle the operations where the size of groups (variadic or not) is
395fd407e1fSAlex Zinenko   // provided as an attribute. For non-variadic elements, make sure to return
396fd407e1fSAlex Zinenko   // an element rather than a singleton container.
397fd407e1fSAlex Zinenko   if (op.getTrait(attrSizedTrait)) {
398fd407e1fSAlex Zinenko     for (int i = 0, e = getNumElements(op); i < e; ++i) {
399fd407e1fSAlex Zinenko       const NamedTypeConstraint &element = getElement(op, i);
400fd407e1fSAlex Zinenko       if (element.name.empty())
401fd407e1fSAlex Zinenko         continue;
402fd407e1fSAlex Zinenko       std::string trailing;
403fd407e1fSAlex Zinenko       if (!element.isVariableLength())
404fd407e1fSAlex Zinenko         trailing = "[0]";
405fd407e1fSAlex Zinenko       else if (element.isOptional())
406fd407e1fSAlex Zinenko         trailing = std::string(
407fd407e1fSAlex Zinenko             llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
408fd407e1fSAlex Zinenko       os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
409fd407e1fSAlex Zinenko                           kind, i, trailing);
410fd407e1fSAlex Zinenko     }
411fd407e1fSAlex Zinenko     return;
412fd407e1fSAlex Zinenko   }
413fd407e1fSAlex Zinenko 
414fd407e1fSAlex Zinenko   llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
415fd407e1fSAlex Zinenko }
416fd407e1fSAlex Zinenko 
417f9265de8SAlex Zinenko /// Free function helpers accessing Operator components.
getNumOperands(const Operator & op)418f9265de8SAlex Zinenko static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
getOperand(const Operator & op,int i)419f9265de8SAlex Zinenko static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
420f9265de8SAlex Zinenko   return op.getOperand(i);
421f9265de8SAlex Zinenko }
getNumResults(const Operator & op)422f9265de8SAlex Zinenko static int getNumResults(const Operator &op) { return op.getNumResults(); }
getResult(const Operator & op,int i)423f9265de8SAlex Zinenko static const NamedTypeConstraint &getResult(const Operator &op, int i) {
424f9265de8SAlex Zinenko   return op.getResult(i);
425f9265de8SAlex Zinenko }
426f9265de8SAlex Zinenko 
427c5a6712fSAlex Zinenko /// Emits accessors to Op operands.
emitOperandAccessors(const Operator & op,raw_ostream & os)428fd407e1fSAlex Zinenko static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
42954c99842SMichal Terepeta   auto getNumVariableLengthOperands = [](const Operator &oper) {
430fd407e1fSAlex Zinenko     return oper.getNumVariableLengthOperands();
431fd407e1fSAlex Zinenko   };
43254c99842SMichal Terepeta   emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
43354c99842SMichal Terepeta                        getNumOperands, getOperand);
434fd407e1fSAlex Zinenko }
435fd407e1fSAlex Zinenko 
436c5a6712fSAlex Zinenko /// Emits accessors Op results.
emitResultAccessors(const Operator & op,raw_ostream & os)437fd407e1fSAlex Zinenko static void emitResultAccessors(const Operator &op, raw_ostream &os) {
43854c99842SMichal Terepeta   auto getNumVariableLengthResults = [](const Operator &oper) {
439fd407e1fSAlex Zinenko     return oper.getNumVariableLengthResults();
440fd407e1fSAlex Zinenko   };
44154c99842SMichal Terepeta   emitElementAccessors(op, os, "result", getNumVariableLengthResults,
44254c99842SMichal Terepeta                        getNumResults, getResult);
443f9265de8SAlex Zinenko }
444f9265de8SAlex Zinenko 
445c5a6712fSAlex Zinenko /// Emits accessors to Op attributes.
emitAttributeAccessors(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)446c5a6712fSAlex Zinenko static void emitAttributeAccessors(const Operator &op,
447c5a6712fSAlex Zinenko                                    const AttributeClasses &attributeClasses,
448c5a6712fSAlex Zinenko                                    raw_ostream &os) {
449c5a6712fSAlex Zinenko   for (const auto &namedAttr : op.getAttributes()) {
450c5a6712fSAlex Zinenko     // Skip "derived" attributes because they are just C++ functions that we
451c5a6712fSAlex Zinenko     // don't currently expose.
452c5a6712fSAlex Zinenko     if (namedAttr.attr.isDerivedAttr())
453c5a6712fSAlex Zinenko       continue;
454c5a6712fSAlex Zinenko 
455c5a6712fSAlex Zinenko     if (namedAttr.name.empty())
456c5a6712fSAlex Zinenko       continue;
457c5a6712fSAlex Zinenko 
458029e199dSAlex Zinenko     std::string sanitizedName = sanitizeName(namedAttr.name);
459029e199dSAlex Zinenko 
460c5a6712fSAlex Zinenko     // Unit attributes are handled specially.
461c5a6712fSAlex Zinenko     if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
462029e199dSAlex Zinenko       os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
463029e199dSAlex Zinenko                           namedAttr.name);
464029e199dSAlex Zinenko       os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
465029e199dSAlex Zinenko                           namedAttr.name);
466029e199dSAlex Zinenko       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
467029e199dSAlex Zinenko                           namedAttr.name);
468c5a6712fSAlex Zinenko       continue;
469c5a6712fSAlex Zinenko     }
470c5a6712fSAlex Zinenko 
471c5a6712fSAlex Zinenko     // Other kinds of attributes need a mapping to a Python type.
472c5a6712fSAlex Zinenko     if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
473c5a6712fSAlex Zinenko       continue;
474c5a6712fSAlex Zinenko 
475029e199dSAlex Zinenko     StringRef pythonType =
476029e199dSAlex Zinenko         attributeClasses.lookup(namedAttr.attr.getStorageType());
477029e199dSAlex Zinenko     if (namedAttr.attr.isOptional()) {
478029e199dSAlex Zinenko       os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
479029e199dSAlex Zinenko                           pythonType, namedAttr.name);
480029e199dSAlex Zinenko       os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
481c5a6712fSAlex Zinenko                           namedAttr.name);
482029e199dSAlex Zinenko       os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
483029e199dSAlex Zinenko                           namedAttr.name);
484029e199dSAlex Zinenko     } else {
485029e199dSAlex Zinenko       os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType,
486029e199dSAlex Zinenko                           namedAttr.name);
487029e199dSAlex Zinenko       os << llvm::formatv(attributeSetterTemplate, sanitizedName,
488029e199dSAlex Zinenko                           namedAttr.name);
489029e199dSAlex Zinenko       // Non-optional attributes cannot be deleted.
490029e199dSAlex Zinenko     }
491c5a6712fSAlex Zinenko   }
492c5a6712fSAlex Zinenko }
493c5a6712fSAlex Zinenko 
494f9265de8SAlex Zinenko /// Template for the default auto-generated builder.
49571b6b010SStella Laurenzo ///   {0} is a comma-separated list of builder arguments, including the trailing
496f9265de8SAlex Zinenko ///       `loc` and `ip`;
4978e6c55c9SStella Laurenzo ///   {1} is the code populating `operands`, `results` and `attributes`,
4988e6c55c9SStella Laurenzo ///       `successors` fields.
499f9265de8SAlex Zinenko constexpr const char *initTemplate = R"Py(
50071b6b010SStella Laurenzo   def __init__(self, {0}):
501f9265de8SAlex Zinenko     operands = []
502f9265de8SAlex Zinenko     results = []
503f9265de8SAlex Zinenko     attributes = {{}
50418fbd5feSAlex Zinenko     regions = None
50571b6b010SStella Laurenzo     {1}
506fd226c9bSStella Laurenzo     super().__init__(self.build_generic(
507fd226c9bSStella Laurenzo       attributes=attributes, results=results, operands=operands,
50818fbd5feSAlex Zinenko       successors=_ods_successors, regions=regions, loc=loc, ip=ip))
509f9265de8SAlex Zinenko )Py";
510f9265de8SAlex Zinenko 
511f9265de8SAlex Zinenko /// Template for appending a single element to the operand/result list.
512b164f23cSAlex Zinenko ///   {0} is the field name.
513b164f23cSAlex Zinenko constexpr const char *singleOperandAppendTemplate =
514b164f23cSAlex Zinenko     "operands.append(_get_op_result_or_value({0}))";
515b164f23cSAlex Zinenko constexpr const char *singleResultAppendTemplate = "results.append({0})";
516f9265de8SAlex Zinenko 
517f9265de8SAlex Zinenko /// Template for appending an optional element to the operand/result list.
518b164f23cSAlex Zinenko ///   {0} is the field name.
519b164f23cSAlex Zinenko constexpr const char *optionalAppendOperandTemplate =
520b164f23cSAlex Zinenko     "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
5216981e5ecSAlex Zinenko constexpr const char *optionalAppendAttrSizedOperandsTemplate =
5226981e5ecSAlex Zinenko     "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
5236981e5ecSAlex Zinenko     "None)";
524b164f23cSAlex Zinenko constexpr const char *optionalAppendResultTemplate =
525b164f23cSAlex Zinenko     "if {0} is not None: results.append({0})";
526f9265de8SAlex Zinenko 
527b164f23cSAlex Zinenko /// Template for appending a list of elements to the operand/result list.
528b164f23cSAlex Zinenko ///   {0} is the field name.
529b164f23cSAlex Zinenko constexpr const char *multiOperandAppendTemplate =
530b164f23cSAlex Zinenko     "operands.extend(_get_op_results_or_values({0}))";
531b164f23cSAlex Zinenko constexpr const char *multiOperandAppendPackTemplate =
532b164f23cSAlex Zinenko     "operands.append(_get_op_results_or_values({0}))";
533b164f23cSAlex Zinenko constexpr const char *multiResultAppendTemplate = "results.extend({0})";
534f9265de8SAlex Zinenko 
535c5a6712fSAlex Zinenko /// Template for setting an attribute in the operation builder.
536c5a6712fSAlex Zinenko ///   {0} is the attribute name;
537c5a6712fSAlex Zinenko ///   {1} is the builder argument name.
538c5a6712fSAlex Zinenko constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
539c5a6712fSAlex Zinenko 
540c5a6712fSAlex Zinenko /// Template for setting an optional attribute in the operation builder.
541c5a6712fSAlex Zinenko ///   {0} is the attribute name;
542c5a6712fSAlex Zinenko ///   {1} is the builder argument name.
543c5a6712fSAlex Zinenko constexpr const char *initOptionalAttributeTemplate =
544c5a6712fSAlex Zinenko     R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
545c5a6712fSAlex Zinenko 
546*989d2b51SMatthias Springer /// Template for setting an attribute with a default value in the operation
547*989d2b51SMatthias Springer /// builder.
548*989d2b51SMatthias Springer ///   {0} is the attribute name;
549*989d2b51SMatthias Springer ///   {1} is the builder argument name;
550*989d2b51SMatthias Springer ///   {2} is the default value.
551*989d2b51SMatthias Springer constexpr const char *initDefaultValuedAttributeTemplate =
552*989d2b51SMatthias Springer     R"Py(attributes["{0}"] = {1} if {1} is not None else {2})Py";
553*989d2b51SMatthias Springer 
554*989d2b51SMatthias Springer /// Template for asserting that an attribute value was provided when calling a
555*989d2b51SMatthias Springer /// builder.
556*989d2b51SMatthias Springer ///   {0} is the attribute name;
557*989d2b51SMatthias Springer ///   {1} is the builder argument name.
558*989d2b51SMatthias Springer constexpr const char *assertAttributeValueSpecified =
559*989d2b51SMatthias Springer     R"Py(assert {1} is not None, "attribute {0} must be specified")Py";
560*989d2b51SMatthias Springer 
561c5a6712fSAlex Zinenko constexpr const char *initUnitAttributeTemplate =
5628a1f1a10SStella Laurenzo     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
5638a1f1a10SStella Laurenzo       _ods_get_default_loc_context(loc)))Py";
564c5a6712fSAlex Zinenko 
5658e6c55c9SStella Laurenzo /// Template to initialize the successors list in the builder if there are any
5668e6c55c9SStella Laurenzo /// successors.
5678e6c55c9SStella Laurenzo ///   {0} is the value to initialize the successors list to.
5688e6c55c9SStella Laurenzo constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
5698e6c55c9SStella Laurenzo 
5708e6c55c9SStella Laurenzo /// Template to append or extend the list of successors in the builder.
5718e6c55c9SStella Laurenzo ///   {0} is the list method ('append' or 'extend');
5728e6c55c9SStella Laurenzo ///   {1} is the value to add.
5738e6c55c9SStella Laurenzo constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
5748e6c55c9SStella Laurenzo 
5752995d29bSAlex Zinenko /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
5762995d29bSAlex Zinenko /// result types of the given operation.
hasSameArgumentAndResultTypes(const Operator & op)5772995d29bSAlex Zinenko static bool hasSameArgumentAndResultTypes(const Operator &op) {
5782995d29bSAlex Zinenko   return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
5792995d29bSAlex Zinenko          op.getNumVariableLengthResults() == 0;
5802995d29bSAlex Zinenko }
5812995d29bSAlex Zinenko 
5822995d29bSAlex Zinenko /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
5832995d29bSAlex Zinenko /// result types of the given operation.
hasFirstAttrDerivedResultTypes(const Operator & op)5842995d29bSAlex Zinenko static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
5852995d29bSAlex Zinenko   return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
5862995d29bSAlex Zinenko          op.getNumVariableLengthResults() == 0;
5872995d29bSAlex Zinenko }
5882995d29bSAlex Zinenko 
5892995d29bSAlex Zinenko /// Returns true if the InferTypeOpInterface can be used to infer result types
5902995d29bSAlex Zinenko /// of the given operation.
hasInferTypeInterface(const Operator & op)5912995d29bSAlex Zinenko static bool hasInferTypeInterface(const Operator &op) {
5922995d29bSAlex Zinenko   return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
5932995d29bSAlex Zinenko          op.getNumRegions() == 0;
5942995d29bSAlex Zinenko }
5952995d29bSAlex Zinenko 
5962995d29bSAlex Zinenko /// Returns true if there is a trait or interface that can be used to infer
5972995d29bSAlex Zinenko /// result types of the given operation.
canInferType(const Operator & op)5982995d29bSAlex Zinenko static bool canInferType(const Operator &op) {
5992995d29bSAlex Zinenko   return hasSameArgumentAndResultTypes(op) ||
6002995d29bSAlex Zinenko          hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
6012995d29bSAlex Zinenko }
6022995d29bSAlex Zinenko 
6032995d29bSAlex Zinenko /// Populates `builderArgs` with result names if the builder is expected to
6042995d29bSAlex Zinenko /// accept them as arguments.
605c5a6712fSAlex Zinenko static void
populateBuilderArgsResults(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs)6062995d29bSAlex Zinenko populateBuilderArgsResults(const Operator &op,
6072995d29bSAlex Zinenko                            llvm::SmallVectorImpl<std::string> &builderArgs) {
6082995d29bSAlex Zinenko   if (canInferType(op))
6092995d29bSAlex Zinenko     return;
6102995d29bSAlex Zinenko 
611c5a6712fSAlex Zinenko   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
612c5a6712fSAlex Zinenko     std::string name = op.getResultName(i).str();
613fd226c9bSStella Laurenzo     if (name.empty()) {
614fd226c9bSStella Laurenzo       if (op.getNumResults() == 1) {
615fd226c9bSStella Laurenzo         // Special case for one result, make the default name be 'result'
616fd226c9bSStella Laurenzo         // to properly match the built-in result accessor.
617fd226c9bSStella Laurenzo         name = "result";
618fd226c9bSStella Laurenzo       } else {
619c5a6712fSAlex Zinenko         name = llvm::formatv("_gen_res_{0}", i);
620fd226c9bSStella Laurenzo       }
621fd226c9bSStella Laurenzo     }
622c5a6712fSAlex Zinenko     name = sanitizeName(name);
623c5a6712fSAlex Zinenko     builderArgs.push_back(name);
624c5a6712fSAlex Zinenko   }
6252995d29bSAlex Zinenko }
6262995d29bSAlex Zinenko 
6272995d29bSAlex Zinenko /// Populates `builderArgs` with the Python-compatible names of builder function
6282995d29bSAlex Zinenko /// arguments using intermixed attributes and operands in the same order as they
6292995d29bSAlex Zinenko /// appear in the `arguments` field of the op definition. Additionally,
6302995d29bSAlex Zinenko /// `operandNames` is populated with names of operands in their order of
6312995d29bSAlex Zinenko /// appearance.
6322995d29bSAlex Zinenko static void
populateBuilderArgs(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs,llvm::SmallVectorImpl<std::string> & operandNames,llvm::SmallVectorImpl<std::string> & successorArgNames)6332995d29bSAlex Zinenko populateBuilderArgs(const Operator &op,
6342995d29bSAlex Zinenko                     llvm::SmallVectorImpl<std::string> &builderArgs,
6352995d29bSAlex Zinenko                     llvm::SmallVectorImpl<std::string> &operandNames,
6362995d29bSAlex Zinenko                     llvm::SmallVectorImpl<std::string> &successorArgNames) {
6372995d29bSAlex Zinenko 
638c5a6712fSAlex Zinenko   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
639c5a6712fSAlex Zinenko     std::string name = op.getArgName(i).str();
640c5a6712fSAlex Zinenko     if (name.empty())
641c5a6712fSAlex Zinenko       name = llvm::formatv("_gen_arg_{0}", i);
642c5a6712fSAlex Zinenko     name = sanitizeName(name);
643c5a6712fSAlex Zinenko     builderArgs.push_back(name);
644c5a6712fSAlex Zinenko     if (!op.getArg(i).is<NamedAttribute *>())
645c5a6712fSAlex Zinenko       operandNames.push_back(name);
646c5a6712fSAlex Zinenko   }
6479b79f50bSJeremy Furtek }
6489b79f50bSJeremy Furtek 
6499b79f50bSJeremy Furtek /// Populates `builderArgs` with the Python-compatible names of builder function
6509b79f50bSJeremy Furtek /// successor arguments. Additionally, `successorArgNames` is also populated.
populateBuilderArgsSuccessors(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs,llvm::SmallVectorImpl<std::string> & successorArgNames)6519b79f50bSJeremy Furtek static void populateBuilderArgsSuccessors(
6529b79f50bSJeremy Furtek     const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
6539b79f50bSJeremy Furtek     llvm::SmallVectorImpl<std::string> &successorArgNames) {
6548e6c55c9SStella Laurenzo 
6558e6c55c9SStella Laurenzo   for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
6568e6c55c9SStella Laurenzo     NamedSuccessor successor = op.getSuccessor(i);
6578e6c55c9SStella Laurenzo     std::string name = std::string(successor.name);
6588e6c55c9SStella Laurenzo     if (name.empty())
6598e6c55c9SStella Laurenzo       name = llvm::formatv("_gen_successor_{0}", i);
6608e6c55c9SStella Laurenzo     name = sanitizeName(name);
6618e6c55c9SStella Laurenzo     builderArgs.push_back(name);
6628e6c55c9SStella Laurenzo     successorArgNames.push_back(name);
6638e6c55c9SStella Laurenzo   }
664c5a6712fSAlex Zinenko }
665c5a6712fSAlex Zinenko 
666*989d2b51SMatthias Springer /// Generates Python code for the default value of the given attribute.
getAttributeDefaultValue(Attribute attr)667*989d2b51SMatthias Springer static FailureOr<std::string> getAttributeDefaultValue(Attribute attr) {
668*989d2b51SMatthias Springer   assert(attr.hasDefaultValue() && "expected attribute with default value");
669*989d2b51SMatthias Springer   StringRef storageType = attr.getStorageType().trim();
670*989d2b51SMatthias Springer   StringRef defaultValCpp = attr.getDefaultValue().trim();
671*989d2b51SMatthias Springer 
672*989d2b51SMatthias Springer   // A list of commonly used attribute types and default values for which
673*989d2b51SMatthias Springer   // we can generate Python code. Extend as needed.
674*989d2b51SMatthias Springer   if (storageType.equals("::mlir::ArrayAttr") && defaultValCpp.equals("{}"))
675*989d2b51SMatthias Springer     return std::string("_ods_ir.ArrayAttr.get([])");
676*989d2b51SMatthias Springer 
677*989d2b51SMatthias Springer   // No match: Cannot generate Python code.
678*989d2b51SMatthias Springer   return failure();
679*989d2b51SMatthias Springer }
680*989d2b51SMatthias Springer 
681c5a6712fSAlex Zinenko /// Populates `builderLines` with additional lines that are required in the
682c5a6712fSAlex Zinenko /// builder to set up operation attributes. `argNames` is expected to contain
683c5a6712fSAlex Zinenko /// the names of builder arguments that correspond to op arguments, i.e. to the
684c5a6712fSAlex Zinenko /// operands and attributes in the same order as they appear in the `arguments`
685c5a6712fSAlex Zinenko /// field.
686c5a6712fSAlex Zinenko static void
populateBuilderLinesAttr(const Operator & op,llvm::ArrayRef<std::string> argNames,llvm::SmallVectorImpl<std::string> & builderLines)687c5a6712fSAlex Zinenko populateBuilderLinesAttr(const Operator &op,
688c5a6712fSAlex Zinenko                          llvm::ArrayRef<std::string> argNames,
689c5a6712fSAlex Zinenko                          llvm::SmallVectorImpl<std::string> &builderLines) {
690c5a6712fSAlex Zinenko   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
691c5a6712fSAlex Zinenko     Argument arg = op.getArg(i);
692c5a6712fSAlex Zinenko     auto *attribute = arg.dyn_cast<NamedAttribute *>();
693c5a6712fSAlex Zinenko     if (!attribute)
694c5a6712fSAlex Zinenko       continue;
695c5a6712fSAlex Zinenko 
696c5a6712fSAlex Zinenko     // Unit attributes are handled specially.
697c5a6712fSAlex Zinenko     if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
698c5a6712fSAlex Zinenko       builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
699c5a6712fSAlex Zinenko                                            attribute->name, argNames[i]));
700c5a6712fSAlex Zinenko       continue;
701c5a6712fSAlex Zinenko     }
702c5a6712fSAlex Zinenko 
703*989d2b51SMatthias Springer     // Attributes with default value are handled specially.
704*989d2b51SMatthias Springer     if (attribute->attr.hasDefaultValue()) {
705*989d2b51SMatthias Springer       // In case we cannot generate Python code for the default value, the
706*989d2b51SMatthias Springer       // attribute must be specified by the user.
707*989d2b51SMatthias Springer       FailureOr<std::string> defaultValPy =
708*989d2b51SMatthias Springer           getAttributeDefaultValue(attribute->attr);
709*989d2b51SMatthias Springer       if (succeeded(defaultValPy)) {
710*989d2b51SMatthias Springer         builderLines.push_back(llvm::formatv(initDefaultValuedAttributeTemplate,
711*989d2b51SMatthias Springer                                              attribute->name, argNames[i],
712*989d2b51SMatthias Springer                                              *defaultValPy));
713*989d2b51SMatthias Springer       } else {
714*989d2b51SMatthias Springer         builderLines.push_back(llvm::formatv(assertAttributeValueSpecified,
715*989d2b51SMatthias Springer                                              attribute->name, argNames[i]));
716*989d2b51SMatthias Springer         builderLines.push_back(
717*989d2b51SMatthias Springer             llvm::formatv(initAttributeTemplate, attribute->name, argNames[i]));
718*989d2b51SMatthias Springer       }
719*989d2b51SMatthias Springer       continue;
720*989d2b51SMatthias Springer     }
721*989d2b51SMatthias Springer 
722c5a6712fSAlex Zinenko     builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
723c5a6712fSAlex Zinenko                                              ? initOptionalAttributeTemplate
724c5a6712fSAlex Zinenko                                              : initAttributeTemplate,
725c5a6712fSAlex Zinenko                                          attribute->name, argNames[i]));
726c5a6712fSAlex Zinenko   }
727c5a6712fSAlex Zinenko }
728c5a6712fSAlex Zinenko 
729c5a6712fSAlex Zinenko /// Populates `builderLines` with additional lines that are required in the
7308e6c55c9SStella Laurenzo /// builder to set up successors. successorArgNames is expected to correspond
7318e6c55c9SStella Laurenzo /// to the Python argument name for each successor on the op.
populateBuilderLinesSuccessors(const Operator & op,llvm::ArrayRef<std::string> successorArgNames,llvm::SmallVectorImpl<std::string> & builderLines)7328e6c55c9SStella Laurenzo static void populateBuilderLinesSuccessors(
7338e6c55c9SStella Laurenzo     const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
7348e6c55c9SStella Laurenzo     llvm::SmallVectorImpl<std::string> &builderLines) {
7358e6c55c9SStella Laurenzo   if (successorArgNames.empty()) {
7368e6c55c9SStella Laurenzo     builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
7378e6c55c9SStella Laurenzo     return;
7388e6c55c9SStella Laurenzo   }
7398e6c55c9SStella Laurenzo 
7408e6c55c9SStella Laurenzo   builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
7418e6c55c9SStella Laurenzo   for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
7428e6c55c9SStella Laurenzo     auto &argName = successorArgNames[i];
7438e6c55c9SStella Laurenzo     const NamedSuccessor &successor = op.getSuccessor(i);
7448e6c55c9SStella Laurenzo     builderLines.push_back(
7458e6c55c9SStella Laurenzo         llvm::formatv(addSuccessorTemplate,
7468e6c55c9SStella Laurenzo                       successor.isVariadic() ? "extend" : "append", argName));
7478e6c55c9SStella Laurenzo   }
7488e6c55c9SStella Laurenzo }
7498e6c55c9SStella Laurenzo 
7508e6c55c9SStella Laurenzo /// Populates `builderLines` with additional lines that are required in the
751b164f23cSAlex Zinenko /// builder to set up op operands.
752b164f23cSAlex Zinenko static void
populateBuilderLinesOperand(const Operator & op,llvm::ArrayRef<std::string> names,llvm::SmallVectorImpl<std::string> & builderLines)753b164f23cSAlex Zinenko populateBuilderLinesOperand(const Operator &op,
754b164f23cSAlex Zinenko                             llvm::ArrayRef<std::string> names,
755b164f23cSAlex Zinenko                             llvm::SmallVectorImpl<std::string> &builderLines) {
756b164f23cSAlex Zinenko   bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
757f9265de8SAlex Zinenko 
758f9265de8SAlex Zinenko   // For each element, find or generate a name.
759b164f23cSAlex Zinenko   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
760b164f23cSAlex Zinenko     const NamedTypeConstraint &element = op.getOperand(i);
761c5a6712fSAlex Zinenko     std::string name = names[i];
762f9265de8SAlex Zinenko 
763f9265de8SAlex Zinenko     // Choose the formatting string based on the element kind.
76471b6b010SStella Laurenzo     llvm::StringRef formatString;
765f9265de8SAlex Zinenko     if (!element.isVariableLength()) {
766b164f23cSAlex Zinenko       formatString = singleOperandAppendTemplate;
767f9265de8SAlex Zinenko     } else if (element.isOptional()) {
7686981e5ecSAlex Zinenko       if (sizedSegments) {
7696981e5ecSAlex Zinenko         formatString = optionalAppendAttrSizedOperandsTemplate;
7706981e5ecSAlex Zinenko       } else {
771b164f23cSAlex Zinenko         formatString = optionalAppendOperandTemplate;
7726981e5ecSAlex Zinenko       }
773f9265de8SAlex Zinenko     } else {
774f9265de8SAlex Zinenko       assert(element.isVariadic() && "unhandled element group type");
775b164f23cSAlex Zinenko       // If emitting with sizedSegments, then we add the actual list-typed
776b164f23cSAlex Zinenko       // element. Otherwise, we extend the actual operands.
77771b6b010SStella Laurenzo       if (sizedSegments) {
778b164f23cSAlex Zinenko         formatString = multiOperandAppendPackTemplate;
77971b6b010SStella Laurenzo       } else {
780b164f23cSAlex Zinenko         formatString = multiOperandAppendTemplate;
78171b6b010SStella Laurenzo       }
782f9265de8SAlex Zinenko     }
783f9265de8SAlex Zinenko 
784b164f23cSAlex Zinenko     builderLines.push_back(llvm::formatv(formatString.data(), name));
785b164f23cSAlex Zinenko   }
786b164f23cSAlex Zinenko }
787b164f23cSAlex Zinenko 
7882995d29bSAlex Zinenko /// Python code template for deriving the operation result types from its
7892995d29bSAlex Zinenko /// attribute:
7902995d29bSAlex Zinenko ///   - {0} is the name of the attribute from which to derive the types.
7912995d29bSAlex Zinenko constexpr const char *deriveTypeFromAttrTemplate =
7922995d29bSAlex Zinenko     R"PY(_ods_result_type_source_attr = attributes["{0}"]
7932995d29bSAlex Zinenko _ods_derived_result_type = (
7942995d29bSAlex Zinenko     _ods_ir.TypeAttr(_ods_result_type_source_attr).value
7952995d29bSAlex Zinenko     if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
7962995d29bSAlex Zinenko     _ods_result_type_source_attr.type))PY";
7972995d29bSAlex Zinenko 
7982995d29bSAlex Zinenko /// Python code template appending {0} type {1} times to the results list.
7992995d29bSAlex Zinenko constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
8002995d29bSAlex Zinenko 
8012995d29bSAlex Zinenko /// Python code template for inferring the operation results using the
8022995d29bSAlex Zinenko /// corresponding interface:
8032995d29bSAlex Zinenko ///   - {0} is the name of the class for which the types are inferred.
8042995d29bSAlex Zinenko constexpr const char *inferTypeInterfaceTemplate =
8052995d29bSAlex Zinenko     R"PY(_ods_context = _ods_get_default_loc_context(loc)
8062995d29bSAlex Zinenko results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
8072995d29bSAlex Zinenko     operands=operands,
8082995d29bSAlex Zinenko     attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
8092995d29bSAlex Zinenko     context=_ods_context,
8102995d29bSAlex Zinenko     loc=loc)
8112995d29bSAlex Zinenko )PY";
8122995d29bSAlex Zinenko 
8132995d29bSAlex Zinenko /// Appends the given multiline string as individual strings into
8142995d29bSAlex Zinenko /// `builderLines`.
appendLineByLine(StringRef string,llvm::SmallVectorImpl<std::string> & builderLines)8152995d29bSAlex Zinenko static void appendLineByLine(StringRef string,
8162995d29bSAlex Zinenko                              llvm::SmallVectorImpl<std::string> &builderLines) {
8172995d29bSAlex Zinenko 
8182995d29bSAlex Zinenko   std::pair<StringRef, StringRef> split = std::make_pair(string, string);
8192995d29bSAlex Zinenko   do {
8202995d29bSAlex Zinenko     split = split.second.split('\n');
8212995d29bSAlex Zinenko     builderLines.push_back(split.first.str());
8222995d29bSAlex Zinenko   } while (!split.second.empty());
8232995d29bSAlex Zinenko }
8242995d29bSAlex Zinenko 
825b164f23cSAlex Zinenko /// Populates `builderLines` with additional lines that are required in the
826b164f23cSAlex Zinenko /// builder to set up op results.
827b164f23cSAlex Zinenko static void
populateBuilderLinesResult(const Operator & op,llvm::ArrayRef<std::string> names,llvm::SmallVectorImpl<std::string> & builderLines)828b164f23cSAlex Zinenko populateBuilderLinesResult(const Operator &op,
829b164f23cSAlex Zinenko                            llvm::ArrayRef<std::string> names,
830b164f23cSAlex Zinenko                            llvm::SmallVectorImpl<std::string> &builderLines) {
831b164f23cSAlex Zinenko   bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
832b164f23cSAlex Zinenko 
8332995d29bSAlex Zinenko   if (hasSameArgumentAndResultTypes(op)) {
8342995d29bSAlex Zinenko     builderLines.push_back(llvm::formatv(
8352995d29bSAlex Zinenko         appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
8362995d29bSAlex Zinenko     return;
8372995d29bSAlex Zinenko   }
8382995d29bSAlex Zinenko 
8392995d29bSAlex Zinenko   if (hasFirstAttrDerivedResultTypes(op)) {
8402995d29bSAlex Zinenko     const NamedAttribute &firstAttr = op.getAttribute(0);
8412995d29bSAlex Zinenko     assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
8422995d29bSAlex Zinenko                                       "from which the type is derived");
8432995d29bSAlex Zinenko     appendLineByLine(
8442995d29bSAlex Zinenko         llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
8452995d29bSAlex Zinenko         builderLines);
8462995d29bSAlex Zinenko     builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
8472995d29bSAlex Zinenko                                          "_ods_derived_result_type",
8482995d29bSAlex Zinenko                                          op.getNumResults()));
8492995d29bSAlex Zinenko     return;
8502995d29bSAlex Zinenko   }
8512995d29bSAlex Zinenko 
8522995d29bSAlex Zinenko   if (hasInferTypeInterface(op)) {
8532995d29bSAlex Zinenko     appendLineByLine(
8542995d29bSAlex Zinenko         llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
8552995d29bSAlex Zinenko         builderLines);
8562995d29bSAlex Zinenko     return;
8572995d29bSAlex Zinenko   }
8582995d29bSAlex Zinenko 
859b164f23cSAlex Zinenko   // For each element, find or generate a name.
860b164f23cSAlex Zinenko   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
861b164f23cSAlex Zinenko     const NamedTypeConstraint &element = op.getResult(i);
862b164f23cSAlex Zinenko     std::string name = names[i];
863b164f23cSAlex Zinenko 
864b164f23cSAlex Zinenko     // Choose the formatting string based on the element kind.
865b164f23cSAlex Zinenko     llvm::StringRef formatString;
866b164f23cSAlex Zinenko     if (!element.isVariableLength()) {
867b164f23cSAlex Zinenko       formatString = singleResultAppendTemplate;
868b164f23cSAlex Zinenko     } else if (element.isOptional()) {
869b164f23cSAlex Zinenko       formatString = optionalAppendResultTemplate;
870b164f23cSAlex Zinenko     } else {
871b164f23cSAlex Zinenko       assert(element.isVariadic() && "unhandled element group type");
872b164f23cSAlex Zinenko       // If emitting with sizedSegments, then we add the actual list-typed
873b164f23cSAlex Zinenko       // element. Otherwise, we extend the actual operands.
874b164f23cSAlex Zinenko       if (sizedSegments) {
875b164f23cSAlex Zinenko         formatString = singleResultAppendTemplate;
876b164f23cSAlex Zinenko       } else {
877b164f23cSAlex Zinenko         formatString = multiResultAppendTemplate;
878b164f23cSAlex Zinenko       }
879b164f23cSAlex Zinenko     }
880b164f23cSAlex Zinenko 
881b164f23cSAlex Zinenko     builderLines.push_back(llvm::formatv(formatString.data(), name));
882f9265de8SAlex Zinenko   }
883f9265de8SAlex Zinenko }
884f9265de8SAlex Zinenko 
88518fbd5feSAlex Zinenko /// If the operation has variadic regions, adds a builder argument to specify
88618fbd5feSAlex Zinenko /// the number of those regions and builder lines to forward it to the generic
88718fbd5feSAlex Zinenko /// constructor.
88818fbd5feSAlex Zinenko static void
populateBuilderRegions(const Operator & op,llvm::SmallVectorImpl<std::string> & builderArgs,llvm::SmallVectorImpl<std::string> & builderLines)88918fbd5feSAlex Zinenko populateBuilderRegions(const Operator &op,
89018fbd5feSAlex Zinenko                        llvm::SmallVectorImpl<std::string> &builderArgs,
89118fbd5feSAlex Zinenko                        llvm::SmallVectorImpl<std::string> &builderLines) {
89218fbd5feSAlex Zinenko   if (op.hasNoVariadicRegions())
89318fbd5feSAlex Zinenko     return;
89418fbd5feSAlex Zinenko 
89518fbd5feSAlex Zinenko   // This is currently enforced when Operator is constructed.
89618fbd5feSAlex Zinenko   assert(op.getNumVariadicRegions() == 1 &&
89718fbd5feSAlex Zinenko          op.getRegion(op.getNumRegions() - 1).isVariadic() &&
89818fbd5feSAlex Zinenko          "expected the last region to be varidic");
89918fbd5feSAlex Zinenko 
90018fbd5feSAlex Zinenko   const NamedRegion &region = op.getRegion(op.getNumRegions() - 1);
90118fbd5feSAlex Zinenko   std::string name =
90218fbd5feSAlex Zinenko       ("num_" + region.name.take_front().lower() + region.name.drop_front())
90318fbd5feSAlex Zinenko           .str();
90418fbd5feSAlex Zinenko   builderArgs.push_back(name);
90518fbd5feSAlex Zinenko   builderLines.push_back(
90618fbd5feSAlex Zinenko       llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
90718fbd5feSAlex Zinenko }
90818fbd5feSAlex Zinenko 
909f9265de8SAlex Zinenko /// Emits a default builder constructing an operation from the list of its
910f9265de8SAlex Zinenko /// result types, followed by a list of its operands.
emitDefaultOpBuilder(const Operator & op,raw_ostream & os)911f9265de8SAlex Zinenko static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
912f9265de8SAlex Zinenko   // If we are asked to skip default builders, comply.
913f9265de8SAlex Zinenko   if (op.skipDefaultBuilders())
914f9265de8SAlex Zinenko     return;
915f9265de8SAlex Zinenko 
9168e6c55c9SStella Laurenzo   llvm::SmallVector<std::string> builderArgs;
9178e6c55c9SStella Laurenzo   llvm::SmallVector<std::string> builderLines;
9188e6c55c9SStella Laurenzo   llvm::SmallVector<std::string> operandArgNames;
9198e6c55c9SStella Laurenzo   llvm::SmallVector<std::string> successorArgNames;
920c5a6712fSAlex Zinenko   builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
9218e6c55c9SStella Laurenzo                       op.getNumNativeAttributes() + op.getNumSuccessors());
9222995d29bSAlex Zinenko   populateBuilderArgsResults(op, builderArgs);
9232995d29bSAlex Zinenko   size_t numResultArgs = builderArgs.size();
9248e6c55c9SStella Laurenzo   populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
9259b79f50bSJeremy Furtek   size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
9269b79f50bSJeremy Furtek   populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
9278e6c55c9SStella Laurenzo 
928b164f23cSAlex Zinenko   populateBuilderLinesOperand(op, operandArgNames, builderLines);
929c5a6712fSAlex Zinenko   populateBuilderLinesAttr(
9302995d29bSAlex Zinenko       op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs),
9312995d29bSAlex Zinenko       builderLines);
9322995d29bSAlex Zinenko   populateBuilderLinesResult(
9332995d29bSAlex Zinenko       op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs),
934c5a6712fSAlex Zinenko       builderLines);
9358e6c55c9SStella Laurenzo   populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
93618fbd5feSAlex Zinenko   populateBuilderRegions(op, builderArgs, builderLines);
937f9265de8SAlex Zinenko 
9389b79f50bSJeremy Furtek   // Layout of builderArgs vector elements:
9399b79f50bSJeremy Furtek   // [ result_args  operand_attr_args successor_args regions ]
9409b79f50bSJeremy Furtek 
9419b79f50bSJeremy Furtek   // Determine whether the argument corresponding to a given index into the
9429b79f50bSJeremy Furtek   // builderArgs vector is a python keyword argument or not.
9439b79f50bSJeremy Furtek   auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
9449b79f50bSJeremy Furtek     // All result, successor, and region arguments are positional arguments.
9459b79f50bSJeremy Furtek     if ((builderArgIndex < numResultArgs) ||
9469b79f50bSJeremy Furtek         (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
9479b79f50bSJeremy Furtek       return false;
9489b79f50bSJeremy Furtek     // Keyword arguments:
9499b79f50bSJeremy Furtek     // - optional named attributes (including unit attributes)
9509b79f50bSJeremy Furtek     // - default-valued named attributes
9519b79f50bSJeremy Furtek     // - optional operands
9529b79f50bSJeremy Furtek     Argument a = op.getArg(builderArgIndex - numResultArgs);
9539b79f50bSJeremy Furtek     if (auto *nattr = a.dyn_cast<NamedAttribute *>())
9549b79f50bSJeremy Furtek       return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
955eacfd047SMehdi Amini     if (auto *ntype = a.dyn_cast<NamedTypeConstraint *>())
9569b79f50bSJeremy Furtek       return ntype->isOptional();
9579b79f50bSJeremy Furtek     return false;
9589b79f50bSJeremy Furtek   };
9599b79f50bSJeremy Furtek 
9609b79f50bSJeremy Furtek   // StringRefs in functionArgs refer to strings allocated by builderArgs.
9619b79f50bSJeremy Furtek   llvm::SmallVector<llvm::StringRef> functionArgs;
9629b79f50bSJeremy Furtek 
9639b79f50bSJeremy Furtek   // Add positional arguments.
9649b79f50bSJeremy Furtek   for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
9659b79f50bSJeremy Furtek     if (!isKeywordArgFn(i))
9669b79f50bSJeremy Furtek       functionArgs.push_back(builderArgs[i]);
9679b79f50bSJeremy Furtek   }
9689b79f50bSJeremy Furtek 
9699b79f50bSJeremy Furtek   // Add a bare '*' to indicate that all following arguments must be keyword
9709b79f50bSJeremy Furtek   // arguments.
9719b79f50bSJeremy Furtek   functionArgs.push_back("*");
9729b79f50bSJeremy Furtek 
9739b79f50bSJeremy Furtek   // Add a default 'None' value to each keyword arg string, and then add to the
9749b79f50bSJeremy Furtek   // function args list.
9759b79f50bSJeremy Furtek   for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
9769b79f50bSJeremy Furtek     if (isKeywordArgFn(i)) {
9779b79f50bSJeremy Furtek       builderArgs[i].append("=None");
9789b79f50bSJeremy Furtek       functionArgs.push_back(builderArgs[i]);
9799b79f50bSJeremy Furtek     }
9809b79f50bSJeremy Furtek   }
9819b79f50bSJeremy Furtek   functionArgs.push_back("loc=None");
9829b79f50bSJeremy Furtek   functionArgs.push_back("ip=None");
9839b79f50bSJeremy Furtek   os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
984f9265de8SAlex Zinenko                       llvm::join(builderLines, "\n    "));
985fd407e1fSAlex Zinenko }
986fd407e1fSAlex Zinenko 
constructAttributeMapping(const llvm::RecordKeeper & records,AttributeClasses & attributeClasses)987c5a6712fSAlex Zinenko static void constructAttributeMapping(const llvm::RecordKeeper &records,
988c5a6712fSAlex Zinenko                                       AttributeClasses &attributeClasses) {
989c5a6712fSAlex Zinenko   for (const llvm::Record *rec :
990c5a6712fSAlex Zinenko        records.getAllDerivedDefinitions("PythonAttr")) {
991c5a6712fSAlex Zinenko     attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
992c5a6712fSAlex Zinenko                                  rec->getValueAsString("pythonType").trim());
993c5a6712fSAlex Zinenko   }
994c5a6712fSAlex Zinenko }
995c5a6712fSAlex Zinenko 
emitSegmentSpec(const Operator & op,const char * kind,llvm::function_ref<int (const Operator &)> getNumElements,llvm::function_ref<const NamedTypeConstraint & (const Operator &,int)> getElement,raw_ostream & os)99671b6b010SStella Laurenzo static void emitSegmentSpec(
99771b6b010SStella Laurenzo     const Operator &op, const char *kind,
99871b6b010SStella Laurenzo     llvm::function_ref<int(const Operator &)> getNumElements,
99971b6b010SStella Laurenzo     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
100071b6b010SStella Laurenzo         getElement,
100171b6b010SStella Laurenzo     raw_ostream &os) {
100271b6b010SStella Laurenzo   std::string segmentSpec("[");
100371b6b010SStella Laurenzo   for (int i = 0, e = getNumElements(op); i < e; ++i) {
100471b6b010SStella Laurenzo     const NamedTypeConstraint &element = getElement(op, i);
10056981e5ecSAlex Zinenko     if (element.isOptional()) {
100671b6b010SStella Laurenzo       segmentSpec.append("0,");
10076981e5ecSAlex Zinenko     } else if (element.isVariadic()) {
10086981e5ecSAlex Zinenko       segmentSpec.append("-1,");
100971b6b010SStella Laurenzo     } else {
101071b6b010SStella Laurenzo       segmentSpec.append("1,");
101171b6b010SStella Laurenzo     }
101271b6b010SStella Laurenzo   }
101371b6b010SStella Laurenzo   segmentSpec.append("]");
101471b6b010SStella Laurenzo 
101571b6b010SStella Laurenzo   os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
101671b6b010SStella Laurenzo }
101771b6b010SStella Laurenzo 
emitRegionAttributes(const Operator & op,raw_ostream & os)101871b6b010SStella Laurenzo static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
101971b6b010SStella Laurenzo   // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
102071b6b010SStella Laurenzo   // Note that the base OpView class defines this as (0, True).
102171b6b010SStella Laurenzo   unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
102271b6b010SStella Laurenzo   os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
102371b6b010SStella Laurenzo                       op.hasNoVariadicRegions() ? "True" : "False");
102471b6b010SStella Laurenzo }
102571b6b010SStella Laurenzo 
102618fbd5feSAlex Zinenko /// Emits named accessors to regions.
emitRegionAccessors(const Operator & op,raw_ostream & os)102718fbd5feSAlex Zinenko static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
102889de9cc8SMehdi Amini   for (const auto &en : llvm::enumerate(op.getRegions())) {
102918fbd5feSAlex Zinenko     const NamedRegion &region = en.value();
103018fbd5feSAlex Zinenko     if (region.name.empty())
103118fbd5feSAlex Zinenko       continue;
103218fbd5feSAlex Zinenko 
103318fbd5feSAlex Zinenko     assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
103418fbd5feSAlex Zinenko            "expected only the last region to be variadic");
103518fbd5feSAlex Zinenko     os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
103618fbd5feSAlex Zinenko                         std::to_string(en.index()) +
103718fbd5feSAlex Zinenko                             (region.isVariadic() ? ":" : ""));
103818fbd5feSAlex Zinenko   }
103918fbd5feSAlex Zinenko }
104018fbd5feSAlex Zinenko 
1041fd407e1fSAlex Zinenko /// Emits bindings for a specific Op to the given output stream.
emitOpBindings(const Operator & op,const AttributeClasses & attributeClasses,raw_ostream & os)1042c5a6712fSAlex Zinenko static void emitOpBindings(const Operator &op,
1043c5a6712fSAlex Zinenko                            const AttributeClasses &attributeClasses,
1044c5a6712fSAlex Zinenko                            raw_ostream &os) {
1045fd407e1fSAlex Zinenko   os << llvm::formatv(opClassTemplate, op.getCppClassName(),
1046fd407e1fSAlex Zinenko                       op.getOperationName());
104771b6b010SStella Laurenzo 
104871b6b010SStella Laurenzo   // Sized segments.
104971b6b010SStella Laurenzo   if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
105071b6b010SStella Laurenzo     emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
105171b6b010SStella Laurenzo   }
105271b6b010SStella Laurenzo   if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
105371b6b010SStella Laurenzo     emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
105471b6b010SStella Laurenzo   }
105571b6b010SStella Laurenzo 
105671b6b010SStella Laurenzo   emitRegionAttributes(op, os);
1057f9265de8SAlex Zinenko   emitDefaultOpBuilder(op, os);
1058fd407e1fSAlex Zinenko   emitOperandAccessors(op, os);
1059c5a6712fSAlex Zinenko   emitAttributeAccessors(op, attributeClasses, os);
1060fd407e1fSAlex Zinenko   emitResultAccessors(op, os);
106118fbd5feSAlex Zinenko   emitRegionAccessors(op, os);
1062fd407e1fSAlex Zinenko }
1063fd407e1fSAlex Zinenko 
1064fd407e1fSAlex Zinenko /// Emits bindings for the dialect specified in the command line, including file
1065fd407e1fSAlex Zinenko /// headers and utilities. Returns `false` on success to comply with Tablegen
1066fd407e1fSAlex Zinenko /// registration requirements.
emitAllOps(const llvm::RecordKeeper & records,raw_ostream & os)1067fd407e1fSAlex Zinenko static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
1068fd407e1fSAlex Zinenko   if (clDialectName.empty())
1069fd407e1fSAlex Zinenko     llvm::PrintFatalError("dialect name not provided");
1070fd407e1fSAlex Zinenko 
1071c5a6712fSAlex Zinenko   AttributeClasses attributeClasses;
1072c5a6712fSAlex Zinenko   constructAttributeMapping(records, attributeClasses);
1073c5a6712fSAlex Zinenko 
10743f71765aSAlex Zinenko   bool isExtension = !clDialectExtensionName.empty();
10753f71765aSAlex Zinenko   os << llvm::formatv(fileHeader, isExtension
10763f71765aSAlex Zinenko                                       ? clDialectExtensionName.getValue()
10773f71765aSAlex Zinenko                                       : clDialectName.getValue());
10783f71765aSAlex Zinenko   if (isExtension)
10793f71765aSAlex Zinenko     os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
10803f71765aSAlex Zinenko   else
1081fd407e1fSAlex Zinenko     os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
1082922b26cdSMehdi Amini 
1083fd407e1fSAlex Zinenko   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
1084fd407e1fSAlex Zinenko     Operator op(rec);
1085fd407e1fSAlex Zinenko     if (op.getDialectName() == clDialectName.getValue())
1086c5a6712fSAlex Zinenko       emitOpBindings(op, attributeClasses, os);
1087fd407e1fSAlex Zinenko   }
1088fd407e1fSAlex Zinenko   return false;
1089fd407e1fSAlex Zinenko }
1090fd407e1fSAlex Zinenko 
1091fd407e1fSAlex Zinenko static GenRegistration
1092fd407e1fSAlex Zinenko     genPythonBindings("gen-python-op-bindings",
1093fd407e1fSAlex Zinenko                       "Generate Python bindings for MLIR Ops", &emitAllOps);
1094