1b6eb26fdSRiver Riddle //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2b6eb26fdSRiver Riddle //
3b6eb26fdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b6eb26fdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5b6eb26fdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b6eb26fdSRiver Riddle //
7b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===//
8b6eb26fdSRiver Riddle //
9b6eb26fdSRiver Riddle // This file implements an applicator that applies pattern rewrites based upon a
10b6eb26fdSRiver Riddle // user defined cost model.
11b6eb26fdSRiver Riddle //
12b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===//
13b6eb26fdSRiver Riddle
14b6eb26fdSRiver Riddle #include "mlir/Rewrite/PatternApplicator.h"
15abfd1a8bSRiver Riddle #include "ByteCode.h"
16b6eb26fdSRiver Riddle #include "llvm/Support/Debug.h"
17b6eb26fdSRiver Riddle
186176a8f9SFrederik Gossen #define DEBUG_TYPE "pattern-application"
1976f3c2f3SRiver Riddle
20b6eb26fdSRiver Riddle using namespace mlir;
21abfd1a8bSRiver Riddle using namespace mlir::detail;
22abfd1a8bSRiver Riddle
PatternApplicator(const FrozenRewritePatternSet & frozenPatternList)23abfd1a8bSRiver Riddle PatternApplicator::PatternApplicator(
2479d7f618SChris Lattner const FrozenRewritePatternSet &frozenPatternList)
25abfd1a8bSRiver Riddle : frozenPatternList(frozenPatternList) {
26abfd1a8bSRiver Riddle if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
27abfd1a8bSRiver Riddle mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
28abfd1a8bSRiver Riddle bytecode->initializeMutableState(*mutableByteCodeState);
29abfd1a8bSRiver Riddle }
30abfd1a8bSRiver Riddle }
31e5639b3fSMehdi Amini PatternApplicator::~PatternApplicator() = default;
32b6eb26fdSRiver Riddle
336176a8f9SFrederik Gossen #ifndef NDEBUG
3476f3c2f3SRiver Riddle /// Log a message for a pattern that is impossible to match.
logImpossibleToMatch(const Pattern & pattern)3576f3c2f3SRiver Riddle static void logImpossibleToMatch(const Pattern &pattern) {
3676f3c2f3SRiver Riddle llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
3776f3c2f3SRiver Riddle << "' because it is impossible to match or cannot lead "
3876f3c2f3SRiver Riddle "to legal IR (by cost model)\n";
3976f3c2f3SRiver Riddle }
40b6eb26fdSRiver Riddle
416176a8f9SFrederik Gossen /// Log IR after pattern application.
getDumpRootOp(Operation * op)426176a8f9SFrederik Gossen static Operation *getDumpRootOp(Operation *op) {
436176a8f9SFrederik Gossen return op->getParentWithTrait<mlir::OpTrait::IsIsolatedFromAbove>();
446176a8f9SFrederik Gossen }
logSucessfulPatternApplication(Operation * op)456176a8f9SFrederik Gossen static void logSucessfulPatternApplication(Operation *op) {
466176a8f9SFrederik Gossen llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n";
476176a8f9SFrederik Gossen op->dump();
486176a8f9SFrederik Gossen llvm::dbgs() << "\n\n";
496176a8f9SFrederik Gossen }
506176a8f9SFrederik Gossen #endif
516176a8f9SFrederik Gossen
applyCostModel(CostModel model)52b6eb26fdSRiver Riddle void PatternApplicator::applyCostModel(CostModel model) {
53abfd1a8bSRiver Riddle // Apply the cost model to the bytecode patterns first, and then the native
54abfd1a8bSRiver Riddle // patterns.
55abfd1a8bSRiver Riddle if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
56*e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(bytecode->getPatterns()))
57abfd1a8bSRiver Riddle mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
58abfd1a8bSRiver Riddle }
59abfd1a8bSRiver Riddle
6076f3c2f3SRiver Riddle // Copy over the patterns so that we can sort by benefit based on the cost
6176f3c2f3SRiver Riddle // model. Patterns that are already impossible to match are ignored.
62b6eb26fdSRiver Riddle patterns.clear();
6376f3c2f3SRiver Riddle for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
6476f3c2f3SRiver Riddle for (const RewritePattern *pattern : it.second) {
6576f3c2f3SRiver Riddle if (pattern->getBenefit().isImpossibleToMatch())
666176a8f9SFrederik Gossen LLVM_DEBUG(logImpossibleToMatch(*pattern));
67b6eb26fdSRiver Riddle else
6876f3c2f3SRiver Riddle patterns[it.first].push_back(pattern);
6976f3c2f3SRiver Riddle }
7076f3c2f3SRiver Riddle }
7176f3c2f3SRiver Riddle anyOpPatterns.clear();
7276f3c2f3SRiver Riddle for (const RewritePattern &pattern :
7376f3c2f3SRiver Riddle frozenPatternList.getMatchAnyOpNativePatterns()) {
7476f3c2f3SRiver Riddle if (pattern.getBenefit().isImpossibleToMatch())
756176a8f9SFrederik Gossen LLVM_DEBUG(logImpossibleToMatch(pattern));
7676f3c2f3SRiver Riddle else
7776f3c2f3SRiver Riddle anyOpPatterns.push_back(&pattern);
78b6eb26fdSRiver Riddle }
79b6eb26fdSRiver Riddle
80b6eb26fdSRiver Riddle // Sort the patterns using the provided cost model.
813fffffa8SRiver Riddle llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
823fffffa8SRiver Riddle auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
83b6eb26fdSRiver Riddle return benefits[lhs] > benefits[rhs];
84b6eb26fdSRiver Riddle };
853fffffa8SRiver Riddle auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
86b6eb26fdSRiver Riddle // Special case for one pattern in the list, which is the most common case.
87b6eb26fdSRiver Riddle if (list.size() == 1) {
88b6eb26fdSRiver Riddle if (model(*list.front()).isImpossibleToMatch()) {
896176a8f9SFrederik Gossen LLVM_DEBUG(logImpossibleToMatch(*list.front()));
90b6eb26fdSRiver Riddle list.clear();
91b6eb26fdSRiver Riddle }
92b6eb26fdSRiver Riddle return;
93b6eb26fdSRiver Riddle }
94b6eb26fdSRiver Riddle
95b6eb26fdSRiver Riddle // Collect the dynamic benefits for the current pattern list.
96b6eb26fdSRiver Riddle benefits.clear();
973fffffa8SRiver Riddle for (const Pattern *pat : list)
98b6eb26fdSRiver Riddle benefits.try_emplace(pat, model(*pat));
99b6eb26fdSRiver Riddle
100b6eb26fdSRiver Riddle // Sort patterns with highest benefit first, and remove those that are
101b6eb26fdSRiver Riddle // impossible to match.
102b6eb26fdSRiver Riddle std::stable_sort(list.begin(), list.end(), cmp);
1036176a8f9SFrederik Gossen while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
1046176a8f9SFrederik Gossen LLVM_DEBUG(logImpossibleToMatch(*list.back()));
1056176a8f9SFrederik Gossen list.pop_back();
1066176a8f9SFrederik Gossen }
107b6eb26fdSRiver Riddle };
108b6eb26fdSRiver Riddle for (auto &it : patterns)
109b6eb26fdSRiver Riddle processPatternList(it.second);
110b6eb26fdSRiver Riddle processPatternList(anyOpPatterns);
111b6eb26fdSRiver Riddle }
112b6eb26fdSRiver Riddle
walkAllPatterns(function_ref<void (const Pattern &)> walk)113b6eb26fdSRiver Riddle void PatternApplicator::walkAllPatterns(
114b6eb26fdSRiver Riddle function_ref<void(const Pattern &)> walk) {
11576f3c2f3SRiver Riddle for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
11676f3c2f3SRiver Riddle for (const auto &pattern : it.second)
11776f3c2f3SRiver Riddle walk(*pattern);
11876f3c2f3SRiver Riddle for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
1193fffffa8SRiver Riddle walk(it);
120abfd1a8bSRiver Riddle if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
121abfd1a8bSRiver Riddle for (const Pattern &it : bytecode->getPatterns())
122abfd1a8bSRiver Riddle walk(it);
123abfd1a8bSRiver Riddle }
124b6eb26fdSRiver Riddle }
125b6eb26fdSRiver Riddle
matchAndRewrite(Operation * op,PatternRewriter & rewriter,function_ref<bool (const Pattern &)> canApply,function_ref<void (const Pattern &)> onFailure,function_ref<LogicalResult (const Pattern &)> onSuccess)126b6eb26fdSRiver Riddle LogicalResult PatternApplicator::matchAndRewrite(
127b6eb26fdSRiver Riddle Operation *op, PatternRewriter &rewriter,
128b6eb26fdSRiver Riddle function_ref<bool(const Pattern &)> canApply,
129b6eb26fdSRiver Riddle function_ref<void(const Pattern &)> onFailure,
130b6eb26fdSRiver Riddle function_ref<LogicalResult(const Pattern &)> onSuccess) {
131abfd1a8bSRiver Riddle // Before checking native patterns, first match against the bytecode. This
132abfd1a8bSRiver Riddle // won't automatically perform any rewrites so there is no need to worry about
133abfd1a8bSRiver Riddle // conflicts.
134abfd1a8bSRiver Riddle SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
135abfd1a8bSRiver Riddle const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
136abfd1a8bSRiver Riddle if (bytecode)
137abfd1a8bSRiver Riddle bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
138abfd1a8bSRiver Riddle
139b6eb26fdSRiver Riddle // Check to see if there are patterns matching this specific operation type.
1403fffffa8SRiver Riddle MutableArrayRef<const RewritePattern *> opPatterns;
141b6eb26fdSRiver Riddle auto patternIt = patterns.find(op->getName());
142b6eb26fdSRiver Riddle if (patternIt != patterns.end())
143b6eb26fdSRiver Riddle opPatterns = patternIt->second;
144b6eb26fdSRiver Riddle
145b6eb26fdSRiver Riddle // Process the patterns for that match the specific operation type, and any
146b6eb26fdSRiver Riddle // operation type in an interleaved fashion.
14785ab413bSRiver Riddle unsigned opIt = 0, opE = opPatterns.size();
14885ab413bSRiver Riddle unsigned anyIt = 0, anyE = anyOpPatterns.size();
14985ab413bSRiver Riddle unsigned pdlIt = 0, pdlE = pdlMatches.size();
15085ab413bSRiver Riddle LogicalResult result = failure();
15185ab413bSRiver Riddle do {
152abfd1a8bSRiver Riddle // Find the next pattern with the highest benefit.
153abfd1a8bSRiver Riddle const Pattern *bestPattern = nullptr;
15485ab413bSRiver Riddle unsigned *bestPatternIt = &opIt;
155abfd1a8bSRiver Riddle const PDLByteCode::MatchResult *pdlMatch = nullptr;
15685ab413bSRiver Riddle
157abfd1a8bSRiver Riddle /// Operation specific patterns.
15885ab413bSRiver Riddle if (opIt < opE)
15985ab413bSRiver Riddle bestPattern = opPatterns[opIt];
160abfd1a8bSRiver Riddle /// Operation agnostic patterns.
16185ab413bSRiver Riddle if (anyIt < anyE &&
16285ab413bSRiver Riddle (!bestPattern ||
16385ab413bSRiver Riddle bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
16485ab413bSRiver Riddle bestPatternIt = &anyIt;
16585ab413bSRiver Riddle bestPattern = anyOpPatterns[anyIt];
16685ab413bSRiver Riddle }
167abfd1a8bSRiver Riddle /// PDL patterns.
16885ab413bSRiver Riddle if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
16985ab413bSRiver Riddle pdlMatches[pdlIt].benefit)) {
17085ab413bSRiver Riddle bestPatternIt = &pdlIt;
17185ab413bSRiver Riddle pdlMatch = &pdlMatches[pdlIt];
17285ab413bSRiver Riddle bestPattern = pdlMatch->pattern;
173abfd1a8bSRiver Riddle }
174abfd1a8bSRiver Riddle if (!bestPattern)
175abfd1a8bSRiver Riddle break;
176b6eb26fdSRiver Riddle
17785ab413bSRiver Riddle // Update the pattern iterator on failure so that this pattern isn't
17885ab413bSRiver Riddle // attempted again.
17985ab413bSRiver Riddle ++(*bestPatternIt);
18085ab413bSRiver Riddle
181b6eb26fdSRiver Riddle // Check that the pattern can be applied.
182abfd1a8bSRiver Riddle if (canApply && !canApply(*bestPattern))
183abfd1a8bSRiver Riddle continue;
184b6eb26fdSRiver Riddle
185b6eb26fdSRiver Riddle // Try to match and rewrite this pattern. The patterns are sorted by
186abfd1a8bSRiver Riddle // benefit, so if we match we can immediately rewrite. For PDL patterns, the
187abfd1a8bSRiver Riddle // match has already been performed, we just need to rewrite.
188b6eb26fdSRiver Riddle rewriter.setInsertionPoint(op);
1896176a8f9SFrederik Gossen #ifndef NDEBUG
1906176a8f9SFrederik Gossen // Operation `op` may be invalidated after applying the rewrite pattern.
1916176a8f9SFrederik Gossen Operation *dumpRootOp = getDumpRootOp(op);
1926176a8f9SFrederik Gossen #endif
193abfd1a8bSRiver Riddle if (pdlMatch) {
194abfd1a8bSRiver Riddle bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
19585ab413bSRiver Riddle result = success(!onSuccess || succeeded(onSuccess(*bestPattern)));
196abfd1a8bSRiver Riddle } else {
19785ab413bSRiver Riddle const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
198e2a77644SButygin
199e2a77644SButygin LLVM_DEBUG(llvm::dbgs()
200e2a77644SButygin << "Trying to match \"" << pattern->getDebugName() << "\"\n");
20185ab413bSRiver Riddle result = pattern->matchAndRewrite(op, rewriter);
202e2a77644SButygin LLVM_DEBUG(llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result "
203e2a77644SButygin << succeeded(result) << "\n");
204e2a77644SButygin
20585ab413bSRiver Riddle if (succeeded(result) && onSuccess && failed(onSuccess(*pattern)))
20685ab413bSRiver Riddle result = failure();
207abfd1a8bSRiver Riddle }
2086176a8f9SFrederik Gossen if (succeeded(result)) {
2096176a8f9SFrederik Gossen LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
21085ab413bSRiver Riddle break;
2116176a8f9SFrederik Gossen }
212b6eb26fdSRiver Riddle
213abfd1a8bSRiver Riddle // Perform any necessary cleanups.
214b6eb26fdSRiver Riddle if (onFailure)
215abfd1a8bSRiver Riddle onFailure(*bestPattern);
21685ab413bSRiver Riddle } while (true);
21785ab413bSRiver Riddle
21885ab413bSRiver Riddle if (mutableByteCodeState)
21985ab413bSRiver Riddle mutableByteCodeState->cleanupAfterMatchAndRewrite();
22085ab413bSRiver Riddle return result;
221b6eb26fdSRiver Riddle }
222