1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python 10 // binding classes wrapping a generic operation API. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/GenInfo.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/ADT/StringSet.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace mlir; 23 using namespace mlir::tblgen; 24 25 /// File header and includes. 26 /// {0} is the dialect namespace. 27 constexpr const char *fileHeader = R"Py( 28 # Autogenerated by mlir-tblgen; don't manually edit. 29 30 from ._ods_common import _cext as _ods_cext 31 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values 32 _ods_ir = _ods_cext.ir 33 34 try: 35 from . import _{0}_ops_ext as _ods_ext_module 36 except ImportError: 37 _ods_ext_module = None 38 39 import builtins 40 41 )Py"; 42 43 /// Template for dialect class: 44 /// {0} is the dialect namespace. 45 constexpr const char *dialectClassTemplate = R"Py( 46 @_ods_cext.register_dialect 47 class _Dialect(_ods_ir.Dialect): 48 DIALECT_NAMESPACE = "{0}" 49 pass 50 51 )Py"; 52 53 /// Template for operation class: 54 /// {0} is the Python class name; 55 /// {1} is the operation name. 56 constexpr const char *opClassTemplate = R"Py( 57 @_ods_cext.register_operation(_Dialect) 58 @_ods_extend_opview_class(_ods_ext_module) 59 class {0}(_ods_ir.OpView): 60 OPERATION_NAME = "{1}" 61 )Py"; 62 63 /// Template for class level declarations of operand and result 64 /// segment specs. 65 /// {0} is either "OPERAND" or "RESULT" 66 /// {1} is the segment spec 67 /// Each segment spec is either None (default) or an array of integers 68 /// where: 69 /// 1 = single element (expect non sequence operand/result) 70 /// -1 = operand/result is a sequence corresponding to a variadic 71 constexpr const char *opClassSizedSegmentsTemplate = R"Py( 72 _ODS_{0}_SEGMENTS = {1} 73 )Py"; 74 75 /// Template for class level declarations of the _ODS_REGIONS spec: 76 /// {0} is the minimum number of regions 77 /// {1} is the Python bool literal for hasNoVariadicRegions 78 constexpr const char *opClassRegionSpecTemplate = R"Py( 79 _ODS_REGIONS = ({0}, {1}) 80 )Py"; 81 82 /// Template for single-element accessor: 83 /// {0} is the name of the accessor; 84 /// {1} is either 'operand' or 'result'; 85 /// {2} is the position in the element list. 86 constexpr const char *opSingleTemplate = R"Py( 87 @builtins.property 88 def {0}(self): 89 return self.operation.{1}s[{2}] 90 )Py"; 91 92 /// Template for single-element accessor after a variable-length group: 93 /// {0} is the name of the accessor; 94 /// {1} is either 'operand' or 'result'; 95 /// {2} is the total number of element groups; 96 /// {3} is the position of the current group in the group list. 97 /// This works for both a single variadic group (non-negative length) and an 98 /// single optional element (zero length if the element is absent). 99 constexpr const char *opSingleAfterVariableTemplate = R"Py( 100 @builtins.property 101 def {0}(self): 102 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 103 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] 104 )Py"; 105 106 /// Template for an optional element accessor: 107 /// {0} is the name of the accessor; 108 /// {1} is either 'operand' or 'result'; 109 /// {2} is the total number of element groups; 110 /// {3} is the position of the current group in the group list. 111 constexpr const char *opOneOptionalTemplate = R"Py( 112 @builtins.property 113 def {0}(self): 114 return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None 115 )Py"; 116 117 /// Template for the variadic group accessor in the single variadic group case: 118 /// {0} is the name of the accessor; 119 /// {1} is either 'operand' or 'result'; 120 /// {2} is the total number of element groups; 121 /// {3} is the position of the current group in the group list. 122 constexpr const char *opOneVariadicTemplate = R"Py( 123 @builtins.property 124 def {0}(self): 125 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 126 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] 127 )Py"; 128 129 /// First part of the template for equally-sized variadic group accessor: 130 /// {0} is the name of the accessor; 131 /// {1} is either 'operand' or 'result'; 132 /// {2} is the total number of variadic groups; 133 /// {3} is the number of non-variadic groups preceding the current group; 134 /// {3} is the number of variadic groups preceding the current group. 135 constexpr const char *opVariadicEqualPrefixTemplate = R"Py( 136 @builtins.property 137 def {0}(self): 138 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py"; 139 140 /// Second part of the template for equally-sized case, accessing a single 141 /// element: 142 /// {0} is either 'operand' or 'result'. 143 constexpr const char *opVariadicEqualSimpleTemplate = R"Py( 144 return self.operation.{0}s[start] 145 )Py"; 146 147 /// Second part of the template for equally-sized case, accessing a variadic 148 /// group: 149 /// {0} is either 'operand' or 'result'. 150 constexpr const char *opVariadicEqualVariadicTemplate = R"Py( 151 return self.operation.{0}s[start:start + pg] 152 )Py"; 153 154 /// Template for an attribute-sized group accessor: 155 /// {0} is the name of the accessor; 156 /// {1} is either 'operand' or 'result'; 157 /// {2} is the position of the group in the group list; 158 /// {3} is a return suffix (expected [0] for single-element, empty for 159 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). 160 constexpr const char *opVariadicSegmentTemplate = R"Py( 161 @builtins.property 162 def {0}(self): 163 {1}_range = _ods_segmented_accessor( 164 self.operation.{1}s, 165 self.operation.attributes["{1}_segment_sizes"], {2}) 166 return {1}_range{3} 167 )Py"; 168 169 /// Template for a suffix when accessing an optional element in the 170 /// attribute-sized case: 171 /// {0} is either 'operand' or 'result'; 172 constexpr const char *opVariadicSegmentOptionalTrailingTemplate = 173 R"Py([0] if len({0}_range) > 0 else None)Py"; 174 175 /// Template for an operation attribute getter: 176 /// {0} is the name of the attribute sanitized for Python; 177 /// {1} is the Python type of the attribute; 178 /// {2} os the original name of the attribute. 179 constexpr const char *attributeGetterTemplate = R"Py( 180 @builtins.property 181 def {0}(self): 182 return {1}(self.operation.attributes["{2}"]) 183 )Py"; 184 185 /// Template for an optional operation attribute getter: 186 /// {0} is the name of the attribute sanitized for Python; 187 /// {1} is the Python type of the attribute; 188 /// {2} is the original name of the attribute. 189 constexpr const char *optionalAttributeGetterTemplate = R"Py( 190 @builtins.property 191 def {0}(self): 192 if "{2}" not in self.operation.attributes: 193 return None 194 return {1}(self.operation.attributes["{2}"]) 195 )Py"; 196 197 /// Template for a getter of a unit operation attribute, returns True of the 198 /// unit attribute is present, False otherwise (unit attributes have meaning 199 /// by mere presence): 200 /// {0} is the name of the attribute sanitized for Python, 201 /// {1} is the original name of the attribute. 202 constexpr const char *unitAttributeGetterTemplate = R"Py( 203 @builtins.property 204 def {0}(self): 205 return "{1}" in self.operation.attributes 206 )Py"; 207 208 /// Template for an operation attribute setter: 209 /// {0} is the name of the attribute sanitized for Python; 210 /// {1} is the original name of the attribute. 211 constexpr const char *attributeSetterTemplate = R"Py( 212 @{0}.setter 213 def {0}(self, value): 214 if value is None: 215 raise ValueError("'None' not allowed as value for mandatory attributes") 216 self.operation.attributes["{1}"] = value 217 )Py"; 218 219 /// Template for a setter of an optional operation attribute, setting to None 220 /// removes the attribute: 221 /// {0} is the name of the attribute sanitized for Python; 222 /// {1} is the original name of the attribute. 223 constexpr const char *optionalAttributeSetterTemplate = R"Py( 224 @{0}.setter 225 def {0}(self, value): 226 if value is not None: 227 self.operation.attributes["{1}"] = value 228 elif "{1}" in self.operation.attributes: 229 del self.operation.attributes["{1}"] 230 )Py"; 231 232 /// Template for a setter of a unit operation attribute, setting to None or 233 /// False removes the attribute: 234 /// {0} is the name of the attribute sanitized for Python; 235 /// {1} is the original name of the attribute. 236 constexpr const char *unitAttributeSetterTemplate = R"Py( 237 @{0}.setter 238 def {0}(self, value): 239 if bool(value): 240 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() 241 elif "{1}" in self.operation.attributes: 242 del self.operation.attributes["{1}"] 243 )Py"; 244 245 /// Template for a deleter of an optional or a unit operation attribute, removes 246 /// the attribute from the operation: 247 /// {0} is the name of the attribute sanitized for Python; 248 /// {1} is the original name of the attribute. 249 constexpr const char *attributeDeleterTemplate = R"Py( 250 @{0}.deleter 251 def {0}(self): 252 del self.operation.attributes["{1}"] 253 )Py"; 254 255 constexpr const char *regionAccessorTemplate = R"PY( 256 @builtins.property 257 def {0}(self): 258 return self.regions[{1}] 259 )PY"; 260 261 static llvm::cl::OptionCategory 262 clOpPythonBindingCat("Options for -gen-python-op-bindings"); 263 264 static llvm::cl::opt<std::string> 265 clDialectName("bind-dialect", 266 llvm::cl::desc("The dialect to run the generator for"), 267 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 268 269 using AttributeClasses = DenseMap<StringRef, StringRef>; 270 271 /// Checks whether `str` is a Python keyword. 272 static bool isPythonKeyword(StringRef str) { 273 static llvm::StringSet<> keywords( 274 {"and", "as", "assert", "break", "class", "continue", 275 "def", "del", "elif", "else", "except", "finally", 276 "for", "from", "global", "if", "import", "in", 277 "is", "lambda", "nonlocal", "not", "or", "pass", 278 "raise", "return", "try", "while", "with", "yield"}); 279 return keywords.contains(str); 280 } 281 282 /// Checks whether `str` would shadow a generated variable or attribute 283 /// part of the OpView API. 284 static bool isODSReserved(StringRef str) { 285 static llvm::StringSet<> reserved( 286 {"attributes", "create", "context", "ip", "operands", "print", "get_asm", 287 "loc", "verify", "regions", "results", "self", "operation", 288 "DIALECT_NAMESPACE", "OPERATION_NAME"}); 289 return str.startswith("_ods_") || str.endswith("_ods") || 290 reserved.contains(str); 291 } 292 293 /// Modifies the `name` in a way that it becomes suitable for Python bindings 294 /// (does not change the `name` if it already is suitable) and returns the 295 /// modified version. 296 static std::string sanitizeName(StringRef name) { 297 if (isPythonKeyword(name) || isODSReserved(name)) 298 return (name + "_").str(); 299 return name.str(); 300 } 301 302 static std::string attrSizedTraitForKind(const char *kind) { 303 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", 304 llvm::StringRef(kind).take_front().upper(), 305 llvm::StringRef(kind).drop_front()); 306 } 307 308 /// Emits accessors to "elements" of an Op definition. Currently, the supported 309 /// elements are operands and results, indicated by `kind`, which must be either 310 /// `operand` or `result` and is used verbatim in the emitted code. 311 static void emitElementAccessors( 312 const Operator &op, raw_ostream &os, const char *kind, 313 llvm::function_ref<unsigned(const Operator &)> getNumVariadic, 314 llvm::function_ref<int(const Operator &)> getNumElements, 315 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 316 getElement) { 317 assert(llvm::is_contained( 318 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && 319 "unsupported kind"); 320 321 // Traits indicating how to process variadic elements. 322 std::string sameSizeTrait = 323 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", 324 llvm::StringRef(kind).take_front().upper(), 325 llvm::StringRef(kind).drop_front()); 326 std::string attrSizedTrait = attrSizedTraitForKind(kind); 327 328 unsigned numVariadic = getNumVariadic(op); 329 330 // If there is only one variadic element group, its size can be inferred from 331 // the total number of elements. If there are none, the generation is 332 // straightforward. 333 if (numVariadic <= 1) { 334 bool seenVariableLength = false; 335 for (int i = 0, e = getNumElements(op); i < e; ++i) { 336 const NamedTypeConstraint &element = getElement(op, i); 337 if (element.isVariableLength()) 338 seenVariableLength = true; 339 if (element.name.empty()) 340 continue; 341 if (element.isVariableLength()) { 342 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate 343 : opOneVariadicTemplate, 344 sanitizeName(element.name), kind, 345 getNumElements(op), i); 346 } else if (seenVariableLength) { 347 os << llvm::formatv(opSingleAfterVariableTemplate, 348 sanitizeName(element.name), kind, 349 getNumElements(op), i); 350 } else { 351 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, 352 i); 353 } 354 } 355 return; 356 } 357 358 // Handle the operations where variadic groups have the same size. 359 if (op.getTrait(sameSizeTrait)) { 360 int numPrecedingSimple = 0; 361 int numPrecedingVariadic = 0; 362 for (int i = 0, e = getNumElements(op); i < e; ++i) { 363 const NamedTypeConstraint &element = getElement(op, i); 364 if (!element.name.empty()) { 365 os << llvm::formatv(opVariadicEqualPrefixTemplate, 366 sanitizeName(element.name), kind, numVariadic, 367 numPrecedingSimple, numPrecedingVariadic); 368 os << llvm::formatv(element.isVariableLength() 369 ? opVariadicEqualVariadicTemplate 370 : opVariadicEqualSimpleTemplate, 371 kind); 372 } 373 if (element.isVariableLength()) 374 ++numPrecedingVariadic; 375 else 376 ++numPrecedingSimple; 377 } 378 return; 379 } 380 381 // Handle the operations where the size of groups (variadic or not) is 382 // provided as an attribute. For non-variadic elements, make sure to return 383 // an element rather than a singleton container. 384 if (op.getTrait(attrSizedTrait)) { 385 for (int i = 0, e = getNumElements(op); i < e; ++i) { 386 const NamedTypeConstraint &element = getElement(op, i); 387 if (element.name.empty()) 388 continue; 389 std::string trailing; 390 if (!element.isVariableLength()) 391 trailing = "[0]"; 392 else if (element.isOptional()) 393 trailing = std::string( 394 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); 395 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), 396 kind, i, trailing); 397 } 398 return; 399 } 400 401 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); 402 } 403 404 /// Free function helpers accessing Operator components. 405 static int getNumOperands(const Operator &op) { return op.getNumOperands(); } 406 static const NamedTypeConstraint &getOperand(const Operator &op, int i) { 407 return op.getOperand(i); 408 } 409 static int getNumResults(const Operator &op) { return op.getNumResults(); } 410 static const NamedTypeConstraint &getResult(const Operator &op, int i) { 411 return op.getResult(i); 412 } 413 414 /// Emits accessors to Op operands. 415 static void emitOperandAccessors(const Operator &op, raw_ostream &os) { 416 auto getNumVariadic = [](const Operator &oper) { 417 return oper.getNumVariableLengthOperands(); 418 }; 419 emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands, 420 getOperand); 421 } 422 423 /// Emits accessors Op results. 424 static void emitResultAccessors(const Operator &op, raw_ostream &os) { 425 auto getNumVariadic = [](const Operator &oper) { 426 return oper.getNumVariableLengthResults(); 427 }; 428 emitElementAccessors(op, os, "result", getNumVariadic, getNumResults, 429 getResult); 430 } 431 432 /// Emits accessors to Op attributes. 433 static void emitAttributeAccessors(const Operator &op, 434 const AttributeClasses &attributeClasses, 435 raw_ostream &os) { 436 for (const auto &namedAttr : op.getAttributes()) { 437 // Skip "derived" attributes because they are just C++ functions that we 438 // don't currently expose. 439 if (namedAttr.attr.isDerivedAttr()) 440 continue; 441 442 if (namedAttr.name.empty()) 443 continue; 444 445 std::string sanitizedName = sanitizeName(namedAttr.name); 446 447 // Unit attributes are handled specially. 448 if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 449 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, 450 namedAttr.name); 451 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, 452 namedAttr.name); 453 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 454 namedAttr.name); 455 continue; 456 } 457 458 // Other kinds of attributes need a mapping to a Python type. 459 if (!attributeClasses.count(namedAttr.attr.getStorageType().trim())) 460 continue; 461 462 StringRef pythonType = 463 attributeClasses.lookup(namedAttr.attr.getStorageType()); 464 if (namedAttr.attr.isOptional()) { 465 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName, 466 pythonType, namedAttr.name); 467 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName, 468 namedAttr.name); 469 os << llvm::formatv(attributeDeleterTemplate, sanitizedName, 470 namedAttr.name); 471 } else { 472 os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType, 473 namedAttr.name); 474 os << llvm::formatv(attributeSetterTemplate, sanitizedName, 475 namedAttr.name); 476 // Non-optional attributes cannot be deleted. 477 } 478 } 479 } 480 481 /// Template for the default auto-generated builder. 482 /// {0} is a comma-separated list of builder arguments, including the trailing 483 /// `loc` and `ip`; 484 /// {1} is the code populating `operands`, `results` and `attributes`, 485 /// `successors` fields. 486 constexpr const char *initTemplate = R"Py( 487 def __init__(self, {0}): 488 operands = [] 489 results = [] 490 attributes = {{} 491 regions = None 492 {1} 493 super().__init__(self.build_generic( 494 attributes=attributes, results=results, operands=operands, 495 successors=_ods_successors, regions=regions, loc=loc, ip=ip)) 496 )Py"; 497 498 /// Template for appending a single element to the operand/result list. 499 /// {0} is the field name. 500 constexpr const char *singleOperandAppendTemplate = 501 "operands.append(_get_op_result_or_value({0}))"; 502 constexpr const char *singleResultAppendTemplate = "results.append({0})"; 503 504 /// Template for appending an optional element to the operand/result list. 505 /// {0} is the field name. 506 constexpr const char *optionalAppendOperandTemplate = 507 "if {0} is not None: operands.append(_get_op_result_or_value({0}))"; 508 constexpr const char *optionalAppendResultTemplate = 509 "if {0} is not None: results.append({0})"; 510 511 /// Template for appending a list of elements to the operand/result list. 512 /// {0} is the field name. 513 constexpr const char *multiOperandAppendTemplate = 514 "operands.extend(_get_op_results_or_values({0}))"; 515 constexpr const char *multiOperandAppendPackTemplate = 516 "operands.append(_get_op_results_or_values({0}))"; 517 constexpr const char *multiResultAppendTemplate = "results.extend({0})"; 518 519 /// Template for setting an attribute in the operation builder. 520 /// {0} is the attribute name; 521 /// {1} is the builder argument name. 522 constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; 523 524 /// Template for setting an optional attribute in the operation builder. 525 /// {0} is the attribute name; 526 /// {1} is the builder argument name. 527 constexpr const char *initOptionalAttributeTemplate = 528 R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; 529 530 constexpr const char *initUnitAttributeTemplate = 531 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( 532 _ods_get_default_loc_context(loc)))Py"; 533 534 /// Template to initialize the successors list in the builder if there are any 535 /// successors. 536 /// {0} is the value to initialize the successors list to. 537 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; 538 539 /// Template to append or extend the list of successors in the builder. 540 /// {0} is the list method ('append' or 'extend'); 541 /// {1} is the value to add. 542 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py"; 543 544 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer 545 /// result types of the given operation. 546 static bool hasSameArgumentAndResultTypes(const Operator &op) { 547 return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && 548 op.getNumVariableLengthResults() == 0; 549 } 550 551 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer 552 /// result types of the given operation. 553 static bool hasFirstAttrDerivedResultTypes(const Operator &op) { 554 return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && 555 op.getNumVariableLengthResults() == 0; 556 } 557 558 /// Returns true if the InferTypeOpInterface can be used to infer result types 559 /// of the given operation. 560 static bool hasInferTypeInterface(const Operator &op) { 561 return op.getTrait("::mlir::InferTypeOpInterface::Trait") && 562 op.getNumRegions() == 0; 563 } 564 565 /// Returns true if there is a trait or interface that can be used to infer 566 /// result types of the given operation. 567 static bool canInferType(const Operator &op) { 568 return hasSameArgumentAndResultTypes(op) || 569 hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); 570 } 571 572 /// Populates `builderArgs` with result names if the builder is expected to 573 /// accept them as arguments. 574 static void 575 populateBuilderArgsResults(const Operator &op, 576 llvm::SmallVectorImpl<std::string> &builderArgs) { 577 if (canInferType(op)) 578 return; 579 580 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 581 std::string name = op.getResultName(i).str(); 582 if (name.empty()) { 583 if (op.getNumResults() == 1) { 584 // Special case for one result, make the default name be 'result' 585 // to properly match the built-in result accessor. 586 name = "result"; 587 } else { 588 name = llvm::formatv("_gen_res_{0}", i); 589 } 590 } 591 name = sanitizeName(name); 592 builderArgs.push_back(name); 593 } 594 } 595 596 /// Populates `builderArgs` with the Python-compatible names of builder function 597 /// arguments using intermixed attributes and operands in the same order as they 598 /// appear in the `arguments` field of the op definition. Additionally, 599 /// `operandNames` is populated with names of operands in their order of 600 /// appearance. 601 static void 602 populateBuilderArgs(const Operator &op, 603 llvm::SmallVectorImpl<std::string> &builderArgs, 604 llvm::SmallVectorImpl<std::string> &operandNames, 605 llvm::SmallVectorImpl<std::string> &successorArgNames) { 606 607 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 608 std::string name = op.getArgName(i).str(); 609 if (name.empty()) 610 name = llvm::formatv("_gen_arg_{0}", i); 611 name = sanitizeName(name); 612 builderArgs.push_back(name); 613 if (!op.getArg(i).is<NamedAttribute *>()) 614 operandNames.push_back(name); 615 } 616 617 for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { 618 NamedSuccessor successor = op.getSuccessor(i); 619 std::string name = std::string(successor.name); 620 if (name.empty()) 621 name = llvm::formatv("_gen_successor_{0}", i); 622 name = sanitizeName(name); 623 builderArgs.push_back(name); 624 successorArgNames.push_back(name); 625 } 626 } 627 628 /// Populates `builderLines` with additional lines that are required in the 629 /// builder to set up operation attributes. `argNames` is expected to contain 630 /// the names of builder arguments that correspond to op arguments, i.e. to the 631 /// operands and attributes in the same order as they appear in the `arguments` 632 /// field. 633 static void 634 populateBuilderLinesAttr(const Operator &op, 635 llvm::ArrayRef<std::string> argNames, 636 llvm::SmallVectorImpl<std::string> &builderLines) { 637 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 638 Argument arg = op.getArg(i); 639 auto *attribute = arg.dyn_cast<NamedAttribute *>(); 640 if (!attribute) 641 continue; 642 643 // Unit attributes are handled specially. 644 if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { 645 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, 646 attribute->name, argNames[i])); 647 continue; 648 } 649 650 builderLines.push_back(llvm::formatv(attribute->attr.isOptional() 651 ? initOptionalAttributeTemplate 652 : initAttributeTemplate, 653 attribute->name, argNames[i])); 654 } 655 } 656 657 /// Populates `builderLines` with additional lines that are required in the 658 /// builder to set up successors. successorArgNames is expected to correspond 659 /// to the Python argument name for each successor on the op. 660 static void populateBuilderLinesSuccessors( 661 const Operator &op, llvm::ArrayRef<std::string> successorArgNames, 662 llvm::SmallVectorImpl<std::string> &builderLines) { 663 if (successorArgNames.empty()) { 664 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None")); 665 return; 666 } 667 668 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]")); 669 for (int i = 0, e = successorArgNames.size(); i < e; ++i) { 670 auto &argName = successorArgNames[i]; 671 const NamedSuccessor &successor = op.getSuccessor(i); 672 builderLines.push_back( 673 llvm::formatv(addSuccessorTemplate, 674 successor.isVariadic() ? "extend" : "append", argName)); 675 } 676 } 677 678 /// Populates `builderLines` with additional lines that are required in the 679 /// builder to set up op operands. 680 static void 681 populateBuilderLinesOperand(const Operator &op, 682 llvm::ArrayRef<std::string> names, 683 llvm::SmallVectorImpl<std::string> &builderLines) { 684 bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr; 685 686 // For each element, find or generate a name. 687 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 688 const NamedTypeConstraint &element = op.getOperand(i); 689 std::string name = names[i]; 690 691 // Choose the formatting string based on the element kind. 692 llvm::StringRef formatString; 693 if (!element.isVariableLength()) { 694 formatString = singleOperandAppendTemplate; 695 } else if (element.isOptional()) { 696 formatString = optionalAppendOperandTemplate; 697 } else { 698 assert(element.isVariadic() && "unhandled element group type"); 699 // If emitting with sizedSegments, then we add the actual list-typed 700 // element. Otherwise, we extend the actual operands. 701 if (sizedSegments) { 702 formatString = multiOperandAppendPackTemplate; 703 } else { 704 formatString = multiOperandAppendTemplate; 705 } 706 } 707 708 builderLines.push_back(llvm::formatv(formatString.data(), name)); 709 } 710 } 711 712 /// Python code template for deriving the operation result types from its 713 /// attribute: 714 /// - {0} is the name of the attribute from which to derive the types. 715 constexpr const char *deriveTypeFromAttrTemplate = 716 R"PY(_ods_result_type_source_attr = attributes["{0}"] 717 _ods_derived_result_type = ( 718 _ods_ir.TypeAttr(_ods_result_type_source_attr).value 719 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else 720 _ods_result_type_source_attr.type))PY"; 721 722 /// Python code template appending {0} type {1} times to the results list. 723 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; 724 725 /// Python code template for inferring the operation results using the 726 /// corresponding interface: 727 /// - {0} is the name of the class for which the types are inferred. 728 constexpr const char *inferTypeInterfaceTemplate = 729 R"PY(_ods_context = _ods_get_default_loc_context(loc) 730 results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( 731 operands=operands, 732 attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), 733 context=_ods_context, 734 loc=loc) 735 )PY"; 736 737 /// Appends the given multiline string as individual strings into 738 /// `builderLines`. 739 static void appendLineByLine(StringRef string, 740 llvm::SmallVectorImpl<std::string> &builderLines) { 741 742 std::pair<StringRef, StringRef> split = std::make_pair(string, string); 743 do { 744 split = split.second.split('\n'); 745 builderLines.push_back(split.first.str()); 746 } while (!split.second.empty()); 747 } 748 749 /// Populates `builderLines` with additional lines that are required in the 750 /// builder to set up op results. 751 static void 752 populateBuilderLinesResult(const Operator &op, 753 llvm::ArrayRef<std::string> names, 754 llvm::SmallVectorImpl<std::string> &builderLines) { 755 bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; 756 757 if (hasSameArgumentAndResultTypes(op)) { 758 builderLines.push_back(llvm::formatv( 759 appendSameResultsTemplate, "operands[0].type", op.getNumResults())); 760 return; 761 } 762 763 if (hasFirstAttrDerivedResultTypes(op)) { 764 const NamedAttribute &firstAttr = op.getAttribute(0); 765 assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " 766 "from which the type is derived"); 767 appendLineByLine( 768 llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), 769 builderLines); 770 builderLines.push_back(llvm::formatv(appendSameResultsTemplate, 771 "_ods_derived_result_type", 772 op.getNumResults())); 773 return; 774 } 775 776 if (hasInferTypeInterface(op)) { 777 appendLineByLine( 778 llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(), 779 builderLines); 780 return; 781 } 782 783 // For each element, find or generate a name. 784 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 785 const NamedTypeConstraint &element = op.getResult(i); 786 std::string name = names[i]; 787 788 // Choose the formatting string based on the element kind. 789 llvm::StringRef formatString; 790 if (!element.isVariableLength()) { 791 formatString = singleResultAppendTemplate; 792 } else if (element.isOptional()) { 793 formatString = optionalAppendResultTemplate; 794 } else { 795 assert(element.isVariadic() && "unhandled element group type"); 796 // If emitting with sizedSegments, then we add the actual list-typed 797 // element. Otherwise, we extend the actual operands. 798 if (sizedSegments) { 799 formatString = singleResultAppendTemplate; 800 } else { 801 formatString = multiResultAppendTemplate; 802 } 803 } 804 805 builderLines.push_back(llvm::formatv(formatString.data(), name)); 806 } 807 } 808 809 /// If the operation has variadic regions, adds a builder argument to specify 810 /// the number of those regions and builder lines to forward it to the generic 811 /// constructor. 812 static void 813 populateBuilderRegions(const Operator &op, 814 llvm::SmallVectorImpl<std::string> &builderArgs, 815 llvm::SmallVectorImpl<std::string> &builderLines) { 816 if (op.hasNoVariadicRegions()) 817 return; 818 819 // This is currently enforced when Operator is constructed. 820 assert(op.getNumVariadicRegions() == 1 && 821 op.getRegion(op.getNumRegions() - 1).isVariadic() && 822 "expected the last region to be varidic"); 823 824 const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1); 825 std::string name = 826 ("num_" + region.name.take_front().lower() + region.name.drop_front()) 827 .str(); 828 builderArgs.push_back(name); 829 builderLines.push_back( 830 llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name)); 831 } 832 833 /// Emits a default builder constructing an operation from the list of its 834 /// result types, followed by a list of its operands. 835 static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { 836 // If we are asked to skip default builders, comply. 837 if (op.skipDefaultBuilders()) 838 return; 839 840 llvm::SmallVector<std::string> builderArgs; 841 llvm::SmallVector<std::string> builderLines; 842 llvm::SmallVector<std::string> operandArgNames; 843 llvm::SmallVector<std::string> successorArgNames; 844 builderArgs.reserve(op.getNumOperands() + op.getNumResults() + 845 op.getNumNativeAttributes() + op.getNumSuccessors()); 846 populateBuilderArgsResults(op, builderArgs); 847 size_t numResultArgs = builderArgs.size(); 848 populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); 849 850 populateBuilderLinesOperand(op, operandArgNames, builderLines); 851 populateBuilderLinesAttr( 852 op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs), 853 builderLines); 854 populateBuilderLinesResult( 855 op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs), 856 builderLines); 857 populateBuilderLinesSuccessors(op, successorArgNames, builderLines); 858 populateBuilderRegions(op, builderArgs, builderLines); 859 860 builderArgs.push_back("*"); 861 builderArgs.push_back("loc=None"); 862 builderArgs.push_back("ip=None"); 863 os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "), 864 llvm::join(builderLines, "\n ")); 865 } 866 867 static void constructAttributeMapping(const llvm::RecordKeeper &records, 868 AttributeClasses &attributeClasses) { 869 for (const llvm::Record *rec : 870 records.getAllDerivedDefinitions("PythonAttr")) { 871 attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(), 872 rec->getValueAsString("pythonType").trim()); 873 } 874 } 875 876 static void emitSegmentSpec( 877 const Operator &op, const char *kind, 878 llvm::function_ref<int(const Operator &)> getNumElements, 879 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 880 getElement, 881 raw_ostream &os) { 882 std::string segmentSpec("["); 883 for (int i = 0, e = getNumElements(op); i < e; ++i) { 884 const NamedTypeConstraint &element = getElement(op, i); 885 if (element.isVariableLength()) { 886 segmentSpec.append("-1,"); 887 } else if (element.isOptional()) { 888 segmentSpec.append("0,"); 889 } else { 890 segmentSpec.append("1,"); 891 } 892 } 893 segmentSpec.append("]"); 894 895 os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); 896 } 897 898 static void emitRegionAttributes(const Operator &op, raw_ostream &os) { 899 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). 900 // Note that the base OpView class defines this as (0, True). 901 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); 902 os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount, 903 op.hasNoVariadicRegions() ? "True" : "False"); 904 } 905 906 /// Emits named accessors to regions. 907 static void emitRegionAccessors(const Operator &op, raw_ostream &os) { 908 for (auto en : llvm::enumerate(op.getRegions())) { 909 const NamedRegion ®ion = en.value(); 910 if (region.name.empty()) 911 continue; 912 913 assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && 914 "expected only the last region to be variadic"); 915 os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name), 916 std::to_string(en.index()) + 917 (region.isVariadic() ? ":" : "")); 918 } 919 } 920 921 /// Emits bindings for a specific Op to the given output stream. 922 static void emitOpBindings(const Operator &op, 923 const AttributeClasses &attributeClasses, 924 raw_ostream &os) { 925 os << llvm::formatv(opClassTemplate, op.getCppClassName(), 926 op.getOperationName()); 927 928 // Sized segments. 929 if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { 930 emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); 931 } 932 if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { 933 emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); 934 } 935 936 emitRegionAttributes(op, os); 937 emitDefaultOpBuilder(op, os); 938 emitOperandAccessors(op, os); 939 emitAttributeAccessors(op, attributeClasses, os); 940 emitResultAccessors(op, os); 941 emitRegionAccessors(op, os); 942 } 943 944 /// Emits bindings for the dialect specified in the command line, including file 945 /// headers and utilities. Returns `false` on success to comply with Tablegen 946 /// registration requirements. 947 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { 948 if (clDialectName.empty()) 949 llvm::PrintFatalError("dialect name not provided"); 950 951 AttributeClasses attributeClasses; 952 constructAttributeMapping(records, attributeClasses); 953 954 os << llvm::formatv(fileHeader, clDialectName.getValue()); 955 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); 956 957 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { 958 Operator op(rec); 959 if (op.getDialectName() == clDialectName.getValue()) 960 emitOpBindings(op, attributeClasses, os); 961 } 962 return false; 963 } 964 965 static GenRegistration 966 genPythonBindings("gen-python-op-bindings", 967 "Generate Python bindings for MLIR Ops", &emitAllOps); 968