1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Pattern.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Support/Debug.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/Record.h"
20 
21 #define DEBUG_TYPE "mlir-tblgen-pattern"
22 
23 using namespace mlir;
24 
25 using llvm::formatv;
26 using mlir::tblgen::Operator;
27 
28 //===----------------------------------------------------------------------===//
29 // DagLeaf
30 //===----------------------------------------------------------------------===//
31 
32 bool tblgen::DagLeaf::isUnspecified() const {
33   return dyn_cast_or_null<llvm::UnsetInit>(def);
34 }
35 
36 bool tblgen::DagLeaf::isOperandMatcher() const {
37   // Operand matchers specify a type constraint.
38   return isSubClassOf("TypeConstraint");
39 }
40 
41 bool tblgen::DagLeaf::isAttrMatcher() const {
42   // Attribute matchers specify an attribute constraint.
43   return isSubClassOf("AttrConstraint");
44 }
45 
46 bool tblgen::DagLeaf::isNativeCodeCall() const {
47   return isSubClassOf("NativeCodeCall");
48 }
49 
50 bool tblgen::DagLeaf::isConstantAttr() const {
51   return isSubClassOf("ConstantAttr");
52 }
53 
54 bool tblgen::DagLeaf::isEnumAttrCase() const {
55   return isSubClassOf("EnumAttrCaseInfo");
56 }
57 
58 tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
59   assert((isOperandMatcher() || isAttrMatcher()) &&
60          "the DAG leaf must be operand or attribute");
61   return Constraint(cast<llvm::DefInit>(def)->getDef());
62 }
63 
64 tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
65   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
66   return ConstantAttr(cast<llvm::DefInit>(def));
67 }
68 
69 tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const {
70   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
71   return EnumAttrCase(cast<llvm::DefInit>(def));
72 }
73 
74 std::string tblgen::DagLeaf::getConditionTemplate() const {
75   return getAsConstraint().getConditionTemplate();
76 }
77 
78 llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const {
79   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
80   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
81 }
82 
83 bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
84   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
85     return defInit->getDef()->isSubClassOf(superclass);
86   return false;
87 }
88 
89 void tblgen::DagLeaf::print(raw_ostream &os) const {
90   if (def)
91     def->print(os);
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // DagNode
96 //===----------------------------------------------------------------------===//
97 
98 bool tblgen::DagNode::isNativeCodeCall() const {
99   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
100     return defInit->getDef()->isSubClassOf("NativeCodeCall");
101   return false;
102 }
103 
104 bool tblgen::DagNode::isOperation() const {
105   return !(isNativeCodeCall() || isReplaceWithValue());
106 }
107 
108 llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
109   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
110   return cast<llvm::DefInit>(node->getOperator())
111       ->getDef()
112       ->getValueAsString("expression");
113 }
114 
115 llvm::StringRef tblgen::DagNode::getSymbol() const {
116   return node->getNameStr();
117 }
118 
119 Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const {
120   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
121   auto it = mapper->find(opDef);
122   if (it != mapper->end())
123     return *it->second;
124   return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
125               .first->second;
126 }
127 
128 int tblgen::DagNode::getNumOps() const {
129   int count = isReplaceWithValue() ? 0 : 1;
130   for (int i = 0, e = getNumArgs(); i != e; ++i) {
131     if (auto child = getArgAsNestedDag(i))
132       count += child.getNumOps();
133   }
134   return count;
135 }
136 
137 int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); }
138 
139 bool tblgen::DagNode::isNestedDagArg(unsigned index) const {
140   return isa<llvm::DagInit>(node->getArg(index));
141 }
142 
143 tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const {
144   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
145 }
146 
147 tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const {
148   assert(!isNestedDagArg(index));
149   return DagLeaf(node->getArg(index));
150 }
151 
152 StringRef tblgen::DagNode::getArgName(unsigned index) const {
153   return node->getArgNameStr(index);
154 }
155 
156 bool tblgen::DagNode::isReplaceWithValue() const {
157   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
158   return dagOpDef->getName() == "replaceWithValue";
159 }
160 
161 void tblgen::DagNode::print(raw_ostream &os) const {
162   if (node)
163     node->print(os);
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // SymbolInfoMap
168 //===----------------------------------------------------------------------===//
169 
170 StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol,
171                                                   int *index) {
172   StringRef name, indexStr;
173   int idx = -1;
174   std::tie(name, indexStr) = symbol.rsplit("__");
175 
176   if (indexStr.consumeInteger(10, idx)) {
177     // The second part is not an index; we return the whole symbol as-is.
178     return symbol;
179   }
180   if (index) {
181     *index = idx;
182   }
183   return name;
184 }
185 
186 tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op,
187                                               SymbolInfo::Kind kind,
188                                               Optional<int> index)
189     : op(op), kind(kind), argIndex(index) {}
190 
191 int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
192   switch (kind) {
193   case Kind::Attr:
194   case Kind::Operand:
195   case Kind::Value:
196     return 1;
197   case Kind::Result:
198     return op->getNumResults();
199   }
200   llvm_unreachable("unknown kind");
201 }
202 
203 std::string
204 tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
205   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
206   switch (kind) {
207   case Kind::Attr: {
208     auto type =
209         op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
210     return formatv("{0} {1};\n", type, name);
211   }
212   case Kind::Operand: {
213     // Use operand range for captured operands (to support potential variadic
214     // operands).
215     return formatv("Operation::operand_range {0}(op0->getOperands());\n", name);
216   }
217   case Kind::Value: {
218     return formatv("ArrayRef<Value> {0};\n", name);
219   }
220   case Kind::Result: {
221     // Use the op itself for captured results.
222     return formatv("{0} {1};\n", op->getQualCppClassName(), name);
223   }
224   }
225   llvm_unreachable("unknown kind");
226 }
227 
228 std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
229     StringRef name, int index, const char *fmt, const char *separator) const {
230   LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
231   switch (kind) {
232   case Kind::Attr: {
233     assert(index < 0);
234     auto repl = formatv(fmt, name);
235     LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
236     return repl;
237   }
238   case Kind::Operand: {
239     assert(index < 0);
240     auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
241     // If this operand is variadic, then return a range. Otherwise, return the
242     // value itself.
243     if (operand->isVariadic()) {
244       auto repl = formatv(fmt, name);
245       LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
246       return repl;
247     }
248     auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
249     LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
250     return repl;
251   }
252   case Kind::Result: {
253     // If `index` is greater than zero, then we are referencing a specific
254     // result of a multi-result op. The result can still be variadic.
255     if (index >= 0) {
256       std::string v = formatv("{0}.getODSResults({1})", name, index);
257       if (!op->getResult(index).isVariadic())
258         v = formatv("(*{0}.begin())", v);
259       auto repl = formatv(fmt, v);
260       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
261       return repl;
262     }
263 
264     // If this op has no result at all but still we bind a symbol to it, it
265     // means we want to capture the op itself.
266     if (op->getNumResults() == 0) {
267       LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
268       return name;
269     }
270 
271     // We are referencing all results of the multi-result op. A specific result
272     // can either be a value or a range. Then join them with `separator`.
273     SmallVector<std::string, 4> values;
274     values.reserve(op->getNumResults());
275 
276     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
277       std::string v = formatv("{0}.getODSResults({1})", name, i);
278       if (!op->getResult(i).isVariadic()) {
279         v = formatv("(*{0}.begin())", v);
280       }
281       values.push_back(formatv(fmt, v));
282     }
283     auto repl = llvm::join(values, separator);
284     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
285     return repl;
286   }
287   case Kind::Value: {
288     assert(index < 0);
289     assert(op == nullptr);
290     auto repl = formatv(fmt, name);
291     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
292     return repl;
293   }
294   }
295   llvm_unreachable("unknown kind");
296 }
297 
298 std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse(
299     StringRef name, int index, const char *fmt, const char *separator) const {
300   LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
301   switch (kind) {
302   case Kind::Attr:
303   case Kind::Operand: {
304     assert(index < 0 && "only allowed for symbol bound to result");
305     auto repl = formatv(fmt, name);
306     LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
307     return repl;
308   }
309   case Kind::Result: {
310     if (index >= 0) {
311       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
312       LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
313       return repl;
314     }
315 
316     // We are referencing all results of the multi-result op. Each result should
317     // have a value range, and then join them with `separator`.
318     SmallVector<std::string, 4> values;
319     values.reserve(op->getNumResults());
320 
321     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
322       values.push_back(
323           formatv(fmt, formatv("{0}.getODSResults({1})", name, i)));
324     }
325     auto repl = llvm::join(values, separator);
326     LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
327     return repl;
328   }
329   case Kind::Value: {
330     assert(index < 0 && "only allowed for symbol bound to result");
331     assert(op == nullptr);
332     auto repl = formatv(fmt, formatv("{{{0}}", name));
333     LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
334     return repl;
335   }
336   }
337   llvm_unreachable("unknown kind");
338 }
339 
340 bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
341                                            int argIndex) {
342   StringRef name = getValuePackName(symbol);
343   if (name != symbol) {
344     auto error = formatv(
345         "symbol '{0}' with trailing index cannot bind to op argument", symbol);
346     PrintFatalError(loc, error);
347   }
348 
349   auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
350                      ? SymbolInfo::getAttr(&op, argIndex)
351                      : SymbolInfo::getOperand(&op, argIndex);
352 
353   return symbolInfoMap.insert({symbol, symInfo}).second;
354 }
355 
356 bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
357   StringRef name = getValuePackName(symbol);
358   return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
359 }
360 
361 bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) {
362   return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
363 }
364 
365 bool tblgen::SymbolInfoMap::contains(StringRef symbol) const {
366   return find(symbol) != symbolInfoMap.end();
367 }
368 
369 tblgen::SymbolInfoMap::const_iterator
370 tblgen::SymbolInfoMap::find(StringRef key) const {
371   StringRef name = getValuePackName(key);
372   return symbolInfoMap.find(name);
373 }
374 
375 int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
376   StringRef name = getValuePackName(symbol);
377   if (name != symbol) {
378     // If there is a trailing index inside symbol, it references just one
379     // static value.
380     return 1;
381   }
382   // Otherwise, find how many it represents by querying the symbol's info.
383   return find(name)->getValue().getStaticValueCount();
384 }
385 
386 std::string
387 tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt,
388                                            const char *separator) const {
389   int index = -1;
390   StringRef name = getValuePackName(symbol, &index);
391 
392   auto it = symbolInfoMap.find(name);
393   if (it == symbolInfoMap.end()) {
394     auto error = formatv("referencing unbound symbol '{0}'", symbol);
395     PrintFatalError(loc, error);
396   }
397 
398   return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
399 }
400 
401 std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol,
402                                                   const char *fmt,
403                                                   const char *separator) const {
404   int index = -1;
405   StringRef name = getValuePackName(symbol, &index);
406 
407   auto it = symbolInfoMap.find(name);
408   if (it == symbolInfoMap.end()) {
409     auto error = formatv("referencing unbound symbol '{0}'", symbol);
410     PrintFatalError(loc, error);
411   }
412 
413   return it->getValue().getAllRangeUse(name, index, fmt, separator);
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // Pattern
418 //==----------------------------------------------------------------------===//
419 
420 tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
421     : def(*def), recordOpMap(mapper) {}
422 
423 tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
424   return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
425 }
426 
427 int tblgen::Pattern::getNumResultPatterns() const {
428   auto *results = def.getValueAsListInit("resultPatterns");
429   return results->size();
430 }
431 
432 tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
433   auto *results = def.getValueAsListInit("resultPatterns");
434   return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
435 }
436 
437 void tblgen::Pattern::collectSourcePatternBoundSymbols(
438     tblgen::SymbolInfoMap &infoMap) {
439   LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
440   collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
441   LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
442 }
443 
444 void tblgen::Pattern::collectResultPatternBoundSymbols(
445     tblgen::SymbolInfoMap &infoMap) {
446   LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
447   for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
448     auto pattern = getResultPattern(i);
449     collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
450   }
451   LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
452 }
453 
454 const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
455   return getSourcePattern().getDialectOp(recordOpMap);
456 }
457 
458 tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) {
459   return node.getDialectOp(recordOpMap);
460 }
461 
462 std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const {
463   auto *listInit = def.getValueAsListInit("constraints");
464   std::vector<tblgen::AppliedConstraint> ret;
465   ret.reserve(listInit->size());
466 
467   for (auto it : *listInit) {
468     auto *dagInit = dyn_cast<llvm::DagInit>(it);
469     if (!dagInit)
470       PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity "
471                                     "constraints should be DAG nodes");
472 
473     std::vector<std::string> entities;
474     entities.reserve(dagInit->arg_size());
475     for (auto *argName : dagInit->getArgNames()) {
476       if (!argName) {
477         PrintFatalError(
478             def.getLoc(),
479             "operands to additional constraints can only be symbol references");
480       }
481       entities.push_back(argName->getValue());
482     }
483 
484     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
485                      dagInit->getNameStr(), std::move(entities));
486   }
487   return ret;
488 }
489 
490 int tblgen::Pattern::getBenefit() const {
491   // The initial benefit value is a heuristic with number of ops in the source
492   // pattern.
493   int initBenefit = getSourcePattern().getNumOps();
494   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
495   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
496     PrintFatalError(def.getLoc(),
497                     "The 'addBenefit' takes and only takes one integer value");
498   }
499   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
500 }
501 
502 std::vector<tblgen::Pattern::IdentifierLine>
503 tblgen::Pattern::getLocation() const {
504   std::vector<std::pair<StringRef, unsigned>> result;
505   result.reserve(def.getLoc().size());
506   for (auto loc : def.getLoc()) {
507     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
508     assert(buf && "invalid source location");
509     result.emplace_back(
510         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
511         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
512   }
513   return result;
514 }
515 
516 void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
517                                           bool isSrcPattern) {
518   auto treeName = tree.getSymbol();
519   if (!tree.isOperation()) {
520     if (!treeName.empty()) {
521       PrintFatalError(
522           def.getLoc(),
523           formatv("binding symbol '{0}' to non-operation unsupported right now",
524                   treeName));
525     }
526     return;
527   }
528 
529   auto &op = getDialectOp(tree);
530   auto numOpArgs = op.getNumArgs();
531   auto numTreeArgs = tree.getNumArgs();
532 
533   if (numOpArgs != numTreeArgs) {
534     auto err = formatv("op '{0}' argument number mismatch: "
535                        "{1} in pattern vs. {2} in definition",
536                        op.getOperationName(), numTreeArgs, numOpArgs);
537     PrintFatalError(def.getLoc(), err);
538   }
539 
540   // The name attached to the DAG node's operator is for representing the
541   // results generated from this op. It should be remembered as bound results.
542   if (!treeName.empty()) {
543     LLVM_DEBUG(llvm::dbgs()
544                << "found symbol bound to op result: " << treeName << '\n');
545     if (!infoMap.bindOpResult(treeName, op))
546       PrintFatalError(def.getLoc(),
547                       formatv("symbol '{0}' bound more than once", treeName));
548   }
549 
550   for (int i = 0; i != numTreeArgs; ++i) {
551     if (auto treeArg = tree.getArgAsNestedDag(i)) {
552       // This DAG node argument is a DAG node itself. Go inside recursively.
553       collectBoundSymbols(treeArg, infoMap, isSrcPattern);
554     } else if (isSrcPattern) {
555       // We can only bind symbols to op arguments in source pattern. Those
556       // symbols are referenced in result patterns.
557       auto treeArgName = tree.getArgName(i);
558       // `$_` is a special symbol meaning ignore the current argument.
559       if (!treeArgName.empty() && treeArgName != "_") {
560         LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
561                                 << treeArgName << '\n');
562         if (!infoMap.bindOpArgument(treeArgName, op, i)) {
563           auto err = formatv("symbol '{0}' bound more than once", treeArgName);
564           PrintFatalError(def.getLoc(), err);
565         }
566       }
567     }
568   }
569 }
570