1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
12 #include "mlir/IR/PatternMatch.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 #include "mlir/Transforms/FoldUtils.h"
16 
17 using namespace mlir;
18 
19 // Native function for testing NativeCodeCall
20 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
21   return choice.getValue() ? input1 : input2;
22 }
23 
24 static void createOpI(PatternRewriter &rewriter, Value input) {
25   rewriter.create<OpI>(rewriter.getUnknownLoc(), input);
26 }
27 
28 static void handleNoResultOp(PatternRewriter &rewriter,
29                              OpSymbolBindingNoResult op) {
30   // Turn the no result op to a one-result op.
31   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
32                                     op.operand());
33 }
34 
35 namespace {
36 #include "TestPatterns.inc"
37 } // end anonymous namespace
38 
39 //===----------------------------------------------------------------------===//
40 // Canonicalizer Driver.
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 struct FoldingPattern : public RewritePattern {
45 public:
46   FoldingPattern(MLIRContext *context)
47       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
48                        /*benefit=*/1, context) {}
49 
50   LogicalResult matchAndRewrite(Operation *op,
51                                 PatternRewriter &rewriter) const override {
52     // Exercice OperationFolder API for a single-result operation that is folded
53     // upon construction. The operation being created through the folder has an
54     // in-place folder, and it should be still present in the output.
55     // Furthermore, the folder should not crash when attempting to recover the
56     // (unchanged) opeation result.
57     OperationFolder folder(op->getContext());
58     Value result = folder.create<TestOpInPlaceFold>(
59         rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
60         rewriter.getI32IntegerAttr(0));
61     assert(result);
62     rewriter.replaceOp(op, result);
63     return success();
64   }
65 };
66 
67 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
68   void runOnFunction() override {
69     mlir::OwningRewritePatternList patterns;
70     populateWithGenerated(&getContext(), &patterns);
71 
72     // Verify named pattern is generated with expected name.
73     patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext());
74 
75     applyPatternsAndFoldGreedily(getFunction(), patterns);
76   }
77 };
78 } // end anonymous namespace
79 
80 //===----------------------------------------------------------------------===//
81 // ReturnType Driver.
82 //===----------------------------------------------------------------------===//
83 
84 namespace {
85 // Generate ops for each instance where the type can be successfully inferred.
86 template <typename OpTy>
87 static void invokeCreateWithInferredReturnType(Operation *op) {
88   auto *context = op->getContext();
89   auto fop = op->getParentOfType<FuncOp>();
90   auto location = UnknownLoc::get(context);
91   OpBuilder b(op);
92   b.setInsertionPointAfter(op);
93 
94   // Use permutations of 2 args as operands.
95   assert(fop.getNumArguments() >= 2);
96   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
97     for (int j = 0; j < e; ++j) {
98       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
99       SmallVector<Type, 2> inferredReturnTypes;
100       if (succeeded(OpTy::inferReturnTypes(
101               context, llvm::None, values, op->getAttrDictionary(),
102               op->getRegions(), inferredReturnTypes))) {
103         OperationState state(location, OpTy::getOperationName());
104         // TODO(jpienaar): Expand to regions.
105         OpTy::build(b, state, values, op->getAttrs());
106         (void)b.createOperation(state);
107       }
108     }
109   }
110 }
111 
112 static void reifyReturnShape(Operation *op) {
113   OpBuilder b(op);
114 
115   // Use permutations of 2 args as operands.
116   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
117   SmallVector<Value, 2> shapes;
118   if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
119     return;
120   for (auto it : llvm::enumerate(shapes))
121     op->emitRemark() << "value " << it.index() << ": "
122                      << it.value().getDefiningOp();
123 }
124 
125 struct TestReturnTypeDriver
126     : public PassWrapper<TestReturnTypeDriver, FunctionPass> {
127   void runOnFunction() override {
128     if (getFunction().getName() == "testCreateFunctions") {
129       std::vector<Operation *> ops;
130       // Collect ops to avoid triggering on inserted ops.
131       for (auto &op : getFunction().getBody().front())
132         ops.push_back(&op);
133       // Generate test patterns for each, but skip terminator.
134       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
135         // Test create method of each of the Op classes below. The resultant
136         // output would be in reverse order underneath `op` from which
137         // the attributes and regions are used.
138         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
139         invokeCreateWithInferredReturnType<
140             OpWithShapedTypeInferTypeInterfaceOp>(op);
141       };
142       return;
143     }
144     if (getFunction().getName() == "testReifyFunctions") {
145       std::vector<Operation *> ops;
146       // Collect ops to avoid triggering on inserted ops.
147       for (auto &op : getFunction().getBody().front())
148         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
149           ops.push_back(&op);
150       // Generate test patterns for each, but skip terminator.
151       for (auto *op : ops)
152         reifyReturnShape(op);
153     }
154   }
155 };
156 } // end anonymous namespace
157 
158 namespace {
159 struct TestDerivedAttributeDriver
160     : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
161   void runOnFunction() override;
162 };
163 } // end anonymous namespace
164 
165 void TestDerivedAttributeDriver::runOnFunction() {
166   getFunction().walk([](DerivedAttributeOpInterface dOp) {
167     auto dAttr = dOp.materializeDerivedAttributes();
168     if (!dAttr)
169       return;
170     for (auto d : dAttr)
171       dOp.emitRemark() << d.first << " = " << d.second;
172   });
173 }
174 
175 //===----------------------------------------------------------------------===//
176 // Legalization Driver.
177 //===----------------------------------------------------------------------===//
178 
179 namespace {
180 //===----------------------------------------------------------------------===//
181 // Region-Block Rewrite Testing
182 
183 /// This pattern is a simple pattern that inlines the first region of a given
184 /// operation into the parent region.
185 struct TestRegionRewriteBlockMovement : public ConversionPattern {
186   TestRegionRewriteBlockMovement(MLIRContext *ctx)
187       : ConversionPattern("test.region", 1, ctx) {}
188 
189   LogicalResult
190   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
191                   ConversionPatternRewriter &rewriter) const final {
192     // Inline this region into the parent region.
193     auto &parentRegion = *op->getParentRegion();
194     if (op->getAttr("legalizer.should_clone"))
195       rewriter.cloneRegionBefore(op->getRegion(0), parentRegion,
196                                  parentRegion.end());
197     else
198       rewriter.inlineRegionBefore(op->getRegion(0), parentRegion,
199                                   parentRegion.end());
200 
201     // Drop this operation.
202     rewriter.eraseOp(op);
203     return success();
204   }
205 };
206 /// This pattern is a simple pattern that generates a region containing an
207 /// illegal operation.
208 struct TestRegionRewriteUndo : public RewritePattern {
209   TestRegionRewriteUndo(MLIRContext *ctx)
210       : RewritePattern("test.region_builder", 1, ctx) {}
211 
212   LogicalResult matchAndRewrite(Operation *op,
213                                 PatternRewriter &rewriter) const final {
214     // Create the region operation with an entry block containing arguments.
215     OperationState newRegion(op->getLoc(), "test.region");
216     newRegion.addRegion();
217     auto *regionOp = rewriter.createOperation(newRegion);
218     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
219     entryBlock->addArgument(rewriter.getIntegerType(64));
220 
221     // Add an explicitly illegal operation to ensure the conversion fails.
222     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
223     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
224 
225     // Drop this operation.
226     rewriter.eraseOp(op);
227     return success();
228   }
229 };
230 /// A simple pattern that creates a block at the end of the parent region of the
231 /// matched operation.
232 struct TestCreateBlock : public RewritePattern {
233   TestCreateBlock(MLIRContext *ctx)
234       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
235 
236   LogicalResult matchAndRewrite(Operation *op,
237                                 PatternRewriter &rewriter) const final {
238     Region &region = *op->getParentRegion();
239     Type i32Type = rewriter.getIntegerType(32);
240     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
241     rewriter.create<TerminatorOp>(op->getLoc());
242     rewriter.replaceOp(op, {});
243     return success();
244   }
245 };
246 
247 /// A simple pattern that creates a block containing an invalid operaiton in
248 /// order to trigger the block creation undo mechanism.
249 struct TestCreateIllegalBlock : public RewritePattern {
250   TestCreateIllegalBlock(MLIRContext *ctx)
251       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
252 
253   LogicalResult matchAndRewrite(Operation *op,
254                                 PatternRewriter &rewriter) const final {
255     Region &region = *op->getParentRegion();
256     Type i32Type = rewriter.getIntegerType(32);
257     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
258     // Create an illegal op to ensure the conversion fails.
259     rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
260     rewriter.create<TerminatorOp>(op->getLoc());
261     rewriter.replaceOp(op, {});
262     return success();
263   }
264 };
265 
266 /// A simple pattern that tests the undo mechanism when replacing the uses of a
267 /// block argument.
268 struct TestUndoBlockArgReplace : public ConversionPattern {
269   TestUndoBlockArgReplace(MLIRContext *ctx)
270       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
271 
272   LogicalResult
273   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
274                   ConversionPatternRewriter &rewriter) const final {
275     auto illegalOp =
276         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
277     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0),
278                                         illegalOp);
279     rewriter.updateRootInPlace(op, [] {});
280     return success();
281   }
282 };
283 
284 /// A rewrite pattern that tests the undo mechanism when erasing a block.
285 struct TestUndoBlockErase : public ConversionPattern {
286   TestUndoBlockErase(MLIRContext *ctx)
287       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
288 
289   LogicalResult
290   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
291                   ConversionPatternRewriter &rewriter) const final {
292     Block *secondBlock = &*std::next(op->getRegion(0).begin());
293     rewriter.setInsertionPointToStart(secondBlock);
294     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
295     rewriter.eraseBlock(secondBlock);
296     rewriter.updateRootInPlace(op, [] {});
297     return success();
298   }
299 };
300 
301 //===----------------------------------------------------------------------===//
302 // Type-Conversion Rewrite Testing
303 
304 /// This patterns erases a region operation that has had a type conversion.
305 struct TestDropOpSignatureConversion : public ConversionPattern {
306   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
307       : ConversionPattern("test.drop_region_op", 1, converter, ctx) {}
308   LogicalResult
309   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
310                   ConversionPatternRewriter &rewriter) const override {
311     Region &region = op->getRegion(0);
312     Block *entry = &region.front();
313 
314     // Convert the original entry arguments.
315     TypeConverter &converter = *getTypeConverter();
316     TypeConverter::SignatureConversion result(entry->getNumArguments());
317     if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
318                                               result)) ||
319         failed(rewriter.convertRegionTypes(&region, converter, &result)))
320       return failure();
321 
322     // Convert the region signature and just drop the operation.
323     rewriter.eraseOp(op);
324     return success();
325   }
326 };
327 /// This pattern simply updates the operands of the given operation.
328 struct TestPassthroughInvalidOp : public ConversionPattern {
329   TestPassthroughInvalidOp(MLIRContext *ctx)
330       : ConversionPattern("test.invalid", 1, ctx) {}
331   LogicalResult
332   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
333                   ConversionPatternRewriter &rewriter) const final {
334     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
335                                              llvm::None);
336     return success();
337   }
338 };
339 /// This pattern handles the case of a split return value.
340 struct TestSplitReturnType : public ConversionPattern {
341   TestSplitReturnType(MLIRContext *ctx)
342       : ConversionPattern("test.return", 1, ctx) {}
343   LogicalResult
344   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
345                   ConversionPatternRewriter &rewriter) const final {
346     // Check for a return of F32.
347     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
348       return failure();
349 
350     // Check if the first operation is a cast operation, if it is we use the
351     // results directly.
352     auto *defOp = operands[0].getDefiningOp();
353     if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
354       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
355       return success();
356     }
357 
358     // Otherwise, fail to match.
359     return failure();
360   }
361 };
362 
363 //===----------------------------------------------------------------------===//
364 // Multi-Level Type-Conversion Rewrite Testing
365 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
366   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
367       : ConversionPattern("test.type_producer", 1, ctx) {}
368   LogicalResult
369   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
370                   ConversionPatternRewriter &rewriter) const final {
371     // If the type is I32, change the type to F32.
372     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
373       return failure();
374     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
375     return success();
376   }
377 };
378 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
379   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
380       : ConversionPattern("test.type_producer", 1, ctx) {}
381   LogicalResult
382   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
383                   ConversionPatternRewriter &rewriter) const final {
384     // If the type is F32, change the type to F64.
385     if (!Type(*op->result_type_begin()).isF32())
386       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
387     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
388     return success();
389   }
390 };
391 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
392   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
393       : ConversionPattern("test.type_producer", 10, ctx) {}
394   LogicalResult
395   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
396                   ConversionPatternRewriter &rewriter) const final {
397     // Always convert to B16, even though it is not a legal type. This tests
398     // that values are unmapped correctly.
399     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
400     return success();
401   }
402 };
403 struct TestUpdateConsumerType : public ConversionPattern {
404   TestUpdateConsumerType(MLIRContext *ctx)
405       : ConversionPattern("test.type_consumer", 1, ctx) {}
406   LogicalResult
407   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
408                   ConversionPatternRewriter &rewriter) const final {
409     // Verify that the incoming operand has been successfully remapped to F64.
410     if (!operands[0].getType().isF64())
411       return failure();
412     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
413     return success();
414   }
415 };
416 
417 //===----------------------------------------------------------------------===//
418 // Non-Root Replacement Rewrite Testing
419 /// This pattern generates an invalid operation, but replaces it before the
420 /// pattern is finished. This checks that we don't need to legalize the
421 /// temporary op.
422 struct TestNonRootReplacement : public RewritePattern {
423   TestNonRootReplacement(MLIRContext *ctx)
424       : RewritePattern("test.replace_non_root", 1, ctx) {}
425 
426   LogicalResult matchAndRewrite(Operation *op,
427                                 PatternRewriter &rewriter) const final {
428     auto resultType = *op->result_type_begin();
429     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
430     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
431 
432     rewriter.replaceOp(illegalOp, {legalOp});
433     rewriter.replaceOp(op, {illegalOp});
434     return success();
435   }
436 };
437 
438 //===----------------------------------------------------------------------===//
439 // Recursive Rewrite Testing
440 /// This pattern is applied to the same operation multiple times, but has a
441 /// bounded recursion.
442 struct TestBoundedRecursiveRewrite
443     : public OpRewritePattern<TestRecursiveRewriteOp> {
444   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
445 
446   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
447                                 PatternRewriter &rewriter) const final {
448     // Decrement the depth of the op in-place.
449     rewriter.updateRootInPlace(op, [&] {
450       op.setAttr("depth",
451                  rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1));
452     });
453     return success();
454   }
455 
456   /// The conversion target handles bounding the recursion of this pattern.
457   bool hasBoundedRewriteRecursion() const final { return true; }
458 };
459 
460 struct TestNestedOpCreationUndoRewrite
461     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
462   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
463 
464   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
465                                 PatternRewriter &rewriter) const final {
466     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
467     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
468     return success();
469   };
470 };
471 } // namespace
472 
473 namespace {
474 struct TestTypeConverter : public TypeConverter {
475   using TypeConverter::TypeConverter;
476   TestTypeConverter() {
477     addConversion(convertType);
478     addMaterialization(materializeCast);
479     addMaterialization(materializeOneToOneCast);
480   }
481 
482   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
483     // Drop I16 types.
484     if (t.isSignlessInteger(16))
485       return success();
486 
487     // Convert I64 to F64.
488     if (t.isSignlessInteger(64)) {
489       results.push_back(FloatType::getF64(t.getContext()));
490       return success();
491     }
492 
493     // Convert I42 to I43.
494     if (t.isInteger(42)) {
495       results.push_back(IntegerType::get(43, t.getContext()));
496       return success();
497     }
498 
499     // Split F32 into F16,F16.
500     if (t.isF32()) {
501       results.assign(2, FloatType::getF16(t.getContext()));
502       return success();
503     }
504 
505     // Otherwise, convert the type directly.
506     results.push_back(t);
507     return success();
508   }
509 
510   /// Hook for materializing a conversion. This is necessary because we generate
511   /// 1->N type mappings.
512   static Optional<Value> materializeCast(PatternRewriter &rewriter,
513                                          Type resultType, ValueRange inputs,
514                                          Location loc) {
515     if (inputs.size() == 1)
516       return inputs[0];
517     return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
518   }
519 
520   /// Materialize the cast for one-to-one conversion from i64 to f64.
521   static Optional<Value> materializeOneToOneCast(PatternRewriter &rewriter,
522                                                  IntegerType resultType,
523                                                  ValueRange inputs,
524                                                  Location loc) {
525     if (resultType.getWidth() == 42 && inputs.size() == 1)
526       return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
527     return llvm::None;
528   }
529 };
530 
531 struct TestLegalizePatternDriver
532     : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
533   /// The mode of conversion to use with the driver.
534   enum class ConversionMode { Analysis, Full, Partial };
535 
536   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
537 
538   void runOnOperation() override {
539     TestTypeConverter converter;
540     mlir::OwningRewritePatternList patterns;
541     populateWithGenerated(&getContext(), &patterns);
542     patterns.insert<
543         TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
544         TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase,
545         TestPassthroughInvalidOp, TestSplitReturnType,
546         TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
547         TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
548         TestNonRootReplacement, TestBoundedRecursiveRewrite,
549         TestNestedOpCreationUndoRewrite>(&getContext());
550     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
551     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
552                                               converter);
553     mlir::populateCallOpTypeConversionPattern(patterns, &getContext(),
554                                               converter);
555 
556     // Define the conversion target used for the test.
557     ConversionTarget target(getContext());
558     target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
559     target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
560                       TerminatorOp>();
561     target
562         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
563     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
564       // Don't allow F32 operands.
565       return llvm::none_of(op.getOperandTypes(),
566                            [](Type type) { return type.isF32(); });
567     });
568     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
569       return converter.isSignatureLegal(op.getType()) &&
570              converter.isLegal(&op.getBody());
571     });
572 
573     // Expect the type_producer/type_consumer operations to only operate on f64.
574     target.addDynamicallyLegalOp<TestTypeProducerOp>(
575         [](TestTypeProducerOp op) { return op.getType().isF64(); });
576     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
577       return op.getOperand().getType().isF64();
578     });
579 
580     // Check support for marking certain operations as recursively legal.
581     target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
582       return static_cast<bool>(
583           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
584     });
585 
586     // Mark the bound recursion operation as dynamically legal.
587     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
588         [](TestRecursiveRewriteOp op) { return op.depth() == 0; });
589 
590     // Handle a partial conversion.
591     if (mode == ConversionMode::Partial) {
592       DenseSet<Operation *> unlegalizedOps;
593       (void)applyPartialConversion(getOperation(), target, patterns,
594                                    &unlegalizedOps);
595       // Emit remarks for each legalizable operation.
596       for (auto *op : unlegalizedOps)
597         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
598       return;
599     }
600 
601     // Handle a full conversion.
602     if (mode == ConversionMode::Full) {
603       // Check support for marking unknown operations as dynamically legal.
604       target.markUnknownOpDynamicallyLegal([](Operation *op) {
605         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
606       });
607 
608       (void)applyFullConversion(getOperation(), target, patterns);
609       return;
610     }
611 
612     // Otherwise, handle an analysis conversion.
613     assert(mode == ConversionMode::Analysis);
614 
615     // Analyze the convertible operations.
616     DenseSet<Operation *> legalizedOps;
617     if (failed(applyAnalysisConversion(getOperation(), target, patterns,
618                                        legalizedOps)))
619       return signalPassFailure();
620 
621     // Emit remarks for each legalizable operation.
622     for (auto *op : legalizedOps)
623       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
624   }
625 
626   /// The mode of conversion to use.
627   ConversionMode mode;
628 };
629 } // end anonymous namespace
630 
631 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
632     legalizerConversionMode(
633         "test-legalize-mode",
634         llvm::cl::desc("The legalization mode to use with the test driver"),
635         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
636         llvm::cl::values(
637             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
638                        "analysis", "Perform an analysis conversion"),
639             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
640                        "Perform a full conversion"),
641             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
642                        "partial", "Perform a partial conversion")));
643 
644 //===----------------------------------------------------------------------===//
645 // ConversionPatternRewriter::getRemappedValue testing. This method is used
646 // to get the remapped value of an original value that was replaced using
647 // ConversionPatternRewriter.
648 namespace {
649 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
650 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
651 /// operand twice.
652 ///
653 /// Example:
654 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
655 /// is replaced with:
656 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
657 struct OneVResOneVOperandOp1Converter
658     : public OpConversionPattern<OneVResOneVOperandOp1> {
659   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
660 
661   LogicalResult
662   matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
663                   ConversionPatternRewriter &rewriter) const override {
664     auto origOps = op.getOperands();
665     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
666            "One operand expected");
667     Value origOp = *origOps.begin();
668     SmallVector<Value, 2> remappedOperands;
669     // Replicate the remapped original operand twice. Note that we don't used
670     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
671     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
672     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
673 
674     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
675                                                        remappedOperands);
676     return success();
677   }
678 };
679 
680 struct TestRemappedValue
681     : public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
682   void runOnFunction() override {
683     mlir::OwningRewritePatternList patterns;
684     patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
685 
686     mlir::ConversionTarget target(getContext());
687     target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
688     // We make OneVResOneVOperandOp1 legal only when it has more that one
689     // operand. This will trigger the conversion that will replace one-operand
690     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
691     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
692         [](Operation *op) -> bool {
693           return std::distance(op->operand_begin(), op->operand_end()) > 1;
694         });
695 
696     if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) {
697       signalPassFailure();
698     }
699   }
700 };
701 } // end anonymous namespace
702 
703 //===----------------------------------------------------------------------===//
704 // Test patterns without a specific root operation kind
705 //===----------------------------------------------------------------------===//
706 
707 namespace {
708 /// This pattern matches and removes any operation in the test dialect.
709 struct RemoveTestDialectOps : public RewritePattern {
710   RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
711 
712   LogicalResult matchAndRewrite(Operation *op,
713                                 PatternRewriter &rewriter) const override {
714     if (!isa<TestDialect>(op->getDialect()))
715       return failure();
716     rewriter.eraseOp(op);
717     return success();
718   }
719 };
720 
721 struct TestUnknownRootOpDriver
722     : public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
723   void runOnFunction() override {
724     mlir::OwningRewritePatternList patterns;
725     patterns.insert<RemoveTestDialectOps>();
726 
727     mlir::ConversionTarget target(getContext());
728     target.addIllegalDialect<TestDialect>();
729     if (failed(applyPartialConversion(getFunction(), target, patterns)))
730       signalPassFailure();
731   }
732 };
733 } // end anonymous namespace
734 
735 namespace mlir {
736 void registerPatternsTestPass() {
737   PassRegistration<TestReturnTypeDriver>("test-return-type",
738                                          "Run return type functions");
739 
740   PassRegistration<TestDerivedAttributeDriver>("test-derived-attr",
741                                                "Run test derived attributes");
742 
743   PassRegistration<TestPatternDriver>("test-patterns",
744                                       "Run test dialect patterns");
745 
746   PassRegistration<TestLegalizePatternDriver>(
747       "test-legalize-patterns", "Run test dialect legalization patterns", [] {
748         return std::make_unique<TestLegalizePatternDriver>(
749             legalizerConversionMode);
750       });
751 
752   PassRegistration<TestRemappedValue>(
753       "test-remapped-value",
754       "Test public remapped value mechanism in ConversionPatternRewriter");
755 
756   PassRegistration<TestUnknownRootOpDriver>(
757       "test-legalize-unknown-root-patterns",
758       "Test public remapped value mechanism in ConversionPatternRewriter");
759 }
760 } // namespace mlir
761