1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13
14 using namespace mlir;
15
16 /// Custom constraint invoked from PDL.
customSingleEntityConstraint(PatternRewriter & rewriter,Operation * rootOp)17 static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
18 Operation *rootOp) {
19 return success(rootOp->getName().getStringRef() == "test.op");
20 }
customMultiEntityConstraint(PatternRewriter & rewriter,Operation * root,Operation * rootCopy)21 static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
22 Operation *root,
23 Operation *rootCopy) {
24 return customSingleEntityConstraint(rewriter, rootCopy);
25 }
customMultiEntityVariadicConstraint(PatternRewriter & rewriter,ValueRange operandValues,TypeRange typeValues)26 static LogicalResult customMultiEntityVariadicConstraint(
27 PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
28 if (operandValues.size() != 2 || typeValues.size() != 2)
29 return failure();
30 return success();
31 }
32
33 // Custom creator invoked from PDL.
customCreate(PatternRewriter & rewriter,Operation * op)34 static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
35 return rewriter.create(OperationState(op->getLoc(), "test.success"));
36 }
customVariadicResultCreate(PatternRewriter & rewriter,Operation * root)37 static auto customVariadicResultCreate(PatternRewriter &rewriter,
38 Operation *root) {
39 return std::make_pair(root->getOperands(), root->getOperands().getTypes());
40 }
customCreateType(PatternRewriter & rewriter)41 static Type customCreateType(PatternRewriter &rewriter) {
42 return rewriter.getF32Type();
43 }
customCreateStrAttr(PatternRewriter & rewriter)44 static std::string customCreateStrAttr(PatternRewriter &rewriter) {
45 return "test.str";
46 }
47
48 /// Custom rewriter invoked from PDL.
customRewriter(PatternRewriter & rewriter,Operation * root,Value input)49 static void customRewriter(PatternRewriter &rewriter, Operation *root,
50 Value input) {
51 rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
52 input);
53 rewriter.eraseOp(root);
54 }
55
56 namespace {
57 struct TestPDLByteCodePass
58 : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anondf724a480111::TestPDLByteCodePass59 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass)
60
61 StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
getDescription__anondf724a480111::TestPDLByteCodePass62 StringRef getDescription() const final {
63 return "Test PDL ByteCode functionality";
64 }
getDependentDialects__anondf724a480111::TestPDLByteCodePass65 void getDependentDialects(DialectRegistry ®istry) const override {
66 // Mark the pdl_interp dialect as a dependent. This is needed, because we
67 // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
68 registry.insert<pdl_interp::PDLInterpDialect>();
69 }
runOnOperation__anondf724a480111::TestPDLByteCodePass70 void runOnOperation() final {
71 ModuleOp module = getOperation();
72
73 // The test cases are encompassed via two modules, one containing the
74 // patterns and one containing the operations to rewrite.
75 ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
76 StringAttr::get(module->getContext(), "patterns"));
77 ModuleOp irModule = module.lookupSymbol<ModuleOp>(
78 StringAttr::get(module->getContext(), "ir"));
79 if (!patternModule || !irModule)
80 return;
81
82 RewritePatternSet patternList(module->getContext());
83
84 // Register ahead of time to test when functions are registered without a
85 // pattern.
86 patternList.getPDLPatterns().registerConstraintFunction(
87 "multi_entity_constraint", customMultiEntityConstraint);
88 patternList.getPDLPatterns().registerConstraintFunction(
89 "single_entity_constraint", customSingleEntityConstraint);
90
91 // Process the pattern module.
92 patternModule.getOperation()->remove();
93 PDLPatternModule pdlPattern(patternModule);
94
95 // Note: This constraint was already registered, but we re-register here to
96 // ensure that duplication registration is allowed (the duplicate mapping
97 // will be ignored). This tests that we support separating the registration
98 // of library functions from the construction of patterns, and also that we
99 // allow multiple patterns to depend on the same library functions (without
100 // asserting/crashing).
101 pdlPattern.registerConstraintFunction("multi_entity_constraint",
102 customMultiEntityConstraint);
103 pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
104 customMultiEntityVariadicConstraint);
105 pdlPattern.registerRewriteFunction("creator", customCreate);
106 pdlPattern.registerRewriteFunction("var_creator",
107 customVariadicResultCreate);
108 pdlPattern.registerRewriteFunction("type_creator", customCreateType);
109 pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr);
110 pdlPattern.registerRewriteFunction("rewriter", customRewriter);
111 patternList.add(std::move(pdlPattern));
112
113 // Invoke the pattern driver with the provided patterns.
114 (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
115 std::move(patternList));
116 }
117 };
118 } // namespace
119
120 namespace mlir {
121 namespace test {
registerTestPDLByteCodePass()122 void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
123 } // namespace test
124 } // namespace mlir
125