1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
12 #include "mlir/IR/PatternMatch.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 #include "mlir/Transforms/FoldUtils.h"
16 
17 using namespace mlir;
18 
19 // Native function for testing NativeCodeCall
20 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
21   return choice.getValue() ? input1 : input2;
22 }
23 
24 static void createOpI(PatternRewriter &rewriter, Value input) {
25   rewriter.create<OpI>(rewriter.getUnknownLoc(), input);
26 }
27 
28 static void handleNoResultOp(PatternRewriter &rewriter,
29                              OpSymbolBindingNoResult op) {
30   // Turn the no result op to a one-result op.
31   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
32                                     op.operand());
33 }
34 
35 namespace {
36 #include "TestPatterns.inc"
37 } // end anonymous namespace
38 
39 //===----------------------------------------------------------------------===//
40 // Canonicalizer Driver.
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 struct FoldingPattern : public RewritePattern {
45 public:
46   FoldingPattern(MLIRContext *context)
47       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
48                        /*benefit=*/1, context) {}
49 
50   LogicalResult matchAndRewrite(Operation *op,
51                                 PatternRewriter &rewriter) const override {
52     // Exercice OperationFolder API for a single-result operation that is folded
53     // upon construction. The operation being created through the folder has an
54     // in-place folder, and it should be still present in the output.
55     // Furthermore, the folder should not crash when attempting to recover the
56     // (unchanged) opeation result.
57     OperationFolder folder(op->getContext());
58     Value result = folder.create<TestOpInPlaceFold>(
59         rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
60         rewriter.getI32IntegerAttr(0));
61     assert(result);
62     rewriter.replaceOp(op, result);
63     return success();
64   }
65 };
66 
67 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
68   void runOnFunction() override {
69     mlir::OwningRewritePatternList patterns;
70     populateWithGenerated(&getContext(), &patterns);
71 
72     // Verify named pattern is generated with expected name.
73     patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext());
74 
75     applyPatternsAndFoldGreedily(getFunction(), patterns);
76   }
77 };
78 } // end anonymous namespace
79 
80 //===----------------------------------------------------------------------===//
81 // ReturnType Driver.
82 //===----------------------------------------------------------------------===//
83 
84 namespace {
85 // Generate ops for each instance where the type can be successfully inferred.
86 template <typename OpTy>
87 static void invokeCreateWithInferredReturnType(Operation *op) {
88   auto *context = op->getContext();
89   auto fop = op->getParentOfType<FuncOp>();
90   auto location = UnknownLoc::get(context);
91   OpBuilder b(op);
92   b.setInsertionPointAfter(op);
93 
94   // Use permutations of 2 args as operands.
95   assert(fop.getNumArguments() >= 2);
96   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
97     for (int j = 0; j < e; ++j) {
98       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
99       SmallVector<Type, 2> inferredReturnTypes;
100       if (succeeded(OpTy::inferReturnTypes(
101               context, llvm::None, values, op->getAttrDictionary(),
102               op->getRegions(), inferredReturnTypes))) {
103         OperationState state(location, OpTy::getOperationName());
104         // TODO(jpienaar): Expand to regions.
105         OpTy::build(b, state, values, op->getAttrs());
106         (void)b.createOperation(state);
107       }
108     }
109   }
110 }
111 
112 static void reifyReturnShape(Operation *op) {
113   OpBuilder b(op);
114 
115   // Use permutations of 2 args as operands.
116   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
117   SmallVector<Value, 2> shapes;
118   if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
119     return;
120   for (auto it : llvm::enumerate(shapes))
121     op->emitRemark() << "value " << it.index() << ": "
122                      << it.value().getDefiningOp();
123 }
124 
125 struct TestReturnTypeDriver
126     : public PassWrapper<TestReturnTypeDriver, FunctionPass> {
127   void runOnFunction() override {
128     if (getFunction().getName() == "testCreateFunctions") {
129       std::vector<Operation *> ops;
130       // Collect ops to avoid triggering on inserted ops.
131       for (auto &op : getFunction().getBody().front())
132         ops.push_back(&op);
133       // Generate test patterns for each, but skip terminator.
134       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
135         // Test create method of each of the Op classes below. The resultant
136         // output would be in reverse order underneath `op` from which
137         // the attributes and regions are used.
138         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
139         invokeCreateWithInferredReturnType<
140             OpWithShapedTypeInferTypeInterfaceOp>(op);
141       };
142       return;
143     }
144     if (getFunction().getName() == "testReifyFunctions") {
145       std::vector<Operation *> ops;
146       // Collect ops to avoid triggering on inserted ops.
147       for (auto &op : getFunction().getBody().front())
148         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
149           ops.push_back(&op);
150       // Generate test patterns for each, but skip terminator.
151       for (auto *op : ops)
152         reifyReturnShape(op);
153     }
154   }
155 };
156 } // end anonymous namespace
157 
158 namespace {
159 struct TestDerivedAttributeDriver
160     : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
161   void runOnFunction() override;
162 };
163 } // end anonymous namespace
164 
165 void TestDerivedAttributeDriver::runOnFunction() {
166   getFunction().walk([](DerivedAttributeOpInterface dOp) {
167     auto dAttr = dOp.materializeDerivedAttributes();
168     if (!dAttr)
169       return;
170     for (auto d : dAttr)
171       dOp.emitRemark() << d.first << " = " << d.second;
172   });
173 }
174 
175 //===----------------------------------------------------------------------===//
176 // Legalization Driver.
177 //===----------------------------------------------------------------------===//
178 
179 namespace {
180 //===----------------------------------------------------------------------===//
181 // Region-Block Rewrite Testing
182 
183 /// This pattern is a simple pattern that inlines the first region of a given
184 /// operation into the parent region.
185 struct TestRegionRewriteBlockMovement : public ConversionPattern {
186   TestRegionRewriteBlockMovement(MLIRContext *ctx)
187       : ConversionPattern("test.region", 1, ctx) {}
188 
189   LogicalResult
190   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
191                   ConversionPatternRewriter &rewriter) const final {
192     // Inline this region into the parent region.
193     auto &parentRegion = *op->getParentRegion();
194     if (op->getAttr("legalizer.should_clone"))
195       rewriter.cloneRegionBefore(op->getRegion(0), parentRegion,
196                                  parentRegion.end());
197     else
198       rewriter.inlineRegionBefore(op->getRegion(0), parentRegion,
199                                   parentRegion.end());
200 
201     // Drop this operation.
202     rewriter.eraseOp(op);
203     return success();
204   }
205 };
206 /// This pattern is a simple pattern that generates a region containing an
207 /// illegal operation.
208 struct TestRegionRewriteUndo : public RewritePattern {
209   TestRegionRewriteUndo(MLIRContext *ctx)
210       : RewritePattern("test.region_builder", 1, ctx) {}
211 
212   LogicalResult matchAndRewrite(Operation *op,
213                                 PatternRewriter &rewriter) const final {
214     // Create the region operation with an entry block containing arguments.
215     OperationState newRegion(op->getLoc(), "test.region");
216     newRegion.addRegion();
217     auto *regionOp = rewriter.createOperation(newRegion);
218     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
219     entryBlock->addArgument(rewriter.getIntegerType(64));
220 
221     // Add an explicitly illegal operation to ensure the conversion fails.
222     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
223     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
224 
225     // Drop this operation.
226     rewriter.eraseOp(op);
227     return success();
228   }
229 };
230 /// A simple pattern that creates a block at the end of the parent region of the
231 /// matched operation.
232 struct TestCreateBlock : public RewritePattern {
233   TestCreateBlock(MLIRContext *ctx)
234       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
235 
236   LogicalResult matchAndRewrite(Operation *op,
237                                 PatternRewriter &rewriter) const final {
238     Region &region = *op->getParentRegion();
239     Type i32Type = rewriter.getIntegerType(32);
240     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
241     rewriter.create<TerminatorOp>(op->getLoc());
242     rewriter.replaceOp(op, {});
243     return success();
244   }
245 };
246 
247 /// A simple pattern that creates a block containing an invalid operaiton in
248 /// order to trigger the block creation undo mechanism.
249 struct TestCreateIllegalBlock : public RewritePattern {
250   TestCreateIllegalBlock(MLIRContext *ctx)
251       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
252 
253   LogicalResult matchAndRewrite(Operation *op,
254                                 PatternRewriter &rewriter) const final {
255     Region &region = *op->getParentRegion();
256     Type i32Type = rewriter.getIntegerType(32);
257     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
258     // Create an illegal op to ensure the conversion fails.
259     rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
260     rewriter.create<TerminatorOp>(op->getLoc());
261     rewriter.replaceOp(op, {});
262     return success();
263   }
264 };
265 
266 /// A simple pattern that tests the undo mechanism when replacing the uses of a
267 /// block argument.
268 struct TestUndoBlockArgReplace : public ConversionPattern {
269   TestUndoBlockArgReplace(MLIRContext *ctx)
270       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
271 
272   LogicalResult
273   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
274                   ConversionPatternRewriter &rewriter) const final {
275     auto illegalOp =
276         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
277     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0),
278                                         illegalOp);
279     rewriter.updateRootInPlace(op, [] {});
280     return success();
281   }
282 };
283 
284 /// A rewrite pattern that tests the undo mechanism when erasing a block.
285 struct TestUndoBlockErase : public ConversionPattern {
286   TestUndoBlockErase(MLIRContext *ctx)
287       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
288 
289   LogicalResult
290   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
291                   ConversionPatternRewriter &rewriter) const final {
292     Block *secondBlock = &*std::next(op->getRegion(0).begin());
293     rewriter.setInsertionPointToStart(secondBlock);
294     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
295     rewriter.eraseBlock(secondBlock);
296     rewriter.updateRootInPlace(op, [] {});
297     return success();
298   }
299 };
300 
301 //===----------------------------------------------------------------------===//
302 // Type-Conversion Rewrite Testing
303 
304 /// This patterns erases a region operation that has had a type conversion.
305 struct TestDropOpSignatureConversion : public ConversionPattern {
306   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
307       : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) {
308   }
309   LogicalResult
310   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
311                   ConversionPatternRewriter &rewriter) const override {
312     Region &region = op->getRegion(0);
313     Block *entry = &region.front();
314 
315     // Convert the original entry arguments.
316     TypeConverter::SignatureConversion result(entry->getNumArguments());
317     if (failed(
318             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
319       return failure();
320 
321     // Convert the region signature and just drop the operation.
322     rewriter.applySignatureConversion(&region, result);
323     rewriter.eraseOp(op);
324     return success();
325   }
326 
327   /// The type converter to use when rewriting the signature.
328   TypeConverter &converter;
329 };
330 /// This pattern simply updates the operands of the given operation.
331 struct TestPassthroughInvalidOp : public ConversionPattern {
332   TestPassthroughInvalidOp(MLIRContext *ctx)
333       : ConversionPattern("test.invalid", 1, ctx) {}
334   LogicalResult
335   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
336                   ConversionPatternRewriter &rewriter) const final {
337     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
338                                              llvm::None);
339     return success();
340   }
341 };
342 /// This pattern handles the case of a split return value.
343 struct TestSplitReturnType : public ConversionPattern {
344   TestSplitReturnType(MLIRContext *ctx)
345       : ConversionPattern("test.return", 1, ctx) {}
346   LogicalResult
347   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
348                   ConversionPatternRewriter &rewriter) const final {
349     // Check for a return of F32.
350     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
351       return failure();
352 
353     // Check if the first operation is a cast operation, if it is we use the
354     // results directly.
355     auto *defOp = operands[0].getDefiningOp();
356     if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
357       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
358       return success();
359     }
360 
361     // Otherwise, fail to match.
362     return failure();
363   }
364 };
365 
366 //===----------------------------------------------------------------------===//
367 // Multi-Level Type-Conversion Rewrite Testing
368 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
369   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
370       : ConversionPattern("test.type_producer", 1, ctx) {}
371   LogicalResult
372   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
373                   ConversionPatternRewriter &rewriter) const final {
374     // If the type is I32, change the type to F32.
375     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
376       return failure();
377     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
378     return success();
379   }
380 };
381 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
382   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
383       : ConversionPattern("test.type_producer", 1, ctx) {}
384   LogicalResult
385   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
386                   ConversionPatternRewriter &rewriter) const final {
387     // If the type is F32, change the type to F64.
388     if (!Type(*op->result_type_begin()).isF32())
389       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
390     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
391     return success();
392   }
393 };
394 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
395   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
396       : ConversionPattern("test.type_producer", 10, ctx) {}
397   LogicalResult
398   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
399                   ConversionPatternRewriter &rewriter) const final {
400     // Always convert to B16, even though it is not a legal type. This tests
401     // that values are unmapped correctly.
402     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
403     return success();
404   }
405 };
406 struct TestUpdateConsumerType : public ConversionPattern {
407   TestUpdateConsumerType(MLIRContext *ctx)
408       : ConversionPattern("test.type_consumer", 1, ctx) {}
409   LogicalResult
410   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
411                   ConversionPatternRewriter &rewriter) const final {
412     // Verify that the incoming operand has been successfully remapped to F64.
413     if (!operands[0].getType().isF64())
414       return failure();
415     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
416     return success();
417   }
418 };
419 
420 //===----------------------------------------------------------------------===//
421 // Non-Root Replacement Rewrite Testing
422 /// This pattern generates an invalid operation, but replaces it before the
423 /// pattern is finished. This checks that we don't need to legalize the
424 /// temporary op.
425 struct TestNonRootReplacement : public RewritePattern {
426   TestNonRootReplacement(MLIRContext *ctx)
427       : RewritePattern("test.replace_non_root", 1, ctx) {}
428 
429   LogicalResult matchAndRewrite(Operation *op,
430                                 PatternRewriter &rewriter) const final {
431     auto resultType = *op->result_type_begin();
432     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
433     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
434 
435     rewriter.replaceOp(illegalOp, {legalOp});
436     rewriter.replaceOp(op, {illegalOp});
437     return success();
438   }
439 };
440 
441 //===----------------------------------------------------------------------===//
442 // Recursive Rewrite Testing
443 /// This pattern is applied to the same operation multiple times, but has a
444 /// bounded recursion.
445 struct TestBoundedRecursiveRewrite
446     : public OpRewritePattern<TestRecursiveRewriteOp> {
447   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
448 
449   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
450                                 PatternRewriter &rewriter) const final {
451     // Decrement the depth of the op in-place.
452     rewriter.updateRootInPlace(op, [&] {
453       op.setAttr("depth",
454                  rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1));
455     });
456     return success();
457   }
458 
459   /// The conversion target handles bounding the recursion of this pattern.
460   bool hasBoundedRewriteRecursion() const final { return true; }
461 };
462 
463 struct TestNestedOpCreationUndoRewrite
464     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
465   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
466 
467   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
468                                 PatternRewriter &rewriter) const final {
469     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
470     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
471     return success();
472   };
473 };
474 } // namespace
475 
476 namespace {
477 struct TestTypeConverter : public TypeConverter {
478   using TypeConverter::TypeConverter;
479   TestTypeConverter() {
480     addConversion(convertType);
481     addMaterialization(materializeCast);
482     addMaterialization(materializeOneToOneCast);
483   }
484 
485   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
486     // Drop I16 types.
487     if (t.isSignlessInteger(16))
488       return success();
489 
490     // Convert I64 to F64.
491     if (t.isSignlessInteger(64)) {
492       results.push_back(FloatType::getF64(t.getContext()));
493       return success();
494     }
495 
496     // Convert I42 to I43.
497     if (t.isInteger(42)) {
498       results.push_back(IntegerType::get(43, t.getContext()));
499       return success();
500     }
501 
502     // Split F32 into F16,F16.
503     if (t.isF32()) {
504       results.assign(2, FloatType::getF16(t.getContext()));
505       return success();
506     }
507 
508     // Otherwise, convert the type directly.
509     results.push_back(t);
510     return success();
511   }
512 
513   /// Hook for materializing a conversion. This is necessary because we generate
514   /// 1->N type mappings.
515   static Optional<Value> materializeCast(PatternRewriter &rewriter,
516                                          Type resultType, ValueRange inputs,
517                                          Location loc) {
518     if (inputs.size() == 1)
519       return inputs[0];
520     return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
521   }
522 
523   /// Materialize the cast for one-to-one conversion from i64 to f64.
524   static Optional<Value> materializeOneToOneCast(PatternRewriter &rewriter,
525                                                  IntegerType resultType,
526                                                  ValueRange inputs,
527                                                  Location loc) {
528     if (resultType.getWidth() == 42 && inputs.size() == 1)
529       return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
530     return llvm::None;
531   }
532 };
533 
534 struct TestLegalizePatternDriver
535     : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
536   /// The mode of conversion to use with the driver.
537   enum class ConversionMode { Analysis, Full, Partial };
538 
539   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
540 
541   void runOnOperation() override {
542     TestTypeConverter converter;
543     mlir::OwningRewritePatternList patterns;
544     populateWithGenerated(&getContext(), &patterns);
545     patterns.insert<
546         TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
547         TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase,
548         TestPassthroughInvalidOp, TestSplitReturnType,
549         TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
550         TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
551         TestNonRootReplacement, TestBoundedRecursiveRewrite,
552         TestNestedOpCreationUndoRewrite>(&getContext());
553     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
554     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
555                                               converter);
556     mlir::populateCallOpTypeConversionPattern(patterns, &getContext(),
557                                               converter);
558 
559     // Define the conversion target used for the test.
560     ConversionTarget target(getContext());
561     target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
562     target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
563                       TerminatorOp>();
564     target
565         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
566     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
567       // Don't allow F32 operands.
568       return llvm::none_of(op.getOperandTypes(),
569                            [](Type type) { return type.isF32(); });
570     });
571     target.addDynamicallyLegalOp<FuncOp>(
572         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
573 
574     // Expect the type_producer/type_consumer operations to only operate on f64.
575     target.addDynamicallyLegalOp<TestTypeProducerOp>(
576         [](TestTypeProducerOp op) { return op.getType().isF64(); });
577     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
578       return op.getOperand().getType().isF64();
579     });
580 
581     // Check support for marking certain operations as recursively legal.
582     target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
583       return static_cast<bool>(
584           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
585     });
586 
587     // Mark the bound recursion operation as dynamically legal.
588     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
589         [](TestRecursiveRewriteOp op) { return op.depth() == 0; });
590 
591     // Handle a partial conversion.
592     if (mode == ConversionMode::Partial) {
593       DenseSet<Operation *> unlegalizedOps;
594       (void)applyPartialConversion(getOperation(), target, patterns, &converter,
595                                    &unlegalizedOps);
596       // Emit remarks for each legalizable operation.
597       for (auto *op : unlegalizedOps)
598         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
599       return;
600     }
601 
602     // Handle a full conversion.
603     if (mode == ConversionMode::Full) {
604       // Check support for marking unknown operations as dynamically legal.
605       target.markUnknownOpDynamicallyLegal([](Operation *op) {
606         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
607       });
608 
609       (void)applyFullConversion(getOperation(), target, patterns, &converter);
610       return;
611     }
612 
613     // Otherwise, handle an analysis conversion.
614     assert(mode == ConversionMode::Analysis);
615 
616     // Analyze the convertible operations.
617     DenseSet<Operation *> legalizedOps;
618     if (failed(applyAnalysisConversion(getOperation(), target, patterns,
619                                        legalizedOps, &converter)))
620       return signalPassFailure();
621 
622     // Emit remarks for each legalizable operation.
623     for (auto *op : legalizedOps)
624       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
625   }
626 
627   /// The mode of conversion to use.
628   ConversionMode mode;
629 };
630 } // end anonymous namespace
631 
632 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
633     legalizerConversionMode(
634         "test-legalize-mode",
635         llvm::cl::desc("The legalization mode to use with the test driver"),
636         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
637         llvm::cl::values(
638             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
639                        "analysis", "Perform an analysis conversion"),
640             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
641                        "Perform a full conversion"),
642             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
643                        "partial", "Perform a partial conversion")));
644 
645 //===----------------------------------------------------------------------===//
646 // ConversionPatternRewriter::getRemappedValue testing. This method is used
647 // to get the remapped value of an original value that was replaced using
648 // ConversionPatternRewriter.
649 namespace {
650 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
651 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
652 /// operand twice.
653 ///
654 /// Example:
655 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
656 /// is replaced with:
657 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
658 struct OneVResOneVOperandOp1Converter
659     : public OpConversionPattern<OneVResOneVOperandOp1> {
660   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
661 
662   LogicalResult
663   matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
664                   ConversionPatternRewriter &rewriter) const override {
665     auto origOps = op.getOperands();
666     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
667            "One operand expected");
668     Value origOp = *origOps.begin();
669     SmallVector<Value, 2> remappedOperands;
670     // Replicate the remapped original operand twice. Note that we don't used
671     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
672     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
673     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
674 
675     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
676                                                        remappedOperands);
677     return success();
678   }
679 };
680 
681 struct TestRemappedValue
682     : public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
683   void runOnFunction() override {
684     mlir::OwningRewritePatternList patterns;
685     patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
686 
687     mlir::ConversionTarget target(getContext());
688     target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
689     // We make OneVResOneVOperandOp1 legal only when it has more that one
690     // operand. This will trigger the conversion that will replace one-operand
691     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
692     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
693         [](Operation *op) -> bool {
694           return std::distance(op->operand_begin(), op->operand_end()) > 1;
695         });
696 
697     if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) {
698       signalPassFailure();
699     }
700   }
701 };
702 } // end anonymous namespace
703 
704 namespace mlir {
705 void registerPatternsTestPass() {
706   mlir::PassRegistration<TestReturnTypeDriver>("test-return-type",
707                                                "Run return type functions");
708 
709   mlir::PassRegistration<TestDerivedAttributeDriver>(
710       "test-derived-attr", "Run test derived attributes");
711 
712   mlir::PassRegistration<TestPatternDriver>("test-patterns",
713                                             "Run test dialect patterns");
714 
715   mlir::PassRegistration<TestLegalizePatternDriver>(
716       "test-legalize-patterns", "Run test dialect legalization patterns", [] {
717         return std::make_unique<TestLegalizePatternDriver>(
718             legalizerConversionMode);
719       });
720 
721   PassRegistration<TestRemappedValue>(
722       "test-remapped-value",
723       "Test public remapped value mechanism in ConversionPatternRewriter");
724 }
725 } // namespace mlir
726