1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Pattern.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 #define DEBUG_TYPE "mlir-tblgen-pattern"
23 
24 using namespace mlir;
25 using namespace tblgen;
26 
27 using llvm::formatv;
28 
29 //===----------------------------------------------------------------------===//
30 // DagLeaf
31 //===----------------------------------------------------------------------===//
32 
33 bool DagLeaf::isUnspecified() const {
34   return dyn_cast_or_null<llvm::UnsetInit>(def);
35 }
36 
37 bool DagLeaf::isOperandMatcher() const {
38   // Operand matchers specify a type constraint.
39   return isSubClassOf("TypeConstraint");
40 }
41 
42 bool DagLeaf::isAttrMatcher() const {
43   // Attribute matchers specify an attribute constraint.
44   return isSubClassOf("AttrConstraint");
45 }
46 
47 bool DagLeaf::isNativeCodeCall() const {
48   return isSubClassOf("NativeCodeCall");
49 }
50 
51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
52 
53 bool DagLeaf::isEnumAttrCase() const {
54   return isSubClassOf("EnumAttrCaseInfo");
55 }
56 
57 bool DagLeaf::isStringAttr() const {
58   return isa<llvm::StringInit, llvm::CodeInit>(def);
59 }
60 
61 Constraint DagLeaf::getAsConstraint() const {
62   assert((isOperandMatcher() || isAttrMatcher()) &&
63          "the DAG leaf must be operand or attribute");
64   return Constraint(cast<llvm::DefInit>(def)->getDef());
65 }
66 
67 ConstantAttr DagLeaf::getAsConstantAttr() const {
68   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
69   return ConstantAttr(cast<llvm::DefInit>(def));
70 }
71 
72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
73   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
74   return EnumAttrCase(cast<llvm::DefInit>(def));
75 }
76 
77 std::string DagLeaf::getConditionTemplate() const {
78   return getAsConstraint().getConditionTemplate();
79 }
80 
81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
82   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
83   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
84 }
85 
86 std::string DagLeaf::getStringAttr() const {
87   assert(isStringAttr() && "the DAG leaf must be string attribute");
88   return def->getAsUnquotedString();
89 }
90 bool DagLeaf::isSubClassOf(StringRef superclass) const {
91   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
92     return defInit->getDef()->isSubClassOf(superclass);
93   return false;
94 }
95 
96 void DagLeaf::print(raw_ostream &os) const {
97   if (def)
98     def->print(os);
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // DagNode
103 //===----------------------------------------------------------------------===//
104 
105 bool DagNode::isNativeCodeCall() const {
106   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
107     return defInit->getDef()->isSubClassOf("NativeCodeCall");
108   return false;
109 }
110 
111 bool DagNode::isOperation() const {
112   return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective();
113 }
114 
115 llvm::StringRef DagNode::getNativeCodeTemplate() const {
116   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
117   return cast<llvm::DefInit>(node->getOperator())
118       ->getDef()
119       ->getValueAsString("expression");
120 }
121 
122 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
123 
124 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
125   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
126   auto it = mapper->find(opDef);
127   if (it != mapper->end())
128     return *it->second;
129   return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
130               .first->second;
131 }
132 
133 int DagNode::getNumOps() const {
134   int count = isReplaceWithValue() ? 0 : 1;
135   for (int i = 0, e = getNumArgs(); i != e; ++i) {
136     if (auto child = getArgAsNestedDag(i))
137       count += child.getNumOps();
138   }
139   return count;
140 }
141 
142 int DagNode::getNumArgs() const { return node->getNumArgs(); }
143 
144 bool DagNode::isNestedDagArg(unsigned index) const {
145   return isa<llvm::DagInit>(node->getArg(index));
146 }
147 
148 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
149   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
150 }
151 
152 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
153   assert(!isNestedDagArg(index));
154   return DagLeaf(node->getArg(index));
155 }
156 
157 StringRef DagNode::getArgName(unsigned index) const {
158   return node->getArgNameStr(index);
159 }
160 
161 bool DagNode::isReplaceWithValue() const {
162   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
163   return dagOpDef->getName() == "replaceWithValue";
164 }
165 
166 bool DagNode::isLocationDirective() const {
167   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
168   return dagOpDef->getName() == "location";
169 }
170 
171 void DagNode::print(raw_ostream &os) const {
172   if (node)
173     node->print(os);
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // SymbolInfoMap
178 //===----------------------------------------------------------------------===//
179 
180 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
181   StringRef name, indexStr;
182   int idx = -1;
183   std::tie(name, indexStr) = symbol.rsplit("__");
184 
185   if (indexStr.consumeInteger(10, idx)) {
186     // The second part is not an index; we return the whole symbol as-is.
187     return symbol;
188   }
189   if (index) {
190     *index = idx;
191   }
192   return name;
193 }
194 
195 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
196                                       Optional<int> index)
197     : op(op), kind(kind), argIndex(index) {}
198 
199 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
200   switch (kind) {
201   case Kind::Attr:
202   case Kind::Operand:
203   case Kind::Value:
204     return 1;
205   case Kind::Result:
206     return op->getNumResults();
207   }
208   llvm_unreachable("unknown kind");
209 }
210 
211 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
212   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
213   switch (kind) {
214   case Kind::Attr: {
215     auto type =
216         op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
217     return std::string(formatv("{0} {1};\n", type, name));
218   }
219   case Kind::Operand: {
220     // Use operand range for captured operands (to support potential variadic
221     // operands).
222     return std::string(
223         formatv("Operation::operand_range {0}(op0->getOperands());\n", name));
224   }
225   case Kind::Value: {
226     return std::string(formatv("ArrayRef<Value> {0};\n", name));
227   }
228   case Kind::Result: {
229     // Use the op itself for captured results.
230     return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
231   }
232   }
233   llvm_unreachable("unknown kind");
234 }
235 
236 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
237     StringRef name, int index, const char *fmt, const char *separator) const {
238   LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
239   switch (kind) {
240   case Kind::Attr: {
241     assert(index < 0);
242     auto repl = formatv(fmt, name);
243     LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
244     return std::string(repl);
245   }
246   case Kind::Operand: {
247     assert(index < 0);
248     auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
249     // If this operand is variadic, then return a range. Otherwise, return the
250     // value itself.
251     if (operand->isVariableLength()) {
252       auto repl = formatv(fmt, name);
253       LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
254       return std::string(repl);
255     }
256     auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
257     LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
258     return std::string(repl);
259   }
260   case Kind::Result: {
261     // If `index` is greater than zero, then we are referencing a specific
262     // result of a multi-result op. The result can still be variadic.
263     if (index >= 0) {
264       std::string v =
265           std::string(formatv("{0}.getODSResults({1})", name, index));
266       if (!op->getResult(index).isVariadic())
267         v = std::string(formatv("(*{0}.begin())", v));
268       auto repl = formatv(fmt, v);
269       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
270       return std::string(repl);
271     }
272 
273     // If this op has no result at all but still we bind a symbol to it, it
274     // means we want to capture the op itself.
275     if (op->getNumResults() == 0) {
276       LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
277       return std::string(name);
278     }
279 
280     // We are referencing all results of the multi-result op. A specific result
281     // can either be a value or a range. Then join them with `separator`.
282     SmallVector<std::string, 4> values;
283     values.reserve(op->getNumResults());
284 
285     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
286       std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
287       if (!op->getResult(i).isVariadic()) {
288         v = std::string(formatv("(*{0}.begin())", v));
289       }
290       values.push_back(std::string(formatv(fmt, v)));
291     }
292     auto repl = llvm::join(values, separator);
293     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
294     return repl;
295   }
296   case Kind::Value: {
297     assert(index < 0);
298     assert(op == nullptr);
299     auto repl = formatv(fmt, name);
300     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
301     return std::string(repl);
302   }
303   }
304   llvm_unreachable("unknown kind");
305 }
306 
307 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
308     StringRef name, int index, const char *fmt, const char *separator) const {
309   LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
310   switch (kind) {
311   case Kind::Attr:
312   case Kind::Operand: {
313     assert(index < 0 && "only allowed for symbol bound to result");
314     auto repl = formatv(fmt, name);
315     LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
316     return std::string(repl);
317   }
318   case Kind::Result: {
319     if (index >= 0) {
320       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
321       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
322       return std::string(repl);
323     }
324 
325     // We are referencing all results of the multi-result op. Each result should
326     // have a value range, and then join them with `separator`.
327     SmallVector<std::string, 4> values;
328     values.reserve(op->getNumResults());
329 
330     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
331       values.push_back(std::string(
332           formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
333     }
334     auto repl = llvm::join(values, separator);
335     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
336     return repl;
337   }
338   case Kind::Value: {
339     assert(index < 0 && "only allowed for symbol bound to result");
340     assert(op == nullptr);
341     auto repl = formatv(fmt, formatv("{{{0}}", name));
342     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
343     return std::string(repl);
344   }
345   }
346   llvm_unreachable("unknown kind");
347 }
348 
349 bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
350                                    int argIndex) {
351   StringRef name = getValuePackName(symbol);
352   if (name != symbol) {
353     auto error = formatv(
354         "symbol '{0}' with trailing index cannot bind to op argument", symbol);
355     PrintFatalError(loc, error);
356   }
357 
358   auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
359                      ? SymbolInfo::getAttr(&op, argIndex)
360                      : SymbolInfo::getOperand(&op, argIndex);
361 
362   return symbolInfoMap.insert({symbol, symInfo}).second;
363 }
364 
365 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
366   StringRef name = getValuePackName(symbol);
367   return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
368 }
369 
370 bool SymbolInfoMap::bindValue(StringRef symbol) {
371   return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
372 }
373 
374 bool SymbolInfoMap::contains(StringRef symbol) const {
375   return find(symbol) != symbolInfoMap.end();
376 }
377 
378 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
379   StringRef name = getValuePackName(key);
380   return symbolInfoMap.find(name);
381 }
382 
383 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
384   StringRef name = getValuePackName(symbol);
385   if (name != symbol) {
386     // If there is a trailing index inside symbol, it references just one
387     // static value.
388     return 1;
389   }
390   // Otherwise, find how many it represents by querying the symbol's info.
391   return find(name)->getValue().getStaticValueCount();
392 }
393 
394 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
395                                                const char *fmt,
396                                                const char *separator) const {
397   int index = -1;
398   StringRef name = getValuePackName(symbol, &index);
399 
400   auto it = symbolInfoMap.find(name);
401   if (it == symbolInfoMap.end()) {
402     auto error = formatv("referencing unbound symbol '{0}'", symbol);
403     PrintFatalError(loc, error);
404   }
405 
406   return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
407 }
408 
409 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
410                                           const char *separator) const {
411   int index = -1;
412   StringRef name = getValuePackName(symbol, &index);
413 
414   auto it = symbolInfoMap.find(name);
415   if (it == symbolInfoMap.end()) {
416     auto error = formatv("referencing unbound symbol '{0}'", symbol);
417     PrintFatalError(loc, error);
418   }
419 
420   return it->getValue().getAllRangeUse(name, index, fmt, separator);
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // Pattern
425 //==----------------------------------------------------------------------===//
426 
427 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
428     : def(*def), recordOpMap(mapper) {}
429 
430 DagNode Pattern::getSourcePattern() const {
431   return DagNode(def.getValueAsDag("sourcePattern"));
432 }
433 
434 int Pattern::getNumResultPatterns() const {
435   auto *results = def.getValueAsListInit("resultPatterns");
436   return results->size();
437 }
438 
439 DagNode Pattern::getResultPattern(unsigned index) const {
440   auto *results = def.getValueAsListInit("resultPatterns");
441   return DagNode(cast<llvm::DagInit>(results->getElement(index)));
442 }
443 
444 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
445   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
446   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
447   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
448 }
449 
450 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
451   LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
452   for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
453     auto pattern = getResultPattern(i);
454     collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
455   }
456   LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
457 }
458 
459 const Operator &Pattern::getSourceRootOp() {
460   return getSourcePattern().getDialectOp(recordOpMap);
461 }
462 
463 Operator &Pattern::getDialectOp(DagNode node) {
464   return node.getDialectOp(recordOpMap);
465 }
466 
467 std::vector<AppliedConstraint> Pattern::getConstraints() const {
468   auto *listInit = def.getValueAsListInit("constraints");
469   std::vector<AppliedConstraint> ret;
470   ret.reserve(listInit->size());
471 
472   for (auto it : *listInit) {
473     auto *dagInit = dyn_cast<llvm::DagInit>(it);
474     if (!dagInit)
475       PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity "
476                                     "constraints should be DAG nodes");
477 
478     std::vector<std::string> entities;
479     entities.reserve(dagInit->arg_size());
480     for (auto *argName : dagInit->getArgNames()) {
481       if (!argName) {
482         PrintFatalError(
483             def.getLoc(),
484             "operands to additional constraints can only be symbol references");
485       }
486       entities.push_back(std::string(argName->getValue()));
487     }
488 
489     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
490                      dagInit->getNameStr(), std::move(entities));
491   }
492   return ret;
493 }
494 
495 int Pattern::getBenefit() const {
496   // The initial benefit value is a heuristic with number of ops in the source
497   // pattern.
498   int initBenefit = getSourcePattern().getNumOps();
499   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
500   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
501     PrintFatalError(def.getLoc(),
502                     "The 'addBenefit' takes and only takes one integer value");
503   }
504   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
505 }
506 
507 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
508   std::vector<std::pair<StringRef, unsigned>> result;
509   result.reserve(def.getLoc().size());
510   for (auto loc : def.getLoc()) {
511     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
512     assert(buf && "invalid source location");
513     result.emplace_back(
514         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
515         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
516   }
517   return result;
518 }
519 
520 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
521                                   bool isSrcPattern) {
522   auto treeName = tree.getSymbol();
523   if (!tree.isOperation()) {
524     if (!treeName.empty()) {
525       PrintFatalError(
526           def.getLoc(),
527           formatv("binding symbol '{0}' to non-operation unsupported right now",
528                   treeName));
529     }
530     return;
531   }
532 
533   auto &op = getDialectOp(tree);
534   auto numOpArgs = op.getNumArgs();
535   auto numTreeArgs = tree.getNumArgs();
536 
537   // The pattern might have the last argument specifying the location.
538   bool hasLocDirective = false;
539   if (numTreeArgs != 0) {
540     if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
541       hasLocDirective = lastArg.isLocationDirective();
542   }
543 
544   if (numOpArgs != numTreeArgs - hasLocDirective) {
545     auto err = formatv("op '{0}' argument number mismatch: "
546                        "{1} in pattern vs. {2} in definition",
547                        op.getOperationName(), numTreeArgs, numOpArgs);
548     PrintFatalError(def.getLoc(), err);
549   }
550 
551   // The name attached to the DAG node's operator is for representing the
552   // results generated from this op. It should be remembered as bound results.
553   if (!treeName.empty()) {
554     LLVM_DEBUG(llvm::dbgs()
555                << "found symbol bound to op result: " << treeName << '\n');
556     if (!infoMap.bindOpResult(treeName, op))
557       PrintFatalError(def.getLoc(),
558                       formatv("symbol '{0}' bound more than once", treeName));
559   }
560 
561   for (int i = 0; i != numTreeArgs; ++i) {
562     if (auto treeArg = tree.getArgAsNestedDag(i)) {
563       // This DAG node argument is a DAG node itself. Go inside recursively.
564       collectBoundSymbols(treeArg, infoMap, isSrcPattern);
565     } else if (isSrcPattern) {
566       // We can only bind symbols to op arguments in source pattern. Those
567       // symbols are referenced in result patterns.
568       auto treeArgName = tree.getArgName(i);
569       // `$_` is a special symbol meaning ignore the current argument.
570       if (!treeArgName.empty() && treeArgName != "_") {
571         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
572                                 << treeArgName << '\n');
573         if (!infoMap.bindOpArgument(treeArgName, op, i)) {
574           auto err = formatv("symbol '{0}' bound more than once", treeArgName);
575           PrintFatalError(def.getLoc(), err);
576         }
577       }
578     }
579   }
580 }
581