1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 10 // Pattern. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/Pattern.h" 15 #include "llvm/ADT/StringExtras.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Support/Debug.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 #define DEBUG_TYPE "mlir-tblgen-pattern" 23 24 using namespace mlir; 25 26 using llvm::formatv; 27 using mlir::tblgen::Operator; 28 29 //===----------------------------------------------------------------------===// 30 // DagLeaf 31 //===----------------------------------------------------------------------===// 32 33 bool tblgen::DagLeaf::isUnspecified() const { 34 return dyn_cast_or_null<llvm::UnsetInit>(def); 35 } 36 37 bool tblgen::DagLeaf::isOperandMatcher() const { 38 // Operand matchers specify a type constraint. 39 return isSubClassOf("TypeConstraint"); 40 } 41 42 bool tblgen::DagLeaf::isAttrMatcher() const { 43 // Attribute matchers specify an attribute constraint. 44 return isSubClassOf("AttrConstraint"); 45 } 46 47 bool tblgen::DagLeaf::isNativeCodeCall() const { 48 return isSubClassOf("NativeCodeCall"); 49 } 50 51 bool tblgen::DagLeaf::isConstantAttr() const { 52 return isSubClassOf("ConstantAttr"); 53 } 54 55 bool tblgen::DagLeaf::isEnumAttrCase() const { 56 return isSubClassOf("EnumAttrCaseInfo"); 57 } 58 59 tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const { 60 assert((isOperandMatcher() || isAttrMatcher()) && 61 "the DAG leaf must be operand or attribute"); 62 return Constraint(cast<llvm::DefInit>(def)->getDef()); 63 } 64 65 tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const { 66 assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 67 return ConstantAttr(cast<llvm::DefInit>(def)); 68 } 69 70 tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const { 71 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 72 return EnumAttrCase(cast<llvm::DefInit>(def)); 73 } 74 75 std::string tblgen::DagLeaf::getConditionTemplate() const { 76 return getAsConstraint().getConditionTemplate(); 77 } 78 79 llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const { 80 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 81 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); 82 } 83 84 bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const { 85 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) 86 return defInit->getDef()->isSubClassOf(superclass); 87 return false; 88 } 89 90 void tblgen::DagLeaf::print(raw_ostream &os) const { 91 if (def) 92 def->print(os); 93 } 94 95 //===----------------------------------------------------------------------===// 96 // DagNode 97 //===----------------------------------------------------------------------===// 98 99 bool tblgen::DagNode::isNativeCodeCall() const { 100 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) 101 return defInit->getDef()->isSubClassOf("NativeCodeCall"); 102 return false; 103 } 104 105 bool tblgen::DagNode::isOperation() const { 106 return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective(); 107 } 108 109 llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { 110 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 111 return cast<llvm::DefInit>(node->getOperator()) 112 ->getDef() 113 ->getValueAsString("expression"); 114 } 115 116 llvm::StringRef tblgen::DagNode::getSymbol() const { 117 return node->getNameStr(); 118 } 119 120 Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const { 121 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 122 auto it = mapper->find(opDef); 123 if (it != mapper->end()) 124 return *it->second; 125 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef)) 126 .first->second; 127 } 128 129 int tblgen::DagNode::getNumOps() const { 130 int count = isReplaceWithValue() ? 0 : 1; 131 for (int i = 0, e = getNumArgs(); i != e; ++i) { 132 if (auto child = getArgAsNestedDag(i)) 133 count += child.getNumOps(); 134 } 135 return count; 136 } 137 138 int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); } 139 140 bool tblgen::DagNode::isNestedDagArg(unsigned index) const { 141 return isa<llvm::DagInit>(node->getArg(index)); 142 } 143 144 tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const { 145 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); 146 } 147 148 tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const { 149 assert(!isNestedDagArg(index)); 150 return DagLeaf(node->getArg(index)); 151 } 152 153 StringRef tblgen::DagNode::getArgName(unsigned index) const { 154 return node->getArgNameStr(index); 155 } 156 157 bool tblgen::DagNode::isReplaceWithValue() const { 158 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 159 return dagOpDef->getName() == "replaceWithValue"; 160 } 161 162 bool tblgen::DagNode::isLocationDirective() const { 163 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 164 return dagOpDef->getName() == "location"; 165 } 166 167 void tblgen::DagNode::print(raw_ostream &os) const { 168 if (node) 169 node->print(os); 170 } 171 172 //===----------------------------------------------------------------------===// 173 // SymbolInfoMap 174 //===----------------------------------------------------------------------===// 175 176 StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol, 177 int *index) { 178 StringRef name, indexStr; 179 int idx = -1; 180 std::tie(name, indexStr) = symbol.rsplit("__"); 181 182 if (indexStr.consumeInteger(10, idx)) { 183 // The second part is not an index; we return the whole symbol as-is. 184 return symbol; 185 } 186 if (index) { 187 *index = idx; 188 } 189 return name; 190 } 191 192 tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, 193 SymbolInfo::Kind kind, 194 Optional<int> index) 195 : op(op), kind(kind), argIndex(index) {} 196 197 int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const { 198 switch (kind) { 199 case Kind::Attr: 200 case Kind::Operand: 201 case Kind::Value: 202 return 1; 203 case Kind::Result: 204 return op->getNumResults(); 205 } 206 llvm_unreachable("unknown kind"); 207 } 208 209 std::string 210 tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { 211 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); 212 switch (kind) { 213 case Kind::Attr: { 214 auto type = 215 op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType(); 216 return std::string(formatv("{0} {1};\n", type, name)); 217 } 218 case Kind::Operand: { 219 // Use operand range for captured operands (to support potential variadic 220 // operands). 221 return std::string( 222 formatv("Operation::operand_range {0}(op0->getOperands());\n", name)); 223 } 224 case Kind::Value: { 225 return std::string(formatv("ArrayRef<Value> {0};\n", name)); 226 } 227 case Kind::Result: { 228 // Use the op itself for captured results. 229 return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name)); 230 } 231 } 232 llvm_unreachable("unknown kind"); 233 } 234 235 std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse( 236 StringRef name, int index, const char *fmt, const char *separator) const { 237 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); 238 switch (kind) { 239 case Kind::Attr: { 240 assert(index < 0); 241 auto repl = formatv(fmt, name); 242 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); 243 return std::string(repl); 244 } 245 case Kind::Operand: { 246 assert(index < 0); 247 auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>(); 248 // If this operand is variadic, then return a range. Otherwise, return the 249 // value itself. 250 if (operand->isVariadic()) { 251 auto repl = formatv(fmt, name); 252 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); 253 return std::string(repl); 254 } 255 auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); 256 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); 257 return std::string(repl); 258 } 259 case Kind::Result: { 260 // If `index` is greater than zero, then we are referencing a specific 261 // result of a multi-result op. The result can still be variadic. 262 if (index >= 0) { 263 std::string v = 264 std::string(formatv("{0}.getODSResults({1})", name, index)); 265 if (!op->getResult(index).isVariadic()) 266 v = std::string(formatv("(*{0}.begin())", v)); 267 auto repl = formatv(fmt, v); 268 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 269 return std::string(repl); 270 } 271 272 // If this op has no result at all but still we bind a symbol to it, it 273 // means we want to capture the op itself. 274 if (op->getNumResults() == 0) { 275 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); 276 return std::string(name); 277 } 278 279 // We are referencing all results of the multi-result op. A specific result 280 // can either be a value or a range. Then join them with `separator`. 281 SmallVector<std::string, 4> values; 282 values.reserve(op->getNumResults()); 283 284 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 285 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i)); 286 if (!op->getResult(i).isVariadic()) { 287 v = std::string(formatv("(*{0}.begin())", v)); 288 } 289 values.push_back(std::string(formatv(fmt, v))); 290 } 291 auto repl = llvm::join(values, separator); 292 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 293 return repl; 294 } 295 case Kind::Value: { 296 assert(index < 0); 297 assert(op == nullptr); 298 auto repl = formatv(fmt, name); 299 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 300 return std::string(repl); 301 } 302 } 303 llvm_unreachable("unknown kind"); 304 } 305 306 std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse( 307 StringRef name, int index, const char *fmt, const char *separator) const { 308 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); 309 switch (kind) { 310 case Kind::Attr: 311 case Kind::Operand: { 312 assert(index < 0 && "only allowed for symbol bound to result"); 313 auto repl = formatv(fmt, name); 314 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); 315 return std::string(repl); 316 } 317 case Kind::Result: { 318 if (index >= 0) { 319 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); 320 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 321 return std::string(repl); 322 } 323 324 // We are referencing all results of the multi-result op. Each result should 325 // have a value range, and then join them with `separator`. 326 SmallVector<std::string, 4> values; 327 values.reserve(op->getNumResults()); 328 329 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 330 values.push_back(std::string( 331 formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); 332 } 333 auto repl = llvm::join(values, separator); 334 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 335 return repl; 336 } 337 case Kind::Value: { 338 assert(index < 0 && "only allowed for symbol bound to result"); 339 assert(op == nullptr); 340 auto repl = formatv(fmt, formatv("{{{0}}", name)); 341 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 342 return std::string(repl); 343 } 344 } 345 llvm_unreachable("unknown kind"); 346 } 347 348 bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, 349 int argIndex) { 350 StringRef name = getValuePackName(symbol); 351 if (name != symbol) { 352 auto error = formatv( 353 "symbol '{0}' with trailing index cannot bind to op argument", symbol); 354 PrintFatalError(loc, error); 355 } 356 357 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>() 358 ? SymbolInfo::getAttr(&op, argIndex) 359 : SymbolInfo::getOperand(&op, argIndex); 360 361 return symbolInfoMap.insert({symbol, symInfo}).second; 362 } 363 364 bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { 365 StringRef name = getValuePackName(symbol); 366 return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second; 367 } 368 369 bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) { 370 return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second; 371 } 372 373 bool tblgen::SymbolInfoMap::contains(StringRef symbol) const { 374 return find(symbol) != symbolInfoMap.end(); 375 } 376 377 tblgen::SymbolInfoMap::const_iterator 378 tblgen::SymbolInfoMap::find(StringRef key) const { 379 StringRef name = getValuePackName(key); 380 return symbolInfoMap.find(name); 381 } 382 383 int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const { 384 StringRef name = getValuePackName(symbol); 385 if (name != symbol) { 386 // If there is a trailing index inside symbol, it references just one 387 // static value. 388 return 1; 389 } 390 // Otherwise, find how many it represents by querying the symbol's info. 391 return find(name)->getValue().getStaticValueCount(); 392 } 393 394 std::string 395 tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt, 396 const char *separator) const { 397 int index = -1; 398 StringRef name = getValuePackName(symbol, &index); 399 400 auto it = symbolInfoMap.find(name); 401 if (it == symbolInfoMap.end()) { 402 auto error = formatv("referencing unbound symbol '{0}'", symbol); 403 PrintFatalError(loc, error); 404 } 405 406 return it->getValue().getValueAndRangeUse(name, index, fmt, separator); 407 } 408 409 std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol, 410 const char *fmt, 411 const char *separator) const { 412 int index = -1; 413 StringRef name = getValuePackName(symbol, &index); 414 415 auto it = symbolInfoMap.find(name); 416 if (it == symbolInfoMap.end()) { 417 auto error = formatv("referencing unbound symbol '{0}'", symbol); 418 PrintFatalError(loc, error); 419 } 420 421 return it->getValue().getAllRangeUse(name, index, fmt, separator); 422 } 423 424 //===----------------------------------------------------------------------===// 425 // Pattern 426 //==----------------------------------------------------------------------===// 427 428 tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) 429 : def(*def), recordOpMap(mapper) {} 430 431 tblgen::DagNode tblgen::Pattern::getSourcePattern() const { 432 return tblgen::DagNode(def.getValueAsDag("sourcePattern")); 433 } 434 435 int tblgen::Pattern::getNumResultPatterns() const { 436 auto *results = def.getValueAsListInit("resultPatterns"); 437 return results->size(); 438 } 439 440 tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { 441 auto *results = def.getValueAsListInit("resultPatterns"); 442 return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index))); 443 } 444 445 void tblgen::Pattern::collectSourcePatternBoundSymbols( 446 tblgen::SymbolInfoMap &infoMap) { 447 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); 448 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); 449 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); 450 } 451 452 void tblgen::Pattern::collectResultPatternBoundSymbols( 453 tblgen::SymbolInfoMap &infoMap) { 454 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); 455 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { 456 auto pattern = getResultPattern(i); 457 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); 458 } 459 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); 460 } 461 462 const tblgen::Operator &tblgen::Pattern::getSourceRootOp() { 463 return getSourcePattern().getDialectOp(recordOpMap); 464 } 465 466 tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) { 467 return node.getDialectOp(recordOpMap); 468 } 469 470 std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const { 471 auto *listInit = def.getValueAsListInit("constraints"); 472 std::vector<tblgen::AppliedConstraint> ret; 473 ret.reserve(listInit->size()); 474 475 for (auto it : *listInit) { 476 auto *dagInit = dyn_cast<llvm::DagInit>(it); 477 if (!dagInit) 478 PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity " 479 "constraints should be DAG nodes"); 480 481 std::vector<std::string> entities; 482 entities.reserve(dagInit->arg_size()); 483 for (auto *argName : dagInit->getArgNames()) { 484 if (!argName) { 485 PrintFatalError( 486 def.getLoc(), 487 "operands to additional constraints can only be symbol references"); 488 } 489 entities.push_back(std::string(argName->getValue())); 490 } 491 492 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), 493 dagInit->getNameStr(), std::move(entities)); 494 } 495 return ret; 496 } 497 498 int tblgen::Pattern::getBenefit() const { 499 // The initial benefit value is a heuristic with number of ops in the source 500 // pattern. 501 int initBenefit = getSourcePattern().getNumOps(); 502 llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); 503 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { 504 PrintFatalError(def.getLoc(), 505 "The 'addBenefit' takes and only takes one integer value"); 506 } 507 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); 508 } 509 510 std::vector<tblgen::Pattern::IdentifierLine> 511 tblgen::Pattern::getLocation() const { 512 std::vector<std::pair<StringRef, unsigned>> result; 513 result.reserve(def.getLoc().size()); 514 for (auto loc : def.getLoc()) { 515 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 516 assert(buf && "invalid source location"); 517 result.emplace_back( 518 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 519 llvm::SrcMgr.getLineAndColumn(loc, buf).first); 520 } 521 return result; 522 } 523 524 void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 525 bool isSrcPattern) { 526 auto treeName = tree.getSymbol(); 527 if (!tree.isOperation()) { 528 if (!treeName.empty()) { 529 PrintFatalError( 530 def.getLoc(), 531 formatv("binding symbol '{0}' to non-operation unsupported right now", 532 treeName)); 533 } 534 return; 535 } 536 537 auto &op = getDialectOp(tree); 538 auto numOpArgs = op.getNumArgs(); 539 auto numTreeArgs = tree.getNumArgs(); 540 541 // The pattern might have the last argument specifying the location. 542 bool hasLocDirective = false; 543 if (numTreeArgs != 0) { 544 if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) 545 hasLocDirective = lastArg.isLocationDirective(); 546 } 547 548 if (numOpArgs != numTreeArgs - hasLocDirective) { 549 auto err = formatv("op '{0}' argument number mismatch: " 550 "{1} in pattern vs. {2} in definition", 551 op.getOperationName(), numTreeArgs, numOpArgs); 552 PrintFatalError(def.getLoc(), err); 553 } 554 555 // The name attached to the DAG node's operator is for representing the 556 // results generated from this op. It should be remembered as bound results. 557 if (!treeName.empty()) { 558 LLVM_DEBUG(llvm::dbgs() 559 << "found symbol bound to op result: " << treeName << '\n'); 560 if (!infoMap.bindOpResult(treeName, op)) 561 PrintFatalError(def.getLoc(), 562 formatv("symbol '{0}' bound more than once", treeName)); 563 } 564 565 for (int i = 0; i != numTreeArgs; ++i) { 566 if (auto treeArg = tree.getArgAsNestedDag(i)) { 567 // This DAG node argument is a DAG node itself. Go inside recursively. 568 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 569 } else if (isSrcPattern) { 570 // We can only bind symbols to op arguments in source pattern. Those 571 // symbols are referenced in result patterns. 572 auto treeArgName = tree.getArgName(i); 573 // `$_` is a special symbol meaning ignore the current argument. 574 if (!treeArgName.empty() && treeArgName != "_") { 575 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " 576 << treeArgName << '\n'); 577 if (!infoMap.bindOpArgument(treeArgName, op, i)) { 578 auto err = formatv("symbol '{0}' bound more than once", treeArgName); 579 PrintFatalError(def.getLoc(), err); 580 } 581 } 582 } 583 } 584 } 585