1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 10 // Pattern. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/Pattern.h" 15 #include "llvm/ADT/Twine.h" 16 #include "llvm/Support/Debug.h" 17 #include "llvm/Support/FormatVariadic.h" 18 #include "llvm/TableGen/Error.h" 19 #include "llvm/TableGen/Record.h" 20 21 #define DEBUG_TYPE "mlir-tblgen-pattern" 22 23 using namespace mlir; 24 25 using llvm::formatv; 26 using mlir::tblgen::Operator; 27 28 //===----------------------------------------------------------------------===// 29 // DagLeaf 30 //===----------------------------------------------------------------------===// 31 32 bool tblgen::DagLeaf::isUnspecified() const { 33 return dyn_cast_or_null<llvm::UnsetInit>(def); 34 } 35 36 bool tblgen::DagLeaf::isOperandMatcher() const { 37 // Operand matchers specify a type constraint. 38 return isSubClassOf("TypeConstraint"); 39 } 40 41 bool tblgen::DagLeaf::isAttrMatcher() const { 42 // Attribute matchers specify an attribute constraint. 43 return isSubClassOf("AttrConstraint"); 44 } 45 46 bool tblgen::DagLeaf::isNativeCodeCall() const { 47 return isSubClassOf("NativeCodeCall"); 48 } 49 50 bool tblgen::DagLeaf::isConstantAttr() const { 51 return isSubClassOf("ConstantAttr"); 52 } 53 54 bool tblgen::DagLeaf::isEnumAttrCase() const { 55 return isSubClassOf("EnumAttrCaseInfo"); 56 } 57 58 tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const { 59 assert((isOperandMatcher() || isAttrMatcher()) && 60 "the DAG leaf must be operand or attribute"); 61 return Constraint(cast<llvm::DefInit>(def)->getDef()); 62 } 63 64 tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const { 65 assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 66 return ConstantAttr(cast<llvm::DefInit>(def)); 67 } 68 69 tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const { 70 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 71 return EnumAttrCase(cast<llvm::DefInit>(def)); 72 } 73 74 std::string tblgen::DagLeaf::getConditionTemplate() const { 75 return getAsConstraint().getConditionTemplate(); 76 } 77 78 llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const { 79 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 80 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); 81 } 82 83 bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const { 84 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) 85 return defInit->getDef()->isSubClassOf(superclass); 86 return false; 87 } 88 89 void tblgen::DagLeaf::print(raw_ostream &os) const { 90 if (def) 91 def->print(os); 92 } 93 94 //===----------------------------------------------------------------------===// 95 // DagNode 96 //===----------------------------------------------------------------------===// 97 98 bool tblgen::DagNode::isNativeCodeCall() const { 99 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) 100 return defInit->getDef()->isSubClassOf("NativeCodeCall"); 101 return false; 102 } 103 104 bool tblgen::DagNode::isOperation() const { 105 return !(isNativeCodeCall() || isReplaceWithValue()); 106 } 107 108 llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { 109 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 110 return cast<llvm::DefInit>(node->getOperator()) 111 ->getDef() 112 ->getValueAsString("expression"); 113 } 114 115 llvm::StringRef tblgen::DagNode::getSymbol() const { 116 return node->getNameStr(); 117 } 118 119 Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const { 120 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 121 auto it = mapper->find(opDef); 122 if (it != mapper->end()) 123 return *it->second; 124 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef)) 125 .first->second; 126 } 127 128 int tblgen::DagNode::getNumOps() const { 129 int count = isReplaceWithValue() ? 0 : 1; 130 for (int i = 0, e = getNumArgs(); i != e; ++i) { 131 if (auto child = getArgAsNestedDag(i)) 132 count += child.getNumOps(); 133 } 134 return count; 135 } 136 137 int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); } 138 139 bool tblgen::DagNode::isNestedDagArg(unsigned index) const { 140 return isa<llvm::DagInit>(node->getArg(index)); 141 } 142 143 tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const { 144 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); 145 } 146 147 tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const { 148 assert(!isNestedDagArg(index)); 149 return DagLeaf(node->getArg(index)); 150 } 151 152 StringRef tblgen::DagNode::getArgName(unsigned index) const { 153 return node->getArgNameStr(index); 154 } 155 156 bool tblgen::DagNode::isReplaceWithValue() const { 157 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 158 return dagOpDef->getName() == "replaceWithValue"; 159 } 160 161 void tblgen::DagNode::print(raw_ostream &os) const { 162 if (node) 163 node->print(os); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // SymbolInfoMap 168 //===----------------------------------------------------------------------===// 169 170 StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol, 171 int *index) { 172 StringRef name, indexStr; 173 int idx = -1; 174 std::tie(name, indexStr) = symbol.rsplit("__"); 175 176 if (indexStr.consumeInteger(10, idx)) { 177 // The second part is not an index; we return the whole symbol as-is. 178 return symbol; 179 } 180 if (index) { 181 *index = idx; 182 } 183 return name; 184 } 185 186 tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, 187 SymbolInfo::Kind kind, 188 Optional<int> index) 189 : op(op), kind(kind), argIndex(index) {} 190 191 int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const { 192 switch (kind) { 193 case Kind::Attr: 194 case Kind::Operand: 195 case Kind::Value: 196 return 1; 197 case Kind::Result: 198 return op->getNumResults(); 199 } 200 llvm_unreachable("unknown kind"); 201 } 202 203 std::string 204 tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { 205 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); 206 switch (kind) { 207 case Kind::Attr: { 208 auto type = 209 op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType(); 210 return std::string(formatv("{0} {1};\n", type, name)); 211 } 212 case Kind::Operand: { 213 // Use operand range for captured operands (to support potential variadic 214 // operands). 215 return std::string( 216 formatv("Operation::operand_range {0}(op0->getOperands());\n", name)); 217 } 218 case Kind::Value: { 219 return std::string(formatv("ArrayRef<Value> {0};\n", name)); 220 } 221 case Kind::Result: { 222 // Use the op itself for captured results. 223 return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name)); 224 } 225 } 226 llvm_unreachable("unknown kind"); 227 } 228 229 std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse( 230 StringRef name, int index, const char *fmt, const char *separator) const { 231 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); 232 switch (kind) { 233 case Kind::Attr: { 234 assert(index < 0); 235 auto repl = formatv(fmt, name); 236 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); 237 return std::string(repl); 238 } 239 case Kind::Operand: { 240 assert(index < 0); 241 auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>(); 242 // If this operand is variadic, then return a range. Otherwise, return the 243 // value itself. 244 if (operand->isVariadic()) { 245 auto repl = formatv(fmt, name); 246 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); 247 return std::string(repl); 248 } 249 auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); 250 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); 251 return std::string(repl); 252 } 253 case Kind::Result: { 254 // If `index` is greater than zero, then we are referencing a specific 255 // result of a multi-result op. The result can still be variadic. 256 if (index >= 0) { 257 std::string v = 258 std::string(formatv("{0}.getODSResults({1})", name, index)); 259 if (!op->getResult(index).isVariadic()) 260 v = std::string(formatv("(*{0}.begin())", v)); 261 auto repl = formatv(fmt, v); 262 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 263 return std::string(repl); 264 } 265 266 // If this op has no result at all but still we bind a symbol to it, it 267 // means we want to capture the op itself. 268 if (op->getNumResults() == 0) { 269 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); 270 return std::string(name); 271 } 272 273 // We are referencing all results of the multi-result op. A specific result 274 // can either be a value or a range. Then join them with `separator`. 275 SmallVector<std::string, 4> values; 276 values.reserve(op->getNumResults()); 277 278 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 279 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i)); 280 if (!op->getResult(i).isVariadic()) { 281 v = std::string(formatv("(*{0}.begin())", v)); 282 } 283 values.push_back(std::string(formatv(fmt, v))); 284 } 285 auto repl = llvm::join(values, separator); 286 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 287 return repl; 288 } 289 case Kind::Value: { 290 assert(index < 0); 291 assert(op == nullptr); 292 auto repl = formatv(fmt, name); 293 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 294 return std::string(repl); 295 } 296 } 297 llvm_unreachable("unknown kind"); 298 } 299 300 std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse( 301 StringRef name, int index, const char *fmt, const char *separator) const { 302 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); 303 switch (kind) { 304 case Kind::Attr: 305 case Kind::Operand: { 306 assert(index < 0 && "only allowed for symbol bound to result"); 307 auto repl = formatv(fmt, name); 308 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); 309 return std::string(repl); 310 } 311 case Kind::Result: { 312 if (index >= 0) { 313 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); 314 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 315 return std::string(repl); 316 } 317 318 // We are referencing all results of the multi-result op. Each result should 319 // have a value range, and then join them with `separator`. 320 SmallVector<std::string, 4> values; 321 values.reserve(op->getNumResults()); 322 323 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 324 values.push_back(std::string( 325 formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); 326 } 327 auto repl = llvm::join(values, separator); 328 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 329 return repl; 330 } 331 case Kind::Value: { 332 assert(index < 0 && "only allowed for symbol bound to result"); 333 assert(op == nullptr); 334 auto repl = formatv(fmt, formatv("{{{0}}", name)); 335 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 336 return std::string(repl); 337 } 338 } 339 llvm_unreachable("unknown kind"); 340 } 341 342 bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, 343 int argIndex) { 344 StringRef name = getValuePackName(symbol); 345 if (name != symbol) { 346 auto error = formatv( 347 "symbol '{0}' with trailing index cannot bind to op argument", symbol); 348 PrintFatalError(loc, error); 349 } 350 351 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>() 352 ? SymbolInfo::getAttr(&op, argIndex) 353 : SymbolInfo::getOperand(&op, argIndex); 354 355 return symbolInfoMap.insert({symbol, symInfo}).second; 356 } 357 358 bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { 359 StringRef name = getValuePackName(symbol); 360 return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second; 361 } 362 363 bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) { 364 return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second; 365 } 366 367 bool tblgen::SymbolInfoMap::contains(StringRef symbol) const { 368 return find(symbol) != symbolInfoMap.end(); 369 } 370 371 tblgen::SymbolInfoMap::const_iterator 372 tblgen::SymbolInfoMap::find(StringRef key) const { 373 StringRef name = getValuePackName(key); 374 return symbolInfoMap.find(name); 375 } 376 377 int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const { 378 StringRef name = getValuePackName(symbol); 379 if (name != symbol) { 380 // If there is a trailing index inside symbol, it references just one 381 // static value. 382 return 1; 383 } 384 // Otherwise, find how many it represents by querying the symbol's info. 385 return find(name)->getValue().getStaticValueCount(); 386 } 387 388 std::string 389 tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt, 390 const char *separator) const { 391 int index = -1; 392 StringRef name = getValuePackName(symbol, &index); 393 394 auto it = symbolInfoMap.find(name); 395 if (it == symbolInfoMap.end()) { 396 auto error = formatv("referencing unbound symbol '{0}'", symbol); 397 PrintFatalError(loc, error); 398 } 399 400 return it->getValue().getValueAndRangeUse(name, index, fmt, separator); 401 } 402 403 std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol, 404 const char *fmt, 405 const char *separator) const { 406 int index = -1; 407 StringRef name = getValuePackName(symbol, &index); 408 409 auto it = symbolInfoMap.find(name); 410 if (it == symbolInfoMap.end()) { 411 auto error = formatv("referencing unbound symbol '{0}'", symbol); 412 PrintFatalError(loc, error); 413 } 414 415 return it->getValue().getAllRangeUse(name, index, fmt, separator); 416 } 417 418 //===----------------------------------------------------------------------===// 419 // Pattern 420 //==----------------------------------------------------------------------===// 421 422 tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) 423 : def(*def), recordOpMap(mapper) {} 424 425 tblgen::DagNode tblgen::Pattern::getSourcePattern() const { 426 return tblgen::DagNode(def.getValueAsDag("sourcePattern")); 427 } 428 429 int tblgen::Pattern::getNumResultPatterns() const { 430 auto *results = def.getValueAsListInit("resultPatterns"); 431 return results->size(); 432 } 433 434 tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { 435 auto *results = def.getValueAsListInit("resultPatterns"); 436 return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index))); 437 } 438 439 void tblgen::Pattern::collectSourcePatternBoundSymbols( 440 tblgen::SymbolInfoMap &infoMap) { 441 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); 442 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); 443 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); 444 } 445 446 void tblgen::Pattern::collectResultPatternBoundSymbols( 447 tblgen::SymbolInfoMap &infoMap) { 448 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); 449 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { 450 auto pattern = getResultPattern(i); 451 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); 452 } 453 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); 454 } 455 456 const tblgen::Operator &tblgen::Pattern::getSourceRootOp() { 457 return getSourcePattern().getDialectOp(recordOpMap); 458 } 459 460 tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) { 461 return node.getDialectOp(recordOpMap); 462 } 463 464 std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const { 465 auto *listInit = def.getValueAsListInit("constraints"); 466 std::vector<tblgen::AppliedConstraint> ret; 467 ret.reserve(listInit->size()); 468 469 for (auto it : *listInit) { 470 auto *dagInit = dyn_cast<llvm::DagInit>(it); 471 if (!dagInit) 472 PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity " 473 "constraints should be DAG nodes"); 474 475 std::vector<std::string> entities; 476 entities.reserve(dagInit->arg_size()); 477 for (auto *argName : dagInit->getArgNames()) { 478 if (!argName) { 479 PrintFatalError( 480 def.getLoc(), 481 "operands to additional constraints can only be symbol references"); 482 } 483 entities.push_back(std::string(argName->getValue())); 484 } 485 486 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), 487 dagInit->getNameStr(), std::move(entities)); 488 } 489 return ret; 490 } 491 492 int tblgen::Pattern::getBenefit() const { 493 // The initial benefit value is a heuristic with number of ops in the source 494 // pattern. 495 int initBenefit = getSourcePattern().getNumOps(); 496 llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); 497 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { 498 PrintFatalError(def.getLoc(), 499 "The 'addBenefit' takes and only takes one integer value"); 500 } 501 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); 502 } 503 504 std::vector<tblgen::Pattern::IdentifierLine> 505 tblgen::Pattern::getLocation() const { 506 std::vector<std::pair<StringRef, unsigned>> result; 507 result.reserve(def.getLoc().size()); 508 for (auto loc : def.getLoc()) { 509 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 510 assert(buf && "invalid source location"); 511 result.emplace_back( 512 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 513 llvm::SrcMgr.getLineAndColumn(loc, buf).first); 514 } 515 return result; 516 } 517 518 void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 519 bool isSrcPattern) { 520 auto treeName = tree.getSymbol(); 521 if (!tree.isOperation()) { 522 if (!treeName.empty()) { 523 PrintFatalError( 524 def.getLoc(), 525 formatv("binding symbol '{0}' to non-operation unsupported right now", 526 treeName)); 527 } 528 return; 529 } 530 531 auto &op = getDialectOp(tree); 532 auto numOpArgs = op.getNumArgs(); 533 auto numTreeArgs = tree.getNumArgs(); 534 535 if (numOpArgs != numTreeArgs) { 536 auto err = formatv("op '{0}' argument number mismatch: " 537 "{1} in pattern vs. {2} in definition", 538 op.getOperationName(), numTreeArgs, numOpArgs); 539 PrintFatalError(def.getLoc(), err); 540 } 541 542 // The name attached to the DAG node's operator is for representing the 543 // results generated from this op. It should be remembered as bound results. 544 if (!treeName.empty()) { 545 LLVM_DEBUG(llvm::dbgs() 546 << "found symbol bound to op result: " << treeName << '\n'); 547 if (!infoMap.bindOpResult(treeName, op)) 548 PrintFatalError(def.getLoc(), 549 formatv("symbol '{0}' bound more than once", treeName)); 550 } 551 552 for (int i = 0; i != numTreeArgs; ++i) { 553 if (auto treeArg = tree.getArgAsNestedDag(i)) { 554 // This DAG node argument is a DAG node itself. Go inside recursively. 555 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 556 } else if (isSrcPattern) { 557 // We can only bind symbols to op arguments in source pattern. Those 558 // symbols are referenced in result patterns. 559 auto treeArgName = tree.getArgName(i); 560 // `$_` is a special symbol meaning ignore the current argument. 561 if (!treeArgName.empty() && treeArgName != "_") { 562 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " 563 << treeArgName << '\n'); 564 if (!infoMap.bindOpArgument(treeArgName, op, i)) { 565 auto err = formatv("symbol '{0}' bound more than once", treeArgName); 566 PrintFatalError(def.getLoc(), err); 567 } 568 } 569 } 570 } 571 } 572