1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13 
14 using namespace mlir;
15 
16 /// Custom constraint invoked from PDL.
17 static LogicalResult customSingleEntityConstraint(PDLValue value,
18                                                   ArrayAttr constantParams,
19                                                   PatternRewriter &rewriter) {
20   Operation *rootOp = value.cast<Operation *>();
21   return success(rootOp->getName().getStringRef() == "test.op");
22 }
23 static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
24                                                  ArrayAttr constantParams,
25                                                  PatternRewriter &rewriter) {
26   return customSingleEntityConstraint(values[1], constantParams, rewriter);
27 }
28 static LogicalResult
29 customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
30                                     ArrayAttr constantParams,
31                                     PatternRewriter &rewriter) {
32   if (llvm::any_of(values, [](const PDLValue &value) { return !value; }))
33     return failure();
34   ValueRange operandValues = values[0].cast<ValueRange>();
35   TypeRange typeValues = values[1].cast<TypeRange>();
36   if (operandValues.size() != 2 || typeValues.size() != 2)
37     return failure();
38   return success();
39 }
40 
41 // Custom creator invoked from PDL.
42 static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
43                          PatternRewriter &rewriter, PDLResultList &results) {
44   results.push_back(rewriter.createOperation(
45       OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
46 }
47 static void customVariadicResultCreate(ArrayRef<PDLValue> args,
48                                        ArrayAttr constantParams,
49                                        PatternRewriter &rewriter,
50                                        PDLResultList &results) {
51   Operation *root = args[0].cast<Operation *>();
52   results.push_back(root->getOperands());
53   results.push_back(root->getOperands().getTypes());
54 }
55 static void customCreateType(ArrayRef<PDLValue> args, ArrayAttr constantParams,
56                              PatternRewriter &rewriter,
57                              PDLResultList &results) {
58   results.push_back(rewriter.getF32Type());
59 }
60 
61 /// Custom rewriter invoked from PDL.
62 static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
63                            PatternRewriter &rewriter, PDLResultList &results) {
64   Operation *root = args[0].cast<Operation *>();
65   OperationState successOpState(root->getLoc(), "test.success");
66   successOpState.addOperands(args[1].cast<Value>());
67   successOpState.addAttribute("constantParams", constantParams);
68   rewriter.createOperation(successOpState);
69   rewriter.eraseOp(root);
70 }
71 
72 namespace {
73 struct TestPDLByteCodePass
74     : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
75   StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
76   StringRef getDescription() const final {
77     return "Test PDL ByteCode functionality";
78   }
79   void getDependentDialects(DialectRegistry &registry) const override {
80     // Mark the pdl_interp dialect as a dependent. This is needed, because we
81     // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
82     registry.insert<pdl_interp::PDLInterpDialect>();
83   }
84   void runOnOperation() final {
85     ModuleOp module = getOperation();
86 
87     // The test cases are encompassed via two modules, one containing the
88     // patterns and one containing the operations to rewrite.
89     ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
90         StringAttr::get(module->getContext(), "patterns"));
91     ModuleOp irModule = module.lookupSymbol<ModuleOp>(
92         StringAttr::get(module->getContext(), "ir"));
93     if (!patternModule || !irModule)
94       return;
95 
96     RewritePatternSet patternList(module->getContext());
97 
98     // Register ahead of time to test when functions are registered without a
99     // pattern.
100     patternList.getPDLPatterns().registerConstraintFunction(
101         "multi_entity_constraint", customMultiEntityConstraint);
102     patternList.getPDLPatterns().registerConstraintFunction(
103         "single_entity_constraint", customSingleEntityConstraint);
104 
105     // Process the pattern module.
106     patternModule.getOperation()->remove();
107     PDLPatternModule pdlPattern(patternModule);
108 
109     // Note: This constraint was already registered, but we re-register here to
110     // ensure that duplication registration is allowed (the duplicate mapping
111     // will be ignored). This tests that we support separating the registration
112     // of library functions from the construction of patterns, and also that we
113     // allow multiple patterns to depend on the same library functions (without
114     // asserting/crashing).
115     pdlPattern.registerConstraintFunction("multi_entity_constraint",
116                                           customMultiEntityConstraint);
117     pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
118                                           customMultiEntityVariadicConstraint);
119     pdlPattern.registerRewriteFunction("creator", customCreate);
120     pdlPattern.registerRewriteFunction("var_creator",
121                                        customVariadicResultCreate);
122     pdlPattern.registerRewriteFunction("type_creator", customCreateType);
123     pdlPattern.registerRewriteFunction("rewriter", customRewriter);
124     patternList.add(std::move(pdlPattern));
125 
126     // Invoke the pattern driver with the provided patterns.
127     (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
128                                        std::move(patternList));
129   }
130 };
131 } // namespace
132 
133 namespace mlir {
134 namespace test {
135 void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
136 } // namespace test
137 } // namespace mlir
138