1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include <utility>
15
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
25
26 using namespace mlir;
27 using namespace tblgen;
28
29 using llvm::formatv;
30
31 //===----------------------------------------------------------------------===//
32 // DagLeaf
33 //===----------------------------------------------------------------------===//
34
isUnspecified() const35 bool DagLeaf::isUnspecified() const {
36 return isa_and_nonnull<llvm::UnsetInit>(def);
37 }
38
isOperandMatcher() const39 bool DagLeaf::isOperandMatcher() const {
40 // Operand matchers specify a type constraint.
41 return isSubClassOf("TypeConstraint");
42 }
43
isAttrMatcher() const44 bool DagLeaf::isAttrMatcher() const {
45 // Attribute matchers specify an attribute constraint.
46 return isSubClassOf("AttrConstraint");
47 }
48
isNativeCodeCall() const49 bool DagLeaf::isNativeCodeCall() const {
50 return isSubClassOf("NativeCodeCall");
51 }
52
isConstantAttr() const53 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
54
isEnumAttrCase() const55 bool DagLeaf::isEnumAttrCase() const {
56 return isSubClassOf("EnumAttrCaseInfo");
57 }
58
isStringAttr() const59 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
60
getAsConstraint() const61 Constraint DagLeaf::getAsConstraint() const {
62 assert((isOperandMatcher() || isAttrMatcher()) &&
63 "the DAG leaf must be operand or attribute");
64 return Constraint(cast<llvm::DefInit>(def)->getDef());
65 }
66
getAsConstantAttr() const67 ConstantAttr DagLeaf::getAsConstantAttr() const {
68 assert(isConstantAttr() && "the DAG leaf must be constant attribute");
69 return ConstantAttr(cast<llvm::DefInit>(def));
70 }
71
getAsEnumAttrCase() const72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
73 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
74 return EnumAttrCase(cast<llvm::DefInit>(def));
75 }
76
getConditionTemplate() const77 std::string DagLeaf::getConditionTemplate() const {
78 return getAsConstraint().getConditionTemplate();
79 }
80
getNativeCodeTemplate() const81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
82 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
83 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
84 }
85
getNumReturnsOfNativeCode() const86 int DagLeaf::getNumReturnsOfNativeCode() const {
87 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
88 return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
89 }
90
getStringAttr() const91 std::string DagLeaf::getStringAttr() const {
92 assert(isStringAttr() && "the DAG leaf must be string attribute");
93 return def->getAsUnquotedString();
94 }
isSubClassOf(StringRef superclass) const95 bool DagLeaf::isSubClassOf(StringRef superclass) const {
96 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
97 return defInit->getDef()->isSubClassOf(superclass);
98 return false;
99 }
100
print(raw_ostream & os) const101 void DagLeaf::print(raw_ostream &os) const {
102 if (def)
103 def->print(os);
104 }
105
106 //===----------------------------------------------------------------------===//
107 // DagNode
108 //===----------------------------------------------------------------------===//
109
isNativeCodeCall() const110 bool DagNode::isNativeCodeCall() const {
111 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
112 return defInit->getDef()->isSubClassOf("NativeCodeCall");
113 return false;
114 }
115
isOperation() const116 bool DagNode::isOperation() const {
117 return !isNativeCodeCall() && !isReplaceWithValue() &&
118 !isLocationDirective() && !isReturnTypeDirective() && !isEither();
119 }
120
getNativeCodeTemplate() const121 llvm::StringRef DagNode::getNativeCodeTemplate() const {
122 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
123 return cast<llvm::DefInit>(node->getOperator())
124 ->getDef()
125 ->getValueAsString("expression");
126 }
127
getNumReturnsOfNativeCode() const128 int DagNode::getNumReturnsOfNativeCode() const {
129 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
130 return cast<llvm::DefInit>(node->getOperator())
131 ->getDef()
132 ->getValueAsInt("numReturns");
133 }
134
getSymbol() const135 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
136
getDialectOp(RecordOperatorMap * mapper) const137 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
138 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
139 auto it = mapper->find(opDef);
140 if (it != mapper->end())
141 return *it->second;
142 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
143 .first->second;
144 }
145
getNumOps() const146 int DagNode::getNumOps() const {
147 // We want to get number of operations recursively involved in the DAG tree.
148 // All other directives should be excluded.
149 int count = isOperation() ? 1 : 0;
150 for (int i = 0, e = getNumArgs(); i != e; ++i) {
151 if (auto child = getArgAsNestedDag(i))
152 count += child.getNumOps();
153 }
154 return count;
155 }
156
getNumArgs() const157 int DagNode::getNumArgs() const { return node->getNumArgs(); }
158
isNestedDagArg(unsigned index) const159 bool DagNode::isNestedDagArg(unsigned index) const {
160 return isa<llvm::DagInit>(node->getArg(index));
161 }
162
getArgAsNestedDag(unsigned index) const163 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
164 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
165 }
166
getArgAsLeaf(unsigned index) const167 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
168 assert(!isNestedDagArg(index));
169 return DagLeaf(node->getArg(index));
170 }
171
getArgName(unsigned index) const172 StringRef DagNode::getArgName(unsigned index) const {
173 return node->getArgNameStr(index);
174 }
175
isReplaceWithValue() const176 bool DagNode::isReplaceWithValue() const {
177 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
178 return dagOpDef->getName() == "replaceWithValue";
179 }
180
isLocationDirective() const181 bool DagNode::isLocationDirective() const {
182 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
183 return dagOpDef->getName() == "location";
184 }
185
isReturnTypeDirective() const186 bool DagNode::isReturnTypeDirective() const {
187 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
188 return dagOpDef->getName() == "returnType";
189 }
190
isEither() const191 bool DagNode::isEither() const {
192 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
193 return dagOpDef->getName() == "either";
194 }
195
print(raw_ostream & os) const196 void DagNode::print(raw_ostream &os) const {
197 if (node)
198 node->print(os);
199 }
200
201 //===----------------------------------------------------------------------===//
202 // SymbolInfoMap
203 //===----------------------------------------------------------------------===//
204
getValuePackName(StringRef symbol,int * index)205 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
206 StringRef name, indexStr;
207 int idx = -1;
208 std::tie(name, indexStr) = symbol.rsplit("__");
209
210 if (indexStr.consumeInteger(10, idx)) {
211 // The second part is not an index; we return the whole symbol as-is.
212 return symbol;
213 }
214 if (index) {
215 *index = idx;
216 }
217 return name;
218 }
219
SymbolInfo(const Operator * op,SymbolInfo::Kind kind,Optional<DagAndConstant> dagAndConstant)220 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
221 Optional<DagAndConstant> dagAndConstant)
222 : op(op), kind(kind), dagAndConstant(std::move(dagAndConstant)) {}
223
getStaticValueCount() const224 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
225 switch (kind) {
226 case Kind::Attr:
227 case Kind::Operand:
228 case Kind::Value:
229 return 1;
230 case Kind::Result:
231 return op->getNumResults();
232 case Kind::MultipleValues:
233 return getSize();
234 }
235 llvm_unreachable("unknown kind");
236 }
237
getVarName(StringRef name) const238 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
239 return alternativeName ? *alternativeName : name.str();
240 }
241
getVarTypeStr(StringRef name) const242 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
243 LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': ");
244 switch (kind) {
245 case Kind::Attr: {
246 if (op)
247 return op->getArg(getArgIndex())
248 .get<NamedAttribute *>()
249 ->attr.getStorageType()
250 .str();
251 // TODO(suderman): Use a more exact type when available.
252 return "::mlir::Attribute";
253 }
254 case Kind::Operand: {
255 // Use operand range for captured operands (to support potential variadic
256 // operands).
257 return "::mlir::Operation::operand_range";
258 }
259 case Kind::Value: {
260 return "::mlir::Value";
261 }
262 case Kind::MultipleValues: {
263 return "::mlir::ValueRange";
264 }
265 case Kind::Result: {
266 // Use the op itself for captured results.
267 return op->getQualCppClassName();
268 }
269 }
270 llvm_unreachable("unknown kind");
271 }
272
getVarDecl(StringRef name) const273 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
274 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
275 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
276 return std::string(
277 formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
278 }
279
getArgDecl(StringRef name) const280 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
281 LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': ");
282 return std::string(
283 formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
284 }
285
getValueAndRangeUse(StringRef name,int index,const char * fmt,const char * separator) const286 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
287 StringRef name, int index, const char *fmt, const char *separator) const {
288 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
289 switch (kind) {
290 case Kind::Attr: {
291 assert(index < 0);
292 auto repl = formatv(fmt, name);
293 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
294 return std::string(repl);
295 }
296 case Kind::Operand: {
297 assert(index < 0);
298 auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
299 // If this operand is variadic, then return a range. Otherwise, return the
300 // value itself.
301 if (operand->isVariableLength()) {
302 auto repl = formatv(fmt, name);
303 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
304 return std::string(repl);
305 }
306 auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
307 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
308 return std::string(repl);
309 }
310 case Kind::Result: {
311 // If `index` is greater than zero, then we are referencing a specific
312 // result of a multi-result op. The result can still be variadic.
313 if (index >= 0) {
314 std::string v =
315 std::string(formatv("{0}.getODSResults({1})", name, index));
316 if (!op->getResult(index).isVariadic())
317 v = std::string(formatv("(*{0}.begin())", v));
318 auto repl = formatv(fmt, v);
319 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
320 return std::string(repl);
321 }
322
323 // If this op has no result at all but still we bind a symbol to it, it
324 // means we want to capture the op itself.
325 if (op->getNumResults() == 0) {
326 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
327 return formatv(fmt, name);
328 }
329
330 // We are referencing all results of the multi-result op. A specific result
331 // can either be a value or a range. Then join them with `separator`.
332 SmallVector<std::string, 4> values;
333 values.reserve(op->getNumResults());
334
335 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
336 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
337 if (!op->getResult(i).isVariadic()) {
338 v = std::string(formatv("(*{0}.begin())", v));
339 }
340 values.push_back(std::string(formatv(fmt, v)));
341 }
342 auto repl = llvm::join(values, separator);
343 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
344 return repl;
345 }
346 case Kind::Value: {
347 assert(index < 0);
348 assert(op == nullptr);
349 auto repl = formatv(fmt, name);
350 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
351 return std::string(repl);
352 }
353 case Kind::MultipleValues: {
354 assert(op == nullptr);
355 assert(index < getSize());
356 if (index >= 0) {
357 std::string repl =
358 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
359 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
360 return repl;
361 }
362 // If it doesn't specify certain element, unpack them all.
363 auto repl =
364 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
365 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
366 return std::string(repl);
367 }
368 }
369 llvm_unreachable("unknown kind");
370 }
371
getAllRangeUse(StringRef name,int index,const char * fmt,const char * separator) const372 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
373 StringRef name, int index, const char *fmt, const char *separator) const {
374 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
375 switch (kind) {
376 case Kind::Attr:
377 case Kind::Operand: {
378 assert(index < 0 && "only allowed for symbol bound to result");
379 auto repl = formatv(fmt, name);
380 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
381 return std::string(repl);
382 }
383 case Kind::Result: {
384 if (index >= 0) {
385 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
386 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
387 return std::string(repl);
388 }
389
390 // We are referencing all results of the multi-result op. Each result should
391 // have a value range, and then join them with `separator`.
392 SmallVector<std::string, 4> values;
393 values.reserve(op->getNumResults());
394
395 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
396 values.push_back(std::string(
397 formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
398 }
399 auto repl = llvm::join(values, separator);
400 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
401 return repl;
402 }
403 case Kind::Value: {
404 assert(index < 0 && "only allowed for symbol bound to result");
405 assert(op == nullptr);
406 auto repl = formatv(fmt, formatv("{{{0}}", name));
407 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
408 return std::string(repl);
409 }
410 case Kind::MultipleValues: {
411 assert(op == nullptr);
412 assert(index < getSize());
413 if (index >= 0) {
414 std::string repl =
415 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
416 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
417 return repl;
418 }
419 auto repl =
420 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
421 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
422 return std::string(repl);
423 }
424 }
425 llvm_unreachable("unknown kind");
426 }
427
bindOpArgument(DagNode node,StringRef symbol,const Operator & op,int argIndex)428 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
429 const Operator &op, int argIndex) {
430 StringRef name = getValuePackName(symbol);
431 if (name != symbol) {
432 auto error = formatv(
433 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
434 PrintFatalError(loc, error);
435 }
436
437 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
438 ? SymbolInfo::getAttr(&op, argIndex)
439 : SymbolInfo::getOperand(node, &op, argIndex);
440
441 std::string key = symbol.str();
442 if (symbolInfoMap.count(key)) {
443 // Only non unique name for the operand is supported.
444 if (symInfo.kind != SymbolInfo::Kind::Operand) {
445 return false;
446 }
447
448 // Cannot add new operand if there is already non operand with the same
449 // name.
450 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
451 return false;
452 }
453 }
454
455 symbolInfoMap.emplace(key, symInfo);
456 return true;
457 }
458
bindOpResult(StringRef symbol,const Operator & op)459 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
460 std::string name = getValuePackName(symbol).str();
461 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
462
463 return symbolInfoMap.count(inserted->first) == 1;
464 }
465
bindValues(StringRef symbol,int numValues)466 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
467 std::string name = getValuePackName(symbol).str();
468 if (numValues > 1)
469 return bindMultipleValues(name, numValues);
470 return bindValue(name);
471 }
472
bindValue(StringRef symbol)473 bool SymbolInfoMap::bindValue(StringRef symbol) {
474 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
475 return symbolInfoMap.count(inserted->first) == 1;
476 }
477
bindMultipleValues(StringRef symbol,int numValues)478 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
479 std::string name = getValuePackName(symbol).str();
480 auto inserted =
481 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
482 return symbolInfoMap.count(inserted->first) == 1;
483 }
484
bindAttr(StringRef symbol)485 bool SymbolInfoMap::bindAttr(StringRef symbol) {
486 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
487 return symbolInfoMap.count(inserted->first) == 1;
488 }
489
contains(StringRef symbol) const490 bool SymbolInfoMap::contains(StringRef symbol) const {
491 return find(symbol) != symbolInfoMap.end();
492 }
493
find(StringRef key) const494 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
495 std::string name = getValuePackName(key).str();
496
497 return symbolInfoMap.find(name);
498 }
499
500 SymbolInfoMap::const_iterator
findBoundSymbol(StringRef key,DagNode node,const Operator & op,int argIndex) const501 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
502 int argIndex) const {
503 return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex));
504 }
505
506 SymbolInfoMap::const_iterator
findBoundSymbol(StringRef key,const SymbolInfo & symbolInfo) const507 SymbolInfoMap::findBoundSymbol(StringRef key,
508 const SymbolInfo &symbolInfo) const {
509 std::string name = getValuePackName(key).str();
510 auto range = symbolInfoMap.equal_range(name);
511
512 for (auto it = range.first; it != range.second; ++it)
513 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
514 return it;
515
516 return symbolInfoMap.end();
517 }
518
519 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
getRangeOfEqualElements(StringRef key)520 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
521 std::string name = getValuePackName(key).str();
522
523 return symbolInfoMap.equal_range(name);
524 }
525
count(StringRef key) const526 int SymbolInfoMap::count(StringRef key) const {
527 std::string name = getValuePackName(key).str();
528 return symbolInfoMap.count(name);
529 }
530
getStaticValueCount(StringRef symbol) const531 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
532 StringRef name = getValuePackName(symbol);
533 if (name != symbol) {
534 // If there is a trailing index inside symbol, it references just one
535 // static value.
536 return 1;
537 }
538 // Otherwise, find how many it represents by querying the symbol's info.
539 return find(name)->second.getStaticValueCount();
540 }
541
getValueAndRangeUse(StringRef symbol,const char * fmt,const char * separator) const542 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
543 const char *fmt,
544 const char *separator) const {
545 int index = -1;
546 StringRef name = getValuePackName(symbol, &index);
547
548 auto it = symbolInfoMap.find(name.str());
549 if (it == symbolInfoMap.end()) {
550 auto error = formatv("referencing unbound symbol '{0}'", symbol);
551 PrintFatalError(loc, error);
552 }
553
554 return it->second.getValueAndRangeUse(name, index, fmt, separator);
555 }
556
getAllRangeUse(StringRef symbol,const char * fmt,const char * separator) const557 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
558 const char *separator) const {
559 int index = -1;
560 StringRef name = getValuePackName(symbol, &index);
561
562 auto it = symbolInfoMap.find(name.str());
563 if (it == symbolInfoMap.end()) {
564 auto error = formatv("referencing unbound symbol '{0}'", symbol);
565 PrintFatalError(loc, error);
566 }
567
568 return it->second.getAllRangeUse(name, index, fmt, separator);
569 }
570
assignUniqueAlternativeNames()571 void SymbolInfoMap::assignUniqueAlternativeNames() {
572 llvm::StringSet<> usedNames;
573
574 for (auto symbolInfoIt = symbolInfoMap.begin();
575 symbolInfoIt != symbolInfoMap.end();) {
576 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
577 auto startRange = range.first;
578 auto endRange = range.second;
579
580 auto operandName = symbolInfoIt->first;
581 int startSearchIndex = 0;
582 for (++startRange; startRange != endRange; ++startRange) {
583 // Current operand name is not unique, find a unique one
584 // and set the alternative name.
585 for (int i = startSearchIndex;; ++i) {
586 std::string alternativeName = operandName + std::to_string(i);
587 if (!usedNames.contains(alternativeName) &&
588 symbolInfoMap.count(alternativeName) == 0) {
589 usedNames.insert(alternativeName);
590 startRange->second.alternativeName = alternativeName;
591 startSearchIndex = i + 1;
592
593 break;
594 }
595 }
596 }
597
598 symbolInfoIt = endRange;
599 }
600 }
601
602 //===----------------------------------------------------------------------===//
603 // Pattern
604 //==----------------------------------------------------------------------===//
605
Pattern(const llvm::Record * def,RecordOperatorMap * mapper)606 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
607 : def(*def), recordOpMap(mapper) {}
608
getSourcePattern() const609 DagNode Pattern::getSourcePattern() const {
610 return DagNode(def.getValueAsDag("sourcePattern"));
611 }
612
getNumResultPatterns() const613 int Pattern::getNumResultPatterns() const {
614 auto *results = def.getValueAsListInit("resultPatterns");
615 return results->size();
616 }
617
getResultPattern(unsigned index) const618 DagNode Pattern::getResultPattern(unsigned index) const {
619 auto *results = def.getValueAsListInit("resultPatterns");
620 return DagNode(cast<llvm::DagInit>(results->getElement(index)));
621 }
622
collectSourcePatternBoundSymbols(SymbolInfoMap & infoMap)623 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
624 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
625 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
626 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
627
628 LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
629 infoMap.assignUniqueAlternativeNames();
630 LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
631 }
632
collectResultPatternBoundSymbols(SymbolInfoMap & infoMap)633 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
634 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
635 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
636 auto pattern = getResultPattern(i);
637 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
638 }
639 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
640 }
641
getSourceRootOp()642 const Operator &Pattern::getSourceRootOp() {
643 return getSourcePattern().getDialectOp(recordOpMap);
644 }
645
getDialectOp(DagNode node)646 Operator &Pattern::getDialectOp(DagNode node) {
647 return node.getDialectOp(recordOpMap);
648 }
649
getConstraints() const650 std::vector<AppliedConstraint> Pattern::getConstraints() const {
651 auto *listInit = def.getValueAsListInit("constraints");
652 std::vector<AppliedConstraint> ret;
653 ret.reserve(listInit->size());
654
655 for (auto *it : *listInit) {
656 auto *dagInit = dyn_cast<llvm::DagInit>(it);
657 if (!dagInit)
658 PrintFatalError(&def, "all elements in Pattern multi-entity "
659 "constraints should be DAG nodes");
660
661 std::vector<std::string> entities;
662 entities.reserve(dagInit->arg_size());
663 for (auto *argName : dagInit->getArgNames()) {
664 if (!argName) {
665 PrintFatalError(
666 &def,
667 "operands to additional constraints can only be symbol references");
668 }
669 entities.emplace_back(argName->getValue());
670 }
671
672 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
673 dagInit->getNameStr(), std::move(entities));
674 }
675 return ret;
676 }
677
getBenefit() const678 int Pattern::getBenefit() const {
679 // The initial benefit value is a heuristic with number of ops in the source
680 // pattern.
681 int initBenefit = getSourcePattern().getNumOps();
682 llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
683 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
684 PrintFatalError(&def,
685 "The 'addBenefit' takes and only takes one integer value");
686 }
687 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
688 }
689
getLocation() const690 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
691 std::vector<std::pair<StringRef, unsigned>> result;
692 result.reserve(def.getLoc().size());
693 for (auto loc : def.getLoc()) {
694 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
695 assert(buf && "invalid source location");
696 result.emplace_back(
697 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
698 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
699 }
700 return result;
701 }
702
verifyBind(bool result,StringRef symbolName)703 void Pattern::verifyBind(bool result, StringRef symbolName) {
704 if (!result) {
705 auto err = formatv("symbol '{0}' bound more than once", symbolName);
706 PrintFatalError(&def, err);
707 }
708 }
709
collectBoundSymbols(DagNode tree,SymbolInfoMap & infoMap,bool isSrcPattern)710 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
711 bool isSrcPattern) {
712 auto treeName = tree.getSymbol();
713 auto numTreeArgs = tree.getNumArgs();
714
715 if (tree.isNativeCodeCall()) {
716 if (!treeName.empty()) {
717 if (!isSrcPattern) {
718 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
719 << treeName << '\n');
720 verifyBind(
721 infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
722 treeName);
723 } else {
724 PrintFatalError(&def,
725 formatv("binding symbol '{0}' to NativecodeCall in "
726 "MatchPattern is not supported",
727 treeName));
728 }
729 }
730
731 for (int i = 0; i != numTreeArgs; ++i) {
732 if (auto treeArg = tree.getArgAsNestedDag(i)) {
733 // This DAG node argument is a DAG node itself. Go inside recursively.
734 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
735 continue;
736 }
737
738 if (!isSrcPattern)
739 continue;
740
741 // We can only bind symbols to arguments in source pattern. Those
742 // symbols are referenced in result patterns.
743 auto treeArgName = tree.getArgName(i);
744
745 // `$_` is a special symbol meaning ignore the current argument.
746 if (!treeArgName.empty() && treeArgName != "_") {
747 DagLeaf leaf = tree.getArgAsLeaf(i);
748
749 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
750 if (leaf.isUnspecified()) {
751 // This is case of $c, a Value without any constraints.
752 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
753 } else {
754 auto constraint = leaf.getAsConstraint();
755 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
756 leaf.isConstantAttr() ||
757 constraint.getKind() == Constraint::Kind::CK_Attr;
758
759 if (isAttr) {
760 // This is case of $a, a binding to a certain attribute.
761 verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
762 continue;
763 }
764
765 // This is case of $b, a binding to a certain type.
766 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
767 }
768 }
769 }
770
771 return;
772 }
773
774 if (tree.isOperation()) {
775 auto &op = getDialectOp(tree);
776 auto numOpArgs = op.getNumArgs();
777 int numEither = 0;
778
779 // We need to exclude the trailing directives and `either` directive groups
780 // two operands of the operation.
781 int numDirectives = 0;
782 for (int i = numTreeArgs - 1; i >= 0; --i) {
783 if (auto dagArg = tree.getArgAsNestedDag(i)) {
784 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
785 ++numDirectives;
786 else if (dagArg.isEither())
787 ++numEither;
788 }
789 }
790
791 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
792 auto err =
793 formatv("op '{0}' argument number mismatch: "
794 "{1} in pattern vs. {2} in definition",
795 op.getOperationName(), numTreeArgs + numEither, numOpArgs);
796 PrintFatalError(&def, err);
797 }
798
799 // The name attached to the DAG node's operator is for representing the
800 // results generated from this op. It should be remembered as bound results.
801 if (!treeName.empty()) {
802 LLVM_DEBUG(llvm::dbgs()
803 << "found symbol bound to op result: " << treeName << '\n');
804 verifyBind(infoMap.bindOpResult(treeName, op), treeName);
805 }
806
807 // The operand in `either` DAG should be bound to the operation in the
808 // parent DagNode.
809 auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
810 int &opArgIdx) {
811 for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
812 if (DagNode subTree = tree.getArgAsNestedDag(i)) {
813 collectBoundSymbols(subTree, infoMap, isSrcPattern);
814 } else {
815 auto argName = tree.getArgName(i);
816 if (!argName.empty() && argName != "_")
817 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
818 argName);
819 }
820 }
821 };
822
823 for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
824 if (auto treeArg = tree.getArgAsNestedDag(i)) {
825 if (treeArg.isEither()) {
826 collectSymbolInEither(tree, treeArg, opArgIdx);
827 } else {
828 // This DAG node argument is a DAG node itself. Go inside recursively.
829 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
830 }
831 continue;
832 }
833
834 if (isSrcPattern) {
835 // We can only bind symbols to op arguments in source pattern. Those
836 // symbols are referenced in result patterns.
837 auto treeArgName = tree.getArgName(i);
838 // `$_` is a special symbol meaning ignore the current argument.
839 if (!treeArgName.empty() && treeArgName != "_") {
840 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
841 << treeArgName << '\n');
842 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
843 treeArgName);
844 }
845 }
846 }
847 return;
848 }
849
850 if (!treeName.empty()) {
851 PrintFatalError(
852 &def, formatv("binding symbol '{0}' to non-operation/native code call "
853 "unsupported right now",
854 treeName));
855 }
856 }
857