1 //===- Operator.cpp - Operator class --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/TableGen/Operator.h" 14 #include "mlir/TableGen/Predicate.h" 15 #include "mlir/TableGen/Trait.h" 16 #include "mlir/TableGen/Type.h" 17 #include "llvm/ADT/EquivalenceClasses.h" 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/ADT/Sequence.h" 20 #include "llvm/ADT/SmallPtrSet.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include "llvm/TableGen/Error.h" 26 #include "llvm/TableGen/Record.h" 27 28 #define DEBUG_TYPE "mlir-tblgen-operator" 29 30 using namespace mlir; 31 using namespace mlir::tblgen; 32 33 using llvm::DagInit; 34 using llvm::DefInit; 35 using llvm::Record; 36 37 Operator::Operator(const llvm::Record &def) 38 : dialect(def.getValueAsDef("opDialect")), def(def) { 39 // The first `_` in the op's TableGen def name is treated as separating the 40 // dialect prefix and the op class name. The dialect prefix will be ignored if 41 // not empty. Otherwise, if def name starts with a `_`, the `_` is considered 42 // as part of the class name. 43 StringRef prefix; 44 std::tie(prefix, cppClassName) = def.getName().split('_'); 45 if (prefix.empty()) { 46 // Class name with a leading underscore and without dialect prefix 47 cppClassName = def.getName(); 48 } else if (cppClassName.empty()) { 49 // Class name without dialect prefix 50 cppClassName = prefix; 51 } 52 53 cppNamespace = def.getValueAsString("cppNamespace"); 54 55 populateOpStructure(); 56 } 57 58 std::string Operator::getOperationName() const { 59 auto prefix = dialect.getName(); 60 auto opName = def.getValueAsString("opName"); 61 if (prefix.empty()) 62 return std::string(opName); 63 return std::string(llvm::formatv("{0}.{1}", prefix, opName)); 64 } 65 66 std::string Operator::getAdaptorName() const { 67 return std::string(llvm::formatv("{0}Adaptor", getCppClassName())); 68 } 69 70 StringRef Operator::getDialectName() const { return dialect.getName(); } 71 72 StringRef Operator::getCppClassName() const { return cppClassName; } 73 74 std::string Operator::getQualCppClassName() const { 75 if (cppNamespace.empty()) 76 return std::string(cppClassName); 77 return std::string(llvm::formatv("{0}::{1}", cppNamespace, cppClassName)); 78 } 79 80 StringRef Operator::getCppNamespace() const { return cppNamespace; } 81 82 int Operator::getNumResults() const { 83 DagInit *results = def.getValueAsDag("results"); 84 return results->getNumArgs(); 85 } 86 87 StringRef Operator::getExtraClassDeclaration() const { 88 constexpr auto attr = "extraClassDeclaration"; 89 if (def.isValueUnset(attr)) 90 return {}; 91 return def.getValueAsString(attr); 92 } 93 94 const llvm::Record &Operator::getDef() const { return def; } 95 96 bool Operator::skipDefaultBuilders() const { 97 return def.getValueAsBit("skipDefaultBuilders"); 98 } 99 100 auto Operator::result_begin() -> value_iterator { return results.begin(); } 101 102 auto Operator::result_end() -> value_iterator { return results.end(); } 103 104 auto Operator::getResults() -> value_range { 105 return {result_begin(), result_end()}; 106 } 107 108 TypeConstraint Operator::getResultTypeConstraint(int index) const { 109 DagInit *results = def.getValueAsDag("results"); 110 return TypeConstraint(cast<DefInit>(results->getArg(index))); 111 } 112 113 StringRef Operator::getResultName(int index) const { 114 DagInit *results = def.getValueAsDag("results"); 115 return results->getArgNameStr(index); 116 } 117 118 auto Operator::getResultDecorators(int index) const -> var_decorator_range { 119 Record *result = 120 cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef(); 121 if (!result->isSubClassOf("OpVariable")) 122 return var_decorator_range(nullptr, nullptr); 123 return *result->getValueAsListInit("decorators"); 124 } 125 126 unsigned Operator::getNumVariableLengthResults() const { 127 return llvm::count_if(results, [](const NamedTypeConstraint &c) { 128 return c.constraint.isVariableLength(); 129 }); 130 } 131 132 unsigned Operator::getNumVariableLengthOperands() const { 133 return llvm::count_if(operands, [](const NamedTypeConstraint &c) { 134 return c.constraint.isVariableLength(); 135 }); 136 } 137 138 bool Operator::hasSingleVariadicArg() const { 139 return getNumArgs() == 1 && getArg(0).is<NamedTypeConstraint *>() && 140 getOperand(0).isVariadic(); 141 } 142 143 Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); } 144 145 Operator::arg_iterator Operator::arg_end() const { return arguments.end(); } 146 147 Operator::arg_range Operator::getArgs() const { 148 return {arg_begin(), arg_end()}; 149 } 150 151 StringRef Operator::getArgName(int index) const { 152 DagInit *argumentValues = def.getValueAsDag("arguments"); 153 return argumentValues->getArgNameStr(index); 154 } 155 156 auto Operator::getArgDecorators(int index) const -> var_decorator_range { 157 Record *arg = 158 cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef(); 159 if (!arg->isSubClassOf("OpVariable")) 160 return var_decorator_range(nullptr, nullptr); 161 return *arg->getValueAsListInit("decorators"); 162 } 163 164 const Trait *Operator::getTrait(StringRef trait) const { 165 for (const auto &t : traits) { 166 if (const auto *traitDef = dyn_cast<NativeTrait>(&t)) { 167 if (traitDef->getFullyQualifiedTraitName() == trait) 168 return traitDef; 169 } else if (const auto *traitDef = dyn_cast<InternalTrait>(&t)) { 170 if (traitDef->getFullyQualifiedTraitName() == trait) 171 return traitDef; 172 } else if (const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) { 173 if (traitDef->getFullyQualifiedTraitName() == trait) 174 return traitDef; 175 } 176 } 177 return nullptr; 178 } 179 180 auto Operator::region_begin() const -> const_region_iterator { 181 return regions.begin(); 182 } 183 auto Operator::region_end() const -> const_region_iterator { 184 return regions.end(); 185 } 186 auto Operator::getRegions() const 187 -> llvm::iterator_range<const_region_iterator> { 188 return {region_begin(), region_end()}; 189 } 190 191 unsigned Operator::getNumRegions() const { return regions.size(); } 192 193 const NamedRegion &Operator::getRegion(unsigned index) const { 194 return regions[index]; 195 } 196 197 unsigned Operator::getNumVariadicRegions() const { 198 return llvm::count_if(regions, 199 [](const NamedRegion &c) { return c.isVariadic(); }); 200 } 201 202 auto Operator::successor_begin() const -> const_successor_iterator { 203 return successors.begin(); 204 } 205 auto Operator::successor_end() const -> const_successor_iterator { 206 return successors.end(); 207 } 208 auto Operator::getSuccessors() const 209 -> llvm::iterator_range<const_successor_iterator> { 210 return {successor_begin(), successor_end()}; 211 } 212 213 unsigned Operator::getNumSuccessors() const { return successors.size(); } 214 215 const NamedSuccessor &Operator::getSuccessor(unsigned index) const { 216 return successors[index]; 217 } 218 219 unsigned Operator::getNumVariadicSuccessors() const { 220 return llvm::count_if(successors, 221 [](const NamedSuccessor &c) { return c.isVariadic(); }); 222 } 223 224 auto Operator::trait_begin() const -> const_trait_iterator { 225 return traits.begin(); 226 } 227 auto Operator::trait_end() const -> const_trait_iterator { 228 return traits.end(); 229 } 230 auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> { 231 return {trait_begin(), trait_end()}; 232 } 233 234 auto Operator::attribute_begin() const -> attribute_iterator { 235 return attributes.begin(); 236 } 237 auto Operator::attribute_end() const -> attribute_iterator { 238 return attributes.end(); 239 } 240 auto Operator::getAttributes() const 241 -> llvm::iterator_range<attribute_iterator> { 242 return {attribute_begin(), attribute_end()}; 243 } 244 245 auto Operator::operand_begin() -> value_iterator { return operands.begin(); } 246 auto Operator::operand_end() -> value_iterator { return operands.end(); } 247 auto Operator::getOperands() -> value_range { 248 return {operand_begin(), operand_end()}; 249 } 250 251 auto Operator::getArg(int index) const -> Argument { return arguments[index]; } 252 253 // Mapping from result index to combined argument and result index. Arguments 254 // are indexed to match getArg index, while the result indexes are mapped to 255 // avoid overlap. 256 static int resultIndex(int i) { return -1 - i; } 257 258 bool Operator::isVariadic() const { 259 return any_of(llvm::concat<const NamedTypeConstraint>(operands, results), 260 [](const NamedTypeConstraint &op) { return op.isVariadic(); }); 261 } 262 263 void Operator::populateTypeInferenceInfo( 264 const llvm::StringMap<int> &argumentsAndResultsIndex) { 265 // If the type inference op interface is not registered, then do not attempt 266 // to determine if the result types an be inferred. 267 auto &recordKeeper = def.getRecords(); 268 auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface); 269 allResultsHaveKnownTypes = false; 270 if (!inferTrait) 271 return; 272 273 // If there are no results, the skip this else the build method generated 274 // overlaps with another autogenerated builder. 275 if (getNumResults() == 0) 276 return; 277 278 // Skip for ops with variadic operands/results. 279 // TODO: This can be relaxed. 280 if (isVariadic()) 281 return; 282 283 // Skip cases currently being custom generated. 284 // TODO: Remove special cases. 285 if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) 286 return; 287 288 // We create equivalence classes of argument/result types where arguments 289 // and results are mapped into the same index space and indices corresponding 290 // to the same type are in the same equivalence class. 291 llvm::EquivalenceClasses<int> ecs; 292 resultTypeMapping.resize(getNumResults()); 293 // Captures the argument whose type matches a given result type. Preference 294 // towards capturing operands first before attributes. 295 auto captureMapping = [&](int i) { 296 bool found = false; 297 ecs.insert(resultIndex(i)); 298 auto mi = ecs.findLeader(resultIndex(i)); 299 for (auto me = ecs.member_end(); mi != me; ++mi) { 300 if (*mi < 0) { 301 auto tc = getResultTypeConstraint(i); 302 if (tc.getBuilderCall().hasValue()) { 303 resultTypeMapping[i].emplace_back(tc); 304 found = true; 305 } 306 continue; 307 } 308 309 if (getArg(*mi).is<NamedAttribute *>()) { 310 // TODO: Handle attributes. 311 continue; 312 } else { 313 resultTypeMapping[i].emplace_back(*mi); 314 found = true; 315 } 316 } 317 return found; 318 }; 319 320 for (const Trait &trait : traits) { 321 const llvm::Record &def = trait.getDef(); 322 // If the infer type op interface was manually added, then treat it as 323 // intention that the op needs special handling. 324 // TODO: Reconsider whether to always generate, this is more conservative 325 // and keeps existing behavior so starting that way for now. 326 if (def.isSubClassOf( 327 llvm::formatv("{0}::Trait", inferTypeOpInterface).str())) 328 return; 329 if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait)) 330 if (&traitDef->getDef() == inferTrait) 331 return; 332 333 if (!def.isSubClassOf("AllTypesMatch")) 334 continue; 335 336 auto values = def.getValueAsListOfStrings("values"); 337 auto root = argumentsAndResultsIndex.lookup(values.front()); 338 for (StringRef str : values) 339 ecs.unionSets(argumentsAndResultsIndex.lookup(str), root); 340 } 341 342 // Verifies that all output types have a corresponding known input type 343 // and chooses matching operand or attribute (in that order) that 344 // matches it. 345 allResultsHaveKnownTypes = 346 all_of(llvm::seq<int>(0, getNumResults()), captureMapping); 347 348 // If the types could be computed, then add type inference trait. 349 if (allResultsHaveKnownTypes) 350 traits.push_back(Trait::create(inferTrait->getDefInit())); 351 } 352 353 void Operator::populateOpStructure() { 354 auto &recordKeeper = def.getRecords(); 355 auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint"); 356 auto *attrClass = recordKeeper.getClass("Attr"); 357 auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr"); 358 auto *opVarClass = recordKeeper.getClass("OpVariable"); 359 numNativeAttributes = 0; 360 361 DagInit *argumentValues = def.getValueAsDag("arguments"); 362 unsigned numArgs = argumentValues->getNumArgs(); 363 364 // Mapping from name of to argument or result index. Arguments are indexed 365 // to match getArg index, while the results are negatively indexed. 366 llvm::StringMap<int> argumentsAndResultsIndex; 367 368 // Handle operands and native attributes. 369 for (unsigned i = 0; i != numArgs; ++i) { 370 auto *arg = argumentValues->getArg(i); 371 auto givenName = argumentValues->getArgNameStr(i); 372 auto *argDefInit = dyn_cast<DefInit>(arg); 373 if (!argDefInit) 374 PrintFatalError(def.getLoc(), 375 Twine("undefined type for argument #") + Twine(i)); 376 Record *argDef = argDefInit->getDef(); 377 if (argDef->isSubClassOf(opVarClass)) 378 argDef = argDef->getValueAsDef("constraint"); 379 380 if (argDef->isSubClassOf(typeConstraintClass)) { 381 operands.push_back( 382 NamedTypeConstraint{givenName, TypeConstraint(argDef)}); 383 } else if (argDef->isSubClassOf(attrClass)) { 384 if (givenName.empty()) 385 PrintFatalError(argDef->getLoc(), "attributes must be named"); 386 if (argDef->isSubClassOf(derivedAttrClass)) 387 PrintFatalError(argDef->getLoc(), 388 "derived attributes not allowed in argument list"); 389 attributes.push_back({givenName, Attribute(argDef)}); 390 ++numNativeAttributes; 391 } else { 392 PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving " 393 "from TypeConstraint or Attr are allowed"); 394 } 395 if (!givenName.empty()) 396 argumentsAndResultsIndex[givenName] = i; 397 } 398 399 // Handle derived attributes. 400 for (const auto &val : def.getValues()) { 401 if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) { 402 if (!record->isSubClassOf(attrClass)) 403 continue; 404 if (!record->isSubClassOf(derivedAttrClass)) 405 PrintFatalError(def.getLoc(), 406 "unexpected Attr where only DerivedAttr is allowed"); 407 408 if (record->getClasses().size() != 1) { 409 PrintFatalError( 410 def.getLoc(), 411 "unsupported attribute modelling, only single class expected"); 412 } 413 attributes.push_back( 414 {cast<llvm::StringInit>(val.getNameInit())->getValue(), 415 Attribute(cast<DefInit>(val.getValue()))}); 416 } 417 } 418 419 // Populate `arguments`. This must happen after we've finalized `operands` and 420 // `attributes` because we will put their elements' pointers in `arguments`. 421 // SmallVector may perform re-allocation under the hood when adding new 422 // elements. 423 int operandIndex = 0, attrIndex = 0; 424 for (unsigned i = 0; i != numArgs; ++i) { 425 Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef(); 426 if (argDef->isSubClassOf(opVarClass)) 427 argDef = argDef->getValueAsDef("constraint"); 428 429 if (argDef->isSubClassOf(typeConstraintClass)) { 430 attrOrOperandMapping.push_back( 431 {OperandOrAttribute::Kind::Operand, operandIndex}); 432 arguments.emplace_back(&operands[operandIndex++]); 433 } else { 434 assert(argDef->isSubClassOf(attrClass)); 435 attrOrOperandMapping.push_back( 436 {OperandOrAttribute::Kind::Attribute, attrIndex}); 437 arguments.emplace_back(&attributes[attrIndex++]); 438 } 439 } 440 441 auto *resultsDag = def.getValueAsDag("results"); 442 auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator()); 443 if (!outsOp || outsOp->getDef()->getName() != "outs") { 444 PrintFatalError(def.getLoc(), "'results' must have 'outs' directive"); 445 } 446 447 // Handle results. 448 for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) { 449 auto name = resultsDag->getArgNameStr(i); 450 auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i)); 451 if (!resultInit) { 452 PrintFatalError(def.getLoc(), 453 Twine("undefined type for result #") + Twine(i)); 454 } 455 auto *resultDef = resultInit->getDef(); 456 if (resultDef->isSubClassOf(opVarClass)) 457 resultDef = resultDef->getValueAsDef("constraint"); 458 results.push_back({name, TypeConstraint(resultDef)}); 459 if (!name.empty()) 460 argumentsAndResultsIndex[name] = resultIndex(i); 461 462 // We currently only support VariadicOfVariadic operands. 463 if (results.back().constraint.isVariadicOfVariadic()) { 464 PrintFatalError( 465 def.getLoc(), 466 "'VariadicOfVariadic' results are currently not supported"); 467 } 468 } 469 470 // Handle successors 471 auto *successorsDag = def.getValueAsDag("successors"); 472 auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator()); 473 if (!successorsOp || successorsOp->getDef()->getName() != "successor") { 474 PrintFatalError(def.getLoc(), 475 "'successors' must have 'successor' directive"); 476 } 477 478 for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) { 479 auto name = successorsDag->getArgNameStr(i); 480 auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i)); 481 if (!successorInit) { 482 PrintFatalError(def.getLoc(), 483 Twine("undefined kind for successor #") + Twine(i)); 484 } 485 Successor successor(successorInit->getDef()); 486 487 // Only support variadic successors if it is the last one for now. 488 if (i != e - 1 && successor.isVariadic()) 489 PrintFatalError(def.getLoc(), "only the last successor can be variadic"); 490 successors.push_back({name, successor}); 491 } 492 493 // Create list of traits, skipping over duplicates: appending to lists in 494 // tablegen is easy, making them unique less so, so dedupe here. 495 if (auto *traitList = def.getValueAsListInit("traits")) { 496 // This is uniquing based on pointers of the trait. 497 SmallPtrSet<const llvm::Init *, 32> traitSet; 498 traits.reserve(traitSet.size()); 499 500 std::function<void(llvm::ListInit *)> insert; 501 insert = [&](llvm::ListInit *traitList) { 502 for (auto *traitInit : *traitList) { 503 auto *def = cast<DefInit>(traitInit)->getDef(); 504 if (def->isSubClassOf("OpTraitList")) { 505 insert(def->getValueAsListInit("traits")); 506 continue; 507 } 508 // Keep traits in the same order while skipping over duplicates. 509 if (traitSet.insert(traitInit).second) 510 traits.push_back(Trait::create(traitInit)); 511 } 512 }; 513 insert(traitList); 514 } 515 516 populateTypeInferenceInfo(argumentsAndResultsIndex); 517 518 // Handle regions 519 auto *regionsDag = def.getValueAsDag("regions"); 520 auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator()); 521 if (!regionsOp || regionsOp->getDef()->getName() != "region") { 522 PrintFatalError(def.getLoc(), "'regions' must have 'region' directive"); 523 } 524 525 for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) { 526 auto name = regionsDag->getArgNameStr(i); 527 auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i)); 528 if (!regionInit) { 529 PrintFatalError(def.getLoc(), 530 Twine("undefined kind for region #") + Twine(i)); 531 } 532 Region region(regionInit->getDef()); 533 if (region.isVariadic()) { 534 // Only support variadic regions if it is the last one for now. 535 if (i != e - 1) 536 PrintFatalError(def.getLoc(), "only the last region can be variadic"); 537 if (name.empty()) 538 PrintFatalError(def.getLoc(), "variadic regions must be named"); 539 } 540 541 regions.push_back({name, region}); 542 } 543 544 // Populate the builders. 545 auto *builderList = 546 dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders")); 547 if (builderList && !builderList->empty()) { 548 for (llvm::Init *init : builderList->getValues()) 549 builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc()); 550 } else if (skipDefaultBuilders()) { 551 PrintFatalError( 552 def.getLoc(), 553 "default builders are skipped and no custom builders provided"); 554 } 555 556 LLVM_DEBUG(print(llvm::dbgs())); 557 } 558 559 auto Operator::getSameTypeAsResult(int index) const -> ArrayRef<ArgOrType> { 560 assert(allResultTypesKnown()); 561 return resultTypeMapping[index]; 562 } 563 564 ArrayRef<llvm::SMLoc> Operator::getLoc() const { return def.getLoc(); } 565 566 bool Operator::hasDescription() const { 567 return def.getValue("description") != nullptr; 568 } 569 570 StringRef Operator::getDescription() const { 571 return def.getValueAsString("description"); 572 } 573 574 bool Operator::hasSummary() const { return def.getValue("summary") != nullptr; } 575 576 StringRef Operator::getSummary() const { 577 return def.getValueAsString("summary"); 578 } 579 580 bool Operator::hasAssemblyFormat() const { 581 auto *valueInit = def.getValueInit("assemblyFormat"); 582 return isa<llvm::StringInit>(valueInit); 583 } 584 585 StringRef Operator::getAssemblyFormat() const { 586 return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat")) 587 .Case<llvm::StringInit>([&](auto *init) { return init->getValue(); }); 588 } 589 590 void Operator::print(llvm::raw_ostream &os) const { 591 os << "op '" << getOperationName() << "'\n"; 592 for (Argument arg : arguments) { 593 if (auto *attr = arg.dyn_cast<NamedAttribute *>()) 594 os << "[attribute] " << attr->name << '\n'; 595 else 596 os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n'; 597 } 598 } 599 600 auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init) 601 -> VariableDecorator { 602 return VariableDecorator(cast<llvm::DefInit>(init)->getDef()); 603 } 604 605 auto Operator::getArgToOperandOrAttribute(int index) const 606 -> OperandOrAttribute { 607 return attrOrOperandMapping[index]; 608 } 609