1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python 10 // binding classes wrapping a generic operation API. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/GenInfo.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/ADT/StringSet.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace mlir; 23 using namespace mlir::tblgen; 24 25 /// File header and includes. 26 /// {0} is the dialect namespace. 27 constexpr const char *fileHeader = R"Py( 28 # Autogenerated by mlir-tblgen; don't manually edit. 29 30 from . import _cext as _ods_cext 31 from . import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context 32 _ods_ir = _ods_cext.ir 33 34 try: 35 from . import _{0} as _ods_ext_module 36 except ImportError: 37 _ods_ext_module = None 38 39 )Py"; 40 41 /// Template for dialect class: 42 /// {0} is the dialect namespace. 43 constexpr const char *dialectClassTemplate = R"Py( 44 @_ods_cext.register_dialect 45 class _Dialect(_ods_ir.Dialect): 46 DIALECT_NAMESPACE = "{0}" 47 pass 48 49 )Py"; 50 51 /// Template for operation class: 52 /// {0} is the Python class name; 53 /// {1} is the operation name. 54 constexpr const char *opClassTemplate = R"Py( 55 @_ods_cext.register_operation(_Dialect) 56 @_ods_extend_opview_class(_ods_ext_module) 57 class {0}(_ods_ir.OpView): 58 OPERATION_NAME = "{1}" 59 )Py"; 60 61 /// Template for class level declarations of operand and result 62 /// segment specs. 63 /// {0} is either "OPERAND" or "RESULT" 64 /// {1} is the segment spec 65 /// Each segment spec is either None (default) or an array of integers 66 /// where: 67 /// 1 = single element (expect non sequence operand/result) 68 /// -1 = operand/result is a sequence corresponding to a variadic 69 constexpr const char *opClassSizedSegmentsTemplate = R"Py( 70 _ODS_{0}_SEGMENTS = {1} 71 )Py"; 72 73 /// Template for class level declarations of the _ODS_REGIONS spec: 74 /// {0} is the minimum number of regions 75 /// {1} is the Python bool literal for hasNoVariadicRegions 76 constexpr const char *opClassRegionSpecTemplate = R"Py( 77 _ODS_REGIONS = ({0}, {1}) 78 )Py"; 79 80 /// Template for single-element accessor: 81 /// {0} is the name of the accessor; 82 /// {1} is either 'operand' or 'result'; 83 /// {2} is the position in the element list. 84 constexpr const char *opSingleTemplate = R"Py( 85 @property 86 def {0}(self): 87 return self.operation.{1}s[{2}] 88 )Py"; 89 90 /// Template for single-element accessor after a variable-length group: 91 /// {0} is the name of the accessor; 92 /// {1} is either 'operand' or 'result'; 93 /// {2} is the total number of element groups; 94 /// {3} is the position of the current group in the group list. 95 /// This works for both a single variadic group (non-negative length) and an 96 /// single optional element (zero length if the element is absent). 97 constexpr const char *opSingleAfterVariableTemplate = R"Py( 98 @property 99 def {0}(self): 100 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 101 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] 102 )Py"; 103 104 /// Template for an optional element accessor: 105 /// {0} is the name of the accessor; 106 /// {1} is either 'operand' or 'result'; 107 /// {2} is the total number of element groups; 108 /// {3} is the position of the current group in the group list. 109 constexpr const char *opOneOptionalTemplate = R"Py( 110 @property 111 def {0}(self); 112 return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} 113 else None 114 )Py"; 115 116 /// Template for the variadic group accessor in the single variadic group case: 117 /// {0} is the name of the accessor; 118 /// {1} is either 'operand' or 'result'; 119 /// {2} is the total number of element groups; 120 /// {3} is the position of the current group in the group list. 121 constexpr const char *opOneVariadicTemplate = R"Py( 122 @property 123 def {0}(self): 124 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 125 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] 126 )Py"; 127 128 /// First part of the template for equally-sized variadic group accessor: 129 /// {0} is the name of the accessor; 130 /// {1} is either 'operand' or 'result'; 131 /// {2} is the total number of variadic groups; 132 /// {3} is the number of non-variadic groups preceding the current group; 133 /// {3} is the number of variadic groups preceding the current group. 134 constexpr const char *opVariadicEqualPrefixTemplate = R"Py( 135 @property 136 def {0}(self): 137 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py"; 138 139 /// Second part of the template for equally-sized case, accessing a single 140 /// element: 141 /// {0} is either 'operand' or 'result'. 142 constexpr const char *opVariadicEqualSimpleTemplate = R"Py( 143 return self.operation.{0}s[start] 144 )Py"; 145 146 /// Second part of the template for equally-sized case, accessing a variadic 147 /// group: 148 /// {0} is either 'operand' or 'result'. 149 constexpr const char *opVariadicEqualVariadicTemplate = R"Py( 150 return self.operation.{0}s[start:start + pg] 151 )Py"; 152 153 /// Template for an attribute-sized group accessor: 154 /// {0} is the name of the accessor; 155 /// {1} is either 'operand' or 'result'; 156 /// {2} is the position of the group in the group list; 157 /// {3} is a return suffix (expected [0] for single-element, empty for 158 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). 159 constexpr const char *opVariadicSegmentTemplate = R"Py( 160 @property 161 def {0}(self): 162 {1}_range = _ods_segmented_accessor( 163 self.operation.{1}s, 164 self.operation.attributes["{1}_segment_sizes"], {2}) 165 return {1}_range{3} 166 )Py"; 167 168 /// Template for a suffix when accessing an optional element in the 169 /// attribute-sized case: 170 /// {0} is either 'operand' or 'result'; 171 constexpr const char *opVariadicSegmentOptionalTrailingTemplate = 172 R"Py([0] if len({0}_range) > 0 else None)Py"; 173 174 /// Template for an operation attribute getter: 175 /// {0} is the name of the attribute sanitized for Python; 176 /// {1} is the Python type of the attribute; 177 /// {2} os the original name of the attribute. 178 constexpr const char *attributeGetterTemplate = R"Py( 179 @property 180 def {0}(self): 181 return {1}(self.operation.attributes["{2}"]) 182 )Py"; 183 184 /// Template for an optional operation attribute getter: 185 /// {0} is the name of the attribute sanitized for Python; 186 /// {1} is the Python type of the attribute; 187 /// {2} is the original name of the attribute. 188 constexpr const char *optionalAttributeGetterTemplate = R"Py( 189 @property 190 def {0}(self): 191 if "{2}" not in self.operation.attributes: 192 return None 193 return {1}(self.operation.attributes["{2}"]) 194 )Py"; 195 196 /// Template for a getter of a unit operation attribute, returns True of the 197 /// unit attribute is present, False otherwise (unit attributes have meaning 198 /// by mere presence): 199 /// {0} is the name of the attribute sanitized for Python, 200 /// {1} is the original name of the attribute. 201 constexpr const char *unitAttributeGetterTemplate = R"Py( 202 @property 203 def {0}(self): 204 return "{1}" in self.operation.attributes 205 )Py"; 206 207 /// Template for an operation attribute setter: 208 /// {0} is the name of the attribute sanitized for Python; 209 /// {1} is the original name of the attribute. 210 constexpr const char *attributeSetterTemplate = R"Py( 211 @{0}.setter 212 def {0}(self, value): 213 if value is None: 214 raise ValueError("'None' not allowed as value for mandatory attributes") 215 self.operation.attributes["{1}"] = value 216 )Py"; 217 218 /// Template for a setter of an optional operation attribute, setting to None 219 /// removes the attribute: 220 /// {0} is the name of the attribute sanitized for Python; 221 /// {1} is the original name of the attribute. 222 constexpr const char *optionalAttributeSetterTemplate = R"Py( 223 @{0}.setter 224 def {0}(self, value): 225 if value is not None: 226 self.operation.attributes["{1}"] = value 227 elif "{1}" in self.operation.attributes: 228 del self.operation.attributes["{1}"] 229 )Py"; 230 231 /// Template for a setter of a unit operation attribute, setting to None or 232 /// False removes the attribute: 233 /// {0} is the name of the attribute sanitized for Python; 234 /// {1} is the original name of the attribute. 235 constexpr const char *unitAttributeSetterTemplate = R"Py( 236 @{0}.setter 237 def {0}(self, value): 238 if bool(value): 239 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() 240 elif "{1}" in self.operation.attributes: 241 del self.operation.attributes["{1}"] 242 )Py"; 243 244 /// Template for a deleter of an optional or a unit operation attribute, removes 245 /// the attribute from the operation: 246 /// {0} is the name of the attribute sanitized for Python; 247 /// {1} is the original name of the attribute. 248 constexpr const char *attributeDeleterTemplate = R"Py( 249 @{0}.deleter 250 def {0}(self): 251 del self.operation.attributes["{1}"] 252 )Py"; 253 254 static llvm::cl::OptionCategory 255 clOpPythonBindingCat("Options for -gen-python-op-bindings"); 256 257 static llvm::cl::opt<std::string> 258 clDialectName("bind-dialect", 259 llvm::cl::desc("The dialect to run the generator for"), 260 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 261 262 using AttributeClasses = DenseMap<StringRef, StringRef>; 263 264 /// Checks whether `str` is a Python keyword. 265 static bool isPythonKeyword(StringRef str) { 266 static llvm::StringSet<> keywords( 267 {"and", "as", "assert", "break", "class", "continue", 268 "def", "del", "elif", "else", "except", "finally", 269 "for", "from", "global", "if", "import", "in", 270 "is", "lambda", "nonlocal", "not", "or", "pass", 271 "raise", "return", "try", "while", "with", "yield"}); 272 return keywords.contains(str); 273 } 274 275 /// Checks whether `str` would shadow a generated variable or attribute 276 /// part of the OpView API. 277 static bool isODSReserved(StringRef str) { 278 static llvm::StringSet<> reserved( 279 {"attributes", "create", "context", "ip", "operands", "print", "get_asm", 280 "loc", "verify", "regions", "result", "results", "self", "operation", 281 "DIALECT_NAMESPACE", "OPERATION_NAME"}); 282 return str.startswith("_ods_") || str.endswith("_ods") || 283 reserved.contains(str); 284 } 285 286 /// Modifies the `name` in a way that it becomes suitable for Python bindings 287 /// (does not change the `name` if it already is suitable) and returns the 288 /// modified version. 289 static std::string sanitizeName(StringRef name) { 290 if (isPythonKeyword(name) || isODSReserved(name)) 291 return (name + "_").str(); 292 return name.str(); 293 } 294 295 static std::string attrSizedTraitForKind(const char *kind) { 296 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", 297 llvm::StringRef(kind).take_front().upper(), 298 llvm::StringRef(kind).drop_front()); 299 } 300 301 /// Emits accessors to "elements" of an Op definition. Currently, the supported 302 /// elements are operands and results, indicated by `kind`, which must be either 303 /// `operand` or `result` and is used verbatim in the emitted code. 304 static void emitElementAccessors( 305 const Operator &op, raw_ostream &os, const char *kind, 306 llvm::function_ref<unsigned(const Operator &)> getNumVariadic, 307 llvm::function_ref<int(const Operator &)> getNumElements, 308 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 309 getElement) { 310 assert(llvm::is_contained( 311 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && 312 "unsupported kind"); 313 314 // Traits indicating how to process variadic elements. 315 std::string sameSizeTrait = 316 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", 317 llvm::StringRef(kind).take_front().upper(), 318 llvm::StringRef(kind).drop_front()); 319 std::string attrSizedTrait = attrSizedTraitForKind(kind); 320 321 unsigned numVariadic = getNumVariadic(op); 322 323 // If there is only one variadic element group, its size can be inferred from 324 // the total number of elements. If there are none, the generation is 325 // straightforward. 326 if (numVariadic <= 1) { 327 bool seenVariableLength = false; 328 for (int i = 0, e = getNumElements(op); i < e; ++i) { 329 const NamedTypeConstraint &element = getElement(op, i); 330 if (element.isVariableLength()) 331 seenVariableLength = true; 332 if (element.name.empty()) 333 continue; 334 if (element.isVariableLength()) { 335 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate 336 : opOneVariadicTemplate, 337 sanitizeName(element.name), kind, 338 getNumElements(op), i); 339 } else if (seenVariableLength) { 340 os << llvm::formatv(opSingleAfterVariableTemplate, 341 sanitizeName(element.name), kind, 342 getNumElements(op), i); 343 } else { 344 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, 345 i); 346 } 347 } 348 return; 349 } 350 351 // Handle the operations where variadic groups have the same size. 352 if (op.getTrait(sameSizeTrait)) { 353 int numPrecedingSimple = 0; 354 int numPrecedingVariadic = 0; 355 for (int i = 0, e = getNumElements(op); i < e; ++i) { 356 const NamedTypeConstraint &element = getElement(op, i); 357 if (!element.name.empty()) { 358 os << llvm::formatv(opVariadicEqualPrefixTemplate, 359 sanitizeName(element.name), kind, numVariadic, 360 numPrecedingSimple, numPrecedingVariadic); 361 os << llvm::formatv(element.isVariableLength() 362 ? opVariadicEqualVariadicTemplate 363 : opVariadicEqualSimpleTemplate, 364 kind); 365 } 366 if (element.isVariableLength()) 367 ++numPrecedingVariadic; 368 else 369 ++numPrecedingSimple; 370 } 371 return; 372 } 373 374 // Handle the operations where the size of groups (variadic or not) is 375 // provided as an attribute. For non-variadic elements, make sure to return 376 // an element rather than a singleton container. 377 if (op.getTrait(attrSizedTrait)) { 378 for (int i = 0, e = getNumElements(op); i < e; ++i) { 379 const NamedTypeConstraint &element = getElement(op, i); 380 if (element.name.empty()) 381 continue; 382 std::string trailing; 383 if (!element.isVariableLength()) 384 trailing = "[0]"; 385 else if (element.isOptional()) 386 trailing = std::string( 387 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); 388 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), 389 kind, i, trailing); 390 } 391 return; 392 } 393 394 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); 395 } 396 397 /// Free function helpers accessing Operator components. 398 static int getNumOperands(const Operator &op) { return op.getNumOperands(); } 399 static const NamedTypeConstraint &getOperand(const Operator &op, int i) { 400 return op.getOperand(i); 401 } 402 static int getNumResults(const Operator &op) { return op.getNumResults(); } 403 static const NamedTypeConstraint &getResult(const Operator &op, int i) { 404 return op.getResult(i); 405 } 406 407 /// Emits accessors to Op operands. 408 static void emitOperandAccessors(const Operator &op, raw_ostream &os) { 409 auto getNumVariadic = [](const Operator &oper) { 410 return oper.getNumVariableLengthOperands(); 411 }; 412 emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands, 413 getOperand); 414 } 415 416 /// Emits accessors Op results. 417 static void emitResultAccessors(const Operator &op, raw_ostream &os) { 418 auto getNumVariadic = [](const Operator &oper) { 419 return oper.getNumVariableLengthResults(); 420 }; 421 emitElementAccessors(op, os, "result", getNumVariadic, getNumResults, 422 getResult); 423 } 424 425 /// Emits accessors to Op attributes. 426 static void emitAttributeAccessors(const Operator &op, 427 const AttributeClasses &attributeClasses, 428 raw_ostream &os) { 429 for (const auto &namedAttr : op.getAttributes()) { 430 // Skip "derived" attributes because they are just C++ functions that we 431 // don't currently expose. 432 if (namedAttr.attr.isDerivedAttr()) 433 continue; 434 435 if (namedAttr.name.empty()) 436 continue; 437 438 std::string sanitizedName = sanitizeName(namedAttr.name); 439 440 // Unit attributes are handled specially. 441 if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 442 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, 443 namedAttr.name); 444 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, 445 namedAttr.name); 446 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 447 namedAttr.name); 448 continue; 449 } 450 451 // Other kinds of attributes need a mapping to a Python type. 452 if (!attributeClasses.count(namedAttr.attr.getStorageType().trim())) 453 continue; 454 455 StringRef pythonType = 456 attributeClasses.lookup(namedAttr.attr.getStorageType()); 457 if (namedAttr.attr.isOptional()) { 458 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName, 459 pythonType, namedAttr.name); 460 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName, 461 namedAttr.name); 462 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 463 namedAttr.name); 464 } else { 465 os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType, 466 namedAttr.name); 467 os << llvm::formatv(attributeSetterTemplate, sanitizedName, 468 namedAttr.name); 469 // Non-optional attributes cannot be deleted. 470 } 471 } 472 } 473 474 /// Template for the default auto-generated builder. 475 /// {0} is a comma-separated list of builder arguments, including the trailing 476 /// `loc` and `ip`; 477 /// {1} is the code populating `operands`, `results` and `attributes` fields. 478 constexpr const char *initTemplate = R"Py( 479 def __init__(self, {0}): 480 operands = [] 481 results = [] 482 attributes = {{} 483 {1} 484 super().__init__(self._ods_build_default( 485 attributes=attributes, operands=operands, results=results, 486 loc=loc, ip=ip)) 487 )Py"; 488 489 /// Template for appending a single element to the operand/result list. 490 /// {0} is either 'operand' or 'result'; 491 /// {1} is the field name. 492 constexpr const char *singleElementAppendTemplate = "{0}s.append({1})"; 493 494 /// Template for appending an optional element to the operand/result list. 495 /// {0} is either 'operand' or 'result'; 496 /// {1} is the field name. 497 constexpr const char *optionalAppendTemplate = 498 "if {1} is not None: {0}s.append({1})"; 499 500 /// Template for appending a a list of elements to the operand/result list. 501 /// {0} is either 'operand' or 'result'; 502 /// {1} is the field name. 503 constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})"; 504 505 /// Template for setting an attribute in the operation builder. 506 /// {0} is the attribute name; 507 /// {1} is the builder argument name. 508 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; 509 510 /// Template for setting an optional attribute in the operation builder. 511 /// {0} is the attribute name; 512 /// {1} is the builder argument name. 513 constexpr const char *initOptionalAttributeTemplate = 514 R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; 515 516 constexpr const char *initUnitAttributeTemplate = 517 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( 518 _ods_get_default_loc_context(loc)))Py"; 519 520 /// Populates `builderArgs` with the Python-compatible names of builder function 521 /// arguments, first the results, then the intermixed attributes and operands in 522 /// the same order as they appear in the `arguments` field of the op definition. 523 /// Additionally, `operandNames` is populated with names of operands in their 524 /// order of appearance. 525 static void 526 populateBuilderArgs(const Operator &op, 527 llvm::SmallVectorImpl<std::string> &builderArgs, 528 llvm::SmallVectorImpl<std::string> &operandNames) { 529 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 530 std::string name = op.getResultName(i).str(); 531 if (name.empty()) 532 name = llvm::formatv("_gen_res_{0}", i); 533 name = sanitizeName(name); 534 builderArgs.push_back(name); 535 } 536 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 537 std::string name = op.getArgName(i).str(); 538 if (name.empty()) 539 name = llvm::formatv("_gen_arg_{0}", i); 540 name = sanitizeName(name); 541 builderArgs.push_back(name); 542 if (!op.getArg(i).is<NamedAttribute *>()) 543 operandNames.push_back(name); 544 } 545 } 546 547 /// Populates `builderLines` with additional lines that are required in the 548 /// builder to set up operation attributes. `argNames` is expected to contain 549 /// the names of builder arguments that correspond to op arguments, i.e. to the 550 /// operands and attributes in the same order as they appear in the `arguments` 551 /// field. 552 static void 553 populateBuilderLinesAttr(const Operator &op, 554 llvm::ArrayRef<std::string> argNames, 555 llvm::SmallVectorImpl<std::string> &builderLines) { 556 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 557 Argument arg = op.getArg(i); 558 auto *attribute = arg.dyn_cast<NamedAttribute *>(); 559 if (!attribute) 560 continue; 561 562 // Unit attributes are handled specially. 563 if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 564 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, 565 attribute->name, argNames[i])); 566 continue; 567 } 568 569 builderLines.push_back(llvm::formatv(attribute->attr.isOptional() 570 ? initOptionalAttributeTemplate 571 : initAttributeTemplate, 572 attribute->name, argNames[i])); 573 } 574 } 575 576 /// Populates `builderLines` with additional lines that are required in the 577 /// builder. `kind` must be either "operand" or "result". `names` contains the 578 /// names of init arguments that correspond to the elements. 579 static void populateBuilderLines( 580 const Operator &op, const char *kind, llvm::ArrayRef<std::string> names, 581 llvm::SmallVectorImpl<std::string> &builderLines, 582 llvm::function_ref<int(const Operator &)> getNumElements, 583 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 584 getElement) { 585 bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr; 586 587 // For each element, find or generate a name. 588 for (int i = 0, e = getNumElements(op); i < e; ++i) { 589 const NamedTypeConstraint &element = getElement(op, i); 590 std::string name = names[i]; 591 592 // Choose the formatting string based on the element kind. 593 llvm::StringRef formatString; 594 if (!element.isVariableLength()) { 595 formatString = singleElementAppendTemplate; 596 } else if (element.isOptional()) { 597 formatString = optionalAppendTemplate; 598 } else { 599 assert(element.isVariadic() && "unhandled element group type"); 600 // If emitting with sizedSegments, then we add the actual list typed 601 // element using the singleElementAppendTemplate. Otherwise, we extend 602 // the actual operands. 603 if (sizedSegments) { 604 // Append the list as is. 605 formatString = singleElementAppendTemplate; 606 } else { 607 // Append the list elements. 608 formatString = multiElementAppendTemplate; 609 } 610 } 611 612 // Add the lines. 613 builderLines.push_back(llvm::formatv(formatString.data(), kind, name)); 614 } 615 } 616 617 /// Emits a default builder constructing an operation from the list of its 618 /// result types, followed by a list of its operands. 619 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { 620 // If we are asked to skip default builders, comply. 621 if (op.skipDefaultBuilders()) 622 return; 623 624 llvm::SmallVector<std::string, 8> builderArgs; 625 llvm::SmallVector<std::string, 8> builderLines; 626 llvm::SmallVector<std::string, 4> operandArgNames; 627 builderArgs.reserve(op.getNumOperands() + op.getNumResults() + 628 op.getNumNativeAttributes()); 629 populateBuilderArgs(op, builderArgs, operandArgNames); 630 populateBuilderLines( 631 op, "result", 632 llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()), 633 builderLines, getNumResults, getResult); 634 populateBuilderLines(op, "operand", operandArgNames, builderLines, 635 getNumOperands, getOperand); 636 populateBuilderLinesAttr( 637 op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()), 638 builderLines); 639 640 builderArgs.push_back("loc=None"); 641 builderArgs.push_back("ip=None"); 642 os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "), 643 llvm::join(builderLines, "\n ")); 644 } 645 646 static void constructAttributeMapping(const llvm::RecordKeeper &records, 647 AttributeClasses &attributeClasses) { 648 for (const llvm::Record *rec : 649 records.getAllDerivedDefinitions("PythonAttr")) { 650 attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(), 651 rec->getValueAsString("pythonType").trim()); 652 } 653 } 654 655 static void emitSegmentSpec( 656 const Operator &op, const char *kind, 657 llvm::function_ref<int(const Operator &)> getNumElements, 658 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 659 getElement, 660 raw_ostream &os) { 661 std::string segmentSpec("["); 662 for (int i = 0, e = getNumElements(op); i < e; ++i) { 663 const NamedTypeConstraint &element = getElement(op, i); 664 if (element.isVariableLength()) { 665 segmentSpec.append("-1,"); 666 } else if (element.isOptional()) { 667 segmentSpec.append("0,"); 668 } else { 669 segmentSpec.append("1,"); 670 } 671 } 672 segmentSpec.append("]"); 673 674 os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); 675 } 676 677 static void emitRegionAttributes(const Operator &op, raw_ostream &os) { 678 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). 679 // Note that the base OpView class defines this as (0, True). 680 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); 681 os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount, 682 op.hasNoVariadicRegions() ? "True" : "False"); 683 } 684 685 /// Emits bindings for a specific Op to the given output stream. 686 static void emitOpBindings(const Operator &op, 687 const AttributeClasses &attributeClasses, 688 raw_ostream &os) { 689 os << llvm::formatv(opClassTemplate, op.getCppClassName(), 690 op.getOperationName()); 691 692 // Sized segments. 693 if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { 694 emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); 695 } 696 if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { 697 emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); 698 } 699 700 emitRegionAttributes(op, os); 701 emitDefaultOpBuilder(op, os); 702 emitOperandAccessors(op, os); 703 emitAttributeAccessors(op, attributeClasses, os); 704 emitResultAccessors(op, os); 705 } 706 707 /// Emits bindings for the dialect specified in the command line, including file 708 /// headers and utilities. Returns `false` on success to comply with Tablegen 709 /// registration requirements. 710 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { 711 if (clDialectName.empty()) 712 llvm::PrintFatalError("dialect name not provided"); 713 714 AttributeClasses attributeClasses; 715 constructAttributeMapping(records, attributeClasses); 716 717 os << llvm::formatv(fileHeader, clDialectName.getValue()); 718 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); 719 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { 720 Operator op(rec); 721 if (op.getDialectName() == clDialectName.getValue()) 722 emitOpBindings(op, attributeClasses, os); 723 } 724 return false; 725 } 726 727 static GenRegistration 728 genPythonBindings("gen-python-op-bindings", 729 "Generate Python bindings for MLIR Ops", &emitAllOps); 730