1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python 10 // binding classes wrapping a generic operation API. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Support/LogicalResult.h" 15 #include "mlir/TableGen/GenInfo.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "llvm/ADT/StringSet.h" 18 #include "llvm/Support/CommandLine.h" 19 #include "llvm/Support/FormatVariadic.h" 20 #include "llvm/TableGen/Error.h" 21 #include "llvm/TableGen/Record.h" 22 23 using namespace mlir; 24 using namespace mlir::tblgen; 25 26 /// File header and includes. 27 /// {0} is the dialect namespace. 28 constexpr const char *fileHeader = R"Py( 29 # Autogenerated by mlir-tblgen; don't manually edit. 30 31 from ._ods_common import _cext as _ods_cext 32 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values 33 _ods_ir = _ods_cext.ir 34 35 try: 36 from . import _{0}_ops_ext as _ods_ext_module 37 except ImportError: 38 _ods_ext_module = None 39 40 import builtins 41 42 )Py"; 43 44 /// Template for dialect class: 45 /// {0} is the dialect namespace. 46 constexpr const char *dialectClassTemplate = R"Py( 47 @_ods_cext.register_dialect 48 class _Dialect(_ods_ir.Dialect): 49 DIALECT_NAMESPACE = "{0}" 50 pass 51 52 )Py"; 53 54 constexpr const char *dialectExtensionTemplate = R"Py( 55 from ._{0}_ops_gen import _Dialect 56 )Py"; 57 58 /// Template for operation class: 59 /// {0} is the Python class name; 60 /// {1} is the operation name. 61 constexpr const char *opClassTemplate = R"Py( 62 @_ods_cext.register_operation(_Dialect) 63 @_ods_extend_opview_class(_ods_ext_module) 64 class {0}(_ods_ir.OpView): 65 OPERATION_NAME = "{1}" 66 )Py"; 67 68 /// Template for class level declarations of operand and result 69 /// segment specs. 70 /// {0} is either "OPERAND" or "RESULT" 71 /// {1} is the segment spec 72 /// Each segment spec is either None (default) or an array of integers 73 /// where: 74 /// 1 = single element (expect non sequence operand/result) 75 /// 0 = optional element (expect a value or None) 76 /// -1 = operand/result is a sequence corresponding to a variadic 77 constexpr const char *opClassSizedSegmentsTemplate = R"Py( 78 _ODS_{0}_SEGMENTS = {1} 79 )Py"; 80 81 /// Template for class level declarations of the _ODS_REGIONS spec: 82 /// {0} is the minimum number of regions 83 /// {1} is the Python bool literal for hasNoVariadicRegions 84 constexpr const char *opClassRegionSpecTemplate = R"Py( 85 _ODS_REGIONS = ({0}, {1}) 86 )Py"; 87 88 /// Template for single-element accessor: 89 /// {0} is the name of the accessor; 90 /// {1} is either 'operand' or 'result'; 91 /// {2} is the position in the element list. 92 constexpr const char *opSingleTemplate = R"Py( 93 @builtins.property 94 def {0}(self): 95 return self.operation.{1}s[{2}] 96 )Py"; 97 98 /// Template for single-element accessor after a variable-length group: 99 /// {0} is the name of the accessor; 100 /// {1} is either 'operand' or 'result'; 101 /// {2} is the total number of element groups; 102 /// {3} is the position of the current group in the group list. 103 /// This works for both a single variadic group (non-negative length) and an 104 /// single optional element (zero length if the element is absent). 105 constexpr const char *opSingleAfterVariableTemplate = R"Py( 106 @builtins.property 107 def {0}(self): 108 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 109 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] 110 )Py"; 111 112 /// Template for an optional element accessor: 113 /// {0} is the name of the accessor; 114 /// {1} is either 'operand' or 'result'; 115 /// {2} is the total number of element groups; 116 /// {3} is the position of the current group in the group list. 117 /// This works if we have only one variable-length group (and it's the optional 118 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is 119 /// smaller than the total number of groups. 120 constexpr const char *opOneOptionalTemplate = R"Py( 121 @builtins.property 122 def {0}(self): 123 return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}] 124 )Py"; 125 126 /// Template for the variadic group accessor in the single variadic group case: 127 /// {0} is the name of the accessor; 128 /// {1} is either 'operand' or 'result'; 129 /// {2} is the total number of element groups; 130 /// {3} is the position of the current group in the group list. 131 constexpr const char *opOneVariadicTemplate = R"Py( 132 @builtins.property 133 def {0}(self): 134 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 135 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] 136 )Py"; 137 138 /// First part of the template for equally-sized variadic group accessor: 139 /// {0} is the name of the accessor; 140 /// {1} is either 'operand' or 'result'; 141 /// {2} is the total number of variadic groups; 142 /// {3} is the number of non-variadic groups preceding the current group; 143 /// {3} is the number of variadic groups preceding the current group. 144 constexpr const char *opVariadicEqualPrefixTemplate = R"Py( 145 @builtins.property 146 def {0}(self): 147 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py"; 148 149 /// Second part of the template for equally-sized case, accessing a single 150 /// element: 151 /// {0} is either 'operand' or 'result'. 152 constexpr const char *opVariadicEqualSimpleTemplate = R"Py( 153 return self.operation.{0}s[start] 154 )Py"; 155 156 /// Second part of the template for equally-sized case, accessing a variadic 157 /// group: 158 /// {0} is either 'operand' or 'result'. 159 constexpr const char *opVariadicEqualVariadicTemplate = R"Py( 160 return self.operation.{0}s[start:start + pg] 161 )Py"; 162 163 /// Template for an attribute-sized group accessor: 164 /// {0} is the name of the accessor; 165 /// {1} is either 'operand' or 'result'; 166 /// {2} is the position of the group in the group list; 167 /// {3} is a return suffix (expected [0] for single-element, empty for 168 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). 169 constexpr const char *opVariadicSegmentTemplate = R"Py( 170 @builtins.property 171 def {0}(self): 172 {1}_range = _ods_segmented_accessor( 173 self.operation.{1}s, 174 self.operation.attributes["{1}_segment_sizes"], {2}) 175 return {1}_range{3} 176 )Py"; 177 178 /// Template for a suffix when accessing an optional element in the 179 /// attribute-sized case: 180 /// {0} is either 'operand' or 'result'; 181 constexpr const char *opVariadicSegmentOptionalTrailingTemplate = 182 R"Py([0] if len({0}_range) > 0 else None)Py"; 183 184 /// Template for an operation attribute getter: 185 /// {0} is the name of the attribute sanitized for Python; 186 /// {1} is the Python type of the attribute; 187 /// {2} os the original name of the attribute. 188 constexpr const char *attributeGetterTemplate = R"Py( 189 @builtins.property 190 def {0}(self): 191 return {1}(self.operation.attributes["{2}"]) 192 )Py"; 193 194 /// Template for an optional operation attribute getter: 195 /// {0} is the name of the attribute sanitized for Python; 196 /// {1} is the Python type of the attribute; 197 /// {2} is the original name of the attribute. 198 constexpr const char *optionalAttributeGetterTemplate = R"Py( 199 @builtins.property 200 def {0}(self): 201 if "{2}" not in self.operation.attributes: 202 return None 203 return {1}(self.operation.attributes["{2}"]) 204 )Py"; 205 206 /// Template for a getter of a unit operation attribute, returns True of the 207 /// unit attribute is present, False otherwise (unit attributes have meaning 208 /// by mere presence): 209 /// {0} is the name of the attribute sanitized for Python, 210 /// {1} is the original name of the attribute. 211 constexpr const char *unitAttributeGetterTemplate = R"Py( 212 @builtins.property 213 def {0}(self): 214 return "{1}" in self.operation.attributes 215 )Py"; 216 217 /// Template for an operation attribute setter: 218 /// {0} is the name of the attribute sanitized for Python; 219 /// {1} is the original name of the attribute. 220 constexpr const char *attributeSetterTemplate = R"Py( 221 @{0}.setter 222 def {0}(self, value): 223 if value is None: 224 raise ValueError("'None' not allowed as value for mandatory attributes") 225 self.operation.attributes["{1}"] = value 226 )Py"; 227 228 /// Template for a setter of an optional operation attribute, setting to None 229 /// removes the attribute: 230 /// {0} is the name of the attribute sanitized for Python; 231 /// {1} is the original name of the attribute. 232 constexpr const char *optionalAttributeSetterTemplate = R"Py( 233 @{0}.setter 234 def {0}(self, value): 235 if value is not None: 236 self.operation.attributes["{1}"] = value 237 elif "{1}" in self.operation.attributes: 238 del self.operation.attributes["{1}"] 239 )Py"; 240 241 /// Template for a setter of a unit operation attribute, setting to None or 242 /// False removes the attribute: 243 /// {0} is the name of the attribute sanitized for Python; 244 /// {1} is the original name of the attribute. 245 constexpr const char *unitAttributeSetterTemplate = R"Py( 246 @{0}.setter 247 def {0}(self, value): 248 if bool(value): 249 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() 250 elif "{1}" in self.operation.attributes: 251 del self.operation.attributes["{1}"] 252 )Py"; 253 254 /// Template for a deleter of an optional or a unit operation attribute, removes 255 /// the attribute from the operation: 256 /// {0} is the name of the attribute sanitized for Python; 257 /// {1} is the original name of the attribute. 258 constexpr const char *attributeDeleterTemplate = R"Py( 259 @{0}.deleter 260 def {0}(self): 261 del self.operation.attributes["{1}"] 262 )Py"; 263 264 constexpr const char *regionAccessorTemplate = R"PY( 265 @builtins.property 266 def {0}(self): 267 return self.regions[{1}] 268 )PY"; 269 270 static llvm::cl::OptionCategory 271 clOpPythonBindingCat("Options for -gen-python-op-bindings"); 272 273 static llvm::cl::opt<std::string> 274 clDialectName("bind-dialect", 275 llvm::cl::desc("The dialect to run the generator for"), 276 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 277 278 static llvm::cl::opt<std::string> clDialectExtensionName( 279 "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"), 280 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 281 282 using AttributeClasses = DenseMap<StringRef, StringRef>; 283 284 /// Checks whether `str` is a Python keyword. 285 static bool isPythonKeyword(StringRef str) { 286 static llvm::StringSet<> keywords( 287 {"and", "as", "assert", "break", "class", "continue", 288 "def", "del", "elif", "else", "except", "finally", 289 "for", "from", "global", "if", "import", "in", 290 "is", "lambda", "nonlocal", "not", "or", "pass", 291 "raise", "return", "try", "while", "with", "yield"}); 292 return keywords.contains(str); 293 } 294 295 /// Checks whether `str` would shadow a generated variable or attribute 296 /// part of the OpView API. 297 static bool isODSReserved(StringRef str) { 298 static llvm::StringSet<> reserved( 299 {"attributes", "create", "context", "ip", "operands", "print", "get_asm", 300 "loc", "verify", "regions", "results", "self", "operation", 301 "DIALECT_NAMESPACE", "OPERATION_NAME"}); 302 return str.startswith("_ods_") || str.endswith("_ods") || 303 reserved.contains(str); 304 } 305 306 /// Modifies the `name` in a way that it becomes suitable for Python bindings 307 /// (does not change the `name` if it already is suitable) and returns the 308 /// modified version. 309 static std::string sanitizeName(StringRef name) { 310 if (isPythonKeyword(name) || isODSReserved(name)) 311 return (name + "_").str(); 312 return name.str(); 313 } 314 315 static std::string attrSizedTraitForKind(const char *kind) { 316 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", 317 llvm::StringRef(kind).take_front().upper(), 318 llvm::StringRef(kind).drop_front()); 319 } 320 321 /// Emits accessors to "elements" of an Op definition. Currently, the supported 322 /// elements are operands and results, indicated by `kind`, which must be either 323 /// `operand` or `result` and is used verbatim in the emitted code. 324 static void emitElementAccessors( 325 const Operator &op, raw_ostream &os, const char *kind, 326 llvm::function_ref<unsigned(const Operator &)> getNumVariableLength, 327 llvm::function_ref<int(const Operator &)> getNumElements, 328 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 329 getElement) { 330 assert(llvm::is_contained( 331 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && 332 "unsupported kind"); 333 334 // Traits indicating how to process variadic elements. 335 std::string sameSizeTrait = 336 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", 337 llvm::StringRef(kind).take_front().upper(), 338 llvm::StringRef(kind).drop_front()); 339 std::string attrSizedTrait = attrSizedTraitForKind(kind); 340 341 unsigned numVariableLength = getNumVariableLength(op); 342 343 // If there is only one variable-length element group, its size can be 344 // inferred from the total number of elements. If there are none, the 345 // generation is straightforward. 346 if (numVariableLength <= 1) { 347 bool seenVariableLength = false; 348 for (int i = 0, e = getNumElements(op); i < e; ++i) { 349 const NamedTypeConstraint &element = getElement(op, i); 350 if (element.isVariableLength()) 351 seenVariableLength = true; 352 if (element.name.empty()) 353 continue; 354 if (element.isVariableLength()) { 355 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate 356 : opOneVariadicTemplate, 357 sanitizeName(element.name), kind, 358 getNumElements(op), i); 359 } else if (seenVariableLength) { 360 os << llvm::formatv(opSingleAfterVariableTemplate, 361 sanitizeName(element.name), kind, 362 getNumElements(op), i); 363 } else { 364 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, 365 i); 366 } 367 } 368 return; 369 } 370 371 // Handle the operations where variadic groups have the same size. 372 if (op.getTrait(sameSizeTrait)) { 373 int numPrecedingSimple = 0; 374 int numPrecedingVariadic = 0; 375 for (int i = 0, e = getNumElements(op); i < e; ++i) { 376 const NamedTypeConstraint &element = getElement(op, i); 377 if (!element.name.empty()) { 378 os << llvm::formatv(opVariadicEqualPrefixTemplate, 379 sanitizeName(element.name), kind, numVariableLength, 380 numPrecedingSimple, numPrecedingVariadic); 381 os << llvm::formatv(element.isVariableLength() 382 ? opVariadicEqualVariadicTemplate 383 : opVariadicEqualSimpleTemplate, 384 kind); 385 } 386 if (element.isVariableLength()) 387 ++numPrecedingVariadic; 388 else 389 ++numPrecedingSimple; 390 } 391 return; 392 } 393 394 // Handle the operations where the size of groups (variadic or not) is 395 // provided as an attribute. For non-variadic elements, make sure to return 396 // an element rather than a singleton container. 397 if (op.getTrait(attrSizedTrait)) { 398 for (int i = 0, e = getNumElements(op); i < e; ++i) { 399 const NamedTypeConstraint &element = getElement(op, i); 400 if (element.name.empty()) 401 continue; 402 std::string trailing; 403 if (!element.isVariableLength()) 404 trailing = "[0]"; 405 else if (element.isOptional()) 406 trailing = std::string( 407 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); 408 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), 409 kind, i, trailing); 410 } 411 return; 412 } 413 414 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); 415 } 416 417 /// Free function helpers accessing Operator components. 418 static int getNumOperands(const Operator &op) { return op.getNumOperands(); } 419 static const NamedTypeConstraint &getOperand(const Operator &op, int i) { 420 return op.getOperand(i); 421 } 422 static int getNumResults(const Operator &op) { return op.getNumResults(); } 423 static const NamedTypeConstraint &getResult(const Operator &op, int i) { 424 return op.getResult(i); 425 } 426 427 /// Emits accessors to Op operands. 428 static void emitOperandAccessors(const Operator &op, raw_ostream &os) { 429 auto getNumVariableLengthOperands = [](const Operator &oper) { 430 return oper.getNumVariableLengthOperands(); 431 }; 432 emitElementAccessors(op, os, "operand", getNumVariableLengthOperands, 433 getNumOperands, getOperand); 434 } 435 436 /// Emits accessors Op results. 437 static void emitResultAccessors(const Operator &op, raw_ostream &os) { 438 auto getNumVariableLengthResults = [](const Operator &oper) { 439 return oper.getNumVariableLengthResults(); 440 }; 441 emitElementAccessors(op, os, "result", getNumVariableLengthResults, 442 getNumResults, getResult); 443 } 444 445 /// Emits accessors to Op attributes. 446 static void emitAttributeAccessors(const Operator &op, 447 const AttributeClasses &attributeClasses, 448 raw_ostream &os) { 449 for (const auto &namedAttr : op.getAttributes()) { 450 // Skip "derived" attributes because they are just C++ functions that we 451 // don't currently expose. 452 if (namedAttr.attr.isDerivedAttr()) 453 continue; 454 455 if (namedAttr.name.empty()) 456 continue; 457 458 std::string sanitizedName = sanitizeName(namedAttr.name); 459 460 // Unit attributes are handled specially. 461 if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 462 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, 463 namedAttr.name); 464 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, 465 namedAttr.name); 466 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 467 namedAttr.name); 468 continue; 469 } 470 471 // Other kinds of attributes need a mapping to a Python type. 472 if (!attributeClasses.count(namedAttr.attr.getStorageType().trim())) 473 continue; 474 475 StringRef pythonType = 476 attributeClasses.lookup(namedAttr.attr.getStorageType()); 477 if (namedAttr.attr.isOptional()) { 478 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName, 479 pythonType, namedAttr.name); 480 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName, 481 namedAttr.name); 482 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 483 namedAttr.name); 484 } else { 485 os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType, 486 namedAttr.name); 487 os << llvm::formatv(attributeSetterTemplate, sanitizedName, 488 namedAttr.name); 489 // Non-optional attributes cannot be deleted. 490 } 491 } 492 } 493 494 /// Template for the default auto-generated builder. 495 /// {0} is a comma-separated list of builder arguments, including the trailing 496 /// `loc` and `ip`; 497 /// {1} is the code populating `operands`, `results` and `attributes`, 498 /// `successors` fields. 499 constexpr const char *initTemplate = R"Py( 500 def __init__(self, {0}): 501 operands = [] 502 results = [] 503 attributes = {{} 504 regions = None 505 {1} 506 super().__init__(self.build_generic( 507 attributes=attributes, results=results, operands=operands, 508 successors=_ods_successors, regions=regions, loc=loc, ip=ip)) 509 )Py"; 510 511 /// Template for appending a single element to the operand/result list. 512 /// {0} is the field name. 513 constexpr const char *singleOperandAppendTemplate = 514 "operands.append(_get_op_result_or_value({0}))"; 515 constexpr const char *singleResultAppendTemplate = "results.append({0})"; 516 517 /// Template for appending an optional element to the operand/result list. 518 /// {0} is the field name. 519 constexpr const char *optionalAppendOperandTemplate = 520 "if {0} is not None: operands.append(_get_op_result_or_value({0}))"; 521 constexpr const char *optionalAppendAttrSizedOperandsTemplate = 522 "operands.append(_get_op_result_or_value({0}) if {0} is not None else " 523 "None)"; 524 constexpr const char *optionalAppendResultTemplate = 525 "if {0} is not None: results.append({0})"; 526 527 /// Template for appending a list of elements to the operand/result list. 528 /// {0} is the field name. 529 constexpr const char *multiOperandAppendTemplate = 530 "operands.extend(_get_op_results_or_values({0}))"; 531 constexpr const char *multiOperandAppendPackTemplate = 532 "operands.append(_get_op_results_or_values({0}))"; 533 constexpr const char *multiResultAppendTemplate = "results.extend({0})"; 534 535 /// Template for setting an attribute in the operation builder. 536 /// {0} is the attribute name; 537 /// {1} is the builder argument name. 538 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; 539 540 /// Template for setting an optional attribute in the operation builder. 541 /// {0} is the attribute name; 542 /// {1} is the builder argument name. 543 constexpr const char *initOptionalAttributeTemplate = 544 R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; 545 546 /// Template for setting an attribute with a default value in the operation 547 /// builder. 548 /// {0} is the attribute name; 549 /// {1} is the builder argument name; 550 /// {2} is the default value. 551 constexpr const char *initDefaultValuedAttributeTemplate = 552 R"Py(attributes["{0}"] = {1} if {1} is not None else {2})Py"; 553 554 /// Template for asserting that an attribute value was provided when calling a 555 /// builder. 556 /// {0} is the attribute name; 557 /// {1} is the builder argument name. 558 constexpr const char *assertAttributeValueSpecified = 559 R"Py(assert {1} is not None, "attribute {0} must be specified")Py"; 560 561 constexpr const char *initUnitAttributeTemplate = 562 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( 563 _ods_get_default_loc_context(loc)))Py"; 564 565 /// Template to initialize the successors list in the builder if there are any 566 /// successors. 567 /// {0} is the value to initialize the successors list to. 568 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; 569 570 /// Template to append or extend the list of successors in the builder. 571 /// {0} is the list method ('append' or 'extend'); 572 /// {1} is the value to add. 573 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py"; 574 575 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer 576 /// result types of the given operation. 577 static bool hasSameArgumentAndResultTypes(const Operator &op) { 578 return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && 579 op.getNumVariableLengthResults() == 0; 580 } 581 582 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer 583 /// result types of the given operation. 584 static bool hasFirstAttrDerivedResultTypes(const Operator &op) { 585 return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && 586 op.getNumVariableLengthResults() == 0; 587 } 588 589 /// Returns true if the InferTypeOpInterface can be used to infer result types 590 /// of the given operation. 591 static bool hasInferTypeInterface(const Operator &op) { 592 return op.getTrait("::mlir::InferTypeOpInterface::Trait") && 593 op.getNumRegions() == 0; 594 } 595 596 /// Returns true if there is a trait or interface that can be used to infer 597 /// result types of the given operation. 598 static bool canInferType(const Operator &op) { 599 return hasSameArgumentAndResultTypes(op) || 600 hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); 601 } 602 603 /// Populates `builderArgs` with result names if the builder is expected to 604 /// accept them as arguments. 605 static void 606 populateBuilderArgsResults(const Operator &op, 607 llvm::SmallVectorImpl<std::string> &builderArgs) { 608 if (canInferType(op)) 609 return; 610 611 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 612 std::string name = op.getResultName(i).str(); 613 if (name.empty()) { 614 if (op.getNumResults() == 1) { 615 // Special case for one result, make the default name be 'result' 616 // to properly match the built-in result accessor. 617 name = "result"; 618 } else { 619 name = llvm::formatv("_gen_res_{0}", i); 620 } 621 } 622 name = sanitizeName(name); 623 builderArgs.push_back(name); 624 } 625 } 626 627 /// Populates `builderArgs` with the Python-compatible names of builder function 628 /// arguments using intermixed attributes and operands in the same order as they 629 /// appear in the `arguments` field of the op definition. Additionally, 630 /// `operandNames` is populated with names of operands in their order of 631 /// appearance. 632 static void 633 populateBuilderArgs(const Operator &op, 634 llvm::SmallVectorImpl<std::string> &builderArgs, 635 llvm::SmallVectorImpl<std::string> &operandNames, 636 llvm::SmallVectorImpl<std::string> &successorArgNames) { 637 638 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 639 std::string name = op.getArgName(i).str(); 640 if (name.empty()) 641 name = llvm::formatv("_gen_arg_{0}", i); 642 name = sanitizeName(name); 643 builderArgs.push_back(name); 644 if (!op.getArg(i).is<NamedAttribute *>()) 645 operandNames.push_back(name); 646 } 647 } 648 649 /// Populates `builderArgs` with the Python-compatible names of builder function 650 /// successor arguments. Additionally, `successorArgNames` is also populated. 651 static void populateBuilderArgsSuccessors( 652 const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs, 653 llvm::SmallVectorImpl<std::string> &successorArgNames) { 654 655 for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { 656 NamedSuccessor successor = op.getSuccessor(i); 657 std::string name = std::string(successor.name); 658 if (name.empty()) 659 name = llvm::formatv("_gen_successor_{0}", i); 660 name = sanitizeName(name); 661 builderArgs.push_back(name); 662 successorArgNames.push_back(name); 663 } 664 } 665 666 /// Generates Python code for the default value of the given attribute. 667 static FailureOr<std::string> getAttributeDefaultValue(Attribute attr) { 668 assert(attr.hasDefaultValue() && "expected attribute with default value"); 669 StringRef storageType = attr.getStorageType().trim(); 670 StringRef defaultValCpp = attr.getDefaultValue().trim(); 671 672 // A list of commonly used attribute types and default values for which 673 // we can generate Python code. Extend as needed. 674 if (storageType.equals("::mlir::ArrayAttr") && defaultValCpp.equals("{}")) 675 return std::string("_ods_ir.ArrayAttr.get([])"); 676 677 // No match: Cannot generate Python code. 678 return failure(); 679 } 680 681 /// Populates `builderLines` with additional lines that are required in the 682 /// builder to set up operation attributes. `argNames` is expected to contain 683 /// the names of builder arguments that correspond to op arguments, i.e. to the 684 /// operands and attributes in the same order as they appear in the `arguments` 685 /// field. 686 static void 687 populateBuilderLinesAttr(const Operator &op, 688 llvm::ArrayRef<std::string> argNames, 689 llvm::SmallVectorImpl<std::string> &builderLines) { 690 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 691 Argument arg = op.getArg(i); 692 auto *attribute = arg.dyn_cast<NamedAttribute *>(); 693 if (!attribute) 694 continue; 695 696 // Unit attributes are handled specially. 697 if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 698 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, 699 attribute->name, argNames[i])); 700 continue; 701 } 702 703 // Attributes with default value are handled specially. 704 if (attribute->attr.hasDefaultValue()) { 705 // In case we cannot generate Python code for the default value, the 706 // attribute must be specified by the user. 707 FailureOr<std::string> defaultValPy = 708 getAttributeDefaultValue(attribute->attr); 709 if (succeeded(defaultValPy)) { 710 builderLines.push_back(llvm::formatv(initDefaultValuedAttributeTemplate, 711 attribute->name, argNames[i], 712 *defaultValPy)); 713 } else { 714 builderLines.push_back(llvm::formatv(assertAttributeValueSpecified, 715 attribute->name, argNames[i])); 716 builderLines.push_back( 717 llvm::formatv(initAttributeTemplate, attribute->name, argNames[i])); 718 } 719 continue; 720 } 721 722 builderLines.push_back(llvm::formatv(attribute->attr.isOptional() 723 ? initOptionalAttributeTemplate 724 : initAttributeTemplate, 725 attribute->name, argNames[i])); 726 } 727 } 728 729 /// Populates `builderLines` with additional lines that are required in the 730 /// builder to set up successors. successorArgNames is expected to correspond 731 /// to the Python argument name for each successor on the op. 732 static void populateBuilderLinesSuccessors( 733 const Operator &op, llvm::ArrayRef<std::string> successorArgNames, 734 llvm::SmallVectorImpl<std::string> &builderLines) { 735 if (successorArgNames.empty()) { 736 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None")); 737 return; 738 } 739 740 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]")); 741 for (int i = 0, e = successorArgNames.size(); i < e; ++i) { 742 auto &argName = successorArgNames[i]; 743 const NamedSuccessor &successor = op.getSuccessor(i); 744 builderLines.push_back( 745 llvm::formatv(addSuccessorTemplate, 746 successor.isVariadic() ? "extend" : "append", argName)); 747 } 748 } 749 750 /// Populates `builderLines` with additional lines that are required in the 751 /// builder to set up op operands. 752 static void 753 populateBuilderLinesOperand(const Operator &op, 754 llvm::ArrayRef<std::string> names, 755 llvm::SmallVectorImpl<std::string> &builderLines) { 756 bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr; 757 758 // For each element, find or generate a name. 759 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 760 const NamedTypeConstraint &element = op.getOperand(i); 761 std::string name = names[i]; 762 763 // Choose the formatting string based on the element kind. 764 llvm::StringRef formatString; 765 if (!element.isVariableLength()) { 766 formatString = singleOperandAppendTemplate; 767 } else if (element.isOptional()) { 768 if (sizedSegments) { 769 formatString = optionalAppendAttrSizedOperandsTemplate; 770 } else { 771 formatString = optionalAppendOperandTemplate; 772 } 773 } else { 774 assert(element.isVariadic() && "unhandled element group type"); 775 // If emitting with sizedSegments, then we add the actual list-typed 776 // element. Otherwise, we extend the actual operands. 777 if (sizedSegments) { 778 formatString = multiOperandAppendPackTemplate; 779 } else { 780 formatString = multiOperandAppendTemplate; 781 } 782 } 783 784 builderLines.push_back(llvm::formatv(formatString.data(), name)); 785 } 786 } 787 788 /// Python code template for deriving the operation result types from its 789 /// attribute: 790 /// - {0} is the name of the attribute from which to derive the types. 791 constexpr const char *deriveTypeFromAttrTemplate = 792 R"PY(_ods_result_type_source_attr = attributes["{0}"] 793 _ods_derived_result_type = ( 794 _ods_ir.TypeAttr(_ods_result_type_source_attr).value 795 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else 796 _ods_result_type_source_attr.type))PY"; 797 798 /// Python code template appending {0} type {1} times to the results list. 799 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; 800 801 /// Python code template for inferring the operation results using the 802 /// corresponding interface: 803 /// - {0} is the name of the class for which the types are inferred. 804 constexpr const char *inferTypeInterfaceTemplate = 805 R"PY(_ods_context = _ods_get_default_loc_context(loc) 806 results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( 807 operands=operands, 808 attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), 809 context=_ods_context, 810 loc=loc) 811 )PY"; 812 813 /// Appends the given multiline string as individual strings into 814 /// `builderLines`. 815 static void appendLineByLine(StringRef string, 816 llvm::SmallVectorImpl<std::string> &builderLines) { 817 818 std::pair<StringRef, StringRef> split = std::make_pair(string, string); 819 do { 820 split = split.second.split('\n'); 821 builderLines.push_back(split.first.str()); 822 } while (!split.second.empty()); 823 } 824 825 /// Populates `builderLines` with additional lines that are required in the 826 /// builder to set up op results. 827 static void 828 populateBuilderLinesResult(const Operator &op, 829 llvm::ArrayRef<std::string> names, 830 llvm::SmallVectorImpl<std::string> &builderLines) { 831 bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; 832 833 if (hasSameArgumentAndResultTypes(op)) { 834 builderLines.push_back(llvm::formatv( 835 appendSameResultsTemplate, "operands[0].type", op.getNumResults())); 836 return; 837 } 838 839 if (hasFirstAttrDerivedResultTypes(op)) { 840 const NamedAttribute &firstAttr = op.getAttribute(0); 841 assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " 842 "from which the type is derived"); 843 appendLineByLine( 844 llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), 845 builderLines); 846 builderLines.push_back(llvm::formatv(appendSameResultsTemplate, 847 "_ods_derived_result_type", 848 op.getNumResults())); 849 return; 850 } 851 852 if (hasInferTypeInterface(op)) { 853 appendLineByLine( 854 llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(), 855 builderLines); 856 return; 857 } 858 859 // For each element, find or generate a name. 860 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 861 const NamedTypeConstraint &element = op.getResult(i); 862 std::string name = names[i]; 863 864 // Choose the formatting string based on the element kind. 865 llvm::StringRef formatString; 866 if (!element.isVariableLength()) { 867 formatString = singleResultAppendTemplate; 868 } else if (element.isOptional()) { 869 formatString = optionalAppendResultTemplate; 870 } else { 871 assert(element.isVariadic() && "unhandled element group type"); 872 // If emitting with sizedSegments, then we add the actual list-typed 873 // element. Otherwise, we extend the actual operands. 874 if (sizedSegments) { 875 formatString = singleResultAppendTemplate; 876 } else { 877 formatString = multiResultAppendTemplate; 878 } 879 } 880 881 builderLines.push_back(llvm::formatv(formatString.data(), name)); 882 } 883 } 884 885 /// If the operation has variadic regions, adds a builder argument to specify 886 /// the number of those regions and builder lines to forward it to the generic 887 /// constructor. 888 static void 889 populateBuilderRegions(const Operator &op, 890 llvm::SmallVectorImpl<std::string> &builderArgs, 891 llvm::SmallVectorImpl<std::string> &builderLines) { 892 if (op.hasNoVariadicRegions()) 893 return; 894 895 // This is currently enforced when Operator is constructed. 896 assert(op.getNumVariadicRegions() == 1 && 897 op.getRegion(op.getNumRegions() - 1).isVariadic() && 898 "expected the last region to be varidic"); 899 900 const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1); 901 std::string name = 902 ("num_" + region.name.take_front().lower() + region.name.drop_front()) 903 .str(); 904 builderArgs.push_back(name); 905 builderLines.push_back( 906 llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name)); 907 } 908 909 /// Emits a default builder constructing an operation from the list of its 910 /// result types, followed by a list of its operands. 911 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { 912 // If we are asked to skip default builders, comply. 913 if (op.skipDefaultBuilders()) 914 return; 915 916 llvm::SmallVector<std::string> builderArgs; 917 llvm::SmallVector<std::string> builderLines; 918 llvm::SmallVector<std::string> operandArgNames; 919 llvm::SmallVector<std::string> successorArgNames; 920 builderArgs.reserve(op.getNumOperands() + op.getNumResults() + 921 op.getNumNativeAttributes() + op.getNumSuccessors()); 922 populateBuilderArgsResults(op, builderArgs); 923 size_t numResultArgs = builderArgs.size(); 924 populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); 925 size_t numOperandAttrArgs = builderArgs.size() - numResultArgs; 926 populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); 927 928 populateBuilderLinesOperand(op, operandArgNames, builderLines); 929 populateBuilderLinesAttr( 930 op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs), 931 builderLines); 932 populateBuilderLinesResult( 933 op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs), 934 builderLines); 935 populateBuilderLinesSuccessors(op, successorArgNames, builderLines); 936 populateBuilderRegions(op, builderArgs, builderLines); 937 938 // Layout of builderArgs vector elements: 939 // [ result_args operand_attr_args successor_args regions ] 940 941 // Determine whether the argument corresponding to a given index into the 942 // builderArgs vector is a python keyword argument or not. 943 auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool { 944 // All result, successor, and region arguments are positional arguments. 945 if ((builderArgIndex < numResultArgs) || 946 (builderArgIndex >= (numResultArgs + numOperandAttrArgs))) 947 return false; 948 // Keyword arguments: 949 // - optional named attributes (including unit attributes) 950 // - default-valued named attributes 951 // - optional operands 952 Argument a = op.getArg(builderArgIndex - numResultArgs); 953 if (auto *nattr = a.dyn_cast<NamedAttribute *>()) 954 return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue()); 955 if (auto *ntype = a.dyn_cast<NamedTypeConstraint *>()) 956 return ntype->isOptional(); 957 return false; 958 }; 959 960 // StringRefs in functionArgs refer to strings allocated by builderArgs. 961 llvm::SmallVector<llvm::StringRef> functionArgs; 962 963 // Add positional arguments. 964 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { 965 if (!isKeywordArgFn(i)) 966 functionArgs.push_back(builderArgs[i]); 967 } 968 969 // Add a bare '*' to indicate that all following arguments must be keyword 970 // arguments. 971 functionArgs.push_back("*"); 972 973 // Add a default 'None' value to each keyword arg string, and then add to the 974 // function args list. 975 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { 976 if (isKeywordArgFn(i)) { 977 builderArgs[i].append("=None"); 978 functionArgs.push_back(builderArgs[i]); 979 } 980 } 981 functionArgs.push_back("loc=None"); 982 functionArgs.push_back("ip=None"); 983 os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "), 984 llvm::join(builderLines, "\n ")); 985 } 986 987 static void constructAttributeMapping(const llvm::RecordKeeper &records, 988 AttributeClasses &attributeClasses) { 989 for (const llvm::Record *rec : 990 records.getAllDerivedDefinitions("PythonAttr")) { 991 attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(), 992 rec->getValueAsString("pythonType").trim()); 993 } 994 } 995 996 static void emitSegmentSpec( 997 const Operator &op, const char *kind, 998 llvm::function_ref<int(const Operator &)> getNumElements, 999 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 1000 getElement, 1001 raw_ostream &os) { 1002 std::string segmentSpec("["); 1003 for (int i = 0, e = getNumElements(op); i < e; ++i) { 1004 const NamedTypeConstraint &element = getElement(op, i); 1005 if (element.isOptional()) { 1006 segmentSpec.append("0,"); 1007 } else if (element.isVariadic()) { 1008 segmentSpec.append("-1,"); 1009 } else { 1010 segmentSpec.append("1,"); 1011 } 1012 } 1013 segmentSpec.append("]"); 1014 1015 os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); 1016 } 1017 1018 static void emitRegionAttributes(const Operator &op, raw_ostream &os) { 1019 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). 1020 // Note that the base OpView class defines this as (0, True). 1021 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); 1022 os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount, 1023 op.hasNoVariadicRegions() ? "True" : "False"); 1024 } 1025 1026 /// Emits named accessors to regions. 1027 static void emitRegionAccessors(const Operator &op, raw_ostream &os) { 1028 for (const auto &en : llvm::enumerate(op.getRegions())) { 1029 const NamedRegion ®ion = en.value(); 1030 if (region.name.empty()) 1031 continue; 1032 1033 assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && 1034 "expected only the last region to be variadic"); 1035 os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name), 1036 std::to_string(en.index()) + 1037 (region.isVariadic() ? ":" : "")); 1038 } 1039 } 1040 1041 /// Emits bindings for a specific Op to the given output stream. 1042 static void emitOpBindings(const Operator &op, 1043 const AttributeClasses &attributeClasses, 1044 raw_ostream &os) { 1045 os << llvm::formatv(opClassTemplate, op.getCppClassName(), 1046 op.getOperationName()); 1047 1048 // Sized segments. 1049 if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { 1050 emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); 1051 } 1052 if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { 1053 emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); 1054 } 1055 1056 emitRegionAttributes(op, os); 1057 emitDefaultOpBuilder(op, os); 1058 emitOperandAccessors(op, os); 1059 emitAttributeAccessors(op, attributeClasses, os); 1060 emitResultAccessors(op, os); 1061 emitRegionAccessors(op, os); 1062 } 1063 1064 /// Emits bindings for the dialect specified in the command line, including file 1065 /// headers and utilities. Returns `false` on success to comply with Tablegen 1066 /// registration requirements. 1067 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { 1068 if (clDialectName.empty()) 1069 llvm::PrintFatalError("dialect name not provided"); 1070 1071 AttributeClasses attributeClasses; 1072 constructAttributeMapping(records, attributeClasses); 1073 1074 bool isExtension = !clDialectExtensionName.empty(); 1075 os << llvm::formatv(fileHeader, isExtension 1076 ? clDialectExtensionName.getValue() 1077 : clDialectName.getValue()); 1078 if (isExtension) 1079 os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue()); 1080 else 1081 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); 1082 1083 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { 1084 Operator op(rec); 1085 if (op.getDialectName() == clDialectName.getValue()) 1086 emitOpBindings(op, attributeClasses, os); 1087 } 1088 return false; 1089 } 1090 1091 static GenRegistration 1092 genPythonBindings("gen-python-op-bindings", 1093 "Generate Python bindings for MLIR Ops", &emitAllOps); 1094