1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Pass/Pass.h" 10 #include "mlir/Pass/PassManager.h" 11 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 12 13 using namespace mlir; 14 15 /// Custom constraint invoked from PDL. 16 static LogicalResult customSingleEntityConstraint(PDLValue value, 17 ArrayAttr constantParams, 18 PatternRewriter &rewriter) { 19 Operation *rootOp = value.cast<Operation *>(); 20 return success(rootOp->getName().getStringRef() == "test.op"); 21 } 22 static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values, 23 ArrayAttr constantParams, 24 PatternRewriter &rewriter) { 25 return customSingleEntityConstraint(values[1], constantParams, rewriter); 26 } 27 28 // Custom creator invoked from PDL. 29 static PDLValue customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams, 30 PatternRewriter &rewriter) { 31 return rewriter.createOperation( 32 OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")); 33 } 34 35 /// Custom rewriter invoked from PDL. 36 static void customRewriter(Operation *root, ArrayRef<PDLValue> args, 37 ArrayAttr constantParams, 38 PatternRewriter &rewriter) { 39 OperationState successOpState(root->getLoc(), "test.success"); 40 successOpState.addOperands(args[0].cast<Value>()); 41 successOpState.addAttribute("constantParams", constantParams); 42 rewriter.createOperation(successOpState); 43 rewriter.eraseOp(root); 44 } 45 46 namespace { 47 struct TestPDLByteCodePass 48 : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { 49 void runOnOperation() final { 50 ModuleOp module = getOperation(); 51 52 // The test cases are encompassed via two modules, one containing the 53 // patterns and one containing the operations to rewrite. 54 ModuleOp patternModule = module.lookupSymbol<ModuleOp>("patterns"); 55 ModuleOp irModule = module.lookupSymbol<ModuleOp>("ir"); 56 if (!patternModule || !irModule) 57 return; 58 59 // Process the pattern module. 60 patternModule.getOperation()->remove(); 61 PDLPatternModule pdlPattern(patternModule); 62 pdlPattern.registerConstraintFunction("multi_entity_constraint", 63 customMultiEntityConstraint); 64 pdlPattern.registerConstraintFunction("single_entity_constraint", 65 customSingleEntityConstraint); 66 pdlPattern.registerCreateFunction("creator", customCreate); 67 pdlPattern.registerRewriteFunction("rewriter", customRewriter); 68 69 OwningRewritePatternList patternList(std::move(pdlPattern)); 70 71 // Invoke the pattern driver with the provided patterns. 72 (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), 73 std::move(patternList)); 74 } 75 }; 76 } // end anonymous namespace 77 78 namespace mlir { 79 namespace test { 80 void registerTestPDLByteCodePass() { 81 PassRegistration<TestPDLByteCodePass>("test-pdl-bytecode-pass", 82 "Test PDL ByteCode functionality"); 83 } 84 } // namespace test 85 } // namespace mlir 86