1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 10 #include "mlir/Pass/Pass.h" 11 #include "mlir/Pass/PassManager.h" 12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 13 14 using namespace mlir; 15 16 /// Custom constraint invoked from PDL. 17 static LogicalResult customSingleEntityConstraint(PDLValue value, 18 PatternRewriter &rewriter) { 19 Operation *rootOp = value.cast<Operation *>(); 20 return success(rootOp->getName().getStringRef() == "test.op"); 21 } 22 static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values, 23 PatternRewriter &rewriter) { 24 return customSingleEntityConstraint(values[1], rewriter); 25 } 26 static LogicalResult 27 customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values, 28 PatternRewriter &rewriter) { 29 if (llvm::any_of(values, [](const PDLValue &value) { return !value; })) 30 return failure(); 31 ValueRange operandValues = values[0].cast<ValueRange>(); 32 TypeRange typeValues = values[1].cast<TypeRange>(); 33 if (operandValues.size() != 2 || typeValues.size() != 2) 34 return failure(); 35 return success(); 36 } 37 38 // Custom creator invoked from PDL. 39 static void customCreate(ArrayRef<PDLValue> args, PatternRewriter &rewriter, 40 PDLResultList &results) { 41 results.push_back(rewriter.create( 42 OperationState(args[0].cast<Operation *>()->getLoc(), "test.success"))); 43 } 44 45 static void customVariadicResultCreate(ArrayRef<PDLValue> args, 46 PatternRewriter &rewriter, 47 PDLResultList &results) { 48 Operation *root = args[0].cast<Operation *>(); 49 results.push_back(root->getOperands()); 50 results.push_back(root->getOperands().getTypes()); 51 } 52 static void customCreateType(ArrayRef<PDLValue> args, PatternRewriter &rewriter, 53 PDLResultList &results) { 54 results.push_back(rewriter.getF32Type()); 55 } 56 57 /// Custom rewriter invoked from PDL. 58 static void customRewriter(ArrayRef<PDLValue> args, PatternRewriter &rewriter, 59 PDLResultList &results) { 60 Operation *root = args[0].cast<Operation *>(); 61 OperationState successOpState(root->getLoc(), "test.success"); 62 successOpState.addOperands(args[1].cast<Value>()); 63 rewriter.create(successOpState); 64 rewriter.eraseOp(root); 65 } 66 67 namespace { 68 struct TestPDLByteCodePass 69 : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { 70 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass) 71 72 StringRef getArgument() const final { return "test-pdl-bytecode-pass"; } 73 StringRef getDescription() const final { 74 return "Test PDL ByteCode functionality"; 75 } 76 void getDependentDialects(DialectRegistry ®istry) const override { 77 // Mark the pdl_interp dialect as a dependent. This is needed, because we 78 // create ops from that dialect as a part of the PDL-to-PDLInterp lowering. 79 registry.insert<pdl_interp::PDLInterpDialect>(); 80 } 81 void runOnOperation() final { 82 ModuleOp module = getOperation(); 83 84 // The test cases are encompassed via two modules, one containing the 85 // patterns and one containing the operations to rewrite. 86 ModuleOp patternModule = module.lookupSymbol<ModuleOp>( 87 StringAttr::get(module->getContext(), "patterns")); 88 ModuleOp irModule = module.lookupSymbol<ModuleOp>( 89 StringAttr::get(module->getContext(), "ir")); 90 if (!patternModule || !irModule) 91 return; 92 93 RewritePatternSet patternList(module->getContext()); 94 95 // Register ahead of time to test when functions are registered without a 96 // pattern. 97 patternList.getPDLPatterns().registerConstraintFunction( 98 "multi_entity_constraint", customMultiEntityConstraint); 99 patternList.getPDLPatterns().registerConstraintFunction( 100 "single_entity_constraint", customSingleEntityConstraint); 101 102 // Process the pattern module. 103 patternModule.getOperation()->remove(); 104 PDLPatternModule pdlPattern(patternModule); 105 106 // Note: This constraint was already registered, but we re-register here to 107 // ensure that duplication registration is allowed (the duplicate mapping 108 // will be ignored). This tests that we support separating the registration 109 // of library functions from the construction of patterns, and also that we 110 // allow multiple patterns to depend on the same library functions (without 111 // asserting/crashing). 112 pdlPattern.registerConstraintFunction("multi_entity_constraint", 113 customMultiEntityConstraint); 114 pdlPattern.registerConstraintFunction("multi_entity_var_constraint", 115 customMultiEntityVariadicConstraint); 116 pdlPattern.registerRewriteFunction("creator", customCreate); 117 pdlPattern.registerRewriteFunction("var_creator", 118 customVariadicResultCreate); 119 pdlPattern.registerRewriteFunction("type_creator", customCreateType); 120 pdlPattern.registerRewriteFunction("rewriter", customRewriter); 121 patternList.add(std::move(pdlPattern)); 122 123 // Invoke the pattern driver with the provided patterns. 124 (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), 125 std::move(patternList)); 126 } 127 }; 128 } // namespace 129 130 namespace mlir { 131 namespace test { 132 void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); } 133 } // namespace test 134 } // namespace mlir 135