1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13 
14 using namespace mlir;
15 
16 /// Custom constraint invoked from PDL.
17 static LogicalResult customSingleEntityConstraint(PDLValue value,
18                                                   PatternRewriter &rewriter) {
19   Operation *rootOp = value.cast<Operation *>();
20   return success(rootOp->getName().getStringRef() == "test.op");
21 }
22 static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
23                                                  PatternRewriter &rewriter) {
24   return customSingleEntityConstraint(values[1], rewriter);
25 }
26 static LogicalResult
27 customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
28                                     PatternRewriter &rewriter) {
29   if (llvm::any_of(values, [](const PDLValue &value) { return !value; }))
30     return failure();
31   ValueRange operandValues = values[0].cast<ValueRange>();
32   TypeRange typeValues = values[1].cast<TypeRange>();
33   if (operandValues.size() != 2 || typeValues.size() != 2)
34     return failure();
35   return success();
36 }
37 
38 // Custom creator invoked from PDL.
39 static void customCreate(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
40                          PDLResultList &results) {
41   results.push_back(rewriter.create(
42       OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
43 }
44 
45 static void customVariadicResultCreate(ArrayRef<PDLValue> args,
46                                        PatternRewriter &rewriter,
47                                        PDLResultList &results) {
48   Operation *root = args[0].cast<Operation *>();
49   results.push_back(root->getOperands());
50   results.push_back(root->getOperands().getTypes());
51 }
52 static void customCreateType(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
53                              PDLResultList &results) {
54   results.push_back(rewriter.getF32Type());
55 }
56 
57 /// Custom rewriter invoked from PDL.
58 static void customRewriter(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
59                            PDLResultList &results) {
60   Operation *root = args[0].cast<Operation *>();
61   OperationState successOpState(root->getLoc(), "test.success");
62   successOpState.addOperands(args[1].cast<Value>());
63   rewriter.create(successOpState);
64   rewriter.eraseOp(root);
65 }
66 
67 namespace {
68 struct TestPDLByteCodePass
69     : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
70   StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
71   StringRef getDescription() const final {
72     return "Test PDL ByteCode functionality";
73   }
74   void getDependentDialects(DialectRegistry &registry) const override {
75     // Mark the pdl_interp dialect as a dependent. This is needed, because we
76     // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
77     registry.insert<pdl_interp::PDLInterpDialect>();
78   }
79   void runOnOperation() final {
80     ModuleOp module = getOperation();
81 
82     // The test cases are encompassed via two modules, one containing the
83     // patterns and one containing the operations to rewrite.
84     ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
85         StringAttr::get(module->getContext(), "patterns"));
86     ModuleOp irModule = module.lookupSymbol<ModuleOp>(
87         StringAttr::get(module->getContext(), "ir"));
88     if (!patternModule || !irModule)
89       return;
90 
91     RewritePatternSet patternList(module->getContext());
92 
93     // Register ahead of time to test when functions are registered without a
94     // pattern.
95     patternList.getPDLPatterns().registerConstraintFunction(
96         "multi_entity_constraint", customMultiEntityConstraint);
97     patternList.getPDLPatterns().registerConstraintFunction(
98         "single_entity_constraint", customSingleEntityConstraint);
99 
100     // Process the pattern module.
101     patternModule.getOperation()->remove();
102     PDLPatternModule pdlPattern(patternModule);
103 
104     // Note: This constraint was already registered, but we re-register here to
105     // ensure that duplication registration is allowed (the duplicate mapping
106     // will be ignored). This tests that we support separating the registration
107     // of library functions from the construction of patterns, and also that we
108     // allow multiple patterns to depend on the same library functions (without
109     // asserting/crashing).
110     pdlPattern.registerConstraintFunction("multi_entity_constraint",
111                                           customMultiEntityConstraint);
112     pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
113                                           customMultiEntityVariadicConstraint);
114     pdlPattern.registerRewriteFunction("creator", customCreate);
115     pdlPattern.registerRewriteFunction("var_creator",
116                                        customVariadicResultCreate);
117     pdlPattern.registerRewriteFunction("type_creator", customCreateType);
118     pdlPattern.registerRewriteFunction("rewriter", customRewriter);
119     patternList.add(std::move(pdlPattern));
120 
121     // Invoke the pattern driver with the provided patterns.
122     (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
123                                        std::move(patternList));
124   }
125 };
126 } // namespace
127 
128 namespace mlir {
129 namespace test {
130 void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
131 } // namespace test
132 } // namespace mlir
133