1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Pattern.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Support/Debug.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/Record.h"
20 
21 #define DEBUG_TYPE "mlir-tblgen-pattern"
22 
23 using namespace mlir;
24 
25 using llvm::formatv;
26 using mlir::tblgen::Operator;
27 
28 //===----------------------------------------------------------------------===//
29 // DagLeaf
30 //===----------------------------------------------------------------------===//
31 
32 bool tblgen::DagLeaf::isUnspecified() const {
33   return dyn_cast_or_null<llvm::UnsetInit>(def);
34 }
35 
36 bool tblgen::DagLeaf::isOperandMatcher() const {
37   // Operand matchers specify a type constraint.
38   return isSubClassOf("TypeConstraint");
39 }
40 
41 bool tblgen::DagLeaf::isAttrMatcher() const {
42   // Attribute matchers specify an attribute constraint.
43   return isSubClassOf("AttrConstraint");
44 }
45 
46 bool tblgen::DagLeaf::isNativeCodeCall() const {
47   return isSubClassOf("NativeCodeCall");
48 }
49 
50 bool tblgen::DagLeaf::isConstantAttr() const {
51   return isSubClassOf("ConstantAttr");
52 }
53 
54 bool tblgen::DagLeaf::isEnumAttrCase() const {
55   return isSubClassOf("EnumAttrCaseInfo");
56 }
57 
58 tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
59   assert((isOperandMatcher() || isAttrMatcher()) &&
60          "the DAG leaf must be operand or attribute");
61   return Constraint(cast<llvm::DefInit>(def)->getDef());
62 }
63 
64 tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
65   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
66   return ConstantAttr(cast<llvm::DefInit>(def));
67 }
68 
69 tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const {
70   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
71   return EnumAttrCase(cast<llvm::DefInit>(def));
72 }
73 
74 std::string tblgen::DagLeaf::getConditionTemplate() const {
75   return getAsConstraint().getConditionTemplate();
76 }
77 
78 llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const {
79   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
80   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
81 }
82 
83 bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
84   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
85     return defInit->getDef()->isSubClassOf(superclass);
86   return false;
87 }
88 
89 void tblgen::DagLeaf::print(raw_ostream &os) const {
90   if (def)
91     def->print(os);
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // DagNode
96 //===----------------------------------------------------------------------===//
97 
98 bool tblgen::DagNode::isNativeCodeCall() const {
99   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
100     return defInit->getDef()->isSubClassOf("NativeCodeCall");
101   return false;
102 }
103 
104 bool tblgen::DagNode::isOperation() const {
105   return !(isNativeCodeCall() || isReplaceWithValue());
106 }
107 
108 llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
109   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
110   return cast<llvm::DefInit>(node->getOperator())
111       ->getDef()
112       ->getValueAsString("expression");
113 }
114 
115 llvm::StringRef tblgen::DagNode::getSymbol() const {
116   return node->getNameStr();
117 }
118 
119 Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const {
120   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
121   auto it = mapper->find(opDef);
122   if (it != mapper->end())
123     return *it->second;
124   return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
125               .first->second;
126 }
127 
128 int tblgen::DagNode::getNumOps() const {
129   int count = isReplaceWithValue() ? 0 : 1;
130   for (int i = 0, e = getNumArgs(); i != e; ++i) {
131     if (auto child = getArgAsNestedDag(i))
132       count += child.getNumOps();
133   }
134   return count;
135 }
136 
137 int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); }
138 
139 bool tblgen::DagNode::isNestedDagArg(unsigned index) const {
140   return isa<llvm::DagInit>(node->getArg(index));
141 }
142 
143 tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const {
144   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
145 }
146 
147 tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const {
148   assert(!isNestedDagArg(index));
149   return DagLeaf(node->getArg(index));
150 }
151 
152 StringRef tblgen::DagNode::getArgName(unsigned index) const {
153   return node->getArgNameStr(index);
154 }
155 
156 bool tblgen::DagNode::isReplaceWithValue() const {
157   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
158   return dagOpDef->getName() == "replaceWithValue";
159 }
160 
161 void tblgen::DagNode::print(raw_ostream &os) const {
162   if (node)
163     node->print(os);
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // SymbolInfoMap
168 //===----------------------------------------------------------------------===//
169 
170 StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol,
171                                                   int *index) {
172   StringRef name, indexStr;
173   int idx = -1;
174   std::tie(name, indexStr) = symbol.rsplit("__");
175 
176   if (indexStr.consumeInteger(10, idx)) {
177     // The second part is not an index; we return the whole symbol as-is.
178     return symbol;
179   }
180   if (index) {
181     *index = idx;
182   }
183   return name;
184 }
185 
186 tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op,
187                                               SymbolInfo::Kind kind,
188                                               Optional<int> index)
189     : op(op), kind(kind), argIndex(index) {}
190 
191 int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
192   switch (kind) {
193   case Kind::Attr:
194   case Kind::Operand:
195   case Kind::Value:
196     return 1;
197   case Kind::Result:
198     return op->getNumResults();
199   }
200   llvm_unreachable("unknown kind");
201 }
202 
203 std::string
204 tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
205   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
206   switch (kind) {
207   case Kind::Attr: {
208     auto type =
209         op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
210     return std::string(formatv("{0} {1};\n", type, name));
211   }
212   case Kind::Operand: {
213     // Use operand range for captured operands (to support potential variadic
214     // operands).
215     return std::string(
216         formatv("Operation::operand_range {0}(op0->getOperands());\n", name));
217   }
218   case Kind::Value: {
219     return std::string(formatv("ArrayRef<Value> {0};\n", name));
220   }
221   case Kind::Result: {
222     // Use the op itself for captured results.
223     return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
224   }
225   }
226   llvm_unreachable("unknown kind");
227 }
228 
229 std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
230     StringRef name, int index, const char *fmt, const char *separator) const {
231   LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
232   switch (kind) {
233   case Kind::Attr: {
234     assert(index < 0);
235     auto repl = formatv(fmt, name);
236     LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
237     return std::string(repl);
238   }
239   case Kind::Operand: {
240     assert(index < 0);
241     auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
242     // If this operand is variadic, then return a range. Otherwise, return the
243     // value itself.
244     if (operand->isVariadic()) {
245       auto repl = formatv(fmt, name);
246       LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
247       return std::string(repl);
248     }
249     auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
250     LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
251     return std::string(repl);
252   }
253   case Kind::Result: {
254     // If `index` is greater than zero, then we are referencing a specific
255     // result of a multi-result op. The result can still be variadic.
256     if (index >= 0) {
257       std::string v =
258           std::string(formatv("{0}.getODSResults({1})", name, index));
259       if (!op->getResult(index).isVariadic())
260         v = std::string(formatv("(*{0}.begin())", v));
261       auto repl = formatv(fmt, v);
262       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
263       return std::string(repl);
264     }
265 
266     // If this op has no result at all but still we bind a symbol to it, it
267     // means we want to capture the op itself.
268     if (op->getNumResults() == 0) {
269       LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
270       return std::string(name);
271     }
272 
273     // We are referencing all results of the multi-result op. A specific result
274     // can either be a value or a range. Then join them with `separator`.
275     SmallVector<std::string, 4> values;
276     values.reserve(op->getNumResults());
277 
278     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
279       std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
280       if (!op->getResult(i).isVariadic()) {
281         v = std::string(formatv("(*{0}.begin())", v));
282       }
283       values.push_back(std::string(formatv(fmt, v)));
284     }
285     auto repl = llvm::join(values, separator);
286     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
287     return repl;
288   }
289   case Kind::Value: {
290     assert(index < 0);
291     assert(op == nullptr);
292     auto repl = formatv(fmt, name);
293     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
294     return std::string(repl);
295   }
296   }
297   llvm_unreachable("unknown kind");
298 }
299 
300 std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse(
301     StringRef name, int index, const char *fmt, const char *separator) const {
302   LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
303   switch (kind) {
304   case Kind::Attr:
305   case Kind::Operand: {
306     assert(index < 0 && "only allowed for symbol bound to result");
307     auto repl = formatv(fmt, name);
308     LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
309     return std::string(repl);
310   }
311   case Kind::Result: {
312     if (index >= 0) {
313       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
314       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
315       return std::string(repl);
316     }
317 
318     // We are referencing all results of the multi-result op. Each result should
319     // have a value range, and then join them with `separator`.
320     SmallVector<std::string, 4> values;
321     values.reserve(op->getNumResults());
322 
323     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
324       values.push_back(std::string(
325           formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
326     }
327     auto repl = llvm::join(values, separator);
328     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
329     return repl;
330   }
331   case Kind::Value: {
332     assert(index < 0 && "only allowed for symbol bound to result");
333     assert(op == nullptr);
334     auto repl = formatv(fmt, formatv("{{{0}}", name));
335     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
336     return std::string(repl);
337   }
338   }
339   llvm_unreachable("unknown kind");
340 }
341 
342 bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
343                                            int argIndex) {
344   StringRef name = getValuePackName(symbol);
345   if (name != symbol) {
346     auto error = formatv(
347         "symbol '{0}' with trailing index cannot bind to op argument", symbol);
348     PrintFatalError(loc, error);
349   }
350 
351   auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
352                      ? SymbolInfo::getAttr(&op, argIndex)
353                      : SymbolInfo::getOperand(&op, argIndex);
354 
355   return symbolInfoMap.insert({symbol, symInfo}).second;
356 }
357 
358 bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
359   StringRef name = getValuePackName(symbol);
360   return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
361 }
362 
363 bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) {
364   return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
365 }
366 
367 bool tblgen::SymbolInfoMap::contains(StringRef symbol) const {
368   return find(symbol) != symbolInfoMap.end();
369 }
370 
371 tblgen::SymbolInfoMap::const_iterator
372 tblgen::SymbolInfoMap::find(StringRef key) const {
373   StringRef name = getValuePackName(key);
374   return symbolInfoMap.find(name);
375 }
376 
377 int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
378   StringRef name = getValuePackName(symbol);
379   if (name != symbol) {
380     // If there is a trailing index inside symbol, it references just one
381     // static value.
382     return 1;
383   }
384   // Otherwise, find how many it represents by querying the symbol's info.
385   return find(name)->getValue().getStaticValueCount();
386 }
387 
388 std::string
389 tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt,
390                                            const char *separator) const {
391   int index = -1;
392   StringRef name = getValuePackName(symbol, &index);
393 
394   auto it = symbolInfoMap.find(name);
395   if (it == symbolInfoMap.end()) {
396     auto error = formatv("referencing unbound symbol '{0}'", symbol);
397     PrintFatalError(loc, error);
398   }
399 
400   return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
401 }
402 
403 std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol,
404                                                   const char *fmt,
405                                                   const char *separator) const {
406   int index = -1;
407   StringRef name = getValuePackName(symbol, &index);
408 
409   auto it = symbolInfoMap.find(name);
410   if (it == symbolInfoMap.end()) {
411     auto error = formatv("referencing unbound symbol '{0}'", symbol);
412     PrintFatalError(loc, error);
413   }
414 
415   return it->getValue().getAllRangeUse(name, index, fmt, separator);
416 }
417 
418 //===----------------------------------------------------------------------===//
419 // Pattern
420 //==----------------------------------------------------------------------===//
421 
422 tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
423     : def(*def), recordOpMap(mapper) {}
424 
425 tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
426   return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
427 }
428 
429 int tblgen::Pattern::getNumResultPatterns() const {
430   auto *results = def.getValueAsListInit("resultPatterns");
431   return results->size();
432 }
433 
434 tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
435   auto *results = def.getValueAsListInit("resultPatterns");
436   return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
437 }
438 
439 void tblgen::Pattern::collectSourcePatternBoundSymbols(
440     tblgen::SymbolInfoMap &infoMap) {
441   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
442   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
443   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
444 }
445 
446 void tblgen::Pattern::collectResultPatternBoundSymbols(
447     tblgen::SymbolInfoMap &infoMap) {
448   LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
449   for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
450     auto pattern = getResultPattern(i);
451     collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
452   }
453   LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
454 }
455 
456 const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
457   return getSourcePattern().getDialectOp(recordOpMap);
458 }
459 
460 tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) {
461   return node.getDialectOp(recordOpMap);
462 }
463 
464 std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const {
465   auto *listInit = def.getValueAsListInit("constraints");
466   std::vector<tblgen::AppliedConstraint> ret;
467   ret.reserve(listInit->size());
468 
469   for (auto it : *listInit) {
470     auto *dagInit = dyn_cast<llvm::DagInit>(it);
471     if (!dagInit)
472       PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity "
473                                     "constraints should be DAG nodes");
474 
475     std::vector<std::string> entities;
476     entities.reserve(dagInit->arg_size());
477     for (auto *argName : dagInit->getArgNames()) {
478       if (!argName) {
479         PrintFatalError(
480             def.getLoc(),
481             "operands to additional constraints can only be symbol references");
482       }
483       entities.push_back(std::string(argName->getValue()));
484     }
485 
486     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
487                      dagInit->getNameStr(), std::move(entities));
488   }
489   return ret;
490 }
491 
492 int tblgen::Pattern::getBenefit() const {
493   // The initial benefit value is a heuristic with number of ops in the source
494   // pattern.
495   int initBenefit = getSourcePattern().getNumOps();
496   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
497   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
498     PrintFatalError(def.getLoc(),
499                     "The 'addBenefit' takes and only takes one integer value");
500   }
501   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
502 }
503 
504 std::vector<tblgen::Pattern::IdentifierLine>
505 tblgen::Pattern::getLocation() const {
506   std::vector<std::pair<StringRef, unsigned>> result;
507   result.reserve(def.getLoc().size());
508   for (auto loc : def.getLoc()) {
509     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
510     assert(buf && "invalid source location");
511     result.emplace_back(
512         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
513         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
514   }
515   return result;
516 }
517 
518 void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
519                                           bool isSrcPattern) {
520   auto treeName = tree.getSymbol();
521   if (!tree.isOperation()) {
522     if (!treeName.empty()) {
523       PrintFatalError(
524           def.getLoc(),
525           formatv("binding symbol '{0}' to non-operation unsupported right now",
526                   treeName));
527     }
528     return;
529   }
530 
531   auto &op = getDialectOp(tree);
532   auto numOpArgs = op.getNumArgs();
533   auto numTreeArgs = tree.getNumArgs();
534 
535   if (numOpArgs != numTreeArgs) {
536     auto err = formatv("op '{0}' argument number mismatch: "
537                        "{1} in pattern vs. {2} in definition",
538                        op.getOperationName(), numTreeArgs, numOpArgs);
539     PrintFatalError(def.getLoc(), err);
540   }
541 
542   // The name attached to the DAG node's operator is for representing the
543   // results generated from this op. It should be remembered as bound results.
544   if (!treeName.empty()) {
545     LLVM_DEBUG(llvm::dbgs()
546                << "found symbol bound to op result: " << treeName << '\n');
547     if (!infoMap.bindOpResult(treeName, op))
548       PrintFatalError(def.getLoc(),
549                       formatv("symbol '{0}' bound more than once", treeName));
550   }
551 
552   for (int i = 0; i != numTreeArgs; ++i) {
553     if (auto treeArg = tree.getArgAsNestedDag(i)) {
554       // This DAG node argument is a DAG node itself. Go inside recursively.
555       collectBoundSymbols(treeArg, infoMap, isSrcPattern);
556     } else if (isSrcPattern) {
557       // We can only bind symbols to op arguments in source pattern. Those
558       // symbols are referenced in result patterns.
559       auto treeArgName = tree.getArgName(i);
560       // `$_` is a special symbol meaning ignore the current argument.
561       if (!treeArgName.empty() && treeArgName != "_") {
562         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
563                                 << treeArgName << '\n');
564         if (!infoMap.bindOpArgument(treeArgName, op, i)) {
565           auto err = formatv("symbol '{0}' bound more than once", treeArgName);
566           PrintFatalError(def.getLoc(), err);
567         }
568       }
569     }
570   }
571 }
572