1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python 10 // binding classes wrapping a generic operation API. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/GenInfo.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/ADT/StringSet.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace mlir; 23 using namespace mlir::tblgen; 24 25 /// File header and includes. 26 /// {0} is the dialect namespace. 27 constexpr const char *fileHeader = R"Py( 28 # Autogenerated by mlir-tblgen; don't manually edit. 29 30 from ._ods_common import _cext as _ods_cext 31 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values 32 _ods_ir = _ods_cext.ir 33 34 try: 35 from . import _{0}_ops_ext as _ods_ext_module 36 except ImportError: 37 _ods_ext_module = None 38 39 import builtins 40 41 )Py"; 42 43 /// Template for dialect class: 44 /// {0} is the dialect namespace. 45 constexpr const char *dialectClassTemplate = R"Py( 46 @_ods_cext.register_dialect 47 class _Dialect(_ods_ir.Dialect): 48 DIALECT_NAMESPACE = "{0}" 49 pass 50 51 )Py"; 52 53 /// Template for operation class: 54 /// {0} is the Python class name; 55 /// {1} is the operation name. 56 constexpr const char *opClassTemplate = R"Py( 57 @_ods_cext.register_operation(_Dialect) 58 @_ods_extend_opview_class(_ods_ext_module) 59 class {0}(_ods_ir.OpView): 60 OPERATION_NAME = "{1}" 61 )Py"; 62 63 /// Template for class level declarations of operand and result 64 /// segment specs. 65 /// {0} is either "OPERAND" or "RESULT" 66 /// {1} is the segment spec 67 /// Each segment spec is either None (default) or an array of integers 68 /// where: 69 /// 1 = single element (expect non sequence operand/result) 70 /// 0 = optional element (expect a value or None) 71 /// -1 = operand/result is a sequence corresponding to a variadic 72 constexpr const char *opClassSizedSegmentsTemplate = R"Py( 73 _ODS_{0}_SEGMENTS = {1} 74 )Py"; 75 76 /// Template for class level declarations of the _ODS_REGIONS spec: 77 /// {0} is the minimum number of regions 78 /// {1} is the Python bool literal for hasNoVariadicRegions 79 constexpr const char *opClassRegionSpecTemplate = R"Py( 80 _ODS_REGIONS = ({0}, {1}) 81 )Py"; 82 83 /// Template for single-element accessor: 84 /// {0} is the name of the accessor; 85 /// {1} is either 'operand' or 'result'; 86 /// {2} is the position in the element list. 87 constexpr const char *opSingleTemplate = R"Py( 88 @builtins.property 89 def {0}(self): 90 return self.operation.{1}s[{2}] 91 )Py"; 92 93 /// Template for single-element accessor after a variable-length group: 94 /// {0} is the name of the accessor; 95 /// {1} is either 'operand' or 'result'; 96 /// {2} is the total number of element groups; 97 /// {3} is the position of the current group in the group list. 98 /// This works for both a single variadic group (non-negative length) and an 99 /// single optional element (zero length if the element is absent). 100 constexpr const char *opSingleAfterVariableTemplate = R"Py( 101 @builtins.property 102 def {0}(self): 103 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 104 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] 105 )Py"; 106 107 /// Template for an optional element accessor: 108 /// {0} is the name of the accessor; 109 /// {1} is either 'operand' or 'result'; 110 /// {2} is the total number of element groups; 111 /// {3} is the position of the current group in the group list. 112 /// This works if we have only one variable-length group (and it's the optional 113 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is 114 /// smaller than the total number of groups. 115 constexpr const char *opOneOptionalTemplate = R"Py( 116 @builtins.property 117 def {0}(self): 118 return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}] 119 )Py"; 120 121 /// Template for the variadic group accessor in the single variadic group case: 122 /// {0} is the name of the accessor; 123 /// {1} is either 'operand' or 'result'; 124 /// {2} is the total number of element groups; 125 /// {3} is the position of the current group in the group list. 126 constexpr const char *opOneVariadicTemplate = R"Py( 127 @builtins.property 128 def {0}(self): 129 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 130 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] 131 )Py"; 132 133 /// First part of the template for equally-sized variadic group accessor: 134 /// {0} is the name of the accessor; 135 /// {1} is either 'operand' or 'result'; 136 /// {2} is the total number of variadic groups; 137 /// {3} is the number of non-variadic groups preceding the current group; 138 /// {3} is the number of variadic groups preceding the current group. 139 constexpr const char *opVariadicEqualPrefixTemplate = R"Py( 140 @builtins.property 141 def {0}(self): 142 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py"; 143 144 /// Second part of the template for equally-sized case, accessing a single 145 /// element: 146 /// {0} is either 'operand' or 'result'. 147 constexpr const char *opVariadicEqualSimpleTemplate = R"Py( 148 return self.operation.{0}s[start] 149 )Py"; 150 151 /// Second part of the template for equally-sized case, accessing a variadic 152 /// group: 153 /// {0} is either 'operand' or 'result'. 154 constexpr const char *opVariadicEqualVariadicTemplate = R"Py( 155 return self.operation.{0}s[start:start + pg] 156 )Py"; 157 158 /// Template for an attribute-sized group accessor: 159 /// {0} is the name of the accessor; 160 /// {1} is either 'operand' or 'result'; 161 /// {2} is the position of the group in the group list; 162 /// {3} is a return suffix (expected [0] for single-element, empty for 163 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). 164 constexpr const char *opVariadicSegmentTemplate = R"Py( 165 @builtins.property 166 def {0}(self): 167 {1}_range = _ods_segmented_accessor( 168 self.operation.{1}s, 169 self.operation.attributes["{1}_segment_sizes"], {2}) 170 return {1}_range{3} 171 )Py"; 172 173 /// Template for a suffix when accessing an optional element in the 174 /// attribute-sized case: 175 /// {0} is either 'operand' or 'result'; 176 constexpr const char *opVariadicSegmentOptionalTrailingTemplate = 177 R"Py([0] if len({0}_range) > 0 else None)Py"; 178 179 /// Template for an operation attribute getter: 180 /// {0} is the name of the attribute sanitized for Python; 181 /// {1} is the Python type of the attribute; 182 /// {2} os the original name of the attribute. 183 constexpr const char *attributeGetterTemplate = R"Py( 184 @builtins.property 185 def {0}(self): 186 return {1}(self.operation.attributes["{2}"]) 187 )Py"; 188 189 /// Template for an optional operation attribute getter: 190 /// {0} is the name of the attribute sanitized for Python; 191 /// {1} is the Python type of the attribute; 192 /// {2} is the original name of the attribute. 193 constexpr const char *optionalAttributeGetterTemplate = R"Py( 194 @builtins.property 195 def {0}(self): 196 if "{2}" not in self.operation.attributes: 197 return None 198 return {1}(self.operation.attributes["{2}"]) 199 )Py"; 200 201 /// Template for a getter of a unit operation attribute, returns True of the 202 /// unit attribute is present, False otherwise (unit attributes have meaning 203 /// by mere presence): 204 /// {0} is the name of the attribute sanitized for Python, 205 /// {1} is the original name of the attribute. 206 constexpr const char *unitAttributeGetterTemplate = R"Py( 207 @builtins.property 208 def {0}(self): 209 return "{1}" in self.operation.attributes 210 )Py"; 211 212 /// Template for an operation attribute setter: 213 /// {0} is the name of the attribute sanitized for Python; 214 /// {1} is the original name of the attribute. 215 constexpr const char *attributeSetterTemplate = R"Py( 216 @{0}.setter 217 def {0}(self, value): 218 if value is None: 219 raise ValueError("'None' not allowed as value for mandatory attributes") 220 self.operation.attributes["{1}"] = value 221 )Py"; 222 223 /// Template for a setter of an optional operation attribute, setting to None 224 /// removes the attribute: 225 /// {0} is the name of the attribute sanitized for Python; 226 /// {1} is the original name of the attribute. 227 constexpr const char *optionalAttributeSetterTemplate = R"Py( 228 @{0}.setter 229 def {0}(self, value): 230 if value is not None: 231 self.operation.attributes["{1}"] = value 232 elif "{1}" in self.operation.attributes: 233 del self.operation.attributes["{1}"] 234 )Py"; 235 236 /// Template for a setter of a unit operation attribute, setting to None or 237 /// False removes the attribute: 238 /// {0} is the name of the attribute sanitized for Python; 239 /// {1} is the original name of the attribute. 240 constexpr const char *unitAttributeSetterTemplate = R"Py( 241 @{0}.setter 242 def {0}(self, value): 243 if bool(value): 244 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() 245 elif "{1}" in self.operation.attributes: 246 del self.operation.attributes["{1}"] 247 )Py"; 248 249 /// Template for a deleter of an optional or a unit operation attribute, removes 250 /// the attribute from the operation: 251 /// {0} is the name of the attribute sanitized for Python; 252 /// {1} is the original name of the attribute. 253 constexpr const char *attributeDeleterTemplate = R"Py( 254 @{0}.deleter 255 def {0}(self): 256 del self.operation.attributes["{1}"] 257 )Py"; 258 259 constexpr const char *regionAccessorTemplate = R"PY( 260 @builtins.property 261 def {0}(self): 262 return self.regions[{1}] 263 )PY"; 264 265 static llvm::cl::OptionCategory 266 clOpPythonBindingCat("Options for -gen-python-op-bindings"); 267 268 static llvm::cl::opt<std::string> 269 clDialectName("bind-dialect", 270 llvm::cl::desc("The dialect to run the generator for"), 271 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 272 273 using AttributeClasses = DenseMap<StringRef, StringRef>; 274 275 /// Checks whether `str` is a Python keyword. 276 static bool isPythonKeyword(StringRef str) { 277 static llvm::StringSet<> keywords( 278 {"and", "as", "assert", "break", "class", "continue", 279 "def", "del", "elif", "else", "except", "finally", 280 "for", "from", "global", "if", "import", "in", 281 "is", "lambda", "nonlocal", "not", "or", "pass", 282 "raise", "return", "try", "while", "with", "yield"}); 283 return keywords.contains(str); 284 } 285 286 /// Checks whether `str` would shadow a generated variable or attribute 287 /// part of the OpView API. 288 static bool isODSReserved(StringRef str) { 289 static llvm::StringSet<> reserved( 290 {"attributes", "create", "context", "ip", "operands", "print", "get_asm", 291 "loc", "verify", "regions", "results", "self", "operation", 292 "DIALECT_NAMESPACE", "OPERATION_NAME"}); 293 return str.startswith("_ods_") || str.endswith("_ods") || 294 reserved.contains(str); 295 } 296 297 /// Modifies the `name` in a way that it becomes suitable for Python bindings 298 /// (does not change the `name` if it already is suitable) and returns the 299 /// modified version. 300 static std::string sanitizeName(StringRef name) { 301 if (isPythonKeyword(name) || isODSReserved(name)) 302 return (name + "_").str(); 303 return name.str(); 304 } 305 306 static std::string attrSizedTraitForKind(const char *kind) { 307 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", 308 llvm::StringRef(kind).take_front().upper(), 309 llvm::StringRef(kind).drop_front()); 310 } 311 312 /// Emits accessors to "elements" of an Op definition. Currently, the supported 313 /// elements are operands and results, indicated by `kind`, which must be either 314 /// `operand` or `result` and is used verbatim in the emitted code. 315 static void emitElementAccessors( 316 const Operator &op, raw_ostream &os, const char *kind, 317 llvm::function_ref<unsigned(const Operator &)> getNumVariableLength, 318 llvm::function_ref<int(const Operator &)> getNumElements, 319 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 320 getElement) { 321 assert(llvm::is_contained( 322 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && 323 "unsupported kind"); 324 325 // Traits indicating how to process variadic elements. 326 std::string sameSizeTrait = 327 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", 328 llvm::StringRef(kind).take_front().upper(), 329 llvm::StringRef(kind).drop_front()); 330 std::string attrSizedTrait = attrSizedTraitForKind(kind); 331 332 unsigned numVariableLength = getNumVariableLength(op); 333 334 // If there is only one variable-length element group, its size can be 335 // inferred from the total number of elements. If there are none, the 336 // generation is straightforward. 337 if (numVariableLength <= 1) { 338 bool seenVariableLength = false; 339 for (int i = 0, e = getNumElements(op); i < e; ++i) { 340 const NamedTypeConstraint &element = getElement(op, i); 341 if (element.isVariableLength()) 342 seenVariableLength = true; 343 if (element.name.empty()) 344 continue; 345 if (element.isVariableLength()) { 346 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate 347 : opOneVariadicTemplate, 348 sanitizeName(element.name), kind, 349 getNumElements(op), i); 350 } else if (seenVariableLength) { 351 os << llvm::formatv(opSingleAfterVariableTemplate, 352 sanitizeName(element.name), kind, 353 getNumElements(op), i); 354 } else { 355 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, 356 i); 357 } 358 } 359 return; 360 } 361 362 // Handle the operations where variadic groups have the same size. 363 if (op.getTrait(sameSizeTrait)) { 364 int numPrecedingSimple = 0; 365 int numPrecedingVariadic = 0; 366 for (int i = 0, e = getNumElements(op); i < e; ++i) { 367 const NamedTypeConstraint &element = getElement(op, i); 368 if (!element.name.empty()) { 369 os << llvm::formatv(opVariadicEqualPrefixTemplate, 370 sanitizeName(element.name), kind, numVariableLength, 371 numPrecedingSimple, numPrecedingVariadic); 372 os << llvm::formatv(element.isVariableLength() 373 ? opVariadicEqualVariadicTemplate 374 : opVariadicEqualSimpleTemplate, 375 kind); 376 } 377 if (element.isVariableLength()) 378 ++numPrecedingVariadic; 379 else 380 ++numPrecedingSimple; 381 } 382 return; 383 } 384 385 // Handle the operations where the size of groups (variadic or not) is 386 // provided as an attribute. For non-variadic elements, make sure to return 387 // an element rather than a singleton container. 388 if (op.getTrait(attrSizedTrait)) { 389 for (int i = 0, e = getNumElements(op); i < e; ++i) { 390 const NamedTypeConstraint &element = getElement(op, i); 391 if (element.name.empty()) 392 continue; 393 std::string trailing; 394 if (!element.isVariableLength()) 395 trailing = "[0]"; 396 else if (element.isOptional()) 397 trailing = std::string( 398 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); 399 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), 400 kind, i, trailing); 401 } 402 return; 403 } 404 405 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); 406 } 407 408 /// Free function helpers accessing Operator components. 409 static int getNumOperands(const Operator &op) { return op.getNumOperands(); } 410 static const NamedTypeConstraint &getOperand(const Operator &op, int i) { 411 return op.getOperand(i); 412 } 413 static int getNumResults(const Operator &op) { return op.getNumResults(); } 414 static const NamedTypeConstraint &getResult(const Operator &op, int i) { 415 return op.getResult(i); 416 } 417 418 /// Emits accessors to Op operands. 419 static void emitOperandAccessors(const Operator &op, raw_ostream &os) { 420 auto getNumVariableLengthOperands = [](const Operator &oper) { 421 return oper.getNumVariableLengthOperands(); 422 }; 423 emitElementAccessors(op, os, "operand", getNumVariableLengthOperands, 424 getNumOperands, getOperand); 425 } 426 427 /// Emits accessors Op results. 428 static void emitResultAccessors(const Operator &op, raw_ostream &os) { 429 auto getNumVariableLengthResults = [](const Operator &oper) { 430 return oper.getNumVariableLengthResults(); 431 }; 432 emitElementAccessors(op, os, "result", getNumVariableLengthResults, 433 getNumResults, getResult); 434 } 435 436 /// Emits accessors to Op attributes. 437 static void emitAttributeAccessors(const Operator &op, 438 const AttributeClasses &attributeClasses, 439 raw_ostream &os) { 440 for (const auto &namedAttr : op.getAttributes()) { 441 // Skip "derived" attributes because they are just C++ functions that we 442 // don't currently expose. 443 if (namedAttr.attr.isDerivedAttr()) 444 continue; 445 446 if (namedAttr.name.empty()) 447 continue; 448 449 std::string sanitizedName = sanitizeName(namedAttr.name); 450 451 // Unit attributes are handled specially. 452 if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 453 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, 454 namedAttr.name); 455 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, 456 namedAttr.name); 457 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 458 namedAttr.name); 459 continue; 460 } 461 462 // Other kinds of attributes need a mapping to a Python type. 463 if (!attributeClasses.count(namedAttr.attr.getStorageType().trim())) 464 continue; 465 466 StringRef pythonType = 467 attributeClasses.lookup(namedAttr.attr.getStorageType()); 468 if (namedAttr.attr.isOptional()) { 469 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName, 470 pythonType, namedAttr.name); 471 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName, 472 namedAttr.name); 473 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 474 namedAttr.name); 475 } else { 476 os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType, 477 namedAttr.name); 478 os << llvm::formatv(attributeSetterTemplate, sanitizedName, 479 namedAttr.name); 480 // Non-optional attributes cannot be deleted. 481 } 482 } 483 } 484 485 /// Template for the default auto-generated builder. 486 /// {0} is a comma-separated list of builder arguments, including the trailing 487 /// `loc` and `ip`; 488 /// {1} is the code populating `operands`, `results` and `attributes`, 489 /// `successors` fields. 490 constexpr const char *initTemplate = R"Py( 491 def __init__(self, {0}): 492 operands = [] 493 results = [] 494 attributes = {{} 495 regions = None 496 {1} 497 super().__init__(self.build_generic( 498 attributes=attributes, results=results, operands=operands, 499 successors=_ods_successors, regions=regions, loc=loc, ip=ip)) 500 )Py"; 501 502 /// Template for appending a single element to the operand/result list. 503 /// {0} is the field name. 504 constexpr const char *singleOperandAppendTemplate = 505 "operands.append(_get_op_result_or_value({0}))"; 506 constexpr const char *singleResultAppendTemplate = "results.append({0})"; 507 508 /// Template for appending an optional element to the operand/result list. 509 /// {0} is the field name. 510 constexpr const char *optionalAppendOperandTemplate = 511 "if {0} is not None: operands.append(_get_op_result_or_value({0}))"; 512 constexpr const char *optionalAppendAttrSizedOperandsTemplate = 513 "operands.append(_get_op_result_or_value({0}) if {0} is not None else " 514 "None)"; 515 constexpr const char *optionalAppendResultTemplate = 516 "if {0} is not None: results.append({0})"; 517 518 /// Template for appending a list of elements to the operand/result list. 519 /// {0} is the field name. 520 constexpr const char *multiOperandAppendTemplate = 521 "operands.extend(_get_op_results_or_values({0}))"; 522 constexpr const char *multiOperandAppendPackTemplate = 523 "operands.append(_get_op_results_or_values({0}))"; 524 constexpr const char *multiResultAppendTemplate = "results.extend({0})"; 525 526 /// Template for setting an attribute in the operation builder. 527 /// {0} is the attribute name; 528 /// {1} is the builder argument name. 529 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; 530 531 /// Template for setting an optional attribute in the operation builder. 532 /// {0} is the attribute name; 533 /// {1} is the builder argument name. 534 constexpr const char *initOptionalAttributeTemplate = 535 R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; 536 537 constexpr const char *initUnitAttributeTemplate = 538 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( 539 _ods_get_default_loc_context(loc)))Py"; 540 541 /// Template to initialize the successors list in the builder if there are any 542 /// successors. 543 /// {0} is the value to initialize the successors list to. 544 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; 545 546 /// Template to append or extend the list of successors in the builder. 547 /// {0} is the list method ('append' or 'extend'); 548 /// {1} is the value to add. 549 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py"; 550 551 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer 552 /// result types of the given operation. 553 static bool hasSameArgumentAndResultTypes(const Operator &op) { 554 return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && 555 op.getNumVariableLengthResults() == 0; 556 } 557 558 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer 559 /// result types of the given operation. 560 static bool hasFirstAttrDerivedResultTypes(const Operator &op) { 561 return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && 562 op.getNumVariableLengthResults() == 0; 563 } 564 565 /// Returns true if the InferTypeOpInterface can be used to infer result types 566 /// of the given operation. 567 static bool hasInferTypeInterface(const Operator &op) { 568 return op.getTrait("::mlir::InferTypeOpInterface::Trait") && 569 op.getNumRegions() == 0; 570 } 571 572 /// Returns true if there is a trait or interface that can be used to infer 573 /// result types of the given operation. 574 static bool canInferType(const Operator &op) { 575 return hasSameArgumentAndResultTypes(op) || 576 hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); 577 } 578 579 /// Populates `builderArgs` with result names if the builder is expected to 580 /// accept them as arguments. 581 static void 582 populateBuilderArgsResults(const Operator &op, 583 llvm::SmallVectorImpl<std::string> &builderArgs) { 584 if (canInferType(op)) 585 return; 586 587 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 588 std::string name = op.getResultName(i).str(); 589 if (name.empty()) { 590 if (op.getNumResults() == 1) { 591 // Special case for one result, make the default name be 'result' 592 // to properly match the built-in result accessor. 593 name = "result"; 594 } else { 595 name = llvm::formatv("_gen_res_{0}", i); 596 } 597 } 598 name = sanitizeName(name); 599 builderArgs.push_back(name); 600 } 601 } 602 603 /// Populates `builderArgs` with the Python-compatible names of builder function 604 /// arguments using intermixed attributes and operands in the same order as they 605 /// appear in the `arguments` field of the op definition. Additionally, 606 /// `operandNames` is populated with names of operands in their order of 607 /// appearance. 608 static void 609 populateBuilderArgs(const Operator &op, 610 llvm::SmallVectorImpl<std::string> &builderArgs, 611 llvm::SmallVectorImpl<std::string> &operandNames, 612 llvm::SmallVectorImpl<std::string> &successorArgNames) { 613 614 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 615 std::string name = op.getArgName(i).str(); 616 if (name.empty()) 617 name = llvm::formatv("_gen_arg_{0}", i); 618 name = sanitizeName(name); 619 builderArgs.push_back(name); 620 if (!op.getArg(i).is<NamedAttribute *>()) 621 operandNames.push_back(name); 622 } 623 } 624 625 /// Populates `builderArgs` with the Python-compatible names of builder function 626 /// successor arguments. Additionally, `successorArgNames` is also populated. 627 static void populateBuilderArgsSuccessors( 628 const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs, 629 llvm::SmallVectorImpl<std::string> &successorArgNames) { 630 631 for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { 632 NamedSuccessor successor = op.getSuccessor(i); 633 std::string name = std::string(successor.name); 634 if (name.empty()) 635 name = llvm::formatv("_gen_successor_{0}", i); 636 name = sanitizeName(name); 637 builderArgs.push_back(name); 638 successorArgNames.push_back(name); 639 } 640 } 641 642 /// Populates `builderLines` with additional lines that are required in the 643 /// builder to set up operation attributes. `argNames` is expected to contain 644 /// the names of builder arguments that correspond to op arguments, i.e. to the 645 /// operands and attributes in the same order as they appear in the `arguments` 646 /// field. 647 static void 648 populateBuilderLinesAttr(const Operator &op, 649 llvm::ArrayRef<std::string> argNames, 650 llvm::SmallVectorImpl<std::string> &builderLines) { 651 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 652 Argument arg = op.getArg(i); 653 auto *attribute = arg.dyn_cast<NamedAttribute *>(); 654 if (!attribute) 655 continue; 656 657 // Unit attributes are handled specially. 658 if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 659 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, 660 attribute->name, argNames[i])); 661 continue; 662 } 663 664 builderLines.push_back(llvm::formatv(attribute->attr.isOptional() 665 ? initOptionalAttributeTemplate 666 : initAttributeTemplate, 667 attribute->name, argNames[i])); 668 } 669 } 670 671 /// Populates `builderLines` with additional lines that are required in the 672 /// builder to set up successors. successorArgNames is expected to correspond 673 /// to the Python argument name for each successor on the op. 674 static void populateBuilderLinesSuccessors( 675 const Operator &op, llvm::ArrayRef<std::string> successorArgNames, 676 llvm::SmallVectorImpl<std::string> &builderLines) { 677 if (successorArgNames.empty()) { 678 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None")); 679 return; 680 } 681 682 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]")); 683 for (int i = 0, e = successorArgNames.size(); i < e; ++i) { 684 auto &argName = successorArgNames[i]; 685 const NamedSuccessor &successor = op.getSuccessor(i); 686 builderLines.push_back( 687 llvm::formatv(addSuccessorTemplate, 688 successor.isVariadic() ? "extend" : "append", argName)); 689 } 690 } 691 692 /// Populates `builderLines` with additional lines that are required in the 693 /// builder to set up op operands. 694 static void 695 populateBuilderLinesOperand(const Operator &op, 696 llvm::ArrayRef<std::string> names, 697 llvm::SmallVectorImpl<std::string> &builderLines) { 698 bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr; 699 700 // For each element, find or generate a name. 701 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 702 const NamedTypeConstraint &element = op.getOperand(i); 703 std::string name = names[i]; 704 705 // Choose the formatting string based on the element kind. 706 llvm::StringRef formatString; 707 if (!element.isVariableLength()) { 708 formatString = singleOperandAppendTemplate; 709 } else if (element.isOptional()) { 710 if (sizedSegments) { 711 formatString = optionalAppendAttrSizedOperandsTemplate; 712 } else { 713 formatString = optionalAppendOperandTemplate; 714 } 715 } else { 716 assert(element.isVariadic() && "unhandled element group type"); 717 // If emitting with sizedSegments, then we add the actual list-typed 718 // element. Otherwise, we extend the actual operands. 719 if (sizedSegments) { 720 formatString = multiOperandAppendPackTemplate; 721 } else { 722 formatString = multiOperandAppendTemplate; 723 } 724 } 725 726 builderLines.push_back(llvm::formatv(formatString.data(), name)); 727 } 728 } 729 730 /// Python code template for deriving the operation result types from its 731 /// attribute: 732 /// - {0} is the name of the attribute from which to derive the types. 733 constexpr const char *deriveTypeFromAttrTemplate = 734 R"PY(_ods_result_type_source_attr = attributes["{0}"] 735 _ods_derived_result_type = ( 736 _ods_ir.TypeAttr(_ods_result_type_source_attr).value 737 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else 738 _ods_result_type_source_attr.type))PY"; 739 740 /// Python code template appending {0} type {1} times to the results list. 741 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; 742 743 /// Python code template for inferring the operation results using the 744 /// corresponding interface: 745 /// - {0} is the name of the class for which the types are inferred. 746 constexpr const char *inferTypeInterfaceTemplate = 747 R"PY(_ods_context = _ods_get_default_loc_context(loc) 748 results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( 749 operands=operands, 750 attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), 751 context=_ods_context, 752 loc=loc) 753 )PY"; 754 755 /// Appends the given multiline string as individual strings into 756 /// `builderLines`. 757 static void appendLineByLine(StringRef string, 758 llvm::SmallVectorImpl<std::string> &builderLines) { 759 760 std::pair<StringRef, StringRef> split = std::make_pair(string, string); 761 do { 762 split = split.second.split('\n'); 763 builderLines.push_back(split.first.str()); 764 } while (!split.second.empty()); 765 } 766 767 /// Populates `builderLines` with additional lines that are required in the 768 /// builder to set up op results. 769 static void 770 populateBuilderLinesResult(const Operator &op, 771 llvm::ArrayRef<std::string> names, 772 llvm::SmallVectorImpl<std::string> &builderLines) { 773 bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; 774 775 if (hasSameArgumentAndResultTypes(op)) { 776 builderLines.push_back(llvm::formatv( 777 appendSameResultsTemplate, "operands[0].type", op.getNumResults())); 778 return; 779 } 780 781 if (hasFirstAttrDerivedResultTypes(op)) { 782 const NamedAttribute &firstAttr = op.getAttribute(0); 783 assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " 784 "from which the type is derived"); 785 appendLineByLine( 786 llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), 787 builderLines); 788 builderLines.push_back(llvm::formatv(appendSameResultsTemplate, 789 "_ods_derived_result_type", 790 op.getNumResults())); 791 return; 792 } 793 794 if (hasInferTypeInterface(op)) { 795 appendLineByLine( 796 llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(), 797 builderLines); 798 return; 799 } 800 801 // For each element, find or generate a name. 802 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 803 const NamedTypeConstraint &element = op.getResult(i); 804 std::string name = names[i]; 805 806 // Choose the formatting string based on the element kind. 807 llvm::StringRef formatString; 808 if (!element.isVariableLength()) { 809 formatString = singleResultAppendTemplate; 810 } else if (element.isOptional()) { 811 formatString = optionalAppendResultTemplate; 812 } else { 813 assert(element.isVariadic() && "unhandled element group type"); 814 // If emitting with sizedSegments, then we add the actual list-typed 815 // element. Otherwise, we extend the actual operands. 816 if (sizedSegments) { 817 formatString = singleResultAppendTemplate; 818 } else { 819 formatString = multiResultAppendTemplate; 820 } 821 } 822 823 builderLines.push_back(llvm::formatv(formatString.data(), name)); 824 } 825 } 826 827 /// If the operation has variadic regions, adds a builder argument to specify 828 /// the number of those regions and builder lines to forward it to the generic 829 /// constructor. 830 static void 831 populateBuilderRegions(const Operator &op, 832 llvm::SmallVectorImpl<std::string> &builderArgs, 833 llvm::SmallVectorImpl<std::string> &builderLines) { 834 if (op.hasNoVariadicRegions()) 835 return; 836 837 // This is currently enforced when Operator is constructed. 838 assert(op.getNumVariadicRegions() == 1 && 839 op.getRegion(op.getNumRegions() - 1).isVariadic() && 840 "expected the last region to be varidic"); 841 842 const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1); 843 std::string name = 844 ("num_" + region.name.take_front().lower() + region.name.drop_front()) 845 .str(); 846 builderArgs.push_back(name); 847 builderLines.push_back( 848 llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name)); 849 } 850 851 /// Emits a default builder constructing an operation from the list of its 852 /// result types, followed by a list of its operands. 853 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { 854 // If we are asked to skip default builders, comply. 855 if (op.skipDefaultBuilders()) 856 return; 857 858 llvm::SmallVector<std::string> builderArgs; 859 llvm::SmallVector<std::string> builderLines; 860 llvm::SmallVector<std::string> operandArgNames; 861 llvm::SmallVector<std::string> successorArgNames; 862 builderArgs.reserve(op.getNumOperands() + op.getNumResults() + 863 op.getNumNativeAttributes() + op.getNumSuccessors()); 864 populateBuilderArgsResults(op, builderArgs); 865 size_t numResultArgs = builderArgs.size(); 866 populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); 867 size_t numOperandAttrArgs = builderArgs.size() - numResultArgs; 868 populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); 869 870 populateBuilderLinesOperand(op, operandArgNames, builderLines); 871 populateBuilderLinesAttr( 872 op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs), 873 builderLines); 874 populateBuilderLinesResult( 875 op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs), 876 builderLines); 877 populateBuilderLinesSuccessors(op, successorArgNames, builderLines); 878 populateBuilderRegions(op, builderArgs, builderLines); 879 880 // Layout of builderArgs vector elements: 881 // [ result_args operand_attr_args successor_args regions ] 882 883 // Determine whether the argument corresponding to a given index into the 884 // builderArgs vector is a python keyword argument or not. 885 auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool { 886 // All result, successor, and region arguments are positional arguments. 887 if ((builderArgIndex < numResultArgs) || 888 (builderArgIndex >= (numResultArgs + numOperandAttrArgs))) 889 return false; 890 // Keyword arguments: 891 // - optional named attributes (including unit attributes) 892 // - default-valued named attributes 893 // - optional operands 894 Argument a = op.getArg(builderArgIndex - numResultArgs); 895 if (auto *nattr = a.dyn_cast<NamedAttribute *>()) 896 return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue()); 897 if (auto *ntype = a.dyn_cast<NamedTypeConstraint *>()) 898 return ntype->isOptional(); 899 else 900 return false; 901 }; 902 903 // StringRefs in functionArgs refer to strings allocated by builderArgs. 904 llvm::SmallVector<llvm::StringRef> functionArgs; 905 906 // Add positional arguments. 907 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { 908 if (!isKeywordArgFn(i)) 909 functionArgs.push_back(builderArgs[i]); 910 } 911 912 // Add a bare '*' to indicate that all following arguments must be keyword 913 // arguments. 914 functionArgs.push_back("*"); 915 916 // Add a default 'None' value to each keyword arg string, and then add to the 917 // function args list. 918 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { 919 if (isKeywordArgFn(i)) { 920 builderArgs[i].append("=None"); 921 functionArgs.push_back(builderArgs[i]); 922 } 923 } 924 functionArgs.push_back("loc=None"); 925 functionArgs.push_back("ip=None"); 926 os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "), 927 llvm::join(builderLines, "\n ")); 928 } 929 930 static void constructAttributeMapping(const llvm::RecordKeeper &records, 931 AttributeClasses &attributeClasses) { 932 for (const llvm::Record *rec : 933 records.getAllDerivedDefinitions("PythonAttr")) { 934 attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(), 935 rec->getValueAsString("pythonType").trim()); 936 } 937 } 938 939 static void emitSegmentSpec( 940 const Operator &op, const char *kind, 941 llvm::function_ref<int(const Operator &)> getNumElements, 942 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 943 getElement, 944 raw_ostream &os) { 945 std::string segmentSpec("["); 946 for (int i = 0, e = getNumElements(op); i < e; ++i) { 947 const NamedTypeConstraint &element = getElement(op, i); 948 if (element.isOptional()) { 949 segmentSpec.append("0,"); 950 } else if (element.isVariadic()) { 951 segmentSpec.append("-1,"); 952 } else { 953 segmentSpec.append("1,"); 954 } 955 } 956 segmentSpec.append("]"); 957 958 os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); 959 } 960 961 static void emitRegionAttributes(const Operator &op, raw_ostream &os) { 962 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). 963 // Note that the base OpView class defines this as (0, True). 964 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); 965 os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount, 966 op.hasNoVariadicRegions() ? "True" : "False"); 967 } 968 969 /// Emits named accessors to regions. 970 static void emitRegionAccessors(const Operator &op, raw_ostream &os) { 971 for (const auto &en : llvm::enumerate(op.getRegions())) { 972 const NamedRegion ®ion = en.value(); 973 if (region.name.empty()) 974 continue; 975 976 assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && 977 "expected only the last region to be variadic"); 978 os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name), 979 std::to_string(en.index()) + 980 (region.isVariadic() ? ":" : "")); 981 } 982 } 983 984 /// Emits bindings for a specific Op to the given output stream. 985 static void emitOpBindings(const Operator &op, 986 const AttributeClasses &attributeClasses, 987 raw_ostream &os) { 988 os << llvm::formatv(opClassTemplate, op.getCppClassName(), 989 op.getOperationName()); 990 991 // Sized segments. 992 if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { 993 emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); 994 } 995 if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { 996 emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); 997 } 998 999 emitRegionAttributes(op, os); 1000 emitDefaultOpBuilder(op, os); 1001 emitOperandAccessors(op, os); 1002 emitAttributeAccessors(op, attributeClasses, os); 1003 emitResultAccessors(op, os); 1004 emitRegionAccessors(op, os); 1005 } 1006 1007 /// Emits bindings for the dialect specified in the command line, including file 1008 /// headers and utilities. Returns `false` on success to comply with Tablegen 1009 /// registration requirements. 1010 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { 1011 if (clDialectName.empty()) 1012 llvm::PrintFatalError("dialect name not provided"); 1013 1014 AttributeClasses attributeClasses; 1015 constructAttributeMapping(records, attributeClasses); 1016 1017 os << llvm::formatv(fileHeader, clDialectName.getValue()); 1018 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); 1019 1020 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { 1021 Operator op(rec); 1022 if (op.getDialectName() == clDialectName.getValue()) 1023 emitOpBindings(op, attributeClasses, os); 1024 } 1025 return false; 1026 } 1027 1028 static GenRegistration 1029 genPythonBindings("gen-python-op-bindings", 1030 "Generate Python bindings for MLIR Ops", &emitAllOps); 1031