1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Pass/Pass.h"
10 #include "mlir/Pass/PassManager.h"
11 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
12 
13 using namespace mlir;
14 
15 /// Custom constraint invoked from PDL.
16 static LogicalResult customSingleEntityConstraint(PDLValue value,
17                                                   ArrayAttr constantParams,
18                                                   PatternRewriter &rewriter) {
19   Operation *rootOp = value.cast<Operation *>();
20   return success(rootOp->getName().getStringRef() == "test.op");
21 }
22 static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
23                                                  ArrayAttr constantParams,
24                                                  PatternRewriter &rewriter) {
25   return customSingleEntityConstraint(values[1], constantParams, rewriter);
26 }
27 static LogicalResult
28 customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
29                                     ArrayAttr constantParams,
30                                     PatternRewriter &rewriter) {
31   if (llvm::any_of(values, [](const PDLValue &value) { return !value; }))
32     return failure();
33   ValueRange operandValues = values[0].cast<ValueRange>();
34   TypeRange typeValues = values[1].cast<TypeRange>();
35   if (operandValues.size() != 2 || typeValues.size() != 2)
36     return failure();
37   return success();
38 }
39 
40 // Custom creator invoked from PDL.
41 static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
42                          PatternRewriter &rewriter, PDLResultList &results) {
43   results.push_back(rewriter.createOperation(
44       OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
45 }
46 static void customVariadicResultCreate(ArrayRef<PDLValue> args,
47                                        ArrayAttr constantParams,
48                                        PatternRewriter &rewriter,
49                                        PDLResultList &results) {
50   Operation *root = args[0].cast<Operation *>();
51   results.push_back(root->getOperands());
52   results.push_back(root->getOperands().getTypes());
53 }
54 static void customCreateType(ArrayRef<PDLValue> args, ArrayAttr constantParams,
55                              PatternRewriter &rewriter,
56                              PDLResultList &results) {
57   results.push_back(rewriter.getF32Type());
58 }
59 
60 /// Custom rewriter invoked from PDL.
61 static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
62                            PatternRewriter &rewriter, PDLResultList &results) {
63   Operation *root = args[0].cast<Operation *>();
64   OperationState successOpState(root->getLoc(), "test.success");
65   successOpState.addOperands(args[1].cast<Value>());
66   successOpState.addAttribute("constantParams", constantParams);
67   rewriter.createOperation(successOpState);
68   rewriter.eraseOp(root);
69 }
70 
71 namespace {
72 struct TestPDLByteCodePass
73     : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
74   StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
75   StringRef getDescription() const final {
76     return "Test PDL ByteCode functionality";
77   }
78   void runOnOperation() final {
79     ModuleOp module = getOperation();
80 
81     // The test cases are encompassed via two modules, one containing the
82     // patterns and one containing the operations to rewrite.
83     ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
84         StringAttr::get(module->getContext(), "patterns"));
85     ModuleOp irModule = module.lookupSymbol<ModuleOp>(
86         StringAttr::get(module->getContext(), "ir"));
87     if (!patternModule || !irModule)
88       return;
89 
90     // Process the pattern module.
91     patternModule.getOperation()->remove();
92     PDLPatternModule pdlPattern(patternModule);
93     pdlPattern.registerConstraintFunction("multi_entity_constraint",
94                                           customMultiEntityConstraint);
95     pdlPattern.registerConstraintFunction("single_entity_constraint",
96                                           customSingleEntityConstraint);
97     pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
98                                           customMultiEntityVariadicConstraint);
99     pdlPattern.registerRewriteFunction("creator", customCreate);
100     pdlPattern.registerRewriteFunction("var_creator",
101                                        customVariadicResultCreate);
102     pdlPattern.registerRewriteFunction("type_creator", customCreateType);
103     pdlPattern.registerRewriteFunction("rewriter", customRewriter);
104 
105     RewritePatternSet patternList(std::move(pdlPattern));
106 
107     // Invoke the pattern driver with the provided patterns.
108     (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
109                                        std::move(patternList));
110   }
111 };
112 } // namespace
113 
114 namespace mlir {
115 namespace test {
116 void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
117 } // namespace test
118 } // namespace mlir
119