1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
10 #include "ByteCode.h"
11 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
12 #include "mlir/Dialect/PDL/IR/PDLOps.h"
13 #include "mlir/Interfaces/SideEffectInterfaces.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16 
17 using namespace mlir;
18 
19 static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
20   // Skip the conversion if the module doesn't contain pdl.
21   if (llvm::empty(pdlModule.getOps<pdl::PatternOp>()))
22     return success();
23 
24   // Simplify the provided PDL module. Note that we can't use the canonicalizer
25   // here because it would create a cyclic dependency.
26   auto simplifyFn = [](Operation *op) {
27     // TODO: Add folding here if ever necessary.
28     if (isOpTriviallyDead(op))
29       op->erase();
30   };
31   pdlModule.getBody()->walk(simplifyFn);
32 
33   /// Lower the PDL pattern module to the interpreter dialect.
34   PassManager pdlPipeline(pdlModule.getContext());
35 #ifdef NDEBUG
36   // We don't want to incur the hit of running the verifier when in release
37   // mode.
38   pdlPipeline.enableVerifier(false);
39 #endif
40   pdlPipeline.addPass(createPDLToPDLInterpPass());
41   if (failed(pdlPipeline.run(pdlModule)))
42     return failure();
43 
44   // Simplify again after running the lowering pipeline.
45   pdlModule.getBody()->walk(simplifyFn);
46   return success();
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // FrozenRewritePatternSet
51 //===----------------------------------------------------------------------===//
52 
53 FrozenRewritePatternSet::FrozenRewritePatternSet()
54     : impl(std::make_shared<Impl>()) {}
55 
56 FrozenRewritePatternSet::FrozenRewritePatternSet(RewritePatternSet &&patterns)
57     : impl(std::make_shared<Impl>()) {
58   // Functor used to walk all of the operations registered in the context. This
59   // is useful for patterns that get applied to multiple operations, such as
60   // interface and trait based patterns.
61   std::vector<AbstractOperation *> abstractOps;
62   auto addToOpsWhen = [&](std::unique_ptr<RewritePattern> &pattern,
63                           function_ref<bool(AbstractOperation *)> callbackFn) {
64     if (abstractOps.empty())
65       abstractOps = pattern->getContext()->getRegisteredOperations();
66     for (AbstractOperation *absOp : abstractOps) {
67       if (callbackFn(absOp)) {
68         OperationName opName(absOp);
69         impl->nativeOpSpecificPatternMap[opName].push_back(pattern.get());
70       }
71     }
72     impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
73   };
74 
75   for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
76     if (Optional<OperationName> rootName = pat->getRootKind()) {
77       impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
78       impl->nativeOpSpecificPatternList.push_back(std::move(pat));
79       continue;
80     }
81     if (Optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
82       addToOpsWhen(pat, [&](AbstractOperation *absOp) {
83         return absOp->hasInterface(*interfaceID);
84       });
85       continue;
86     }
87     if (Optional<TypeID> traitID = pat->getRootTraitID()) {
88       addToOpsWhen(pat, [&](AbstractOperation *absOp) {
89         return absOp->hasTrait(*traitID);
90       });
91       continue;
92     }
93     impl->nativeAnyOpPatterns.push_back(std::move(pat));
94   }
95 
96   // Generate the bytecode for the PDL patterns if any were provided.
97   PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
98   ModuleOp pdlModule = pdlPatterns.getModule();
99   if (!pdlModule)
100     return;
101   if (failed(convertPDLToPDLInterp(pdlModule)))
102     llvm::report_fatal_error(
103         "failed to lower PDL pattern module to the PDL Interpreter");
104 
105   // Generate the pdl bytecode.
106   impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
107       pdlModule, pdlPatterns.takeConstraintFunctions(),
108       pdlPatterns.takeRewriteFunctions());
109 }
110 
111 FrozenRewritePatternSet::~FrozenRewritePatternSet() {}
112