1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 #include "../PassDetail.h"
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::pdl_to_pdl_interp;
25 
26 //===----------------------------------------------------------------------===//
27 // PatternLowering
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 /// This class generators operations within the PDL Interpreter dialect from a
32 /// given module containing PDL pattern operations.
33 struct PatternLowering {
34 public:
35   PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule);
36 
37   /// Generate code for matching and rewriting based on the pattern operations
38   /// within the module.
39   void lower(ModuleOp module);
40 
41 private:
42   using ValueMap = llvm::ScopedHashTable<Position *, Value>;
43   using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
44 
45   /// Generate interpreter operations for the tree rooted at the given matcher
46   /// node, in the specified region.
47   Block *generateMatcher(MatcherNode &node, Region &region);
48 
49   /// Get or create an access to the provided positional value in the current
50   /// block. This operation may mutate the provided block pointer if nested
51   /// regions (i.e., pdl_interp.iterate) are required.
52   Value getValueAt(Block *&currentBlock, Position *pos);
53 
54   /// Create the interpreter predicate operations. This operation may mutate the
55   /// provided current block pointer if nested regions (iterates) are required.
56   void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
57 
58   /// Create the interpreter switch / predicate operations, with several case
59   /// destinations. This operation never mutates the provided current block
60   /// pointer, because the switch operation does not need Values beyond `val`.
61   void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
62 
63   /// Create the interpreter operations to record a successful pattern match
64   /// using the contained root operation. This operation may mutate the current
65   /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
66   void generate(SuccessNode *successNode, Block *&currentBlock);
67 
68   /// Generate a rewriter function for the given pattern operation, and returns
69   /// a reference to that function.
70   SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
71                                  SmallVectorImpl<Position *> &usedMatchValues);
72 
73   /// Generate the rewriter code for the given operation.
74   void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
75                         DenseMap<Value, Value> &rewriteValues,
76                         function_ref<Value(Value)> mapRewriteValue);
77   void generateRewriter(pdl::AttributeOp attrOp,
78                         DenseMap<Value, Value> &rewriteValues,
79                         function_ref<Value(Value)> mapRewriteValue);
80   void generateRewriter(pdl::EraseOp eraseOp,
81                         DenseMap<Value, Value> &rewriteValues,
82                         function_ref<Value(Value)> mapRewriteValue);
83   void generateRewriter(pdl::OperationOp operationOp,
84                         DenseMap<Value, Value> &rewriteValues,
85                         function_ref<Value(Value)> mapRewriteValue);
86   void generateRewriter(pdl::ReplaceOp replaceOp,
87                         DenseMap<Value, Value> &rewriteValues,
88                         function_ref<Value(Value)> mapRewriteValue);
89   void generateRewriter(pdl::ResultOp resultOp,
90                         DenseMap<Value, Value> &rewriteValues,
91                         function_ref<Value(Value)> mapRewriteValue);
92   void generateRewriter(pdl::ResultsOp resultOp,
93                         DenseMap<Value, Value> &rewriteValues,
94                         function_ref<Value(Value)> mapRewriteValue);
95   void generateRewriter(pdl::TypeOp typeOp,
96                         DenseMap<Value, Value> &rewriteValues,
97                         function_ref<Value(Value)> mapRewriteValue);
98   void generateRewriter(pdl::TypesOp typeOp,
99                         DenseMap<Value, Value> &rewriteValues,
100                         function_ref<Value(Value)> mapRewriteValue);
101 
102   /// Generate the values used for resolving the result types of an operation
103   /// created within a dag rewriter region.
104   void generateOperationResultTypeRewriter(
105       pdl::OperationOp op, SmallVectorImpl<Value> &types,
106       DenseMap<Value, Value> &rewriteValues,
107       function_ref<Value(Value)> mapRewriteValue);
108 
109   /// A builder to use when generating interpreter operations.
110   OpBuilder builder;
111 
112   /// The matcher function used for all match related logic within PDL patterns.
113   FuncOp matcherFunc;
114 
115   /// The rewriter module containing the all rewrite related logic within PDL
116   /// patterns.
117   ModuleOp rewriterModule;
118 
119   /// The symbol table of the rewriter module used for insertion.
120   SymbolTable rewriterSymbolTable;
121 
122   /// A scoped map connecting a position with the corresponding interpreter
123   /// value.
124   ValueMap values;
125 
126   /// A stack of blocks used as the failure destination for matcher nodes that
127   /// don't have an explicit failure path.
128   SmallVector<Block *, 8> failureBlockStack;
129 
130   /// A mapping between values defined in a pattern match, and the corresponding
131   /// positional value.
132   DenseMap<Value, Position *> valueToPosition;
133 
134   /// The set of operation values whose whose location will be used for newly
135   /// generated operations.
136   SetVector<Value> locOps;
137 };
138 } // end anonymous namespace
139 
140 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule)
141     : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
142       rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
143 
144 void PatternLowering::lower(ModuleOp module) {
145   PredicateUniquer predicateUniquer;
146   PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
147 
148   // Define top-level scope for the arguments to the matcher function.
149   ValueMapScope topLevelValueScope(values);
150 
151   // Insert the root operation, i.e. argument to the matcher, at the root
152   // position.
153   Block *matcherEntryBlock = matcherFunc.addEntryBlock();
154   values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
155 
156   // Generate a root matcher node from the provided PDL module.
157   std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
158       module, predicateBuilder, valueToPosition);
159   Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
160   assert(failureBlockStack.empty() && "failed to empty the stack");
161 
162   // After generation, merged the first matched block into the entry.
163   matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
164                                             firstMatcherBlock->getOperations());
165   firstMatcherBlock->erase();
166 }
167 
168 Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
169   // Push a new scope for the values used by this matcher.
170   Block *block = &region.emplaceBlock();
171   ValueMapScope scope(values);
172 
173   // If this is the return node, simply insert the corresponding interpreter
174   // finalize.
175   if (isa<ExitNode>(node)) {
176     builder.setInsertionPointToEnd(block);
177     builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
178     return block;
179   }
180 
181   // Get the next block in the match sequence.
182   // This is intentionally executed first, before we get the value for the
183   // position associated with the node, so that we preserve an "there exist"
184   // semantics: if getting a value requires an upward traversal (going from a
185   // value to its consumers), we want to perform the check on all the consumers
186   // before we pass control to the failure node.
187   std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
188   Block *failureBlock;
189   if (failureNode) {
190     failureBlock = generateMatcher(*failureNode, region);
191     failureBlockStack.push_back(failureBlock);
192   } else {
193     assert(!failureBlockStack.empty() && "expected valid failure block");
194     failureBlock = failureBlockStack.back();
195   }
196 
197   // If this node contains a position, get the corresponding value for this
198   // block.
199   Block *currentBlock = block;
200   Position *position = node.getPosition();
201   Value val = position ? getValueAt(currentBlock, position) : Value();
202 
203   // If this value corresponds to an operation, record that we are going to use
204   // its location as part of a fused location.
205   bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
206   if (isOperationValue)
207     locOps.insert(val);
208 
209   // Dispatch to the correct method based on derived node type.
210   TypeSwitch<MatcherNode *>(&node)
211       .Case<BoolNode, SwitchNode>(
212           [&](auto *derivedNode) { generate(derivedNode, currentBlock, val); })
213       .Case([&](SuccessNode *successNode) {
214         generate(successNode, currentBlock);
215       });
216 
217   // Pop all the failure blocks that were inserted due to nesting of
218   // pdl_interp.iterate.
219   while (failureBlockStack.back() != failureBlock) {
220     failureBlockStack.pop_back();
221     assert(!failureBlockStack.empty() && "unable to locate failure block");
222   }
223 
224   // Pop the new failure block.
225   if (failureNode)
226     failureBlockStack.pop_back();
227 
228   if (isOperationValue)
229     locOps.remove(val);
230 
231   return block;
232 }
233 
234 Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
235   if (Value val = values.lookup(pos))
236     return val;
237 
238   // Get the value for the parent position.
239   Value parentVal = getValueAt(currentBlock, pos->getParent());
240 
241   // TODO: Use a location from the position.
242   Location loc = parentVal.getLoc();
243   builder.setInsertionPointToEnd(currentBlock);
244   Value value;
245   switch (pos->getKind()) {
246   case Predicates::OperationPos: {
247     auto *operationPos = cast<OperationPosition>(pos);
248     if (!operationPos->isUpward()) {
249       // Standard (downward) traversal which directly follows the defining op.
250       value = builder.create<pdl_interp::GetDefiningOpOp>(
251           loc, builder.getType<pdl::OperationType>(), parentVal);
252       break;
253     }
254 
255     // The first operation retrieves the representative value of a range.
256     // This applies only when the parent is a range of values.
257     if (parentVal.getType().isa<pdl::RangeType>())
258       value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
259     else
260       value = parentVal;
261 
262     // The second operation retrieves the users.
263     value = builder.create<pdl_interp::GetUsersOp>(loc, value);
264 
265     // The third operation iterates over them.
266     assert(!failureBlockStack.empty() && "expected valid failure block");
267     auto foreach = builder.create<pdl_interp::ForEachOp>(
268         loc, value, failureBlockStack.back(), /*initLoop=*/true);
269     value = foreach.getLoopVariable();
270 
271     // Create the success and continuation blocks.
272     Block *successBlock = builder.createBlock(&foreach.region());
273     Block *continueBlock = builder.createBlock(successBlock);
274     builder.create<pdl_interp::ContinueOp>(loc);
275     failureBlockStack.push_back(continueBlock);
276 
277     // The fourth operation extracts the operand(s) of the user at the specified
278     // index (which can be None, indicating all operands).
279     builder.setInsertionPointToStart(&foreach.region().front());
280     Value operands = builder.create<pdl_interp::GetOperandsOp>(
281         loc, parentVal.getType(), value, operationPos->getIndex());
282 
283     // The fifth operation compares the operands to the parent value / range.
284     builder.create<pdl_interp::AreEqualOp>(loc, parentVal, operands,
285                                            successBlock, continueBlock);
286     currentBlock = successBlock;
287     break;
288   }
289   case Predicates::OperandPos: {
290     auto *operandPos = cast<OperandPosition>(pos);
291     value = builder.create<pdl_interp::GetOperandOp>(
292         loc, builder.getType<pdl::ValueType>(), parentVal,
293         operandPos->getOperandNumber());
294     break;
295   }
296   case Predicates::OperandGroupPos: {
297     auto *operandPos = cast<OperandGroupPosition>(pos);
298     Type valueTy = builder.getType<pdl::ValueType>();
299     value = builder.create<pdl_interp::GetOperandsOp>(
300         loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
301         parentVal, operandPos->getOperandGroupNumber());
302     break;
303   }
304   case Predicates::AttributePos: {
305     auto *attrPos = cast<AttributePosition>(pos);
306     value = builder.create<pdl_interp::GetAttributeOp>(
307         loc, builder.getType<pdl::AttributeType>(), parentVal,
308         attrPos->getName().strref());
309     break;
310   }
311   case Predicates::TypePos: {
312     if (parentVal.getType().isa<pdl::AttributeType>())
313       value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
314     else
315       value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
316     break;
317   }
318   case Predicates::ResultPos: {
319     auto *resPos = cast<ResultPosition>(pos);
320     value = builder.create<pdl_interp::GetResultOp>(
321         loc, builder.getType<pdl::ValueType>(), parentVal,
322         resPos->getResultNumber());
323     break;
324   }
325   case Predicates::ResultGroupPos: {
326     auto *resPos = cast<ResultGroupPosition>(pos);
327     Type valueTy = builder.getType<pdl::ValueType>();
328     value = builder.create<pdl_interp::GetResultsOp>(
329         loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
330         parentVal, resPos->getResultGroupNumber());
331     break;
332   }
333   default:
334     llvm_unreachable("Generating unknown Position getter");
335     break;
336   }
337 
338   values.insert(pos, value);
339   return value;
340 }
341 
342 void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
343                                Value val) {
344   Location loc = val.getLoc();
345   Qualifier *question = boolNode->getQuestion();
346   Qualifier *answer = boolNode->getAnswer();
347   Region *region = currentBlock->getParent();
348 
349   // Execute the getValue queries first, so that we create success
350   // matcher in the correct (possibly nested) region.
351   SmallVector<Value> args;
352   if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
353     args = {getValueAt(currentBlock, equalToQuestion->getValue())};
354   } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
355     for (Position *position : std::get<1>(cstQuestion->getValue()))
356       args.push_back(getValueAt(currentBlock, position));
357   }
358 
359   // Generate the matcher in the current (potentially nested) region
360   // and get the failure successor.
361   Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
362   Block *failure = failureBlockStack.back();
363 
364   // Finally, create the predicate.
365   builder.setInsertionPointToEnd(currentBlock);
366   Predicates::Kind kind = question->getKind();
367   switch (kind) {
368   case Predicates::IsNotNullQuestion:
369     builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
370     break;
371   case Predicates::OperationNameQuestion: {
372     auto *opNameAnswer = cast<OperationNameAnswer>(answer);
373     builder.create<pdl_interp::CheckOperationNameOp>(
374         loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
375     break;
376   }
377   case Predicates::TypeQuestion: {
378     auto *ans = cast<TypeAnswer>(answer);
379     if (val.getType().isa<pdl::RangeType>())
380       builder.create<pdl_interp::CheckTypesOp>(
381           loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
382     else
383       builder.create<pdl_interp::CheckTypeOp>(
384           loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
385     break;
386   }
387   case Predicates::AttributeQuestion: {
388     auto *ans = cast<AttributeAnswer>(answer);
389     builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
390                                                  success, failure);
391     break;
392   }
393   case Predicates::OperandCountAtLeastQuestion:
394   case Predicates::OperandCountQuestion:
395     builder.create<pdl_interp::CheckOperandCountOp>(
396         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
397         /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
398         success, failure);
399     break;
400   case Predicates::ResultCountAtLeastQuestion:
401   case Predicates::ResultCountQuestion:
402     builder.create<pdl_interp::CheckResultCountOp>(
403         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
404         /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
405         success, failure);
406     break;
407   case Predicates::EqualToQuestion: {
408     bool trueAnswer = isa<TrueAnswer>(answer);
409     builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
410                                            trueAnswer ? success : failure,
411                                            trueAnswer ? failure : success);
412     break;
413   }
414   case Predicates::ConstraintQuestion: {
415     auto value = cast<ConstraintQuestion>(question)->getValue();
416     builder.create<pdl_interp::ApplyConstraintOp>(
417         loc, std::get<0>(value), args, std::get<2>(value).cast<ArrayAttr>(),
418         success, failure);
419     break;
420   }
421   default:
422     llvm_unreachable("Generating unknown Predicate operation");
423   }
424 }
425 
426 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
427 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
428                            llvm::MapVector<Qualifier *, Block *> &dests) {
429   std::vector<ValT> values;
430   std::vector<Block *> blocks;
431   values.reserve(dests.size());
432   blocks.reserve(dests.size());
433   for (const auto &it : dests) {
434     blocks.push_back(it.second);
435     values.push_back(cast<PredT>(it.first)->getValue());
436   }
437   builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
438 }
439 
440 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
441                                Value val) {
442   Qualifier *question = switchNode->getQuestion();
443   Region *region = currentBlock->getParent();
444   Block *defaultDest = failureBlockStack.back();
445 
446   // If the switch question is not an exact answer, i.e. for the `at_least`
447   // cases, we generate a special block sequence.
448   Predicates::Kind kind = question->getKind();
449   if (kind == Predicates::OperandCountAtLeastQuestion ||
450       kind == Predicates::ResultCountAtLeastQuestion) {
451     // Order the children such that the cases are in reverse numerical order.
452     SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
453         llvm::seq<unsigned>(0, switchNode->getChildren().size()));
454     llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
455       return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
456              cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
457     });
458 
459     // Build the destination for each child using the next highest child as a
460     // a failure destination. This essentially creates the following control
461     // flow:
462     //
463     // if (operand_count < 1)
464     //   goto failure
465     // if (child1.match())
466     //   ...
467     //
468     // if (operand_count < 2)
469     //   goto failure
470     // if (child2.match())
471     //   ...
472     //
473     // failure:
474     //   ...
475     //
476     failureBlockStack.push_back(defaultDest);
477     Location loc = val.getLoc();
478     for (unsigned idx : sortedChildren) {
479       auto &child = switchNode->getChild(idx);
480       Block *childBlock = generateMatcher(*child.second, *region);
481       Block *predicateBlock = builder.createBlock(childBlock);
482       builder.setInsertionPointToEnd(predicateBlock);
483       unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
484       switch (kind) {
485       case Predicates::OperandCountAtLeastQuestion:
486         builder.create<pdl_interp::CheckOperandCountOp>(
487             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
488         break;
489       case Predicates::ResultCountAtLeastQuestion:
490         builder.create<pdl_interp::CheckResultCountOp>(
491             loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
492         break;
493       default:
494         llvm_unreachable("Generating invalid AtLeast operation");
495       }
496       failureBlockStack.back() = predicateBlock;
497     }
498     Block *firstPredicateBlock = failureBlockStack.pop_back_val();
499     currentBlock->getOperations().splice(currentBlock->end(),
500                                          firstPredicateBlock->getOperations());
501     firstPredicateBlock->erase();
502     return;
503   }
504 
505   // Otherwise, generate each of the children and generate an interpreter
506   // switch.
507   llvm::MapVector<Qualifier *, Block *> children;
508   for (auto &it : switchNode->getChildren())
509     children.insert({it.first, generateMatcher(*it.second, *region)});
510   builder.setInsertionPointToEnd(currentBlock);
511 
512   switch (question->getKind()) {
513   case Predicates::OperandCountQuestion:
514     return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
515                           int32_t>(val, defaultDest, builder, children);
516   case Predicates::ResultCountQuestion:
517     return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
518                           int32_t>(val, defaultDest, builder, children);
519   case Predicates::OperationNameQuestion:
520     return createSwitchOp<pdl_interp::SwitchOperationNameOp,
521                           OperationNameAnswer>(val, defaultDest, builder,
522                                                children);
523   case Predicates::TypeQuestion:
524     if (val.getType().isa<pdl::RangeType>()) {
525       return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
526           val, defaultDest, builder, children);
527     }
528     return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
529         val, defaultDest, builder, children);
530   case Predicates::AttributeQuestion:
531     return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
532         val, defaultDest, builder, children);
533   default:
534     llvm_unreachable("Generating unknown switch predicate.");
535   }
536 }
537 
538 void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
539   pdl::PatternOp pattern = successNode->getPattern();
540   Value root = successNode->getRoot();
541 
542   // Generate a rewriter for the pattern this success node represents, and track
543   // any values used from the match region.
544   SmallVector<Position *, 8> usedMatchValues;
545   SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
546 
547   // Process any values used in the rewrite that are defined in the match.
548   std::vector<Value> mappedMatchValues;
549   mappedMatchValues.reserve(usedMatchValues.size());
550   for (Position *position : usedMatchValues)
551     mappedMatchValues.push_back(getValueAt(currentBlock, position));
552 
553   // Collect the set of operations generated by the rewriter.
554   SmallVector<StringRef, 4> generatedOps;
555   for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
556     generatedOps.push_back(*op.name());
557   ArrayAttr generatedOpsAttr;
558   if (!generatedOps.empty())
559     generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
560 
561   // Grab the root kind if present.
562   StringAttr rootKindAttr;
563   if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
564     if (Optional<StringRef> rootKind = rootOp.name())
565       rootKindAttr = builder.getStringAttr(*rootKind);
566 
567   builder.setInsertionPointToEnd(currentBlock);
568   builder.create<pdl_interp::RecordMatchOp>(
569       pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
570       rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
571       failureBlockStack.back());
572 }
573 
574 SymbolRefAttr PatternLowering::generateRewriter(
575     pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
576   FuncOp rewriterFunc =
577       FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter",
578                      builder.getFunctionType(llvm::None, llvm::None));
579   rewriterSymbolTable.insert(rewriterFunc);
580 
581   // Generate the rewriter function body.
582   builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock());
583 
584   // Map an input operand of the pattern to a generated interpreter value.
585   DenseMap<Value, Value> rewriteValues;
586   auto mapRewriteValue = [&](Value oldValue) {
587     Value &newValue = rewriteValues[oldValue];
588     if (newValue)
589       return newValue;
590 
591     // Prefer materializing constants directly when possible.
592     Operation *oldOp = oldValue.getDefiningOp();
593     if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
594       if (Attribute value = attrOp.valueAttr()) {
595         return newValue = builder.create<pdl_interp::CreateAttributeOp>(
596                    attrOp.getLoc(), value);
597       }
598     } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
599       if (TypeAttr type = typeOp.typeAttr()) {
600         return newValue = builder.create<pdl_interp::CreateTypeOp>(
601                    typeOp.getLoc(), type);
602       }
603     } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
604       if (ArrayAttr type = typeOp.typesAttr()) {
605         return newValue = builder.create<pdl_interp::CreateTypesOp>(
606                    typeOp.getLoc(), typeOp.getType(), type);
607       }
608     }
609 
610     // Otherwise, add this as an input to the rewriter.
611     Position *inputPos = valueToPosition.lookup(oldValue);
612     assert(inputPos && "expected value to be a pattern input");
613     usedMatchValues.push_back(inputPos);
614     return newValue = rewriterFunc.front().addArgument(oldValue.getType());
615   };
616 
617   // If this is a custom rewriter, simply dispatch to the registered rewrite
618   // method.
619   pdl::RewriteOp rewriter = pattern.getRewriter();
620   if (StringAttr rewriteName = rewriter.nameAttr()) {
621     SmallVector<Value> args;
622     if (rewriter.root())
623       args.push_back(mapRewriteValue(rewriter.root()));
624     auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
625     args.append(mappedArgs.begin(), mappedArgs.end());
626     builder.create<pdl_interp::ApplyRewriteOp>(
627         rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args,
628         rewriter.externalConstParamsAttr());
629   } else {
630     // Otherwise this is a dag rewriter defined using PDL operations.
631     for (Operation &rewriteOp : *rewriter.getBody()) {
632       llvm::TypeSwitch<Operation *>(&rewriteOp)
633           .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
634                 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
635                 pdl::TypeOp, pdl::TypesOp>([&](auto op) {
636             this->generateRewriter(op, rewriteValues, mapRewriteValue);
637           });
638     }
639   }
640 
641   // Update the signature of the rewrite function.
642   rewriterFunc.setType(builder.getFunctionType(
643       llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
644       /*results=*/llvm::None));
645 
646   builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
647   return SymbolRefAttr::get(
648       builder.getContext(),
649       pdl_interp::PDLInterpDialect::getRewriterModuleName(),
650       SymbolRefAttr::get(rewriterFunc));
651 }
652 
653 void PatternLowering::generateRewriter(
654     pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
655     function_ref<Value(Value)> mapRewriteValue) {
656   SmallVector<Value, 2> arguments;
657   for (Value argument : rewriteOp.args())
658     arguments.push_back(mapRewriteValue(argument));
659   auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
660       rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
661       arguments, rewriteOp.constParamsAttr());
662   for (auto it : llvm::zip(rewriteOp.results(), interpOp.results()))
663     rewriteValues[std::get<0>(it)] = std::get<1>(it);
664 }
665 
666 void PatternLowering::generateRewriter(
667     pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
668     function_ref<Value(Value)> mapRewriteValue) {
669   Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
670       attrOp.getLoc(), attrOp.valueAttr());
671   rewriteValues[attrOp] = newAttr;
672 }
673 
674 void PatternLowering::generateRewriter(
675     pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
676     function_ref<Value(Value)> mapRewriteValue) {
677   builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
678                                       mapRewriteValue(eraseOp.operation()));
679 }
680 
681 void PatternLowering::generateRewriter(
682     pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
683     function_ref<Value(Value)> mapRewriteValue) {
684   SmallVector<Value, 4> operands;
685   for (Value operand : operationOp.operands())
686     operands.push_back(mapRewriteValue(operand));
687 
688   SmallVector<Value, 4> attributes;
689   for (Value attr : operationOp.attributes())
690     attributes.push_back(mapRewriteValue(attr));
691 
692   SmallVector<Value, 2> types;
693   generateOperationResultTypeRewriter(operationOp, types, rewriteValues,
694                                       mapRewriteValue);
695 
696   // Create the new operation.
697   Location loc = operationOp.getLoc();
698   Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
699       loc, *operationOp.name(), types, operands, attributes,
700       operationOp.attributeNames());
701   rewriteValues[operationOp.op()] = createdOp;
702 
703   // Generate accesses for any results that have their types constrained.
704   // Handle the case where there is a single range representing all of the
705   // result types.
706   OperandRange resultTys = operationOp.types();
707   if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
708     Value &type = rewriteValues[resultTys[0]];
709     if (!type) {
710       auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
711       type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
712     }
713     return;
714   }
715 
716   // Otherwise, populate the individual results.
717   bool seenVariableLength = false;
718   Type valueTy = builder.getType<pdl::ValueType>();
719   Type valueRangeTy = pdl::RangeType::get(valueTy);
720   for (auto it : llvm::enumerate(resultTys)) {
721     Value &type = rewriteValues[it.value()];
722     if (type)
723       continue;
724     bool isVariadic = it.value().getType().isa<pdl::RangeType>();
725     seenVariableLength |= isVariadic;
726 
727     // After a variable length result has been seen, we need to use result
728     // groups because the exact index of the result is not statically known.
729     Value resultVal;
730     if (seenVariableLength)
731       resultVal = builder.create<pdl_interp::GetResultsOp>(
732           loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
733     else
734       resultVal = builder.create<pdl_interp::GetResultOp>(
735           loc, valueTy, createdOp, it.index());
736     type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
737   }
738 }
739 
740 void PatternLowering::generateRewriter(
741     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
742     function_ref<Value(Value)> mapRewriteValue) {
743   SmallVector<Value, 4> replOperands;
744 
745   // If the replacement was another operation, get its results. `pdl` allows
746   // for using an operation for simplicitly, but the interpreter isn't as
747   // user facing.
748   if (Value replOp = replaceOp.replOperation()) {
749     // Don't use replace if we know the replaced operation has no results.
750     auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
751     if (!opOp || !opOp.types().empty()) {
752       replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
753           replOp.getLoc(), mapRewriteValue(replOp)));
754     }
755   } else {
756     for (Value operand : replaceOp.replValues())
757       replOperands.push_back(mapRewriteValue(operand));
758   }
759 
760   // If there are no replacement values, just create an erase instead.
761   if (replOperands.empty()) {
762     builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
763                                         mapRewriteValue(replaceOp.operation()));
764     return;
765   }
766 
767   builder.create<pdl_interp::ReplaceOp>(
768       replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
769 }
770 
771 void PatternLowering::generateRewriter(
772     pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
773     function_ref<Value(Value)> mapRewriteValue) {
774   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
775       resultOp.getLoc(), builder.getType<pdl::ValueType>(),
776       mapRewriteValue(resultOp.parent()), resultOp.index());
777 }
778 
779 void PatternLowering::generateRewriter(
780     pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
781     function_ref<Value(Value)> mapRewriteValue) {
782   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
783       resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
784       resultOp.index());
785 }
786 
787 void PatternLowering::generateRewriter(
788     pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
789     function_ref<Value(Value)> mapRewriteValue) {
790   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
791   // type.
792   if (TypeAttr typeAttr = typeOp.typeAttr()) {
793     rewriteValues[typeOp] =
794         builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
795   }
796 }
797 
798 void PatternLowering::generateRewriter(
799     pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
800     function_ref<Value(Value)> mapRewriteValue) {
801   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
802   // type.
803   if (ArrayAttr typeAttr = typeOp.typesAttr()) {
804     rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
805         typeOp.getLoc(), typeOp.getType(), typeAttr);
806   }
807 }
808 
809 void PatternLowering::generateOperationResultTypeRewriter(
810     pdl::OperationOp op, SmallVectorImpl<Value> &types,
811     DenseMap<Value, Value> &rewriteValues,
812     function_ref<Value(Value)> mapRewriteValue) {
813   // Look for an operation that was replaced by `op`. The result types will be
814   // inferred from the results that were replaced.
815   Block *rewriterBlock = op->getBlock();
816   Value replacedOp;
817   for (OpOperand &use : op.op().getUses()) {
818     // Check that the use corresponds to a ReplaceOp and that it is the
819     // replacement value, not the operation being replaced.
820     pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
821     if (!replOpUser || use.getOperandNumber() == 0)
822       continue;
823     // Make sure the replaced operation was defined before this one.
824     Value replOpVal = replOpUser.operation();
825     Operation *replacedOp = replOpVal.getDefiningOp();
826     if (replacedOp->getBlock() == rewriterBlock &&
827         !replacedOp->isBeforeInBlock(op))
828       continue;
829 
830     Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
831         replacedOp->getLoc(), mapRewriteValue(replOpVal));
832     types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
833         replacedOp->getLoc(), replacedOpResults));
834     return;
835   }
836 
837   // Check if the operation has type inference support.
838   if (op.hasTypeInference()) {
839     types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc()));
840     return;
841   }
842 
843   // Otherwise, handle inference for each of the result types individually.
844   OperandRange resultTypeValues = op.types();
845   types.reserve(resultTypeValues.size());
846   for (auto it : llvm::enumerate(resultTypeValues)) {
847     Value resultType = it.value();
848 
849     // Check for an already translated value.
850     if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
851       types.push_back(existingRewriteValue);
852       continue;
853     }
854 
855     // Check for an input from the matcher.
856     if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
857       types.push_back(mapRewriteValue(resultType));
858       continue;
859     }
860 
861     // The verifier asserts that the result types of each pdl.operation can be
862     // inferred. If we reach here, there is a bug either in the logic above or
863     // in the verifier for pdl.operation.
864     op->emitOpError() << "unable to infer result type for operation";
865     llvm_unreachable("unable to infer result type for operation");
866   }
867 }
868 
869 //===----------------------------------------------------------------------===//
870 // Conversion Pass
871 //===----------------------------------------------------------------------===//
872 
873 namespace {
874 struct PDLToPDLInterpPass
875     : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
876   void runOnOperation() final;
877 };
878 } // namespace
879 
880 /// Convert the given module containing PDL pattern operations into a PDL
881 /// Interpreter operations.
882 void PDLToPDLInterpPass::runOnOperation() {
883   ModuleOp module = getOperation();
884 
885   // Create the main matcher function This function contains all of the match
886   // related functionality from patterns in the module.
887   OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
888   FuncOp matcherFunc = builder.create<FuncOp>(
889       module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
890       builder.getFunctionType(builder.getType<pdl::OperationType>(),
891                               /*results=*/llvm::None),
892       /*attrs=*/llvm::None);
893 
894   // Create a nested module to hold the functions invoked for rewriting the IR
895   // after a successful match.
896   ModuleOp rewriterModule = builder.create<ModuleOp>(
897       module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
898 
899   // Generate the code for the patterns within the module.
900   PatternLowering generator(matcherFunc, rewriterModule);
901   generator.lower(module);
902 
903   // After generation, delete all of the pattern operations.
904   for (pdl::PatternOp pattern :
905        llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
906     pattern.erase();
907 }
908 
909 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
910   return std::make_unique<PDLToPDLInterpPass>();
911 }
912