179d7f618SChris Lattner //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
279d7f618SChris Lattner //
379d7f618SChris Lattner // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
479d7f618SChris Lattner // See https://llvm.org/LICENSE.txt for license information.
579d7f618SChris Lattner // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
679d7f618SChris Lattner //
779d7f618SChris Lattner //===----------------------------------------------------------------------===//
879d7f618SChris Lattner 
979d7f618SChris Lattner #include "mlir/Rewrite/FrozenRewritePatternSet.h"
1079d7f618SChris Lattner #include "ByteCode.h"
1179d7f618SChris Lattner #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
1279d7f618SChris Lattner #include "mlir/Dialect/PDL/IR/PDLOps.h"
1379d7f618SChris Lattner #include "mlir/Interfaces/SideEffectInterfaces.h"
1479d7f618SChris Lattner #include "mlir/Pass/Pass.h"
1579d7f618SChris Lattner #include "mlir/Pass/PassManager.h"
1679d7f618SChris Lattner 
1779d7f618SChris Lattner using namespace mlir;
1879d7f618SChris Lattner 
convertPDLToPDLInterp(ModuleOp pdlModule)1979d7f618SChris Lattner static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
2079d7f618SChris Lattner   // Skip the conversion if the module doesn't contain pdl.
2179d7f618SChris Lattner   if (llvm::empty(pdlModule.getOps<pdl::PatternOp>()))
2279d7f618SChris Lattner     return success();
2379d7f618SChris Lattner 
2479d7f618SChris Lattner   // Simplify the provided PDL module. Note that we can't use the canonicalizer
2579d7f618SChris Lattner   // here because it would create a cyclic dependency.
2679d7f618SChris Lattner   auto simplifyFn = [](Operation *op) {
2779d7f618SChris Lattner     // TODO: Add folding here if ever necessary.
2879d7f618SChris Lattner     if (isOpTriviallyDead(op))
2979d7f618SChris Lattner       op->erase();
3079d7f618SChris Lattner   };
3179d7f618SChris Lattner   pdlModule.getBody()->walk(simplifyFn);
3279d7f618SChris Lattner 
3379d7f618SChris Lattner   /// Lower the PDL pattern module to the interpreter dialect.
3479d7f618SChris Lattner   PassManager pdlPipeline(pdlModule.getContext());
3579d7f618SChris Lattner #ifdef NDEBUG
3679d7f618SChris Lattner   // We don't want to incur the hit of running the verifier when in release
3779d7f618SChris Lattner   // mode.
3879d7f618SChris Lattner   pdlPipeline.enableVerifier(false);
3979d7f618SChris Lattner #endif
4079d7f618SChris Lattner   pdlPipeline.addPass(createPDLToPDLInterpPass());
4179d7f618SChris Lattner   if (failed(pdlPipeline.run(pdlModule)))
4279d7f618SChris Lattner     return failure();
4379d7f618SChris Lattner 
4479d7f618SChris Lattner   // Simplify again after running the lowering pipeline.
4579d7f618SChris Lattner   pdlModule.getBody()->walk(simplifyFn);
4679d7f618SChris Lattner   return success();
4779d7f618SChris Lattner }
4879d7f618SChris Lattner 
4979d7f618SChris Lattner //===----------------------------------------------------------------------===//
5079d7f618SChris Lattner // FrozenRewritePatternSet
5179d7f618SChris Lattner //===----------------------------------------------------------------------===//
5279d7f618SChris Lattner 
FrozenRewritePatternSet()5379d7f618SChris Lattner FrozenRewritePatternSet::FrozenRewritePatternSet()
5479d7f618SChris Lattner     : impl(std::make_shared<Impl>()) {}
5579d7f618SChris Lattner 
FrozenRewritePatternSet(RewritePatternSet && patterns,ArrayRef<std::string> disabledPatternLabels,ArrayRef<std::string> enabledPatternLabels)560289a269SRiver Riddle FrozenRewritePatternSet::FrozenRewritePatternSet(
570289a269SRiver Riddle     RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
580289a269SRiver Riddle     ArrayRef<std::string> enabledPatternLabels)
5979d7f618SChris Lattner     : impl(std::make_shared<Impl>()) {
600289a269SRiver Riddle   DenseSet<StringRef> disabledPatterns, enabledPatterns;
610289a269SRiver Riddle   disabledPatterns.insert(disabledPatternLabels.begin(),
620289a269SRiver Riddle                           disabledPatternLabels.end());
630289a269SRiver Riddle   enabledPatterns.insert(enabledPatternLabels.begin(),
640289a269SRiver Riddle                          enabledPatternLabels.end());
650289a269SRiver Riddle 
6676f3c2f3SRiver Riddle   // Functor used to walk all of the operations registered in the context. This
6776f3c2f3SRiver Riddle   // is useful for patterns that get applied to multiple operations, such as
6876f3c2f3SRiver Riddle   // interface and trait based patterns.
69edc6c0ecSRiver Riddle   std::vector<RegisteredOperationName> opInfos;
70edc6c0ecSRiver Riddle   auto addToOpsWhen =
71edc6c0ecSRiver Riddle       [&](std::unique_ptr<RewritePattern> &pattern,
72edc6c0ecSRiver Riddle           function_ref<bool(RegisteredOperationName)> callbackFn) {
73edc6c0ecSRiver Riddle         if (opInfos.empty())
74edc6c0ecSRiver Riddle           opInfos = pattern->getContext()->getRegisteredOperations();
75edc6c0ecSRiver Riddle         for (RegisteredOperationName info : opInfos)
76edc6c0ecSRiver Riddle           if (callbackFn(info))
77edc6c0ecSRiver Riddle             impl->nativeOpSpecificPatternMap[info].push_back(pattern.get());
7876f3c2f3SRiver Riddle         impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
7976f3c2f3SRiver Riddle       };
8076f3c2f3SRiver Riddle 
8176f3c2f3SRiver Riddle   for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
820289a269SRiver Riddle     // Don't add patterns that haven't been enabled by the user.
830289a269SRiver Riddle     if (!enabledPatterns.empty()) {
840289a269SRiver Riddle       auto isEnabledFn = [&](StringRef label) {
850289a269SRiver Riddle         return enabledPatterns.count(label);
860289a269SRiver Riddle       };
870289a269SRiver Riddle       if (!isEnabledFn(pat->getDebugName()) &&
880289a269SRiver Riddle           llvm::none_of(pat->getDebugLabels(), isEnabledFn))
890289a269SRiver Riddle         continue;
900289a269SRiver Riddle     }
910289a269SRiver Riddle     // Don't add patterns that have been disabled by the user.
920289a269SRiver Riddle     if (!disabledPatterns.empty()) {
930289a269SRiver Riddle       auto isDisabledFn = [&](StringRef label) {
940289a269SRiver Riddle         return disabledPatterns.count(label);
950289a269SRiver Riddle       };
960289a269SRiver Riddle       if (isDisabledFn(pat->getDebugName()) ||
970289a269SRiver Riddle           llvm::any_of(pat->getDebugLabels(), isDisabledFn))
980289a269SRiver Riddle         continue;
990289a269SRiver Riddle     }
1000289a269SRiver Riddle 
10176f3c2f3SRiver Riddle     if (Optional<OperationName> rootName = pat->getRootKind()) {
10276f3c2f3SRiver Riddle       impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
10376f3c2f3SRiver Riddle       impl->nativeOpSpecificPatternList.push_back(std::move(pat));
10476f3c2f3SRiver Riddle       continue;
10576f3c2f3SRiver Riddle     }
10676f3c2f3SRiver Riddle     if (Optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
107edc6c0ecSRiver Riddle       addToOpsWhen(pat, [&](RegisteredOperationName info) {
108edc6c0ecSRiver Riddle         return info.hasInterface(*interfaceID);
10976f3c2f3SRiver Riddle       });
11076f3c2f3SRiver Riddle       continue;
11176f3c2f3SRiver Riddle     }
11276f3c2f3SRiver Riddle     if (Optional<TypeID> traitID = pat->getRootTraitID()) {
113edc6c0ecSRiver Riddle       addToOpsWhen(pat, [&](RegisteredOperationName info) {
114edc6c0ecSRiver Riddle         return info.hasTrait(*traitID);
11576f3c2f3SRiver Riddle       });
11676f3c2f3SRiver Riddle       continue;
11776f3c2f3SRiver Riddle     }
11876f3c2f3SRiver Riddle     impl->nativeAnyOpPatterns.push_back(std::move(pat));
11976f3c2f3SRiver Riddle   }
12079d7f618SChris Lattner 
12179d7f618SChris Lattner   // Generate the bytecode for the PDL patterns if any were provided.
12279d7f618SChris Lattner   PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
12379d7f618SChris Lattner   ModuleOp pdlModule = pdlPatterns.getModule();
12479d7f618SChris Lattner   if (!pdlModule)
12579d7f618SChris Lattner     return;
12679d7f618SChris Lattner   if (failed(convertPDLToPDLInterp(pdlModule)))
12779d7f618SChris Lattner     llvm::report_fatal_error(
12879d7f618SChris Lattner         "failed to lower PDL pattern module to the PDL Interpreter");
12979d7f618SChris Lattner 
13079d7f618SChris Lattner   // Generate the pdl bytecode.
13179d7f618SChris Lattner   impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
13279d7f618SChris Lattner       pdlModule, pdlPatterns.takeConstraintFunctions(),
13379d7f618SChris Lattner       pdlPatterns.takeRewriteFunctions());
13479d7f618SChris Lattner }
13579d7f618SChris Lattner 
136*e5639b3fSMehdi Amini FrozenRewritePatternSet::~FrozenRewritePatternSet() = default;
137