1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
12 #include "mlir/IR/PatternMatch.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 #include "mlir/Transforms/FoldUtils.h"
16 
17 using namespace mlir;
18 
19 // Native function for testing NativeCodeCall
20 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
21   return choice.getValue() ? input1 : input2;
22 }
23 
24 static void createOpI(PatternRewriter &rewriter, Value input) {
25   rewriter.create<OpI>(rewriter.getUnknownLoc(), input);
26 }
27 
28 static void handleNoResultOp(PatternRewriter &rewriter,
29                              OpSymbolBindingNoResult op) {
30   // Turn the no result op to a one-result op.
31   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
32                                     op.operand());
33 }
34 
35 namespace {
36 #include "TestPatterns.inc"
37 } // end anonymous namespace
38 
39 //===----------------------------------------------------------------------===//
40 // Canonicalizer Driver.
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 struct FoldingPattern : public RewritePattern {
45 public:
46   FoldingPattern(MLIRContext *context)
47       : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
48                        /*benefit=*/1, context) {}
49 
50   LogicalResult matchAndRewrite(Operation *op,
51                                 PatternRewriter &rewriter) const override {
52     // Exercice OperationFolder API for a single-result operation that is folded
53     // upon construction. The operation being created through the folder has an
54     // in-place folder, and it should be still present in the output.
55     // Furthermore, the folder should not crash when attempting to recover the
56     // (unchanged) opeation result.
57     OperationFolder folder(op->getContext());
58     Value result = folder.create<TestOpInPlaceFold>(
59         rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
60         rewriter.getI32IntegerAttr(0));
61     assert(result);
62     rewriter.replaceOp(op, result);
63     return success();
64   }
65 };
66 
67 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
68   void runOnFunction() override {
69     mlir::OwningRewritePatternList patterns;
70     populateWithGenerated(&getContext(), &patterns);
71 
72     // Verify named pattern is generated with expected name.
73     patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext());
74 
75     applyPatternsAndFoldGreedily(getFunction(), patterns);
76   }
77 };
78 } // end anonymous namespace
79 
80 //===----------------------------------------------------------------------===//
81 // ReturnType Driver.
82 //===----------------------------------------------------------------------===//
83 
84 namespace {
85 // Generate ops for each instance where the type can be successfully inferred.
86 template <typename OpTy>
87 static void invokeCreateWithInferredReturnType(Operation *op) {
88   auto *context = op->getContext();
89   auto fop = op->getParentOfType<FuncOp>();
90   auto location = UnknownLoc::get(context);
91   OpBuilder b(op);
92   b.setInsertionPointAfter(op);
93 
94   // Use permutations of 2 args as operands.
95   assert(fop.getNumArguments() >= 2);
96   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
97     for (int j = 0; j < e; ++j) {
98       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
99       SmallVector<Type, 2> inferredReturnTypes;
100       if (succeeded(OpTy::inferReturnTypes(
101               context, llvm::None, values, op->getAttrDictionary(),
102               op->getRegions(), inferredReturnTypes))) {
103         OperationState state(location, OpTy::getOperationName());
104         // TODO(jpienaar): Expand to regions.
105         OpTy::build(b, state, values, op->getAttrs());
106         (void)b.createOperation(state);
107       }
108     }
109   }
110 }
111 
112 static void reifyReturnShape(Operation *op) {
113   OpBuilder b(op);
114 
115   // Use permutations of 2 args as operands.
116   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
117   SmallVector<Value, 2> shapes;
118   if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
119     return;
120   for (auto it : llvm::enumerate(shapes))
121     op->emitRemark() << "value " << it.index() << ": "
122                      << it.value().getDefiningOp();
123 }
124 
125 struct TestReturnTypeDriver
126     : public PassWrapper<TestReturnTypeDriver, FunctionPass> {
127   void runOnFunction() override {
128     if (getFunction().getName() == "testCreateFunctions") {
129       std::vector<Operation *> ops;
130       // Collect ops to avoid triggering on inserted ops.
131       for (auto &op : getFunction().getBody().front())
132         ops.push_back(&op);
133       // Generate test patterns for each, but skip terminator.
134       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
135         // Test create method of each of the Op classes below. The resultant
136         // output would be in reverse order underneath `op` from which
137         // the attributes and regions are used.
138         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
139         invokeCreateWithInferredReturnType<
140             OpWithShapedTypeInferTypeInterfaceOp>(op);
141       };
142       return;
143     }
144     if (getFunction().getName() == "testReifyFunctions") {
145       std::vector<Operation *> ops;
146       // Collect ops to avoid triggering on inserted ops.
147       for (auto &op : getFunction().getBody().front())
148         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
149           ops.push_back(&op);
150       // Generate test patterns for each, but skip terminator.
151       for (auto *op : ops)
152         reifyReturnShape(op);
153     }
154   }
155 };
156 } // end anonymous namespace
157 
158 namespace {
159 struct TestDerivedAttributeDriver
160     : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
161   void runOnFunction() override;
162 };
163 } // end anonymous namespace
164 
165 void TestDerivedAttributeDriver::runOnFunction() {
166   getFunction().walk([](DerivedAttributeOpInterface dOp) {
167     auto dAttr = dOp.materializeDerivedAttributes();
168     if (!dAttr)
169       return;
170     for (auto d : dAttr)
171       dOp.emitRemark() << d.first << " = " << d.second;
172   });
173 }
174 
175 //===----------------------------------------------------------------------===//
176 // Legalization Driver.
177 //===----------------------------------------------------------------------===//
178 
179 namespace {
180 //===----------------------------------------------------------------------===//
181 // Region-Block Rewrite Testing
182 
183 /// This pattern is a simple pattern that inlines the first region of a given
184 /// operation into the parent region.
185 struct TestRegionRewriteBlockMovement : public ConversionPattern {
186   TestRegionRewriteBlockMovement(MLIRContext *ctx)
187       : ConversionPattern("test.region", 1, ctx) {}
188 
189   LogicalResult
190   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
191                   ConversionPatternRewriter &rewriter) const final {
192     // Inline this region into the parent region.
193     auto &parentRegion = *op->getParentRegion();
194     if (op->getAttr("legalizer.should_clone"))
195       rewriter.cloneRegionBefore(op->getRegion(0), parentRegion,
196                                  parentRegion.end());
197     else
198       rewriter.inlineRegionBefore(op->getRegion(0), parentRegion,
199                                   parentRegion.end());
200 
201     // Drop this operation.
202     rewriter.eraseOp(op);
203     return success();
204   }
205 };
206 /// This pattern is a simple pattern that generates a region containing an
207 /// illegal operation.
208 struct TestRegionRewriteUndo : public RewritePattern {
209   TestRegionRewriteUndo(MLIRContext *ctx)
210       : RewritePattern("test.region_builder", 1, ctx) {}
211 
212   LogicalResult matchAndRewrite(Operation *op,
213                                 PatternRewriter &rewriter) const final {
214     // Create the region operation with an entry block containing arguments.
215     OperationState newRegion(op->getLoc(), "test.region");
216     newRegion.addRegion();
217     auto *regionOp = rewriter.createOperation(newRegion);
218     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
219     entryBlock->addArgument(rewriter.getIntegerType(64));
220 
221     // Add an explicitly illegal operation to ensure the conversion fails.
222     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
223     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
224 
225     // Drop this operation.
226     rewriter.eraseOp(op);
227     return success();
228   }
229 };
230 /// A simple pattern that creates a block at the end of the parent region of the
231 /// matched operation.
232 struct TestCreateBlock : public RewritePattern {
233   TestCreateBlock(MLIRContext *ctx)
234       : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
235 
236   LogicalResult matchAndRewrite(Operation *op,
237                                 PatternRewriter &rewriter) const final {
238     Region &region = *op->getParentRegion();
239     Type i32Type = rewriter.getIntegerType(32);
240     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
241     rewriter.create<TerminatorOp>(op->getLoc());
242     rewriter.replaceOp(op, {});
243     return success();
244   }
245 };
246 
247 /// A simple pattern that creates a block containing an invalid operaiton in
248 /// order to trigger the block creation undo mechanism.
249 struct TestCreateIllegalBlock : public RewritePattern {
250   TestCreateIllegalBlock(MLIRContext *ctx)
251       : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
252 
253   LogicalResult matchAndRewrite(Operation *op,
254                                 PatternRewriter &rewriter) const final {
255     Region &region = *op->getParentRegion();
256     Type i32Type = rewriter.getIntegerType(32);
257     rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
258     // Create an illegal op to ensure the conversion fails.
259     rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
260     rewriter.create<TerminatorOp>(op->getLoc());
261     rewriter.replaceOp(op, {});
262     return success();
263   }
264 };
265 
266 /// A simple pattern that tests the undo mechanism when replacing the uses of a
267 /// block argument.
268 struct TestUndoBlockArgReplace : public ConversionPattern {
269   TestUndoBlockArgReplace(MLIRContext *ctx)
270       : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
271 
272   LogicalResult
273   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
274                   ConversionPatternRewriter &rewriter) const final {
275     auto illegalOp =
276         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
277     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0),
278                                         illegalOp);
279     rewriter.updateRootInPlace(op, [] {});
280     return success();
281   }
282 };
283 
284 /// A rewrite pattern that tests the undo mechanism when erasing a block.
285 struct TestUndoBlockErase : public ConversionPattern {
286   TestUndoBlockErase(MLIRContext *ctx)
287       : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
288 
289   LogicalResult
290   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
291                   ConversionPatternRewriter &rewriter) const final {
292     Block *secondBlock = &*std::next(op->getRegion(0).begin());
293     rewriter.setInsertionPointToStart(secondBlock);
294     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
295     rewriter.eraseBlock(secondBlock);
296     rewriter.updateRootInPlace(op, [] {});
297     return success();
298   }
299 };
300 
301 //===----------------------------------------------------------------------===//
302 // Type-Conversion Rewrite Testing
303 
304 /// This patterns erases a region operation that has had a type conversion.
305 struct TestDropOpSignatureConversion : public ConversionPattern {
306   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
307       : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) {
308   }
309   LogicalResult
310   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
311                   ConversionPatternRewriter &rewriter) const override {
312     Region &region = op->getRegion(0);
313     Block *entry = &region.front();
314 
315     // Convert the original entry arguments.
316     TypeConverter::SignatureConversion result(entry->getNumArguments());
317     for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i)
318       if (failed(converter.convertSignatureArg(
319               i, entry->getArgument(i).getType(), result)))
320         return failure();
321 
322     // Convert the region signature and just drop the operation.
323     rewriter.applySignatureConversion(&region, result);
324     rewriter.eraseOp(op);
325     return success();
326   }
327 
328   /// The type converter to use when rewriting the signature.
329   TypeConverter &converter;
330 };
331 /// This pattern simply updates the operands of the given operation.
332 struct TestPassthroughInvalidOp : public ConversionPattern {
333   TestPassthroughInvalidOp(MLIRContext *ctx)
334       : ConversionPattern("test.invalid", 1, ctx) {}
335   LogicalResult
336   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
337                   ConversionPatternRewriter &rewriter) const final {
338     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
339                                              llvm::None);
340     return success();
341   }
342 };
343 /// This pattern handles the case of a split return value.
344 struct TestSplitReturnType : public ConversionPattern {
345   TestSplitReturnType(MLIRContext *ctx)
346       : ConversionPattern("test.return", 1, ctx) {}
347   LogicalResult
348   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
349                   ConversionPatternRewriter &rewriter) const final {
350     // Check for a return of F32.
351     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
352       return failure();
353 
354     // Check if the first operation is a cast operation, if it is we use the
355     // results directly.
356     auto *defOp = operands[0].getDefiningOp();
357     if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
358       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
359       return success();
360     }
361 
362     // Otherwise, fail to match.
363     return failure();
364   }
365 };
366 
367 //===----------------------------------------------------------------------===//
368 // Multi-Level Type-Conversion Rewrite Testing
369 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
370   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
371       : ConversionPattern("test.type_producer", 1, ctx) {}
372   LogicalResult
373   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
374                   ConversionPatternRewriter &rewriter) const final {
375     // If the type is I32, change the type to F32.
376     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
377       return failure();
378     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
379     return success();
380   }
381 };
382 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
383   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
384       : ConversionPattern("test.type_producer", 1, ctx) {}
385   LogicalResult
386   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
387                   ConversionPatternRewriter &rewriter) const final {
388     // If the type is F32, change the type to F64.
389     if (!Type(*op->result_type_begin()).isF32())
390       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
391     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
392     return success();
393   }
394 };
395 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
396   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
397       : ConversionPattern("test.type_producer", 10, ctx) {}
398   LogicalResult
399   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
400                   ConversionPatternRewriter &rewriter) const final {
401     // Always convert to B16, even though it is not a legal type. This tests
402     // that values are unmapped correctly.
403     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
404     return success();
405   }
406 };
407 struct TestUpdateConsumerType : public ConversionPattern {
408   TestUpdateConsumerType(MLIRContext *ctx)
409       : ConversionPattern("test.type_consumer", 1, ctx) {}
410   LogicalResult
411   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
412                   ConversionPatternRewriter &rewriter) const final {
413     // Verify that the incoming operand has been successfully remapped to F64.
414     if (!operands[0].getType().isF64())
415       return failure();
416     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
417     return success();
418   }
419 };
420 
421 //===----------------------------------------------------------------------===//
422 // Non-Root Replacement Rewrite Testing
423 /// This pattern generates an invalid operation, but replaces it before the
424 /// pattern is finished. This checks that we don't need to legalize the
425 /// temporary op.
426 struct TestNonRootReplacement : public RewritePattern {
427   TestNonRootReplacement(MLIRContext *ctx)
428       : RewritePattern("test.replace_non_root", 1, ctx) {}
429 
430   LogicalResult matchAndRewrite(Operation *op,
431                                 PatternRewriter &rewriter) const final {
432     auto resultType = *op->result_type_begin();
433     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
434     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
435 
436     rewriter.replaceOp(illegalOp, {legalOp});
437     rewriter.replaceOp(op, {illegalOp});
438     return success();
439   }
440 };
441 
442 //===----------------------------------------------------------------------===//
443 // Recursive Rewrite Testing
444 /// This pattern is applied to the same operation multiple times, but has a
445 /// bounded recursion.
446 struct TestBoundedRecursiveRewrite
447     : public OpRewritePattern<TestRecursiveRewriteOp> {
448   using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
449 
450   LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
451                                 PatternRewriter &rewriter) const final {
452     // Decrement the depth of the op in-place.
453     rewriter.updateRootInPlace(op, [&] {
454       op.setAttr("depth",
455                  rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1));
456     });
457     return success();
458   }
459 
460   /// The conversion target handles bounding the recursion of this pattern.
461   bool hasBoundedRewriteRecursion() const final { return true; }
462 };
463 
464 struct TestNestedOpCreationUndoRewrite
465     : public OpRewritePattern<IllegalOpWithRegionAnchor> {
466   using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
467 
468   LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
469                                 PatternRewriter &rewriter) const final {
470     // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
471     rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
472     return success();
473   };
474 };
475 } // namespace
476 
477 namespace {
478 struct TestTypeConverter : public TypeConverter {
479   using TypeConverter::TypeConverter;
480   TestTypeConverter() { addConversion(convertType); }
481 
482   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
483     // Drop I16 types.
484     if (t.isSignlessInteger(16))
485       return success();
486 
487     // Convert I64 to F64.
488     if (t.isSignlessInteger(64)) {
489       results.push_back(FloatType::getF64(t.getContext()));
490       return success();
491     }
492 
493     // Split F32 into F16,F16.
494     if (t.isF32()) {
495       results.assign(2, FloatType::getF16(t.getContext()));
496       return success();
497     }
498 
499     // Otherwise, convert the type directly.
500     results.push_back(t);
501     return success();
502   }
503 
504   /// Override the hook to materialize a conversion. This is necessary because
505   /// we generate 1->N type mappings.
506   Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
507                                    ArrayRef<Value> inputs,
508                                    Location loc) override {
509     return rewriter.create<TestCastOp>(loc, resultType, inputs);
510   }
511 };
512 
513 struct TestLegalizePatternDriver
514     : public PassWrapper<TestLegalizePatternDriver, OperationPass<ModuleOp>> {
515   /// The mode of conversion to use with the driver.
516   enum class ConversionMode { Analysis, Full, Partial };
517 
518   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
519 
520   void runOnOperation() override {
521     TestTypeConverter converter;
522     mlir::OwningRewritePatternList patterns;
523     populateWithGenerated(&getContext(), &patterns);
524     patterns.insert<
525         TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
526         TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase,
527         TestPassthroughInvalidOp, TestSplitReturnType,
528         TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
529         TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
530         TestNonRootReplacement, TestBoundedRecursiveRewrite,
531         TestNestedOpCreationUndoRewrite>(&getContext());
532     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
533     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
534                                               converter);
535     mlir::populateCallOpTypeConversionPattern(patterns, &getContext(),
536                                               converter);
537 
538     // Define the conversion target used for the test.
539     ConversionTarget target(getContext());
540     target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
541     target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
542                       TerminatorOp>();
543     target
544         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
545     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
546       // Don't allow F32 operands.
547       return llvm::none_of(op.getOperandTypes(),
548                            [](Type type) { return type.isF32(); });
549     });
550     target.addDynamicallyLegalOp<FuncOp>(
551         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
552 
553     // Expect the type_producer/type_consumer operations to only operate on f64.
554     target.addDynamicallyLegalOp<TestTypeProducerOp>(
555         [](TestTypeProducerOp op) { return op.getType().isF64(); });
556     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
557       return op.getOperand().getType().isF64();
558     });
559 
560     // Check support for marking certain operations as recursively legal.
561     target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
562       return static_cast<bool>(
563           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
564     });
565 
566     // Mark the bound recursion operation as dynamically legal.
567     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
568         [](TestRecursiveRewriteOp op) { return op.depth() == 0; });
569 
570     // Handle a partial conversion.
571     if (mode == ConversionMode::Partial) {
572       DenseSet<Operation *> unlegalizedOps;
573       (void)applyPartialConversion(getOperation(), target, patterns, &converter,
574                                    &unlegalizedOps);
575       // Emit remarks for each legalizable operation.
576       for (auto *op : unlegalizedOps)
577         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
578       return;
579     }
580 
581     // Handle a full conversion.
582     if (mode == ConversionMode::Full) {
583       // Check support for marking unknown operations as dynamically legal.
584       target.markUnknownOpDynamicallyLegal([](Operation *op) {
585         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
586       });
587 
588       (void)applyFullConversion(getOperation(), target, patterns, &converter);
589       return;
590     }
591 
592     // Otherwise, handle an analysis conversion.
593     assert(mode == ConversionMode::Analysis);
594 
595     // Analyze the convertible operations.
596     DenseSet<Operation *> legalizedOps;
597     if (failed(applyAnalysisConversion(getOperation(), target, patterns,
598                                        legalizedOps, &converter)))
599       return signalPassFailure();
600 
601     // Emit remarks for each legalizable operation.
602     for (auto *op : legalizedOps)
603       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
604   }
605 
606   /// The mode of conversion to use.
607   ConversionMode mode;
608 };
609 } // end anonymous namespace
610 
611 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
612     legalizerConversionMode(
613         "test-legalize-mode",
614         llvm::cl::desc("The legalization mode to use with the test driver"),
615         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
616         llvm::cl::values(
617             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
618                        "analysis", "Perform an analysis conversion"),
619             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
620                        "Perform a full conversion"),
621             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
622                        "partial", "Perform a partial conversion")));
623 
624 //===----------------------------------------------------------------------===//
625 // ConversionPatternRewriter::getRemappedValue testing. This method is used
626 // to get the remapped value of an original value that was replaced using
627 // ConversionPatternRewriter.
628 namespace {
629 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
630 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
631 /// operand twice.
632 ///
633 /// Example:
634 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
635 /// is replaced with:
636 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
637 struct OneVResOneVOperandOp1Converter
638     : public OpConversionPattern<OneVResOneVOperandOp1> {
639   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
640 
641   LogicalResult
642   matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
643                   ConversionPatternRewriter &rewriter) const override {
644     auto origOps = op.getOperands();
645     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
646            "One operand expected");
647     Value origOp = *origOps.begin();
648     SmallVector<Value, 2> remappedOperands;
649     // Replicate the remapped original operand twice. Note that we don't used
650     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
651     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
652     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
653 
654     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
655                                                        remappedOperands);
656     return success();
657   }
658 };
659 
660 struct TestRemappedValue
661     : public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
662   void runOnFunction() override {
663     mlir::OwningRewritePatternList patterns;
664     patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
665 
666     mlir::ConversionTarget target(getContext());
667     target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
668     // We make OneVResOneVOperandOp1 legal only when it has more that one
669     // operand. This will trigger the conversion that will replace one-operand
670     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
671     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
672         [](Operation *op) -> bool {
673           return std::distance(op->operand_begin(), op->operand_end()) > 1;
674         });
675 
676     if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) {
677       signalPassFailure();
678     }
679   }
680 };
681 } // end anonymous namespace
682 
683 namespace mlir {
684 void registerPatternsTestPass() {
685   mlir::PassRegistration<TestReturnTypeDriver>("test-return-type",
686                                                "Run return type functions");
687 
688   mlir::PassRegistration<TestDerivedAttributeDriver>(
689       "test-derived-attr", "Run test derived attributes");
690 
691   mlir::PassRegistration<TestPatternDriver>("test-patterns",
692                                             "Run test dialect patterns");
693 
694   mlir::PassRegistration<TestLegalizePatternDriver>(
695       "test-legalize-patterns", "Run test dialect legalization patterns", [] {
696         return std::make_unique<TestLegalizePatternDriver>(
697             legalizerConversionMode);
698       });
699 
700   PassRegistration<TestRemappedValue>(
701       "test-remapped-value",
702       "Test public remapped value mechanism in ConversionPatternRewriter");
703 }
704 } // namespace mlir
705