1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python 10 // binding classes wrapping a generic operation API. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/GenInfo.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/ADT/StringSet.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace mlir; 23 using namespace mlir::tblgen; 24 25 /// File header and includes. 26 constexpr const char *fileHeader = R"Py( 27 # Autogenerated by mlir-tblgen; don't manually edit. 28 29 import array as _ods_array 30 from . import _cext as _ods_cext 31 from . import _segmented_accessor as _ods_segmented_accessor, _equally_sized_accessor as _ods_equally_sized_accessor, _get_default_loc_context as _ods_get_default_loc_context 32 _ods_ir = _ods_cext.ir 33 )Py"; 34 35 /// Template for dialect class: 36 /// {0} is the dialect namespace. 37 constexpr const char *dialectClassTemplate = R"Py( 38 @_ods_cext.register_dialect 39 class _Dialect(_ods_ir.Dialect): 40 DIALECT_NAMESPACE = "{0}" 41 pass 42 43 )Py"; 44 45 /// Template for operation class: 46 /// {0} is the Python class name; 47 /// {1} is the operation name. 48 constexpr const char *opClassTemplate = R"Py( 49 @_ods_cext.register_operation(_Dialect) 50 class {0}(_ods_ir.OpView): 51 OPERATION_NAME = "{1}" 52 )Py"; 53 54 /// Template for single-element accessor: 55 /// {0} is the name of the accessor; 56 /// {1} is either 'operand' or 'result'; 57 /// {2} is the position in the element list. 58 constexpr const char *opSingleTemplate = R"Py( 59 @property 60 def {0}(self): 61 return self.operation.{1}s[{2}] 62 )Py"; 63 64 /// Template for single-element accessor after a variable-length group: 65 /// {0} is the name of the accessor; 66 /// {1} is either 'operand' or 'result'; 67 /// {2} is the total number of element groups; 68 /// {3} is the position of the current group in the group list. 69 /// This works for both a single variadic group (non-negative length) and an 70 /// single optional element (zero length if the element is absent). 71 constexpr const char *opSingleAfterVariableTemplate = R"Py( 72 @property 73 def {0}(self): 74 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 75 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] 76 )Py"; 77 78 /// Template for an optional element accessor: 79 /// {0} is the name of the accessor; 80 /// {1} is either 'operand' or 'result'; 81 /// {2} is the total number of element groups; 82 /// {3} is the position of the current group in the group list. 83 constexpr const char *opOneOptionalTemplate = R"Py( 84 @property 85 def {0}(self); 86 return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} 87 else None 88 )Py"; 89 90 /// Template for the variadic group accessor in the single variadic group case: 91 /// {0} is the name of the accessor; 92 /// {1} is either 'operand' or 'result'; 93 /// {2} is the total number of element groups; 94 /// {3} is the position of the current group in the group list. 95 constexpr const char *opOneVariadicTemplate = R"Py( 96 @property 97 def {0}(self): 98 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 99 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] 100 )Py"; 101 102 /// First part of the template for equally-sized variadic group accessor: 103 /// {0} is the name of the accessor; 104 /// {1} is either 'operand' or 'result'; 105 /// {2} is the total number of variadic groups; 106 /// {3} is the number of non-variadic groups preceding the current group; 107 /// {3} is the number of variadic groups preceding the current group. 108 constexpr const char *opVariadicEqualPrefixTemplate = R"Py( 109 @property 110 def {0}(self): 111 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py"; 112 113 /// Second part of the template for equally-sized case, accessing a single 114 /// element: 115 /// {0} is either 'operand' or 'result'. 116 constexpr const char *opVariadicEqualSimpleTemplate = R"Py( 117 return self.operation.{0}s[start] 118 )Py"; 119 120 /// Second part of the template for equally-sized case, accessing a variadic 121 /// group: 122 /// {0} is either 'operand' or 'result'. 123 constexpr const char *opVariadicEqualVariadicTemplate = R"Py( 124 return self.operation.{0}s[start:start + pg] 125 )Py"; 126 127 /// Template for an attribute-sized group accessor: 128 /// {0} is the name of the accessor; 129 /// {1} is either 'operand' or 'result'; 130 /// {2} is the position of the group in the group list; 131 /// {3} is a return suffix (expected [0] for single-element, empty for 132 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). 133 constexpr const char *opVariadicSegmentTemplate = R"Py( 134 @property 135 def {0}(self): 136 {1}_range = _ods_segmented_accessor( 137 self.operation.{1}s, 138 self.operation.attributes["{1}_segment_sizes"], {2}) 139 return {1}_range{3} 140 )Py"; 141 142 /// Template for a suffix when accessing an optional element in the 143 /// attribute-sized case: 144 /// {0} is either 'operand' or 'result'; 145 constexpr const char *opVariadicSegmentOptionalTrailingTemplate = 146 R"Py([0] if len({0}_range) > 0 else None)Py"; 147 148 /// Template for an operation attribute getter: 149 /// {0} is the name of the attribute sanitized for Python; 150 /// {1} is the Python type of the attribute; 151 /// {2} os the original name of the attribute. 152 constexpr const char *attributeGetterTemplate = R"Py( 153 @property 154 def {0}(self): 155 return {1}(self.operation.attributes["{2}"]) 156 )Py"; 157 158 /// Template for an optional operation attribute getter: 159 /// {0} is the name of the attribute sanitized for Python; 160 /// {1} is the Python type of the attribute; 161 /// {2} is the original name of the attribute. 162 constexpr const char *optionalAttributeGetterTemplate = R"Py( 163 @property 164 def {0}(self): 165 if "{2}" not in self.operation.attributes: 166 return None 167 return {1}(self.operation.attributes["{2}"]) 168 )Py"; 169 170 /// Template for a getter of a unit operation attribute, returns True of the 171 /// unit attribute is present, False otherwise (unit attributes have meaning 172 /// by mere presence): 173 /// {0} is the name of the attribute sanitized for Python, 174 /// {1} is the original name of the attribute. 175 constexpr const char *unitAttributeGetterTemplate = R"Py( 176 @property 177 def {0}(self): 178 return "{1}" in self.operation.attributes 179 )Py"; 180 181 /// Template for an operation attribute setter: 182 /// {0} is the name of the attribute sanitized for Python; 183 /// {1} is the original name of the attribute. 184 constexpr const char *attributeSetterTemplate = R"Py( 185 @{0}.setter 186 def {0}(self, value): 187 if value is None: 188 raise ValueError("'None' not allowed as value for mandatory attributes") 189 self.operation.attributes["{1}"] = value 190 )Py"; 191 192 /// Template for a setter of an optional operation attribute, setting to None 193 /// removes the attribute: 194 /// {0} is the name of the attribute sanitized for Python; 195 /// {1} is the original name of the attribute. 196 constexpr const char *optionalAttributeSetterTemplate = R"Py( 197 @{0}.setter 198 def {0}(self, value): 199 if value is not None: 200 self.operation.attributes["{1}"] = value 201 elif "{1}" in self.operation.attributes: 202 del self.operation.attributes["{1}"] 203 )Py"; 204 205 /// Template for a setter of a unit operation attribute, setting to None or 206 /// False removes the attribute: 207 /// {0} is the name of the attribute sanitized for Python; 208 /// {1} is the original name of the attribute. 209 constexpr const char *unitAttributeSetterTemplate = R"Py( 210 @{0}.setter 211 def {0}(self, value): 212 if bool(value): 213 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() 214 elif "{1}" in self.operation.attributes: 215 del self.operation.attributes["{1}"] 216 )Py"; 217 218 /// Template for a deleter of an optional or a unit operation attribute, removes 219 /// the attribute from the operation: 220 /// {0} is the name of the attribute sanitized for Python; 221 /// {1} is the original name of the attribute. 222 constexpr const char *attributeDeleterTemplate = R"Py( 223 @{0}.deleter 224 def {0}(self): 225 del self.operation.attributes["{1}"] 226 )Py"; 227 228 static llvm::cl::OptionCategory 229 clOpPythonBindingCat("Options for -gen-python-op-bindings"); 230 231 static llvm::cl::opt<std::string> 232 clDialectName("bind-dialect", 233 llvm::cl::desc("The dialect to run the generator for"), 234 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 235 236 using AttributeClasses = DenseMap<StringRef, StringRef>; 237 238 /// Checks whether `str` is a Python keyword. 239 static bool isPythonKeyword(StringRef str) { 240 static llvm::StringSet<> keywords( 241 {"and", "as", "assert", "break", "class", "continue", 242 "def", "del", "elif", "else", "except", "finally", 243 "for", "from", "global", "if", "import", "in", 244 "is", "lambda", "nonlocal", "not", "or", "pass", 245 "raise", "return", "try", "while", "with", "yield"}); 246 return keywords.contains(str); 247 }; 248 249 /// Checks whether `str` would shadow a generated variable or attribute 250 /// part of the OpView API. 251 static bool isODSReserved(StringRef str) { 252 static llvm::StringSet<> reserved( 253 {"attributes", "create", "context", "ip", "operands", "print", "get_asm", 254 "loc", "verify", "regions", "result", "results", "self", "operation", 255 "DIALECT_NAMESPACE", "OPERATION_NAME"}); 256 return str.startswith("_ods_") || str.endswith("_ods") || 257 reserved.contains(str); 258 } 259 260 /// Modifies the `name` in a way that it becomes suitable for Python bindings 261 /// (does not change the `name` if it already is suitable) and returns the 262 /// modified version. 263 static std::string sanitizeName(StringRef name) { 264 if (isPythonKeyword(name) || isODSReserved(name)) 265 return (name + "_").str(); 266 return name.str(); 267 } 268 269 static std::string attrSizedTraitForKind(const char *kind) { 270 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", 271 llvm::StringRef(kind).take_front().upper(), 272 llvm::StringRef(kind).drop_front()); 273 } 274 275 /// Emits accessors to "elements" of an Op definition. Currently, the supported 276 /// elements are operands and results, indicated by `kind`, which must be either 277 /// `operand` or `result` and is used verbatim in the emitted code. 278 static void emitElementAccessors( 279 const Operator &op, raw_ostream &os, const char *kind, 280 llvm::function_ref<unsigned(const Operator &)> getNumVariadic, 281 llvm::function_ref<int(const Operator &)> getNumElements, 282 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 283 getElement) { 284 assert(llvm::is_contained( 285 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && 286 "unsupported kind"); 287 288 // Traits indicating how to process variadic elements. 289 std::string sameSizeTrait = 290 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", 291 llvm::StringRef(kind).take_front().upper(), 292 llvm::StringRef(kind).drop_front()); 293 std::string attrSizedTrait = attrSizedTraitForKind(kind); 294 295 unsigned numVariadic = getNumVariadic(op); 296 297 // If there is only one variadic element group, its size can be inferred from 298 // the total number of elements. If there are none, the generation is 299 // straightforward. 300 if (numVariadic <= 1) { 301 bool seenVariableLength = false; 302 for (int i = 0, e = getNumElements(op); i < e; ++i) { 303 const NamedTypeConstraint &element = getElement(op, i); 304 if (element.isVariableLength()) 305 seenVariableLength = true; 306 if (element.name.empty()) 307 continue; 308 if (element.isVariableLength()) { 309 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate 310 : opOneVariadicTemplate, 311 sanitizeName(element.name), kind, 312 getNumElements(op), i); 313 } else if (seenVariableLength) { 314 os << llvm::formatv(opSingleAfterVariableTemplate, 315 sanitizeName(element.name), kind, 316 getNumElements(op), i); 317 } else { 318 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, 319 i); 320 } 321 } 322 return; 323 } 324 325 // Handle the operations where variadic groups have the same size. 326 if (op.getTrait(sameSizeTrait)) { 327 int numPrecedingSimple = 0; 328 int numPrecedingVariadic = 0; 329 for (int i = 0, e = getNumElements(op); i < e; ++i) { 330 const NamedTypeConstraint &element = getElement(op, i); 331 if (!element.name.empty()) { 332 os << llvm::formatv(opVariadicEqualPrefixTemplate, 333 sanitizeName(element.name), kind, numVariadic, 334 numPrecedingSimple, numPrecedingVariadic); 335 os << llvm::formatv(element.isVariableLength() 336 ? opVariadicEqualVariadicTemplate 337 : opVariadicEqualSimpleTemplate, 338 kind); 339 } 340 if (element.isVariableLength()) 341 ++numPrecedingVariadic; 342 else 343 ++numPrecedingSimple; 344 } 345 return; 346 } 347 348 // Handle the operations where the size of groups (variadic or not) is 349 // provided as an attribute. For non-variadic elements, make sure to return 350 // an element rather than a singleton container. 351 if (op.getTrait(attrSizedTrait)) { 352 for (int i = 0, e = getNumElements(op); i < e; ++i) { 353 const NamedTypeConstraint &element = getElement(op, i); 354 if (element.name.empty()) 355 continue; 356 std::string trailing; 357 if (!element.isVariableLength()) 358 trailing = "[0]"; 359 else if (element.isOptional()) 360 trailing = std::string( 361 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); 362 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), 363 kind, i, trailing); 364 } 365 return; 366 } 367 368 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); 369 } 370 371 /// Free function helpers accessing Operator components. 372 static int getNumOperands(const Operator &op) { return op.getNumOperands(); } 373 static const NamedTypeConstraint &getOperand(const Operator &op, int i) { 374 return op.getOperand(i); 375 } 376 static int getNumResults(const Operator &op) { return op.getNumResults(); } 377 static const NamedTypeConstraint &getResult(const Operator &op, int i) { 378 return op.getResult(i); 379 } 380 381 /// Emits accessors to Op operands. 382 static void emitOperandAccessors(const Operator &op, raw_ostream &os) { 383 auto getNumVariadic = [](const Operator &oper) { 384 return oper.getNumVariableLengthOperands(); 385 }; 386 emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands, 387 getOperand); 388 } 389 390 /// Emits accessors Op results. 391 static void emitResultAccessors(const Operator &op, raw_ostream &os) { 392 auto getNumVariadic = [](const Operator &oper) { 393 return oper.getNumVariableLengthResults(); 394 }; 395 emitElementAccessors(op, os, "result", getNumVariadic, getNumResults, 396 getResult); 397 } 398 399 /// Emits accessors to Op attributes. 400 static void emitAttributeAccessors(const Operator &op, 401 const AttributeClasses &attributeClasses, 402 raw_ostream &os) { 403 for (const auto &namedAttr : op.getAttributes()) { 404 // Skip "derived" attributes because they are just C++ functions that we 405 // don't currently expose. 406 if (namedAttr.attr.isDerivedAttr()) 407 continue; 408 409 if (namedAttr.name.empty()) 410 continue; 411 412 std::string sanitizedName = sanitizeName(namedAttr.name); 413 414 // Unit attributes are handled specially. 415 if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 416 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, 417 namedAttr.name); 418 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, 419 namedAttr.name); 420 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 421 namedAttr.name); 422 continue; 423 } 424 425 // Other kinds of attributes need a mapping to a Python type. 426 if (!attributeClasses.count(namedAttr.attr.getStorageType().trim())) 427 continue; 428 429 StringRef pythonType = 430 attributeClasses.lookup(namedAttr.attr.getStorageType()); 431 if (namedAttr.attr.isOptional()) { 432 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName, 433 pythonType, namedAttr.name); 434 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName, 435 namedAttr.name); 436 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 437 namedAttr.name); 438 } else { 439 os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType, 440 namedAttr.name); 441 os << llvm::formatv(attributeSetterTemplate, sanitizedName, 442 namedAttr.name); 443 // Non-optional attributes cannot be deleted. 444 } 445 } 446 } 447 448 /// Template for the default auto-generated builder. 449 /// {0} is the operation name; 450 /// {1} is a comma-separated list of builder arguments, including the trailing 451 /// `loc` and `ip`; 452 /// {2} is the code populating `operands`, `results` and `attributes` fields. 453 constexpr const char *initTemplate = R"Py( 454 def __init__(self, {1}): 455 operands = [] 456 results = [] 457 attributes = {{} 458 {2} 459 super().__init__(_ods_ir.Operation.create( 460 "{0}", attributes=attributes, operands=operands, results=results, 461 loc=loc, ip=ip)) 462 )Py"; 463 464 /// Template for appending a single element to the operand/result list. 465 /// {0} is either 'operand' or 'result'; 466 /// {1} is the field name. 467 constexpr const char *singleElementAppendTemplate = "{0}s.append({1})"; 468 469 /// Template for appending an optional element to the operand/result list. 470 /// {0} is either 'operand' or 'result'; 471 /// {1} is the field name. 472 constexpr const char *optionalAppendTemplate = 473 "if {1} is not None: {0}s.append({1})"; 474 475 /// Template for appending a variadic element to the operand/result list. 476 /// {0} is either 'operand' or 'result'; 477 /// {1} is the field name. 478 constexpr const char *variadicAppendTemplate = "{0}s += [*{1}]"; 479 480 /// Template for setting up the segment sizes buffer. 481 constexpr const char *segmentDeclarationTemplate = 482 "{0}_segment_sizes_ods = _ods_array.array('L')"; 483 484 /// Template for attaching segment sizes to the attribute list. 485 constexpr const char *segmentAttributeTemplate = 486 R"Py(attributes["{0}_segment_sizes"] = _ods_ir.DenseElementsAttr.get({0}_segment_sizes_ods, 487 context=_ods_get_default_loc_context(loc)))Py"; 488 489 /// Template for appending the unit size to the segment sizes. 490 /// {0} is either 'operand' or 'result'; 491 /// {1} is the field name. 492 constexpr const char *singleElementSegmentTemplate = 493 "{0}_segment_sizes_ods.append(1) # {1}"; 494 495 /// Template for appending 0/1 for an optional element to the segment sizes. 496 /// {0} is either 'operand' or 'result'; 497 /// {1} is the field name. 498 constexpr const char *optionalSegmentTemplate = 499 "{0}_segment_sizes_ods.append(0 if {1} is None else 1)"; 500 501 /// Template for appending the length of a variadic group to the segment sizes. 502 /// {0} is either 'operand' or 'result'; 503 /// {1} is the field name. 504 constexpr const char *variadicSegmentTemplate = 505 "{0}_segment_sizes_ods.append(len({1}))"; 506 507 /// Template for setting an attribute in the operation builder. 508 /// {0} is the attribute name; 509 /// {1} is the builder argument name. 510 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; 511 512 /// Template for setting an optional attribute in the operation builder. 513 /// {0} is the attribute name; 514 /// {1} is the builder argument name. 515 constexpr const char *initOptionalAttributeTemplate = 516 R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; 517 518 constexpr const char *initUnitAttributeTemplate = 519 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( 520 _ods_get_default_loc_context(loc)))Py"; 521 522 /// Populates `builderArgs` with the Python-compatible names of builder function 523 /// arguments, first the results, then the intermixed attributes and operands in 524 /// the same order as they appear in the `arguments` field of the op definition. 525 /// Additionally, `operandNames` is populated with names of operands in their 526 /// order of appearance. 527 static void 528 populateBuilderArgs(const Operator &op, 529 llvm::SmallVectorImpl<std::string> &builderArgs, 530 llvm::SmallVectorImpl<std::string> &operandNames) { 531 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 532 std::string name = op.getResultName(i).str(); 533 if (name.empty()) 534 name = llvm::formatv("_gen_res_{0}", i); 535 name = sanitizeName(name); 536 builderArgs.push_back(name); 537 } 538 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 539 std::string name = op.getArgName(i).str(); 540 if (name.empty()) 541 name = llvm::formatv("_gen_arg_{0}", i); 542 name = sanitizeName(name); 543 builderArgs.push_back(name); 544 if (!op.getArg(i).is<NamedAttribute *>()) 545 operandNames.push_back(name); 546 } 547 } 548 549 /// Populates `builderLines` with additional lines that are required in the 550 /// builder to set up operation attributes. `argNames` is expected to contain 551 /// the names of builder arguments that correspond to op arguments, i.e. to the 552 /// operands and attributes in the same order as they appear in the `arguments` 553 /// field. 554 static void 555 populateBuilderLinesAttr(const Operator &op, 556 llvm::ArrayRef<std::string> argNames, 557 llvm::SmallVectorImpl<std::string> &builderLines) { 558 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 559 Argument arg = op.getArg(i); 560 auto *attribute = arg.dyn_cast<NamedAttribute *>(); 561 if (!attribute) 562 continue; 563 564 // Unit attributes are handled specially. 565 if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 566 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, 567 attribute->name, argNames[i])); 568 continue; 569 } 570 571 builderLines.push_back(llvm::formatv(attribute->attr.isOptional() 572 ? initOptionalAttributeTemplate 573 : initAttributeTemplate, 574 attribute->name, argNames[i])); 575 } 576 } 577 578 /// Populates `builderLines` with additional lines that are required in the 579 /// builder. `kind` must be either "operand" or "result". `names` contains the 580 /// names of init arguments that correspond to the elements. 581 static void populateBuilderLines( 582 const Operator &op, const char *kind, llvm::ArrayRef<std::string> names, 583 llvm::SmallVectorImpl<std::string> &builderLines, 584 llvm::function_ref<int(const Operator &)> getNumElements, 585 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 586 getElement) { 587 // The segment sizes buffer only has to be populated if there attr-sized 588 // segments trait is present. 589 bool includeSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr; 590 if (includeSegments) 591 builderLines.push_back(llvm::formatv(segmentDeclarationTemplate, kind)); 592 593 // For each element, find or generate a name. 594 for (int i = 0, e = getNumElements(op); i < e; ++i) { 595 const NamedTypeConstraint &element = getElement(op, i); 596 std::string name = names[i]; 597 598 // Choose the formatting string based on the element kind. 599 llvm::StringRef formatString, segmentFormatString; 600 if (!element.isVariableLength()) { 601 formatString = singleElementAppendTemplate; 602 segmentFormatString = singleElementSegmentTemplate; 603 } else if (element.isOptional()) { 604 formatString = optionalAppendTemplate; 605 segmentFormatString = optionalSegmentTemplate; 606 } else { 607 assert(element.isVariadic() && "unhandled element group type"); 608 formatString = variadicAppendTemplate; 609 segmentFormatString = variadicSegmentTemplate; 610 } 611 612 // Add the lines. 613 builderLines.push_back(llvm::formatv(formatString.data(), kind, name)); 614 if (includeSegments) 615 builderLines.push_back( 616 llvm::formatv(segmentFormatString.data(), kind, name)); 617 } 618 619 if (includeSegments) 620 builderLines.push_back(llvm::formatv(segmentAttributeTemplate, kind)); 621 } 622 623 /// Emits a default builder constructing an operation from the list of its 624 /// result types, followed by a list of its operands. 625 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { 626 // If we are asked to skip default builders, comply. 627 if (op.skipDefaultBuilders()) 628 return; 629 630 llvm::SmallVector<std::string, 8> builderArgs; 631 llvm::SmallVector<std::string, 8> builderLines; 632 llvm::SmallVector<std::string, 4> operandArgNames; 633 builderArgs.reserve(op.getNumOperands() + op.getNumResults() + 634 op.getNumNativeAttributes()); 635 populateBuilderArgs(op, builderArgs, operandArgNames); 636 populateBuilderLines( 637 op, "result", 638 llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()), 639 builderLines, getNumResults, getResult); 640 populateBuilderLines(op, "operand", operandArgNames, builderLines, 641 getNumOperands, getOperand); 642 populateBuilderLinesAttr( 643 op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()), 644 builderLines); 645 646 builderArgs.push_back("loc=None"); 647 builderArgs.push_back("ip=None"); 648 os << llvm::formatv(initTemplate, op.getOperationName(), 649 llvm::join(builderArgs, ", "), 650 llvm::join(builderLines, "\n ")); 651 } 652 653 static void constructAttributeMapping(const llvm::RecordKeeper &records, 654 AttributeClasses &attributeClasses) { 655 for (const llvm::Record *rec : 656 records.getAllDerivedDefinitions("PythonAttr")) { 657 attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(), 658 rec->getValueAsString("pythonType").trim()); 659 } 660 } 661 662 /// Emits bindings for a specific Op to the given output stream. 663 static void emitOpBindings(const Operator &op, 664 const AttributeClasses &attributeClasses, 665 raw_ostream &os) { 666 os << llvm::formatv(opClassTemplate, op.getCppClassName(), 667 op.getOperationName()); 668 emitDefaultOpBuilder(op, os); 669 emitOperandAccessors(op, os); 670 emitAttributeAccessors(op, attributeClasses, os); 671 emitResultAccessors(op, os); 672 } 673 674 /// Emits bindings for the dialect specified in the command line, including file 675 /// headers and utilities. Returns `false` on success to comply with Tablegen 676 /// registration requirements. 677 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { 678 if (clDialectName.empty()) 679 llvm::PrintFatalError("dialect name not provided"); 680 681 AttributeClasses attributeClasses; 682 constructAttributeMapping(records, attributeClasses); 683 684 os << fileHeader; 685 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); 686 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { 687 Operator op(rec); 688 if (op.getDialectName() == clDialectName.getValue()) 689 emitOpBindings(op, attributeClasses, os); 690 } 691 return false; 692 } 693 694 static GenRegistration 695 genPythonBindings("gen-python-op-bindings", 696 "Generate Python bindings for MLIR Ops", &emitAllOps); 697