1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Pattern.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 #define DEBUG_TYPE "mlir-tblgen-pattern"
23 
24 using namespace mlir;
25 using namespace tblgen;
26 
27 using llvm::formatv;
28 
29 //===----------------------------------------------------------------------===//
30 // DagLeaf
31 //===----------------------------------------------------------------------===//
32 
33 bool DagLeaf::isUnspecified() const {
34   return dyn_cast_or_null<llvm::UnsetInit>(def);
35 }
36 
37 bool DagLeaf::isOperandMatcher() const {
38   // Operand matchers specify a type constraint.
39   return isSubClassOf("TypeConstraint");
40 }
41 
42 bool DagLeaf::isAttrMatcher() const {
43   // Attribute matchers specify an attribute constraint.
44   return isSubClassOf("AttrConstraint");
45 }
46 
47 bool DagLeaf::isNativeCodeCall() const {
48   return isSubClassOf("NativeCodeCall");
49 }
50 
51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
52 
53 bool DagLeaf::isEnumAttrCase() const {
54   return isSubClassOf("EnumAttrCaseInfo");
55 }
56 
57 bool DagLeaf::isStringAttr() const {
58   return isa<llvm::StringInit>(def);
59 }
60 
61 Constraint DagLeaf::getAsConstraint() const {
62   assert((isOperandMatcher() || isAttrMatcher()) &&
63          "the DAG leaf must be operand or attribute");
64   return Constraint(cast<llvm::DefInit>(def)->getDef());
65 }
66 
67 ConstantAttr DagLeaf::getAsConstantAttr() const {
68   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
69   return ConstantAttr(cast<llvm::DefInit>(def));
70 }
71 
72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
73   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
74   return EnumAttrCase(cast<llvm::DefInit>(def));
75 }
76 
77 std::string DagLeaf::getConditionTemplate() const {
78   return getAsConstraint().getConditionTemplate();
79 }
80 
81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
82   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
83   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
84 }
85 
86 std::string DagLeaf::getStringAttr() const {
87   assert(isStringAttr() && "the DAG leaf must be string attribute");
88   return def->getAsUnquotedString();
89 }
90 bool DagLeaf::isSubClassOf(StringRef superclass) const {
91   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
92     return defInit->getDef()->isSubClassOf(superclass);
93   return false;
94 }
95 
96 void DagLeaf::print(raw_ostream &os) const {
97   if (def)
98     def->print(os);
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // DagNode
103 //===----------------------------------------------------------------------===//
104 
105 bool DagNode::isNativeCodeCall() const {
106   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
107     return defInit->getDef()->isSubClassOf("NativeCodeCall");
108   return false;
109 }
110 
111 bool DagNode::isOperation() const {
112   return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective();
113 }
114 
115 llvm::StringRef DagNode::getNativeCodeTemplate() const {
116   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
117   return cast<llvm::DefInit>(node->getOperator())
118       ->getDef()
119       ->getValueAsString("expression");
120 }
121 
122 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
123 
124 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
125   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
126   auto it = mapper->find(opDef);
127   if (it != mapper->end())
128     return *it->second;
129   return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
130               .first->second;
131 }
132 
133 int DagNode::getNumOps() const {
134   int count = isReplaceWithValue() ? 0 : 1;
135   for (int i = 0, e = getNumArgs(); i != e; ++i) {
136     if (auto child = getArgAsNestedDag(i))
137       count += child.getNumOps();
138   }
139   return count;
140 }
141 
142 int DagNode::getNumArgs() const { return node->getNumArgs(); }
143 
144 bool DagNode::isNestedDagArg(unsigned index) const {
145   return isa<llvm::DagInit>(node->getArg(index));
146 }
147 
148 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
149   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
150 }
151 
152 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
153   assert(!isNestedDagArg(index));
154   return DagLeaf(node->getArg(index));
155 }
156 
157 StringRef DagNode::getArgName(unsigned index) const {
158   return node->getArgNameStr(index);
159 }
160 
161 bool DagNode::isReplaceWithValue() const {
162   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
163   return dagOpDef->getName() == "replaceWithValue";
164 }
165 
166 bool DagNode::isLocationDirective() const {
167   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
168   return dagOpDef->getName() == "location";
169 }
170 
171 void DagNode::print(raw_ostream &os) const {
172   if (node)
173     node->print(os);
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // SymbolInfoMap
178 //===----------------------------------------------------------------------===//
179 
180 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
181   StringRef name, indexStr;
182   int idx = -1;
183   std::tie(name, indexStr) = symbol.rsplit("__");
184 
185   if (indexStr.consumeInteger(10, idx)) {
186     // The second part is not an index; we return the whole symbol as-is.
187     return symbol;
188   }
189   if (index) {
190     *index = idx;
191   }
192   return name;
193 }
194 
195 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
196                                       Optional<DagAndIndex> dagAndIndex)
197     : op(op), kind(kind), dagAndIndex(dagAndIndex) {}
198 
199 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
200   switch (kind) {
201   case Kind::Attr:
202   case Kind::Operand:
203   case Kind::Value:
204     return 1;
205   case Kind::Result:
206     return op->getNumResults();
207   }
208   llvm_unreachable("unknown kind");
209 }
210 
211 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
212   return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
213 }
214 
215 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
216   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
217   switch (kind) {
218   case Kind::Attr: {
219     if (op) {
220       auto type = op->getArg((*dagAndIndex).second)
221                       .get<NamedAttribute *>()
222                       ->attr.getStorageType();
223       return std::string(formatv("{0} {1};\n", type, name));
224     }
225     // TODO(suderman): Use a more exact type when available.
226     return std::string(formatv("Attribute {0};\n", name));
227   }
228   case Kind::Operand: {
229     // Use operand range for captured operands (to support potential variadic
230     // operands).
231     return std::string(
232         formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
233                 getVarName(name)));
234   }
235   case Kind::Value: {
236     return std::string(formatv("::mlir::Value {0};\n", name));
237   }
238   case Kind::Result: {
239     // Use the op itself for captured results.
240     return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
241   }
242   }
243   llvm_unreachable("unknown kind");
244 }
245 
246 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
247     StringRef name, int index, const char *fmt, const char *separator) const {
248   LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
249   switch (kind) {
250   case Kind::Attr: {
251     assert(index < 0);
252     auto repl = formatv(fmt, name);
253     LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
254     return std::string(repl);
255   }
256   case Kind::Operand: {
257     assert(index < 0);
258     auto *operand =
259         op->getArg((*dagAndIndex).second).get<NamedTypeConstraint *>();
260     // If this operand is variadic, then return a range. Otherwise, return the
261     // value itself.
262     if (operand->isVariableLength()) {
263       auto repl = formatv(fmt, name);
264       LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
265       return std::string(repl);
266     }
267     auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
268     LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
269     return std::string(repl);
270   }
271   case Kind::Result: {
272     // If `index` is greater than zero, then we are referencing a specific
273     // result of a multi-result op. The result can still be variadic.
274     if (index >= 0) {
275       std::string v =
276           std::string(formatv("{0}.getODSResults({1})", name, index));
277       if (!op->getResult(index).isVariadic())
278         v = std::string(formatv("(*{0}.begin())", v));
279       auto repl = formatv(fmt, v);
280       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
281       return std::string(repl);
282     }
283 
284     // If this op has no result at all but still we bind a symbol to it, it
285     // means we want to capture the op itself.
286     if (op->getNumResults() == 0) {
287       LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
288       return std::string(name);
289     }
290 
291     // We are referencing all results of the multi-result op. A specific result
292     // can either be a value or a range. Then join them with `separator`.
293     SmallVector<std::string, 4> values;
294     values.reserve(op->getNumResults());
295 
296     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
297       std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
298       if (!op->getResult(i).isVariadic()) {
299         v = std::string(formatv("(*{0}.begin())", v));
300       }
301       values.push_back(std::string(formatv(fmt, v)));
302     }
303     auto repl = llvm::join(values, separator);
304     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
305     return repl;
306   }
307   case Kind::Value: {
308     assert(index < 0);
309     assert(op == nullptr);
310     auto repl = formatv(fmt, name);
311     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
312     return std::string(repl);
313   }
314   }
315   llvm_unreachable("unknown kind");
316 }
317 
318 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
319     StringRef name, int index, const char *fmt, const char *separator) const {
320   LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
321   switch (kind) {
322   case Kind::Attr:
323   case Kind::Operand: {
324     assert(index < 0 && "only allowed for symbol bound to result");
325     auto repl = formatv(fmt, name);
326     LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
327     return std::string(repl);
328   }
329   case Kind::Result: {
330     if (index >= 0) {
331       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
332       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
333       return std::string(repl);
334     }
335 
336     // We are referencing all results of the multi-result op. Each result should
337     // have a value range, and then join them with `separator`.
338     SmallVector<std::string, 4> values;
339     values.reserve(op->getNumResults());
340 
341     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
342       values.push_back(std::string(
343           formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
344     }
345     auto repl = llvm::join(values, separator);
346     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
347     return repl;
348   }
349   case Kind::Value: {
350     assert(index < 0 && "only allowed for symbol bound to result");
351     assert(op == nullptr);
352     auto repl = formatv(fmt, formatv("{{{0}}", name));
353     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
354     return std::string(repl);
355   }
356   }
357   llvm_unreachable("unknown kind");
358 }
359 
360 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
361                                    const Operator &op, int argIndex) {
362   StringRef name = getValuePackName(symbol);
363   if (name != symbol) {
364     auto error = formatv(
365         "symbol '{0}' with trailing index cannot bind to op argument", symbol);
366     PrintFatalError(loc, error);
367   }
368 
369   auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
370                      ? SymbolInfo::getAttr(&op, argIndex)
371                      : SymbolInfo::getOperand(node, &op, argIndex);
372 
373   std::string key = symbol.str();
374   if (symbolInfoMap.count(key)) {
375     // Only non unique name for the operand is supported.
376     if (symInfo.kind != SymbolInfo::Kind::Operand) {
377       return false;
378     }
379 
380     // Cannot add new operand if there is already non operand with the same
381     // name.
382     if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
383       return false;
384     }
385   }
386 
387   symbolInfoMap.emplace(key, symInfo);
388   return true;
389 }
390 
391 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
392   std::string name = getValuePackName(symbol).str();
393   auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
394 
395   return symbolInfoMap.count(inserted->first) == 1;
396 }
397 
398 bool SymbolInfoMap::bindValue(StringRef symbol) {
399   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
400   return symbolInfoMap.count(inserted->first) == 1;
401 }
402 
403 bool SymbolInfoMap::bindAttr(StringRef symbol) {
404   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
405   return symbolInfoMap.count(inserted->first) == 1;
406 }
407 
408 bool SymbolInfoMap::contains(StringRef symbol) const {
409   return find(symbol) != symbolInfoMap.end();
410 }
411 
412 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
413   std::string name = getValuePackName(key).str();
414 
415   return symbolInfoMap.find(name);
416 }
417 
418 SymbolInfoMap::const_iterator
419 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
420                                int argIndex) const {
421   std::string name = getValuePackName(key).str();
422   auto range = symbolInfoMap.equal_range(name);
423 
424   const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex);
425 
426   for (auto it = range.first; it != range.second; ++it) {
427     if (it->second.dagAndIndex == symbolInfo.dagAndIndex) {
428       return it;
429     }
430   }
431 
432   return symbolInfoMap.end();
433 }
434 
435 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
436 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
437   std::string name = getValuePackName(key).str();
438 
439   return symbolInfoMap.equal_range(name);
440 }
441 
442 int SymbolInfoMap::count(StringRef key) const {
443   std::string name = getValuePackName(key).str();
444   return symbolInfoMap.count(name);
445 }
446 
447 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
448   StringRef name = getValuePackName(symbol);
449   if (name != symbol) {
450     // If there is a trailing index inside symbol, it references just one
451     // static value.
452     return 1;
453   }
454   // Otherwise, find how many it represents by querying the symbol's info.
455   return find(name)->second.getStaticValueCount();
456 }
457 
458 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
459                                                const char *fmt,
460                                                const char *separator) const {
461   int index = -1;
462   StringRef name = getValuePackName(symbol, &index);
463 
464   auto it = symbolInfoMap.find(name.str());
465   if (it == symbolInfoMap.end()) {
466     auto error = formatv("referencing unbound symbol '{0}'", symbol);
467     PrintFatalError(loc, error);
468   }
469 
470   return it->second.getValueAndRangeUse(name, index, fmt, separator);
471 }
472 
473 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
474                                           const char *separator) const {
475   int index = -1;
476   StringRef name = getValuePackName(symbol, &index);
477 
478   auto it = symbolInfoMap.find(name.str());
479   if (it == symbolInfoMap.end()) {
480     auto error = formatv("referencing unbound symbol '{0}'", symbol);
481     PrintFatalError(loc, error);
482   }
483 
484   return it->second.getAllRangeUse(name, index, fmt, separator);
485 }
486 
487 void SymbolInfoMap::assignUniqueAlternativeNames() {
488   llvm::StringSet<> usedNames;
489 
490   for (auto symbolInfoIt = symbolInfoMap.begin();
491        symbolInfoIt != symbolInfoMap.end();) {
492     auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
493     auto startRange = range.first;
494     auto endRange = range.second;
495 
496     auto operandName = symbolInfoIt->first;
497     int startSearchIndex = 0;
498     for (++startRange; startRange != endRange; ++startRange) {
499       // Current operand name is not unique, find a unique one
500       // and set the alternative name.
501       for (int i = startSearchIndex;; ++i) {
502         std::string alternativeName = operandName + std::to_string(i);
503         if (!usedNames.contains(alternativeName) &&
504             symbolInfoMap.count(alternativeName) == 0) {
505           usedNames.insert(alternativeName);
506           startRange->second.alternativeName = alternativeName;
507           startSearchIndex = i + 1;
508 
509           break;
510         }
511       }
512     }
513 
514     symbolInfoIt = endRange;
515   }
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // Pattern
520 //==----------------------------------------------------------------------===//
521 
522 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
523     : def(*def), recordOpMap(mapper) {}
524 
525 DagNode Pattern::getSourcePattern() const {
526   return DagNode(def.getValueAsDag("sourcePattern"));
527 }
528 
529 int Pattern::getNumResultPatterns() const {
530   auto *results = def.getValueAsListInit("resultPatterns");
531   return results->size();
532 }
533 
534 DagNode Pattern::getResultPattern(unsigned index) const {
535   auto *results = def.getValueAsListInit("resultPatterns");
536   return DagNode(cast<llvm::DagInit>(results->getElement(index)));
537 }
538 
539 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
540   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
541   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
542   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
543 
544   LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
545   infoMap.assignUniqueAlternativeNames();
546   LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
547 }
548 
549 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
550   LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
551   for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
552     auto pattern = getResultPattern(i);
553     collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
554   }
555   LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
556 }
557 
558 const Operator &Pattern::getSourceRootOp() {
559   return getSourcePattern().getDialectOp(recordOpMap);
560 }
561 
562 Operator &Pattern::getDialectOp(DagNode node) {
563   return node.getDialectOp(recordOpMap);
564 }
565 
566 std::vector<AppliedConstraint> Pattern::getConstraints() const {
567   auto *listInit = def.getValueAsListInit("constraints");
568   std::vector<AppliedConstraint> ret;
569   ret.reserve(listInit->size());
570 
571   for (auto it : *listInit) {
572     auto *dagInit = dyn_cast<llvm::DagInit>(it);
573     if (!dagInit)
574       PrintFatalError(&def, "all elements in Pattern multi-entity "
575                             "constraints should be DAG nodes");
576 
577     std::vector<std::string> entities;
578     entities.reserve(dagInit->arg_size());
579     for (auto *argName : dagInit->getArgNames()) {
580       if (!argName) {
581         PrintFatalError(
582             &def,
583             "operands to additional constraints can only be symbol references");
584       }
585       entities.push_back(std::string(argName->getValue()));
586     }
587 
588     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
589                      dagInit->getNameStr(), std::move(entities));
590   }
591   return ret;
592 }
593 
594 int Pattern::getBenefit() const {
595   // The initial benefit value is a heuristic with number of ops in the source
596   // pattern.
597   int initBenefit = getSourcePattern().getNumOps();
598   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
599   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
600     PrintFatalError(&def,
601                     "The 'addBenefit' takes and only takes one integer value");
602   }
603   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
604 }
605 
606 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
607   std::vector<std::pair<StringRef, unsigned>> result;
608   result.reserve(def.getLoc().size());
609   for (auto loc : def.getLoc()) {
610     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
611     assert(buf && "invalid source location");
612     result.emplace_back(
613         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
614         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
615   }
616   return result;
617 }
618 
619 void Pattern::verifyBind(bool result, StringRef symbolName) {
620   if (!result) {
621     auto err = formatv("symbol '{0}' bound more than once", symbolName);
622     PrintFatalError(&def, err);
623   }
624 }
625 
626 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
627                                   bool isSrcPattern) {
628   auto treeName = tree.getSymbol();
629   auto numTreeArgs = tree.getNumArgs();
630 
631   if (tree.isNativeCodeCall()) {
632     if (!treeName.empty()) {
633       if (!isSrcPattern) {
634         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
635                                 << treeName << '\n');
636         verifyBind(infoMap.bindValue(treeName), treeName);
637       } else {
638         PrintFatalError(&def,
639                         formatv("binding symbol '{0}' to NativecodeCall in "
640                                 "MatchPattern is not supported",
641                                 treeName));
642       }
643     }
644 
645     for (int i = 0; i != numTreeArgs; ++i) {
646       if (auto treeArg = tree.getArgAsNestedDag(i)) {
647         // This DAG node argument is a DAG node itself. Go inside recursively.
648         collectBoundSymbols(treeArg, infoMap, isSrcPattern);
649         continue;
650       }
651 
652       if (!isSrcPattern)
653         continue;
654 
655       // We can only bind symbols to arguments in source pattern. Those
656       // symbols are referenced in result patterns.
657       auto treeArgName = tree.getArgName(i);
658 
659       // `$_` is a special symbol meaning ignore the current argument.
660       if (!treeArgName.empty() && treeArgName != "_") {
661         DagLeaf leaf = tree.getArgAsLeaf(i);
662 
663         // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
664         if (leaf.isUnspecified()) {
665           // This is case of $c, a Value without any constraints.
666           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
667         } else {
668           auto constraint = leaf.getAsConstraint();
669           bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
670                         leaf.isConstantAttr() ||
671                         constraint.getKind() == Constraint::Kind::CK_Attr;
672 
673           if (isAttr) {
674             // This is case of $a, a binding to a certain attribute.
675             verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
676             continue;
677           }
678 
679           // This is case of $b, a binding to a certain type.
680           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
681         }
682       }
683     }
684 
685     return;
686   }
687 
688   if (tree.isOperation()) {
689     auto &op = getDialectOp(tree);
690     auto numOpArgs = op.getNumArgs();
691 
692     // The pattern might have the last argument specifying the location.
693     bool hasLocDirective = false;
694     if (numTreeArgs != 0) {
695       if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
696         hasLocDirective = lastArg.isLocationDirective();
697     }
698 
699     if (numOpArgs != numTreeArgs - hasLocDirective) {
700       auto err = formatv("op '{0}' argument number mismatch: "
701                          "{1} in pattern vs. {2} in definition",
702                          op.getOperationName(), numTreeArgs, numOpArgs);
703       PrintFatalError(&def, err);
704     }
705 
706     // The name attached to the DAG node's operator is for representing the
707     // results generated from this op. It should be remembered as bound results.
708     if (!treeName.empty()) {
709       LLVM_DEBUG(llvm::dbgs()
710                  << "found symbol bound to op result: " << treeName << '\n');
711       verifyBind(infoMap.bindOpResult(treeName, op), treeName);
712     }
713 
714     for (int i = 0; i != numTreeArgs; ++i) {
715       if (auto treeArg = tree.getArgAsNestedDag(i)) {
716         // This DAG node argument is a DAG node itself. Go inside recursively.
717         collectBoundSymbols(treeArg, infoMap, isSrcPattern);
718         continue;
719       }
720 
721       if (isSrcPattern) {
722         // We can only bind symbols to op arguments in source pattern. Those
723         // symbols are referenced in result patterns.
724         auto treeArgName = tree.getArgName(i);
725         // `$_` is a special symbol meaning ignore the current argument.
726         if (!treeArgName.empty() && treeArgName != "_") {
727           LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
728                                   << treeArgName << '\n');
729           verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i),
730                      treeArgName);
731         }
732       }
733     }
734     return;
735   }
736 
737   if (!treeName.empty()) {
738     PrintFatalError(
739         &def, formatv("binding symbol '{0}' to non-operation/native code call "
740                       "unsupported right now",
741                       treeName));
742   }
743   return;
744 }
745