1 //===- Operator.cpp - Operator class --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/TableGen/Operator.h" 14 #include "mlir/TableGen/OpTrait.h" 15 #include "mlir/TableGen/Predicate.h" 16 #include "mlir/TableGen/Type.h" 17 #include "llvm/ADT/EquivalenceClasses.h" 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/ADT/Sequence.h" 20 #include "llvm/ADT/SmallPtrSet.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/FormatVariadic.h" 24 #include "llvm/TableGen/Error.h" 25 #include "llvm/TableGen/Record.h" 26 27 #define DEBUG_TYPE "mlir-tblgen-operator" 28 29 using namespace mlir; 30 31 using llvm::DagInit; 32 using llvm::DefInit; 33 using llvm::Record; 34 35 tblgen::Operator::Operator(const llvm::Record &def) 36 : dialect(def.getValueAsDef("opDialect")), def(def) { 37 // The first `_` in the op's TableGen def name is treated as separating the 38 // dialect prefix and the op class name. The dialect prefix will be ignored if 39 // not empty. Otherwise, if def name starts with a `_`, the `_` is considered 40 // as part of the class name. 41 StringRef prefix; 42 std::tie(prefix, cppClassName) = def.getName().split('_'); 43 if (prefix.empty()) { 44 // Class name with a leading underscore and without dialect prefix 45 cppClassName = def.getName(); 46 } else if (cppClassName.empty()) { 47 // Class name without dialect prefix 48 cppClassName = prefix; 49 } 50 51 populateOpStructure(); 52 } 53 54 std::string tblgen::Operator::getOperationName() const { 55 auto prefix = dialect.getName(); 56 auto opName = def.getValueAsString("opName"); 57 if (prefix.empty()) 58 return std::string(opName); 59 return std::string(llvm::formatv("{0}.{1}", prefix, opName)); 60 } 61 62 std::string tblgen::Operator::getAdaptorName() const { 63 return std::string(llvm::formatv("{0}Adaptor", getCppClassName())); 64 } 65 66 StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); } 67 68 StringRef tblgen::Operator::getCppClassName() const { return cppClassName; } 69 70 std::string tblgen::Operator::getQualCppClassName() const { 71 auto prefix = dialect.getCppNamespace(); 72 if (prefix.empty()) 73 return std::string(cppClassName); 74 return std::string(llvm::formatv("{0}::{1}", prefix, cppClassName)); 75 } 76 77 int tblgen::Operator::getNumResults() const { 78 DagInit *results = def.getValueAsDag("results"); 79 return results->getNumArgs(); 80 } 81 82 StringRef tblgen::Operator::getExtraClassDeclaration() const { 83 constexpr auto attr = "extraClassDeclaration"; 84 if (def.isValueUnset(attr)) 85 return {}; 86 return def.getValueAsString(attr); 87 } 88 89 const llvm::Record &tblgen::Operator::getDef() const { return def; } 90 91 bool tblgen::Operator::skipDefaultBuilders() const { 92 return def.getValueAsBit("skipDefaultBuilders"); 93 } 94 95 auto tblgen::Operator::result_begin() -> value_iterator { 96 return results.begin(); 97 } 98 99 auto tblgen::Operator::result_end() -> value_iterator { return results.end(); } 100 101 auto tblgen::Operator::getResults() -> value_range { 102 return {result_begin(), result_end()}; 103 } 104 105 tblgen::TypeConstraint 106 tblgen::Operator::getResultTypeConstraint(int index) const { 107 DagInit *results = def.getValueAsDag("results"); 108 return TypeConstraint(cast<DefInit>(results->getArg(index))); 109 } 110 111 StringRef tblgen::Operator::getResultName(int index) const { 112 DagInit *results = def.getValueAsDag("results"); 113 return results->getArgNameStr(index); 114 } 115 116 auto tblgen::Operator::getResultDecorators(int index) const 117 -> var_decorator_range { 118 Record *result = 119 cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef(); 120 if (!result->isSubClassOf("OpVariable")) 121 return var_decorator_range(nullptr, nullptr); 122 return *result->getValueAsListInit("decorators"); 123 } 124 125 unsigned tblgen::Operator::getNumVariableLengthResults() const { 126 return llvm::count_if(results, [](const NamedTypeConstraint &c) { 127 return c.constraint.isVariableLength(); 128 }); 129 } 130 131 unsigned tblgen::Operator::getNumVariableLengthOperands() const { 132 return llvm::count_if(operands, [](const NamedTypeConstraint &c) { 133 return c.constraint.isVariableLength(); 134 }); 135 } 136 137 tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const { 138 return arguments.begin(); 139 } 140 141 tblgen::Operator::arg_iterator tblgen::Operator::arg_end() const { 142 return arguments.end(); 143 } 144 145 tblgen::Operator::arg_range tblgen::Operator::getArgs() const { 146 return {arg_begin(), arg_end()}; 147 } 148 149 StringRef tblgen::Operator::getArgName(int index) const { 150 DagInit *argumentValues = def.getValueAsDag("arguments"); 151 return argumentValues->getArgName(index)->getValue(); 152 } 153 154 auto tblgen::Operator::getArgDecorators(int index) const 155 -> var_decorator_range { 156 Record *arg = 157 cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef(); 158 if (!arg->isSubClassOf("OpVariable")) 159 return var_decorator_range(nullptr, nullptr); 160 return *arg->getValueAsListInit("decorators"); 161 } 162 163 const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const { 164 for (const auto &t : traits) { 165 if (const auto *opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) { 166 if (opTrait->getTrait() == trait) 167 return opTrait; 168 } else if (const auto *opTrait = dyn_cast<tblgen::InternalOpTrait>(&t)) { 169 if (opTrait->getTrait() == trait) 170 return opTrait; 171 } else if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&t)) { 172 if (opTrait->getTrait() == trait) 173 return opTrait; 174 } 175 } 176 return nullptr; 177 } 178 179 auto tblgen::Operator::region_begin() const -> const_region_iterator { 180 return regions.begin(); 181 } 182 auto tblgen::Operator::region_end() const -> const_region_iterator { 183 return regions.end(); 184 } 185 auto tblgen::Operator::getRegions() const 186 -> llvm::iterator_range<const_region_iterator> { 187 return {region_begin(), region_end()}; 188 } 189 190 unsigned tblgen::Operator::getNumRegions() const { return regions.size(); } 191 192 const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const { 193 return regions[index]; 194 } 195 196 unsigned tblgen::Operator::getNumVariadicRegions() const { 197 return llvm::count_if(regions, 198 [](const NamedRegion &c) { return c.isVariadic(); }); 199 } 200 201 auto tblgen::Operator::successor_begin() const -> const_successor_iterator { 202 return successors.begin(); 203 } 204 auto tblgen::Operator::successor_end() const -> const_successor_iterator { 205 return successors.end(); 206 } 207 auto tblgen::Operator::getSuccessors() const 208 -> llvm::iterator_range<const_successor_iterator> { 209 return {successor_begin(), successor_end()}; 210 } 211 212 unsigned tblgen::Operator::getNumSuccessors() const { 213 return successors.size(); 214 } 215 216 const tblgen::NamedSuccessor & 217 tblgen::Operator::getSuccessor(unsigned index) const { 218 return successors[index]; 219 } 220 221 unsigned tblgen::Operator::getNumVariadicSuccessors() const { 222 return llvm::count_if(successors, 223 [](const NamedSuccessor &c) { return c.isVariadic(); }); 224 } 225 226 auto tblgen::Operator::trait_begin() const -> const_trait_iterator { 227 return traits.begin(); 228 } 229 auto tblgen::Operator::trait_end() const -> const_trait_iterator { 230 return traits.end(); 231 } 232 auto tblgen::Operator::getTraits() const 233 -> llvm::iterator_range<const_trait_iterator> { 234 return {trait_begin(), trait_end()}; 235 } 236 237 auto tblgen::Operator::attribute_begin() const -> attribute_iterator { 238 return attributes.begin(); 239 } 240 auto tblgen::Operator::attribute_end() const -> attribute_iterator { 241 return attributes.end(); 242 } 243 auto tblgen::Operator::getAttributes() const 244 -> llvm::iterator_range<attribute_iterator> { 245 return {attribute_begin(), attribute_end()}; 246 } 247 248 auto tblgen::Operator::operand_begin() -> value_iterator { 249 return operands.begin(); 250 } 251 auto tblgen::Operator::operand_end() -> value_iterator { 252 return operands.end(); 253 } 254 auto tblgen::Operator::getOperands() -> value_range { 255 return {operand_begin(), operand_end()}; 256 } 257 258 auto tblgen::Operator::getArg(int index) const -> Argument { 259 return arguments[index]; 260 } 261 262 // Mapping from result index to combined argument and result index. Arguments 263 // are indexed to match getArg index, while the result indexes are mapped to 264 // avoid overlap. 265 static int resultIndex(int i) { return -1 - i; } 266 267 bool tblgen::Operator::isVariadic() const { 268 return any_of(llvm::concat<const NamedTypeConstraint>(operands, results), 269 [](const NamedTypeConstraint &op) { return op.isVariadic(); }); 270 } 271 272 void tblgen::Operator::populateTypeInferenceInfo( 273 const llvm::StringMap<int> &argumentsAndResultsIndex) { 274 // If the type inference op interface is not registered, then do not attempt 275 // to determine if the result types an be inferred. 276 auto &recordKeeper = def.getRecords(); 277 auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface); 278 allResultsHaveKnownTypes = false; 279 if (!inferTrait) 280 return; 281 282 // If there are no results, the skip this else the build method generated 283 // overlaps with another autogenerated builder. 284 if (getNumResults() == 0) 285 return; 286 287 // Skip for ops with variadic operands/results. 288 // TODO: This can be relaxed. 289 if (isVariadic()) 290 return; 291 292 // Skip cases currently being custom generated. 293 // TODO: Remove special cases. 294 if (getTrait("OpTrait::SameOperandsAndResultType")) 295 return; 296 297 // We create equivalence classes of argument/result types where arguments 298 // and results are mapped into the same index space and indices corresponding 299 // to the same type are in the same equivalence class. 300 llvm::EquivalenceClasses<int> ecs; 301 resultTypeMapping.resize(getNumResults()); 302 // Captures the argument whose type matches a given result type. Preference 303 // towards capturing operands first before attributes. 304 auto captureMapping = [&](int i) { 305 bool found = false; 306 ecs.insert(resultIndex(i)); 307 auto mi = ecs.findLeader(resultIndex(i)); 308 for (auto me = ecs.member_end(); mi != me; ++mi) { 309 if (*mi < 0) { 310 auto tc = getResultTypeConstraint(i); 311 if (tc.getBuilderCall().hasValue()) { 312 resultTypeMapping[i].emplace_back(tc); 313 found = true; 314 } 315 continue; 316 } 317 318 if (getArg(*mi).is<NamedAttribute *>()) { 319 // TODO: Handle attributes. 320 continue; 321 } else { 322 resultTypeMapping[i].emplace_back(*mi); 323 found = true; 324 } 325 } 326 return found; 327 }; 328 329 for (const OpTrait &trait : traits) { 330 const llvm::Record &def = trait.getDef(); 331 // If the infer type op interface was manually added, then treat it as 332 // intention that the op needs special handling. 333 // TODO: Reconsider whether to always generate, this is more conservative 334 // and keeps existing behavior so starting that way for now. 335 if (def.isSubClassOf( 336 llvm::formatv("{0}::Trait", inferTypeOpInterface).str())) 337 return; 338 if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait)) 339 if (opTrait->getTrait().startswith(inferTypeOpInterface)) 340 return; 341 342 if (!def.isSubClassOf("AllTypesMatch")) 343 continue; 344 345 auto values = def.getValueAsListOfStrings("values"); 346 auto root = argumentsAndResultsIndex.lookup(values.front()); 347 for (StringRef str : values) 348 ecs.unionSets(argumentsAndResultsIndex.lookup(str), root); 349 } 350 351 // Verifies that all output types have a corresponding known input type 352 // and chooses matching operand or attribute (in that order) that 353 // matches it. 354 allResultsHaveKnownTypes = 355 all_of(llvm::seq<int>(0, getNumResults()), captureMapping); 356 357 // If the types could be computed, then add type inference trait. 358 if (allResultsHaveKnownTypes) 359 traits.push_back(OpTrait::create(inferTrait->getDefInit())); 360 } 361 362 void tblgen::Operator::populateOpStructure() { 363 auto &recordKeeper = def.getRecords(); 364 auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint"); 365 auto *attrClass = recordKeeper.getClass("Attr"); 366 auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr"); 367 auto *opVarClass = recordKeeper.getClass("OpVariable"); 368 numNativeAttributes = 0; 369 370 DagInit *argumentValues = def.getValueAsDag("arguments"); 371 unsigned numArgs = argumentValues->getNumArgs(); 372 373 // Mapping from name of to argument or result index. Arguments are indexed 374 // to match getArg index, while the results are negatively indexed. 375 llvm::StringMap<int> argumentsAndResultsIndex; 376 377 // Handle operands and native attributes. 378 for (unsigned i = 0; i != numArgs; ++i) { 379 auto *arg = argumentValues->getArg(i); 380 auto givenName = argumentValues->getArgNameStr(i); 381 auto *argDefInit = dyn_cast<DefInit>(arg); 382 if (!argDefInit) 383 PrintFatalError(def.getLoc(), 384 Twine("undefined type for argument #") + Twine(i)); 385 Record *argDef = argDefInit->getDef(); 386 if (argDef->isSubClassOf(opVarClass)) 387 argDef = argDef->getValueAsDef("constraint"); 388 389 if (argDef->isSubClassOf(typeConstraintClass)) { 390 operands.push_back( 391 NamedTypeConstraint{givenName, TypeConstraint(argDef)}); 392 } else if (argDef->isSubClassOf(attrClass)) { 393 if (givenName.empty()) 394 PrintFatalError(argDef->getLoc(), "attributes must be named"); 395 if (argDef->isSubClassOf(derivedAttrClass)) 396 PrintFatalError(argDef->getLoc(), 397 "derived attributes not allowed in argument list"); 398 attributes.push_back({givenName, Attribute(argDef)}); 399 ++numNativeAttributes; 400 } else { 401 PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving " 402 "from TypeConstraint or Attr are allowed"); 403 } 404 if (!givenName.empty()) 405 argumentsAndResultsIndex[givenName] = i; 406 } 407 408 // Handle derived attributes. 409 for (const auto &val : def.getValues()) { 410 if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) { 411 if (!record->isSubClassOf(attrClass)) 412 continue; 413 if (!record->isSubClassOf(derivedAttrClass)) 414 PrintFatalError(def.getLoc(), 415 "unexpected Attr where only DerivedAttr is allowed"); 416 417 if (record->getClasses().size() != 1) { 418 PrintFatalError( 419 def.getLoc(), 420 "unsupported attribute modelling, only single class expected"); 421 } 422 attributes.push_back( 423 {cast<llvm::StringInit>(val.getNameInit())->getValue(), 424 Attribute(cast<DefInit>(val.getValue()))}); 425 } 426 } 427 428 // Populate `arguments`. This must happen after we've finalized `operands` and 429 // `attributes` because we will put their elements' pointers in `arguments`. 430 // SmallVector may perform re-allocation under the hood when adding new 431 // elements. 432 int operandIndex = 0, attrIndex = 0; 433 for (unsigned i = 0; i != numArgs; ++i) { 434 Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef(); 435 if (argDef->isSubClassOf(opVarClass)) 436 argDef = argDef->getValueAsDef("constraint"); 437 438 if (argDef->isSubClassOf(typeConstraintClass)) { 439 arguments.emplace_back(&operands[operandIndex++]); 440 } else { 441 assert(argDef->isSubClassOf(attrClass)); 442 arguments.emplace_back(&attributes[attrIndex++]); 443 } 444 } 445 446 auto *resultsDag = def.getValueAsDag("results"); 447 auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator()); 448 if (!outsOp || outsOp->getDef()->getName() != "outs") { 449 PrintFatalError(def.getLoc(), "'results' must have 'outs' directive"); 450 } 451 452 // Handle results. 453 for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) { 454 auto name = resultsDag->getArgNameStr(i); 455 auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i)); 456 if (!resultInit) { 457 PrintFatalError(def.getLoc(), 458 Twine("undefined type for result #") + Twine(i)); 459 } 460 auto *resultDef = resultInit->getDef(); 461 if (resultDef->isSubClassOf(opVarClass)) 462 resultDef = resultDef->getValueAsDef("constraint"); 463 results.push_back({name, TypeConstraint(resultDef)}); 464 if (!name.empty()) 465 argumentsAndResultsIndex[name] = resultIndex(i); 466 } 467 468 // Handle successors 469 auto *successorsDag = def.getValueAsDag("successors"); 470 auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator()); 471 if (!successorsOp || successorsOp->getDef()->getName() != "successor") { 472 PrintFatalError(def.getLoc(), 473 "'successors' must have 'successor' directive"); 474 } 475 476 for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) { 477 auto name = successorsDag->getArgNameStr(i); 478 auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i)); 479 if (!successorInit) { 480 PrintFatalError(def.getLoc(), 481 Twine("undefined kind for successor #") + Twine(i)); 482 } 483 Successor successor(successorInit->getDef()); 484 485 // Only support variadic successors if it is the last one for now. 486 if (i != e - 1 && successor.isVariadic()) 487 PrintFatalError(def.getLoc(), "only the last successor can be variadic"); 488 successors.push_back({name, successor}); 489 } 490 491 // Create list of traits, skipping over duplicates: appending to lists in 492 // tablegen is easy, making them unique less so, so dedupe here. 493 if (auto *traitList = def.getValueAsListInit("traits")) { 494 // This is uniquing based on pointers of the trait. 495 SmallPtrSet<const llvm::Init *, 32> traitSet; 496 traits.reserve(traitSet.size()); 497 for (auto *traitInit : *traitList) { 498 // Keep traits in the same order while skipping over duplicates. 499 if (traitSet.insert(traitInit).second) 500 traits.push_back(OpTrait::create(traitInit)); 501 } 502 } 503 504 populateTypeInferenceInfo(argumentsAndResultsIndex); 505 506 // Handle regions 507 auto *regionsDag = def.getValueAsDag("regions"); 508 auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator()); 509 if (!regionsOp || regionsOp->getDef()->getName() != "region") { 510 PrintFatalError(def.getLoc(), "'regions' must have 'region' directive"); 511 } 512 513 for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) { 514 auto name = regionsDag->getArgNameStr(i); 515 auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i)); 516 if (!regionInit) { 517 PrintFatalError(def.getLoc(), 518 Twine("undefined kind for region #") + Twine(i)); 519 } 520 Region region(regionInit->getDef()); 521 if (region.isVariadic()) { 522 // Only support variadic regions if it is the last one for now. 523 if (i != e - 1) 524 PrintFatalError(def.getLoc(), "only the last region can be variadic"); 525 if (name.empty()) 526 PrintFatalError(def.getLoc(), "variadic regions must be named"); 527 } 528 529 regions.push_back({name, region}); 530 } 531 532 LLVM_DEBUG(print(llvm::dbgs())); 533 } 534 535 auto tblgen::Operator::getSameTypeAsResult(int index) const 536 -> ArrayRef<ArgOrType> { 537 assert(allResultTypesKnown()); 538 return resultTypeMapping[index]; 539 } 540 541 ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); } 542 543 bool tblgen::Operator::hasDescription() const { 544 return def.getValue("description") != nullptr; 545 } 546 547 StringRef tblgen::Operator::getDescription() const { 548 return def.getValueAsString("description"); 549 } 550 551 bool tblgen::Operator::hasSummary() const { 552 return def.getValue("summary") != nullptr; 553 } 554 555 StringRef tblgen::Operator::getSummary() const { 556 return def.getValueAsString("summary"); 557 } 558 559 bool tblgen::Operator::hasAssemblyFormat() const { 560 auto *valueInit = def.getValueInit("assemblyFormat"); 561 return isa<llvm::CodeInit, llvm::StringInit>(valueInit); 562 } 563 564 StringRef tblgen::Operator::getAssemblyFormat() const { 565 return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat")) 566 .Case<llvm::StringInit, llvm::CodeInit>( 567 [&](auto *init) { return init->getValue(); }); 568 } 569 570 void tblgen::Operator::print(llvm::raw_ostream &os) const { 571 os << "op '" << getOperationName() << "'\n"; 572 for (Argument arg : arguments) { 573 if (auto *attr = arg.dyn_cast<NamedAttribute *>()) 574 os << "[attribute] " << attr->name << '\n'; 575 else 576 os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n'; 577 } 578 } 579 580 auto tblgen::Operator::VariableDecoratorIterator::unwrap(llvm::Init *init) 581 -> VariableDecorator { 582 return VariableDecorator(cast<llvm::DefInit>(init)->getDef()); 583 } 584