1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13 
14 using namespace mlir;
15 
16 /// Custom constraint invoked from PDL.
customSingleEntityConstraint(PatternRewriter & rewriter,Operation * rootOp)17 static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
18                                                   Operation *rootOp) {
19   return success(rootOp->getName().getStringRef() == "test.op");
20 }
customMultiEntityConstraint(PatternRewriter & rewriter,Operation * root,Operation * rootCopy)21 static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
22                                                  Operation *root,
23                                                  Operation *rootCopy) {
24   return customSingleEntityConstraint(rewriter, rootCopy);
25 }
customMultiEntityVariadicConstraint(PatternRewriter & rewriter,ValueRange operandValues,TypeRange typeValues)26 static LogicalResult customMultiEntityVariadicConstraint(
27     PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
28   if (operandValues.size() != 2 || typeValues.size() != 2)
29     return failure();
30   return success();
31 }
32 
33 // Custom creator invoked from PDL.
customCreate(PatternRewriter & rewriter,Operation * op)34 static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
35   return rewriter.create(OperationState(op->getLoc(), "test.success"));
36 }
customVariadicResultCreate(PatternRewriter & rewriter,Operation * root)37 static auto customVariadicResultCreate(PatternRewriter &rewriter,
38                                        Operation *root) {
39   return std::make_pair(root->getOperands(), root->getOperands().getTypes());
40 }
customCreateType(PatternRewriter & rewriter)41 static Type customCreateType(PatternRewriter &rewriter) {
42   return rewriter.getF32Type();
43 }
customCreateStrAttr(PatternRewriter & rewriter)44 static std::string customCreateStrAttr(PatternRewriter &rewriter) {
45   return "test.str";
46 }
47 
48 /// Custom rewriter invoked from PDL.
customRewriter(PatternRewriter & rewriter,Operation * root,Value input)49 static void customRewriter(PatternRewriter &rewriter, Operation *root,
50                            Value input) {
51   rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
52                   input);
53   rewriter.eraseOp(root);
54 }
55 
56 namespace {
57 struct TestPDLByteCodePass
58     : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anondf724a480111::TestPDLByteCodePass59   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass)
60 
61   StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
getDescription__anondf724a480111::TestPDLByteCodePass62   StringRef getDescription() const final {
63     return "Test PDL ByteCode functionality";
64   }
getDependentDialects__anondf724a480111::TestPDLByteCodePass65   void getDependentDialects(DialectRegistry &registry) const override {
66     // Mark the pdl_interp dialect as a dependent. This is needed, because we
67     // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
68     registry.insert<pdl_interp::PDLInterpDialect>();
69   }
runOnOperation__anondf724a480111::TestPDLByteCodePass70   void runOnOperation() final {
71     ModuleOp module = getOperation();
72 
73     // The test cases are encompassed via two modules, one containing the
74     // patterns and one containing the operations to rewrite.
75     ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
76         StringAttr::get(module->getContext(), "patterns"));
77     ModuleOp irModule = module.lookupSymbol<ModuleOp>(
78         StringAttr::get(module->getContext(), "ir"));
79     if (!patternModule || !irModule)
80       return;
81 
82     RewritePatternSet patternList(module->getContext());
83 
84     // Register ahead of time to test when functions are registered without a
85     // pattern.
86     patternList.getPDLPatterns().registerConstraintFunction(
87         "multi_entity_constraint", customMultiEntityConstraint);
88     patternList.getPDLPatterns().registerConstraintFunction(
89         "single_entity_constraint", customSingleEntityConstraint);
90 
91     // Process the pattern module.
92     patternModule.getOperation()->remove();
93     PDLPatternModule pdlPattern(patternModule);
94 
95     // Note: This constraint was already registered, but we re-register here to
96     // ensure that duplication registration is allowed (the duplicate mapping
97     // will be ignored). This tests that we support separating the registration
98     // of library functions from the construction of patterns, and also that we
99     // allow multiple patterns to depend on the same library functions (without
100     // asserting/crashing).
101     pdlPattern.registerConstraintFunction("multi_entity_constraint",
102                                           customMultiEntityConstraint);
103     pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
104                                           customMultiEntityVariadicConstraint);
105     pdlPattern.registerRewriteFunction("creator", customCreate);
106     pdlPattern.registerRewriteFunction("var_creator",
107                                        customVariadicResultCreate);
108     pdlPattern.registerRewriteFunction("type_creator", customCreateType);
109     pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr);
110     pdlPattern.registerRewriteFunction("rewriter", customRewriter);
111     patternList.add(std::move(pdlPattern));
112 
113     // Invoke the pattern driver with the provided patterns.
114     (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
115                                        std::move(patternList));
116   }
117 };
118 } // namespace
119 
120 namespace mlir {
121 namespace test {
registerTestPDLByteCodePass()122 void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
123 } // namespace test
124 } // namespace mlir
125