1abfd1a8bSRiver Riddle //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2abfd1a8bSRiver Riddle //
3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6abfd1a8bSRiver Riddle //
7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
8abfd1a8bSRiver Riddle 
9b4130e9eSStanislav Funiak #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10abfd1a8bSRiver Riddle #include "mlir/Pass/Pass.h"
11abfd1a8bSRiver Riddle #include "mlir/Pass/PassManager.h"
12abfd1a8bSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13abfd1a8bSRiver Riddle 
14abfd1a8bSRiver Riddle using namespace mlir;
15abfd1a8bSRiver Riddle 
16abfd1a8bSRiver Riddle /// Custom constraint invoked from PDL.
customSingleEntityConstraint(PatternRewriter & rewriter,Operation * rootOp)17*ea64828aSRiver Riddle static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
18*ea64828aSRiver Riddle                                                   Operation *rootOp) {
19abfd1a8bSRiver Riddle   return success(rootOp->getName().getStringRef() == "test.op");
20abfd1a8bSRiver Riddle }
customMultiEntityConstraint(PatternRewriter & rewriter,Operation * root,Operation * rootCopy)21*ea64828aSRiver Riddle static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
22*ea64828aSRiver Riddle                                                  Operation *root,
23*ea64828aSRiver Riddle                                                  Operation *rootCopy) {
24*ea64828aSRiver Riddle   return customSingleEntityConstraint(rewriter, rootCopy);
25abfd1a8bSRiver Riddle }
customMultiEntityVariadicConstraint(PatternRewriter & rewriter,ValueRange operandValues,TypeRange typeValues)26*ea64828aSRiver Riddle static LogicalResult customMultiEntityVariadicConstraint(
27*ea64828aSRiver Riddle     PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
2885ab413bSRiver Riddle   if (operandValues.size() != 2 || typeValues.size() != 2)
2985ab413bSRiver Riddle     return failure();
3085ab413bSRiver Riddle   return success();
3185ab413bSRiver Riddle }
32abfd1a8bSRiver Riddle 
33abfd1a8bSRiver Riddle // Custom creator invoked from PDL.
customCreate(PatternRewriter & rewriter,Operation * op)34*ea64828aSRiver Riddle static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
35*ea64828aSRiver Riddle   return rewriter.create(OperationState(op->getLoc(), "test.success"));
36abfd1a8bSRiver Riddle }
customVariadicResultCreate(PatternRewriter & rewriter,Operation * root)37*ea64828aSRiver Riddle static auto customVariadicResultCreate(PatternRewriter &rewriter,
38*ea64828aSRiver Riddle                                        Operation *root) {
39*ea64828aSRiver Riddle   return std::make_pair(root->getOperands(), root->getOperands().getTypes());
4085ab413bSRiver Riddle }
customCreateType(PatternRewriter & rewriter)41*ea64828aSRiver Riddle static Type customCreateType(PatternRewriter &rewriter) {
42*ea64828aSRiver Riddle   return rewriter.getF32Type();
43*ea64828aSRiver Riddle }
customCreateStrAttr(PatternRewriter & rewriter)44*ea64828aSRiver Riddle static std::string customCreateStrAttr(PatternRewriter &rewriter) {
45*ea64828aSRiver Riddle   return "test.str";
4685ab413bSRiver Riddle }
47abfd1a8bSRiver Riddle 
48abfd1a8bSRiver Riddle /// Custom rewriter invoked from PDL.
customRewriter(PatternRewriter & rewriter,Operation * root,Value input)49*ea64828aSRiver Riddle static void customRewriter(PatternRewriter &rewriter, Operation *root,
50*ea64828aSRiver Riddle                            Value input) {
51*ea64828aSRiver Riddle   rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
52*ea64828aSRiver Riddle                   input);
53abfd1a8bSRiver Riddle   rewriter.eraseOp(root);
54abfd1a8bSRiver Riddle }
55abfd1a8bSRiver Riddle 
56abfd1a8bSRiver Riddle namespace {
57abfd1a8bSRiver Riddle struct TestPDLByteCodePass
58abfd1a8bSRiver Riddle     : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anondf724a480111::TestPDLByteCodePass595e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass)
605e50dd04SRiver Riddle 
61b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
getDescription__anondf724a480111::TestPDLByteCodePass62b5e22e6dSMehdi Amini   StringRef getDescription() const final {
63b5e22e6dSMehdi Amini     return "Test PDL ByteCode functionality";
64b5e22e6dSMehdi Amini   }
getDependentDialects__anondf724a480111::TestPDLByteCodePass65b4130e9eSStanislav Funiak   void getDependentDialects(DialectRegistry &registry) const override {
66b4130e9eSStanislav Funiak     // Mark the pdl_interp dialect as a dependent. This is needed, because we
67b4130e9eSStanislav Funiak     // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
68b4130e9eSStanislav Funiak     registry.insert<pdl_interp::PDLInterpDialect>();
69b4130e9eSStanislav Funiak   }
runOnOperation__anondf724a480111::TestPDLByteCodePass70abfd1a8bSRiver Riddle   void runOnOperation() final {
71abfd1a8bSRiver Riddle     ModuleOp module = getOperation();
72abfd1a8bSRiver Riddle 
73abfd1a8bSRiver Riddle     // The test cases are encompassed via two modules, one containing the
74abfd1a8bSRiver Riddle     // patterns and one containing the operations to rewrite.
7541d4aa7dSChris Lattner     ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
7641d4aa7dSChris Lattner         StringAttr::get(module->getContext(), "patterns"));
7741d4aa7dSChris Lattner     ModuleOp irModule = module.lookupSymbol<ModuleOp>(
7841d4aa7dSChris Lattner         StringAttr::get(module->getContext(), "ir"));
79abfd1a8bSRiver Riddle     if (!patternModule || !irModule)
80abfd1a8bSRiver Riddle       return;
81abfd1a8bSRiver Riddle 
8206c3b9c7SRiver Riddle     RewritePatternSet patternList(module->getContext());
8306c3b9c7SRiver Riddle 
8406c3b9c7SRiver Riddle     // Register ahead of time to test when functions are registered without a
8506c3b9c7SRiver Riddle     // pattern.
8606c3b9c7SRiver Riddle     patternList.getPDLPatterns().registerConstraintFunction(
8706c3b9c7SRiver Riddle         "multi_entity_constraint", customMultiEntityConstraint);
8806c3b9c7SRiver Riddle     patternList.getPDLPatterns().registerConstraintFunction(
8906c3b9c7SRiver Riddle         "single_entity_constraint", customSingleEntityConstraint);
9006c3b9c7SRiver Riddle 
91abfd1a8bSRiver Riddle     // Process the pattern module.
92abfd1a8bSRiver Riddle     patternModule.getOperation()->remove();
93abfd1a8bSRiver Riddle     PDLPatternModule pdlPattern(patternModule);
9406c3b9c7SRiver Riddle 
9506c3b9c7SRiver Riddle     // Note: This constraint was already registered, but we re-register here to
9606c3b9c7SRiver Riddle     // ensure that duplication registration is allowed (the duplicate mapping
9706c3b9c7SRiver Riddle     // will be ignored). This tests that we support separating the registration
9806c3b9c7SRiver Riddle     // of library functions from the construction of patterns, and also that we
9906c3b9c7SRiver Riddle     // allow multiple patterns to depend on the same library functions (without
10006c3b9c7SRiver Riddle     // asserting/crashing).
101abfd1a8bSRiver Riddle     pdlPattern.registerConstraintFunction("multi_entity_constraint",
102abfd1a8bSRiver Riddle                                           customMultiEntityConstraint);
10385ab413bSRiver Riddle     pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
10485ab413bSRiver Riddle                                           customMultiEntityVariadicConstraint);
10502c4c0d5SRiver Riddle     pdlPattern.registerRewriteFunction("creator", customCreate);
10685ab413bSRiver Riddle     pdlPattern.registerRewriteFunction("var_creator",
10785ab413bSRiver Riddle                                        customVariadicResultCreate);
10885ab413bSRiver Riddle     pdlPattern.registerRewriteFunction("type_creator", customCreateType);
109*ea64828aSRiver Riddle     pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr);
110abfd1a8bSRiver Riddle     pdlPattern.registerRewriteFunction("rewriter", customRewriter);
11106c3b9c7SRiver Riddle     patternList.add(std::move(pdlPattern));
112abfd1a8bSRiver Riddle 
113abfd1a8bSRiver Riddle     // Invoke the pattern driver with the provided patterns.
114abfd1a8bSRiver Riddle     (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
115abfd1a8bSRiver Riddle                                        std::move(patternList));
116abfd1a8bSRiver Riddle   }
117abfd1a8bSRiver Riddle };
118be0a7e9fSMehdi Amini } // namespace
119abfd1a8bSRiver Riddle 
120abfd1a8bSRiver Riddle namespace mlir {
121abfd1a8bSRiver Riddle namespace test {
registerTestPDLByteCodePass()122b5e22e6dSMehdi Amini void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
123abfd1a8bSRiver Riddle } // namespace test
124abfd1a8bSRiver Riddle } // namespace mlir
125