1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Operator.h"
19 #include "mlir/TableGen/Pattern.h"
20 #include "mlir/TableGen/Predicate.h"
21 #include "mlir/TableGen/Type.h"
22 #include "llvm/ADT/FunctionExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/FormatAdapters.h"
29 #include "llvm/Support/PrettyStackTrace.h"
30 #include "llvm/Support/Signals.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Main.h"
33 #include "llvm/TableGen/Record.h"
34 #include "llvm/TableGen/TableGenBackend.h"
35
36 using namespace mlir;
37 using namespace mlir::tblgen;
38
39 using llvm::formatv;
40 using llvm::Record;
41 using llvm::RecordKeeper;
42
43 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
44
45 namespace llvm {
46 template <>
47 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
formatllvm::format_provider48 static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
49 raw_ostream &os, StringRef style) {
50 os << v.first << ":" << v.second;
51 }
52 };
53 } // namespace llvm
54
55 //===----------------------------------------------------------------------===//
56 // PatternEmitter
57 //===----------------------------------------------------------------------===//
58
59 namespace {
60
61 class StaticMatcherHelper;
62
63 class PatternEmitter {
64 public:
65 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
66 StaticMatcherHelper &helper);
67
68 // Emits the mlir::RewritePattern struct named `rewriteName`.
69 void emit(StringRef rewriteName);
70
71 // Emits the static function of DAG matcher.
72 void emitStaticMatcher(DagNode tree, std::string funcName);
73
74 private:
75 // Emits the code for matching ops.
76 void emitMatchLogic(DagNode tree, StringRef opName);
77
78 // Emits the code for rewriting ops.
79 void emitRewriteLogic();
80
81 //===--------------------------------------------------------------------===//
82 // Match utilities
83 //===--------------------------------------------------------------------===//
84
85 // Emits C++ statements for matching the DAG structure.
86 void emitMatch(DagNode tree, StringRef name, int depth);
87
88 // Emit C++ function call to static DAG matcher.
89 void emitStaticMatchCall(DagNode tree, StringRef name);
90
91 // Emit C++ function call to static type/attribute constraint function.
92 void emitStaticVerifierCall(StringRef funcName, StringRef opName,
93 StringRef arg, StringRef failureStr);
94
95 // Emits C++ statements for matching using a native code call.
96 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
97
98 // Emits C++ statements for matching the op constrained by the given DAG
99 // `tree` returning the op's variable name.
100 void emitOpMatch(DagNode tree, StringRef opName, int depth);
101
102 // Emits C++ statements for matching the `argIndex`-th argument of the given
103 // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
104 // bound name and the constraint of the operand respectively.
105 void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
106 DagLeaf operandMatcher, StringRef argName,
107 int argIndex);
108
109 // Emits C++ statements for matching the operands which can be matched in
110 // either order.
111 void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
112 StringRef opName, int argIndex, int &operandIndex,
113 int depth);
114
115 // Emits C++ statements for matching the `argIndex`-th argument of the given
116 // DAG `tree` as an attribute.
117 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
118 int depth);
119
120 // Emits C++ for checking a match with a corresponding match failure
121 // diagnostic.
122 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
123 const llvm::formatv_object_base &failureFmt);
124
125 // Emits C++ for checking a match with a corresponding match failure
126 // diagnostics.
127 void emitMatchCheck(StringRef opName, const std::string &matchStr,
128 const std::string &failureStr);
129
130 //===--------------------------------------------------------------------===//
131 // Rewrite utilities
132 //===--------------------------------------------------------------------===//
133
134 // The entry point for handling a result pattern rooted at `resultTree`. This
135 // method dispatches to concrete handlers according to `resultTree`'s kind and
136 // returns a symbol representing the whole value pack. Callers are expected to
137 // further resolve the symbol according to the specific use case.
138 //
139 // `depth` is the nesting level of `resultTree`; 0 means top-level result
140 // pattern. For top-level result pattern, `resultIndex` indicates which result
141 // of the matched root op this pattern is intended to replace, which can be
142 // used to deduce the result type of the op generated from this result
143 // pattern.
144 std::string handleResultPattern(DagNode resultTree, int resultIndex,
145 int depth);
146
147 // Emits the C++ statement to replace the matched DAG with a value built via
148 // calling native C++ code.
149 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
150
151 // Returns the symbol of the old value serving as the replacement.
152 StringRef handleReplaceWithValue(DagNode tree);
153
154 // Trailing directives are used at the end of DAG node argument lists to
155 // specify additional behaviour for op matchers and creators, etc.
156 struct TrailingDirectives {
157 // DAG node containing the `location` directive. Null if there is none.
158 DagNode location;
159
160 // DAG node containing the `returnType` directive. Null if there is none.
161 DagNode returnType;
162
163 // Number of found trailing directives.
164 int numDirectives;
165 };
166
167 // Collect any trailing directives.
168 TrailingDirectives getTrailingDirectives(DagNode tree);
169
170 // Returns the location value to use.
171 std::string getLocation(TrailingDirectives &tail);
172
173 // Returns the location value to use.
174 std::string handleLocationDirective(DagNode tree);
175
176 // Emit return type argument.
177 std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
178
179 // Emits the C++ statement to build a new op out of the given DAG `tree` and
180 // returns the variable name that this op is assigned to. If the root op in
181 // DAG `tree` has a specified name, the created op will be assigned to a
182 // variable of the given name. Otherwise, a unique name will be used as the
183 // result value name.
184 std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
185
186 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
187
188 // Emits a local variable for each value and attribute to be used for creating
189 // an op.
190 void createSeparateLocalVarsForOpArgs(DagNode node,
191 ChildNodeIndexNameMap &childNodeNames);
192
193 // Emits the concrete arguments used to call an op's builder.
194 void supplyValuesForOpArgs(DagNode node,
195 const ChildNodeIndexNameMap &childNodeNames,
196 int depth);
197
198 // Emits the local variables for holding all values as a whole and all named
199 // attributes as a whole to be used for creating an op.
200 void createAggregateLocalVarsForOpArgs(
201 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
202
203 // Returns the C++ expression to construct a constant attribute of the given
204 // `value` for the given attribute kind `attr`.
205 std::string handleConstantAttr(Attribute attr, const Twine &value);
206
207 // Returns the C++ expression to build an argument from the given DAG `leaf`.
208 // `patArgName` is used to bound the argument to the source pattern.
209 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
210
211 //===--------------------------------------------------------------------===//
212 // General utilities
213 //===--------------------------------------------------------------------===//
214
215 // Collects all of the operations within the given dag tree.
216 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
217
218 // Returns a unique symbol for a local variable of the given `op`.
219 std::string getUniqueSymbol(const Operator *op);
220
221 //===--------------------------------------------------------------------===//
222 // Symbol utilities
223 //===--------------------------------------------------------------------===//
224
225 // Returns how many static values the given DAG `node` correspond to.
226 int getNodeValueCount(DagNode node);
227
228 private:
229 // Pattern instantiation location followed by the location of multiclass
230 // prototypes used. This is intended to be used as a whole to
231 // PrintFatalError() on errors.
232 ArrayRef<SMLoc> loc;
233
234 // Op's TableGen Record to wrapper object.
235 RecordOperatorMap *opMap;
236
237 // Handy wrapper for pattern being emitted.
238 Pattern pattern;
239
240 // Map for all bound symbols' info.
241 SymbolInfoMap symbolInfoMap;
242
243 StaticMatcherHelper &staticMatcherHelper;
244
245 // The next unused ID for newly created values.
246 unsigned nextValueId = 0;
247
248 raw_indented_ostream os;
249
250 // Format contexts containing placeholder substitutions.
251 FmtContext fmtCtx;
252 };
253
254 // Tracks DagNode's reference multiple times across patterns. Enables generating
255 // static matcher functions for DagNode's referenced multiple times rather than
256 // inlining them.
257 class StaticMatcherHelper {
258 public:
259 StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper,
260 RecordOperatorMap &mapper);
261
262 // Determine if we should inline the match logic or delegate to a static
263 // function.
useStaticMatcher(DagNode node)264 bool useStaticMatcher(DagNode node) {
265 return refStats[node] > kStaticMatcherThreshold;
266 }
267
268 // Get the name of the static DAG matcher function corresponding to the node.
getMatcherName(DagNode node)269 std::string getMatcherName(DagNode node) {
270 assert(useStaticMatcher(node));
271 return matcherNames[node];
272 }
273
274 // Get the name of static type/attribute verification function.
275 StringRef getVerifierName(DagLeaf leaf);
276
277 // Collect the `Record`s, i.e., the DRR, so that we can get the information of
278 // the duplicated DAGs.
279 void addPattern(Record *record);
280
281 // Emit all static functions of DAG Matcher.
282 void populateStaticMatchers(raw_ostream &os);
283
284 // Emit all static functions for Constraints.
285 void populateStaticConstraintFunctions(raw_ostream &os);
286
287 private:
288 static constexpr unsigned kStaticMatcherThreshold = 1;
289
290 // Consider two patterns as down below,
291 // DagNode_Root_A DagNode_Root_B
292 // \ \
293 // DagNode_C DagNode_C
294 // \ \
295 // DagNode_D DagNode_D
296 //
297 // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
298 // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
299 // multiple times so we'll have static matchers for both of them. When we're
300 // emitting the match logic for DagNode_C, we will check if DagNode_D has the
301 // static matcher generated. If so, then we'll generate a call to the
302 // function, inline otherwise. In this case, inlining is not what we want. As
303 // a result, generate the static matcher in topological order to ensure all
304 // the dependent static matchers are generated and we can avoid accidentally
305 // inlining.
306 //
307 // The topological order of all the DagNodes among all patterns.
308 SmallVector<std::pair<DagNode, Record *>> topologicalOrder;
309
310 RecordOperatorMap &opMap;
311
312 // Records of the static function name of each DagNode
313 DenseMap<DagNode, std::string> matcherNames;
314
315 // After collecting all the DagNode in each pattern, `refStats` records the
316 // number of users for each DagNode. We will generate the static matcher for a
317 // DagNode while the number of users exceeds a certain threshold.
318 DenseMap<DagNode, unsigned> refStats;
319
320 // Number of static matcher generated. This is used to generate a unique name
321 // for each DagNode.
322 int staticMatcherCounter = 0;
323
324 // The DagLeaf which contains type or attr constraint.
325 SetVector<DagLeaf> constraints;
326
327 // Static type/attribute verification function emitter.
328 StaticVerifierFunctionEmitter staticVerifierEmitter;
329 };
330
331 } // namespace
332
PatternEmitter(Record * pat,RecordOperatorMap * mapper,raw_ostream & os,StaticMatcherHelper & helper)333 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
334 raw_ostream &os, StaticMatcherHelper &helper)
335 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
336 symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) {
337 fmtCtx.withBuilder("rewriter");
338 }
339
handleConstantAttr(Attribute attr,const Twine & value)340 std::string PatternEmitter::handleConstantAttr(Attribute attr,
341 const Twine &value) {
342 if (!attr.isConstBuildable())
343 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
344 " does not have the 'constBuilderCall' field");
345
346 // TODO: Verify the constants here
347 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
348 }
349
emitStaticMatcher(DagNode tree,std::string funcName)350 void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
351 os << formatv(
352 "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
353 "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
354 "*, 4> &tblgen_ops",
355 funcName);
356
357 // We pass the reference of the variables that need to be captured. Hence we
358 // need to collect all the symbols in the tree first.
359 pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true);
360 symbolInfoMap.assignUniqueAlternativeNames();
361 for (const auto &info : symbolInfoMap)
362 os << formatv(", {0}", info.second.getArgDecl(info.first));
363
364 os << ") {\n";
365 os.indent();
366 os << "(void)tblgen_ops;\n";
367
368 // Note that a static matcher is considered at least one step from the match
369 // entry.
370 emitMatch(tree, "op0", /*depth=*/1);
371
372 os << "return ::mlir::success();\n";
373 os.unindent();
374 os << "}\n\n";
375 }
376
377 // Helper function to match patterns.
emitMatch(DagNode tree,StringRef name,int depth)378 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
379 if (tree.isNativeCodeCall()) {
380 emitNativeCodeMatch(tree, name, depth);
381 return;
382 }
383
384 if (tree.isOperation()) {
385 emitOpMatch(tree, name, depth);
386 return;
387 }
388
389 PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
390 }
391
emitStaticMatchCall(DagNode tree,StringRef opName)392 void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
393 std::string funcName = staticMatcherHelper.getMatcherName(tree);
394 os << formatv("if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", funcName,
395 opName);
396
397 // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
398 // one pass.
399
400 // In general, bound symbol should have the unique name in the pattern but
401 // for the operand, binding same symbol to multiple operands imply a
402 // constraint at the same time. In this case, we will rename those operands
403 // with different names. As a result, we need to collect all the symbolInfos
404 // from the DagNode then get the updated name of the local variables from the
405 // global symbolInfoMap.
406
407 // Collect all the bound symbols in the Dag
408 SymbolInfoMap localSymbolMap(loc);
409 pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true);
410
411 for (const auto &info : localSymbolMap) {
412 auto name = info.first;
413 auto symboInfo = info.second;
414 auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo);
415 os << formatv(", {0}", ret->second.getVarName(name));
416 }
417
418 os << "))) {\n";
419 os.scope().os << "return ::mlir::failure();\n";
420 os << "}\n";
421 }
422
emitStaticVerifierCall(StringRef funcName,StringRef opName,StringRef arg,StringRef failureStr)423 void PatternEmitter::emitStaticVerifierCall(StringRef funcName,
424 StringRef opName, StringRef arg,
425 StringRef failureStr) {
426 os << formatv("if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n",
427 funcName, opName, arg, failureStr);
428 os.scope().os << "return ::mlir::failure();\n";
429 os << "}\n";
430 }
431
432 // Helper function to match patterns.
emitNativeCodeMatch(DagNode tree,StringRef opName,int depth)433 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
434 int depth) {
435 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
436 LLVM_DEBUG(tree.print(llvm::dbgs()));
437 LLVM_DEBUG(llvm::dbgs() << '\n');
438
439 // The order of generating static matcher follows the topological order so
440 // that for every dependent DagNode already have their static matcher
441 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
442 // is when we are generating the static matcher for a DagNode itself. In this
443 // case, we need to emit the function body rather than a function call.
444 if (staticMatcherHelper.useStaticMatcher(tree) &&
445 !staticMatcherHelper.getMatcherName(tree).empty()) {
446 emitStaticMatchCall(tree, opName);
447
448 // NativeCodeCall will never be at depth 0 so that we don't need to catch
449 // the root operation as emitOpMatch();
450
451 return;
452 }
453
454 // TODO(suderman): iterate through arguments, determine their types, output
455 // names.
456 SmallVector<std::string, 8> capture;
457
458 raw_indented_ostream::DelimitedScope scope(os);
459
460 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
461 std::string argName = formatv("arg{0}_{1}", depth, i);
462 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
463 if (argTree.isEither())
464 PrintFatalError(loc, "NativeCodeCall cannot have `either` operands");
465
466 os << "::mlir::Value " << argName << ";\n";
467 } else {
468 auto leaf = tree.getArgAsLeaf(i);
469 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
470 os << "::mlir::Attribute " << argName << ";\n";
471 } else {
472 os << "::mlir::Value " << argName << ";\n";
473 }
474 }
475
476 capture.push_back(std::move(argName));
477 }
478
479 auto tail = getTrailingDirectives(tree);
480 if (tail.returnType)
481 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
482 auto locToUse = getLocation(tail);
483
484 auto fmt = tree.getNativeCodeTemplate();
485 if (fmt.count("$_self") != 1)
486 PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
487 "passing the defining Operation");
488
489 auto nativeCodeCall = std::string(
490 tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()),
491 static_cast<ArrayRef<std::string>>(capture)));
492
493 emitMatchCheck(opName, formatv("!::mlir::failed({0})", nativeCodeCall),
494 formatv("\"{0} return ::mlir::failure\"", nativeCodeCall));
495
496 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
497 auto name = tree.getArgName(i);
498 if (!name.empty() && name != "_") {
499 os << formatv("{0} = {1};\n", name, capture[i]);
500 }
501 }
502
503 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
504 std::string argName = capture[i];
505
506 // Handle nested DAG construct first
507 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
508 PrintFatalError(
509 loc, formatv("Matching nested tree in NativeCodecall not support for "
510 "{0} as arg {1}",
511 argName, i));
512 }
513
514 DagLeaf leaf = tree.getArgAsLeaf(i);
515
516 // The parameter for native function doesn't bind any constraints.
517 if (leaf.isUnspecified())
518 continue;
519
520 auto constraint = leaf.getAsConstraint();
521
522 std::string self;
523 if (leaf.isAttrMatcher() || leaf.isConstantAttr())
524 self = argName;
525 else
526 self = formatv("{0}.getType()", argName);
527 StringRef verifier = staticMatcherHelper.getVerifierName(leaf);
528 emitStaticVerifierCall(
529 verifier, opName, self,
530 formatv("\"operand {0} of native code call '{1}' failed to satisfy "
531 "constraint: "
532 "'{2}'\"",
533 i, tree.getNativeCodeTemplate(),
534 escapeString(constraint.getSummary()))
535 .str());
536 }
537
538 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
539 }
540
541 // Helper function to match patterns.
emitOpMatch(DagNode tree,StringRef opName,int depth)542 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
543 Operator &op = tree.getDialectOp(opMap);
544 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
545 << op.getOperationName() << "' at depth " << depth
546 << '\n');
547
548 auto getCastedName = [depth]() -> std::string {
549 return formatv("castedOp{0}", depth);
550 };
551
552 // The order of generating static matcher follows the topological order so
553 // that for every dependent DagNode already have their static matcher
554 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
555 // is when we are generating the static matcher for a DagNode itself. In this
556 // case, we need to emit the function body rather than a function call.
557 if (staticMatcherHelper.useStaticMatcher(tree) &&
558 !staticMatcherHelper.getMatcherName(tree).empty()) {
559 emitStaticMatchCall(tree, opName);
560 // In the codegen of rewriter, we suppose that castedOp0 will capture the
561 // root operation. Manually add it if the root DagNode is a static matcher.
562 if (depth == 0)
563 os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
564 "(void){2};\n",
565 opName, op.getQualCppClassName(), getCastedName());
566 return;
567 }
568
569 std::string castedName = getCastedName();
570 os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
571 "(void){0};\n",
572 castedName, opName, op.getQualCppClassName());
573
574 // Skip the operand matching at depth 0 as the pattern rewriter already does.
575 if (depth != 0)
576 emitMatchCheck(opName, /*matchStr=*/castedName,
577 formatv("\"{0} is not {1} type\"", castedName,
578 op.getQualCppClassName()));
579
580 // If the operand's name is set, set to that variable.
581 auto name = tree.getSymbol();
582 if (!name.empty())
583 os << formatv("{0} = {1};\n", name, castedName);
584
585 for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) {
586 auto opArg = op.getArg(i);
587 std::string argName = formatv("op{0}", depth + 1);
588
589 // Handle nested DAG construct first
590 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
591 if (argTree.isEither()) {
592 emitEitherOperandMatch(tree, argTree, castedName, i, nextOperand,
593 depth);
594 continue;
595 }
596 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
597 if (operand->isVariableLength()) {
598 auto error = formatv("use nested DAG construct to match op {0}'s "
599 "variadic operand #{1} unsupported now",
600 op.getOperationName(), i);
601 PrintFatalError(loc, error);
602 }
603 }
604
605 os << "{\n";
606
607 // Attributes don't count for getODSOperands.
608 // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
609 os.indent() << formatv(
610 "auto *{0} = "
611 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
612 argName, castedName, nextOperand);
613 // Null check of operand's definingOp
614 emitMatchCheck(
615 castedName, /*matchStr=*/argName,
616 formatv("\"There's no operation that defines operand {0} of {1}\"",
617 nextOperand++, castedName));
618 emitMatch(argTree, argName, depth + 1);
619 os << formatv("tblgen_ops.push_back({0});\n", argName);
620 os.unindent() << "}\n";
621 continue;
622 }
623
624 // Next handle DAG leaf: operand or attribute
625 if (opArg.is<NamedTypeConstraint *>()) {
626 auto operandName =
627 formatv("{0}.getODSOperands({1})", castedName, nextOperand);
628 emitOperandMatch(tree, castedName, operandName.str(),
629 /*operandMatcher=*/tree.getArgAsLeaf(i),
630 /*argName=*/tree.getArgName(i),
631 /*argIndex=*/i);
632 ++nextOperand;
633 } else if (opArg.is<NamedAttribute *>()) {
634 emitAttributeMatch(tree, opName, i, depth);
635 } else {
636 PrintFatalError(loc, "unhandled case when matching op");
637 }
638 }
639 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
640 << op.getOperationName() << "' at depth " << depth
641 << '\n');
642 }
643
emitOperandMatch(DagNode tree,StringRef opName,StringRef operandName,DagLeaf operandMatcher,StringRef argName,int argIndex)644 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
645 StringRef operandName,
646 DagLeaf operandMatcher, StringRef argName,
647 int argIndex) {
648 Operator &op = tree.getDialectOp(opMap);
649 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
650
651 // If a constraint is specified, we need to generate C++ statements to
652 // check the constraint.
653 if (!operandMatcher.isUnspecified()) {
654 if (!operandMatcher.isOperandMatcher())
655 PrintFatalError(
656 loc, formatv("the {1}-th argument of op '{0}' should be an operand",
657 op.getOperationName(), argIndex + 1));
658
659 // Only need to verify if the matcher's type is different from the one
660 // of op definition.
661 Constraint constraint = operandMatcher.getAsConstraint();
662 if (operand->constraint != constraint) {
663 if (operand->isVariableLength()) {
664 auto error = formatv(
665 "further constrain op {0}'s variadic operand #{1} unsupported now",
666 op.getOperationName(), argIndex);
667 PrintFatalError(loc, error);
668 }
669 auto self = formatv("(*{0}.begin()).getType()", operandName);
670 StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher);
671 emitStaticVerifierCall(
672 verifier, opName, self.str(),
673 formatv(
674 "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
675 operand - op.operand_begin(), op.getOperationName(),
676 escapeString(constraint.getSummary()))
677 .str());
678 }
679 }
680
681 // Capture the value
682 // `$_` is a special symbol to ignore op argument matching.
683 if (!argName.empty() && argName != "_") {
684 auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex);
685 os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName);
686 }
687 }
688
emitEitherOperandMatch(DagNode tree,DagNode eitherArgTree,StringRef opName,int argIndex,int & operandIndex,int depth)689 void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
690 StringRef opName, int argIndex,
691 int &operandIndex, int depth) {
692 constexpr int numEitherArgs = 2;
693 if (eitherArgTree.getNumArgs() != numEitherArgs)
694 PrintFatalError(loc, "`either` only supports grouping two operands");
695
696 Operator &op = tree.getDialectOp(opMap);
697
698 std::string codeBuffer;
699 llvm::raw_string_ostream tblgenOps(codeBuffer);
700
701 std::string lambda = formatv("eitherLambda{0}", depth);
702 os << formatv(
703 "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n",
704 lambda);
705
706 os.indent();
707
708 for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
709 if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) {
710 if (argTree.isEither())
711 PrintFatalError(loc, "either cannot be nested");
712
713 std::string argName = formatv("local_op_{0}", i).str();
714
715 os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName,
716 i);
717 emitMatchCheck(
718 opName, /*matchStr=*/argName,
719 formatv("\"There's no operation that defines operand {0} of {1}\"",
720 operandIndex++, opName));
721 emitMatch(argTree, argName, depth + 1);
722 // `tblgen_ops` is used to collect the matched operations. In either, we
723 // need to queue the operation only if the matching success. Thus we emit
724 // the code at the end.
725 tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
726 } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
727 emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
728 /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
729 /*argName=*/eitherArgTree.getArgName(i), argIndex);
730 ++operandIndex;
731 } else {
732 PrintFatalError(loc, "either can only be applied on operand");
733 }
734 }
735
736 os << tblgenOps.str();
737 os << "return ::mlir::success();\n";
738 os.unindent() << "};\n";
739
740 os << "{\n";
741 os.indent();
742
743 os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName,
744 operandIndex - 2);
745 os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName,
746 operandIndex - 1);
747
748 os << formatv("if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && "
749 "::mlir::failed({0}(eitherOperand1, "
750 "eitherOperand0)))\n",
751 lambda);
752 os.indent() << "return ::mlir::failure();\n";
753
754 os.unindent().unindent() << "}\n";
755 }
756
emitAttributeMatch(DagNode tree,StringRef opName,int argIndex,int depth)757 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
758 int argIndex, int depth) {
759 Operator &op = tree.getDialectOp(opMap);
760 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
761 const auto &attr = namedAttr->attr;
762
763 os << "{\n";
764 os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
765 "(void)tblgen_attr;\n",
766 opName, attr.getStorageType(), namedAttr->name);
767
768 // TODO: This should use getter method to avoid duplication.
769 if (attr.hasDefaultValue()) {
770 os << "if (!tblgen_attr) tblgen_attr = "
771 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
772 attr.getDefaultValue()))
773 << ";\n";
774 } else if (attr.isOptional()) {
775 // For a missing attribute that is optional according to definition, we
776 // should just capture a mlir::Attribute() to signal the missing state.
777 // That is precisely what getAttr() returns on missing attributes.
778 } else {
779 emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
780 formatv("\"expected op '{0}' to have attribute '{1}' "
781 "of type '{2}'\"",
782 op.getOperationName(), namedAttr->name,
783 attr.getStorageType()));
784 }
785
786 auto matcher = tree.getArgAsLeaf(argIndex);
787 if (!matcher.isUnspecified()) {
788 if (!matcher.isAttrMatcher()) {
789 PrintFatalError(
790 loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
791 op.getOperationName(), argIndex + 1));
792 }
793
794 // If a constraint is specified, we need to generate function call to its
795 // static verifier.
796 StringRef verifier = staticMatcherHelper.getVerifierName(matcher);
797 if (attr.isOptional()) {
798 // Avoid dereferencing null attribute. This is using a simple heuristic to
799 // avoid common cases of attempting to dereference null attribute. This
800 // will return where there is no check if attribute is null unless the
801 // attribute's value is not used.
802 // FIXME: This could be improved as some null dereferences could slip
803 // through.
804 if (!StringRef(matcher.getConditionTemplate()).contains("!$_self") &&
805 StringRef(matcher.getConditionTemplate()).contains("$_self")) {
806 os << "if (!tblgen_attr) return ::mlir::failure();\n";
807 }
808 }
809 emitStaticVerifierCall(
810 verifier, opName, "tblgen_attr",
811 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
812 "'{2}'\"",
813 op.getOperationName(), namedAttr->name,
814 escapeString(matcher.getAsConstraint().getSummary()))
815 .str());
816 }
817
818 // Capture the value
819 auto name = tree.getArgName(argIndex);
820 // `$_` is a special symbol to ignore op argument matching.
821 if (!name.empty() && name != "_") {
822 os << formatv("{0} = tblgen_attr;\n", name);
823 }
824
825 os.unindent() << "}\n";
826 }
827
emitMatchCheck(StringRef opName,const FmtObjectBase & matchFmt,const llvm::formatv_object_base & failureFmt)828 void PatternEmitter::emitMatchCheck(
829 StringRef opName, const FmtObjectBase &matchFmt,
830 const llvm::formatv_object_base &failureFmt) {
831 emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
832 }
833
emitMatchCheck(StringRef opName,const std::string & matchStr,const std::string & failureStr)834 void PatternEmitter::emitMatchCheck(StringRef opName,
835 const std::string &matchStr,
836 const std::string &failureStr) {
837
838 os << "if (!(" << matchStr << "))";
839 os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
840 << ", [&](::mlir::Diagnostic &diag) {\n diag << "
841 << failureStr << ";\n});";
842 }
843
emitMatchLogic(DagNode tree,StringRef opName)844 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
845 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
846 int depth = 0;
847 emitMatch(tree, opName, depth);
848
849 for (auto &appliedConstraint : pattern.getConstraints()) {
850 auto &constraint = appliedConstraint.constraint;
851 auto &entities = appliedConstraint.entities;
852
853 auto condition = constraint.getConditionTemplate();
854 if (isa<TypeConstraint>(constraint)) {
855 if (entities.size() != 1)
856 PrintFatalError(loc, "type constraint requires exactly one argument");
857
858 auto self = formatv("({0}.getType())",
859 symbolInfoMap.getValueAndRangeUse(entities.front()));
860 emitMatchCheck(
861 opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
862 formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"",
863 entities.front(), escapeString(constraint.getSummary())));
864
865 } else if (isa<AttrConstraint>(constraint)) {
866 PrintFatalError(
867 loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
868 } else {
869 // TODO: replace formatv arguments with the exact specified
870 // args.
871 if (entities.size() > 4) {
872 PrintFatalError(loc, "only support up to 4-entity constraints now");
873 }
874 SmallVector<std::string, 4> names;
875 int i = 0;
876 for (int e = entities.size(); i < e; ++i)
877 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
878 std::string self = appliedConstraint.self;
879 if (!self.empty())
880 self = symbolInfoMap.getValueAndRangeUse(self);
881 for (; i < 4; ++i)
882 names.push_back("<unused>");
883 emitMatchCheck(opName,
884 tgfmt(condition, &fmtCtx.withSelf(self), names[0],
885 names[1], names[2], names[3]),
886 formatv("\"entities '{0}' failed to satisfy constraint: "
887 "'{1}'\"",
888 llvm::join(entities, ", "),
889 escapeString(constraint.getSummary())));
890 }
891 }
892
893 // Some of the operands could be bound to the same symbol name, we need
894 // to enforce equality constraint on those.
895 // TODO: we should be able to emit equality checks early
896 // and short circuit unnecessary work if vars are not equal.
897 for (auto symbolInfoIt = symbolInfoMap.begin();
898 symbolInfoIt != symbolInfoMap.end();) {
899 auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
900 auto startRange = range.first;
901 auto endRange = range.second;
902
903 auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
904 for (++startRange; startRange != endRange; ++startRange) {
905 auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
906 emitMatchCheck(
907 opName,
908 formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
909 formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
910 secondOperand));
911 }
912
913 symbolInfoIt = endRange;
914 }
915
916 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
917 }
918
collectOps(DagNode tree,llvm::SmallPtrSetImpl<const Operator * > & ops)919 void PatternEmitter::collectOps(DagNode tree,
920 llvm::SmallPtrSetImpl<const Operator *> &ops) {
921 // Check if this tree is an operation.
922 if (tree.isOperation()) {
923 const Operator &op = tree.getDialectOp(opMap);
924 LLVM_DEBUG(llvm::dbgs()
925 << "found operation " << op.getOperationName() << '\n');
926 ops.insert(&op);
927 }
928
929 // Recurse the arguments of the tree.
930 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
931 if (auto child = tree.getArgAsNestedDag(i))
932 collectOps(child, ops);
933 }
934
emit(StringRef rewriteName)935 void PatternEmitter::emit(StringRef rewriteName) {
936 // Get the DAG tree for the source pattern.
937 DagNode sourceTree = pattern.getSourcePattern();
938
939 const Operator &rootOp = pattern.getSourceRootOp();
940 auto rootName = rootOp.getOperationName();
941
942 // Collect the set of result operations.
943 llvm::SmallPtrSet<const Operator *, 4> resultOps;
944 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
945 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
946 collectOps(pattern.getResultPattern(i), resultOps);
947 }
948 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
949
950 // Emit RewritePattern for Pattern.
951 auto locs = pattern.getLocation();
952 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
953 make_range(locs.rbegin(), locs.rend()));
954 os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
955 {0}(::mlir::MLIRContext *context)
956 : ::mlir::RewritePattern("{1}", {2}, context, {{)",
957 rewriteName, rootName, pattern.getBenefit());
958 // Sort result operators by name.
959 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
960 resultOps.end());
961 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
962 return lhs->getOperationName() < rhs->getOperationName();
963 });
964 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
965 os << '"' << op->getOperationName() << '"';
966 });
967 os << "}) {}\n";
968
969 // Emit matchAndRewrite() function.
970 {
971 auto classScope = os.scope();
972 os.printReindented(R"(
973 ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
974 ::mlir::PatternRewriter &rewriter) const override {)")
975 << '\n';
976 {
977 auto functionScope = os.scope();
978
979 // Register all symbols bound in the source pattern.
980 pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
981
982 LLVM_DEBUG(llvm::dbgs()
983 << "start creating local variables for capturing matches\n");
984 os << "// Variables for capturing values and attributes used while "
985 "creating ops\n";
986 // Create local variables for storing the arguments and results bound
987 // to symbols.
988 for (const auto &symbolInfoPair : symbolInfoMap) {
989 const auto &symbol = symbolInfoPair.first;
990 const auto &info = symbolInfoPair.second;
991
992 os << info.getVarDecl(symbol);
993 }
994 // TODO: capture ops with consistent numbering so that it can be
995 // reused for fused loc.
996 os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
997 LLVM_DEBUG(llvm::dbgs()
998 << "done creating local variables for capturing matches\n");
999
1000 os << "// Match\n";
1001 os << "tblgen_ops.push_back(op0);\n";
1002 emitMatchLogic(sourceTree, "op0");
1003
1004 os << "\n// Rewrite\n";
1005 emitRewriteLogic();
1006
1007 os << "return ::mlir::success();\n";
1008 }
1009 os << "};\n";
1010 }
1011 os << "};\n\n";
1012 }
1013
emitRewriteLogic()1014 void PatternEmitter::emitRewriteLogic() {
1015 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
1016 const Operator &rootOp = pattern.getSourceRootOp();
1017 int numExpectedResults = rootOp.getNumResults();
1018 int numResultPatterns = pattern.getNumResultPatterns();
1019
1020 // First register all symbols bound to ops generated in result patterns.
1021 pattern.collectResultPatternBoundSymbols(symbolInfoMap);
1022
1023 // Only the last N static values generated are used to replace the matched
1024 // root N-result op. We need to calculate the starting index (of the results
1025 // of the matched op) each result pattern is to replace.
1026 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
1027 // If we don't need to replace any value at all, set the replacement starting
1028 // index as the number of result patterns so we skip all of them when trying
1029 // to replace the matched op's results.
1030 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
1031 for (int i = numResultPatterns - 1; i >= 0; --i) {
1032 auto numValues = getNodeValueCount(pattern.getResultPattern(i));
1033 offsets[i] = offsets[i + 1] - numValues;
1034 if (offsets[i] == 0) {
1035 if (replStartIndex == -1)
1036 replStartIndex = i;
1037 } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
1038 auto error = formatv(
1039 "cannot use the same multi-result op '{0}' to generate both "
1040 "auxiliary values and values to be used for replacing the matched op",
1041 pattern.getResultPattern(i).getSymbol());
1042 PrintFatalError(loc, error);
1043 }
1044 }
1045
1046 if (offsets.front() > 0) {
1047 const char error[] = "no enough values generated to replace the matched op";
1048 PrintFatalError(loc, error);
1049 }
1050
1051 os << "auto odsLoc = rewriter.getFusedLoc({";
1052 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
1053 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
1054 }
1055 os << "}); (void)odsLoc;\n";
1056
1057 // Process auxiliary result patterns.
1058 for (int i = 0; i < replStartIndex; ++i) {
1059 DagNode resultTree = pattern.getResultPattern(i);
1060 auto val = handleResultPattern(resultTree, offsets[i], 0);
1061 // Normal op creation will be streamed to `os` by the above call; but
1062 // NativeCodeCall will only be materialized to `os` if it is used. Here
1063 // we are handling auxiliary patterns so we want the side effect even if
1064 // NativeCodeCall is not replacing matched root op's results.
1065 if (resultTree.isNativeCodeCall() &&
1066 resultTree.getNumReturnsOfNativeCode() == 0)
1067 os << val << ";\n";
1068 }
1069
1070 if (numExpectedResults == 0) {
1071 assert(replStartIndex >= numResultPatterns &&
1072 "invalid auxiliary vs. replacement pattern division!");
1073 // No result to replace. Just erase the op.
1074 os << "rewriter.eraseOp(op0);\n";
1075 } else {
1076 // Process replacement result patterns.
1077 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
1078 for (int i = replStartIndex; i < numResultPatterns; ++i) {
1079 DagNode resultTree = pattern.getResultPattern(i);
1080 auto val = handleResultPattern(resultTree, offsets[i], 0);
1081 os << "\n";
1082 // Resolve each symbol for all range use so that we can loop over them.
1083 // We need an explicit cast to `SmallVector` to capture the cases where
1084 // `{0}` resolves to an `Operation::result_range` as well as cases that
1085 // are not iterable (e.g. vector that gets wrapped in additional braces by
1086 // RewriterGen).
1087 // TODO: Revisit the need for materializing a vector.
1088 os << symbolInfoMap.getAllRangeUse(
1089 val,
1090 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
1091 " tblgen_repl_values.push_back(v);\n}\n",
1092 "\n");
1093 }
1094 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
1095 }
1096
1097 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
1098 }
1099
getUniqueSymbol(const Operator * op)1100 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
1101 return std::string(
1102 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
1103 }
1104
handleResultPattern(DagNode resultTree,int resultIndex,int depth)1105 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
1106 int resultIndex, int depth) {
1107 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
1108 LLVM_DEBUG(resultTree.print(llvm::dbgs()));
1109 LLVM_DEBUG(llvm::dbgs() << '\n');
1110
1111 if (resultTree.isLocationDirective()) {
1112 PrintFatalError(loc,
1113 "location directive can only be used with op creation");
1114 }
1115
1116 if (resultTree.isNativeCodeCall())
1117 return handleReplaceWithNativeCodeCall(resultTree, depth);
1118
1119 if (resultTree.isReplaceWithValue())
1120 return handleReplaceWithValue(resultTree).str();
1121
1122 // Normal op creation.
1123 auto symbol = handleOpCreation(resultTree, resultIndex, depth);
1124 if (resultTree.getSymbol().empty()) {
1125 // This is an op not explicitly bound to a symbol in the rewrite rule.
1126 // Register the auto-generated symbol for it.
1127 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
1128 }
1129 return symbol;
1130 }
1131
handleReplaceWithValue(DagNode tree)1132 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
1133 assert(tree.isReplaceWithValue());
1134
1135 if (tree.getNumArgs() != 1) {
1136 PrintFatalError(
1137 loc, "replaceWithValue directive must take exactly one argument");
1138 }
1139
1140 if (!tree.getSymbol().empty()) {
1141 PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
1142 }
1143
1144 return tree.getArgName(0);
1145 }
1146
handleLocationDirective(DagNode tree)1147 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
1148 assert(tree.isLocationDirective());
1149 auto lookUpArgLoc = [this, &tree](int idx) {
1150 const auto *const lookupFmt = "{0}.getLoc()";
1151 return symbolInfoMap.getValueAndRangeUse(tree.getArgName(idx), lookupFmt);
1152 };
1153
1154 if (tree.getNumArgs() == 0)
1155 llvm::PrintFatalError(
1156 "At least one argument to location directive required");
1157
1158 if (!tree.getSymbol().empty())
1159 PrintFatalError(loc, "cannot bind symbol to location");
1160
1161 if (tree.getNumArgs() == 1) {
1162 DagLeaf leaf = tree.getArgAsLeaf(0);
1163 if (leaf.isStringAttr())
1164 return formatv("::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))",
1165 leaf.getStringAttr())
1166 .str();
1167 return lookUpArgLoc(0);
1168 }
1169
1170 std::string ret;
1171 llvm::raw_string_ostream os(ret);
1172 std::string strAttr;
1173 os << "rewriter.getFusedLoc({";
1174 bool first = true;
1175 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1176 DagLeaf leaf = tree.getArgAsLeaf(i);
1177 // Handle the optional string value.
1178 if (leaf.isStringAttr()) {
1179 if (!strAttr.empty())
1180 llvm::PrintFatalError("Only one string attribute may be specified");
1181 strAttr = leaf.getStringAttr();
1182 continue;
1183 }
1184 os << (first ? "" : ", ") << lookUpArgLoc(i);
1185 first = false;
1186 }
1187 os << "}";
1188 if (!strAttr.empty()) {
1189 os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
1190 }
1191 os << ")";
1192 return os.str();
1193 }
1194
handleReturnTypeArg(DagNode returnType,int i,int depth)1195 std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
1196 int depth) {
1197 // Nested NativeCodeCall.
1198 if (auto dagNode = returnType.getArgAsNestedDag(i)) {
1199 if (!dagNode.isNativeCodeCall())
1200 PrintFatalError(loc, "nested DAG in `returnType` must be a native code "
1201 "call");
1202 return handleReplaceWithNativeCodeCall(dagNode, depth);
1203 }
1204 // String literal.
1205 auto dagLeaf = returnType.getArgAsLeaf(i);
1206 if (dagLeaf.isStringAttr())
1207 return tgfmt(dagLeaf.getStringAttr(), &fmtCtx);
1208 return tgfmt(
1209 "$0.getType()", &fmtCtx,
1210 handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i)));
1211 }
1212
handleOpArgument(DagLeaf leaf,StringRef patArgName)1213 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
1214 StringRef patArgName) {
1215 if (leaf.isStringAttr())
1216 PrintFatalError(loc, "raw string not supported as argument");
1217 if (leaf.isConstantAttr()) {
1218 auto constAttr = leaf.getAsConstantAttr();
1219 return handleConstantAttr(constAttr.getAttribute(),
1220 constAttr.getConstantValue());
1221 }
1222 if (leaf.isEnumAttrCase()) {
1223 auto enumCase = leaf.getAsEnumAttrCase();
1224 // This is an enum case backed by an IntegerAttr. We need to get its value
1225 // to build the constant.
1226 std::string val = std::to_string(enumCase.getValue());
1227 return handleConstantAttr(enumCase, val);
1228 }
1229
1230 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
1231 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
1232 if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
1233 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
1234 << "' (via symbol ref)\n");
1235 return argName;
1236 }
1237 if (leaf.isNativeCodeCall()) {
1238 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
1239 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
1240 << "' (via NativeCodeCall)\n");
1241 return std::string(repl);
1242 }
1243 PrintFatalError(loc, "unhandled case when rewriting op");
1244 }
1245
handleReplaceWithNativeCodeCall(DagNode tree,int depth)1246 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
1247 int depth) {
1248 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
1249 LLVM_DEBUG(tree.print(llvm::dbgs()));
1250 LLVM_DEBUG(llvm::dbgs() << '\n');
1251
1252 auto fmt = tree.getNativeCodeTemplate();
1253
1254 SmallVector<std::string, 16> attrs;
1255
1256 auto tail = getTrailingDirectives(tree);
1257 if (tail.returnType)
1258 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
1259 auto locToUse = getLocation(tail);
1260
1261 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1262 if (tree.isNestedDagArg(i)) {
1263 attrs.push_back(
1264 handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
1265 } else {
1266 attrs.push_back(
1267 handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
1268 }
1269 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
1270 << " replacement: " << attrs[i] << "\n");
1271 }
1272
1273 std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse),
1274 static_cast<ArrayRef<std::string>>(attrs));
1275
1276 // In general, NativeCodeCall without naming binding don't need this. To
1277 // ensure void helper function has been correctly labeled, i.e., use
1278 // NativeCodeCallVoid, we cache the result to a local variable so that we will
1279 // get a compilation error in the auto-generated file.
1280 // Example.
1281 // // In the td file
1282 // Pat<(...), (NativeCodeCall<Foo> ...)>
1283 //
1284 // ---
1285 //
1286 // // In the auto-generated .cpp
1287 // ...
1288 // // Causes compilation error if Foo() returns void.
1289 // auto nativeVar = Foo();
1290 // ...
1291 if (tree.getNumReturnsOfNativeCode() != 0) {
1292 // Determine the local variable name for return value.
1293 std::string varName =
1294 SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
1295 if (varName.empty()) {
1296 varName = formatv("nativeVar_{0}", nextValueId++);
1297 // Register the local variable for later uses.
1298 symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
1299 }
1300
1301 // Catch the return value of helper function.
1302 os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
1303
1304 if (!tree.getSymbol().empty())
1305 symbol = tree.getSymbol().str();
1306 else
1307 symbol = varName;
1308 }
1309
1310 return symbol;
1311 }
1312
getNodeValueCount(DagNode node)1313 int PatternEmitter::getNodeValueCount(DagNode node) {
1314 if (node.isOperation()) {
1315 // If the op is bound to a symbol in the rewrite rule, query its result
1316 // count from the symbol info map.
1317 auto symbol = node.getSymbol();
1318 if (!symbol.empty()) {
1319 return symbolInfoMap.getStaticValueCount(symbol);
1320 }
1321 // Otherwise this is an unbound op; we will use all its results.
1322 return pattern.getDialectOp(node).getNumResults();
1323 }
1324
1325 if (node.isNativeCodeCall())
1326 return node.getNumReturnsOfNativeCode();
1327
1328 return 1;
1329 }
1330
1331 PatternEmitter::TrailingDirectives
getTrailingDirectives(DagNode tree)1332 PatternEmitter::getTrailingDirectives(DagNode tree) {
1333 TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0};
1334
1335 // Look backwards through the arguments.
1336 auto numPatArgs = tree.getNumArgs();
1337 for (int i = numPatArgs - 1; i >= 0; --i) {
1338 auto dagArg = tree.getArgAsNestedDag(i);
1339 // A leaf is not a directive. Stop looking.
1340 if (!dagArg)
1341 break;
1342
1343 auto isLocation = dagArg.isLocationDirective();
1344 auto isReturnType = dagArg.isReturnTypeDirective();
1345 // If encountered a DAG node that isn't a trailing directive, stop looking.
1346 if (!(isLocation || isReturnType))
1347 break;
1348 // Save the directive, but error if one of the same type was already
1349 // found.
1350 ++tail.numDirectives;
1351 if (isLocation) {
1352 if (tail.location)
1353 PrintFatalError(loc, "`location` directive can only be specified "
1354 "once");
1355 tail.location = dagArg;
1356 } else if (isReturnType) {
1357 if (tail.returnType)
1358 PrintFatalError(loc, "`returnType` directive can only be specified "
1359 "once");
1360 tail.returnType = dagArg;
1361 }
1362 }
1363
1364 return tail;
1365 }
1366
1367 std::string
getLocation(PatternEmitter::TrailingDirectives & tail)1368 PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
1369 if (tail.location)
1370 return handleLocationDirective(tail.location);
1371
1372 // If no explicit location is given, use the default, all fused, location.
1373 return "odsLoc";
1374 }
1375
handleOpCreation(DagNode tree,int resultIndex,int depth)1376 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
1377 int depth) {
1378 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
1379 LLVM_DEBUG(tree.print(llvm::dbgs()));
1380 LLVM_DEBUG(llvm::dbgs() << '\n');
1381
1382 Operator &resultOp = tree.getDialectOp(opMap);
1383 auto numOpArgs = resultOp.getNumArgs();
1384 auto numPatArgs = tree.getNumArgs();
1385
1386 auto tail = getTrailingDirectives(tree);
1387 auto locToUse = getLocation(tail);
1388
1389 auto inPattern = numPatArgs - tail.numDirectives;
1390 if (numOpArgs != inPattern) {
1391 PrintFatalError(loc,
1392 formatv("resultant op '{0}' argument number mismatch: "
1393 "{1} in pattern vs. {2} in definition",
1394 resultOp.getOperationName(), inPattern, numOpArgs));
1395 }
1396
1397 // A map to collect all nested DAG child nodes' names, with operand index as
1398 // the key. This includes both bound and unbound child nodes.
1399 ChildNodeIndexNameMap childNodeNames;
1400
1401 // First go through all the child nodes who are nested DAG constructs to
1402 // create ops for them and remember the symbol names for them, so that we can
1403 // use the results in the current node. This happens in a recursive manner.
1404 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1405 if (auto child = tree.getArgAsNestedDag(i))
1406 childNodeNames[i] = handleResultPattern(child, i, depth + 1);
1407 }
1408
1409 // The name of the local variable holding this op.
1410 std::string valuePackName;
1411 // The symbol for holding the result of this pattern. Note that the result of
1412 // this pattern is not necessarily the same as the variable created by this
1413 // pattern because we can use `__N` suffix to refer only a specific result if
1414 // the generated op is a multi-result op.
1415 std::string resultValue;
1416 if (tree.getSymbol().empty()) {
1417 // No symbol is explicitly bound to this op in the pattern. Generate a
1418 // unique name.
1419 valuePackName = resultValue = getUniqueSymbol(&resultOp);
1420 } else {
1421 resultValue = std::string(tree.getSymbol());
1422 // Strip the index to get the name for the value pack and use it to name the
1423 // local variable for the op.
1424 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
1425 }
1426
1427 // Create the local variable for this op.
1428 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
1429 valuePackName);
1430
1431 // Right now ODS don't have general type inference support. Except a few
1432 // special cases listed below, DRR needs to supply types for all results
1433 // when building an op.
1434 bool isSameOperandsAndResultType =
1435 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
1436 bool useFirstAttr =
1437 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
1438
1439 if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
1440 // We know how to deduce the result type for ops with these traits and we've
1441 // generated builders taking aggregate parameters. Use those builders to
1442 // create the ops.
1443
1444 // First prepare local variables for op arguments used in builder call.
1445 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1446
1447 // Then create the op.
1448 os.scope("", "\n}\n").os << formatv(
1449 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1450 valuePackName, resultOp.getQualCppClassName(), locToUse);
1451 return resultValue;
1452 }
1453
1454 bool usePartialResults = valuePackName != resultValue;
1455
1456 if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
1457 // For these cases (broadcastable ops, op results used both as auxiliary
1458 // values and replacement values, ops in nested patterns, auxiliary ops), we
1459 // still need to supply the result types when building the op. But because
1460 // we don't generate a builder automatically with ODS for them, it's the
1461 // developer's responsibility to make sure such a builder (with result type
1462 // deduction ability) exists. We go through the separate-parameter builder
1463 // here given that it's easier for developers to write compared to
1464 // aggregate-parameter builders.
1465 createSeparateLocalVarsForOpArgs(tree, childNodeNames);
1466
1467 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
1468 resultOp.getQualCppClassName(), locToUse);
1469 supplyValuesForOpArgs(tree, childNodeNames, depth);
1470 os << "\n );\n}\n";
1471 return resultValue;
1472 }
1473
1474 // If we are provided explicit return types, use them to build the op.
1475 // However, if depth == 0 and resultIndex >= 0, it means we are replacing
1476 // the values generated from the source pattern root op. Then we must use the
1477 // source pattern's value types to determine the value type of the generated
1478 // op here.
1479 if (depth == 0 && resultIndex >= 0 && tail.returnType)
1480 PrintFatalError(loc, "Cannot specify explicit return types in an op whose "
1481 "return values replace the source pattern's root op");
1482
1483 // First prepare local variables for op arguments used in builder call.
1484 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1485
1486 // Then prepare the result types. We need to specify the types for all
1487 // results.
1488 os.indent() << formatv("::llvm::SmallVector<::mlir::Type, 4> tblgen_types; "
1489 "(void)tblgen_types;\n");
1490 int numResults = resultOp.getNumResults();
1491 if (tail.returnType) {
1492 auto numRetTys = tail.returnType.getNumArgs();
1493 for (int i = 0; i < numRetTys; ++i) {
1494 auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1);
1495 os << "tblgen_types.push_back(" << varName << ");\n";
1496 }
1497 } else {
1498 if (numResults != 0) {
1499 // Copy the result types from the source pattern.
1500 for (int i = 0; i < numResults; ++i)
1501 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
1502 " tblgen_types.push_back(v.getType());\n}\n",
1503 resultIndex + i);
1504 }
1505 }
1506 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1507 "tblgen_values, tblgen_attrs);\n",
1508 valuePackName, resultOp.getQualCppClassName(), locToUse);
1509 os.unindent() << "}\n";
1510 return resultValue;
1511 }
1512
createSeparateLocalVarsForOpArgs(DagNode node,ChildNodeIndexNameMap & childNodeNames)1513 void PatternEmitter::createSeparateLocalVarsForOpArgs(
1514 DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1515 Operator &resultOp = node.getDialectOp(opMap);
1516
1517 // Now prepare operands used for building this op:
1518 // * If the operand is non-variadic, we create a `Value` local variable.
1519 // * If the operand is variadic, we create a `SmallVector<Value>` local
1520 // variable.
1521
1522 int valueIndex = 0; // An index for uniquing local variable names.
1523 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1524 const auto *operand =
1525 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
1526 // We do not need special handling for attributes.
1527 if (!operand)
1528 continue;
1529
1530 raw_indented_ostream::DelimitedScope scope(os);
1531 std::string varName;
1532 if (operand->isVariadic()) {
1533 varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
1534 os << formatv("::llvm::SmallVector<::mlir::Value, 4> {0};\n", varName);
1535 std::string range;
1536 if (node.isNestedDagArg(argIndex)) {
1537 range = childNodeNames[argIndex];
1538 } else {
1539 range = std::string(node.getArgName(argIndex));
1540 }
1541 // Resolve the symbol for all range use so that we have a uniform way of
1542 // capturing the values.
1543 range = symbolInfoMap.getValueAndRangeUse(range);
1544 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range,
1545 varName);
1546 } else {
1547 varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
1548 os << formatv("::mlir::Value {0} = ", varName);
1549 if (node.isNestedDagArg(argIndex)) {
1550 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
1551 } else {
1552 DagLeaf leaf = node.getArgAsLeaf(argIndex);
1553 auto symbol =
1554 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1555 if (leaf.isNativeCodeCall()) {
1556 os << std::string(
1557 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1558 } else {
1559 os << symbol;
1560 }
1561 }
1562 os << ";\n";
1563 }
1564
1565 // Update to use the newly created local variable for building the op later.
1566 childNodeNames[argIndex] = varName;
1567 }
1568 }
1569
supplyValuesForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1570 void PatternEmitter::supplyValuesForOpArgs(
1571 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1572 Operator &resultOp = node.getDialectOp(opMap);
1573 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1574 argIndex != numOpArgs; ++argIndex) {
1575 // Start each argument on its own line.
1576 os << ",\n ";
1577
1578 Argument opArg = resultOp.getArg(argIndex);
1579 // Handle the case of operand first.
1580 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
1581 if (!operand->name.empty())
1582 os << "/*" << operand->name << "=*/";
1583 os << childNodeNames.lookup(argIndex);
1584 continue;
1585 }
1586
1587 // The argument in the op definition.
1588 auto opArgName = resultOp.getArgName(argIndex);
1589 if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1590 if (!subTree.isNativeCodeCall())
1591 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1592 "for creating attribute");
1593 os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
1594 } else {
1595 auto leaf = node.getArgAsLeaf(argIndex);
1596 // The argument in the result DAG pattern.
1597 auto patArgName = node.getArgName(argIndex);
1598 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1599 // TODO: Refactor out into map to avoid recomputing these.
1600 if (!opArg.is<NamedAttribute *>())
1601 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1602 if (!patArgName.empty())
1603 os << "/*" << patArgName << "=*/";
1604 } else {
1605 os << "/*" << opArgName << "=*/";
1606 }
1607 os << handleOpArgument(leaf, patArgName);
1608 }
1609 }
1610 }
1611
createAggregateLocalVarsForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1612 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1613 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1614 Operator &resultOp = node.getDialectOp(opMap);
1615
1616 auto scope = os.scope();
1617 os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
1618 "tblgen_values; (void)tblgen_values;\n");
1619 os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1620 "tblgen_attrs; (void)tblgen_attrs;\n");
1621
1622 const char *addAttrCmd =
1623 "if (auto tmpAttr = {1}) {\n"
1624 " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
1625 "tmpAttr);\n}\n";
1626 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1627 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1628 // The argument in the op definition.
1629 auto opArgName = resultOp.getArgName(argIndex);
1630 if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1631 if (!subTree.isNativeCodeCall())
1632 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1633 "for creating attribute");
1634 os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
1635 } else {
1636 auto leaf = node.getArgAsLeaf(argIndex);
1637 // The argument in the result DAG pattern.
1638 auto patArgName = node.getArgName(argIndex);
1639 os << formatv(addAttrCmd, opArgName,
1640 handleOpArgument(leaf, patArgName));
1641 }
1642 continue;
1643 }
1644
1645 const auto *operand =
1646 resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1647 std::string varName;
1648 if (operand->isVariadic()) {
1649 std::string range;
1650 if (node.isNestedDagArg(argIndex)) {
1651 range = childNodeNames.lookup(argIndex);
1652 } else {
1653 range = std::string(node.getArgName(argIndex));
1654 }
1655 // Resolve the symbol for all range use so that we have a uniform way of
1656 // capturing the values.
1657 range = symbolInfoMap.getValueAndRangeUse(range);
1658 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
1659 range);
1660 } else {
1661 os << formatv("tblgen_values.push_back(");
1662 if (node.isNestedDagArg(argIndex)) {
1663 os << symbolInfoMap.getValueAndRangeUse(
1664 childNodeNames.lookup(argIndex));
1665 } else {
1666 DagLeaf leaf = node.getArgAsLeaf(argIndex);
1667 if (leaf.isConstantAttr())
1668 // TODO: Use better location
1669 PrintFatalError(
1670 loc,
1671 "attribute found where value was expected, if attempting to use "
1672 "constant value, construct a constant op with given attribute "
1673 "instead");
1674
1675 auto symbol =
1676 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1677 if (leaf.isNativeCodeCall()) {
1678 os << std::string(
1679 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1680 } else {
1681 os << symbol;
1682 }
1683 }
1684 os << ");\n";
1685 }
1686 }
1687 }
1688
StaticMatcherHelper(raw_ostream & os,const RecordKeeper & recordKeeper,RecordOperatorMap & mapper)1689 StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
1690 const RecordKeeper &recordKeeper,
1691 RecordOperatorMap &mapper)
1692 : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {}
1693
populateStaticMatchers(raw_ostream & os)1694 void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
1695 // PatternEmitter will use the static matcher if there's one generated. To
1696 // ensure that all the dependent static matchers are generated before emitting
1697 // the matching logic of the DagNode, we use topological order to achieve it.
1698 for (auto &dagInfo : topologicalOrder) {
1699 DagNode node = dagInfo.first;
1700 if (!useStaticMatcher(node))
1701 continue;
1702
1703 std::string funcName =
1704 formatv("static_dag_matcher_{0}", staticMatcherCounter++);
1705 assert(matcherNames.find(node) == matcherNames.end());
1706 PatternEmitter(dagInfo.second, &opMap, os, *this)
1707 .emitStaticMatcher(node, funcName);
1708 matcherNames[node] = funcName;
1709 }
1710 }
1711
populateStaticConstraintFunctions(raw_ostream & os)1712 void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
1713 staticVerifierEmitter.emitPatternConstraints(constraints.getArrayRef());
1714 }
1715
addPattern(Record * record)1716 void StaticMatcherHelper::addPattern(Record *record) {
1717 Pattern pat(record, &opMap);
1718
1719 // While generating the function body of the DAG matcher, it may depends on
1720 // other DAG matchers. To ensure the dependent matchers are ready, we compute
1721 // the topological order for all the DAGs and emit the DAG matchers in this
1722 // order.
1723 llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
1724 ++refStats[node];
1725
1726 if (refStats[node] != 1)
1727 return;
1728
1729 for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
1730 if (DagNode sibling = node.getArgAsNestedDag(i))
1731 dfs(sibling);
1732 else {
1733 DagLeaf leaf = node.getArgAsLeaf(i);
1734 if (!leaf.isUnspecified())
1735 constraints.insert(leaf);
1736 }
1737
1738 topologicalOrder.push_back(std::make_pair(node, record));
1739 };
1740
1741 dfs(pat.getSourcePattern());
1742 }
1743
getVerifierName(DagLeaf leaf)1744 StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
1745 if (leaf.isAttrMatcher()) {
1746 Optional<StringRef> constraint =
1747 staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint());
1748 assert(constraint && "attribute constraint was not uniqued");
1749 return *constraint;
1750 }
1751 assert(leaf.isOperandMatcher());
1752 return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint());
1753 }
1754
emitRewriters(const RecordKeeper & recordKeeper,raw_ostream & os)1755 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1756 emitSourceFileHeader("Rewriters", os);
1757
1758 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1759
1760 // We put the map here because it can be shared among multiple patterns.
1761 RecordOperatorMap recordOpMap;
1762
1763 // Exam all the patterns and generate static matcher for the duplicated
1764 // DagNode.
1765 StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap);
1766 for (Record *p : patterns)
1767 staticMatcher.addPattern(p);
1768 staticMatcher.populateStaticConstraintFunctions(os);
1769 staticMatcher.populateStaticMatchers(os);
1770
1771 std::vector<std::string> rewriterNames;
1772 rewriterNames.reserve(patterns.size());
1773
1774 std::string baseRewriterName = "GeneratedConvert";
1775 int rewriterIndex = 0;
1776
1777 for (Record *p : patterns) {
1778 std::string name;
1779 if (p->isAnonymous()) {
1780 // If no name is provided, ensure unique rewriter names simply by
1781 // appending unique suffix.
1782 name = baseRewriterName + llvm::utostr(rewriterIndex++);
1783 } else {
1784 name = std::string(p->getName());
1785 }
1786 LLVM_DEBUG(llvm::dbgs()
1787 << "=== start generating pattern '" << name << "' ===\n");
1788 PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name);
1789 LLVM_DEBUG(llvm::dbgs()
1790 << "=== done generating pattern '" << name << "' ===\n");
1791 rewriterNames.push_back(std::move(name));
1792 }
1793
1794 // Emit function to add the generated matchers to the pattern list.
1795 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
1796 "::mlir::RewritePatternSet &patterns) {\n";
1797 for (const auto &name : rewriterNames) {
1798 os << " patterns.add<" << name << ">(patterns.getContext());\n";
1799 }
1800 os << "}\n";
1801 }
1802
1803 static mlir::GenRegistration
1804 genRewriters("gen-rewriters", "Generate pattern rewriters",
__anon7a36dead0702(const RecordKeeper &records, raw_ostream &os) 1805 [](const RecordKeeper &records, raw_ostream &os) {
1806 emitRewriters(records, os);
1807 return false;
1808 });
1809