1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 10 // Pattern. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/Pattern.h" 15 #include "llvm/ADT/StringExtras.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Support/Debug.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 #define DEBUG_TYPE "mlir-tblgen-pattern" 23 24 using namespace mlir; 25 using namespace tblgen; 26 27 using llvm::formatv; 28 29 //===----------------------------------------------------------------------===// 30 // DagLeaf 31 //===----------------------------------------------------------------------===// 32 33 bool DagLeaf::isUnspecified() const { 34 return dyn_cast_or_null<llvm::UnsetInit>(def); 35 } 36 37 bool DagLeaf::isOperandMatcher() const { 38 // Operand matchers specify a type constraint. 39 return isSubClassOf("TypeConstraint"); 40 } 41 42 bool DagLeaf::isAttrMatcher() const { 43 // Attribute matchers specify an attribute constraint. 44 return isSubClassOf("AttrConstraint"); 45 } 46 47 bool DagLeaf::isNativeCodeCall() const { 48 return isSubClassOf("NativeCodeCall"); 49 } 50 51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } 52 53 bool DagLeaf::isEnumAttrCase() const { 54 return isSubClassOf("EnumAttrCaseInfo"); 55 } 56 57 bool DagLeaf::isStringAttr() const { 58 return isa<llvm::StringInit, llvm::CodeInit>(def); 59 } 60 61 Constraint DagLeaf::getAsConstraint() const { 62 assert((isOperandMatcher() || isAttrMatcher()) && 63 "the DAG leaf must be operand or attribute"); 64 return Constraint(cast<llvm::DefInit>(def)->getDef()); 65 } 66 67 ConstantAttr DagLeaf::getAsConstantAttr() const { 68 assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 69 return ConstantAttr(cast<llvm::DefInit>(def)); 70 } 71 72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const { 73 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 74 return EnumAttrCase(cast<llvm::DefInit>(def)); 75 } 76 77 std::string DagLeaf::getConditionTemplate() const { 78 return getAsConstraint().getConditionTemplate(); 79 } 80 81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const { 82 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 83 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); 84 } 85 86 std::string DagLeaf::getStringAttr() const { 87 assert(isStringAttr() && "the DAG leaf must be string attribute"); 88 return def->getAsUnquotedString(); 89 } 90 bool DagLeaf::isSubClassOf(StringRef superclass) const { 91 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) 92 return defInit->getDef()->isSubClassOf(superclass); 93 return false; 94 } 95 96 void DagLeaf::print(raw_ostream &os) const { 97 if (def) 98 def->print(os); 99 } 100 101 //===----------------------------------------------------------------------===// 102 // DagNode 103 //===----------------------------------------------------------------------===// 104 105 bool DagNode::isNativeCodeCall() const { 106 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) 107 return defInit->getDef()->isSubClassOf("NativeCodeCall"); 108 return false; 109 } 110 111 bool DagNode::isOperation() const { 112 return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective(); 113 } 114 115 llvm::StringRef DagNode::getNativeCodeTemplate() const { 116 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 117 return cast<llvm::DefInit>(node->getOperator()) 118 ->getDef() 119 ->getValueAsString("expression"); 120 } 121 122 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); } 123 124 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { 125 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 126 auto it = mapper->find(opDef); 127 if (it != mapper->end()) 128 return *it->second; 129 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef)) 130 .first->second; 131 } 132 133 int DagNode::getNumOps() const { 134 int count = isReplaceWithValue() ? 0 : 1; 135 for (int i = 0, e = getNumArgs(); i != e; ++i) { 136 if (auto child = getArgAsNestedDag(i)) 137 count += child.getNumOps(); 138 } 139 return count; 140 } 141 142 int DagNode::getNumArgs() const { return node->getNumArgs(); } 143 144 bool DagNode::isNestedDagArg(unsigned index) const { 145 return isa<llvm::DagInit>(node->getArg(index)); 146 } 147 148 DagNode DagNode::getArgAsNestedDag(unsigned index) const { 149 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); 150 } 151 152 DagLeaf DagNode::getArgAsLeaf(unsigned index) const { 153 assert(!isNestedDagArg(index)); 154 return DagLeaf(node->getArg(index)); 155 } 156 157 StringRef DagNode::getArgName(unsigned index) const { 158 return node->getArgNameStr(index); 159 } 160 161 bool DagNode::isReplaceWithValue() const { 162 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 163 return dagOpDef->getName() == "replaceWithValue"; 164 } 165 166 bool DagNode::isLocationDirective() const { 167 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 168 return dagOpDef->getName() == "location"; 169 } 170 171 void DagNode::print(raw_ostream &os) const { 172 if (node) 173 node->print(os); 174 } 175 176 //===----------------------------------------------------------------------===// 177 // SymbolInfoMap 178 //===----------------------------------------------------------------------===// 179 180 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) { 181 StringRef name, indexStr; 182 int idx = -1; 183 std::tie(name, indexStr) = symbol.rsplit("__"); 184 185 if (indexStr.consumeInteger(10, idx)) { 186 // The second part is not an index; we return the whole symbol as-is. 187 return symbol; 188 } 189 if (index) { 190 *index = idx; 191 } 192 return name; 193 } 194 195 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind, 196 Optional<int> index) 197 : op(op), kind(kind), argIndex(index) {} 198 199 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { 200 switch (kind) { 201 case Kind::Attr: 202 case Kind::Operand: 203 case Kind::Value: 204 return 1; 205 case Kind::Result: 206 return op->getNumResults(); 207 } 208 llvm_unreachable("unknown kind"); 209 } 210 211 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { 212 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); 213 switch (kind) { 214 case Kind::Attr: { 215 auto type = 216 op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType(); 217 return std::string(formatv("{0} {1};\n", type, name)); 218 } 219 case Kind::Operand: { 220 // Use operand range for captured operands (to support potential variadic 221 // operands). 222 return std::string( 223 formatv("Operation::operand_range {0}(op0->getOperands());\n", name)); 224 } 225 case Kind::Value: { 226 return std::string(formatv("ArrayRef<Value> {0};\n", name)); 227 } 228 case Kind::Result: { 229 // Use the op itself for captured results. 230 return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name)); 231 } 232 } 233 llvm_unreachable("unknown kind"); 234 } 235 236 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( 237 StringRef name, int index, const char *fmt, const char *separator) const { 238 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); 239 switch (kind) { 240 case Kind::Attr: { 241 assert(index < 0); 242 auto repl = formatv(fmt, name); 243 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); 244 return std::string(repl); 245 } 246 case Kind::Operand: { 247 assert(index < 0); 248 auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>(); 249 // If this operand is variadic, then return a range. Otherwise, return the 250 // value itself. 251 if (operand->isVariableLength()) { 252 auto repl = formatv(fmt, name); 253 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); 254 return std::string(repl); 255 } 256 auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); 257 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); 258 return std::string(repl); 259 } 260 case Kind::Result: { 261 // If `index` is greater than zero, then we are referencing a specific 262 // result of a multi-result op. The result can still be variadic. 263 if (index >= 0) { 264 std::string v = 265 std::string(formatv("{0}.getODSResults({1})", name, index)); 266 if (!op->getResult(index).isVariadic()) 267 v = std::string(formatv("(*{0}.begin())", v)); 268 auto repl = formatv(fmt, v); 269 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 270 return std::string(repl); 271 } 272 273 // If this op has no result at all but still we bind a symbol to it, it 274 // means we want to capture the op itself. 275 if (op->getNumResults() == 0) { 276 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); 277 return std::string(name); 278 } 279 280 // We are referencing all results of the multi-result op. A specific result 281 // can either be a value or a range. Then join them with `separator`. 282 SmallVector<std::string, 4> values; 283 values.reserve(op->getNumResults()); 284 285 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 286 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i)); 287 if (!op->getResult(i).isVariadic()) { 288 v = std::string(formatv("(*{0}.begin())", v)); 289 } 290 values.push_back(std::string(formatv(fmt, v))); 291 } 292 auto repl = llvm::join(values, separator); 293 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 294 return repl; 295 } 296 case Kind::Value: { 297 assert(index < 0); 298 assert(op == nullptr); 299 auto repl = formatv(fmt, name); 300 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 301 return std::string(repl); 302 } 303 } 304 llvm_unreachable("unknown kind"); 305 } 306 307 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( 308 StringRef name, int index, const char *fmt, const char *separator) const { 309 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); 310 switch (kind) { 311 case Kind::Attr: 312 case Kind::Operand: { 313 assert(index < 0 && "only allowed for symbol bound to result"); 314 auto repl = formatv(fmt, name); 315 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); 316 return std::string(repl); 317 } 318 case Kind::Result: { 319 if (index >= 0) { 320 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); 321 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 322 return std::string(repl); 323 } 324 325 // We are referencing all results of the multi-result op. Each result should 326 // have a value range, and then join them with `separator`. 327 SmallVector<std::string, 4> values; 328 values.reserve(op->getNumResults()); 329 330 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 331 values.push_back(std::string( 332 formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); 333 } 334 auto repl = llvm::join(values, separator); 335 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 336 return repl; 337 } 338 case Kind::Value: { 339 assert(index < 0 && "only allowed for symbol bound to result"); 340 assert(op == nullptr); 341 auto repl = formatv(fmt, formatv("{{{0}}", name)); 342 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 343 return std::string(repl); 344 } 345 } 346 llvm_unreachable("unknown kind"); 347 } 348 349 bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, 350 int argIndex) { 351 StringRef name = getValuePackName(symbol); 352 if (name != symbol) { 353 auto error = formatv( 354 "symbol '{0}' with trailing index cannot bind to op argument", symbol); 355 PrintFatalError(loc, error); 356 } 357 358 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>() 359 ? SymbolInfo::getAttr(&op, argIndex) 360 : SymbolInfo::getOperand(&op, argIndex); 361 362 return symbolInfoMap.insert({symbol, symInfo}).second; 363 } 364 365 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { 366 StringRef name = getValuePackName(symbol); 367 return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second; 368 } 369 370 bool SymbolInfoMap::bindValue(StringRef symbol) { 371 return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second; 372 } 373 374 bool SymbolInfoMap::contains(StringRef symbol) const { 375 return find(symbol) != symbolInfoMap.end(); 376 } 377 378 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { 379 StringRef name = getValuePackName(key); 380 return symbolInfoMap.find(name); 381 } 382 383 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { 384 StringRef name = getValuePackName(symbol); 385 if (name != symbol) { 386 // If there is a trailing index inside symbol, it references just one 387 // static value. 388 return 1; 389 } 390 // Otherwise, find how many it represents by querying the symbol's info. 391 return find(name)->getValue().getStaticValueCount(); 392 } 393 394 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, 395 const char *fmt, 396 const char *separator) const { 397 int index = -1; 398 StringRef name = getValuePackName(symbol, &index); 399 400 auto it = symbolInfoMap.find(name); 401 if (it == symbolInfoMap.end()) { 402 auto error = formatv("referencing unbound symbol '{0}'", symbol); 403 PrintFatalError(loc, error); 404 } 405 406 return it->getValue().getValueAndRangeUse(name, index, fmt, separator); 407 } 408 409 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, 410 const char *separator) const { 411 int index = -1; 412 StringRef name = getValuePackName(symbol, &index); 413 414 auto it = symbolInfoMap.find(name); 415 if (it == symbolInfoMap.end()) { 416 auto error = formatv("referencing unbound symbol '{0}'", symbol); 417 PrintFatalError(loc, error); 418 } 419 420 return it->getValue().getAllRangeUse(name, index, fmt, separator); 421 } 422 423 //===----------------------------------------------------------------------===// 424 // Pattern 425 //==----------------------------------------------------------------------===// 426 427 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) 428 : def(*def), recordOpMap(mapper) {} 429 430 DagNode Pattern::getSourcePattern() const { 431 return DagNode(def.getValueAsDag("sourcePattern")); 432 } 433 434 int Pattern::getNumResultPatterns() const { 435 auto *results = def.getValueAsListInit("resultPatterns"); 436 return results->size(); 437 } 438 439 DagNode Pattern::getResultPattern(unsigned index) const { 440 auto *results = def.getValueAsListInit("resultPatterns"); 441 return DagNode(cast<llvm::DagInit>(results->getElement(index))); 442 } 443 444 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { 445 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); 446 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); 447 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); 448 } 449 450 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { 451 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); 452 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { 453 auto pattern = getResultPattern(i); 454 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); 455 } 456 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); 457 } 458 459 const Operator &Pattern::getSourceRootOp() { 460 return getSourcePattern().getDialectOp(recordOpMap); 461 } 462 463 Operator &Pattern::getDialectOp(DagNode node) { 464 return node.getDialectOp(recordOpMap); 465 } 466 467 std::vector<AppliedConstraint> Pattern::getConstraints() const { 468 auto *listInit = def.getValueAsListInit("constraints"); 469 std::vector<AppliedConstraint> ret; 470 ret.reserve(listInit->size()); 471 472 for (auto it : *listInit) { 473 auto *dagInit = dyn_cast<llvm::DagInit>(it); 474 if (!dagInit) 475 PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity " 476 "constraints should be DAG nodes"); 477 478 std::vector<std::string> entities; 479 entities.reserve(dagInit->arg_size()); 480 for (auto *argName : dagInit->getArgNames()) { 481 if (!argName) { 482 PrintFatalError( 483 def.getLoc(), 484 "operands to additional constraints can only be symbol references"); 485 } 486 entities.push_back(std::string(argName->getValue())); 487 } 488 489 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), 490 dagInit->getNameStr(), std::move(entities)); 491 } 492 return ret; 493 } 494 495 int Pattern::getBenefit() const { 496 // The initial benefit value is a heuristic with number of ops in the source 497 // pattern. 498 int initBenefit = getSourcePattern().getNumOps(); 499 llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); 500 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { 501 PrintFatalError(def.getLoc(), 502 "The 'addBenefit' takes and only takes one integer value"); 503 } 504 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); 505 } 506 507 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const { 508 std::vector<std::pair<StringRef, unsigned>> result; 509 result.reserve(def.getLoc().size()); 510 for (auto loc : def.getLoc()) { 511 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 512 assert(buf && "invalid source location"); 513 result.emplace_back( 514 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 515 llvm::SrcMgr.getLineAndColumn(loc, buf).first); 516 } 517 return result; 518 } 519 520 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 521 bool isSrcPattern) { 522 auto treeName = tree.getSymbol(); 523 if (!tree.isOperation()) { 524 if (!treeName.empty()) { 525 PrintFatalError( 526 def.getLoc(), 527 formatv("binding symbol '{0}' to non-operation unsupported right now", 528 treeName)); 529 } 530 return; 531 } 532 533 auto &op = getDialectOp(tree); 534 auto numOpArgs = op.getNumArgs(); 535 auto numTreeArgs = tree.getNumArgs(); 536 537 // The pattern might have the last argument specifying the location. 538 bool hasLocDirective = false; 539 if (numTreeArgs != 0) { 540 if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) 541 hasLocDirective = lastArg.isLocationDirective(); 542 } 543 544 if (numOpArgs != numTreeArgs - hasLocDirective) { 545 auto err = formatv("op '{0}' argument number mismatch: " 546 "{1} in pattern vs. {2} in definition", 547 op.getOperationName(), numTreeArgs, numOpArgs); 548 PrintFatalError(def.getLoc(), err); 549 } 550 551 // The name attached to the DAG node's operator is for representing the 552 // results generated from this op. It should be remembered as bound results. 553 if (!treeName.empty()) { 554 LLVM_DEBUG(llvm::dbgs() 555 << "found symbol bound to op result: " << treeName << '\n'); 556 if (!infoMap.bindOpResult(treeName, op)) 557 PrintFatalError(def.getLoc(), 558 formatv("symbol '{0}' bound more than once", treeName)); 559 } 560 561 for (int i = 0; i != numTreeArgs; ++i) { 562 if (auto treeArg = tree.getArgAsNestedDag(i)) { 563 // This DAG node argument is a DAG node itself. Go inside recursively. 564 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 565 } else if (isSrcPattern) { 566 // We can only bind symbols to op arguments in source pattern. Those 567 // symbols are referenced in result patterns. 568 auto treeArgName = tree.getArgName(i); 569 // `$_` is a special symbol meaning ignore the current argument. 570 if (!treeArgName.empty() && treeArgName != "_") { 571 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " 572 << treeArgName << '\n'); 573 if (!infoMap.bindOpArgument(treeArgName, op, i)) { 574 auto err = formatv("symbol '{0}' bound more than once", treeArgName); 575 PrintFatalError(def.getLoc(), err); 576 } 577 } 578 } 579 } 580 } 581