1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 10 #include "mlir/Pass/Pass.h" 11 #include "mlir/Pass/PassManager.h" 12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 13 14 using namespace mlir; 15 16 /// Custom constraint invoked from PDL. 17 static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter, 18 Operation *rootOp) { 19 return success(rootOp->getName().getStringRef() == "test.op"); 20 } 21 static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter, 22 Operation *root, 23 Operation *rootCopy) { 24 return customSingleEntityConstraint(rewriter, rootCopy); 25 } 26 static LogicalResult customMultiEntityVariadicConstraint( 27 PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) { 28 if (operandValues.size() != 2 || typeValues.size() != 2) 29 return failure(); 30 return success(); 31 } 32 33 // Custom creator invoked from PDL. 34 static Operation *customCreate(PatternRewriter &rewriter, Operation *op) { 35 return rewriter.create(OperationState(op->getLoc(), "test.success")); 36 } 37 static auto customVariadicResultCreate(PatternRewriter &rewriter, 38 Operation *root) { 39 return std::make_pair(root->getOperands(), root->getOperands().getTypes()); 40 } 41 static Type customCreateType(PatternRewriter &rewriter) { 42 return rewriter.getF32Type(); 43 } 44 static std::string customCreateStrAttr(PatternRewriter &rewriter) { 45 return "test.str"; 46 } 47 48 /// Custom rewriter invoked from PDL. 49 static void customRewriter(PatternRewriter &rewriter, Operation *root, 50 Value input) { 51 rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"), 52 input); 53 rewriter.eraseOp(root); 54 } 55 56 namespace { 57 struct TestPDLByteCodePass 58 : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { 59 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass) 60 61 StringRef getArgument() const final { return "test-pdl-bytecode-pass"; } 62 StringRef getDescription() const final { 63 return "Test PDL ByteCode functionality"; 64 } 65 void getDependentDialects(DialectRegistry ®istry) const override { 66 // Mark the pdl_interp dialect as a dependent. This is needed, because we 67 // create ops from that dialect as a part of the PDL-to-PDLInterp lowering. 68 registry.insert<pdl_interp::PDLInterpDialect>(); 69 } 70 void runOnOperation() final { 71 ModuleOp module = getOperation(); 72 73 // The test cases are encompassed via two modules, one containing the 74 // patterns and one containing the operations to rewrite. 75 ModuleOp patternModule = module.lookupSymbol<ModuleOp>( 76 StringAttr::get(module->getContext(), "patterns")); 77 ModuleOp irModule = module.lookupSymbol<ModuleOp>( 78 StringAttr::get(module->getContext(), "ir")); 79 if (!patternModule || !irModule) 80 return; 81 82 RewritePatternSet patternList(module->getContext()); 83 84 // Register ahead of time to test when functions are registered without a 85 // pattern. 86 patternList.getPDLPatterns().registerConstraintFunction( 87 "multi_entity_constraint", customMultiEntityConstraint); 88 patternList.getPDLPatterns().registerConstraintFunction( 89 "single_entity_constraint", customSingleEntityConstraint); 90 91 // Process the pattern module. 92 patternModule.getOperation()->remove(); 93 PDLPatternModule pdlPattern(patternModule); 94 95 // Note: This constraint was already registered, but we re-register here to 96 // ensure that duplication registration is allowed (the duplicate mapping 97 // will be ignored). This tests that we support separating the registration 98 // of library functions from the construction of patterns, and also that we 99 // allow multiple patterns to depend on the same library functions (without 100 // asserting/crashing). 101 pdlPattern.registerConstraintFunction("multi_entity_constraint", 102 customMultiEntityConstraint); 103 pdlPattern.registerConstraintFunction("multi_entity_var_constraint", 104 customMultiEntityVariadicConstraint); 105 pdlPattern.registerRewriteFunction("creator", customCreate); 106 pdlPattern.registerRewriteFunction("var_creator", 107 customVariadicResultCreate); 108 pdlPattern.registerRewriteFunction("type_creator", customCreateType); 109 pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr); 110 pdlPattern.registerRewriteFunction("rewriter", customRewriter); 111 patternList.add(std::move(pdlPattern)); 112 113 // Invoke the pattern driver with the provided patterns. 114 (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), 115 std::move(patternList)); 116 } 117 }; 118 } // namespace 119 120 namespace mlir { 121 namespace test { 122 void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); } 123 } // namespace test 124 } // namespace mlir 125