//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; /// Custom constraint invoked from PDL. static LogicalResult customSingleEntityConstraint(PDLValue value, ArrayAttr constantParams, PatternRewriter &rewriter) { Operation *rootOp = value.cast(); return success(rootOp->getName().getStringRef() == "test.op"); } static LogicalResult customMultiEntityConstraint(ArrayRef values, ArrayAttr constantParams, PatternRewriter &rewriter) { return customSingleEntityConstraint(values[1], constantParams, rewriter); } static LogicalResult customMultiEntityVariadicConstraint(ArrayRef values, ArrayAttr constantParams, PatternRewriter &rewriter) { if (llvm::any_of(values, [](const PDLValue &value) { return !value; })) return failure(); ValueRange operandValues = values[0].cast(); TypeRange typeValues = values[1].cast(); if (operandValues.size() != 2 || typeValues.size() != 2) return failure(); return success(); } // Custom creator invoked from PDL. static void customCreate(ArrayRef args, ArrayAttr constantParams, PatternRewriter &rewriter, PDLResultList &results) { results.push_back(rewriter.createOperation( OperationState(args[0].cast()->getLoc(), "test.success"))); } static void customVariadicResultCreate(ArrayRef args, ArrayAttr constantParams, PatternRewriter &rewriter, PDLResultList &results) { Operation *root = args[0].cast(); results.push_back(root->getOperands()); results.push_back(root->getOperands().getTypes()); } static void customCreateType(ArrayRef args, ArrayAttr constantParams, PatternRewriter &rewriter, PDLResultList &results) { results.push_back(rewriter.getF32Type()); } /// Custom rewriter invoked from PDL. static void customRewriter(ArrayRef args, ArrayAttr constantParams, PatternRewriter &rewriter, PDLResultList &results) { Operation *root = args[0].cast(); OperationState successOpState(root->getLoc(), "test.success"); successOpState.addOperands(args[1].cast()); successOpState.addAttribute("constantParams", constantParams); rewriter.createOperation(successOpState); rewriter.eraseOp(root); } namespace { struct TestPDLByteCodePass : public PassWrapper> { StringRef getArgument() const final { return "test-pdl-bytecode-pass"; } StringRef getDescription() const final { return "Test PDL ByteCode functionality"; } void runOnOperation() final { ModuleOp module = getOperation(); // The test cases are encompassed via two modules, one containing the // patterns and one containing the operations to rewrite. ModuleOp patternModule = module.lookupSymbol( StringAttr::get(module->getContext(), "patterns")); ModuleOp irModule = module.lookupSymbol( StringAttr::get(module->getContext(), "ir")); if (!patternModule || !irModule) return; // Process the pattern module. patternModule.getOperation()->remove(); PDLPatternModule pdlPattern(patternModule); pdlPattern.registerConstraintFunction("multi_entity_constraint", customMultiEntityConstraint); pdlPattern.registerConstraintFunction("single_entity_constraint", customSingleEntityConstraint); pdlPattern.registerConstraintFunction("multi_entity_var_constraint", customMultiEntityVariadicConstraint); pdlPattern.registerRewriteFunction("creator", customCreate); pdlPattern.registerRewriteFunction("var_creator", customVariadicResultCreate); pdlPattern.registerRewriteFunction("type_creator", customCreateType); pdlPattern.registerRewriteFunction("rewriter", customRewriter); RewritePatternSet patternList(std::move(pdlPattern)); // Invoke the pattern driver with the provided patterns. (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), std::move(patternList)); } }; } // end anonymous namespace namespace mlir { namespace test { void registerTestPDLByteCodePass() { PassRegistration(); } } // namespace test } // namespace mlir