1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
12 #include "mlir/IR/PatternMatch.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 #include "mlir/Transforms/FoldUtils.h"
16 
17 using namespace mlir;
18 
19 // Native function for testing NativeCodeCall
20 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
21   return choice.getValue() ? input1 : input2;
22 }
23 
24 static void createOpI(PatternRewriter &rewriter, Value input) {
25   rewriter.create<OpI>(rewriter.getUnknownLoc(), input);
26 }
27 
28 static void handleNoResultOp(PatternRewriter &rewriter,
29                              OpSymbolBindingNoResult op) {
30   // Turn the no result op to a one-result op.
31   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
32                                     op.operand());
33 }
34 
35 namespace {
36 #include "TestPatterns.inc"
37 } // end anonymous namespace
38 
39 //===----------------------------------------------------------------------===//
40 // Canonicalizer Driver.
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 struct FoldingPattern : public RewritePattern {
45 public:
46   FoldingPattern(MLIRContext *context)
47       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
48                        /*benefit=*/1, context) {}
49 
50   LogicalResult matchAndRewrite(Operation *op,
51                                 PatternRewriter &rewriter) const override {
52     // Exercice OperationFolder API for a single-result operation that is folded
53     // upon construction. The operation being created through the folder has an
54     // in-place folder, and it should be still present in the output.
55     // Furthermore, the folder should not crash when attempting to recover the
56     // (unchanged) opeation result.
57     OperationFolder folder(op->getContext());
58     Value result = folder.create<TestOpInPlaceFold>(
59         rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
60         rewriter.getI32IntegerAttr(0));
61     assert(result);
62     rewriter.replaceOp(op, result);
63     return success();
64   }
65 };
66 
67 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
68   void runOnFunction() override {
69     mlir::OwningRewritePatternList patterns;
70     populateWithGenerated(&getContext(), &patterns);
71 
72     // Verify named pattern is generated with expected name.
73     patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext());
74 
75     applyPatternsAndFoldGreedily(getFunction(), patterns);
76   }
77 };
78 } // end anonymous namespace
79 
80 //===----------------------------------------------------------------------===//
81 // ReturnType Driver.
82 //===----------------------------------------------------------------------===//
83 
84 namespace {
85 // Generate ops for each instance where the type can be successfully inferred.
86 template <typename OpTy>
87 static void invokeCreateWithInferredReturnType(Operation *op) {
88   auto *context = op->getContext();
89   auto fop = op->getParentOfType<FuncOp>();
90   auto location = UnknownLoc::get(context);
91   OpBuilder b(op);
92   b.setInsertionPointAfter(op);
93 
94   // Use permutations of 2 args as operands.
95   assert(fop.getNumArguments() >= 2);
96   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
97     for (int j = 0; j < e; ++j) {
98       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
99       SmallVector<Type, 2> inferredReturnTypes;
100       if (succeeded(OpTy::inferReturnTypes(
101               context, llvm::None, values, op->getAttrDictionary(),
102               op->getRegions(), inferredReturnTypes))) {
103         OperationState state(location, OpTy::getOperationName());
104         // TODO(jpienaar): Expand to regions.
105         OpTy::build(b, state, values, op->getAttrs());
106         (void)b.createOperation(state);
107       }
108     }
109   }
110 }
111 
112 static void reifyReturnShape(Operation *op) {
113   OpBuilder b(op);
114 
115   // Use permutations of 2 args as operands.
116   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
117   SmallVector<Value, 2> shapes;
118   if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
119     return;
120   for (auto it : llvm::enumerate(shapes))
121     op->emitRemark() << "value " << it.index() << ": "
122                      << it.value().getDefiningOp();
123 }
124 
125 struct TestReturnTypeDriver
126     : public PassWrapper<TestReturnTypeDriver, FunctionPass> {
127   void runOnFunction() override {
128     if (getFunction().getName() == "testCreateFunctions") {
129       std::vector<Operation *> ops;
130       // Collect ops to avoid triggering on inserted ops.
131       for (auto &op : getFunction().getBody().front())
132         ops.push_back(&op);
133       // Generate test patterns for each, but skip terminator.
134       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
135         // Test create method of each of the Op classes below. The resultant
136         // output would be in reverse order underneath `op` from which
137         // the attributes and regions are used.
138         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
139         invokeCreateWithInferredReturnType<
140             OpWithShapedTypeInferTypeInterfaceOp>(op);
141       };
142       return;
143     }
144     if (getFunction().getName() == "testReifyFunctions") {
145       std::vector<Operation *> ops;
146       // Collect ops to avoid triggering on inserted ops.
147       for (auto &op : getFunction().getBody().front())
148         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
149           ops.push_back(&op);
150       // Generate test patterns for each, but skip terminator.
151       for (auto *op : ops)
152         reifyReturnShape(op);
153     }
154   }
155 };
156 } // end anonymous namespace
157 
158 namespace {
159 struct TestDerivedAttributeDriver
160     : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
161   void runOnFunction() override;
162 };
163 } // end anonymous namespace
164 
165 void TestDerivedAttributeDriver::runOnFunction() {
166   getFunction().walk([](DerivedAttributeOpInterface dOp) {
167     auto dAttr = dOp.materializeDerivedAttributes();
168     if (!dAttr)
169       return;
170     for (auto d : dAttr)
171       dOp.emitRemark() << d.first << " = " << d.second;
172   });
173 }
174 
175 //===----------------------------------------------------------------------===//
176 // Legalization Driver.
177 //===----------------------------------------------------------------------===//
178 
179 namespace {
180 //===----------------------------------------------------------------------===//
181 // Region-Block Rewrite Testing
182 
183 /// This pattern is a simple pattern that inlines the first region of a given
184 /// operation into the parent region.
185 struct TestRegionRewriteBlockMovement : public ConversionPattern {
186   TestRegionRewriteBlockMovement(MLIRContext *ctx)
187       : ConversionPattern("test.region", 1, ctx) {}
188 
189   LogicalResult
190   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
191                   ConversionPatternRewriter &rewriter) const final {
192     // Inline this region into the parent region.
193     auto &parentRegion = *op->getParentRegion();
194     if (op->getAttr("legalizer.should_clone"))
195       rewriter.cloneRegionBefore(op->getRegion(0), parentRegion,
196                                  parentRegion.end());
197     else
198       rewriter.inlineRegionBefore(op->getRegion(0), parentRegion,
199                                   parentRegion.end());
200 
201     // Drop this operation.
202     rewriter.eraseOp(op);
203     return success();
204   }
205 };
206 /// This pattern is a simple pattern that generates a region containing an
207 /// illegal operation.
208 struct TestRegionRewriteUndo : public RewritePattern {
209   TestRegionRewriteUndo(MLIRContext *ctx)
210       : RewritePattern("test.region_builder", 1, ctx) {}
211 
212   LogicalResult matchAndRewrite(Operation *op,
213                                 PatternRewriter &rewriter) const final {
214     // Create the region operation with an entry block containing arguments.
215     OperationState newRegion(op->getLoc(), "test.region");
216     newRegion.addRegion();
217     auto *regionOp = rewriter.createOperation(newRegion);
218     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
219     entryBlock->addArgument(rewriter.getIntegerType(64));
220 
221     // Add an explicitly illegal operation to ensure the conversion fails.
222     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
223     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
224 
225     // Drop this operation.
226     rewriter.eraseOp(op);
227     return success();
228   }
229 };
230 /// A simple pattern that creates a block at the end of the parent region of the
231 /// matched operation.
232 struct TestCreateBlock : public RewritePattern {
233   TestCreateBlock(MLIRContext *ctx)
234       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
235 
236   LogicalResult matchAndRewrite(Operation *op,
237                                 PatternRewriter &rewriter) const final {
238     Region &region = *op->getParentRegion();
239     Type i32Type = rewriter.getIntegerType(32);
240     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
241     rewriter.create<TerminatorOp>(op->getLoc());
242     rewriter.replaceOp(op, {});
243     return success();
244   }
245 };
246 
247 /// A simple pattern that creates a block containing an invalid operaiton in
248 /// order to trigger the block creation undo mechanism.
249 struct TestCreateIllegalBlock : public RewritePattern {
250   TestCreateIllegalBlock(MLIRContext *ctx)
251       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
252 
253   LogicalResult matchAndRewrite(Operation *op,
254                                 PatternRewriter &rewriter) const final {
255     Region &region = *op->getParentRegion();
256     Type i32Type = rewriter.getIntegerType(32);
257     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
258     // Create an illegal op to ensure the conversion fails.
259     rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
260     rewriter.create<TerminatorOp>(op->getLoc());
261     rewriter.replaceOp(op, {});
262     return success();
263   }
264 };
265 
266 /// A simple pattern that tests the undo mechanism when replacing the uses of a
267 /// block argument.
268 struct TestUndoBlockArgReplace : public ConversionPattern {
269   TestUndoBlockArgReplace(MLIRContext *ctx)
270       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
271 
272   LogicalResult
273   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
274                   ConversionPatternRewriter &rewriter) const final {
275     auto illegalOp =
276         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
277     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0),
278                                         illegalOp);
279     rewriter.updateRootInPlace(op, [] {});
280     return success();
281   }
282 };
283 
284 //===----------------------------------------------------------------------===//
285 // Type-Conversion Rewrite Testing
286 
287 /// This patterns erases a region operation that has had a type conversion.
288 struct TestDropOpSignatureConversion : public ConversionPattern {
289   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
290       : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) {
291   }
292   LogicalResult
293   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
294                   ConversionPatternRewriter &rewriter) const override {
295     Region &region = op->getRegion(0);
296     Block *entry = &region.front();
297 
298     // Convert the original entry arguments.
299     TypeConverter::SignatureConversion result(entry->getNumArguments());
300     for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i)
301       if (failed(converter.convertSignatureArg(
302               i, entry->getArgument(i).getType(), result)))
303         return failure();
304 
305     // Convert the region signature and just drop the operation.
306     rewriter.applySignatureConversion(&region, result);
307     rewriter.eraseOp(op);
308     return success();
309   }
310 
311   /// The type converter to use when rewriting the signature.
312   TypeConverter &converter;
313 };
314 /// This pattern simply updates the operands of the given operation.
315 struct TestPassthroughInvalidOp : public ConversionPattern {
316   TestPassthroughInvalidOp(MLIRContext *ctx)
317       : ConversionPattern("test.invalid", 1, ctx) {}
318   LogicalResult
319   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
320                   ConversionPatternRewriter &rewriter) const final {
321     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
322                                              llvm::None);
323     return success();
324   }
325 };
326 /// This pattern handles the case of a split return value.
327 struct TestSplitReturnType : public ConversionPattern {
328   TestSplitReturnType(MLIRContext *ctx)
329       : ConversionPattern("test.return", 1, ctx) {}
330   LogicalResult
331   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
332                   ConversionPatternRewriter &rewriter) const final {
333     // Check for a return of F32.
334     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
335       return failure();
336 
337     // Check if the first operation is a cast operation, if it is we use the
338     // results directly.
339     auto *defOp = operands[0].getDefiningOp();
340     if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
341       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
342       return success();
343     }
344 
345     // Otherwise, fail to match.
346     return failure();
347   }
348 };
349 
350 //===----------------------------------------------------------------------===//
351 // Multi-Level Type-Conversion Rewrite Testing
352 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
353   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
354       : ConversionPattern("test.type_producer", 1, ctx) {}
355   LogicalResult
356   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
357                   ConversionPatternRewriter &rewriter) const final {
358     // If the type is I32, change the type to F32.
359     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
360       return failure();
361     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
362     return success();
363   }
364 };
365 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
366   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
367       : ConversionPattern("test.type_producer", 1, ctx) {}
368   LogicalResult
369   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
370                   ConversionPatternRewriter &rewriter) const final {
371     // If the type is F32, change the type to F64.
372     if (!Type(*op->result_type_begin()).isF32())
373       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
374     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
375     return success();
376   }
377 };
378 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
379   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
380       : ConversionPattern("test.type_producer", 10, ctx) {}
381   LogicalResult
382   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
383                   ConversionPatternRewriter &rewriter) const final {
384     // Always convert to B16, even though it is not a legal type. This tests
385     // that values are unmapped correctly.
386     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
387     return success();
388   }
389 };
390 struct TestUpdateConsumerType : public ConversionPattern {
391   TestUpdateConsumerType(MLIRContext *ctx)
392       : ConversionPattern("test.type_consumer", 1, ctx) {}
393   LogicalResult
394   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
395                   ConversionPatternRewriter &rewriter) const final {
396     // Verify that the incoming operand has been successfully remapped to F64.
397     if (!operands[0].getType().isF64())
398       return failure();
399     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
400     return success();
401   }
402 };
403 
404 //===----------------------------------------------------------------------===//
405 // Non-Root Replacement Rewrite Testing
406 /// This pattern generates an invalid operation, but replaces it before the
407 /// pattern is finished. This checks that we don't need to legalize the
408 /// temporary op.
409 struct TestNonRootReplacement : public RewritePattern {
410   TestNonRootReplacement(MLIRContext *ctx)
411       : RewritePattern("test.replace_non_root", 1, ctx) {}
412 
413   LogicalResult matchAndRewrite(Operation *op,
414                                 PatternRewriter &rewriter) const final {
415     auto resultType = *op->result_type_begin();
416     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
417     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
418 
419     rewriter.replaceOp(illegalOp, {legalOp});
420     rewriter.replaceOp(op, {illegalOp});
421     return success();
422   }
423 };
424 
425 //===----------------------------------------------------------------------===//
426 // Recursive Rewrite Testing
427 /// This pattern is applied to the same operation multiple times, but has a
428 /// bounded recursion.
429 struct TestBoundedRecursiveRewrite
430     : public OpRewritePattern<TestRecursiveRewriteOp> {
431   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
432 
433   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
434                                 PatternRewriter &rewriter) const final {
435     // Decrement the depth of the op in-place.
436     rewriter.updateRootInPlace(op, [&] {
437       op.setAttr("depth",
438                  rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1));
439     });
440     return success();
441   }
442 
443   /// The conversion target handles bounding the recursion of this pattern.
444   bool hasBoundedRewriteRecursion() const final { return true; }
445 };
446 } // namespace
447 
448 namespace {
449 struct TestTypeConverter : public TypeConverter {
450   using TypeConverter::TypeConverter;
451   TestTypeConverter() { addConversion(convertType); }
452 
453   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
454     // Drop I16 types.
455     if (t.isSignlessInteger(16))
456       return success();
457 
458     // Convert I64 to F64.
459     if (t.isSignlessInteger(64)) {
460       results.push_back(FloatType::getF64(t.getContext()));
461       return success();
462     }
463 
464     // Split F32 into F16,F16.
465     if (t.isF32()) {
466       results.assign(2, FloatType::getF16(t.getContext()));
467       return success();
468     }
469 
470     // Otherwise, convert the type directly.
471     results.push_back(t);
472     return success();
473   }
474 
475   /// Override the hook to materialize a conversion. This is necessary because
476   /// we generate 1->N type mappings.
477   Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
478                                    ArrayRef<Value> inputs,
479                                    Location loc) override {
480     return rewriter.create<TestCastOp>(loc, resultType, inputs);
481   }
482 };
483 
484 struct TestLegalizePatternDriver
485     : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
486   /// The mode of conversion to use with the driver.
487   enum class ConversionMode { Analysis, Full, Partial };
488 
489   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
490 
491   void runOnOperation() override {
492     TestTypeConverter converter;
493     mlir::OwningRewritePatternList patterns;
494     populateWithGenerated(&getContext(), &patterns);
495     patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
496                     TestCreateBlock, TestCreateIllegalBlock,
497                     TestUndoBlockArgReplace, TestPassthroughInvalidOp,
498                     TestSplitReturnType, TestChangeProducerTypeI32ToF32,
499                     TestChangeProducerTypeF32ToF64,
500                     TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
501                     TestNonRootReplacement, TestBoundedRecursiveRewrite>(
502         &getContext());
503     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
504     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
505                                               converter);
506     mlir::populateCallOpTypeConversionPattern(patterns, &getContext(),
507                                               converter);
508 
509     // Define the conversion target used for the test.
510     ConversionTarget target(getContext());
511     target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
512     target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
513                       TerminatorOp>();
514     target
515         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
516     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
517       // Don't allow F32 operands.
518       return llvm::none_of(op.getOperandTypes(),
519                            [](Type type) { return type.isF32(); });
520     });
521     target.addDynamicallyLegalOp<FuncOp>(
522         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
523 
524     // Expect the type_producer/type_consumer operations to only operate on f64.
525     target.addDynamicallyLegalOp<TestTypeProducerOp>(
526         [](TestTypeProducerOp op) { return op.getType().isF64(); });
527     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
528       return op.getOperand().getType().isF64();
529     });
530 
531     // Check support for marking certain operations as recursively legal.
532     target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
533       return static_cast<bool>(
534           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
535     });
536 
537     // Mark the bound recursion operation as dynamically legal.
538     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
539         [](TestRecursiveRewriteOp op) { return op.depth() == 0; });
540 
541     // Handle a partial conversion.
542     if (mode == ConversionMode::Partial) {
543       DenseSet<Operation *> unlegalizedOps;
544       (void)applyPartialConversion(getOperation(), target, patterns, &converter,
545                                    &unlegalizedOps);
546       // Emit remarks for each legalizable operation.
547       for (auto *op : unlegalizedOps)
548         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
549       return;
550     }
551 
552     // Handle a full conversion.
553     if (mode == ConversionMode::Full) {
554       // Check support for marking unknown operations as dynamically legal.
555       target.markUnknownOpDynamicallyLegal([](Operation *op) {
556         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
557       });
558 
559       (void)applyFullConversion(getOperation(), target, patterns, &converter);
560       return;
561     }
562 
563     // Otherwise, handle an analysis conversion.
564     assert(mode == ConversionMode::Analysis);
565 
566     // Analyze the convertible operations.
567     DenseSet<Operation *> legalizedOps;
568     if (failed(applyAnalysisConversion(getOperation(), target, patterns,
569                                        legalizedOps, &converter)))
570       return signalPassFailure();
571 
572     // Emit remarks for each legalizable operation.
573     for (auto *op : legalizedOps)
574       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
575   }
576 
577   /// The mode of conversion to use.
578   ConversionMode mode;
579 };
580 } // end anonymous namespace
581 
582 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
583     legalizerConversionMode(
584         "test-legalize-mode",
585         llvm::cl::desc("The legalization mode to use with the test driver"),
586         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
587         llvm::cl::values(
588             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
589                        "analysis", "Perform an analysis conversion"),
590             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
591                        "Perform a full conversion"),
592             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
593                        "partial", "Perform a partial conversion")));
594 
595 //===----------------------------------------------------------------------===//
596 // ConversionPatternRewriter::getRemappedValue testing. This method is used
597 // to get the remapped value of an original value that was replaced using
598 // ConversionPatternRewriter.
599 namespace {
600 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
601 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
602 /// operand twice.
603 ///
604 /// Example:
605 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
606 /// is replaced with:
607 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
608 struct OneVResOneVOperandOp1Converter
609     : public OpConversionPattern<OneVResOneVOperandOp1> {
610   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
611 
612   LogicalResult
613   matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
614                   ConversionPatternRewriter &rewriter) const override {
615     auto origOps = op.getOperands();
616     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
617            "One operand expected");
618     Value origOp = *origOps.begin();
619     SmallVector<Value, 2> remappedOperands;
620     // Replicate the remapped original operand twice. Note that we don't used
621     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
622     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
623     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
624 
625     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
626                                                        remappedOperands);
627     return success();
628   }
629 };
630 
631 struct TestRemappedValue
632     : public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
633   void runOnFunction() override {
634     mlir::OwningRewritePatternList patterns;
635     patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
636 
637     mlir::ConversionTarget target(getContext());
638     target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
639     // We make OneVResOneVOperandOp1 legal only when it has more that one
640     // operand. This will trigger the conversion that will replace one-operand
641     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
642     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
643         [](Operation *op) -> bool {
644           return std::distance(op->operand_begin(), op->operand_end()) > 1;
645         });
646 
647     if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) {
648       signalPassFailure();
649     }
650   }
651 };
652 } // end anonymous namespace
653 
654 namespace mlir {
655 void registerPatternsTestPass() {
656   mlir::PassRegistration<TestReturnTypeDriver>("test-return-type",
657                                                "Run return type functions");
658 
659   mlir::PassRegistration<TestDerivedAttributeDriver>(
660       "test-derived-attr", "Run test derived attributes");
661 
662   mlir::PassRegistration<TestPatternDriver>("test-patterns",
663                                             "Run test dialect patterns");
664 
665   mlir::PassRegistration<TestLegalizePatternDriver>(
666       "test-legalize-patterns", "Run test dialect legalization patterns", [] {
667         return std::make_unique<TestLegalizePatternDriver>(
668             legalizerConversionMode);
669       });
670 
671   PassRegistration<TestRemappedValue>(
672       "test-remapped-value",
673       "Test public remapped value mechanism in ConversionPatternRewriter");
674 }
675 } // namespace mlir
676