1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python 10 // binding classes wrapping a generic operation API. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/GenInfo.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/ADT/StringSet.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace mlir; 23 using namespace mlir::tblgen; 24 25 /// File header and includes. 26 /// {0} is the dialect namespace. 27 constexpr const char *fileHeader = R"Py( 28 # Autogenerated by mlir-tblgen; don't manually edit. 29 30 from ._ods_common import _cext as _ods_cext 31 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context 32 _ods_ir = _ods_cext.ir 33 34 try: 35 from . import _{0}_ops_ext as _ods_ext_module 36 except ImportError: 37 _ods_ext_module = None 38 39 import builtins 40 41 )Py"; 42 43 /// Template for dialect class: 44 /// {0} is the dialect namespace. 45 constexpr const char *dialectClassTemplate = R"Py( 46 @_ods_cext.register_dialect 47 class _Dialect(_ods_ir.Dialect): 48 DIALECT_NAMESPACE = "{0}" 49 pass 50 51 )Py"; 52 53 /// Template for operation class: 54 /// {0} is the Python class name; 55 /// {1} is the operation name. 56 constexpr const char *opClassTemplate = R"Py( 57 @_ods_cext.register_operation(_Dialect) 58 @_ods_extend_opview_class(_ods_ext_module) 59 class {0}(_ods_ir.OpView): 60 OPERATION_NAME = "{1}" 61 )Py"; 62 63 /// Template for class level declarations of operand and result 64 /// segment specs. 65 /// {0} is either "OPERAND" or "RESULT" 66 /// {1} is the segment spec 67 /// Each segment spec is either None (default) or an array of integers 68 /// where: 69 /// 1 = single element (expect non sequence operand/result) 70 /// -1 = operand/result is a sequence corresponding to a variadic 71 constexpr const char *opClassSizedSegmentsTemplate = R"Py( 72 _ODS_{0}_SEGMENTS = {1} 73 )Py"; 74 75 /// Template for class level declarations of the _ODS_REGIONS spec: 76 /// {0} is the minimum number of regions 77 /// {1} is the Python bool literal for hasNoVariadicRegions 78 constexpr const char *opClassRegionSpecTemplate = R"Py( 79 _ODS_REGIONS = ({0}, {1}) 80 )Py"; 81 82 /// Template for single-element accessor: 83 /// {0} is the name of the accessor; 84 /// {1} is either 'operand' or 'result'; 85 /// {2} is the position in the element list. 86 constexpr const char *opSingleTemplate = R"Py( 87 @builtins.property 88 def {0}(self): 89 return self.operation.{1}s[{2}] 90 )Py"; 91 92 /// Template for single-element accessor after a variable-length group: 93 /// {0} is the name of the accessor; 94 /// {1} is either 'operand' or 'result'; 95 /// {2} is the total number of element groups; 96 /// {3} is the position of the current group in the group list. 97 /// This works for both a single variadic group (non-negative length) and an 98 /// single optional element (zero length if the element is absent). 99 constexpr const char *opSingleAfterVariableTemplate = R"Py( 100 @builtins.property 101 def {0}(self): 102 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 103 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] 104 )Py"; 105 106 /// Template for an optional element accessor: 107 /// {0} is the name of the accessor; 108 /// {1} is either 'operand' or 'result'; 109 /// {2} is the total number of element groups; 110 /// {3} is the position of the current group in the group list. 111 constexpr const char *opOneOptionalTemplate = R"Py( 112 @builtins.property 113 def {0}(self): 114 return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None 115 )Py"; 116 117 /// Template for the variadic group accessor in the single variadic group case: 118 /// {0} is the name of the accessor; 119 /// {1} is either 'operand' or 'result'; 120 /// {2} is the total number of element groups; 121 /// {3} is the position of the current group in the group list. 122 constexpr const char *opOneVariadicTemplate = R"Py( 123 @builtins.property 124 def {0}(self): 125 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 126 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] 127 )Py"; 128 129 /// First part of the template for equally-sized variadic group accessor: 130 /// {0} is the name of the accessor; 131 /// {1} is either 'operand' or 'result'; 132 /// {2} is the total number of variadic groups; 133 /// {3} is the number of non-variadic groups preceding the current group; 134 /// {3} is the number of variadic groups preceding the current group. 135 constexpr const char *opVariadicEqualPrefixTemplate = R"Py( 136 @builtins.property 137 def {0}(self): 138 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py"; 139 140 /// Second part of the template for equally-sized case, accessing a single 141 /// element: 142 /// {0} is either 'operand' or 'result'. 143 constexpr const char *opVariadicEqualSimpleTemplate = R"Py( 144 return self.operation.{0}s[start] 145 )Py"; 146 147 /// Second part of the template for equally-sized case, accessing a variadic 148 /// group: 149 /// {0} is either 'operand' or 'result'. 150 constexpr const char *opVariadicEqualVariadicTemplate = R"Py( 151 return self.operation.{0}s[start:start + pg] 152 )Py"; 153 154 /// Template for an attribute-sized group accessor: 155 /// {0} is the name of the accessor; 156 /// {1} is either 'operand' or 'result'; 157 /// {2} is the position of the group in the group list; 158 /// {3} is a return suffix (expected [0] for single-element, empty for 159 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). 160 constexpr const char *opVariadicSegmentTemplate = R"Py( 161 @builtins.property 162 def {0}(self): 163 {1}_range = _ods_segmented_accessor( 164 self.operation.{1}s, 165 self.operation.attributes["{1}_segment_sizes"], {2}) 166 return {1}_range{3} 167 )Py"; 168 169 /// Template for a suffix when accessing an optional element in the 170 /// attribute-sized case: 171 /// {0} is either 'operand' or 'result'; 172 constexpr const char *opVariadicSegmentOptionalTrailingTemplate = 173 R"Py([0] if len({0}_range) > 0 else None)Py"; 174 175 /// Template for an operation attribute getter: 176 /// {0} is the name of the attribute sanitized for Python; 177 /// {1} is the Python type of the attribute; 178 /// {2} os the original name of the attribute. 179 constexpr const char *attributeGetterTemplate = R"Py( 180 @builtins.property 181 def {0}(self): 182 return {1}(self.operation.attributes["{2}"]) 183 )Py"; 184 185 /// Template for an optional operation attribute getter: 186 /// {0} is the name of the attribute sanitized for Python; 187 /// {1} is the Python type of the attribute; 188 /// {2} is the original name of the attribute. 189 constexpr const char *optionalAttributeGetterTemplate = R"Py( 190 @builtins.property 191 def {0}(self): 192 if "{2}" not in self.operation.attributes: 193 return None 194 return {1}(self.operation.attributes["{2}"]) 195 )Py"; 196 197 /// Template for a getter of a unit operation attribute, returns True of the 198 /// unit attribute is present, False otherwise (unit attributes have meaning 199 /// by mere presence): 200 /// {0} is the name of the attribute sanitized for Python, 201 /// {1} is the original name of the attribute. 202 constexpr const char *unitAttributeGetterTemplate = R"Py( 203 @builtins.property 204 def {0}(self): 205 return "{1}" in self.operation.attributes 206 )Py"; 207 208 /// Template for an operation attribute setter: 209 /// {0} is the name of the attribute sanitized for Python; 210 /// {1} is the original name of the attribute. 211 constexpr const char *attributeSetterTemplate = R"Py( 212 @{0}.setter 213 def {0}(self, value): 214 if value is None: 215 raise ValueError("'None' not allowed as value for mandatory attributes") 216 self.operation.attributes["{1}"] = value 217 )Py"; 218 219 /// Template for a setter of an optional operation attribute, setting to None 220 /// removes the attribute: 221 /// {0} is the name of the attribute sanitized for Python; 222 /// {1} is the original name of the attribute. 223 constexpr const char *optionalAttributeSetterTemplate = R"Py( 224 @{0}.setter 225 def {0}(self, value): 226 if value is not None: 227 self.operation.attributes["{1}"] = value 228 elif "{1}" in self.operation.attributes: 229 del self.operation.attributes["{1}"] 230 )Py"; 231 232 /// Template for a setter of a unit operation attribute, setting to None or 233 /// False removes the attribute: 234 /// {0} is the name of the attribute sanitized for Python; 235 /// {1} is the original name of the attribute. 236 constexpr const char *unitAttributeSetterTemplate = R"Py( 237 @{0}.setter 238 def {0}(self, value): 239 if bool(value): 240 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() 241 elif "{1}" in self.operation.attributes: 242 del self.operation.attributes["{1}"] 243 )Py"; 244 245 /// Template for a deleter of an optional or a unit operation attribute, removes 246 /// the attribute from the operation: 247 /// {0} is the name of the attribute sanitized for Python; 248 /// {1} is the original name of the attribute. 249 constexpr const char *attributeDeleterTemplate = R"Py( 250 @{0}.deleter 251 def {0}(self): 252 del self.operation.attributes["{1}"] 253 )Py"; 254 255 static llvm::cl::OptionCategory 256 clOpPythonBindingCat("Options for -gen-python-op-bindings"); 257 258 static llvm::cl::opt<std::string> 259 clDialectName("bind-dialect", 260 llvm::cl::desc("The dialect to run the generator for"), 261 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 262 263 using AttributeClasses = DenseMap<StringRef, StringRef>; 264 265 /// Checks whether `str` is a Python keyword. 266 static bool isPythonKeyword(StringRef str) { 267 static llvm::StringSet<> keywords( 268 {"and", "as", "assert", "break", "class", "continue", 269 "def", "del", "elif", "else", "except", "finally", 270 "for", "from", "global", "if", "import", "in", 271 "is", "lambda", "nonlocal", "not", "or", "pass", 272 "raise", "return", "try", "while", "with", "yield"}); 273 return keywords.contains(str); 274 } 275 276 /// Checks whether `str` would shadow a generated variable or attribute 277 /// part of the OpView API. 278 static bool isODSReserved(StringRef str) { 279 static llvm::StringSet<> reserved( 280 {"attributes", "create", "context", "ip", "operands", "print", "get_asm", 281 "loc", "verify", "regions", "results", "self", "operation", 282 "DIALECT_NAMESPACE", "OPERATION_NAME"}); 283 return str.startswith("_ods_") || str.endswith("_ods") || 284 reserved.contains(str); 285 } 286 287 /// Modifies the `name` in a way that it becomes suitable for Python bindings 288 /// (does not change the `name` if it already is suitable) and returns the 289 /// modified version. 290 static std::string sanitizeName(StringRef name) { 291 if (isPythonKeyword(name) || isODSReserved(name)) 292 return (name + "_").str(); 293 return name.str(); 294 } 295 296 static std::string attrSizedTraitForKind(const char *kind) { 297 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", 298 llvm::StringRef(kind).take_front().upper(), 299 llvm::StringRef(kind).drop_front()); 300 } 301 302 /// Emits accessors to "elements" of an Op definition. Currently, the supported 303 /// elements are operands and results, indicated by `kind`, which must be either 304 /// `operand` or `result` and is used verbatim in the emitted code. 305 static void emitElementAccessors( 306 const Operator &op, raw_ostream &os, const char *kind, 307 llvm::function_ref<unsigned(const Operator &)> getNumVariadic, 308 llvm::function_ref<int(const Operator &)> getNumElements, 309 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 310 getElement) { 311 assert(llvm::is_contained( 312 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && 313 "unsupported kind"); 314 315 // Traits indicating how to process variadic elements. 316 std::string sameSizeTrait = 317 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", 318 llvm::StringRef(kind).take_front().upper(), 319 llvm::StringRef(kind).drop_front()); 320 std::string attrSizedTrait = attrSizedTraitForKind(kind); 321 322 unsigned numVariadic = getNumVariadic(op); 323 324 // If there is only one variadic element group, its size can be inferred from 325 // the total number of elements. If there are none, the generation is 326 // straightforward. 327 if (numVariadic <= 1) { 328 bool seenVariableLength = false; 329 for (int i = 0, e = getNumElements(op); i < e; ++i) { 330 const NamedTypeConstraint &element = getElement(op, i); 331 if (element.isVariableLength()) 332 seenVariableLength = true; 333 if (element.name.empty()) 334 continue; 335 if (element.isVariableLength()) { 336 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate 337 : opOneVariadicTemplate, 338 sanitizeName(element.name), kind, 339 getNumElements(op), i); 340 } else if (seenVariableLength) { 341 os << llvm::formatv(opSingleAfterVariableTemplate, 342 sanitizeName(element.name), kind, 343 getNumElements(op), i); 344 } else { 345 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, 346 i); 347 } 348 } 349 return; 350 } 351 352 // Handle the operations where variadic groups have the same size. 353 if (op.getTrait(sameSizeTrait)) { 354 int numPrecedingSimple = 0; 355 int numPrecedingVariadic = 0; 356 for (int i = 0, e = getNumElements(op); i < e; ++i) { 357 const NamedTypeConstraint &element = getElement(op, i); 358 if (!element.name.empty()) { 359 os << llvm::formatv(opVariadicEqualPrefixTemplate, 360 sanitizeName(element.name), kind, numVariadic, 361 numPrecedingSimple, numPrecedingVariadic); 362 os << llvm::formatv(element.isVariableLength() 363 ? opVariadicEqualVariadicTemplate 364 : opVariadicEqualSimpleTemplate, 365 kind); 366 } 367 if (element.isVariableLength()) 368 ++numPrecedingVariadic; 369 else 370 ++numPrecedingSimple; 371 } 372 return; 373 } 374 375 // Handle the operations where the size of groups (variadic or not) is 376 // provided as an attribute. For non-variadic elements, make sure to return 377 // an element rather than a singleton container. 378 if (op.getTrait(attrSizedTrait)) { 379 for (int i = 0, e = getNumElements(op); i < e; ++i) { 380 const NamedTypeConstraint &element = getElement(op, i); 381 if (element.name.empty()) 382 continue; 383 std::string trailing; 384 if (!element.isVariableLength()) 385 trailing = "[0]"; 386 else if (element.isOptional()) 387 trailing = std::string( 388 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); 389 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), 390 kind, i, trailing); 391 } 392 return; 393 } 394 395 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); 396 } 397 398 /// Free function helpers accessing Operator components. 399 static int getNumOperands(const Operator &op) { return op.getNumOperands(); } 400 static const NamedTypeConstraint &getOperand(const Operator &op, int i) { 401 return op.getOperand(i); 402 } 403 static int getNumResults(const Operator &op) { return op.getNumResults(); } 404 static const NamedTypeConstraint &getResult(const Operator &op, int i) { 405 return op.getResult(i); 406 } 407 408 /// Emits accessors to Op operands. 409 static void emitOperandAccessors(const Operator &op, raw_ostream &os) { 410 auto getNumVariadic = [](const Operator &oper) { 411 return oper.getNumVariableLengthOperands(); 412 }; 413 emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands, 414 getOperand); 415 } 416 417 /// Emits accessors Op results. 418 static void emitResultAccessors(const Operator &op, raw_ostream &os) { 419 auto getNumVariadic = [](const Operator &oper) { 420 return oper.getNumVariableLengthResults(); 421 }; 422 emitElementAccessors(op, os, "result", getNumVariadic, getNumResults, 423 getResult); 424 } 425 426 /// Emits accessors to Op attributes. 427 static void emitAttributeAccessors(const Operator &op, 428 const AttributeClasses &attributeClasses, 429 raw_ostream &os) { 430 for (const auto &namedAttr : op.getAttributes()) { 431 // Skip "derived" attributes because they are just C++ functions that we 432 // don't currently expose. 433 if (namedAttr.attr.isDerivedAttr()) 434 continue; 435 436 if (namedAttr.name.empty()) 437 continue; 438 439 std::string sanitizedName = sanitizeName(namedAttr.name); 440 441 // Unit attributes are handled specially. 442 if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 443 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, 444 namedAttr.name); 445 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, 446 namedAttr.name); 447 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 448 namedAttr.name); 449 continue; 450 } 451 452 // Other kinds of attributes need a mapping to a Python type. 453 if (!attributeClasses.count(namedAttr.attr.getStorageType().trim())) 454 continue; 455 456 StringRef pythonType = 457 attributeClasses.lookup(namedAttr.attr.getStorageType()); 458 if (namedAttr.attr.isOptional()) { 459 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName, 460 pythonType, namedAttr.name); 461 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName, 462 namedAttr.name); 463 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 464 namedAttr.name); 465 } else { 466 os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType, 467 namedAttr.name); 468 os << llvm::formatv(attributeSetterTemplate, sanitizedName, 469 namedAttr.name); 470 // Non-optional attributes cannot be deleted. 471 } 472 } 473 } 474 475 /// Template for the default auto-generated builder. 476 /// {0} is a comma-separated list of builder arguments, including the trailing 477 /// `loc` and `ip`; 478 /// {1} is the code populating `operands`, `results` and `attributes`, 479 /// `successors` fields. 480 constexpr const char *initTemplate = R"Py( 481 def __init__(self, {0}): 482 operands = [] 483 results = [] 484 attributes = {{} 485 {1} 486 super().__init__(self.build_generic( 487 attributes=attributes, results=results, operands=operands, 488 successors=_ods_successors, loc=loc, ip=ip)) 489 )Py"; 490 491 /// Template for appending a single element to the operand/result list. 492 /// {0} is either 'operand' or 'result'; 493 /// {1} is the field name. 494 constexpr const char *singleElementAppendTemplate = "{0}s.append({1})"; 495 496 /// Template for appending an optional element to the operand/result list. 497 /// {0} is either 'operand' or 'result'; 498 /// {1} is the field name. 499 constexpr const char *optionalAppendTemplate = 500 "if {1} is not None: {0}s.append({1})"; 501 502 /// Template for appending a a list of elements to the operand/result list. 503 /// {0} is either 'operand' or 'result'; 504 /// {1} is the field name. 505 constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})"; 506 507 /// Template for setting an attribute in the operation builder. 508 /// {0} is the attribute name; 509 /// {1} is the builder argument name. 510 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; 511 512 /// Template for setting an optional attribute in the operation builder. 513 /// {0} is the attribute name; 514 /// {1} is the builder argument name. 515 constexpr const char *initOptionalAttributeTemplate = 516 R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; 517 518 constexpr const char *initUnitAttributeTemplate = 519 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( 520 _ods_get_default_loc_context(loc)))Py"; 521 522 /// Template to initialize the successors list in the builder if there are any 523 /// successors. 524 /// {0} is the value to initialize the successors list to. 525 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; 526 527 /// Template to append or extend the list of successors in the builder. 528 /// {0} is the list method ('append' or 'extend'); 529 /// {1} is the value to add. 530 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py"; 531 532 /// Populates `builderArgs` with the Python-compatible names of builder function 533 /// arguments, first the results, then the intermixed attributes and operands in 534 /// the same order as they appear in the `arguments` field of the op definition. 535 /// Additionally, `operandNames` is populated with names of operands in their 536 /// order of appearance. 537 static void 538 populateBuilderArgs(const Operator &op, 539 llvm::SmallVectorImpl<std::string> &builderArgs, 540 llvm::SmallVectorImpl<std::string> &operandNames, 541 llvm::SmallVectorImpl<std::string> &successorArgNames) { 542 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 543 std::string name = op.getResultName(i).str(); 544 if (name.empty()) { 545 if (op.getNumResults() == 1) { 546 // Special case for one result, make the default name be 'result' 547 // to properly match the built-in result accessor. 548 name = "result"; 549 } else { 550 name = llvm::formatv("_gen_res_{0}", i); 551 } 552 } 553 name = sanitizeName(name); 554 builderArgs.push_back(name); 555 } 556 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 557 std::string name = op.getArgName(i).str(); 558 if (name.empty()) 559 name = llvm::formatv("_gen_arg_{0}", i); 560 name = sanitizeName(name); 561 builderArgs.push_back(name); 562 if (!op.getArg(i).is<NamedAttribute *>()) 563 operandNames.push_back(name); 564 } 565 566 for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { 567 NamedSuccessor successor = op.getSuccessor(i); 568 std::string name = std::string(successor.name); 569 if (name.empty()) 570 name = llvm::formatv("_gen_successor_{0}", i); 571 name = sanitizeName(name); 572 builderArgs.push_back(name); 573 successorArgNames.push_back(name); 574 } 575 } 576 577 /// Populates `builderLines` with additional lines that are required in the 578 /// builder to set up operation attributes. `argNames` is expected to contain 579 /// the names of builder arguments that correspond to op arguments, i.e. to the 580 /// operands and attributes in the same order as they appear in the `arguments` 581 /// field. 582 static void 583 populateBuilderLinesAttr(const Operator &op, 584 llvm::ArrayRef<std::string> argNames, 585 llvm::SmallVectorImpl<std::string> &builderLines) { 586 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 587 Argument arg = op.getArg(i); 588 auto *attribute = arg.dyn_cast<NamedAttribute *>(); 589 if (!attribute) 590 continue; 591 592 // Unit attributes are handled specially. 593 if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 594 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, 595 attribute->name, argNames[i])); 596 continue; 597 } 598 599 builderLines.push_back(llvm::formatv(attribute->attr.isOptional() 600 ? initOptionalAttributeTemplate 601 : initAttributeTemplate, 602 attribute->name, argNames[i])); 603 } 604 } 605 606 /// Populates `builderLines` with additional lines that are required in the 607 /// builder to set up successors. successorArgNames is expected to correspond 608 /// to the Python argument name for each successor on the op. 609 static void populateBuilderLinesSuccessors( 610 const Operator &op, llvm::ArrayRef<std::string> successorArgNames, 611 llvm::SmallVectorImpl<std::string> &builderLines) { 612 if (successorArgNames.empty()) { 613 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None")); 614 return; 615 } 616 617 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]")); 618 for (int i = 0, e = successorArgNames.size(); i < e; ++i) { 619 auto &argName = successorArgNames[i]; 620 const NamedSuccessor &successor = op.getSuccessor(i); 621 builderLines.push_back( 622 llvm::formatv(addSuccessorTemplate, 623 successor.isVariadic() ? "extend" : "append", argName)); 624 } 625 } 626 627 /// Populates `builderLines` with additional lines that are required in the 628 /// builder. `kind` must be either "operand" or "result". `names` contains the 629 /// names of init arguments that correspond to the elements. 630 static void populateBuilderLines( 631 const Operator &op, const char *kind, llvm::ArrayRef<std::string> names, 632 llvm::SmallVectorImpl<std::string> &builderLines, 633 llvm::function_ref<int(const Operator &)> getNumElements, 634 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 635 getElement) { 636 bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr; 637 638 // For each element, find or generate a name. 639 for (int i = 0, e = getNumElements(op); i < e; ++i) { 640 const NamedTypeConstraint &element = getElement(op, i); 641 std::string name = names[i]; 642 643 // Choose the formatting string based on the element kind. 644 llvm::StringRef formatString; 645 if (!element.isVariableLength()) { 646 formatString = singleElementAppendTemplate; 647 } else if (element.isOptional()) { 648 formatString = optionalAppendTemplate; 649 } else { 650 assert(element.isVariadic() && "unhandled element group type"); 651 // If emitting with sizedSegments, then we add the actual list typed 652 // element using the singleElementAppendTemplate. Otherwise, we extend 653 // the actual operands. 654 if (sizedSegments) { 655 // Append the list as is. 656 formatString = singleElementAppendTemplate; 657 } else { 658 // Append the list elements. 659 formatString = multiElementAppendTemplate; 660 } 661 } 662 663 // Add the lines. 664 builderLines.push_back(llvm::formatv(formatString.data(), kind, name)); 665 } 666 } 667 668 /// Emits a default builder constructing an operation from the list of its 669 /// result types, followed by a list of its operands. 670 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { 671 // If we are asked to skip default builders, comply. 672 if (op.skipDefaultBuilders()) 673 return; 674 675 llvm::SmallVector<std::string> builderArgs; 676 llvm::SmallVector<std::string> builderLines; 677 llvm::SmallVector<std::string> operandArgNames; 678 llvm::SmallVector<std::string> successorArgNames; 679 builderArgs.reserve(op.getNumOperands() + op.getNumResults() + 680 op.getNumNativeAttributes() + op.getNumSuccessors()); 681 populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); 682 683 populateBuilderLines( 684 op, "result", 685 llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()), 686 builderLines, getNumResults, getResult); 687 populateBuilderLines(op, "operand", operandArgNames, builderLines, 688 getNumOperands, getOperand); 689 populateBuilderLinesAttr( 690 op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()), 691 builderLines); 692 populateBuilderLinesSuccessors(op, successorArgNames, builderLines); 693 694 builderArgs.push_back("*"); 695 builderArgs.push_back("loc=None"); 696 builderArgs.push_back("ip=None"); 697 os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "), 698 llvm::join(builderLines, "\n ")); 699 } 700 701 static void constructAttributeMapping(const llvm::RecordKeeper &records, 702 AttributeClasses &attributeClasses) { 703 for (const llvm::Record *rec : 704 records.getAllDerivedDefinitions("PythonAttr")) { 705 attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(), 706 rec->getValueAsString("pythonType").trim()); 707 } 708 } 709 710 static void emitSegmentSpec( 711 const Operator &op, const char *kind, 712 llvm::function_ref<int(const Operator &)> getNumElements, 713 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 714 getElement, 715 raw_ostream &os) { 716 std::string segmentSpec("["); 717 for (int i = 0, e = getNumElements(op); i < e; ++i) { 718 const NamedTypeConstraint &element = getElement(op, i); 719 if (element.isVariableLength()) { 720 segmentSpec.append("-1,"); 721 } else if (element.isOptional()) { 722 segmentSpec.append("0,"); 723 } else { 724 segmentSpec.append("1,"); 725 } 726 } 727 segmentSpec.append("]"); 728 729 os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); 730 } 731 732 static void emitRegionAttributes(const Operator &op, raw_ostream &os) { 733 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). 734 // Note that the base OpView class defines this as (0, True). 735 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); 736 os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount, 737 op.hasNoVariadicRegions() ? "True" : "False"); 738 } 739 740 /// Emits bindings for a specific Op to the given output stream. 741 static void emitOpBindings(const Operator &op, 742 const AttributeClasses &attributeClasses, 743 raw_ostream &os) { 744 os << llvm::formatv(opClassTemplate, op.getCppClassName(), 745 op.getOperationName()); 746 747 // Sized segments. 748 if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { 749 emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); 750 } 751 if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { 752 emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); 753 } 754 755 emitRegionAttributes(op, os); 756 emitDefaultOpBuilder(op, os); 757 emitOperandAccessors(op, os); 758 emitAttributeAccessors(op, attributeClasses, os); 759 emitResultAccessors(op, os); 760 } 761 762 /// Emits bindings for the dialect specified in the command line, including file 763 /// headers and utilities. Returns `false` on success to comply with Tablegen 764 /// registration requirements. 765 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { 766 if (clDialectName.empty()) 767 llvm::PrintFatalError("dialect name not provided"); 768 769 AttributeClasses attributeClasses; 770 constructAttributeMapping(records, attributeClasses); 771 772 os << llvm::formatv(fileHeader, clDialectName.getValue()); 773 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); 774 775 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { 776 Operator op(rec); 777 if (op.getDialectName() == clDialectName.getValue()) 778 emitOpBindings(op, attributeClasses, os); 779 } 780 return false; 781 } 782 783 static GenRegistration 784 genPythonBindings("gen-python-op-bindings", 785 "Generate Python bindings for MLIR Ops", &emitAllOps); 786