1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 10 // Pattern. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/Pattern.h" 15 #include "llvm/ADT/StringExtras.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Support/Debug.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 #define DEBUG_TYPE "mlir-tblgen-pattern" 23 24 using namespace mlir; 25 26 using llvm::formatv; 27 using mlir::tblgen::Operator; 28 29 //===----------------------------------------------------------------------===// 30 // DagLeaf 31 //===----------------------------------------------------------------------===// 32 33 bool tblgen::DagLeaf::isUnspecified() const { 34 return dyn_cast_or_null<llvm::UnsetInit>(def); 35 } 36 37 bool tblgen::DagLeaf::isOperandMatcher() const { 38 // Operand matchers specify a type constraint. 39 return isSubClassOf("TypeConstraint"); 40 } 41 42 bool tblgen::DagLeaf::isAttrMatcher() const { 43 // Attribute matchers specify an attribute constraint. 44 return isSubClassOf("AttrConstraint"); 45 } 46 47 bool tblgen::DagLeaf::isNativeCodeCall() const { 48 return isSubClassOf("NativeCodeCall"); 49 } 50 51 bool tblgen::DagLeaf::isConstantAttr() const { 52 return isSubClassOf("ConstantAttr"); 53 } 54 55 bool tblgen::DagLeaf::isEnumAttrCase() const { 56 return isSubClassOf("EnumAttrCaseInfo"); 57 } 58 59 bool tblgen::DagLeaf::isStringAttr() const { 60 return isa<llvm::StringInit>(def) || isa<llvm::CodeInit>(def); 61 } 62 63 tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const { 64 assert((isOperandMatcher() || isAttrMatcher()) && 65 "the DAG leaf must be operand or attribute"); 66 return Constraint(cast<llvm::DefInit>(def)->getDef()); 67 } 68 69 tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const { 70 assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 71 return ConstantAttr(cast<llvm::DefInit>(def)); 72 } 73 74 tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const { 75 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 76 return EnumAttrCase(cast<llvm::DefInit>(def)); 77 } 78 79 std::string tblgen::DagLeaf::getConditionTemplate() const { 80 return getAsConstraint().getConditionTemplate(); 81 } 82 83 llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const { 84 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 85 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); 86 } 87 88 std::string tblgen::DagLeaf::getStringAttr() const { 89 assert(isStringAttr() && "the DAG leaf must be string attribute"); 90 return def->getAsUnquotedString(); 91 } 92 bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const { 93 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) 94 return defInit->getDef()->isSubClassOf(superclass); 95 return false; 96 } 97 98 void tblgen::DagLeaf::print(raw_ostream &os) const { 99 if (def) 100 def->print(os); 101 } 102 103 //===----------------------------------------------------------------------===// 104 // DagNode 105 //===----------------------------------------------------------------------===// 106 107 bool tblgen::DagNode::isNativeCodeCall() const { 108 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) 109 return defInit->getDef()->isSubClassOf("NativeCodeCall"); 110 return false; 111 } 112 113 bool tblgen::DagNode::isOperation() const { 114 return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective(); 115 } 116 117 llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { 118 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 119 return cast<llvm::DefInit>(node->getOperator()) 120 ->getDef() 121 ->getValueAsString("expression"); 122 } 123 124 llvm::StringRef tblgen::DagNode::getSymbol() const { 125 return node->getNameStr(); 126 } 127 128 Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const { 129 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 130 auto it = mapper->find(opDef); 131 if (it != mapper->end()) 132 return *it->second; 133 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef)) 134 .first->second; 135 } 136 137 int tblgen::DagNode::getNumOps() const { 138 int count = isReplaceWithValue() ? 0 : 1; 139 for (int i = 0, e = getNumArgs(); i != e; ++i) { 140 if (auto child = getArgAsNestedDag(i)) 141 count += child.getNumOps(); 142 } 143 return count; 144 } 145 146 int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); } 147 148 bool tblgen::DagNode::isNestedDagArg(unsigned index) const { 149 return isa<llvm::DagInit>(node->getArg(index)); 150 } 151 152 tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const { 153 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); 154 } 155 156 tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const { 157 assert(!isNestedDagArg(index)); 158 return DagLeaf(node->getArg(index)); 159 } 160 161 StringRef tblgen::DagNode::getArgName(unsigned index) const { 162 return node->getArgNameStr(index); 163 } 164 165 bool tblgen::DagNode::isReplaceWithValue() const { 166 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 167 return dagOpDef->getName() == "replaceWithValue"; 168 } 169 170 bool tblgen::DagNode::isLocationDirective() const { 171 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 172 return dagOpDef->getName() == "location"; 173 } 174 175 void tblgen::DagNode::print(raw_ostream &os) const { 176 if (node) 177 node->print(os); 178 } 179 180 //===----------------------------------------------------------------------===// 181 // SymbolInfoMap 182 //===----------------------------------------------------------------------===// 183 184 StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol, 185 int *index) { 186 StringRef name, indexStr; 187 int idx = -1; 188 std::tie(name, indexStr) = symbol.rsplit("__"); 189 190 if (indexStr.consumeInteger(10, idx)) { 191 // The second part is not an index; we return the whole symbol as-is. 192 return symbol; 193 } 194 if (index) { 195 *index = idx; 196 } 197 return name; 198 } 199 200 tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, 201 SymbolInfo::Kind kind, 202 Optional<int> index) 203 : op(op), kind(kind), argIndex(index) {} 204 205 int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const { 206 switch (kind) { 207 case Kind::Attr: 208 case Kind::Operand: 209 case Kind::Value: 210 return 1; 211 case Kind::Result: 212 return op->getNumResults(); 213 } 214 llvm_unreachable("unknown kind"); 215 } 216 217 std::string 218 tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { 219 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); 220 switch (kind) { 221 case Kind::Attr: { 222 auto type = 223 op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType(); 224 return std::string(formatv("{0} {1};\n", type, name)); 225 } 226 case Kind::Operand: { 227 // Use operand range for captured operands (to support potential variadic 228 // operands). 229 return std::string( 230 formatv("Operation::operand_range {0}(op0->getOperands());\n", name)); 231 } 232 case Kind::Value: { 233 return std::string(formatv("ArrayRef<Value> {0};\n", name)); 234 } 235 case Kind::Result: { 236 // Use the op itself for captured results. 237 return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name)); 238 } 239 } 240 llvm_unreachable("unknown kind"); 241 } 242 243 std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse( 244 StringRef name, int index, const char *fmt, const char *separator) const { 245 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); 246 switch (kind) { 247 case Kind::Attr: { 248 assert(index < 0); 249 auto repl = formatv(fmt, name); 250 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); 251 return std::string(repl); 252 } 253 case Kind::Operand: { 254 assert(index < 0); 255 auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>(); 256 // If this operand is variadic, then return a range. Otherwise, return the 257 // value itself. 258 if (operand->isVariableLength()) { 259 auto repl = formatv(fmt, name); 260 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); 261 return std::string(repl); 262 } 263 auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); 264 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); 265 return std::string(repl); 266 } 267 case Kind::Result: { 268 // If `index` is greater than zero, then we are referencing a specific 269 // result of a multi-result op. The result can still be variadic. 270 if (index >= 0) { 271 std::string v = 272 std::string(formatv("{0}.getODSResults({1})", name, index)); 273 if (!op->getResult(index).isVariadic()) 274 v = std::string(formatv("(*{0}.begin())", v)); 275 auto repl = formatv(fmt, v); 276 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 277 return std::string(repl); 278 } 279 280 // If this op has no result at all but still we bind a symbol to it, it 281 // means we want to capture the op itself. 282 if (op->getNumResults() == 0) { 283 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); 284 return std::string(name); 285 } 286 287 // We are referencing all results of the multi-result op. A specific result 288 // can either be a value or a range. Then join them with `separator`. 289 SmallVector<std::string, 4> values; 290 values.reserve(op->getNumResults()); 291 292 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 293 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i)); 294 if (!op->getResult(i).isVariadic()) { 295 v = std::string(formatv("(*{0}.begin())", v)); 296 } 297 values.push_back(std::string(formatv(fmt, v))); 298 } 299 auto repl = llvm::join(values, separator); 300 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 301 return repl; 302 } 303 case Kind::Value: { 304 assert(index < 0); 305 assert(op == nullptr); 306 auto repl = formatv(fmt, name); 307 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 308 return std::string(repl); 309 } 310 } 311 llvm_unreachable("unknown kind"); 312 } 313 314 std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse( 315 StringRef name, int index, const char *fmt, const char *separator) const { 316 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); 317 switch (kind) { 318 case Kind::Attr: 319 case Kind::Operand: { 320 assert(index < 0 && "only allowed for symbol bound to result"); 321 auto repl = formatv(fmt, name); 322 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); 323 return std::string(repl); 324 } 325 case Kind::Result: { 326 if (index >= 0) { 327 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); 328 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 329 return std::string(repl); 330 } 331 332 // We are referencing all results of the multi-result op. Each result should 333 // have a value range, and then join them with `separator`. 334 SmallVector<std::string, 4> values; 335 values.reserve(op->getNumResults()); 336 337 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 338 values.push_back(std::string( 339 formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); 340 } 341 auto repl = llvm::join(values, separator); 342 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 343 return repl; 344 } 345 case Kind::Value: { 346 assert(index < 0 && "only allowed for symbol bound to result"); 347 assert(op == nullptr); 348 auto repl = formatv(fmt, formatv("{{{0}}", name)); 349 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 350 return std::string(repl); 351 } 352 } 353 llvm_unreachable("unknown kind"); 354 } 355 356 bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, 357 int argIndex) { 358 StringRef name = getValuePackName(symbol); 359 if (name != symbol) { 360 auto error = formatv( 361 "symbol '{0}' with trailing index cannot bind to op argument", symbol); 362 PrintFatalError(loc, error); 363 } 364 365 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>() 366 ? SymbolInfo::getAttr(&op, argIndex) 367 : SymbolInfo::getOperand(&op, argIndex); 368 369 return symbolInfoMap.insert({symbol, symInfo}).second; 370 } 371 372 bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { 373 StringRef name = getValuePackName(symbol); 374 return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second; 375 } 376 377 bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) { 378 return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second; 379 } 380 381 bool tblgen::SymbolInfoMap::contains(StringRef symbol) const { 382 return find(symbol) != symbolInfoMap.end(); 383 } 384 385 tblgen::SymbolInfoMap::const_iterator 386 tblgen::SymbolInfoMap::find(StringRef key) const { 387 StringRef name = getValuePackName(key); 388 return symbolInfoMap.find(name); 389 } 390 391 int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const { 392 StringRef name = getValuePackName(symbol); 393 if (name != symbol) { 394 // If there is a trailing index inside symbol, it references just one 395 // static value. 396 return 1; 397 } 398 // Otherwise, find how many it represents by querying the symbol's info. 399 return find(name)->getValue().getStaticValueCount(); 400 } 401 402 std::string 403 tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt, 404 const char *separator) const { 405 int index = -1; 406 StringRef name = getValuePackName(symbol, &index); 407 408 auto it = symbolInfoMap.find(name); 409 if (it == symbolInfoMap.end()) { 410 auto error = formatv("referencing unbound symbol '{0}'", symbol); 411 PrintFatalError(loc, error); 412 } 413 414 return it->getValue().getValueAndRangeUse(name, index, fmt, separator); 415 } 416 417 std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol, 418 const char *fmt, 419 const char *separator) const { 420 int index = -1; 421 StringRef name = getValuePackName(symbol, &index); 422 423 auto it = symbolInfoMap.find(name); 424 if (it == symbolInfoMap.end()) { 425 auto error = formatv("referencing unbound symbol '{0}'", symbol); 426 PrintFatalError(loc, error); 427 } 428 429 return it->getValue().getAllRangeUse(name, index, fmt, separator); 430 } 431 432 //===----------------------------------------------------------------------===// 433 // Pattern 434 //==----------------------------------------------------------------------===// 435 436 tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) 437 : def(*def), recordOpMap(mapper) {} 438 439 tblgen::DagNode tblgen::Pattern::getSourcePattern() const { 440 return tblgen::DagNode(def.getValueAsDag("sourcePattern")); 441 } 442 443 int tblgen::Pattern::getNumResultPatterns() const { 444 auto *results = def.getValueAsListInit("resultPatterns"); 445 return results->size(); 446 } 447 448 tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { 449 auto *results = def.getValueAsListInit("resultPatterns"); 450 return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index))); 451 } 452 453 void tblgen::Pattern::collectSourcePatternBoundSymbols( 454 tblgen::SymbolInfoMap &infoMap) { 455 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); 456 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); 457 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); 458 } 459 460 void tblgen::Pattern::collectResultPatternBoundSymbols( 461 tblgen::SymbolInfoMap &infoMap) { 462 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); 463 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { 464 auto pattern = getResultPattern(i); 465 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); 466 } 467 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); 468 } 469 470 const tblgen::Operator &tblgen::Pattern::getSourceRootOp() { 471 return getSourcePattern().getDialectOp(recordOpMap); 472 } 473 474 tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) { 475 return node.getDialectOp(recordOpMap); 476 } 477 478 std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const { 479 auto *listInit = def.getValueAsListInit("constraints"); 480 std::vector<tblgen::AppliedConstraint> ret; 481 ret.reserve(listInit->size()); 482 483 for (auto it : *listInit) { 484 auto *dagInit = dyn_cast<llvm::DagInit>(it); 485 if (!dagInit) 486 PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity " 487 "constraints should be DAG nodes"); 488 489 std::vector<std::string> entities; 490 entities.reserve(dagInit->arg_size()); 491 for (auto *argName : dagInit->getArgNames()) { 492 if (!argName) { 493 PrintFatalError( 494 def.getLoc(), 495 "operands to additional constraints can only be symbol references"); 496 } 497 entities.push_back(std::string(argName->getValue())); 498 } 499 500 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), 501 dagInit->getNameStr(), std::move(entities)); 502 } 503 return ret; 504 } 505 506 int tblgen::Pattern::getBenefit() const { 507 // The initial benefit value is a heuristic with number of ops in the source 508 // pattern. 509 int initBenefit = getSourcePattern().getNumOps(); 510 llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); 511 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { 512 PrintFatalError(def.getLoc(), 513 "The 'addBenefit' takes and only takes one integer value"); 514 } 515 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); 516 } 517 518 std::vector<tblgen::Pattern::IdentifierLine> 519 tblgen::Pattern::getLocation() const { 520 std::vector<std::pair<StringRef, unsigned>> result; 521 result.reserve(def.getLoc().size()); 522 for (auto loc : def.getLoc()) { 523 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 524 assert(buf && "invalid source location"); 525 result.emplace_back( 526 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 527 llvm::SrcMgr.getLineAndColumn(loc, buf).first); 528 } 529 return result; 530 } 531 532 void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 533 bool isSrcPattern) { 534 auto treeName = tree.getSymbol(); 535 if (!tree.isOperation()) { 536 if (!treeName.empty()) { 537 PrintFatalError( 538 def.getLoc(), 539 formatv("binding symbol '{0}' to non-operation unsupported right now", 540 treeName)); 541 } 542 return; 543 } 544 545 auto &op = getDialectOp(tree); 546 auto numOpArgs = op.getNumArgs(); 547 auto numTreeArgs = tree.getNumArgs(); 548 549 // The pattern might have the last argument specifying the location. 550 bool hasLocDirective = false; 551 if (numTreeArgs != 0) { 552 if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) 553 hasLocDirective = lastArg.isLocationDirective(); 554 } 555 556 if (numOpArgs != numTreeArgs - hasLocDirective) { 557 auto err = formatv("op '{0}' argument number mismatch: " 558 "{1} in pattern vs. {2} in definition", 559 op.getOperationName(), numTreeArgs, numOpArgs); 560 PrintFatalError(def.getLoc(), err); 561 } 562 563 // The name attached to the DAG node's operator is for representing the 564 // results generated from this op. It should be remembered as bound results. 565 if (!treeName.empty()) { 566 LLVM_DEBUG(llvm::dbgs() 567 << "found symbol bound to op result: " << treeName << '\n'); 568 if (!infoMap.bindOpResult(treeName, op)) 569 PrintFatalError(def.getLoc(), 570 formatv("symbol '{0}' bound more than once", treeName)); 571 } 572 573 for (int i = 0; i != numTreeArgs; ++i) { 574 if (auto treeArg = tree.getArgAsNestedDag(i)) { 575 // This DAG node argument is a DAG node itself. Go inside recursively. 576 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 577 } else if (isSrcPattern) { 578 // We can only bind symbols to op arguments in source pattern. Those 579 // symbols are referenced in result patterns. 580 auto treeArgName = tree.getArgName(i); 581 // `$_` is a special symbol meaning ignore the current argument. 582 if (!treeArgName.empty() && treeArgName != "_") { 583 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " 584 << treeArgName << '\n'); 585 if (!infoMap.bindOpArgument(treeArgName, op, i)) { 586 auto err = formatv("symbol '{0}' bound more than once", treeArgName); 587 PrintFatalError(def.getLoc(), err); 588 } 589 } 590 } 591 } 592 } 593