1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Pattern.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 #define DEBUG_TYPE "mlir-tblgen-pattern"
23 
24 using namespace mlir;
25 using namespace tblgen;
26 
27 using llvm::formatv;
28 
29 //===----------------------------------------------------------------------===//
30 // DagLeaf
31 //===----------------------------------------------------------------------===//
32 
33 bool DagLeaf::isUnspecified() const {
34   return dyn_cast_or_null<llvm::UnsetInit>(def);
35 }
36 
37 bool DagLeaf::isOperandMatcher() const {
38   // Operand matchers specify a type constraint.
39   return isSubClassOf("TypeConstraint");
40 }
41 
42 bool DagLeaf::isAttrMatcher() const {
43   // Attribute matchers specify an attribute constraint.
44   return isSubClassOf("AttrConstraint");
45 }
46 
47 bool DagLeaf::isNativeCodeCall() const {
48   return isSubClassOf("NativeCodeCall");
49 }
50 
51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
52 
53 bool DagLeaf::isEnumAttrCase() const {
54   return isSubClassOf("EnumAttrCaseInfo");
55 }
56 
57 bool DagLeaf::isStringAttr() const {
58   return isa<llvm::StringInit, llvm::CodeInit>(def);
59 }
60 
61 Constraint DagLeaf::getAsConstraint() const {
62   assert((isOperandMatcher() || isAttrMatcher()) &&
63          "the DAG leaf must be operand or attribute");
64   return Constraint(cast<llvm::DefInit>(def)->getDef());
65 }
66 
67 ConstantAttr DagLeaf::getAsConstantAttr() const {
68   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
69   return ConstantAttr(cast<llvm::DefInit>(def));
70 }
71 
72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
73   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
74   return EnumAttrCase(cast<llvm::DefInit>(def));
75 }
76 
77 std::string DagLeaf::getConditionTemplate() const {
78   return getAsConstraint().getConditionTemplate();
79 }
80 
81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
82   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
83   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
84 }
85 
86 std::string DagLeaf::getStringAttr() const {
87   assert(isStringAttr() && "the DAG leaf must be string attribute");
88   return def->getAsUnquotedString();
89 }
90 bool DagLeaf::isSubClassOf(StringRef superclass) const {
91   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
92     return defInit->getDef()->isSubClassOf(superclass);
93   return false;
94 }
95 
96 void DagLeaf::print(raw_ostream &os) const {
97   if (def)
98     def->print(os);
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // DagNode
103 //===----------------------------------------------------------------------===//
104 
105 bool DagNode::isNativeCodeCall() const {
106   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
107     return defInit->getDef()->isSubClassOf("NativeCodeCall");
108   return false;
109 }
110 
111 bool DagNode::isOperation() const {
112   return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective();
113 }
114 
115 llvm::StringRef DagNode::getNativeCodeTemplate() const {
116   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
117   return cast<llvm::DefInit>(node->getOperator())
118       ->getDef()
119       ->getValueAsString("expression");
120 }
121 
122 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
123 
124 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
125   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
126   auto it = mapper->find(opDef);
127   if (it != mapper->end())
128     return *it->second;
129   return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
130               .first->second;
131 }
132 
133 int DagNode::getNumOps() const {
134   int count = isReplaceWithValue() ? 0 : 1;
135   for (int i = 0, e = getNumArgs(); i != e; ++i) {
136     if (auto child = getArgAsNestedDag(i))
137       count += child.getNumOps();
138   }
139   return count;
140 }
141 
142 int DagNode::getNumArgs() const { return node->getNumArgs(); }
143 
144 bool DagNode::isNestedDagArg(unsigned index) const {
145   return isa<llvm::DagInit>(node->getArg(index));
146 }
147 
148 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
149   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
150 }
151 
152 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
153   assert(!isNestedDagArg(index));
154   return DagLeaf(node->getArg(index));
155 }
156 
157 StringRef DagNode::getArgName(unsigned index) const {
158   return node->getArgNameStr(index);
159 }
160 
161 bool DagNode::isReplaceWithValue() const {
162   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
163   return dagOpDef->getName() == "replaceWithValue";
164 }
165 
166 bool DagNode::isLocationDirective() const {
167   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
168   return dagOpDef->getName() == "location";
169 }
170 
171 void DagNode::print(raw_ostream &os) const {
172   if (node)
173     node->print(os);
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // SymbolInfoMap
178 //===----------------------------------------------------------------------===//
179 
180 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
181   StringRef name, indexStr;
182   int idx = -1;
183   std::tie(name, indexStr) = symbol.rsplit("__");
184 
185   if (indexStr.consumeInteger(10, idx)) {
186     // The second part is not an index; we return the whole symbol as-is.
187     return symbol;
188   }
189   if (index) {
190     *index = idx;
191   }
192   return name;
193 }
194 
195 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
196                                       Optional<int> index)
197     : op(op), kind(kind), argIndex(index) {}
198 
199 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
200   switch (kind) {
201   case Kind::Attr:
202   case Kind::Operand:
203   case Kind::Value:
204     return 1;
205   case Kind::Result:
206     return op->getNumResults();
207   }
208   llvm_unreachable("unknown kind");
209 }
210 
211 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
212   return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
213 }
214 
215 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
216   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
217   switch (kind) {
218   case Kind::Attr: {
219     auto type =
220         op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
221     return std::string(formatv("{0} {1};\n", type, name));
222   }
223   case Kind::Operand: {
224     // Use operand range for captured operands (to support potential variadic
225     // operands).
226     return std::string(
227         formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
228                 getVarName(name)));
229   }
230   case Kind::Value: {
231     return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name));
232   }
233   case Kind::Result: {
234     // Use the op itself for captured results.
235     return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
236   }
237   }
238   llvm_unreachable("unknown kind");
239 }
240 
241 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
242     StringRef name, int index, const char *fmt, const char *separator) const {
243   LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
244   switch (kind) {
245   case Kind::Attr: {
246     assert(index < 0);
247     auto repl = formatv(fmt, name);
248     LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
249     return std::string(repl);
250   }
251   case Kind::Operand: {
252     assert(index < 0);
253     auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
254     // If this operand is variadic, then return a range. Otherwise, return the
255     // value itself.
256     if (operand->isVariableLength()) {
257       auto repl = formatv(fmt, name);
258       LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
259       return std::string(repl);
260     }
261     auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
262     LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
263     return std::string(repl);
264   }
265   case Kind::Result: {
266     // If `index` is greater than zero, then we are referencing a specific
267     // result of a multi-result op. The result can still be variadic.
268     if (index >= 0) {
269       std::string v =
270           std::string(formatv("{0}.getODSResults({1})", name, index));
271       if (!op->getResult(index).isVariadic())
272         v = std::string(formatv("(*{0}.begin())", v));
273       auto repl = formatv(fmt, v);
274       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
275       return std::string(repl);
276     }
277 
278     // If this op has no result at all but still we bind a symbol to it, it
279     // means we want to capture the op itself.
280     if (op->getNumResults() == 0) {
281       LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
282       return std::string(name);
283     }
284 
285     // We are referencing all results of the multi-result op. A specific result
286     // can either be a value or a range. Then join them with `separator`.
287     SmallVector<std::string, 4> values;
288     values.reserve(op->getNumResults());
289 
290     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
291       std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
292       if (!op->getResult(i).isVariadic()) {
293         v = std::string(formatv("(*{0}.begin())", v));
294       }
295       values.push_back(std::string(formatv(fmt, v)));
296     }
297     auto repl = llvm::join(values, separator);
298     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
299     return repl;
300   }
301   case Kind::Value: {
302     assert(index < 0);
303     assert(op == nullptr);
304     auto repl = formatv(fmt, name);
305     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
306     return std::string(repl);
307   }
308   }
309   llvm_unreachable("unknown kind");
310 }
311 
312 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
313     StringRef name, int index, const char *fmt, const char *separator) const {
314   LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
315   switch (kind) {
316   case Kind::Attr:
317   case Kind::Operand: {
318     assert(index < 0 && "only allowed for symbol bound to result");
319     auto repl = formatv(fmt, name);
320     LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
321     return std::string(repl);
322   }
323   case Kind::Result: {
324     if (index >= 0) {
325       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
326       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
327       return std::string(repl);
328     }
329 
330     // We are referencing all results of the multi-result op. Each result should
331     // have a value range, and then join them with `separator`.
332     SmallVector<std::string, 4> values;
333     values.reserve(op->getNumResults());
334 
335     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
336       values.push_back(std::string(
337           formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
338     }
339     auto repl = llvm::join(values, separator);
340     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
341     return repl;
342   }
343   case Kind::Value: {
344     assert(index < 0 && "only allowed for symbol bound to result");
345     assert(op == nullptr);
346     auto repl = formatv(fmt, formatv("{{{0}}", name));
347     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
348     return std::string(repl);
349   }
350   }
351   llvm_unreachable("unknown kind");
352 }
353 
354 bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
355                                    int argIndex) {
356   StringRef name = getValuePackName(symbol);
357   if (name != symbol) {
358     auto error = formatv(
359         "symbol '{0}' with trailing index cannot bind to op argument", symbol);
360     PrintFatalError(loc, error);
361   }
362 
363   auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
364                      ? SymbolInfo::getAttr(&op, argIndex)
365                      : SymbolInfo::getOperand(&op, argIndex);
366 
367   std::string key = symbol.str();
368   if (auto numberOfEntries = symbolInfoMap.count(key)) {
369     // Only non unique name for the operand is supported.
370     if (symInfo.kind != SymbolInfo::Kind::Operand) {
371       return false;
372     }
373 
374     // Cannot add new operand if there is already non operand with the same
375     // name.
376     if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
377       return false;
378     }
379   }
380 
381   symbolInfoMap.emplace(key, symInfo);
382   return true;
383 }
384 
385 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
386   StringRef name = getValuePackName(symbol);
387   auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
388 
389   return symbolInfoMap.count(inserted->first) == 1;
390 }
391 
392 bool SymbolInfoMap::bindValue(StringRef symbol) {
393   auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getValue());
394   return symbolInfoMap.count(inserted->first) == 1;
395 }
396 
397 bool SymbolInfoMap::contains(StringRef symbol) const {
398   return find(symbol) != symbolInfoMap.end();
399 }
400 
401 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
402   std::string name = getValuePackName(key).str();
403 
404   return symbolInfoMap.find(name);
405 }
406 
407 SymbolInfoMap::const_iterator
408 SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
409                                int argIndex) const {
410   std::string name = getValuePackName(key).str();
411   auto range = symbolInfoMap.equal_range(name);
412 
413   for (auto it = range.first; it != range.second; ++it) {
414     if (it->second.op == &op && it->second.argIndex == argIndex) {
415       return it;
416     }
417   }
418 
419   return symbolInfoMap.end();
420 }
421 
422 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
423 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
424   std::string name = getValuePackName(key).str();
425 
426   return symbolInfoMap.equal_range(name);
427 }
428 
429 int SymbolInfoMap::count(StringRef key) const {
430   std::string name = getValuePackName(key).str();
431   return symbolInfoMap.count(name);
432 }
433 
434 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
435   StringRef name = getValuePackName(symbol);
436   if (name != symbol) {
437     // If there is a trailing index inside symbol, it references just one
438     // static value.
439     return 1;
440   }
441   // Otherwise, find how many it represents by querying the symbol's info.
442   return find(name)->second.getStaticValueCount();
443 }
444 
445 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
446                                                const char *fmt,
447                                                const char *separator) const {
448   int index = -1;
449   StringRef name = getValuePackName(symbol, &index);
450 
451   auto it = symbolInfoMap.find(name.str());
452   if (it == symbolInfoMap.end()) {
453     auto error = formatv("referencing unbound symbol '{0}'", symbol);
454     PrintFatalError(loc, error);
455   }
456 
457   return it->second.getValueAndRangeUse(name, index, fmt, separator);
458 }
459 
460 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
461                                           const char *separator) const {
462   int index = -1;
463   StringRef name = getValuePackName(symbol, &index);
464 
465   auto it = symbolInfoMap.find(name.str());
466   if (it == symbolInfoMap.end()) {
467     auto error = formatv("referencing unbound symbol '{0}'", symbol);
468     PrintFatalError(loc, error);
469   }
470 
471   return it->second.getAllRangeUse(name, index, fmt, separator);
472 }
473 
474 void SymbolInfoMap::assignUniqueAlternativeNames() {
475   llvm::StringSet<> usedNames;
476 
477   for (auto symbolInfoIt = symbolInfoMap.begin();
478        symbolInfoIt != symbolInfoMap.end();) {
479     auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
480     auto startRange = range.first;
481     auto endRange = range.second;
482 
483     auto operandName = symbolInfoIt->first;
484     int startSearchIndex = 0;
485     for (++startRange; startRange != endRange; ++startRange) {
486       // Current operand name is not unique, find a unique one
487       // and set the alternative name.
488       for (int i = startSearchIndex;; ++i) {
489         std::string alternativeName = operandName + std::to_string(i);
490         if (!usedNames.contains(alternativeName) &&
491             symbolInfoMap.count(alternativeName) == 0) {
492           usedNames.insert(alternativeName);
493           startRange->second.alternativeName = alternativeName;
494           startSearchIndex = i + 1;
495 
496           break;
497         }
498       }
499     }
500 
501     symbolInfoIt = endRange;
502   }
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // Pattern
507 //==----------------------------------------------------------------------===//
508 
509 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
510     : def(*def), recordOpMap(mapper) {}
511 
512 DagNode Pattern::getSourcePattern() const {
513   return DagNode(def.getValueAsDag("sourcePattern"));
514 }
515 
516 int Pattern::getNumResultPatterns() const {
517   auto *results = def.getValueAsListInit("resultPatterns");
518   return results->size();
519 }
520 
521 DagNode Pattern::getResultPattern(unsigned index) const {
522   auto *results = def.getValueAsListInit("resultPatterns");
523   return DagNode(cast<llvm::DagInit>(results->getElement(index)));
524 }
525 
526 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
527   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
528   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
529   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
530 
531   LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
532   infoMap.assignUniqueAlternativeNames();
533   LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
534 }
535 
536 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
537   LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
538   for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
539     auto pattern = getResultPattern(i);
540     collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
541   }
542   LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
543 }
544 
545 const Operator &Pattern::getSourceRootOp() {
546   return getSourcePattern().getDialectOp(recordOpMap);
547 }
548 
549 Operator &Pattern::getDialectOp(DagNode node) {
550   return node.getDialectOp(recordOpMap);
551 }
552 
553 std::vector<AppliedConstraint> Pattern::getConstraints() const {
554   auto *listInit = def.getValueAsListInit("constraints");
555   std::vector<AppliedConstraint> ret;
556   ret.reserve(listInit->size());
557 
558   for (auto it : *listInit) {
559     auto *dagInit = dyn_cast<llvm::DagInit>(it);
560     if (!dagInit)
561       PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity "
562                                     "constraints should be DAG nodes");
563 
564     std::vector<std::string> entities;
565     entities.reserve(dagInit->arg_size());
566     for (auto *argName : dagInit->getArgNames()) {
567       if (!argName) {
568         PrintFatalError(
569             def.getLoc(),
570             "operands to additional constraints can only be symbol references");
571       }
572       entities.push_back(std::string(argName->getValue()));
573     }
574 
575     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
576                      dagInit->getNameStr(), std::move(entities));
577   }
578   return ret;
579 }
580 
581 int Pattern::getBenefit() const {
582   // The initial benefit value is a heuristic with number of ops in the source
583   // pattern.
584   int initBenefit = getSourcePattern().getNumOps();
585   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
586   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
587     PrintFatalError(def.getLoc(),
588                     "The 'addBenefit' takes and only takes one integer value");
589   }
590   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
591 }
592 
593 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
594   std::vector<std::pair<StringRef, unsigned>> result;
595   result.reserve(def.getLoc().size());
596   for (auto loc : def.getLoc()) {
597     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
598     assert(buf && "invalid source location");
599     result.emplace_back(
600         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
601         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
602   }
603   return result;
604 }
605 
606 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
607                                   bool isSrcPattern) {
608   auto treeName = tree.getSymbol();
609   if (!tree.isOperation()) {
610     if (!treeName.empty()) {
611       PrintFatalError(
612           def.getLoc(),
613           formatv("binding symbol '{0}' to non-operation unsupported right now",
614                   treeName));
615     }
616     return;
617   }
618 
619   auto &op = getDialectOp(tree);
620   auto numOpArgs = op.getNumArgs();
621   auto numTreeArgs = tree.getNumArgs();
622 
623   // The pattern might have the last argument specifying the location.
624   bool hasLocDirective = false;
625   if (numTreeArgs != 0) {
626     if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
627       hasLocDirective = lastArg.isLocationDirective();
628   }
629 
630   if (numOpArgs != numTreeArgs - hasLocDirective) {
631     auto err = formatv("op '{0}' argument number mismatch: "
632                        "{1} in pattern vs. {2} in definition",
633                        op.getOperationName(), numTreeArgs, numOpArgs);
634     PrintFatalError(def.getLoc(), err);
635   }
636 
637   // The name attached to the DAG node's operator is for representing the
638   // results generated from this op. It should be remembered as bound results.
639   if (!treeName.empty()) {
640     LLVM_DEBUG(llvm::dbgs()
641                << "found symbol bound to op result: " << treeName << '\n');
642     if (!infoMap.bindOpResult(treeName, op))
643       PrintFatalError(def.getLoc(),
644                       formatv("symbol '{0}' bound more than once", treeName));
645   }
646 
647   for (int i = 0; i != numTreeArgs; ++i) {
648     if (auto treeArg = tree.getArgAsNestedDag(i)) {
649       // This DAG node argument is a DAG node itself. Go inside recursively.
650       collectBoundSymbols(treeArg, infoMap, isSrcPattern);
651     } else if (isSrcPattern) {
652       // We can only bind symbols to op arguments in source pattern. Those
653       // symbols are referenced in result patterns.
654       auto treeArgName = tree.getArgName(i);
655       // `$_` is a special symbol meaning ignore the current argument.
656       if (!treeArgName.empty() && treeArgName != "_") {
657         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
658                                 << treeArgName << '\n');
659         if (!infoMap.bindOpArgument(treeArgName, op, i)) {
660           auto err = formatv("symbol '{0}' bound more than once", treeArgName);
661           PrintFatalError(def.getLoc(), err);
662         }
663       }
664     }
665   }
666 }
667