1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 19 // Pattern. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/TableGen/Pattern.h" 24 #include "llvm/ADT/Twine.h" 25 #include "llvm/Support/Debug.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/TableGen/Error.h" 28 #include "llvm/TableGen/Record.h" 29 30 #define DEBUG_TYPE "mlir-tblgen-pattern" 31 32 using namespace mlir; 33 34 using llvm::formatv; 35 using mlir::tblgen::Operator; 36 37 //===----------------------------------------------------------------------===// 38 // DagLeaf 39 //===----------------------------------------------------------------------===// 40 41 bool tblgen::DagLeaf::isUnspecified() const { 42 return dyn_cast_or_null<llvm::UnsetInit>(def); 43 } 44 45 bool tblgen::DagLeaf::isOperandMatcher() const { 46 // Operand matchers specify a type constraint. 47 return isSubClassOf("TypeConstraint"); 48 } 49 50 bool tblgen::DagLeaf::isAttrMatcher() const { 51 // Attribute matchers specify an attribute constraint. 52 return isSubClassOf("AttrConstraint"); 53 } 54 55 bool tblgen::DagLeaf::isNativeCodeCall() const { 56 return isSubClassOf("NativeCodeCall"); 57 } 58 59 bool tblgen::DagLeaf::isConstantAttr() const { 60 return isSubClassOf("ConstantAttr"); 61 } 62 63 bool tblgen::DagLeaf::isEnumAttrCase() const { 64 return isSubClassOf("EnumAttrCaseInfo"); 65 } 66 67 tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const { 68 assert((isOperandMatcher() || isAttrMatcher()) && 69 "the DAG leaf must be operand or attribute"); 70 return Constraint(cast<llvm::DefInit>(def)->getDef()); 71 } 72 73 tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const { 74 assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 75 return ConstantAttr(cast<llvm::DefInit>(def)); 76 } 77 78 tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const { 79 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 80 return EnumAttrCase(cast<llvm::DefInit>(def)); 81 } 82 83 std::string tblgen::DagLeaf::getConditionTemplate() const { 84 return getAsConstraint().getConditionTemplate(); 85 } 86 87 llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const { 88 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 89 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); 90 } 91 92 bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const { 93 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) 94 return defInit->getDef()->isSubClassOf(superclass); 95 return false; 96 } 97 98 void tblgen::DagLeaf::print(raw_ostream &os) const { 99 if (def) 100 def->print(os); 101 } 102 103 //===----------------------------------------------------------------------===// 104 // DagNode 105 //===----------------------------------------------------------------------===// 106 107 bool tblgen::DagNode::isNativeCodeCall() const { 108 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) 109 return defInit->getDef()->isSubClassOf("NativeCodeCall"); 110 return false; 111 } 112 113 bool tblgen::DagNode::isOperation() const { 114 return !(isNativeCodeCall() || isReplaceWithValue()); 115 } 116 117 llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { 118 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 119 return cast<llvm::DefInit>(node->getOperator()) 120 ->getDef() 121 ->getValueAsString("expression"); 122 } 123 124 llvm::StringRef tblgen::DagNode::getSymbol() const { 125 return node->getNameStr(); 126 } 127 128 Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const { 129 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 130 auto it = mapper->find(opDef); 131 if (it != mapper->end()) 132 return *it->second; 133 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef)) 134 .first->second; 135 } 136 137 int tblgen::DagNode::getNumOps() const { 138 int count = isReplaceWithValue() ? 0 : 1; 139 for (int i = 0, e = getNumArgs(); i != e; ++i) { 140 if (auto child = getArgAsNestedDag(i)) 141 count += child.getNumOps(); 142 } 143 return count; 144 } 145 146 int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); } 147 148 bool tblgen::DagNode::isNestedDagArg(unsigned index) const { 149 return isa<llvm::DagInit>(node->getArg(index)); 150 } 151 152 tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const { 153 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); 154 } 155 156 tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const { 157 assert(!isNestedDagArg(index)); 158 return DagLeaf(node->getArg(index)); 159 } 160 161 StringRef tblgen::DagNode::getArgName(unsigned index) const { 162 return node->getArgNameStr(index); 163 } 164 165 bool tblgen::DagNode::isReplaceWithValue() const { 166 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 167 return dagOpDef->getName() == "replaceWithValue"; 168 } 169 170 void tblgen::DagNode::print(raw_ostream &os) const { 171 if (node) 172 node->print(os); 173 } 174 175 //===----------------------------------------------------------------------===// 176 // SymbolInfoMap 177 //===----------------------------------------------------------------------===// 178 179 StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol, 180 int *index) { 181 StringRef name, indexStr; 182 int idx = -1; 183 std::tie(name, indexStr) = symbol.rsplit("__"); 184 185 if (indexStr.consumeInteger(10, idx)) { 186 // The second part is not an index; we return the whole symbol as-is. 187 return symbol; 188 } 189 if (index) { 190 *index = idx; 191 } 192 return name; 193 } 194 195 tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, 196 SymbolInfo::Kind kind, 197 Optional<int> index) 198 : op(op), kind(kind), argIndex(index) {} 199 200 int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const { 201 switch (kind) { 202 case Kind::Attr: 203 case Kind::Operand: 204 case Kind::Value: 205 return 1; 206 case Kind::Result: 207 return op->getNumResults(); 208 } 209 llvm_unreachable("unknown kind"); 210 } 211 212 std::string 213 tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { 214 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); 215 switch (kind) { 216 case Kind::Attr: { 217 auto type = 218 op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType(); 219 return formatv("{0} {1};\n", type, name); 220 } 221 case Kind::Operand: { 222 // Use operand range for captured operands (to support potential variadic 223 // operands). 224 return formatv("Operation::operand_range {0}(op0->getOperands());\n", name); 225 } 226 case Kind::Value: { 227 return formatv("ArrayRef<Value *> {0};\n", name); 228 } 229 case Kind::Result: { 230 // Use the op itself for captured results. 231 return formatv("{0} {1};\n", op->getQualCppClassName(), name); 232 } 233 } 234 llvm_unreachable("unknown kind"); 235 } 236 237 std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse( 238 StringRef name, int index, const char *fmt, const char *separator) const { 239 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); 240 switch (kind) { 241 case Kind::Attr: { 242 assert(index < 0); 243 auto repl = formatv(fmt, name); 244 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); 245 return repl; 246 } 247 case Kind::Operand: { 248 assert(index < 0); 249 auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>(); 250 // If this operand is variadic, then return a range. Otherwise, return the 251 // value itself. 252 if (operand->isVariadic()) { 253 auto repl = formatv(fmt, name); 254 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); 255 return repl; 256 } 257 auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); 258 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); 259 return repl; 260 } 261 case Kind::Result: { 262 // If `index` is greater than zero, then we are referencing a specific 263 // result of a multi-result op. The result can still be variadic. 264 if (index >= 0) { 265 std::string v = formatv("{0}.getODSResults({1})", name, index); 266 if (!op->getResult(index).isVariadic()) 267 v = formatv("(*{0}.begin())", v); 268 auto repl = formatv(fmt, v); 269 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 270 return repl; 271 } 272 273 // If this op has no result at all but still we bind a symbol to it, it 274 // means we want to capture the op itself. 275 if (op->getNumResults() == 0) { 276 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); 277 return name; 278 } 279 280 // We are referencing all results of the multi-result op. A specific result 281 // can either be a value or a range. Then join them with `separator`. 282 SmallVector<std::string, 4> values; 283 values.reserve(op->getNumResults()); 284 285 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 286 std::string v = formatv("{0}.getODSResults({1})", name, i); 287 if (!op->getResult(i).isVariadic()) { 288 v = formatv("(*{0}.begin())", v); 289 } 290 values.push_back(formatv(fmt, v)); 291 } 292 auto repl = llvm::join(values, separator); 293 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 294 return repl; 295 } 296 case Kind::Value: { 297 assert(index < 0); 298 assert(op == nullptr); 299 auto repl = formatv(fmt, name); 300 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 301 return repl; 302 } 303 } 304 llvm_unreachable("unknown kind"); 305 } 306 307 std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse( 308 StringRef name, int index, const char *fmt, const char *separator) const { 309 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); 310 switch (kind) { 311 case Kind::Attr: 312 case Kind::Operand: { 313 assert(index < 0 && "only allowed for symbol bound to result"); 314 auto repl = formatv(fmt, name); 315 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); 316 return repl; 317 } 318 case Kind::Result: { 319 if (index >= 0) { 320 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); 321 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); 322 return repl; 323 } 324 325 // We are referencing all results of the multi-result op. Each result should 326 // have a value range, and then join them with `separator`. 327 SmallVector<std::string, 4> values; 328 values.reserve(op->getNumResults()); 329 330 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 331 values.push_back( 332 formatv(fmt, formatv("{0}.getODSResults({1})", name, i))); 333 } 334 auto repl = llvm::join(values, separator); 335 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); 336 return repl; 337 } 338 case Kind::Value: { 339 assert(index < 0 && "only allowed for symbol bound to result"); 340 assert(op == nullptr); 341 auto repl = formatv(fmt, formatv("{{{0}}", name)); 342 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); 343 return repl; 344 } 345 } 346 llvm_unreachable("unknown kind"); 347 } 348 349 bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, 350 int argIndex) { 351 StringRef name = getValuePackName(symbol); 352 if (name != symbol) { 353 auto error = formatv( 354 "symbol '{0}' with trailing index cannot bind to op argument", symbol); 355 PrintFatalError(loc, error); 356 } 357 358 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>() 359 ? SymbolInfo::getAttr(&op, argIndex) 360 : SymbolInfo::getOperand(&op, argIndex); 361 362 return symbolInfoMap.insert({symbol, symInfo}).second; 363 } 364 365 bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { 366 StringRef name = getValuePackName(symbol); 367 return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second; 368 } 369 370 bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) { 371 return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second; 372 } 373 374 bool tblgen::SymbolInfoMap::contains(StringRef symbol) const { 375 return find(symbol) != symbolInfoMap.end(); 376 } 377 378 tblgen::SymbolInfoMap::const_iterator 379 tblgen::SymbolInfoMap::find(StringRef key) const { 380 StringRef name = getValuePackName(key); 381 return symbolInfoMap.find(name); 382 } 383 384 int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const { 385 StringRef name = getValuePackName(symbol); 386 if (name != symbol) { 387 // If there is a trailing index inside symbol, it references just one 388 // static value. 389 return 1; 390 } 391 // Otherwise, find how many it represents by querying the symbol's info. 392 return find(name)->getValue().getStaticValueCount(); 393 } 394 395 std::string 396 tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt, 397 const char *separator) const { 398 int index = -1; 399 StringRef name = getValuePackName(symbol, &index); 400 401 auto it = symbolInfoMap.find(name); 402 if (it == symbolInfoMap.end()) { 403 auto error = formatv("referencing unbound symbol '{0}'", symbol); 404 PrintFatalError(loc, error); 405 } 406 407 return it->getValue().getValueAndRangeUse(name, index, fmt, separator); 408 } 409 410 std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol, 411 const char *fmt, 412 const char *separator) const { 413 int index = -1; 414 StringRef name = getValuePackName(symbol, &index); 415 416 auto it = symbolInfoMap.find(name); 417 if (it == symbolInfoMap.end()) { 418 auto error = formatv("referencing unbound symbol '{0}'", symbol); 419 PrintFatalError(loc, error); 420 } 421 422 return it->getValue().getAllRangeUse(name, index, fmt, separator); 423 } 424 425 //===----------------------------------------------------------------------===// 426 // Pattern 427 //==----------------------------------------------------------------------===// 428 429 tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) 430 : def(*def), recordOpMap(mapper) {} 431 432 tblgen::DagNode tblgen::Pattern::getSourcePattern() const { 433 return tblgen::DagNode(def.getValueAsDag("sourcePattern")); 434 } 435 436 int tblgen::Pattern::getNumResultPatterns() const { 437 auto *results = def.getValueAsListInit("resultPatterns"); 438 return results->size(); 439 } 440 441 tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { 442 auto *results = def.getValueAsListInit("resultPatterns"); 443 return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index))); 444 } 445 446 void tblgen::Pattern::collectSourcePatternBoundSymbols( 447 tblgen::SymbolInfoMap &infoMap) { 448 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); 449 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); 450 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); 451 } 452 453 void tblgen::Pattern::collectResultPatternBoundSymbols( 454 tblgen::SymbolInfoMap &infoMap) { 455 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); 456 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { 457 auto pattern = getResultPattern(i); 458 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); 459 } 460 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); 461 } 462 463 const tblgen::Operator &tblgen::Pattern::getSourceRootOp() { 464 return getSourcePattern().getDialectOp(recordOpMap); 465 } 466 467 tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) { 468 return node.getDialectOp(recordOpMap); 469 } 470 471 std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const { 472 auto *listInit = def.getValueAsListInit("constraints"); 473 std::vector<tblgen::AppliedConstraint> ret; 474 ret.reserve(listInit->size()); 475 476 for (auto it : *listInit) { 477 auto *dagInit = dyn_cast<llvm::DagInit>(it); 478 if (!dagInit) 479 PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity " 480 "constraints should be DAG nodes"); 481 482 std::vector<std::string> entities; 483 entities.reserve(dagInit->arg_size()); 484 for (auto *argName : dagInit->getArgNames()) { 485 if (!argName) { 486 PrintFatalError( 487 def.getLoc(), 488 "operands to additional constraints can only be symbol references"); 489 } 490 entities.push_back(argName->getValue()); 491 } 492 493 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), 494 dagInit->getNameStr(), std::move(entities)); 495 } 496 return ret; 497 } 498 499 int tblgen::Pattern::getBenefit() const { 500 // The initial benefit value is a heuristic with number of ops in the source 501 // pattern. 502 int initBenefit = getSourcePattern().getNumOps(); 503 llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); 504 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { 505 PrintFatalError(def.getLoc(), 506 "The 'addBenefit' takes and only takes one integer value"); 507 } 508 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); 509 } 510 511 std::vector<tblgen::Pattern::IdentifierLine> 512 tblgen::Pattern::getLocation() const { 513 std::vector<std::pair<StringRef, unsigned>> result; 514 result.reserve(def.getLoc().size()); 515 for (auto loc : def.getLoc()) { 516 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 517 assert(buf && "invalid source location"); 518 result.emplace_back( 519 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 520 llvm::SrcMgr.getLineAndColumn(loc, buf).first); 521 } 522 return result; 523 } 524 525 void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 526 bool isSrcPattern) { 527 auto treeName = tree.getSymbol(); 528 if (!tree.isOperation()) { 529 if (!treeName.empty()) { 530 PrintFatalError( 531 def.getLoc(), 532 formatv("binding symbol '{0}' to non-operation unsupported right now", 533 treeName)); 534 } 535 return; 536 } 537 538 auto &op = getDialectOp(tree); 539 auto numOpArgs = op.getNumArgs(); 540 auto numTreeArgs = tree.getNumArgs(); 541 542 if (numOpArgs != numTreeArgs) { 543 auto err = formatv("op '{0}' argument number mismatch: " 544 "{1} in pattern vs. {2} in definition", 545 op.getOperationName(), numTreeArgs, numOpArgs); 546 PrintFatalError(def.getLoc(), err); 547 } 548 549 // The name attached to the DAG node's operator is for representing the 550 // results generated from this op. It should be remembered as bound results. 551 if (!treeName.empty()) { 552 LLVM_DEBUG(llvm::dbgs() 553 << "found symbol bound to op result: " << treeName << '\n'); 554 if (!infoMap.bindOpResult(treeName, op)) 555 PrintFatalError(def.getLoc(), 556 formatv("symbol '{0}' bound more than once", treeName)); 557 } 558 559 for (int i = 0; i != numTreeArgs; ++i) { 560 if (auto treeArg = tree.getArgAsNestedDag(i)) { 561 // This DAG node argument is a DAG node itself. Go inside recursively. 562 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 563 } else if (isSrcPattern) { 564 // We can only bind symbols to op arguments in source pattern. Those 565 // symbols are referenced in result patterns. 566 auto treeArgName = tree.getArgName(i); 567 // `$_` is a special symbol meaning ignore the current argument. 568 if (!treeArgName.empty() && treeArgName != "_") { 569 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " 570 << treeArgName << '\n'); 571 if (!infoMap.bindOpArgument(treeArgName, op, i)) { 572 auto err = formatv("symbol '{0}' bound more than once", treeArgName); 573 PrintFatalError(def.getLoc(), err); 574 } 575 } 576 } 577 } 578 } 579