1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python 10 // binding classes wrapping a generic operation API. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/GenInfo.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/ADT/StringSet.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace mlir; 23 using namespace mlir::tblgen; 24 25 /// File header and includes. 26 /// {0} is the dialect namespace. 27 constexpr const char *fileHeader = R"Py( 28 # Autogenerated by mlir-tblgen; don't manually edit. 29 30 from ._ods_common import _cext as _ods_cext 31 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values 32 _ods_ir = _ods_cext.ir 33 34 try: 35 from . import _{0}_ops_ext as _ods_ext_module 36 except ImportError: 37 _ods_ext_module = None 38 39 import builtins 40 41 )Py"; 42 43 /// Template for dialect class: 44 /// {0} is the dialect namespace. 45 constexpr const char *dialectClassTemplate = R"Py( 46 @_ods_cext.register_dialect 47 class _Dialect(_ods_ir.Dialect): 48 DIALECT_NAMESPACE = "{0}" 49 pass 50 51 )Py"; 52 53 constexpr const char *dialectExtensionTemplate = R"Py( 54 from ._{0}_ops_gen import _Dialect 55 )Py"; 56 57 /// Template for operation class: 58 /// {0} is the Python class name; 59 /// {1} is the operation name. 60 constexpr const char *opClassTemplate = R"Py( 61 @_ods_cext.register_operation(_Dialect) 62 @_ods_extend_opview_class(_ods_ext_module) 63 class {0}(_ods_ir.OpView): 64 OPERATION_NAME = "{1}" 65 )Py"; 66 67 /// Template for class level declarations of operand and result 68 /// segment specs. 69 /// {0} is either "OPERAND" or "RESULT" 70 /// {1} is the segment spec 71 /// Each segment spec is either None (default) or an array of integers 72 /// where: 73 /// 1 = single element (expect non sequence operand/result) 74 /// 0 = optional element (expect a value or None) 75 /// -1 = operand/result is a sequence corresponding to a variadic 76 constexpr const char *opClassSizedSegmentsTemplate = R"Py( 77 _ODS_{0}_SEGMENTS = {1} 78 )Py"; 79 80 /// Template for class level declarations of the _ODS_REGIONS spec: 81 /// {0} is the minimum number of regions 82 /// {1} is the Python bool literal for hasNoVariadicRegions 83 constexpr const char *opClassRegionSpecTemplate = R"Py( 84 _ODS_REGIONS = ({0}, {1}) 85 )Py"; 86 87 /// Template for single-element accessor: 88 /// {0} is the name of the accessor; 89 /// {1} is either 'operand' or 'result'; 90 /// {2} is the position in the element list. 91 constexpr const char *opSingleTemplate = R"Py( 92 @builtins.property 93 def {0}(self): 94 return self.operation.{1}s[{2}] 95 )Py"; 96 97 /// Template for single-element accessor after a variable-length group: 98 /// {0} is the name of the accessor; 99 /// {1} is either 'operand' or 'result'; 100 /// {2} is the total number of element groups; 101 /// {3} is the position of the current group in the group list. 102 /// This works for both a single variadic group (non-negative length) and an 103 /// single optional element (zero length if the element is absent). 104 constexpr const char *opSingleAfterVariableTemplate = R"Py( 105 @builtins.property 106 def {0}(self): 107 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 108 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] 109 )Py"; 110 111 /// Template for an optional element accessor: 112 /// {0} is the name of the accessor; 113 /// {1} is either 'operand' or 'result'; 114 /// {2} is the total number of element groups; 115 /// {3} is the position of the current group in the group list. 116 /// This works if we have only one variable-length group (and it's the optional 117 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is 118 /// smaller than the total number of groups. 119 constexpr const char *opOneOptionalTemplate = R"Py( 120 @builtins.property 121 def {0}(self): 122 return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}] 123 )Py"; 124 125 /// Template for the variadic group accessor in the single variadic group case: 126 /// {0} is the name of the accessor; 127 /// {1} is either 'operand' or 'result'; 128 /// {2} is the total number of element groups; 129 /// {3} is the position of the current group in the group list. 130 constexpr const char *opOneVariadicTemplate = R"Py( 131 @builtins.property 132 def {0}(self): 133 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 134 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] 135 )Py"; 136 137 /// First part of the template for equally-sized variadic group accessor: 138 /// {0} is the name of the accessor; 139 /// {1} is either 'operand' or 'result'; 140 /// {2} is the total number of variadic groups; 141 /// {3} is the number of non-variadic groups preceding the current group; 142 /// {3} is the number of variadic groups preceding the current group. 143 constexpr const char *opVariadicEqualPrefixTemplate = R"Py( 144 @builtins.property 145 def {0}(self): 146 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py"; 147 148 /// Second part of the template for equally-sized case, accessing a single 149 /// element: 150 /// {0} is either 'operand' or 'result'. 151 constexpr const char *opVariadicEqualSimpleTemplate = R"Py( 152 return self.operation.{0}s[start] 153 )Py"; 154 155 /// Second part of the template for equally-sized case, accessing a variadic 156 /// group: 157 /// {0} is either 'operand' or 'result'. 158 constexpr const char *opVariadicEqualVariadicTemplate = R"Py( 159 return self.operation.{0}s[start:start + pg] 160 )Py"; 161 162 /// Template for an attribute-sized group accessor: 163 /// {0} is the name of the accessor; 164 /// {1} is either 'operand' or 'result'; 165 /// {2} is the position of the group in the group list; 166 /// {3} is a return suffix (expected [0] for single-element, empty for 167 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). 168 constexpr const char *opVariadicSegmentTemplate = R"Py( 169 @builtins.property 170 def {0}(self): 171 {1}_range = _ods_segmented_accessor( 172 self.operation.{1}s, 173 self.operation.attributes["{1}_segment_sizes"], {2}) 174 return {1}_range{3} 175 )Py"; 176 177 /// Template for a suffix when accessing an optional element in the 178 /// attribute-sized case: 179 /// {0} is either 'operand' or 'result'; 180 constexpr const char *opVariadicSegmentOptionalTrailingTemplate = 181 R"Py([0] if len({0}_range) > 0 else None)Py"; 182 183 /// Template for an operation attribute getter: 184 /// {0} is the name of the attribute sanitized for Python; 185 /// {1} is the Python type of the attribute; 186 /// {2} os the original name of the attribute. 187 constexpr const char *attributeGetterTemplate = R"Py( 188 @builtins.property 189 def {0}(self): 190 return {1}(self.operation.attributes["{2}"]) 191 )Py"; 192 193 /// Template for an optional operation attribute getter: 194 /// {0} is the name of the attribute sanitized for Python; 195 /// {1} is the Python type of the attribute; 196 /// {2} is the original name of the attribute. 197 constexpr const char *optionalAttributeGetterTemplate = R"Py( 198 @builtins.property 199 def {0}(self): 200 if "{2}" not in self.operation.attributes: 201 return None 202 return {1}(self.operation.attributes["{2}"]) 203 )Py"; 204 205 /// Template for a getter of a unit operation attribute, returns True of the 206 /// unit attribute is present, False otherwise (unit attributes have meaning 207 /// by mere presence): 208 /// {0} is the name of the attribute sanitized for Python, 209 /// {1} is the original name of the attribute. 210 constexpr const char *unitAttributeGetterTemplate = R"Py( 211 @builtins.property 212 def {0}(self): 213 return "{1}" in self.operation.attributes 214 )Py"; 215 216 /// Template for an operation attribute setter: 217 /// {0} is the name of the attribute sanitized for Python; 218 /// {1} is the original name of the attribute. 219 constexpr const char *attributeSetterTemplate = R"Py( 220 @{0}.setter 221 def {0}(self, value): 222 if value is None: 223 raise ValueError("'None' not allowed as value for mandatory attributes") 224 self.operation.attributes["{1}"] = value 225 )Py"; 226 227 /// Template for a setter of an optional operation attribute, setting to None 228 /// removes the attribute: 229 /// {0} is the name of the attribute sanitized for Python; 230 /// {1} is the original name of the attribute. 231 constexpr const char *optionalAttributeSetterTemplate = R"Py( 232 @{0}.setter 233 def {0}(self, value): 234 if value is not None: 235 self.operation.attributes["{1}"] = value 236 elif "{1}" in self.operation.attributes: 237 del self.operation.attributes["{1}"] 238 )Py"; 239 240 /// Template for a setter of a unit operation attribute, setting to None or 241 /// False removes the attribute: 242 /// {0} is the name of the attribute sanitized for Python; 243 /// {1} is the original name of the attribute. 244 constexpr const char *unitAttributeSetterTemplate = R"Py( 245 @{0}.setter 246 def {0}(self, value): 247 if bool(value): 248 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() 249 elif "{1}" in self.operation.attributes: 250 del self.operation.attributes["{1}"] 251 )Py"; 252 253 /// Template for a deleter of an optional or a unit operation attribute, removes 254 /// the attribute from the operation: 255 /// {0} is the name of the attribute sanitized for Python; 256 /// {1} is the original name of the attribute. 257 constexpr const char *attributeDeleterTemplate = R"Py( 258 @{0}.deleter 259 def {0}(self): 260 del self.operation.attributes["{1}"] 261 )Py"; 262 263 constexpr const char *regionAccessorTemplate = R"PY( 264 @builtins.property 265 def {0}(self): 266 return self.regions[{1}] 267 )PY"; 268 269 static llvm::cl::OptionCategory 270 clOpPythonBindingCat("Options for -gen-python-op-bindings"); 271 272 static llvm::cl::opt<std::string> 273 clDialectName("bind-dialect", 274 llvm::cl::desc("The dialect to run the generator for"), 275 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 276 277 static llvm::cl::opt<std::string> clDialectExtensionName( 278 "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"), 279 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 280 281 using AttributeClasses = DenseMap<StringRef, StringRef>; 282 283 /// Checks whether `str` is a Python keyword. 284 static bool isPythonKeyword(StringRef str) { 285 static llvm::StringSet<> keywords( 286 {"and", "as", "assert", "break", "class", "continue", 287 "def", "del", "elif", "else", "except", "finally", 288 "for", "from", "global", "if", "import", "in", 289 "is", "lambda", "nonlocal", "not", "or", "pass", 290 "raise", "return", "try", "while", "with", "yield"}); 291 return keywords.contains(str); 292 } 293 294 /// Checks whether `str` would shadow a generated variable or attribute 295 /// part of the OpView API. 296 static bool isODSReserved(StringRef str) { 297 static llvm::StringSet<> reserved( 298 {"attributes", "create", "context", "ip", "operands", "print", "get_asm", 299 "loc", "verify", "regions", "results", "self", "operation", 300 "DIALECT_NAMESPACE", "OPERATION_NAME"}); 301 return str.startswith("_ods_") || str.endswith("_ods") || 302 reserved.contains(str); 303 } 304 305 /// Modifies the `name` in a way that it becomes suitable for Python bindings 306 /// (does not change the `name` if it already is suitable) and returns the 307 /// modified version. 308 static std::string sanitizeName(StringRef name) { 309 if (isPythonKeyword(name) || isODSReserved(name)) 310 return (name + "_").str(); 311 return name.str(); 312 } 313 314 static std::string attrSizedTraitForKind(const char *kind) { 315 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", 316 llvm::StringRef(kind).take_front().upper(), 317 llvm::StringRef(kind).drop_front()); 318 } 319 320 /// Emits accessors to "elements" of an Op definition. Currently, the supported 321 /// elements are operands and results, indicated by `kind`, which must be either 322 /// `operand` or `result` and is used verbatim in the emitted code. 323 static void emitElementAccessors( 324 const Operator &op, raw_ostream &os, const char *kind, 325 llvm::function_ref<unsigned(const Operator &)> getNumVariableLength, 326 llvm::function_ref<int(const Operator &)> getNumElements, 327 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 328 getElement) { 329 assert(llvm::is_contained( 330 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && 331 "unsupported kind"); 332 333 // Traits indicating how to process variadic elements. 334 std::string sameSizeTrait = 335 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", 336 llvm::StringRef(kind).take_front().upper(), 337 llvm::StringRef(kind).drop_front()); 338 std::string attrSizedTrait = attrSizedTraitForKind(kind); 339 340 unsigned numVariableLength = getNumVariableLength(op); 341 342 // If there is only one variable-length element group, its size can be 343 // inferred from the total number of elements. If there are none, the 344 // generation is straightforward. 345 if (numVariableLength <= 1) { 346 bool seenVariableLength = false; 347 for (int i = 0, e = getNumElements(op); i < e; ++i) { 348 const NamedTypeConstraint &element = getElement(op, i); 349 if (element.isVariableLength()) 350 seenVariableLength = true; 351 if (element.name.empty()) 352 continue; 353 if (element.isVariableLength()) { 354 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate 355 : opOneVariadicTemplate, 356 sanitizeName(element.name), kind, 357 getNumElements(op), i); 358 } else if (seenVariableLength) { 359 os << llvm::formatv(opSingleAfterVariableTemplate, 360 sanitizeName(element.name), kind, 361 getNumElements(op), i); 362 } else { 363 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, 364 i); 365 } 366 } 367 return; 368 } 369 370 // Handle the operations where variadic groups have the same size. 371 if (op.getTrait(sameSizeTrait)) { 372 int numPrecedingSimple = 0; 373 int numPrecedingVariadic = 0; 374 for (int i = 0, e = getNumElements(op); i < e; ++i) { 375 const NamedTypeConstraint &element = getElement(op, i); 376 if (!element.name.empty()) { 377 os << llvm::formatv(opVariadicEqualPrefixTemplate, 378 sanitizeName(element.name), kind, numVariableLength, 379 numPrecedingSimple, numPrecedingVariadic); 380 os << llvm::formatv(element.isVariableLength() 381 ? opVariadicEqualVariadicTemplate 382 : opVariadicEqualSimpleTemplate, 383 kind); 384 } 385 if (element.isVariableLength()) 386 ++numPrecedingVariadic; 387 else 388 ++numPrecedingSimple; 389 } 390 return; 391 } 392 393 // Handle the operations where the size of groups (variadic or not) is 394 // provided as an attribute. For non-variadic elements, make sure to return 395 // an element rather than a singleton container. 396 if (op.getTrait(attrSizedTrait)) { 397 for (int i = 0, e = getNumElements(op); i < e; ++i) { 398 const NamedTypeConstraint &element = getElement(op, i); 399 if (element.name.empty()) 400 continue; 401 std::string trailing; 402 if (!element.isVariableLength()) 403 trailing = "[0]"; 404 else if (element.isOptional()) 405 trailing = std::string( 406 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); 407 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), 408 kind, i, trailing); 409 } 410 return; 411 } 412 413 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); 414 } 415 416 /// Free function helpers accessing Operator components. 417 static int getNumOperands(const Operator &op) { return op.getNumOperands(); } 418 static const NamedTypeConstraint &getOperand(const Operator &op, int i) { 419 return op.getOperand(i); 420 } 421 static int getNumResults(const Operator &op) { return op.getNumResults(); } 422 static const NamedTypeConstraint &getResult(const Operator &op, int i) { 423 return op.getResult(i); 424 } 425 426 /// Emits accessors to Op operands. 427 static void emitOperandAccessors(const Operator &op, raw_ostream &os) { 428 auto getNumVariableLengthOperands = [](const Operator &oper) { 429 return oper.getNumVariableLengthOperands(); 430 }; 431 emitElementAccessors(op, os, "operand", getNumVariableLengthOperands, 432 getNumOperands, getOperand); 433 } 434 435 /// Emits accessors Op results. 436 static void emitResultAccessors(const Operator &op, raw_ostream &os) { 437 auto getNumVariableLengthResults = [](const Operator &oper) { 438 return oper.getNumVariableLengthResults(); 439 }; 440 emitElementAccessors(op, os, "result", getNumVariableLengthResults, 441 getNumResults, getResult); 442 } 443 444 /// Emits accessors to Op attributes. 445 static void emitAttributeAccessors(const Operator &op, 446 const AttributeClasses &attributeClasses, 447 raw_ostream &os) { 448 for (const auto &namedAttr : op.getAttributes()) { 449 // Skip "derived" attributes because they are just C++ functions that we 450 // don't currently expose. 451 if (namedAttr.attr.isDerivedAttr()) 452 continue; 453 454 if (namedAttr.name.empty()) 455 continue; 456 457 std::string sanitizedName = sanitizeName(namedAttr.name); 458 459 // Unit attributes are handled specially. 460 if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 461 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, 462 namedAttr.name); 463 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, 464 namedAttr.name); 465 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 466 namedAttr.name); 467 continue; 468 } 469 470 // Other kinds of attributes need a mapping to a Python type. 471 if (!attributeClasses.count(namedAttr.attr.getStorageType().trim())) 472 continue; 473 474 StringRef pythonType = 475 attributeClasses.lookup(namedAttr.attr.getStorageType()); 476 if (namedAttr.attr.isOptional()) { 477 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName, 478 pythonType, namedAttr.name); 479 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName, 480 namedAttr.name); 481 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 482 namedAttr.name); 483 } else { 484 os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType, 485 namedAttr.name); 486 os << llvm::formatv(attributeSetterTemplate, sanitizedName, 487 namedAttr.name); 488 // Non-optional attributes cannot be deleted. 489 } 490 } 491 } 492 493 /// Template for the default auto-generated builder. 494 /// {0} is a comma-separated list of builder arguments, including the trailing 495 /// `loc` and `ip`; 496 /// {1} is the code populating `operands`, `results` and `attributes`, 497 /// `successors` fields. 498 constexpr const char *initTemplate = R"Py( 499 def __init__(self, {0}): 500 operands = [] 501 results = [] 502 attributes = {{} 503 regions = None 504 {1} 505 super().__init__(self.build_generic( 506 attributes=attributes, results=results, operands=operands, 507 successors=_ods_successors, regions=regions, loc=loc, ip=ip)) 508 )Py"; 509 510 /// Template for appending a single element to the operand/result list. 511 /// {0} is the field name. 512 constexpr const char *singleOperandAppendTemplate = 513 "operands.append(_get_op_result_or_value({0}))"; 514 constexpr const char *singleResultAppendTemplate = "results.append({0})"; 515 516 /// Template for appending an optional element to the operand/result list. 517 /// {0} is the field name. 518 constexpr const char *optionalAppendOperandTemplate = 519 "if {0} is not None: operands.append(_get_op_result_or_value({0}))"; 520 constexpr const char *optionalAppendAttrSizedOperandsTemplate = 521 "operands.append(_get_op_result_or_value({0}) if {0} is not None else " 522 "None)"; 523 constexpr const char *optionalAppendResultTemplate = 524 "if {0} is not None: results.append({0})"; 525 526 /// Template for appending a list of elements to the operand/result list. 527 /// {0} is the field name. 528 constexpr const char *multiOperandAppendTemplate = 529 "operands.extend(_get_op_results_or_values({0}))"; 530 constexpr const char *multiOperandAppendPackTemplate = 531 "operands.append(_get_op_results_or_values({0}))"; 532 constexpr const char *multiResultAppendTemplate = "results.extend({0})"; 533 534 /// Template for setting an attribute in the operation builder. 535 /// {0} is the attribute name; 536 /// {1} is the builder argument name. 537 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; 538 539 /// Template for setting an optional attribute in the operation builder. 540 /// {0} is the attribute name; 541 /// {1} is the builder argument name. 542 constexpr const char *initOptionalAttributeTemplate = 543 R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; 544 545 constexpr const char *initUnitAttributeTemplate = 546 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( 547 _ods_get_default_loc_context(loc)))Py"; 548 549 /// Template to initialize the successors list in the builder if there are any 550 /// successors. 551 /// {0} is the value to initialize the successors list to. 552 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; 553 554 /// Template to append or extend the list of successors in the builder. 555 /// {0} is the list method ('append' or 'extend'); 556 /// {1} is the value to add. 557 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py"; 558 559 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer 560 /// result types of the given operation. 561 static bool hasSameArgumentAndResultTypes(const Operator &op) { 562 return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && 563 op.getNumVariableLengthResults() == 0; 564 } 565 566 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer 567 /// result types of the given operation. 568 static bool hasFirstAttrDerivedResultTypes(const Operator &op) { 569 return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && 570 op.getNumVariableLengthResults() == 0; 571 } 572 573 /// Returns true if the InferTypeOpInterface can be used to infer result types 574 /// of the given operation. 575 static bool hasInferTypeInterface(const Operator &op) { 576 return op.getTrait("::mlir::InferTypeOpInterface::Trait") && 577 op.getNumRegions() == 0; 578 } 579 580 /// Returns true if there is a trait or interface that can be used to infer 581 /// result types of the given operation. 582 static bool canInferType(const Operator &op) { 583 return hasSameArgumentAndResultTypes(op) || 584 hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); 585 } 586 587 /// Populates `builderArgs` with result names if the builder is expected to 588 /// accept them as arguments. 589 static void 590 populateBuilderArgsResults(const Operator &op, 591 llvm::SmallVectorImpl<std::string> &builderArgs) { 592 if (canInferType(op)) 593 return; 594 595 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 596 std::string name = op.getResultName(i).str(); 597 if (name.empty()) { 598 if (op.getNumResults() == 1) { 599 // Special case for one result, make the default name be 'result' 600 // to properly match the built-in result accessor. 601 name = "result"; 602 } else { 603 name = llvm::formatv("_gen_res_{0}", i); 604 } 605 } 606 name = sanitizeName(name); 607 builderArgs.push_back(name); 608 } 609 } 610 611 /// Populates `builderArgs` with the Python-compatible names of builder function 612 /// arguments using intermixed attributes and operands in the same order as they 613 /// appear in the `arguments` field of the op definition. Additionally, 614 /// `operandNames` is populated with names of operands in their order of 615 /// appearance. 616 static void 617 populateBuilderArgs(const Operator &op, 618 llvm::SmallVectorImpl<std::string> &builderArgs, 619 llvm::SmallVectorImpl<std::string> &operandNames, 620 llvm::SmallVectorImpl<std::string> &successorArgNames) { 621 622 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 623 std::string name = op.getArgName(i).str(); 624 if (name.empty()) 625 name = llvm::formatv("_gen_arg_{0}", i); 626 name = sanitizeName(name); 627 builderArgs.push_back(name); 628 if (!op.getArg(i).is<NamedAttribute *>()) 629 operandNames.push_back(name); 630 } 631 } 632 633 /// Populates `builderArgs` with the Python-compatible names of builder function 634 /// successor arguments. Additionally, `successorArgNames` is also populated. 635 static void populateBuilderArgsSuccessors( 636 const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs, 637 llvm::SmallVectorImpl<std::string> &successorArgNames) { 638 639 for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { 640 NamedSuccessor successor = op.getSuccessor(i); 641 std::string name = std::string(successor.name); 642 if (name.empty()) 643 name = llvm::formatv("_gen_successor_{0}", i); 644 name = sanitizeName(name); 645 builderArgs.push_back(name); 646 successorArgNames.push_back(name); 647 } 648 } 649 650 /// Populates `builderLines` with additional lines that are required in the 651 /// builder to set up operation attributes. `argNames` is expected to contain 652 /// the names of builder arguments that correspond to op arguments, i.e. to the 653 /// operands and attributes in the same order as they appear in the `arguments` 654 /// field. 655 static void 656 populateBuilderLinesAttr(const Operator &op, 657 llvm::ArrayRef<std::string> argNames, 658 llvm::SmallVectorImpl<std::string> &builderLines) { 659 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 660 Argument arg = op.getArg(i); 661 auto *attribute = arg.dyn_cast<NamedAttribute *>(); 662 if (!attribute) 663 continue; 664 665 // Unit attributes are handled specially. 666 if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 667 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, 668 attribute->name, argNames[i])); 669 continue; 670 } 671 672 builderLines.push_back(llvm::formatv(attribute->attr.isOptional() 673 ? initOptionalAttributeTemplate 674 : initAttributeTemplate, 675 attribute->name, argNames[i])); 676 } 677 } 678 679 /// Populates `builderLines` with additional lines that are required in the 680 /// builder to set up successors. successorArgNames is expected to correspond 681 /// to the Python argument name for each successor on the op. 682 static void populateBuilderLinesSuccessors( 683 const Operator &op, llvm::ArrayRef<std::string> successorArgNames, 684 llvm::SmallVectorImpl<std::string> &builderLines) { 685 if (successorArgNames.empty()) { 686 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None")); 687 return; 688 } 689 690 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]")); 691 for (int i = 0, e = successorArgNames.size(); i < e; ++i) { 692 auto &argName = successorArgNames[i]; 693 const NamedSuccessor &successor = op.getSuccessor(i); 694 builderLines.push_back( 695 llvm::formatv(addSuccessorTemplate, 696 successor.isVariadic() ? "extend" : "append", argName)); 697 } 698 } 699 700 /// Populates `builderLines` with additional lines that are required in the 701 /// builder to set up op operands. 702 static void 703 populateBuilderLinesOperand(const Operator &op, 704 llvm::ArrayRef<std::string> names, 705 llvm::SmallVectorImpl<std::string> &builderLines) { 706 bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr; 707 708 // For each element, find or generate a name. 709 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 710 const NamedTypeConstraint &element = op.getOperand(i); 711 std::string name = names[i]; 712 713 // Choose the formatting string based on the element kind. 714 llvm::StringRef formatString; 715 if (!element.isVariableLength()) { 716 formatString = singleOperandAppendTemplate; 717 } else if (element.isOptional()) { 718 if (sizedSegments) { 719 formatString = optionalAppendAttrSizedOperandsTemplate; 720 } else { 721 formatString = optionalAppendOperandTemplate; 722 } 723 } else { 724 assert(element.isVariadic() && "unhandled element group type"); 725 // If emitting with sizedSegments, then we add the actual list-typed 726 // element. Otherwise, we extend the actual operands. 727 if (sizedSegments) { 728 formatString = multiOperandAppendPackTemplate; 729 } else { 730 formatString = multiOperandAppendTemplate; 731 } 732 } 733 734 builderLines.push_back(llvm::formatv(formatString.data(), name)); 735 } 736 } 737 738 /// Python code template for deriving the operation result types from its 739 /// attribute: 740 /// - {0} is the name of the attribute from which to derive the types. 741 constexpr const char *deriveTypeFromAttrTemplate = 742 R"PY(_ods_result_type_source_attr = attributes["{0}"] 743 _ods_derived_result_type = ( 744 _ods_ir.TypeAttr(_ods_result_type_source_attr).value 745 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else 746 _ods_result_type_source_attr.type))PY"; 747 748 /// Python code template appending {0} type {1} times to the results list. 749 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; 750 751 /// Python code template for inferring the operation results using the 752 /// corresponding interface: 753 /// - {0} is the name of the class for which the types are inferred. 754 constexpr const char *inferTypeInterfaceTemplate = 755 R"PY(_ods_context = _ods_get_default_loc_context(loc) 756 results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( 757 operands=operands, 758 attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), 759 context=_ods_context, 760 loc=loc) 761 )PY"; 762 763 /// Appends the given multiline string as individual strings into 764 /// `builderLines`. 765 static void appendLineByLine(StringRef string, 766 llvm::SmallVectorImpl<std::string> &builderLines) { 767 768 std::pair<StringRef, StringRef> split = std::make_pair(string, string); 769 do { 770 split = split.second.split('\n'); 771 builderLines.push_back(split.first.str()); 772 } while (!split.second.empty()); 773 } 774 775 /// Populates `builderLines` with additional lines that are required in the 776 /// builder to set up op results. 777 static void 778 populateBuilderLinesResult(const Operator &op, 779 llvm::ArrayRef<std::string> names, 780 llvm::SmallVectorImpl<std::string> &builderLines) { 781 bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; 782 783 if (hasSameArgumentAndResultTypes(op)) { 784 builderLines.push_back(llvm::formatv( 785 appendSameResultsTemplate, "operands[0].type", op.getNumResults())); 786 return; 787 } 788 789 if (hasFirstAttrDerivedResultTypes(op)) { 790 const NamedAttribute &firstAttr = op.getAttribute(0); 791 assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " 792 "from which the type is derived"); 793 appendLineByLine( 794 llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), 795 builderLines); 796 builderLines.push_back(llvm::formatv(appendSameResultsTemplate, 797 "_ods_derived_result_type", 798 op.getNumResults())); 799 return; 800 } 801 802 if (hasInferTypeInterface(op)) { 803 appendLineByLine( 804 llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(), 805 builderLines); 806 return; 807 } 808 809 // For each element, find or generate a name. 810 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 811 const NamedTypeConstraint &element = op.getResult(i); 812 std::string name = names[i]; 813 814 // Choose the formatting string based on the element kind. 815 llvm::StringRef formatString; 816 if (!element.isVariableLength()) { 817 formatString = singleResultAppendTemplate; 818 } else if (element.isOptional()) { 819 formatString = optionalAppendResultTemplate; 820 } else { 821 assert(element.isVariadic() && "unhandled element group type"); 822 // If emitting with sizedSegments, then we add the actual list-typed 823 // element. Otherwise, we extend the actual operands. 824 if (sizedSegments) { 825 formatString = singleResultAppendTemplate; 826 } else { 827 formatString = multiResultAppendTemplate; 828 } 829 } 830 831 builderLines.push_back(llvm::formatv(formatString.data(), name)); 832 } 833 } 834 835 /// If the operation has variadic regions, adds a builder argument to specify 836 /// the number of those regions and builder lines to forward it to the generic 837 /// constructor. 838 static void 839 populateBuilderRegions(const Operator &op, 840 llvm::SmallVectorImpl<std::string> &builderArgs, 841 llvm::SmallVectorImpl<std::string> &builderLines) { 842 if (op.hasNoVariadicRegions()) 843 return; 844 845 // This is currently enforced when Operator is constructed. 846 assert(op.getNumVariadicRegions() == 1 && 847 op.getRegion(op.getNumRegions() - 1).isVariadic() && 848 "expected the last region to be varidic"); 849 850 const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1); 851 std::string name = 852 ("num_" + region.name.take_front().lower() + region.name.drop_front()) 853 .str(); 854 builderArgs.push_back(name); 855 builderLines.push_back( 856 llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name)); 857 } 858 859 /// Emits a default builder constructing an operation from the list of its 860 /// result types, followed by a list of its operands. 861 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { 862 // If we are asked to skip default builders, comply. 863 if (op.skipDefaultBuilders()) 864 return; 865 866 llvm::SmallVector<std::string> builderArgs; 867 llvm::SmallVector<std::string> builderLines; 868 llvm::SmallVector<std::string> operandArgNames; 869 llvm::SmallVector<std::string> successorArgNames; 870 builderArgs.reserve(op.getNumOperands() + op.getNumResults() + 871 op.getNumNativeAttributes() + op.getNumSuccessors()); 872 populateBuilderArgsResults(op, builderArgs); 873 size_t numResultArgs = builderArgs.size(); 874 populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); 875 size_t numOperandAttrArgs = builderArgs.size() - numResultArgs; 876 populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); 877 878 populateBuilderLinesOperand(op, operandArgNames, builderLines); 879 populateBuilderLinesAttr( 880 op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs), 881 builderLines); 882 populateBuilderLinesResult( 883 op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs), 884 builderLines); 885 populateBuilderLinesSuccessors(op, successorArgNames, builderLines); 886 populateBuilderRegions(op, builderArgs, builderLines); 887 888 // Layout of builderArgs vector elements: 889 // [ result_args operand_attr_args successor_args regions ] 890 891 // Determine whether the argument corresponding to a given index into the 892 // builderArgs vector is a python keyword argument or not. 893 auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool { 894 // All result, successor, and region arguments are positional arguments. 895 if ((builderArgIndex < numResultArgs) || 896 (builderArgIndex >= (numResultArgs + numOperandAttrArgs))) 897 return false; 898 // Keyword arguments: 899 // - optional named attributes (including unit attributes) 900 // - default-valued named attributes 901 // - optional operands 902 Argument a = op.getArg(builderArgIndex - numResultArgs); 903 if (auto *nattr = a.dyn_cast<NamedAttribute *>()) 904 return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue()); 905 if (auto *ntype = a.dyn_cast<NamedTypeConstraint *>()) 906 return ntype->isOptional(); 907 return false; 908 }; 909 910 // StringRefs in functionArgs refer to strings allocated by builderArgs. 911 llvm::SmallVector<llvm::StringRef> functionArgs; 912 913 // Add positional arguments. 914 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { 915 if (!isKeywordArgFn(i)) 916 functionArgs.push_back(builderArgs[i]); 917 } 918 919 // Add a bare '*' to indicate that all following arguments must be keyword 920 // arguments. 921 functionArgs.push_back("*"); 922 923 // Add a default 'None' value to each keyword arg string, and then add to the 924 // function args list. 925 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { 926 if (isKeywordArgFn(i)) { 927 builderArgs[i].append("=None"); 928 functionArgs.push_back(builderArgs[i]); 929 } 930 } 931 functionArgs.push_back("loc=None"); 932 functionArgs.push_back("ip=None"); 933 os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "), 934 llvm::join(builderLines, "\n ")); 935 } 936 937 static void constructAttributeMapping(const llvm::RecordKeeper &records, 938 AttributeClasses &attributeClasses) { 939 for (const llvm::Record *rec : 940 records.getAllDerivedDefinitions("PythonAttr")) { 941 attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(), 942 rec->getValueAsString("pythonType").trim()); 943 } 944 } 945 946 static void emitSegmentSpec( 947 const Operator &op, const char *kind, 948 llvm::function_ref<int(const Operator &)> getNumElements, 949 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 950 getElement, 951 raw_ostream &os) { 952 std::string segmentSpec("["); 953 for (int i = 0, e = getNumElements(op); i < e; ++i) { 954 const NamedTypeConstraint &element = getElement(op, i); 955 if (element.isOptional()) { 956 segmentSpec.append("0,"); 957 } else if (element.isVariadic()) { 958 segmentSpec.append("-1,"); 959 } else { 960 segmentSpec.append("1,"); 961 } 962 } 963 segmentSpec.append("]"); 964 965 os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); 966 } 967 968 static void emitRegionAttributes(const Operator &op, raw_ostream &os) { 969 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). 970 // Note that the base OpView class defines this as (0, True). 971 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); 972 os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount, 973 op.hasNoVariadicRegions() ? "True" : "False"); 974 } 975 976 /// Emits named accessors to regions. 977 static void emitRegionAccessors(const Operator &op, raw_ostream &os) { 978 for (const auto &en : llvm::enumerate(op.getRegions())) { 979 const NamedRegion ®ion = en.value(); 980 if (region.name.empty()) 981 continue; 982 983 assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && 984 "expected only the last region to be variadic"); 985 os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name), 986 std::to_string(en.index()) + 987 (region.isVariadic() ? ":" : "")); 988 } 989 } 990 991 /// Emits bindings for a specific Op to the given output stream. 992 static void emitOpBindings(const Operator &op, 993 const AttributeClasses &attributeClasses, 994 raw_ostream &os) { 995 os << llvm::formatv(opClassTemplate, op.getCppClassName(), 996 op.getOperationName()); 997 998 // Sized segments. 999 if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { 1000 emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); 1001 } 1002 if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { 1003 emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); 1004 } 1005 1006 emitRegionAttributes(op, os); 1007 emitDefaultOpBuilder(op, os); 1008 emitOperandAccessors(op, os); 1009 emitAttributeAccessors(op, attributeClasses, os); 1010 emitResultAccessors(op, os); 1011 emitRegionAccessors(op, os); 1012 } 1013 1014 /// Emits bindings for the dialect specified in the command line, including file 1015 /// headers and utilities. Returns `false` on success to comply with Tablegen 1016 /// registration requirements. 1017 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { 1018 if (clDialectName.empty()) 1019 llvm::PrintFatalError("dialect name not provided"); 1020 1021 AttributeClasses attributeClasses; 1022 constructAttributeMapping(records, attributeClasses); 1023 1024 bool isExtension = !clDialectExtensionName.empty(); 1025 os << llvm::formatv(fileHeader, isExtension 1026 ? clDialectExtensionName.getValue() 1027 : clDialectName.getValue()); 1028 if (isExtension) 1029 os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue()); 1030 else 1031 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); 1032 1033 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { 1034 Operator op(rec); 1035 if (op.getDialectName() == clDialectName.getValue()) 1036 emitOpBindings(op, attributeClasses, os); 1037 } 1038 return false; 1039 } 1040 1041 static GenRegistration 1042 genPythonBindings("gen-python-op-bindings", 1043 "Generate Python bindings for MLIR Ops", &emitAllOps); 1044