1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 10 // Pattern. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/Pattern.h" 15 #include "llvm/ADT/StringExtras.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Support/Debug.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 #define DEBUG_TYPE "mlir-tblgen-pattern" 23 24 using namespace mlir; 25 using namespace tblgen; 26 27 using llvm::formatv; 28 29 //===----------------------------------------------------------------------===// 30 // DagLeaf 31 //===----------------------------------------------------------------------===// 32 33 bool DagLeaf::isUnspecified() const { 34 return dyn_cast_or_null<llvm::UnsetInit>(def); 35 } 36 37 bool DagLeaf::isOperandMatcher() const { 38 // Operand matchers specify a type constraint. 39 return isSubClassOf("TypeConstraint"); 40 } 41 42 bool DagLeaf::isAttrMatcher() const { 43 // Attribute matchers specify an attribute constraint. 44 return isSubClassOf("AttrConstraint"); 45 } 46 47 bool DagLeaf::isNativeCodeCall() const { 48 return isSubClassOf("NativeCodeCall"); 49 } 50 51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } 52 53 bool DagLeaf::isEnumAttrCase() const { 54 return isSubClassOf("EnumAttrCaseInfo"); 55 } 56 57 bool DagLeaf::isStringAttr() const { 58 return isa<llvm::StringInit>(def); 59 } 60 61 Constraint DagLeaf::getAsConstraint() const { 62 assert((isOperandMatcher() || isAttrMatcher()) && 63 "the DAG leaf must be operand or attribute"); 64 return Constraint(cast<llvm::DefInit>(def)->getDef()); 65 } 66 67 ConstantAttr DagLeaf::getAsConstantAttr() const { 68 assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 69 return ConstantAttr(cast<llvm::DefInit>(def)); 70 } 71 72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const { 73 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 74 return EnumAttrCase(cast<llvm::DefInit>(def)); 75 } 76 77 std::string DagLeaf::getConditionTemplate() const { 78 return getAsConstraint().getConditionTemplate(); 79 } 80 81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const { 82 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 83 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); 84 } 85 86 int DagLeaf::getNumReturnsOfNativeCode() const { 87 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 88 return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns"); 89 } 90 91 std::string DagLeaf::getStringAttr() const { 92 assert(isStringAttr() && "the DAG leaf must be string attribute"); 93 return def->getAsUnquotedString(); 94 } 95 bool DagLeaf::isSubClassOf(StringRef superclass) const { 96 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) 97 return defInit->getDef()->isSubClassOf(superclass); 98 return false; 99 } 100 101 void DagLeaf::print(raw_ostream &os) const { 102 if (def) 103 def->print(os); 104 } 105 106 //===----------------------------------------------------------------------===// 107 // DagNode 108 //===----------------------------------------------------------------------===// 109 110 bool DagNode::isNativeCodeCall() const { 111 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) 112 return defInit->getDef()->isSubClassOf("NativeCodeCall"); 113 return false; 114 } 115 116 bool DagNode::isOperation() const { 117 return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective(); 118 } 119 120 llvm::StringRef DagNode::getNativeCodeTemplate() const { 121 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 122 return cast<llvm::DefInit>(node->getOperator()) 123 ->getDef() 124 ->getValueAsString("expression"); 125 } 126 127 int DagNode::getNumReturnsOfNativeCode() const { 128 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 129 return cast<llvm::DefInit>(node->getOperator()) 130 ->getDef() 131 ->getValueAsInt("numReturns"); 132 } 133 134 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); } 135 136 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { 137 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 138 auto it = mapper->find(opDef); 139 if (it != mapper->end()) 140 return *it->second; 141 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef)) 142 .first->second; 143 } 144 145 int DagNode::getNumOps() const { 146 int count = isReplaceWithValue() ? 0 : 1; 147 for (int i = 0, e = getNumArgs(); i != e; ++i) { 148 if (auto child = getArgAsNestedDag(i)) 149 count += child.getNumOps(); 150 } 151 return count; 152 } 153 154 int DagNode::getNumArgs() const { return node->getNumArgs(); } 155 156 bool DagNode::isNestedDagArg(unsigned index) const { 157 return isa<llvm::DagInit>(node->getArg(index)); 158 } 159 160 DagNode DagNode::getArgAsNestedDag(unsigned index) const { 161 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); 162 } 163 164 DagLeaf DagNode::getArgAsLeaf(unsigned index) const { 165 assert(!isNestedDagArg(index)); 166 return DagLeaf(node->getArg(index)); 167 } 168 169 StringRef DagNode::getArgName(unsigned index) const { 170 return node->getArgNameStr(index); 171 } 172 173 bool DagNode::isReplaceWithValue() const { 174 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 175 return dagOpDef->getName() == "replaceWithValue"; 176 } 177 178 bool DagNode::isLocationDirective() const { 179 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 180 return dagOpDef->getName() == "location"; 181 } 182 183 void DagNode::print(raw_ostream &os) const { 184 if (node) 185 node->print(os); 186 } 187 188 //===----------------------------------------------------------------------===// 189 // SymbolInfoMap 190 //===----------------------------------------------------------------------===// 191 192 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) { 193 StringRef name, indexStr; 194 int idx = -1; 195 std::tie(name, indexStr) = symbol.rsplit("__"); 196 197 if (indexStr.consumeInteger(10, idx)) { 198 // The second part is not an index; we return the whole symbol as-is. 199 return symbol; 200 } 201 if (index) { 202 *index = idx; 203 } 204 return name; 205 } 206 207 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind, 208 Optional<DagAndConstant> dagAndConstant) 209 : op(op), kind(kind), dagAndConstant(dagAndConstant) {} 210 211 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { 212 switch (kind) { 213 case Kind::Attr: 214 case Kind::Operand: 215 case Kind::Value: 216 return 1; 217 case Kind::Result: 218 return op->getNumResults(); 219 case Kind::MultipleValues: 220 return getSize(); 221 } 222 llvm_unreachable("unknown kind"); 223 } 224 225 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const { 226 return alternativeName.hasValue() ? alternativeName.getValue() : name.str(); 227 } 228 229 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { 230 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); 231 switch (kind) { 232 case Kind::Attr: { 233 if (op) { 234 auto type = op->getArg(getArgIndex()) 235 .get<NamedAttribute *>() 236 ->attr.getStorageType(); 237 return std::string(formatv("{0} {1};\n", type, name)); 238 } 239 // TODO(suderman): Use a more exact type when available. 240 return std::string(formatv("Attribute {0};\n", name)); 241 } 242 case Kind::Operand: { 243 // Use operand range for captured operands (to support potential variadic 244 // operands). 245 return std::string( 246 formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n", 247 getVarName(name))); 248 } 249 case Kind::Value: { 250 return std::string(formatv("::mlir::Value {0};\n", name)); 251 } 252 case Kind::MultipleValues: { 253 // This is for the variable used in the source pattern. Each named value in 254 // source pattern will only be bound to a Value. The others in the result 255 // pattern may be associated with multiple Values as we will use `auto` to 256 // do the type inference. 257 return std::string(formatv( 258 "::mlir::Value {0}_raw; ::mlir::ValueRange {0}({0}_raw);\n", name)); 259 } 260 case Kind::Result: { 261 // Use the op itself for captured results. 262 return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name)); 263 } 264 } 265 llvm_unreachable("unknown kind"); 266 } 267 268 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( 269 StringRef name, int index, const char *fmt, const char *separator) const { 270 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); 271 switch (kind) { 272 case Kind::Attr: { 273 assert(index < 0); 274 auto repl = formatv(fmt, name); 275 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); 276 return std::string(repl); 277 } 278 case Kind::Operand: { 279 assert(index < 0); 280 auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>(); 281 // If this operand is variadic, then return a range. Otherwise, return the 282 // value itself. 283 if (operand->isVariableLength()) { 284 auto repl = formatv(fmt, name); 285 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); 286 return std::string(repl); 287 } 288 auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); 289 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); 290 return std::string(repl); 291 } 292 case Kind::Result: { 293 // If `index` is greater than zero, then we are referencing a specific 294 // result of a multi-result op. The result can still be variadic. 295 if (index >= 0) { 296 std::string v = 297 std::string(formatv("{0}.getODSResults({1})", name, index)); 298 if (!op->getResult(index).isVariadic()) 299 v = std::string(formatv("(*{0}.begin())", v)); 300 auto repl = formatv(fmt, v); 301 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 302 return std::string(repl); 303 } 304 305 // If this op has no result at all but still we bind a symbol to it, it 306 // means we want to capture the op itself. 307 if (op->getNumResults() == 0) { 308 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); 309 return std::string(name); 310 } 311 312 // We are referencing all results of the multi-result op. A specific result 313 // can either be a value or a range. Then join them with `separator`. 314 SmallVector<std::string, 4> values; 315 values.reserve(op->getNumResults()); 316 317 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 318 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i)); 319 if (!op->getResult(i).isVariadic()) { 320 v = std::string(formatv("(*{0}.begin())", v)); 321 } 322 values.push_back(std::string(formatv(fmt, v))); 323 } 324 auto repl = llvm::join(values, separator); 325 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 326 return repl; 327 } 328 case Kind::Value: { 329 assert(index < 0); 330 assert(op == nullptr); 331 auto repl = formatv(fmt, name); 332 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 333 return std::string(repl); 334 } 335 case Kind::MultipleValues: { 336 assert(op == nullptr); 337 assert(index < getSize()); 338 if (index >= 0) { 339 std::string repl = 340 formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); 341 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 342 return repl; 343 } 344 // If it doesn't specify certain element, unpack them all. 345 auto repl = 346 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); 347 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 348 return std::string(repl); 349 } 350 } 351 llvm_unreachable("unknown kind"); 352 } 353 354 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( 355 StringRef name, int index, const char *fmt, const char *separator) const { 356 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); 357 switch (kind) { 358 case Kind::Attr: 359 case Kind::Operand: { 360 assert(index < 0 && "only allowed for symbol bound to result"); 361 auto repl = formatv(fmt, name); 362 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); 363 return std::string(repl); 364 } 365 case Kind::Result: { 366 if (index >= 0) { 367 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); 368 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 369 return std::string(repl); 370 } 371 372 // We are referencing all results of the multi-result op. Each result should 373 // have a value range, and then join them with `separator`. 374 SmallVector<std::string, 4> values; 375 values.reserve(op->getNumResults()); 376 377 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 378 values.push_back(std::string( 379 formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); 380 } 381 auto repl = llvm::join(values, separator); 382 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 383 return repl; 384 } 385 case Kind::Value: { 386 assert(index < 0 && "only allowed for symbol bound to result"); 387 assert(op == nullptr); 388 auto repl = formatv(fmt, formatv("{{{0}}", name)); 389 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 390 return std::string(repl); 391 } 392 case Kind::MultipleValues: { 393 assert(op == nullptr); 394 assert(index < getSize()); 395 if (index >= 0) { 396 std::string repl = 397 formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); 398 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 399 return repl; 400 } 401 auto repl = 402 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); 403 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); 404 return std::string(repl); 405 } 406 } 407 llvm_unreachable("unknown kind"); 408 } 409 410 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, 411 const Operator &op, int argIndex) { 412 StringRef name = getValuePackName(symbol); 413 if (name != symbol) { 414 auto error = formatv( 415 "symbol '{0}' with trailing index cannot bind to op argument", symbol); 416 PrintFatalError(loc, error); 417 } 418 419 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>() 420 ? SymbolInfo::getAttr(&op, argIndex) 421 : SymbolInfo::getOperand(node, &op, argIndex); 422 423 std::string key = symbol.str(); 424 if (symbolInfoMap.count(key)) { 425 // Only non unique name for the operand is supported. 426 if (symInfo.kind != SymbolInfo::Kind::Operand) { 427 return false; 428 } 429 430 // Cannot add new operand if there is already non operand with the same 431 // name. 432 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) { 433 return false; 434 } 435 } 436 437 symbolInfoMap.emplace(key, symInfo); 438 return true; 439 } 440 441 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { 442 std::string name = getValuePackName(symbol).str(); 443 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op)); 444 445 return symbolInfoMap.count(inserted->first) == 1; 446 } 447 448 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) { 449 std::string name = getValuePackName(symbol).str(); 450 if (numValues > 1) 451 return bindMultipleValues(name, numValues); 452 return bindValue(name); 453 } 454 455 bool SymbolInfoMap::bindValue(StringRef symbol) { 456 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue()); 457 return symbolInfoMap.count(inserted->first) == 1; 458 } 459 460 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) { 461 std::string name = getValuePackName(symbol).str(); 462 auto inserted = 463 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues)); 464 return symbolInfoMap.count(inserted->first) == 1; 465 } 466 467 bool SymbolInfoMap::bindAttr(StringRef symbol) { 468 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr()); 469 return symbolInfoMap.count(inserted->first) == 1; 470 } 471 472 bool SymbolInfoMap::contains(StringRef symbol) const { 473 return find(symbol) != symbolInfoMap.end(); 474 } 475 476 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { 477 std::string name = getValuePackName(key).str(); 478 479 return symbolInfoMap.find(name); 480 } 481 482 SymbolInfoMap::const_iterator 483 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op, 484 int argIndex) const { 485 std::string name = getValuePackName(key).str(); 486 auto range = symbolInfoMap.equal_range(name); 487 488 const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex); 489 490 for (auto it = range.first; it != range.second; ++it) 491 if (it->second.dagAndConstant == symbolInfo.dagAndConstant) 492 return it; 493 494 return symbolInfoMap.end(); 495 } 496 497 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator> 498 SymbolInfoMap::getRangeOfEqualElements(StringRef key) { 499 std::string name = getValuePackName(key).str(); 500 501 return symbolInfoMap.equal_range(name); 502 } 503 504 int SymbolInfoMap::count(StringRef key) const { 505 std::string name = getValuePackName(key).str(); 506 return symbolInfoMap.count(name); 507 } 508 509 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { 510 StringRef name = getValuePackName(symbol); 511 if (name != symbol) { 512 // If there is a trailing index inside symbol, it references just one 513 // static value. 514 return 1; 515 } 516 // Otherwise, find how many it represents by querying the symbol's info. 517 return find(name)->second.getStaticValueCount(); 518 } 519 520 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, 521 const char *fmt, 522 const char *separator) const { 523 int index = -1; 524 StringRef name = getValuePackName(symbol, &index); 525 526 auto it = symbolInfoMap.find(name.str()); 527 if (it == symbolInfoMap.end()) { 528 auto error = formatv("referencing unbound symbol '{0}'", symbol); 529 PrintFatalError(loc, error); 530 } 531 532 return it->second.getValueAndRangeUse(name, index, fmt, separator); 533 } 534 535 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, 536 const char *separator) const { 537 int index = -1; 538 StringRef name = getValuePackName(symbol, &index); 539 540 auto it = symbolInfoMap.find(name.str()); 541 if (it == symbolInfoMap.end()) { 542 auto error = formatv("referencing unbound symbol '{0}'", symbol); 543 PrintFatalError(loc, error); 544 } 545 546 return it->second.getAllRangeUse(name, index, fmt, separator); 547 } 548 549 void SymbolInfoMap::assignUniqueAlternativeNames() { 550 llvm::StringSet<> usedNames; 551 552 for (auto symbolInfoIt = symbolInfoMap.begin(); 553 symbolInfoIt != symbolInfoMap.end();) { 554 auto range = symbolInfoMap.equal_range(symbolInfoIt->first); 555 auto startRange = range.first; 556 auto endRange = range.second; 557 558 auto operandName = symbolInfoIt->first; 559 int startSearchIndex = 0; 560 for (++startRange; startRange != endRange; ++startRange) { 561 // Current operand name is not unique, find a unique one 562 // and set the alternative name. 563 for (int i = startSearchIndex;; ++i) { 564 std::string alternativeName = operandName + std::to_string(i); 565 if (!usedNames.contains(alternativeName) && 566 symbolInfoMap.count(alternativeName) == 0) { 567 usedNames.insert(alternativeName); 568 startRange->second.alternativeName = alternativeName; 569 startSearchIndex = i + 1; 570 571 break; 572 } 573 } 574 } 575 576 symbolInfoIt = endRange; 577 } 578 } 579 580 //===----------------------------------------------------------------------===// 581 // Pattern 582 //==----------------------------------------------------------------------===// 583 584 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) 585 : def(*def), recordOpMap(mapper) {} 586 587 DagNode Pattern::getSourcePattern() const { 588 return DagNode(def.getValueAsDag("sourcePattern")); 589 } 590 591 int Pattern::getNumResultPatterns() const { 592 auto *results = def.getValueAsListInit("resultPatterns"); 593 return results->size(); 594 } 595 596 DagNode Pattern::getResultPattern(unsigned index) const { 597 auto *results = def.getValueAsListInit("resultPatterns"); 598 return DagNode(cast<llvm::DagInit>(results->getElement(index))); 599 } 600 601 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { 602 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); 603 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); 604 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); 605 606 LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n"); 607 infoMap.assignUniqueAlternativeNames(); 608 LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n"); 609 } 610 611 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { 612 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); 613 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { 614 auto pattern = getResultPattern(i); 615 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); 616 } 617 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); 618 } 619 620 const Operator &Pattern::getSourceRootOp() { 621 return getSourcePattern().getDialectOp(recordOpMap); 622 } 623 624 Operator &Pattern::getDialectOp(DagNode node) { 625 return node.getDialectOp(recordOpMap); 626 } 627 628 std::vector<AppliedConstraint> Pattern::getConstraints() const { 629 auto *listInit = def.getValueAsListInit("constraints"); 630 std::vector<AppliedConstraint> ret; 631 ret.reserve(listInit->size()); 632 633 for (auto it : *listInit) { 634 auto *dagInit = dyn_cast<llvm::DagInit>(it); 635 if (!dagInit) 636 PrintFatalError(&def, "all elements in Pattern multi-entity " 637 "constraints should be DAG nodes"); 638 639 std::vector<std::string> entities; 640 entities.reserve(dagInit->arg_size()); 641 for (auto *argName : dagInit->getArgNames()) { 642 if (!argName) { 643 PrintFatalError( 644 &def, 645 "operands to additional constraints can only be symbol references"); 646 } 647 entities.push_back(std::string(argName->getValue())); 648 } 649 650 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), 651 dagInit->getNameStr(), std::move(entities)); 652 } 653 return ret; 654 } 655 656 int Pattern::getBenefit() const { 657 // The initial benefit value is a heuristic with number of ops in the source 658 // pattern. 659 int initBenefit = getSourcePattern().getNumOps(); 660 llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); 661 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { 662 PrintFatalError(&def, 663 "The 'addBenefit' takes and only takes one integer value"); 664 } 665 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); 666 } 667 668 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const { 669 std::vector<std::pair<StringRef, unsigned>> result; 670 result.reserve(def.getLoc().size()); 671 for (auto loc : def.getLoc()) { 672 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 673 assert(buf && "invalid source location"); 674 result.emplace_back( 675 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 676 llvm::SrcMgr.getLineAndColumn(loc, buf).first); 677 } 678 return result; 679 } 680 681 void Pattern::verifyBind(bool result, StringRef symbolName) { 682 if (!result) { 683 auto err = formatv("symbol '{0}' bound more than once", symbolName); 684 PrintFatalError(&def, err); 685 } 686 } 687 688 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 689 bool isSrcPattern) { 690 auto treeName = tree.getSymbol(); 691 auto numTreeArgs = tree.getNumArgs(); 692 693 if (tree.isNativeCodeCall()) { 694 if (!treeName.empty()) { 695 if (!isSrcPattern) { 696 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: " 697 << treeName << '\n'); 698 verifyBind( 699 infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()), 700 treeName); 701 } else { 702 PrintFatalError(&def, 703 formatv("binding symbol '{0}' to NativecodeCall in " 704 "MatchPattern is not supported", 705 treeName)); 706 } 707 } 708 709 for (int i = 0; i != numTreeArgs; ++i) { 710 if (auto treeArg = tree.getArgAsNestedDag(i)) { 711 // This DAG node argument is a DAG node itself. Go inside recursively. 712 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 713 continue; 714 } 715 716 if (!isSrcPattern) 717 continue; 718 719 // We can only bind symbols to arguments in source pattern. Those 720 // symbols are referenced in result patterns. 721 auto treeArgName = tree.getArgName(i); 722 723 // `$_` is a special symbol meaning ignore the current argument. 724 if (!treeArgName.empty() && treeArgName != "_") { 725 DagLeaf leaf = tree.getArgAsLeaf(i); 726 727 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c), 728 if (leaf.isUnspecified()) { 729 // This is case of $c, a Value without any constraints. 730 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 731 } else { 732 auto constraint = leaf.getAsConstraint(); 733 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || 734 leaf.isConstantAttr() || 735 constraint.getKind() == Constraint::Kind::CK_Attr; 736 737 if (isAttr) { 738 // This is case of $a, a binding to a certain attribute. 739 verifyBind(infoMap.bindAttr(treeArgName), treeArgName); 740 continue; 741 } 742 743 // This is case of $b, a binding to a certain type. 744 verifyBind(infoMap.bindValue(treeArgName), treeArgName); 745 } 746 } 747 } 748 749 return; 750 } 751 752 if (tree.isOperation()) { 753 auto &op = getDialectOp(tree); 754 auto numOpArgs = op.getNumArgs(); 755 756 // The pattern might have the last argument specifying the location. 757 bool hasLocDirective = false; 758 if (numTreeArgs != 0) { 759 if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) 760 hasLocDirective = lastArg.isLocationDirective(); 761 } 762 763 if (numOpArgs != numTreeArgs - hasLocDirective) { 764 auto err = formatv("op '{0}' argument number mismatch: " 765 "{1} in pattern vs. {2} in definition", 766 op.getOperationName(), numTreeArgs, numOpArgs); 767 PrintFatalError(&def, err); 768 } 769 770 // The name attached to the DAG node's operator is for representing the 771 // results generated from this op. It should be remembered as bound results. 772 if (!treeName.empty()) { 773 LLVM_DEBUG(llvm::dbgs() 774 << "found symbol bound to op result: " << treeName << '\n'); 775 verifyBind(infoMap.bindOpResult(treeName, op), treeName); 776 } 777 778 for (int i = 0; i != numTreeArgs; ++i) { 779 if (auto treeArg = tree.getArgAsNestedDag(i)) { 780 // This DAG node argument is a DAG node itself. Go inside recursively. 781 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 782 continue; 783 } 784 785 if (isSrcPattern) { 786 // We can only bind symbols to op arguments in source pattern. Those 787 // symbols are referenced in result patterns. 788 auto treeArgName = tree.getArgName(i); 789 // `$_` is a special symbol meaning ignore the current argument. 790 if (!treeArgName.empty() && treeArgName != "_") { 791 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " 792 << treeArgName << '\n'); 793 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i), 794 treeArgName); 795 } 796 } 797 } 798 return; 799 } 800 801 if (!treeName.empty()) { 802 PrintFatalError( 803 &def, formatv("binding symbol '{0}' to non-operation/native code call " 804 "unsupported right now", 805 treeName)); 806 } 807 return; 808 } 809