1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 10 // Pattern. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/Pattern.h" 15 #include "llvm/ADT/StringExtras.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Support/Debug.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 #define DEBUG_TYPE "mlir-tblgen-pattern" 23 24 using namespace mlir; 25 using namespace tblgen; 26 27 using llvm::formatv; 28 29 //===----------------------------------------------------------------------===// 30 // DagLeaf 31 //===----------------------------------------------------------------------===// 32 33 bool DagLeaf::isUnspecified() const { 34 return dyn_cast_or_null<llvm::UnsetInit>(def); 35 } 36 37 bool DagLeaf::isOperandMatcher() const { 38 // Operand matchers specify a type constraint. 39 return isSubClassOf("TypeConstraint"); 40 } 41 42 bool DagLeaf::isAttrMatcher() const { 43 // Attribute matchers specify an attribute constraint. 44 return isSubClassOf("AttrConstraint"); 45 } 46 47 bool DagLeaf::isNativeCodeCall() const { 48 return isSubClassOf("NativeCodeCall"); 49 } 50 51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } 52 53 bool DagLeaf::isEnumAttrCase() const { 54 return isSubClassOf("EnumAttrCaseInfo"); 55 } 56 57 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); } 58 59 Constraint DagLeaf::getAsConstraint() const { 60 assert((isOperandMatcher() || isAttrMatcher()) && 61 "the DAG leaf must be operand or attribute"); 62 return Constraint(cast<llvm::DefInit>(def)->getDef()); 63 } 64 65 ConstantAttr DagLeaf::getAsConstantAttr() const { 66 assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 67 return ConstantAttr(cast<llvm::DefInit>(def)); 68 } 69 70 EnumAttrCase DagLeaf::getAsEnumAttrCase() const { 71 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 72 return EnumAttrCase(cast<llvm::DefInit>(def)); 73 } 74 75 std::string DagLeaf::getConditionTemplate() const { 76 return getAsConstraint().getConditionTemplate(); 77 } 78 79 llvm::StringRef DagLeaf::getNativeCodeTemplate() const { 80 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 81 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); 82 } 83 84 int DagLeaf::getNumReturnsOfNativeCode() const { 85 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 86 return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns"); 87 } 88 89 std::string DagLeaf::getStringAttr() const { 90 assert(isStringAttr() && "the DAG leaf must be string attribute"); 91 return def->getAsUnquotedString(); 92 } 93 bool DagLeaf::isSubClassOf(StringRef superclass) const { 94 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) 95 return defInit->getDef()->isSubClassOf(superclass); 96 return false; 97 } 98 99 void DagLeaf::print(raw_ostream &os) const { 100 if (def) 101 def->print(os); 102 } 103 104 //===----------------------------------------------------------------------===// 105 // DagNode 106 //===----------------------------------------------------------------------===// 107 108 bool DagNode::isNativeCodeCall() const { 109 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) 110 return defInit->getDef()->isSubClassOf("NativeCodeCall"); 111 return false; 112 } 113 114 bool DagNode::isOperation() const { 115 return !isNativeCodeCall() && !isReplaceWithValue() && 116 !isLocationDirective() && !isReturnTypeDirective(); 117 } 118 119 llvm::StringRef DagNode::getNativeCodeTemplate() const { 120 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 121 return cast<llvm::DefInit>(node->getOperator()) 122 ->getDef() 123 ->getValueAsString("expression"); 124 } 125 126 int DagNode::getNumReturnsOfNativeCode() const { 127 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 128 return cast<llvm::DefInit>(node->getOperator()) 129 ->getDef() 130 ->getValueAsInt("numReturns"); 131 } 132 133 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); } 134 135 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { 136 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 137 auto it = mapper->find(opDef); 138 if (it != mapper->end()) 139 return *it->second; 140 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef)) 141 .first->second; 142 } 143 144 int DagNode::getNumOps() const { 145 int count = isReplaceWithValue() ? 0 : 1; 146 for (int i = 0, e = getNumArgs(); i != e; ++i) { 147 if (auto child = getArgAsNestedDag(i)) 148 count += child.getNumOps(); 149 } 150 return count; 151 } 152 153 int DagNode::getNumArgs() const { return node->getNumArgs(); } 154 155 bool DagNode::isNestedDagArg(unsigned index) const { 156 return isa<llvm::DagInit>(node->getArg(index)); 157 } 158 159 DagNode DagNode::getArgAsNestedDag(unsigned index) const { 160 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); 161 } 162 163 DagLeaf DagNode::getArgAsLeaf(unsigned index) const { 164 assert(!isNestedDagArg(index)); 165 return DagLeaf(node->getArg(index)); 166 } 167 168 StringRef DagNode::getArgName(unsigned index) const { 169 return node->getArgNameStr(index); 170 } 171 172 bool DagNode::isReplaceWithValue() const { 173 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 174 return dagOpDef->getName() == "replaceWithValue"; 175 } 176 177 bool DagNode::isLocationDirective() const { 178 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 179 return dagOpDef->getName() == "location"; 180 } 181 182 bool DagNode::isReturnTypeDirective() const { 183 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 184 return dagOpDef->getName() == "returnType"; 185 } 186 187 void DagNode::print(raw_ostream &os) const { 188 if (node) 189 node->print(os); 190 } 191 192 //===----------------------------------------------------------------------===// 193 // SymbolInfoMap 194 //===----------------------------------------------------------------------===// 195 196 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) { 197 StringRef name, indexStr; 198 int idx = -1; 199 std::tie(name, indexStr) = symbol.rsplit("__"); 200 201 if (indexStr.consumeInteger(10, idx)) { 202 // The second part is not an index; we return the whole symbol as-is. 203 return symbol; 204 } 205 if (index) { 206 *index = idx; 207 } 208 return name; 209 } 210 211 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind, 212 Optional<DagAndConstant> dagAndConstant) 213 : op(op), kind(kind), dagAndConstant(dagAndConstant) {} 214 215 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { 216 switch (kind) { 217 case Kind::Attr: 218 case Kind::Operand: 219 case Kind::Value: 220 return 1; 221 case Kind::Result: 222 return op->getNumResults(); 223 case Kind::MultipleValues: 224 return getSize(); 225 } 226 llvm_unreachable("unknown kind"); 227 } 228 229 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const { 230 return alternativeName.hasValue() ? alternativeName.getValue() : name.str(); 231 } 232 233 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const { 234 LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': "); 235 switch (kind) { 236 case Kind::Attr: { 237 if (op) 238 return op->getArg(getArgIndex()) 239 .get<NamedAttribute *>() 240 ->attr.getStorageType() 241 .str(); 242 // TODO(suderman): Use a more exact type when available. 243 return "Attribute"; 244 } 245 case Kind::Operand: { 246 // Use operand range for captured operands (to support potential variadic 247 // operands). 248 return "::mlir::Operation::operand_range"; 249 } 250 case Kind::Value: { 251 return "::mlir::Value"; 252 } 253 case Kind::MultipleValues: { 254 return "::mlir::ValueRange"; 255 } 256 case Kind::Result: { 257 // Use the op itself for captured results. 258 return op->getQualCppClassName(); 259 } 260 } 261 llvm_unreachable("unknown kind"); 262 } 263 264 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { 265 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); 266 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : ""; 267 return std::string( 268 formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit)); 269 } 270 271 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const { 272 LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': "); 273 return std::string( 274 formatv("{0} &{1}", getVarTypeStr(name), getVarName(name))); 275 } 276 277 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( 278 StringRef name, int index, const char *fmt, const char *separator) const { 279 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); 280 switch (kind) { 281 case Kind::Attr: { 282 assert(index < 0); 283 auto repl = formatv(fmt, name); 284 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); 285 return std::string(repl); 286 } 287 case Kind::Operand: { 288 assert(index < 0); 289 auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>(); 290 // If this operand is variadic, then return a range. Otherwise, return the 291 // value itself. 292 if (operand->isVariableLength()) { 293 auto repl = formatv(fmt, name); 294 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); 295 return std::string(repl); 296 } 297 auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); 298 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); 299 return std::string(repl); 300 } 301 case Kind::Result: { 302 // If `index` is greater than zero, then we are referencing a specific 303 // result of a multi-result op. The result can still be variadic. 304 if (index >= 0) { 305 std::string v = 306 std::string(formatv("{0}.getODSResults({1})", name, index)); 307 if (!op->getResult(index).isVariadic()) 308 v = std::string(formatv("(*{0}.begin())", v)); 309 auto repl = formatv(fmt, v); 310 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 311 return std::string(repl); 312 } 313 314 // If this op has no result at all but still we bind a symbol to it, it 315 // means we want to capture the op itself. 316 if (op->getNumResults() == 0) { 317 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); 318 return std::string(name); 319 } 320 321 // We are referencing all results of the multi-result op. A specific result 322 // can either be a value or a range. Then join them with `separator`. 323 SmallVector<std::string, 4> values; 324 values.reserve(op->getNumResults()); 325 326 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 327 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i)); 328 if (!op->getResult(i).isVariadic()) { 329 v = std::string(formatv("(*{0}.begin())", v)); 330 } 331 values.push_back(std::string(formatv(fmt, v))); 332 } 333 auto repl = llvm::join(values, separator); 334 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 335 return repl; 336 } 337 case Kind::Value: { 338 assert(index < 0); 339 assert(op == nullptr); 340 auto repl = formatv(fmt, name); 341 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 342 return std::string(repl); 343 } 344 case Kind::MultipleValues: { 345 assert(op == nullptr); 346 assert(index < getSize()); 347 if (index >= 0) { 348 std::string repl = 349 formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); 350 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 351 return repl; 352 } 353 // If it doesn't specify certain element, unpack them all. 354 auto repl = 355 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); 356 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 357 return std::string(repl); 358 } 359 } 360 llvm_unreachable("unknown kind"); 361 } 362 363 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( 364 StringRef name, int index, const char *fmt, const char *separator) const { 365 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); 366 switch (kind) { 367 case Kind::Attr: 368 case Kind::Operand: { 369 assert(index < 0 && "only allowed for symbol bound to result"); 370 auto repl = formatv(fmt, name); 371 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); 372 return std::string(repl); 373 } 374 case Kind::Result: { 375 if (index >= 0) { 376 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); 377 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 378 return std::string(repl); 379 } 380 381 // We are referencing all results of the multi-result op. Each result should 382 // have a value range, and then join them with `separator`. 383 SmallVector<std::string, 4> values; 384 values.reserve(op->getNumResults()); 385 386 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 387 values.push_back(std::string( 388 formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); 389 } 390 auto repl = llvm::join(values, separator); 391 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 392 return repl; 393 } 394 case Kind::Value: { 395 assert(index < 0 && "only allowed for symbol bound to result"); 396 assert(op == nullptr); 397 auto repl = formatv(fmt, formatv("{{{0}}", name)); 398 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 399 return std::string(repl); 400 } 401 case Kind::MultipleValues: { 402 assert(op == nullptr); 403 assert(index < getSize()); 404 if (index >= 0) { 405 std::string repl = 406 formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); 407 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 408 return repl; 409 } 410 auto repl = 411 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); 412 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 413 return std::string(repl); 414 } 415 } 416 llvm_unreachable("unknown kind"); 417 } 418 419 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, 420 const Operator &op, int argIndex) { 421 StringRef name = getValuePackName(symbol); 422 if (name != symbol) { 423 auto error = formatv( 424 "symbol '{0}' with trailing index cannot bind to op argument", symbol); 425 PrintFatalError(loc, error); 426 } 427 428 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>() 429 ? SymbolInfo::getAttr(&op, argIndex) 430 : SymbolInfo::getOperand(node, &op, argIndex); 431 432 std::string key = symbol.str(); 433 if (symbolInfoMap.count(key)) { 434 // Only non unique name for the operand is supported. 435 if (symInfo.kind != SymbolInfo::Kind::Operand) { 436 return false; 437 } 438 439 // Cannot add new operand if there is already non operand with the same 440 // name. 441 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) { 442 return false; 443 } 444 } 445 446 symbolInfoMap.emplace(key, symInfo); 447 return true; 448 } 449 450 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { 451 std::string name = getValuePackName(symbol).str(); 452 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op)); 453 454 return symbolInfoMap.count(inserted->first) == 1; 455 } 456 457 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) { 458 std::string name = getValuePackName(symbol).str(); 459 if (numValues > 1) 460 return bindMultipleValues(name, numValues); 461 return bindValue(name); 462 } 463 464 bool SymbolInfoMap::bindValue(StringRef symbol) { 465 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue()); 466 return symbolInfoMap.count(inserted->first) == 1; 467 } 468 469 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) { 470 std::string name = getValuePackName(symbol).str(); 471 auto inserted = 472 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues)); 473 return symbolInfoMap.count(inserted->first) == 1; 474 } 475 476 bool SymbolInfoMap::bindAttr(StringRef symbol) { 477 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr()); 478 return symbolInfoMap.count(inserted->first) == 1; 479 } 480 481 bool SymbolInfoMap::contains(StringRef symbol) const { 482 return find(symbol) != symbolInfoMap.end(); 483 } 484 485 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { 486 std::string name = getValuePackName(key).str(); 487 488 return symbolInfoMap.find(name); 489 } 490 491 SymbolInfoMap::const_iterator 492 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op, 493 int argIndex) const { 494 return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex)); 495 } 496 497 SymbolInfoMap::const_iterator 498 SymbolInfoMap::findBoundSymbol(StringRef key, SymbolInfo symbolInfo) const { 499 std::string name = getValuePackName(key).str(); 500 auto range = symbolInfoMap.equal_range(name); 501 502 for (auto it = range.first; it != range.second; ++it) 503 if (it->second.dagAndConstant == symbolInfo.dagAndConstant) 504 return it; 505 506 return symbolInfoMap.end(); 507 } 508 509 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator> 510 SymbolInfoMap::getRangeOfEqualElements(StringRef key) { 511 std::string name = getValuePackName(key).str(); 512 513 return symbolInfoMap.equal_range(name); 514 } 515 516 int SymbolInfoMap::count(StringRef key) const { 517 std::string name = getValuePackName(key).str(); 518 return symbolInfoMap.count(name); 519 } 520 521 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { 522 StringRef name = getValuePackName(symbol); 523 if (name != symbol) { 524 // If there is a trailing index inside symbol, it references just one 525 // static value. 526 return 1; 527 } 528 // Otherwise, find how many it represents by querying the symbol's info. 529 return find(name)->second.getStaticValueCount(); 530 } 531 532 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, 533 const char *fmt, 534 const char *separator) const { 535 int index = -1; 536 StringRef name = getValuePackName(symbol, &index); 537 538 auto it = symbolInfoMap.find(name.str()); 539 if (it == symbolInfoMap.end()) { 540 auto error = formatv("referencing unbound symbol '{0}'", symbol); 541 PrintFatalError(loc, error); 542 } 543 544 return it->second.getValueAndRangeUse(name, index, fmt, separator); 545 } 546 547 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, 548 const char *separator) const { 549 int index = -1; 550 StringRef name = getValuePackName(symbol, &index); 551 552 auto it = symbolInfoMap.find(name.str()); 553 if (it == symbolInfoMap.end()) { 554 auto error = formatv("referencing unbound symbol '{0}'", symbol); 555 PrintFatalError(loc, error); 556 } 557 558 return it->second.getAllRangeUse(name, index, fmt, separator); 559 } 560 561 void SymbolInfoMap::assignUniqueAlternativeNames() { 562 llvm::StringSet<> usedNames; 563 564 for (auto symbolInfoIt = symbolInfoMap.begin(); 565 symbolInfoIt != symbolInfoMap.end();) { 566 auto range = symbolInfoMap.equal_range(symbolInfoIt->first); 567 auto startRange = range.first; 568 auto endRange = range.second; 569 570 auto operandName = symbolInfoIt->first; 571 int startSearchIndex = 0; 572 for (++startRange; startRange != endRange; ++startRange) { 573 // Current operand name is not unique, find a unique one 574 // and set the alternative name. 575 for (int i = startSearchIndex;; ++i) { 576 std::string alternativeName = operandName + std::to_string(i); 577 if (!usedNames.contains(alternativeName) && 578 symbolInfoMap.count(alternativeName) == 0) { 579 usedNames.insert(alternativeName); 580 startRange->second.alternativeName = alternativeName; 581 startSearchIndex = i + 1; 582 583 break; 584 } 585 } 586 } 587 588 symbolInfoIt = endRange; 589 } 590 } 591 592 //===----------------------------------------------------------------------===// 593 // Pattern 594 //==----------------------------------------------------------------------===// 595 596 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) 597 : def(*def), recordOpMap(mapper) {} 598 599 DagNode Pattern::getSourcePattern() const { 600 return DagNode(def.getValueAsDag("sourcePattern")); 601 } 602 603 int Pattern::getNumResultPatterns() const { 604 auto *results = def.getValueAsListInit("resultPatterns"); 605 return results->size(); 606 } 607 608 DagNode Pattern::getResultPattern(unsigned index) const { 609 auto *results = def.getValueAsListInit("resultPatterns"); 610 return DagNode(cast<llvm::DagInit>(results->getElement(index))); 611 } 612 613 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { 614 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); 615 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); 616 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); 617 618 LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n"); 619 infoMap.assignUniqueAlternativeNames(); 620 LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n"); 621 } 622 623 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { 624 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); 625 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { 626 auto pattern = getResultPattern(i); 627 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); 628 } 629 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); 630 } 631 632 const Operator &Pattern::getSourceRootOp() { 633 return getSourcePattern().getDialectOp(recordOpMap); 634 } 635 636 Operator &Pattern::getDialectOp(DagNode node) { 637 return node.getDialectOp(recordOpMap); 638 } 639 640 std::vector<AppliedConstraint> Pattern::getConstraints() const { 641 auto *listInit = def.getValueAsListInit("constraints"); 642 std::vector<AppliedConstraint> ret; 643 ret.reserve(listInit->size()); 644 645 for (auto it : *listInit) { 646 auto *dagInit = dyn_cast<llvm::DagInit>(it); 647 if (!dagInit) 648 PrintFatalError(&def, "all elements in Pattern multi-entity " 649 "constraints should be DAG nodes"); 650 651 std::vector<std::string> entities; 652 entities.reserve(dagInit->arg_size()); 653 for (auto *argName : dagInit->getArgNames()) { 654 if (!argName) { 655 PrintFatalError( 656 &def, 657 "operands to additional constraints can only be symbol references"); 658 } 659 entities.push_back(std::string(argName->getValue())); 660 } 661 662 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), 663 dagInit->getNameStr(), std::move(entities)); 664 } 665 return ret; 666 } 667 668 int Pattern::getBenefit() const { 669 // The initial benefit value is a heuristic with number of ops in the source 670 // pattern. 671 int initBenefit = getSourcePattern().getNumOps(); 672 llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); 673 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { 674 PrintFatalError(&def, 675 "The 'addBenefit' takes and only takes one integer value"); 676 } 677 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); 678 } 679 680 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const { 681 std::vector<std::pair<StringRef, unsigned>> result; 682 result.reserve(def.getLoc().size()); 683 for (auto loc : def.getLoc()) { 684 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 685 assert(buf && "invalid source location"); 686 result.emplace_back( 687 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 688 llvm::SrcMgr.getLineAndColumn(loc, buf).first); 689 } 690 return result; 691 } 692 693 void Pattern::verifyBind(bool result, StringRef symbolName) { 694 if (!result) { 695 auto err = formatv("symbol '{0}' bound more than once", symbolName); 696 PrintFatalError(&def, err); 697 } 698 } 699 700 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 701 bool isSrcPattern) { 702 auto treeName = tree.getSymbol(); 703 auto numTreeArgs = tree.getNumArgs(); 704 705 if (tree.isNativeCodeCall()) { 706 if (!treeName.empty()) { 707 if (!isSrcPattern) { 708 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: " 709 << treeName << '\n'); 710 verifyBind( 711 infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()), 712 treeName); 713 } else { 714 PrintFatalError(&def, 715 formatv("binding symbol '{0}' to NativecodeCall in " 716 "MatchPattern is not supported", 717 treeName)); 718 } 719 } 720 721 for (int i = 0; i != numTreeArgs; ++i) { 722 if (auto treeArg = tree.getArgAsNestedDag(i)) { 723 // This DAG node argument is a DAG node itself. Go inside recursively. 724 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 725 continue; 726 } 727 728 if (!isSrcPattern) 729 continue; 730 731 // We can only bind symbols to arguments in source pattern. Those 732 // symbols are referenced in result patterns. 733 auto treeArgName = tree.getArgName(i); 734 735 // `$_` is a special symbol meaning ignore the current argument. 736 if (!treeArgName.empty() && treeArgName != "_") { 737 DagLeaf leaf = tree.getArgAsLeaf(i); 738 739 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c), 740 if (leaf.isUnspecified()) { 741 // This is case of $c, a Value without any constraints. 742 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 743 } else { 744 auto constraint = leaf.getAsConstraint(); 745 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || 746 leaf.isConstantAttr() || 747 constraint.getKind() == Constraint::Kind::CK_Attr; 748 749 if (isAttr) { 750 // This is case of $a, a binding to a certain attribute. 751 verifyBind(infoMap.bindAttr(treeArgName), treeArgName); 752 continue; 753 } 754 755 // This is case of $b, a binding to a certain type. 756 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 757 } 758 } 759 } 760 761 return; 762 } 763 764 if (tree.isOperation()) { 765 auto &op = getDialectOp(tree); 766 auto numOpArgs = op.getNumArgs(); 767 768 // The pattern might have trailing directives. 769 int numDirectives = 0; 770 for (int i = numTreeArgs - 1; i >= 0; --i) { 771 if (auto dagArg = tree.getArgAsNestedDag(i)) { 772 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective()) 773 ++numDirectives; 774 else 775 break; 776 } 777 } 778 779 if (numOpArgs != numTreeArgs - numDirectives) { 780 auto err = formatv("op '{0}' argument number mismatch: " 781 "{1} in pattern vs. {2} in definition", 782 op.getOperationName(), numTreeArgs, numOpArgs); 783 PrintFatalError(&def, err); 784 } 785 786 // The name attached to the DAG node's operator is for representing the 787 // results generated from this op. It should be remembered as bound results. 788 if (!treeName.empty()) { 789 LLVM_DEBUG(llvm::dbgs() 790 << "found symbol bound to op result: " << treeName << '\n'); 791 verifyBind(infoMap.bindOpResult(treeName, op), treeName); 792 } 793 794 for (int i = 0; i != numTreeArgs; ++i) { 795 if (auto treeArg = tree.getArgAsNestedDag(i)) { 796 // This DAG node argument is a DAG node itself. Go inside recursively. 797 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 798 continue; 799 } 800 801 if (isSrcPattern) { 802 // We can only bind symbols to op arguments in source pattern. Those 803 // symbols are referenced in result patterns. 804 auto treeArgName = tree.getArgName(i); 805 // `$_` is a special symbol meaning ignore the current argument. 806 if (!treeArgName.empty() && treeArgName != "_") { 807 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " 808 << treeArgName << '\n'); 809 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i), 810 treeArgName); 811 } 812 } 813 } 814 return; 815 } 816 817 if (!treeName.empty()) { 818 PrintFatalError( 819 &def, formatv("binding symbol '{0}' to non-operation/native code call " 820 "unsupported right now", 821 treeName)); 822 } 823 return; 824 } 825