1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 10 #include "mlir/Pass/Pass.h" 11 #include "mlir/Pass/PassManager.h" 12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 13 14 using namespace mlir; 15 16 /// Custom constraint invoked from PDL. 17 static LogicalResult customSingleEntityConstraint(PDLValue value, 18 ArrayAttr constantParams, 19 PatternRewriter &rewriter) { 20 Operation *rootOp = value.cast<Operation *>(); 21 return success(rootOp->getName().getStringRef() == "test.op"); 22 } 23 static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values, 24 ArrayAttr constantParams, 25 PatternRewriter &rewriter) { 26 return customSingleEntityConstraint(values[1], constantParams, rewriter); 27 } 28 static LogicalResult 29 customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values, 30 ArrayAttr constantParams, 31 PatternRewriter &rewriter) { 32 if (llvm::any_of(values, [](const PDLValue &value) { return !value; })) 33 return failure(); 34 ValueRange operandValues = values[0].cast<ValueRange>(); 35 TypeRange typeValues = values[1].cast<TypeRange>(); 36 if (operandValues.size() != 2 || typeValues.size() != 2) 37 return failure(); 38 return success(); 39 } 40 41 // Custom creator invoked from PDL. 42 static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams, 43 PatternRewriter &rewriter, PDLResultList &results) { 44 results.push_back(rewriter.createOperation( 45 OperationState(args[0].cast<Operation *>()->getLoc(), "test.success"))); 46 } 47 static void customVariadicResultCreate(ArrayRef<PDLValue> args, 48 ArrayAttr constantParams, 49 PatternRewriter &rewriter, 50 PDLResultList &results) { 51 Operation *root = args[0].cast<Operation *>(); 52 results.push_back(root->getOperands()); 53 results.push_back(root->getOperands().getTypes()); 54 } 55 static void customCreateType(ArrayRef<PDLValue> args, ArrayAttr constantParams, 56 PatternRewriter &rewriter, 57 PDLResultList &results) { 58 results.push_back(rewriter.getF32Type()); 59 } 60 61 /// Custom rewriter invoked from PDL. 62 static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams, 63 PatternRewriter &rewriter, PDLResultList &results) { 64 Operation *root = args[0].cast<Operation *>(); 65 OperationState successOpState(root->getLoc(), "test.success"); 66 successOpState.addOperands(args[1].cast<Value>()); 67 successOpState.addAttribute("constantParams", constantParams); 68 rewriter.createOperation(successOpState); 69 rewriter.eraseOp(root); 70 } 71 72 namespace { 73 struct TestPDLByteCodePass 74 : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { 75 StringRef getArgument() const final { return "test-pdl-bytecode-pass"; } 76 StringRef getDescription() const final { 77 return "Test PDL ByteCode functionality"; 78 } 79 void getDependentDialects(DialectRegistry ®istry) const override { 80 // Mark the pdl_interp dialect as a dependent. This is needed, because we 81 // create ops from that dialect as a part of the PDL-to-PDLInterp lowering. 82 registry.insert<pdl_interp::PDLInterpDialect>(); 83 } 84 void runOnOperation() final { 85 ModuleOp module = getOperation(); 86 87 // The test cases are encompassed via two modules, one containing the 88 // patterns and one containing the operations to rewrite. 89 ModuleOp patternModule = module.lookupSymbol<ModuleOp>( 90 StringAttr::get(module->getContext(), "patterns")); 91 ModuleOp irModule = module.lookupSymbol<ModuleOp>( 92 StringAttr::get(module->getContext(), "ir")); 93 if (!patternModule || !irModule) 94 return; 95 96 RewritePatternSet patternList(module->getContext()); 97 98 // Register ahead of time to test when functions are registered without a 99 // pattern. 100 patternList.getPDLPatterns().registerConstraintFunction( 101 "multi_entity_constraint", customMultiEntityConstraint); 102 patternList.getPDLPatterns().registerConstraintFunction( 103 "single_entity_constraint", customSingleEntityConstraint); 104 105 // Process the pattern module. 106 patternModule.getOperation()->remove(); 107 PDLPatternModule pdlPattern(patternModule); 108 109 // Note: This constraint was already registered, but we re-register here to 110 // ensure that duplication registration is allowed (the duplicate mapping 111 // will be ignored). This tests that we support separating the registration 112 // of library functions from the construction of patterns, and also that we 113 // allow multiple patterns to depend on the same library functions (without 114 // asserting/crashing). 115 pdlPattern.registerConstraintFunction("multi_entity_constraint", 116 customMultiEntityConstraint); 117 pdlPattern.registerConstraintFunction("multi_entity_var_constraint", 118 customMultiEntityVariadicConstraint); 119 pdlPattern.registerRewriteFunction("creator", customCreate); 120 pdlPattern.registerRewriteFunction("var_creator", 121 customVariadicResultCreate); 122 pdlPattern.registerRewriteFunction("type_creator", customCreateType); 123 pdlPattern.registerRewriteFunction("rewriter", customRewriter); 124 patternList.add(std::move(pdlPattern)); 125 126 // Invoke the pattern driver with the provided patterns. 127 (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), 128 std::move(patternList)); 129 } 130 }; 131 } // namespace 132 133 namespace mlir { 134 namespace test { 135 void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); } 136 } // namespace test 137 } // namespace mlir 138