1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 
25 /// File header and includes.
26 constexpr const char *fileHeader = R"Py(
27 # Autogenerated by mlir-tblgen; don't manually edit.
28 
29 from . import _cext
30 from . import _segmented_accessor, _equally_sized_accessor
31 _ir = _cext.ir
32 )Py";
33 
34 /// Template for dialect class:
35 ///   {0} is the dialect namespace.
36 constexpr const char *dialectClassTemplate = R"Py(
37 @_cext.register_dialect
38 class _Dialect(_ir.Dialect):
39   DIALECT_NAMESPACE = "{0}"
40   pass
41 
42 )Py";
43 
44 /// Template for operation class:
45 ///   {0} is the Python class name;
46 ///   {1} is the operation name.
47 constexpr const char *opClassTemplate = R"Py(
48 @_cext.register_operation(_Dialect)
49 class {0}(_ir.OpView):
50   OPERATION_NAME = "{1}"
51 )Py";
52 
53 /// Template for single-element accessor:
54 ///   {0} is the name of the accessor;
55 ///   {1} is either 'operand' or 'result';
56 ///   {2} is the position in the element list.
57 constexpr const char *opSingleTemplate = R"Py(
58   @property
59   def {0}(self):
60     return self.operation.{1}s[{2}]
61 )Py";
62 
63 /// Template for single-element accessor after a variable-length group:
64 ///   {0} is the name of the accessor;
65 ///   {1} is either 'operand' or 'result';
66 ///   {2} is the total number of element groups;
67 ///   {3} is the position of the current group in the group list.
68 /// This works for both a single variadic group (non-negative length) and an
69 /// single optional element (zero length if the element is absent).
70 constexpr const char *opSingleAfterVariableTemplate = R"Py(
71   @property
72   def {0}(self):
73     variadic_group_length = len(self.operation.{1}s) - {2} + 1
74     return self.operation.{1}s[{3} + variadic_group_length - 1]
75 )Py";
76 
77 /// Template for an optional element accessor:
78 ///   {0} is the name of the accessor;
79 ///   {1} is either 'operand' or 'result';
80 ///   {2} is the total number of element groups;
81 ///   {3} is the position of the current group in the group list.
82 constexpr const char *opOneOptionalTemplate = R"Py(
83   @property
84   def {0}(self);
85     return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2}
86                                     else None
87 )Py";
88 
89 /// Template for the variadic group accessor in the single variadic group case:
90 ///   {0} is the name of the accessor;
91 ///   {1} is either 'operand' or 'result';
92 ///   {2} is the total number of element groups;
93 ///   {3} is the position of the current group in the group list.
94 constexpr const char *opOneVariadicTemplate = R"Py(
95   @property
96   def {0}(self):
97     variadic_group_length = len(self.operation.{1}s) - {2} + 1
98     return self.operation.{1}s[{3}:{3} + variadic_group_length]
99 )Py";
100 
101 /// First part of the template for equally-sized variadic group accessor:
102 ///   {0} is the name of the accessor;
103 ///   {1} is either 'operand' or 'result';
104 ///   {2} is the total number of variadic groups;
105 ///   {3} is the number of non-variadic groups preceding the current group;
106 ///   {3} is the number of variadic groups preceding the current group.
107 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
108   @property
109   def {0}(self):
110     start, pg = _equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
111 
112 /// Second part of the template for equally-sized case, accessing a single
113 /// element:
114 ///   {0} is either 'operand' or 'result'.
115 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
116     return self.operation.{0}s[start]
117 )Py";
118 
119 /// Second part of the template for equally-sized case, accessing a variadic
120 /// group:
121 ///   {0} is either 'operand' or 'result'.
122 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
123     return self.operation.{0}s[start:start + pg]
124 )Py";
125 
126 /// Template for an attribute-sized group accessor:
127 ///   {0} is the name of the accessor;
128 ///   {1} is either 'operand' or 'result';
129 ///   {2} is the position of the group in the group list;
130 ///   {3} is a return suffix (expected [0] for single-element, empty for
131 ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
132 constexpr const char *opVariadicSegmentTemplate = R"Py(
133   @property
134   def {0}(self):
135     {1}_range = _segmented_accessor(
136          self.operation.{1}s,
137          self.operation.attributes["{1}_segment_sizes"], {2})
138     return {1}_range{3}
139 )Py";
140 
141 /// Template for a suffix when accessing an optional element in the
142 /// attribute-sized case:
143 ///   {0} is either 'operand' or 'result';
144 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
145     R"Py([0] if len({0}_range) > 0 else None)Py";
146 
147 static llvm::cl::OptionCategory
148     clOpPythonBindingCat("Options for -gen-python-op-bindings");
149 
150 static llvm::cl::opt<std::string>
151     clDialectName("bind-dialect",
152                   llvm::cl::desc("The dialect to run the generator for"),
153                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
154 
155 /// Checks whether `str` is a Python keyword.
156 static bool isPythonKeyword(StringRef str) {
157   static llvm::StringSet<> keywords(
158       {"and",   "as",     "assert",   "break", "class",  "continue",
159        "def",   "del",    "elif",     "else",  "except", "finally",
160        "for",   "from",   "global",   "if",    "import", "in",
161        "is",    "lambda", "nonlocal", "not",   "or",     "pass",
162        "raise", "return", "try",      "while", "with",   "yield"});
163   return keywords.contains(str);
164 };
165 
166 /// Modifies the `name` in a way that it becomes suitable for Python bindings
167 /// (does not change the `name` if it already is suitable) and returns the
168 /// modified version.
169 static std::string sanitizeName(StringRef name) {
170   if (isPythonKeyword(name))
171     return (name + "_").str();
172   return name.str();
173 }
174 
175 /// Emits accessors to "elements" of an Op definition. Currently, the supported
176 /// elements are operands and results, indicated by `kind`, which must be either
177 /// `operand` or `result` and is used verbatim in the emitted code.
178 static void emitElementAccessors(
179     const Operator &op, raw_ostream &os, const char *kind,
180     llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
181     llvm::function_ref<int(const Operator &)> getNumElements,
182     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
183         getElement) {
184   assert(llvm::is_contained(
185              llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
186          "unsupported kind");
187 
188   // Traits indicating how to process variadic elements.
189   std::string sameSizeTrait =
190       llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
191                     llvm::StringRef(kind).take_front().upper(),
192                     llvm::StringRef(kind).drop_front());
193   std::string attrSizedTrait =
194       llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
195                     llvm::StringRef(kind).take_front().upper(),
196                     llvm::StringRef(kind).drop_front());
197 
198   unsigned numVariadic = getNumVariadic(op);
199 
200   // If there is only one variadic element group, its size can be inferred from
201   // the total number of elements. If there are none, the generation is
202   // straightforward.
203   if (numVariadic <= 1) {
204     bool seenVariableLength = false;
205     for (int i = 0, e = getNumElements(op); i < e; ++i) {
206       const NamedTypeConstraint &element = getElement(op, i);
207       if (element.isVariableLength())
208         seenVariableLength = true;
209       if (element.name.empty())
210         continue;
211       if (element.isVariableLength()) {
212         os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
213                                                  : opOneVariadicTemplate,
214                             sanitizeName(element.name), kind,
215                             getNumElements(op), i);
216       } else if (seenVariableLength) {
217         os << llvm::formatv(opSingleAfterVariableTemplate,
218                             sanitizeName(element.name), kind,
219                             getNumElements(op), i);
220       } else {
221         os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
222                             i);
223       }
224     }
225     return;
226   }
227 
228   // Handle the operations where variadic groups have the same size.
229   if (op.getTrait(sameSizeTrait)) {
230     int numPrecedingSimple = 0;
231     int numPrecedingVariadic = 0;
232     for (int i = 0, e = getNumElements(op); i < e; ++i) {
233       const NamedTypeConstraint &element = getElement(op, i);
234       if (!element.name.empty()) {
235         os << llvm::formatv(opVariadicEqualPrefixTemplate,
236                             sanitizeName(element.name), kind, numVariadic,
237                             numPrecedingSimple, numPrecedingVariadic);
238         os << llvm::formatv(element.isVariableLength()
239                                 ? opVariadicEqualVariadicTemplate
240                                 : opVariadicEqualSimpleTemplate,
241                             kind);
242       }
243       if (element.isVariableLength())
244         ++numPrecedingVariadic;
245       else
246         ++numPrecedingSimple;
247     }
248     return;
249   }
250 
251   // Handle the operations where the size of groups (variadic or not) is
252   // provided as an attribute. For non-variadic elements, make sure to return
253   // an element rather than a singleton container.
254   if (op.getTrait(attrSizedTrait)) {
255     for (int i = 0, e = getNumElements(op); i < e; ++i) {
256       const NamedTypeConstraint &element = getElement(op, i);
257       if (element.name.empty())
258         continue;
259       std::string trailing;
260       if (!element.isVariableLength())
261         trailing = "[0]";
262       else if (element.isOptional())
263         trailing = std::string(
264             llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
265       os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
266                           kind, i, trailing);
267     }
268     return;
269   }
270 
271   llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
272 }
273 
274 /// Emits accessor to Op operands.
275 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
276   auto getNumVariadic = [](const Operator &oper) {
277     return oper.getNumVariableLengthOperands();
278   };
279   auto getNumElements = [](const Operator &oper) {
280     return oper.getNumOperands();
281   };
282   auto getElement = [](const Operator &oper,
283                        int i) -> const NamedTypeConstraint & {
284     return oper.getOperand(i);
285   };
286   emitElementAccessors(op, os, "operand", getNumVariadic, getNumElements,
287                        getElement);
288 }
289 
290 /// Emits access or Op results.
291 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
292   auto getNumVariadic = [](const Operator &oper) {
293     return oper.getNumVariableLengthResults();
294   };
295   auto getNumElements = [](const Operator &oper) {
296     return oper.getNumResults();
297   };
298   auto getElement = [](const Operator &oper,
299                        int i) -> const NamedTypeConstraint & {
300     return oper.getResult(i);
301   };
302   emitElementAccessors(op, os, "result", getNumVariadic, getNumElements,
303                        getElement);
304 }
305 
306 /// Emits bindings for a specific Op to the given output stream.
307 static void emitOpBindings(const Operator &op, raw_ostream &os) {
308   os << llvm::formatv(opClassTemplate, op.getCppClassName(),
309                       op.getOperationName());
310   emitOperandAccessors(op, os);
311   emitResultAccessors(op, os);
312 }
313 
314 /// Emits bindings for the dialect specified in the command line, including file
315 /// headers and utilities. Returns `false` on success to comply with Tablegen
316 /// registration requirements.
317 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
318   if (clDialectName.empty())
319     llvm::PrintFatalError("dialect name not provided");
320 
321   os << fileHeader;
322   os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
323   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
324     Operator op(rec);
325     if (op.getDialectName() == clDialectName.getValue())
326       emitOpBindings(op, os);
327   }
328   return false;
329 }
330 
331 static GenRegistration
332     genPythonBindings("gen-python-op-bindings",
333                       "Generate Python bindings for MLIR Ops", &emitAllOps);
334