1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python 10 // binding classes wrapping a generic operation API. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/GenInfo.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/ADT/StringSet.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace mlir; 23 using namespace mlir::tblgen; 24 25 /// File header and includes. 26 /// {0} is the dialect namespace. 27 constexpr const char *fileHeader = R"Py( 28 # Autogenerated by mlir-tblgen; don't manually edit. 29 30 from ._ods_common import _cext as _ods_cext 31 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context 32 _ods_ir = _ods_cext.ir 33 34 try: 35 from . import _{0}_ops_ext as _ods_ext_module 36 except ImportError: 37 _ods_ext_module = None 38 39 )Py"; 40 41 /// Template for dialect class: 42 /// {0} is the dialect namespace. 43 constexpr const char *dialectClassTemplate = R"Py( 44 @_ods_cext.register_dialect 45 class _Dialect(_ods_ir.Dialect): 46 DIALECT_NAMESPACE = "{0}" 47 pass 48 49 )Py"; 50 51 /// Template for operation class: 52 /// {0} is the Python class name; 53 /// {1} is the operation name. 54 constexpr const char *opClassTemplate = R"Py( 55 @_ods_cext.register_operation(_Dialect) 56 @_ods_extend_opview_class(_ods_ext_module) 57 class {0}(_ods_ir.OpView): 58 OPERATION_NAME = "{1}" 59 )Py"; 60 61 /// Template for class level declarations of operand and result 62 /// segment specs. 63 /// {0} is either "OPERAND" or "RESULT" 64 /// {1} is the segment spec 65 /// Each segment spec is either None (default) or an array of integers 66 /// where: 67 /// 1 = single element (expect non sequence operand/result) 68 /// -1 = operand/result is a sequence corresponding to a variadic 69 constexpr const char *opClassSizedSegmentsTemplate = R"Py( 70 _ODS_{0}_SEGMENTS = {1} 71 )Py"; 72 73 /// Template for class level declarations of the _ODS_REGIONS spec: 74 /// {0} is the minimum number of regions 75 /// {1} is the Python bool literal for hasNoVariadicRegions 76 constexpr const char *opClassRegionSpecTemplate = R"Py( 77 _ODS_REGIONS = ({0}, {1}) 78 )Py"; 79 80 /// Template for single-element accessor: 81 /// {0} is the name of the accessor; 82 /// {1} is either 'operand' or 'result'; 83 /// {2} is the position in the element list. 84 constexpr const char *opSingleTemplate = R"Py( 85 @property 86 def {0}(self): 87 return self.operation.{1}s[{2}] 88 )Py"; 89 90 /// Template for single-element accessor after a variable-length group: 91 /// {0} is the name of the accessor; 92 /// {1} is either 'operand' or 'result'; 93 /// {2} is the total number of element groups; 94 /// {3} is the position of the current group in the group list. 95 /// This works for both a single variadic group (non-negative length) and an 96 /// single optional element (zero length if the element is absent). 97 constexpr const char *opSingleAfterVariableTemplate = R"Py( 98 @property 99 def {0}(self): 100 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 101 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] 102 )Py"; 103 104 /// Template for an optional element accessor: 105 /// {0} is the name of the accessor; 106 /// {1} is either 'operand' or 'result'; 107 /// {2} is the total number of element groups; 108 /// {3} is the position of the current group in the group list. 109 constexpr const char *opOneOptionalTemplate = R"Py( 110 @property 111 def {0}(self): 112 return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None 113 )Py"; 114 115 /// Template for the variadic group accessor in the single variadic group case: 116 /// {0} is the name of the accessor; 117 /// {1} is either 'operand' or 'result'; 118 /// {2} is the total number of element groups; 119 /// {3} is the position of the current group in the group list. 120 constexpr const char *opOneVariadicTemplate = R"Py( 121 @property 122 def {0}(self): 123 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 124 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] 125 )Py"; 126 127 /// First part of the template for equally-sized variadic group accessor: 128 /// {0} is the name of the accessor; 129 /// {1} is either 'operand' or 'result'; 130 /// {2} is the total number of variadic groups; 131 /// {3} is the number of non-variadic groups preceding the current group; 132 /// {3} is the number of variadic groups preceding the current group. 133 constexpr const char *opVariadicEqualPrefixTemplate = R"Py( 134 @property 135 def {0}(self): 136 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py"; 137 138 /// Second part of the template for equally-sized case, accessing a single 139 /// element: 140 /// {0} is either 'operand' or 'result'. 141 constexpr const char *opVariadicEqualSimpleTemplate = R"Py( 142 return self.operation.{0}s[start] 143 )Py"; 144 145 /// Second part of the template for equally-sized case, accessing a variadic 146 /// group: 147 /// {0} is either 'operand' or 'result'. 148 constexpr const char *opVariadicEqualVariadicTemplate = R"Py( 149 return self.operation.{0}s[start:start + pg] 150 )Py"; 151 152 /// Template for an attribute-sized group accessor: 153 /// {0} is the name of the accessor; 154 /// {1} is either 'operand' or 'result'; 155 /// {2} is the position of the group in the group list; 156 /// {3} is a return suffix (expected [0] for single-element, empty for 157 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). 158 constexpr const char *opVariadicSegmentTemplate = R"Py( 159 @property 160 def {0}(self): 161 {1}_range = _ods_segmented_accessor( 162 self.operation.{1}s, 163 self.operation.attributes["{1}_segment_sizes"], {2}) 164 return {1}_range{3} 165 )Py"; 166 167 /// Template for a suffix when accessing an optional element in the 168 /// attribute-sized case: 169 /// {0} is either 'operand' or 'result'; 170 constexpr const char *opVariadicSegmentOptionalTrailingTemplate = 171 R"Py([0] if len({0}_range) > 0 else None)Py"; 172 173 /// Template for an operation attribute getter: 174 /// {0} is the name of the attribute sanitized for Python; 175 /// {1} is the Python type of the attribute; 176 /// {2} os the original name of the attribute. 177 constexpr const char *attributeGetterTemplate = R"Py( 178 @property 179 def {0}(self): 180 return {1}(self.operation.attributes["{2}"]) 181 )Py"; 182 183 /// Template for an optional operation attribute getter: 184 /// {0} is the name of the attribute sanitized for Python; 185 /// {1} is the Python type of the attribute; 186 /// {2} is the original name of the attribute. 187 constexpr const char *optionalAttributeGetterTemplate = R"Py( 188 @property 189 def {0}(self): 190 if "{2}" not in self.operation.attributes: 191 return None 192 return {1}(self.operation.attributes["{2}"]) 193 )Py"; 194 195 /// Template for a getter of a unit operation attribute, returns True of the 196 /// unit attribute is present, False otherwise (unit attributes have meaning 197 /// by mere presence): 198 /// {0} is the name of the attribute sanitized for Python, 199 /// {1} is the original name of the attribute. 200 constexpr const char *unitAttributeGetterTemplate = R"Py( 201 @property 202 def {0}(self): 203 return "{1}" in self.operation.attributes 204 )Py"; 205 206 /// Template for an operation attribute setter: 207 /// {0} is the name of the attribute sanitized for Python; 208 /// {1} is the original name of the attribute. 209 constexpr const char *attributeSetterTemplate = R"Py( 210 @{0}.setter 211 def {0}(self, value): 212 if value is None: 213 raise ValueError("'None' not allowed as value for mandatory attributes") 214 self.operation.attributes["{1}"] = value 215 )Py"; 216 217 /// Template for a setter of an optional operation attribute, setting to None 218 /// removes the attribute: 219 /// {0} is the name of the attribute sanitized for Python; 220 /// {1} is the original name of the attribute. 221 constexpr const char *optionalAttributeSetterTemplate = R"Py( 222 @{0}.setter 223 def {0}(self, value): 224 if value is not None: 225 self.operation.attributes["{1}"] = value 226 elif "{1}" in self.operation.attributes: 227 del self.operation.attributes["{1}"] 228 )Py"; 229 230 /// Template for a setter of a unit operation attribute, setting to None or 231 /// False removes the attribute: 232 /// {0} is the name of the attribute sanitized for Python; 233 /// {1} is the original name of the attribute. 234 constexpr const char *unitAttributeSetterTemplate = R"Py( 235 @{0}.setter 236 def {0}(self, value): 237 if bool(value): 238 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() 239 elif "{1}" in self.operation.attributes: 240 del self.operation.attributes["{1}"] 241 )Py"; 242 243 /// Template for a deleter of an optional or a unit operation attribute, removes 244 /// the attribute from the operation: 245 /// {0} is the name of the attribute sanitized for Python; 246 /// {1} is the original name of the attribute. 247 constexpr const char *attributeDeleterTemplate = R"Py( 248 @{0}.deleter 249 def {0}(self): 250 del self.operation.attributes["{1}"] 251 )Py"; 252 253 static llvm::cl::OptionCategory 254 clOpPythonBindingCat("Options for -gen-python-op-bindings"); 255 256 static llvm::cl::opt<std::string> 257 clDialectName("bind-dialect", 258 llvm::cl::desc("The dialect to run the generator for"), 259 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 260 261 using AttributeClasses = DenseMap<StringRef, StringRef>; 262 263 /// Checks whether `str` is a Python keyword. 264 static bool isPythonKeyword(StringRef str) { 265 static llvm::StringSet<> keywords( 266 {"and", "as", "assert", "break", "class", "continue", 267 "def", "del", "elif", "else", "except", "finally", 268 "for", "from", "global", "if", "import", "in", 269 "is", "lambda", "nonlocal", "not", "or", "pass", 270 "raise", "return", "try", "while", "with", "yield"}); 271 return keywords.contains(str); 272 } 273 274 /// Checks whether `str` would shadow a generated variable or attribute 275 /// part of the OpView API. 276 static bool isODSReserved(StringRef str) { 277 static llvm::StringSet<> reserved( 278 {"attributes", "create", "context", "ip", "operands", "print", "get_asm", 279 "loc", "verify", "regions", "results", "self", "operation", 280 "DIALECT_NAMESPACE", "OPERATION_NAME"}); 281 return str.startswith("_ods_") || str.endswith("_ods") || 282 reserved.contains(str); 283 } 284 285 /// Modifies the `name` in a way that it becomes suitable for Python bindings 286 /// (does not change the `name` if it already is suitable) and returns the 287 /// modified version. 288 static std::string sanitizeName(StringRef name) { 289 if (isPythonKeyword(name) || isODSReserved(name)) 290 return (name + "_").str(); 291 return name.str(); 292 } 293 294 static std::string attrSizedTraitForKind(const char *kind) { 295 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", 296 llvm::StringRef(kind).take_front().upper(), 297 llvm::StringRef(kind).drop_front()); 298 } 299 300 /// Emits accessors to "elements" of an Op definition. Currently, the supported 301 /// elements are operands and results, indicated by `kind`, which must be either 302 /// `operand` or `result` and is used verbatim in the emitted code. 303 static void emitElementAccessors( 304 const Operator &op, raw_ostream &os, const char *kind, 305 llvm::function_ref<unsigned(const Operator &)> getNumVariadic, 306 llvm::function_ref<int(const Operator &)> getNumElements, 307 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 308 getElement) { 309 assert(llvm::is_contained( 310 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && 311 "unsupported kind"); 312 313 // Traits indicating how to process variadic elements. 314 std::string sameSizeTrait = 315 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", 316 llvm::StringRef(kind).take_front().upper(), 317 llvm::StringRef(kind).drop_front()); 318 std::string attrSizedTrait = attrSizedTraitForKind(kind); 319 320 unsigned numVariadic = getNumVariadic(op); 321 322 // If there is only one variadic element group, its size can be inferred from 323 // the total number of elements. If there are none, the generation is 324 // straightforward. 325 if (numVariadic <= 1) { 326 bool seenVariableLength = false; 327 for (int i = 0, e = getNumElements(op); i < e; ++i) { 328 const NamedTypeConstraint &element = getElement(op, i); 329 if (element.isVariableLength()) 330 seenVariableLength = true; 331 if (element.name.empty()) 332 continue; 333 if (element.isVariableLength()) { 334 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate 335 : opOneVariadicTemplate, 336 sanitizeName(element.name), kind, 337 getNumElements(op), i); 338 } else if (seenVariableLength) { 339 os << llvm::formatv(opSingleAfterVariableTemplate, 340 sanitizeName(element.name), kind, 341 getNumElements(op), i); 342 } else { 343 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, 344 i); 345 } 346 } 347 return; 348 } 349 350 // Handle the operations where variadic groups have the same size. 351 if (op.getTrait(sameSizeTrait)) { 352 int numPrecedingSimple = 0; 353 int numPrecedingVariadic = 0; 354 for (int i = 0, e = getNumElements(op); i < e; ++i) { 355 const NamedTypeConstraint &element = getElement(op, i); 356 if (!element.name.empty()) { 357 os << llvm::formatv(opVariadicEqualPrefixTemplate, 358 sanitizeName(element.name), kind, numVariadic, 359 numPrecedingSimple, numPrecedingVariadic); 360 os << llvm::formatv(element.isVariableLength() 361 ? opVariadicEqualVariadicTemplate 362 : opVariadicEqualSimpleTemplate, 363 kind); 364 } 365 if (element.isVariableLength()) 366 ++numPrecedingVariadic; 367 else 368 ++numPrecedingSimple; 369 } 370 return; 371 } 372 373 // Handle the operations where the size of groups (variadic or not) is 374 // provided as an attribute. For non-variadic elements, make sure to return 375 // an element rather than a singleton container. 376 if (op.getTrait(attrSizedTrait)) { 377 for (int i = 0, e = getNumElements(op); i < e; ++i) { 378 const NamedTypeConstraint &element = getElement(op, i); 379 if (element.name.empty()) 380 continue; 381 std::string trailing; 382 if (!element.isVariableLength()) 383 trailing = "[0]"; 384 else if (element.isOptional()) 385 trailing = std::string( 386 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); 387 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), 388 kind, i, trailing); 389 } 390 return; 391 } 392 393 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); 394 } 395 396 /// Free function helpers accessing Operator components. 397 static int getNumOperands(const Operator &op) { return op.getNumOperands(); } 398 static const NamedTypeConstraint &getOperand(const Operator &op, int i) { 399 return op.getOperand(i); 400 } 401 static int getNumResults(const Operator &op) { return op.getNumResults(); } 402 static const NamedTypeConstraint &getResult(const Operator &op, int i) { 403 return op.getResult(i); 404 } 405 406 /// Emits accessors to Op operands. 407 static void emitOperandAccessors(const Operator &op, raw_ostream &os) { 408 auto getNumVariadic = [](const Operator &oper) { 409 return oper.getNumVariableLengthOperands(); 410 }; 411 emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands, 412 getOperand); 413 } 414 415 /// Emits accessors Op results. 416 static void emitResultAccessors(const Operator &op, raw_ostream &os) { 417 auto getNumVariadic = [](const Operator &oper) { 418 return oper.getNumVariableLengthResults(); 419 }; 420 emitElementAccessors(op, os, "result", getNumVariadic, getNumResults, 421 getResult); 422 } 423 424 /// Emits accessors to Op attributes. 425 static void emitAttributeAccessors(const Operator &op, 426 const AttributeClasses &attributeClasses, 427 raw_ostream &os) { 428 for (const auto &namedAttr : op.getAttributes()) { 429 // Skip "derived" attributes because they are just C++ functions that we 430 // don't currently expose. 431 if (namedAttr.attr.isDerivedAttr()) 432 continue; 433 434 if (namedAttr.name.empty()) 435 continue; 436 437 std::string sanitizedName = sanitizeName(namedAttr.name); 438 439 // Unit attributes are handled specially. 440 if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 441 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, 442 namedAttr.name); 443 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, 444 namedAttr.name); 445 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 446 namedAttr.name); 447 continue; 448 } 449 450 // Other kinds of attributes need a mapping to a Python type. 451 if (!attributeClasses.count(namedAttr.attr.getStorageType().trim())) 452 continue; 453 454 StringRef pythonType = 455 attributeClasses.lookup(namedAttr.attr.getStorageType()); 456 if (namedAttr.attr.isOptional()) { 457 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName, 458 pythonType, namedAttr.name); 459 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName, 460 namedAttr.name); 461 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 462 namedAttr.name); 463 } else { 464 os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType, 465 namedAttr.name); 466 os << llvm::formatv(attributeSetterTemplate, sanitizedName, 467 namedAttr.name); 468 // Non-optional attributes cannot be deleted. 469 } 470 } 471 } 472 473 /// Template for the default auto-generated builder. 474 /// {0} is a comma-separated list of builder arguments, including the trailing 475 /// `loc` and `ip`; 476 /// {1} is the code populating `operands`, `results` and `attributes` fields. 477 constexpr const char *initTemplate = R"Py( 478 def __init__(self, {0}): 479 operands = [] 480 results = [] 481 attributes = {{} 482 {1} 483 super().__init__(self.build_generic( 484 attributes=attributes, results=results, operands=operands, 485 loc=loc, ip=ip)) 486 )Py"; 487 488 /// Template for appending a single element to the operand/result list. 489 /// {0} is either 'operand' or 'result'; 490 /// {1} is the field name. 491 constexpr const char *singleElementAppendTemplate = "{0}s.append({1})"; 492 493 /// Template for appending an optional element to the operand/result list. 494 /// {0} is either 'operand' or 'result'; 495 /// {1} is the field name. 496 constexpr const char *optionalAppendTemplate = 497 "if {1} is not None: {0}s.append({1})"; 498 499 /// Template for appending a a list of elements to the operand/result list. 500 /// {0} is either 'operand' or 'result'; 501 /// {1} is the field name. 502 constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})"; 503 504 /// Template for setting an attribute in the operation builder. 505 /// {0} is the attribute name; 506 /// {1} is the builder argument name. 507 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; 508 509 /// Template for setting an optional attribute in the operation builder. 510 /// {0} is the attribute name; 511 /// {1} is the builder argument name. 512 constexpr const char *initOptionalAttributeTemplate = 513 R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; 514 515 constexpr const char *initUnitAttributeTemplate = 516 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( 517 _ods_get_default_loc_context(loc)))Py"; 518 519 /// Populates `builderArgs` with the Python-compatible names of builder function 520 /// arguments, first the results, then the intermixed attributes and operands in 521 /// the same order as they appear in the `arguments` field of the op definition. 522 /// Additionally, `operandNames` is populated with names of operands in their 523 /// order of appearance. 524 static void 525 populateBuilderArgs(const Operator &op, 526 llvm::SmallVectorImpl<std::string> &builderArgs, 527 llvm::SmallVectorImpl<std::string> &operandNames) { 528 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 529 std::string name = op.getResultName(i).str(); 530 if (name.empty()) { 531 if (op.getNumResults() == 1) { 532 // Special case for one result, make the default name be 'result' 533 // to properly match the built-in result accessor. 534 name = "result"; 535 } else { 536 name = llvm::formatv("_gen_res_{0}", i); 537 } 538 } 539 name = sanitizeName(name); 540 builderArgs.push_back(name); 541 } 542 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 543 std::string name = op.getArgName(i).str(); 544 if (name.empty()) 545 name = llvm::formatv("_gen_arg_{0}", i); 546 name = sanitizeName(name); 547 builderArgs.push_back(name); 548 if (!op.getArg(i).is<NamedAttribute *>()) 549 operandNames.push_back(name); 550 } 551 } 552 553 /// Populates `builderLines` with additional lines that are required in the 554 /// builder to set up operation attributes. `argNames` is expected to contain 555 /// the names of builder arguments that correspond to op arguments, i.e. to the 556 /// operands and attributes in the same order as they appear in the `arguments` 557 /// field. 558 static void 559 populateBuilderLinesAttr(const Operator &op, 560 llvm::ArrayRef<std::string> argNames, 561 llvm::SmallVectorImpl<std::string> &builderLines) { 562 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 563 Argument arg = op.getArg(i); 564 auto *attribute = arg.dyn_cast<NamedAttribute *>(); 565 if (!attribute) 566 continue; 567 568 // Unit attributes are handled specially. 569 if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 570 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, 571 attribute->name, argNames[i])); 572 continue; 573 } 574 575 builderLines.push_back(llvm::formatv(attribute->attr.isOptional() 576 ? initOptionalAttributeTemplate 577 : initAttributeTemplate, 578 attribute->name, argNames[i])); 579 } 580 } 581 582 /// Populates `builderLines` with additional lines that are required in the 583 /// builder. `kind` must be either "operand" or "result". `names` contains the 584 /// names of init arguments that correspond to the elements. 585 static void populateBuilderLines( 586 const Operator &op, const char *kind, llvm::ArrayRef<std::string> names, 587 llvm::SmallVectorImpl<std::string> &builderLines, 588 llvm::function_ref<int(const Operator &)> getNumElements, 589 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 590 getElement) { 591 bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr; 592 593 // For each element, find or generate a name. 594 for (int i = 0, e = getNumElements(op); i < e; ++i) { 595 const NamedTypeConstraint &element = getElement(op, i); 596 std::string name = names[i]; 597 598 // Choose the formatting string based on the element kind. 599 llvm::StringRef formatString; 600 if (!element.isVariableLength()) { 601 formatString = singleElementAppendTemplate; 602 } else if (element.isOptional()) { 603 formatString = optionalAppendTemplate; 604 } else { 605 assert(element.isVariadic() && "unhandled element group type"); 606 // If emitting with sizedSegments, then we add the actual list typed 607 // element using the singleElementAppendTemplate. Otherwise, we extend 608 // the actual operands. 609 if (sizedSegments) { 610 // Append the list as is. 611 formatString = singleElementAppendTemplate; 612 } else { 613 // Append the list elements. 614 formatString = multiElementAppendTemplate; 615 } 616 } 617 618 // Add the lines. 619 builderLines.push_back(llvm::formatv(formatString.data(), kind, name)); 620 } 621 } 622 623 /// Emits a default builder constructing an operation from the list of its 624 /// result types, followed by a list of its operands. 625 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { 626 // If we are asked to skip default builders, comply. 627 if (op.skipDefaultBuilders()) 628 return; 629 630 llvm::SmallVector<std::string, 8> builderArgs; 631 llvm::SmallVector<std::string, 8> builderLines; 632 llvm::SmallVector<std::string, 4> operandArgNames; 633 builderArgs.reserve(op.getNumOperands() + op.getNumResults() + 634 op.getNumNativeAttributes()); 635 populateBuilderArgs(op, builderArgs, operandArgNames); 636 populateBuilderLines( 637 op, "result", 638 llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()), 639 builderLines, getNumResults, getResult); 640 populateBuilderLines(op, "operand", operandArgNames, builderLines, 641 getNumOperands, getOperand); 642 populateBuilderLinesAttr( 643 op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()), 644 builderLines); 645 646 builderArgs.push_back("*"); 647 builderArgs.push_back("loc=None"); 648 builderArgs.push_back("ip=None"); 649 os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "), 650 llvm::join(builderLines, "\n ")); 651 } 652 653 static void constructAttributeMapping(const llvm::RecordKeeper &records, 654 AttributeClasses &attributeClasses) { 655 for (const llvm::Record *rec : 656 records.getAllDerivedDefinitions("PythonAttr")) { 657 attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(), 658 rec->getValueAsString("pythonType").trim()); 659 } 660 } 661 662 static void emitSegmentSpec( 663 const Operator &op, const char *kind, 664 llvm::function_ref<int(const Operator &)> getNumElements, 665 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 666 getElement, 667 raw_ostream &os) { 668 std::string segmentSpec("["); 669 for (int i = 0, e = getNumElements(op); i < e; ++i) { 670 const NamedTypeConstraint &element = getElement(op, i); 671 if (element.isVariableLength()) { 672 segmentSpec.append("-1,"); 673 } else if (element.isOptional()) { 674 segmentSpec.append("0,"); 675 } else { 676 segmentSpec.append("1,"); 677 } 678 } 679 segmentSpec.append("]"); 680 681 os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); 682 } 683 684 static void emitRegionAttributes(const Operator &op, raw_ostream &os) { 685 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). 686 // Note that the base OpView class defines this as (0, True). 687 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); 688 os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount, 689 op.hasNoVariadicRegions() ? "True" : "False"); 690 } 691 692 /// Emits bindings for a specific Op to the given output stream. 693 static void emitOpBindings(const Operator &op, 694 const AttributeClasses &attributeClasses, 695 raw_ostream &os) { 696 os << llvm::formatv(opClassTemplate, op.getCppClassName(), 697 op.getOperationName()); 698 699 // Sized segments. 700 if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { 701 emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); 702 } 703 if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { 704 emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); 705 } 706 707 emitRegionAttributes(op, os); 708 emitDefaultOpBuilder(op, os); 709 emitOperandAccessors(op, os); 710 emitAttributeAccessors(op, attributeClasses, os); 711 emitResultAccessors(op, os); 712 } 713 714 /// Emits bindings for the dialect specified in the command line, including file 715 /// headers and utilities. Returns `false` on success to comply with Tablegen 716 /// registration requirements. 717 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { 718 if (clDialectName.empty()) 719 llvm::PrintFatalError("dialect name not provided"); 720 721 AttributeClasses attributeClasses; 722 constructAttributeMapping(records, attributeClasses); 723 724 os << llvm::formatv(fileHeader, clDialectName.getValue()); 725 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); 726 727 if (clDialectName == "builtin") 728 clDialectName = ""; 729 730 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { 731 Operator op(rec); 732 if (op.getDialectName() == clDialectName.getValue()) 733 emitOpBindings(op, attributeClasses, os); 734 } 735 return false; 736 } 737 738 static GenRegistration 739 genPythonBindings("gen-python-op-bindings", 740 "Generate Python bindings for MLIR Ops", &emitAllOps); 741