1abfd1a8bSRiver Riddle //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2abfd1a8bSRiver Riddle //
3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6abfd1a8bSRiver Riddle //
7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
8abfd1a8bSRiver Riddle
9b4130e9eSStanislav Funiak #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10abfd1a8bSRiver Riddle #include "mlir/Pass/Pass.h"
11abfd1a8bSRiver Riddle #include "mlir/Pass/PassManager.h"
12abfd1a8bSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13abfd1a8bSRiver Riddle
14abfd1a8bSRiver Riddle using namespace mlir;
15abfd1a8bSRiver Riddle
16abfd1a8bSRiver Riddle /// Custom constraint invoked from PDL.
customSingleEntityConstraint(PatternRewriter & rewriter,Operation * rootOp)17*ea64828aSRiver Riddle static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
18*ea64828aSRiver Riddle Operation *rootOp) {
19abfd1a8bSRiver Riddle return success(rootOp->getName().getStringRef() == "test.op");
20abfd1a8bSRiver Riddle }
customMultiEntityConstraint(PatternRewriter & rewriter,Operation * root,Operation * rootCopy)21*ea64828aSRiver Riddle static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
22*ea64828aSRiver Riddle Operation *root,
23*ea64828aSRiver Riddle Operation *rootCopy) {
24*ea64828aSRiver Riddle return customSingleEntityConstraint(rewriter, rootCopy);
25abfd1a8bSRiver Riddle }
customMultiEntityVariadicConstraint(PatternRewriter & rewriter,ValueRange operandValues,TypeRange typeValues)26*ea64828aSRiver Riddle static LogicalResult customMultiEntityVariadicConstraint(
27*ea64828aSRiver Riddle PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
2885ab413bSRiver Riddle if (operandValues.size() != 2 || typeValues.size() != 2)
2985ab413bSRiver Riddle return failure();
3085ab413bSRiver Riddle return success();
3185ab413bSRiver Riddle }
32abfd1a8bSRiver Riddle
33abfd1a8bSRiver Riddle // Custom creator invoked from PDL.
customCreate(PatternRewriter & rewriter,Operation * op)34*ea64828aSRiver Riddle static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
35*ea64828aSRiver Riddle return rewriter.create(OperationState(op->getLoc(), "test.success"));
36abfd1a8bSRiver Riddle }
customVariadicResultCreate(PatternRewriter & rewriter,Operation * root)37*ea64828aSRiver Riddle static auto customVariadicResultCreate(PatternRewriter &rewriter,
38*ea64828aSRiver Riddle Operation *root) {
39*ea64828aSRiver Riddle return std::make_pair(root->getOperands(), root->getOperands().getTypes());
4085ab413bSRiver Riddle }
customCreateType(PatternRewriter & rewriter)41*ea64828aSRiver Riddle static Type customCreateType(PatternRewriter &rewriter) {
42*ea64828aSRiver Riddle return rewriter.getF32Type();
43*ea64828aSRiver Riddle }
customCreateStrAttr(PatternRewriter & rewriter)44*ea64828aSRiver Riddle static std::string customCreateStrAttr(PatternRewriter &rewriter) {
45*ea64828aSRiver Riddle return "test.str";
4685ab413bSRiver Riddle }
47abfd1a8bSRiver Riddle
48abfd1a8bSRiver Riddle /// Custom rewriter invoked from PDL.
customRewriter(PatternRewriter & rewriter,Operation * root,Value input)49*ea64828aSRiver Riddle static void customRewriter(PatternRewriter &rewriter, Operation *root,
50*ea64828aSRiver Riddle Value input) {
51*ea64828aSRiver Riddle rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
52*ea64828aSRiver Riddle input);
53abfd1a8bSRiver Riddle rewriter.eraseOp(root);
54abfd1a8bSRiver Riddle }
55abfd1a8bSRiver Riddle
56abfd1a8bSRiver Riddle namespace {
57abfd1a8bSRiver Riddle struct TestPDLByteCodePass
58abfd1a8bSRiver Riddle : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anondf724a480111::TestPDLByteCodePass595e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass)
605e50dd04SRiver Riddle
61b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
getDescription__anondf724a480111::TestPDLByteCodePass62b5e22e6dSMehdi Amini StringRef getDescription() const final {
63b5e22e6dSMehdi Amini return "Test PDL ByteCode functionality";
64b5e22e6dSMehdi Amini }
getDependentDialects__anondf724a480111::TestPDLByteCodePass65b4130e9eSStanislav Funiak void getDependentDialects(DialectRegistry ®istry) const override {
66b4130e9eSStanislav Funiak // Mark the pdl_interp dialect as a dependent. This is needed, because we
67b4130e9eSStanislav Funiak // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
68b4130e9eSStanislav Funiak registry.insert<pdl_interp::PDLInterpDialect>();
69b4130e9eSStanislav Funiak }
runOnOperation__anondf724a480111::TestPDLByteCodePass70abfd1a8bSRiver Riddle void runOnOperation() final {
71abfd1a8bSRiver Riddle ModuleOp module = getOperation();
72abfd1a8bSRiver Riddle
73abfd1a8bSRiver Riddle // The test cases are encompassed via two modules, one containing the
74abfd1a8bSRiver Riddle // patterns and one containing the operations to rewrite.
7541d4aa7dSChris Lattner ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
7641d4aa7dSChris Lattner StringAttr::get(module->getContext(), "patterns"));
7741d4aa7dSChris Lattner ModuleOp irModule = module.lookupSymbol<ModuleOp>(
7841d4aa7dSChris Lattner StringAttr::get(module->getContext(), "ir"));
79abfd1a8bSRiver Riddle if (!patternModule || !irModule)
80abfd1a8bSRiver Riddle return;
81abfd1a8bSRiver Riddle
8206c3b9c7SRiver Riddle RewritePatternSet patternList(module->getContext());
8306c3b9c7SRiver Riddle
8406c3b9c7SRiver Riddle // Register ahead of time to test when functions are registered without a
8506c3b9c7SRiver Riddle // pattern.
8606c3b9c7SRiver Riddle patternList.getPDLPatterns().registerConstraintFunction(
8706c3b9c7SRiver Riddle "multi_entity_constraint", customMultiEntityConstraint);
8806c3b9c7SRiver Riddle patternList.getPDLPatterns().registerConstraintFunction(
8906c3b9c7SRiver Riddle "single_entity_constraint", customSingleEntityConstraint);
9006c3b9c7SRiver Riddle
91abfd1a8bSRiver Riddle // Process the pattern module.
92abfd1a8bSRiver Riddle patternModule.getOperation()->remove();
93abfd1a8bSRiver Riddle PDLPatternModule pdlPattern(patternModule);
9406c3b9c7SRiver Riddle
9506c3b9c7SRiver Riddle // Note: This constraint was already registered, but we re-register here to
9606c3b9c7SRiver Riddle // ensure that duplication registration is allowed (the duplicate mapping
9706c3b9c7SRiver Riddle // will be ignored). This tests that we support separating the registration
9806c3b9c7SRiver Riddle // of library functions from the construction of patterns, and also that we
9906c3b9c7SRiver Riddle // allow multiple patterns to depend on the same library functions (without
10006c3b9c7SRiver Riddle // asserting/crashing).
101abfd1a8bSRiver Riddle pdlPattern.registerConstraintFunction("multi_entity_constraint",
102abfd1a8bSRiver Riddle customMultiEntityConstraint);
10385ab413bSRiver Riddle pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
10485ab413bSRiver Riddle customMultiEntityVariadicConstraint);
10502c4c0d5SRiver Riddle pdlPattern.registerRewriteFunction("creator", customCreate);
10685ab413bSRiver Riddle pdlPattern.registerRewriteFunction("var_creator",
10785ab413bSRiver Riddle customVariadicResultCreate);
10885ab413bSRiver Riddle pdlPattern.registerRewriteFunction("type_creator", customCreateType);
109*ea64828aSRiver Riddle pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr);
110abfd1a8bSRiver Riddle pdlPattern.registerRewriteFunction("rewriter", customRewriter);
11106c3b9c7SRiver Riddle patternList.add(std::move(pdlPattern));
112abfd1a8bSRiver Riddle
113abfd1a8bSRiver Riddle // Invoke the pattern driver with the provided patterns.
114abfd1a8bSRiver Riddle (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
115abfd1a8bSRiver Riddle std::move(patternList));
116abfd1a8bSRiver Riddle }
117abfd1a8bSRiver Riddle };
118be0a7e9fSMehdi Amini } // namespace
119abfd1a8bSRiver Riddle
120abfd1a8bSRiver Riddle namespace mlir {
121abfd1a8bSRiver Riddle namespace test {
registerTestPDLByteCodePass()122b5e22e6dSMehdi Amini void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
123abfd1a8bSRiver Riddle } // namespace test
124abfd1a8bSRiver Riddle } // namespace mlir
125