1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
10 #include "ByteCode.h"
11 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
12 #include "mlir/Dialect/PDL/IR/PDLOps.h"
13 #include "mlir/Interfaces/SideEffectInterfaces.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16 
17 using namespace mlir;
18 
convertPDLToPDLInterp(ModuleOp pdlModule)19 static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
20   // Skip the conversion if the module doesn't contain pdl.
21   if (llvm::empty(pdlModule.getOps<pdl::PatternOp>()))
22     return success();
23 
24   // Simplify the provided PDL module. Note that we can't use the canonicalizer
25   // here because it would create a cyclic dependency.
26   auto simplifyFn = [](Operation *op) {
27     // TODO: Add folding here if ever necessary.
28     if (isOpTriviallyDead(op))
29       op->erase();
30   };
31   pdlModule.getBody()->walk(simplifyFn);
32 
33   /// Lower the PDL pattern module to the interpreter dialect.
34   PassManager pdlPipeline(pdlModule.getContext());
35 #ifdef NDEBUG
36   // We don't want to incur the hit of running the verifier when in release
37   // mode.
38   pdlPipeline.enableVerifier(false);
39 #endif
40   pdlPipeline.addPass(createPDLToPDLInterpPass());
41   if (failed(pdlPipeline.run(pdlModule)))
42     return failure();
43 
44   // Simplify again after running the lowering pipeline.
45   pdlModule.getBody()->walk(simplifyFn);
46   return success();
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // FrozenRewritePatternSet
51 //===----------------------------------------------------------------------===//
52 
FrozenRewritePatternSet()53 FrozenRewritePatternSet::FrozenRewritePatternSet()
54     : impl(std::make_shared<Impl>()) {}
55 
FrozenRewritePatternSet(RewritePatternSet && patterns,ArrayRef<std::string> disabledPatternLabels,ArrayRef<std::string> enabledPatternLabels)56 FrozenRewritePatternSet::FrozenRewritePatternSet(
57     RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
58     ArrayRef<std::string> enabledPatternLabels)
59     : impl(std::make_shared<Impl>()) {
60   DenseSet<StringRef> disabledPatterns, enabledPatterns;
61   disabledPatterns.insert(disabledPatternLabels.begin(),
62                           disabledPatternLabels.end());
63   enabledPatterns.insert(enabledPatternLabels.begin(),
64                          enabledPatternLabels.end());
65 
66   // Functor used to walk all of the operations registered in the context. This
67   // is useful for patterns that get applied to multiple operations, such as
68   // interface and trait based patterns.
69   std::vector<RegisteredOperationName> opInfos;
70   auto addToOpsWhen =
71       [&](std::unique_ptr<RewritePattern> &pattern,
72           function_ref<bool(RegisteredOperationName)> callbackFn) {
73         if (opInfos.empty())
74           opInfos = pattern->getContext()->getRegisteredOperations();
75         for (RegisteredOperationName info : opInfos)
76           if (callbackFn(info))
77             impl->nativeOpSpecificPatternMap[info].push_back(pattern.get());
78         impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
79       };
80 
81   for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
82     // Don't add patterns that haven't been enabled by the user.
83     if (!enabledPatterns.empty()) {
84       auto isEnabledFn = [&](StringRef label) {
85         return enabledPatterns.count(label);
86       };
87       if (!isEnabledFn(pat->getDebugName()) &&
88           llvm::none_of(pat->getDebugLabels(), isEnabledFn))
89         continue;
90     }
91     // Don't add patterns that have been disabled by the user.
92     if (!disabledPatterns.empty()) {
93       auto isDisabledFn = [&](StringRef label) {
94         return disabledPatterns.count(label);
95       };
96       if (isDisabledFn(pat->getDebugName()) ||
97           llvm::any_of(pat->getDebugLabels(), isDisabledFn))
98         continue;
99     }
100 
101     if (Optional<OperationName> rootName = pat->getRootKind()) {
102       impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
103       impl->nativeOpSpecificPatternList.push_back(std::move(pat));
104       continue;
105     }
106     if (Optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
107       addToOpsWhen(pat, [&](RegisteredOperationName info) {
108         return info.hasInterface(*interfaceID);
109       });
110       continue;
111     }
112     if (Optional<TypeID> traitID = pat->getRootTraitID()) {
113       addToOpsWhen(pat, [&](RegisteredOperationName info) {
114         return info.hasTrait(*traitID);
115       });
116       continue;
117     }
118     impl->nativeAnyOpPatterns.push_back(std::move(pat));
119   }
120 
121   // Generate the bytecode for the PDL patterns if any were provided.
122   PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
123   ModuleOp pdlModule = pdlPatterns.getModule();
124   if (!pdlModule)
125     return;
126   if (failed(convertPDLToPDLInterp(pdlModule)))
127     llvm::report_fatal_error(
128         "failed to lower PDL pattern module to the PDL Interpreter");
129 
130   // Generate the pdl bytecode.
131   impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
132       pdlModule, pdlPatterns.takeConstraintFunctions(),
133       pdlPatterns.takeRewriteFunctions());
134 }
135 
136 FrozenRewritePatternSet::~FrozenRewritePatternSet() = default;
137