1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Operator.h"
19 #include "mlir/TableGen/Pattern.h"
20 #include "mlir/TableGen/Predicate.h"
21 #include "mlir/TableGen/Type.h"
22 #include "llvm/ADT/FunctionExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/FormatAdapters.h"
29 #include "llvm/Support/PrettyStackTrace.h"
30 #include "llvm/Support/Signals.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Main.h"
33 #include "llvm/TableGen/Record.h"
34 #include "llvm/TableGen/TableGenBackend.h"
35 
36 using namespace mlir;
37 using namespace mlir::tblgen;
38 
39 using llvm::formatv;
40 using llvm::Record;
41 using llvm::RecordKeeper;
42 
43 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
44 
45 namespace llvm {
46 template <>
47 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
formatllvm::format_provider48   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
49                      raw_ostream &os, StringRef style) {
50     os << v.first << ":" << v.second;
51   }
52 };
53 } // namespace llvm
54 
55 //===----------------------------------------------------------------------===//
56 // PatternEmitter
57 //===----------------------------------------------------------------------===//
58 
59 namespace {
60 
61 class StaticMatcherHelper;
62 
63 class PatternEmitter {
64 public:
65   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
66                  StaticMatcherHelper &helper);
67 
68   // Emits the mlir::RewritePattern struct named `rewriteName`.
69   void emit(StringRef rewriteName);
70 
71   // Emits the static function of DAG matcher.
72   void emitStaticMatcher(DagNode tree, std::string funcName);
73 
74 private:
75   // Emits the code for matching ops.
76   void emitMatchLogic(DagNode tree, StringRef opName);
77 
78   // Emits the code for rewriting ops.
79   void emitRewriteLogic();
80 
81   //===--------------------------------------------------------------------===//
82   // Match utilities
83   //===--------------------------------------------------------------------===//
84 
85   // Emits C++ statements for matching the DAG structure.
86   void emitMatch(DagNode tree, StringRef name, int depth);
87 
88   // Emit C++ function call to static DAG matcher.
89   void emitStaticMatchCall(DagNode tree, StringRef name);
90 
91   // Emit C++ function call to static type/attribute constraint function.
92   void emitStaticVerifierCall(StringRef funcName, StringRef opName,
93                               StringRef arg, StringRef failureStr);
94 
95   // Emits C++ statements for matching using a native code call.
96   void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
97 
98   // Emits C++ statements for matching the op constrained by the given DAG
99   // `tree` returning the op's variable name.
100   void emitOpMatch(DagNode tree, StringRef opName, int depth);
101 
102   // Emits C++ statements for matching the `argIndex`-th argument of the given
103   // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
104   // bound name and the constraint of the operand respectively.
105   void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
106                         DagLeaf operandMatcher, StringRef argName,
107                         int argIndex);
108 
109   // Emits C++ statements for matching the operands which can be matched in
110   // either order.
111   void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
112                               StringRef opName, int argIndex, int &operandIndex,
113                               int depth);
114 
115   // Emits C++ statements for matching the `argIndex`-th argument of the given
116   // DAG `tree` as an attribute.
117   void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
118                           int depth);
119 
120   // Emits C++ for checking a match with a corresponding match failure
121   // diagnostic.
122   void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
123                       const llvm::formatv_object_base &failureFmt);
124 
125   // Emits C++ for checking a match with a corresponding match failure
126   // diagnostics.
127   void emitMatchCheck(StringRef opName, const std::string &matchStr,
128                       const std::string &failureStr);
129 
130   //===--------------------------------------------------------------------===//
131   // Rewrite utilities
132   //===--------------------------------------------------------------------===//
133 
134   // The entry point for handling a result pattern rooted at `resultTree`. This
135   // method dispatches to concrete handlers according to `resultTree`'s kind and
136   // returns a symbol representing the whole value pack. Callers are expected to
137   // further resolve the symbol according to the specific use case.
138   //
139   // `depth` is the nesting level of `resultTree`; 0 means top-level result
140   // pattern. For top-level result pattern, `resultIndex` indicates which result
141   // of the matched root op this pattern is intended to replace, which can be
142   // used to deduce the result type of the op generated from this result
143   // pattern.
144   std::string handleResultPattern(DagNode resultTree, int resultIndex,
145                                   int depth);
146 
147   // Emits the C++ statement to replace the matched DAG with a value built via
148   // calling native C++ code.
149   std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
150 
151   // Returns the symbol of the old value serving as the replacement.
152   StringRef handleReplaceWithValue(DagNode tree);
153 
154   // Trailing directives are used at the end of DAG node argument lists to
155   // specify additional behaviour for op matchers and creators, etc.
156   struct TrailingDirectives {
157     // DAG node containing the `location` directive. Null if there is none.
158     DagNode location;
159 
160     // DAG node containing the `returnType` directive. Null if there is none.
161     DagNode returnType;
162 
163     // Number of found trailing directives.
164     int numDirectives;
165   };
166 
167   // Collect any trailing directives.
168   TrailingDirectives getTrailingDirectives(DagNode tree);
169 
170   // Returns the location value to use.
171   std::string getLocation(TrailingDirectives &tail);
172 
173   // Returns the location value to use.
174   std::string handleLocationDirective(DagNode tree);
175 
176   // Emit return type argument.
177   std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
178 
179   // Emits the C++ statement to build a new op out of the given DAG `tree` and
180   // returns the variable name that this op is assigned to. If the root op in
181   // DAG `tree` has a specified name, the created op will be assigned to a
182   // variable of the given name. Otherwise, a unique name will be used as the
183   // result value name.
184   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
185 
186   using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
187 
188   // Emits a local variable for each value and attribute to be used for creating
189   // an op.
190   void createSeparateLocalVarsForOpArgs(DagNode node,
191                                         ChildNodeIndexNameMap &childNodeNames);
192 
193   // Emits the concrete arguments used to call an op's builder.
194   void supplyValuesForOpArgs(DagNode node,
195                              const ChildNodeIndexNameMap &childNodeNames,
196                              int depth);
197 
198   // Emits the local variables for holding all values as a whole and all named
199   // attributes as a whole to be used for creating an op.
200   void createAggregateLocalVarsForOpArgs(
201       DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
202 
203   // Returns the C++ expression to construct a constant attribute of the given
204   // `value` for the given attribute kind `attr`.
205   std::string handleConstantAttr(Attribute attr, const Twine &value);
206 
207   // Returns the C++ expression to build an argument from the given DAG `leaf`.
208   // `patArgName` is used to bound the argument to the source pattern.
209   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
210 
211   //===--------------------------------------------------------------------===//
212   // General utilities
213   //===--------------------------------------------------------------------===//
214 
215   // Collects all of the operations within the given dag tree.
216   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
217 
218   // Returns a unique symbol for a local variable of the given `op`.
219   std::string getUniqueSymbol(const Operator *op);
220 
221   //===--------------------------------------------------------------------===//
222   // Symbol utilities
223   //===--------------------------------------------------------------------===//
224 
225   // Returns how many static values the given DAG `node` correspond to.
226   int getNodeValueCount(DagNode node);
227 
228 private:
229   // Pattern instantiation location followed by the location of multiclass
230   // prototypes used. This is intended to be used as a whole to
231   // PrintFatalError() on errors.
232   ArrayRef<SMLoc> loc;
233 
234   // Op's TableGen Record to wrapper object.
235   RecordOperatorMap *opMap;
236 
237   // Handy wrapper for pattern being emitted.
238   Pattern pattern;
239 
240   // Map for all bound symbols' info.
241   SymbolInfoMap symbolInfoMap;
242 
243   StaticMatcherHelper &staticMatcherHelper;
244 
245   // The next unused ID for newly created values.
246   unsigned nextValueId = 0;
247 
248   raw_indented_ostream os;
249 
250   // Format contexts containing placeholder substitutions.
251   FmtContext fmtCtx;
252 };
253 
254 // Tracks DagNode's reference multiple times across patterns. Enables generating
255 // static matcher functions for DagNode's referenced multiple times rather than
256 // inlining them.
257 class StaticMatcherHelper {
258 public:
259   StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper,
260                       RecordOperatorMap &mapper);
261 
262   // Determine if we should inline the match logic or delegate to a static
263   // function.
useStaticMatcher(DagNode node)264   bool useStaticMatcher(DagNode node) {
265     return refStats[node] > kStaticMatcherThreshold;
266   }
267 
268   // Get the name of the static DAG matcher function corresponding to the node.
getMatcherName(DagNode node)269   std::string getMatcherName(DagNode node) {
270     assert(useStaticMatcher(node));
271     return matcherNames[node];
272   }
273 
274   // Get the name of static type/attribute verification function.
275   StringRef getVerifierName(DagLeaf leaf);
276 
277   // Collect the `Record`s, i.e., the DRR, so that we can get the information of
278   // the duplicated DAGs.
279   void addPattern(Record *record);
280 
281   // Emit all static functions of DAG Matcher.
282   void populateStaticMatchers(raw_ostream &os);
283 
284   // Emit all static functions for Constraints.
285   void populateStaticConstraintFunctions(raw_ostream &os);
286 
287 private:
288   static constexpr unsigned kStaticMatcherThreshold = 1;
289 
290   // Consider two patterns as down below,
291   //   DagNode_Root_A    DagNode_Root_B
292   //       \                 \
293   //     DagNode_C         DagNode_C
294   //         \                 \
295   //       DagNode_D         DagNode_D
296   //
297   // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
298   // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
299   // multiple times so we'll have static matchers for both of them. When we're
300   // emitting the match logic for DagNode_C, we will check if DagNode_D has the
301   // static matcher generated. If so, then we'll generate a call to the
302   // function, inline otherwise. In this case, inlining is not what we want. As
303   // a result, generate the static matcher in topological order to ensure all
304   // the dependent static matchers are generated and we can avoid accidentally
305   // inlining.
306   //
307   // The topological order of all the DagNodes among all patterns.
308   SmallVector<std::pair<DagNode, Record *>> topologicalOrder;
309 
310   RecordOperatorMap &opMap;
311 
312   // Records of the static function name of each DagNode
313   DenseMap<DagNode, std::string> matcherNames;
314 
315   // After collecting all the DagNode in each pattern, `refStats` records the
316   // number of users for each DagNode. We will generate the static matcher for a
317   // DagNode while the number of users exceeds a certain threshold.
318   DenseMap<DagNode, unsigned> refStats;
319 
320   // Number of static matcher generated. This is used to generate a unique name
321   // for each DagNode.
322   int staticMatcherCounter = 0;
323 
324   // The DagLeaf which contains type or attr constraint.
325   SetVector<DagLeaf> constraints;
326 
327   // Static type/attribute verification function emitter.
328   StaticVerifierFunctionEmitter staticVerifierEmitter;
329 };
330 
331 } // namespace
332 
PatternEmitter(Record * pat,RecordOperatorMap * mapper,raw_ostream & os,StaticMatcherHelper & helper)333 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
334                                raw_ostream &os, StaticMatcherHelper &helper)
335     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
336       symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) {
337   fmtCtx.withBuilder("rewriter");
338 }
339 
handleConstantAttr(Attribute attr,const Twine & value)340 std::string PatternEmitter::handleConstantAttr(Attribute attr,
341                                                const Twine &value) {
342   if (!attr.isConstBuildable())
343     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
344                              " does not have the 'constBuilderCall' field");
345 
346   // TODO: Verify the constants here
347   return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
348 }
349 
emitStaticMatcher(DagNode tree,std::string funcName)350 void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
351   os << formatv(
352       "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
353       "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
354       "*, 4> &tblgen_ops",
355       funcName);
356 
357   // We pass the reference of the variables that need to be captured. Hence we
358   // need to collect all the symbols in the tree first.
359   pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true);
360   symbolInfoMap.assignUniqueAlternativeNames();
361   for (const auto &info : symbolInfoMap)
362     os << formatv(", {0}", info.second.getArgDecl(info.first));
363 
364   os << ") {\n";
365   os.indent();
366   os << "(void)tblgen_ops;\n";
367 
368   // Note that a static matcher is considered at least one step from the match
369   // entry.
370   emitMatch(tree, "op0", /*depth=*/1);
371 
372   os << "return ::mlir::success();\n";
373   os.unindent();
374   os << "}\n\n";
375 }
376 
377 // Helper function to match patterns.
emitMatch(DagNode tree,StringRef name,int depth)378 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
379   if (tree.isNativeCodeCall()) {
380     emitNativeCodeMatch(tree, name, depth);
381     return;
382   }
383 
384   if (tree.isOperation()) {
385     emitOpMatch(tree, name, depth);
386     return;
387   }
388 
389   PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
390 }
391 
emitStaticMatchCall(DagNode tree,StringRef opName)392 void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
393   std::string funcName = staticMatcherHelper.getMatcherName(tree);
394   os << formatv("if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", funcName,
395                 opName);
396 
397   // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
398   // one pass.
399 
400   // In general, bound symbol should have the unique name in the pattern but
401   // for the operand, binding same symbol to multiple operands imply a
402   // constraint at the same time. In this case, we will rename those operands
403   // with different names. As a result, we need to collect all the symbolInfos
404   // from the DagNode then get the updated name of the local variables from the
405   // global symbolInfoMap.
406 
407   // Collect all the bound symbols in the Dag
408   SymbolInfoMap localSymbolMap(loc);
409   pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true);
410 
411   for (const auto &info : localSymbolMap) {
412     auto name = info.first;
413     auto symboInfo = info.second;
414     auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo);
415     os << formatv(", {0}", ret->second.getVarName(name));
416   }
417 
418   os << "))) {\n";
419   os.scope().os << "return ::mlir::failure();\n";
420   os << "}\n";
421 }
422 
emitStaticVerifierCall(StringRef funcName,StringRef opName,StringRef arg,StringRef failureStr)423 void PatternEmitter::emitStaticVerifierCall(StringRef funcName,
424                                             StringRef opName, StringRef arg,
425                                             StringRef failureStr) {
426   os << formatv("if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n",
427                 funcName, opName, arg, failureStr);
428   os.scope().os << "return ::mlir::failure();\n";
429   os << "}\n";
430 }
431 
432 // Helper function to match patterns.
emitNativeCodeMatch(DagNode tree,StringRef opName,int depth)433 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
434                                          int depth) {
435   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
436   LLVM_DEBUG(tree.print(llvm::dbgs()));
437   LLVM_DEBUG(llvm::dbgs() << '\n');
438 
439   // The order of generating static matcher follows the topological order so
440   // that for every dependent DagNode already have their static matcher
441   // generated if needed. The reason we check if `getMatcherName(tree).empty()`
442   // is when we are generating the static matcher for a DagNode itself. In this
443   // case, we need to emit the function body rather than a function call.
444   if (staticMatcherHelper.useStaticMatcher(tree) &&
445       !staticMatcherHelper.getMatcherName(tree).empty()) {
446     emitStaticMatchCall(tree, opName);
447 
448     // NativeCodeCall will never be at depth 0 so that we don't need to catch
449     // the root operation as emitOpMatch();
450 
451     return;
452   }
453 
454   // TODO(suderman): iterate through arguments, determine their types, output
455   // names.
456   SmallVector<std::string, 8> capture;
457 
458   raw_indented_ostream::DelimitedScope scope(os);
459 
460   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
461     std::string argName = formatv("arg{0}_{1}", depth, i);
462     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
463       if (argTree.isEither())
464         PrintFatalError(loc, "NativeCodeCall cannot have `either` operands");
465 
466       os << "::mlir::Value " << argName << ";\n";
467     } else {
468       auto leaf = tree.getArgAsLeaf(i);
469       if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
470         os << "::mlir::Attribute " << argName << ";\n";
471       } else {
472         os << "::mlir::Value " << argName << ";\n";
473       }
474     }
475 
476     capture.push_back(std::move(argName));
477   }
478 
479   auto tail = getTrailingDirectives(tree);
480   if (tail.returnType)
481     PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
482   auto locToUse = getLocation(tail);
483 
484   auto fmt = tree.getNativeCodeTemplate();
485   if (fmt.count("$_self") != 1)
486     PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
487                          "passing the defining Operation");
488 
489   auto nativeCodeCall = std::string(
490       tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()),
491             static_cast<ArrayRef<std::string>>(capture)));
492 
493   emitMatchCheck(opName, formatv("!::mlir::failed({0})", nativeCodeCall),
494                  formatv("\"{0} return ::mlir::failure\"", nativeCodeCall));
495 
496   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
497     auto name = tree.getArgName(i);
498     if (!name.empty() && name != "_") {
499       os << formatv("{0} = {1};\n", name, capture[i]);
500     }
501   }
502 
503   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
504     std::string argName = capture[i];
505 
506     // Handle nested DAG construct first
507     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
508       PrintFatalError(
509           loc, formatv("Matching nested tree in NativeCodecall not support for "
510                        "{0} as arg {1}",
511                        argName, i));
512     }
513 
514     DagLeaf leaf = tree.getArgAsLeaf(i);
515 
516     // The parameter for native function doesn't bind any constraints.
517     if (leaf.isUnspecified())
518       continue;
519 
520     auto constraint = leaf.getAsConstraint();
521 
522     std::string self;
523     if (leaf.isAttrMatcher() || leaf.isConstantAttr())
524       self = argName;
525     else
526       self = formatv("{0}.getType()", argName);
527     StringRef verifier = staticMatcherHelper.getVerifierName(leaf);
528     emitStaticVerifierCall(
529         verifier, opName, self,
530         formatv("\"operand {0} of native code call '{1}' failed to satisfy "
531                 "constraint: "
532                 "'{2}'\"",
533                 i, tree.getNativeCodeTemplate(),
534                 escapeString(constraint.getSummary()))
535             .str());
536   }
537 
538   LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
539 }
540 
541 // Helper function to match patterns.
emitOpMatch(DagNode tree,StringRef opName,int depth)542 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
543   Operator &op = tree.getDialectOp(opMap);
544   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
545                           << op.getOperationName() << "' at depth " << depth
546                           << '\n');
547 
548   auto getCastedName = [depth]() -> std::string {
549     return formatv("castedOp{0}", depth);
550   };
551 
552   // The order of generating static matcher follows the topological order so
553   // that for every dependent DagNode already have their static matcher
554   // generated if needed. The reason we check if `getMatcherName(tree).empty()`
555   // is when we are generating the static matcher for a DagNode itself. In this
556   // case, we need to emit the function body rather than a function call.
557   if (staticMatcherHelper.useStaticMatcher(tree) &&
558       !staticMatcherHelper.getMatcherName(tree).empty()) {
559     emitStaticMatchCall(tree, opName);
560     // In the codegen of rewriter, we suppose that castedOp0 will capture the
561     // root operation. Manually add it if the root DagNode is a static matcher.
562     if (depth == 0)
563       os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
564                     "(void){2};\n",
565                     opName, op.getQualCppClassName(), getCastedName());
566     return;
567   }
568 
569   std::string castedName = getCastedName();
570   os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
571                 "(void){0};\n",
572                 castedName, opName, op.getQualCppClassName());
573 
574   // Skip the operand matching at depth 0 as the pattern rewriter already does.
575   if (depth != 0)
576     emitMatchCheck(opName, /*matchStr=*/castedName,
577                    formatv("\"{0} is not {1} type\"", castedName,
578                            op.getQualCppClassName()));
579 
580   // If the operand's name is set, set to that variable.
581   auto name = tree.getSymbol();
582   if (!name.empty())
583     os << formatv("{0} = {1};\n", name, castedName);
584 
585   for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) {
586     auto opArg = op.getArg(i);
587     std::string argName = formatv("op{0}", depth + 1);
588 
589     // Handle nested DAG construct first
590     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
591       if (argTree.isEither()) {
592         emitEitherOperandMatch(tree, argTree, castedName, i, nextOperand,
593                                depth);
594         continue;
595       }
596       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
597         if (operand->isVariableLength()) {
598           auto error = formatv("use nested DAG construct to match op {0}'s "
599                                "variadic operand #{1} unsupported now",
600                                op.getOperationName(), i);
601           PrintFatalError(loc, error);
602         }
603       }
604 
605       os << "{\n";
606 
607       // Attributes don't count for getODSOperands.
608       // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
609       os.indent() << formatv(
610           "auto *{0} = "
611           "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
612           argName, castedName, nextOperand);
613       // Null check of operand's definingOp
614       emitMatchCheck(
615           castedName, /*matchStr=*/argName,
616           formatv("\"There's no operation that defines operand {0} of {1}\"",
617                   nextOperand++, castedName));
618       emitMatch(argTree, argName, depth + 1);
619       os << formatv("tblgen_ops.push_back({0});\n", argName);
620       os.unindent() << "}\n";
621       continue;
622     }
623 
624     // Next handle DAG leaf: operand or attribute
625     if (opArg.is<NamedTypeConstraint *>()) {
626       auto operandName =
627           formatv("{0}.getODSOperands({1})", castedName, nextOperand);
628       emitOperandMatch(tree, castedName, operandName.str(),
629                        /*operandMatcher=*/tree.getArgAsLeaf(i),
630                        /*argName=*/tree.getArgName(i),
631                        /*argIndex=*/i);
632       ++nextOperand;
633     } else if (opArg.is<NamedAttribute *>()) {
634       emitAttributeMatch(tree, opName, i, depth);
635     } else {
636       PrintFatalError(loc, "unhandled case when matching op");
637     }
638   }
639   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
640                           << op.getOperationName() << "' at depth " << depth
641                           << '\n');
642 }
643 
emitOperandMatch(DagNode tree,StringRef opName,StringRef operandName,DagLeaf operandMatcher,StringRef argName,int argIndex)644 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
645                                       StringRef operandName,
646                                       DagLeaf operandMatcher, StringRef argName,
647                                       int argIndex) {
648   Operator &op = tree.getDialectOp(opMap);
649   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
650 
651   // If a constraint is specified, we need to generate C++ statements to
652   // check the constraint.
653   if (!operandMatcher.isUnspecified()) {
654     if (!operandMatcher.isOperandMatcher())
655       PrintFatalError(
656           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
657                        op.getOperationName(), argIndex + 1));
658 
659     // Only need to verify if the matcher's type is different from the one
660     // of op definition.
661     Constraint constraint = operandMatcher.getAsConstraint();
662     if (operand->constraint != constraint) {
663       if (operand->isVariableLength()) {
664         auto error = formatv(
665             "further constrain op {0}'s variadic operand #{1} unsupported now",
666             op.getOperationName(), argIndex);
667         PrintFatalError(loc, error);
668       }
669       auto self = formatv("(*{0}.begin()).getType()", operandName);
670       StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher);
671       emitStaticVerifierCall(
672           verifier, opName, self.str(),
673           formatv(
674               "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
675               operand - op.operand_begin(), op.getOperationName(),
676               escapeString(constraint.getSummary()))
677               .str());
678     }
679   }
680 
681   // Capture the value
682   // `$_` is a special symbol to ignore op argument matching.
683   if (!argName.empty() && argName != "_") {
684     auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex);
685     os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName);
686   }
687 }
688 
emitEitherOperandMatch(DagNode tree,DagNode eitherArgTree,StringRef opName,int argIndex,int & operandIndex,int depth)689 void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
690                                             StringRef opName, int argIndex,
691                                             int &operandIndex, int depth) {
692   constexpr int numEitherArgs = 2;
693   if (eitherArgTree.getNumArgs() != numEitherArgs)
694     PrintFatalError(loc, "`either` only supports grouping two operands");
695 
696   Operator &op = tree.getDialectOp(opMap);
697 
698   std::string codeBuffer;
699   llvm::raw_string_ostream tblgenOps(codeBuffer);
700 
701   std::string lambda = formatv("eitherLambda{0}", depth);
702   os << formatv(
703       "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n",
704       lambda);
705 
706   os.indent();
707 
708   for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
709     if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) {
710       if (argTree.isEither())
711         PrintFatalError(loc, "either cannot be nested");
712 
713       std::string argName = formatv("local_op_{0}", i).str();
714 
715       os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName,
716                     i);
717       emitMatchCheck(
718           opName, /*matchStr=*/argName,
719           formatv("\"There's no operation that defines operand {0} of {1}\"",
720                   operandIndex++, opName));
721       emitMatch(argTree, argName, depth + 1);
722       // `tblgen_ops` is used to collect the matched operations. In either, we
723       // need to queue the operation only if the matching success. Thus we emit
724       // the code at the end.
725       tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
726     } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
727       emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
728                        /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
729                        /*argName=*/eitherArgTree.getArgName(i), argIndex);
730       ++operandIndex;
731     } else {
732       PrintFatalError(loc, "either can only be applied on operand");
733     }
734   }
735 
736   os << tblgenOps.str();
737   os << "return ::mlir::success();\n";
738   os.unindent() << "};\n";
739 
740   os << "{\n";
741   os.indent();
742 
743   os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName,
744                 operandIndex - 2);
745   os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName,
746                 operandIndex - 1);
747 
748   os << formatv("if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && "
749                 "::mlir::failed({0}(eitherOperand1, "
750                 "eitherOperand0)))\n",
751                 lambda);
752   os.indent() << "return ::mlir::failure();\n";
753 
754   os.unindent().unindent() << "}\n";
755 }
756 
emitAttributeMatch(DagNode tree,StringRef opName,int argIndex,int depth)757 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
758                                         int argIndex, int depth) {
759   Operator &op = tree.getDialectOp(opMap);
760   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
761   const auto &attr = namedAttr->attr;
762 
763   os << "{\n";
764   os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
765                          "(void)tblgen_attr;\n",
766                          opName, attr.getStorageType(), namedAttr->name);
767 
768   // TODO: This should use getter method to avoid duplication.
769   if (attr.hasDefaultValue()) {
770     os << "if (!tblgen_attr) tblgen_attr = "
771        << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
772                             attr.getDefaultValue()))
773        << ";\n";
774   } else if (attr.isOptional()) {
775     // For a missing attribute that is optional according to definition, we
776     // should just capture a mlir::Attribute() to signal the missing state.
777     // That is precisely what getAttr() returns on missing attributes.
778   } else {
779     emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
780                    formatv("\"expected op '{0}' to have attribute '{1}' "
781                            "of type '{2}'\"",
782                            op.getOperationName(), namedAttr->name,
783                            attr.getStorageType()));
784   }
785 
786   auto matcher = tree.getArgAsLeaf(argIndex);
787   if (!matcher.isUnspecified()) {
788     if (!matcher.isAttrMatcher()) {
789       PrintFatalError(
790           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
791                        op.getOperationName(), argIndex + 1));
792     }
793 
794     // If a constraint is specified, we need to generate function call to its
795     // static verifier.
796     StringRef verifier = staticMatcherHelper.getVerifierName(matcher);
797     if (attr.isOptional()) {
798       // Avoid dereferencing null attribute. This is using a simple heuristic to
799       // avoid common cases of attempting to dereference null attribute. This
800       // will return where there is no check if attribute is null unless the
801       // attribute's value is not used.
802       // FIXME: This could be improved as some null dereferences could slip
803       // through.
804       if (!StringRef(matcher.getConditionTemplate()).contains("!$_self") &&
805           StringRef(matcher.getConditionTemplate()).contains("$_self")) {
806         os << "if (!tblgen_attr) return ::mlir::failure();\n";
807       }
808     }
809     emitStaticVerifierCall(
810         verifier, opName, "tblgen_attr",
811         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
812                 "'{2}'\"",
813                 op.getOperationName(), namedAttr->name,
814                 escapeString(matcher.getAsConstraint().getSummary()))
815             .str());
816   }
817 
818   // Capture the value
819   auto name = tree.getArgName(argIndex);
820   // `$_` is a special symbol to ignore op argument matching.
821   if (!name.empty() && name != "_") {
822     os << formatv("{0} = tblgen_attr;\n", name);
823   }
824 
825   os.unindent() << "}\n";
826 }
827 
emitMatchCheck(StringRef opName,const FmtObjectBase & matchFmt,const llvm::formatv_object_base & failureFmt)828 void PatternEmitter::emitMatchCheck(
829     StringRef opName, const FmtObjectBase &matchFmt,
830     const llvm::formatv_object_base &failureFmt) {
831   emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
832 }
833 
emitMatchCheck(StringRef opName,const std::string & matchStr,const std::string & failureStr)834 void PatternEmitter::emitMatchCheck(StringRef opName,
835                                     const std::string &matchStr,
836                                     const std::string &failureStr) {
837 
838   os << "if (!(" << matchStr << "))";
839   os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
840                               << ", [&](::mlir::Diagnostic &diag) {\n  diag << "
841                               << failureStr << ";\n});";
842 }
843 
emitMatchLogic(DagNode tree,StringRef opName)844 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
845   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
846   int depth = 0;
847   emitMatch(tree, opName, depth);
848 
849   for (auto &appliedConstraint : pattern.getConstraints()) {
850     auto &constraint = appliedConstraint.constraint;
851     auto &entities = appliedConstraint.entities;
852 
853     auto condition = constraint.getConditionTemplate();
854     if (isa<TypeConstraint>(constraint)) {
855       if (entities.size() != 1)
856         PrintFatalError(loc, "type constraint requires exactly one argument");
857 
858       auto self = formatv("({0}.getType())",
859                           symbolInfoMap.getValueAndRangeUse(entities.front()));
860       emitMatchCheck(
861           opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
862           formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"",
863                   entities.front(), escapeString(constraint.getSummary())));
864 
865     } else if (isa<AttrConstraint>(constraint)) {
866       PrintFatalError(
867           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
868     } else {
869       // TODO: replace formatv arguments with the exact specified
870       // args.
871       if (entities.size() > 4) {
872         PrintFatalError(loc, "only support up to 4-entity constraints now");
873       }
874       SmallVector<std::string, 4> names;
875       int i = 0;
876       for (int e = entities.size(); i < e; ++i)
877         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
878       std::string self = appliedConstraint.self;
879       if (!self.empty())
880         self = symbolInfoMap.getValueAndRangeUse(self);
881       for (; i < 4; ++i)
882         names.push_back("<unused>");
883       emitMatchCheck(opName,
884                      tgfmt(condition, &fmtCtx.withSelf(self), names[0],
885                            names[1], names[2], names[3]),
886                      formatv("\"entities '{0}' failed to satisfy constraint: "
887                              "'{1}'\"",
888                              llvm::join(entities, ", "),
889                              escapeString(constraint.getSummary())));
890     }
891   }
892 
893   // Some of the operands could be bound to the same symbol name, we need
894   // to enforce equality constraint on those.
895   // TODO: we should be able to emit equality checks early
896   // and short circuit unnecessary work if vars are not equal.
897   for (auto symbolInfoIt = symbolInfoMap.begin();
898        symbolInfoIt != symbolInfoMap.end();) {
899     auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
900     auto startRange = range.first;
901     auto endRange = range.second;
902 
903     auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
904     for (++startRange; startRange != endRange; ++startRange) {
905       auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
906       emitMatchCheck(
907           opName,
908           formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
909           formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
910                   secondOperand));
911     }
912 
913     symbolInfoIt = endRange;
914   }
915 
916   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
917 }
918 
collectOps(DagNode tree,llvm::SmallPtrSetImpl<const Operator * > & ops)919 void PatternEmitter::collectOps(DagNode tree,
920                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
921   // Check if this tree is an operation.
922   if (tree.isOperation()) {
923     const Operator &op = tree.getDialectOp(opMap);
924     LLVM_DEBUG(llvm::dbgs()
925                << "found operation " << op.getOperationName() << '\n');
926     ops.insert(&op);
927   }
928 
929   // Recurse the arguments of the tree.
930   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
931     if (auto child = tree.getArgAsNestedDag(i))
932       collectOps(child, ops);
933 }
934 
emit(StringRef rewriteName)935 void PatternEmitter::emit(StringRef rewriteName) {
936   // Get the DAG tree for the source pattern.
937   DagNode sourceTree = pattern.getSourcePattern();
938 
939   const Operator &rootOp = pattern.getSourceRootOp();
940   auto rootName = rootOp.getOperationName();
941 
942   // Collect the set of result operations.
943   llvm::SmallPtrSet<const Operator *, 4> resultOps;
944   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
945   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
946     collectOps(pattern.getResultPattern(i), resultOps);
947   }
948   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
949 
950   // Emit RewritePattern for Pattern.
951   auto locs = pattern.getLocation();
952   os << formatv("/* Generated from:\n    {0:$[ instantiating\n    ]}\n*/\n",
953                 make_range(locs.rbegin(), locs.rend()));
954   os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
955   {0}(::mlir::MLIRContext *context)
956       : ::mlir::RewritePattern("{1}", {2}, context, {{)",
957                 rewriteName, rootName, pattern.getBenefit());
958   // Sort result operators by name.
959   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
960                                                          resultOps.end());
961   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
962     return lhs->getOperationName() < rhs->getOperationName();
963   });
964   llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
965     os << '"' << op->getOperationName() << '"';
966   });
967   os << "}) {}\n";
968 
969   // Emit matchAndRewrite() function.
970   {
971     auto classScope = os.scope();
972     os.printReindented(R"(
973     ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
974         ::mlir::PatternRewriter &rewriter) const override {)")
975         << '\n';
976     {
977       auto functionScope = os.scope();
978 
979       // Register all symbols bound in the source pattern.
980       pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
981 
982       LLVM_DEBUG(llvm::dbgs()
983                  << "start creating local variables for capturing matches\n");
984       os << "// Variables for capturing values and attributes used while "
985             "creating ops\n";
986       // Create local variables for storing the arguments and results bound
987       // to symbols.
988       for (const auto &symbolInfoPair : symbolInfoMap) {
989         const auto &symbol = symbolInfoPair.first;
990         const auto &info = symbolInfoPair.second;
991 
992         os << info.getVarDecl(symbol);
993       }
994       // TODO: capture ops with consistent numbering so that it can be
995       // reused for fused loc.
996       os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
997       LLVM_DEBUG(llvm::dbgs()
998                  << "done creating local variables for capturing matches\n");
999 
1000       os << "// Match\n";
1001       os << "tblgen_ops.push_back(op0);\n";
1002       emitMatchLogic(sourceTree, "op0");
1003 
1004       os << "\n// Rewrite\n";
1005       emitRewriteLogic();
1006 
1007       os << "return ::mlir::success();\n";
1008     }
1009     os << "};\n";
1010   }
1011   os << "};\n\n";
1012 }
1013 
emitRewriteLogic()1014 void PatternEmitter::emitRewriteLogic() {
1015   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
1016   const Operator &rootOp = pattern.getSourceRootOp();
1017   int numExpectedResults = rootOp.getNumResults();
1018   int numResultPatterns = pattern.getNumResultPatterns();
1019 
1020   // First register all symbols bound to ops generated in result patterns.
1021   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
1022 
1023   // Only the last N static values generated are used to replace the matched
1024   // root N-result op. We need to calculate the starting index (of the results
1025   // of the matched op) each result pattern is to replace.
1026   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
1027   // If we don't need to replace any value at all, set the replacement starting
1028   // index as the number of result patterns so we skip all of them when trying
1029   // to replace the matched op's results.
1030   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
1031   for (int i = numResultPatterns - 1; i >= 0; --i) {
1032     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
1033     offsets[i] = offsets[i + 1] - numValues;
1034     if (offsets[i] == 0) {
1035       if (replStartIndex == -1)
1036         replStartIndex = i;
1037     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
1038       auto error = formatv(
1039           "cannot use the same multi-result op '{0}' to generate both "
1040           "auxiliary values and values to be used for replacing the matched op",
1041           pattern.getResultPattern(i).getSymbol());
1042       PrintFatalError(loc, error);
1043     }
1044   }
1045 
1046   if (offsets.front() > 0) {
1047     const char error[] = "no enough values generated to replace the matched op";
1048     PrintFatalError(loc, error);
1049   }
1050 
1051   os << "auto odsLoc = rewriter.getFusedLoc({";
1052   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
1053     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
1054   }
1055   os << "}); (void)odsLoc;\n";
1056 
1057   // Process auxiliary result patterns.
1058   for (int i = 0; i < replStartIndex; ++i) {
1059     DagNode resultTree = pattern.getResultPattern(i);
1060     auto val = handleResultPattern(resultTree, offsets[i], 0);
1061     // Normal op creation will be streamed to `os` by the above call; but
1062     // NativeCodeCall will only be materialized to `os` if it is used. Here
1063     // we are handling auxiliary patterns so we want the side effect even if
1064     // NativeCodeCall is not replacing matched root op's results.
1065     if (resultTree.isNativeCodeCall() &&
1066         resultTree.getNumReturnsOfNativeCode() == 0)
1067       os << val << ";\n";
1068   }
1069 
1070   if (numExpectedResults == 0) {
1071     assert(replStartIndex >= numResultPatterns &&
1072            "invalid auxiliary vs. replacement pattern division!");
1073     // No result to replace. Just erase the op.
1074     os << "rewriter.eraseOp(op0);\n";
1075   } else {
1076     // Process replacement result patterns.
1077     os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
1078     for (int i = replStartIndex; i < numResultPatterns; ++i) {
1079       DagNode resultTree = pattern.getResultPattern(i);
1080       auto val = handleResultPattern(resultTree, offsets[i], 0);
1081       os << "\n";
1082       // Resolve each symbol for all range use so that we can loop over them.
1083       // We need an explicit cast to `SmallVector` to capture the cases where
1084       // `{0}` resolves to an `Operation::result_range` as well as cases that
1085       // are not iterable (e.g. vector that gets wrapped in additional braces by
1086       // RewriterGen).
1087       // TODO: Revisit the need for materializing a vector.
1088       os << symbolInfoMap.getAllRangeUse(
1089           val,
1090           "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
1091           "  tblgen_repl_values.push_back(v);\n}\n",
1092           "\n");
1093     }
1094     os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
1095   }
1096 
1097   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
1098 }
1099 
getUniqueSymbol(const Operator * op)1100 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
1101   return std::string(
1102       formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
1103 }
1104 
handleResultPattern(DagNode resultTree,int resultIndex,int depth)1105 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
1106                                                 int resultIndex, int depth) {
1107   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
1108   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
1109   LLVM_DEBUG(llvm::dbgs() << '\n');
1110 
1111   if (resultTree.isLocationDirective()) {
1112     PrintFatalError(loc,
1113                     "location directive can only be used with op creation");
1114   }
1115 
1116   if (resultTree.isNativeCodeCall())
1117     return handleReplaceWithNativeCodeCall(resultTree, depth);
1118 
1119   if (resultTree.isReplaceWithValue())
1120     return handleReplaceWithValue(resultTree).str();
1121 
1122   // Normal op creation.
1123   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
1124   if (resultTree.getSymbol().empty()) {
1125     // This is an op not explicitly bound to a symbol in the rewrite rule.
1126     // Register the auto-generated symbol for it.
1127     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
1128   }
1129   return symbol;
1130 }
1131 
handleReplaceWithValue(DagNode tree)1132 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
1133   assert(tree.isReplaceWithValue());
1134 
1135   if (tree.getNumArgs() != 1) {
1136     PrintFatalError(
1137         loc, "replaceWithValue directive must take exactly one argument");
1138   }
1139 
1140   if (!tree.getSymbol().empty()) {
1141     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
1142   }
1143 
1144   return tree.getArgName(0);
1145 }
1146 
handleLocationDirective(DagNode tree)1147 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
1148   assert(tree.isLocationDirective());
1149   auto lookUpArgLoc = [this, &tree](int idx) {
1150     const auto *const lookupFmt = "{0}.getLoc()";
1151     return symbolInfoMap.getValueAndRangeUse(tree.getArgName(idx), lookupFmt);
1152   };
1153 
1154   if (tree.getNumArgs() == 0)
1155     llvm::PrintFatalError(
1156         "At least one argument to location directive required");
1157 
1158   if (!tree.getSymbol().empty())
1159     PrintFatalError(loc, "cannot bind symbol to location");
1160 
1161   if (tree.getNumArgs() == 1) {
1162     DagLeaf leaf = tree.getArgAsLeaf(0);
1163     if (leaf.isStringAttr())
1164       return formatv("::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))",
1165                      leaf.getStringAttr())
1166           .str();
1167     return lookUpArgLoc(0);
1168   }
1169 
1170   std::string ret;
1171   llvm::raw_string_ostream os(ret);
1172   std::string strAttr;
1173   os << "rewriter.getFusedLoc({";
1174   bool first = true;
1175   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1176     DagLeaf leaf = tree.getArgAsLeaf(i);
1177     // Handle the optional string value.
1178     if (leaf.isStringAttr()) {
1179       if (!strAttr.empty())
1180         llvm::PrintFatalError("Only one string attribute may be specified");
1181       strAttr = leaf.getStringAttr();
1182       continue;
1183     }
1184     os << (first ? "" : ", ") << lookUpArgLoc(i);
1185     first = false;
1186   }
1187   os << "}";
1188   if (!strAttr.empty()) {
1189     os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
1190   }
1191   os << ")";
1192   return os.str();
1193 }
1194 
handleReturnTypeArg(DagNode returnType,int i,int depth)1195 std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
1196                                                 int depth) {
1197   // Nested NativeCodeCall.
1198   if (auto dagNode = returnType.getArgAsNestedDag(i)) {
1199     if (!dagNode.isNativeCodeCall())
1200       PrintFatalError(loc, "nested DAG in `returnType` must be a native code "
1201                            "call");
1202     return handleReplaceWithNativeCodeCall(dagNode, depth);
1203   }
1204   // String literal.
1205   auto dagLeaf = returnType.getArgAsLeaf(i);
1206   if (dagLeaf.isStringAttr())
1207     return tgfmt(dagLeaf.getStringAttr(), &fmtCtx);
1208   return tgfmt(
1209       "$0.getType()", &fmtCtx,
1210       handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i)));
1211 }
1212 
handleOpArgument(DagLeaf leaf,StringRef patArgName)1213 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
1214                                              StringRef patArgName) {
1215   if (leaf.isStringAttr())
1216     PrintFatalError(loc, "raw string not supported as argument");
1217   if (leaf.isConstantAttr()) {
1218     auto constAttr = leaf.getAsConstantAttr();
1219     return handleConstantAttr(constAttr.getAttribute(),
1220                               constAttr.getConstantValue());
1221   }
1222   if (leaf.isEnumAttrCase()) {
1223     auto enumCase = leaf.getAsEnumAttrCase();
1224     // This is an enum case backed by an IntegerAttr. We need to get its value
1225     // to build the constant.
1226     std::string val = std::to_string(enumCase.getValue());
1227     return handleConstantAttr(enumCase, val);
1228   }
1229 
1230   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
1231   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
1232   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
1233     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
1234                             << "' (via symbol ref)\n");
1235     return argName;
1236   }
1237   if (leaf.isNativeCodeCall()) {
1238     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
1239     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
1240                             << "' (via NativeCodeCall)\n");
1241     return std::string(repl);
1242   }
1243   PrintFatalError(loc, "unhandled case when rewriting op");
1244 }
1245 
handleReplaceWithNativeCodeCall(DagNode tree,int depth)1246 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
1247                                                             int depth) {
1248   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
1249   LLVM_DEBUG(tree.print(llvm::dbgs()));
1250   LLVM_DEBUG(llvm::dbgs() << '\n');
1251 
1252   auto fmt = tree.getNativeCodeTemplate();
1253 
1254   SmallVector<std::string, 16> attrs;
1255 
1256   auto tail = getTrailingDirectives(tree);
1257   if (tail.returnType)
1258     PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
1259   auto locToUse = getLocation(tail);
1260 
1261   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1262     if (tree.isNestedDagArg(i)) {
1263       attrs.push_back(
1264           handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
1265     } else {
1266       attrs.push_back(
1267           handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
1268     }
1269     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
1270                             << " replacement: " << attrs[i] << "\n");
1271   }
1272 
1273   std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse),
1274                              static_cast<ArrayRef<std::string>>(attrs));
1275 
1276   // In general, NativeCodeCall without naming binding don't need this. To
1277   // ensure void helper function has been correctly labeled, i.e., use
1278   // NativeCodeCallVoid, we cache the result to a local variable so that we will
1279   // get a compilation error in the auto-generated file.
1280   // Example.
1281   //   // In the td file
1282   //   Pat<(...), (NativeCodeCall<Foo> ...)>
1283   //
1284   //   ---
1285   //
1286   //   // In the auto-generated .cpp
1287   //   ...
1288   //   // Causes compilation error if Foo() returns void.
1289   //   auto nativeVar = Foo();
1290   //   ...
1291   if (tree.getNumReturnsOfNativeCode() != 0) {
1292     // Determine the local variable name for return value.
1293     std::string varName =
1294         SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
1295     if (varName.empty()) {
1296       varName = formatv("nativeVar_{0}", nextValueId++);
1297       // Register the local variable for later uses.
1298       symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
1299     }
1300 
1301     // Catch the return value of helper function.
1302     os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
1303 
1304     if (!tree.getSymbol().empty())
1305       symbol = tree.getSymbol().str();
1306     else
1307       symbol = varName;
1308   }
1309 
1310   return symbol;
1311 }
1312 
getNodeValueCount(DagNode node)1313 int PatternEmitter::getNodeValueCount(DagNode node) {
1314   if (node.isOperation()) {
1315     // If the op is bound to a symbol in the rewrite rule, query its result
1316     // count from the symbol info map.
1317     auto symbol = node.getSymbol();
1318     if (!symbol.empty()) {
1319       return symbolInfoMap.getStaticValueCount(symbol);
1320     }
1321     // Otherwise this is an unbound op; we will use all its results.
1322     return pattern.getDialectOp(node).getNumResults();
1323   }
1324 
1325   if (node.isNativeCodeCall())
1326     return node.getNumReturnsOfNativeCode();
1327 
1328   return 1;
1329 }
1330 
1331 PatternEmitter::TrailingDirectives
getTrailingDirectives(DagNode tree)1332 PatternEmitter::getTrailingDirectives(DagNode tree) {
1333   TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0};
1334 
1335   // Look backwards through the arguments.
1336   auto numPatArgs = tree.getNumArgs();
1337   for (int i = numPatArgs - 1; i >= 0; --i) {
1338     auto dagArg = tree.getArgAsNestedDag(i);
1339     // A leaf is not a directive. Stop looking.
1340     if (!dagArg)
1341       break;
1342 
1343     auto isLocation = dagArg.isLocationDirective();
1344     auto isReturnType = dagArg.isReturnTypeDirective();
1345     // If encountered a DAG node that isn't a trailing directive, stop looking.
1346     if (!(isLocation || isReturnType))
1347       break;
1348     // Save the directive, but error if one of the same type was already
1349     // found.
1350     ++tail.numDirectives;
1351     if (isLocation) {
1352       if (tail.location)
1353         PrintFatalError(loc, "`location` directive can only be specified "
1354                              "once");
1355       tail.location = dagArg;
1356     } else if (isReturnType) {
1357       if (tail.returnType)
1358         PrintFatalError(loc, "`returnType` directive can only be specified "
1359                              "once");
1360       tail.returnType = dagArg;
1361     }
1362   }
1363 
1364   return tail;
1365 }
1366 
1367 std::string
getLocation(PatternEmitter::TrailingDirectives & tail)1368 PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
1369   if (tail.location)
1370     return handleLocationDirective(tail.location);
1371 
1372   // If no explicit location is given, use the default, all fused, location.
1373   return "odsLoc";
1374 }
1375 
handleOpCreation(DagNode tree,int resultIndex,int depth)1376 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
1377                                              int depth) {
1378   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
1379   LLVM_DEBUG(tree.print(llvm::dbgs()));
1380   LLVM_DEBUG(llvm::dbgs() << '\n');
1381 
1382   Operator &resultOp = tree.getDialectOp(opMap);
1383   auto numOpArgs = resultOp.getNumArgs();
1384   auto numPatArgs = tree.getNumArgs();
1385 
1386   auto tail = getTrailingDirectives(tree);
1387   auto locToUse = getLocation(tail);
1388 
1389   auto inPattern = numPatArgs - tail.numDirectives;
1390   if (numOpArgs != inPattern) {
1391     PrintFatalError(loc,
1392                     formatv("resultant op '{0}' argument number mismatch: "
1393                             "{1} in pattern vs. {2} in definition",
1394                             resultOp.getOperationName(), inPattern, numOpArgs));
1395   }
1396 
1397   // A map to collect all nested DAG child nodes' names, with operand index as
1398   // the key. This includes both bound and unbound child nodes.
1399   ChildNodeIndexNameMap childNodeNames;
1400 
1401   // First go through all the child nodes who are nested DAG constructs to
1402   // create ops for them and remember the symbol names for them, so that we can
1403   // use the results in the current node. This happens in a recursive manner.
1404   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1405     if (auto child = tree.getArgAsNestedDag(i))
1406       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
1407   }
1408 
1409   // The name of the local variable holding this op.
1410   std::string valuePackName;
1411   // The symbol for holding the result of this pattern. Note that the result of
1412   // this pattern is not necessarily the same as the variable created by this
1413   // pattern because we can use `__N` suffix to refer only a specific result if
1414   // the generated op is a multi-result op.
1415   std::string resultValue;
1416   if (tree.getSymbol().empty()) {
1417     // No symbol is explicitly bound to this op in the pattern. Generate a
1418     // unique name.
1419     valuePackName = resultValue = getUniqueSymbol(&resultOp);
1420   } else {
1421     resultValue = std::string(tree.getSymbol());
1422     // Strip the index to get the name for the value pack and use it to name the
1423     // local variable for the op.
1424     valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
1425   }
1426 
1427   // Create the local variable for this op.
1428   os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
1429                 valuePackName);
1430 
1431   // Right now ODS don't have general type inference support. Except a few
1432   // special cases listed below, DRR needs to supply types for all results
1433   // when building an op.
1434   bool isSameOperandsAndResultType =
1435       resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
1436   bool useFirstAttr =
1437       resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
1438 
1439   if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
1440     // We know how to deduce the result type for ops with these traits and we've
1441     // generated builders taking aggregate parameters. Use those builders to
1442     // create the ops.
1443 
1444     // First prepare local variables for op arguments used in builder call.
1445     createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1446 
1447     // Then create the op.
1448     os.scope("", "\n}\n").os << formatv(
1449         "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1450         valuePackName, resultOp.getQualCppClassName(), locToUse);
1451     return resultValue;
1452   }
1453 
1454   bool usePartialResults = valuePackName != resultValue;
1455 
1456   if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
1457     // For these cases (broadcastable ops, op results used both as auxiliary
1458     // values and replacement values, ops in nested patterns, auxiliary ops), we
1459     // still need to supply the result types when building the op. But because
1460     // we don't generate a builder automatically with ODS for them, it's the
1461     // developer's responsibility to make sure such a builder (with result type
1462     // deduction ability) exists. We go through the separate-parameter builder
1463     // here given that it's easier for developers to write compared to
1464     // aggregate-parameter builders.
1465     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
1466 
1467     os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
1468                              resultOp.getQualCppClassName(), locToUse);
1469     supplyValuesForOpArgs(tree, childNodeNames, depth);
1470     os << "\n  );\n}\n";
1471     return resultValue;
1472   }
1473 
1474   // If we are provided explicit return types, use them to build the op.
1475   // However, if depth == 0 and resultIndex >= 0, it means we are replacing
1476   // the values generated from the source pattern root op. Then we must use the
1477   // source pattern's value types to determine the value type of the generated
1478   // op here.
1479   if (depth == 0 && resultIndex >= 0 && tail.returnType)
1480     PrintFatalError(loc, "Cannot specify explicit return types in an op whose "
1481                          "return values replace the source pattern's root op");
1482 
1483   // First prepare local variables for op arguments used in builder call.
1484   createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1485 
1486   // Then prepare the result types. We need to specify the types for all
1487   // results.
1488   os.indent() << formatv("::llvm::SmallVector<::mlir::Type, 4> tblgen_types; "
1489                          "(void)tblgen_types;\n");
1490   int numResults = resultOp.getNumResults();
1491   if (tail.returnType) {
1492     auto numRetTys = tail.returnType.getNumArgs();
1493     for (int i = 0; i < numRetTys; ++i) {
1494       auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1);
1495       os << "tblgen_types.push_back(" << varName << ");\n";
1496     }
1497   } else {
1498     if (numResults != 0) {
1499       // Copy the result types from the source pattern.
1500       for (int i = 0; i < numResults; ++i)
1501         os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
1502                       "  tblgen_types.push_back(v.getType());\n}\n",
1503                       resultIndex + i);
1504     }
1505   }
1506   os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1507                 "tblgen_values, tblgen_attrs);\n",
1508                 valuePackName, resultOp.getQualCppClassName(), locToUse);
1509   os.unindent() << "}\n";
1510   return resultValue;
1511 }
1512 
createSeparateLocalVarsForOpArgs(DagNode node,ChildNodeIndexNameMap & childNodeNames)1513 void PatternEmitter::createSeparateLocalVarsForOpArgs(
1514     DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1515   Operator &resultOp = node.getDialectOp(opMap);
1516 
1517   // Now prepare operands used for building this op:
1518   // * If the operand is non-variadic, we create a `Value` local variable.
1519   // * If the operand is variadic, we create a `SmallVector<Value>` local
1520   //   variable.
1521 
1522   int valueIndex = 0; // An index for uniquing local variable names.
1523   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1524     const auto *operand =
1525         resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
1526     // We do not need special handling for attributes.
1527     if (!operand)
1528       continue;
1529 
1530     raw_indented_ostream::DelimitedScope scope(os);
1531     std::string varName;
1532     if (operand->isVariadic()) {
1533       varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
1534       os << formatv("::llvm::SmallVector<::mlir::Value, 4> {0};\n", varName);
1535       std::string range;
1536       if (node.isNestedDagArg(argIndex)) {
1537         range = childNodeNames[argIndex];
1538       } else {
1539         range = std::string(node.getArgName(argIndex));
1540       }
1541       // Resolve the symbol for all range use so that we have a uniform way of
1542       // capturing the values.
1543       range = symbolInfoMap.getValueAndRangeUse(range);
1544       os << formatv("for (auto v: {0}) {{\n  {1}.push_back(v);\n}\n", range,
1545                     varName);
1546     } else {
1547       varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
1548       os << formatv("::mlir::Value {0} = ", varName);
1549       if (node.isNestedDagArg(argIndex)) {
1550         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
1551       } else {
1552         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1553         auto symbol =
1554             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1555         if (leaf.isNativeCodeCall()) {
1556           os << std::string(
1557               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1558         } else {
1559           os << symbol;
1560         }
1561       }
1562       os << ";\n";
1563     }
1564 
1565     // Update to use the newly created local variable for building the op later.
1566     childNodeNames[argIndex] = varName;
1567   }
1568 }
1569 
supplyValuesForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1570 void PatternEmitter::supplyValuesForOpArgs(
1571     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1572   Operator &resultOp = node.getDialectOp(opMap);
1573   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1574        argIndex != numOpArgs; ++argIndex) {
1575     // Start each argument on its own line.
1576     os << ",\n    ";
1577 
1578     Argument opArg = resultOp.getArg(argIndex);
1579     // Handle the case of operand first.
1580     if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
1581       if (!operand->name.empty())
1582         os << "/*" << operand->name << "=*/";
1583       os << childNodeNames.lookup(argIndex);
1584       continue;
1585     }
1586 
1587     // The argument in the op definition.
1588     auto opArgName = resultOp.getArgName(argIndex);
1589     if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1590       if (!subTree.isNativeCodeCall())
1591         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1592                              "for creating attribute");
1593       os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
1594     } else {
1595       auto leaf = node.getArgAsLeaf(argIndex);
1596       // The argument in the result DAG pattern.
1597       auto patArgName = node.getArgName(argIndex);
1598       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1599         // TODO: Refactor out into map to avoid recomputing these.
1600         if (!opArg.is<NamedAttribute *>())
1601           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1602         if (!patArgName.empty())
1603           os << "/*" << patArgName << "=*/";
1604       } else {
1605         os << "/*" << opArgName << "=*/";
1606       }
1607       os << handleOpArgument(leaf, patArgName);
1608     }
1609   }
1610 }
1611 
createAggregateLocalVarsForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1612 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1613     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1614   Operator &resultOp = node.getDialectOp(opMap);
1615 
1616   auto scope = os.scope();
1617   os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
1618                 "tblgen_values; (void)tblgen_values;\n");
1619   os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1620                 "tblgen_attrs; (void)tblgen_attrs;\n");
1621 
1622   const char *addAttrCmd =
1623       "if (auto tmpAttr = {1}) {\n"
1624       "  tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
1625       "tmpAttr);\n}\n";
1626   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1627     if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1628       // The argument in the op definition.
1629       auto opArgName = resultOp.getArgName(argIndex);
1630       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1631         if (!subTree.isNativeCodeCall())
1632           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1633                                "for creating attribute");
1634         os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
1635       } else {
1636         auto leaf = node.getArgAsLeaf(argIndex);
1637         // The argument in the result DAG pattern.
1638         auto patArgName = node.getArgName(argIndex);
1639         os << formatv(addAttrCmd, opArgName,
1640                       handleOpArgument(leaf, patArgName));
1641       }
1642       continue;
1643     }
1644 
1645     const auto *operand =
1646         resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1647     std::string varName;
1648     if (operand->isVariadic()) {
1649       std::string range;
1650       if (node.isNestedDagArg(argIndex)) {
1651         range = childNodeNames.lookup(argIndex);
1652       } else {
1653         range = std::string(node.getArgName(argIndex));
1654       }
1655       // Resolve the symbol for all range use so that we have a uniform way of
1656       // capturing the values.
1657       range = symbolInfoMap.getValueAndRangeUse(range);
1658       os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
1659                     range);
1660     } else {
1661       os << formatv("tblgen_values.push_back(");
1662       if (node.isNestedDagArg(argIndex)) {
1663         os << symbolInfoMap.getValueAndRangeUse(
1664             childNodeNames.lookup(argIndex));
1665       } else {
1666         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1667         if (leaf.isConstantAttr())
1668           // TODO: Use better location
1669           PrintFatalError(
1670               loc,
1671               "attribute found where value was expected, if attempting to use "
1672               "constant value, construct a constant op with given attribute "
1673               "instead");
1674 
1675         auto symbol =
1676             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1677         if (leaf.isNativeCodeCall()) {
1678           os << std::string(
1679               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1680         } else {
1681           os << symbol;
1682         }
1683       }
1684       os << ");\n";
1685     }
1686   }
1687 }
1688 
StaticMatcherHelper(raw_ostream & os,const RecordKeeper & recordKeeper,RecordOperatorMap & mapper)1689 StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
1690                                          const RecordKeeper &recordKeeper,
1691                                          RecordOperatorMap &mapper)
1692     : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {}
1693 
populateStaticMatchers(raw_ostream & os)1694 void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
1695   // PatternEmitter will use the static matcher if there's one generated. To
1696   // ensure that all the dependent static matchers are generated before emitting
1697   // the matching logic of the DagNode, we use topological order to achieve it.
1698   for (auto &dagInfo : topologicalOrder) {
1699     DagNode node = dagInfo.first;
1700     if (!useStaticMatcher(node))
1701       continue;
1702 
1703     std::string funcName =
1704         formatv("static_dag_matcher_{0}", staticMatcherCounter++);
1705     assert(matcherNames.find(node) == matcherNames.end());
1706     PatternEmitter(dagInfo.second, &opMap, os, *this)
1707         .emitStaticMatcher(node, funcName);
1708     matcherNames[node] = funcName;
1709   }
1710 }
1711 
populateStaticConstraintFunctions(raw_ostream & os)1712 void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
1713   staticVerifierEmitter.emitPatternConstraints(constraints.getArrayRef());
1714 }
1715 
addPattern(Record * record)1716 void StaticMatcherHelper::addPattern(Record *record) {
1717   Pattern pat(record, &opMap);
1718 
1719   // While generating the function body of the DAG matcher, it may depends on
1720   // other DAG matchers. To ensure the dependent matchers are ready, we compute
1721   // the topological order for all the DAGs and emit the DAG matchers in this
1722   // order.
1723   llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
1724     ++refStats[node];
1725 
1726     if (refStats[node] != 1)
1727       return;
1728 
1729     for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
1730       if (DagNode sibling = node.getArgAsNestedDag(i))
1731         dfs(sibling);
1732       else {
1733         DagLeaf leaf = node.getArgAsLeaf(i);
1734         if (!leaf.isUnspecified())
1735           constraints.insert(leaf);
1736       }
1737 
1738     topologicalOrder.push_back(std::make_pair(node, record));
1739   };
1740 
1741   dfs(pat.getSourcePattern());
1742 }
1743 
getVerifierName(DagLeaf leaf)1744 StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
1745   if (leaf.isAttrMatcher()) {
1746     Optional<StringRef> constraint =
1747         staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint());
1748     assert(constraint && "attribute constraint was not uniqued");
1749     return *constraint;
1750   }
1751   assert(leaf.isOperandMatcher());
1752   return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint());
1753 }
1754 
emitRewriters(const RecordKeeper & recordKeeper,raw_ostream & os)1755 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1756   emitSourceFileHeader("Rewriters", os);
1757 
1758   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1759 
1760   // We put the map here because it can be shared among multiple patterns.
1761   RecordOperatorMap recordOpMap;
1762 
1763   // Exam all the patterns and generate static matcher for the duplicated
1764   // DagNode.
1765   StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap);
1766   for (Record *p : patterns)
1767     staticMatcher.addPattern(p);
1768   staticMatcher.populateStaticConstraintFunctions(os);
1769   staticMatcher.populateStaticMatchers(os);
1770 
1771   std::vector<std::string> rewriterNames;
1772   rewriterNames.reserve(patterns.size());
1773 
1774   std::string baseRewriterName = "GeneratedConvert";
1775   int rewriterIndex = 0;
1776 
1777   for (Record *p : patterns) {
1778     std::string name;
1779     if (p->isAnonymous()) {
1780       // If no name is provided, ensure unique rewriter names simply by
1781       // appending unique suffix.
1782       name = baseRewriterName + llvm::utostr(rewriterIndex++);
1783     } else {
1784       name = std::string(p->getName());
1785     }
1786     LLVM_DEBUG(llvm::dbgs()
1787                << "=== start generating pattern '" << name << "' ===\n");
1788     PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name);
1789     LLVM_DEBUG(llvm::dbgs()
1790                << "=== done generating pattern '" << name << "' ===\n");
1791     rewriterNames.push_back(std::move(name));
1792   }
1793 
1794   // Emit function to add the generated matchers to the pattern list.
1795   os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
1796         "::mlir::RewritePatternSet &patterns) {\n";
1797   for (const auto &name : rewriterNames) {
1798     os << "  patterns.add<" << name << ">(patterns.getContext());\n";
1799   }
1800   os << "}\n";
1801 }
1802 
1803 static mlir::GenRegistration
1804     genRewriters("gen-rewriters", "Generate pattern rewriters",
__anon7a36dead0702(const RecordKeeper &records, raw_ostream &os) 1805                  [](const RecordKeeper &records, raw_ostream &os) {
1806                    emitRewriters(records, os);
1807                    return false;
1808                  });
1809