1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Conversion/StandardToStandard/StandardToStandard.h"
11 #include "mlir/IR/PatternMatch.h"
12 #include "mlir/Pass/Pass.h"
13 #include "mlir/Transforms/DialectConversion.h"
14 using namespace mlir;
15 
16 // Native function for testing NativeCodeCall
17 static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
18   return choice.getValue() ? input1 : input2;
19 }
20 
21 static void createOpI(PatternRewriter &rewriter, Value input) {
22   rewriter.create<OpI>(rewriter.getUnknownLoc(), input);
23 }
24 
25 static void handleNoResultOp(PatternRewriter &rewriter,
26                              OpSymbolBindingNoResult op) {
27   // Turn the no result op to a one-result op.
28   rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
29                                     op.operand());
30 }
31 
32 namespace {
33 #include "TestPatterns.inc"
34 } // end anonymous namespace
35 
36 //===----------------------------------------------------------------------===//
37 // Canonicalizer Driver.
38 //===----------------------------------------------------------------------===//
39 
40 namespace {
41 struct TestPatternDriver : public FunctionPass<TestPatternDriver> {
42   void runOnFunction() override {
43     mlir::OwningRewritePatternList patterns;
44     populateWithGenerated(&getContext(), &patterns);
45 
46     // Verify named pattern is generated with expected name.
47     patterns.insert<TestNamedPatternRule>(&getContext());
48 
49     applyPatternsGreedily(getFunction(), patterns);
50   }
51 };
52 } // end anonymous namespace
53 
54 //===----------------------------------------------------------------------===//
55 // ReturnType Driver.
56 //===----------------------------------------------------------------------===//
57 
58 namespace {
59 // Generate ops for each instance where the type can be successfully inferred.
60 template <typename OpTy>
61 static void invokeCreateWithInferredReturnType(Operation *op) {
62   auto *context = op->getContext();
63   auto fop = op->getParentOfType<FuncOp>();
64   auto location = UnknownLoc::get(context);
65   OpBuilder b(op);
66   b.setInsertionPointAfter(op);
67 
68   // Use permutations of 2 args as operands.
69   assert(fop.getNumArguments() >= 2);
70   for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
71     for (int j = 0; j < e; ++j) {
72       std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
73       SmallVector<Type, 2> inferredReturnTypes;
74       if (succeeded(OpTy::inferReturnTypes(context, llvm::None, values,
75                                            op->getAttrs(), op->getRegions(),
76                                            inferredReturnTypes))) {
77         OperationState state(location, OpTy::getOperationName());
78         // TODO(jpienaar): Expand to regions.
79         OpTy::build(&b, state, values, op->getAttrs());
80         (void)b.createOperation(state);
81       }
82     }
83   }
84 }
85 
86 static void reifyReturnShape(Operation *op) {
87   OpBuilder b(op);
88 
89   // Use permutations of 2 args as operands.
90   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
91   SmallVector<Value, 2> shapes;
92   if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
93     return;
94   for (auto it : llvm::enumerate(shapes))
95     op->emitRemark() << "value " << it.index() << ": "
96                      << it.value().getDefiningOp();
97 }
98 
99 struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
100   void runOnFunction() override {
101     if (getFunction().getName() == "testCreateFunctions") {
102       std::vector<Operation *> ops;
103       // Collect ops to avoid triggering on inserted ops.
104       for (auto &op : getFunction().getBody().front())
105         ops.push_back(&op);
106       // Generate test patterns for each, but skip terminator.
107       for (auto *op : llvm::makeArrayRef(ops).drop_back()) {
108         // Test create method of each of the Op classes below. The resultant
109         // output would be in reverse order underneath `op` from which
110         // the attributes and regions are used.
111         invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
112         invokeCreateWithInferredReturnType<
113             OpWithShapedTypeInferTypeInterfaceOp>(op);
114       };
115       return;
116     }
117     if (getFunction().getName() == "testReifyFunctions") {
118       std::vector<Operation *> ops;
119       // Collect ops to avoid triggering on inserted ops.
120       for (auto &op : getFunction().getBody().front())
121         if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
122           ops.push_back(&op);
123       // Generate test patterns for each, but skip terminator.
124       for (auto *op : ops)
125         reifyReturnShape(op);
126     }
127   }
128 };
129 } // end anonymous namespace
130 
131 //===----------------------------------------------------------------------===//
132 // Legalization Driver.
133 //===----------------------------------------------------------------------===//
134 
135 namespace {
136 //===----------------------------------------------------------------------===//
137 // Region-Block Rewrite Testing
138 
139 /// This pattern is a simple pattern that inlines the first region of a given
140 /// operation into the parent region.
141 struct TestRegionRewriteBlockMovement : public ConversionPattern {
142   TestRegionRewriteBlockMovement(MLIRContext *ctx)
143       : ConversionPattern("test.region", 1, ctx) {}
144 
145   LogicalResult
146   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
147                   ConversionPatternRewriter &rewriter) const final {
148     // Inline this region into the parent region.
149     auto &parentRegion = *op->getParentRegion();
150     if (op->getAttr("legalizer.should_clone"))
151       rewriter.cloneRegionBefore(op->getRegion(0), parentRegion,
152                                  parentRegion.end());
153     else
154       rewriter.inlineRegionBefore(op->getRegion(0), parentRegion,
155                                   parentRegion.end());
156 
157     // Drop this operation.
158     rewriter.eraseOp(op);
159     return success();
160   }
161 };
162 /// This pattern is a simple pattern that generates a region containing an
163 /// illegal operation.
164 struct TestRegionRewriteUndo : public RewritePattern {
165   TestRegionRewriteUndo(MLIRContext *ctx)
166       : RewritePattern("test.region_builder", 1, ctx) {}
167 
168   LogicalResult matchAndRewrite(Operation *op,
169                                 PatternRewriter &rewriter) const final {
170     // Create the region operation with an entry block containing arguments.
171     OperationState newRegion(op->getLoc(), "test.region");
172     newRegion.addRegion();
173     auto *regionOp = rewriter.createOperation(newRegion);
174     auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
175     entryBlock->addArgument(rewriter.getIntegerType(64));
176 
177     // Add an explicitly illegal operation to ensure the conversion fails.
178     rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
179     rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
180 
181     // Drop this operation.
182     rewriter.eraseOp(op);
183     return success();
184   }
185 };
186 
187 //===----------------------------------------------------------------------===//
188 // Type-Conversion Rewrite Testing
189 
190 /// This patterns erases a region operation that has had a type conversion.
191 struct TestDropOpSignatureConversion : public ConversionPattern {
192   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
193       : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) {
194   }
195   LogicalResult
196   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
197                   ConversionPatternRewriter &rewriter) const override {
198     Region &region = op->getRegion(0);
199     Block *entry = &region.front();
200 
201     // Convert the original entry arguments.
202     TypeConverter::SignatureConversion result(entry->getNumArguments());
203     for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i)
204       if (failed(converter.convertSignatureArg(
205               i, entry->getArgument(i).getType(), result)))
206         return failure();
207 
208     // Convert the region signature and just drop the operation.
209     rewriter.applySignatureConversion(&region, result);
210     rewriter.eraseOp(op);
211     return success();
212   }
213 
214   /// The type converter to use when rewriting the signature.
215   TypeConverter &converter;
216 };
217 /// This pattern simply updates the operands of the given operation.
218 struct TestPassthroughInvalidOp : public ConversionPattern {
219   TestPassthroughInvalidOp(MLIRContext *ctx)
220       : ConversionPattern("test.invalid", 1, ctx) {}
221   LogicalResult
222   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
223                   ConversionPatternRewriter &rewriter) const final {
224     rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
225                                              llvm::None);
226     return success();
227   }
228 };
229 /// This pattern handles the case of a split return value.
230 struct TestSplitReturnType : public ConversionPattern {
231   TestSplitReturnType(MLIRContext *ctx)
232       : ConversionPattern("test.return", 1, ctx) {}
233   LogicalResult
234   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
235                   ConversionPatternRewriter &rewriter) const final {
236     // Check for a return of F32.
237     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
238       return failure();
239 
240     // Check if the first operation is a cast operation, if it is we use the
241     // results directly.
242     auto *defOp = operands[0].getDefiningOp();
243     if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
244       rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
245       return success();
246     }
247 
248     // Otherwise, fail to match.
249     return failure();
250   }
251 };
252 
253 //===----------------------------------------------------------------------===//
254 // Multi-Level Type-Conversion Rewrite Testing
255 struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
256   TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
257       : ConversionPattern("test.type_producer", 1, ctx) {}
258   LogicalResult
259   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
260                   ConversionPatternRewriter &rewriter) const final {
261     // If the type is I32, change the type to F32.
262     if (!Type(*op->result_type_begin()).isSignlessInteger(32))
263       return failure();
264     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
265     return success();
266   }
267 };
268 struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
269   TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
270       : ConversionPattern("test.type_producer", 1, ctx) {}
271   LogicalResult
272   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
273                   ConversionPatternRewriter &rewriter) const final {
274     // If the type is F32, change the type to F64.
275     if (!Type(*op->result_type_begin()).isF32())
276       return rewriter.notifyMatchFailure(op, "expected single f32 operand");
277     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
278     return success();
279   }
280 };
281 struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
282   TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
283       : ConversionPattern("test.type_producer", 10, ctx) {}
284   LogicalResult
285   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
286                   ConversionPatternRewriter &rewriter) const final {
287     // Always convert to B16, even though it is not a legal type. This tests
288     // that values are unmapped correctly.
289     rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
290     return success();
291   }
292 };
293 struct TestUpdateConsumerType : public ConversionPattern {
294   TestUpdateConsumerType(MLIRContext *ctx)
295       : ConversionPattern("test.type_consumer", 1, ctx) {}
296   LogicalResult
297   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
298                   ConversionPatternRewriter &rewriter) const final {
299     // Verify that the incoming operand has been successfully remapped to F64.
300     if (!operands[0].getType().isF64())
301       return failure();
302     rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
303     return success();
304   }
305 };
306 
307 //===----------------------------------------------------------------------===//
308 // Non-Root Replacement Rewrite Testing
309 /// This pattern generates an invalid operation, but replaces it before the
310 /// pattern is finished. This checks that we don't need to legalize the
311 /// temporary op.
312 struct TestNonRootReplacement : public RewritePattern {
313   TestNonRootReplacement(MLIRContext *ctx)
314       : RewritePattern("test.replace_non_root", 1, ctx) {}
315 
316   LogicalResult matchAndRewrite(Operation *op,
317                                 PatternRewriter &rewriter) const final {
318     auto resultType = *op->result_type_begin();
319     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
320     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
321 
322     rewriter.replaceOp(illegalOp, {legalOp});
323     rewriter.replaceOp(op, {illegalOp});
324     return success();
325   }
326 };
327 } // namespace
328 
329 namespace {
330 struct TestTypeConverter : public TypeConverter {
331   using TypeConverter::TypeConverter;
332   TestTypeConverter() { addConversion(convertType); }
333 
334   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
335     // Drop I16 types.
336     if (t.isSignlessInteger(16))
337       return success();
338 
339     // Convert I64 to F64.
340     if (t.isSignlessInteger(64)) {
341       results.push_back(FloatType::getF64(t.getContext()));
342       return success();
343     }
344 
345     // Split F32 into F16,F16.
346     if (t.isF32()) {
347       results.assign(2, FloatType::getF16(t.getContext()));
348       return success();
349     }
350 
351     // Otherwise, convert the type directly.
352     results.push_back(t);
353     return success();
354   }
355 
356   /// Override the hook to materialize a conversion. This is necessary because
357   /// we generate 1->N type mappings.
358   Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
359                                    ArrayRef<Value> inputs,
360                                    Location loc) override {
361     return rewriter.create<TestCastOp>(loc, resultType, inputs);
362   }
363 };
364 
365 struct TestLegalizePatternDriver
366     : public ModulePass<TestLegalizePatternDriver> {
367   /// The mode of conversion to use with the driver.
368   enum class ConversionMode { Analysis, Full, Partial };
369 
370   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
371 
372   void runOnModule() override {
373     TestTypeConverter converter;
374     mlir::OwningRewritePatternList patterns;
375     populateWithGenerated(&getContext(), &patterns);
376     patterns
377         .insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
378                 TestPassthroughInvalidOp, TestSplitReturnType,
379                 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
380                 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
381                 TestNonRootReplacement>(&getContext());
382     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
383     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
384                                               converter);
385     mlir::populateCallOpTypeConversionPattern(patterns, &getContext(),
386                                               converter);
387 
388     // Define the conversion target used for the test.
389     ConversionTarget target(getContext());
390     target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
391     target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp>();
392     target
393         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
394     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
395       // Don't allow F32 operands.
396       return llvm::none_of(op.getOperandTypes(),
397                            [](Type type) { return type.isF32(); });
398     });
399     target.addDynamicallyLegalOp<FuncOp>(
400         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
401 
402     // Expect the type_producer/type_consumer operations to only operate on f64.
403     target.addDynamicallyLegalOp<TestTypeProducerOp>(
404         [](TestTypeProducerOp op) { return op.getType().isF64(); });
405     target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
406       return op.getOperand().getType().isF64();
407     });
408 
409     // Check support for marking certain operations as recursively legal.
410     target.markOpRecursivelyLegal<FuncOp, ModuleOp>([](Operation *op) {
411       return static_cast<bool>(
412           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
413     });
414 
415     // Handle a partial conversion.
416     if (mode == ConversionMode::Partial) {
417       (void)applyPartialConversion(getModule(), target, patterns, &converter);
418       return;
419     }
420 
421     // Handle a full conversion.
422     if (mode == ConversionMode::Full) {
423       // Check support for marking unknown operations as dynamically legal.
424       target.markUnknownOpDynamicallyLegal([](Operation *op) {
425         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
426       });
427 
428       (void)applyFullConversion(getModule(), target, patterns, &converter);
429       return;
430     }
431 
432     // Otherwise, handle an analysis conversion.
433     assert(mode == ConversionMode::Analysis);
434 
435     // Analyze the convertible operations.
436     DenseSet<Operation *> legalizedOps;
437     if (failed(applyAnalysisConversion(getModule(), target, patterns,
438                                        legalizedOps, &converter)))
439       return signalPassFailure();
440 
441     // Emit remarks for each legalizable operation.
442     for (auto *op : legalizedOps)
443       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
444   }
445 
446   /// The mode of conversion to use.
447   ConversionMode mode;
448 };
449 } // end anonymous namespace
450 
451 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
452     legalizerConversionMode(
453         "test-legalize-mode",
454         llvm::cl::desc("The legalization mode to use with the test driver"),
455         llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
456         llvm::cl::values(
457             clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
458                        "analysis", "Perform an analysis conversion"),
459             clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
460                        "Perform a full conversion"),
461             clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
462                        "partial", "Perform a partial conversion")));
463 
464 //===----------------------------------------------------------------------===//
465 // ConversionPatternRewriter::getRemappedValue testing. This method is used
466 // to get the remapped value of a original value that was replaced using
467 // ConversionPatternRewriter.
468 namespace {
469 /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
470 /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
471 /// operand twice.
472 ///
473 /// Example:
474 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0)
475 /// is replaced with:
476 ///   %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
477 struct OneVResOneVOperandOp1Converter
478     : public OpConversionPattern<OneVResOneVOperandOp1> {
479   using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
480 
481   LogicalResult
482   matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
483                   ConversionPatternRewriter &rewriter) const override {
484     auto origOps = op.getOperands();
485     assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
486            "One operand expected");
487     Value origOp = *origOps.begin();
488     SmallVector<Value, 2> remappedOperands;
489     // Replicate the remapped original operand twice. Note that we don't used
490     // the remapped 'operand' since the goal is testing 'getRemappedValue'.
491     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
492     remappedOperands.push_back(rewriter.getRemappedValue(origOp));
493 
494     rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
495                                                        remappedOperands);
496     return success();
497   }
498 };
499 
500 struct TestRemappedValue : public mlir::FunctionPass<TestRemappedValue> {
501   void runOnFunction() override {
502     mlir::OwningRewritePatternList patterns;
503     patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
504 
505     mlir::ConversionTarget target(getContext());
506     target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
507     // We make OneVResOneVOperandOp1 legal only when it has more that one
508     // operand. This will trigger the conversion that will replace one-operand
509     // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
510     target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
511         [](Operation *op) -> bool {
512           return std::distance(op->operand_begin(), op->operand_end()) > 1;
513         });
514 
515     if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) {
516       signalPassFailure();
517     }
518   }
519 };
520 } // end anonymous namespace
521 
522 namespace mlir {
523 void registerPatternsTestPass() {
524   mlir::PassRegistration<TestReturnTypeDriver>("test-return-type",
525                                                "Run return type functions");
526 
527   mlir::PassRegistration<TestPatternDriver>("test-patterns",
528                                             "Run test dialect patterns");
529 
530   mlir::PassRegistration<TestLegalizePatternDriver>(
531       "test-legalize-patterns", "Run test dialect legalization patterns", [] {
532         return std::make_unique<TestLegalizePatternDriver>(
533             legalizerConversionMode);
534       });
535 
536   PassRegistration<TestRemappedValue>(
537       "test-remapped-value",
538       "Test public remapped value mechanism in ConversionPatternRewriter");
539 }
540 } // namespace mlir
541