1abfd1a8bSRiver Riddle //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// 2abfd1a8bSRiver Riddle // 3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6abfd1a8bSRiver Riddle // 7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 8abfd1a8bSRiver Riddle 9abfd1a8bSRiver Riddle #include "mlir/Pass/Pass.h" 10abfd1a8bSRiver Riddle #include "mlir/Pass/PassManager.h" 11abfd1a8bSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 12abfd1a8bSRiver Riddle 13abfd1a8bSRiver Riddle using namespace mlir; 14abfd1a8bSRiver Riddle 15abfd1a8bSRiver Riddle /// Custom constraint invoked from PDL. 16abfd1a8bSRiver Riddle static LogicalResult customSingleEntityConstraint(PDLValue value, 17abfd1a8bSRiver Riddle ArrayAttr constantParams, 18abfd1a8bSRiver Riddle PatternRewriter &rewriter) { 19abfd1a8bSRiver Riddle Operation *rootOp = value.cast<Operation *>(); 20abfd1a8bSRiver Riddle return success(rootOp->getName().getStringRef() == "test.op"); 21abfd1a8bSRiver Riddle } 22abfd1a8bSRiver Riddle static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values, 23abfd1a8bSRiver Riddle ArrayAttr constantParams, 24abfd1a8bSRiver Riddle PatternRewriter &rewriter) { 25abfd1a8bSRiver Riddle return customSingleEntityConstraint(values[1], constantParams, rewriter); 26abfd1a8bSRiver Riddle } 27*85ab413bSRiver Riddle static LogicalResult 28*85ab413bSRiver Riddle customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values, 29*85ab413bSRiver Riddle ArrayAttr constantParams, 30*85ab413bSRiver Riddle PatternRewriter &rewriter) { 31*85ab413bSRiver Riddle if (llvm::any_of(values, [](const PDLValue &value) { return !value; })) 32*85ab413bSRiver Riddle return failure(); 33*85ab413bSRiver Riddle ValueRange operandValues = values[0].cast<ValueRange>(); 34*85ab413bSRiver Riddle TypeRange typeValues = values[1].cast<TypeRange>(); 35*85ab413bSRiver Riddle if (operandValues.size() != 2 || typeValues.size() != 2) 36*85ab413bSRiver Riddle return failure(); 37*85ab413bSRiver Riddle return success(); 38*85ab413bSRiver Riddle } 39abfd1a8bSRiver Riddle 40abfd1a8bSRiver Riddle // Custom creator invoked from PDL. 4102c4c0d5SRiver Riddle static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams, 4202c4c0d5SRiver Riddle PatternRewriter &rewriter, PDLResultList &results) { 4302c4c0d5SRiver Riddle results.push_back(rewriter.createOperation( 4402c4c0d5SRiver Riddle OperationState(args[0].cast<Operation *>()->getLoc(), "test.success"))); 45abfd1a8bSRiver Riddle } 46*85ab413bSRiver Riddle static void customVariadicResultCreate(ArrayRef<PDLValue> args, 47*85ab413bSRiver Riddle ArrayAttr constantParams, 48*85ab413bSRiver Riddle PatternRewriter &rewriter, 49*85ab413bSRiver Riddle PDLResultList &results) { 50*85ab413bSRiver Riddle Operation *root = args[0].cast<Operation *>(); 51*85ab413bSRiver Riddle results.push_back(root->getOperands()); 52*85ab413bSRiver Riddle results.push_back(root->getOperands().getTypes()); 53*85ab413bSRiver Riddle } 54*85ab413bSRiver Riddle static void customCreateType(ArrayRef<PDLValue> args, ArrayAttr constantParams, 55*85ab413bSRiver Riddle PatternRewriter &rewriter, 56*85ab413bSRiver Riddle PDLResultList &results) { 57*85ab413bSRiver Riddle results.push_back(rewriter.getF32Type()); 58*85ab413bSRiver Riddle } 59abfd1a8bSRiver Riddle 60abfd1a8bSRiver Riddle /// Custom rewriter invoked from PDL. 6102c4c0d5SRiver Riddle static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams, 6202c4c0d5SRiver Riddle PatternRewriter &rewriter, PDLResultList &results) { 6302c4c0d5SRiver Riddle Operation *root = args[0].cast<Operation *>(); 64abfd1a8bSRiver Riddle OperationState successOpState(root->getLoc(), "test.success"); 6502c4c0d5SRiver Riddle successOpState.addOperands(args[1].cast<Value>()); 66abfd1a8bSRiver Riddle successOpState.addAttribute("constantParams", constantParams); 67abfd1a8bSRiver Riddle rewriter.createOperation(successOpState); 68abfd1a8bSRiver Riddle rewriter.eraseOp(root); 69abfd1a8bSRiver Riddle } 70abfd1a8bSRiver Riddle 71abfd1a8bSRiver Riddle namespace { 72abfd1a8bSRiver Riddle struct TestPDLByteCodePass 73abfd1a8bSRiver Riddle : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { 74abfd1a8bSRiver Riddle void runOnOperation() final { 75abfd1a8bSRiver Riddle ModuleOp module = getOperation(); 76abfd1a8bSRiver Riddle 77abfd1a8bSRiver Riddle // The test cases are encompassed via two modules, one containing the 78abfd1a8bSRiver Riddle // patterns and one containing the operations to rewrite. 79abfd1a8bSRiver Riddle ModuleOp patternModule = module.lookupSymbol<ModuleOp>("patterns"); 80abfd1a8bSRiver Riddle ModuleOp irModule = module.lookupSymbol<ModuleOp>("ir"); 81abfd1a8bSRiver Riddle if (!patternModule || !irModule) 82abfd1a8bSRiver Riddle return; 83abfd1a8bSRiver Riddle 84abfd1a8bSRiver Riddle // Process the pattern module. 85abfd1a8bSRiver Riddle patternModule.getOperation()->remove(); 86abfd1a8bSRiver Riddle PDLPatternModule pdlPattern(patternModule); 87abfd1a8bSRiver Riddle pdlPattern.registerConstraintFunction("multi_entity_constraint", 88abfd1a8bSRiver Riddle customMultiEntityConstraint); 89abfd1a8bSRiver Riddle pdlPattern.registerConstraintFunction("single_entity_constraint", 90abfd1a8bSRiver Riddle customSingleEntityConstraint); 91*85ab413bSRiver Riddle pdlPattern.registerConstraintFunction("multi_entity_var_constraint", 92*85ab413bSRiver Riddle customMultiEntityVariadicConstraint); 9302c4c0d5SRiver Riddle pdlPattern.registerRewriteFunction("creator", customCreate); 94*85ab413bSRiver Riddle pdlPattern.registerRewriteFunction("var_creator", 95*85ab413bSRiver Riddle customVariadicResultCreate); 96*85ab413bSRiver Riddle pdlPattern.registerRewriteFunction("type_creator", customCreateType); 97abfd1a8bSRiver Riddle pdlPattern.registerRewriteFunction("rewriter", customRewriter); 98abfd1a8bSRiver Riddle 99abfd1a8bSRiver Riddle OwningRewritePatternList patternList(std::move(pdlPattern)); 100abfd1a8bSRiver Riddle 101abfd1a8bSRiver Riddle // Invoke the pattern driver with the provided patterns. 102abfd1a8bSRiver Riddle (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), 103abfd1a8bSRiver Riddle std::move(patternList)); 104abfd1a8bSRiver Riddle } 105abfd1a8bSRiver Riddle }; 106abfd1a8bSRiver Riddle } // end anonymous namespace 107abfd1a8bSRiver Riddle 108abfd1a8bSRiver Riddle namespace mlir { 109abfd1a8bSRiver Riddle namespace test { 110abfd1a8bSRiver Riddle void registerTestPDLByteCodePass() { 111abfd1a8bSRiver Riddle PassRegistration<TestPDLByteCodePass>("test-pdl-bytecode-pass", 112abfd1a8bSRiver Riddle "Test PDL ByteCode functionality"); 113abfd1a8bSRiver Riddle } 114abfd1a8bSRiver Riddle } // namespace test 115abfd1a8bSRiver Riddle } // namespace mlir 116