1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 #include "../PassDetail.h"
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 using namespace mlir;
22 using namespace mlir::pdl_to_pdl_interp;
23 
24 //===----------------------------------------------------------------------===//
25 // PatternLowering
26 //===----------------------------------------------------------------------===//
27 
28 namespace {
29 /// This class generators operations within the PDL Interpreter dialect from a
30 /// given module containing PDL pattern operations.
31 struct PatternLowering {
32 public:
33   PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule);
34 
35   /// Generate code for matching and rewriting based on the pattern operations
36   /// within the module.
37   void lower(ModuleOp module);
38 
39 private:
40   using ValueMap = llvm::ScopedHashTable<Position *, Value>;
41   using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
42 
43   /// Generate interpreter operations for the tree rooted at the given matcher
44   /// node.
45   Block *generateMatcher(MatcherNode &node);
46 
47   /// Get or create an access to the provided positional value within the
48   /// current block.
49   Value getValueAt(Block *cur, Position *pos);
50 
51   /// Create an interpreter predicate operation, branching to the provided true
52   /// and false destinations.
53   void generatePredicate(Block *currentBlock, Qualifier *question,
54                          Qualifier *answer, Value val, Block *trueDest,
55                          Block *falseDest);
56 
57   /// Create an interpreter switch predicate operation, with a provided default
58   /// and several case destinations.
59   void generateSwitch(SwitchNode *switchNode, Block *currentBlock,
60                       Qualifier *question, Value val, Block *defaultDest);
61 
62   /// Create the interpreter operations to record a successful pattern match.
63   void generateRecordMatch(Block *currentBlock, Block *nextBlock,
64                            pdl::PatternOp pattern);
65 
66   /// Generate a rewriter function for the given pattern operation, and returns
67   /// a reference to that function.
68   SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
69                                  SmallVectorImpl<Position *> &usedMatchValues);
70 
71   /// Generate the rewriter code for the given operation.
72   void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
73                         DenseMap<Value, Value> &rewriteValues,
74                         function_ref<Value(Value)> mapRewriteValue);
75   void generateRewriter(pdl::AttributeOp attrOp,
76                         DenseMap<Value, Value> &rewriteValues,
77                         function_ref<Value(Value)> mapRewriteValue);
78   void generateRewriter(pdl::EraseOp eraseOp,
79                         DenseMap<Value, Value> &rewriteValues,
80                         function_ref<Value(Value)> mapRewriteValue);
81   void generateRewriter(pdl::OperationOp operationOp,
82                         DenseMap<Value, Value> &rewriteValues,
83                         function_ref<Value(Value)> mapRewriteValue);
84   void generateRewriter(pdl::ReplaceOp replaceOp,
85                         DenseMap<Value, Value> &rewriteValues,
86                         function_ref<Value(Value)> mapRewriteValue);
87   void generateRewriter(pdl::ResultOp resultOp,
88                         DenseMap<Value, Value> &rewriteValues,
89                         function_ref<Value(Value)> mapRewriteValue);
90   void generateRewriter(pdl::ResultsOp resultOp,
91                         DenseMap<Value, Value> &rewriteValues,
92                         function_ref<Value(Value)> mapRewriteValue);
93   void generateRewriter(pdl::TypeOp typeOp,
94                         DenseMap<Value, Value> &rewriteValues,
95                         function_ref<Value(Value)> mapRewriteValue);
96   void generateRewriter(pdl::TypesOp typeOp,
97                         DenseMap<Value, Value> &rewriteValues,
98                         function_ref<Value(Value)> mapRewriteValue);
99 
100   /// Generate the values used for resolving the result types of an operation
101   /// created within a dag rewriter region.
102   void generateOperationResultTypeRewriter(
103       pdl::OperationOp op, SmallVectorImpl<Value> &types,
104       DenseMap<Value, Value> &rewriteValues,
105       function_ref<Value(Value)> mapRewriteValue);
106 
107   /// A builder to use when generating interpreter operations.
108   OpBuilder builder;
109 
110   /// The matcher function used for all match related logic within PDL patterns.
111   FuncOp matcherFunc;
112 
113   /// The rewriter module containing the all rewrite related logic within PDL
114   /// patterns.
115   ModuleOp rewriterModule;
116 
117   /// The symbol table of the rewriter module used for insertion.
118   SymbolTable rewriterSymbolTable;
119 
120   /// A scoped map connecting a position with the corresponding interpreter
121   /// value.
122   ValueMap values;
123 
124   /// A stack of blocks used as the failure destination for matcher nodes that
125   /// don't have an explicit failure path.
126   SmallVector<Block *, 8> failureBlockStack;
127 
128   /// A mapping between values defined in a pattern match, and the corresponding
129   /// positional value.
130   DenseMap<Value, Position *> valueToPosition;
131 
132   /// The set of operation values whose whose location will be used for newly
133   /// generated operations.
134   SetVector<Value> locOps;
135 };
136 } // end anonymous namespace
137 
138 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule)
139     : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
140       rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
141 
142 void PatternLowering::lower(ModuleOp module) {
143   PredicateUniquer predicateUniquer;
144   PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
145 
146   // Define top-level scope for the arguments to the matcher function.
147   ValueMapScope topLevelValueScope(values);
148 
149   // Insert the root operation, i.e. argument to the matcher, at the root
150   // position.
151   Block *matcherEntryBlock = matcherFunc.addEntryBlock();
152   values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
153 
154   // Generate a root matcher node from the provided PDL module.
155   std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
156       module, predicateBuilder, valueToPosition);
157   Block *firstMatcherBlock = generateMatcher(*root);
158 
159   // After generation, merged the first matched block into the entry.
160   matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
161                                             firstMatcherBlock->getOperations());
162   firstMatcherBlock->erase();
163 }
164 
165 Block *PatternLowering::generateMatcher(MatcherNode &node) {
166   // Push a new scope for the values used by this matcher.
167   Block *block = matcherFunc.addBlock();
168   ValueMapScope scope(values);
169 
170   // If this is the return node, simply insert the corresponding interpreter
171   // finalize.
172   if (isa<ExitNode>(node)) {
173     builder.setInsertionPointToEnd(block);
174     builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
175     return block;
176   }
177 
178   // If this node contains a position, get the corresponding value for this
179   // block.
180   Position *position = node.getPosition();
181   Value val = position ? getValueAt(block, position) : Value();
182 
183   // Get the next block in the match sequence.
184   std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
185   Block *nextBlock;
186   if (failureNode) {
187     nextBlock = generateMatcher(*failureNode);
188     failureBlockStack.push_back(nextBlock);
189   } else {
190     assert(!failureBlockStack.empty() && "expected valid failure block");
191     nextBlock = failureBlockStack.back();
192   }
193 
194   // If this value corresponds to an operation, record that we are going to use
195   // its location as part of a fused location.
196   bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
197   if (isOperationValue)
198     locOps.insert(val);
199 
200   // Generate code for a boolean predicate node.
201   if (auto *boolNode = dyn_cast<BoolNode>(&node)) {
202     auto *child = generateMatcher(*boolNode->getSuccessNode());
203     generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val,
204                       child, nextBlock);
205 
206     // Generate code for a switch node.
207   } else if (auto *switchNode = dyn_cast<SwitchNode>(&node)) {
208     generateSwitch(switchNode, block, node.getQuestion(), val, nextBlock);
209 
210     // Generate code for a success node.
211   } else if (auto *successNode = dyn_cast<SuccessNode>(&node)) {
212     generateRecordMatch(block, nextBlock, successNode->getPattern());
213   }
214 
215   if (failureNode)
216     failureBlockStack.pop_back();
217   if (isOperationValue)
218     locOps.remove(val);
219   return block;
220 }
221 
222 Value PatternLowering::getValueAt(Block *cur, Position *pos) {
223   if (Value val = values.lookup(pos))
224     return val;
225 
226   // Get the value for the parent position.
227   Value parentVal = getValueAt(cur, pos->getParent());
228 
229   // TODO: Use a location from the position.
230   Location loc = parentVal.getLoc();
231   builder.setInsertionPointToEnd(cur);
232   Value value;
233   switch (pos->getKind()) {
234   case Predicates::OperationPos:
235     value = builder.create<pdl_interp::GetDefiningOpOp>(
236         loc, builder.getType<pdl::OperationType>(), parentVal);
237     break;
238   case Predicates::OperandPos: {
239     auto *operandPos = cast<OperandPosition>(pos);
240     value = builder.create<pdl_interp::GetOperandOp>(
241         loc, builder.getType<pdl::ValueType>(), parentVal,
242         operandPos->getOperandNumber());
243     break;
244   }
245   case Predicates::OperandGroupPos: {
246     auto *operandPos = cast<OperandGroupPosition>(pos);
247     Type valueTy = builder.getType<pdl::ValueType>();
248     value = builder.create<pdl_interp::GetOperandsOp>(
249         loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
250         parentVal, operandPos->getOperandGroupNumber());
251     break;
252   }
253   case Predicates::AttributePos: {
254     auto *attrPos = cast<AttributePosition>(pos);
255     value = builder.create<pdl_interp::GetAttributeOp>(
256         loc, builder.getType<pdl::AttributeType>(), parentVal,
257         attrPos->getName().strref());
258     break;
259   }
260   case Predicates::TypePos: {
261     if (parentVal.getType().isa<pdl::AttributeType>())
262       value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
263     else
264       value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
265     break;
266   }
267   case Predicates::ResultPos: {
268     auto *resPos = cast<ResultPosition>(pos);
269     value = builder.create<pdl_interp::GetResultOp>(
270         loc, builder.getType<pdl::ValueType>(), parentVal,
271         resPos->getResultNumber());
272     break;
273   }
274   case Predicates::ResultGroupPos: {
275     auto *resPos = cast<ResultGroupPosition>(pos);
276     Type valueTy = builder.getType<pdl::ValueType>();
277     value = builder.create<pdl_interp::GetResultsOp>(
278         loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
279         parentVal, resPos->getResultGroupNumber());
280     break;
281   }
282   default:
283     llvm_unreachable("Generating unknown Position getter");
284     break;
285   }
286   values.insert(pos, value);
287   return value;
288 }
289 
290 void PatternLowering::generatePredicate(Block *currentBlock,
291                                         Qualifier *question, Qualifier *answer,
292                                         Value val, Block *trueDest,
293                                         Block *falseDest) {
294   builder.setInsertionPointToEnd(currentBlock);
295   Location loc = val.getLoc();
296   Predicates::Kind kind = question->getKind();
297   switch (kind) {
298   case Predicates::IsNotNullQuestion:
299     builder.create<pdl_interp::IsNotNullOp>(loc, val, trueDest, falseDest);
300     break;
301   case Predicates::OperationNameQuestion: {
302     auto *opNameAnswer = cast<OperationNameAnswer>(answer);
303     builder.create<pdl_interp::CheckOperationNameOp>(
304         loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest);
305     break;
306   }
307   case Predicates::TypeQuestion: {
308     auto *ans = cast<TypeAnswer>(answer);
309     if (val.getType().isa<pdl::RangeType>())
310       builder.create<pdl_interp::CheckTypesOp>(
311           loc, val, ans->getValue().cast<ArrayAttr>(), trueDest, falseDest);
312     else
313       builder.create<pdl_interp::CheckTypeOp>(
314           loc, val, ans->getValue().cast<TypeAttr>(), trueDest, falseDest);
315     break;
316   }
317   case Predicates::AttributeQuestion: {
318     auto *ans = cast<AttributeAnswer>(answer);
319     builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
320                                                  trueDest, falseDest);
321     break;
322   }
323   case Predicates::OperandCountAtLeastQuestion:
324   case Predicates::OperandCountQuestion:
325     builder.create<pdl_interp::CheckOperandCountOp>(
326         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
327         /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
328         trueDest, falseDest);
329     break;
330   case Predicates::ResultCountAtLeastQuestion:
331   case Predicates::ResultCountQuestion:
332     builder.create<pdl_interp::CheckResultCountOp>(
333         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
334         /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
335         trueDest, falseDest);
336     break;
337   case Predicates::EqualToQuestion: {
338     auto *equalToQuestion = cast<EqualToQuestion>(question);
339     builder.create<pdl_interp::AreEqualOp>(
340         loc, val, getValueAt(currentBlock, equalToQuestion->getValue()),
341         trueDest, falseDest);
342     break;
343   }
344   case Predicates::ConstraintQuestion: {
345     auto *cstQuestion = cast<ConstraintQuestion>(question);
346     SmallVector<Value, 2> args;
347     for (Position *position : std::get<1>(cstQuestion->getValue()))
348       args.push_back(getValueAt(currentBlock, position));
349     builder.create<pdl_interp::ApplyConstraintOp>(
350         loc, std::get<0>(cstQuestion->getValue()), args,
351         std::get<2>(cstQuestion->getValue()).cast<ArrayAttr>(), trueDest,
352         falseDest);
353     break;
354   }
355   default:
356     llvm_unreachable("Generating unknown Predicate operation");
357   }
358 }
359 
360 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
361 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
362                            llvm::MapVector<Qualifier *, Block *> &dests) {
363   std::vector<ValT> values;
364   std::vector<Block *> blocks;
365   values.reserve(dests.size());
366   blocks.reserve(dests.size());
367   for (const auto &it : dests) {
368     blocks.push_back(it.second);
369     values.push_back(cast<PredT>(it.first)->getValue());
370   }
371   builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
372 }
373 
374 void PatternLowering::generateSwitch(SwitchNode *switchNode,
375                                      Block *currentBlock, Qualifier *question,
376                                      Value val, Block *defaultDest) {
377   // If the switch question is not an exact answer, i.e. for the `at_least`
378   // cases, we generate a special block sequence.
379   Predicates::Kind kind = question->getKind();
380   if (kind == Predicates::OperandCountAtLeastQuestion ||
381       kind == Predicates::ResultCountAtLeastQuestion) {
382     // Order the children such that the cases are in reverse numerical order.
383     SmallVector<unsigned> sortedChildren(
384         llvm::seq<unsigned>(0, switchNode->getChildren().size()));
385     llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
386       return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
387              cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
388     });
389 
390     // Build the destination for each child using the next highest child as a
391     // a failure destination. This essentially creates the following control
392     // flow:
393     //
394     // if (operand_count < 1)
395     //   goto failure
396     // if (child1.match())
397     //   ...
398     //
399     // if (operand_count < 2)
400     //   goto failure
401     // if (child2.match())
402     //   ...
403     //
404     // failure:
405     //   ...
406     //
407     failureBlockStack.push_back(defaultDest);
408     for (unsigned idx : sortedChildren) {
409       auto &child = switchNode->getChild(idx);
410       Block *childBlock = generateMatcher(*child.second);
411       Block *predicateBlock = builder.createBlock(childBlock);
412       generatePredicate(predicateBlock, question, child.first, val, childBlock,
413                         defaultDest);
414       failureBlockStack.back() = predicateBlock;
415     }
416     Block *firstPredicateBlock = failureBlockStack.pop_back_val();
417     currentBlock->getOperations().splice(currentBlock->end(),
418                                          firstPredicateBlock->getOperations());
419     firstPredicateBlock->erase();
420     return;
421   }
422 
423   // Otherwise, generate each of the children and generate an interpreter
424   // switch.
425   llvm::MapVector<Qualifier *, Block *> children;
426   for (auto &it : switchNode->getChildren())
427     children.insert({it.first, generateMatcher(*it.second)});
428   builder.setInsertionPointToEnd(currentBlock);
429 
430   switch (question->getKind()) {
431   case Predicates::OperandCountQuestion:
432     return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
433                           int32_t>(val, defaultDest, builder, children);
434   case Predicates::ResultCountQuestion:
435     return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
436                           int32_t>(val, defaultDest, builder, children);
437   case Predicates::OperationNameQuestion:
438     return createSwitchOp<pdl_interp::SwitchOperationNameOp,
439                           OperationNameAnswer>(val, defaultDest, builder,
440                                                children);
441   case Predicates::TypeQuestion:
442     if (val.getType().isa<pdl::RangeType>()) {
443       return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
444           val, defaultDest, builder, children);
445     }
446     return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
447         val, defaultDest, builder, children);
448   case Predicates::AttributeQuestion:
449     return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
450         val, defaultDest, builder, children);
451   default:
452     llvm_unreachable("Generating unknown switch predicate.");
453   }
454 }
455 
456 void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock,
457                                           pdl::PatternOp pattern) {
458   // Generate a rewriter for the pattern this success node represents, and track
459   // any values used from the match region.
460   SmallVector<Position *, 8> usedMatchValues;
461   SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
462 
463   // Process any values used in the rewrite that are defined in the match.
464   std::vector<Value> mappedMatchValues;
465   mappedMatchValues.reserve(usedMatchValues.size());
466   for (Position *position : usedMatchValues)
467     mappedMatchValues.push_back(getValueAt(currentBlock, position));
468 
469   // Collect the set of operations generated by the rewriter.
470   SmallVector<StringRef, 4> generatedOps;
471   for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
472     generatedOps.push_back(*op.name());
473   ArrayAttr generatedOpsAttr;
474   if (!generatedOps.empty())
475     generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
476 
477   // Grab the root kind if present.
478   StringAttr rootKindAttr;
479   if (Optional<StringRef> rootKind = pattern.getRootKind())
480     rootKindAttr = builder.getStringAttr(*rootKind);
481 
482   builder.setInsertionPointToEnd(currentBlock);
483   builder.create<pdl_interp::RecordMatchOp>(
484       pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
485       rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
486       nextBlock);
487 }
488 
489 SymbolRefAttr PatternLowering::generateRewriter(
490     pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
491   FuncOp rewriterFunc =
492       FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter",
493                      builder.getFunctionType(llvm::None, llvm::None));
494   rewriterSymbolTable.insert(rewriterFunc);
495 
496   // Generate the rewriter function body.
497   builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock());
498 
499   // Map an input operand of the pattern to a generated interpreter value.
500   DenseMap<Value, Value> rewriteValues;
501   auto mapRewriteValue = [&](Value oldValue) {
502     Value &newValue = rewriteValues[oldValue];
503     if (newValue)
504       return newValue;
505 
506     // Prefer materializing constants directly when possible.
507     Operation *oldOp = oldValue.getDefiningOp();
508     if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
509       if (Attribute value = attrOp.valueAttr()) {
510         return newValue = builder.create<pdl_interp::CreateAttributeOp>(
511                    attrOp.getLoc(), value);
512       }
513     } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
514       if (TypeAttr type = typeOp.typeAttr()) {
515         return newValue = builder.create<pdl_interp::CreateTypeOp>(
516                    typeOp.getLoc(), type);
517       }
518     } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
519       if (ArrayAttr type = typeOp.typesAttr()) {
520         return newValue = builder.create<pdl_interp::CreateTypesOp>(
521                    typeOp.getLoc(), typeOp.getType(), type);
522       }
523     }
524 
525     // Otherwise, add this as an input to the rewriter.
526     Position *inputPos = valueToPosition.lookup(oldValue);
527     assert(inputPos && "expected value to be a pattern input");
528     usedMatchValues.push_back(inputPos);
529     return newValue = rewriterFunc.front().addArgument(oldValue.getType());
530   };
531 
532   // If this is a custom rewriter, simply dispatch to the registered rewrite
533   // method.
534   pdl::RewriteOp rewriter = pattern.getRewriter();
535   if (StringAttr rewriteName = rewriter.nameAttr()) {
536     auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
537     SmallVector<Value, 4> args(1, mapRewriteValue(rewriter.root()));
538     args.append(mappedArgs.begin(), mappedArgs.end());
539     builder.create<pdl_interp::ApplyRewriteOp>(
540         rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args,
541         rewriter.externalConstParamsAttr());
542   } else {
543     // Otherwise this is a dag rewriter defined using PDL operations.
544     for (Operation &rewriteOp : *rewriter.getBody()) {
545       llvm::TypeSwitch<Operation *>(&rewriteOp)
546           .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
547                 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
548                 pdl::TypeOp, pdl::TypesOp>([&](auto op) {
549             this->generateRewriter(op, rewriteValues, mapRewriteValue);
550           });
551     }
552   }
553 
554   // Update the signature of the rewrite function.
555   rewriterFunc.setType(builder.getFunctionType(
556       llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
557       /*results=*/llvm::None));
558 
559   builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
560   return builder.getSymbolRefAttr(
561       pdl_interp::PDLInterpDialect::getRewriterModuleName(),
562       builder.getSymbolRefAttr(rewriterFunc));
563 }
564 
565 void PatternLowering::generateRewriter(
566     pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
567     function_ref<Value(Value)> mapRewriteValue) {
568   SmallVector<Value, 2> arguments;
569   for (Value argument : rewriteOp.args())
570     arguments.push_back(mapRewriteValue(argument));
571   auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
572       rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
573       arguments, rewriteOp.constParamsAttr());
574   for (auto it : llvm::zip(rewriteOp.results(), interpOp.results()))
575     rewriteValues[std::get<0>(it)] = std::get<1>(it);
576 }
577 
578 void PatternLowering::generateRewriter(
579     pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
580     function_ref<Value(Value)> mapRewriteValue) {
581   Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
582       attrOp.getLoc(), attrOp.valueAttr());
583   rewriteValues[attrOp] = newAttr;
584 }
585 
586 void PatternLowering::generateRewriter(
587     pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
588     function_ref<Value(Value)> mapRewriteValue) {
589   builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
590                                       mapRewriteValue(eraseOp.operation()));
591 }
592 
593 void PatternLowering::generateRewriter(
594     pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
595     function_ref<Value(Value)> mapRewriteValue) {
596   SmallVector<Value, 4> operands;
597   for (Value operand : operationOp.operands())
598     operands.push_back(mapRewriteValue(operand));
599 
600   SmallVector<Value, 4> attributes;
601   for (Value attr : operationOp.attributes())
602     attributes.push_back(mapRewriteValue(attr));
603 
604   SmallVector<Value, 2> types;
605   generateOperationResultTypeRewriter(operationOp, types, rewriteValues,
606                                       mapRewriteValue);
607 
608   // Create the new operation.
609   Location loc = operationOp.getLoc();
610   Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
611       loc, *operationOp.name(), types, operands, attributes,
612       operationOp.attributeNames());
613   rewriteValues[operationOp.op()] = createdOp;
614 
615   // Generate accesses for any results that have their types constrained.
616   // Handle the case where there is a single range representing all of the
617   // result types.
618   OperandRange resultTys = operationOp.types();
619   if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
620     Value &type = rewriteValues[resultTys[0]];
621     if (!type) {
622       auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
623       type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
624     }
625     return;
626   }
627 
628   // Otherwise, populate the individual results.
629   bool seenVariableLength = false;
630   Type valueTy = builder.getType<pdl::ValueType>();
631   Type valueRangeTy = pdl::RangeType::get(valueTy);
632   for (auto it : llvm::enumerate(resultTys)) {
633     Value &type = rewriteValues[it.value()];
634     if (type)
635       continue;
636     bool isVariadic = it.value().getType().isa<pdl::RangeType>();
637     seenVariableLength |= isVariadic;
638 
639     // After a variable length result has been seen, we need to use result
640     // groups because the exact index of the result is not statically known.
641     Value resultVal;
642     if (seenVariableLength)
643       resultVal = builder.create<pdl_interp::GetResultsOp>(
644           loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
645     else
646       resultVal = builder.create<pdl_interp::GetResultOp>(
647           loc, valueTy, createdOp, it.index());
648     type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
649   }
650 }
651 
652 void PatternLowering::generateRewriter(
653     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
654     function_ref<Value(Value)> mapRewriteValue) {
655   SmallVector<Value, 4> replOperands;
656 
657   // If the replacement was another operation, get its results. `pdl` allows
658   // for using an operation for simplicitly, but the interpreter isn't as
659   // user facing.
660   if (Value replOp = replaceOp.replOperation()) {
661     // Don't use replace if we know the replaced operation has no results.
662     auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
663     if (!opOp || !opOp.types().empty()) {
664       replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
665           replOp.getLoc(), mapRewriteValue(replOp)));
666     }
667   } else {
668     for (Value operand : replaceOp.replValues())
669       replOperands.push_back(mapRewriteValue(operand));
670   }
671 
672   // If there are no replacement values, just create an erase instead.
673   if (replOperands.empty()) {
674     builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
675                                         mapRewriteValue(replaceOp.operation()));
676     return;
677   }
678 
679   builder.create<pdl_interp::ReplaceOp>(
680       replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
681 }
682 
683 void PatternLowering::generateRewriter(
684     pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
685     function_ref<Value(Value)> mapRewriteValue) {
686   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
687       resultOp.getLoc(), builder.getType<pdl::ValueType>(),
688       mapRewriteValue(resultOp.parent()), resultOp.index());
689 }
690 
691 void PatternLowering::generateRewriter(
692     pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
693     function_ref<Value(Value)> mapRewriteValue) {
694   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
695       resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
696       resultOp.index());
697 }
698 
699 void PatternLowering::generateRewriter(
700     pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
701     function_ref<Value(Value)> mapRewriteValue) {
702   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
703   // type.
704   if (TypeAttr typeAttr = typeOp.typeAttr()) {
705     rewriteValues[typeOp] =
706         builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
707   }
708 }
709 
710 void PatternLowering::generateRewriter(
711     pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
712     function_ref<Value(Value)> mapRewriteValue) {
713   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
714   // type.
715   if (ArrayAttr typeAttr = typeOp.typesAttr()) {
716     rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
717         typeOp.getLoc(), typeOp.getType(), typeAttr);
718   }
719 }
720 
721 void PatternLowering::generateOperationResultTypeRewriter(
722     pdl::OperationOp op, SmallVectorImpl<Value> &types,
723     DenseMap<Value, Value> &rewriteValues,
724     function_ref<Value(Value)> mapRewriteValue) {
725   // Look for an operation that was replaced by `op`. The result types will be
726   // inferred from the results that were replaced.
727   Block *rewriterBlock = op->getBlock();
728   Value replacedOp;
729   for (OpOperand &use : op.op().getUses()) {
730     // Check that the use corresponds to a ReplaceOp and that it is the
731     // replacement value, not the operation being replaced.
732     pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
733     if (!replOpUser || use.getOperandNumber() == 0)
734       continue;
735     // Make sure the replaced operation was defined before this one.
736     Value replOpVal = replOpUser.operation();
737     Operation *replacedOp = replOpVal.getDefiningOp();
738     if (replacedOp->getBlock() == rewriterBlock &&
739         !replacedOp->isBeforeInBlock(op))
740       continue;
741 
742     Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
743         replacedOp->getLoc(), mapRewriteValue(replOpVal));
744     types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
745         replacedOp->getLoc(), replacedOpResults));
746     return;
747   }
748 
749   // Check if the operation has type inference support.
750   if (op.hasTypeInference()) {
751     types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc()));
752     return;
753   }
754 
755   // Otherwise, handle inference for each of the result types individually.
756   OperandRange resultTypeValues = op.types();
757   types.reserve(resultTypeValues.size());
758   for (auto it : llvm::enumerate(resultTypeValues)) {
759     Value resultType = it.value();
760 
761     // Check for an already translated value.
762     if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
763       types.push_back(existingRewriteValue);
764       continue;
765     }
766 
767     // Check for an input from the matcher.
768     if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
769       types.push_back(mapRewriteValue(resultType));
770       continue;
771     }
772 
773     // The verifier asserts that the result types of each pdl.operation can be
774     // inferred. If we reach here, there is a bug either in the logic above or
775     // in the verifier for pdl.operation.
776     op->emitOpError() << "unable to infer result type for operation";
777     llvm_unreachable("unable to infer result type for operation");
778   }
779 }
780 
781 //===----------------------------------------------------------------------===//
782 // Conversion Pass
783 //===----------------------------------------------------------------------===//
784 
785 namespace {
786 struct PDLToPDLInterpPass
787     : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
788   void runOnOperation() final;
789 };
790 } // namespace
791 
792 /// Convert the given module containing PDL pattern operations into a PDL
793 /// Interpreter operations.
794 void PDLToPDLInterpPass::runOnOperation() {
795   ModuleOp module = getOperation();
796 
797   // Create the main matcher function This function contains all of the match
798   // related functionality from patterns in the module.
799   OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
800   FuncOp matcherFunc = builder.create<FuncOp>(
801       module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
802       builder.getFunctionType(builder.getType<pdl::OperationType>(),
803                               /*results=*/llvm::None),
804       /*attrs=*/llvm::None);
805 
806   // Create a nested module to hold the functions invoked for rewriting the IR
807   // after a successful match.
808   ModuleOp rewriterModule = builder.create<ModuleOp>(
809       module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
810 
811   // Generate the code for the patterns within the module.
812   PatternLowering generator(matcherFunc, rewriterModule);
813   generator.lower(module);
814 
815   // After generation, delete all of the pattern operations.
816   for (pdl::PatternOp pattern :
817        llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
818     pattern.erase();
819 }
820 
821 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
822   return std::make_unique<PDLToPDLInterpPass>();
823 }
824