1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 10 // Pattern. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/Pattern.h" 15 #include "llvm/ADT/StringExtras.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Support/Debug.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 #define DEBUG_TYPE "mlir-tblgen-pattern" 23 24 using namespace mlir; 25 using namespace tblgen; 26 27 using llvm::formatv; 28 29 //===----------------------------------------------------------------------===// 30 // DagLeaf 31 //===----------------------------------------------------------------------===// 32 33 bool DagLeaf::isUnspecified() const { 34 return dyn_cast_or_null<llvm::UnsetInit>(def); 35 } 36 37 bool DagLeaf::isOperandMatcher() const { 38 // Operand matchers specify a type constraint. 39 return isSubClassOf("TypeConstraint"); 40 } 41 42 bool DagLeaf::isAttrMatcher() const { 43 // Attribute matchers specify an attribute constraint. 44 return isSubClassOf("AttrConstraint"); 45 } 46 47 bool DagLeaf::isNativeCodeCall() const { 48 return isSubClassOf("NativeCodeCall"); 49 } 50 51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } 52 53 bool DagLeaf::isEnumAttrCase() const { 54 return isSubClassOf("EnumAttrCaseInfo"); 55 } 56 57 bool DagLeaf::isStringAttr() const { 58 return isa<llvm::StringInit>(def); 59 } 60 61 Constraint DagLeaf::getAsConstraint() const { 62 assert((isOperandMatcher() || isAttrMatcher()) && 63 "the DAG leaf must be operand or attribute"); 64 return Constraint(cast<llvm::DefInit>(def)->getDef()); 65 } 66 67 ConstantAttr DagLeaf::getAsConstantAttr() const { 68 assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 69 return ConstantAttr(cast<llvm::DefInit>(def)); 70 } 71 72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const { 73 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 74 return EnumAttrCase(cast<llvm::DefInit>(def)); 75 } 76 77 std::string DagLeaf::getConditionTemplate() const { 78 return getAsConstraint().getConditionTemplate(); 79 } 80 81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const { 82 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 83 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); 84 } 85 86 std::string DagLeaf::getStringAttr() const { 87 assert(isStringAttr() && "the DAG leaf must be string attribute"); 88 return def->getAsUnquotedString(); 89 } 90 bool DagLeaf::isSubClassOf(StringRef superclass) const { 91 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) 92 return defInit->getDef()->isSubClassOf(superclass); 93 return false; 94 } 95 96 void DagLeaf::print(raw_ostream &os) const { 97 if (def) 98 def->print(os); 99 } 100 101 //===----------------------------------------------------------------------===// 102 // DagNode 103 //===----------------------------------------------------------------------===// 104 105 bool DagNode::isNativeCodeCall() const { 106 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) 107 return defInit->getDef()->isSubClassOf("NativeCodeCall"); 108 return false; 109 } 110 111 bool DagNode::isOperation() const { 112 return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective(); 113 } 114 115 llvm::StringRef DagNode::getNativeCodeTemplate() const { 116 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 117 return cast<llvm::DefInit>(node->getOperator()) 118 ->getDef() 119 ->getValueAsString("expression"); 120 } 121 122 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); } 123 124 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { 125 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 126 auto it = mapper->find(opDef); 127 if (it != mapper->end()) 128 return *it->second; 129 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef)) 130 .first->second; 131 } 132 133 int DagNode::getNumOps() const { 134 int count = isReplaceWithValue() ? 0 : 1; 135 for (int i = 0, e = getNumArgs(); i != e; ++i) { 136 if (auto child = getArgAsNestedDag(i)) 137 count += child.getNumOps(); 138 } 139 return count; 140 } 141 142 int DagNode::getNumArgs() const { return node->getNumArgs(); } 143 144 bool DagNode::isNestedDagArg(unsigned index) const { 145 return isa<llvm::DagInit>(node->getArg(index)); 146 } 147 148 DagNode DagNode::getArgAsNestedDag(unsigned index) const { 149 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); 150 } 151 152 DagLeaf DagNode::getArgAsLeaf(unsigned index) const { 153 assert(!isNestedDagArg(index)); 154 return DagLeaf(node->getArg(index)); 155 } 156 157 StringRef DagNode::getArgName(unsigned index) const { 158 return node->getArgNameStr(index); 159 } 160 161 bool DagNode::isReplaceWithValue() const { 162 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 163 return dagOpDef->getName() == "replaceWithValue"; 164 } 165 166 bool DagNode::isLocationDirective() const { 167 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 168 return dagOpDef->getName() == "location"; 169 } 170 171 void DagNode::print(raw_ostream &os) const { 172 if (node) 173 node->print(os); 174 } 175 176 //===----------------------------------------------------------------------===// 177 // SymbolInfoMap 178 //===----------------------------------------------------------------------===// 179 180 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) { 181 StringRef name, indexStr; 182 int idx = -1; 183 std::tie(name, indexStr) = symbol.rsplit("__"); 184 185 if (indexStr.consumeInteger(10, idx)) { 186 // The second part is not an index; we return the whole symbol as-is. 187 return symbol; 188 } 189 if (index) { 190 *index = idx; 191 } 192 return name; 193 } 194 195 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind, 196 Optional<DagAndIndex> dagAndIndex) 197 : op(op), kind(kind), dagAndIndex(dagAndIndex) {} 198 199 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { 200 switch (kind) { 201 case Kind::Attr: 202 case Kind::Operand: 203 case Kind::Value: 204 return 1; 205 case Kind::Result: 206 return op->getNumResults(); 207 } 208 llvm_unreachable("unknown kind"); 209 } 210 211 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const { 212 return alternativeName.hasValue() ? alternativeName.getValue() : name.str(); 213 } 214 215 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { 216 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); 217 switch (kind) { 218 case Kind::Attr: { 219 if (op) { 220 auto type = op->getArg((*dagAndIndex).second) 221 .get<NamedAttribute *>() 222 ->attr.getStorageType(); 223 return std::string(formatv("{0} {1};\n", type, name)); 224 } 225 // TODO(suderman): Use a more exact type when available. 226 return std::string(formatv("Attribute {0};\n", name)); 227 } 228 case Kind::Operand: { 229 // Use operand range for captured operands (to support potential variadic 230 // operands). 231 return std::string( 232 formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n", 233 getVarName(name))); 234 } 235 case Kind::Value: { 236 return std::string(formatv("::mlir::Value {0};\n", name)); 237 } 238 case Kind::Result: { 239 // Use the op itself for captured results. 240 return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name)); 241 } 242 } 243 llvm_unreachable("unknown kind"); 244 } 245 246 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( 247 StringRef name, int index, const char *fmt, const char *separator) const { 248 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); 249 switch (kind) { 250 case Kind::Attr: { 251 assert(index < 0); 252 auto repl = formatv(fmt, name); 253 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); 254 return std::string(repl); 255 } 256 case Kind::Operand: { 257 assert(index < 0); 258 auto *operand = 259 op->getArg((*dagAndIndex).second).get<NamedTypeConstraint *>(); 260 // If this operand is variadic, then return a range. Otherwise, return the 261 // value itself. 262 if (operand->isVariableLength()) { 263 auto repl = formatv(fmt, name); 264 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); 265 return std::string(repl); 266 } 267 auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); 268 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); 269 return std::string(repl); 270 } 271 case Kind::Result: { 272 // If `index` is greater than zero, then we are referencing a specific 273 // result of a multi-result op. The result can still be variadic. 274 if (index >= 0) { 275 std::string v = 276 std::string(formatv("{0}.getODSResults({1})", name, index)); 277 if (!op->getResult(index).isVariadic()) 278 v = std::string(formatv("(*{0}.begin())", v)); 279 auto repl = formatv(fmt, v); 280 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 281 return std::string(repl); 282 } 283 284 // If this op has no result at all but still we bind a symbol to it, it 285 // means we want to capture the op itself. 286 if (op->getNumResults() == 0) { 287 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); 288 return std::string(name); 289 } 290 291 // We are referencing all results of the multi-result op. A specific result 292 // can either be a value or a range. Then join them with `separator`. 293 SmallVector<std::string, 4> values; 294 values.reserve(op->getNumResults()); 295 296 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 297 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i)); 298 if (!op->getResult(i).isVariadic()) { 299 v = std::string(formatv("(*{0}.begin())", v)); 300 } 301 values.push_back(std::string(formatv(fmt, v))); 302 } 303 auto repl = llvm::join(values, separator); 304 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 305 return repl; 306 } 307 case Kind::Value: { 308 assert(index < 0); 309 assert(op == nullptr); 310 auto repl = formatv(fmt, name); 311 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 312 return std::string(repl); 313 } 314 } 315 llvm_unreachable("unknown kind"); 316 } 317 318 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( 319 StringRef name, int index, const char *fmt, const char *separator) const { 320 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); 321 switch (kind) { 322 case Kind::Attr: 323 case Kind::Operand: { 324 assert(index < 0 && "only allowed for symbol bound to result"); 325 auto repl = formatv(fmt, name); 326 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); 327 return std::string(repl); 328 } 329 case Kind::Result: { 330 if (index >= 0) { 331 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); 332 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 333 return std::string(repl); 334 } 335 336 // We are referencing all results of the multi-result op. Each result should 337 // have a value range, and then join them with `separator`. 338 SmallVector<std::string, 4> values; 339 values.reserve(op->getNumResults()); 340 341 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 342 values.push_back(std::string( 343 formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); 344 } 345 auto repl = llvm::join(values, separator); 346 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 347 return repl; 348 } 349 case Kind::Value: { 350 assert(index < 0 && "only allowed for symbol bound to result"); 351 assert(op == nullptr); 352 auto repl = formatv(fmt, formatv("{{{0}}", name)); 353 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 354 return std::string(repl); 355 } 356 } 357 llvm_unreachable("unknown kind"); 358 } 359 360 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, 361 const Operator &op, int argIndex) { 362 StringRef name = getValuePackName(symbol); 363 if (name != symbol) { 364 auto error = formatv( 365 "symbol '{0}' with trailing index cannot bind to op argument", symbol); 366 PrintFatalError(loc, error); 367 } 368 369 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>() 370 ? SymbolInfo::getAttr(&op, argIndex) 371 : SymbolInfo::getOperand(node, &op, argIndex); 372 373 std::string key = symbol.str(); 374 if (symbolInfoMap.count(key)) { 375 // Only non unique name for the operand is supported. 376 if (symInfo.kind != SymbolInfo::Kind::Operand) { 377 return false; 378 } 379 380 // Cannot add new operand if there is already non operand with the same 381 // name. 382 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) { 383 return false; 384 } 385 } 386 387 symbolInfoMap.emplace(key, symInfo); 388 return true; 389 } 390 391 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { 392 std::string name = getValuePackName(symbol).str(); 393 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op)); 394 395 return symbolInfoMap.count(inserted->first) == 1; 396 } 397 398 bool SymbolInfoMap::bindValue(StringRef symbol) { 399 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue()); 400 return symbolInfoMap.count(inserted->first) == 1; 401 } 402 403 bool SymbolInfoMap::bindAttr(StringRef symbol) { 404 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr()); 405 return symbolInfoMap.count(inserted->first) == 1; 406 } 407 408 bool SymbolInfoMap::contains(StringRef symbol) const { 409 return find(symbol) != symbolInfoMap.end(); 410 } 411 412 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { 413 std::string name = getValuePackName(key).str(); 414 415 return symbolInfoMap.find(name); 416 } 417 418 SymbolInfoMap::const_iterator 419 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op, 420 int argIndex) const { 421 std::string name = getValuePackName(key).str(); 422 auto range = symbolInfoMap.equal_range(name); 423 424 const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex); 425 426 for (auto it = range.first; it != range.second; ++it) { 427 if (it->second.dagAndIndex == symbolInfo.dagAndIndex) { 428 return it; 429 } 430 } 431 432 return symbolInfoMap.end(); 433 } 434 435 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator> 436 SymbolInfoMap::getRangeOfEqualElements(StringRef key) { 437 std::string name = getValuePackName(key).str(); 438 439 return symbolInfoMap.equal_range(name); 440 } 441 442 int SymbolInfoMap::count(StringRef key) const { 443 std::string name = getValuePackName(key).str(); 444 return symbolInfoMap.count(name); 445 } 446 447 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { 448 StringRef name = getValuePackName(symbol); 449 if (name != symbol) { 450 // If there is a trailing index inside symbol, it references just one 451 // static value. 452 return 1; 453 } 454 // Otherwise, find how many it represents by querying the symbol's info. 455 return find(name)->second.getStaticValueCount(); 456 } 457 458 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, 459 const char *fmt, 460 const char *separator) const { 461 int index = -1; 462 StringRef name = getValuePackName(symbol, &index); 463 464 auto it = symbolInfoMap.find(name.str()); 465 if (it == symbolInfoMap.end()) { 466 auto error = formatv("referencing unbound symbol '{0}'", symbol); 467 PrintFatalError(loc, error); 468 } 469 470 return it->second.getValueAndRangeUse(name, index, fmt, separator); 471 } 472 473 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, 474 const char *separator) const { 475 int index = -1; 476 StringRef name = getValuePackName(symbol, &index); 477 478 auto it = symbolInfoMap.find(name.str()); 479 if (it == symbolInfoMap.end()) { 480 auto error = formatv("referencing unbound symbol '{0}'", symbol); 481 PrintFatalError(loc, error); 482 } 483 484 return it->second.getAllRangeUse(name, index, fmt, separator); 485 } 486 487 void SymbolInfoMap::assignUniqueAlternativeNames() { 488 llvm::StringSet<> usedNames; 489 490 for (auto symbolInfoIt = symbolInfoMap.begin(); 491 symbolInfoIt != symbolInfoMap.end();) { 492 auto range = symbolInfoMap.equal_range(symbolInfoIt->first); 493 auto startRange = range.first; 494 auto endRange = range.second; 495 496 auto operandName = symbolInfoIt->first; 497 int startSearchIndex = 0; 498 for (++startRange; startRange != endRange; ++startRange) { 499 // Current operand name is not unique, find a unique one 500 // and set the alternative name. 501 for (int i = startSearchIndex;; ++i) { 502 std::string alternativeName = operandName + std::to_string(i); 503 if (!usedNames.contains(alternativeName) && 504 symbolInfoMap.count(alternativeName) == 0) { 505 usedNames.insert(alternativeName); 506 startRange->second.alternativeName = alternativeName; 507 startSearchIndex = i + 1; 508 509 break; 510 } 511 } 512 } 513 514 symbolInfoIt = endRange; 515 } 516 } 517 518 //===----------------------------------------------------------------------===// 519 // Pattern 520 //==----------------------------------------------------------------------===// 521 522 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) 523 : def(*def), recordOpMap(mapper) {} 524 525 DagNode Pattern::getSourcePattern() const { 526 return DagNode(def.getValueAsDag("sourcePattern")); 527 } 528 529 int Pattern::getNumResultPatterns() const { 530 auto *results = def.getValueAsListInit("resultPatterns"); 531 return results->size(); 532 } 533 534 DagNode Pattern::getResultPattern(unsigned index) const { 535 auto *results = def.getValueAsListInit("resultPatterns"); 536 return DagNode(cast<llvm::DagInit>(results->getElement(index))); 537 } 538 539 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { 540 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); 541 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); 542 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); 543 544 LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n"); 545 infoMap.assignUniqueAlternativeNames(); 546 LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n"); 547 } 548 549 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { 550 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); 551 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { 552 auto pattern = getResultPattern(i); 553 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); 554 } 555 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); 556 } 557 558 const Operator &Pattern::getSourceRootOp() { 559 return getSourcePattern().getDialectOp(recordOpMap); 560 } 561 562 Operator &Pattern::getDialectOp(DagNode node) { 563 return node.getDialectOp(recordOpMap); 564 } 565 566 std::vector<AppliedConstraint> Pattern::getConstraints() const { 567 auto *listInit = def.getValueAsListInit("constraints"); 568 std::vector<AppliedConstraint> ret; 569 ret.reserve(listInit->size()); 570 571 for (auto it : *listInit) { 572 auto *dagInit = dyn_cast<llvm::DagInit>(it); 573 if (!dagInit) 574 PrintFatalError(&def, "all elements in Pattern multi-entity " 575 "constraints should be DAG nodes"); 576 577 std::vector<std::string> entities; 578 entities.reserve(dagInit->arg_size()); 579 for (auto *argName : dagInit->getArgNames()) { 580 if (!argName) { 581 PrintFatalError( 582 &def, 583 "operands to additional constraints can only be symbol references"); 584 } 585 entities.push_back(std::string(argName->getValue())); 586 } 587 588 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), 589 dagInit->getNameStr(), std::move(entities)); 590 } 591 return ret; 592 } 593 594 int Pattern::getBenefit() const { 595 // The initial benefit value is a heuristic with number of ops in the source 596 // pattern. 597 int initBenefit = getSourcePattern().getNumOps(); 598 llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); 599 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { 600 PrintFatalError(&def, 601 "The 'addBenefit' takes and only takes one integer value"); 602 } 603 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); 604 } 605 606 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const { 607 std::vector<std::pair<StringRef, unsigned>> result; 608 result.reserve(def.getLoc().size()); 609 for (auto loc : def.getLoc()) { 610 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 611 assert(buf && "invalid source location"); 612 result.emplace_back( 613 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 614 llvm::SrcMgr.getLineAndColumn(loc, buf).first); 615 } 616 return result; 617 } 618 619 void Pattern::verifyBind(bool result, StringRef symbolName) { 620 if (!result) { 621 auto err = formatv("symbol '{0}' bound more than once", symbolName); 622 PrintFatalError(&def, err); 623 } 624 } 625 626 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 627 bool isSrcPattern) { 628 auto treeName = tree.getSymbol(); 629 auto numTreeArgs = tree.getNumArgs(); 630 631 if (tree.isNativeCodeCall()) { 632 if (!treeName.empty()) { 633 if (!isSrcPattern) { 634 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: " 635 << treeName << '\n'); 636 verifyBind(infoMap.bindValue(treeName), treeName); 637 } else { 638 PrintFatalError(&def, 639 formatv("binding symbol '{0}' to NativecodeCall in " 640 "MatchPattern is not supported", 641 treeName)); 642 } 643 } 644 645 for (int i = 0; i != numTreeArgs; ++i) { 646 if (auto treeArg = tree.getArgAsNestedDag(i)) { 647 // This DAG node argument is a DAG node itself. Go inside recursively. 648 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 649 continue; 650 } 651 652 if (!isSrcPattern) 653 continue; 654 655 // We can only bind symbols to arguments in source pattern. Those 656 // symbols are referenced in result patterns. 657 auto treeArgName = tree.getArgName(i); 658 659 // `$_` is a special symbol meaning ignore the current argument. 660 if (!treeArgName.empty() && treeArgName != "_") { 661 DagLeaf leaf = tree.getArgAsLeaf(i); 662 663 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c), 664 if (leaf.isUnspecified()) { 665 // This is case of $c, a Value without any constraints. 666 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 667 } else { 668 auto constraint = leaf.getAsConstraint(); 669 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || 670 leaf.isConstantAttr() || 671 constraint.getKind() == Constraint::Kind::CK_Attr; 672 673 if (isAttr) { 674 // This is case of $a, a binding to a certain attribute. 675 verifyBind(infoMap.bindAttr(treeArgName), treeArgName); 676 continue; 677 } 678 679 // This is case of $b, a binding to a certain type. 680 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 681 } 682 } 683 } 684 685 return; 686 } 687 688 if (tree.isOperation()) { 689 auto &op = getDialectOp(tree); 690 auto numOpArgs = op.getNumArgs(); 691 692 // The pattern might have the last argument specifying the location. 693 bool hasLocDirective = false; 694 if (numTreeArgs != 0) { 695 if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) 696 hasLocDirective = lastArg.isLocationDirective(); 697 } 698 699 if (numOpArgs != numTreeArgs - hasLocDirective) { 700 auto err = formatv("op '{0}' argument number mismatch: " 701 "{1} in pattern vs. {2} in definition", 702 op.getOperationName(), numTreeArgs, numOpArgs); 703 PrintFatalError(&def, err); 704 } 705 706 // The name attached to the DAG node's operator is for representing the 707 // results generated from this op. It should be remembered as bound results. 708 if (!treeName.empty()) { 709 LLVM_DEBUG(llvm::dbgs() 710 << "found symbol bound to op result: " << treeName << '\n'); 711 verifyBind(infoMap.bindOpResult(treeName, op), treeName); 712 } 713 714 for (int i = 0; i != numTreeArgs; ++i) { 715 if (auto treeArg = tree.getArgAsNestedDag(i)) { 716 // This DAG node argument is a DAG node itself. Go inside recursively. 717 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 718 continue; 719 } 720 721 if (isSrcPattern) { 722 // We can only bind symbols to op arguments in source pattern. Those 723 // symbols are referenced in result patterns. 724 auto treeArgName = tree.getArgName(i); 725 // `$_` is a special symbol meaning ignore the current argument. 726 if (!treeArgName.empty() && treeArgName != "_") { 727 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " 728 << treeArgName << '\n'); 729 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i), 730 treeArgName); 731 } 732 } 733 } 734 return; 735 } 736 737 if (!treeName.empty()) { 738 PrintFatalError( 739 &def, formatv("binding symbol '{0}' to non-operation/native code call " 740 "unsupported right now", 741 treeName)); 742 } 743 return; 744 } 745