1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Pattern.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 #define DEBUG_TYPE "mlir-tblgen-pattern"
23 
24 using namespace mlir;
25 using namespace tblgen;
26 
27 using llvm::formatv;
28 
29 //===----------------------------------------------------------------------===//
30 // DagLeaf
31 //===----------------------------------------------------------------------===//
32 
33 bool DagLeaf::isUnspecified() const {
34   return dyn_cast_or_null<llvm::UnsetInit>(def);
35 }
36 
37 bool DagLeaf::isOperandMatcher() const {
38   // Operand matchers specify a type constraint.
39   return isSubClassOf("TypeConstraint");
40 }
41 
42 bool DagLeaf::isAttrMatcher() const {
43   // Attribute matchers specify an attribute constraint.
44   return isSubClassOf("AttrConstraint");
45 }
46 
47 bool DagLeaf::isNativeCodeCall() const {
48   return isSubClassOf("NativeCodeCall");
49 }
50 
51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
52 
53 bool DagLeaf::isEnumAttrCase() const {
54   return isSubClassOf("EnumAttrCaseInfo");
55 }
56 
57 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
58 
59 Constraint DagLeaf::getAsConstraint() const {
60   assert((isOperandMatcher() || isAttrMatcher()) &&
61          "the DAG leaf must be operand or attribute");
62   return Constraint(cast<llvm::DefInit>(def)->getDef());
63 }
64 
65 ConstantAttr DagLeaf::getAsConstantAttr() const {
66   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
67   return ConstantAttr(cast<llvm::DefInit>(def));
68 }
69 
70 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
71   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
72   return EnumAttrCase(cast<llvm::DefInit>(def));
73 }
74 
75 std::string DagLeaf::getConditionTemplate() const {
76   return getAsConstraint().getConditionTemplate();
77 }
78 
79 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
80   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
81   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
82 }
83 
84 int DagLeaf::getNumReturnsOfNativeCode() const {
85   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
86   return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
87 }
88 
89 std::string DagLeaf::getStringAttr() const {
90   assert(isStringAttr() && "the DAG leaf must be string attribute");
91   return def->getAsUnquotedString();
92 }
93 bool DagLeaf::isSubClassOf(StringRef superclass) const {
94   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
95     return defInit->getDef()->isSubClassOf(superclass);
96   return false;
97 }
98 
99 void DagLeaf::print(raw_ostream &os) const {
100   if (def)
101     def->print(os);
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // DagNode
106 //===----------------------------------------------------------------------===//
107 
108 bool DagNode::isNativeCodeCall() const {
109   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
110     return defInit->getDef()->isSubClassOf("NativeCodeCall");
111   return false;
112 }
113 
114 bool DagNode::isOperation() const {
115   return !isNativeCodeCall() && !isReplaceWithValue() &&
116          !isLocationDirective() && !isReturnTypeDirective();
117 }
118 
119 llvm::StringRef DagNode::getNativeCodeTemplate() const {
120   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
121   return cast<llvm::DefInit>(node->getOperator())
122       ->getDef()
123       ->getValueAsString("expression");
124 }
125 
126 int DagNode::getNumReturnsOfNativeCode() const {
127   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
128   return cast<llvm::DefInit>(node->getOperator())
129       ->getDef()
130       ->getValueAsInt("numReturns");
131 }
132 
133 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
134 
135 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
136   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
137   auto it = mapper->find(opDef);
138   if (it != mapper->end())
139     return *it->second;
140   return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
141               .first->second;
142 }
143 
144 int DagNode::getNumOps() const {
145   int count = isReplaceWithValue() ? 0 : 1;
146   for (int i = 0, e = getNumArgs(); i != e; ++i) {
147     if (auto child = getArgAsNestedDag(i))
148       count += child.getNumOps();
149   }
150   return count;
151 }
152 
153 int DagNode::getNumArgs() const { return node->getNumArgs(); }
154 
155 bool DagNode::isNestedDagArg(unsigned index) const {
156   return isa<llvm::DagInit>(node->getArg(index));
157 }
158 
159 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
160   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
161 }
162 
163 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
164   assert(!isNestedDagArg(index));
165   return DagLeaf(node->getArg(index));
166 }
167 
168 StringRef DagNode::getArgName(unsigned index) const {
169   return node->getArgNameStr(index);
170 }
171 
172 bool DagNode::isReplaceWithValue() const {
173   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
174   return dagOpDef->getName() == "replaceWithValue";
175 }
176 
177 bool DagNode::isLocationDirective() const {
178   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
179   return dagOpDef->getName() == "location";
180 }
181 
182 bool DagNode::isReturnTypeDirective() const {
183   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
184   return dagOpDef->getName() == "returnType";
185 }
186 
187 void DagNode::print(raw_ostream &os) const {
188   if (node)
189     node->print(os);
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // SymbolInfoMap
194 //===----------------------------------------------------------------------===//
195 
196 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
197   StringRef name, indexStr;
198   int idx = -1;
199   std::tie(name, indexStr) = symbol.rsplit("__");
200 
201   if (indexStr.consumeInteger(10, idx)) {
202     // The second part is not an index; we return the whole symbol as-is.
203     return symbol;
204   }
205   if (index) {
206     *index = idx;
207   }
208   return name;
209 }
210 
211 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
212                                       Optional<DagAndConstant> dagAndConstant)
213     : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
214 
215 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
216   switch (kind) {
217   case Kind::Attr:
218   case Kind::Operand:
219   case Kind::Value:
220     return 1;
221   case Kind::Result:
222     return op->getNumResults();
223   case Kind::MultipleValues:
224     return getSize();
225   }
226   llvm_unreachable("unknown kind");
227 }
228 
229 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
230   return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
231 }
232 
233 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
234   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
235   switch (kind) {
236   case Kind::Attr: {
237     if (op) {
238       auto type = op->getArg(getArgIndex())
239                       .get<NamedAttribute *>()
240                       ->attr.getStorageType();
241       return std::string(formatv("{0} {1};\n", type, name));
242     }
243     // TODO(suderman): Use a more exact type when available.
244     return std::string(formatv("Attribute {0};\n", name));
245   }
246   case Kind::Operand: {
247     // Use operand range for captured operands (to support potential variadic
248     // operands).
249     return std::string(
250         formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
251                 getVarName(name)));
252   }
253   case Kind::Value: {
254     return std::string(formatv("::mlir::Value {0};\n", name));
255   }
256   case Kind::MultipleValues: {
257     // This is for the variable used in the source pattern. Each named value in
258     // source pattern will only be bound to a Value. The others in the result
259     // pattern may be associated with multiple Values as we will use `auto` to
260     // do the type inference.
261     return std::string(formatv(
262         "::mlir::Value {0}_raw; ::mlir::ValueRange {0}({0}_raw);\n", name));
263   }
264   case Kind::Result: {
265     // Use the op itself for captured results.
266     return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
267   }
268   }
269   llvm_unreachable("unknown kind");
270 }
271 
272 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
273     StringRef name, int index, const char *fmt, const char *separator) const {
274   LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
275   switch (kind) {
276   case Kind::Attr: {
277     assert(index < 0);
278     auto repl = formatv(fmt, name);
279     LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
280     return std::string(repl);
281   }
282   case Kind::Operand: {
283     assert(index < 0);
284     auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
285     // If this operand is variadic, then return a range. Otherwise, return the
286     // value itself.
287     if (operand->isVariableLength()) {
288       auto repl = formatv(fmt, name);
289       LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
290       return std::string(repl);
291     }
292     auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
293     LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
294     return std::string(repl);
295   }
296   case Kind::Result: {
297     // If `index` is greater than zero, then we are referencing a specific
298     // result of a multi-result op. The result can still be variadic.
299     if (index >= 0) {
300       std::string v =
301           std::string(formatv("{0}.getODSResults({1})", name, index));
302       if (!op->getResult(index).isVariadic())
303         v = std::string(formatv("(*{0}.begin())", v));
304       auto repl = formatv(fmt, v);
305       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
306       return std::string(repl);
307     }
308 
309     // If this op has no result at all but still we bind a symbol to it, it
310     // means we want to capture the op itself.
311     if (op->getNumResults() == 0) {
312       LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
313       return std::string(name);
314     }
315 
316     // We are referencing all results of the multi-result op. A specific result
317     // can either be a value or a range. Then join them with `separator`.
318     SmallVector<std::string, 4> values;
319     values.reserve(op->getNumResults());
320 
321     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
322       std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
323       if (!op->getResult(i).isVariadic()) {
324         v = std::string(formatv("(*{0}.begin())", v));
325       }
326       values.push_back(std::string(formatv(fmt, v)));
327     }
328     auto repl = llvm::join(values, separator);
329     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
330     return repl;
331   }
332   case Kind::Value: {
333     assert(index < 0);
334     assert(op == nullptr);
335     auto repl = formatv(fmt, name);
336     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
337     return std::string(repl);
338   }
339   case Kind::MultipleValues: {
340     assert(op == nullptr);
341     assert(index < getSize());
342     if (index >= 0) {
343       std::string repl =
344           formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
345       LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
346       return repl;
347     }
348     // If it doesn't specify certain element, unpack them all.
349     auto repl =
350         formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
351     LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
352     return std::string(repl);
353   }
354   }
355   llvm_unreachable("unknown kind");
356 }
357 
358 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
359     StringRef name, int index, const char *fmt, const char *separator) const {
360   LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
361   switch (kind) {
362   case Kind::Attr:
363   case Kind::Operand: {
364     assert(index < 0 && "only allowed for symbol bound to result");
365     auto repl = formatv(fmt, name);
366     LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
367     return std::string(repl);
368   }
369   case Kind::Result: {
370     if (index >= 0) {
371       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
372       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
373       return std::string(repl);
374     }
375 
376     // We are referencing all results of the multi-result op. Each result should
377     // have a value range, and then join them with `separator`.
378     SmallVector<std::string, 4> values;
379     values.reserve(op->getNumResults());
380 
381     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
382       values.push_back(std::string(
383           formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
384     }
385     auto repl = llvm::join(values, separator);
386     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
387     return repl;
388   }
389   case Kind::Value: {
390     assert(index < 0 && "only allowed for symbol bound to result");
391     assert(op == nullptr);
392     auto repl = formatv(fmt, formatv("{{{0}}", name));
393     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
394     return std::string(repl);
395   }
396   case Kind::MultipleValues: {
397     assert(op == nullptr);
398     assert(index < getSize());
399     if (index >= 0) {
400       std::string repl =
401           formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
402       LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
403       return repl;
404     }
405     auto repl =
406         formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
407     LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
408     return std::string(repl);
409   }
410   }
411   llvm_unreachable("unknown kind");
412 }
413 
414 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
415                                    const Operator &op, int argIndex) {
416   StringRef name = getValuePackName(symbol);
417   if (name != symbol) {
418     auto error = formatv(
419         "symbol '{0}' with trailing index cannot bind to op argument", symbol);
420     PrintFatalError(loc, error);
421   }
422 
423   auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
424                      ? SymbolInfo::getAttr(&op, argIndex)
425                      : SymbolInfo::getOperand(node, &op, argIndex);
426 
427   std::string key = symbol.str();
428   if (symbolInfoMap.count(key)) {
429     // Only non unique name for the operand is supported.
430     if (symInfo.kind != SymbolInfo::Kind::Operand) {
431       return false;
432     }
433 
434     // Cannot add new operand if there is already non operand with the same
435     // name.
436     if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
437       return false;
438     }
439   }
440 
441   symbolInfoMap.emplace(key, symInfo);
442   return true;
443 }
444 
445 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
446   std::string name = getValuePackName(symbol).str();
447   auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
448 
449   return symbolInfoMap.count(inserted->first) == 1;
450 }
451 
452 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
453   std::string name = getValuePackName(symbol).str();
454   if (numValues > 1)
455     return bindMultipleValues(name, numValues);
456   return bindValue(name);
457 }
458 
459 bool SymbolInfoMap::bindValue(StringRef symbol) {
460   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
461   return symbolInfoMap.count(inserted->first) == 1;
462 }
463 
464 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
465   std::string name = getValuePackName(symbol).str();
466   auto inserted =
467       symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
468   return symbolInfoMap.count(inserted->first) == 1;
469 }
470 
471 bool SymbolInfoMap::bindAttr(StringRef symbol) {
472   auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
473   return symbolInfoMap.count(inserted->first) == 1;
474 }
475 
476 bool SymbolInfoMap::contains(StringRef symbol) const {
477   return find(symbol) != symbolInfoMap.end();
478 }
479 
480 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
481   std::string name = getValuePackName(key).str();
482 
483   return symbolInfoMap.find(name);
484 }
485 
486 SymbolInfoMap::const_iterator
487 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
488                                int argIndex) const {
489   std::string name = getValuePackName(key).str();
490   auto range = symbolInfoMap.equal_range(name);
491 
492   const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex);
493 
494   for (auto it = range.first; it != range.second; ++it)
495     if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
496       return it;
497 
498   return symbolInfoMap.end();
499 }
500 
501 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
502 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
503   std::string name = getValuePackName(key).str();
504 
505   return symbolInfoMap.equal_range(name);
506 }
507 
508 int SymbolInfoMap::count(StringRef key) const {
509   std::string name = getValuePackName(key).str();
510   return symbolInfoMap.count(name);
511 }
512 
513 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
514   StringRef name = getValuePackName(symbol);
515   if (name != symbol) {
516     // If there is a trailing index inside symbol, it references just one
517     // static value.
518     return 1;
519   }
520   // Otherwise, find how many it represents by querying the symbol's info.
521   return find(name)->second.getStaticValueCount();
522 }
523 
524 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
525                                                const char *fmt,
526                                                const char *separator) const {
527   int index = -1;
528   StringRef name = getValuePackName(symbol, &index);
529 
530   auto it = symbolInfoMap.find(name.str());
531   if (it == symbolInfoMap.end()) {
532     auto error = formatv("referencing unbound symbol '{0}'", symbol);
533     PrintFatalError(loc, error);
534   }
535 
536   return it->second.getValueAndRangeUse(name, index, fmt, separator);
537 }
538 
539 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
540                                           const char *separator) const {
541   int index = -1;
542   StringRef name = getValuePackName(symbol, &index);
543 
544   auto it = symbolInfoMap.find(name.str());
545   if (it == symbolInfoMap.end()) {
546     auto error = formatv("referencing unbound symbol '{0}'", symbol);
547     PrintFatalError(loc, error);
548   }
549 
550   return it->second.getAllRangeUse(name, index, fmt, separator);
551 }
552 
553 void SymbolInfoMap::assignUniqueAlternativeNames() {
554   llvm::StringSet<> usedNames;
555 
556   for (auto symbolInfoIt = symbolInfoMap.begin();
557        symbolInfoIt != symbolInfoMap.end();) {
558     auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
559     auto startRange = range.first;
560     auto endRange = range.second;
561 
562     auto operandName = symbolInfoIt->first;
563     int startSearchIndex = 0;
564     for (++startRange; startRange != endRange; ++startRange) {
565       // Current operand name is not unique, find a unique one
566       // and set the alternative name.
567       for (int i = startSearchIndex;; ++i) {
568         std::string alternativeName = operandName + std::to_string(i);
569         if (!usedNames.contains(alternativeName) &&
570             symbolInfoMap.count(alternativeName) == 0) {
571           usedNames.insert(alternativeName);
572           startRange->second.alternativeName = alternativeName;
573           startSearchIndex = i + 1;
574 
575           break;
576         }
577       }
578     }
579 
580     symbolInfoIt = endRange;
581   }
582 }
583 
584 //===----------------------------------------------------------------------===//
585 // Pattern
586 //==----------------------------------------------------------------------===//
587 
588 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
589     : def(*def), recordOpMap(mapper) {}
590 
591 DagNode Pattern::getSourcePattern() const {
592   return DagNode(def.getValueAsDag("sourcePattern"));
593 }
594 
595 int Pattern::getNumResultPatterns() const {
596   auto *results = def.getValueAsListInit("resultPatterns");
597   return results->size();
598 }
599 
600 DagNode Pattern::getResultPattern(unsigned index) const {
601   auto *results = def.getValueAsListInit("resultPatterns");
602   return DagNode(cast<llvm::DagInit>(results->getElement(index)));
603 }
604 
605 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
606   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
607   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
608   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
609 
610   LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
611   infoMap.assignUniqueAlternativeNames();
612   LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
613 }
614 
615 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
616   LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
617   for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
618     auto pattern = getResultPattern(i);
619     collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
620   }
621   LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
622 }
623 
624 const Operator &Pattern::getSourceRootOp() {
625   return getSourcePattern().getDialectOp(recordOpMap);
626 }
627 
628 Operator &Pattern::getDialectOp(DagNode node) {
629   return node.getDialectOp(recordOpMap);
630 }
631 
632 std::vector<AppliedConstraint> Pattern::getConstraints() const {
633   auto *listInit = def.getValueAsListInit("constraints");
634   std::vector<AppliedConstraint> ret;
635   ret.reserve(listInit->size());
636 
637   for (auto it : *listInit) {
638     auto *dagInit = dyn_cast<llvm::DagInit>(it);
639     if (!dagInit)
640       PrintFatalError(&def, "all elements in Pattern multi-entity "
641                             "constraints should be DAG nodes");
642 
643     std::vector<std::string> entities;
644     entities.reserve(dagInit->arg_size());
645     for (auto *argName : dagInit->getArgNames()) {
646       if (!argName) {
647         PrintFatalError(
648             &def,
649             "operands to additional constraints can only be symbol references");
650       }
651       entities.push_back(std::string(argName->getValue()));
652     }
653 
654     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
655                      dagInit->getNameStr(), std::move(entities));
656   }
657   return ret;
658 }
659 
660 int Pattern::getBenefit() const {
661   // The initial benefit value is a heuristic with number of ops in the source
662   // pattern.
663   int initBenefit = getSourcePattern().getNumOps();
664   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
665   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
666     PrintFatalError(&def,
667                     "The 'addBenefit' takes and only takes one integer value");
668   }
669   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
670 }
671 
672 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
673   std::vector<std::pair<StringRef, unsigned>> result;
674   result.reserve(def.getLoc().size());
675   for (auto loc : def.getLoc()) {
676     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
677     assert(buf && "invalid source location");
678     result.emplace_back(
679         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
680         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
681   }
682   return result;
683 }
684 
685 void Pattern::verifyBind(bool result, StringRef symbolName) {
686   if (!result) {
687     auto err = formatv("symbol '{0}' bound more than once", symbolName);
688     PrintFatalError(&def, err);
689   }
690 }
691 
692 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
693                                   bool isSrcPattern) {
694   auto treeName = tree.getSymbol();
695   auto numTreeArgs = tree.getNumArgs();
696 
697   if (tree.isNativeCodeCall()) {
698     if (!treeName.empty()) {
699       if (!isSrcPattern) {
700         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
701                                 << treeName << '\n');
702         verifyBind(
703             infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
704             treeName);
705       } else {
706         PrintFatalError(&def,
707                         formatv("binding symbol '{0}' to NativecodeCall in "
708                                 "MatchPattern is not supported",
709                                 treeName));
710       }
711     }
712 
713     for (int i = 0; i != numTreeArgs; ++i) {
714       if (auto treeArg = tree.getArgAsNestedDag(i)) {
715         // This DAG node argument is a DAG node itself. Go inside recursively.
716         collectBoundSymbols(treeArg, infoMap, isSrcPattern);
717         continue;
718       }
719 
720       if (!isSrcPattern)
721         continue;
722 
723       // We can only bind symbols to arguments in source pattern. Those
724       // symbols are referenced in result patterns.
725       auto treeArgName = tree.getArgName(i);
726 
727       // `$_` is a special symbol meaning ignore the current argument.
728       if (!treeArgName.empty() && treeArgName != "_") {
729         DagLeaf leaf = tree.getArgAsLeaf(i);
730 
731         // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
732         if (leaf.isUnspecified()) {
733           // This is case of $c, a Value without any constraints.
734           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
735         } else {
736           auto constraint = leaf.getAsConstraint();
737           bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
738                         leaf.isConstantAttr() ||
739                         constraint.getKind() == Constraint::Kind::CK_Attr;
740 
741           if (isAttr) {
742             // This is case of $a, a binding to a certain attribute.
743             verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
744             continue;
745           }
746 
747           // This is case of $b, a binding to a certain type.
748           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
749         }
750       }
751     }
752 
753     return;
754   }
755 
756   if (tree.isOperation()) {
757     auto &op = getDialectOp(tree);
758     auto numOpArgs = op.getNumArgs();
759 
760     // The pattern might have trailing directives.
761     int numDirectives = 0;
762     for (int i = numTreeArgs - 1; i >= 0; --i) {
763       if (auto dagArg = tree.getArgAsNestedDag(i)) {
764         if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
765           ++numDirectives;
766         else
767           break;
768       }
769     }
770 
771     if (numOpArgs != numTreeArgs - numDirectives) {
772       auto err = formatv("op '{0}' argument number mismatch: "
773                          "{1} in pattern vs. {2} in definition",
774                          op.getOperationName(), numTreeArgs, numOpArgs);
775       PrintFatalError(&def, err);
776     }
777 
778     // The name attached to the DAG node's operator is for representing the
779     // results generated from this op. It should be remembered as bound results.
780     if (!treeName.empty()) {
781       LLVM_DEBUG(llvm::dbgs()
782                  << "found symbol bound to op result: " << treeName << '\n');
783       verifyBind(infoMap.bindOpResult(treeName, op), treeName);
784     }
785 
786     for (int i = 0; i != numTreeArgs; ++i) {
787       if (auto treeArg = tree.getArgAsNestedDag(i)) {
788         // This DAG node argument is a DAG node itself. Go inside recursively.
789         collectBoundSymbols(treeArg, infoMap, isSrcPattern);
790         continue;
791       }
792 
793       if (isSrcPattern) {
794         // We can only bind symbols to op arguments in source pattern. Those
795         // symbols are referenced in result patterns.
796         auto treeArgName = tree.getArgName(i);
797         // `$_` is a special symbol meaning ignore the current argument.
798         if (!treeArgName.empty() && treeArgName != "_") {
799           LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
800                                   << treeArgName << '\n');
801           verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i),
802                      treeArgName);
803         }
804       }
805     }
806     return;
807   }
808 
809   if (!treeName.empty()) {
810     PrintFatalError(
811         &def, formatv("binding symbol '{0}' to non-operation/native code call "
812                       "unsupported right now",
813                       treeName));
814   }
815   return;
816 }
817