1 //===- Operator.cpp - Operator class --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/TableGen/Operator.h" 14 #include "mlir/TableGen/Predicate.h" 15 #include "mlir/TableGen/Trait.h" 16 #include "mlir/TableGen/Type.h" 17 #include "llvm/ADT/EquivalenceClasses.h" 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/ADT/Sequence.h" 20 #include "llvm/ADT/SmallPtrSet.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/ErrorHandling.h" 25 #include "llvm/Support/FormatVariadic.h" 26 #include "llvm/TableGen/Error.h" 27 #include "llvm/TableGen/Record.h" 28 29 #define DEBUG_TYPE "mlir-tblgen-operator" 30 31 using namespace mlir; 32 using namespace mlir::tblgen; 33 34 using llvm::DagInit; 35 using llvm::DefInit; 36 using llvm::Record; 37 38 Operator::Operator(const llvm::Record &def) 39 : dialect(def.getValueAsDef("opDialect")), def(def) { 40 // The first `_` in the op's TableGen def name is treated as separating the 41 // dialect prefix and the op class name. The dialect prefix will be ignored if 42 // not empty. Otherwise, if def name starts with a `_`, the `_` is considered 43 // as part of the class name. 44 StringRef prefix; 45 std::tie(prefix, cppClassName) = def.getName().split('_'); 46 if (prefix.empty()) { 47 // Class name with a leading underscore and without dialect prefix 48 cppClassName = def.getName(); 49 } else if (cppClassName.empty()) { 50 // Class name without dialect prefix 51 cppClassName = prefix; 52 } 53 54 cppNamespace = def.getValueAsString("cppNamespace"); 55 56 populateOpStructure(); 57 assertInvariants(); 58 } 59 60 std::string Operator::getOperationName() const { 61 auto prefix = dialect.getName(); 62 auto opName = def.getValueAsString("opName"); 63 if (prefix.empty()) 64 return std::string(opName); 65 return std::string(llvm::formatv("{0}.{1}", prefix, opName)); 66 } 67 68 std::string Operator::getAdaptorName() const { 69 return std::string(llvm::formatv("{0}Adaptor", getCppClassName())); 70 } 71 72 void Operator::assertInvariants() const { 73 // Check that the name of arguments/results/regions/successors don't overlap. 74 DenseMap<StringRef, StringRef> existingNames; 75 auto checkName = [&](StringRef name, StringRef entity) { 76 if (name.empty()) 77 return; 78 auto insertion = existingNames.insert({name, entity}); 79 if (insertion.second) 80 return; 81 if (entity == insertion.first->second) 82 PrintFatalError(getLoc(), "op has a conflict with two " + entity + 83 " having the same name '" + name + "'"); 84 PrintFatalError(getLoc(), "op has a conflict with " + 85 insertion.first->second + " and " + entity + 86 " both having an entry with the name '" + 87 name + "'"); 88 }; 89 // Check operands amongst themselves. 90 for (int i : llvm::seq<int>(0, getNumOperands())) 91 checkName(getOperand(i).name, "operands"); 92 93 // Check results amongst themselves and against operands. 94 for (int i : llvm::seq<int>(0, getNumResults())) 95 checkName(getResult(i).name, "results"); 96 97 // Check regions amongst themselves and against operands and results. 98 for (int i : llvm::seq<int>(0, getNumRegions())) 99 checkName(getRegion(i).name, "regions"); 100 101 // Check successors amongst themselves and against operands, results, and 102 // regions. 103 for (int i : llvm::seq<int>(0, getNumSuccessors())) 104 checkName(getSuccessor(i).name, "successors"); 105 } 106 107 StringRef Operator::getDialectName() const { return dialect.getName(); } 108 109 StringRef Operator::getCppClassName() const { return cppClassName; } 110 111 std::string Operator::getQualCppClassName() const { 112 if (cppNamespace.empty()) 113 return std::string(cppClassName); 114 return std::string(llvm::formatv("{0}::{1}", cppNamespace, cppClassName)); 115 } 116 117 StringRef Operator::getCppNamespace() const { return cppNamespace; } 118 119 int Operator::getNumResults() const { 120 DagInit *results = def.getValueAsDag("results"); 121 return results->getNumArgs(); 122 } 123 124 StringRef Operator::getExtraClassDeclaration() const { 125 constexpr auto attr = "extraClassDeclaration"; 126 if (def.isValueUnset(attr)) 127 return {}; 128 return def.getValueAsString(attr); 129 } 130 131 const llvm::Record &Operator::getDef() const { return def; } 132 133 bool Operator::skipDefaultBuilders() const { 134 return def.getValueAsBit("skipDefaultBuilders"); 135 } 136 137 auto Operator::result_begin() -> value_iterator { return results.begin(); } 138 139 auto Operator::result_end() -> value_iterator { return results.end(); } 140 141 auto Operator::getResults() -> value_range { 142 return {result_begin(), result_end()}; 143 } 144 145 TypeConstraint Operator::getResultTypeConstraint(int index) const { 146 DagInit *results = def.getValueAsDag("results"); 147 return TypeConstraint(cast<DefInit>(results->getArg(index))); 148 } 149 150 StringRef Operator::getResultName(int index) const { 151 DagInit *results = def.getValueAsDag("results"); 152 return results->getArgNameStr(index); 153 } 154 155 auto Operator::getResultDecorators(int index) const -> var_decorator_range { 156 Record *result = 157 cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef(); 158 if (!result->isSubClassOf("OpVariable")) 159 return var_decorator_range(nullptr, nullptr); 160 return *result->getValueAsListInit("decorators"); 161 } 162 163 unsigned Operator::getNumVariableLengthResults() const { 164 return llvm::count_if(results, [](const NamedTypeConstraint &c) { 165 return c.constraint.isVariableLength(); 166 }); 167 } 168 169 unsigned Operator::getNumVariableLengthOperands() const { 170 return llvm::count_if(operands, [](const NamedTypeConstraint &c) { 171 return c.constraint.isVariableLength(); 172 }); 173 } 174 175 bool Operator::hasSingleVariadicArg() const { 176 return getNumArgs() == 1 && getArg(0).is<NamedTypeConstraint *>() && 177 getOperand(0).isVariadic(); 178 } 179 180 Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); } 181 182 Operator::arg_iterator Operator::arg_end() const { return arguments.end(); } 183 184 Operator::arg_range Operator::getArgs() const { 185 return {arg_begin(), arg_end()}; 186 } 187 188 StringRef Operator::getArgName(int index) const { 189 DagInit *argumentValues = def.getValueAsDag("arguments"); 190 return argumentValues->getArgNameStr(index); 191 } 192 193 auto Operator::getArgDecorators(int index) const -> var_decorator_range { 194 Record *arg = 195 cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef(); 196 if (!arg->isSubClassOf("OpVariable")) 197 return var_decorator_range(nullptr, nullptr); 198 return *arg->getValueAsListInit("decorators"); 199 } 200 201 const Trait *Operator::getTrait(StringRef trait) const { 202 for (const auto &t : traits) { 203 if (const auto *traitDef = dyn_cast<NativeTrait>(&t)) { 204 if (traitDef->getFullyQualifiedTraitName() == trait) 205 return traitDef; 206 } else if (const auto *traitDef = dyn_cast<InternalTrait>(&t)) { 207 if (traitDef->getFullyQualifiedTraitName() == trait) 208 return traitDef; 209 } else if (const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) { 210 if (traitDef->getFullyQualifiedTraitName() == trait) 211 return traitDef; 212 } 213 } 214 return nullptr; 215 } 216 217 auto Operator::region_begin() const -> const_region_iterator { 218 return regions.begin(); 219 } 220 auto Operator::region_end() const -> const_region_iterator { 221 return regions.end(); 222 } 223 auto Operator::getRegions() const 224 -> llvm::iterator_range<const_region_iterator> { 225 return {region_begin(), region_end()}; 226 } 227 228 unsigned Operator::getNumRegions() const { return regions.size(); } 229 230 const NamedRegion &Operator::getRegion(unsigned index) const { 231 return regions[index]; 232 } 233 234 unsigned Operator::getNumVariadicRegions() const { 235 return llvm::count_if(regions, 236 [](const NamedRegion &c) { return c.isVariadic(); }); 237 } 238 239 auto Operator::successor_begin() const -> const_successor_iterator { 240 return successors.begin(); 241 } 242 auto Operator::successor_end() const -> const_successor_iterator { 243 return successors.end(); 244 } 245 auto Operator::getSuccessors() const 246 -> llvm::iterator_range<const_successor_iterator> { 247 return {successor_begin(), successor_end()}; 248 } 249 250 unsigned Operator::getNumSuccessors() const { return successors.size(); } 251 252 const NamedSuccessor &Operator::getSuccessor(unsigned index) const { 253 return successors[index]; 254 } 255 256 unsigned Operator::getNumVariadicSuccessors() const { 257 return llvm::count_if(successors, 258 [](const NamedSuccessor &c) { return c.isVariadic(); }); 259 } 260 261 auto Operator::trait_begin() const -> const_trait_iterator { 262 return traits.begin(); 263 } 264 auto Operator::trait_end() const -> const_trait_iterator { 265 return traits.end(); 266 } 267 auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> { 268 return {trait_begin(), trait_end()}; 269 } 270 271 auto Operator::attribute_begin() const -> attribute_iterator { 272 return attributes.begin(); 273 } 274 auto Operator::attribute_end() const -> attribute_iterator { 275 return attributes.end(); 276 } 277 auto Operator::getAttributes() const 278 -> llvm::iterator_range<attribute_iterator> { 279 return {attribute_begin(), attribute_end()}; 280 } 281 282 auto Operator::operand_begin() -> value_iterator { return operands.begin(); } 283 auto Operator::operand_end() -> value_iterator { return operands.end(); } 284 auto Operator::getOperands() -> value_range { 285 return {operand_begin(), operand_end()}; 286 } 287 288 auto Operator::getArg(int index) const -> Argument { return arguments[index]; } 289 290 // Mapping from result index to combined argument and result index. Arguments 291 // are indexed to match getArg index, while the result indexes are mapped to 292 // avoid overlap. 293 static int resultIndex(int i) { return -1 - i; } 294 295 bool Operator::isVariadic() const { 296 return any_of(llvm::concat<const NamedTypeConstraint>(operands, results), 297 [](const NamedTypeConstraint &op) { return op.isVariadic(); }); 298 } 299 300 void Operator::populateTypeInferenceInfo( 301 const llvm::StringMap<int> &argumentsAndResultsIndex) { 302 // If the type inference op interface is not registered, then do not attempt 303 // to determine if the result types an be inferred. 304 auto &recordKeeper = def.getRecords(); 305 auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface); 306 allResultsHaveKnownTypes = false; 307 if (!inferTrait) 308 return; 309 310 // If there are no results, the skip this else the build method generated 311 // overlaps with another autogenerated builder. 312 if (getNumResults() == 0) 313 return; 314 315 // Skip for ops with variadic operands/results. 316 // TODO: This can be relaxed. 317 if (isVariadic()) 318 return; 319 320 // Skip cases currently being custom generated. 321 // TODO: Remove special cases. 322 if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) 323 return; 324 325 // We create equivalence classes of argument/result types where arguments 326 // and results are mapped into the same index space and indices corresponding 327 // to the same type are in the same equivalence class. 328 llvm::EquivalenceClasses<int> ecs; 329 resultTypeMapping.resize(getNumResults()); 330 // Captures the argument whose type matches a given result type. Preference 331 // towards capturing operands first before attributes. 332 auto captureMapping = [&](int i) { 333 bool found = false; 334 ecs.insert(resultIndex(i)); 335 auto mi = ecs.findLeader(resultIndex(i)); 336 for (auto me = ecs.member_end(); mi != me; ++mi) { 337 if (*mi < 0) { 338 auto tc = getResultTypeConstraint(i); 339 if (tc.getBuilderCall().hasValue()) { 340 resultTypeMapping[i].emplace_back(tc); 341 found = true; 342 } 343 continue; 344 } 345 346 if (getArg(*mi).is<NamedAttribute *>()) { 347 // TODO: Handle attributes. 348 continue; 349 } else { 350 resultTypeMapping[i].emplace_back(*mi); 351 found = true; 352 } 353 } 354 return found; 355 }; 356 357 for (const Trait &trait : traits) { 358 const llvm::Record &def = trait.getDef(); 359 // If the infer type op interface was manually added, then treat it as 360 // intention that the op needs special handling. 361 // TODO: Reconsider whether to always generate, this is more conservative 362 // and keeps existing behavior so starting that way for now. 363 if (def.isSubClassOf( 364 llvm::formatv("{0}::Trait", inferTypeOpInterface).str())) 365 return; 366 if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait)) 367 if (&traitDef->getDef() == inferTrait) 368 return; 369 370 if (!def.isSubClassOf("AllTypesMatch")) 371 continue; 372 373 auto values = def.getValueAsListOfStrings("values"); 374 auto root = argumentsAndResultsIndex.lookup(values.front()); 375 for (StringRef str : values) 376 ecs.unionSets(argumentsAndResultsIndex.lookup(str), root); 377 } 378 379 // Verifies that all output types have a corresponding known input type 380 // and chooses matching operand or attribute (in that order) that 381 // matches it. 382 allResultsHaveKnownTypes = 383 all_of(llvm::seq<int>(0, getNumResults()), captureMapping); 384 385 // If the types could be computed, then add type inference trait. 386 if (allResultsHaveKnownTypes) 387 traits.push_back(Trait::create(inferTrait->getDefInit())); 388 } 389 390 void Operator::populateOpStructure() { 391 auto &recordKeeper = def.getRecords(); 392 auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint"); 393 auto *attrClass = recordKeeper.getClass("Attr"); 394 auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr"); 395 auto *opVarClass = recordKeeper.getClass("OpVariable"); 396 numNativeAttributes = 0; 397 398 DagInit *argumentValues = def.getValueAsDag("arguments"); 399 unsigned numArgs = argumentValues->getNumArgs(); 400 401 // Mapping from name of to argument or result index. Arguments are indexed 402 // to match getArg index, while the results are negatively indexed. 403 llvm::StringMap<int> argumentsAndResultsIndex; 404 405 // Handle operands and native attributes. 406 for (unsigned i = 0; i != numArgs; ++i) { 407 auto *arg = argumentValues->getArg(i); 408 auto givenName = argumentValues->getArgNameStr(i); 409 auto *argDefInit = dyn_cast<DefInit>(arg); 410 if (!argDefInit) 411 PrintFatalError(def.getLoc(), 412 Twine("undefined type for argument #") + Twine(i)); 413 Record *argDef = argDefInit->getDef(); 414 if (argDef->isSubClassOf(opVarClass)) 415 argDef = argDef->getValueAsDef("constraint"); 416 417 if (argDef->isSubClassOf(typeConstraintClass)) { 418 operands.push_back( 419 NamedTypeConstraint{givenName, TypeConstraint(argDef)}); 420 } else if (argDef->isSubClassOf(attrClass)) { 421 if (givenName.empty()) 422 PrintFatalError(argDef->getLoc(), "attributes must be named"); 423 if (argDef->isSubClassOf(derivedAttrClass)) 424 PrintFatalError(argDef->getLoc(), 425 "derived attributes not allowed in argument list"); 426 attributes.push_back({givenName, Attribute(argDef)}); 427 ++numNativeAttributes; 428 } else { 429 PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving " 430 "from TypeConstraint or Attr are allowed"); 431 } 432 if (!givenName.empty()) 433 argumentsAndResultsIndex[givenName] = i; 434 } 435 436 // Handle derived attributes. 437 for (const auto &val : def.getValues()) { 438 if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) { 439 if (!record->isSubClassOf(attrClass)) 440 continue; 441 if (!record->isSubClassOf(derivedAttrClass)) 442 PrintFatalError(def.getLoc(), 443 "unexpected Attr where only DerivedAttr is allowed"); 444 445 if (record->getClasses().size() != 1) { 446 PrintFatalError( 447 def.getLoc(), 448 "unsupported attribute modelling, only single class expected"); 449 } 450 attributes.push_back( 451 {cast<llvm::StringInit>(val.getNameInit())->getValue(), 452 Attribute(cast<DefInit>(val.getValue()))}); 453 } 454 } 455 456 // Populate `arguments`. This must happen after we've finalized `operands` and 457 // `attributes` because we will put their elements' pointers in `arguments`. 458 // SmallVector may perform re-allocation under the hood when adding new 459 // elements. 460 int operandIndex = 0, attrIndex = 0; 461 for (unsigned i = 0; i != numArgs; ++i) { 462 Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef(); 463 if (argDef->isSubClassOf(opVarClass)) 464 argDef = argDef->getValueAsDef("constraint"); 465 466 if (argDef->isSubClassOf(typeConstraintClass)) { 467 attrOrOperandMapping.push_back( 468 {OperandOrAttribute::Kind::Operand, operandIndex}); 469 arguments.emplace_back(&operands[operandIndex++]); 470 } else { 471 assert(argDef->isSubClassOf(attrClass)); 472 attrOrOperandMapping.push_back( 473 {OperandOrAttribute::Kind::Attribute, attrIndex}); 474 arguments.emplace_back(&attributes[attrIndex++]); 475 } 476 } 477 478 auto *resultsDag = def.getValueAsDag("results"); 479 auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator()); 480 if (!outsOp || outsOp->getDef()->getName() != "outs") { 481 PrintFatalError(def.getLoc(), "'results' must have 'outs' directive"); 482 } 483 484 // Handle results. 485 for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) { 486 auto name = resultsDag->getArgNameStr(i); 487 auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i)); 488 if (!resultInit) { 489 PrintFatalError(def.getLoc(), 490 Twine("undefined type for result #") + Twine(i)); 491 } 492 auto *resultDef = resultInit->getDef(); 493 if (resultDef->isSubClassOf(opVarClass)) 494 resultDef = resultDef->getValueAsDef("constraint"); 495 results.push_back({name, TypeConstraint(resultDef)}); 496 if (!name.empty()) 497 argumentsAndResultsIndex[name] = resultIndex(i); 498 499 // We currently only support VariadicOfVariadic operands. 500 if (results.back().constraint.isVariadicOfVariadic()) { 501 PrintFatalError( 502 def.getLoc(), 503 "'VariadicOfVariadic' results are currently not supported"); 504 } 505 } 506 507 // Handle successors 508 auto *successorsDag = def.getValueAsDag("successors"); 509 auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator()); 510 if (!successorsOp || successorsOp->getDef()->getName() != "successor") { 511 PrintFatalError(def.getLoc(), 512 "'successors' must have 'successor' directive"); 513 } 514 515 for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) { 516 auto name = successorsDag->getArgNameStr(i); 517 auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i)); 518 if (!successorInit) { 519 PrintFatalError(def.getLoc(), 520 Twine("undefined kind for successor #") + Twine(i)); 521 } 522 Successor successor(successorInit->getDef()); 523 524 // Only support variadic successors if it is the last one for now. 525 if (i != e - 1 && successor.isVariadic()) 526 PrintFatalError(def.getLoc(), "only the last successor can be variadic"); 527 successors.push_back({name, successor}); 528 } 529 530 // Create list of traits, skipping over duplicates: appending to lists in 531 // tablegen is easy, making them unique less so, so dedupe here. 532 if (auto *traitList = def.getValueAsListInit("traits")) { 533 // This is uniquing based on pointers of the trait. 534 SmallPtrSet<const llvm::Init *, 32> traitSet; 535 traits.reserve(traitSet.size()); 536 537 std::function<void(llvm::ListInit *)> insert; 538 insert = [&](llvm::ListInit *traitList) { 539 for (auto *traitInit : *traitList) { 540 auto *def = cast<DefInit>(traitInit)->getDef(); 541 if (def->isSubClassOf("OpTraitList")) { 542 insert(def->getValueAsListInit("traits")); 543 continue; 544 } 545 // Keep traits in the same order while skipping over duplicates. 546 if (traitSet.insert(traitInit).second) 547 traits.push_back(Trait::create(traitInit)); 548 } 549 }; 550 insert(traitList); 551 } 552 553 populateTypeInferenceInfo(argumentsAndResultsIndex); 554 555 // Handle regions 556 auto *regionsDag = def.getValueAsDag("regions"); 557 auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator()); 558 if (!regionsOp || regionsOp->getDef()->getName() != "region") { 559 PrintFatalError(def.getLoc(), "'regions' must have 'region' directive"); 560 } 561 562 for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) { 563 auto name = regionsDag->getArgNameStr(i); 564 auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i)); 565 if (!regionInit) { 566 PrintFatalError(def.getLoc(), 567 Twine("undefined kind for region #") + Twine(i)); 568 } 569 Region region(regionInit->getDef()); 570 if (region.isVariadic()) { 571 // Only support variadic regions if it is the last one for now. 572 if (i != e - 1) 573 PrintFatalError(def.getLoc(), "only the last region can be variadic"); 574 if (name.empty()) 575 PrintFatalError(def.getLoc(), "variadic regions must be named"); 576 } 577 578 regions.push_back({name, region}); 579 } 580 581 // Populate the builders. 582 auto *builderList = 583 dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders")); 584 if (builderList && !builderList->empty()) { 585 for (llvm::Init *init : builderList->getValues()) 586 builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc()); 587 } else if (skipDefaultBuilders()) { 588 PrintFatalError( 589 def.getLoc(), 590 "default builders are skipped and no custom builders provided"); 591 } 592 593 LLVM_DEBUG(print(llvm::dbgs())); 594 } 595 596 auto Operator::getSameTypeAsResult(int index) const -> ArrayRef<ArgOrType> { 597 assert(allResultTypesKnown()); 598 return resultTypeMapping[index]; 599 } 600 601 ArrayRef<llvm::SMLoc> Operator::getLoc() const { return def.getLoc(); } 602 603 bool Operator::hasDescription() const { 604 return def.getValue("description") != nullptr; 605 } 606 607 StringRef Operator::getDescription() const { 608 return def.getValueAsString("description"); 609 } 610 611 bool Operator::hasSummary() const { return def.getValue("summary") != nullptr; } 612 613 StringRef Operator::getSummary() const { 614 return def.getValueAsString("summary"); 615 } 616 617 bool Operator::hasAssemblyFormat() const { 618 auto *valueInit = def.getValueInit("assemblyFormat"); 619 return isa<llvm::StringInit>(valueInit); 620 } 621 622 StringRef Operator::getAssemblyFormat() const { 623 return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat")) 624 .Case<llvm::StringInit>([&](auto *init) { return init->getValue(); }); 625 } 626 627 void Operator::print(llvm::raw_ostream &os) const { 628 os << "op '" << getOperationName() << "'\n"; 629 for (Argument arg : arguments) { 630 if (auto *attr = arg.dyn_cast<NamedAttribute *>()) 631 os << "[attribute] " << attr->name << '\n'; 632 else 633 os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n'; 634 } 635 } 636 637 auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init) 638 -> VariableDecorator { 639 return VariableDecorator(cast<llvm::DefInit>(init)->getDef()); 640 } 641 642 auto Operator::getArgToOperandOrAttribute(int index) const 643 -> OperandOrAttribute { 644 return attrOrOperandMapping[index]; 645 } 646 647 // Helper to return the names for accessor. 648 static SmallVector<std::string, 2> 649 getGetterOrSetterNames(bool isGetter, const Operator &op, StringRef name) { 650 Dialect::EmitPrefix prefixType = op.getDialect().getEmitAccessorPrefix(); 651 std::string prefix; 652 if (prefixType != Dialect::EmitPrefix::Raw) 653 prefix = isGetter ? "get" : "set"; 654 655 SmallVector<std::string, 2> names; 656 bool rawToo = prefixType == Dialect::EmitPrefix::Both; 657 658 // Whether to skip generating prefixed form for argument. This just does some 659 // basic checks. 660 // 661 // There are a little bit more invasive checks possible for cases where not 662 // all ops have the trait that would cause overlap. For many cases here, 663 // renaming would be better (e.g., we can only guard in limited manner against 664 // methods from traits and interfaces here, so avoiding these in op definition 665 // is safer). 666 auto skip = [&](StringRef newName) { 667 bool shouldSkip = newName == "getAttributeNames" || 668 newName == "getAttributes" || newName == "getOperation" || 669 newName == "getType"; 670 if (newName == "getOperands") { 671 // To reduce noise, skip generating the prefixed form and the warning if 672 // $operands correspond to single variadic argument. 673 if (op.getNumOperands() == 1 && op.getNumVariableLengthOperands() == 1) 674 return true; 675 shouldSkip = true; 676 } 677 if (newName == "getRegions") { 678 if (op.getNumRegions() == 1 && op.getNumVariadicRegions() == 1) 679 return true; 680 shouldSkip = true; 681 } 682 if (!shouldSkip) 683 return false; 684 685 // This note could be avoided where the final function generated would 686 // have been identical. But preferably in the op definition avoiding using 687 // the generic name and then getting a more specialize type is better. 688 PrintNote(op.getLoc(), 689 "Skipping generation of prefixed accessor `" + newName + 690 "` as it overlaps with default one; generating raw form (`" + 691 name + "`) still"); 692 return true; 693 }; 694 695 if (!prefix.empty()) { 696 names.push_back( 697 prefix + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true)); 698 // Skip cases which would overlap with default ones for now. 699 if (skip(names.back())) { 700 rawToo = true; 701 names.clear(); 702 } else if (rawToo) { 703 LLVM_DEBUG(llvm::errs() << "WITH_GETTER(\"" << op.getQualCppClassName() 704 << "::" << name << "\")\n" 705 << "WITH_GETTER(\"" << op.getQualCppClassName() 706 << "Adaptor::" << name << "\")\n";); 707 } 708 } 709 710 if (prefix.empty() || rawToo) 711 names.push_back(name.str()); 712 return names; 713 } 714 715 SmallVector<std::string, 2> Operator::getGetterNames(StringRef name) const { 716 return getGetterOrSetterNames(/*isGetter=*/true, *this, name); 717 } 718 719 SmallVector<std::string, 2> Operator::getSetterNames(StringRef name) const { 720 return getGetterOrSetterNames(/*isGetter=*/false, *this, name); 721 } 722