1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 
25 /// File header and includes.
26 constexpr const char *fileHeader = R"Py(
27 # Autogenerated by mlir-tblgen; don't manually edit.
28 
29 import array
30 from . import _cext
31 from . import _segmented_accessor, _equally_sized_accessor
32 _ir = _cext.ir
33 )Py";
34 
35 /// Template for dialect class:
36 ///   {0} is the dialect namespace.
37 constexpr const char *dialectClassTemplate = R"Py(
38 @_cext.register_dialect
39 class _Dialect(_ir.Dialect):
40   DIALECT_NAMESPACE = "{0}"
41   pass
42 
43 )Py";
44 
45 /// Template for operation class:
46 ///   {0} is the Python class name;
47 ///   {1} is the operation name.
48 constexpr const char *opClassTemplate = R"Py(
49 @_cext.register_operation(_Dialect)
50 class {0}(_ir.OpView):
51   OPERATION_NAME = "{1}"
52 )Py";
53 
54 /// Template for single-element accessor:
55 ///   {0} is the name of the accessor;
56 ///   {1} is either 'operand' or 'result';
57 ///   {2} is the position in the element list.
58 constexpr const char *opSingleTemplate = R"Py(
59   @property
60   def {0}(self):
61     return self.operation.{1}s[{2}]
62 )Py";
63 
64 /// Template for single-element accessor after a variable-length group:
65 ///   {0} is the name of the accessor;
66 ///   {1} is either 'operand' or 'result';
67 ///   {2} is the total number of element groups;
68 ///   {3} is the position of the current group in the group list.
69 /// This works for both a single variadic group (non-negative length) and an
70 /// single optional element (zero length if the element is absent).
71 constexpr const char *opSingleAfterVariableTemplate = R"Py(
72   @property
73   def {0}(self):
74     variadic_group_length = len(self.operation.{1}s) - {2} + 1
75     return self.operation.{1}s[{3} + variadic_group_length - 1]
76 )Py";
77 
78 /// Template for an optional element accessor:
79 ///   {0} is the name of the accessor;
80 ///   {1} is either 'operand' or 'result';
81 ///   {2} is the total number of element groups;
82 ///   {3} is the position of the current group in the group list.
83 constexpr const char *opOneOptionalTemplate = R"Py(
84   @property
85   def {0}(self);
86     return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2}
87                                     else None
88 )Py";
89 
90 /// Template for the variadic group accessor in the single variadic group case:
91 ///   {0} is the name of the accessor;
92 ///   {1} is either 'operand' or 'result';
93 ///   {2} is the total number of element groups;
94 ///   {3} is the position of the current group in the group list.
95 constexpr const char *opOneVariadicTemplate = R"Py(
96   @property
97   def {0}(self):
98     variadic_group_length = len(self.operation.{1}s) - {2} + 1
99     return self.operation.{1}s[{3}:{3} + variadic_group_length]
100 )Py";
101 
102 /// First part of the template for equally-sized variadic group accessor:
103 ///   {0} is the name of the accessor;
104 ///   {1} is either 'operand' or 'result';
105 ///   {2} is the total number of variadic groups;
106 ///   {3} is the number of non-variadic groups preceding the current group;
107 ///   {3} is the number of variadic groups preceding the current group.
108 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
109   @property
110   def {0}(self):
111     start, pg = _equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
112 
113 /// Second part of the template for equally-sized case, accessing a single
114 /// element:
115 ///   {0} is either 'operand' or 'result'.
116 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
117     return self.operation.{0}s[start]
118 )Py";
119 
120 /// Second part of the template for equally-sized case, accessing a variadic
121 /// group:
122 ///   {0} is either 'operand' or 'result'.
123 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
124     return self.operation.{0}s[start:start + pg]
125 )Py";
126 
127 /// Template for an attribute-sized group accessor:
128 ///   {0} is the name of the accessor;
129 ///   {1} is either 'operand' or 'result';
130 ///   {2} is the position of the group in the group list;
131 ///   {3} is a return suffix (expected [0] for single-element, empty for
132 ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
133 constexpr const char *opVariadicSegmentTemplate = R"Py(
134   @property
135   def {0}(self):
136     {1}_range = _segmented_accessor(
137          self.operation.{1}s,
138          self.operation.attributes["{1}_segment_sizes"], {2})
139     return {1}_range{3}
140 )Py";
141 
142 /// Template for a suffix when accessing an optional element in the
143 /// attribute-sized case:
144 ///   {0} is either 'operand' or 'result';
145 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
146     R"Py([0] if len({0}_range) > 0 else None)Py";
147 
148 static llvm::cl::OptionCategory
149     clOpPythonBindingCat("Options for -gen-python-op-bindings");
150 
151 static llvm::cl::opt<std::string>
152     clDialectName("bind-dialect",
153                   llvm::cl::desc("The dialect to run the generator for"),
154                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
155 
156 /// Checks whether `str` is a Python keyword.
157 static bool isPythonKeyword(StringRef str) {
158   static llvm::StringSet<> keywords(
159       {"and",   "as",     "assert",   "break", "class",  "continue",
160        "def",   "del",    "elif",     "else",  "except", "finally",
161        "for",   "from",   "global",   "if",    "import", "in",
162        "is",    "lambda", "nonlocal", "not",   "or",     "pass",
163        "raise", "return", "try",      "while", "with",   "yield"});
164   return keywords.contains(str);
165 };
166 
167 /// Modifies the `name` in a way that it becomes suitable for Python bindings
168 /// (does not change the `name` if it already is suitable) and returns the
169 /// modified version.
170 static std::string sanitizeName(StringRef name) {
171   if (isPythonKeyword(name))
172     return (name + "_").str();
173   return name.str();
174 }
175 
176 static std::string attrSizedTraitForKind(const char *kind) {
177   return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
178                        llvm::StringRef(kind).take_front().upper(),
179                        llvm::StringRef(kind).drop_front());
180 }
181 
182 /// Emits accessors to "elements" of an Op definition. Currently, the supported
183 /// elements are operands and results, indicated by `kind`, which must be either
184 /// `operand` or `result` and is used verbatim in the emitted code.
185 static void emitElementAccessors(
186     const Operator &op, raw_ostream &os, const char *kind,
187     llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
188     llvm::function_ref<int(const Operator &)> getNumElements,
189     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
190         getElement) {
191   assert(llvm::is_contained(
192              llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
193          "unsupported kind");
194 
195   // Traits indicating how to process variadic elements.
196   std::string sameSizeTrait =
197       llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
198                     llvm::StringRef(kind).take_front().upper(),
199                     llvm::StringRef(kind).drop_front());
200   std::string attrSizedTrait = attrSizedTraitForKind(kind);
201 
202   unsigned numVariadic = getNumVariadic(op);
203 
204   // If there is only one variadic element group, its size can be inferred from
205   // the total number of elements. If there are none, the generation is
206   // straightforward.
207   if (numVariadic <= 1) {
208     bool seenVariableLength = false;
209     for (int i = 0, e = getNumElements(op); i < e; ++i) {
210       const NamedTypeConstraint &element = getElement(op, i);
211       if (element.isVariableLength())
212         seenVariableLength = true;
213       if (element.name.empty())
214         continue;
215       if (element.isVariableLength()) {
216         os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
217                                                  : opOneVariadicTemplate,
218                             sanitizeName(element.name), kind,
219                             getNumElements(op), i);
220       } else if (seenVariableLength) {
221         os << llvm::formatv(opSingleAfterVariableTemplate,
222                             sanitizeName(element.name), kind,
223                             getNumElements(op), i);
224       } else {
225         os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
226                             i);
227       }
228     }
229     return;
230   }
231 
232   // Handle the operations where variadic groups have the same size.
233   if (op.getTrait(sameSizeTrait)) {
234     int numPrecedingSimple = 0;
235     int numPrecedingVariadic = 0;
236     for (int i = 0, e = getNumElements(op); i < e; ++i) {
237       const NamedTypeConstraint &element = getElement(op, i);
238       if (!element.name.empty()) {
239         os << llvm::formatv(opVariadicEqualPrefixTemplate,
240                             sanitizeName(element.name), kind, numVariadic,
241                             numPrecedingSimple, numPrecedingVariadic);
242         os << llvm::formatv(element.isVariableLength()
243                                 ? opVariadicEqualVariadicTemplate
244                                 : opVariadicEqualSimpleTemplate,
245                             kind);
246       }
247       if (element.isVariableLength())
248         ++numPrecedingVariadic;
249       else
250         ++numPrecedingSimple;
251     }
252     return;
253   }
254 
255   // Handle the operations where the size of groups (variadic or not) is
256   // provided as an attribute. For non-variadic elements, make sure to return
257   // an element rather than a singleton container.
258   if (op.getTrait(attrSizedTrait)) {
259     for (int i = 0, e = getNumElements(op); i < e; ++i) {
260       const NamedTypeConstraint &element = getElement(op, i);
261       if (element.name.empty())
262         continue;
263       std::string trailing;
264       if (!element.isVariableLength())
265         trailing = "[0]";
266       else if (element.isOptional())
267         trailing = std::string(
268             llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
269       os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
270                           kind, i, trailing);
271     }
272     return;
273   }
274 
275   llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
276 }
277 
278 /// Free function helpers accessing Operator components.
279 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
280 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
281   return op.getOperand(i);
282 }
283 static int getNumResults(const Operator &op) { return op.getNumResults(); }
284 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
285   return op.getResult(i);
286 }
287 
288 /// Emits accessor to Op operands.
289 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
290   auto getNumVariadic = [](const Operator &oper) {
291     return oper.getNumVariableLengthOperands();
292   };
293   emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands,
294                        getOperand);
295 }
296 
297 /// Emits access or Op results.
298 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
299   auto getNumVariadic = [](const Operator &oper) {
300     return oper.getNumVariableLengthResults();
301   };
302   emitElementAccessors(op, os, "result", getNumVariadic, getNumResults,
303                        getResult);
304 }
305 
306 /// Template for the default auto-generated builder.
307 ///   {0} is the operation name;
308 ///   {1} is a comma-separated list of builder arguments, including the trailing
309 ///       `loc` and `ip`;
310 ///   {2} is the code populating `operands`, `results` and `attributes` fields.
311 constexpr const char *initTemplate = R"Py(
312   def __init__(self, {1}):
313     operands = []
314     results = []
315     attributes = {{}
316     {2}
317     super().__init__(_ir.Operation.create(
318       "{0}", attributes=attributes, operands=operands, results=results,
319       loc=loc, ip=ip))
320 )Py";
321 
322 /// Template for appending a single element to the operand/result list.
323 ///   {0} is either 'operand' or 'result';
324 ///   {1} is the field name.
325 constexpr const char *singleElementAppendTemplate = "{0}s.append({1})";
326 
327 /// Template for appending an optional element to the operand/result list.
328 ///   {0} is either 'operand' or 'result';
329 ///   {1} is the field name.
330 constexpr const char *optionalAppendTemplate =
331     "if {1} is not None: {0}s.append({1})";
332 
333 /// Template for appending a variadic element to the operand/result list.
334 ///   {0} is either 'operand' or 'result';
335 ///   {1} is the field name.
336 constexpr const char *variadicAppendTemplate = "{0}s += [*{1}]";
337 
338 /// Template for setting up the segment sizes buffer.
339 constexpr const char *segmentDeclarationTemplate =
340     "{0}_segment_sizes = array.array('L')";
341 
342 /// Template for attaching segment sizes to the attribute list.
343 constexpr const char *segmentAttributeTemplate =
344     R"Py(attributes["{0}_segment_sizes"] = _ir.DenseElementsAttr.get({0}_segment_sizes,
345       context=Location.current.context if loc is None else loc.context))Py";
346 
347 /// Template for appending the unit size to the segment sizes.
348 ///   {0} is either 'operand' or 'result';
349 ///   {1} is the field name.
350 constexpr const char *singleElementSegmentTemplate =
351     "{0}_segment_sizes.append(1) # {1}";
352 
353 /// Template for appending 0/1 for an optional element to the segment sizes.
354 ///   {0} is either 'operand' or 'result';
355 ///   {1} is the field name.
356 constexpr const char *optionalSegmentTemplate =
357     "{0}_segment_sizes.append(0 if {1} is None else 1)";
358 
359 /// Template for appending the length of a variadic group to the segment sizes.
360 ///   {0} is either 'operand' or 'result';
361 ///   {1} is the field name.
362 constexpr const char *variadicSegmentTemplate =
363     "{0}_segment_sizes.append(len({1}))";
364 
365 /// Populates `builderArgs` with the list of `__init__` arguments that
366 /// correspond to either operands or results of `op`, and `builderLines` with
367 /// additional lines that are required in the builder. `kind` must be either
368 /// "operand" or "result". `unnamedTemplate` is used to generate names for
369 /// operands or results that don't have the name in ODS.
370 static void populateBuilderLines(
371     const Operator &op, const char *kind, const char *unnamedTemplate,
372     llvm::SmallVectorImpl<std::string> &builderArgs,
373     llvm::SmallVectorImpl<std::string> &builderLines,
374     llvm::function_ref<int(const Operator &)> getNumElements,
375     llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
376         getElement) {
377   // The segment sizes buffer only has to be populated if there attr-sized
378   // segments trait is present.
379   bool includeSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr;
380   if (includeSegments)
381     builderLines.push_back(llvm::formatv(segmentDeclarationTemplate, kind));
382 
383   // For each element, find or generate a name.
384   for (int i = 0, e = getNumElements(op); i < e; ++i) {
385     const NamedTypeConstraint &element = getElement(op, i);
386     std::string name = element.name.str();
387     if (name.empty())
388       name = llvm::formatv(unnamedTemplate, i).str();
389     name = sanitizeName(name);
390     builderArgs.push_back(name);
391 
392     // Choose the formatting string based on the element kind.
393     llvm::StringRef formatString, segmentFormatString;
394     if (!element.isVariableLength()) {
395       formatString = singleElementAppendTemplate;
396       segmentFormatString = singleElementSegmentTemplate;
397     } else if (element.isOptional()) {
398       formatString = optionalAppendTemplate;
399       segmentFormatString = optionalSegmentTemplate;
400     } else {
401       assert(element.isVariadic() && "unhandled element group type");
402       formatString = variadicAppendTemplate;
403       segmentFormatString = variadicSegmentTemplate;
404     }
405 
406     // Add the lines.
407     builderLines.push_back(llvm::formatv(formatString.data(), kind, name));
408     if (includeSegments)
409       builderLines.push_back(
410           llvm::formatv(segmentFormatString.data(), kind, name));
411   }
412 
413   if (includeSegments)
414     builderLines.push_back(llvm::formatv(segmentAttributeTemplate, kind));
415 }
416 
417 /// Emits a default builder constructing an operation from the list of its
418 /// result types, followed by a list of its operands.
419 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
420   // TODO: support attribute types.
421   if (op.getNumNativeAttributes() != 0)
422     return;
423 
424   // If we are asked to skip default builders, comply.
425   if (op.skipDefaultBuilders())
426     return;
427 
428   llvm::SmallVector<std::string, 8> builderArgs;
429   llvm::SmallVector<std::string, 8> builderLines;
430   builderArgs.reserve(op.getNumOperands() + op.getNumResults());
431   populateBuilderLines(op, "result", "_gen_res_{0}", builderArgs, builderLines,
432                        getNumResults, getResult);
433   populateBuilderLines(op, "operand", "_gen_arg_{0}", builderArgs, builderLines,
434                        getNumOperands, getOperand);
435 
436   builderArgs.push_back("loc=None");
437   builderArgs.push_back("ip=None");
438   os << llvm::formatv(initTemplate, op.getOperationName(),
439                       llvm::join(builderArgs, ", "),
440                       llvm::join(builderLines, "\n    "));
441 }
442 
443 /// Emits bindings for a specific Op to the given output stream.
444 static void emitOpBindings(const Operator &op, raw_ostream &os) {
445   os << llvm::formatv(opClassTemplate, op.getCppClassName(),
446                       op.getOperationName());
447   emitDefaultOpBuilder(op, os);
448   emitOperandAccessors(op, os);
449   emitResultAccessors(op, os);
450 }
451 
452 /// Emits bindings for the dialect specified in the command line, including file
453 /// headers and utilities. Returns `false` on success to comply with Tablegen
454 /// registration requirements.
455 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
456   if (clDialectName.empty())
457     llvm::PrintFatalError("dialect name not provided");
458 
459   os << fileHeader;
460   os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
461   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
462     Operator op(rec);
463     if (op.getDialectName() == clDialectName.getValue())
464       emitOpBindings(op, os);
465   }
466   return false;
467 }
468 
469 static GenRegistration
470     genPythonBindings("gen-python-op-bindings",
471                       "Generate Python bindings for MLIR Ops", &emitAllOps);
472