1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python 10 // binding classes wrapping a generic operation API. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/GenInfo.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/ADT/StringSet.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace mlir; 23 using namespace mlir::tblgen; 24 25 /// File header and includes. 26 constexpr const char *fileHeader = R"Py( 27 # Autogenerated by mlir-tblgen; don't manually edit. 28 29 import array 30 from . import _cext 31 from . import _segmented_accessor, _equally_sized_accessor 32 _ir = _cext.ir 33 )Py"; 34 35 /// Template for dialect class: 36 /// {0} is the dialect namespace. 37 constexpr const char *dialectClassTemplate = R"Py( 38 @_cext.register_dialect 39 class _Dialect(_ir.Dialect): 40 DIALECT_NAMESPACE = "{0}" 41 pass 42 43 )Py"; 44 45 /// Template for operation class: 46 /// {0} is the Python class name; 47 /// {1} is the operation name. 48 constexpr const char *opClassTemplate = R"Py( 49 @_cext.register_operation(_Dialect) 50 class {0}(_ir.OpView): 51 OPERATION_NAME = "{1}" 52 )Py"; 53 54 /// Template for single-element accessor: 55 /// {0} is the name of the accessor; 56 /// {1} is either 'operand' or 'result'; 57 /// {2} is the position in the element list. 58 constexpr const char *opSingleTemplate = R"Py( 59 @property 60 def {0}(self): 61 return self.operation.{1}s[{2}] 62 )Py"; 63 64 /// Template for single-element accessor after a variable-length group: 65 /// {0} is the name of the accessor; 66 /// {1} is either 'operand' or 'result'; 67 /// {2} is the total number of element groups; 68 /// {3} is the position of the current group in the group list. 69 /// This works for both a single variadic group (non-negative length) and an 70 /// single optional element (zero length if the element is absent). 71 constexpr const char *opSingleAfterVariableTemplate = R"Py( 72 @property 73 def {0}(self): 74 variadic_group_length = len(self.operation.{1}s) - {2} + 1 75 return self.operation.{1}s[{3} + variadic_group_length - 1] 76 )Py"; 77 78 /// Template for an optional element accessor: 79 /// {0} is the name of the accessor; 80 /// {1} is either 'operand' or 'result'; 81 /// {2} is the total number of element groups; 82 /// {3} is the position of the current group in the group list. 83 constexpr const char *opOneOptionalTemplate = R"Py( 84 @property 85 def {0}(self); 86 return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} 87 else None 88 )Py"; 89 90 /// Template for the variadic group accessor in the single variadic group case: 91 /// {0} is the name of the accessor; 92 /// {1} is either 'operand' or 'result'; 93 /// {2} is the total number of element groups; 94 /// {3} is the position of the current group in the group list. 95 constexpr const char *opOneVariadicTemplate = R"Py( 96 @property 97 def {0}(self): 98 variadic_group_length = len(self.operation.{1}s) - {2} + 1 99 return self.operation.{1}s[{3}:{3} + variadic_group_length] 100 )Py"; 101 102 /// First part of the template for equally-sized variadic group accessor: 103 /// {0} is the name of the accessor; 104 /// {1} is either 'operand' or 'result'; 105 /// {2} is the total number of variadic groups; 106 /// {3} is the number of non-variadic groups preceding the current group; 107 /// {3} is the number of variadic groups preceding the current group. 108 constexpr const char *opVariadicEqualPrefixTemplate = R"Py( 109 @property 110 def {0}(self): 111 start, pg = _equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py"; 112 113 /// Second part of the template for equally-sized case, accessing a single 114 /// element: 115 /// {0} is either 'operand' or 'result'. 116 constexpr const char *opVariadicEqualSimpleTemplate = R"Py( 117 return self.operation.{0}s[start] 118 )Py"; 119 120 /// Second part of the template for equally-sized case, accessing a variadic 121 /// group: 122 /// {0} is either 'operand' or 'result'. 123 constexpr const char *opVariadicEqualVariadicTemplate = R"Py( 124 return self.operation.{0}s[start:start + pg] 125 )Py"; 126 127 /// Template for an attribute-sized group accessor: 128 /// {0} is the name of the accessor; 129 /// {1} is either 'operand' or 'result'; 130 /// {2} is the position of the group in the group list; 131 /// {3} is a return suffix (expected [0] for single-element, empty for 132 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). 133 constexpr const char *opVariadicSegmentTemplate = R"Py( 134 @property 135 def {0}(self): 136 {1}_range = _segmented_accessor( 137 self.operation.{1}s, 138 self.operation.attributes["{1}_segment_sizes"], {2}) 139 return {1}_range{3} 140 )Py"; 141 142 /// Template for a suffix when accessing an optional element in the 143 /// attribute-sized case: 144 /// {0} is either 'operand' or 'result'; 145 constexpr const char *opVariadicSegmentOptionalTrailingTemplate = 146 R"Py([0] if len({0}_range) > 0 else None)Py"; 147 148 static llvm::cl::OptionCategory 149 clOpPythonBindingCat("Options for -gen-python-op-bindings"); 150 151 static llvm::cl::opt<std::string> 152 clDialectName("bind-dialect", 153 llvm::cl::desc("The dialect to run the generator for"), 154 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 155 156 /// Checks whether `str` is a Python keyword. 157 static bool isPythonKeyword(StringRef str) { 158 static llvm::StringSet<> keywords( 159 {"and", "as", "assert", "break", "class", "continue", 160 "def", "del", "elif", "else", "except", "finally", 161 "for", "from", "global", "if", "import", "in", 162 "is", "lambda", "nonlocal", "not", "or", "pass", 163 "raise", "return", "try", "while", "with", "yield"}); 164 return keywords.contains(str); 165 }; 166 167 /// Modifies the `name` in a way that it becomes suitable for Python bindings 168 /// (does not change the `name` if it already is suitable) and returns the 169 /// modified version. 170 static std::string sanitizeName(StringRef name) { 171 if (isPythonKeyword(name)) 172 return (name + "_").str(); 173 return name.str(); 174 } 175 176 static std::string attrSizedTraitForKind(const char *kind) { 177 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", 178 llvm::StringRef(kind).take_front().upper(), 179 llvm::StringRef(kind).drop_front()); 180 } 181 182 /// Emits accessors to "elements" of an Op definition. Currently, the supported 183 /// elements are operands and results, indicated by `kind`, which must be either 184 /// `operand` or `result` and is used verbatim in the emitted code. 185 static void emitElementAccessors( 186 const Operator &op, raw_ostream &os, const char *kind, 187 llvm::function_ref<unsigned(const Operator &)> getNumVariadic, 188 llvm::function_ref<int(const Operator &)> getNumElements, 189 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 190 getElement) { 191 assert(llvm::is_contained( 192 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && 193 "unsupported kind"); 194 195 // Traits indicating how to process variadic elements. 196 std::string sameSizeTrait = 197 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", 198 llvm::StringRef(kind).take_front().upper(), 199 llvm::StringRef(kind).drop_front()); 200 std::string attrSizedTrait = attrSizedTraitForKind(kind); 201 202 unsigned numVariadic = getNumVariadic(op); 203 204 // If there is only one variadic element group, its size can be inferred from 205 // the total number of elements. If there are none, the generation is 206 // straightforward. 207 if (numVariadic <= 1) { 208 bool seenVariableLength = false; 209 for (int i = 0, e = getNumElements(op); i < e; ++i) { 210 const NamedTypeConstraint &element = getElement(op, i); 211 if (element.isVariableLength()) 212 seenVariableLength = true; 213 if (element.name.empty()) 214 continue; 215 if (element.isVariableLength()) { 216 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate 217 : opOneVariadicTemplate, 218 sanitizeName(element.name), kind, 219 getNumElements(op), i); 220 } else if (seenVariableLength) { 221 os << llvm::formatv(opSingleAfterVariableTemplate, 222 sanitizeName(element.name), kind, 223 getNumElements(op), i); 224 } else { 225 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, 226 i); 227 } 228 } 229 return; 230 } 231 232 // Handle the operations where variadic groups have the same size. 233 if (op.getTrait(sameSizeTrait)) { 234 int numPrecedingSimple = 0; 235 int numPrecedingVariadic = 0; 236 for (int i = 0, e = getNumElements(op); i < e; ++i) { 237 const NamedTypeConstraint &element = getElement(op, i); 238 if (!element.name.empty()) { 239 os << llvm::formatv(opVariadicEqualPrefixTemplate, 240 sanitizeName(element.name), kind, numVariadic, 241 numPrecedingSimple, numPrecedingVariadic); 242 os << llvm::formatv(element.isVariableLength() 243 ? opVariadicEqualVariadicTemplate 244 : opVariadicEqualSimpleTemplate, 245 kind); 246 } 247 if (element.isVariableLength()) 248 ++numPrecedingVariadic; 249 else 250 ++numPrecedingSimple; 251 } 252 return; 253 } 254 255 // Handle the operations where the size of groups (variadic or not) is 256 // provided as an attribute. For non-variadic elements, make sure to return 257 // an element rather than a singleton container. 258 if (op.getTrait(attrSizedTrait)) { 259 for (int i = 0, e = getNumElements(op); i < e; ++i) { 260 const NamedTypeConstraint &element = getElement(op, i); 261 if (element.name.empty()) 262 continue; 263 std::string trailing; 264 if (!element.isVariableLength()) 265 trailing = "[0]"; 266 else if (element.isOptional()) 267 trailing = std::string( 268 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); 269 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), 270 kind, i, trailing); 271 } 272 return; 273 } 274 275 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); 276 } 277 278 /// Free function helpers accessing Operator components. 279 static int getNumOperands(const Operator &op) { return op.getNumOperands(); } 280 static const NamedTypeConstraint &getOperand(const Operator &op, int i) { 281 return op.getOperand(i); 282 } 283 static int getNumResults(const Operator &op) { return op.getNumResults(); } 284 static const NamedTypeConstraint &getResult(const Operator &op, int i) { 285 return op.getResult(i); 286 } 287 288 /// Emits accessor to Op operands. 289 static void emitOperandAccessors(const Operator &op, raw_ostream &os) { 290 auto getNumVariadic = [](const Operator &oper) { 291 return oper.getNumVariableLengthOperands(); 292 }; 293 emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands, 294 getOperand); 295 } 296 297 /// Emits access or Op results. 298 static void emitResultAccessors(const Operator &op, raw_ostream &os) { 299 auto getNumVariadic = [](const Operator &oper) { 300 return oper.getNumVariableLengthResults(); 301 }; 302 emitElementAccessors(op, os, "result", getNumVariadic, getNumResults, 303 getResult); 304 } 305 306 /// Template for the default auto-generated builder. 307 /// {0} is the operation name; 308 /// {1} is a comma-separated list of builder arguments, including the trailing 309 /// `loc` and `ip`; 310 /// {2} is the code populating `operands`, `results` and `attributes` fields. 311 constexpr const char *initTemplate = R"Py( 312 def __init__(self, {1}): 313 operands = [] 314 results = [] 315 attributes = {{} 316 {2} 317 super().__init__(_ir.Operation.create( 318 "{0}", attributes=attributes, operands=operands, results=results, 319 loc=loc, ip=ip)) 320 )Py"; 321 322 /// Template for appending a single element to the operand/result list. 323 /// {0} is either 'operand' or 'result'; 324 /// {1} is the field name. 325 constexpr const char *singleElementAppendTemplate = "{0}s.append({1})"; 326 327 /// Template for appending an optional element to the operand/result list. 328 /// {0} is either 'operand' or 'result'; 329 /// {1} is the field name. 330 constexpr const char *optionalAppendTemplate = 331 "if {1} is not None: {0}s.append({1})"; 332 333 /// Template for appending a variadic element to the operand/result list. 334 /// {0} is either 'operand' or 'result'; 335 /// {1} is the field name. 336 constexpr const char *variadicAppendTemplate = "{0}s += [*{1}]"; 337 338 /// Template for setting up the segment sizes buffer. 339 constexpr const char *segmentDeclarationTemplate = 340 "{0}_segment_sizes = array.array('L')"; 341 342 /// Template for attaching segment sizes to the attribute list. 343 constexpr const char *segmentAttributeTemplate = 344 R"Py(attributes["{0}_segment_sizes"] = _ir.DenseElementsAttr.get({0}_segment_sizes, 345 context=Location.current.context if loc is None else loc.context))Py"; 346 347 /// Template for appending the unit size to the segment sizes. 348 /// {0} is either 'operand' or 'result'; 349 /// {1} is the field name. 350 constexpr const char *singleElementSegmentTemplate = 351 "{0}_segment_sizes.append(1) # {1}"; 352 353 /// Template for appending 0/1 for an optional element to the segment sizes. 354 /// {0} is either 'operand' or 'result'; 355 /// {1} is the field name. 356 constexpr const char *optionalSegmentTemplate = 357 "{0}_segment_sizes.append(0 if {1} is None else 1)"; 358 359 /// Template for appending the length of a variadic group to the segment sizes. 360 /// {0} is either 'operand' or 'result'; 361 /// {1} is the field name. 362 constexpr const char *variadicSegmentTemplate = 363 "{0}_segment_sizes.append(len({1}))"; 364 365 /// Populates `builderArgs` with the list of `__init__` arguments that 366 /// correspond to either operands or results of `op`, and `builderLines` with 367 /// additional lines that are required in the builder. `kind` must be either 368 /// "operand" or "result". `unnamedTemplate` is used to generate names for 369 /// operands or results that don't have the name in ODS. 370 static void populateBuilderLines( 371 const Operator &op, const char *kind, const char *unnamedTemplate, 372 llvm::SmallVectorImpl<std::string> &builderArgs, 373 llvm::SmallVectorImpl<std::string> &builderLines, 374 llvm::function_ref<int(const Operator &)> getNumElements, 375 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 376 getElement) { 377 // The segment sizes buffer only has to be populated if there attr-sized 378 // segments trait is present. 379 bool includeSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr; 380 if (includeSegments) 381 builderLines.push_back(llvm::formatv(segmentDeclarationTemplate, kind)); 382 383 // For each element, find or generate a name. 384 for (int i = 0, e = getNumElements(op); i < e; ++i) { 385 const NamedTypeConstraint &element = getElement(op, i); 386 std::string name = element.name.str(); 387 if (name.empty()) 388 name = llvm::formatv(unnamedTemplate, i).str(); 389 name = sanitizeName(name); 390 builderArgs.push_back(name); 391 392 // Choose the formatting string based on the element kind. 393 llvm::StringRef formatString, segmentFormatString; 394 if (!element.isVariableLength()) { 395 formatString = singleElementAppendTemplate; 396 segmentFormatString = singleElementSegmentTemplate; 397 } else if (element.isOptional()) { 398 formatString = optionalAppendTemplate; 399 segmentFormatString = optionalSegmentTemplate; 400 } else { 401 assert(element.isVariadic() && "unhandled element group type"); 402 formatString = variadicAppendTemplate; 403 segmentFormatString = variadicSegmentTemplate; 404 } 405 406 // Add the lines. 407 builderLines.push_back(llvm::formatv(formatString.data(), kind, name)); 408 if (includeSegments) 409 builderLines.push_back( 410 llvm::formatv(segmentFormatString.data(), kind, name)); 411 } 412 413 if (includeSegments) 414 builderLines.push_back(llvm::formatv(segmentAttributeTemplate, kind)); 415 } 416 417 /// Emits a default builder constructing an operation from the list of its 418 /// result types, followed by a list of its operands. 419 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { 420 // TODO: support attribute types. 421 if (op.getNumNativeAttributes() != 0) 422 return; 423 424 // If we are asked to skip default builders, comply. 425 if (op.skipDefaultBuilders()) 426 return; 427 428 llvm::SmallVector<std::string, 8> builderArgs; 429 llvm::SmallVector<std::string, 8> builderLines; 430 builderArgs.reserve(op.getNumOperands() + op.getNumResults()); 431 populateBuilderLines(op, "result", "_gen_res_{0}", builderArgs, builderLines, 432 getNumResults, getResult); 433 populateBuilderLines(op, "operand", "_gen_arg_{0}", builderArgs, builderLines, 434 getNumOperands, getOperand); 435 436 builderArgs.push_back("loc=None"); 437 builderArgs.push_back("ip=None"); 438 os << llvm::formatv(initTemplate, op.getOperationName(), 439 llvm::join(builderArgs, ", "), 440 llvm::join(builderLines, "\n ")); 441 } 442 443 /// Emits bindings for a specific Op to the given output stream. 444 static void emitOpBindings(const Operator &op, raw_ostream &os) { 445 os << llvm::formatv(opClassTemplate, op.getCppClassName(), 446 op.getOperationName()); 447 emitDefaultOpBuilder(op, os); 448 emitOperandAccessors(op, os); 449 emitResultAccessors(op, os); 450 } 451 452 /// Emits bindings for the dialect specified in the command line, including file 453 /// headers and utilities. Returns `false` on success to comply with Tablegen 454 /// registration requirements. 455 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { 456 if (clDialectName.empty()) 457 llvm::PrintFatalError("dialect name not provided"); 458 459 os << fileHeader; 460 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); 461 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { 462 Operator op(rec); 463 if (op.getDialectName() == clDialectName.getValue()) 464 emitOpBindings(op, os); 465 } 466 return false; 467 } 468 469 static GenRegistration 470 genPythonBindings("gen-python-op-bindings", 471 "Generate Python bindings for MLIR Ops", &emitAllOps); 472