1 //===- Operator.cpp - Operator class --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/TableGen/Operator.h" 14 #include "mlir/TableGen/OpTrait.h" 15 #include "mlir/TableGen/Predicate.h" 16 #include "mlir/TableGen/Type.h" 17 #include "llvm/ADT/EquivalenceClasses.h" 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/ADT/Sequence.h" 20 #include "llvm/ADT/SmallPtrSet.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include "llvm/TableGen/Error.h" 26 #include "llvm/TableGen/Record.h" 27 28 #define DEBUG_TYPE "mlir-tblgen-operator" 29 30 using namespace mlir; 31 using namespace mlir::tblgen; 32 33 using llvm::DagInit; 34 using llvm::DefInit; 35 using llvm::Record; 36 37 Operator::Operator(const llvm::Record &def) 38 : dialect(def.getValueAsDef("opDialect")), def(def) { 39 // The first `_` in the op's TableGen def name is treated as separating the 40 // dialect prefix and the op class name. The dialect prefix will be ignored if 41 // not empty. Otherwise, if def name starts with a `_`, the `_` is considered 42 // as part of the class name. 43 StringRef prefix; 44 std::tie(prefix, cppClassName) = def.getName().split('_'); 45 if (prefix.empty()) { 46 // Class name with a leading underscore and without dialect prefix 47 cppClassName = def.getName(); 48 } else if (cppClassName.empty()) { 49 // Class name without dialect prefix 50 cppClassName = prefix; 51 } 52 53 populateOpStructure(); 54 } 55 56 std::string Operator::getOperationName() const { 57 auto prefix = dialect.getName(); 58 auto opName = def.getValueAsString("opName"); 59 if (prefix.empty()) 60 return std::string(opName); 61 return std::string(llvm::formatv("{0}.{1}", prefix, opName)); 62 } 63 64 std::string Operator::getAdaptorName() const { 65 return std::string(llvm::formatv("{0}Adaptor", getCppClassName())); 66 } 67 68 StringRef Operator::getDialectName() const { return dialect.getName(); } 69 70 StringRef Operator::getCppClassName() const { return cppClassName; } 71 72 std::string Operator::getQualCppClassName() const { 73 auto prefix = dialect.getCppNamespace(); 74 if (prefix.empty()) 75 return std::string(cppClassName); 76 return std::string(llvm::formatv("{0}::{1}", prefix, cppClassName)); 77 } 78 79 int Operator::getNumResults() const { 80 DagInit *results = def.getValueAsDag("results"); 81 return results->getNumArgs(); 82 } 83 84 StringRef Operator::getExtraClassDeclaration() const { 85 constexpr auto attr = "extraClassDeclaration"; 86 if (def.isValueUnset(attr)) 87 return {}; 88 return def.getValueAsString(attr); 89 } 90 91 const llvm::Record &Operator::getDef() const { return def; } 92 93 bool Operator::skipDefaultBuilders() const { 94 return def.getValueAsBit("skipDefaultBuilders"); 95 } 96 97 auto Operator::result_begin() -> value_iterator { return results.begin(); } 98 99 auto Operator::result_end() -> value_iterator { return results.end(); } 100 101 auto Operator::getResults() -> value_range { 102 return {result_begin(), result_end()}; 103 } 104 105 TypeConstraint Operator::getResultTypeConstraint(int index) const { 106 DagInit *results = def.getValueAsDag("results"); 107 return TypeConstraint(cast<DefInit>(results->getArg(index))); 108 } 109 110 StringRef Operator::getResultName(int index) const { 111 DagInit *results = def.getValueAsDag("results"); 112 return results->getArgNameStr(index); 113 } 114 115 auto Operator::getResultDecorators(int index) const -> var_decorator_range { 116 Record *result = 117 cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef(); 118 if (!result->isSubClassOf("OpVariable")) 119 return var_decorator_range(nullptr, nullptr); 120 return *result->getValueAsListInit("decorators"); 121 } 122 123 unsigned Operator::getNumVariableLengthResults() const { 124 return llvm::count_if(results, [](const NamedTypeConstraint &c) { 125 return c.constraint.isVariableLength(); 126 }); 127 } 128 129 unsigned Operator::getNumVariableLengthOperands() const { 130 return llvm::count_if(operands, [](const NamedTypeConstraint &c) { 131 return c.constraint.isVariableLength(); 132 }); 133 } 134 135 bool Operator::hasSingleVariadicArg() const { 136 return getNumArgs() == 1 && getArg(0).is<NamedTypeConstraint *>() && 137 getOperand(0).isVariadic(); 138 } 139 140 Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); } 141 142 Operator::arg_iterator Operator::arg_end() const { return arguments.end(); } 143 144 Operator::arg_range Operator::getArgs() const { 145 return {arg_begin(), arg_end()}; 146 } 147 148 StringRef Operator::getArgName(int index) const { 149 DagInit *argumentValues = def.getValueAsDag("arguments"); 150 return argumentValues->getArgName(index)->getValue(); 151 } 152 153 auto Operator::getArgDecorators(int index) const -> var_decorator_range { 154 Record *arg = 155 cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef(); 156 if (!arg->isSubClassOf("OpVariable")) 157 return var_decorator_range(nullptr, nullptr); 158 return *arg->getValueAsListInit("decorators"); 159 } 160 161 const OpTrait *Operator::getTrait(StringRef trait) const { 162 for (const auto &t : traits) { 163 if (const auto *opTrait = dyn_cast<NativeOpTrait>(&t)) { 164 if (opTrait->getTrait() == trait) 165 return opTrait; 166 } else if (const auto *opTrait = dyn_cast<InternalOpTrait>(&t)) { 167 if (opTrait->getTrait() == trait) 168 return opTrait; 169 } else if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&t)) { 170 if (opTrait->getTrait() == trait) 171 return opTrait; 172 } 173 } 174 return nullptr; 175 } 176 177 auto Operator::region_begin() const -> const_region_iterator { 178 return regions.begin(); 179 } 180 auto Operator::region_end() const -> const_region_iterator { 181 return regions.end(); 182 } 183 auto Operator::getRegions() const 184 -> llvm::iterator_range<const_region_iterator> { 185 return {region_begin(), region_end()}; 186 } 187 188 unsigned Operator::getNumRegions() const { return regions.size(); } 189 190 const NamedRegion &Operator::getRegion(unsigned index) const { 191 return regions[index]; 192 } 193 194 unsigned Operator::getNumVariadicRegions() const { 195 return llvm::count_if(regions, 196 [](const NamedRegion &c) { return c.isVariadic(); }); 197 } 198 199 auto Operator::successor_begin() const -> const_successor_iterator { 200 return successors.begin(); 201 } 202 auto Operator::successor_end() const -> const_successor_iterator { 203 return successors.end(); 204 } 205 auto Operator::getSuccessors() const 206 -> llvm::iterator_range<const_successor_iterator> { 207 return {successor_begin(), successor_end()}; 208 } 209 210 unsigned Operator::getNumSuccessors() const { return successors.size(); } 211 212 const NamedSuccessor &Operator::getSuccessor(unsigned index) const { 213 return successors[index]; 214 } 215 216 unsigned Operator::getNumVariadicSuccessors() const { 217 return llvm::count_if(successors, 218 [](const NamedSuccessor &c) { return c.isVariadic(); }); 219 } 220 221 auto Operator::trait_begin() const -> const_trait_iterator { 222 return traits.begin(); 223 } 224 auto Operator::trait_end() const -> const_trait_iterator { 225 return traits.end(); 226 } 227 auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> { 228 return {trait_begin(), trait_end()}; 229 } 230 231 auto Operator::attribute_begin() const -> attribute_iterator { 232 return attributes.begin(); 233 } 234 auto Operator::attribute_end() const -> attribute_iterator { 235 return attributes.end(); 236 } 237 auto Operator::getAttributes() const 238 -> llvm::iterator_range<attribute_iterator> { 239 return {attribute_begin(), attribute_end()}; 240 } 241 242 auto Operator::operand_begin() -> value_iterator { return operands.begin(); } 243 auto Operator::operand_end() -> value_iterator { return operands.end(); } 244 auto Operator::getOperands() -> value_range { 245 return {operand_begin(), operand_end()}; 246 } 247 248 auto Operator::getArg(int index) const -> Argument { return arguments[index]; } 249 250 // Mapping from result index to combined argument and result index. Arguments 251 // are indexed to match getArg index, while the result indexes are mapped to 252 // avoid overlap. 253 static int resultIndex(int i) { return -1 - i; } 254 255 bool Operator::isVariadic() const { 256 return any_of(llvm::concat<const NamedTypeConstraint>(operands, results), 257 [](const NamedTypeConstraint &op) { return op.isVariadic(); }); 258 } 259 260 void Operator::populateTypeInferenceInfo( 261 const llvm::StringMap<int> &argumentsAndResultsIndex) { 262 // If the type inference op interface is not registered, then do not attempt 263 // to determine if the result types an be inferred. 264 auto &recordKeeper = def.getRecords(); 265 auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface); 266 allResultsHaveKnownTypes = false; 267 if (!inferTrait) 268 return; 269 270 // If there are no results, the skip this else the build method generated 271 // overlaps with another autogenerated builder. 272 if (getNumResults() == 0) 273 return; 274 275 // Skip for ops with variadic operands/results. 276 // TODO: This can be relaxed. 277 if (isVariadic()) 278 return; 279 280 // Skip cases currently being custom generated. 281 // TODO: Remove special cases. 282 if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) 283 return; 284 285 // We create equivalence classes of argument/result types where arguments 286 // and results are mapped into the same index space and indices corresponding 287 // to the same type are in the same equivalence class. 288 llvm::EquivalenceClasses<int> ecs; 289 resultTypeMapping.resize(getNumResults()); 290 // Captures the argument whose type matches a given result type. Preference 291 // towards capturing operands first before attributes. 292 auto captureMapping = [&](int i) { 293 bool found = false; 294 ecs.insert(resultIndex(i)); 295 auto mi = ecs.findLeader(resultIndex(i)); 296 for (auto me = ecs.member_end(); mi != me; ++mi) { 297 if (*mi < 0) { 298 auto tc = getResultTypeConstraint(i); 299 if (tc.getBuilderCall().hasValue()) { 300 resultTypeMapping[i].emplace_back(tc); 301 found = true; 302 } 303 continue; 304 } 305 306 if (getArg(*mi).is<NamedAttribute *>()) { 307 // TODO: Handle attributes. 308 continue; 309 } else { 310 resultTypeMapping[i].emplace_back(*mi); 311 found = true; 312 } 313 } 314 return found; 315 }; 316 317 for (const OpTrait &trait : traits) { 318 const llvm::Record &def = trait.getDef(); 319 // If the infer type op interface was manually added, then treat it as 320 // intention that the op needs special handling. 321 // TODO: Reconsider whether to always generate, this is more conservative 322 // and keeps existing behavior so starting that way for now. 323 if (def.isSubClassOf( 324 llvm::formatv("{0}::Trait", inferTypeOpInterface).str())) 325 return; 326 if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&trait)) 327 if (&opTrait->getDef() == inferTrait) 328 return; 329 330 if (!def.isSubClassOf("AllTypesMatch")) 331 continue; 332 333 auto values = def.getValueAsListOfStrings("values"); 334 auto root = argumentsAndResultsIndex.lookup(values.front()); 335 for (StringRef str : values) 336 ecs.unionSets(argumentsAndResultsIndex.lookup(str), root); 337 } 338 339 // Verifies that all output types have a corresponding known input type 340 // and chooses matching operand or attribute (in that order) that 341 // matches it. 342 allResultsHaveKnownTypes = 343 all_of(llvm::seq<int>(0, getNumResults()), captureMapping); 344 345 // If the types could be computed, then add type inference trait. 346 if (allResultsHaveKnownTypes) 347 traits.push_back(OpTrait::create(inferTrait->getDefInit())); 348 } 349 350 void Operator::populateOpStructure() { 351 auto &recordKeeper = def.getRecords(); 352 auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint"); 353 auto *attrClass = recordKeeper.getClass("Attr"); 354 auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr"); 355 auto *opVarClass = recordKeeper.getClass("OpVariable"); 356 numNativeAttributes = 0; 357 358 DagInit *argumentValues = def.getValueAsDag("arguments"); 359 unsigned numArgs = argumentValues->getNumArgs(); 360 361 // Mapping from name of to argument or result index. Arguments are indexed 362 // to match getArg index, while the results are negatively indexed. 363 llvm::StringMap<int> argumentsAndResultsIndex; 364 365 // Handle operands and native attributes. 366 for (unsigned i = 0; i != numArgs; ++i) { 367 auto *arg = argumentValues->getArg(i); 368 auto givenName = argumentValues->getArgNameStr(i); 369 auto *argDefInit = dyn_cast<DefInit>(arg); 370 if (!argDefInit) 371 PrintFatalError(def.getLoc(), 372 Twine("undefined type for argument #") + Twine(i)); 373 Record *argDef = argDefInit->getDef(); 374 if (argDef->isSubClassOf(opVarClass)) 375 argDef = argDef->getValueAsDef("constraint"); 376 377 if (argDef->isSubClassOf(typeConstraintClass)) { 378 operands.push_back( 379 NamedTypeConstraint{givenName, TypeConstraint(argDef)}); 380 } else if (argDef->isSubClassOf(attrClass)) { 381 if (givenName.empty()) 382 PrintFatalError(argDef->getLoc(), "attributes must be named"); 383 if (argDef->isSubClassOf(derivedAttrClass)) 384 PrintFatalError(argDef->getLoc(), 385 "derived attributes not allowed in argument list"); 386 attributes.push_back({givenName, Attribute(argDef)}); 387 ++numNativeAttributes; 388 } else { 389 PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving " 390 "from TypeConstraint or Attr are allowed"); 391 } 392 if (!givenName.empty()) 393 argumentsAndResultsIndex[givenName] = i; 394 } 395 396 // Handle derived attributes. 397 for (const auto &val : def.getValues()) { 398 if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) { 399 if (!record->isSubClassOf(attrClass)) 400 continue; 401 if (!record->isSubClassOf(derivedAttrClass)) 402 PrintFatalError(def.getLoc(), 403 "unexpected Attr where only DerivedAttr is allowed"); 404 405 if (record->getClasses().size() != 1) { 406 PrintFatalError( 407 def.getLoc(), 408 "unsupported attribute modelling, only single class expected"); 409 } 410 attributes.push_back( 411 {cast<llvm::StringInit>(val.getNameInit())->getValue(), 412 Attribute(cast<DefInit>(val.getValue()))}); 413 } 414 } 415 416 // Populate `arguments`. This must happen after we've finalized `operands` and 417 // `attributes` because we will put their elements' pointers in `arguments`. 418 // SmallVector may perform re-allocation under the hood when adding new 419 // elements. 420 int operandIndex = 0, attrIndex = 0; 421 for (unsigned i = 0; i != numArgs; ++i) { 422 Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef(); 423 if (argDef->isSubClassOf(opVarClass)) 424 argDef = argDef->getValueAsDef("constraint"); 425 426 if (argDef->isSubClassOf(typeConstraintClass)) { 427 attrOrOperandMapping.push_back( 428 {OperandOrAttribute::Kind::Operand, operandIndex}); 429 arguments.emplace_back(&operands[operandIndex++]); 430 } else { 431 assert(argDef->isSubClassOf(attrClass)); 432 attrOrOperandMapping.push_back( 433 {OperandOrAttribute::Kind::Attribute, attrIndex}); 434 arguments.emplace_back(&attributes[attrIndex++]); 435 } 436 } 437 438 auto *resultsDag = def.getValueAsDag("results"); 439 auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator()); 440 if (!outsOp || outsOp->getDef()->getName() != "outs") { 441 PrintFatalError(def.getLoc(), "'results' must have 'outs' directive"); 442 } 443 444 // Handle results. 445 for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) { 446 auto name = resultsDag->getArgNameStr(i); 447 auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i)); 448 if (!resultInit) { 449 PrintFatalError(def.getLoc(), 450 Twine("undefined type for result #") + Twine(i)); 451 } 452 auto *resultDef = resultInit->getDef(); 453 if (resultDef->isSubClassOf(opVarClass)) 454 resultDef = resultDef->getValueAsDef("constraint"); 455 results.push_back({name, TypeConstraint(resultDef)}); 456 if (!name.empty()) 457 argumentsAndResultsIndex[name] = resultIndex(i); 458 } 459 460 // Handle successors 461 auto *successorsDag = def.getValueAsDag("successors"); 462 auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator()); 463 if (!successorsOp || successorsOp->getDef()->getName() != "successor") { 464 PrintFatalError(def.getLoc(), 465 "'successors' must have 'successor' directive"); 466 } 467 468 for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) { 469 auto name = successorsDag->getArgNameStr(i); 470 auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i)); 471 if (!successorInit) { 472 PrintFatalError(def.getLoc(), 473 Twine("undefined kind for successor #") + Twine(i)); 474 } 475 Successor successor(successorInit->getDef()); 476 477 // Only support variadic successors if it is the last one for now. 478 if (i != e - 1 && successor.isVariadic()) 479 PrintFatalError(def.getLoc(), "only the last successor can be variadic"); 480 successors.push_back({name, successor}); 481 } 482 483 // Create list of traits, skipping over duplicates: appending to lists in 484 // tablegen is easy, making them unique less so, so dedupe here. 485 if (auto *traitList = def.getValueAsListInit("traits")) { 486 // This is uniquing based on pointers of the trait. 487 SmallPtrSet<const llvm::Init *, 32> traitSet; 488 traits.reserve(traitSet.size()); 489 for (auto *traitInit : *traitList) { 490 // Keep traits in the same order while skipping over duplicates. 491 if (traitSet.insert(traitInit).second) 492 traits.push_back(OpTrait::create(traitInit)); 493 } 494 } 495 496 populateTypeInferenceInfo(argumentsAndResultsIndex); 497 498 // Handle regions 499 auto *regionsDag = def.getValueAsDag("regions"); 500 auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator()); 501 if (!regionsOp || regionsOp->getDef()->getName() != "region") { 502 PrintFatalError(def.getLoc(), "'regions' must have 'region' directive"); 503 } 504 505 for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) { 506 auto name = regionsDag->getArgNameStr(i); 507 auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i)); 508 if (!regionInit) { 509 PrintFatalError(def.getLoc(), 510 Twine("undefined kind for region #") + Twine(i)); 511 } 512 Region region(regionInit->getDef()); 513 if (region.isVariadic()) { 514 // Only support variadic regions if it is the last one for now. 515 if (i != e - 1) 516 PrintFatalError(def.getLoc(), "only the last region can be variadic"); 517 if (name.empty()) 518 PrintFatalError(def.getLoc(), "variadic regions must be named"); 519 } 520 521 regions.push_back({name, region}); 522 } 523 524 LLVM_DEBUG(print(llvm::dbgs())); 525 } 526 527 auto Operator::getSameTypeAsResult(int index) const -> ArrayRef<ArgOrType> { 528 assert(allResultTypesKnown()); 529 return resultTypeMapping[index]; 530 } 531 532 ArrayRef<llvm::SMLoc> Operator::getLoc() const { return def.getLoc(); } 533 534 bool Operator::hasDescription() const { 535 return def.getValue("description") != nullptr; 536 } 537 538 StringRef Operator::getDescription() const { 539 return def.getValueAsString("description"); 540 } 541 542 bool Operator::hasSummary() const { return def.getValue("summary") != nullptr; } 543 544 StringRef Operator::getSummary() const { 545 return def.getValueAsString("summary"); 546 } 547 548 bool Operator::hasAssemblyFormat() const { 549 auto *valueInit = def.getValueInit("assemblyFormat"); 550 return isa<llvm::CodeInit, llvm::StringInit>(valueInit); 551 } 552 553 StringRef Operator::getAssemblyFormat() const { 554 return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat")) 555 .Case<llvm::StringInit, llvm::CodeInit>( 556 [&](auto *init) { return init->getValue(); }); 557 } 558 559 void Operator::print(llvm::raw_ostream &os) const { 560 os << "op '" << getOperationName() << "'\n"; 561 for (Argument arg : arguments) { 562 if (auto *attr = arg.dyn_cast<NamedAttribute *>()) 563 os << "[attribute] " << attr->name << '\n'; 564 else 565 os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n'; 566 } 567 } 568 569 Operator::NamespaceEmitter::NamespaceEmitter(raw_ostream &os, Operator &op) 570 : os(os) { 571 auto dialect = op.getDialect(); 572 if (!dialect) 573 return; 574 llvm::SplitString(dialect.getCppNamespace(), namespaces, "::"); 575 for (StringRef ns : namespaces) 576 os << "namespace " << ns << " {\n"; 577 } 578 579 Operator::NamespaceEmitter::~NamespaceEmitter() { 580 for (StringRef ns : llvm::reverse(namespaces)) 581 os << "} // namespace " << ns << "\n"; 582 } 583 584 auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init) 585 -> VariableDecorator { 586 return VariableDecorator(cast<llvm::DefInit>(init)->getDef()); 587 } 588 589 auto Operator::getArgToOperandOrAttribute(int index) const 590 -> OperandOrAttribute { 591 return attrOrOperandMapping[index]; 592 } 593