1 //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Pass/Pass.h" 10 #include "mlir/Pass/PassManager.h" 11 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 12 13 using namespace mlir; 14 15 /// Custom constraint invoked from PDL. 16 static LogicalResult customSingleEntityConstraint(PDLValue value, 17 ArrayAttr constantParams, 18 PatternRewriter &rewriter) { 19 Operation *rootOp = value.cast<Operation *>(); 20 return success(rootOp->getName().getStringRef() == "test.op"); 21 } 22 static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values, 23 ArrayAttr constantParams, 24 PatternRewriter &rewriter) { 25 return customSingleEntityConstraint(values[1], constantParams, rewriter); 26 } 27 static LogicalResult 28 customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values, 29 ArrayAttr constantParams, 30 PatternRewriter &rewriter) { 31 if (llvm::any_of(values, [](const PDLValue &value) { return !value; })) 32 return failure(); 33 ValueRange operandValues = values[0].cast<ValueRange>(); 34 TypeRange typeValues = values[1].cast<TypeRange>(); 35 if (operandValues.size() != 2 || typeValues.size() != 2) 36 return failure(); 37 return success(); 38 } 39 40 // Custom creator invoked from PDL. 41 static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams, 42 PatternRewriter &rewriter, PDLResultList &results) { 43 results.push_back(rewriter.createOperation( 44 OperationState(args[0].cast<Operation *>()->getLoc(), "test.success"))); 45 } 46 static void customVariadicResultCreate(ArrayRef<PDLValue> args, 47 ArrayAttr constantParams, 48 PatternRewriter &rewriter, 49 PDLResultList &results) { 50 Operation *root = args[0].cast<Operation *>(); 51 results.push_back(root->getOperands()); 52 results.push_back(root->getOperands().getTypes()); 53 } 54 static void customCreateType(ArrayRef<PDLValue> args, ArrayAttr constantParams, 55 PatternRewriter &rewriter, 56 PDLResultList &results) { 57 results.push_back(rewriter.getF32Type()); 58 } 59 60 /// Custom rewriter invoked from PDL. 61 static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams, 62 PatternRewriter &rewriter, PDLResultList &results) { 63 Operation *root = args[0].cast<Operation *>(); 64 OperationState successOpState(root->getLoc(), "test.success"); 65 successOpState.addOperands(args[1].cast<Value>()); 66 successOpState.addAttribute("constantParams", constantParams); 67 rewriter.createOperation(successOpState); 68 rewriter.eraseOp(root); 69 } 70 71 namespace { 72 struct TestPDLByteCodePass 73 : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { 74 StringRef getArgument() const final { return "test-pdl-bytecode-pass"; } 75 StringRef getDescription() const final { 76 return "Test PDL ByteCode functionality"; 77 } 78 void runOnOperation() final { 79 ModuleOp module = getOperation(); 80 81 // The test cases are encompassed via two modules, one containing the 82 // patterns and one containing the operations to rewrite. 83 ModuleOp patternModule = module.lookupSymbol<ModuleOp>( 84 StringAttr::get(module->getContext(), "patterns")); 85 ModuleOp irModule = module.lookupSymbol<ModuleOp>( 86 StringAttr::get(module->getContext(), "ir")); 87 if (!patternModule || !irModule) 88 return; 89 90 // Process the pattern module. 91 patternModule.getOperation()->remove(); 92 PDLPatternModule pdlPattern(patternModule); 93 pdlPattern.registerConstraintFunction("multi_entity_constraint", 94 customMultiEntityConstraint); 95 pdlPattern.registerConstraintFunction("single_entity_constraint", 96 customSingleEntityConstraint); 97 pdlPattern.registerConstraintFunction("multi_entity_var_constraint", 98 customMultiEntityVariadicConstraint); 99 pdlPattern.registerRewriteFunction("creator", customCreate); 100 pdlPattern.registerRewriteFunction("var_creator", 101 customVariadicResultCreate); 102 pdlPattern.registerRewriteFunction("type_creator", customCreateType); 103 pdlPattern.registerRewriteFunction("rewriter", customRewriter); 104 105 RewritePatternSet patternList(std::move(pdlPattern)); 106 107 // Invoke the pattern driver with the provided patterns. 108 (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), 109 std::move(patternList)); 110 } 111 }; 112 } // namespace 113 114 namespace mlir { 115 namespace test { 116 void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); } 117 } // namespace test 118 } // namespace mlir 119