1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python 10 // binding classes wrapping a generic operation API. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/GenInfo.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/ADT/StringSet.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace mlir; 23 using namespace mlir::tblgen; 24 25 /// File header and includes. 26 /// {0} is the dialect namespace. 27 constexpr const char *fileHeader = R"Py( 28 # Autogenerated by mlir-tblgen; don't manually edit. 29 30 from ._ods_common import _cext as _ods_cext 31 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values 32 _ods_ir = _ods_cext.ir 33 34 try: 35 from . import _{0}_ops_ext as _ods_ext_module 36 except ImportError: 37 _ods_ext_module = None 38 39 import builtins 40 41 )Py"; 42 43 /// Template for dialect class: 44 /// {0} is the dialect namespace. 45 constexpr const char *dialectClassTemplate = R"Py( 46 @_ods_cext.register_dialect 47 class _Dialect(_ods_ir.Dialect): 48 DIALECT_NAMESPACE = "{0}" 49 pass 50 51 )Py"; 52 53 /// Template for operation class: 54 /// {0} is the Python class name; 55 /// {1} is the operation name. 56 constexpr const char *opClassTemplate = R"Py( 57 @_ods_cext.register_operation(_Dialect) 58 @_ods_extend_opview_class(_ods_ext_module) 59 class {0}(_ods_ir.OpView): 60 OPERATION_NAME = "{1}" 61 )Py"; 62 63 /// Template for class level declarations of operand and result 64 /// segment specs. 65 /// {0} is either "OPERAND" or "RESULT" 66 /// {1} is the segment spec 67 /// Each segment spec is either None (default) or an array of integers 68 /// where: 69 /// 1 = single element (expect non sequence operand/result) 70 /// 0 = optional element (expect a value or None) 71 /// -1 = operand/result is a sequence corresponding to a variadic 72 constexpr const char *opClassSizedSegmentsTemplate = R"Py( 73 _ODS_{0}_SEGMENTS = {1} 74 )Py"; 75 76 /// Template for class level declarations of the _ODS_REGIONS spec: 77 /// {0} is the minimum number of regions 78 /// {1} is the Python bool literal for hasNoVariadicRegions 79 constexpr const char *opClassRegionSpecTemplate = R"Py( 80 _ODS_REGIONS = ({0}, {1}) 81 )Py"; 82 83 /// Template for single-element accessor: 84 /// {0} is the name of the accessor; 85 /// {1} is either 'operand' or 'result'; 86 /// {2} is the position in the element list. 87 constexpr const char *opSingleTemplate = R"Py( 88 @builtins.property 89 def {0}(self): 90 return self.operation.{1}s[{2}] 91 )Py"; 92 93 /// Template for single-element accessor after a variable-length group: 94 /// {0} is the name of the accessor; 95 /// {1} is either 'operand' or 'result'; 96 /// {2} is the total number of element groups; 97 /// {3} is the position of the current group in the group list. 98 /// This works for both a single variadic group (non-negative length) and an 99 /// single optional element (zero length if the element is absent). 100 constexpr const char *opSingleAfterVariableTemplate = R"Py( 101 @builtins.property 102 def {0}(self): 103 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 104 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] 105 )Py"; 106 107 /// Template for an optional element accessor: 108 /// {0} is the name of the accessor; 109 /// {1} is either 'operand' or 'result'; 110 /// {2} is the total number of element groups; 111 /// {3} is the position of the current group in the group list. 112 constexpr const char *opOneOptionalTemplate = R"Py( 113 @builtins.property 114 def {0}(self): 115 return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None 116 )Py"; 117 118 /// Template for the variadic group accessor in the single variadic group case: 119 /// {0} is the name of the accessor; 120 /// {1} is either 'operand' or 'result'; 121 /// {2} is the total number of element groups; 122 /// {3} is the position of the current group in the group list. 123 constexpr const char *opOneVariadicTemplate = R"Py( 124 @builtins.property 125 def {0}(self): 126 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 127 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] 128 )Py"; 129 130 /// First part of the template for equally-sized variadic group accessor: 131 /// {0} is the name of the accessor; 132 /// {1} is either 'operand' or 'result'; 133 /// {2} is the total number of variadic groups; 134 /// {3} is the number of non-variadic groups preceding the current group; 135 /// {3} is the number of variadic groups preceding the current group. 136 constexpr const char *opVariadicEqualPrefixTemplate = R"Py( 137 @builtins.property 138 def {0}(self): 139 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py"; 140 141 /// Second part of the template for equally-sized case, accessing a single 142 /// element: 143 /// {0} is either 'operand' or 'result'. 144 constexpr const char *opVariadicEqualSimpleTemplate = R"Py( 145 return self.operation.{0}s[start] 146 )Py"; 147 148 /// Second part of the template for equally-sized case, accessing a variadic 149 /// group: 150 /// {0} is either 'operand' or 'result'. 151 constexpr const char *opVariadicEqualVariadicTemplate = R"Py( 152 return self.operation.{0}s[start:start + pg] 153 )Py"; 154 155 /// Template for an attribute-sized group accessor: 156 /// {0} is the name of the accessor; 157 /// {1} is either 'operand' or 'result'; 158 /// {2} is the position of the group in the group list; 159 /// {3} is a return suffix (expected [0] for single-element, empty for 160 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). 161 constexpr const char *opVariadicSegmentTemplate = R"Py( 162 @builtins.property 163 def {0}(self): 164 {1}_range = _ods_segmented_accessor( 165 self.operation.{1}s, 166 self.operation.attributes["{1}_segment_sizes"], {2}) 167 return {1}_range{3} 168 )Py"; 169 170 /// Template for a suffix when accessing an optional element in the 171 /// attribute-sized case: 172 /// {0} is either 'operand' or 'result'; 173 constexpr const char *opVariadicSegmentOptionalTrailingTemplate = 174 R"Py([0] if len({0}_range) > 0 else None)Py"; 175 176 /// Template for an operation attribute getter: 177 /// {0} is the name of the attribute sanitized for Python; 178 /// {1} is the Python type of the attribute; 179 /// {2} os the original name of the attribute. 180 constexpr const char *attributeGetterTemplate = R"Py( 181 @builtins.property 182 def {0}(self): 183 return {1}(self.operation.attributes["{2}"]) 184 )Py"; 185 186 /// Template for an optional operation attribute getter: 187 /// {0} is the name of the attribute sanitized for Python; 188 /// {1} is the Python type of the attribute; 189 /// {2} is the original name of the attribute. 190 constexpr const char *optionalAttributeGetterTemplate = R"Py( 191 @builtins.property 192 def {0}(self): 193 if "{2}" not in self.operation.attributes: 194 return None 195 return {1}(self.operation.attributes["{2}"]) 196 )Py"; 197 198 /// Template for a getter of a unit operation attribute, returns True of the 199 /// unit attribute is present, False otherwise (unit attributes have meaning 200 /// by mere presence): 201 /// {0} is the name of the attribute sanitized for Python, 202 /// {1} is the original name of the attribute. 203 constexpr const char *unitAttributeGetterTemplate = R"Py( 204 @builtins.property 205 def {0}(self): 206 return "{1}" in self.operation.attributes 207 )Py"; 208 209 /// Template for an operation attribute setter: 210 /// {0} is the name of the attribute sanitized for Python; 211 /// {1} is the original name of the attribute. 212 constexpr const char *attributeSetterTemplate = R"Py( 213 @{0}.setter 214 def {0}(self, value): 215 if value is None: 216 raise ValueError("'None' not allowed as value for mandatory attributes") 217 self.operation.attributes["{1}"] = value 218 )Py"; 219 220 /// Template for a setter of an optional operation attribute, setting to None 221 /// removes the attribute: 222 /// {0} is the name of the attribute sanitized for Python; 223 /// {1} is the original name of the attribute. 224 constexpr const char *optionalAttributeSetterTemplate = R"Py( 225 @{0}.setter 226 def {0}(self, value): 227 if value is not None: 228 self.operation.attributes["{1}"] = value 229 elif "{1}" in self.operation.attributes: 230 del self.operation.attributes["{1}"] 231 )Py"; 232 233 /// Template for a setter of a unit operation attribute, setting to None or 234 /// False removes the attribute: 235 /// {0} is the name of the attribute sanitized for Python; 236 /// {1} is the original name of the attribute. 237 constexpr const char *unitAttributeSetterTemplate = R"Py( 238 @{0}.setter 239 def {0}(self, value): 240 if bool(value): 241 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() 242 elif "{1}" in self.operation.attributes: 243 del self.operation.attributes["{1}"] 244 )Py"; 245 246 /// Template for a deleter of an optional or a unit operation attribute, removes 247 /// the attribute from the operation: 248 /// {0} is the name of the attribute sanitized for Python; 249 /// {1} is the original name of the attribute. 250 constexpr const char *attributeDeleterTemplate = R"Py( 251 @{0}.deleter 252 def {0}(self): 253 del self.operation.attributes["{1}"] 254 )Py"; 255 256 constexpr const char *regionAccessorTemplate = R"PY( 257 @builtins.property 258 def {0}(self): 259 return self.regions[{1}] 260 )PY"; 261 262 static llvm::cl::OptionCategory 263 clOpPythonBindingCat("Options for -gen-python-op-bindings"); 264 265 static llvm::cl::opt<std::string> 266 clDialectName("bind-dialect", 267 llvm::cl::desc("The dialect to run the generator for"), 268 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 269 270 using AttributeClasses = DenseMap<StringRef, StringRef>; 271 272 /// Checks whether `str` is a Python keyword. 273 static bool isPythonKeyword(StringRef str) { 274 static llvm::StringSet<> keywords( 275 {"and", "as", "assert", "break", "class", "continue", 276 "def", "del", "elif", "else", "except", "finally", 277 "for", "from", "global", "if", "import", "in", 278 "is", "lambda", "nonlocal", "not", "or", "pass", 279 "raise", "return", "try", "while", "with", "yield"}); 280 return keywords.contains(str); 281 } 282 283 /// Checks whether `str` would shadow a generated variable or attribute 284 /// part of the OpView API. 285 static bool isODSReserved(StringRef str) { 286 static llvm::StringSet<> reserved( 287 {"attributes", "create", "context", "ip", "operands", "print", "get_asm", 288 "loc", "verify", "regions", "results", "self", "operation", 289 "DIALECT_NAMESPACE", "OPERATION_NAME"}); 290 return str.startswith("_ods_") || str.endswith("_ods") || 291 reserved.contains(str); 292 } 293 294 /// Modifies the `name` in a way that it becomes suitable for Python bindings 295 /// (does not change the `name` if it already is suitable) and returns the 296 /// modified version. 297 static std::string sanitizeName(StringRef name) { 298 if (isPythonKeyword(name) || isODSReserved(name)) 299 return (name + "_").str(); 300 return name.str(); 301 } 302 303 static std::string attrSizedTraitForKind(const char *kind) { 304 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", 305 llvm::StringRef(kind).take_front().upper(), 306 llvm::StringRef(kind).drop_front()); 307 } 308 309 /// Emits accessors to "elements" of an Op definition. Currently, the supported 310 /// elements are operands and results, indicated by `kind`, which must be either 311 /// `operand` or `result` and is used verbatim in the emitted code. 312 static void emitElementAccessors( 313 const Operator &op, raw_ostream &os, const char *kind, 314 llvm::function_ref<unsigned(const Operator &)> getNumVariadic, 315 llvm::function_ref<int(const Operator &)> getNumElements, 316 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 317 getElement) { 318 assert(llvm::is_contained( 319 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && 320 "unsupported kind"); 321 322 // Traits indicating how to process variadic elements. 323 std::string sameSizeTrait = 324 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", 325 llvm::StringRef(kind).take_front().upper(), 326 llvm::StringRef(kind).drop_front()); 327 std::string attrSizedTrait = attrSizedTraitForKind(kind); 328 329 unsigned numVariadic = getNumVariadic(op); 330 331 // If there is only one variadic element group, its size can be inferred from 332 // the total number of elements. If there are none, the generation is 333 // straightforward. 334 if (numVariadic <= 1) { 335 bool seenVariableLength = false; 336 for (int i = 0, e = getNumElements(op); i < e; ++i) { 337 const NamedTypeConstraint &element = getElement(op, i); 338 if (element.isVariableLength()) 339 seenVariableLength = true; 340 if (element.name.empty()) 341 continue; 342 if (element.isVariableLength()) { 343 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate 344 : opOneVariadicTemplate, 345 sanitizeName(element.name), kind, 346 getNumElements(op), i); 347 } else if (seenVariableLength) { 348 os << llvm::formatv(opSingleAfterVariableTemplate, 349 sanitizeName(element.name), kind, 350 getNumElements(op), i); 351 } else { 352 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, 353 i); 354 } 355 } 356 return; 357 } 358 359 // Handle the operations where variadic groups have the same size. 360 if (op.getTrait(sameSizeTrait)) { 361 int numPrecedingSimple = 0; 362 int numPrecedingVariadic = 0; 363 for (int i = 0, e = getNumElements(op); i < e; ++i) { 364 const NamedTypeConstraint &element = getElement(op, i); 365 if (!element.name.empty()) { 366 os << llvm::formatv(opVariadicEqualPrefixTemplate, 367 sanitizeName(element.name), kind, numVariadic, 368 numPrecedingSimple, numPrecedingVariadic); 369 os << llvm::formatv(element.isVariableLength() 370 ? opVariadicEqualVariadicTemplate 371 : opVariadicEqualSimpleTemplate, 372 kind); 373 } 374 if (element.isVariableLength()) 375 ++numPrecedingVariadic; 376 else 377 ++numPrecedingSimple; 378 } 379 return; 380 } 381 382 // Handle the operations where the size of groups (variadic or not) is 383 // provided as an attribute. For non-variadic elements, make sure to return 384 // an element rather than a singleton container. 385 if (op.getTrait(attrSizedTrait)) { 386 for (int i = 0, e = getNumElements(op); i < e; ++i) { 387 const NamedTypeConstraint &element = getElement(op, i); 388 if (element.name.empty()) 389 continue; 390 std::string trailing; 391 if (!element.isVariableLength()) 392 trailing = "[0]"; 393 else if (element.isOptional()) 394 trailing = std::string( 395 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); 396 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), 397 kind, i, trailing); 398 } 399 return; 400 } 401 402 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); 403 } 404 405 /// Free function helpers accessing Operator components. 406 static int getNumOperands(const Operator &op) { return op.getNumOperands(); } 407 static const NamedTypeConstraint &getOperand(const Operator &op, int i) { 408 return op.getOperand(i); 409 } 410 static int getNumResults(const Operator &op) { return op.getNumResults(); } 411 static const NamedTypeConstraint &getResult(const Operator &op, int i) { 412 return op.getResult(i); 413 } 414 415 /// Emits accessors to Op operands. 416 static void emitOperandAccessors(const Operator &op, raw_ostream &os) { 417 auto getNumVariadic = [](const Operator &oper) { 418 return oper.getNumVariableLengthOperands(); 419 }; 420 emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands, 421 getOperand); 422 } 423 424 /// Emits accessors Op results. 425 static void emitResultAccessors(const Operator &op, raw_ostream &os) { 426 auto getNumVariadic = [](const Operator &oper) { 427 return oper.getNumVariableLengthResults(); 428 }; 429 emitElementAccessors(op, os, "result", getNumVariadic, getNumResults, 430 getResult); 431 } 432 433 /// Emits accessors to Op attributes. 434 static void emitAttributeAccessors(const Operator &op, 435 const AttributeClasses &attributeClasses, 436 raw_ostream &os) { 437 for (const auto &namedAttr : op.getAttributes()) { 438 // Skip "derived" attributes because they are just C++ functions that we 439 // don't currently expose. 440 if (namedAttr.attr.isDerivedAttr()) 441 continue; 442 443 if (namedAttr.name.empty()) 444 continue; 445 446 std::string sanitizedName = sanitizeName(namedAttr.name); 447 448 // Unit attributes are handled specially. 449 if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 450 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, 451 namedAttr.name); 452 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, 453 namedAttr.name); 454 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 455 namedAttr.name); 456 continue; 457 } 458 459 // Other kinds of attributes need a mapping to a Python type. 460 if (!attributeClasses.count(namedAttr.attr.getStorageType().trim())) 461 continue; 462 463 StringRef pythonType = 464 attributeClasses.lookup(namedAttr.attr.getStorageType()); 465 if (namedAttr.attr.isOptional()) { 466 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName, 467 pythonType, namedAttr.name); 468 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName, 469 namedAttr.name); 470 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 471 namedAttr.name); 472 } else { 473 os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType, 474 namedAttr.name); 475 os << llvm::formatv(attributeSetterTemplate, sanitizedName, 476 namedAttr.name); 477 // Non-optional attributes cannot be deleted. 478 } 479 } 480 } 481 482 /// Template for the default auto-generated builder. 483 /// {0} is a comma-separated list of builder arguments, including the trailing 484 /// `loc` and `ip`; 485 /// {1} is the code populating `operands`, `results` and `attributes`, 486 /// `successors` fields. 487 constexpr const char *initTemplate = R"Py( 488 def __init__(self, {0}): 489 operands = [] 490 results = [] 491 attributes = {{} 492 regions = None 493 {1} 494 super().__init__(self.build_generic( 495 attributes=attributes, results=results, operands=operands, 496 successors=_ods_successors, regions=regions, loc=loc, ip=ip)) 497 )Py"; 498 499 /// Template for appending a single element to the operand/result list. 500 /// {0} is the field name. 501 constexpr const char *singleOperandAppendTemplate = 502 "operands.append(_get_op_result_or_value({0}))"; 503 constexpr const char *singleResultAppendTemplate = "results.append({0})"; 504 505 /// Template for appending an optional element to the operand/result list. 506 /// {0} is the field name. 507 constexpr const char *optionalAppendOperandTemplate = 508 "if {0} is not None: operands.append(_get_op_result_or_value({0}))"; 509 constexpr const char *optionalAppendAttrSizedOperandsTemplate = 510 "operands.append(_get_op_result_or_value({0}) if {0} is not None else " 511 "None)"; 512 constexpr const char *optionalAppendResultTemplate = 513 "if {0} is not None: results.append({0})"; 514 515 /// Template for appending a list of elements to the operand/result list. 516 /// {0} is the field name. 517 constexpr const char *multiOperandAppendTemplate = 518 "operands.extend(_get_op_results_or_values({0}))"; 519 constexpr const char *multiOperandAppendPackTemplate = 520 "operands.append(_get_op_results_or_values({0}))"; 521 constexpr const char *multiResultAppendTemplate = "results.extend({0})"; 522 523 /// Template for setting an attribute in the operation builder. 524 /// {0} is the attribute name; 525 /// {1} is the builder argument name. 526 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; 527 528 /// Template for setting an optional attribute in the operation builder. 529 /// {0} is the attribute name; 530 /// {1} is the builder argument name. 531 constexpr const char *initOptionalAttributeTemplate = 532 R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; 533 534 constexpr const char *initUnitAttributeTemplate = 535 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( 536 _ods_get_default_loc_context(loc)))Py"; 537 538 /// Template to initialize the successors list in the builder if there are any 539 /// successors. 540 /// {0} is the value to initialize the successors list to. 541 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; 542 543 /// Template to append or extend the list of successors in the builder. 544 /// {0} is the list method ('append' or 'extend'); 545 /// {1} is the value to add. 546 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py"; 547 548 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer 549 /// result types of the given operation. 550 static bool hasSameArgumentAndResultTypes(const Operator &op) { 551 return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && 552 op.getNumVariableLengthResults() == 0; 553 } 554 555 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer 556 /// result types of the given operation. 557 static bool hasFirstAttrDerivedResultTypes(const Operator &op) { 558 return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && 559 op.getNumVariableLengthResults() == 0; 560 } 561 562 /// Returns true if the InferTypeOpInterface can be used to infer result types 563 /// of the given operation. 564 static bool hasInferTypeInterface(const Operator &op) { 565 return op.getTrait("::mlir::InferTypeOpInterface::Trait") && 566 op.getNumRegions() == 0; 567 } 568 569 /// Returns true if there is a trait or interface that can be used to infer 570 /// result types of the given operation. 571 static bool canInferType(const Operator &op) { 572 return hasSameArgumentAndResultTypes(op) || 573 hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); 574 } 575 576 /// Populates `builderArgs` with result names if the builder is expected to 577 /// accept them as arguments. 578 static void 579 populateBuilderArgsResults(const Operator &op, 580 llvm::SmallVectorImpl<std::string> &builderArgs) { 581 if (canInferType(op)) 582 return; 583 584 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 585 std::string name = op.getResultName(i).str(); 586 if (name.empty()) { 587 if (op.getNumResults() == 1) { 588 // Special case for one result, make the default name be 'result' 589 // to properly match the built-in result accessor. 590 name = "result"; 591 } else { 592 name = llvm::formatv("_gen_res_{0}", i); 593 } 594 } 595 name = sanitizeName(name); 596 builderArgs.push_back(name); 597 } 598 } 599 600 /// Populates `builderArgs` with the Python-compatible names of builder function 601 /// arguments using intermixed attributes and operands in the same order as they 602 /// appear in the `arguments` field of the op definition. Additionally, 603 /// `operandNames` is populated with names of operands in their order of 604 /// appearance. 605 static void 606 populateBuilderArgs(const Operator &op, 607 llvm::SmallVectorImpl<std::string> &builderArgs, 608 llvm::SmallVectorImpl<std::string> &operandNames, 609 llvm::SmallVectorImpl<std::string> &successorArgNames) { 610 611 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 612 std::string name = op.getArgName(i).str(); 613 if (name.empty()) 614 name = llvm::formatv("_gen_arg_{0}", i); 615 name = sanitizeName(name); 616 builderArgs.push_back(name); 617 if (!op.getArg(i).is<NamedAttribute *>()) 618 operandNames.push_back(name); 619 } 620 621 for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { 622 NamedSuccessor successor = op.getSuccessor(i); 623 std::string name = std::string(successor.name); 624 if (name.empty()) 625 name = llvm::formatv("_gen_successor_{0}", i); 626 name = sanitizeName(name); 627 builderArgs.push_back(name); 628 successorArgNames.push_back(name); 629 } 630 } 631 632 /// Populates `builderLines` with additional lines that are required in the 633 /// builder to set up operation attributes. `argNames` is expected to contain 634 /// the names of builder arguments that correspond to op arguments, i.e. to the 635 /// operands and attributes in the same order as they appear in the `arguments` 636 /// field. 637 static void 638 populateBuilderLinesAttr(const Operator &op, 639 llvm::ArrayRef<std::string> argNames, 640 llvm::SmallVectorImpl<std::string> &builderLines) { 641 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 642 Argument arg = op.getArg(i); 643 auto *attribute = arg.dyn_cast<NamedAttribute *>(); 644 if (!attribute) 645 continue; 646 647 // Unit attributes are handled specially. 648 if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 649 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, 650 attribute->name, argNames[i])); 651 continue; 652 } 653 654 builderLines.push_back(llvm::formatv(attribute->attr.isOptional() 655 ? initOptionalAttributeTemplate 656 : initAttributeTemplate, 657 attribute->name, argNames[i])); 658 } 659 } 660 661 /// Populates `builderLines` with additional lines that are required in the 662 /// builder to set up successors. successorArgNames is expected to correspond 663 /// to the Python argument name for each successor on the op. 664 static void populateBuilderLinesSuccessors( 665 const Operator &op, llvm::ArrayRef<std::string> successorArgNames, 666 llvm::SmallVectorImpl<std::string> &builderLines) { 667 if (successorArgNames.empty()) { 668 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None")); 669 return; 670 } 671 672 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]")); 673 for (int i = 0, e = successorArgNames.size(); i < e; ++i) { 674 auto &argName = successorArgNames[i]; 675 const NamedSuccessor &successor = op.getSuccessor(i); 676 builderLines.push_back( 677 llvm::formatv(addSuccessorTemplate, 678 successor.isVariadic() ? "extend" : "append", argName)); 679 } 680 } 681 682 /// Populates `builderLines` with additional lines that are required in the 683 /// builder to set up op operands. 684 static void 685 populateBuilderLinesOperand(const Operator &op, 686 llvm::ArrayRef<std::string> names, 687 llvm::SmallVectorImpl<std::string> &builderLines) { 688 bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr; 689 690 // For each element, find or generate a name. 691 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 692 const NamedTypeConstraint &element = op.getOperand(i); 693 std::string name = names[i]; 694 695 // Choose the formatting string based on the element kind. 696 llvm::StringRef formatString; 697 if (!element.isVariableLength()) { 698 formatString = singleOperandAppendTemplate; 699 } else if (element.isOptional()) { 700 if (sizedSegments) { 701 formatString = optionalAppendAttrSizedOperandsTemplate; 702 } else { 703 formatString = optionalAppendOperandTemplate; 704 } 705 } else { 706 assert(element.isVariadic() && "unhandled element group type"); 707 // If emitting with sizedSegments, then we add the actual list-typed 708 // element. Otherwise, we extend the actual operands. 709 if (sizedSegments) { 710 formatString = multiOperandAppendPackTemplate; 711 } else { 712 formatString = multiOperandAppendTemplate; 713 } 714 } 715 716 builderLines.push_back(llvm::formatv(formatString.data(), name)); 717 } 718 } 719 720 /// Python code template for deriving the operation result types from its 721 /// attribute: 722 /// - {0} is the name of the attribute from which to derive the types. 723 constexpr const char *deriveTypeFromAttrTemplate = 724 R"PY(_ods_result_type_source_attr = attributes["{0}"] 725 _ods_derived_result_type = ( 726 _ods_ir.TypeAttr(_ods_result_type_source_attr).value 727 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else 728 _ods_result_type_source_attr.type))PY"; 729 730 /// Python code template appending {0} type {1} times to the results list. 731 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; 732 733 /// Python code template for inferring the operation results using the 734 /// corresponding interface: 735 /// - {0} is the name of the class for which the types are inferred. 736 constexpr const char *inferTypeInterfaceTemplate = 737 R"PY(_ods_context = _ods_get_default_loc_context(loc) 738 results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( 739 operands=operands, 740 attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), 741 context=_ods_context, 742 loc=loc) 743 )PY"; 744 745 /// Appends the given multiline string as individual strings into 746 /// `builderLines`. 747 static void appendLineByLine(StringRef string, 748 llvm::SmallVectorImpl<std::string> &builderLines) { 749 750 std::pair<StringRef, StringRef> split = std::make_pair(string, string); 751 do { 752 split = split.second.split('\n'); 753 builderLines.push_back(split.first.str()); 754 } while (!split.second.empty()); 755 } 756 757 /// Populates `builderLines` with additional lines that are required in the 758 /// builder to set up op results. 759 static void 760 populateBuilderLinesResult(const Operator &op, 761 llvm::ArrayRef<std::string> names, 762 llvm::SmallVectorImpl<std::string> &builderLines) { 763 bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; 764 765 if (hasSameArgumentAndResultTypes(op)) { 766 builderLines.push_back(llvm::formatv( 767 appendSameResultsTemplate, "operands[0].type", op.getNumResults())); 768 return; 769 } 770 771 if (hasFirstAttrDerivedResultTypes(op)) { 772 const NamedAttribute &firstAttr = op.getAttribute(0); 773 assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " 774 "from which the type is derived"); 775 appendLineByLine( 776 llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), 777 builderLines); 778 builderLines.push_back(llvm::formatv(appendSameResultsTemplate, 779 "_ods_derived_result_type", 780 op.getNumResults())); 781 return; 782 } 783 784 if (hasInferTypeInterface(op)) { 785 appendLineByLine( 786 llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(), 787 builderLines); 788 return; 789 } 790 791 // For each element, find or generate a name. 792 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 793 const NamedTypeConstraint &element = op.getResult(i); 794 std::string name = names[i]; 795 796 // Choose the formatting string based on the element kind. 797 llvm::StringRef formatString; 798 if (!element.isVariableLength()) { 799 formatString = singleResultAppendTemplate; 800 } else if (element.isOptional()) { 801 formatString = optionalAppendResultTemplate; 802 } else { 803 assert(element.isVariadic() && "unhandled element group type"); 804 // If emitting with sizedSegments, then we add the actual list-typed 805 // element. Otherwise, we extend the actual operands. 806 if (sizedSegments) { 807 formatString = singleResultAppendTemplate; 808 } else { 809 formatString = multiResultAppendTemplate; 810 } 811 } 812 813 builderLines.push_back(llvm::formatv(formatString.data(), name)); 814 } 815 } 816 817 /// If the operation has variadic regions, adds a builder argument to specify 818 /// the number of those regions and builder lines to forward it to the generic 819 /// constructor. 820 static void 821 populateBuilderRegions(const Operator &op, 822 llvm::SmallVectorImpl<std::string> &builderArgs, 823 llvm::SmallVectorImpl<std::string> &builderLines) { 824 if (op.hasNoVariadicRegions()) 825 return; 826 827 // This is currently enforced when Operator is constructed. 828 assert(op.getNumVariadicRegions() == 1 && 829 op.getRegion(op.getNumRegions() - 1).isVariadic() && 830 "expected the last region to be varidic"); 831 832 const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1); 833 std::string name = 834 ("num_" + region.name.take_front().lower() + region.name.drop_front()) 835 .str(); 836 builderArgs.push_back(name); 837 builderLines.push_back( 838 llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name)); 839 } 840 841 /// Emits a default builder constructing an operation from the list of its 842 /// result types, followed by a list of its operands. 843 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { 844 // If we are asked to skip default builders, comply. 845 if (op.skipDefaultBuilders()) 846 return; 847 848 llvm::SmallVector<std::string> builderArgs; 849 llvm::SmallVector<std::string> builderLines; 850 llvm::SmallVector<std::string> operandArgNames; 851 llvm::SmallVector<std::string> successorArgNames; 852 builderArgs.reserve(op.getNumOperands() + op.getNumResults() + 853 op.getNumNativeAttributes() + op.getNumSuccessors()); 854 populateBuilderArgsResults(op, builderArgs); 855 size_t numResultArgs = builderArgs.size(); 856 populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); 857 858 populateBuilderLinesOperand(op, operandArgNames, builderLines); 859 populateBuilderLinesAttr( 860 op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs), 861 builderLines); 862 populateBuilderLinesResult( 863 op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs), 864 builderLines); 865 populateBuilderLinesSuccessors(op, successorArgNames, builderLines); 866 populateBuilderRegions(op, builderArgs, builderLines); 867 868 builderArgs.push_back("*"); 869 builderArgs.push_back("loc=None"); 870 builderArgs.push_back("ip=None"); 871 os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "), 872 llvm::join(builderLines, "\n ")); 873 } 874 875 static void constructAttributeMapping(const llvm::RecordKeeper &records, 876 AttributeClasses &attributeClasses) { 877 for (const llvm::Record *rec : 878 records.getAllDerivedDefinitions("PythonAttr")) { 879 attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(), 880 rec->getValueAsString("pythonType").trim()); 881 } 882 } 883 884 static void emitSegmentSpec( 885 const Operator &op, const char *kind, 886 llvm::function_ref<int(const Operator &)> getNumElements, 887 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 888 getElement, 889 raw_ostream &os) { 890 std::string segmentSpec("["); 891 for (int i = 0, e = getNumElements(op); i < e; ++i) { 892 const NamedTypeConstraint &element = getElement(op, i); 893 if (element.isOptional()) { 894 segmentSpec.append("0,"); 895 } else if (element.isVariadic()) { 896 segmentSpec.append("-1,"); 897 } else { 898 segmentSpec.append("1,"); 899 } 900 } 901 segmentSpec.append("]"); 902 903 os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); 904 } 905 906 static void emitRegionAttributes(const Operator &op, raw_ostream &os) { 907 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). 908 // Note that the base OpView class defines this as (0, True). 909 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); 910 os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount, 911 op.hasNoVariadicRegions() ? "True" : "False"); 912 } 913 914 /// Emits named accessors to regions. 915 static void emitRegionAccessors(const Operator &op, raw_ostream &os) { 916 for (auto en : llvm::enumerate(op.getRegions())) { 917 const NamedRegion ®ion = en.value(); 918 if (region.name.empty()) 919 continue; 920 921 assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && 922 "expected only the last region to be variadic"); 923 os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name), 924 std::to_string(en.index()) + 925 (region.isVariadic() ? ":" : "")); 926 } 927 } 928 929 /// Emits bindings for a specific Op to the given output stream. 930 static void emitOpBindings(const Operator &op, 931 const AttributeClasses &attributeClasses, 932 raw_ostream &os) { 933 os << llvm::formatv(opClassTemplate, op.getCppClassName(), 934 op.getOperationName()); 935 936 // Sized segments. 937 if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { 938 emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); 939 } 940 if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { 941 emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); 942 } 943 944 emitRegionAttributes(op, os); 945 emitDefaultOpBuilder(op, os); 946 emitOperandAccessors(op, os); 947 emitAttributeAccessors(op, attributeClasses, os); 948 emitResultAccessors(op, os); 949 emitRegionAccessors(op, os); 950 } 951 952 /// Emits bindings for the dialect specified in the command line, including file 953 /// headers and utilities. Returns `false` on success to comply with Tablegen 954 /// registration requirements. 955 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { 956 if (clDialectName.empty()) 957 llvm::PrintFatalError("dialect name not provided"); 958 959 AttributeClasses attributeClasses; 960 constructAttributeMapping(records, attributeClasses); 961 962 os << llvm::formatv(fileHeader, clDialectName.getValue()); 963 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); 964 965 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { 966 Operator op(rec); 967 if (op.getDialectName() == clDialectName.getValue()) 968 emitOpBindings(op, attributeClasses, os); 969 } 970 return false; 971 } 972 973 static GenRegistration 974 genPythonBindings("gen-python-op-bindings", 975 "Generate Python bindings for MLIR Ops", &emitAllOps); 976