1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
10 #include "ByteCode.h"
11 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
12 #include "mlir/Dialect/PDL/IR/PDLOps.h"
13 #include "mlir/Interfaces/SideEffectInterfaces.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16
17 using namespace mlir;
18
convertPDLToPDLInterp(ModuleOp pdlModule)19 static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
20 // Skip the conversion if the module doesn't contain pdl.
21 if (llvm::empty(pdlModule.getOps<pdl::PatternOp>()))
22 return success();
23
24 // Simplify the provided PDL module. Note that we can't use the canonicalizer
25 // here because it would create a cyclic dependency.
26 auto simplifyFn = [](Operation *op) {
27 // TODO: Add folding here if ever necessary.
28 if (isOpTriviallyDead(op))
29 op->erase();
30 };
31 pdlModule.getBody()->walk(simplifyFn);
32
33 /// Lower the PDL pattern module to the interpreter dialect.
34 PassManager pdlPipeline(pdlModule.getContext());
35 #ifdef NDEBUG
36 // We don't want to incur the hit of running the verifier when in release
37 // mode.
38 pdlPipeline.enableVerifier(false);
39 #endif
40 pdlPipeline.addPass(createPDLToPDLInterpPass());
41 if (failed(pdlPipeline.run(pdlModule)))
42 return failure();
43
44 // Simplify again after running the lowering pipeline.
45 pdlModule.getBody()->walk(simplifyFn);
46 return success();
47 }
48
49 //===----------------------------------------------------------------------===//
50 // FrozenRewritePatternSet
51 //===----------------------------------------------------------------------===//
52
FrozenRewritePatternSet()53 FrozenRewritePatternSet::FrozenRewritePatternSet()
54 : impl(std::make_shared<Impl>()) {}
55
FrozenRewritePatternSet(RewritePatternSet && patterns,ArrayRef<std::string> disabledPatternLabels,ArrayRef<std::string> enabledPatternLabels)56 FrozenRewritePatternSet::FrozenRewritePatternSet(
57 RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
58 ArrayRef<std::string> enabledPatternLabels)
59 : impl(std::make_shared<Impl>()) {
60 DenseSet<StringRef> disabledPatterns, enabledPatterns;
61 disabledPatterns.insert(disabledPatternLabels.begin(),
62 disabledPatternLabels.end());
63 enabledPatterns.insert(enabledPatternLabels.begin(),
64 enabledPatternLabels.end());
65
66 // Functor used to walk all of the operations registered in the context. This
67 // is useful for patterns that get applied to multiple operations, such as
68 // interface and trait based patterns.
69 std::vector<RegisteredOperationName> opInfos;
70 auto addToOpsWhen =
71 [&](std::unique_ptr<RewritePattern> &pattern,
72 function_ref<bool(RegisteredOperationName)> callbackFn) {
73 if (opInfos.empty())
74 opInfos = pattern->getContext()->getRegisteredOperations();
75 for (RegisteredOperationName info : opInfos)
76 if (callbackFn(info))
77 impl->nativeOpSpecificPatternMap[info].push_back(pattern.get());
78 impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
79 };
80
81 for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
82 // Don't add patterns that haven't been enabled by the user.
83 if (!enabledPatterns.empty()) {
84 auto isEnabledFn = [&](StringRef label) {
85 return enabledPatterns.count(label);
86 };
87 if (!isEnabledFn(pat->getDebugName()) &&
88 llvm::none_of(pat->getDebugLabels(), isEnabledFn))
89 continue;
90 }
91 // Don't add patterns that have been disabled by the user.
92 if (!disabledPatterns.empty()) {
93 auto isDisabledFn = [&](StringRef label) {
94 return disabledPatterns.count(label);
95 };
96 if (isDisabledFn(pat->getDebugName()) ||
97 llvm::any_of(pat->getDebugLabels(), isDisabledFn))
98 continue;
99 }
100
101 if (Optional<OperationName> rootName = pat->getRootKind()) {
102 impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
103 impl->nativeOpSpecificPatternList.push_back(std::move(pat));
104 continue;
105 }
106 if (Optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
107 addToOpsWhen(pat, [&](RegisteredOperationName info) {
108 return info.hasInterface(*interfaceID);
109 });
110 continue;
111 }
112 if (Optional<TypeID> traitID = pat->getRootTraitID()) {
113 addToOpsWhen(pat, [&](RegisteredOperationName info) {
114 return info.hasTrait(*traitID);
115 });
116 continue;
117 }
118 impl->nativeAnyOpPatterns.push_back(std::move(pat));
119 }
120
121 // Generate the bytecode for the PDL patterns if any were provided.
122 PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
123 ModuleOp pdlModule = pdlPatterns.getModule();
124 if (!pdlModule)
125 return;
126 if (failed(convertPDLToPDLInterp(pdlModule)))
127 llvm::report_fatal_error(
128 "failed to lower PDL pattern module to the PDL Interpreter");
129
130 // Generate the pdl bytecode.
131 impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
132 pdlModule, pdlPatterns.takeConstraintFunctions(),
133 pdlPatterns.takeRewriteFunctions());
134 }
135
136 FrozenRewritePatternSet::~FrozenRewritePatternSet() = default;
137