1abfd1a8bSRiver Riddle //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2abfd1a8bSRiver Riddle //
3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6abfd1a8bSRiver Riddle //
7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
8abfd1a8bSRiver Riddle 
9abfd1a8bSRiver Riddle #include "mlir/Pass/Pass.h"
10abfd1a8bSRiver Riddle #include "mlir/Pass/PassManager.h"
11abfd1a8bSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
12abfd1a8bSRiver Riddle 
13abfd1a8bSRiver Riddle using namespace mlir;
14abfd1a8bSRiver Riddle 
15abfd1a8bSRiver Riddle /// Custom constraint invoked from PDL.
16abfd1a8bSRiver Riddle static LogicalResult customSingleEntityConstraint(PDLValue value,
17abfd1a8bSRiver Riddle                                                   ArrayAttr constantParams,
18abfd1a8bSRiver Riddle                                                   PatternRewriter &rewriter) {
19abfd1a8bSRiver Riddle   Operation *rootOp = value.cast<Operation *>();
20abfd1a8bSRiver Riddle   return success(rootOp->getName().getStringRef() == "test.op");
21abfd1a8bSRiver Riddle }
22abfd1a8bSRiver Riddle static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
23abfd1a8bSRiver Riddle                                                  ArrayAttr constantParams,
24abfd1a8bSRiver Riddle                                                  PatternRewriter &rewriter) {
25abfd1a8bSRiver Riddle   return customSingleEntityConstraint(values[1], constantParams, rewriter);
26abfd1a8bSRiver Riddle }
2785ab413bSRiver Riddle static LogicalResult
2885ab413bSRiver Riddle customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
2985ab413bSRiver Riddle                                     ArrayAttr constantParams,
3085ab413bSRiver Riddle                                     PatternRewriter &rewriter) {
3185ab413bSRiver Riddle   if (llvm::any_of(values, [](const PDLValue &value) { return !value; }))
3285ab413bSRiver Riddle     return failure();
3385ab413bSRiver Riddle   ValueRange operandValues = values[0].cast<ValueRange>();
3485ab413bSRiver Riddle   TypeRange typeValues = values[1].cast<TypeRange>();
3585ab413bSRiver Riddle   if (operandValues.size() != 2 || typeValues.size() != 2)
3685ab413bSRiver Riddle     return failure();
3785ab413bSRiver Riddle   return success();
3885ab413bSRiver Riddle }
39abfd1a8bSRiver Riddle 
40abfd1a8bSRiver Riddle // Custom creator invoked from PDL.
4102c4c0d5SRiver Riddle static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
4202c4c0d5SRiver Riddle                          PatternRewriter &rewriter, PDLResultList &results) {
4302c4c0d5SRiver Riddle   results.push_back(rewriter.createOperation(
4402c4c0d5SRiver Riddle       OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
45abfd1a8bSRiver Riddle }
4685ab413bSRiver Riddle static void customVariadicResultCreate(ArrayRef<PDLValue> args,
4785ab413bSRiver Riddle                                        ArrayAttr constantParams,
4885ab413bSRiver Riddle                                        PatternRewriter &rewriter,
4985ab413bSRiver Riddle                                        PDLResultList &results) {
5085ab413bSRiver Riddle   Operation *root = args[0].cast<Operation *>();
5185ab413bSRiver Riddle   results.push_back(root->getOperands());
5285ab413bSRiver Riddle   results.push_back(root->getOperands().getTypes());
5385ab413bSRiver Riddle }
5485ab413bSRiver Riddle static void customCreateType(ArrayRef<PDLValue> args, ArrayAttr constantParams,
5585ab413bSRiver Riddle                              PatternRewriter &rewriter,
5685ab413bSRiver Riddle                              PDLResultList &results) {
5785ab413bSRiver Riddle   results.push_back(rewriter.getF32Type());
5885ab413bSRiver Riddle }
59abfd1a8bSRiver Riddle 
60abfd1a8bSRiver Riddle /// Custom rewriter invoked from PDL.
6102c4c0d5SRiver Riddle static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
6202c4c0d5SRiver Riddle                            PatternRewriter &rewriter, PDLResultList &results) {
6302c4c0d5SRiver Riddle   Operation *root = args[0].cast<Operation *>();
64abfd1a8bSRiver Riddle   OperationState successOpState(root->getLoc(), "test.success");
6502c4c0d5SRiver Riddle   successOpState.addOperands(args[1].cast<Value>());
66abfd1a8bSRiver Riddle   successOpState.addAttribute("constantParams", constantParams);
67abfd1a8bSRiver Riddle   rewriter.createOperation(successOpState);
68abfd1a8bSRiver Riddle   rewriter.eraseOp(root);
69abfd1a8bSRiver Riddle }
70abfd1a8bSRiver Riddle 
71abfd1a8bSRiver Riddle namespace {
72abfd1a8bSRiver Riddle struct TestPDLByteCodePass
73abfd1a8bSRiver Riddle     : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
74b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
75b5e22e6dSMehdi Amini   StringRef getDescription() const final {
76b5e22e6dSMehdi Amini     return "Test PDL ByteCode functionality";
77b5e22e6dSMehdi Amini   }
78abfd1a8bSRiver Riddle   void runOnOperation() final {
79abfd1a8bSRiver Riddle     ModuleOp module = getOperation();
80abfd1a8bSRiver Riddle 
81abfd1a8bSRiver Riddle     // The test cases are encompassed via two modules, one containing the
82abfd1a8bSRiver Riddle     // patterns and one containing the operations to rewrite.
83*41d4aa7dSChris Lattner     ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
84*41d4aa7dSChris Lattner         StringAttr::get(module->getContext(), "patterns"));
85*41d4aa7dSChris Lattner     ModuleOp irModule = module.lookupSymbol<ModuleOp>(
86*41d4aa7dSChris Lattner         StringAttr::get(module->getContext(), "ir"));
87abfd1a8bSRiver Riddle     if (!patternModule || !irModule)
88abfd1a8bSRiver Riddle       return;
89abfd1a8bSRiver Riddle 
90abfd1a8bSRiver Riddle     // Process the pattern module.
91abfd1a8bSRiver Riddle     patternModule.getOperation()->remove();
92abfd1a8bSRiver Riddle     PDLPatternModule pdlPattern(patternModule);
93abfd1a8bSRiver Riddle     pdlPattern.registerConstraintFunction("multi_entity_constraint",
94abfd1a8bSRiver Riddle                                           customMultiEntityConstraint);
95abfd1a8bSRiver Riddle     pdlPattern.registerConstraintFunction("single_entity_constraint",
96abfd1a8bSRiver Riddle                                           customSingleEntityConstraint);
9785ab413bSRiver Riddle     pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
9885ab413bSRiver Riddle                                           customMultiEntityVariadicConstraint);
9902c4c0d5SRiver Riddle     pdlPattern.registerRewriteFunction("creator", customCreate);
10085ab413bSRiver Riddle     pdlPattern.registerRewriteFunction("var_creator",
10185ab413bSRiver Riddle                                        customVariadicResultCreate);
10285ab413bSRiver Riddle     pdlPattern.registerRewriteFunction("type_creator", customCreateType);
103abfd1a8bSRiver Riddle     pdlPattern.registerRewriteFunction("rewriter", customRewriter);
104abfd1a8bSRiver Riddle 
105dc4e913bSChris Lattner     RewritePatternSet patternList(std::move(pdlPattern));
106abfd1a8bSRiver Riddle 
107abfd1a8bSRiver Riddle     // Invoke the pattern driver with the provided patterns.
108abfd1a8bSRiver Riddle     (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
109abfd1a8bSRiver Riddle                                        std::move(patternList));
110abfd1a8bSRiver Riddle   }
111abfd1a8bSRiver Riddle };
112abfd1a8bSRiver Riddle } // end anonymous namespace
113abfd1a8bSRiver Riddle 
114abfd1a8bSRiver Riddle namespace mlir {
115abfd1a8bSRiver Riddle namespace test {
116b5e22e6dSMehdi Amini void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
117abfd1a8bSRiver Riddle } // namespace test
118abfd1a8bSRiver Riddle } // namespace mlir
119